diff --git a/.vscode/settings.json b/.vscode/settings.json
index 97fafffbe..81f679c3e 100644
--- a/.vscode/settings.json
+++ b/.vscode/settings.json
@@ -1,11 +1,14 @@
{
- "python.linting.pylintEnabled": true,
- "python.linting.enabled": true,
- "python.testing.pytestArgs": ["."],
- "python.testing.unittestEnabled": false,
- "python.testing.pytestEnabled": true,
- "python.analysis.typeCheckingMode": "basic",
- "python.formatting.provider": "black",
- "python.languageServer": "Pylance",
- "rust-analyzer.linkedProjects": ["./manager/Cargo.toml"]
+ "python.linting.pylintEnabled": true,
+ "python.linting.enabled": true,
+ "python.testing.pytestArgs": ["."],
+ "python.testing.unittestEnabled": false,
+ "python.testing.pytestEnabled": true,
+ "python.analysis.typeCheckingMode": "basic",
+ "python.formatting.provider": "none",
+ "python.languageServer": "Pylance",
+ "rust-analyzer.linkedProjects": ["./manager/Cargo.toml"],
+ "[python]": {
+ "editor.defaultFormatter": "ms-python.black-formatter"
+ }
}
diff --git a/api/routes/settings.py b/api/routes/settings.py
index 76805a060..27f8ed403 100644
--- a/api/routes/settings.py
+++ b/api/routes/settings.py
@@ -4,7 +4,7 @@
from fastapi import APIRouter
from core import config
-from core.config.config import update_config
+from core.config._config import update_config
router = APIRouter(tags=["settings"])
diff --git a/core/config/__init__.py b/core/config/__init__.py
index 2d52a925a..780069840 100644
--- a/core/config/__init__.py
+++ b/core/config/__init__.py
@@ -2,7 +2,7 @@
from diffusers.utils.constants import DIFFUSERS_CACHE
-from .config import (
+from ._config import (
Configuration,
Img2ImgConfig,
Txt2ImgConfig,
diff --git a/core/config/_config.py b/core/config/_config.py
new file mode 100644
index 000000000..b01d704ff
--- /dev/null
+++ b/core/config/_config.py
@@ -0,0 +1,75 @@
+import logging
+from dataclasses import Field, dataclass, field, fields
+
+from dataclasses_json import CatchAll, DataClassJsonMixin, Undefined, dataclass_json
+
+from .api_settings import APIConfig
+from .bot_settings import BotConfig
+from .default_settings import (
+ Txt2ImgConfig,
+ Img2ImgConfig,
+ InpaintingConfig,
+ ControlNetConfig,
+ UpscaleConfig,
+ AITemplateConfig,
+ ONNXConfig,
+)
+from .frontend_settings import FrontendConfig
+from .interrogator_settings import InterrogatorConfig
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass_json(undefined=Undefined.INCLUDE)
+@dataclass
+class Configuration(DataClassJsonMixin):
+ "Main configuration class for the application"
+
+ txt2img: Txt2ImgConfig = field(default_factory=Txt2ImgConfig)
+ img2img: Img2ImgConfig = field(default_factory=Img2ImgConfig)
+ inpainting: InpaintingConfig = field(default_factory=InpaintingConfig)
+ controlnet: ControlNetConfig = field(default_factory=ControlNetConfig)
+ upscale: UpscaleConfig = field(default_factory=UpscaleConfig)
+ api: APIConfig = field(default_factory=APIConfig)
+ interrogator: InterrogatorConfig = field(default_factory=InterrogatorConfig)
+ aitemplate: AITemplateConfig = field(default_factory=AITemplateConfig)
+ onnx: ONNXConfig = field(default_factory=ONNXConfig)
+ bot: BotConfig = field(default_factory=BotConfig)
+ frontend: FrontendConfig = field(default_factory=FrontendConfig)
+ extra: CatchAll = field(default_factory=dict)
+
+
+def save_config(config: Configuration):
+ "Save the configuration to a file"
+
+ logger.info("Saving configuration to data/settings.json")
+
+ with open("data/settings.json", "w", encoding="utf-8") as f:
+ f.write(config.to_json(ensure_ascii=False, indent=4))
+
+
+def update_config(config: Configuration, new_config: Configuration):
+ "Update the configuration with new values instead of overwriting the pointer"
+
+ for cls_field in fields(new_config):
+ assert isinstance(cls_field, Field)
+ setattr(config, cls_field.name, getattr(new_config, cls_field.name))
+
+
+def load_config():
+ "Load the configuration from a file"
+
+ logger.info("Loading configuration from data/settings.json")
+
+ try:
+ with open("data/settings.json", "r", encoding="utf-8") as f:
+ config = Configuration.from_json(f.read())
+ logger.info("Configuration loaded from data/settings.json")
+ return config
+
+ except FileNotFoundError:
+ logger.info("data/settings.json not found, creating a new one")
+ config = Configuration()
+ save_config(config)
+ logger.info("Configuration saved to data/settings.json")
+ return config
\ No newline at end of file
diff --git a/core/config/api_settings.py b/core/config/api_settings.py
new file mode 100644
index 000000000..bb7491a09
--- /dev/null
+++ b/core/config/api_settings.py
@@ -0,0 +1,110 @@
+from dataclasses import dataclass, field
+from typing import List, Literal, Union
+
+import torch
+
+
+@dataclass
+class APIConfig:
+ "Configuration for the API"
+
+ # Autoload
+ autoloaded_textual_inversions: List[str] = field(default_factory=list)
+
+ # Websockets and intervals
+ websocket_sync_interval: float = 0.02
+ websocket_perf_interval: float = 1.0
+
+ # TomeSD
+ use_tomesd: bool = False # really extreme, probably will have to wait around until tome improves a bit
+ tomesd_ratio: float = 0.25 # had to tone this down, 0.4 is too big of a context loss even on short prompts
+ tomesd_downsample_layers: Literal[1, 2, 4, 8] = 1
+
+ image_preview_delay: float = 2.0
+
+ # General optimizations
+ autocast: bool = False
+ attention_processor: Literal[
+ "xformers", "sdpa", "cross-attention", "subquadratic", "multihead"
+ ] = "sdpa"
+ subquadratic_size: int = 512
+ attention_slicing: Union[int, Literal["auto", "disabled"]] = "disabled"
+ channels_last: bool = True
+ vae_slicing: bool = True
+ vae_tiling: bool = False
+ trace_model: bool = False
+ clear_memory_policy: Literal["always", "after_disconnect", "never"] = "always"
+ offload: bool = False
+ data_type: Literal["float32", "float16", "bfloat16"] = "float16"
+
+ # CUDA specific optimizations
+ reduced_precision: bool = False
+ cudnn_benchmark: bool = False
+ deterministic_generation: bool = False
+
+ # Device settings
+ device_id: int = 0
+ device_type: Literal["cpu", "cuda", "mps", "directml", "intel", "vulkan"] = "cuda"
+
+ # Critical
+ enable_shutdown: bool = True
+
+ # VAE
+ upcast_vae: bool = False
+
+ # CLIP
+ clip_skip: int = 1
+ clip_quantization: Literal["full", "int8", "int4"] = "full"
+
+ huggingface_style_parsing: bool = False
+
+ # Saving
+ save_path_template: str = "{folder}/{prompt}/{id}-{index}.{extension}"
+ image_extension: Literal["png", "webp", "jpeg"] = "png"
+ image_quality: int = 95
+ image_return_format: Literal["bytes", "base64"] = "base64"
+
+ # Grid
+ disable_grid: bool = False
+
+ # Torch compile
+ torch_compile: bool = False
+ torch_compile_fullgraph: bool = False
+ torch_compile_dynamic: bool = False
+ torch_compile_backend: str = "inductor"
+ torch_compile_mode: Literal[
+ "default",
+ "reduce-overhead",
+ "max-autotune",
+ ] = "reduce-overhead"
+
+ @property
+ def dtype(self):
+ "Return selected data type"
+ if self.data_type == "bfloat16":
+ return torch.bfloat16
+ if self.data_type == "float16":
+ return torch.float16
+ return torch.float32
+
+ @property
+ def device(self):
+ "Return the device"
+
+ if self.device_type == "intel":
+ from core.inference.functions import is_ipex_available
+
+ return torch.device("xpu" if is_ipex_available() else "cpu")
+
+ if self.device_type in ["cpu", "mps"]:
+ return torch.device(self.device_type)
+
+ if self.device_type in ["vulkan", "cuda"]:
+ return torch.device(f"{self.device_type}:{self.device_id}")
+
+ if self.device_type == "directml":
+ import torch_directml # pylint: disable=import-error
+
+ return torch_directml.device()
+ else:
+ raise ValueError(f"Device type {self.device_type} not supported")
\ No newline at end of file
diff --git a/core/config/bot_settings.py b/core/config/bot_settings.py
new file mode 100644
index 000000000..3db429d58
--- /dev/null
+++ b/core/config/bot_settings.py
@@ -0,0 +1,14 @@
+from dataclasses import dataclass
+
+from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers
+
+
+@dataclass
+class BotConfig:
+ "Configuration for the bot"
+
+ default_scheduler: KarrasDiffusionSchedulers = (
+ KarrasDiffusionSchedulers.DPMSolverSinglestepScheduler
+ )
+ verbose: bool = False
+ use_default_negative_prompt: bool = True
\ No newline at end of file
diff --git a/core/config/config.py b/core/config/config.py
deleted file mode 100644
index d2dad6baf..000000000
--- a/core/config/config.py
+++ /dev/null
@@ -1,316 +0,0 @@
-import logging
-import multiprocessing
-from dataclasses import Field, dataclass, field, fields
-from typing import List, Literal, Optional, Union
-
-import torch
-from dataclasses_json import CatchAll, DataClassJsonMixin, Undefined, dataclass_json
-from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers
-
-logger = logging.getLogger(__name__)
-
-
-@dataclass
-class QuantDict:
- vae_decoder: Optional[bool] = None
- vae_encoder: Optional[bool] = None
- unet: Optional[bool] = None
- text_encoder: Optional[bool] = None
-
-
-@dataclass
-class Txt2ImgConfig:
- "Configuration for the text to image pipeline"
-
- width: int = 512
- height: int = 512
- seed: int = -1
- cfg_scale: int = 7
- sampler: int = KarrasDiffusionSchedulers.DPMSolverSinglestepScheduler.value
- prompt: str = ""
- negative_prompt: str = ""
- steps: int = 40
- batch_count: int = 1
- batch_size: int = 1
- self_attention_scale: float = 0.0
-
-
-@dataclass
-class Img2ImgConfig:
- "Configuration for the image to image pipeline"
-
- width: int = 512
- height: int = 512
- seed: int = -1
- cfg_scale: int = 7
- sampler: int = KarrasDiffusionSchedulers.DPMSolverSinglestepScheduler.value
- prompt: str = ""
- negative_prompt: str = ""
- steps: int = 40
- batch_count: int = 1
- batch_size: int = 1
- resize_method: int = 0
- denoising_strength: float = 0.6
- self_attention_scale: float = 0.0
-
-
-@dataclass
-class InpaintingConfig:
- "Configuration for the inpainting pipeline"
-
- prompt: str = ""
- negative_prompt: str = ""
- width: int = 512
- height: int = 512
- steps: int = 40
- cfg_scale: int = 7
- seed: int = -1
- batch_count: int = 1
- batch_size: int = 1
- sampler: int = KarrasDiffusionSchedulers.DPMSolverSinglestepScheduler.value
- self_attention_scale: float = 0.0
-
-
-@dataclass
-class ControlNetConfig:
- "Configuration for the inpainting pipeline"
-
- prompt: str = ""
- negative_prompt: str = ""
- width: int = 512
- height: int = 512
- seed: int = -1
- cfg_scale: int = 7
- steps: int = 40
- batch_count: int = 1
- batch_size: int = 1
- sampler: int = KarrasDiffusionSchedulers.DPMSolverSinglestepScheduler.value
- controlnet: str = "lllyasviel/sd-controlnet-canny"
- controlnet_conditioning_scale: float = 1.0
- detection_resolution: int = 512
- is_preprocessed: bool = False
- save_preprocessed: bool = False
- return_preprocessed: bool = True
-
-
-@dataclass
-class UpscaleConfig:
- "Configuration for the RealESRGAN upscaler"
-
- model: str = "RealESRGAN_x4plus_anime_6B"
- upscale_factor: int = 4
- tile_size: int = field(default=128)
- tile_padding: int = field(default=10)
-
-
-@dataclass
-class APIConfig:
- "Configuration for the API"
-
- # Websockets and intervals
- websocket_sync_interval: float = 0.02
- websocket_perf_interval: float = 1.0
-
- # TomeSD
- use_tomesd: bool = False # really extreme, probably will have to wait around until tome improves a bit
- tomesd_ratio: float = 0.25 # had to tone this down, 0.4 is too big of a context loss even on short prompts
- tomesd_downsample_layers: Literal[1, 2, 4, 8] = 1
-
- image_preview_delay: float = 2.0
-
- # General optimizations
- autocast: bool = False
- attention_processor: Literal[
- "xformers", "sdpa", "cross-attention", "subquadratic", "multihead"
- ] = "sdpa"
- subquadratic_size: int = 512
- attention_slicing: Union[int, Literal["auto", "disabled"]] = "disabled"
- channels_last: bool = True
- vae_slicing: bool = True
- vae_tiling: bool = False
- trace_model: bool = False
- clear_memory_policy: Literal["always", "after_disconnect", "never"] = "always"
- offload: Literal["module", "model", "disabled"] = "disabled"
- data_type: Literal["float32", "float16", "bfloat16"] = "float16"
-
- # CUDA specific optimizations
- reduced_precision: bool = False
- cudnn_benchmark: bool = False
- deterministic_generation: bool = False
-
- # Device settings
- device_id: int = 0
- device_type: Literal["cpu", "cuda", "mps", "directml", "intel", "vulkan"] = "cuda"
-
- # Critical
- enable_shutdown: bool = True
-
- # CLIP
- clip_skip: int = 1
- clip_quantization: Literal["full", "int8", "int4"] = "full"
-
- # Autoload
- autoloaded_textual_inversions: List[str] = field(default_factory=list)
-
- huggingface_style_parsing: bool = False
-
- # Saving
- save_path_template: str = "{folder}/{prompt}/{id}-{index}.{extension}"
- image_extension: Literal["png", "webp", "jpeg"] = "png"
- image_quality: int = 95
- image_return_format: Literal["bytes", "base64"] = "base64"
-
- # Grid
- disable_grid: bool = False
-
- # Torch compile
- torch_compile: bool = False
- torch_compile_fullgraph: bool = False
- torch_compile_dynamic: bool = False
- torch_compile_backend: str = "inductor"
- torch_compile_mode: Literal[
- "default",
- "reduce-overhead",
- "max-autotune",
- ] = "reduce-overhead"
-
- @property
- def dtype(self):
- "Return selected data type"
- if self.data_type == "bfloat16":
- return torch.bfloat16
- if self.data_type == "float16":
- return torch.float16
- return torch.float32
-
- @property
- def device(self):
- "Return the device"
-
- if self.device_type == "intel":
- from core.inference.functions import is_ipex_available
-
- return torch.device("xpu" if is_ipex_available() else "cpu")
-
- if self.device_type in ["cpu", "mps"]:
- return torch.device(self.device_type)
-
- if self.device_type in ["vulkan", "cuda"]:
- return torch.device(f"{self.device_type}:{self.device_id}")
-
- if self.device_type == "directml":
- import torch_directml # pylint: disable=import-error
-
- return torch_directml.device()
- else:
- raise ValueError(f"Device type {self.device_type} not supported")
-
-
-@dataclass
-class AITemplateConfig:
- "Configuration for model inference and acceleration"
-
- num_threads: int = field(default=min(multiprocessing.cpu_count() - 1, 8))
-
-
-@dataclass
-class ONNXConfig:
- "Configuration for ONNX acceleration"
-
- quant_dict: QuantDict = field(default_factory=QuantDict)
-
-
-@dataclass
-class BotConfig:
- "Configuration for the bot"
-
- default_scheduler: KarrasDiffusionSchedulers = (
- KarrasDiffusionSchedulers.DPMSolverSinglestepScheduler
- )
- verbose: bool = False
- use_default_negative_prompt: bool = True
-
-
-@dataclass
-class InterrogatorConfig:
- "Configuration for interrogation models"
-
- # set to "Salesforce/blip-image-captioning-base" for an extra gig of vram
- caption_model: str = "Salesforce/blip-image-captioning-large"
- visualizer_model: str = "ViT-L-14/openai"
-
- offload_captioner: bool = False
- offload_visualizer: bool = False
-
- chunk_size: int = 2048 # set to 1024 for lower vram usage
- flavor_intermediate_count: int = 2048 # set to 1024 for lower vram usage
-
- flamingo_model: str = "dhansmair/flamingo-mini"
-
- caption_max_length: int = 32
-
-
-@dataclass
-class FrontendConfig:
- "Configuration for the frontend"
-
- theme: Literal["dark", "light"] = "dark"
- enable_theme_editor: bool = False
- image_browser_columns: int = 5
- on_change_timer: int = 0
- nsfw_ok_threshold: int = 0
-
-
-@dataclass_json(undefined=Undefined.INCLUDE)
-@dataclass
-class Configuration(DataClassJsonMixin):
- "Main configuration class for the application"
-
- txt2img: Txt2ImgConfig = field(default_factory=Txt2ImgConfig)
- img2img: Img2ImgConfig = field(default_factory=Img2ImgConfig)
- inpainting: InpaintingConfig = field(default_factory=InpaintingConfig)
- controlnet: ControlNetConfig = field(default_factory=ControlNetConfig)
- upscale: UpscaleConfig = field(default_factory=UpscaleConfig)
- api: APIConfig = field(default_factory=APIConfig)
- interrogator: InterrogatorConfig = field(default_factory=InterrogatorConfig)
- aitemplate: AITemplateConfig = field(default_factory=AITemplateConfig)
- onnx: ONNXConfig = field(default_factory=ONNXConfig)
- bot: BotConfig = field(default_factory=BotConfig)
- frontend: FrontendConfig = field(default_factory=FrontendConfig)
- extra: CatchAll = field(default_factory=dict)
-
-
-def save_config(config: Configuration):
- "Save the configuration to a file"
-
- logger.info("Saving configuration to data/settings.json")
-
- with open("data/settings.json", "w", encoding="utf-8") as f:
- f.write(config.to_json(ensure_ascii=False, indent=4))
-
-
-def update_config(config: Configuration, new_config: Configuration):
- "Update the configuration with new values instead of overwriting the pointer"
-
- for cls_field in fields(new_config):
- assert isinstance(cls_field, Field)
- setattr(config, cls_field.name, getattr(new_config, cls_field.name))
-
-
-def load_config():
- "Load the configuration from a file"
-
- logger.info("Loading configuration from data/settings.json")
-
- try:
- with open("data/settings.json", "r", encoding="utf-8") as f:
- config = Configuration.from_json(f.read())
- logger.info("Configuration loaded from data/settings.json")
- return config
-
- except FileNotFoundError:
- logger.info("data/settings.json not found, creating a new one")
- config = Configuration()
- save_config(config)
- logger.info("Configuration saved to data/settings.json")
- return config
diff --git a/core/config/default_settings.py b/core/config/default_settings.py
new file mode 100644
index 000000000..3f9b51209
--- /dev/null
+++ b/core/config/default_settings.py
@@ -0,0 +1,114 @@
+from dataclasses import dataclass, field
+import multiprocessing
+from typing import Optional
+
+from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers
+
+
+@dataclass
+class QuantDict:
+ "Configuration for ONNX quantization"
+
+ vae_decoder: Optional[bool] = None
+ vae_encoder: Optional[bool] = None
+ unet: Optional[bool] = None
+ text_encoder: Optional[bool] = None
+
+
+@dataclass
+class Txt2ImgConfig:
+ "Configuration for the text to image pipeline"
+
+ width: int = 512
+ height: int = 512
+ seed: int = -1
+ cfg_scale: int = 7
+ sampler: int = KarrasDiffusionSchedulers.DPMSolverSinglestepScheduler.value
+ prompt: str = ""
+ negative_prompt: str = ""
+ steps: int = 40
+ batch_count: int = 1
+ batch_size: int = 1
+ self_attention_scale: float = 0.0
+
+
+@dataclass
+class Img2ImgConfig:
+ "Configuration for the image to image pipeline"
+
+ width: int = 512
+ height: int = 512
+ seed: int = -1
+ cfg_scale: int = 7
+ sampler: int = KarrasDiffusionSchedulers.DPMSolverSinglestepScheduler.value
+ prompt: str = ""
+ negative_prompt: str = ""
+ steps: int = 40
+ batch_count: int = 1
+ batch_size: int = 1
+ resize_method: int = 0
+ denoising_strength: float = 0.6
+ self_attention_scale: float = 0.0
+
+
+@dataclass
+class InpaintingConfig:
+ "Configuration for the inpainting pipeline"
+
+ prompt: str = ""
+ negative_prompt: str = ""
+ width: int = 512
+ height: int = 512
+ steps: int = 40
+ cfg_scale: int = 7
+ seed: int = -1
+ batch_count: int = 1
+ batch_size: int = 1
+ sampler: int = KarrasDiffusionSchedulers.DPMSolverSinglestepScheduler.value
+ self_attention_scale: float = 0.0
+
+
+@dataclass
+class ControlNetConfig:
+ "Configuration for the inpainting pipeline"
+
+ prompt: str = ""
+ negative_prompt: str = ""
+ width: int = 512
+ height: int = 512
+ seed: int = -1
+ cfg_scale: int = 7
+ steps: int = 40
+ batch_count: int = 1
+ batch_size: int = 1
+ sampler: int = KarrasDiffusionSchedulers.DPMSolverSinglestepScheduler.value
+ controlnet: str = "lllyasviel/sd-controlnet-canny"
+ controlnet_conditioning_scale: float = 1.0
+ detection_resolution: int = 512
+ is_preprocessed: bool = False
+ save_preprocessed: bool = False
+ return_preprocessed: bool = True
+
+
+@dataclass
+class UpscaleConfig:
+ "Configuration for the RealESRGAN upscaler"
+
+ model: str = "RealESRGAN_x4plus_anime_6B"
+ upscale_factor: int = 4
+ tile_size: int = field(default=128)
+ tile_padding: int = field(default=10)
+
+
+@dataclass
+class AITemplateConfig:
+ "Configuration for model inference and acceleration"
+
+ num_threads: int = field(default=min(multiprocessing.cpu_count() - 1, 8))
+
+
+@dataclass
+class ONNXConfig:
+ "Configuration for ONNX acceleration"
+
+ quant_dict: QuantDict = field(default_factory=QuantDict)
\ No newline at end of file
diff --git a/core/config/frontend_settings.py b/core/config/frontend_settings.py
new file mode 100644
index 000000000..20ceb13b4
--- /dev/null
+++ b/core/config/frontend_settings.py
@@ -0,0 +1,13 @@
+from dataclasses import dataclass
+from typing import Literal
+
+
+@dataclass
+class FrontendConfig:
+ "Configuration for the frontend"
+
+ theme: Literal["dark", "light"] = "dark"
+ enable_theme_editor: bool = False
+ image_browser_columns: int = 5
+ on_change_timer: int = 0
+ nsfw_ok_threshold: int = 0
\ No newline at end of file
diff --git a/core/config/interrogator_settings.py b/core/config/interrogator_settings.py
new file mode 100644
index 000000000..eb94aa4fd
--- /dev/null
+++ b/core/config/interrogator_settings.py
@@ -0,0 +1,20 @@
+from dataclasses import dataclass
+
+
+@dataclass
+class InterrogatorConfig:
+ "Configuration for interrogation models"
+
+ # set to "Salesforce/blip-image-captioning-base" for an extra gig of vram
+ caption_model: str = "Salesforce/blip-image-captioning-large"
+ visualizer_model: str = "ViT-L-14/openai"
+
+ offload_captioner: bool = False
+ offload_visualizer: bool = False
+
+ chunk_size: int = 2048 # set to 1024 for lower vram usage
+ flavor_intermediate_count: int = 2048 # set to 1024 for lower vram usage
+
+ flamingo_model: str = "dhansmair/flamingo-mini"
+
+ caption_max_length: int = 32
\ No newline at end of file
diff --git a/core/files.py b/core/files.py
index 71d401690..5134bcb98 100644
--- a/core/files.py
+++ b/core/files.py
@@ -82,7 +82,10 @@ def pytorch(self) -> List[ModelResponse]:
state="not loaded",
)
)
- elif ".safetensors" in model_name or ".ckpt" in model_name:
+ elif (self.checkpoint_converted_path / model_name).suffix in [
+ ".ckpt",
+ ".safetensors",
+ ]:
# Assuming that model is in Checkpoint / Safetensors format
models.append(
ModelResponse(
diff --git a/core/flags.py b/core/flags.py
index 0ce4506ec..690091fc9 100644
--- a/core/flags.py
+++ b/core/flags.py
@@ -7,8 +7,7 @@
"nearest",
"area",
"bilinear",
- "bislerp-original",
- "bislerp-tortured",
+ "bislerp",
"bicubic",
"nearest-exact",
]
diff --git a/core/inference/pytorch/pipeline.py b/core/inference/pytorch/pipeline.py
index 3ce27f49c..c77d74158 100644
--- a/core/inference/pytorch/pipeline.py
+++ b/core/inference/pytorch/pipeline.py
@@ -24,7 +24,7 @@
prepare_mask_latents,
preprocess_image,
)
-from core.optimizations import autocast
+from core.optimizations import autocast, upcast_vae, ensure_correct_device, unload_all
from .sag import CrossAttnStoreProcessor, pred_epsilon, pred_x0, sag_masking
@@ -108,27 +108,13 @@ def _execution_device(self):
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
- if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): # type: ignore
+ if self.device != torch.device("meta") and not hasattr(self.unet, "offload_device"): # type: ignore
return self.device
- for module in self.unet.modules(): # type: ignore
- if (
- hasattr(module, "_hf_hook")
- and hasattr(
- module._hf_hook, # pylint: disable=protected-access
- "execution_device",
- )
- and module._hf_hook.execution_device # pylint: disable=protected-access # type: ignore
- is not None
- ):
- return torch.device(
- module._hf_hook.execution_device # pylint: disable=protected-access # type: ignore
- )
- return self.device
+ return getattr(self.unet, "offload_device", self.device)
def _encode_prompt(
self,
prompt,
- _device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
@@ -171,6 +157,7 @@ def _encode_prompt(
" the batch size of `prompt`."
)
+ ensure_correct_device(self.text_encoder)
text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
pipe=self.parent,
prompt=prompt,
@@ -214,7 +201,14 @@ def _check_inputs(self, prompt, strength, callback_steps):
)
def _decode_latents(self, latents, height, width):
+ if config.api.upcast_vae:
+ upcast_vae(self.vae)
+ latents = latents.to(
+ next(iter(self.vae.post_quant_conv.parameters())).dtype
+ )
+
latents = 1 / 0.18215 * latents
+ ensure_correct_device(self.vae)
image = self.vae.decode(latents).sample # type: ignore
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
@@ -356,7 +350,6 @@ def __call__(
# 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt,
- device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
@@ -454,6 +447,7 @@ def get_map_size(_, __, output):
] # output.sample.shape[-2:] in older diffusers
# 8. Denoising loop
+ ensure_correct_device(self.unet) # type: ignore
with ExitStack() as gs:
if do_self_attention_guidance:
gs.enter_context(self.unet.mid_block.attentions[0].register_forward_hook(get_map_size)) # type: ignore
@@ -598,18 +592,18 @@ def get_map_size(_, __, output):
# 9. Post-processing
if output_type == "latent":
+ unload_all()
return latents, False
# TODO: maybe implement asymmetric vqgan?
image = self._decode_latents(latents, height=height, width=width)
+ unload_all()
+
# 11. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
- if hasattr(self, "final_offload_hook"):
- self.final_offload_hook.offload() # type: ignore
-
if not return_dict:
return image, False
diff --git a/core/inference/pytorch/pytorch.py b/core/inference/pytorch/pytorch.py
index 6a4726ab6..84e80261e 100755
--- a/core/inference/pytorch/pytorch.py
+++ b/core/inference/pytorch/pytorch.py
@@ -67,10 +67,6 @@ def __init__(
self.text_encoder: CLIPTextModel
self.tokenizer: CLIPTokenizer
self.scheduler: Any
- self.feature_extractor: Any
- self.requires_safety_checker: bool
- self.safety_checker: Any
- self.image_encoder: Any
self.controlnet: Optional[ControlNetModel] = None
self.current_controlnet: str = ""
@@ -99,9 +95,6 @@ def load(self):
self.text_encoder = pipe.text_encoder # type: ignore
self.tokenizer = pipe.tokenizer # type: ignore
self.scheduler = pipe.scheduler # type: ignore
- self.feature_extractor = pipe.feature_extractor # type: ignore
- self.requires_safety_checker = False # type: ignore
- self.safety_checker = pipe.safety_checker # type: ignore
if not self.bare:
# Autoload textual inversions
@@ -131,18 +124,19 @@ def change_vae(self, vae: str) -> None:
setattr(self, "original_vae", self.vae)
old_vae = getattr(self, "original_vae")
+ dtype = self.unet.dtype
+ device = old_vae.device
if vae == "default":
self.vae = old_vae
else:
- # Why the fuck do you think that's constant pylint?
- # Are you mentally insane?
- if Path(vae).is_dir():
- self.vae = AutoencoderKL.from_pretrained(vae) # type: ignore
+ if "/" in vae or Path(vae).is_dir():
+ self.vae = AutoencoderKL.from_pretrained(vae).to( # type: ignore
+ device=device, dtype=dtype
+ )
else:
self.vae = convert_vaept_to_diffusers(vae).to(
- device=old_vae.device, dtype=old_vae.dtype
+ device=device, dtype=dtype
)
- # This is at the end 'cause I've read horror stories about pythons prefetch system
self.vae_path = vae
def unload(self) -> None:
@@ -154,9 +148,6 @@ def unload(self) -> None:
self.text_encoder,
self.tokenizer,
self.scheduler,
- self.feature_extractor,
- self.requires_safety_checker,
- self.safety_checker,
)
if hasattr(self, "image_encoder"):
@@ -573,9 +564,9 @@ def save(self, path: str = "converted", safetensors: bool = False):
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
scheduler=self.scheduler,
- feature_extractor=self.feature_extractor,
- requires_safety_checker=self.requires_safety_checker,
- safety_checker=self.safety_checker,
+ feature_extractor=None, # type: ignore
+ requires_safety_checker=False,
+ safety_checker=None, # type: ignore
)
pipe.save_pretrained(path, safe_serialization=safetensors)
diff --git a/core/inference/pytorch/tiled_diffusion/__init__.py b/core/inference/pytorch/tiled_diffusion/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/core/inference/pytorch/tiled_diffusion/canvas.py b/core/inference/pytorch/tiled_diffusion/canvas.py
new file mode 100644
index 000000000..b99f13813
--- /dev/null
+++ b/core/inference/pytorch/tiled_diffusion/canvas.py
@@ -0,0 +1,473 @@
+# pylint: disable=attribute-defined-outside-init
+
+from dataclasses import dataclass
+from typing import Literal, Union, List, Optional, Callable
+from time import time
+import math
+
+from diffusers import (
+ DiffusionPipeline,
+ AutoencoderKL,
+ SchedulerMixin,
+ UNet2DConditionModel,
+)
+from transformers import CLIPTextModel, CLIPTokenizer
+import torch
+from tqdm import tqdm
+import numpy as np
+from numpy import pi, exp, sqrt
+from PIL.Image import Image
+
+from core.config import config
+from core.optimizations import autocast, upcast_vae, ensure_correct_device, unload_all
+from core.inference.utilities import preprocess_image, get_weighted_text_embeddings
+from core.utils import resize
+
+mask = Literal["constant", "gaussian", "quartic"]
+reroll = Literal["reset", "epsilon"]
+
+
+@dataclass
+class CanvasRegion:
+ """Class defining a region in the canvas."""
+
+ row_init: int
+ row_end: int
+ col_init: int
+ col_end: int
+ region_seed: int = None # type: ignore
+ noise_eps: float = 0.0
+
+ def __post_init__(self):
+ if self.region_seed is None:
+ self.region_seed = math.ceil(time())
+ coords = [self.row_init, self.row_end, self.col_init, self.col_end]
+ for coord in coords:
+ if coord < 0:
+ raise ValueError(
+ f"Region coordinates must be non-negative, found {coords}."
+ )
+ if coord % 8 != 0:
+ raise ValueError(
+ f"Region coordinates must be multiples of 8, found {coords}."
+ )
+ if self.noise_eps < 0:
+ raise ValueError(
+ f"Region epsilon must be non-negative, found {self.noise_eps}."
+ )
+ self.latent_row_init = self.row_init // 8
+ self.latent_row_end = self.row_end // 8
+ self.latent_col_init = self.col_init // 8
+ self.latent_col_end = self.col_end // 8
+
+ @property
+ def width(self):
+ "col_end - col_init"
+ return self.col_end - self.col_init
+
+ @property
+ def height(self):
+ "row_end - row_init"
+ return self.row_end - self.row_init
+
+ def get_region_generator(self, device: Union[torch.device, str] = "cpu"):
+ """Creates a generator for the region."""
+ return torch.Generator(device).manual_seed(self.region_seed)
+
+
+@dataclass
+class DiffusionRegion(CanvasRegion):
+ """Abstract class for places where diffusion is taking place."""
+
+
+@dataclass
+class RerollRegion(CanvasRegion):
+ """Class defining a region where latent rerolling will be taking place."""
+
+ reroll_mode: reroll = "reset"
+
+
+@dataclass
+class Text2ImageRegion(DiffusionRegion):
+ """Class defining a region where text guided diffusion will be taking place."""
+
+ prompt: str = ""
+ negative_prompt: str = ""
+ guidance_scale: float = 7.5
+ mask_type: mask = "gaussian"
+ mask_weight: float = 1.0
+
+ text_embeddings = None
+
+ def __post_init__(self):
+ super().__post_init__()
+ if self.mask_weight < 0:
+ raise ValueError(
+ f"Mask weight must be non-negative, found {self.mask_weight}."
+ )
+
+ @property
+ def do_classifier_free_guidance(self) -> bool:
+ """Whether to do classifier-free guidance (guidance_scale > 1.0)"""
+ return self.guidance_scale > 1.0
+
+
+@dataclass
+class Image2ImageRegion(DiffusionRegion):
+ """Class defining a region where image guided diffusion will be taking place."""
+
+ reference_image: Image = None # type: ignore
+ strength: float = 0.8
+
+ def __post_init__(self):
+ super().__post_init__()
+ if self.reference_image is None:
+ raise ValueError("Reference image must be provided.")
+ if self.strength < 0 or self.strength > 1:
+ raise ValueError(
+ f"Strength must be between 0 and 1, found {self.strength}."
+ )
+ self.reference_image = resize(self.reference_image, self.width, self.height)
+
+ def encode_reference_image(
+ self,
+ encoder,
+ generator: torch.Generator,
+ device: Union[torch.device, str] = "cpu",
+ ):
+ """Encodes the reference image for this diffusion region."""
+ ensure_correct_device(encoder) # type: ignore
+ img = preprocess_image(self.reference_image)
+ self.reference_latents = encoder.encode(img.to(device)).latent_dist.sample(
+ generator=generator
+ )
+ self.reference_latents = 0.18215 * self.reference_latents
+
+
+@dataclass
+class MaskWeightsBuilder:
+ """Auxiliary class to compute a tensor of weights for a given diffusion region."""
+
+ latent_space_dim: int
+ nbatch: int = 1
+
+ def compute_mask_weights(self, region: Text2ImageRegion) -> torch.Tensor:
+ """Computes a tensor of weights for the given diffusion region."""
+ if region.mask_type == "gaussian":
+ mask_weights = self._gaussian_weights(region)
+ elif region.mask_type == "constant":
+ mask_weights = self._constant_weights(region)
+ else:
+ mask_weights = self._quartic_weights(region)
+ return mask_weights
+
+ def _constant_weights(self, region: Text2ImageRegion) -> torch.Tensor:
+ """Computes a tensor of constant weights for the given diffusion region."""
+ return (
+ torch.ones(
+ (
+ self.nbatch,
+ self.latent_space_dim,
+ region.latent_row_end - region.latent_row_init,
+ region.latent_col_end - region.latent_col_init,
+ )
+ )
+ * region.mask_weight
+ )
+
+ def _gaussian_weights(self, region: Text2ImageRegion) -> torch.Tensor:
+ """Computes a tensor of gaussian weights for the given diffusion region."""
+ latent_width = region.latent_col_end - region.latent_col_init
+ latent_height = region.latent_row_end - region.latent_row_init
+
+ var = 0.01
+ midpoint = (
+ latent_width - 1
+ ) / 2 # -1 because index goes from 0 to latent_width - 1
+ x_probs = [
+ exp(
+ -(x - midpoint)
+ * (x - midpoint)
+ / (latent_width * latent_width)
+ / (2 * var)
+ )
+ / sqrt(2 * pi * var)
+ for x in range(latent_width)
+ ]
+ midpoint = (latent_height - 1) / 2
+ y_probs = [
+ exp(
+ -(y - midpoint)
+ * (y - midpoint)
+ / (latent_height * latent_height)
+ / (2 * var)
+ )
+ / sqrt(2 * pi * var)
+ for y in range(latent_height)
+ ]
+ weights = np.outer(y_probs, x_probs) * region.mask_weight
+ return torch.tile(
+ torch.tensor(weights), (self.nbatch, self.latent_space_dim, 1, 1)
+ )
+
+ def _quartic_weights(self, region: Text2ImageRegion) -> torch.Tensor:
+ """Generates a quartic mask of weights for tile contributions
+
+ The quartic kernel has bounded support over the diffusion region, and a smooth decay to the region limits.
+ """
+ quartic_constant = 15.0 / 16.0
+
+ support = (
+ np.array(range(region.latent_col_init, region.latent_col_end))
+ - region.latent_col_init
+ ) / (region.latent_col_end - region.latent_col_init - 1) * 1.99 - (1.99 / 2.0)
+ x_probs = quartic_constant * np.square(1 - np.square(support))
+ support = (
+ np.array(range(region.latent_row_init, region.latent_row_end))
+ - region.latent_row_init
+ ) / (region.latent_row_end - region.latent_row_init - 1) * 1.99 - (1.99 / 2.0)
+ y_probs = quartic_constant * np.square(1 - np.square(support))
+
+ weights = np.outer(y_probs, x_probs) * region.mask_weight
+ return torch.tile(
+ torch.tensor(weights), (self.nbatch, self.latent_space_dim, 1, 1)
+ )
+
+
+class StableDiffusionCanvasPipeline(DiffusionPipeline):
+ """Stable Diffusion pipeline with support for multiple diffusions on one canvas."""
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: SchedulerMixin,
+ ):
+ super().__init__()
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=None,
+ feature_extractor=None,
+ )
+
+ def _decode_latents(self, latents):
+ latents = 1 / 0.18215 * latents
+ ensure_correct_device(self.vae)
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+ return self.numpy_to_pil(image)
+
+ def _get_timesteps(self, num_inference_steps: int, strength: float) -> torch.Tensor:
+ offset = self.scheduler.config.get("steps_offset", 0)
+ init_timestep = int(num_inference_steps * (1 - strength)) + offset
+ init_timestep = min(init_timestep, num_inference_steps)
+
+ t_start = min(
+ max(num_inference_steps - init_timestep + offset, 0),
+ num_inference_steps - 1,
+ )
+ latest_timestep = self.scheduler.timesteps[t_start]
+ return latest_timestep
+
+ @property
+ def _execution_device(self):
+ # TODO: implement this from the SDXL PR
+ return self.unet.device
+
+ def __call__(
+ self,
+ height: int,
+ width: int,
+ regions: List[DiffusionRegion],
+ generator: torch.Generator,
+ num_inference_steps: int = 50,
+ reroll_regions: List[RerollRegion] = [],
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ ) -> List[Image]:
+ batch_size = 1
+ device = self._execution_device
+ unet_channels = self.unet.config.in_channels
+
+ txt_regions = [r for r in regions if isinstance(r, Text2ImageRegion)]
+ img_regions = [r for r in regions if isinstance(r, Image2ImageRegion)]
+
+ latents_shape = (
+ batch_size,
+ unet_channels,
+ math.ceil(height / 8),
+ math.ceil(width / 8),
+ )
+
+ all_eps_rerolls = regions + [
+ r for r in reroll_regions if r.reroll_mode == "epsilon"
+ ]
+
+ with autocast(
+ dtype=self.unet.dtype,
+ disable=not config.api.autocast,
+ ):
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+
+ for region in txt_regions:
+ ensure_correct_device(self.text_encoder)
+ prompt_embeds, negative_prompt_embeds = get_weighted_text_embeddings(
+ self, # type: ignore
+ region.prompt,
+ region.negative_prompt,
+ 3,
+ False,
+ False,
+ False,
+ )
+ if region.do_classifier_free_guidance:
+ region.text_embeddings = torch.cat([prompt_embeds, negative_prompt_embeds], dim=0) # type: ignore
+ else:
+ region.text_embeddings = prompt_embeds
+
+ init_noise = torch.randn(latents_shape, device=device, generator=generator)
+
+ for region in reroll_regions:
+ if region.reroll_mode == "reset":
+ latent_height = region.latent_row_end - region.latent_row_init
+ latent_width = region.latent_col_end - region.latent_col_init
+ region_shape = (
+ latents_shape[0],
+ latents_shape[1],
+ latent_height,
+ latent_width,
+ )
+ init_noise[
+ :,
+ :,
+ region.latent_row_init : region.latent_row_end,
+ region.latent_col_init : region.latent_col_end,
+ ] = torch.randn(region_shape, device=device, generator=generator)
+
+ for region in all_eps_rerolls:
+ if region.noise_eps > 0:
+ region_noise = init_noise[
+ :,
+ :,
+ region.latent_row_init : region.latent_row_end,
+ region.latent_col_init : region.latent_col_end,
+ ]
+ eps_noise = (
+ torch.randn(
+ region_noise.shape,
+ generator=region.get_region_generator(device=device),
+ device=device,
+ )
+ * region.noise_eps
+ )
+ init_noise[
+ :,
+ :,
+ region.latent_row_init : region.latent_row_end,
+ region.latent_col_init : region.latent_col_end,
+ ] += eps_noise
+
+ latents = init_noise * self.scheduler.init_noise_sigma
+
+ for region in img_regions:
+ region.encode_reference_image(
+ self.vae, generator=generator, device=device
+ )
+
+ mask_builder = MaskWeightsBuilder(unet_channels, batch_size)
+ mask_weights = [
+ mask_builder.compute_mask_weights(r).to(device=device)
+ for r in txt_regions
+ ]
+ ensure_correct_device(self.unet) # type: ignore
+ for i, t in enumerate(tqdm(self.scheduler.timesteps)):
+ noise_pred_regions = []
+
+ for region in txt_regions:
+ region_latents = latents[
+ :,
+ :,
+ region.latent_row_end : region.latent_row_init,
+ region.latent_col_init : region.latent_col_end,
+ ]
+ latent_model_input = region_latents
+ if region.do_classifier_free_guidance:
+ latent_model_input = torch.cat([latent_model_input] * 2)
+ latent_model_input = self.scheduler.scale_model_input(
+ latent_model_input, t
+ )
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=region.text_embeddings,
+ ).sample
+ if region.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred_region = (
+ noise_pred_uncond
+ + region.guidance_scale
+ * (noise_pred_text - noise_pred_uncond)
+ )
+ noise_pred_regions.append(noise_pred_region)
+ noise_pred = torch.zeros(latents.shape, device=device)
+ contributors = torch.zeros(latents.shape, device=device)
+
+ for region, noise_pred_region, mask_weights_region in zip(
+ txt_regions, noise_pred_regions, mask_weights
+ ):
+ noise_pred[
+ :,
+ :,
+ region.latent_row_end : region.latent_row_init,
+ region.latent_col_init : region.latent_col_end,
+ ] += (
+ noise_pred_region * mask_weights_region
+ )
+ contributors[
+ :,
+ :,
+ region.latent_row_end : region.latent_row_init,
+ region.latent_col_init : region.latent_col_end,
+ ] += mask_weights_region
+ noise_pred /= contributors
+ noise_pred = torch.nan_to_num(noise_pred)
+
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
+
+ for region in img_regions:
+ influence_step = self._get_timesteps(
+ num_inference_steps, region.strength
+ )
+ if t > influence_step:
+ timestep = t.repeat(batch_size)
+ region_init_noise = init_noise[
+ :,
+ :,
+ region.latent_row_init : region.latent_row_end,
+ region.latent_col_init : region.latent_col_end,
+ ]
+ region_latents = self.scheduler.add_noise(
+ region.reference_latents, region_init_noise, timestep
+ )
+ latents[
+ :,
+ :,
+ region.latent_row_init : region.latent_row_end,
+ region.latent_col_init : region.latent_col_end,
+ ] = region_latents
+
+ if callback is not None:
+ if i % callback_steps == 0:
+ callback(i, t, latents)
+ image = self._decode_latents(latents)
+
+ unload_all()
+
+ return image
diff --git a/core/inference/utilities/latents.py b/core/inference/utilities/latents.py
index 84f91c5d7..cb814d938 100644
--- a/core/inference/utilities/latents.py
+++ b/core/inference/utilities/latents.py
@@ -12,6 +12,7 @@
from core.config import config
from core.flags import LatentScaleModel
+from core.optimizations import upcast_vae
logger = logging.getLogger(__name__)
@@ -256,8 +257,16 @@ def prepare_latents(
return latents, None, None
else:
if image.shape[1] != 4:
+ if config.api.upcast_vae:
+ upcast_vae(pipe.vae)
+ image = image.to(
+ next(iter(pipe.vae.post_quant_conv.parameters())).dtype
+ )
+ else:
+ image = image.to(config.api.device, dtype=pipe.vae.dtype)
+
image = pad_tensor(image, pipe.vae_scale_factor)
- init_latent_dist = pipe.vae.encode(image.to(config.api.device, dtype=pipe.vae.dtype)).latent_dist # type: ignore
+ init_latent_dist = pipe.vae.encode(image).latent_dist # type: ignore
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents
init_latents = torch.cat([init_latents] * batch_size, dim=0)
@@ -281,95 +290,14 @@ def prepare_latents(
shape, generator=generator, device="cpu", dtype=dtype
).to(device)
else:
- # Retarded fix, but hey, if it works, it works
- if hasattr(pipe.vae, "main_device"):
- noise = torch.randn(
- shape,
- generator=torch.Generator("cpu").manual_seed(1),
- device="cpu",
- dtype=dtype,
- ).to(device)
- else:
- noise = torch.randn(
- shape, generator=generator, device=device, dtype=dtype
- )
- # Now this... I may have called the previous "hack" retarded, but this...
- # This just takes it to a whole new level
+ noise = torch.randn(
+ shape, generator=generator, device=device, dtype=dtype
+ )
latents = pipe.scheduler.add_noise(init_latents.to(device), noise.to(device), timestep.to(device)) # type: ignore
return latents, init_latents_orig, noise
-def bislerp_original(samples, width, height):
- shape = list(samples.shape)
- width_scale = (shape[3]) / (width)
- height_scale = (shape[2]) / (height)
-
- shape[3] = width
- shape[2] = height
- out1 = torch.empty(
- shape, dtype=samples.dtype, layout=samples.layout, device=samples.device
- )
-
- def algorithm(in1, in2, t):
- dims = in1.shape
- val = t
-
- # flatten to batches
- low = in1.reshape(dims[0], -1)
- high = in2.reshape(dims[0], -1)
-
- low_weight = torch.norm(low, dim=1, keepdim=True)
- low_weight[low_weight == 0] = 0.0000000001
- low_norm = low / low_weight
- high_weight = torch.norm(high, dim=1, keepdim=True)
- high_weight[high_weight == 0] = 0.0000000001
- high_norm = high / high_weight
-
- dot_prod = (low_norm * high_norm).sum(1)
- dot_prod[dot_prod > 0.9995] = 0.9995
- dot_prod[dot_prod < -0.9995] = -0.9995
- omega = torch.acos(dot_prod)
- so = torch.sin(omega)
- res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low_norm + (
- torch.sin(val * omega) / so
- ).unsqueeze(1) * high_norm
- res *= low_weight * (1.0 - val) + high_weight * val
- return res.reshape(dims)
-
- for x_dest in range(shape[3]):
- for y_dest in range(shape[2]):
- y = (y_dest + 0.5) * height_scale - 0.5
- x = (x_dest + 0.5) * width_scale - 0.5
-
- x1 = max(math.floor(x), 0)
- x2 = min(x1 + 1, samples.shape[3] - 1)
- wx = x - math.floor(x)
-
- y1 = max(math.floor(y), 0)
- y2 = min(y1 + 1, samples.shape[2] - 1)
- wy = y - math.floor(y)
-
- in1 = samples[:, :, y1, x1]
- in2 = samples[:, :, y1, x2]
- in3 = samples[:, :, y2, x1]
- in4 = samples[:, :, y2, x2]
-
- if (x1 == x2) and (y1 == y2):
- out_value = in1
- elif x1 == x2:
- out_value = algorithm(in1, in3, wy)
- elif y1 == y2:
- out_value = algorithm(in1, in2, wx)
- else:
- o1 = algorithm(in1, in2, wx)
- o2 = algorithm(in3, in4, wx)
- out_value = algorithm(o1, o2, wy)
-
- out1[:, :, y_dest, x_dest] = out_value
- return out1
-
-
-def bislerp_gabeified(samples, width, height):
+def bislerp(samples, width, height):
device = samples.device
def slerp(b1, b2, r):
@@ -488,9 +416,7 @@ def scale_latents(
):
"Interpolate the latents to the desired scale."
- align_to = (
- 32 if latent_scale_mode in ["bislerp-tortured", "bislerp-original"] else 8
- )
+ align_to = 32 if latent_scale_mode == "bislerp" else 8
s = time()
@@ -503,10 +429,8 @@ def scale_latents(
height_truncated = int(((latents.shape[3] * scale - 1) // align_to + 1) * align_to)
# Scale the latents
- if latent_scale_mode == "bislerp-tortured":
- interpolated = bislerp_gabeified(latents, height_truncated, width_truncated)
- elif latent_scale_mode == "bislerp-original":
- interpolated = bislerp_original(latents, height_truncated, width_truncated)
+ if latent_scale_mode == "bislerp":
+ interpolated = bislerp(latents, height_truncated, width_truncated)
else:
interpolated = F.interpolate(
latents,
diff --git a/core/inference/utilities/lwp.py b/core/inference/utilities/lwp.py
index 43515d4db..d46576912 100644
--- a/core/inference/utilities/lwp.py
+++ b/core/inference/utilities/lwp.py
@@ -8,6 +8,7 @@
from ...config import config
from ...files import get_full_model_path
+from ...optimizations import ensure_correct_device
logger = logging.getLogger(__name__)
@@ -253,6 +254,7 @@ def get_unweighted_text_embeddings(
if hasattr(pipe, "clip_inference"):
text_embedding = pipe.clip_inference(text_input_chunk)
else:
+ ensure_correct_device(pipe.text_encoder)
text_embedding = pipe.text_encoder(text_input_chunk)[0] # type: ignore
if no_boseos_middle:
@@ -272,6 +274,7 @@ def get_unweighted_text_embeddings(
if hasattr(pipe, "clip_inference"):
text_embeddings = pipe.clip_inference(text_input)
else:
+ ensure_correct_device(pipe.text_encoder)
text_embeddings = pipe.text_encoder(text_input)[0] # type: ignore
return text_embeddings
diff --git a/core/optimizations/__init__.py b/core/optimizations/__init__.py
index 431337bba..c5d8f5831 100644
--- a/core/optimizations/__init__.py
+++ b/core/optimizations/__init__.py
@@ -1,8 +1,13 @@
from .autocast_utils import autocast, without_autocast
from .pytorch_optimizations import optimize_model
+from .offload import ensure_correct_device, unload_all
+from .upcast import upcast_vae
__all__ = [
"optimize_model",
"without_autocast",
"autocast",
+ "upcast_vae",
+ "ensure_correct_device",
+ "unload_all",
]
diff --git a/core/optimizations/offload.py b/core/optimizations/offload.py
new file mode 100644
index 000000000..8444ab699
--- /dev/null
+++ b/core/optimizations/offload.py
@@ -0,0 +1,39 @@
+# pylint: disable=global-statement
+
+import logging
+
+import torch
+
+from core.config import config
+
+logger = logging.getLogger(__name__)
+_module: torch.nn.Module = None # type: ignore
+
+
+def unload_all():
+ global _module
+ if _module is not None:
+ _module.cpu()
+ _module = None # type: ignore
+
+
+def ensure_correct_device(module: torch.nn.Module):
+ if hasattr(module, "offload_device"):
+ global _module
+
+ device = getattr(module, "offload_device", config.api.device)
+ logger.debug(f"Transferring {module.__class__.__name__} to {str(device)}.")
+
+ if _module is not None:
+ logger.debug(f"Transferring {_module.__class__.__name__} to cpu.")
+ _module.cpu()
+
+ module.to(device=device)
+ _module = module
+ else:
+ logger.debug(f"Don't need to do anything with {module.__class__.__name__}.")
+
+
+def set_offload(module, device: torch.device):
+ logger.debug(f"Offloading {module.__class__.__name__} to {str(device)}.")
+ setattr(module, "offload_device", device)
\ No newline at end of file
diff --git a/core/optimizations/pytorch_optimizations.py b/core/optimizations/pytorch_optimizations.py
index a34472ba0..c2a5f0be5 100644
--- a/core/optimizations/pytorch_optimizations.py
+++ b/core/optimizations/pytorch_optimizations.py
@@ -7,13 +7,12 @@
StableDiffusionPipeline,
StableDiffusionUpscalePipeline,
)
-from diffusers.utils import is_accelerate_available
from core.config import config
-from core.files import get_full_model_path
from .attn import set_attention_processor
from .trace_utils import generate_inputs, trace_model
+from .offload import set_offload
logger = logging.getLogger(__name__)
@@ -48,22 +47,24 @@ def optimize_model(
'You can disable it by going inside Graphics Settings → "Default Graphics Settings" and disabling "Hardware-accelerated GPU Scheduling"'
)
- offload = (
- config.api.offload
- if (is_pytorch_pipe(pipe) and not is_for_aitemplate)
- else None
+ offload = config.api.offload and is_pytorch_pipe(pipe) and not is_for_aitemplate
+ can_offload = (
+ config.api.device_type
+ not in [
+ "cpu",
+ "vulkan",
+ "mps",
+ ]
+ and offload
)
- can_offload = config.api.device_type not in [
- "cpu",
- "vulkan",
- "mps",
- ] and (offload != "disabled" and offload is not None)
# Took me an hour to understand why CPU stopped working...
# Turns out AMD just lacks support for BF16...
# Not mad, not mad at all... to be fair, I'm just disappointed
if not can_offload and not is_for_aitemplate:
pipe.to(device, torch_dtype=config.api.dtype)
+ else:
+ pipe.to(torch_dtype=config.api.dtype)
if config.api.device_type == "cuda" and not is_for_aitemplate:
supports_tf = supports_tf32(device)
@@ -123,80 +124,27 @@ def optimize_model(
logger.info("Optimization: Enabled autocast")
if can_offload:
- if not is_accelerate_available():
- logger.warning(
- "Optimization: Offload is not available, because accelerate is not installed"
- )
- else:
- if offload == "model":
- # Offload to CPU
- from accelerate import cpu_offload_with_hook
-
- if config.api.device_type == "cuda":
- torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
-
- hook = None
-
- for cpu_offloaded_model in [
- pipe.text_encoder,
- pipe.unet,
- pipe.vae,
- ]:
- _, hook = cpu_offload_with_hook(
- cpu_offloaded_model, device, prev_module_hook=hook
- )
- pipe.final_offload_hook = hook
- setattr(pipe.vae, "main_device", True)
- setattr(pipe.unet, "main_device", True)
- logger.info("Optimization: Offloaded model parts to CPU.")
-
- elif offload == "module":
- # Enable sequential offload
- from accelerate import cpu_offload, disk_offload
-
- for m in [
- pipe.vae,
- pipe.unet,
- ]:
- if USE_DISK_OFFLOAD:
- # If USE_DISK_OFFLOAD toggle set (idk why anyone would do this, but it's nice to support stuff
- # like this in case anyone wants to try running this on fuck knows what)
- # then offload to disk.
- disk_offload(
- m,
- str(
- get_full_model_path("offload-dir", model_folder="temp")
- / m.__name__
- ),
- device,
- offload_buffers=True,
- )
- else:
- cpu_offload(m, device, offload_buffers=True)
-
- logger.info("Optimization: Enabled sequential offload")
+ # Offload to CPU
+
+ for model_name in [
+ "text_encoder",
+ "text_encoder2",
+ "unet",
+ "vae",
+ ]:
+ cpu_offloaded_model = getattr(pipe, model_name, None)
+ if cpu_offloaded_model is not None:
+ set_offload(cpu_offloaded_model, device)
+ setattr(pipe, model_name, cpu_offloaded_model)
+ logger.info("Optimization: Offloaded model parts to CPU.")
if config.api.vae_slicing:
- if not (
- issubclass(pipe.__class__, StableDiffusionUpscalePipeline)
- or isinstance(pipe, StableDiffusionUpscalePipeline)
- ):
- pipe.enable_vae_slicing()
- logger.info("Optimization: Enabled VAE slicing")
- else:
- logger.debug(
- "Optimization: VAE slicing is not available for upscale models"
- )
+ pipe.enable_vae_slicing()
+ logger.info("Optimization: Enabled VAE slicing")
if config.api.vae_tiling:
- if not (
- issubclass(pipe.__class__, StableDiffusionUpscalePipeline)
- or isinstance(pipe, StableDiffusionUpscalePipeline)
- ):
- pipe.enable_vae_tiling()
- logger.info("Optimization: Enabled VAE tiling")
- else:
- logger.debug("Optimization: VAE tiling is not available for upscale models")
+ pipe.enable_vae_tiling()
+ logger.info("Optimization: Enabled VAE tiling")
if config.api.use_tomesd and not is_for_aitemplate:
try:
diff --git a/core/optimizations/upcast.py b/core/optimizations/upcast.py
new file mode 100644
index 000000000..dccd7d0b2
--- /dev/null
+++ b/core/optimizations/upcast.py
@@ -0,0 +1,26 @@
+import logging
+
+from diffusers.models import AutoencoderKL
+from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
+import torch
+
+logger = logging.getLogger(__name__)
+
+
+def upcast_vae(vae: AutoencoderKL):
+ dtype = vae.dtype
+ logger.debug('Upcasting VAE to FP32 (vae["force_upcast"] OR config.api.upcast_vae)')
+ vae.to(dtype=torch.float32)
+ use_torch_2_0_or_xformers = isinstance(
+ vae.decoder.mid_block.attentions[0].processor, # type: ignore
+ (
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+ ),
+ )
+ # if xformers or torch_2_0 is used attention block does not need
+ # to be in float32 which can save lots of memory
+ if use_torch_2_0_or_xformers:
+ vae.post_quant_conv.to(dtype=dtype)
+ vae.decoder.conv_in.to(dtype=dtype)
+ vae.decoder.mid_block.to(dtype=dtype) # type: ignore
\ No newline at end of file
diff --git a/frontend/dist/assets/SettingsView.js b/frontend/dist/assets/SettingsView.js
index 5b51a3cda..9ecba43a0 100644
--- a/frontend/dist/assets/SettingsView.js
+++ b/frontend/dist/assets/SettingsView.js
@@ -2244,21 +2244,7 @@ const _sfc_main$d = /* @__PURE__ */ defineComponent({
}),
createVNode(unref(NFormItem), { label: "Offload" }, {
default: withCtx(() => [
- createVNode(unref(NSelect), {
- options: [
- {
- value: "disabled",
- label: "Disabled"
- },
- {
- value: "model",
- label: "Offload the whole model to RAM when not used"
- },
- {
- value: "module",
- label: "Offload individual modules to RAM when not used"
- }
- ],
+ createVNode(unref(NSwitch), {
value: unref(settings).defaultSettings.api.offload,
"onUpdate:value": _cache[24] || (_cache[24] = ($event) => unref(settings).defaultSettings.api.offload = $event)
}, null, 8, ["value"])
@@ -2495,12 +2481,8 @@ const _sfc_main$b = /* @__PURE__ */ defineComponent({
value: "bicubic"
},
{
- label: "Bislerp (Original, slow)",
- value: "bislerp-original"
- },
- {
- label: "Bislerp (Tortured, fast)",
- value: "bislerp-tortured"
+ label: "Bislerp",
+ value: "bislerp"
},
{
label: "Nearest Exact",
@@ -2509,7 +2491,7 @@ const _sfc_main$b = /* @__PURE__ */ defineComponent({
],
value: unref(settings).defaultSettings.extra.highres.latent_scale_mode,
"onUpdate:value": _cache[1] || (_cache[1] = ($event) => unref(settings).defaultSettings.extra.highres.latent_scale_mode = $event)
- }, null, 8, ["options", "value"])
+ }, null, 8, ["value"])
]),
_: 1
}),
diff --git a/frontend/dist/assets/index.js b/frontend/dist/assets/index.js
index 2da292cc9..b16ef051d 100644
--- a/frontend/dist/assets/index.js
+++ b/frontend/dist/assets/index.js
@@ -40646,7 +40646,7 @@ const defaultSettings = {
vae_tiling: false,
trace_model: false,
cudnn_benchmark: false,
- offload: "disabled",
+ offload: false,
device_id: 0,
device_type: "cuda",
data_type: "float16",
diff --git a/frontend/src/components/settings/APISettings.vue b/frontend/src/components/settings/APISettings.vue
index ede89a6fb..9faf3c49c 100644
--- a/frontend/src/components/settings/APISettings.vue
+++ b/frontend/src/components/settings/APISettings.vue
@@ -209,24 +209,7 @@