@@ -146,7 +146,7 @@ def test_convert_ops():
146146def test_repeated_create_and_destroy ():
147147 collective = TorchCollective ()
148148 with mock .patch ("torch.distributed.init_process_group" ):
149- collective .setup (main_address = "foo" , main_port = 123 )
149+ collective .setup (main_address = "foo" , main_port = " 123" )
150150
151151 assert not os .environ
152152
@@ -157,7 +157,9 @@ def test_repeated_create_and_destroy():
157157 with pytest .raises (RuntimeError , match = "TorchCollective` already owns a group" ):
158158 collective .create_group ()
159159
160- with mock .patch ("torch.distributed.destroy_process_group" ) as destroy_mock :
160+ with mock .patch .dict ("torch.distributed.distributed_c10d._pg_map" , {collective .group : ("" , None )}), mock .patch (
161+ "torch.distributed.destroy_process_group"
162+ ) as destroy_mock :
161163 collective .teardown ()
162164 # this would be called twice if `init_process_group` wasn't patched. once for the group and once for the default
163165 # group
@@ -269,3 +271,38 @@ def _test_two_groups(strategy, left_collective, right_collective):
269271@pytest .mark .skip (reason = "TODO(carmocca): causing hangs in CI" )
270272def test_two_groups ():
271273 collective_launch (_test_two_groups , [torch .device ("cpu" )] * 3 , num_groups = 2 )
274+
275+
276+ def _test_default_process_group (strategy , * collectives ):
277+ for collective in collectives :
278+ assert collective .group == torch .distributed .group .WORLD
279+ world_size = strategy .world_size
280+ for c in collectives :
281+ tensor = torch .tensor (world_size )
282+ r = c .all_reduce (tensor )
283+ assert world_size ** 2 == r
284+
285+
286+ @skip_distributed_unavailable
287+ @RunIf (skip_windows = True )
288+ @mock .patch .dict (os .environ , os .environ .copy (), clear = True ) # sets CUDA_MODULE_LOADING in torch==1.13
289+ def test_default_process_group ():
290+ collective_launch (_test_default_process_group , [torch .device ("cpu" )] * 3 , num_groups = 2 )
291+
292+
293+ @skip_distributed_unavailable
294+ @mock .patch .dict (os .environ , {}, clear = True )
295+ def test_collective_manages_default_group ():
296+ collective = TorchCollective ()
297+ with mock .patch ("torch.distributed.init_process_group" ):
298+ collective .setup (main_address = "foo" , main_port = "123" )
299+
300+ assert TorchCollective .manages_default_group
301+
302+ with mock .patch .object (collective , "_group" ) as mock_group , mock .patch .dict (
303+ "torch.distributed.distributed_c10d._pg_map" , {mock_group : ("" , None )}
304+ ), mock .patch ("torch.distributed.destroy_process_group" ) as destroy_mock :
305+ collective .teardown ()
306+ destroy_mock .assert_called_once_with (mock_group )
307+
308+ assert not TorchCollective .manages_default_group
0 commit comments