Skip to content

Commit dd4b942

Browse files
authored
Merge branch 'main' into fix_qwen3
2 parents 13ac6ee + c254c19 commit dd4b942

File tree

10 files changed

+422
-35
lines changed

10 files changed

+422
-35
lines changed

src/llmcompressor/args/model_arguments.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,11 @@ class ModelArguments:
6464
)
6565

6666
tie_word_embeddings: bool = field(
67-
default=False,
67+
default=True,
6868
metadata={
6969
"help": "Whether the model's input and output word embeddings "
70-
"should be tied. Note that this is only relevant if the "
70+
"should attempt to be left tied. False means always untie."
71+
" Note that this is only relevant if the "
7172
"model has a output word embedding layer."
7273
},
7374
)

src/llmcompressor/entrypoints/oneshot.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def oneshot(
233233
processor: Optional[Union[str, ProcessorMixin]] = None,
234234
use_auth_token: bool = False,
235235
precision: str = "auto",
236-
tie_word_embeddings: bool = False,
236+
tie_word_embeddings: bool = True,
237237
trust_remote_code_model: bool = False,
238238
save_compressed: bool = True,
239239
model_revision: str = "main",
@@ -282,7 +282,7 @@ def oneshot(
282282
models.
283283
:param precision: Precision to cast model weights to, default to auto.
284284
:param tie_word_embeddings: Whether the model's input and output word embeddings
285-
should be tied.
285+
should be left tied if possible. False means always untie.
286286
:param trust_remote_code_model: Whether to allow for custom models to execute
287287
their own modeling files.
288288
:param save_compressed: Whether to compress sparse models during save.

src/llmcompressor/entrypoints/utils.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def pre_process(
5959
Raises:
6060
FileNotFoundError: If the model or processor path is invalid.
6161
"""
62-
_warn_tied_embeddings(model_args.tie_word_embeddings)
6362

6463
# Initialize model
6564
if isinstance(model_args.model, (str, PosixPath)):
@@ -150,21 +149,6 @@ def post_process(
150149
reset_session()
151150

152151

153-
def _warn_tied_embeddings(tie_word_embeddings: bool = False):
154-
"""
155-
Logs a warning if the model has tied word embeddings.
156-
The `tie_word_embeddings` flag may cause issues during saving in the one-shot
157-
calibration workflow due to shared tensor addresses.
158-
"""
159-
if tie_word_embeddings:
160-
logger.debug(
161-
"The tie_word_embeddings flag is by default set to False. "
162-
"This guarantees that the one-shot algorithm saves the final "
163-
"weights without errors. Detected tie_word_embeddings=True. "
164-
"This may cause issues with the one-shot algorithm on save."
165-
)
166-
167-
168152
def initialize_model_from_path(
169153
model_args: ModelArguments,
170154
training_args: Optional[TrainingArguments] = None,

src/llmcompressor/modifiers/quantization/quantization/mixin.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
reset_quantization_status,
3535
)
3636
from llmcompressor.modifiers.utils.hooks import HooksMixin
37+
from llmcompressor.transformers.compression.compressed_tensors_utils import (
38+
untie_if_target_shared_embedding,
39+
)
3740

3841
__all__ = ["QuantizationMixin"]
3942

@@ -179,6 +182,12 @@ def start_calibration(self, model: torch.nn.Module):
179182
180183
:param model: model to prepare for calibration
181184
"""
185+
186+
matched_module_generator = (
187+
x[1] for x in match_named_modules(model, self.resolved_targets, self.ignore)
188+
)
189+
untie_if_target_shared_embedding(model, matched_module_generator)
190+
182191
for _, module in match_named_modules(model, self.resolved_targets, self.ignore):
183192
self._initialize_observers(module)
184193
self._calibration_hooks |= self._initialize_hooks(module)

src/llmcompressor/modifiers/transform/quip/base.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,14 @@
77
TransformScheme,
88
apply_transform_config,
99
)
10-
from compressed_tensors.utils import TorchDtype
10+
from compressed_tensors.utils import TorchDtype, match_named_modules
1111
from pydantic import Field, ValidationInfo, field_validator
1212

1313
from llmcompressor.core import Event, EventType, State
1414
from llmcompressor.modifiers import Modifier
15+
from llmcompressor.transformers.compression.compressed_tensors_utils import (
16+
untie_if_target_shared_embedding,
17+
)
1518

1619
__all__ = ["QuIPModifier"]
1720

@@ -100,6 +103,16 @@ def on_initialize(self, state: State, **kwargs) -> bool:
100103
def on_start(self, state: State, event: Event, **kwargs):
101104
self.started_ = True
102105

106+
def matched_module_generator():
107+
for scheme in self.transform_config.config_groups.values():
108+
for arg in scheme.apply:
109+
gen = match_named_modules(state.model, arg.targets, arg.ignore)
110+
for _, module in gen:
111+
yield module
112+
113+
# Untie embeddings if they will be targeted by transforms
114+
untie_if_target_shared_embedding(state.model, matched_module_generator())
115+
103116
apply_transform_config(state.model, self.transform_config)
104117

105118
def on_event(self, state: State, event: Event, **kwargs):

src/llmcompressor/modifiers/transform/spinquant/base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
from llmcompressor.core import Event, EventType, State
1717
from llmcompressor.modeling import center_embeddings, fuse_norm_linears
1818
from llmcompressor.modifiers import Modifier
19+
from llmcompressor.transformers.compression.compressed_tensors_utils import (
20+
untie_word_embeddings,
21+
)
1922

2023
from .mappings import SpinQuantMapping, infer_mapping_from_model
2124
from .norm_mappings import NormMapping, infer_norm_mapping_from_model
@@ -148,6 +151,8 @@ def on_initialize(self, state: State, **kwargs) -> bool:
148151
def on_start(self, state: State, event: Event, **kwargs):
149152
self.started_ = True
150153

154+
# needed any time embeddings/lm_head is modified
155+
untie_word_embeddings(state.model)
151156
# needs to happen after the model has been hooked to execute on the GPU
152157
# otherwise we're applying weight transforms on CPU
153158
self._center_embeddings(state.model)

src/llmcompressor/transformers/compression/compressed_tensors_utils.py

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import weakref
3+
from collections.abc import Generator
34
from functools import wraps
45
from typing import Optional
56

@@ -126,8 +127,15 @@ def untie_word_embeddings(model: PreTrainedModel):
126127
127128
:param model: model to fix
128129
"""
129-
input_embed = model.get_input_embeddings()
130-
output_embed = model.get_output_embeddings()
130+
try:
131+
input_embed = model.get_input_embeddings()
132+
output_embed = model.get_output_embeddings()
133+
except NotImplementedError as e:
134+
logger.warning(
135+
f"cannot untie model of type {model.__class__} which doesn't have "
136+
f"get_input_embeddings and get_output_embeddings implmented\n{e}"
137+
)
138+
return
131139

132140
for module in (input_embed, output_embed):
133141
if module is None or not hasattr(module, "weight"):
@@ -149,6 +157,80 @@ def untie_word_embeddings(model: PreTrainedModel):
149157
model.config.tie_word_embeddings = False
150158

151159

160+
def _get_embeddings_or_warn(
161+
model: torch.nn.Module,
162+
) -> tuple[torch.nn.Module | None, torch.nn.Module | None]:
163+
if not (
164+
hasattr(model, "get_input_embeddings")
165+
and hasattr(model, "get_output_embeddings")
166+
):
167+
logger.warning(
168+
f"{model.__class__} doesn't have attribute get_input_embeddings and"
169+
" get_output_embeddings implemented."
170+
"\nThis can cause"
171+
" problems when quantizing layers with shared weights"
172+
)
173+
return None, None
174+
175+
try:
176+
input_embeddings, output_embeddings = (
177+
model.get_input_embeddings(),
178+
model.get_output_embeddings(),
179+
)
180+
except NotImplementedError as e:
181+
logger.warning(
182+
f"{model.__class__} doesn't have get_input_embeddings and "
183+
"get_output_embeddings implemented."
184+
"\nThis can cause"
185+
" problems when quantizing layers with shared weights"
186+
f"\n{e}"
187+
)
188+
return None, None
189+
190+
if not (
191+
isinstance(input_embeddings, torch.nn.Module)
192+
and isinstance(output_embeddings, torch.nn.Module)
193+
):
194+
logger.warning(
195+
f"expected modules from {model.__class__} get_input_embeddings and"
196+
f" get_output_embeddings but got {type(input_embeddings)}"
197+
f" and {type(output_embeddings)}."
198+
"\nThis can cause"
199+
" problems when quantizing layers with shared weights"
200+
)
201+
return None, None
202+
return input_embeddings, output_embeddings
203+
204+
205+
def untie_if_target_shared_embedding(
206+
model: torch.nn.Module, matched_module_generator: Generator[torch.nn.Module]
207+
):
208+
"""
209+
Helper method that checks for shared input/output embedding and unties them
210+
if either shows up in the matched_module_generator
211+
212+
:param model: model to untie if embeddings are shared and targeted by
213+
matched_module_generator
214+
:param matched_module_generator: Generator of all modules (not names) which
215+
will be modified by quantization or transformation
216+
"""
217+
input_embeddings, output_embeddings = _get_embeddings_or_warn(model)
218+
219+
if None in (input_embeddings, output_embeddings): # if couldn't find embeddings
220+
return
221+
222+
if (
223+
input_embeddings.weight is not output_embeddings.weight
224+
): # if not shared, can ignore
225+
return
226+
227+
# if shared, check if either is targeted
228+
for module in matched_module_generator:
229+
if module in (input_embeddings, output_embeddings):
230+
untie_word_embeddings(model)
231+
return
232+
233+
152234
def get_model_compressor(
153235
model: torch.nn.Module,
154236
sparsity_config: Optional[SparsityCompressionConfig] = None,

0 commit comments

Comments
 (0)