Skip to content

Commit 4bd18f5

Browse files
committed
[Local Tensor] Replace dry_run.py with local tensor mode implementation
Replaces `dry_run.py` implementation with local tensor mode for DRY_RUN configuration validation. Local tensor mode provides deeper validation coverage, including `ParallelDims` creation, which the previous implementation could not verify. **Note:** Currently returns early before `init_weights()` due to a known limitation in local tensor mode. This still validates more of the pipeline than the previous approach. ghstack-source-id: c37e849 Pull-Request: #2057
1 parent 22e959a commit 4bd18f5

File tree

5 files changed

+55
-163
lines changed

5 files changed

+55
-163
lines changed

run_train.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"}
2222
if [ "$DRY_RUN" = "1" ]; then
2323
# Dry run mode: validate configuration without GPU/distributed setup
2424
echo "Running in DRY RUN mode - configuration validation only"
25-
python scripts/dry_run.py --job.config_file ${CONFIG_FILE} "$@"
25+
NGPU="${NGPU}" LOCAL_RANK=0 python3 -m "${TRAIN_FILE}" --job.config_file "${CONFIG_FILE}" "$@" --comm.local_tensor_mode
2626
else
2727
# Normal training with torchrun
2828
PYTORCH_ALLOC_CONF="expandable_segments:True" \

scripts/dry_run.py

Lines changed: 0 additions & 159 deletions
This file was deleted.

torchtitan/config/job_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,9 @@ class Comm:
791791
save_traces_file_prefix: str = "rank_"
792792
"""Flight recorder trace files prefix"""
793793

794+
local_tensor_mode: bool = False
795+
"""Local tensor mode, for debugging purposes. This is an experimental feature."""
796+
794797

795798
@dataclass
796799
class MemoryEstimation:

torchtitan/distributed/utils.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import torch.distributed._functional_collectives as funcol
1515
import torch.distributed.distributed_c10d as c10d
1616
from torch import distributed as dist
17+
from torch.distributed import _local_tensor
1718
from torch.distributed.device_mesh import DeviceMesh
1819
from torch.distributed.tensor import DTensor
1920

@@ -258,12 +259,45 @@ def maybe_enable_amp(
258259
)
259260

260261

262+
def init_local_tensor_mode(world_size: int) -> int:
263+
"""Initialize local tensor mode for debugging purposes.
264+
265+
Args:
266+
world_size: The number of GPUs to simulate
267+
268+
Returns:
269+
The world size
270+
"""
271+
torch.distributed.init_process_group(
272+
"fake",
273+
rank=0,
274+
world_size=world_size,
275+
)
276+
lm = _local_tensor.LocalTensorMode(world_size)
277+
lm.__enter__()
278+
return world_size
279+
280+
261281
def init_distributed(
262282
comm_config: CommConfig,
263283
enable_cpu_backend: bool = False,
264284
base_folder: str = "",
265285
ranks: list[int] | None = None,
266-
):
286+
) -> int:
287+
if comm_config.local_tensor_mode:
288+
ngpu_str = os.environ.get("NGPU")
289+
if ngpu_str is None:
290+
raise ValueError(
291+
"NGPU environment variable must be set when using local_tensor_mode"
292+
)
293+
try:
294+
world_size = int(ngpu_str)
295+
except ValueError as e:
296+
raise ValueError(
297+
f"NGPU environment variable must be a valid integer, got: {ngpu_str}"
298+
) from e
299+
return init_local_tensor_mode(world_size)
300+
267301
def _warn_overwrite_env(env, val):
268302
if env in os.environ:
269303
logger.warning(
@@ -309,6 +343,8 @@ def _get_distributed_backend(enable_cpu_backend):
309343
_ranks=ranks if ranks is not None else [],
310344
)
311345

346+
return torch.distributed.get_world_size()
347+
312348

313349
def set_pg_timeouts(timeout, world_mesh):
314350
"""

torchtitan/train.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,12 @@ def __init__(self, job_config: JobConfig):
208208
self.loss_fn, self.gradient_accumulation_steps
209209
)
210210

211+
# TODO(local_tensor): Remove this early return once LocalTensor supports
212+
# init_weights().Currently skipping parallelism setup and model initialization
213+
# in local tensor mode.
214+
if job_config.comm.local_tensor_mode:
215+
return
216+
211217
# apply parallelisms and initialization
212218
if parallel_dims.pp_enabled:
213219
if not self.train_spec.pipelining_fn:
@@ -360,13 +366,12 @@ def __init__(self, job_config: JobConfig):
360366

361367
def init_distributed(self) -> ParallelDims:
362368
job_config = self.job_config
363-
dist_utils.init_distributed(
369+
world_size = dist_utils.init_distributed(
364370
job_config.comm,
365371
enable_cpu_backend=job_config.training.enable_cpu_offload,
366372
base_folder=job_config.job.dump_folder,
367373
)
368374

369-
world_size = int(os.environ["WORLD_SIZE"])
370375
parallelism_config = job_config.parallelism
371376

372377
return ParallelDims(
@@ -718,6 +723,13 @@ def main(trainer_class: type[Trainer]) -> None:
718723
try:
719724
trainer = trainer_class(config)
720725

726+
# TODO(local_tensor): Remove this special case once LocalTensor supports
727+
# init_weights(). In local tensor mode, skip training/checkpointing as the
728+
# model is not fully initialized
729+
if config.comm.local_tensor_mode:
730+
logger.info("Local tensor mode enabled - skipping training execution")
731+
return
732+
721733
if config.checkpoint.create_seed_checkpoint:
722734
assert (
723735
int(os.environ["WORLD_SIZE"]) == 1

0 commit comments

Comments
 (0)