diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index bc86ba25f6..0c7ca0c755 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -1,5 +1,4 @@ import inspect -from typing import Dict, List, Optional, Tuple, Union import torch from compressed_tensors.quantization import disable_quantization @@ -94,8 +93,6 @@ class AWQModifier(Modifier, QuantizationMixin): - on_finalize - clear resolved mappings and captured activations - :param sequential_targets: list of module names to compress in - the same calibration pass :param mappings: list activation layers to smooth, and which layers to scale the output such that activations are smoothed. Each entry of the mapping list should be a list itself, in which the first @@ -118,27 +115,30 @@ class AWQModifier(Modifier, QuantizationMixin): model_config: ConfigDict = ConfigDict(arbitrary_types_allowed=True) # User-provided vars (in addition to QuantizationMixin args) - sequential_targets: Union[str, List[str], None] = None - mappings: Optional[List[AWQMapping]] = None - offload_device: Optional[torch.device] = None + mappings: list[AWQMapping] | None = None + offload_device: torch.device | None = None duo_scaling: bool = True # Private vars set during validation - _num_bits: Optional[int] = PrivateAttr(default=None) - _symmetric: Optional[bool] = PrivateAttr(default=None) - _group_size: Optional[int] = PrivateAttr(default=None) + _num_bits: int | None = PrivateAttr(default=None) + _symmetric: bool | None = PrivateAttr(default=None) + _group_size: int | None = PrivateAttr(default=None) # Private vars set during initialization, cleared during finalization - _resolved_mappings: List[ResolvedMapping] = PrivateAttr(default_factory=list) + _resolved_mappings: list[ResolvedMapping] = PrivateAttr(default_factory=list) # Cache list of forward input args for each parent module, one dict for each batch - _parent_args_cache: Dict[Module, IntermediatesCache] = PrivateAttr( + _parent_args_cache: dict[Module, IntermediatesCache] = PrivateAttr( default_factory=dict ) # Dict[smooth layer name, (activation means, activation counts)] - _smooth_activation_means: Dict[str, Tuple[torch.FloatTensor, int]] = PrivateAttr( + _smooth_activation_means: dict[str, tuple[torch.FloatTensor, int]] = PrivateAttr( default_factory=dict ) + # NOTE: in case a user wants to run both AWQ and GPTQ before quantizing, + # this is set to True + _supports_disabling_quantization: bool = PrivateAttr(True) + # NOTE: different name chosen to avoid collision with # QuantizationMixin.validate_model_after, which must be called first @model_validator(mode="after") @@ -389,7 +389,7 @@ def _setup_activation_cache_hooks(self) -> None: def cache_parent_kwargs_hook( module: torch.nn.Module, - args: Tuple[torch.Tensor, ...], + args: tuple[torch.Tensor, ...], kwargs, ): values = inspect.signature(module.forward).bind(*args, **kwargs) @@ -398,7 +398,7 @@ def cache_parent_kwargs_hook( def create_cache_smooth_activations_hook_fn(smooth_name): def cache_smooth_activations_hook( _module: torch.nn.Module, - args: Tuple[torch.Tensor, ...], + args: tuple[torch.Tensor, ...], _output: torch.Tensor, ): self._smooth_activation_means[smooth_name] = _accumulate_mean( @@ -559,13 +559,13 @@ def _smooth(module): v.batch_intermediates.clear() self._assert_all_activations_consumed() - def _run_samples(self, module: Module) -> List[torch.Tensor]: + def _run_samples(self, module: Module) -> list[torch.Tensor]: outputs = [ module(**batch_kwargs) for batch_kwargs in self._parent_args_cache[module] ] return [ # If Tuple, assume that first argument is the input - output[0] if isinstance(output, Tuple) else output + output[0] if isinstance(output, tuple) else output for output in outputs ] @@ -574,8 +574,8 @@ def _compute_best_scale( x_mean: torch.Tensor, w_mean: torch.Tensor, parent_module: torch.nn.Module, - linears2scale: List[torch.nn.Linear], - fp16_outputs: List[torch.Tensor], + linears2scale: list[torch.nn.Linear], + fp16_outputs: list[torch.Tensor], ) -> torch.Tensor: """ Compute loss and select best scales @@ -667,8 +667,8 @@ def _compute_best_scale( @torch.no_grad() def _compute_loss( self, - fp16_outputs: List[torch.Tensor], - int_w_outputs: List[torch.Tensor], + fp16_outputs: list[torch.Tensor], + int_w_outputs: list[torch.Tensor], device: torch.device, ) -> torch.Tensor: loss = 0.0 @@ -746,8 +746,8 @@ def _pseudo_quantize_tensor( def _accumulate_mean( inp: torch.Tensor, - prev_mean_and_count: Optional[Tuple[torch.FloatTensor, int]], -) -> Tuple[torch.FloatTensor, int]: + prev_mean_and_count: tuple[torch.FloatTensor, int] | None, +) -> tuple[torch.FloatTensor, int]: sum_added = inp.sum(dim=0) num_added = inp.size(0) if prev_mean_and_count is None: @@ -761,7 +761,7 @@ def _accumulate_mean( return (prev_sum + sum_added) / new_count, new_count -def get_lowest_common_parent(names: List[str], module: Module) -> Tuple[str, Module]: +def get_lowest_common_parent(names: list[str], module: Module) -> tuple[str, Module]: """ Given a list of names, returns the lowest-scope common parent. diff --git a/src/llmcompressor/modifiers/awq/mappings.py b/src/llmcompressor/modifiers/awq/mappings.py index 907bca4880..9bec035df4 100644 --- a/src/llmcompressor/modifiers/awq/mappings.py +++ b/src/llmcompressor/modifiers/awq/mappings.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Dict, List, Optional from loguru import logger from torch.nn import Module @@ -143,7 +142,7 @@ class AWQMapping: # ["re:.*dense$"] # ), ] -AWQ_MAPPING_REGISTRY: Dict[str, list[AWQMapping]] = { +AWQ_MAPPING_REGISTRY: dict[str, list[AWQMapping]] = { "BloomForCausalLM": _bloom_mappings, "CohereForCausalLM": _cohere_mappings, "Cohere2ForCausalLM": _cohere_mappings, @@ -186,13 +185,13 @@ class ResolvedMapping: smooth_name: str smooth_layer: Module - balance_layers: List[Module] - balance_names: Optional[List[str]] = None - parent: Optional[Module] = None - parent_name: Optional[str] = None + balance_layers: list[Module] + balance_names: list[str] + parent: Module + parent_name: str -def get_layer_mappings_from_architecture(architecture: str) -> List[AWQMapping]: +def get_layer_mappings_from_architecture(architecture: str) -> list[AWQMapping]: """ :param architecture: str: The architecture of the model :return: list: The layer mappings for the given architecture diff --git a/src/llmcompressor/modifiers/quantization/quantization/mixin.py b/src/llmcompressor/modifiers/quantization/quantization/mixin.py index 463806ce16..718f3347ca 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/mixin.py +++ b/src/llmcompressor/modifiers/quantization/quantization/mixin.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Set, Union +from typing import Any import torch from compressed_tensors.modeling import ( @@ -89,20 +89,26 @@ class QuantizationMixin(HooksMixin): and kv_cache_scheme != None, the quantization of kv cache will fail """ - config_groups: Optional[Dict[str, QuantizationScheme]] = None + config_groups: dict[str, QuantizationScheme] | None = None # NOTE: targets is not the sole source of truth for finding all matching target # layers in a model. Additional information can be stored in `config_groups` # Use self.resolved_targets as source of truth. - targets: Union[str, List[str]] = Field(default_factory=lambda: ["Linear"]) - ignore: List[str] = Field(default_factory=list) - scheme: Optional[Union[str, Dict[str, Any]]] = None - kv_cache_scheme: Optional[QuantizationArgs] = None + targets: str | list[str] | None = Field(default_factory=lambda: ["Linear"]) + ignore: list[str] = Field(default_factory=list) + scheme: str | dict[str, Any] | None = None + kv_cache_scheme: QuantizationArgs | None = None - _calibration_hooks: Set[RemovableHandle] = PrivateAttr(default_factory=set) - _resolved_config: Optional[QuantizationConfig] = PrivateAttr(None) + _calibration_hooks: set[RemovableHandle] = PrivateAttr(default_factory=set) + _resolved_config: QuantizationConfig | None = PrivateAttr(None) + + # NOTE: in some cases, we need to allow users to run instances of the + # QuantizationMixin without quantizing modules, e.g. when a user wants + # to run both AWQ and GPTQ before quantizing. Set this field to True + # on classes that subclass QuantiztaionMixin to allow for this. + _supports_disabling_quantization: bool = PrivateAttr(False) @field_validator("targets", mode="before") - def validate_targets(cls, value: Union[str, List[str]]) -> List[str]: + def validate_targets(cls, value: str | list[str]) -> list[str]: if isinstance(value, str): return [value] @@ -110,8 +116,8 @@ def validate_targets(cls, value: Union[str, List[str]]) -> List[str]: @field_validator("scheme", mode="before") def validate_scheme( - cls, value: Optional[Union[str, Dict[str, Any]]] - ) -> Optional[Union[str, Dict[str, Any]]]: + cls, value: str | dict[str, Any] | None + ) -> str | dict[str, Any] | None: if isinstance(value, str) and not is_preset_scheme(value): raise ValueError( "`scheme` must either be a preset scheme name or a dictionary " @@ -138,7 +144,7 @@ def resolved_config(self) -> QuantizationConfig: return self._resolved_config @property - def resolved_targets(self) -> Set[str]: + def resolved_targets(self) -> set[str]: """ Set of all resolved targets, i.e. all unique targets listed in resolved quantization config. @@ -221,6 +227,12 @@ def resolve_quantization_config(self) -> QuantizationConfig: kv_cache_scheme = self.kv_cache_scheme ignore = self.ignore + # NOTE: this will only happen if user explicitly sets targets=None + if targets is None and config_groups is None: + if self._supports_disabling_quantization: + return QuantizationConfig({}) + raise ValueError("Please specify either `targets` or `config_groups`") + if scheme is not None and config_groups is not None: raise ValueError("Please specify either `scheme` or `config_groups`") @@ -286,7 +298,7 @@ def _initialize_observers(self, module: torch.nn.Module): if output: initialize_observer(module, base_name="output") - def _initialize_hooks(self, module: torch.nn.Module) -> Set[RemovableHandle]: + def _initialize_hooks(self, module: torch.nn.Module) -> set[RemovableHandle]: hooks = set() if not hasattr(module, "quantization_scheme"): return hooks diff --git a/tests/llmcompressor/modifiers/awq/test_base.py b/tests/llmcompressor/modifiers/awq/test_base.py index a66a278f32..2a13dabbdb 100644 --- a/tests/llmcompressor/modifiers/awq/test_base.py +++ b/tests/llmcompressor/modifiers/awq/test_base.py @@ -228,3 +228,9 @@ def test_get_lowest_common_parent(): ["embed_tokens", "decoder.self_attn.v_proj"], model ) assert parent_name == "" and parent == model + + +def test_awq_supports_disabling_quantization(): + awq = AWQModifier(scheme="W4A16", targets=None) + + assert len(awq.resolved_config.config_groups) == 0 diff --git a/tests/llmcompressor/modifiers/quantization/test_base.py b/tests/llmcompressor/modifiers/quantization/test_base.py index 51ea28f13b..de2c6c7f97 100644 --- a/tests/llmcompressor/modifiers/quantization/test_base.py +++ b/tests/llmcompressor/modifiers/quantization/test_base.py @@ -3,7 +3,7 @@ import pytest from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme -from llmcompressor.modifiers.quantization import GPTQModifier +from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier @pytest.fixture @@ -211,3 +211,11 @@ def test_resolved_targets( ) assert modifier.resolved_targets == resolved_targets + + +def test_does_not_support_disabling_quantization(): + with pytest.raises(ValueError): + GPTQModifier(scheme="W4A16", targets=None).resolve_quantization_config() + + with pytest.raises(ValueError): + QuantizationModifier(scheme="W4A16", targets=None).resolve_quantization_config()