Skip to content

Commit 8d45f21

Browse files
DN6sayakpaul
andauthored
Fix Context Parallel validation checks (#12446)
* update * update * update * update * update * update * update * update * update * update * update * update --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 0fd58c7 commit 8d45f21

File tree

3 files changed

+108
-75
lines changed

3 files changed

+108
-75
lines changed

src/diffusers/models/_modeling_parallel.py

Lines changed: 47 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,16 @@ class ContextParallelConfig:
4444
4545
Args:
4646
ring_degree (`int`, *optional*, defaults to `1`):
47-
Number of devices to use for ring attention within a context parallel region. Must be a divisor of the
48-
total number of devices in the context parallel mesh.
47+
Number of devices to use for Ring Attention. Sequence is split across devices. Each device computes
48+
attention between its local Q and KV chunks passed sequentially around ring. Lower memory (only holds 1/N
49+
of KV at a time), overlaps compute with communication, but requires N iterations to see all tokens. Best
50+
for long sequences with limited memory/bandwidth. Number of devices to use for ring attention within a
51+
context parallel region. Must be a divisor of the total number of devices in the context parallel mesh.
4952
ulysses_degree (`int`, *optional*, defaults to `1`):
50-
Number of devices to use for ulysses attention within a context parallel region. Must be a divisor of the
51-
total number of devices in the context parallel mesh.
53+
Number of devices to use for Ulysses Attention. Sequence split is across devices. Each device computes
54+
local QKV, then all-gathers all KV chunks to compute full attention in one pass. Higher memory (stores all
55+
KV), requires high-bandwidth all-to-all communication, but lower latency. Best for moderate sequences with
56+
good interconnect bandwidth.
5257
convert_to_fp32 (`bool`, *optional*, defaults to `True`):
5358
Whether to convert output and LSE to float32 for ring attention numerical stability.
5459
rotate_method (`str`, *optional*, defaults to `"allgather"`):
@@ -79,29 +84,46 @@ def __post_init__(self):
7984
if self.ulysses_degree is None:
8085
self.ulysses_degree = 1
8186

87+
if self.ring_degree == 1 and self.ulysses_degree == 1:
88+
raise ValueError(
89+
"Either ring_degree or ulysses_degree must be greater than 1 in order to use context parallel inference"
90+
)
91+
if self.ring_degree < 1 or self.ulysses_degree < 1:
92+
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
93+
if self.ring_degree > 1 and self.ulysses_degree > 1:
94+
raise ValueError(
95+
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
96+
)
97+
if self.rotate_method != "allgather":
98+
raise NotImplementedError(
99+
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
100+
)
101+
102+
@property
103+
def mesh_shape(self) -> Tuple[int, int]:
104+
return (self.ring_degree, self.ulysses_degree)
105+
106+
@property
107+
def mesh_dim_names(self) -> Tuple[str, str]:
108+
"""Dimension names for the device mesh."""
109+
return ("ring", "ulysses")
110+
82111
def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.distributed.device_mesh.DeviceMesh):
83112
self._rank = rank
84113
self._world_size = world_size
85114
self._device = device
86115
self._mesh = mesh
87-
if self.ring_degree is None:
88-
self.ring_degree = 1
89-
if self.ulysses_degree is None:
90-
self.ulysses_degree = 1
91-
if self.rotate_method != "allgather":
92-
raise NotImplementedError(
93-
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
116+
117+
if self.ulysses_degree * self.ring_degree > world_size:
118+
raise ValueError(
119+
f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})."
94120
)
95-
if self._flattened_mesh is None:
96-
self._flattened_mesh = self._mesh._flatten()
97-
if self._ring_mesh is None:
98-
self._ring_mesh = self._mesh["ring"]
99-
if self._ulysses_mesh is None:
100-
self._ulysses_mesh = self._mesh["ulysses"]
101-
if self._ring_local_rank is None:
102-
self._ring_local_rank = self._ring_mesh.get_local_rank()
103-
if self._ulysses_local_rank is None:
104-
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()
121+
122+
self._flattened_mesh = self._mesh._flatten()
123+
self._ring_mesh = self._mesh["ring"]
124+
self._ulysses_mesh = self._mesh["ulysses"]
125+
self._ring_local_rank = self._ring_mesh.get_local_rank()
126+
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()
105127

106128

107129
@dataclass
@@ -119,22 +141,22 @@ class ParallelConfig:
119141
_rank: int = None
120142
_world_size: int = None
121143
_device: torch.device = None
122-
_cp_mesh: torch.distributed.device_mesh.DeviceMesh = None
144+
_mesh: torch.distributed.device_mesh.DeviceMesh = None
123145

124146
def setup(
125147
self,
126148
rank: int,
127149
world_size: int,
128150
device: torch.device,
129151
*,
130-
cp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
152+
mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
131153
):
132154
self._rank = rank
133155
self._world_size = world_size
134156
self._device = device
135-
self._cp_mesh = cp_mesh
157+
self._mesh = mesh
136158
if self.context_parallel_config is not None:
137-
self.context_parallel_config.setup(rank, world_size, device, cp_mesh)
159+
self.context_parallel_config.setup(rank, world_size, device, mesh)
138160

139161

140162
@dataclass(frozen=True)

src/diffusers/models/attention_dispatch.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ class _AttentionBackendRegistry:
220220
_backends = {}
221221
_constraints = {}
222222
_supported_arg_names = {}
223-
_supports_context_parallel = {}
223+
_supports_context_parallel = set()
224224
_active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND)
225225
_checks_enabled = DIFFUSERS_ATTN_CHECKS
226226

@@ -237,7 +237,9 @@ def decorator(func):
237237
cls._backends[backend] = func
238238
cls._constraints[backend] = constraints or []
239239
cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys())
240-
cls._supports_context_parallel[backend] = supports_context_parallel
240+
if supports_context_parallel:
241+
cls._supports_context_parallel.add(backend.value)
242+
241243
return func
242244

243245
return decorator
@@ -251,15 +253,12 @@ def list_backends(cls):
251253
return list(cls._backends.keys())
252254

253255
@classmethod
254-
def _is_context_parallel_enabled(
255-
cls, backend: AttentionBackendName, parallel_config: Optional["ParallelConfig"]
256+
def _is_context_parallel_available(
257+
cls,
258+
backend: AttentionBackendName,
256259
) -> bool:
257-
supports_context_parallel = backend in cls._supports_context_parallel
258-
is_degree_greater_than_1 = parallel_config is not None and (
259-
parallel_config.context_parallel_config.ring_degree > 1
260-
or parallel_config.context_parallel_config.ulysses_degree > 1
261-
)
262-
return supports_context_parallel and is_degree_greater_than_1
260+
supports_context_parallel = backend.value in cls._supports_context_parallel
261+
return supports_context_parallel
263262

264263

265264
@contextlib.contextmanager
@@ -306,14 +305,6 @@ def dispatch_attention_fn(
306305
backend_name = AttentionBackendName(backend)
307306
backend_fn = _AttentionBackendRegistry._backends.get(backend_name)
308307

309-
if parallel_config is not None and not _AttentionBackendRegistry._is_context_parallel_enabled(
310-
backend_name, parallel_config
311-
):
312-
raise ValueError(
313-
f"Backend {backend_name} either does not support context parallelism or context parallelism "
314-
f"was enabled with a world size of 1."
315-
)
316-
317308
kwargs = {
318309
"query": query,
319310
"key": key,

src/diffusers/models/modeling_utils.py

Lines changed: 52 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1484,59 +1484,71 @@ def enable_parallelism(
14841484
config: Union[ParallelConfig, ContextParallelConfig],
14851485
cp_plan: Optional[Dict[str, ContextParallelModelPlan]] = None,
14861486
):
1487-
from ..hooks.context_parallel import apply_context_parallel
1488-
from .attention import AttentionModuleMixin
1489-
from .attention_processor import Attention, MochiAttention
1490-
14911487
logger.warning(
14921488
"`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning."
14931489
)
14941490

1491+
if not torch.distributed.is_available() and not torch.distributed.is_initialized():
1492+
raise RuntimeError(
1493+
"torch.distributed must be available and initialized before calling `enable_parallelism`."
1494+
)
1495+
1496+
from ..hooks.context_parallel import apply_context_parallel
1497+
from .attention import AttentionModuleMixin
1498+
from .attention_dispatch import AttentionBackendName, _AttentionBackendRegistry
1499+
from .attention_processor import Attention, MochiAttention
1500+
14951501
if isinstance(config, ContextParallelConfig):
14961502
config = ParallelConfig(context_parallel_config=config)
14971503

1498-
if not torch.distributed.is_initialized():
1499-
raise RuntimeError("torch.distributed must be initialized before calling `enable_parallelism`.")
1500-
15011504
rank = torch.distributed.get_rank()
15021505
world_size = torch.distributed.get_world_size()
15031506
device_type = torch._C._get_accelerator().type
15041507
device_module = torch.get_device_module(device_type)
15051508
device = torch.device(device_type, rank % device_module.device_count())
15061509

1507-
cp_mesh = None
1510+
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
1511+
15081512
if config.context_parallel_config is not None:
1509-
cp_config = config.context_parallel_config
1510-
if cp_config.ring_degree < 1 or cp_config.ulysses_degree < 1:
1511-
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
1512-
if cp_config.ring_degree > 1 and cp_config.ulysses_degree > 1:
1513-
raise ValueError(
1514-
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
1515-
)
1516-
if cp_config.ring_degree * cp_config.ulysses_degree > world_size:
1517-
raise ValueError(
1518-
f"The product of `ring_degree` ({cp_config.ring_degree}) and `ulysses_degree` ({cp_config.ulysses_degree}) must not exceed the world size ({world_size})."
1519-
)
1520-
cp_mesh = torch.distributed.device_mesh.init_device_mesh(
1521-
device_type=device_type,
1522-
mesh_shape=(cp_config.ring_degree, cp_config.ulysses_degree),
1523-
mesh_dim_names=("ring", "ulysses"),
1524-
)
1513+
for module in self.modules():
1514+
if not isinstance(module, attention_classes):
1515+
continue
15251516

1526-
config.setup(rank, world_size, device, cp_mesh=cp_mesh)
1517+
processor = module.processor
1518+
if processor is None or not hasattr(processor, "_attention_backend"):
1519+
continue
15271520

1528-
if cp_plan is None and self._cp_plan is None:
1529-
raise ValueError(
1530-
"`cp_plan` must be provided either as an argument or set in the model's `_cp_plan` attribute."
1531-
)
1532-
cp_plan = cp_plan if cp_plan is not None else self._cp_plan
1521+
attention_backend = processor._attention_backend
1522+
if attention_backend is None:
1523+
attention_backend, _ = _AttentionBackendRegistry.get_active_backend()
1524+
else:
1525+
attention_backend = AttentionBackendName(attention_backend)
1526+
1527+
if not _AttentionBackendRegistry._is_context_parallel_available(attention_backend):
1528+
compatible_backends = sorted(_AttentionBackendRegistry._supports_context_parallel)
1529+
raise ValueError(
1530+
f"Context parallelism is enabled but the attention processor '{processor.__class__.__name__}' "
1531+
f"is using backend '{attention_backend.value}' which does not support context parallelism. "
1532+
f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()` before "
1533+
f"calling `enable_parallelism()`."
1534+
)
15331535

1536+
# All modules use the same attention processor and backend. We don't need to
1537+
# iterate over all modules after checking the first processor
1538+
break
1539+
1540+
mesh = None
15341541
if config.context_parallel_config is not None:
1535-
apply_context_parallel(self, config.context_parallel_config, cp_plan)
1542+
cp_config = config.context_parallel_config
1543+
mesh = torch.distributed.device_mesh.init_device_mesh(
1544+
device_type=device_type,
1545+
mesh_shape=cp_config.mesh_shape,
1546+
mesh_dim_names=cp_config.mesh_dim_names,
1547+
)
15361548

1549+
config.setup(rank, world_size, device, mesh=mesh)
15371550
self._parallel_config = config
15381551

1539-
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
15401552
for module in self.modules():
15411553
if not isinstance(module, attention_classes):
15421554
continue
@@ -1545,6 +1557,14 @@ def enable_parallelism(
15451557
continue
15461558
processor._parallel_config = config
15471559

1560+
if config.context_parallel_config is not None:
1561+
if cp_plan is None and self._cp_plan is None:
1562+
raise ValueError(
1563+
"`cp_plan` must be provided either as an argument or set in the model's `_cp_plan` attribute."
1564+
)
1565+
cp_plan = cp_plan if cp_plan is not None else self._cp_plan
1566+
apply_context_parallel(self, config.context_parallel_config, cp_plan)
1567+
15481568
@classmethod
15491569
def _load_pretrained_model(
15501570
cls,

0 commit comments

Comments
 (0)