From 4a87e96a8dfc311df82f62e498a3861df80abc32 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 3 Dec 2025 19:28:33 -0500 Subject: [PATCH 01/11] Make old fp8 system use new mixed quant system. Upgrade the internal mixed quant format to be easier to deal with. --- comfy/model_base.py | 15 +- comfy/model_detection.py | 33 +---- comfy/model_patcher.py | 20 +-- comfy/ops.py | 129 ++++++------------ comfy/quant_ops.py | 23 +++- comfy/sd.py | 54 ++++++-- comfy/sd1_clip.py | 22 +-- comfy/supported_models_base.py | 3 +- comfy/text_encoders/cosmos.py | 12 +- comfy/text_encoders/flux.py | 12 +- comfy/text_encoders/genmo.py | 6 +- comfy/text_encoders/hidream.py | 10 +- comfy/text_encoders/hunyuan_image.py | 12 +- comfy/text_encoders/hunyuan_video.py | 23 ++-- comfy/text_encoders/lumina2.py | 6 +- comfy/text_encoders/omnigen2.py | 6 +- comfy/text_encoders/pixart_t5.py | 6 +- comfy/text_encoders/qwen_image.py | 6 +- comfy/text_encoders/sd3_clip.py | 19 +-- comfy/text_encoders/wan.py | 6 +- comfy/text_encoders/z_image.py | 5 +- comfy/utils.py | 66 +++++++++ .../comfy_quant/test_mixed_precision.py | 3 +- 23 files changed, 241 insertions(+), 256 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 9b76c285ed1d..9cde2e051624 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -57,6 +57,7 @@ import comfy.latent_formats import comfy.model_sampling import math +import json from typing import TYPE_CHECKING if TYPE_CHECKING: from comfy.model_patcher import ModelPatcher @@ -134,7 +135,7 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_mod if not unet_config.get("disable_unet_model_creation", False): if model_config.custom_operations is None: fp8 = model_config.optimizations.get("fp8", False) - operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config) + operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, model_config=model_config) else: operations = model_config.custom_operations self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) @@ -329,18 +330,6 @@ def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_ extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict)) unet_state_dict = self.diffusion_model.state_dict() - - if self.model_config.scaled_fp8 is not None: - unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8) - - # Save mixed precision metadata - if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config: - metadata = { - "format_version": "1.0", - "layers": self.model_config.layer_quant_config - } - unet_state_dict["_quantization_metadata"] = metadata - unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) if self.model_type == ModelType.V_PREDICTION: diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 7d0517e611b1..fd1907627509 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -6,20 +6,6 @@ import logging import torch - -def detect_layer_quantization(metadata): - quant_key = "_quantization_metadata" - if metadata is not None and quant_key in metadata: - quant_metadata = metadata.pop(quant_key) - quant_metadata = json.loads(quant_metadata) - if isinstance(quant_metadata, dict) and "layers" in quant_metadata: - logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})") - return quant_metadata["layers"] - else: - raise ValueError("Invalid quantization metadata format") - return None - - def count_blocks(state_dict_keys, prefix_string): count = 0 while True: @@ -767,22 +753,11 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal if model_config is None and use_base_if_no_match: model_config = comfy.supported_models_base.BASE(unet_config) - scaled_fp8_key = "{}scaled_fp8".format(unet_key_prefix) - if scaled_fp8_key in state_dict: - scaled_fp8_weight = state_dict.pop(scaled_fp8_key) - model_config.scaled_fp8 = scaled_fp8_weight.dtype - if model_config.scaled_fp8 == torch.float32: - model_config.scaled_fp8 = torch.float8_e4m3fn - if scaled_fp8_weight.nelement() == 2: - model_config.optimizations["fp8"] = False - else: - model_config.optimizations["fp8"] = True - # Detect per-layer quantization (mixed precision) - layer_quant_config = detect_layer_quantization(metadata) - if layer_quant_config: - model_config.layer_quant_config = layer_quant_config - logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized") + quant_config = comfy.utils.detect_layer_quantization(state_dict, unet_key_prefix) + if quant_config: + model_config.quant_config = quant_config + logging.info("Detected mixed precision quantization") return model_config diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index df2d8e827661..19ab8b4ea320 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -126,27 +126,11 @@ class LowVramPatch: def __init__(self, key, patches, convert_func=None, set_func=None): self.key = key self.patches = patches - self.convert_func = convert_func + self.convert_func = convert_func # TODO: remove self.set_func = set_func def __call__(self, weight): - intermediate_dtype = weight.dtype - if self.convert_func is not None: - weight = self.convert_func(weight, inplace=False) - - if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops - intermediate_dtype = torch.float32 - out = comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype) - if self.set_func is None: - return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key)) - else: - return self.set_func(out, seed=string_to_seed(self.key), return_weight=True) - - out = comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype) - if self.set_func is not None: - return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype) - else: - return out + return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype) #The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3 LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3 diff --git a/comfy/ops.py b/comfy/ops.py index eae434e68465..34f0a74d7ba8 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -23,6 +23,7 @@ import comfy.float import comfy.rmsnorm import contextlib +import json def run_every_op(): if torch.compiler.is_compiling(): @@ -422,22 +423,12 @@ def fp8_linear(self, input): if input.ndim == 3 or input.ndim == 2: w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True) + scale_weight = torch.ones((), device=input.device, dtype=torch.float32) - scale_weight = self.scale_weight - scale_input = self.scale_input - if scale_weight is None: - scale_weight = torch.ones((), device=input.device, dtype=torch.float32) - else: - scale_weight = scale_weight.to(input.device) - - if scale_input is None: - scale_input = torch.ones((), device=input.device, dtype=torch.float32) - input = torch.clamp(input, min=-448, max=448, out=input) - layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype} - quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight) - else: - scale_input = scale_input.to(input.device) - quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype) + scale_input = torch.ones((), device=input.device, dtype=torch.float32) + input = torch.clamp(input, min=-448, max=448, out=input) + layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype} + quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight) # Wrap weight in QuantizedTensor - this enables unified dispatch # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! @@ -471,59 +462,6 @@ def forward_comfy_cast_weights(self, input): uncast_bias_weight(self, weight, bias, offload_stream) return x -def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None): - logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input)) - class scaled_fp8_op(manual_cast): - class Linear(manual_cast.Linear): - def __init__(self, *args, **kwargs): - if override_dtype is not None: - kwargs['dtype'] = override_dtype - super().__init__(*args, **kwargs) - - def reset_parameters(self): - if not hasattr(self, 'scale_weight'): - self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False) - - if not scale_input: - self.scale_input = None - - if not hasattr(self, 'scale_input'): - self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False) - return None - - def forward_comfy_cast_weights(self, input): - if fp8_matrix_mult: - out = fp8_linear(self, input) - if out is not None: - return out - - weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) - - if weight.numel() < input.numel(): #TODO: optimize - x = torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias) - else: - x = torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias) - uncast_bias_weight(self, weight, bias, offload_stream) - return x - - def convert_weight(self, weight, inplace=False, **kwargs): - if inplace: - weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype) - return weight - else: - return weight.to(dtype=torch.float32) * self.scale_weight.to(device=weight.device, dtype=torch.float32) - - def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): - weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed) - if return_weight: - return weight - if inplace_update: - self.weight.data.copy_(weight) - else: - self.weight = torch.nn.Parameter(weight, requires_grad=False) - - return scaled_fp8_op - CUBLAS_IS_AVAILABLE = False try: from cublas_ops import CublasLinear @@ -550,9 +488,9 @@ def forward(self, *args, **kwargs): from .quant_ops import QuantizedTensor, QUANT_ALGOS -def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False): +def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False): class MixedPrecisionOps(manual_cast): - _layer_quant_config = layer_quant_config + _quant_config = quant_config _compute_dtype = compute_dtype _full_precision_mm = full_precision_mm @@ -595,27 +533,36 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, manually_loaded_keys = [weight_key] - if layer_name not in MixedPrecisionOps._layer_quant_config: + layer_conf = state_dict.pop(f"{prefix}comfy_quant", None) + if layer_conf is not None: + layer_conf = json.loads(layer_conf.numpy().tobytes()) + + if layer_conf is None: self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) else: - quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None) - if quant_format is None: + self.quant_format = layer_conf.get("format", None) + if not self._full_precision_mm: + self._full_precision_mm = layer_conf.get("full_precision_matrix_mult", False) + + if self.quant_format is None: raise ValueError(f"Unknown quantization format for layer {layer_name}") - qconfig = QUANT_ALGOS[quant_format] + qconfig = QUANT_ALGOS[self.quant_format] self.layout_type = qconfig["comfy_tensor_layout"] weight_scale_key = f"{prefix}weight_scale" + scale = state_dict.pop(weight_scale_key, None) layout_params = { - 'scale': state_dict.pop(weight_scale_key, None), + 'scale': scale, 'orig_dtype': MixedPrecisionOps._compute_dtype, 'block_size': qconfig.get("group_size", None), } - if layout_params['scale'] is not None: + + if scale is not None: manually_loaded_keys.append(weight_scale_key) self.weight = torch.nn.Parameter( - QuantizedTensor(weight.to(device=device), self.layout_type, layout_params), + QuantizedTensor(weight.to(device=device, dtype=qconfig.get("storage_t", None)), self.layout_type, layout_params), requires_grad=False ) @@ -624,7 +571,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, _v = state_dict.pop(param_key, None) if _v is None: continue - setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False)) + self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False)) manually_loaded_keys.append(param_key) super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) @@ -633,6 +580,16 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, if key in missing_keys: missing_keys.remove(key) + def state_dict(self, *args, destination=None, prefix="", **kwargs): + sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs) + if isinstance(self.weight, QuantizedTensor): + sd["{}weight_scale".format(prefix)] = self.weight._layout_params['scale'] + quant_conf = {"format": self.quant_format} + if self._full_precision_mm: + quant_conf["full_precision_matrix_mult"] = True + sd["{}comfy_quant".format(prefix)] = torch.frombuffer(json.dumps(quant_conf).encode('utf-8'), dtype=torch.uint8) + return sd + def _forward(self, input, weight, bias): return torch.nn.functional.linear(input, weight, bias) @@ -648,9 +605,8 @@ def forward(self, input, *args, **kwargs): if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(input, *args, **kwargs) if (getattr(self, 'layout_type', None) is not None and - getattr(self, 'input_scale', None) is not None and not isinstance(input, QuantizedTensor)): - input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype) + input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None), dtype=self.weight.dtype) return self._forward(input, self.weight, self.bias) def convert_weight(self, weight, inplace=False, **kwargs): @@ -661,7 +617,7 @@ def convert_weight(self, weight, inplace=False, **kwargs): def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): if getattr(self, 'layout_type', None) is not None: - weight = QuantizedTensor.from_float(weight, self.layout_type, scale=None, dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True) + weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True) else: weight = weight.to(self.weight.dtype) if return_weight: @@ -672,15 +628,12 @@ def set_weight(self, weight, inplace_update=False, seed=None, return_weight=Fals return MixedPrecisionOps -def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None): +def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None): fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular - if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config: - logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers") - return mixed_precision_ops(model_config.layer_quant_config, compute_dtype, full_precision_mm=not fp8_compute) - - if scaled_fp8 is not None: - return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8) + if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config: + logging.info("Using mixed precision operations") + return mixed_precision_ops(model_config.quant_config, compute_dtype, full_precision_mm=not fp8_compute) if ( fp8_compute and diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index bb1fb860ca52..a565162db851 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -238,6 +238,12 @@ def is_pinned(self): def is_contiguous(self, *arg, **kwargs): return self._qdata.is_contiguous(*arg, **kwargs) + def storage(self): + return self._qdata.storage() + + def untyped_storage(self): + return self._qdata.untyped_storage() + # ============================================================================== # Generic Utilities (Layout-Agnostic Operations) # ============================================================================== @@ -397,17 +403,20 @@ class TensorCoreFP8Layout(QuantizedLayout): def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False): orig_dtype = tensor.dtype - if scale is None: + if scale == "recalculate": scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max - if not isinstance(scale, torch.Tensor): - scale = torch.tensor(scale) - scale = scale.to(device=tensor.device, dtype=torch.float32) + if scale is not None: + if not isinstance(scale, torch.Tensor): + scale = torch.tensor(scale) + scale = scale.to(device=tensor.device, dtype=torch.float32) - if inplace_ops: - tensor *= (1.0 / scale).to(tensor.dtype) + if inplace_ops: + tensor *= (1.0 / scale).to(tensor.dtype) + else: + tensor = tensor * (1.0 / scale).to(tensor.dtype) else: - tensor = tensor * (1.0 / scale).to(tensor.dtype) + scale = torch.ones((), device=tensor.device, dtype=torch.float32) if stochastic_rounding > 0: tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding) diff --git a/comfy/sd.py b/comfy/sd.py index f9e5efab5001..440aed114e57 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -964,10 +964,7 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI clip_data = [] for p in ckpt_paths: sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True) - if metadata is not None: - quant_metadata = metadata.get("_quantization_metadata", None) - if quant_metadata is not None: - sd["_quantization_metadata"] = quant_metadata + sd, metadata = comfy.utils.convert_old_quants(sd, model_prefix="", metadata=metadata) clip_data.append(sd) return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options) @@ -1084,7 +1081,7 @@ class EmptyClass: clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False) clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer elif clip_type == CLIPType.HIDREAM: - clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None) + clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None) clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: clip_target.clip = sdxl_clip.SDXLRefinerClipModel @@ -1108,7 +1105,7 @@ class EmptyClass: tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) elif clip_type == CLIPType.HIDREAM: clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data), - clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None) + clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None) clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: #CLIPType.MOCHI clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data)) @@ -1137,7 +1134,7 @@ class EmptyClass: tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) elif te_model == TEModel.LLAMA3_8: clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data), - clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None) + clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None) clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer elif te_model == TEModel.QWEN25_3B: clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data)) @@ -1165,7 +1162,7 @@ class EmptyClass: clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False) clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer elif clip_type == CLIPType.HIDREAM: - clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None) + clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None) clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: clip_target.clip = sd1_clip.SD1ClipModel @@ -1291,6 +1288,8 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix) load_device = model_management.get_torch_device() + sd, metadata = comfy.utils.convert_old_quants(sd, diffusion_model_prefix, metadata=metadata) + model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata) if model_config is None: logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.") @@ -1299,9 +1298,8 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c return None return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used' - unet_weight_dtype = list(model_config.supported_inference_dtypes) - if model_config.scaled_fp8 is not None: + if model_config.quant_config is not None: weight_dtype = None model_config.custom_operations = model_options.get("custom_operations", None) @@ -1310,7 +1308,10 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c if unet_dtype is None: unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype) - manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) + if model_config.quant_config is not None: + manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes) + else: + manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) if model_config.clip_vision_prefix is not None: @@ -1328,6 +1329,26 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c vae = VAE(sd=vae_sd, metadata=metadata) if output_clip: + scaled_fp8_list = [] + for k in list(sd.keys()): # Convert scaled fp8 to mixed ops + if k.endswith(".scaled_fp8"): + scaled_fp8_list.append(k[:-len("scaled_fp8")]) + + if len(scaled_fp8_list) > 0: + out_sd = {} + for k in sd: + skip = False + for pref in scaled_fp8_list: + skip = skip or k.startswith(pref) + if not skip: + out_sd[k] = sd[k] + + for pref in scaled_fp8_list: + quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={}) + for k in quant_sd: + out_sd[k] = quant_sd[k] + sd = out_sd + clip_target = model_config.clip_target(state_dict=sd) if clip_target is not None: clip_sd = model_config.process_clip_state_dict(sd) @@ -1389,7 +1410,11 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): temp_sd = comfy.utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True) if len(temp_sd) > 0: sd = temp_sd + quant_key = "{}_quantization_metadata".format(diffusion_model_prefix) + if metadata is not None and quant_key in metadata: + metadata["_quantization_metadata"] = metadata.pop(quant_key) + sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata) parameters = comfy.utils.calculate_parameters(sd) weight_dtype = comfy.utils.weight_dtype(sd) @@ -1420,7 +1445,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): offload_device = model_management.unet_offload_device() unet_weight_dtype = list(model_config.supported_inference_dtypes) - if model_config.scaled_fp8 is not None: + if model_config.quant_config is not None: weight_dtype = None if dtype is None: @@ -1428,7 +1453,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): else: unet_dtype = dtype - if model_config.layer_quant_config is not None: + if model_config.quant_config is not None: manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes) else: manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) @@ -1472,6 +1497,9 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m if vae is not None: vae_sd = vae.get_sd() + if metadata is None: + metadata = {} + model_management.load_models_gpu(load_models, force_patch_weights=True) clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 0fc9ab3dba58..174df66b28b9 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -107,29 +107,17 @@ def __init__(self, device="cpu", max_length=77, config[k] = v operations = model_options.get("custom_operations", None) - scaled_fp8 = None - quantization_metadata = model_options.get("quantization_metadata", None) + quant_config = model_options.get("quantization_metadata", None) if operations is None: - layer_quant_config = None - if quantization_metadata is not None: - layer_quant_config = json.loads(quantization_metadata).get("layers", None) - - if layer_quant_config is not None: - operations = comfy.ops.mixed_precision_ops(layer_quant_config, dtype, full_precision_mm=True) - logging.info(f"Using MixedPrecisionOps for text encoder: {len(layer_quant_config)} quantized layers") + if quant_config is not None: + operations = comfy.ops.mixed_precision_ops(quant_config, dtype, full_precision_mm=True) + logging.info("Using MixedPrecisionOps for text encoder") else: - # Fallback to scaled_fp8_ops for backward compatibility - scaled_fp8 = model_options.get("scaled_fp8", None) - if scaled_fp8 is not None: - operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8) - else: - operations = comfy.ops.manual_cast + operations = comfy.ops.manual_cast self.operations = operations self.transformer = model_class(config, dtype, device, self.operations) - if scaled_fp8 is not None: - self.transformer.scaled_fp8 = torch.nn.Parameter(torch.tensor([], dtype=scaled_fp8)) self.num_layers = self.transformer.num_layers diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index e4bd7451429b..9fd84d329710 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -49,8 +49,7 @@ class BASE: manual_cast_dtype = None custom_operations = None - scaled_fp8 = None - layer_quant_config = None # Per-layer quantization configuration for mixed precision + quant_config = None # quantization configuration for mixed precision optimizations = {"fp8": False} @classmethod diff --git a/comfy/text_encoders/cosmos.py b/comfy/text_encoders/cosmos.py index a1adb5242bc9..448381fa9f9c 100644 --- a/comfy/text_encoders/cosmos.py +++ b/comfy/text_encoders/cosmos.py @@ -7,10 +7,10 @@ class T5XXLModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_old_config_xxl.json") - t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None) - if t5xxl_scaled_fp8 is not None: + t5xxl_quantization_metadata = model_options.get("t5xxl_quantization_metadata", None) + if t5xxl_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = t5xxl_scaled_fp8 + model_options["quantization_metadata"] = t5xxl_quantization_metadata super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, zero_out_masked=attention_mask, model_options=model_options) @@ -30,12 +30,12 @@ def __init__(self, embedding_directory=None, tokenizer_data={}): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer) -def te(dtype_t5=None, t5xxl_scaled_fp8=None): +def te(dtype_t5=None, t5_quantization_metadata=None): class CosmosTEModel_(CosmosT5XXL): def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata if dtype is None: dtype = dtype_t5 super().__init__(device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py index 99f4812bb57f..21d93d757603 100644 --- a/comfy/text_encoders/flux.py +++ b/comfy/text_encoders/flux.py @@ -63,12 +63,12 @@ def load_sd(self, sd): else: return self.t5xxl.load_sd(sd) -def flux_clip(dtype_t5=None, t5xxl_scaled_fp8=None): +def flux_clip(dtype_t5=None, t5_quantization_metadata=None): class FluxClipModel_(FluxClipModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options) return FluxClipModel_ @@ -159,15 +159,13 @@ def encode_token_weights(self, token_weight_pairs): out = out.reshape(out.shape[0], out.shape[1], -1) return out, pooled, extra -def flux2_te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None, pruned=False): +def flux2_te(dtype_llama=None, llama_quantization_metadata=None, pruned=False): class Flux2TEModel_(Flux2TEModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: - model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 if dtype_llama is not None: dtype = dtype_llama if llama_quantization_metadata is not None: + model_options = model_options.copy() model_options["quantization_metadata"] = llama_quantization_metadata if pruned: model_options = model_options.copy() diff --git a/comfy/text_encoders/genmo.py b/comfy/text_encoders/genmo.py index 9dcf190a2325..5daea81355c1 100644 --- a/comfy/text_encoders/genmo.py +++ b/comfy/text_encoders/genmo.py @@ -26,12 +26,12 @@ def __init__(self, embedding_directory=None, tokenizer_data={}): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer) -def mochi_te(dtype_t5=None, t5xxl_scaled_fp8=None): +def mochi_te(dtype_t5=None, t5_quantization_metadata=None): class MochiTEModel_(MochiT5XXL): def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata if dtype is None: dtype = dtype_t5 super().__init__(device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/hidream.py b/comfy/text_encoders/hidream.py index dbcf52784d63..600b344803dd 100644 --- a/comfy/text_encoders/hidream.py +++ b/comfy/text_encoders/hidream.py @@ -142,14 +142,14 @@ def load_sd(self, sd): return self.llama.load_sd(sd) -def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None): +def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5_quantization_metadata=None, llama_quantization_metadata=None): class HiDreamTEModel_(HiDreamTEModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 - if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options: + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["llama_scaled_fp8"] = llama_scaled_fp8 + model_options["llama_quantization_metadata"] = llama_quantization_metadata super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, dtype_t5=dtype_t5, dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options) return HiDreamTEModel_ diff --git a/comfy/text_encoders/hunyuan_image.py b/comfy/text_encoders/hunyuan_image.py index ff04726e11ba..cd198036c838 100644 --- a/comfy/text_encoders/hunyuan_image.py +++ b/comfy/text_encoders/hunyuan_image.py @@ -40,10 +40,10 @@ def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): class Qwen25_7BVLIModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}): - llama_scaled_fp8 = model_options.get("qwen_scaled_fp8", None) - if llama_scaled_fp8 is not None: + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 + model_options["quantization_metadata"] = llama_quantization_metadata super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) @@ -91,12 +91,12 @@ def load_sd(self, sd): else: return super().load_sd(sd) -def te(byt5=True, dtype_llama=None, llama_scaled_fp8=None): +def te(byt5=True, dtype_llama=None, llama_quantization_metadata=None): class QwenImageTEModel_(HunyuanImageTEModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["qwen_scaled_fp8"] = llama_scaled_fp8 + model_options["llama_quantization_metadata"] = llama_quantization_metadata if dtype_llama is not None: dtype = dtype_llama super().__init__(byt5=byt5, device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/hunyuan_video.py b/comfy/text_encoders/hunyuan_video.py index 0110517bbdf3..a9a6c525e011 100644 --- a/comfy/text_encoders/hunyuan_video.py +++ b/comfy/text_encoders/hunyuan_video.py @@ -6,7 +6,7 @@ import torch import os import numbers - +import comfy.utils def llama_detect(state_dict, prefix=""): out = {} @@ -14,12 +14,9 @@ def llama_detect(state_dict, prefix=""): if t5_key in state_dict: out["dtype_llama"] = state_dict[t5_key].dtype - scaled_fp8_key = "{}scaled_fp8".format(prefix) - if scaled_fp8_key in state_dict: - out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype - - if "_quantization_metadata" in state_dict: - out["llama_quantization_metadata"] = state_dict["_quantization_metadata"] + quant = comfy.utils.detect_layer_quantization(state_dict, prefix) + if quant is not None: + out["llama_quantization_metadata"] = quant return out @@ -31,10 +28,10 @@ def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256, class LLAMAModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}, special_tokens={"start": 128000, "pad": 128258}): - llama_scaled_fp8 = model_options.get("llama_scaled_fp8", None) - if llama_scaled_fp8 is not None: + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 + model_options["quantization_metadata"] = llama_quantization_metadata textmodel_json_config = {} vocab_size = model_options.get("vocab_size", None) @@ -161,11 +158,11 @@ def load_sd(self, sd): return self.llama.load_sd(sd) -def hunyuan_video_clip(dtype_llama=None, llama_scaled_fp8=None): +def hunyuan_video_clip(dtype_llama=None, llama_quantization_metadata=None): class HunyuanVideoClipModel_(HunyuanVideoClipModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options: + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["llama_scaled_fp8"] = llama_scaled_fp8 + model_options["llama_quantization_metadata"] = llama_quantization_metadata super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options) return HunyuanVideoClipModel_ diff --git a/comfy/text_encoders/lumina2.py b/comfy/text_encoders/lumina2.py index fd986e2c1a63..7a6cfdab2ead 100644 --- a/comfy/text_encoders/lumina2.py +++ b/comfy/text_encoders/lumina2.py @@ -40,7 +40,7 @@ def __init__(self, device="cpu", dtype=None, model_options={}, name="gemma2_2b", super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options) -def te(dtype_llama=None, llama_scaled_fp8=None, model_type="gemma2_2b"): +def te(dtype_llama=None, llama_quantization_metadata=None, model_type="gemma2_2b"): if model_type == "gemma2_2b": model = Gemma2_2BModel elif model_type == "gemma3_4b": @@ -48,9 +48,9 @@ def te(dtype_llama=None, llama_scaled_fp8=None, model_type="gemma2_2b"): class LuminaTEModel_(LuminaModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 + model_options["quantization_metadata"] = llama_quantization_metadata if dtype_llama is not None: dtype = dtype_llama super().__init__(device=device, dtype=dtype, name=model_type, model_options=model_options, clip_model=model) diff --git a/comfy/text_encoders/omnigen2.py b/comfy/text_encoders/omnigen2.py index 1a01b2dd4304..50aa4121fbfb 100644 --- a/comfy/text_encoders/omnigen2.py +++ b/comfy/text_encoders/omnigen2.py @@ -32,12 +32,12 @@ def __init__(self, device="cpu", dtype=None, model_options={}): super().__init__(device=device, dtype=dtype, name="qwen25_3b", clip_model=Qwen25_3BModel, model_options=model_options) -def te(dtype_llama=None, llama_scaled_fp8=None): +def te(dtype_llama=None, llama_quantization_metadata=None): class Omnigen2TEModel_(Omnigen2Model): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 + model_options["quantization_metadata"] = llama_quantization_metadata if dtype_llama is not None: dtype = dtype_llama super().__init__(device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/pixart_t5.py b/comfy/text_encoders/pixart_t5.py index 5f383de076fa..e5e5f18bed0c 100644 --- a/comfy/text_encoders/pixart_t5.py +++ b/comfy/text_encoders/pixart_t5.py @@ -30,12 +30,12 @@ class PixArtTokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer) -def pixart_te(dtype_t5=None, t5xxl_scaled_fp8=None): +def pixart_te(dtype_t5=None, t5_quantization_metadata=None): class PixArtTEModel_(PixArtT5XXL): def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata if dtype is None: dtype = dtype_t5 super().__init__(device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/qwen_image.py b/comfy/text_encoders/qwen_image.py index c0d32a6ef4ec..5c14dec23397 100644 --- a/comfy/text_encoders/qwen_image.py +++ b/comfy/text_encoders/qwen_image.py @@ -85,12 +85,12 @@ def encode_token_weights(self, token_weight_pairs, template_end=-1): return out, pooled, extra -def te(dtype_llama=None, llama_scaled_fp8=None): +def te(dtype_llama=None, llama_quantization_metadata=None): class QwenImageTEModel_(QwenImageTEModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 + model_options["quantization_metadata"] = llama_quantization_metadata if dtype_llama is not None: dtype = dtype_llama super().__init__(device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/sd3_clip.py b/comfy/text_encoders/sd3_clip.py index ff5d412db148..8b153c72b9db 100644 --- a/comfy/text_encoders/sd3_clip.py +++ b/comfy/text_encoders/sd3_clip.py @@ -6,14 +6,15 @@ import os import comfy.model_management import logging +import comfy.utils class T5XXLModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=False, model_options={}): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json") - t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None) - if t5xxl_scaled_fp8 is not None: + t5xxl_quantization_metadata = model_options.get("t5xxl_quantization_metadata", None) + if t5xxl_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = t5xxl_scaled_fp8 + model_options["quantization_metadata"] = t5xxl_quantization_metadata model_options = {**model_options, "model_name": "t5xxl"} super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) @@ -25,9 +26,9 @@ def t5_xxl_detect(state_dict, prefix=""): if t5_key in state_dict: out["dtype_t5"] = state_dict[t5_key].dtype - scaled_fp8_key = "{}scaled_fp8".format(prefix) - if scaled_fp8_key in state_dict: - out["t5xxl_scaled_fp8"] = state_dict[scaled_fp8_key].dtype + quant = comfy.utils.detect_layer_quantization(state_dict, prefix) + if quant is not None: + out["t5_quantization_metadata"] = quant return out @@ -156,11 +157,11 @@ def load_sd(self, sd): else: return self.t5xxl.load_sd(sd) -def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5xxl_scaled_fp8=None, t5_attention_mask=False): +def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_quantization_metadata=None, t5_attention_mask=False): class SD3ClipModel_(SD3ClipModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options) return SD3ClipModel_ diff --git a/comfy/text_encoders/wan.py b/comfy/text_encoders/wan.py index d50fa4b28df3..164a57edd7ac 100644 --- a/comfy/text_encoders/wan.py +++ b/comfy/text_encoders/wan.py @@ -25,12 +25,12 @@ class WanT5Model(sd1_clip.SD1ClipModel): def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs): super().__init__(device=device, dtype=dtype, model_options=model_options, name="umt5xxl", clip_model=UMT5XXlModel, **kwargs) -def te(dtype_t5=None, t5xxl_scaled_fp8=None): +def te(dtype_t5=None, t5_quantization_metadata=None): class WanTEModel(WanT5Model): def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = t5xxl_scaled_fp8 + model_options["quantization_metadata"] = t5_quantization_metadata if dtype_t5 is not None: dtype = dtype_t5 super().__init__(device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/z_image.py b/comfy/text_encoders/z_image.py index bb9273b201a6..19adde0b76bc 100644 --- a/comfy/text_encoders/z_image.py +++ b/comfy/text_encoders/z_image.py @@ -34,12 +34,9 @@ def __init__(self, device="cpu", dtype=None, model_options={}): super().__init__(device=device, dtype=dtype, name="qwen3_4b", clip_model=Qwen3_4BModel, model_options=model_options) -def te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None): +def te(dtype_llama=None, llama_quantization_metadata=None): class ZImageTEModel_(ZImageTEModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: - model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 if dtype_llama is not None: dtype = dtype_llama if llama_quantization_metadata is not None: diff --git a/comfy/utils.py b/comfy/utils.py index 37485e497a99..89846bc95fb1 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -29,6 +29,7 @@ from torch.nn.functional import interpolate from einops import rearrange from comfy.cli_args import args +import json MMAP_TORCH_FILES = args.mmap_torch_files DISABLE_MMAP = args.disable_mmap @@ -1194,3 +1195,68 @@ def unpack_latents(combined_latent, latent_shapes): else: output_tensors = combined_latent return output_tensors + +def detect_layer_quantization(state_dict, prefix): + for k in state_dict: + if k.startswith(prefix) and k.endswith(".comfy_quant"): + logging.info("Found quantization metadata version 1") + return {"mixed_ops": True} + return None + +def convert_old_quants(state_dict, model_prefix="", metadata={}): + if metadata is None: + metadata = {} + + quant_metadata = None + if "_quantization_metadata" not in metadata: + scaled_fp8_key = "{}scaled_fp8".format(model_prefix) + + if scaled_fp8_key in state_dict: + scaled_fp8_weight = state_dict[scaled_fp8_key] + scaled_fp8_dtype = scaled_fp8_weight.dtype + if scaled_fp8_dtype == torch.float32: + scaled_fp8_dtype = torch.float8_e4m3fn + + if scaled_fp8_weight.nelement() == 2: + full_precision_matrix_mult = True + else: + full_precision_matrix_mult = False + + out_sd = {} + layers = {} + for k in list(state_dict.keys()): + if not k.startswith(model_prefix): + out_sd[k] = state_dict[k] + continue + k_out = k + w = state_dict.pop(k) + layer = None + if k_out.endswith(".scale_weight"): + layer = k_out[:-len(".scale_weight")] + k_out = "{}.weight_scale".format(layer) + + if layer is not None: + layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints + if full_precision_matrix_mult: + layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult + layers[layer] = layer_conf + + if k_out.endswith(".scale_input"): + layer = k_out[:-len(".scale_input")] + k_out = "{}.input_scale".format(layer) + if w.item() == 1.0: + continue + + out_sd[k_out] = w + + state_dict = out_sd + quant_metadata = {"layers": layers} + else: + quant_metadata = json.loads(metadata["_quantization_metadata"]) + + if quant_metadata is not None: + layers = quant_metadata["layers"] + for k, v in layers.items(): + state_dict["{}.comfy_quant".format(k)] = torch.frombuffer(json.dumps(v).encode('utf-8'), dtype=torch.uint8) + + return state_dict, metadata diff --git a/tests-unit/comfy_quant/test_mixed_precision.py b/tests-unit/comfy_quant/test_mixed_precision.py index 63361309f63d..5f45bfeef17b 100644 --- a/tests-unit/comfy_quant/test_mixed_precision.py +++ b/tests-unit/comfy_quant/test_mixed_precision.py @@ -115,7 +115,8 @@ def test_mixed_precision_load(self): # Forward pass input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) - output = model(input_tensor) + with torch.inference_mode(): + output = model(input_tensor) self.assertEqual(output.shape, (5, 40)) From 8fa3d1281e85b0f8f973d6cc3d519f8387e36413 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 3 Dec 2025 19:43:40 -0500 Subject: [PATCH 02/11] Remove unused json import --- comfy/model_base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 9cde2e051624..3cedd4f31c1f 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -57,7 +57,6 @@ import comfy.latent_formats import comfy.model_sampling import math -import json from typing import TYPE_CHECKING if TYPE_CHECKING: from comfy.model_patcher import ModelPatcher From 83a0fa107baf0f70549dc2363f8504fda32bb510 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 3 Dec 2025 20:11:28 -0500 Subject: [PATCH 03/11] Quick fix unit tests. --- tests-unit/comfy_quant/test_mixed_precision.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/tests-unit/comfy_quant/test_mixed_precision.py b/tests-unit/comfy_quant/test_mixed_precision.py index 5f45bfeef17b..3a54941e6bf3 100644 --- a/tests-unit/comfy_quant/test_mixed_precision.py +++ b/tests-unit/comfy_quant/test_mixed_precision.py @@ -2,6 +2,7 @@ import torch import sys import os +import json # Add comfy to path sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) @@ -15,6 +16,7 @@ def has_gpu(): from comfy import ops from comfy.quant_ops import QuantizedTensor +import comfy.utils class SimpleModel(torch.nn.Module): @@ -94,8 +96,9 @@ def test_mixed_precision_load(self): "layer3.weight_scale": torch.tensor(1.5, dtype=torch.float32), } + state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})}) # Create model and load state dict (strict=False because custom loading pops keys) - model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config)) + model = SimpleModel(operations=ops.mixed_precision_ops({})) model.load_state_dict(state_dict, strict=False) # Verify weights are wrapped in QuantizedTensor @@ -142,7 +145,8 @@ def test_state_dict_quantized_preserved(self): "layer3.bias": torch.randn(40, dtype=torch.bfloat16), } - model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config)) + state_dict1, _ = comfy.utils.convert_old_quants(state_dict1, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})}) + model = SimpleModel(operations=ops.mixed_precision_ops({})) model.load_state_dict(state_dict1, strict=False) # Save state dict @@ -179,7 +183,8 @@ def test_weight_function_compatibility(self): "layer3.bias": torch.randn(40, dtype=torch.bfloat16), } - model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config)) + state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})}) + model = SimpleModel(operations=ops.mixed_precision_ops({})) model.load_state_dict(state_dict, strict=False) # Add a weight function (simulating LoRA) @@ -216,8 +221,10 @@ def test_error_handling_unknown_format(self): "layer3.bias": torch.randn(40, dtype=torch.bfloat16), } + state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})}) + # Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS - model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config)) + model = SimpleModel(operations=ops.mixed_precision_ops({})) with self.assertRaises(KeyError): model.load_state_dict(state_dict, strict=False) From c217243b5669f6ab7884573899cbccb11b56bfb1 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 4 Dec 2025 19:09:35 -0500 Subject: [PATCH 04/11] Fix fp8 fast lora offload. --- comfy/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ops.py b/comfy/ops.py index 34f0a74d7ba8..cce68febef89 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -449,7 +449,7 @@ def reset_parameters(self): return None def forward_comfy_cast_weights(self, input): - if not self.training: + if len(self.weight_function) == 0 and len(self.bias_function) == 0: try: out = fp8_linear(self, input) if out is not None: From 55366d9c0fcb31012b5a441e2164eb9b7870eb66 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 4 Dec 2025 19:17:06 -0500 Subject: [PATCH 05/11] Remove. --- comfy/text_encoders/ovis.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/comfy/text_encoders/ovis.py b/comfy/text_encoders/ovis.py index 81c9bd51c47c..5754424d25b1 100644 --- a/comfy/text_encoders/ovis.py +++ b/comfy/text_encoders/ovis.py @@ -55,12 +55,9 @@ def encode_token_weights(self, token_weight_pairs, template_end=-1): return out, pooled, {} -def te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None): +def te(dtype_llama=None, llama_quantization_metadata=None): class OvisTEModel_(OvisTEModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: - model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 if dtype_llama is not None: dtype = dtype_llama if llama_quantization_metadata is not None: From 295a0170d64b2544de5d22b3067eab93142dba65 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 4 Dec 2025 19:38:01 -0500 Subject: [PATCH 06/11] Remove. --- comfy/quant_ops.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index a565162db851..622ecebf5a64 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -241,9 +241,6 @@ def is_contiguous(self, *arg, **kwargs): def storage(self): return self._qdata.storage() - def untyped_storage(self): - return self._qdata.untyped_storage() - # ============================================================================== # Generic Utilities (Layout-Agnostic Operations) # ============================================================================== From b8afb60ee88f921a6331e631986fa608d1d66989 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 4 Dec 2025 22:17:13 -0500 Subject: [PATCH 07/11] torch compile fixes. --- comfy/quant_ops.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 622ecebf5a64..f75fea7b8766 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -252,12 +252,6 @@ def _create_transformed_qtensor(qt, transform_fn): def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"): - if target_dtype is not None and target_dtype != qt.dtype: - logging.warning( - f"QuantizedTensor: dtype conversion requested to {target_dtype}, " - f"but not supported for quantized tensors. Ignoring dtype." - ) - if target_layout is not None and target_layout != torch.strided: logging.warning( f"QuantizedTensor: layout change requested to {target_layout}, " @@ -277,6 +271,8 @@ def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout= logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}") new_q_data = qt._qdata.to(device=target_device) new_params = _move_layout_params_to_device(qt._layout_params, target_device) + if target_dtype is not None: + new_params["orig_dtype"] = target_dtype new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params) logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}") return new_qt @@ -400,7 +396,7 @@ class TensorCoreFP8Layout(QuantizedLayout): def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False): orig_dtype = tensor.dtype - if scale == "recalculate": + if isinstance(scale, str) and scale == "recalculate": scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max if scale is not None: From 88172a43390a2dc080121233ba92e04e76048d45 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 4 Dec 2025 23:58:29 -0500 Subject: [PATCH 08/11] Fix torch compile issue. --- comfy/ops.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/comfy/ops.py b/comfy/ops.py index cce68febef89..dc06709a1a90 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -626,6 +626,20 @@ def set_weight(self, weight, inplace_update=False, seed=None, return_weight=Fals assert inplace_update is False # TODO: eventually remove the inplace_update stuff self.weight = torch.nn.Parameter(weight, requires_grad=False) + def _apply(self, fn, recurse=True): # This is to get torch.compile + moving weights to another device working + if recurse: + for module in self.children(): + module._apply(fn) + + for key, param in self._parameters.items(): + if param is None: + continue + self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False)) + for key, buf in self._buffers.items(): + if buf is not None: + self._buffers[key] = fn(buf) + return self + return MixedPrecisionOps def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None): From e12842fe316c87c9c77da8b20a3b42adb5896742 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 5 Dec 2025 01:11:22 -0500 Subject: [PATCH 09/11] Remove --- comfy/sd.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 440aed114e57..e264c1ee543f 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1217,8 +1217,6 @@ class EmptyClass: parameters = 0 for c in clip_data: - if "_quantization_metadata" in c: - c.pop("_quantization_metadata") parameters += comfy.utils.calculate_parameters(c) tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options) @@ -1410,9 +1408,6 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): temp_sd = comfy.utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True) if len(temp_sd) > 0: sd = temp_sd - quant_key = "{}_quantization_metadata".format(diffusion_model_prefix) - if metadata is not None and quant_key in metadata: - metadata["_quantization_metadata"] = metadata.pop(quant_key) sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata) parameters = comfy.utils.calculate_parameters(sd) From 129f3e7db1f8e6c642dd106bc5349891356903f9 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 5 Dec 2025 02:05:14 -0500 Subject: [PATCH 10/11] Fix loras. --- comfy/quant_ops.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index f75fea7b8766..571d3f760624 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -338,7 +338,9 @@ def generic_copy_(func, args, kwargs): # Copy from another quantized tensor qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking) qt_dest._layout_type = src._layout_type + orig_dtype = qt_dest._layout_params["orig_dtype"] _copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking) + qt_dest._layout_params["orig_dtype"] = orig_dtype else: # Copy from regular tensor - just copy raw data qt_dest._qdata.copy_(src) From 3193d3aa53fe6b3dd56b2d98d419f56296268880 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 5 Dec 2025 02:16:34 -0500 Subject: [PATCH 11/11] Don't convert quants when custom ops are used. --- comfy/sd.py | 57 ++++++++++++++++++++++++++++++++--------------------- 1 file changed, 34 insertions(+), 23 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index e264c1ee543f..9eea7d63b4ac 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -964,7 +964,8 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI clip_data = [] for p in ckpt_paths: sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True) - sd, metadata = comfy.utils.convert_old_quants(sd, model_prefix="", metadata=metadata) + if model_options.get("custom_operations", None) is None: + sd, metadata = comfy.utils.convert_old_quants(sd, model_prefix="", metadata=metadata) clip_data.append(sd) return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options) @@ -1286,7 +1287,9 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix) load_device = model_management.get_torch_device() - sd, metadata = comfy.utils.convert_old_quants(sd, diffusion_model_prefix, metadata=metadata) + custom_operations = model_options.get("custom_operations", None) + if custom_operations is None: + sd, metadata = comfy.utils.convert_old_quants(sd, diffusion_model_prefix, metadata=metadata) model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata) if model_config is None: @@ -1300,7 +1303,9 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c if model_config.quant_config is not None: weight_dtype = None - model_config.custom_operations = model_options.get("custom_operations", None) + if custom_operations is not None: + model_config.custom_operations = custom_operations + unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None)) if unet_dtype is None: @@ -1327,25 +1332,26 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c vae = VAE(sd=vae_sd, metadata=metadata) if output_clip: - scaled_fp8_list = [] - for k in list(sd.keys()): # Convert scaled fp8 to mixed ops - if k.endswith(".scaled_fp8"): - scaled_fp8_list.append(k[:-len("scaled_fp8")]) - - if len(scaled_fp8_list) > 0: - out_sd = {} - for k in sd: - skip = False - for pref in scaled_fp8_list: - skip = skip or k.startswith(pref) - if not skip: - out_sd[k] = sd[k] + if te_model_options.get("custom_operations", None) is None: + scaled_fp8_list = [] + for k in list(sd.keys()): # Convert scaled fp8 to mixed ops + if k.endswith(".scaled_fp8"): + scaled_fp8_list.append(k[:-len("scaled_fp8")]) + + if len(scaled_fp8_list) > 0: + out_sd = {} + for k in sd: + skip = False + for pref in scaled_fp8_list: + skip = skip or k.startswith(pref) + if not skip: + out_sd[k] = sd[k] - for pref in scaled_fp8_list: - quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={}) - for k in quant_sd: - out_sd[k] = quant_sd[k] - sd = out_sd + for pref in scaled_fp8_list: + quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={}) + for k in quant_sd: + out_sd[k] = quant_sd[k] + sd = out_sd clip_target = model_config.clip_target(state_dict=sd) if clip_target is not None: @@ -1409,7 +1415,9 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): if len(temp_sd) > 0: sd = temp_sd - sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata) + custom_operations = model_options.get("custom_operations", None) + if custom_operations is None: + sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata) parameters = comfy.utils.calculate_parameters(sd) weight_dtype = comfy.utils.weight_dtype(sd) @@ -1453,7 +1461,10 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): else: manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) - model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations) + + if custom_operations is not None: + model_config.custom_operations = custom_operations + if model_options.get("fp8_optimizations", False): model_config.optimizations["fp8"] = True