11import os
22import weakref
3+ from collections .abc import Generator
34from functools import wraps
45from typing import Optional
56
@@ -126,8 +127,15 @@ def untie_word_embeddings(model: PreTrainedModel):
126127
127128 :param model: model to fix
128129 """
129- input_embed = model .get_input_embeddings ()
130- output_embed = model .get_output_embeddings ()
130+ try :
131+ input_embed = model .get_input_embeddings ()
132+ output_embed = model .get_output_embeddings ()
133+ except NotImplementedError as e :
134+ logger .warning (
135+ f"cannot untie model of type { model .__class__ } which doesn't have "
136+ f"get_input_embeddings and get_output_embeddings implmented\n { e } "
137+ )
138+ return
131139
132140 for module in (input_embed , output_embed ):
133141 if module is None or not hasattr (module , "weight" ):
@@ -149,6 +157,80 @@ def untie_word_embeddings(model: PreTrainedModel):
149157 model .config .tie_word_embeddings = False
150158
151159
160+ def _get_embeddings_or_warn (
161+ model : torch .nn .Module ,
162+ ) -> tuple [torch .nn .Module | None , torch .nn .Module | None ]:
163+ if not (
164+ hasattr (model , "get_input_embeddings" )
165+ and hasattr (model , "get_output_embeddings" )
166+ ):
167+ logger .warning (
168+ f"{ model .__class__ } doesn't have attribute get_input_embeddings and"
169+ " get_output_embeddings implemented."
170+ "\n This can cause"
171+ " problems when quantizing layers with shared weights"
172+ )
173+ return None , None
174+
175+ try :
176+ input_embeddings , output_embeddings = (
177+ model .get_input_embeddings (),
178+ model .get_output_embeddings (),
179+ )
180+ except NotImplementedError as e :
181+ logger .warning (
182+ f"{ model .__class__ } doesn't have get_input_embeddings and "
183+ "get_output_embeddings implemented."
184+ "\n This can cause"
185+ " problems when quantizing layers with shared weights"
186+ f"\n { e } "
187+ )
188+ return None , None
189+
190+ if not (
191+ isinstance (input_embeddings , torch .nn .Module )
192+ and isinstance (output_embeddings , torch .nn .Module )
193+ ):
194+ logger .warning (
195+ f"expected modules from { model .__class__ } get_input_embeddings and"
196+ f" get_output_embeddings but got { type (input_embeddings )} "
197+ f" and { type (output_embeddings )} ."
198+ "\n This can cause"
199+ " problems when quantizing layers with shared weights"
200+ )
201+ return None , None
202+ return input_embeddings , output_embeddings
203+
204+
205+ def untie_if_target_shared_embedding (
206+ model : torch .nn .Module , matched_module_generator : Generator [torch .nn .Module ]
207+ ):
208+ """
209+ Helper method that checks for shared input/output embedding and unties them
210+ if either shows up in the matched_module_generator
211+
212+ :param model: model to untie if embeddings are shared and targeted by
213+ matched_module_generator
214+ :param matched_module_generator: Generator of all modules (not names) which
215+ will be modified by quantization or transformation
216+ """
217+ input_embeddings , output_embeddings = _get_embeddings_or_warn (model )
218+
219+ if None in (input_embeddings , output_embeddings ): # if couldn't find embeddings
220+ return
221+
222+ if (
223+ input_embeddings .weight is not output_embeddings .weight
224+ ): # if not shared, can ignore
225+ return
226+
227+ # if shared, check if either is targeted
228+ for module in matched_module_generator :
229+ if module in (input_embeddings , output_embeddings ):
230+ untie_word_embeddings (model )
231+ return
232+
233+
152234def get_model_compressor (
153235 model : torch .nn .Module ,
154236 sparsity_config : Optional [SparsityCompressionConfig ] = None ,
0 commit comments