Skip to content

Commit bb308da

Browse files
authored
fix set_determinism on single gpu (#1983)
**Summary** Currently, running `CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" NGPU=1 CUDA_VISIBLE_DEVICES=0 ./run_train.sh` returns ``` dim for dim in distinct_seed_mesh_dims if dim in world_mesh.mesh_dim_names ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ TypeError: argument of type 'NoneType' is not iterable ``` This PR fixes the case for a single GPU or when world_mesh.mesh_dim_names is None **Testing** Added unit test to `tests/unit_tests/test_set_determinism.py`
1 parent 8659543 commit bb308da

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

tests/unit_tests/test_set_determinism.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,27 @@ def test_seed_uniqueness_3d_mesh(self, mock_get_rank, mock_get_world_size):
208208
f"Expected {mesh_sizes[0] * mesh_sizes[1]} unique seeds for (dp_shard, dp_replicate) combinations",
209209
)
210210

211+
@patch("torch.distributed.distributed_c10d.get_world_size")
212+
@patch("torch.distributed.distributed_c10d.get_rank")
213+
def test_set_determinism_single_gpu(self, mock_get_rank, mock_get_world_size):
214+
"""Test set_determinism for single GPU (empty mesh)"""
215+
mock_get_world_size.return_value = 1
216+
mock_get_rank.return_value = 0
217+
218+
base_seed = 42
219+
220+
fake_mesh = MagicMock()
221+
fake_mesh.mesh_dim_names = None
222+
fake_mesh.get_coordinate.return_value = None
223+
224+
debug_config = DebugConfig(seed=base_seed, deterministic=False)
225+
set_determinism(
226+
world_mesh=fake_mesh,
227+
device=self.device,
228+
debug_config=debug_config,
229+
distinct_seed_mesh_dims=["pp"],
230+
)
231+
211232

212233
if __name__ == "__main__":
213234
unittest.main()

torchtitan/distributed/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,9 @@ def set_determinism(
145145
# and choose a unique seed for each rank on the PP mesh.
146146
# We support multiple distinct dimensions by adding each distinct dimension's local rank to the seed.
147147
distinct_dims_in_mesh = [
148-
dim for dim in distinct_seed_mesh_dims if dim in world_mesh.mesh_dim_names
148+
dim
149+
for dim in distinct_seed_mesh_dims
150+
if world_mesh.mesh_dim_names and dim in world_mesh.mesh_dim_names
149151
]
150152

151153
if c10d.get_world_size() > 1 and distinct_dims_in_mesh:

0 commit comments

Comments
 (0)