66import comfy .model_management
77import comfy .ldm .common_dit
88import comfy .latent_formats
9+ import comfy .ldm .lumina .controlnet
910
1011
1112class BlockWiseControlBlock (torch .nn .Module ):
@@ -189,6 +190,35 @@ def _process_layer_features(
189190
190191 return embedding
191192
193+ def z_image_convert (sd ):
194+ replace_keys = {".attention.to_out.0.bias" : ".attention.out.bias" ,
195+ ".attention.norm_k.weight" : ".attention.k_norm.weight" ,
196+ ".attention.norm_q.weight" : ".attention.q_norm.weight" ,
197+ ".attention.to_out.0.weight" : ".attention.out.weight"
198+ }
199+
200+ out_sd = {}
201+ for k in sorted (sd .keys ()):
202+ w = sd [k ]
203+
204+ k_out = k
205+ if k_out .endswith (".attention.to_k.weight" ):
206+ cc = [w ]
207+ continue
208+ if k_out .endswith (".attention.to_q.weight" ):
209+ cc = [w ] + cc
210+ continue
211+ if k_out .endswith (".attention.to_v.weight" ):
212+ cc = cc + [w ]
213+ w = torch .cat (cc , dim = 0 )
214+ k_out = k_out .replace (".attention.to_v.weight" , ".attention.qkv.weight" )
215+
216+ for r , rr in replace_keys .items ():
217+ k_out = k_out .replace (r , rr )
218+ out_sd [k_out ] = w
219+
220+ return out_sd
221+
192222class ModelPatchLoader :
193223 @classmethod
194224 def INPUT_TYPES (s ):
@@ -211,6 +241,9 @@ def load_model_patch(self, name):
211241 elif 'feature_embedder.mid_layer_norm.bias' in sd :
212242 sd = comfy .utils .state_dict_prefix_replace (sd , {"feature_embedder." : "" }, filter_keys = True )
213243 model = SigLIPMultiFeatProjModel (device = comfy .model_management .unet_offload_device (), dtype = dtype , operations = comfy .ops .manual_cast )
244+ elif 'control_all_x_embedder.2-1.weight' in sd : # alipai z image fun controlnet
245+ sd = z_image_convert (sd )
246+ model = comfy .ldm .lumina .controlnet .ZImage_Control (device = comfy .model_management .unet_offload_device (), dtype = dtype , operations = comfy .ops .manual_cast )
214247
215248 model .load_state_dict (sd )
216249 model = comfy .model_patcher .ModelPatcher (model , load_device = comfy .model_management .get_torch_device (), offload_device = comfy .model_management .unet_offload_device ())
@@ -263,6 +296,70 @@ def to(self, device_or_dtype):
263296 def models (self ):
264297 return [self .model_patch ]
265298
299+ class ZImageControlPatch :
300+ def __init__ (self , model_patch , vae , image , strength ):
301+ self .model_patch = model_patch
302+ self .vae = vae
303+ self .image = image
304+ self .strength = strength
305+ self .encoded_image = self .encode_latent_cond (image )
306+ self .encoded_image_size = (image .shape [1 ], image .shape [2 ])
307+ self .temp_data = None
308+
309+ def encode_latent_cond (self , image ):
310+ latent_image = comfy .latent_formats .Flux ().process_in (self .vae .encode (image ))
311+ return latent_image
312+
313+ def __call__ (self , kwargs ):
314+ x = kwargs .get ("x" )
315+ img = kwargs .get ("img" )
316+ txt = kwargs .get ("txt" )
317+ pe = kwargs .get ("pe" )
318+ vec = kwargs .get ("vec" )
319+ block_index = kwargs .get ("block_index" )
320+ spacial_compression = self .vae .spacial_compression_encode ()
321+ if self .encoded_image is None or self .encoded_image_size != (x .shape [- 2 ] * spacial_compression , x .shape [- 1 ] * spacial_compression ):
322+ image_scaled = comfy .utils .common_upscale (self .image .movedim (- 1 , 1 ), x .shape [- 1 ] * spacial_compression , x .shape [- 2 ] * spacial_compression , "area" , "center" )
323+ loaded_models = comfy .model_management .loaded_models (only_currently_used = True )
324+ self .encoded_image = self .encode_latent_cond (image_scaled .movedim (1 , - 1 ))
325+ self .encoded_image_size = (image_scaled .shape [- 2 ], image_scaled .shape [- 1 ])
326+ comfy .model_management .load_models_gpu (loaded_models )
327+ print ("encode" )
328+
329+ cnet_index = (block_index // 5 )
330+ cnet_index_float = (block_index / 5 )
331+
332+ kwargs .pop ("img" ) # we do ops in place
333+ kwargs .pop ("txt" )
334+
335+ cnet_blocks = self .model_patch .model .n_control_layers
336+ if cnet_index_float > (cnet_blocks - 1 ):
337+ self .temp_data = None
338+ return kwargs
339+
340+ if self .temp_data is None or self .temp_data [0 ] > cnet_index :
341+ self .temp_data = (- 1 , (None , self .model_patch .model (txt , self .encoded_image .to (img .dtype ), pe , vec )))
342+
343+ while self .temp_data [0 ] < cnet_index and (self .temp_data [0 ] + 1 ) < cnet_blocks :
344+ next_layer = self .temp_data [0 ] + 1
345+ self .temp_data = (next_layer , self .model_patch .model .forward_control_block (next_layer , self .temp_data [1 ][1 ], img [:, :self .temp_data [1 ][1 ].shape [1 ]], None , pe , vec ))
346+
347+ if cnet_index_float == self .temp_data [0 ]:
348+ img [:, :self .temp_data [1 ][0 ].shape [1 ]] += (self .temp_data [1 ][0 ] * self .strength )
349+ if cnet_blocks == self .temp_data [0 ] + 1 :
350+ self .temp_data = None
351+
352+ return kwargs
353+
354+ def to (self , device_or_dtype ):
355+ if isinstance (device_or_dtype , torch .device ):
356+ self .encoded_image = self .encoded_image .to (device_or_dtype )
357+ self .temp_data = None
358+ return self
359+
360+ def models (self ):
361+ return [self .model_patch ]
362+
266363class QwenImageDiffsynthControlnet :
267364 @classmethod
268365 def INPUT_TYPES (s ):
@@ -289,7 +386,10 @@ def diffsynth_controlnet(self, model, model_patch, vae, image, strength, mask=No
289386 mask = mask .unsqueeze (2 )
290387 mask = 1.0 - mask
291388
292- model_patched .set_model_double_block_patch (DiffSynthCnetPatch (model_patch , vae , image , strength , mask ))
389+ if isinstance (model_patch .model , comfy .ldm .lumina .controlnet .ZImage_Control ):
390+ model_patched .set_model_double_block_patch (ZImageControlPatch (model_patch , vae , image , strength ))
391+ else :
392+ model_patched .set_model_double_block_patch (DiffSynthCnetPatch (model_patch , vae , image , strength , mask ))
293393 return (model_patched ,)
294394
295395
0 commit comments