Skip to content

Commit b5c3db4

Browse files
committed
standardize get_fused_names
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent beff73b commit b5c3db4

File tree

4 files changed

+101
-116
lines changed

4 files changed

+101
-116
lines changed

src/llmcompressor/entrypoints/model_free/__init__.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from compressed_tensors.utils.match import _match_name
1212
from loguru import logger
1313
from safetensors.torch import load_file, save_file
14+
from torch.nn import Module
1415

1516
from llmcompressor.entrypoints.model_free.helpers import gpu_if_available
1617
from llmcompressor.entrypoints.model_free.lifecycle import (
@@ -169,11 +170,18 @@ def _process_file_microscale_scheme(
169170
"""
170171
assert is_microscale_scheme(scheme), "Use `_process_file` for non microscale scheme"
171172
tensors = load_file(file_path)
172-
fused_names = get_fused_names(tensors)
173-
fused_names_to_parent = {
174-
name: prefix for prefix, names in fused_names.items() for name in names
173+
fused_sets, unmatched_sets = get_fused_names(tensors)
174+
assert len(unmatched_sets) <= 0 # should be caught by `validate_safetensors_index`
175+
176+
fused_name_to_fused_index: dict[str, int] # fused_name -> fused_index
177+
fused_modules: dict[int, dict[str, Module]] # fused_index -> named_modules
178+
179+
fused_name_to_fused_index = {
180+
name: index
181+
for index, matched_set in enumerate(fused_sets)
182+
for name in matched_set.values()
175183
}
176-
fused_parent_submodules = defaultdict(dict)
184+
fused_modules = defaultdict(dict)
177185

178186
for name in list(tensors.keys()):
179187
module_name, param_name = name.rsplit(".", 1)
@@ -187,9 +195,9 @@ def _process_file_microscale_scheme(
187195

188196
# 2. calibrate weight qparams. Delay scale/zp calibration for fused modules
189197
calibrate_global_scale(module)
190-
if name in fused_names_to_parent:
191-
fused_parent = fused_names_to_parent[name]
192-
fused_parent_submodules[fused_parent][name] = module
198+
if name in fused_name_to_fused_index:
199+
fused_index = fused_name_to_fused_index[name]
200+
fused_modules[fused_index][name] = module
193201
continue
194202

195203
calibrate_scale_zp(module)
@@ -204,7 +212,7 @@ def _process_file_microscale_scheme(
204212
tensors[key] = value.to("cpu")
205213

206214
# compress and save miscroscale fused modules
207-
for parent_name, named_modules in fused_parent_submodules.items():
215+
for named_modules in fused_modules.values():
208216
# 2.1. fuse global scales
209217
global_scales = [m.weight_global_scale for m in named_modules.values()]
210218
fused_global_scale = torch.min(torch.cat(global_scales, dim=0))
@@ -216,7 +224,7 @@ def _process_file_microscale_scheme(
216224
# 2.2. finish calibration with fused global scales
217225
calibrate_scale_zp(module)
218226

219-
# 3. compress module using qparams
227+
# 3. compress module using miscroscale qparams
220228
compress_module(module)
221229

222230
# 4. save compressed data (on cpu)

src/llmcompressor/entrypoints/model_free/helpers.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Mapping, TypeVar
44

55
import torch
6+
from compressed_tensors.utils.match import _match_name
67
from loguru import logger
78
from transformers.file_utils import CONFIG_NAME
89

@@ -11,9 +12,15 @@
1112
"find_safetensors_index_path",
1213
"find_config_path",
1314
"find_safetensors_index_file",
15+
"match_names_set_eager",
16+
"MatchedNamesSet",
1417
"invert_mapping",
1518
]
1619

20+
KeyType = TypeVar("K")
21+
ValueType = TypeVar("V")
22+
MatchedNamesSet = dict[str, str | None]
23+
1724

1825
def gpu_if_available(device: torch.device | str | None) -> torch.device:
1926
if device is not None:
@@ -54,8 +61,38 @@ def find_safetensors_index_file(model_files: dict[str, str]) -> str | None:
5461
return None
5562

5663

57-
KeyType = TypeVar("K")
58-
ValueType = TypeVar("V")
64+
def match_names_set_eager(
65+
names: set[str] | list[str],
66+
targets: set[str] | list[str],
67+
return_unmatched: bool = True,
68+
) -> list[MatchedNamesSet] | tuple[list[MatchedNamesSet], MatchedNamesSet]:
69+
matched_sets = []
70+
matches = dict.fromkeys(targets, None)
71+
72+
for name in names:
73+
# match until we get a full set
74+
for target in targets:
75+
if _match_name(name, target):
76+
if matches[target] is None:
77+
matches[target] = name
78+
else:
79+
# matched target twice without completing a set
80+
raise ValueError(
81+
f"Matched a {target} twice before "
82+
f"completing set ({matches[target]}, {name})"
83+
)
84+
85+
# once we have a full set, yield and reset
86+
if all((matches[target] is not None for target in targets)):
87+
matched_sets.append(matches)
88+
matches = dict.fromkeys(targets, None)
89+
90+
unmatched_set = matches if any((v is not None for v in matches.values())) else None
91+
92+
if return_unmatched:
93+
return matched_sets, unmatched_set
94+
else:
95+
return matched_sets
5996

6097

6198
def invert_mapping(
Lines changed: 29 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,87 +1,43 @@
1-
import torch
21
from compressed_tensors.quantization import QuantizationScheme, QuantizationStrategy
3-
from compressed_tensors.utils.match import _match_name
42

5-
__all__ = ["get_fused_names", "is_microscale_scheme", "match_names_set_eager"]
3+
from llmcompressor.entrypoints.model_free.helpers import (
4+
MatchedNamesSet,
5+
match_names_set_eager,
6+
)
67

8+
__all__ = ["is_microscale_scheme", "get_fused_names", "DEFAULT_FUSED_MAPPINGS"]
79

8-
MatchedNamesSet = dict[str, str | None]
10+
11+
DEFAULT_FUSED_MAPPINGS = [
12+
[
13+
"re:.*(attn|attention)\.q_proj\.weight$",
14+
"re:.*(attn|attention)\.k_proj\.weight$",
15+
"re:.*(attn|attention)\.v_proj\.weight$",
16+
],
17+
[
18+
"re:.*(attn|attention)\.wq_a\.weight$",
19+
"re:.*(attn|attention)\.wkv_a_with_mqa\.weight$",
20+
],
21+
["re:.*mlp\.gate_proj\.weight$", "re:.*attn\.up_proj\.weight$"],
22+
["re:.*w1\.weight$", "re:.*w3\.weight$"],
23+
]
924

1025

1126
def is_microscale_scheme(scheme: QuantizationScheme) -> bool:
1227
assert scheme.weights is not None
1328
return scheme.weights.strategy == QuantizationStrategy.TENSOR_GROUP
1429

1530

16-
def match_names_set_eager(
31+
def get_fused_names(
1732
tensor_names: set[str] | list[str],
18-
targets: set[str] | list[str],
19-
return_unmatched: bool = True,
20-
) -> list[MatchedNamesSet] | tuple[list[MatchedNamesSet], MatchedNamesSet]:
21-
matched_sets = []
22-
matches = dict.fromkeys(targets, None)
23-
24-
for name in tensor_names:
25-
# match until we get a full set
26-
for target in targets:
27-
if _match_name(name, target):
28-
if matches[target] is None:
29-
matches[target] = name
30-
else:
31-
# matched target twice without completing a set
32-
raise ValueError(
33-
f"Matched a {target} twice before "
34-
f"completing set ({matches[target]}, {name})"
35-
)
36-
37-
# once we have a full set, yield and reset
38-
if all((matches[target] is not None for target in targets)):
39-
matched_sets.append(matches)
40-
matches = dict.fromkeys(targets, None)
41-
42-
unmatched_set = matches if any((v is not None for v in matches.values())) else None
43-
44-
if return_unmatched:
45-
return matched_sets, unmatched_set
46-
else:
47-
return matched_sets
48-
49-
50-
def get_fused_names(tensors: dict[str, torch.Tensor]) -> dict[str, list[str]]:
51-
fused_names = {}
52-
53-
for name in tensors:
54-
parts = name.rsplit(".")
55-
if len(parts) < 3:
56-
continue
57-
58-
parent, module, param = parts[-3:]
59-
60-
if (
61-
("attn" in parent or "attention" in parent)
62-
and module == "q_proj"
63-
and param == "weight"
64-
):
65-
parent_name = ".".join((*parts[:-3], parent))
66-
q_name = ".".join((parent_name, "q_proj", param))
67-
k_name = ".".join((parent_name, "k_proj", param))
68-
v_name = ".".join((parent_name, "v_proj", param))
69-
70-
submodule_names = [q_name, k_name, v_name]
71-
72-
if all(name in tensors for name in submodule_names):
73-
assert parent_name not in fused_names
74-
fused_names[parent_name] = submodule_names
75-
76-
if "mlp" in parent and module == "gate_proj" and param == "weight":
77-
parent_name = ".".join((*parts[:-3], parent))
78-
gate_name = ".".join((parent_name, "gate_proj", param))
79-
up_name = ".".join((parent_name, "up_proj", param))
80-
81-
submodule_names = [gate_name, up_name]
33+
) -> tuple[list[MatchedNamesSet], list[MatchedNamesSet]]:
34+
matched = []
35+
unmatched = []
36+
for mapping in DEFAULT_FUSED_MAPPINGS:
37+
_matched, _unmatched = match_names_set_eager(tensor_names, mapping)
8238

83-
if all(name in tensors for name in submodule_names):
84-
assert parent_name not in fused_names
85-
fused_names[parent_name] = submodule_names
39+
matched.extend(_matched)
40+
if _unmatched is not None:
41+
unmatched.append(_unmatched)
8642

87-
return fused_names
43+
return matched, unmatched

src/llmcompressor/entrypoints/model_free/reindex_fused_weights.py

Lines changed: 16 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
find_safetensors_index_file,
1414
invert_mapping,
1515
)
16-
from llmcompressor.entrypoints.model_free.microscale import match_names_set_eager
16+
from llmcompressor.entrypoints.model_free.microscale import (
17+
DEFAULT_FUSED_MAPPINGS,
18+
get_fused_names,
19+
)
1720
from llmcompressor.entrypoints.model_free.model_utils import (
1821
get_checkpoint_files,
1922
is_weights_file,
@@ -25,23 +28,6 @@
2528
# 1. the incomplete set is the last set of weights (sorted alphabetically)
2629
# 2. the remainder of the incomplete set is the next file (sorted alphabetically)
2730

28-
model_stub = ""
29-
fused_mappings: list[list[str]] = []
30-
31-
DEFAULT_FUSED_MAPPINGS = [
32-
[
33-
"re:.*(attn|attention)\.q_proj\.weight$",
34-
"re:.*(attn|attention)\.k_proj\.weight$",
35-
"re:.*(attn|attention)\.v_proj\.weight$",
36-
],
37-
[
38-
"re:.*(attn|attention)\.wq_a\.weight$",
39-
"re:.*(attn|attention)\.wkv_a_with_mqa\.weight$",
40-
],
41-
["re:.*mlp\.gate_proj\.weight$", "re:.*attn\.up_proj\.weight$"],
42-
["re:.*w1\.weight$", "re:.*w3\.weight$"],
43-
]
44-
4531

4632
def main(
4733
model_stub: str,
@@ -96,20 +82,18 @@ def main(
9682
carry_over_tensors = {}
9783

9884
tensor_names = sorted(list(tensors.keys()))
99-
for mapping in fused_mappings:
100-
_matches, unmatched = match_names_set_eager(tensor_names, mapping)
101-
102-
if unmatched is not None:
103-
# move to carry over
104-
unmatched_tensors = {
105-
key: tensors[key] for key in unmatched.values() if key is not None
106-
}
107-
carry_over_tensors.update(unmatched_tensors)
108-
109-
# delete from current file
110-
for key in unmatched_tensors:
111-
tensor_names.remove(key)
112-
del tensors[key]
85+
_matches, unmatched_sets = get_fused_names(tensor_names)
86+
for unmatched in unmatched_sets:
87+
# move to carry over
88+
unmatched_tensors = {
89+
key: tensors[key] for key in unmatched.values() if key is not None
90+
}
91+
carry_over_tensors.update(unmatched_tensors)
92+
93+
# delete from current file
94+
for key in unmatched_tensors:
95+
tensor_names.remove(key)
96+
del tensors[key]
11397

11498
# save tensors after modification
11599
executor.submit(with_progress(save_file, tensors, save_path, progress=progress))

0 commit comments

Comments
 (0)