99 TransformScheme ,
1010 apply_transform_config ,
1111)
12- from compressed_tensors .utils import TorchDtype
12+ from compressed_tensors .utils import TorchDtype , get_head_dim
1313from pydantic import Field , ValidationInfo , field_validator
1414from transformers import PreTrainedModel
1515
@@ -129,16 +129,17 @@ def on_initialize(self, state: State, **kwargs) -> bool:
129129
130130 self .mappings = infer_mapping_from_model (state .model )
131131 self .norm_mappings = infer_norm_mapping_from_model (state .model )
132+ head_dim = get_head_dim (state .model .config )
132133
133134 config_groups = {}
134135 if SpinquantRotation .R1 in self .rotations :
135136 config_groups ["R1" ] = self ._create_r1_scheme ()
136137
137138 if SpinquantRotation .R2 in self .rotations :
138- config_groups ["R2" ] = self ._create_r2_scheme (state . model )
139+ config_groups ["R2" ] = self ._create_r2_scheme (head_dim )
139140
140141 if SpinquantRotation .R3 in self .rotations :
141- config_groups ["R3" ] = self ._create_r3_scheme ()
142+ config_groups ["R3" ] = self ._create_r3_scheme (head_dim )
142143
143144 if SpinquantRotation .R4 in self .rotations :
144145 config_groups ["R4" ] = self ._create_r4_scheme ()
@@ -223,24 +224,7 @@ def _create_r1_scheme(self) -> TransformScheme:
223224 ],
224225 )
225226
226- def _create_r2_scheme (self , model : PreTrainedModel ) -> TransformScheme :
227- config = model .config
228-
229- if hasattr (config , "head_dim" ):
230- head_dim = config .head_dim
231- elif hasattr (config , "hidden_size" ) and hasattr (config , "num_attention_heads" ):
232- head_dim = config .hidden_size // config .num_attention_heads
233- else :
234- raise NotImplementedError ()
235-
236- if self .transform_block_size :
237- if head_dim % self .transform_block_size != 0 :
238- raise ValueError (
239- f"transform_block_size { self .transform_block_size } must be set "
240- f"such that model's head_dim { head_dim } is evenly divisible by it"
241- )
242- head_dim = self .transform_block_size
243-
227+ def _create_r2_scheme (self , head_dim : int ) -> TransformScheme :
244228 return TransformScheme (
245229 type = self .transform_type ,
246230 randomize = self .randomize ,
@@ -257,9 +241,23 @@ def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme:
257241 ],
258242 )
259243
260- def _create_r3_scheme (self ) -> TransformScheme :
261- raise NotImplementedError (
262- "SpinQuant R3 rotations will be added in a future release"
244+ def _create_r3_scheme (self , head_dim : int ) -> TransformScheme :
245+ return TransformScheme (
246+ type = self .transform_type ,
247+ randomize = self .randomize ,
248+ requires_grad = self .learnable ,
249+ precision = self .precision ,
250+ head_dim = head_dim ,
251+ apply = [
252+ TransformArgs (
253+ targets = [self .mappings .attn ],
254+ location = "q_attn" ,
255+ ),
256+ TransformArgs (
257+ targets = [self .mappings .attn ],
258+ location = "k_cache" ,
259+ ),
260+ ],
263261 )
264262
265263 def _create_r4_scheme (self ) -> TransformScheme :
0 commit comments