File tree Expand file tree Collapse file tree 1 file changed +11
-7
lines changed Expand file tree Collapse file tree 1 file changed +11
-7
lines changed Original file line number Diff line number Diff line change @@ -108,9 +108,16 @@ def init_distributed_device_so(
108108 world_size = 1
109109 global_rank = 0
110110 local_rank = 0
111+ device_type , * device_idx = device .split (':' , maxsplit = 1 )
112+
111113 if dist_backend is None :
112- # FIXME sane defaults for other device backends?
113- dist_backend = 'nccl' if 'cuda' in device else 'gloo'
114+ # FIXME: verify that ROCm transform nccl to rccl
115+ dist_backends = {
116+ "xpu" : "ccl" ,
117+ "hpu" : "hccl" ,
118+ "cuda" : "nccl" ,
119+ }
120+ dist_backend = dist_backends .get (device_type , 'gloo' )
114121 dist_url = dist_url or 'env://'
115122
116123 # TBD, support horovod?
@@ -150,18 +157,15 @@ def init_distributed_device_so(
150157 global_rank = torch .distributed .get_rank ()
151158 distributed = True
152159
153- if 'cuda' in device :
160+ if device_type == 'cuda' :
154161 assert torch .cuda .is_available (), f'CUDA is not available but { device } was specified.'
155162
156163 if distributed and device != 'cpu' :
157- device , * device_idx = device .split (':' , maxsplit = 1 )
158-
159164 # Ignore manually specified device index in distributed mode and
160165 # override with resolved local rank, fewer headaches in most setups.
161166 if device_idx :
162167 _logger .warning (f'device index { device_idx [0 ]} removed from specified ({ device } ).' )
163-
164- device = f'{ device } :{ local_rank } '
168+ device = f'{ device_type } :{ local_rank } '
165169
166170 if device .startswith ('cuda:' ):
167171 torch .cuda .set_device (device )
You can’t perform that action at this time.
0 commit comments