diff --git a/.vscode/settings.json b/.vscode/settings.json index 97fafffbe..81f679c3e 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,11 +1,14 @@ { - "python.linting.pylintEnabled": true, - "python.linting.enabled": true, - "python.testing.pytestArgs": ["."], - "python.testing.unittestEnabled": false, - "python.testing.pytestEnabled": true, - "python.analysis.typeCheckingMode": "basic", - "python.formatting.provider": "black", - "python.languageServer": "Pylance", - "rust-analyzer.linkedProjects": ["./manager/Cargo.toml"] + "python.linting.pylintEnabled": true, + "python.linting.enabled": true, + "python.testing.pytestArgs": ["."], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true, + "python.analysis.typeCheckingMode": "basic", + "python.formatting.provider": "none", + "python.languageServer": "Pylance", + "rust-analyzer.linkedProjects": ["./manager/Cargo.toml"], + "[python]": { + "editor.defaultFormatter": "ms-python.black-formatter" + } } diff --git a/api/routes/settings.py b/api/routes/settings.py index 76805a060..27f8ed403 100644 --- a/api/routes/settings.py +++ b/api/routes/settings.py @@ -4,7 +4,7 @@ from fastapi import APIRouter from core import config -from core.config.config import update_config +from core.config._config import update_config router = APIRouter(tags=["settings"]) diff --git a/core/config/__init__.py b/core/config/__init__.py index 2d52a925a..780069840 100644 --- a/core/config/__init__.py +++ b/core/config/__init__.py @@ -2,7 +2,7 @@ from diffusers.utils.constants import DIFFUSERS_CACHE -from .config import ( +from ._config import ( Configuration, Img2ImgConfig, Txt2ImgConfig, diff --git a/core/config/_config.py b/core/config/_config.py new file mode 100644 index 000000000..b01d704ff --- /dev/null +++ b/core/config/_config.py @@ -0,0 +1,75 @@ +import logging +from dataclasses import Field, dataclass, field, fields + +from dataclasses_json import CatchAll, DataClassJsonMixin, Undefined, dataclass_json + +from .api_settings import APIConfig +from .bot_settings import BotConfig +from .default_settings import ( + Txt2ImgConfig, + Img2ImgConfig, + InpaintingConfig, + ControlNetConfig, + UpscaleConfig, + AITemplateConfig, + ONNXConfig, +) +from .frontend_settings import FrontendConfig +from .interrogator_settings import InterrogatorConfig + +logger = logging.getLogger(__name__) + + +@dataclass_json(undefined=Undefined.INCLUDE) +@dataclass +class Configuration(DataClassJsonMixin): + "Main configuration class for the application" + + txt2img: Txt2ImgConfig = field(default_factory=Txt2ImgConfig) + img2img: Img2ImgConfig = field(default_factory=Img2ImgConfig) + inpainting: InpaintingConfig = field(default_factory=InpaintingConfig) + controlnet: ControlNetConfig = field(default_factory=ControlNetConfig) + upscale: UpscaleConfig = field(default_factory=UpscaleConfig) + api: APIConfig = field(default_factory=APIConfig) + interrogator: InterrogatorConfig = field(default_factory=InterrogatorConfig) + aitemplate: AITemplateConfig = field(default_factory=AITemplateConfig) + onnx: ONNXConfig = field(default_factory=ONNXConfig) + bot: BotConfig = field(default_factory=BotConfig) + frontend: FrontendConfig = field(default_factory=FrontendConfig) + extra: CatchAll = field(default_factory=dict) + + +def save_config(config: Configuration): + "Save the configuration to a file" + + logger.info("Saving configuration to data/settings.json") + + with open("data/settings.json", "w", encoding="utf-8") as f: + f.write(config.to_json(ensure_ascii=False, indent=4)) + + +def update_config(config: Configuration, new_config: Configuration): + "Update the configuration with new values instead of overwriting the pointer" + + for cls_field in fields(new_config): + assert isinstance(cls_field, Field) + setattr(config, cls_field.name, getattr(new_config, cls_field.name)) + + +def load_config(): + "Load the configuration from a file" + + logger.info("Loading configuration from data/settings.json") + + try: + with open("data/settings.json", "r", encoding="utf-8") as f: + config = Configuration.from_json(f.read()) + logger.info("Configuration loaded from data/settings.json") + return config + + except FileNotFoundError: + logger.info("data/settings.json not found, creating a new one") + config = Configuration() + save_config(config) + logger.info("Configuration saved to data/settings.json") + return config \ No newline at end of file diff --git a/core/config/api_settings.py b/core/config/api_settings.py new file mode 100644 index 000000000..bb7491a09 --- /dev/null +++ b/core/config/api_settings.py @@ -0,0 +1,110 @@ +from dataclasses import dataclass, field +from typing import List, Literal, Union + +import torch + + +@dataclass +class APIConfig: + "Configuration for the API" + + # Autoload + autoloaded_textual_inversions: List[str] = field(default_factory=list) + + # Websockets and intervals + websocket_sync_interval: float = 0.02 + websocket_perf_interval: float = 1.0 + + # TomeSD + use_tomesd: bool = False # really extreme, probably will have to wait around until tome improves a bit + tomesd_ratio: float = 0.25 # had to tone this down, 0.4 is too big of a context loss even on short prompts + tomesd_downsample_layers: Literal[1, 2, 4, 8] = 1 + + image_preview_delay: float = 2.0 + + # General optimizations + autocast: bool = False + attention_processor: Literal[ + "xformers", "sdpa", "cross-attention", "subquadratic", "multihead" + ] = "sdpa" + subquadratic_size: int = 512 + attention_slicing: Union[int, Literal["auto", "disabled"]] = "disabled" + channels_last: bool = True + vae_slicing: bool = True + vae_tiling: bool = False + trace_model: bool = False + clear_memory_policy: Literal["always", "after_disconnect", "never"] = "always" + offload: bool = False + data_type: Literal["float32", "float16", "bfloat16"] = "float16" + + # CUDA specific optimizations + reduced_precision: bool = False + cudnn_benchmark: bool = False + deterministic_generation: bool = False + + # Device settings + device_id: int = 0 + device_type: Literal["cpu", "cuda", "mps", "directml", "intel", "vulkan"] = "cuda" + + # Critical + enable_shutdown: bool = True + + # VAE + upcast_vae: bool = False + + # CLIP + clip_skip: int = 1 + clip_quantization: Literal["full", "int8", "int4"] = "full" + + huggingface_style_parsing: bool = False + + # Saving + save_path_template: str = "{folder}/{prompt}/{id}-{index}.{extension}" + image_extension: Literal["png", "webp", "jpeg"] = "png" + image_quality: int = 95 + image_return_format: Literal["bytes", "base64"] = "base64" + + # Grid + disable_grid: bool = False + + # Torch compile + torch_compile: bool = False + torch_compile_fullgraph: bool = False + torch_compile_dynamic: bool = False + torch_compile_backend: str = "inductor" + torch_compile_mode: Literal[ + "default", + "reduce-overhead", + "max-autotune", + ] = "reduce-overhead" + + @property + def dtype(self): + "Return selected data type" + if self.data_type == "bfloat16": + return torch.bfloat16 + if self.data_type == "float16": + return torch.float16 + return torch.float32 + + @property + def device(self): + "Return the device" + + if self.device_type == "intel": + from core.inference.functions import is_ipex_available + + return torch.device("xpu" if is_ipex_available() else "cpu") + + if self.device_type in ["cpu", "mps"]: + return torch.device(self.device_type) + + if self.device_type in ["vulkan", "cuda"]: + return torch.device(f"{self.device_type}:{self.device_id}") + + if self.device_type == "directml": + import torch_directml # pylint: disable=import-error + + return torch_directml.device() + else: + raise ValueError(f"Device type {self.device_type} not supported") \ No newline at end of file diff --git a/core/config/bot_settings.py b/core/config/bot_settings.py new file mode 100644 index 000000000..3db429d58 --- /dev/null +++ b/core/config/bot_settings.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass + +from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers + + +@dataclass +class BotConfig: + "Configuration for the bot" + + default_scheduler: KarrasDiffusionSchedulers = ( + KarrasDiffusionSchedulers.DPMSolverSinglestepScheduler + ) + verbose: bool = False + use_default_negative_prompt: bool = True \ No newline at end of file diff --git a/core/config/config.py b/core/config/config.py deleted file mode 100644 index d2dad6baf..000000000 --- a/core/config/config.py +++ /dev/null @@ -1,316 +0,0 @@ -import logging -import multiprocessing -from dataclasses import Field, dataclass, field, fields -from typing import List, Literal, Optional, Union - -import torch -from dataclasses_json import CatchAll, DataClassJsonMixin, Undefined, dataclass_json -from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers - -logger = logging.getLogger(__name__) - - -@dataclass -class QuantDict: - vae_decoder: Optional[bool] = None - vae_encoder: Optional[bool] = None - unet: Optional[bool] = None - text_encoder: Optional[bool] = None - - -@dataclass -class Txt2ImgConfig: - "Configuration for the text to image pipeline" - - width: int = 512 - height: int = 512 - seed: int = -1 - cfg_scale: int = 7 - sampler: int = KarrasDiffusionSchedulers.DPMSolverSinglestepScheduler.value - prompt: str = "" - negative_prompt: str = "" - steps: int = 40 - batch_count: int = 1 - batch_size: int = 1 - self_attention_scale: float = 0.0 - - -@dataclass -class Img2ImgConfig: - "Configuration for the image to image pipeline" - - width: int = 512 - height: int = 512 - seed: int = -1 - cfg_scale: int = 7 - sampler: int = KarrasDiffusionSchedulers.DPMSolverSinglestepScheduler.value - prompt: str = "" - negative_prompt: str = "" - steps: int = 40 - batch_count: int = 1 - batch_size: int = 1 - resize_method: int = 0 - denoising_strength: float = 0.6 - self_attention_scale: float = 0.0 - - -@dataclass -class InpaintingConfig: - "Configuration for the inpainting pipeline" - - prompt: str = "" - negative_prompt: str = "" - width: int = 512 - height: int = 512 - steps: int = 40 - cfg_scale: int = 7 - seed: int = -1 - batch_count: int = 1 - batch_size: int = 1 - sampler: int = KarrasDiffusionSchedulers.DPMSolverSinglestepScheduler.value - self_attention_scale: float = 0.0 - - -@dataclass -class ControlNetConfig: - "Configuration for the inpainting pipeline" - - prompt: str = "" - negative_prompt: str = "" - width: int = 512 - height: int = 512 - seed: int = -1 - cfg_scale: int = 7 - steps: int = 40 - batch_count: int = 1 - batch_size: int = 1 - sampler: int = KarrasDiffusionSchedulers.DPMSolverSinglestepScheduler.value - controlnet: str = "lllyasviel/sd-controlnet-canny" - controlnet_conditioning_scale: float = 1.0 - detection_resolution: int = 512 - is_preprocessed: bool = False - save_preprocessed: bool = False - return_preprocessed: bool = True - - -@dataclass -class UpscaleConfig: - "Configuration for the RealESRGAN upscaler" - - model: str = "RealESRGAN_x4plus_anime_6B" - upscale_factor: int = 4 - tile_size: int = field(default=128) - tile_padding: int = field(default=10) - - -@dataclass -class APIConfig: - "Configuration for the API" - - # Websockets and intervals - websocket_sync_interval: float = 0.02 - websocket_perf_interval: float = 1.0 - - # TomeSD - use_tomesd: bool = False # really extreme, probably will have to wait around until tome improves a bit - tomesd_ratio: float = 0.25 # had to tone this down, 0.4 is too big of a context loss even on short prompts - tomesd_downsample_layers: Literal[1, 2, 4, 8] = 1 - - image_preview_delay: float = 2.0 - - # General optimizations - autocast: bool = False - attention_processor: Literal[ - "xformers", "sdpa", "cross-attention", "subquadratic", "multihead" - ] = "sdpa" - subquadratic_size: int = 512 - attention_slicing: Union[int, Literal["auto", "disabled"]] = "disabled" - channels_last: bool = True - vae_slicing: bool = True - vae_tiling: bool = False - trace_model: bool = False - clear_memory_policy: Literal["always", "after_disconnect", "never"] = "always" - offload: Literal["module", "model", "disabled"] = "disabled" - data_type: Literal["float32", "float16", "bfloat16"] = "float16" - - # CUDA specific optimizations - reduced_precision: bool = False - cudnn_benchmark: bool = False - deterministic_generation: bool = False - - # Device settings - device_id: int = 0 - device_type: Literal["cpu", "cuda", "mps", "directml", "intel", "vulkan"] = "cuda" - - # Critical - enable_shutdown: bool = True - - # CLIP - clip_skip: int = 1 - clip_quantization: Literal["full", "int8", "int4"] = "full" - - # Autoload - autoloaded_textual_inversions: List[str] = field(default_factory=list) - - huggingface_style_parsing: bool = False - - # Saving - save_path_template: str = "{folder}/{prompt}/{id}-{index}.{extension}" - image_extension: Literal["png", "webp", "jpeg"] = "png" - image_quality: int = 95 - image_return_format: Literal["bytes", "base64"] = "base64" - - # Grid - disable_grid: bool = False - - # Torch compile - torch_compile: bool = False - torch_compile_fullgraph: bool = False - torch_compile_dynamic: bool = False - torch_compile_backend: str = "inductor" - torch_compile_mode: Literal[ - "default", - "reduce-overhead", - "max-autotune", - ] = "reduce-overhead" - - @property - def dtype(self): - "Return selected data type" - if self.data_type == "bfloat16": - return torch.bfloat16 - if self.data_type == "float16": - return torch.float16 - return torch.float32 - - @property - def device(self): - "Return the device" - - if self.device_type == "intel": - from core.inference.functions import is_ipex_available - - return torch.device("xpu" if is_ipex_available() else "cpu") - - if self.device_type in ["cpu", "mps"]: - return torch.device(self.device_type) - - if self.device_type in ["vulkan", "cuda"]: - return torch.device(f"{self.device_type}:{self.device_id}") - - if self.device_type == "directml": - import torch_directml # pylint: disable=import-error - - return torch_directml.device() - else: - raise ValueError(f"Device type {self.device_type} not supported") - - -@dataclass -class AITemplateConfig: - "Configuration for model inference and acceleration" - - num_threads: int = field(default=min(multiprocessing.cpu_count() - 1, 8)) - - -@dataclass -class ONNXConfig: - "Configuration for ONNX acceleration" - - quant_dict: QuantDict = field(default_factory=QuantDict) - - -@dataclass -class BotConfig: - "Configuration for the bot" - - default_scheduler: KarrasDiffusionSchedulers = ( - KarrasDiffusionSchedulers.DPMSolverSinglestepScheduler - ) - verbose: bool = False - use_default_negative_prompt: bool = True - - -@dataclass -class InterrogatorConfig: - "Configuration for interrogation models" - - # set to "Salesforce/blip-image-captioning-base" for an extra gig of vram - caption_model: str = "Salesforce/blip-image-captioning-large" - visualizer_model: str = "ViT-L-14/openai" - - offload_captioner: bool = False - offload_visualizer: bool = False - - chunk_size: int = 2048 # set to 1024 for lower vram usage - flavor_intermediate_count: int = 2048 # set to 1024 for lower vram usage - - flamingo_model: str = "dhansmair/flamingo-mini" - - caption_max_length: int = 32 - - -@dataclass -class FrontendConfig: - "Configuration for the frontend" - - theme: Literal["dark", "light"] = "dark" - enable_theme_editor: bool = False - image_browser_columns: int = 5 - on_change_timer: int = 0 - nsfw_ok_threshold: int = 0 - - -@dataclass_json(undefined=Undefined.INCLUDE) -@dataclass -class Configuration(DataClassJsonMixin): - "Main configuration class for the application" - - txt2img: Txt2ImgConfig = field(default_factory=Txt2ImgConfig) - img2img: Img2ImgConfig = field(default_factory=Img2ImgConfig) - inpainting: InpaintingConfig = field(default_factory=InpaintingConfig) - controlnet: ControlNetConfig = field(default_factory=ControlNetConfig) - upscale: UpscaleConfig = field(default_factory=UpscaleConfig) - api: APIConfig = field(default_factory=APIConfig) - interrogator: InterrogatorConfig = field(default_factory=InterrogatorConfig) - aitemplate: AITemplateConfig = field(default_factory=AITemplateConfig) - onnx: ONNXConfig = field(default_factory=ONNXConfig) - bot: BotConfig = field(default_factory=BotConfig) - frontend: FrontendConfig = field(default_factory=FrontendConfig) - extra: CatchAll = field(default_factory=dict) - - -def save_config(config: Configuration): - "Save the configuration to a file" - - logger.info("Saving configuration to data/settings.json") - - with open("data/settings.json", "w", encoding="utf-8") as f: - f.write(config.to_json(ensure_ascii=False, indent=4)) - - -def update_config(config: Configuration, new_config: Configuration): - "Update the configuration with new values instead of overwriting the pointer" - - for cls_field in fields(new_config): - assert isinstance(cls_field, Field) - setattr(config, cls_field.name, getattr(new_config, cls_field.name)) - - -def load_config(): - "Load the configuration from a file" - - logger.info("Loading configuration from data/settings.json") - - try: - with open("data/settings.json", "r", encoding="utf-8") as f: - config = Configuration.from_json(f.read()) - logger.info("Configuration loaded from data/settings.json") - return config - - except FileNotFoundError: - logger.info("data/settings.json not found, creating a new one") - config = Configuration() - save_config(config) - logger.info("Configuration saved to data/settings.json") - return config diff --git a/core/config/default_settings.py b/core/config/default_settings.py new file mode 100644 index 000000000..3f9b51209 --- /dev/null +++ b/core/config/default_settings.py @@ -0,0 +1,114 @@ +from dataclasses import dataclass, field +import multiprocessing +from typing import Optional + +from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers + + +@dataclass +class QuantDict: + "Configuration for ONNX quantization" + + vae_decoder: Optional[bool] = None + vae_encoder: Optional[bool] = None + unet: Optional[bool] = None + text_encoder: Optional[bool] = None + + +@dataclass +class Txt2ImgConfig: + "Configuration for the text to image pipeline" + + width: int = 512 + height: int = 512 + seed: int = -1 + cfg_scale: int = 7 + sampler: int = KarrasDiffusionSchedulers.DPMSolverSinglestepScheduler.value + prompt: str = "" + negative_prompt: str = "" + steps: int = 40 + batch_count: int = 1 + batch_size: int = 1 + self_attention_scale: float = 0.0 + + +@dataclass +class Img2ImgConfig: + "Configuration for the image to image pipeline" + + width: int = 512 + height: int = 512 + seed: int = -1 + cfg_scale: int = 7 + sampler: int = KarrasDiffusionSchedulers.DPMSolverSinglestepScheduler.value + prompt: str = "" + negative_prompt: str = "" + steps: int = 40 + batch_count: int = 1 + batch_size: int = 1 + resize_method: int = 0 + denoising_strength: float = 0.6 + self_attention_scale: float = 0.0 + + +@dataclass +class InpaintingConfig: + "Configuration for the inpainting pipeline" + + prompt: str = "" + negative_prompt: str = "" + width: int = 512 + height: int = 512 + steps: int = 40 + cfg_scale: int = 7 + seed: int = -1 + batch_count: int = 1 + batch_size: int = 1 + sampler: int = KarrasDiffusionSchedulers.DPMSolverSinglestepScheduler.value + self_attention_scale: float = 0.0 + + +@dataclass +class ControlNetConfig: + "Configuration for the inpainting pipeline" + + prompt: str = "" + negative_prompt: str = "" + width: int = 512 + height: int = 512 + seed: int = -1 + cfg_scale: int = 7 + steps: int = 40 + batch_count: int = 1 + batch_size: int = 1 + sampler: int = KarrasDiffusionSchedulers.DPMSolverSinglestepScheduler.value + controlnet: str = "lllyasviel/sd-controlnet-canny" + controlnet_conditioning_scale: float = 1.0 + detection_resolution: int = 512 + is_preprocessed: bool = False + save_preprocessed: bool = False + return_preprocessed: bool = True + + +@dataclass +class UpscaleConfig: + "Configuration for the RealESRGAN upscaler" + + model: str = "RealESRGAN_x4plus_anime_6B" + upscale_factor: int = 4 + tile_size: int = field(default=128) + tile_padding: int = field(default=10) + + +@dataclass +class AITemplateConfig: + "Configuration for model inference and acceleration" + + num_threads: int = field(default=min(multiprocessing.cpu_count() - 1, 8)) + + +@dataclass +class ONNXConfig: + "Configuration for ONNX acceleration" + + quant_dict: QuantDict = field(default_factory=QuantDict) \ No newline at end of file diff --git a/core/config/frontend_settings.py b/core/config/frontend_settings.py new file mode 100644 index 000000000..20ceb13b4 --- /dev/null +++ b/core/config/frontend_settings.py @@ -0,0 +1,13 @@ +from dataclasses import dataclass +from typing import Literal + + +@dataclass +class FrontendConfig: + "Configuration for the frontend" + + theme: Literal["dark", "light"] = "dark" + enable_theme_editor: bool = False + image_browser_columns: int = 5 + on_change_timer: int = 0 + nsfw_ok_threshold: int = 0 \ No newline at end of file diff --git a/core/config/interrogator_settings.py b/core/config/interrogator_settings.py new file mode 100644 index 000000000..eb94aa4fd --- /dev/null +++ b/core/config/interrogator_settings.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + + +@dataclass +class InterrogatorConfig: + "Configuration for interrogation models" + + # set to "Salesforce/blip-image-captioning-base" for an extra gig of vram + caption_model: str = "Salesforce/blip-image-captioning-large" + visualizer_model: str = "ViT-L-14/openai" + + offload_captioner: bool = False + offload_visualizer: bool = False + + chunk_size: int = 2048 # set to 1024 for lower vram usage + flavor_intermediate_count: int = 2048 # set to 1024 for lower vram usage + + flamingo_model: str = "dhansmair/flamingo-mini" + + caption_max_length: int = 32 \ No newline at end of file diff --git a/core/files.py b/core/files.py index 71d401690..5134bcb98 100644 --- a/core/files.py +++ b/core/files.py @@ -82,7 +82,10 @@ def pytorch(self) -> List[ModelResponse]: state="not loaded", ) ) - elif ".safetensors" in model_name or ".ckpt" in model_name: + elif (self.checkpoint_converted_path / model_name).suffix in [ + ".ckpt", + ".safetensors", + ]: # Assuming that model is in Checkpoint / Safetensors format models.append( ModelResponse( diff --git a/core/flags.py b/core/flags.py index 0ce4506ec..690091fc9 100644 --- a/core/flags.py +++ b/core/flags.py @@ -7,8 +7,7 @@ "nearest", "area", "bilinear", - "bislerp-original", - "bislerp-tortured", + "bislerp", "bicubic", "nearest-exact", ] diff --git a/core/inference/pytorch/pipeline.py b/core/inference/pytorch/pipeline.py index 3ce27f49c..c77d74158 100644 --- a/core/inference/pytorch/pipeline.py +++ b/core/inference/pytorch/pipeline.py @@ -24,7 +24,7 @@ prepare_mask_latents, preprocess_image, ) -from core.optimizations import autocast +from core.optimizations import autocast, upcast_vae, ensure_correct_device, unload_all from .sag import CrossAttnStoreProcessor, pred_epsilon, pred_x0, sag_masking @@ -108,27 +108,13 @@ def _execution_device(self): `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ - if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): # type: ignore + if self.device != torch.device("meta") and not hasattr(self.unet, "offload_device"): # type: ignore return self.device - for module in self.unet.modules(): # type: ignore - if ( - hasattr(module, "_hf_hook") - and hasattr( - module._hf_hook, # pylint: disable=protected-access - "execution_device", - ) - and module._hf_hook.execution_device # pylint: disable=protected-access # type: ignore - is not None - ): - return torch.device( - module._hf_hook.execution_device # pylint: disable=protected-access # type: ignore - ) - return self.device + return getattr(self.unet, "offload_device", self.device) def _encode_prompt( self, prompt, - _device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, @@ -171,6 +157,7 @@ def _encode_prompt( " the batch size of `prompt`." ) + ensure_correct_device(self.text_encoder) text_embeddings, uncond_embeddings = get_weighted_text_embeddings( pipe=self.parent, prompt=prompt, @@ -214,7 +201,14 @@ def _check_inputs(self, prompt, strength, callback_steps): ) def _decode_latents(self, latents, height, width): + if config.api.upcast_vae: + upcast_vae(self.vae) + latents = latents.to( + next(iter(self.vae.post_quant_conv.parameters())).dtype + ) + latents = 1 / 0.18215 * latents + ensure_correct_device(self.vae) image = self.vae.decode(latents).sample # type: ignore image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 @@ -356,7 +350,6 @@ def __call__( # 3. Encode input prompt text_embeddings = self._encode_prompt( prompt, - device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, @@ -454,6 +447,7 @@ def get_map_size(_, __, output): ] # output.sample.shape[-2:] in older diffusers # 8. Denoising loop + ensure_correct_device(self.unet) # type: ignore with ExitStack() as gs: if do_self_attention_guidance: gs.enter_context(self.unet.mid_block.attentions[0].register_forward_hook(get_map_size)) # type: ignore @@ -598,18 +592,18 @@ def get_map_size(_, __, output): # 9. Post-processing if output_type == "latent": + unload_all() return latents, False # TODO: maybe implement asymmetric vqgan? image = self._decode_latents(latents, height=height, width=width) + unload_all() + # 11. Convert to PIL if output_type == "pil": image = self.numpy_to_pil(image) - if hasattr(self, "final_offload_hook"): - self.final_offload_hook.offload() # type: ignore - if not return_dict: return image, False diff --git a/core/inference/pytorch/pytorch.py b/core/inference/pytorch/pytorch.py index 6a4726ab6..84e80261e 100755 --- a/core/inference/pytorch/pytorch.py +++ b/core/inference/pytorch/pytorch.py @@ -67,10 +67,6 @@ def __init__( self.text_encoder: CLIPTextModel self.tokenizer: CLIPTokenizer self.scheduler: Any - self.feature_extractor: Any - self.requires_safety_checker: bool - self.safety_checker: Any - self.image_encoder: Any self.controlnet: Optional[ControlNetModel] = None self.current_controlnet: str = "" @@ -99,9 +95,6 @@ def load(self): self.text_encoder = pipe.text_encoder # type: ignore self.tokenizer = pipe.tokenizer # type: ignore self.scheduler = pipe.scheduler # type: ignore - self.feature_extractor = pipe.feature_extractor # type: ignore - self.requires_safety_checker = False # type: ignore - self.safety_checker = pipe.safety_checker # type: ignore if not self.bare: # Autoload textual inversions @@ -131,18 +124,19 @@ def change_vae(self, vae: str) -> None: setattr(self, "original_vae", self.vae) old_vae = getattr(self, "original_vae") + dtype = self.unet.dtype + device = old_vae.device if vae == "default": self.vae = old_vae else: - # Why the fuck do you think that's constant pylint? - # Are you mentally insane? - if Path(vae).is_dir(): - self.vae = AutoencoderKL.from_pretrained(vae) # type: ignore + if "/" in vae or Path(vae).is_dir(): + self.vae = AutoencoderKL.from_pretrained(vae).to( # type: ignore + device=device, dtype=dtype + ) else: self.vae = convert_vaept_to_diffusers(vae).to( - device=old_vae.device, dtype=old_vae.dtype + device=device, dtype=dtype ) - # This is at the end 'cause I've read horror stories about pythons prefetch system self.vae_path = vae def unload(self) -> None: @@ -154,9 +148,6 @@ def unload(self) -> None: self.text_encoder, self.tokenizer, self.scheduler, - self.feature_extractor, - self.requires_safety_checker, - self.safety_checker, ) if hasattr(self, "image_encoder"): @@ -573,9 +564,9 @@ def save(self, path: str = "converted", safetensors: bool = False): text_encoder=self.text_encoder, tokenizer=self.tokenizer, scheduler=self.scheduler, - feature_extractor=self.feature_extractor, - requires_safety_checker=self.requires_safety_checker, - safety_checker=self.safety_checker, + feature_extractor=None, # type: ignore + requires_safety_checker=False, + safety_checker=None, # type: ignore ) pipe.save_pretrained(path, safe_serialization=safetensors) diff --git a/core/inference/pytorch/tiled_diffusion/__init__.py b/core/inference/pytorch/tiled_diffusion/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/core/inference/pytorch/tiled_diffusion/canvas.py b/core/inference/pytorch/tiled_diffusion/canvas.py new file mode 100644 index 000000000..b99f13813 --- /dev/null +++ b/core/inference/pytorch/tiled_diffusion/canvas.py @@ -0,0 +1,473 @@ +# pylint: disable=attribute-defined-outside-init + +from dataclasses import dataclass +from typing import Literal, Union, List, Optional, Callable +from time import time +import math + +from diffusers import ( + DiffusionPipeline, + AutoencoderKL, + SchedulerMixin, + UNet2DConditionModel, +) +from transformers import CLIPTextModel, CLIPTokenizer +import torch +from tqdm import tqdm +import numpy as np +from numpy import pi, exp, sqrt +from PIL.Image import Image + +from core.config import config +from core.optimizations import autocast, upcast_vae, ensure_correct_device, unload_all +from core.inference.utilities import preprocess_image, get_weighted_text_embeddings +from core.utils import resize + +mask = Literal["constant", "gaussian", "quartic"] +reroll = Literal["reset", "epsilon"] + + +@dataclass +class CanvasRegion: + """Class defining a region in the canvas.""" + + row_init: int + row_end: int + col_init: int + col_end: int + region_seed: int = None # type: ignore + noise_eps: float = 0.0 + + def __post_init__(self): + if self.region_seed is None: + self.region_seed = math.ceil(time()) + coords = [self.row_init, self.row_end, self.col_init, self.col_end] + for coord in coords: + if coord < 0: + raise ValueError( + f"Region coordinates must be non-negative, found {coords}." + ) + if coord % 8 != 0: + raise ValueError( + f"Region coordinates must be multiples of 8, found {coords}." + ) + if self.noise_eps < 0: + raise ValueError( + f"Region epsilon must be non-negative, found {self.noise_eps}." + ) + self.latent_row_init = self.row_init // 8 + self.latent_row_end = self.row_end // 8 + self.latent_col_init = self.col_init // 8 + self.latent_col_end = self.col_end // 8 + + @property + def width(self): + "col_end - col_init" + return self.col_end - self.col_init + + @property + def height(self): + "row_end - row_init" + return self.row_end - self.row_init + + def get_region_generator(self, device: Union[torch.device, str] = "cpu"): + """Creates a generator for the region.""" + return torch.Generator(device).manual_seed(self.region_seed) + + +@dataclass +class DiffusionRegion(CanvasRegion): + """Abstract class for places where diffusion is taking place.""" + + +@dataclass +class RerollRegion(CanvasRegion): + """Class defining a region where latent rerolling will be taking place.""" + + reroll_mode: reroll = "reset" + + +@dataclass +class Text2ImageRegion(DiffusionRegion): + """Class defining a region where text guided diffusion will be taking place.""" + + prompt: str = "" + negative_prompt: str = "" + guidance_scale: float = 7.5 + mask_type: mask = "gaussian" + mask_weight: float = 1.0 + + text_embeddings = None + + def __post_init__(self): + super().__post_init__() + if self.mask_weight < 0: + raise ValueError( + f"Mask weight must be non-negative, found {self.mask_weight}." + ) + + @property + def do_classifier_free_guidance(self) -> bool: + """Whether to do classifier-free guidance (guidance_scale > 1.0)""" + return self.guidance_scale > 1.0 + + +@dataclass +class Image2ImageRegion(DiffusionRegion): + """Class defining a region where image guided diffusion will be taking place.""" + + reference_image: Image = None # type: ignore + strength: float = 0.8 + + def __post_init__(self): + super().__post_init__() + if self.reference_image is None: + raise ValueError("Reference image must be provided.") + if self.strength < 0 or self.strength > 1: + raise ValueError( + f"Strength must be between 0 and 1, found {self.strength}." + ) + self.reference_image = resize(self.reference_image, self.width, self.height) + + def encode_reference_image( + self, + encoder, + generator: torch.Generator, + device: Union[torch.device, str] = "cpu", + ): + """Encodes the reference image for this diffusion region.""" + ensure_correct_device(encoder) # type: ignore + img = preprocess_image(self.reference_image) + self.reference_latents = encoder.encode(img.to(device)).latent_dist.sample( + generator=generator + ) + self.reference_latents = 0.18215 * self.reference_latents + + +@dataclass +class MaskWeightsBuilder: + """Auxiliary class to compute a tensor of weights for a given diffusion region.""" + + latent_space_dim: int + nbatch: int = 1 + + def compute_mask_weights(self, region: Text2ImageRegion) -> torch.Tensor: + """Computes a tensor of weights for the given diffusion region.""" + if region.mask_type == "gaussian": + mask_weights = self._gaussian_weights(region) + elif region.mask_type == "constant": + mask_weights = self._constant_weights(region) + else: + mask_weights = self._quartic_weights(region) + return mask_weights + + def _constant_weights(self, region: Text2ImageRegion) -> torch.Tensor: + """Computes a tensor of constant weights for the given diffusion region.""" + return ( + torch.ones( + ( + self.nbatch, + self.latent_space_dim, + region.latent_row_end - region.latent_row_init, + region.latent_col_end - region.latent_col_init, + ) + ) + * region.mask_weight + ) + + def _gaussian_weights(self, region: Text2ImageRegion) -> torch.Tensor: + """Computes a tensor of gaussian weights for the given diffusion region.""" + latent_width = region.latent_col_end - region.latent_col_init + latent_height = region.latent_row_end - region.latent_row_init + + var = 0.01 + midpoint = ( + latent_width - 1 + ) / 2 # -1 because index goes from 0 to latent_width - 1 + x_probs = [ + exp( + -(x - midpoint) + * (x - midpoint) + / (latent_width * latent_width) + / (2 * var) + ) + / sqrt(2 * pi * var) + for x in range(latent_width) + ] + midpoint = (latent_height - 1) / 2 + y_probs = [ + exp( + -(y - midpoint) + * (y - midpoint) + / (latent_height * latent_height) + / (2 * var) + ) + / sqrt(2 * pi * var) + for y in range(latent_height) + ] + weights = np.outer(y_probs, x_probs) * region.mask_weight + return torch.tile( + torch.tensor(weights), (self.nbatch, self.latent_space_dim, 1, 1) + ) + + def _quartic_weights(self, region: Text2ImageRegion) -> torch.Tensor: + """Generates a quartic mask of weights for tile contributions + + The quartic kernel has bounded support over the diffusion region, and a smooth decay to the region limits. + """ + quartic_constant = 15.0 / 16.0 + + support = ( + np.array(range(region.latent_col_init, region.latent_col_end)) + - region.latent_col_init + ) / (region.latent_col_end - region.latent_col_init - 1) * 1.99 - (1.99 / 2.0) + x_probs = quartic_constant * np.square(1 - np.square(support)) + support = ( + np.array(range(region.latent_row_init, region.latent_row_end)) + - region.latent_row_init + ) / (region.latent_row_end - region.latent_row_init - 1) * 1.99 - (1.99 / 2.0) + y_probs = quartic_constant * np.square(1 - np.square(support)) + + weights = np.outer(y_probs, x_probs) * region.mask_weight + return torch.tile( + torch.tensor(weights), (self.nbatch, self.latent_space_dim, 1, 1) + ) + + +class StableDiffusionCanvasPipeline(DiffusionPipeline): + """Stable Diffusion pipeline with support for multiple diffusions on one canvas.""" + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: SchedulerMixin, + ): + super().__init__() + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=None, + feature_extractor=None, + ) + + def _decode_latents(self, latents): + latents = 1 / 0.18215 * latents + ensure_correct_device(self.vae) + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + return self.numpy_to_pil(image) + + def _get_timesteps(self, num_inference_steps: int, strength: float) -> torch.Tensor: + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * (1 - strength)) + offset + init_timestep = min(init_timestep, num_inference_steps) + + t_start = min( + max(num_inference_steps - init_timestep + offset, 0), + num_inference_steps - 1, + ) + latest_timestep = self.scheduler.timesteps[t_start] + return latest_timestep + + @property + def _execution_device(self): + # TODO: implement this from the SDXL PR + return self.unet.device + + def __call__( + self, + height: int, + width: int, + regions: List[DiffusionRegion], + generator: torch.Generator, + num_inference_steps: int = 50, + reroll_regions: List[RerollRegion] = [], + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + ) -> List[Image]: + batch_size = 1 + device = self._execution_device + unet_channels = self.unet.config.in_channels + + txt_regions = [r for r in regions if isinstance(r, Text2ImageRegion)] + img_regions = [r for r in regions if isinstance(r, Image2ImageRegion)] + + latents_shape = ( + batch_size, + unet_channels, + math.ceil(height / 8), + math.ceil(width / 8), + ) + + all_eps_rerolls = regions + [ + r for r in reroll_regions if r.reroll_mode == "epsilon" + ] + + with autocast( + dtype=self.unet.dtype, + disable=not config.api.autocast, + ): + self.scheduler.set_timesteps(num_inference_steps, device=device) + + for region in txt_regions: + ensure_correct_device(self.text_encoder) + prompt_embeds, negative_prompt_embeds = get_weighted_text_embeddings( + self, # type: ignore + region.prompt, + region.negative_prompt, + 3, + False, + False, + False, + ) + if region.do_classifier_free_guidance: + region.text_embeddings = torch.cat([prompt_embeds, negative_prompt_embeds], dim=0) # type: ignore + else: + region.text_embeddings = prompt_embeds + + init_noise = torch.randn(latents_shape, device=device, generator=generator) + + for region in reroll_regions: + if region.reroll_mode == "reset": + latent_height = region.latent_row_end - region.latent_row_init + latent_width = region.latent_col_end - region.latent_col_init + region_shape = ( + latents_shape[0], + latents_shape[1], + latent_height, + latent_width, + ) + init_noise[ + :, + :, + region.latent_row_init : region.latent_row_end, + region.latent_col_init : region.latent_col_end, + ] = torch.randn(region_shape, device=device, generator=generator) + + for region in all_eps_rerolls: + if region.noise_eps > 0: + region_noise = init_noise[ + :, + :, + region.latent_row_init : region.latent_row_end, + region.latent_col_init : region.latent_col_end, + ] + eps_noise = ( + torch.randn( + region_noise.shape, + generator=region.get_region_generator(device=device), + device=device, + ) + * region.noise_eps + ) + init_noise[ + :, + :, + region.latent_row_init : region.latent_row_end, + region.latent_col_init : region.latent_col_end, + ] += eps_noise + + latents = init_noise * self.scheduler.init_noise_sigma + + for region in img_regions: + region.encode_reference_image( + self.vae, generator=generator, device=device + ) + + mask_builder = MaskWeightsBuilder(unet_channels, batch_size) + mask_weights = [ + mask_builder.compute_mask_weights(r).to(device=device) + for r in txt_regions + ] + ensure_correct_device(self.unet) # type: ignore + for i, t in enumerate(tqdm(self.scheduler.timesteps)): + noise_pred_regions = [] + + for region in txt_regions: + region_latents = latents[ + :, + :, + region.latent_row_end : region.latent_row_init, + region.latent_col_init : region.latent_col_end, + ] + latent_model_input = region_latents + if region.do_classifier_free_guidance: + latent_model_input = torch.cat([latent_model_input] * 2) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=region.text_embeddings, + ).sample + if region.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_region = ( + noise_pred_uncond + + region.guidance_scale + * (noise_pred_text - noise_pred_uncond) + ) + noise_pred_regions.append(noise_pred_region) + noise_pred = torch.zeros(latents.shape, device=device) + contributors = torch.zeros(latents.shape, device=device) + + for region, noise_pred_region, mask_weights_region in zip( + txt_regions, noise_pred_regions, mask_weights + ): + noise_pred[ + :, + :, + region.latent_row_end : region.latent_row_init, + region.latent_col_init : region.latent_col_end, + ] += ( + noise_pred_region * mask_weights_region + ) + contributors[ + :, + :, + region.latent_row_end : region.latent_row_init, + region.latent_col_init : region.latent_col_end, + ] += mask_weights_region + noise_pred /= contributors + noise_pred = torch.nan_to_num(noise_pred) + + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + + for region in img_regions: + influence_step = self._get_timesteps( + num_inference_steps, region.strength + ) + if t > influence_step: + timestep = t.repeat(batch_size) + region_init_noise = init_noise[ + :, + :, + region.latent_row_init : region.latent_row_end, + region.latent_col_init : region.latent_col_end, + ] + region_latents = self.scheduler.add_noise( + region.reference_latents, region_init_noise, timestep + ) + latents[ + :, + :, + region.latent_row_init : region.latent_row_end, + region.latent_col_init : region.latent_col_end, + ] = region_latents + + if callback is not None: + if i % callback_steps == 0: + callback(i, t, latents) + image = self._decode_latents(latents) + + unload_all() + + return image diff --git a/core/inference/utilities/latents.py b/core/inference/utilities/latents.py index 84f91c5d7..cb814d938 100644 --- a/core/inference/utilities/latents.py +++ b/core/inference/utilities/latents.py @@ -12,6 +12,7 @@ from core.config import config from core.flags import LatentScaleModel +from core.optimizations import upcast_vae logger = logging.getLogger(__name__) @@ -256,8 +257,16 @@ def prepare_latents( return latents, None, None else: if image.shape[1] != 4: + if config.api.upcast_vae: + upcast_vae(pipe.vae) + image = image.to( + next(iter(pipe.vae.post_quant_conv.parameters())).dtype + ) + else: + image = image.to(config.api.device, dtype=pipe.vae.dtype) + image = pad_tensor(image, pipe.vae_scale_factor) - init_latent_dist = pipe.vae.encode(image.to(config.api.device, dtype=pipe.vae.dtype)).latent_dist # type: ignore + init_latent_dist = pipe.vae.encode(image).latent_dist # type: ignore init_latents = init_latent_dist.sample(generator=generator) init_latents = 0.18215 * init_latents init_latents = torch.cat([init_latents] * batch_size, dim=0) @@ -281,95 +290,14 @@ def prepare_latents( shape, generator=generator, device="cpu", dtype=dtype ).to(device) else: - # Retarded fix, but hey, if it works, it works - if hasattr(pipe.vae, "main_device"): - noise = torch.randn( - shape, - generator=torch.Generator("cpu").manual_seed(1), - device="cpu", - dtype=dtype, - ).to(device) - else: - noise = torch.randn( - shape, generator=generator, device=device, dtype=dtype - ) - # Now this... I may have called the previous "hack" retarded, but this... - # This just takes it to a whole new level + noise = torch.randn( + shape, generator=generator, device=device, dtype=dtype + ) latents = pipe.scheduler.add_noise(init_latents.to(device), noise.to(device), timestep.to(device)) # type: ignore return latents, init_latents_orig, noise -def bislerp_original(samples, width, height): - shape = list(samples.shape) - width_scale = (shape[3]) / (width) - height_scale = (shape[2]) / (height) - - shape[3] = width - shape[2] = height - out1 = torch.empty( - shape, dtype=samples.dtype, layout=samples.layout, device=samples.device - ) - - def algorithm(in1, in2, t): - dims = in1.shape - val = t - - # flatten to batches - low = in1.reshape(dims[0], -1) - high = in2.reshape(dims[0], -1) - - low_weight = torch.norm(low, dim=1, keepdim=True) - low_weight[low_weight == 0] = 0.0000000001 - low_norm = low / low_weight - high_weight = torch.norm(high, dim=1, keepdim=True) - high_weight[high_weight == 0] = 0.0000000001 - high_norm = high / high_weight - - dot_prod = (low_norm * high_norm).sum(1) - dot_prod[dot_prod > 0.9995] = 0.9995 - dot_prod[dot_prod < -0.9995] = -0.9995 - omega = torch.acos(dot_prod) - so = torch.sin(omega) - res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low_norm + ( - torch.sin(val * omega) / so - ).unsqueeze(1) * high_norm - res *= low_weight * (1.0 - val) + high_weight * val - return res.reshape(dims) - - for x_dest in range(shape[3]): - for y_dest in range(shape[2]): - y = (y_dest + 0.5) * height_scale - 0.5 - x = (x_dest + 0.5) * width_scale - 0.5 - - x1 = max(math.floor(x), 0) - x2 = min(x1 + 1, samples.shape[3] - 1) - wx = x - math.floor(x) - - y1 = max(math.floor(y), 0) - y2 = min(y1 + 1, samples.shape[2] - 1) - wy = y - math.floor(y) - - in1 = samples[:, :, y1, x1] - in2 = samples[:, :, y1, x2] - in3 = samples[:, :, y2, x1] - in4 = samples[:, :, y2, x2] - - if (x1 == x2) and (y1 == y2): - out_value = in1 - elif x1 == x2: - out_value = algorithm(in1, in3, wy) - elif y1 == y2: - out_value = algorithm(in1, in2, wx) - else: - o1 = algorithm(in1, in2, wx) - o2 = algorithm(in3, in4, wx) - out_value = algorithm(o1, o2, wy) - - out1[:, :, y_dest, x_dest] = out_value - return out1 - - -def bislerp_gabeified(samples, width, height): +def bislerp(samples, width, height): device = samples.device def slerp(b1, b2, r): @@ -488,9 +416,7 @@ def scale_latents( ): "Interpolate the latents to the desired scale." - align_to = ( - 32 if latent_scale_mode in ["bislerp-tortured", "bislerp-original"] else 8 - ) + align_to = 32 if latent_scale_mode == "bislerp" else 8 s = time() @@ -503,10 +429,8 @@ def scale_latents( height_truncated = int(((latents.shape[3] * scale - 1) // align_to + 1) * align_to) # Scale the latents - if latent_scale_mode == "bislerp-tortured": - interpolated = bislerp_gabeified(latents, height_truncated, width_truncated) - elif latent_scale_mode == "bislerp-original": - interpolated = bislerp_original(latents, height_truncated, width_truncated) + if latent_scale_mode == "bislerp": + interpolated = bislerp(latents, height_truncated, width_truncated) else: interpolated = F.interpolate( latents, diff --git a/core/inference/utilities/lwp.py b/core/inference/utilities/lwp.py index 43515d4db..d46576912 100644 --- a/core/inference/utilities/lwp.py +++ b/core/inference/utilities/lwp.py @@ -8,6 +8,7 @@ from ...config import config from ...files import get_full_model_path +from ...optimizations import ensure_correct_device logger = logging.getLogger(__name__) @@ -253,6 +254,7 @@ def get_unweighted_text_embeddings( if hasattr(pipe, "clip_inference"): text_embedding = pipe.clip_inference(text_input_chunk) else: + ensure_correct_device(pipe.text_encoder) text_embedding = pipe.text_encoder(text_input_chunk)[0] # type: ignore if no_boseos_middle: @@ -272,6 +274,7 @@ def get_unweighted_text_embeddings( if hasattr(pipe, "clip_inference"): text_embeddings = pipe.clip_inference(text_input) else: + ensure_correct_device(pipe.text_encoder) text_embeddings = pipe.text_encoder(text_input)[0] # type: ignore return text_embeddings diff --git a/core/optimizations/__init__.py b/core/optimizations/__init__.py index 431337bba..c5d8f5831 100644 --- a/core/optimizations/__init__.py +++ b/core/optimizations/__init__.py @@ -1,8 +1,13 @@ from .autocast_utils import autocast, without_autocast from .pytorch_optimizations import optimize_model +from .offload import ensure_correct_device, unload_all +from .upcast import upcast_vae __all__ = [ "optimize_model", "without_autocast", "autocast", + "upcast_vae", + "ensure_correct_device", + "unload_all", ] diff --git a/core/optimizations/offload.py b/core/optimizations/offload.py new file mode 100644 index 000000000..8444ab699 --- /dev/null +++ b/core/optimizations/offload.py @@ -0,0 +1,39 @@ +# pylint: disable=global-statement + +import logging + +import torch + +from core.config import config + +logger = logging.getLogger(__name__) +_module: torch.nn.Module = None # type: ignore + + +def unload_all(): + global _module + if _module is not None: + _module.cpu() + _module = None # type: ignore + + +def ensure_correct_device(module: torch.nn.Module): + if hasattr(module, "offload_device"): + global _module + + device = getattr(module, "offload_device", config.api.device) + logger.debug(f"Transferring {module.__class__.__name__} to {str(device)}.") + + if _module is not None: + logger.debug(f"Transferring {_module.__class__.__name__} to cpu.") + _module.cpu() + + module.to(device=device) + _module = module + else: + logger.debug(f"Don't need to do anything with {module.__class__.__name__}.") + + +def set_offload(module, device: torch.device): + logger.debug(f"Offloading {module.__class__.__name__} to {str(device)}.") + setattr(module, "offload_device", device) \ No newline at end of file diff --git a/core/optimizations/pytorch_optimizations.py b/core/optimizations/pytorch_optimizations.py index a34472ba0..c2a5f0be5 100644 --- a/core/optimizations/pytorch_optimizations.py +++ b/core/optimizations/pytorch_optimizations.py @@ -7,13 +7,12 @@ StableDiffusionPipeline, StableDiffusionUpscalePipeline, ) -from diffusers.utils import is_accelerate_available from core.config import config -from core.files import get_full_model_path from .attn import set_attention_processor from .trace_utils import generate_inputs, trace_model +from .offload import set_offload logger = logging.getLogger(__name__) @@ -48,22 +47,24 @@ def optimize_model( 'You can disable it by going inside Graphics Settings → "Default Graphics Settings" and disabling "Hardware-accelerated GPU Scheduling"' ) - offload = ( - config.api.offload - if (is_pytorch_pipe(pipe) and not is_for_aitemplate) - else None + offload = config.api.offload and is_pytorch_pipe(pipe) and not is_for_aitemplate + can_offload = ( + config.api.device_type + not in [ + "cpu", + "vulkan", + "mps", + ] + and offload ) - can_offload = config.api.device_type not in [ - "cpu", - "vulkan", - "mps", - ] and (offload != "disabled" and offload is not None) # Took me an hour to understand why CPU stopped working... # Turns out AMD just lacks support for BF16... # Not mad, not mad at all... to be fair, I'm just disappointed if not can_offload and not is_for_aitemplate: pipe.to(device, torch_dtype=config.api.dtype) + else: + pipe.to(torch_dtype=config.api.dtype) if config.api.device_type == "cuda" and not is_for_aitemplate: supports_tf = supports_tf32(device) @@ -123,80 +124,27 @@ def optimize_model( logger.info("Optimization: Enabled autocast") if can_offload: - if not is_accelerate_available(): - logger.warning( - "Optimization: Offload is not available, because accelerate is not installed" - ) - else: - if offload == "model": - # Offload to CPU - from accelerate import cpu_offload_with_hook - - if config.api.device_type == "cuda": - torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) - - hook = None - - for cpu_offloaded_model in [ - pipe.text_encoder, - pipe.unet, - pipe.vae, - ]: - _, hook = cpu_offload_with_hook( - cpu_offloaded_model, device, prev_module_hook=hook - ) - pipe.final_offload_hook = hook - setattr(pipe.vae, "main_device", True) - setattr(pipe.unet, "main_device", True) - logger.info("Optimization: Offloaded model parts to CPU.") - - elif offload == "module": - # Enable sequential offload - from accelerate import cpu_offload, disk_offload - - for m in [ - pipe.vae, - pipe.unet, - ]: - if USE_DISK_OFFLOAD: - # If USE_DISK_OFFLOAD toggle set (idk why anyone would do this, but it's nice to support stuff - # like this in case anyone wants to try running this on fuck knows what) - # then offload to disk. - disk_offload( - m, - str( - get_full_model_path("offload-dir", model_folder="temp") - / m.__name__ - ), - device, - offload_buffers=True, - ) - else: - cpu_offload(m, device, offload_buffers=True) - - logger.info("Optimization: Enabled sequential offload") + # Offload to CPU + + for model_name in [ + "text_encoder", + "text_encoder2", + "unet", + "vae", + ]: + cpu_offloaded_model = getattr(pipe, model_name, None) + if cpu_offloaded_model is not None: + set_offload(cpu_offloaded_model, device) + setattr(pipe, model_name, cpu_offloaded_model) + logger.info("Optimization: Offloaded model parts to CPU.") if config.api.vae_slicing: - if not ( - issubclass(pipe.__class__, StableDiffusionUpscalePipeline) - or isinstance(pipe, StableDiffusionUpscalePipeline) - ): - pipe.enable_vae_slicing() - logger.info("Optimization: Enabled VAE slicing") - else: - logger.debug( - "Optimization: VAE slicing is not available for upscale models" - ) + pipe.enable_vae_slicing() + logger.info("Optimization: Enabled VAE slicing") if config.api.vae_tiling: - if not ( - issubclass(pipe.__class__, StableDiffusionUpscalePipeline) - or isinstance(pipe, StableDiffusionUpscalePipeline) - ): - pipe.enable_vae_tiling() - logger.info("Optimization: Enabled VAE tiling") - else: - logger.debug("Optimization: VAE tiling is not available for upscale models") + pipe.enable_vae_tiling() + logger.info("Optimization: Enabled VAE tiling") if config.api.use_tomesd and not is_for_aitemplate: try: diff --git a/core/optimizations/upcast.py b/core/optimizations/upcast.py new file mode 100644 index 000000000..dccd7d0b2 --- /dev/null +++ b/core/optimizations/upcast.py @@ -0,0 +1,26 @@ +import logging + +from diffusers.models import AutoencoderKL +from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor +import torch + +logger = logging.getLogger(__name__) + + +def upcast_vae(vae: AutoencoderKL): + dtype = vae.dtype + logger.debug('Upcasting VAE to FP32 (vae["force_upcast"] OR config.api.upcast_vae)') + vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + vae.decoder.mid_block.attentions[0].processor, # type: ignore + ( + AttnProcessor2_0, + XFormersAttnProcessor, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + vae.post_quant_conv.to(dtype=dtype) + vae.decoder.conv_in.to(dtype=dtype) + vae.decoder.mid_block.to(dtype=dtype) # type: ignore \ No newline at end of file diff --git a/frontend/dist/assets/SettingsView.js b/frontend/dist/assets/SettingsView.js index 5b51a3cda..9ecba43a0 100644 --- a/frontend/dist/assets/SettingsView.js +++ b/frontend/dist/assets/SettingsView.js @@ -2244,21 +2244,7 @@ const _sfc_main$d = /* @__PURE__ */ defineComponent({ }), createVNode(unref(NFormItem), { label: "Offload" }, { default: withCtx(() => [ - createVNode(unref(NSelect), { - options: [ - { - value: "disabled", - label: "Disabled" - }, - { - value: "model", - label: "Offload the whole model to RAM when not used" - }, - { - value: "module", - label: "Offload individual modules to RAM when not used" - } - ], + createVNode(unref(NSwitch), { value: unref(settings).defaultSettings.api.offload, "onUpdate:value": _cache[24] || (_cache[24] = ($event) => unref(settings).defaultSettings.api.offload = $event) }, null, 8, ["value"]) @@ -2495,12 +2481,8 @@ const _sfc_main$b = /* @__PURE__ */ defineComponent({ value: "bicubic" }, { - label: "Bislerp (Original, slow)", - value: "bislerp-original" - }, - { - label: "Bislerp (Tortured, fast)", - value: "bislerp-tortured" + label: "Bislerp", + value: "bislerp" }, { label: "Nearest Exact", @@ -2509,7 +2491,7 @@ const _sfc_main$b = /* @__PURE__ */ defineComponent({ ], value: unref(settings).defaultSettings.extra.highres.latent_scale_mode, "onUpdate:value": _cache[1] || (_cache[1] = ($event) => unref(settings).defaultSettings.extra.highres.latent_scale_mode = $event) - }, null, 8, ["options", "value"]) + }, null, 8, ["value"]) ]), _: 1 }), diff --git a/frontend/dist/assets/index.js b/frontend/dist/assets/index.js index 2da292cc9..b16ef051d 100644 --- a/frontend/dist/assets/index.js +++ b/frontend/dist/assets/index.js @@ -40646,7 +40646,7 @@ const defaultSettings = { vae_tiling: false, trace_model: false, cudnn_benchmark: false, - offload: "disabled", + offload: false, device_id: 0, device_type: "cuda", data_type: "float16", diff --git a/frontend/src/components/settings/APISettings.vue b/frontend/src/components/settings/APISettings.vue index ede89a6fb..9faf3c49c 100644 --- a/frontend/src/components/settings/APISettings.vue +++ b/frontend/src/components/settings/APISettings.vue @@ -209,24 +209,7 @@ - - +

Device

diff --git a/frontend/src/components/settings/ExtraSettings/HiresSettings.vue b/frontend/src/components/settings/ExtraSettings/HiresSettings.vue index e859b0095..3d95d09df 100644 --- a/frontend/src/components/settings/ExtraSettings/HiresSettings.vue +++ b/frontend/src/components/settings/ExtraSettings/HiresSettings.vue @@ -26,12 +26,8 @@ value: 'bicubic', }, { - label: 'Bislerp (Original, slow)', - value: 'bislerp-original', - }, - { - label: 'Bislerp (Tortured, fast)', - value: 'bislerp-tortured', + label: 'Bislerp', + value: 'bislerp', }, { label: 'Nearest Exact', diff --git a/frontend/src/settings.ts b/frontend/src/settings.ts index 664a35568..873a80a6c 100644 --- a/frontend/src/settings.ts +++ b/frontend/src/settings.ts @@ -35,8 +35,7 @@ export interface ISettings { | "nearest" | "area" | "bilinear" - | "bislerp-original" - | "bislerp-tortured" + | "bislerp" | "bicubic" | "nearest-exact"; strength: number; @@ -156,7 +155,7 @@ export interface ISettings { vae_slicing: boolean; vae_tiling: boolean; trace_model: boolean; - offload: "module" | "model" | "disabled"; + offload: boolean; image_preview_delay: number; device_id: number; device_type: "cpu" | "cuda" | "mps" | "directml"; @@ -176,6 +175,8 @@ export interface ISettings { disable_grid: boolean; + upcast_vae: boolean; + torch_compile: boolean; torch_compile_fullgraph: boolean; torch_compile_dynamic: boolean; @@ -317,7 +318,7 @@ export const defaultSettings: ISettings = { vae_tiling: false, trace_model: false, cudnn_benchmark: false, - offload: "disabled", + offload: false, device_id: 0, device_type: "cuda", @@ -340,6 +341,8 @@ export const defaultSettings: ISettings = { disable_grid: false, + upcast_vae: false, + torch_compile: false, torch_compile_fullgraph: false, torch_compile_dynamic: false,