@@ -1484,59 +1484,71 @@ def enable_parallelism(
14841484 config : Union [ParallelConfig , ContextParallelConfig ],
14851485 cp_plan : Optional [Dict [str , ContextParallelModelPlan ]] = None ,
14861486 ):
1487- from ..hooks .context_parallel import apply_context_parallel
1488- from .attention import AttentionModuleMixin
1489- from .attention_processor import Attention , MochiAttention
1490-
14911487 logger .warning (
14921488 "`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning."
14931489 )
14941490
1491+ if not torch .distributed .is_available () and not torch .distributed .is_initialized ():
1492+ raise RuntimeError (
1493+ "torch.distributed must be available and initialized before calling `enable_parallelism`."
1494+ )
1495+
1496+ from ..hooks .context_parallel import apply_context_parallel
1497+ from .attention import AttentionModuleMixin
1498+ from .attention_dispatch import AttentionBackendName , _AttentionBackendRegistry
1499+ from .attention_processor import Attention , MochiAttention
1500+
14951501 if isinstance (config , ContextParallelConfig ):
14961502 config = ParallelConfig (context_parallel_config = config )
14971503
1498- if not torch .distributed .is_initialized ():
1499- raise RuntimeError ("torch.distributed must be initialized before calling `enable_parallelism`." )
1500-
15011504 rank = torch .distributed .get_rank ()
15021505 world_size = torch .distributed .get_world_size ()
15031506 device_type = torch ._C ._get_accelerator ().type
15041507 device_module = torch .get_device_module (device_type )
15051508 device = torch .device (device_type , rank % device_module .device_count ())
15061509
1507- cp_mesh = None
1510+ attention_classes = (Attention , MochiAttention , AttentionModuleMixin )
1511+
15081512 if config .context_parallel_config is not None :
1509- cp_config = config .context_parallel_config
1510- if cp_config .ring_degree < 1 or cp_config .ulysses_degree < 1 :
1511- raise ValueError ("`ring_degree` and `ulysses_degree` must be greater than or equal to 1." )
1512- if cp_config .ring_degree > 1 and cp_config .ulysses_degree > 1 :
1513- raise ValueError (
1514- "Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
1515- )
1516- if cp_config .ring_degree * cp_config .ulysses_degree > world_size :
1517- raise ValueError (
1518- f"The product of `ring_degree` ({ cp_config .ring_degree } ) and `ulysses_degree` ({ cp_config .ulysses_degree } ) must not exceed the world size ({ world_size } )."
1519- )
1520- cp_mesh = torch .distributed .device_mesh .init_device_mesh (
1521- device_type = device_type ,
1522- mesh_shape = (cp_config .ring_degree , cp_config .ulysses_degree ),
1523- mesh_dim_names = ("ring" , "ulysses" ),
1524- )
1513+ for module in self .modules ():
1514+ if not isinstance (module , attention_classes ):
1515+ continue
15251516
1526- config .setup (rank , world_size , device , cp_mesh = cp_mesh )
1517+ processor = module .processor
1518+ if processor is None or not hasattr (processor , "_attention_backend" ):
1519+ continue
15271520
1528- if cp_plan is None and self ._cp_plan is None :
1529- raise ValueError (
1530- "`cp_plan` must be provided either as an argument or set in the model's `_cp_plan` attribute."
1531- )
1532- cp_plan = cp_plan if cp_plan is not None else self ._cp_plan
1521+ attention_backend = processor ._attention_backend
1522+ if attention_backend is None :
1523+ attention_backend , _ = _AttentionBackendRegistry .get_active_backend ()
1524+ else :
1525+ attention_backend = AttentionBackendName (attention_backend )
1526+
1527+ if not _AttentionBackendRegistry ._is_context_parallel_available (attention_backend ):
1528+ compatible_backends = sorted (_AttentionBackendRegistry ._supports_context_parallel )
1529+ raise ValueError (
1530+ f"Context parallelism is enabled but the attention processor '{ processor .__class__ .__name__ } ' "
1531+ f"is using backend '{ attention_backend .value } ' which does not support context parallelism. "
1532+ f"Please set a compatible attention backend: { compatible_backends } using `model.set_attention_backend()` before "
1533+ f"calling `enable_parallelism()`."
1534+ )
15331535
1536+ # All modules use the same attention processor and backend. We don't need to
1537+ # iterate over all modules after checking the first processor
1538+ break
1539+
1540+ mesh = None
15341541 if config .context_parallel_config is not None :
1535- apply_context_parallel (self , config .context_parallel_config , cp_plan )
1542+ cp_config = config .context_parallel_config
1543+ mesh = torch .distributed .device_mesh .init_device_mesh (
1544+ device_type = device_type ,
1545+ mesh_shape = cp_config .mesh_shape ,
1546+ mesh_dim_names = cp_config .mesh_dim_names ,
1547+ )
15361548
1549+ config .setup (rank , world_size , device , mesh = mesh )
15371550 self ._parallel_config = config
15381551
1539- attention_classes = (Attention , MochiAttention , AttentionModuleMixin )
15401552 for module in self .modules ():
15411553 if not isinstance (module , attention_classes ):
15421554 continue
@@ -1545,6 +1557,14 @@ def enable_parallelism(
15451557 continue
15461558 processor ._parallel_config = config
15471559
1560+ if config .context_parallel_config is not None :
1561+ if cp_plan is None and self ._cp_plan is None :
1562+ raise ValueError (
1563+ "`cp_plan` must be provided either as an argument or set in the model's `_cp_plan` attribute."
1564+ )
1565+ cp_plan = cp_plan if cp_plan is not None else self ._cp_plan
1566+ apply_context_parallel (self , config .context_parallel_config , cp_plan )
1567+
15481568 @classmethod
15491569 def _load_pretrained_model (
15501570 cls ,
0 commit comments