Skip to content

Commit 9e82674

Browse files
authored
Merge branch 'main' into 02_moe_e2e
2 parents 2901390 + 15b29be commit 9e82674

File tree

2 files changed

+25
-24
lines changed

2 files changed

+25
-24
lines changed

src/llmcompressor/modifiers/transform/spinquant/base.py

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
TransformScheme,
1010
apply_transform_config,
1111
)
12-
from compressed_tensors.utils import TorchDtype
12+
from compressed_tensors.utils import TorchDtype, get_head_dim
1313
from pydantic import Field, ValidationInfo, field_validator
1414
from transformers import PreTrainedModel
1515

@@ -129,16 +129,17 @@ def on_initialize(self, state: State, **kwargs) -> bool:
129129

130130
self.mappings = infer_mapping_from_model(state.model)
131131
self.norm_mappings = infer_norm_mapping_from_model(state.model)
132+
head_dim = get_head_dim(state.model.config)
132133

133134
config_groups = {}
134135
if SpinquantRotation.R1 in self.rotations:
135136
config_groups["R1"] = self._create_r1_scheme()
136137

137138
if SpinquantRotation.R2 in self.rotations:
138-
config_groups["R2"] = self._create_r2_scheme(state.model)
139+
config_groups["R2"] = self._create_r2_scheme(head_dim)
139140

140141
if SpinquantRotation.R3 in self.rotations:
141-
config_groups["R3"] = self._create_r3_scheme()
142+
config_groups["R3"] = self._create_r3_scheme(head_dim)
142143

143144
if SpinquantRotation.R4 in self.rotations:
144145
config_groups["R4"] = self._create_r4_scheme()
@@ -223,24 +224,7 @@ def _create_r1_scheme(self) -> TransformScheme:
223224
],
224225
)
225226

226-
def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme:
227-
config = model.config
228-
229-
if hasattr(config, "head_dim"):
230-
head_dim = config.head_dim
231-
elif hasattr(config, "hidden_size") and hasattr(config, "num_attention_heads"):
232-
head_dim = config.hidden_size // config.num_attention_heads
233-
else:
234-
raise NotImplementedError()
235-
236-
if self.transform_block_size:
237-
if head_dim % self.transform_block_size != 0:
238-
raise ValueError(
239-
f"transform_block_size {self.transform_block_size} must be set "
240-
f"such that model's head_dim {head_dim} is evenly divisible by it"
241-
)
242-
head_dim = self.transform_block_size
243-
227+
def _create_r2_scheme(self, head_dim: int) -> TransformScheme:
244228
return TransformScheme(
245229
type=self.transform_type,
246230
randomize=self.randomize,
@@ -257,9 +241,23 @@ def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme:
257241
],
258242
)
259243

260-
def _create_r3_scheme(self) -> TransformScheme:
261-
raise NotImplementedError(
262-
"SpinQuant R3 rotations will be added in a future release"
244+
def _create_r3_scheme(self, head_dim: int) -> TransformScheme:
245+
return TransformScheme(
246+
type=self.transform_type,
247+
randomize=self.randomize,
248+
requires_grad=self.learnable,
249+
precision=self.precision,
250+
head_dim=head_dim,
251+
apply=[
252+
TransformArgs(
253+
targets=[self.mappings.attn],
254+
location="q_attn",
255+
),
256+
TransformArgs(
257+
targets=[self.mappings.attn],
258+
location="k_cache",
259+
),
260+
],
263261
)
264262

265263
def _create_r4_scheme(self) -> TransformScheme:

src/llmcompressor/modifiers/transform/spinquant/mappings.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class SpinQuantMapping(BaseModel):
1414
layers (https://arxiv.org/pdf/2405.16406 Fig. 1).
1515
1616
:param embedding: name or regex of embedding layer
17+
:param attn: name or regex of attention block in decoder layer
1718
:param attn_q: name or regex of q_proj layer in attention block
1819
:param attn_k: name or regex of k_proj layer in attention block
1920
:param attn_v: name or regex of v_proj layer in attention block
@@ -29,6 +30,7 @@ class SpinQuantMapping(BaseModel):
2930

3031
embedding: str
3132

33+
attn: str
3234
attn_q: str
3335
attn_k: str
3436
attn_v: str
@@ -50,6 +52,7 @@ def cast_to_list(cls, value):
5052

5153
_default_mappings = SpinQuantMapping(
5254
embedding="re:.*embed_tokens$",
55+
attn="re:.*self_attn$",
5356
attn_q="re:.*q_proj$",
5457
attn_k="re:.*k_proj$",
5558
attn_v="re:.*v_proj$",

0 commit comments

Comments
 (0)