File tree Expand file tree Collapse file tree 2 files changed +3
-2
lines changed Expand file tree Collapse file tree 2 files changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -45,6 +45,7 @@ if [[ "${SHARD_NUMBER:-2}" == "2" ]]; then
4545 # DTensor tests
4646 time python test/run_test.py --verbose -i distributed/tensor/test_random_ops
4747 time python test/run_test.py --verbose -i distributed/tensor/test_dtensor_compile
48+ time python test/run_test.py --verbose -i distributed/tensor/test_utils.py
4849
4950 # DeviceMesh test
5051 time python test/run_test.py --verbose -i distributed/test_device_mesh
Original file line number Diff line number Diff line change @@ -284,12 +284,12 @@ def compute_global_tensor_shape(
284284 if isinstance (placements [0 ], Replicate ):
285285 return shape
286286 elif isinstance (placements [0 ], Shard ):
287- local_shape = torch .tensor (list (shape ))
287+ local_shape = torch .tensor (list (shape ), device = mesh . device_type )
288288 gathered_shaped_tensors = [
289289 torch .empty_like (local_shape , device = local_shape .device )
290290 for _ in range (mesh .size ())
291291 ]
292- funcol .all_gather_inplace (gathered_shaped_tensors , local_shape )
292+ funcol .all_gather_inplace (gathered_shaped_tensors , local_shape , mesh )
293293 sharded_dim_sum = 0
294294 shard_dim = placements [0 ].dim
295295 other_dims = [d for d in range (mesh .ndim ) if d != shard_dim ]
You can’t perform that action at this time.
0 commit comments