@@ -1081,25 +1081,15 @@ def forward(
10811081 A tuple of tensors that if specified are added to the residuals of down unet blocks.
10821082 mid_block_additional_residual: (`torch.Tensor`, *optional*):
10831083 A tensor that if specified is added to the residual of the middle unet block.
1084+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
1085+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
10841086 encoder_attention_mask (`torch.Tensor`):
10851087 A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
10861088 `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
10871089 which adds large negative values to the attention scores corresponding to "discard" tokens.
10881090 return_dict (`bool`, *optional*, defaults to `True`):
10891091 Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
10901092 tuple.
1091- cross_attention_kwargs (`dict`, *optional*):
1092- A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
1093- added_cond_kwargs: (`dict`, *optional*):
1094- A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
1095- are passed along to the UNet blocks.
1096- down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
1097- additional residuals to be added to UNet long skip connections from down blocks to up blocks for
1098- example from ControlNet side model(s)
1099- mid_block_additional_residual (`torch.Tensor`, *optional*):
1100- additional residual to be added to UNet mid block output, for example from ControlNet side model
1101- down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
1102- additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
11031093
11041094 Returns:
11051095 [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
@@ -1185,7 +1175,13 @@ def forward(
11851175 cross_attention_kwargs ["gligen" ] = {"objs" : self .position_net (** gligen_args )}
11861176
11871177 # 3. down
1188- lora_scale = cross_attention_kwargs .get ("scale" , 1.0 ) if cross_attention_kwargs is not None else 1.0
1178+ # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
1179+ # to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
1180+ if cross_attention_kwargs is not None :
1181+ lora_scale = cross_attention_kwargs .pop ("scale" , 1.0 )
1182+ else :
1183+ lora_scale = 1.0
1184+
11891185 if USE_PEFT_BACKEND :
11901186 # weight the lora layers by setting `lora_scale` for each PEFT layer
11911187 scale_lora_layers (self , lora_scale )
0 commit comments