|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +import pytest |
15 | 16 | import torch |
16 | 17 | from compressed_tensors.quantization import ( |
17 | | - QuantizationConfig, |
18 | | - QuantizationStatus, |
19 | | - apply_quantization_config, |
| 18 | + QuantizationArgs, |
| 19 | + QuantizationScheme, |
| 20 | + initialize_module_for_quantization, |
20 | 21 | ) |
21 | | -from transformers import AutoModelForCausalLM, AutoTokenizer |
22 | | - |
23 | | -from llmcompressor.modifiers.quantization.calibration import ( |
24 | | - calibrate_input_hook, |
25 | | - initialize_observer, |
26 | | -) |
27 | | -from llmcompressor.observers.helpers import get_observer_token_count |
28 | | - |
29 | | - |
30 | | -def _prep_for_input_quant_calibration(module: torch.nn.Module): |
31 | | - quantization_scheme = getattr(module, "quantization_scheme", None) |
32 | | - if not quantization_scheme: |
33 | | - return |
34 | | - |
35 | | - module.register_forward_pre_hook(calibrate_input_hook) |
36 | | - module.quantization_status = QuantizationStatus.CALIBRATION |
37 | 22 |
|
| 23 | +from llmcompressor.observers.helpers import flatten_for_calibration |
38 | 24 |
|
39 | | -def test_get_observer_token_count(): |
40 | | - model = AutoModelForCausalLM.from_pretrained("Isotonic/TinyMixtral-4x248M-MoE") |
41 | | - tokenizer = AutoTokenizer.from_pretrained("Isotonic/TinyMixtral-4x248M-MoE") |
42 | | - model.eval() |
43 | | - config = QuantizationConfig( |
44 | | - format="fakequant", |
45 | | - quantization_status="calibration", |
46 | | - config_groups={ |
47 | | - "group_1": { |
48 | | - "input_activations": { |
49 | | - "num_bits": 8, |
50 | | - "type": "int", |
51 | | - "symmetric": False, |
52 | | - "strategy": "tensor", |
53 | | - }, |
54 | | - "targets": ["Linear"], |
55 | | - }, |
56 | | - }, |
57 | | - ) |
58 | | - apply_quantization_config(model, config) |
59 | | - model.apply(lambda module: initialize_observer(module, base_name="input")) |
60 | | - model.apply(_prep_for_input_quant_calibration) |
61 | | - |
62 | | - # start calibration |
63 | | - calib_list = [ |
64 | | - "I am a string that", |
65 | | - "is used for calibration so", |
66 | | - "that your model is", |
67 | | - "quantized properly.", |
68 | | - ] |
69 | 25 |
|
70 | | - total_num_tokens_observed = 0 |
71 | | - for calib_sample in calib_list: |
72 | | - calib_tensor = tokenizer(calib_sample, return_tensors="pt") |
73 | | - _ = model(**calib_tensor) |
74 | | - total_num_tokens_observed += len(calib_tensor.input_ids.flatten()) |
| 26 | +def make_dummy_g_idx(columns: int, group_size: int) -> torch.Tensor: |
| 27 | + perm = torch.randperm(columns) |
| 28 | + return torch.tensor([index // group_size for index in range(columns)])[perm] |
75 | 29 |
|
76 | | - counter = get_observer_token_count(model) |
77 | 30 |
|
78 | | - # filter out the None values |
79 | | - # (tokens, in the appropriate format, that were not observed by the model) |
80 | | - counter = {k: v for k, v in counter.items() if v is not None} |
| 31 | +@pytest.mark.parametrize( |
| 32 | + "args", |
| 33 | + [ |
| 34 | + QuantizationArgs(strategy="tensor"), |
| 35 | + QuantizationArgs(strategy="tensor_group", group_size=4), |
| 36 | + ], |
| 37 | +) |
| 38 | +def test_flatten_for_calibration_input(args): |
| 39 | + module = torch.nn.Linear(8, 10) |
| 40 | + scheme = QuantizationScheme(targets=[], input_activations=args) |
| 41 | + initialize_module_for_quantization(module, scheme) |
81 | 42 |
|
82 | | - # iterate over all the layers in the model where the token count in the proper |
83 | | - # format is has been observed |
84 | | - for i in range(model.config.num_hidden_layers): |
85 | | - # fetch the tokens observed by the router |
86 | | - tokens_observed_by_router = counter.pop( |
87 | | - f"model.layers.{i}.block_sparse_moe.gate" |
88 | | - ) |
89 | | - assert tokens_observed_by_router == total_num_tokens_observed |
| 43 | + input = torch.empty((3, 5, 8)) |
| 44 | + input_flattened = flatten_for_calibration(input, "input", scheme.input_activations) |
| 45 | + assert input_flattened.shape[1:-1] == module.input_scale.shape |
| 46 | + assert input_flattened.shape[1:-1] == module.input_zero_point.shape |
90 | 47 |
|
91 | | - # fetch the sum of tokens observed by all the experts |
92 | | - sum_tokens_observed_by_experts = 0 |
93 | | - keys_for_this_layer = [ |
94 | | - k |
95 | | - for k in counter.keys() |
96 | | - if f"model.layers.{i}.block_sparse_moe.experts" in k |
97 | | - ] |
98 | | - for key in keys_for_this_layer: |
99 | | - sum_tokens_observed_by_experts += counter.pop(key) |
100 | 48 |
|
101 | | - # each Mixtral expert is comprised of 3 linear layers, |
102 | | - # so we need to multiply by 3 |
103 | | - assert ( |
104 | | - sum_tokens_observed_by_experts |
105 | | - == total_num_tokens_observed * model.config.num_experts_per_tok * 3 |
106 | | - ) |
| 49 | +@pytest.mark.parametrize( |
| 50 | + "args,g_idx", |
| 51 | + [ |
| 52 | + (QuantizationArgs(strategy="tensor"), None), |
| 53 | + (QuantizationArgs(strategy="channel"), None), |
| 54 | + (QuantizationArgs(strategy="group", group_size=4), None), |
| 55 | + (QuantizationArgs(strategy="group", group_size=4), make_dummy_g_idx(8, 4)), |
| 56 | + (QuantizationArgs(strategy="tensor_group", group_size=4), None), |
| 57 | + (QuantizationArgs(strategy="block", block_structure=[5, 4]), None), |
| 58 | + ], |
| 59 | +) |
| 60 | +def test_flatten_for_calibration_weights(args, g_idx): |
| 61 | + module = torch.nn.Linear(8, 10) |
| 62 | + scheme = QuantizationScheme(targets=[], weights=args) |
| 63 | + initialize_module_for_quantization(module, scheme) |
107 | 64 |
|
108 | | - # there are no more information in the counter |
109 | | - assert len(counter) == 0 |
| 65 | + weight_flattened = flatten_for_calibration( |
| 66 | + module.weight, |
| 67 | + "weight", |
| 68 | + scheme.weights, |
| 69 | + g_idx=g_idx, |
| 70 | + ) |
| 71 | + assert weight_flattened.shape[1:-1] == module.weight_scale.shape |
| 72 | + assert weight_flattened.shape[1:-1] == module.weight_zero_point.shape |
0 commit comments