Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 1 addition & 13 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_mod
if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None:
fp8 = model_config.optimizations.get("fp8", False)
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config)
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, model_config=model_config)
else:
operations = model_config.custom_operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
Expand Down Expand Up @@ -329,18 +329,6 @@ def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))

unet_state_dict = self.diffusion_model.state_dict()

if self.model_config.scaled_fp8 is not None:
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)

# Save mixed precision metadata
if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config:
metadata = {
"format_version": "1.0",
"layers": self.model_config.layer_quant_config
}
unet_state_dict["_quantization_metadata"] = metadata

unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)

if self.model_type == ModelType.V_PREDICTION:
Expand Down
33 changes: 4 additions & 29 deletions comfy/model_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,6 @@
import logging
import torch


def detect_layer_quantization(metadata):
quant_key = "_quantization_metadata"
if metadata is not None and quant_key in metadata:
quant_metadata = metadata.pop(quant_key)
quant_metadata = json.loads(quant_metadata)
if isinstance(quant_metadata, dict) and "layers" in quant_metadata:
logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})")
return quant_metadata["layers"]
else:
raise ValueError("Invalid quantization metadata format")
return None


def count_blocks(state_dict_keys, prefix_string):
count = 0
while True:
Expand Down Expand Up @@ -767,22 +753,11 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
if model_config is None and use_base_if_no_match:
model_config = comfy.supported_models_base.BASE(unet_config)

scaled_fp8_key = "{}scaled_fp8".format(unet_key_prefix)
if scaled_fp8_key in state_dict:
scaled_fp8_weight = state_dict.pop(scaled_fp8_key)
model_config.scaled_fp8 = scaled_fp8_weight.dtype
if model_config.scaled_fp8 == torch.float32:
model_config.scaled_fp8 = torch.float8_e4m3fn
if scaled_fp8_weight.nelement() == 2:
model_config.optimizations["fp8"] = False
else:
model_config.optimizations["fp8"] = True

# Detect per-layer quantization (mixed precision)
layer_quant_config = detect_layer_quantization(metadata)
if layer_quant_config:
model_config.layer_quant_config = layer_quant_config
logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized")
quant_config = comfy.utils.detect_layer_quantization(state_dict, unet_key_prefix)
if quant_config:
model_config.quant_config = quant_config
logging.info("Detected mixed precision quantization")

return model_config

Expand Down
20 changes: 2 additions & 18 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,27 +126,11 @@ class LowVramPatch:
def __init__(self, key, patches, convert_func=None, set_func=None):
self.key = key
self.patches = patches
self.convert_func = convert_func
self.convert_func = convert_func # TODO: remove
self.set_func = set_func

def __call__(self, weight):
intermediate_dtype = weight.dtype
if self.convert_func is not None:
weight = self.convert_func(weight, inplace=False)

if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
intermediate_dtype = torch.float32
out = comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype)
if self.set_func is None:
return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key))
else:
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True)

out = comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
if self.set_func is not None:
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype)
else:
return out
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)

#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3
Expand Down
145 changes: 56 additions & 89 deletions comfy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import comfy.float
import comfy.rmsnorm
import contextlib
import json

def run_every_op():
if torch.compiler.is_compiling():
Expand Down Expand Up @@ -422,22 +423,12 @@ def fp8_linear(self, input):

if input.ndim == 3 or input.ndim == 2:
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)

scale_weight = self.scale_weight
scale_input = self.scale_input
if scale_weight is None:
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
else:
scale_weight = scale_weight.to(input.device)

if scale_input is None:
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
input = torch.clamp(input, min=-448, max=448, out=input)
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
else:
scale_input = scale_input.to(input.device)
quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype)
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
input = torch.clamp(input, min=-448, max=448, out=input)
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)

# Wrap weight in QuantizedTensor - this enables unified dispatch
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
Expand All @@ -458,7 +449,7 @@ def reset_parameters(self):
return None

def forward_comfy_cast_weights(self, input):
if not self.training:
if len(self.weight_function) == 0 and len(self.bias_function) == 0:
try:
out = fp8_linear(self, input)
if out is not None:
Expand All @@ -471,59 +462,6 @@ def forward_comfy_cast_weights(self, input):
uncast_bias_weight(self, weight, bias, offload_stream)
return x

def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
class scaled_fp8_op(manual_cast):
class Linear(manual_cast.Linear):
def __init__(self, *args, **kwargs):
if override_dtype is not None:
kwargs['dtype'] = override_dtype
super().__init__(*args, **kwargs)

def reset_parameters(self):
if not hasattr(self, 'scale_weight'):
self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)

if not scale_input:
self.scale_input = None

if not hasattr(self, 'scale_input'):
self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
return None

def forward_comfy_cast_weights(self, input):
if fp8_matrix_mult:
out = fp8_linear(self, input)
if out is not None:
return out

weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)

if weight.numel() < input.numel(): #TODO: optimize
x = torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
else:
x = torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x

def convert_weight(self, weight, inplace=False, **kwargs):
if inplace:
weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
return weight
else:
return weight.to(dtype=torch.float32) * self.scale_weight.to(device=weight.device, dtype=torch.float32)

def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
if return_weight:
return weight
if inplace_update:
self.weight.data.copy_(weight)
else:
self.weight = torch.nn.Parameter(weight, requires_grad=False)

return scaled_fp8_op

CUBLAS_IS_AVAILABLE = False
try:
from cublas_ops import CublasLinear
Expand All @@ -550,9 +488,9 @@ def forward(self, *args, **kwargs):
from .quant_ops import QuantizedTensor, QUANT_ALGOS


def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
class MixedPrecisionOps(manual_cast):
_layer_quant_config = layer_quant_config
_quant_config = quant_config
_compute_dtype = compute_dtype
_full_precision_mm = full_precision_mm

Expand Down Expand Up @@ -595,27 +533,36 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,

manually_loaded_keys = [weight_key]

if layer_name not in MixedPrecisionOps._layer_quant_config:
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
if layer_conf is not None:
layer_conf = json.loads(layer_conf.numpy().tobytes())

if layer_conf is None:
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
else:
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
if quant_format is None:
self.quant_format = layer_conf.get("format", None)
if not self._full_precision_mm:
self._full_precision_mm = layer_conf.get("full_precision_matrix_mult", False)

if self.quant_format is None:
raise ValueError(f"Unknown quantization format for layer {layer_name}")

qconfig = QUANT_ALGOS[quant_format]
qconfig = QUANT_ALGOS[self.quant_format]
self.layout_type = qconfig["comfy_tensor_layout"]

weight_scale_key = f"{prefix}weight_scale"
scale = state_dict.pop(weight_scale_key, None)
layout_params = {
'scale': state_dict.pop(weight_scale_key, None),
'scale': scale,
'orig_dtype': MixedPrecisionOps._compute_dtype,
'block_size': qconfig.get("group_size", None),
}
if layout_params['scale'] is not None:

if scale is not None:
manually_loaded_keys.append(weight_scale_key)

self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(device=device), self.layout_type, layout_params),
QuantizedTensor(weight.to(device=device, dtype=qconfig.get("storage_t", None)), self.layout_type, layout_params),
requires_grad=False
)

Expand All @@ -624,7 +571,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
_v = state_dict.pop(param_key, None)
if _v is None:
continue
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
manually_loaded_keys.append(param_key)

super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
Expand All @@ -633,6 +580,16 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
if key in missing_keys:
missing_keys.remove(key)

def state_dict(self, *args, destination=None, prefix="", **kwargs):
sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs)
if isinstance(self.weight, QuantizedTensor):
sd["{}weight_scale".format(prefix)] = self.weight._layout_params['scale']
quant_conf = {"format": self.quant_format}
if self._full_precision_mm:
quant_conf["full_precision_matrix_mult"] = True
sd["{}comfy_quant".format(prefix)] = torch.frombuffer(json.dumps(quant_conf).encode('utf-8'), dtype=torch.uint8)
return sd

def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias)

Expand All @@ -648,9 +605,8 @@ def forward(self, input, *args, **kwargs):
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(input, *args, **kwargs)
if (getattr(self, 'layout_type', None) is not None and
getattr(self, 'input_scale', None) is not None and
not isinstance(input, QuantizedTensor)):
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None), dtype=self.weight.dtype)
return self._forward(input, self.weight, self.bias)

def convert_weight(self, weight, inplace=False, **kwargs):
Expand All @@ -661,7 +617,7 @@ def convert_weight(self, weight, inplace=False, **kwargs):

def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
if getattr(self, 'layout_type', None) is not None:
weight = QuantizedTensor.from_float(weight, self.layout_type, scale=None, dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
else:
weight = weight.to(self.weight.dtype)
if return_weight:
Expand All @@ -670,17 +626,28 @@ def set_weight(self, weight, inplace_update=False, seed=None, return_weight=Fals
assert inplace_update is False # TODO: eventually remove the inplace_update stuff
self.weight = torch.nn.Parameter(weight, requires_grad=False)

def _apply(self, fn, recurse=True): # This is to get torch.compile + moving weights to another device working
if recurse:
for module in self.children():
module._apply(fn)

for key, param in self._parameters.items():
if param is None:
continue
self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False))
for key, buf in self._buffers.items():
if buf is not None:
self._buffers[key] = fn(buf)
return self

return MixedPrecisionOps

def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular

if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
return mixed_precision_ops(model_config.layer_quant_config, compute_dtype, full_precision_mm=not fp8_compute)

if scaled_fp8 is not None:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
logging.info("Using mixed precision operations")
return mixed_precision_ops(model_config.quant_config, compute_dtype, full_precision_mm=not fp8_compute)

if (
fp8_compute and
Expand Down
Loading