@@ -102,8 +102,12 @@ be found in
102102 os.environ[' MASTER_ADDR' ] = ' localhost'
103103 os.environ[' MASTER_PORT' ] = ' 12355'
104104
105+ # We want to be able to train our model on an `accelerator <https://pytorch.org/docs/stable/torch.html#accelerators>`__
106+ # such as CUDA, MPS, MTIA, or XPU.
107+ acc = torch.accelerator.current_accelerator()
108+ backend = torch.distributed.get_default_backend_for_device(acc)
105109 # initialize the process group
106- dist.init_process_group(" gloo " , rank = rank, world_size = world_size)
110+ dist.init_process_group(backend , rank = rank, world_size = world_size)
107111
108112 def cleanup ():
109113 dist.destroy_process_group()
@@ -216,8 +220,11 @@ and elasticity support, please refer to `TorchElastic <https://pytorch.org/elast
216220 # Use a barrier() to make sure that process 1 loads the model after process
217221 # 0 saves it.
218222 dist.barrier()
223+ # We want to be able to train our model on an `accelerator <https://pytorch.org/docs/stable/torch.html#accelerators>`__
224+ # such as CUDA, MPS, MTIA, or XPU.
225+ acc = torch.accelerator.current_accelerator()
219226 # configure map_location properly
220- map_location = {' cuda: %d ' % 0 : ' cuda: %d ' % rank}
227+ map_location = {f ' { acc } :0 ' : f ' { acc } : { rank} ' }
221228 ddp_model.load_state_dict(
222229 torch.load(CHECKPOINT_PATH , map_location = map_location, weights_only = True ))
223230
@@ -295,7 +302,7 @@ either the application or the model ``forward()`` method.
295302
296303
297304 if __name__ == " __main__" :
298- n_gpus = torch.cuda .device_count()
305+ n_gpus = torch.accelerator .device_count()
299306 assert n_gpus >= 2 , f " Requires at least 2 GPUs to run, but got { n_gpus} "
300307 world_size = n_gpus
301308 run_demo(demo_basic, world_size)
@@ -331,12 +338,14 @@ Let's still use the Toymodel example and create a file named ``elastic_ddp.py``.
331338
332339
333340 def demo_basic ():
334- torch.cuda.set_device(int (os.environ[" LOCAL_RANK" ]))
335- dist.init_process_group(" nccl" )
341+ torch.accelerator.set_device_index(int (os.environ[" LOCAL_RANK" ]))
342+ acc = torch.accelerator.current_accelerator()
343+ backend = torch.distributed.get_default_backend_for_device(acc)
344+ dist.init_process_group(backend)
336345 rank = dist.get_rank()
337346 print (f " Start running basic DDP example on rank { rank} . " )
338347 # create model and move it to GPU with id rank
339- device_id = rank % torch.cuda .device_count()
348+ device_id = rank % torch.accelerator .device_count()
340349 model = ToyModel().to(device_id)
341350 ddp_model = DDP(model, device_ids = [device_id])
342351 loss_fn = nn.MSELoss()
0 commit comments