@@ -123,134 +123,76 @@ def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", b
123123 return new_state_dict
124124
125125
126- def _convert_kohya_lora_to_diffusers (state_dict , unet_name = "unet" , text_encoder_name = "text_encoder" ):
126+ def _convert_non_diffusers_lora_to_diffusers (state_dict , unet_name = "unet" , text_encoder_name = "text_encoder" ):
127+ """
128+ Converts a non-Diffusers LoRA state dict to a Diffusers compatible state dict.
129+
130+ Args:
131+ state_dict (`dict`): The state dict to convert.
132+ unet_name (`str`, optional): The name of the U-Net module in the Diffusers model. Defaults to "unet".
133+ text_encoder_name (`str`, optional): The name of the text encoder module in the Diffusers model. Defaults to
134+ "text_encoder".
135+
136+ Returns:
137+ `tuple`: A tuple containing the converted state dict and a dictionary of alphas.
138+ """
127139 unet_state_dict = {}
128140 te_state_dict = {}
129141 te2_state_dict = {}
130142 network_alphas = {}
131- is_unet_dora_lora = any ("dora_scale" in k and "lora_unet_" in k for k in state_dict )
132- is_te_dora_lora = any ("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k ) for k in state_dict )
133- is_te2_dora_lora = any ("dora_scale" in k and "lora_te2_" in k for k in state_dict )
134143
135- if is_unet_dora_lora or is_te_dora_lora or is_te2_dora_lora :
144+ # Check for DoRA-enabled LoRAs.
145+ if any (
146+ "dora_scale" in k and ("lora_unet_" in k or "lora_te_" in k or "lora_te1_" in k or "lora_te2_" in k )
147+ for k in state_dict
148+ ):
136149 if is_peft_version ("<" , "0.9.0" ):
137150 raise ValueError (
138151 "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
139152 )
140153
141- # every down weight has a corresponding up weight and potentially an alpha weight
142- lora_keys = [k for k in state_dict .keys () if k .endswith ("lora_down.weight" )]
143- for key in lora_keys :
154+ # Iterate over all LoRA weights.
155+ all_lora_keys = list (state_dict .keys ())
156+ for key in all_lora_keys :
157+ if not key .endswith ("lora_down.weight" ):
158+ continue
159+
160+ # Extract LoRA name.
144161 lora_name = key .split ("." )[0 ]
162+
163+ # Find corresponding up weight and alpha.
145164 lora_name_up = lora_name + ".lora_up.weight"
146165 lora_name_alpha = lora_name + ".alpha"
147166
167+ # Handle U-Net LoRAs.
148168 if lora_name .startswith ("lora_unet_" ):
149- diffusers_name = key .replace ("lora_unet_" , "" ).replace ("_" , "." )
150-
151- if "input.blocks" in diffusers_name :
152- diffusers_name = diffusers_name .replace ("input.blocks" , "down_blocks" )
153- else :
154- diffusers_name = diffusers_name .replace ("down.blocks" , "down_blocks" )
169+ diffusers_name = _convert_unet_lora_key (key )
155170
156- if "middle.block" in diffusers_name :
157- diffusers_name = diffusers_name .replace ("middle.block" , "mid_block" )
158- else :
159- diffusers_name = diffusers_name .replace ("mid.block" , "mid_block" )
160- if "output.blocks" in diffusers_name :
161- diffusers_name = diffusers_name .replace ("output.blocks" , "up_blocks" )
162- else :
163- diffusers_name = diffusers_name .replace ("up.blocks" , "up_blocks" )
164-
165- diffusers_name = diffusers_name .replace ("transformer.blocks" , "transformer_blocks" )
166- diffusers_name = diffusers_name .replace ("to.q.lora" , "to_q_lora" )
167- diffusers_name = diffusers_name .replace ("to.k.lora" , "to_k_lora" )
168- diffusers_name = diffusers_name .replace ("to.v.lora" , "to_v_lora" )
169- diffusers_name = diffusers_name .replace ("to.out.0.lora" , "to_out_lora" )
170- diffusers_name = diffusers_name .replace ("proj.in" , "proj_in" )
171- diffusers_name = diffusers_name .replace ("proj.out" , "proj_out" )
172- diffusers_name = diffusers_name .replace ("emb.layers" , "time_emb_proj" )
173-
174- # SDXL specificity.
175- if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name :
176- pattern = r"\.\d+(?=\D*$)"
177- diffusers_name = re .sub (pattern , "" , diffusers_name , count = 1 )
178- if ".in." in diffusers_name :
179- diffusers_name = diffusers_name .replace ("in.layers.2" , "conv1" )
180- if ".out." in diffusers_name :
181- diffusers_name = diffusers_name .replace ("out.layers.3" , "conv2" )
182- if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name :
183- diffusers_name = diffusers_name .replace ("op" , "conv" )
184- if "skip" in diffusers_name :
185- diffusers_name = diffusers_name .replace ("skip.connection" , "conv_shortcut" )
186-
187- # LyCORIS specificity.
188- if "time.emb.proj" in diffusers_name :
189- diffusers_name = diffusers_name .replace ("time.emb.proj" , "time_emb_proj" )
190- if "conv.shortcut" in diffusers_name :
191- diffusers_name = diffusers_name .replace ("conv.shortcut" , "conv_shortcut" )
192-
193- # General coverage.
194- if "transformer_blocks" in diffusers_name :
195- if "attn1" in diffusers_name or "attn2" in diffusers_name :
196- diffusers_name = diffusers_name .replace ("attn1" , "attn1.processor" )
197- diffusers_name = diffusers_name .replace ("attn2" , "attn2.processor" )
198- unet_state_dict [diffusers_name ] = state_dict .pop (key )
199- unet_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict .pop (lora_name_up )
200- elif "ff" in diffusers_name :
201- unet_state_dict [diffusers_name ] = state_dict .pop (key )
202- unet_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict .pop (lora_name_up )
203- elif any (key in diffusers_name for key in ("proj_in" , "proj_out" )):
204- unet_state_dict [diffusers_name ] = state_dict .pop (key )
205- unet_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict .pop (lora_name_up )
206- else :
207- unet_state_dict [diffusers_name ] = state_dict .pop (key )
208- unet_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict .pop (lora_name_up )
171+ # Store down and up weights.
172+ unet_state_dict [diffusers_name ] = state_dict .pop (key )
173+ unet_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict .pop (lora_name_up )
209174
210- if is_unet_dora_lora :
175+ # Store DoRA scale if present.
176+ if "dora_scale" in state_dict :
211177 dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
212178 unet_state_dict [
213179 diffusers_name .replace (dora_scale_key_to_replace , ".lora_magnitude_vector." )
214180 ] = state_dict .pop (key .replace ("lora_down.weight" , "dora_scale" ))
215181
182+ # Handle text encoder LoRAs.
216183 elif lora_name .startswith (("lora_te_" , "lora_te1_" , "lora_te2_" )):
184+ diffusers_name = _convert_text_encoder_lora_key (key , lora_name )
185+
186+ # Store down and up weights for te or te2.
217187 if lora_name .startswith (("lora_te_" , "lora_te1_" )):
218- key_to_replace = "lora_te_" if lora_name .startswith ("lora_te_" ) else "lora_te1_"
188+ te_state_dict [diffusers_name ] = state_dict .pop (key )
189+ te_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict .pop (lora_name_up )
219190 else :
220- key_to_replace = "lora_te2_"
221-
222- diffusers_name = key .replace (key_to_replace , "" ).replace ("_" , "." )
223- diffusers_name = diffusers_name .replace ("text.model" , "text_model" )
224- diffusers_name = diffusers_name .replace ("self.attn" , "self_attn" )
225- diffusers_name = diffusers_name .replace ("q.proj.lora" , "to_q_lora" )
226- diffusers_name = diffusers_name .replace ("k.proj.lora" , "to_k_lora" )
227- diffusers_name = diffusers_name .replace ("v.proj.lora" , "to_v_lora" )
228- diffusers_name = diffusers_name .replace ("out.proj.lora" , "to_out_lora" )
229- diffusers_name = diffusers_name .replace ("text.projection" , "text_projection" )
230-
231- if "self_attn" in diffusers_name :
232- if lora_name .startswith (("lora_te_" , "lora_te1_" )):
233- te_state_dict [diffusers_name ] = state_dict .pop (key )
234- te_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict .pop (lora_name_up )
235- else :
236- te2_state_dict [diffusers_name ] = state_dict .pop (key )
237- te2_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict .pop (lora_name_up )
238- elif "mlp" in diffusers_name :
239- # Be aware that this is the new diffusers convention and the rest of the code might
240- # not utilize it yet.
241- diffusers_name = diffusers_name .replace (".lora." , ".lora_linear_layer." )
242- if lora_name .startswith (("lora_te_" , "lora_te1_" )):
243- te_state_dict [diffusers_name ] = state_dict .pop (key )
244- te_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict .pop (lora_name_up )
245- else :
246- te2_state_dict [diffusers_name ] = state_dict .pop (key )
247- te2_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict .pop (lora_name_up )
248- # OneTrainer specificity
249- elif "text_projection" in diffusers_name and lora_name .startswith ("lora_te2_" ):
250191 te2_state_dict [diffusers_name ] = state_dict .pop (key )
251192 te2_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict .pop (lora_name_up )
252193
253- if (is_te_dora_lora or is_te2_dora_lora ) and lora_name .startswith (("lora_te_" , "lora_te1_" , "lora_te2_" )):
194+ # Store DoRA scale if present.
195+ if "dora_scale" in state_dict :
254196 dora_scale_key_to_replace_te = (
255197 "_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
256198 )
@@ -263,22 +205,18 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
263205 diffusers_name .replace (dora_scale_key_to_replace_te , ".lora_magnitude_vector." )
264206 ] = state_dict .pop (key .replace ("lora_down.weight" , "dora_scale" ))
265207
266- # Rename the alphas so that they can be mapped appropriately .
208+ # Store alpha if present .
267209 if lora_name_alpha in state_dict :
268210 alpha = state_dict .pop (lora_name_alpha ).item ()
269- if lora_name_alpha .startswith ("lora_unet_" ):
270- prefix = "unet."
271- elif lora_name_alpha .startswith (("lora_te_" , "lora_te1_" )):
272- prefix = "text_encoder."
273- else :
274- prefix = "text_encoder_2."
275- new_name = prefix + diffusers_name .split (".lora." )[0 ] + ".alpha"
276- network_alphas .update ({new_name : alpha })
211+ network_alphas .update (_get_alpha_name (lora_name_alpha , diffusers_name , alpha ))
277212
213+ # Check if any keys remain.
278214 if len (state_dict ) > 0 :
279215 raise ValueError (f"The following keys have not been correctly renamed: \n \n { ', ' .join (state_dict .keys ())} " )
280216
281217 logger .info ("Kohya-style checkpoint detected." )
218+
219+ # Construct final state dict.
282220 unet_state_dict = {f"{ unet_name } .{ module_name } " : params for module_name , params in unet_state_dict .items ()}
283221 te_state_dict = {f"{ text_encoder_name } .{ module_name } " : params for module_name , params in te_state_dict .items ()}
284222 te2_state_dict = (
@@ -291,3 +229,100 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
291229
292230 new_state_dict = {** unet_state_dict , ** te_state_dict }
293231 return new_state_dict , network_alphas
232+
233+
234+ def _convert_unet_lora_key (key ):
235+ """
236+ Converts a U-Net LoRA key to a Diffusers compatible key.
237+ """
238+ diffusers_name = key .replace ("lora_unet_" , "" ).replace ("_" , "." )
239+
240+ # Replace common U-Net naming patterns.
241+ diffusers_name = diffusers_name .replace ("input.blocks" , "down_blocks" )
242+ diffusers_name = diffusers_name .replace ("down.blocks" , "down_blocks" )
243+ diffusers_name = diffusers_name .replace ("middle.block" , "mid_block" )
244+ diffusers_name = diffusers_name .replace ("mid.block" , "mid_block" )
245+ diffusers_name = diffusers_name .replace ("output.blocks" , "up_blocks" )
246+ diffusers_name = diffusers_name .replace ("up.blocks" , "up_blocks" )
247+ diffusers_name = diffusers_name .replace ("transformer.blocks" , "transformer_blocks" )
248+ diffusers_name = diffusers_name .replace ("to.q.lora" , "to_q_lora" )
249+ diffusers_name = diffusers_name .replace ("to.k.lora" , "to_k_lora" )
250+ diffusers_name = diffusers_name .replace ("to.v.lora" , "to_v_lora" )
251+ diffusers_name = diffusers_name .replace ("to.out.0.lora" , "to_out_lora" )
252+ diffusers_name = diffusers_name .replace ("proj.in" , "proj_in" )
253+ diffusers_name = diffusers_name .replace ("proj.out" , "proj_out" )
254+ diffusers_name = diffusers_name .replace ("emb.layers" , "time_emb_proj" )
255+
256+ # SDXL specific conversions.
257+ if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name :
258+ pattern = r"\.\d+(?=\D*$)"
259+ diffusers_name = re .sub (pattern , "" , diffusers_name , count = 1 )
260+ if ".in." in diffusers_name :
261+ diffusers_name = diffusers_name .replace ("in.layers.2" , "conv1" )
262+ if ".out." in diffusers_name :
263+ diffusers_name = diffusers_name .replace ("out.layers.3" , "conv2" )
264+ if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name :
265+ diffusers_name = diffusers_name .replace ("op" , "conv" )
266+ if "skip" in diffusers_name :
267+ diffusers_name = diffusers_name .replace ("skip.connection" , "conv_shortcut" )
268+
269+ # LyCORIS specific conversions.
270+ if "time.emb.proj" in diffusers_name :
271+ diffusers_name = diffusers_name .replace ("time.emb.proj" , "time_emb_proj" )
272+ if "conv.shortcut" in diffusers_name :
273+ diffusers_name = diffusers_name .replace ("conv.shortcut" , "conv_shortcut" )
274+
275+ # General conversions.
276+ if "transformer_blocks" in diffusers_name :
277+ if "attn1" in diffusers_name or "attn2" in diffusers_name :
278+ diffusers_name = diffusers_name .replace ("attn1" , "attn1.processor" )
279+ diffusers_name = diffusers_name .replace ("attn2" , "attn2.processor" )
280+ elif "ff" in diffusers_name :
281+ pass
282+ elif any (key in diffusers_name for key in ("proj_in" , "proj_out" )):
283+ pass
284+ else :
285+ pass
286+
287+ return diffusers_name
288+
289+
290+ def _convert_text_encoder_lora_key (key , lora_name ):
291+ """
292+ Converts a text encoder LoRA key to a Diffusers compatible key.
293+ """
294+ if lora_name .startswith (("lora_te_" , "lora_te1_" )):
295+ key_to_replace = "lora_te_" if lora_name .startswith ("lora_te_" ) else "lora_te1_"
296+ else :
297+ key_to_replace = "lora_te2_"
298+
299+ diffusers_name = key .replace (key_to_replace , "" ).replace ("_" , "." )
300+ diffusers_name = diffusers_name .replace ("text.model" , "text_model" )
301+ diffusers_name = diffusers_name .replace ("self.attn" , "self_attn" )
302+ diffusers_name = diffusers_name .replace ("q.proj.lora" , "to_q_lora" )
303+ diffusers_name = diffusers_name .replace ("k.proj.lora" , "to_k_lora" )
304+ diffusers_name = diffusers_name .replace ("v.proj.lora" , "to_v_lora" )
305+ diffusers_name = diffusers_name .replace ("out.proj.lora" , "to_out_lora" )
306+ diffusers_name = diffusers_name .replace ("text.projection" , "text_projection" )
307+
308+ if "self_attn" in diffusers_name or "text_projection" in diffusers_name :
309+ pass
310+ elif "mlp" in diffusers_name :
311+ # Be aware that this is the new diffusers convention and the rest of the code might
312+ # not utilize it yet.
313+ diffusers_name = diffusers_name .replace (".lora." , ".lora_linear_layer." )
314+ return diffusers_name
315+
316+
317+ def _get_alpha_name (lora_name_alpha , diffusers_name , alpha ):
318+ """
319+ Gets the correct alpha name for the Diffusers model.
320+ """
321+ if lora_name_alpha .startswith ("lora_unet_" ):
322+ prefix = "unet."
323+ elif lora_name_alpha .startswith (("lora_te_" , "lora_te1_" )):
324+ prefix = "text_encoder."
325+ else :
326+ prefix = "text_encoder_2."
327+ new_name = prefix + diffusers_name .split (".lora." )[0 ] + ".alpha"
328+ return {new_name : alpha }
0 commit comments