Skip to content

Commit 62f62d0

Browse files
authored
Deprecate config functions like int4_weight_only (#2994)
* Deprecate config functions like `int4_weight_only` **Summary:** These have been superseded by `AOBaseConfig` objects for several releases already, but we never deprecated them. We will keep them around for another release before breaking BC and removing them. **Test Plan:** ``` python test/quantization/test_quant_api.py -k test_config_deprecation ``` [ghstack-poisoned] * Update on "Deprecate config functions like `int4_weight_only`" **Summary:** These have been superseded by `AOBaseConfig` objects for several releases already, but we never deprecated them. We will keep them around for another release before breaking BC and removing them. **Test Plan:** ``` python test/quantization/test_quant_api.py -k test_config_deprecation ``` [ghstack-poisoned] * Update on "Deprecate config functions like `int4_weight_only`" **Summary:** These have been superseded by `AOBaseConfig` objects for several releases already, but we never deprecated them. We will keep them around for another release before breaking BC and removing them. **Test Plan:** ``` python test/quantization/test_quant_api.py -k test_config_deprecation ``` [ghstack-poisoned]
1 parent 067b273 commit 62f62d0

File tree

3 files changed

+103
-14
lines changed

3 files changed

+103
-14
lines changed

test/quantization/test_quant_api.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import gc
1111
import tempfile
1212
import unittest
13+
import warnings
1314
from pathlib import Path
1415

1516
import torch
@@ -37,6 +38,8 @@
3738
PerGroup,
3839
)
3940
from torchao.quantization.quant_api import (
41+
Float8DynamicActivationFloat8WeightConfig,
42+
Float8StaticActivationFloat8WeightConfig,
4043
Int4WeightOnlyConfig,
4144
Int8DynamicActivationIntxWeightConfig,
4245
Int8WeightOnlyConfig,
@@ -623,8 +626,8 @@ def test_workflow_e2e_numerics(self, config):
623626
isinstance(
624627
config,
625628
(
626-
float8_dynamic_activation_float8_weight,
627-
float8_static_activation_float8_weight,
629+
Float8DynamicActivationFloat8WeightConfig,
630+
Float8StaticActivationFloat8WeightConfig,
628631
),
629632
)
630633
and not is_sm_at_least_89()
@@ -755,6 +758,56 @@ def test_int4wo_cuda_serialization(self):
755758
# load state_dict in cuda
756759
model.load_state_dict(sd, assign=True)
757760

761+
def test_config_deprecation(self):
762+
"""
763+
Test that old config functions like `int4_weight_only` trigger deprecation warnings.
764+
"""
765+
from torchao.quantization import (
766+
float8_dynamic_activation_float8_weight,
767+
float8_static_activation_float8_weight,
768+
float8_weight_only,
769+
fpx_weight_only,
770+
gemlite_uintx_weight_only,
771+
int4_dynamic_activation_int4_weight,
772+
int4_weight_only,
773+
int8_dynamic_activation_int4_weight,
774+
int8_dynamic_activation_int8_weight,
775+
int8_weight_only,
776+
uintx_weight_only,
777+
)
778+
779+
# Reset deprecation warning state, otherwise we won't log warnings here
780+
warnings.resetwarnings()
781+
782+
# Map from deprecated API to the args needed to instantiate it
783+
deprecated_apis_to_args = {
784+
float8_dynamic_activation_float8_weight: (),
785+
float8_static_activation_float8_weight: (torch.randn(3)),
786+
float8_weight_only: (),
787+
fpx_weight_only: (3, 2),
788+
gemlite_uintx_weight_only: (),
789+
int4_dynamic_activation_int4_weight: (),
790+
int4_weight_only: (),
791+
int8_dynamic_activation_int4_weight: (),
792+
int8_dynamic_activation_int8_weight: (),
793+
int8_weight_only: (),
794+
uintx_weight_only: (torch.uint4,),
795+
}
796+
797+
with warnings.catch_warnings(record=True) as _warnings:
798+
# Call each deprecated API twice
799+
for cls, args in deprecated_apis_to_args.items():
800+
cls(*args)
801+
cls(*args)
802+
803+
# Each call should trigger the warning only once
804+
self.assertEqual(len(_warnings), len(deprecated_apis_to_args))
805+
for w in _warnings:
806+
self.assertIn(
807+
"is deprecated and will be removed in a future release",
808+
str(w.message),
809+
)
810+
758811

759812
common_utils.instantiate_parametrized_tests(TestQuantFlow)
760813

torchao/quantization/quant_api.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
to_weight_tensor_with_linear_activation_quantization_metadata,
9393
)
9494
from torchao.utils import (
95+
_ConfigDeprecationWrapper,
9596
_is_fbgemm_genai_gpu_available,
9697
is_MI300,
9798
is_sm_at_least_89,
@@ -639,7 +640,9 @@ def __post_init__(self):
639640

640641

641642
# for BC
642-
int8_dynamic_activation_int4_weight = Int8DynamicActivationInt4WeightConfig
643+
int8_dynamic_activation_int4_weight = _ConfigDeprecationWrapper(
644+
"int8_dynamic_activation_int4_weight", Int8DynamicActivationInt4WeightConfig
645+
)
643646

644647

645648
@register_quantize_module_handler(Int8DynamicActivationInt4WeightConfig)
@@ -972,7 +975,9 @@ def __post_init__(self):
972975

973976

974977
# for bc
975-
int4_dynamic_activation_int4_weight = Int4DynamicActivationInt4WeightConfig
978+
int4_dynamic_activation_int4_weight = _ConfigDeprecationWrapper(
979+
"int4_dynamic_activation_int4_weight", Int4DynamicActivationInt4WeightConfig
980+
)
976981

977982

978983
@register_quantize_module_handler(Int4DynamicActivationInt4WeightConfig)
@@ -1033,7 +1038,9 @@ def __post_init__(self):
10331038

10341039

10351040
# for BC
1036-
gemlite_uintx_weight_only = GemliteUIntXWeightOnlyConfig
1041+
gemlite_uintx_weight_only = _ConfigDeprecationWrapper(
1042+
"gemlite_uintx_weight_only", GemliteUIntXWeightOnlyConfig
1043+
)
10371044

10381045

10391046
@register_quantize_module_handler(GemliteUIntXWeightOnlyConfig)
@@ -1115,7 +1122,7 @@ def __post_init__(self):
11151122

11161123
# for BC
11171124
# TODO maybe change other callsites
1118-
int4_weight_only = Int4WeightOnlyConfig
1125+
int4_weight_only = _ConfigDeprecationWrapper("int4_weight_only", Int4WeightOnlyConfig)
11191126

11201127

11211128
def _int4_weight_only_quantize_tensor(weight, config):
@@ -1325,7 +1332,7 @@ def __post_init__(self):
13251332

13261333

13271334
# for BC
1328-
int8_weight_only = Int8WeightOnlyConfig
1335+
int8_weight_only = _ConfigDeprecationWrapper("int8_weight_only", Int8WeightOnlyConfig)
13291336

13301337

13311338
def _int8_weight_only_quantize_tensor(weight, config):
@@ -1486,7 +1493,9 @@ def __post_init__(self):
14861493

14871494

14881495
# for BC
1489-
int8_dynamic_activation_int8_weight = Int8DynamicActivationInt8WeightConfig
1496+
int8_dynamic_activation_int8_weight = _ConfigDeprecationWrapper(
1497+
"int8_dynamic_activation_int8_weight", Int8DynamicActivationInt8WeightConfig
1498+
)
14901499

14911500

14921501
def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config):
@@ -1595,7 +1604,9 @@ def __post_init__(self):
15951604

15961605

15971606
# for BC
1598-
float8_weight_only = Float8WeightOnlyConfig
1607+
float8_weight_only = _ConfigDeprecationWrapper(
1608+
"float8_weight_only", Float8WeightOnlyConfig
1609+
)
15991610

16001611

16011612
def _float8_weight_only_quant_tensor(weight, config):
@@ -1753,7 +1764,9 @@ def __post_init__(self):
17531764

17541765

17551766
# for bc
1756-
float8_dynamic_activation_float8_weight = Float8DynamicActivationFloat8WeightConfig
1767+
float8_dynamic_activation_float8_weight = _ConfigDeprecationWrapper(
1768+
"float8_dynamic_activation_float8_weight", Float8DynamicActivationFloat8WeightConfig
1769+
)
17571770

17581771

17591772
def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
@@ -1926,7 +1939,9 @@ def __post_init__(self):
19261939

19271940

19281941
# for bc
1929-
float8_static_activation_float8_weight = Float8StaticActivationFloat8WeightConfig
1942+
float8_static_activation_float8_weight = _ConfigDeprecationWrapper(
1943+
"float8_static_activation_float8_weight", Float8StaticActivationFloat8WeightConfig
1944+
)
19301945

19311946

19321947
@register_quantize_module_handler(Float8StaticActivationFloat8WeightConfig)
@@ -2009,7 +2024,9 @@ def __post_init__(self):
20092024

20102025

20112026
# for BC
2012-
uintx_weight_only = UIntXWeightOnlyConfig
2027+
uintx_weight_only = _ConfigDeprecationWrapper(
2028+
"uintx_weight_only", UIntXWeightOnlyConfig
2029+
)
20132030

20142031

20152032
@register_quantize_module_handler(UIntXWeightOnlyConfig)
@@ -2262,7 +2279,7 @@ def __post_init__(self):
22622279

22632280

22642281
# for BC
2265-
fpx_weight_only = FPXWeightOnlyConfig
2282+
fpx_weight_only = _ConfigDeprecationWrapper("fpx_weight_only", FPXWeightOnlyConfig)
22662283

22672284

22682285
@register_quantize_module_handler(FPXWeightOnlyConfig)

torchao/utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from functools import reduce
1313
from importlib.metadata import version
1414
from math import gcd
15-
from typing import Any, Callable, Optional
15+
from typing import Any, Callable, Optional, Type
1616

1717
import torch
1818
import torch.nn.utils.parametrize as parametrize
@@ -433,6 +433,25 @@ def __eq__(self, other):
433433
TORCH_VERSION_AFTER_2_2 = _deprecated_torch_version_after("2.2.0.dev")
434434

435435

436+
class _ConfigDeprecationWrapper:
437+
"""
438+
A deprecation wrapper that directs users from a deprecated "config function"
439+
(e.g. `int4_weight_only`) to the replacement config class.
440+
"""
441+
442+
def __init__(self, deprecated_name: str, config_cls: Type):
443+
self.deprecated_name = deprecated_name
444+
self.config_cls = config_cls
445+
446+
def __call__(self, *args, **kwargs):
447+
warnings.warn(
448+
f"`{self.deprecated_name}` is deprecated and will be removed in a future release. "
449+
f"Please use `{self.config_cls.__name__}` instead. Example usage:\n"
450+
f" quantize_(model, {self.config_cls.__name__}(...))"
451+
)
452+
return self.config_cls(*args, **kwargs)
453+
454+
436455
"""
437456
Helper function for implementing aten op or torch function dispatch
438457
and dispatching to these implementations.

0 commit comments

Comments
 (0)