Skip to content

Commit 354d35a

Browse files
DefTruthDN6
andauthored
bugfix: fix chrono-edit context parallel (#12660)
* bugfix: fix chrono-edit context parallel * bugfix: fix chrono-edit context parallel * Update src/diffusers/models/transformers/transformer_chronoedit.py Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> * Update src/diffusers/models/transformers/transformer_chronoedit.py Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> * Clean up comments in transformer_chronoedit.py Removed unnecessary comments regarding parallelization in cross-attention. * fix style * fix qc --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
1 parent 544ba67 commit 354d35a

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

src/diffusers/models/transformers/transformer_chronoedit.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: t
6767
return key_img, value_img
6868

6969

70-
# Copied from diffusers.models.transformers.transformer_wan.WanAttnProcessor
70+
# modified from diffusers.models.transformers.transformer_wan.WanAttnProcessor
7171
class WanAttnProcessor:
7272
_attention_backend = None
7373
_parallel_config = None
@@ -137,7 +137,8 @@ def apply_rotary_emb(
137137
dropout_p=0.0,
138138
is_causal=False,
139139
backend=self._attention_backend,
140-
parallel_config=self._parallel_config,
140+
# Reference: https://github.com/huggingface/diffusers/pull/12660
141+
parallel_config=None,
141142
)
142143
hidden_states_img = hidden_states_img.flatten(2, 3)
143144
hidden_states_img = hidden_states_img.type_as(query)
@@ -150,7 +151,8 @@ def apply_rotary_emb(
150151
dropout_p=0.0,
151152
is_causal=False,
152153
backend=self._attention_backend,
153-
parallel_config=self._parallel_config,
154+
# Reference: https://github.com/huggingface/diffusers/pull/12660
155+
parallel_config=(self._parallel_config if encoder_hidden_states is None else None),
154156
)
155157
hidden_states = hidden_states.flatten(2, 3)
156158
hidden_states = hidden_states.type_as(query)
@@ -568,9 +570,11 @@ class ChronoEditTransformer3DModel(
568570
"blocks.0": {
569571
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
570572
},
571-
"blocks.*": {
572-
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
573-
},
573+
# Reference: https://github.com/huggingface/diffusers/pull/12660
574+
# We need to disable the splitting of encoder_hidden_states because
575+
# the image_encoder consistently generates 257 tokens for image_embed. This causes
576+
# the shape of encoder_hidden_states—whose token count is always 769 (512 + 257)
577+
# after concatenation—to be indivisible by the number of devices in the CP.
574578
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
575579
}
576580

0 commit comments

Comments
 (0)