Skip to content

Commit 88669aa

Browse files
committed
[Local Tensor] Replace dry_run.py with local tensor mode implementation
Replaces `dry_run.py` implementation with local tensor mode for DRY_RUN configuration validation. Local tensor mode provides deeper validation coverage, including `ParallelDims` creation, which the previous implementation could not verify. **Note:** Currently returns early before `init_weights()` due to a known limitation in local tensor mode. This still validates more of the pipeline than the previous approach. ghstack-source-id: 27b8bad Pull-Request: #2057
1 parent 22e959a commit 88669aa

File tree

5 files changed

+74
-164
lines changed

5 files changed

+74
-164
lines changed

run_train.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"}
2222
if [ "$DRY_RUN" = "1" ]; then
2323
# Dry run mode: validate configuration without GPU/distributed setup
2424
echo "Running in DRY RUN mode - configuration validation only"
25-
python scripts/dry_run.py --job.config_file ${CONFIG_FILE} "$@"
25+
NGPU="${NGPU}" LOCAL_RANK=0 python3 -m "${TRAIN_FILE}" --job.config_file "${CONFIG_FILE}" "$@" --comm.fake_backend --training.steps=1
2626
else
2727
# Normal training with torchrun
2828
PYTORCH_ALLOC_CONF="expandable_segments:True" \

scripts/dry_run.py

Lines changed: 0 additions & 159 deletions
This file was deleted.

torchtitan/config/job_config.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,24 @@ class Comm:
791791
save_traces_file_prefix: str = "rank_"
792792
"""Flight recorder trace files prefix"""
793793

794+
fake_backend: bool = False
795+
"""Fake comm backend for dry run mode only"""
796+
797+
local_tensor_mode: bool = False
798+
"""
799+
Local tensor mode for debugging purposes. There will be only one process
800+
regardless of the number of GPUs. LocalTensor will simulate the
801+
computation by running one rank after another. While the performance will
802+
be slow, the numerics should be the same. This enables us to verify
803+
numerics with fewer GPUs. For example, we can directly run 5D
804+
parallelisms within a single node to reduce the combinations we need to
805+
use in integration tests.
806+
807+
NOTE: This is an experimental feature.
808+
809+
NOTE: fake_backend should be set to True when local_tensor_mode is True.
810+
"""
811+
794812

795813
@dataclass
796814
class MemoryEstimation:

torchtitan/distributed/utils.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,12 +258,43 @@ def maybe_enable_amp(
258258
)
259259

260260

261+
def init_fake_mode(world_size: int) -> int:
262+
"""Initialize fake backend
263+
264+
Args:
265+
world_size: The number of GPUs to simulate
266+
267+
Returns:
268+
The world size
269+
"""
270+
torch.distributed.init_process_group(
271+
"fake",
272+
rank=0,
273+
world_size=world_size,
274+
)
275+
return world_size
276+
277+
261278
def init_distributed(
262279
comm_config: CommConfig,
263280
enable_cpu_backend: bool = False,
264281
base_folder: str = "",
265282
ranks: list[int] | None = None,
266-
):
283+
) -> int:
284+
if comm_config.fake_backend:
285+
ngpu_str = os.environ.get("NGPU")
286+
if ngpu_str is None:
287+
raise ValueError(
288+
"NGPU environment variable must be set when using local_tensor_mode"
289+
)
290+
try:
291+
world_size = int(ngpu_str)
292+
except ValueError as e:
293+
raise ValueError(
294+
f"NGPU environment variable must be a valid integer, got: {ngpu_str}"
295+
) from e
296+
return init_fake_mode(world_size)
297+
267298
def _warn_overwrite_env(env, val):
268299
if env in os.environ:
269300
logger.warning(
@@ -309,6 +340,8 @@ def _get_distributed_backend(enable_cpu_backend):
309340
_ranks=ranks if ranks is not None else [],
310341
)
311342

343+
return torch.distributed.get_world_size()
344+
312345

313346
def set_pg_timeouts(timeout, world_mesh):
314347
"""

torchtitan/train.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from typing import Any, Generator, Iterable
1212

1313
import torch
14+
from torch.distributed import _local_tensor
1415

1516
from torch.distributed.elastic.multiprocessing.errors import record
1617

@@ -208,6 +209,12 @@ def __init__(self, job_config: JobConfig):
208209
self.loss_fn, self.gradient_accumulation_steps
209210
)
210211

212+
# TODO(local_tensor): Remove this early return once LocalTensor supports
213+
# init_weights().Currently skipping parallelism setup and model initialization
214+
# in local tensor mode.
215+
if job_config.comm.local_tensor_mode:
216+
return
217+
211218
# apply parallelisms and initialization
212219
if parallel_dims.pp_enabled:
213220
if not self.train_spec.pipelining_fn:
@@ -360,15 +367,19 @@ def __init__(self, job_config: JobConfig):
360367

361368
def init_distributed(self) -> ParallelDims:
362369
job_config = self.job_config
363-
dist_utils.init_distributed(
370+
world_size = dist_utils.init_distributed(
364371
job_config.comm,
365372
enable_cpu_backend=job_config.training.enable_cpu_offload,
366373
base_folder=job_config.job.dump_folder,
367374
)
368375

369-
world_size = int(os.environ["WORLD_SIZE"])
370-
parallelism_config = job_config.parallelism
376+
if job_config.comm.local_tensor_mode:
377+
if not job_config.comm.fake_backend:
378+
raise ValueError("LocalTensor can only be used with fake backend.")
379+
lm = _local_tensor.LocalTensorMode(world_size)
380+
lm.__enter__()
371381

382+
parallelism_config = job_config.parallelism
372383
return ParallelDims(
373384
dp_shard=parallelism_config.data_parallel_shard_degree,
374385
dp_replicate=parallelism_config.data_parallel_replicate_degree,
@@ -718,6 +729,13 @@ def main(trainer_class: type[Trainer]) -> None:
718729
try:
719730
trainer = trainer_class(config)
720731

732+
# TODO(local_tensor): Remove this special case once LocalTensor supports
733+
# init_weights(). In local tensor mode, skip training/checkpointing as the
734+
# model is not fully initialized
735+
if config.comm.local_tensor_mode:
736+
logger.info("Local tensor mode enabled - skipping training execution")
737+
return
738+
721739
if config.checkpoint.create_seed_checkpoint:
722740
assert (
723741
int(os.environ["WORLD_SIZE"]) == 1

0 commit comments

Comments
 (0)