Skip to content

Commit 757d716

Browse files
committed
[Local Tensor] Replace dry_run.py with local tensor mode implementation
Replaces `dry_run.py` implementation with local tensor mode for DRY_RUN configuration validation. Local tensor mode provides deeper validation coverage, including `ParallelDims` creation, which the previous implementation could not verify. **Note:** Currently returns early before `init_weights()` due to a known limitation in local tensor mode. This still validates more of the pipeline than the previous approach. ghstack-source-id: 5ea1d46 Pull-Request: #2057
1 parent 22e959a commit 757d716

File tree

5 files changed

+82
-169
lines changed

5 files changed

+82
-169
lines changed

run_train.sh

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,21 @@ set -ex
1010
# use envs as local overwrites for convenience
1111
# e.g.
1212
# LOG_RANK=0,1 NGPU=4 ./run_train.sh
13-
# DRY_RUN=1 ./run_train.sh # for config validation without GPU
13+
# COMM_MODE="fake_backend" ./run_train.sh # for config validation without GPU
14+
# COMM_MODE="local_tensor" ./run_train.sh # for local tensor debugging mode
1415
NGPU=${NGPU:-"8"}
1516
export LOG_RANK=${LOG_RANK:-0}
1617
CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model.toml"}
1718
TRAIN_FILE=${TRAIN_FILE:-"torchtitan.train"}
18-
DRY_RUN=${DRY_RUN:-0}
19+
# COMM_MODE options: "fake_backend" (dry run), "local_tensor" (debug mode), or empty for normal training
20+
COMM_MODE=${COMM_MODE:-""}
1921

2022
TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"}
2123

22-
if [ "$DRY_RUN" = "1" ]; then
23-
# Dry run mode: validate configuration without GPU/distributed setup
24-
echo "Running in DRY RUN mode - configuration validation only"
25-
python scripts/dry_run.py --job.config_file ${CONFIG_FILE} "$@"
24+
if [ -n "$COMM_MODE" ]; then
25+
# Communication mode specified: validate configuration or run in debug mode
26+
echo "Running with comm_mode=${COMM_MODE}"
27+
NGPU="${NGPU}" LOCAL_RANK=0 python3 -m "${TRAIN_FILE}" --job.config_file "${CONFIG_FILE}" "$@" --comm.comm_mode=${COMM_MODE} --training.steps=1
2628
else
2729
# Normal training with torchrun
2830
PYTORCH_ALLOC_CONF="expandable_segments:True" \

scripts/dry_run.py

Lines changed: 0 additions & 159 deletions
This file was deleted.

torchtitan/config/job_config.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,22 @@ class Comm:
791791
save_traces_file_prefix: str = "rank_"
792792
"""Flight recorder trace files prefix"""
793793

794+
comm_mode: Literal["default", "fake_backend", "local_tensor"] = "default"
795+
"""
796+
Communication mode for distributed training.
797+
798+
Options:
799+
- "default": Normal distributed training with real communication
800+
- "fake_backend": Fake comm backend for dry run mode only (configuration validation without GPU)
801+
- "local_tensor": Local tensor mode for debugging purposes. There will be only one process
802+
regardless of the number of GPUs. LocalTensor will simulate the computation by running one
803+
rank after another. While the performance will be slow, the numerics should be the same.
804+
This enables us to verify numerics with fewer GPUs. For example, we can directly run 5D
805+
parallelisms within a single node to reduce the combinations we need to use in integration tests.
806+
807+
NOTE: local_tensor is an experimental feature and automatically uses fake_backend internally.
808+
"""
809+
794810

795811
@dataclass
796812
class MemoryEstimation:

torchtitan/distributed/utils.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,12 +258,53 @@ def maybe_enable_amp(
258258
)
259259

260260

261+
def init_fake_mode(world_size: int, comm_mode: str = "fake_backend"):
262+
"""Initialize fake backend
263+
264+
Args:
265+
world_size: The number of GPUs to simulate
266+
comm_mode: Communication mode ("fake_backend" or "local_tensor")
267+
268+
Returns:
269+
The world size
270+
"""
271+
torch.distributed.init_process_group(
272+
"fake",
273+
rank=0,
274+
world_size=world_size,
275+
)
276+
277+
# If local_tensor mode is enabled, initialize LocalTensorMode context
278+
if comm_mode == "local_tensor":
279+
from torch.distributed import _local_tensor
280+
281+
lm = _local_tensor.LocalTensorMode(world_size)
282+
lm.__enter__()
283+
284+
return world_size
285+
286+
261287
def init_distributed(
262288
comm_config: CommConfig,
263289
enable_cpu_backend: bool = False,
264290
base_folder: str = "",
265291
ranks: list[int] | None = None,
266-
):
292+
) -> int:
293+
if comm_config.comm_mode in ("fake_backend", "local_tensor"):
294+
ngpu_str = os.environ.get("NGPU")
295+
if ngpu_str is None:
296+
raise ValueError(
297+
f"NGPU environment variable must be set when using comm_mode={comm_config.comm_mode}"
298+
)
299+
try:
300+
world_size = int(ngpu_str)
301+
except ValueError as e:
302+
raise ValueError(
303+
f"NGPU environment variable must be a valid integer, got: {ngpu_str}"
304+
) from e
305+
init_fake_mode(world_size, comm_config.comm_mode)
306+
return world_size
307+
267308
def _warn_overwrite_env(env, val):
268309
if env in os.environ:
269310
logger.warning(
@@ -309,6 +350,8 @@ def _get_distributed_backend(enable_cpu_backend):
309350
_ranks=ranks if ranks is not None else [],
310351
)
311352

353+
return torch.distributed.get_world_size()
354+
312355

313356
def set_pg_timeouts(timeout, world_mesh):
314357
"""

torchtitan/train.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,12 @@ def __init__(self, job_config: JobConfig):
208208
self.loss_fn, self.gradient_accumulation_steps
209209
)
210210

211+
# TODO(local_tensor): Remove this special case once LocalTensor supports
212+
# init_weights(). Currently it fails occasionally.
213+
if job_config.comm.comm_mode == "local_tensor":
214+
logger.info("Local tensor mode enabled - skipping training execution")
215+
return
216+
211217
# apply parallelisms and initialization
212218
if parallel_dims.pp_enabled:
213219
if not self.train_spec.pipelining_fn:
@@ -360,15 +366,13 @@ def __init__(self, job_config: JobConfig):
360366

361367
def init_distributed(self) -> ParallelDims:
362368
job_config = self.job_config
363-
dist_utils.init_distributed(
369+
world_size = dist_utils.init_distributed(
364370
job_config.comm,
365371
enable_cpu_backend=job_config.training.enable_cpu_offload,
366372
base_folder=job_config.job.dump_folder,
367373
)
368374

369-
world_size = int(os.environ["WORLD_SIZE"])
370375
parallelism_config = job_config.parallelism
371-
372376
return ParallelDims(
373377
dp_shard=parallelism_config.data_parallel_shard_degree,
374378
dp_replicate=parallelism_config.data_parallel_replicate_degree,
@@ -718,6 +722,13 @@ def main(trainer_class: type[Trainer]) -> None:
718722
try:
719723
trainer = trainer_class(config)
720724

725+
# TODO(local_tensor): Remove this special case once LocalTensor supports
726+
# init_weights() and foreach_allgather. In local tensor mode, skip
727+
# training/checkpointing as the # model is not fully initialized
728+
if config.comm.comm_mode == "local_tensor":
729+
logger.info("Local tensor mode enabled - skipping training execution")
730+
return
731+
721732
if config.checkpoint.create_seed_checkpoint:
722733
assert (
723734
int(os.environ["WORLD_SIZE"]) == 1

0 commit comments

Comments
 (0)