From c4126634d24d4633848e073c435f727363e79f29 Mon Sep 17 00:00:00 2001 From: junjzhang Date: Mon, 3 Mar 2025 22:25:10 +0800 Subject: [PATCH 1/5] finished training with hf models --- torchtitan/experiments/__init__.py | 5 + .../experiments/train_llama_hf/README.md | 23 + .../experiments/train_llama_hf/__init__.py | 10 + .../experiments/train_llama_hf/dataset.py | 127 +++++ .../train_llama_hf/extra_requirements.txt | 2 + .../train_llama_hf/hf_weights_utils.py | 188 +++++++ torchtitan/experiments/train_llama_hf/loss.py | 14 + .../train_llama_hf/model/__init__.py | 32 ++ .../train_llama_hf/model/parallelize_llama.py | 382 ++++++++++++++ .../train_llama_hf/model/pipeline_llama.py | 335 +++++++++++++ .../model/train_configs/llama3_8b_hf.toml | 60 +++ .../experiments/train_llama_hf/run_train.sh | 28 ++ .../train_llama_hf/test_loading_hf_weights.py | 46 ++ .../test_loading_hf_weights_helper.py | 102 ++++ .../train_llama_hf/train_llama_hf.py | 472 ++++++++++++++++++ 15 files changed, 1826 insertions(+) create mode 100644 torchtitan/experiments/__init__.py create mode 100644 torchtitan/experiments/train_llama_hf/README.md create mode 100644 torchtitan/experiments/train_llama_hf/__init__.py create mode 100644 torchtitan/experiments/train_llama_hf/dataset.py create mode 100644 torchtitan/experiments/train_llama_hf/extra_requirements.txt create mode 100644 torchtitan/experiments/train_llama_hf/hf_weights_utils.py create mode 100644 torchtitan/experiments/train_llama_hf/loss.py create mode 100644 torchtitan/experiments/train_llama_hf/model/__init__.py create mode 100644 torchtitan/experiments/train_llama_hf/model/parallelize_llama.py create mode 100644 torchtitan/experiments/train_llama_hf/model/pipeline_llama.py create mode 100644 torchtitan/experiments/train_llama_hf/model/train_configs/llama3_8b_hf.toml create mode 100644 torchtitan/experiments/train_llama_hf/run_train.sh create mode 100644 torchtitan/experiments/train_llama_hf/test_loading_hf_weights.py create mode 100644 torchtitan/experiments/train_llama_hf/test_loading_hf_weights_helper.py create mode 100644 torchtitan/experiments/train_llama_hf/train_llama_hf.py diff --git a/torchtitan/experiments/__init__.py b/torchtitan/experiments/__init__.py new file mode 100644 index 0000000000..2e41cd717f --- /dev/null +++ b/torchtitan/experiments/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchtitan/experiments/train_llama_hf/README.md b/torchtitan/experiments/train_llama_hf/README.md new file mode 100644 index 0000000000..08b0f081d5 --- /dev/null +++ b/torchtitan/experiments/train_llama_hf/README.md @@ -0,0 +1,23 @@ +# Training LLAMA with HF weights + +This directory contains scripts and configs for training LLAMA with HF weights using TorchTitan. + +## Usage + +### Install extra dependencies + +```bash +pip install -r extra_requirements.txt +``` + +### Test loading HF weights + +```bash +pytest test_loading_hf_weights.py +``` + +### Run training + +```bash +LOG_RANK=7 bash run_train.sh +``` diff --git a/torchtitan/experiments/train_llama_hf/__init__.py b/torchtitan/experiments/train_llama_hf/__init__.py new file mode 100644 index 0000000000..9819d6e6e3 --- /dev/null +++ b/torchtitan/experiments/train_llama_hf/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Llama 3 is licensed under the LLAMA 3 Community License, +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +import torchtitan.experiments.train_llama_hf.model # noqa: F401 diff --git a/torchtitan/experiments/train_llama_hf/dataset.py b/torchtitan/experiments/train_llama_hf/dataset.py new file mode 100644 index 0000000000..a77e5102e8 --- /dev/null +++ b/torchtitan/experiments/train_llama_hf/dataset.py @@ -0,0 +1,127 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch + +from datasets import Dataset +from datasets.distributed import split_dataset_by_node +from torch.distributed.checkpoint.stateful import Stateful +from torch.utils.data import IterableDataset +from transformers import PreTrainedTokenizerBase + +from torchtitan.components.dataloader import ParallelAwareDataloader +from torchtitan.config_manager import JobConfig +from torchtitan.datasets.hf_datasets import _validate_dataset +from torchtitan.tools.logging import logger + + +class HuggingFaceDataset(IterableDataset, Stateful): + def __init__( + self, + dataset_name: str, + dataset_path: Optional[str], + tokenizer: PreTrainedTokenizerBase, + seq_len: int = 2048, + dp_rank: int = 0, + dp_world_size: int = 1, + infinite: bool = False, + ) -> None: + # Force lowercase for consistent comparison + dataset_name = dataset_name.lower() + + path, dataset_loader, text_processor = _validate_dataset( + dataset_name, dataset_path + ) + ds = dataset_loader(path) + + self.dataset_name = dataset_name + self._data = split_dataset_by_node(ds, dp_rank, dp_world_size) + self._tokenizer = tokenizer + self.seq_len = seq_len + self.infinite = infinite + self._text_processor = text_processor + + # Variables for checkpointing + self._sample_idx = 0 + self._all_tokens: list[int] = [] + + def _get_data_iter(self): + if isinstance(self._data, Dataset) and self._sample_idx == len(self._data): + return iter([]) + + it = iter(self._data) + for _ in range(self._sample_idx): + next(it) + return it + + def __iter__(self): + max_buffer_token_len = 1 + self.seq_len + + while True: + for sample in self._get_data_iter(): + # Use the dataset-specific text processor + sample_text = self._text_processor(sample) + sample_tokens = self._tokenizer.encode(sample_text) + self._all_tokens.extend(sample_tokens) + self._sample_idx += 1 + + while len(self._all_tokens) >= max_buffer_token_len: + x = torch.LongTensor(self._all_tokens[:max_buffer_token_len]) + # update tokens to the remaining tokens + self._all_tokens = self._all_tokens[max_buffer_token_len:] + input = x[:-1] + label = x[1:] + # Add position IDs (0 to seq_len-1) + position_ids = torch.arange(len(input), dtype=torch.long) + yield input, label, position_ids + + if not self.infinite: + logger.warning(f"Dataset {self.dataset_name} has run out of data") + break + else: + # Reset offset for the next iteration + self._sample_idx = 0 + logger.warning(f"Dataset {self.dataset_name} is being re-looped") + + def load_state_dict(self, state_dict): + self._sample_idx = state_dict["sample_idx"] + self._all_tokens = state_dict["token_buffer"] + + def state_dict(self): + return {"token_buffer": self._all_tokens, "sample_idx": self._sample_idx} + + +def build_hf_dataloader( + dp_world_size: int, + dp_rank: int, + tokenizer, + job_config: JobConfig, + infinite: bool = True, +) -> ParallelAwareDataloader: + """Build a data loader for HuggingFace datasets.""" + dataset_name = job_config.training.dataset + dataset_path = job_config.training.dataset_path + batch_size = job_config.training.batch_size + seq_len = job_config.training.seq_len + + hf_ds = HuggingFaceDataset( + dataset_name=dataset_name, + dataset_path=dataset_path, + tokenizer=tokenizer, + seq_len=seq_len, + dp_rank=dp_rank, + dp_world_size=dp_world_size, + infinite=infinite, + ) + + return ParallelAwareDataloader( + dataset=hf_ds, + dp_rank=dp_rank, + dp_world_size=dp_world_size, + batch_size=batch_size, + ) diff --git a/torchtitan/experiments/train_llama_hf/extra_requirements.txt b/torchtitan/experiments/train_llama_hf/extra_requirements.txt new file mode 100644 index 0000000000..eef7d90264 --- /dev/null +++ b/torchtitan/experiments/train_llama_hf/extra_requirements.txt @@ -0,0 +1,2 @@ +transformers >=4.49.0 +sentencepiece >=0.2.0 diff --git a/torchtitan/experiments/train_llama_hf/hf_weights_utils.py b/torchtitan/experiments/train_llama_hf/hf_weights_utils.py new file mode 100644 index 0000000000..c8423a6292 --- /dev/null +++ b/torchtitan/experiments/train_llama_hf/hf_weights_utils.py @@ -0,0 +1,188 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import gc +import json +from collections import defaultdict +from pathlib import Path + +import torch.nn as nn +from huggingface_hub import repo_exists, snapshot_download +from safetensors import safe_open +from torch.distributed.tensor import distribute_tensor, DTensor +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME + +from torchtitan.tools.logging import logger + +INDEX_NAME_MAPPING = { + "safetensors": SAFE_WEIGHTS_INDEX_NAME, +} + +PATTERNS_TO_REMOVE = [ + "._orig_mod", # Some optimizers add suffixes + "._fsdp_wrapped_module", # FSDP wrapper + "._checkpoint_wrapped_module", # checkpoint wrapper + ".module", # DataParallel/DistributedDataParallel + "_module.", # Some wrappers add prefix +] + + +def normalize_state_dict_key( + key: str, patterns_to_remove: list[str] = PATTERNS_TO_REMOVE +) -> str: + """ + Normalize the state dict key, remove the prefix or suffix added by various wrappers. + Args: + key: The original state dict key + Returns: + The normalized key + """ + normalized_key = key + for pattern in patterns_to_remove: + normalized_key = normalized_key.replace(pattern, "") + + return normalized_key + + +def get_weight_map(pretrained_model_path: Path) -> dict[str, str]: + """ + Get the weight map from the pretrained model. + Args: + pretrained_model_path: The path to the pretrained model. + Returns: + weight_map: A dictionary mapping from the path to the weight map to the list of state dict keys. + """ + index_file = pretrained_model_path / INDEX_NAME_MAPPING["safetensors"] + if not index_file.exists(): + return None + with open(index_file, "r") as f: + metadata = json.load(f) + return metadata["weight_map"] + + +def group_state_dict_keys_and_st_partition_paths( + pretrained_model_path: Path, + state_dict_keys, + weight_map, + state_dict_map: dict[str, str] = None, +): + """ + Group state dict keys and save them to a file. + Args: + pretrained_model_path: The path to the pretrained model. + state_dict_keys: The state dict keys to group. + weight_map: The weight map. + state_dict_map: A dictionary mapping from the state dict key to the weight path. + Returns: + st_partition_map: A dictionary mapping from the weight path to the list of state dict keys. + """ + st_partition_map = defaultdict(list) + for state_dict_key in state_dict_keys: + ckpt_state_dict_key = ( + state_dict_map[state_dict_key] + if state_dict_map is not None + else state_dict_key + ) + if weight_map is None: + partition_path = pretrained_model_path / "model.safetensors" + else: + partition_path = pretrained_model_path / weight_map[ckpt_state_dict_key] + st_partition_map[partition_path].append(state_dict_key) + return st_partition_map + + +def load_sharded_state_dict_for_model_from_path( + pretrained_model_path: Path, + model: nn.Module, + mapping_dict: dict[str, str] = None, + **kwargs, +): + """ + Load the state dict sharded (depends on DTensor) from the pretrained model path. It only load the weights for current rank. + Args: + pretrained_model_path: The path to the pretrained model, it could be a local path or an s3 path. + model: The model to load the state dict into. + **kwargs: other arguments for torch.nn.Module.load_state_dict + """ + # check exceptions + if not pretrained_model_path.exists(): + raise ValueError( + f"The pretrained model path {pretrained_model_path} does not exist." + ) + if not pretrained_model_path.is_dir(): + raise ValueError( + f"The pretrained model path {pretrained_model_path} is not a directory." + ) + # get the weight map + weight_map = get_weight_map(pretrained_model_path) + model_state_dict = model.state_dict() + model_state_dict_keys = list(model_state_dict.keys()) + + # create a mapping_dict between the original state_dict_key and the weight_map_key if not provided + mapping_dict = ( + mapping_dict + if mapping_dict is not None + else {key: normalize_state_dict_key(key) for key in model_state_dict_keys} + ) + st_partition_map = group_state_dict_keys_and_st_partition_paths( + pretrained_model_path, model_state_dict_keys, weight_map, mapping_dict + ) + + # get the sharded state dict + state_dict = {} + for safetensor_partition_path, state_dict_keys in st_partition_map.items(): + with safe_open(safetensor_partition_path, framework="pt", device="cpu") as f: + for state_dict_key in state_dict_keys: + model_tensor = model_state_dict[state_dict_key] + ckpt_state_dict_key = mapping_dict[state_dict_key] + if isinstance(model_tensor, DTensor): + local_tensor = f.get_tensor(ckpt_state_dict_key) + state_dict[state_dict_key] = distribute_tensor( + local_tensor, + model_tensor.device_mesh, + model_tensor.placements, + ) + else: + state_dict[state_dict_key] = f.get_tensor(ckpt_state_dict_key) + model.load_state_dict(state_dict, **kwargs) + del state_dict + gc.collect() + + +def load_sharded_state_dict_for_model_from_hf( + pretrained_model_id_or_path: str, + model: nn.Module, + **kwargs, +): + """ + Load the state dict sharded (depends on DTensor) from the pretrained model path. It only load the weights for current rank. + Args: + pretrained_model_id_or_path: The id or path to the pretrained model, it could be a repo id in huggingface, + or a local path + model: The model to load the state dict into. + **kwargs: other arguments for torch.nn.Module.load_state_dict + """ + logger.info(f"Loading the state dict from {pretrained_model_id_or_path}") + pretrained_model_id_or_path = Path(pretrained_model_id_or_path) + if not pretrained_model_id_or_path.exists(): + if not repo_exists(str(pretrained_model_id_or_path)): + raise ValueError( + f"The pretrained model {pretrained_model_id_or_path} does not exist" + ) + logger.info( + f"Try to download the model from huggingface: {pretrained_model_id_or_path}" + ) + pretrained_model_path = Path( + snapshot_download(str(pretrained_model_id_or_path)) + ) + elif not pretrained_model_id_or_path.is_dir(): + raise ValueError( + f"The pretrained model path {pretrained_model_id_or_path} is not a directory." + ) + else: + pretrained_model_path = pretrained_model_id_or_path + + load_sharded_state_dict_for_model_from_path(pretrained_model_path, model, **kwargs) diff --git a/torchtitan/experiments/train_llama_hf/loss.py b/torchtitan/experiments/train_llama_hf/loss.py new file mode 100644 index 0000000000..b13de33635 --- /dev/null +++ b/torchtitan/experiments/train_llama_hf/loss.py @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch + + +def cross_entropy_loss_hf(preds, labels): + loss = torch.nn.functional.cross_entropy( + preds[0].flatten(0, 1).float(), labels.flatten(0, 1) + ) + return loss diff --git a/torchtitan/experiments/train_llama_hf/model/__init__.py b/torchtitan/experiments/train_llama_hf/model/__init__.py new file mode 100644 index 0000000000..5126818b05 --- /dev/null +++ b/torchtitan/experiments/train_llama_hf/model/__init__.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +from loss import cross_entropy_loss_hf +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + +from torchtitan.components.optimizer import build_lr_schedulers, build_optimizers +from torchtitan.experiments.train_llama_hf.dataset import build_hf_dataloader +from torchtitan.protocols.train_spec import register_train_spec, TrainSpec + +from .parallelize_llama import parallelize_llama +from .pipeline_llama import pipeline_llama + +register_train_spec( + TrainSpec( + name="llama3_hf", + cls=AutoModelForCausalLM, + config=AutoConfig, + parallelize_fn=parallelize_llama, + pipelining_fn=pipeline_llama, + build_optimizers_fn=build_optimizers, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_hf_dataloader, + tokenizer_cls=AutoTokenizer, + loss_fn=cross_entropy_loss_hf, + ) +) diff --git a/torchtitan/experiments/train_llama_hf/model/parallelize_llama.py b/torchtitan/experiments/train_llama_hf/model/parallelize_llama.py new file mode 100644 index 0000000000..160edd983a --- /dev/null +++ b/torchtitan/experiments/train_llama_hf/model/parallelize_llama.py @@ -0,0 +1,382 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This file applies the PT-D parallelisms (except pipeline parallelism) and various +# training techniques (e.g. activation checkpointing and compile) to the Llama model. + +from collections import defaultdict + +import torch +import torch.nn as nn + +from torch.distributed import DeviceMesh +from torch.distributed._composable.fsdp import ( + CPUOffloadPolicy, + fully_shard, + MixedPrecisionPolicy, +) +from torch.distributed._composable.replicate import replicate +from torch.distributed._tensor import Replicate, Shard +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + checkpoint_wrapper as ptd_checkpoint_wrapper, +) +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + PrepareModuleInput, + RowwiseParallel, + SequenceParallel, +) + +from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP +from torchtitan.distributed import ParallelDims +from torchtitan.tools.logging import logger + + +def parallelize_llama( + model: nn.Module, + world_mesh: DeviceMesh, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + """ + Apply tensor parallelism, activation checkpointing, torch.compile, and data + parallelism to the model. + + NOTE: The passed-in model preferably should be on meta device. Otherwise, + the model must fit on GPU or CPU memory. + """ + + if parallel_dims.tp_enabled: + if ( + job_config.experimental.enable_async_tensor_parallel + and not job_config.training.compile + ): + raise RuntimeError("Async TP requires --training.compile") + enable_float8_linear = "float8" in job_config.model.converters + apply_tp( + model, + world_mesh["tp"], + loss_parallel=parallel_dims.loss_parallel_enabled, + enable_float8=enable_float8_linear, + enable_async_tp=job_config.experimental.enable_async_tensor_parallel, + ) + + if job_config.activation_checkpoint.mode != "none": + apply_ac(model, job_config.activation_checkpoint) + + # turn on per-TransformerBlock compile after AC wrapping and before FSDP + if job_config.training.compile: + apply_compile(model) + + if ( + parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled + ): # apply FSDP or HSDP, potentially with Context Parallel + if parallel_dims.dp_replicate_enabled: + dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + else: + dp_mesh_dim_names = ("dp_shard_cp",) + + apply_fsdp( + model, + world_mesh[tuple(dp_mesh_dim_names)], + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], + pp_enabled=parallel_dims.pp_enabled, + cpu_offload=job_config.training.enable_cpu_offload, + reshard_after_forward_policy=job_config.training.fsdp_reshard_after_forward, + ) + + if parallel_dims.dp_replicate_enabled: + logger.info("Applied HSDP to the model") + else: + logger.info("Applied FSDP to the model") + + if parallel_dims.cp_enabled: + logger.info("Applied Context Parallel to the model") + + if job_config.training.enable_cpu_offload: + logger.info("Applied CPU Offloading to the model") + elif parallel_dims.dp_replicate_enabled: + if world_mesh.ndim > 1: + raise RuntimeError("DDP has not supported > 1D parallelism") + apply_ddp( + model, + world_mesh, + enable_compile=job_config.training.compile, + enable_compiled_autograd=job_config.experimental.enable_compiled_autograd, + ) + + +def apply_tp( + model: nn.Module, + tp_mesh: DeviceMesh, + loss_parallel: bool, + enable_float8: bool, + enable_async_tp: bool, +): + """Apply tensor parallelism.""" + # 1. Parallelize the embedding and shard its outputs (which are the first + # transformer block's inputs) + # 2. Parallelize the root norm layer over the sequence dim + # 3. Parallelize the final linear output layer + parallelize_module( + model, + tp_mesh, + { + "model.embed_tokens": RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + ), + "model.norm": SequenceParallel(), + "lm_head": ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Shard(-1) if loss_parallel else Replicate(), + use_local_output=not loss_parallel, + ), + }, + ) + + # Parallel styles used for transformer block linear weights and their + # inputs may be different for float8 linears + if enable_float8: + # TODO(vkuzo): once float8 configuration supports delayed scaling, + # add a check here to enforce supported float8 all-gather configurations + # TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there + from torchao.float8.float8_tensor_parallel import ( + Float8ColwiseParallel, + Float8RowwiseParallel, + PrepareFloat8ModuleInput, + ) + + rowwise_parallel, colwise_parallel, prepare_module_input = ( + Float8RowwiseParallel, + Float8ColwiseParallel, + PrepareFloat8ModuleInput, + ) + else: + rowwise_parallel, colwise_parallel, prepare_module_input = ( + RowwiseParallel, + ColwiseParallel, + PrepareModuleInput, + ) + + # Apply tensor + sequence parallelism to every transformer block + # NOTE: At the cost of model code change, we can accelerate Sequence Parallel + # by folding (and unfolding) the batch dimension and the sequence dimension. + # Examples can be found at https://github.com/pytorch/torchtitan/pull/437 + for transformer_block in model.model.layers: + layer_plan = { + "input_layernorm": SequenceParallel(), + "self_attn": prepare_module_input( + input_kwarg_layouts={ + "hidden_states": Shard(1), + }, + desired_input_kwarg_layouts={ + "hidden_states": Replicate(), + }, + ), + "self_attn.q_proj": colwise_parallel(), + "self_attn.k_proj": colwise_parallel(), + "self_attn.v_proj": colwise_parallel(), + "self_attn.o_proj": rowwise_parallel(output_layouts=Shard(1)), + "post_attention_layernorm": SequenceParallel(), + "mlp": prepare_module_input( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + ), + "mlp.gate_proj": colwise_parallel(), + "mlp.up_proj": colwise_parallel(), + "mlp.down_proj": rowwise_parallel(output_layouts=Shard(1)), + } + + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=layer_plan, + ) + + if enable_async_tp: + from torch.distributed._symmetric_memory import enable_symm_mem_for_group + + torch._inductor.config._micro_pipeline_tp = True + enable_symm_mem_for_group(tp_mesh.get_group().group_name) + + logger.info( + f"Applied {'Float8 ' if enable_float8 else ''}{'Async ' if enable_async_tp else ''}" + "Tensor Parallelism to the model" + ) + + +# for selective op activation checkpointing +_save_list = { + torch.ops.aten.mm.default, + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops._c10d_functional.reduce_scatter_tensor.default, + # for low precision training, it's useful to always save + # the result of max, since the absolute maximum is + # used to compute the scaling factor for quantization. + torch.ops.aten.max.default, +} + + +def _apply_ac_to_transformer_block(module: nn.Module, ac_config): + valid_ac_modes = ("full", "selective") + if ac_config.mode not in valid_ac_modes: + raise ValueError( + f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}" + ) + + if ac_config.mode == "full": + return ptd_checkpoint_wrapper(module, preserve_rng_state=False) + + assert ac_config.mode == "selective", f"{ac_config.mode}" + use_op_sac = ac_config.selective_ac_option == "op" + use_layer_sac = ac_config.selective_ac_option.isdigit() + if not use_op_sac and not use_layer_sac: + raise ValueError( + f"Invalid selective AC option: {ac_config.selective_ac_option}. " + f"Valid options: 'op' or a positive int representing layer frequency" + ) + if use_op_sac: + from torch.utils.checkpoint import ( + CheckpointPolicy, + create_selective_checkpoint_contexts, + ) + + def _get_custom_policy(meta): + def _custom_policy(ctx, func, *args, **kwargs): + mode = "recompute" if ctx.is_recompute else "forward" + mm_count_key = f"{mode}_mm_count" + if func == torch.ops.aten.mm.default: + meta[mm_count_key] += 1 + # Saves output of all compute ops, except every second mm + to_save = func in _save_list and not ( + func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0 + ) + return ( + CheckpointPolicy.MUST_SAVE + if to_save + else CheckpointPolicy.PREFER_RECOMPUTE + ) + + return _custom_policy + + def selective_checkpointing_context_fn(): + meta = defaultdict(int) + return create_selective_checkpoint_contexts(_get_custom_policy(meta)) + + return ptd_checkpoint_wrapper( + module, + context_fn=selective_checkpointing_context_fn, + preserve_rng_state=False, + ) + elif use_layer_sac: + # Checkpoint every `ac_freq` of the modules passed to this function + ac_freq = int(ac_config.selective_ac_option) + ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0) + ptd_checkpoint_wrapper._count += 1 + if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0: + return ptd_checkpoint_wrapper(module, preserve_rng_state=False) + else: + return module + + +def apply_ac(model: nn.Module, ac_config): + """Apply activation checkpointing to the model.""" + for layer_id, transformer_block in model.model.layers.named_children(): + transformer_block = _apply_ac_to_transformer_block(transformer_block, ac_config) + model.model.layers.register_module(layer_id, transformer_block) + logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") + + +def apply_compile(model: nn.Module): + """ + Apply torch.compile to each TransformerBlock, which makes compilation efficient due to + repeated structure. Alternatively one can compile the whole model (after applying DP). + """ + for layer_id, transformer_block in model.model.layers.named_children(): + transformer_block = torch.compile(transformer_block, fullgraph=True) + model.model.layers.register_module(layer_id, transformer_block) + + logger.info("Compiling each TransformerBlock with torch.compile") + + +def apply_fsdp( + model: nn.Module, + dp_mesh: DeviceMesh, + param_dtype: torch.dtype, + reduce_dtype: torch.dtype, + pp_enabled: bool, + cpu_offload: bool = False, + reshard_after_forward_policy: str = "default", +): + """ + Apply data parallelism (via FSDP2) to the model. + + Args: + model (nn.Module): The model to apply data parallelism to. + dp_mesh (DeviceMesh): The device mesh to use for data parallelism. + param_dtype (torch.dtype): The data type to use for model parameters. + reduce_dtype (torch.dtype): The data type to use for reduction operations. + pp_enabled (bool): Whether pipeline parallelism is enabled. + cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False. + reshard_after_forward_policy (str, optional): The policy to use for resharding after forward pass. Defaults to "default". + Other options: "never", "always". + - "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios. + - "always" will enable `reshard_after_forward` for all forward passes. + - "never" will disable `reshard_after_forward` for all forward passes. + + """ + mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) + fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} + if cpu_offload: + fsdp_config["offload_policy"] = CPUOffloadPolicy() + + for layer_id, transformer_block in enumerate(model.model.layers): + if reshard_after_forward_policy == "always": + reshard_after_forward = True + elif reshard_after_forward_policy == "never": + reshard_after_forward = False + elif reshard_after_forward_policy == "default": + if pp_enabled: + # For PP, do not reshard after forward to avoid per-microbatch + # all-gathers, which can be expensive and non-overlapped + reshard_after_forward = False + else: + # As an optimization, do not reshard after forward for the last + # transformer block since FSDP would prefetch it immediately + reshard_after_forward = int(layer_id) < len(model.model.layers) - 1 + else: + raise ValueError( + f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}." + ) + fully_shard( + transformer_block, + **fsdp_config, + reshard_after_forward=reshard_after_forward, + ) + fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled) + + +def apply_ddp( + model: nn.Module, + dp_mesh: DeviceMesh, + enable_compile: bool, + enable_compiled_autograd: bool, +): + if enable_compile: + if enable_compiled_autograd: + torch._dynamo.config.optimize_ddp = ( + "python_reducer_without_compiled_forward" + ) + else: + torch._dynamo.config.optimize_ddp = "ddp_optimizer" + + replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) + + logger.info("Applied DDP to the model") diff --git a/torchtitan/experiments/train_llama_hf/model/pipeline_llama.py b/torchtitan/experiments/train_llama_hf/model/pipeline_llama.py new file mode 100644 index 0000000000..6e6a815134 --- /dev/null +++ b/torchtitan/experiments/train_llama_hf/model/pipeline_llama.py @@ -0,0 +1,335 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This file applies the PT-D pipeline parallelism to the Llama model. + +import copy +from typing import Callable, Optional, Union + +import torch +import torch.nn as nn +from torch.distributed import DeviceMesh +from torch.distributed.pipelining import PipelineStage +from torch.distributed.pipelining.schedules import ( + _PipelineSchedule, + get_schedule_class, + ScheduleZBVZeroBubble, +) +from transformers import PretrainedConfig + +from torchtitan.config_manager import JobConfig +from torchtitan.distributed import ParallelDims +from torchtitan.distributed.pipeline import ( + build_pipeline_schedule, + generate_split_points, + stage_ids_this_rank, +) +from torchtitan.tools.logging import logger + + +DeviceType = Union[int, str, torch.device] + + +def patch_llama_forward(model: nn.Module): + """ + patch forward method for LlamaModel instance. + Revised from transformers.models.llama.modeling_llama.LlamaModel.forward + """ + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values=None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs, + ) -> torch.Tensor: + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, False + ) + + hidden_states = inputs_embeds + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + hidden_states = layer_outputs[0] + + # adapt for last stage / other stage + if self.norm is not None: + hidden_states = self.norm(hidden_states) + + return hidden_states + + # bind to model instance + model.forward = forward.__get__(model, model.__class__) + + +def patch_llamaforcasuallm_forward(model: nn.Module): + """ + patch forward method for LlamaForCausalLM instance + """ + + def forward( + self, + input_ids_or_input_embeds: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values=None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, + ) -> torch.Tensor: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + if return_dict is True: + raise ValueError( + "return_dict must be False while using pipeline parallelism" + ) + + # adapt for first staget / other stage + if self.model.embed_tokens is not None: + input_ids = input_ids_or_input_embeds + inputs_embeds = None + else: + input_ids = None + inputs_embeds = input_ids_or_input_embeds + + hidden_states = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + + # adapt for last stage / other stage + if self.lm_head is not None: + logits = self.lm_head(hidden_states[:, slice_indices, :]) + return (logits,) + else: + return (hidden_states[:, slice_indices, :],) + + # bind to model instance + model.forward = forward.__get__(model, model.__class__) + + # patch for LlamaModel instance + patch_llama_forward(model.model) + + +def pipeline_llama( + model: nn.Module, + pp_mesh: DeviceMesh, + parallel_dims: ParallelDims, + job_config: JobConfig, + device: DeviceType, + model_config: PretrainedConfig, + loss_fn: Callable[..., torch.Tensor], +) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]: + logger.info( + "Patching Llama forward method for pipeline parallelism, it will disable some features of orignal HF model" + ) + patch_llamaforcasuallm_forward(model) + stages, models = pipeline_llama_manual_split( + model, pp_mesh, parallel_dims, job_config, device, model_config + ) + + pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn) + + # This is used in the train loop to determine whether to pass in the input_ids and labels + has_first_stage = False + has_last_stage = False + for stage in stages: + if stage.is_first: + has_first_stage = True + if stage.is_last: + has_last_stage = True + + return pp_schedule, models, has_first_stage, has_last_stage + + +def pipeline_llama_manual_split( + whole_model: nn.Module, + pp_mesh: DeviceMesh, + parallel_dims: ParallelDims, + job_config: JobConfig, + device: DeviceType, + model_config: PretrainedConfig, +) -> tuple[list[PipelineStage], list[nn.Module]]: + """ + This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage. + + It wraps the model chunk in a ManualPipelineStage object and returns both the stage and model objects. + + The stage object is used to create a pipeline schedule, and the model object can be used for applying SPMD + parallelism. + """ + pp_rank = pp_mesh.get_local_rank() + pp_size = pp_mesh.size() + + splits = ( + job_config.experimental.pipeline_parallel_split_points + or generate_split_points( + job_config, parallel_dims.pp, model_config.num_hidden_layers + ) + ) + + def _build_stage( + stage_idx: int, + start_layer: Optional[str], + stop_layer: Optional[str], + is_first: bool = False, + is_last: bool = False, + ) -> tuple[PipelineStage, nn.Module]: + model = copy.deepcopy(whole_model) + if not is_first: + model.model.embed_tokens = None + + drop_layers = start_layer is not None + del_indexes = [] + for i in range(len(model.model.layers)): + # we keep layers in a contiguous region between start (inclusive) and stop (exclusive) + if f"layers.{i}" == start_layer: + drop_layers = False + if f"layers.{i}" == stop_layer: + drop_layers = True + if drop_layers: + del_indexes.append(i) + + # delete layers in reverse order to avoid index shifting + del_indexes.reverse() + for i in del_indexes: + del model.model.layers[i] + + if not is_last: + model.model.norm = None + model.lm_head = None + + stage = PipelineStage( + model, + stage_idx, + num_stages, + device, + group=pp_mesh.get_group("pp"), + ) + return stage, model + + num_stages = len(splits) + 1 + stage_idx = pp_rank + + stages = [] + models = [] + + schedule_class = get_schedule_class( + job_config.experimental.pipeline_parallel_schedule + ) + style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop" + + for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style): + start_layer = splits[stage_idx - 1] if stage_idx > 0 else None + stop_layer = splits[stage_idx] if stage_idx < num_stages - 1 else None + stage, model_chunk = _build_stage( + stage_idx, + start_layer, + stop_layer, + is_first=stage_idx == 0, + is_last=stage_idx == num_stages - 1, + ) + logger.info( + f"PP rank {pp_rank} is building stage_idx {stage_idx}" + f" with start_layer {start_layer}, stop_layer {stop_layer}" + ) + stages.append(stage) + models.append(model_chunk) + return stages, models diff --git a/torchtitan/experiments/train_llama_hf/model/train_configs/llama3_8b_hf.toml b/torchtitan/experiments/train_llama_hf/model/train_configs/llama3_8b_hf.toml new file mode 100644 index 0000000000..fac392ece5 --- /dev/null +++ b/torchtitan/experiments/train_llama_hf/model/train_configs/llama3_8b_hf.toml @@ -0,0 +1,60 @@ +# torchtitan Config.toml +# NOTE: this toml config is a preset for 64 A100 GPUs. + +[job] +dump_folder = "./outputs" +description = "Llama 3 8B training" + +[profiling] +enable_profiling = true +save_traces_folder = "profile_trace" +profile_freq = 100 + +[metrics] +log_freq = 10 +enable_tensorboard = true +save_tb_folder = "tb" + +[model] +name = "llama3_hf" +flavor = "/data/models/Meta-Llama-3.1-8B" +tokenizer_path = "/data/models/Meta-Llama-3.1-8B" +# converters = "float8" + +[optimizer] +name = "AdamW" +lr = 3e-4 + +[training] +batch_size = 2 +seq_len = 8192 +warmup_steps = 200 # lr scheduler warm up +max_norm = 1.0 # grad norm clipping +steps = 1000 +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +tensor_parallel_degree = 2 +compile = true +dataset = "c4_test" +dataset_path = "../../../tests/assets/c4_test" + +[experimental] +context_parallel_degree = 1 +pipeline_parallel_degree = 2 +enable_async_tensor_parallel = true + +[checkpoint] +enable_checkpoint = false +folder = "checkpoint" +interval = 500 +model_weights_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = 'selective' # ['none', 'selective', 'full'] +selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false diff --git a/torchtitan/experiments/train_llama_hf/run_train.sh b/torchtitan/experiments/train_llama_hf/run_train.sh new file mode 100644 index 0000000000..1f7a2ab39f --- /dev/null +++ b/torchtitan/experiments/train_llama_hf/run_train.sh @@ -0,0 +1,28 @@ +#!/usr/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -ex + +# use envs as local overrides for convenience +# e.g. +# LOG_RANK=0,1 NGPU=4 ./run_train.sh +NGPU=${NGPU:-"8"} +LOG_RANK=${LOG_RANK:-0} +CONFIG_FILE=${CONFIG_FILE:-"./model/train_configs/llama3_8b_hf.toml"} + +overrides="" +if [ $# -ne 0 ]; then + overrides="$*" +fi + +TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"} + +PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \ +TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE} \ +torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ +--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ +train_llama_hf.py --job.config_file ${CONFIG_FILE} $overrides diff --git a/torchtitan/experiments/train_llama_hf/test_loading_hf_weights.py b/torchtitan/experiments/train_llama_hf/test_loading_hf_weights.py new file mode 100644 index 0000000000..48d7dc640f --- /dev/null +++ b/torchtitan/experiments/train_llama_hf/test_loading_hf_weights.py @@ -0,0 +1,46 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import subprocess + +from pathlib import Path + +import pytest + + +PRETRAINED_MODEL_ID = "hf-internal-testing/tiny-random-LlamaForCausalLM" + + +@pytest.mark.parametrize("dp_shard_degree", [-1]) +@pytest.mark.parametrize("tp_degree", [2, 4]) +@pytest.mark.parametrize("pp_degree", [2]) +@pytest.mark.parametrize("world_size", [8]) +def test_load_sharded_state_dict_for_model_from_hf( + dp_shard_degree, tp_degree, pp_degree, world_size +): + test_file_path = Path(__file__).parent / "test_loading_hf_weights_helper.py" + cmd = [ + "torchrun", + "--local-ranks-filter", + "0", + "--nproc_per_node", + str(world_size), + str(test_file_path), + "--experimental.pipeline_parallel_degree", + str(pp_degree), + "--training.tensor_parallel_degree", + str(tp_degree), + "--training.data_parallel_shard_degree", + str(dp_shard_degree), + "--model.name", + PRETRAINED_MODEL_ID, + "--model.flavor", + PRETRAINED_MODEL_ID, + "--model.tokenizer_path", + PRETRAINED_MODEL_ID, + ] + result = subprocess.run(cmd, check=True) + assert result.returncode == 0 diff --git a/torchtitan/experiments/train_llama_hf/test_loading_hf_weights_helper.py b/torchtitan/experiments/train_llama_hf/test_loading_hf_weights_helper.py new file mode 100644 index 0000000000..25aca1cf38 --- /dev/null +++ b/torchtitan/experiments/train_llama_hf/test_loading_hf_weights_helper.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os + +import torch +from transformers import AutoConfig, AutoModelForCausalLM + +from torchtitan.config_manager import JobConfig +from torchtitan.distributed import ParallelDims +from torchtitan.experiments.train_llama_hf.hf_weights_utils import ( + load_sharded_state_dict_for_model_from_hf, + normalize_state_dict_key, +) + +from torchtitan.experiments.train_llama_hf.model.parallelize_llama import ( + apply_fsdp, + apply_tp, +) +from torchtitan.experiments.train_llama_hf.model.pipeline_llama import ( + pipeline_llama_manual_split, +) + + +def main(job_config: JobConfig): + world_size = int(os.environ["WORLD_SIZE"]) + device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") + parallel_dims = ParallelDims( + dp_shard=job_config.training.data_parallel_shard_degree, + dp_replicate=job_config.training.data_parallel_replicate_degree, + cp=job_config.experimental.context_parallel_degree, + tp=job_config.training.tensor_parallel_degree, + pp=job_config.experimental.pipeline_parallel_degree, + world_size=world_size, + enable_loss_parallel=not job_config.training.disable_loss_parallel, + ) + world_mesh = parallel_dims.build_mesh(device_type="cuda") + + model_config = AutoConfig.from_pretrained(job_config.model.flavor) + + # load model + with torch.device("meta"): + model = AutoModelForCausalLM.from_config(model_config) + + with torch.device("cpu"): + gold_model_state_dict = AutoModelForCausalLM.from_pretrained( + job_config.model.flavor + ).state_dict() + # apply parallelisms + if parallel_dims.pp_enabled: + # apply PT-D Pipeline Parallel + _, model_parts = pipeline_llama_manual_split( + model, + world_mesh["pp"], + parallel_dims, + job_config, + device, + model_config, + ) + else: + model_parts = [model] + for m in model_parts: + if parallel_dims.tp_enabled: + apply_tp( + m, + world_mesh["tp"], + loss_parallel=False, + enable_float8=False, + enable_async_tp=False, + ) + if parallel_dims.dp_shard_enabled: + apply_fsdp( + m, + world_mesh["dp_shard"], + param_dtype=torch.float32, + reduce_dtype=torch.float32, + pp_enabled=False, + cpu_offload=False, + reshard_after_forward_policy="default", + ) + + m.to_empty(device="cuda") + # load weights + with torch.no_grad(): + load_sharded_state_dict_for_model_from_hf(job_config.model.flavor, m) + for k, v in m.state_dict().items(): + if isinstance(v, torch.distributed.tensor.DTensor): + full_tensor = v.full_tensor().to("cpu") + else: + full_tensor = v.to("cpu") + k = normalize_state_dict_key(k) + gt_value = gold_model_state_dict[k] + assert torch.allclose(full_tensor, gt_value), f"tensor mismatch for {k}" + + +if __name__ == "__main__": + job_config = JobConfig() + job_config.parse_args() + main(job_config) diff --git a/torchtitan/experiments/train_llama_hf/train_llama_hf.py b/torchtitan/experiments/train_llama_hf/train_llama_hf.py new file mode 100644 index 0000000000..5dc1176b46 --- /dev/null +++ b/torchtitan/experiments/train_llama_hf/train_llama_hf.py @@ -0,0 +1,472 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import time +from datetime import timedelta + +import torch +from torch.distributed.elastic.multiprocessing.errors import record + +from torchtitan.components.checkpoint import CheckpointManager, TrainState +from torchtitan.components.ft import FTParallelDims, init_ft_manager +from torchtitan.config_manager import JobConfig +from torchtitan.distributed import ParallelDims, utils as dist_utils + +from torchtitan.experiments.train_llama_hf.hf_weights_utils import ( + load_sharded_state_dict_for_model_from_hf, +) + +from torchtitan.protocols.model_converter import build_model_converters +from torchtitan.protocols.train_spec import get_train_spec + +from torchtitan.tools import utils +from torchtitan.tools.logging import init_logger, logger +from torchtitan.tools.metrics import build_device_memory_monitor, build_metric_logger +from torchtitan.tools.profiling import ( + maybe_enable_memory_snapshot, + maybe_enable_profiling, +) + + +# Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html +@record +def main(job_config: JobConfig): + logger.info(f"Starting job: {job_config.job.description}") + + if job_config.job.print_args: + logger.info(f"Running with args: {job_config.to_dict()}") + + if job_config.training.compile: + # raise cache size limit to 4GB + logger.info("Raising dynamo cache size limit to 4GB") + import torch._dynamo + + torch._dynamo.config.cache_size_limit = 4096 + + # used for colorful printing + color = utils.NoColor if job_config.metrics.disable_color_printing else utils.Color + + # take control of garbage collection to avoid stragglers + gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq) + + device_module, device_type = utils.device_module, utils.device_type + device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}") + # Device has to be set before creating TorchFT manager. + device_module.set_device(device) + ft_manager = init_ft_manager(job_config) + + # init distributed + world_size = int(os.environ["WORLD_SIZE"]) + if not ft_manager.enabled: + parallel_dims = ParallelDims( + dp_shard=job_config.training.data_parallel_shard_degree, + dp_replicate=job_config.training.data_parallel_replicate_degree, + cp=job_config.experimental.context_parallel_degree, + tp=job_config.training.tensor_parallel_degree, + pp=job_config.experimental.pipeline_parallel_degree, + world_size=world_size, + enable_loss_parallel=not job_config.training.disable_loss_parallel, + ) + else: + parallel_dims = FTParallelDims( + dp_shard=job_config.training.data_parallel_shard_degree, + dp_replicate=job_config.training.data_parallel_replicate_degree, + cp=job_config.experimental.context_parallel_degree, + tp=job_config.training.tensor_parallel_degree, + pp=job_config.experimental.pipeline_parallel_degree, + world_size=world_size, + enable_loss_parallel=not job_config.training.disable_loss_parallel, + ft_manager=ft_manager, + ) + dist_utils.init_distributed(job_config) + # initialize device memory monitor and get peak flops for MFU calculation + device_memory_monitor = build_device_memory_monitor() + gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name) + logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}") + + # build meshes + world_mesh = parallel_dims.build_mesh(device_type=device_type) + if parallel_dims.dp_enabled: + dp_mesh = world_mesh["dp"] + dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank() + else: + dp_degree, dp_rank = 1, 0 + + if parallel_dims.pp_enabled: + pp_mesh = world_mesh["pp"] + + # Set random seed, and maybe enable deterministic mode (mainly for debugging, expect perf loss) + dist_utils.set_determinism( + world_mesh, device, job_config.training.seed, job_config.training.deterministic + ) + train_spec = get_train_spec(job_config.model.name) + + # build dataloader + tokenizer = train_spec.tokenizer_cls.from_pretrained( + job_config.model.tokenizer_path + ) + + # If TorchFT is enabled, the dp_rank and dp_degree, which are used for + # dataloader must be changed. + if ft_manager.enabled: + dp_degree, dp_rank = ft_manager.get_dp_info(dp_degree, dp_rank) + dataloader = train_spec.build_dataloader_fn( + dp_world_size=dp_degree, + dp_rank=dp_rank, + tokenizer=tokenizer, + job_config=job_config, + ) + + # build model (using meta init) + model_cls = train_spec.cls + model_config = train_spec.config.from_pretrained(job_config.model.flavor) + model_config.return_dict = False # for compatibility with pipeline parallel + # adapt util function + model_config.n_layers = model_config.num_hidden_layers + model_config.n_heads = model_config.num_attention_heads + model_config.dim = model_config.hidden_size + + logger.info( + f"Building {train_spec.name} {job_config.model.flavor} with {model_config}" + ) + with torch.device("meta"): + model = model_cls.from_config(model_config) + + # Build the collection of model converters. No-op if `model.converters` empty + model_converters = build_model_converters(job_config, parallel_dims) + model_converters.convert(model) + + # log model size + model_param_count = utils.get_num_params(model) + num_flop_per_token = utils.get_num_flop_per_token( + utils.get_num_params(model, exclude_embedding=True), + model_config, + job_config.training.seq_len, + ) + logger.info( + f"{color.blue}Model {train_spec.name} {job_config.model.flavor} " + f"{color.red}size: {model_param_count:,} total parameters{color.reset}" + ) + + # move sharded model to CPU/GPU and initialize weights via DTensor + if job_config.checkpoint.create_seed_checkpoint: + init_device = "cpu" + buffer_device = None + elif job_config.training.enable_cpu_offload: + init_device = "cpu" + buffer_device = device_type + else: + init_device = device_type + buffer_device = None + + # apply parallelisms and initialization + if parallel_dims.pp_enabled: + # apply PT-D Pipeline Parallel + ( + pp_schedule, + model_parts, + has_first_stage, + has_last_stage, + ) = train_spec.pipelining_fn( + model, + pp_mesh, + parallel_dims, + job_config, + device, + model_config, + train_spec.loss_fn, + ) + # when PP is enabled, `model` obj is no longer used after this point, model_parts is used instead + del model + + # For PP with looped schedules, each item in model_parts is one stage-model-chunk. + # We need to iterate through model_parts to apply SPMD parallelisms, compilation, + # optimizer, and checkpointing + for m in model_parts: + # apply SPMD-style PT-D techniques + train_spec.parallelize_fn(m, world_mesh, parallel_dims, job_config) + m.to_empty(device=init_device) + with torch.no_grad(): + load_sharded_state_dict_for_model_from_hf(job_config.model.flavor, m) + m.train() + else: + # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel + train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config) + model.to_empty(device=init_device) + with torch.no_grad(): + load_sharded_state_dict_for_model_from_hf(job_config.model.flavor, model) + model.train() + + model_parts = [model] + + device_mem_stats = device_memory_monitor.get_peak_stats() + logger.info( + f"{device_type.upper()} memory usage for model: " + f"{device_mem_stats.max_reserved_gib:.2f}GiB" + f"({device_mem_stats.max_reserved_pct:.2f}%)" + ) + + # build optimizer after applying parallelisms to the model + optimizers = train_spec.build_optimizers_fn(model_parts, job_config, ft_manager) + lr_schedulers = train_spec.build_lr_schedulers_fn(optimizers, job_config) + # Post optimizer step model converters hook. + # e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2 + # where it issues a single all-reduce for all parameters at once for better performance + optimizers.register_step_post_hook( + lambda *args, **kwargs: model_converters.post_optimizer_hook(model_parts) + ) + + train_state = TrainState() + + # load initial checkpoint + checkpoint = CheckpointManager( + dataloader=dataloader, + model_parts=model_parts, + optimizers=optimizers, + lr_schedulers=lr_schedulers, + states={"train_state": train_state}, + job_config=job_config, + ft_manager=ft_manager, + ) + + if job_config.checkpoint.create_seed_checkpoint: + assert ( + world_size == 1 + ), "Must create seed checkpoint using a single device, to disable sharding" + assert ( + job_config.checkpoint.enable_checkpoint + ), "Must enable checkpointing when creating a seed checkpoint" + checkpoint.save(curr_step=0, force=True) + logger.info("Created seed checkpoint") + return + + checkpoint.load(step=job_config.checkpoint.load_step) + metric_logger = build_metric_logger(job_config, parallel_dims) + + # plot losses loaded from checkpoint (if any) to TensorBoard + # NOTE: Loss info after the last log step before checkpoint saving will not be ploted. + # This can be avoided by setting checkpoint.interval to be a multiple of metrics.log_freq + if train_state.step > 0: + for idx, step in enumerate(train_state.log_steps): + metrics = { + "loss_metrics/global_avg_loss": train_state.global_avg_losses[idx], + "loss_metrics/global_max_loss": train_state.global_max_losses[idx], + } + metric_logger.log(metrics, step=step) + + data_iterator = iter(dataloader) + + train_context = dist_utils.get_train_context( + parallel_dims.loss_parallel_enabled, + job_config.experimental.enable_compiled_autograd, + ) + + # variables used to keep info for metrics logging + ntokens_since_last_log = 0 + data_loading_times = [] + time_last_log = time.perf_counter() + device_memory_monitor.reset_peak_stats() + + # train loop + logger.info( + f"Training starts at step {train_state.step + 1}, " + f"with local batch size {job_config.training.batch_size}, " + f"global batch size {job_config.training.batch_size * dp_degree}, " + f"sequence length {job_config.training.seq_len}, " + f"total steps {job_config.training.steps} " + f"(warmup {job_config.training.warmup_steps})" + ) + with ( + maybe_enable_profiling( + job_config, global_step=train_state.step + ) as torch_profiler, + maybe_enable_memory_snapshot( + job_config, global_step=train_state.step + ) as memory_profiler, + ): + while train_state.step < job_config.training.steps: + train_state.step += 1 + gc_handler.run(train_state.step) + + # get batch + data_load_start = time.perf_counter() + batch = next(data_iterator) + input_ids, labels, position_ids = batch + ntokens_since_last_log += labels.numel() + data_loading_times.append(time.perf_counter() - data_load_start) + + input_ids = input_ids.to(device_type) + labels = labels.to(device_type) + position_ids = position_ids.to(device_type) + optimizers.zero_grad() + + # apply context parallelism if cp is enabled + # ensure CP handles the separate freqs_cis buffer for each pp stage + optional_context_parallel_ctx = ( + utils.create_context_parallel_ctx( + cp_mesh=world_mesh["cp"], + cp_buffers=[input_ids, labels, position_ids], + cp_seq_dims=[1, 1, 1], + cp_no_restore_buffers={input_ids, labels}, + cp_rotate_method=job_config.experimental.context_parallel_rotate_method, + ) + if parallel_dims.cp_enabled + else None + ) + + if parallel_dims.pp_enabled: + # Pipeline Parallel forward / backward inside step() call + with train_context(optional_context_parallel_ctx): + targets, losses = (labels, []) if has_last_stage else (None, None) + if has_first_stage: + pp_schedule.step( + input_ids, + position_ids=position_ids, + target=targets, + losses=losses, + ) + else: + pp_schedule.step( + position_ids=position_ids, target=targets, losses=losses + ) + + # accumulate losses across pipeline microbatches + # TODO: PP+FSDP unexpectedly puts the loss back to the CPU + loss = ( + torch.mean(torch.stack(losses)).to(device) + if has_last_stage + else torch.tensor([-1.0], device=device) + ) + else: + # Non-PP forward / backward + with train_context(optional_context_parallel_ctx): + pred = model(input_ids) + loss = train_spec.loss_fn(pred, labels) + # pred.shape=(bs, seq_len, vocab_size) + # need to free to before bwd to avoid peaking memory + del pred + loss.backward() + + # clip gradients + dist_utils.clip_grad_norm_( + [p for m in model_parts for p in m.parameters()], + job_config.training.max_norm, + foreach=True, + pp_mesh=pp_mesh if parallel_dims.pp_enabled else None, + ) + + # optimizer step + checkpoint.maybe_wait_for_staging() + optimizers.step() + lr_schedulers.step() + + # log metrics + if ( + train_state.step == 1 + or train_state.step % job_config.metrics.log_freq == 0 + ): + if ( + parallel_dims.dp_replicate_enabled + or parallel_dims.dp_shard_enabled + or parallel_dims.cp_enabled + ): + loss = loss.detach() + global_avg_loss, global_max_loss = ( + dist_utils.dist_mean(loss, world_mesh["dp_cp"]), + dist_utils.dist_max(loss, world_mesh["dp_cp"]), + ) + else: + global_avg_loss = global_max_loss = loss.item() + + # update train state + train_state.log_steps.append(train_state.step) + train_state.global_avg_losses.append(global_avg_loss) + train_state.global_max_losses.append(global_max_loss) + + time_delta = time.perf_counter() - time_last_log + + # tokens per second per device, abbreviated as tps + tps = ntokens_since_last_log / ( + time_delta * parallel_dims.non_data_parallel_size + ) + # model FLOPS utilization + # For its definition and calculation, please refer to the PaLM paper: + # https://arxiv.org/abs/2204.02311 + mfu = 100 * num_flop_per_token * tps / gpu_peak_flops + tflops = num_flop_per_token * tps / 1e12 + + time_end_to_end = time_delta / job_config.metrics.log_freq + time_data_loading = sum(data_loading_times) / len(data_loading_times) + time_data_loading_pct = 100 * sum(data_loading_times) / time_delta + + device_mem_stats = device_memory_monitor.get_peak_stats() + + metrics = { + "loss_metrics/global_avg_loss": global_avg_loss, + "loss_metrics/global_max_loss": global_max_loss, + "throughput(tps)": tps, + "tflops": tflops, + "mfu(%)": mfu, + "time_metrics/end_to_end(s)": time_end_to_end, + "time_metrics/data_loading(s)": time_data_loading, + "time_metrics/data_loading(%)": time_data_loading_pct, + "memory/max_active(GiB)": device_mem_stats.max_active_gib, + "memory/max_active(%)": device_mem_stats.max_active_pct, + "memory/max_reserved(GiB)": device_mem_stats.max_reserved_gib, + "memory/max_reserved(%)": device_mem_stats.max_reserved_pct, + "memory/num_alloc_retries": device_mem_stats.num_alloc_retries, + "memory/num_ooms": device_mem_stats.num_ooms, + } + metric_logger.log(metrics, step=train_state.step) + + logger.info( + f"{color.red}step: {train_state.step:2} " + f"{color.green}loss: {global_avg_loss:7.4f} " + f"{color.yellow}memory: {device_mem_stats.max_reserved_gib:5.2f}GiB" + f"({device_mem_stats.max_reserved_pct:.2f}%) " + f"{color.blue}tps: {round(tps):,} " + f"{color.cyan}tflops: {tflops:,.2f} " + f"{color.magenta}mfu: {mfu:.2f}%{color.reset}" + ) + + ntokens_since_last_log = 0 + data_loading_times.clear() + time_last_log = time.perf_counter() + device_memory_monitor.reset_peak_stats() + + checkpoint.save( + train_state.step, force=(train_state.step == job_config.training.steps) + ) + + # signal the profiler that the next profiling step has started + if torch_profiler: + torch_profiler.step() + if memory_profiler: + memory_profiler.step() + + # reduce timeout after first train step for faster signal + # (assuming lazy init and compilation are finished) + if train_state.step == 1: + dist_utils.set_pg_timeouts( + timeout=timedelta(seconds=job_config.comm.train_timeout_seconds), + world_mesh=world_mesh, + ) + + if torch.distributed.get_rank() == 0: + logger.info("Sleeping 2 seconds for other ranks to complete") + time.sleep(2) + + metric_logger.close() + logger.info("Training completed") + + +if __name__ == "__main__": + init_logger() + config = JobConfig() + config.parse_args() + main(config) + torch.distributed.destroy_process_group() From ce37765d4bc7709ac526ce9371776bafccf891c7 Mon Sep 17 00:00:00 2001 From: junjzhang Date: Tue, 4 Mar 2025 16:51:10 +0800 Subject: [PATCH 2/5] Tried to reuse codes --- .../experiments/train_llama_hf/dataset.py | 53 ++------- .../train_llama_hf/model/__init__.py | 6 +- .../train_llama_hf/model/parallelize_llama.py | 103 +----------------- 3 files changed, 20 insertions(+), 142 deletions(-) diff --git a/torchtitan/experiments/train_llama_hf/dataset.py b/torchtitan/experiments/train_llama_hf/dataset.py index a77e5102e8..b8bcbcd01c 100644 --- a/torchtitan/experiments/train_llama_hf/dataset.py +++ b/torchtitan/experiments/train_llama_hf/dataset.py @@ -8,19 +8,15 @@ import torch -from datasets import Dataset -from datasets.distributed import split_dataset_by_node -from torch.distributed.checkpoint.stateful import Stateful -from torch.utils.data import IterableDataset from transformers import PreTrainedTokenizerBase from torchtitan.components.dataloader import ParallelAwareDataloader from torchtitan.config_manager import JobConfig -from torchtitan.datasets.hf_datasets import _validate_dataset +from torchtitan.datasets.hf_datasets import HuggingFaceDataset from torchtitan.tools.logging import logger -class HuggingFaceDataset(IterableDataset, Stateful): +class HuggingFaceDatasetWithPos(HuggingFaceDataset): def __init__( self, dataset_name: str, @@ -31,33 +27,15 @@ def __init__( dp_world_size: int = 1, infinite: bool = False, ) -> None: - # Force lowercase for consistent comparison - dataset_name = dataset_name.lower() - - path, dataset_loader, text_processor = _validate_dataset( - dataset_name, dataset_path + super().__init__( + dataset_name, + dataset_path, + tokenizer, + seq_len, + dp_rank, + dp_world_size, + infinite, ) - ds = dataset_loader(path) - - self.dataset_name = dataset_name - self._data = split_dataset_by_node(ds, dp_rank, dp_world_size) - self._tokenizer = tokenizer - self.seq_len = seq_len - self.infinite = infinite - self._text_processor = text_processor - - # Variables for checkpointing - self._sample_idx = 0 - self._all_tokens: list[int] = [] - - def _get_data_iter(self): - if isinstance(self._data, Dataset) and self._sample_idx == len(self._data): - return iter([]) - - it = iter(self._data) - for _ in range(self._sample_idx): - next(it) - return it def __iter__(self): max_buffer_token_len = 1 + self.seq_len @@ -88,15 +66,8 @@ def __iter__(self): self._sample_idx = 0 logger.warning(f"Dataset {self.dataset_name} is being re-looped") - def load_state_dict(self, state_dict): - self._sample_idx = state_dict["sample_idx"] - self._all_tokens = state_dict["token_buffer"] - - def state_dict(self): - return {"token_buffer": self._all_tokens, "sample_idx": self._sample_idx} - -def build_hf_dataloader( +def build_pos_included_hf_dataloader( dp_world_size: int, dp_rank: int, tokenizer, @@ -109,7 +80,7 @@ def build_hf_dataloader( batch_size = job_config.training.batch_size seq_len = job_config.training.seq_len - hf_ds = HuggingFaceDataset( + hf_ds = HuggingFaceDatasetWithPos( dataset_name=dataset_name, dataset_path=dataset_path, tokenizer=tokenizer, diff --git a/torchtitan/experiments/train_llama_hf/model/__init__.py b/torchtitan/experiments/train_llama_hf/model/__init__.py index 5126818b05..4213dcc096 100644 --- a/torchtitan/experiments/train_llama_hf/model/__init__.py +++ b/torchtitan/experiments/train_llama_hf/model/__init__.py @@ -10,7 +10,9 @@ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from torchtitan.components.optimizer import build_lr_schedulers, build_optimizers -from torchtitan.experiments.train_llama_hf.dataset import build_hf_dataloader +from torchtitan.experiments.train_llama_hf.dataset import ( + build_pos_included_hf_dataloader, +) from torchtitan.protocols.train_spec import register_train_spec, TrainSpec from .parallelize_llama import parallelize_llama @@ -25,7 +27,7 @@ pipelining_fn=pipeline_llama, build_optimizers_fn=build_optimizers, build_lr_schedulers_fn=build_lr_schedulers, - build_dataloader_fn=build_hf_dataloader, + build_dataloader_fn=build_pos_included_hf_dataloader, tokenizer_cls=AutoTokenizer, loss_fn=cross_entropy_loss_hf, ) diff --git a/torchtitan/experiments/train_llama_hf/model/parallelize_llama.py b/torchtitan/experiments/train_llama_hf/model/parallelize_llama.py index 160edd983a..0bb0186393 100644 --- a/torchtitan/experiments/train_llama_hf/model/parallelize_llama.py +++ b/torchtitan/experiments/train_llama_hf/model/parallelize_llama.py @@ -7,7 +7,6 @@ # This file applies the PT-D parallelisms (except pipeline parallelism) and various # training techniques (e.g. activation checkpointing and compile) to the Llama model. -from collections import defaultdict import torch import torch.nn as nn @@ -18,11 +17,7 @@ fully_shard, MixedPrecisionPolicy, ) -from torch.distributed._composable.replicate import replicate from torch.distributed._tensor import Replicate, Shard -from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - checkpoint_wrapper as ptd_checkpoint_wrapper, -) from torch.distributed.tensor.parallel import ( ColwiseParallel, parallelize_module, @@ -33,6 +28,10 @@ from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims +from torchtitan.models.llama.parallelize_llama import ( + _apply_ac_to_transformer_block, + apply_ddp, +) from torchtitan.tools.logging import logger @@ -211,81 +210,6 @@ def apply_tp( ) -# for selective op activation checkpointing -_save_list = { - torch.ops.aten.mm.default, - torch.ops.aten._scaled_dot_product_efficient_attention.default, - torch.ops.aten._scaled_dot_product_flash_attention.default, - torch.ops._c10d_functional.reduce_scatter_tensor.default, - # for low precision training, it's useful to always save - # the result of max, since the absolute maximum is - # used to compute the scaling factor for quantization. - torch.ops.aten.max.default, -} - - -def _apply_ac_to_transformer_block(module: nn.Module, ac_config): - valid_ac_modes = ("full", "selective") - if ac_config.mode not in valid_ac_modes: - raise ValueError( - f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}" - ) - - if ac_config.mode == "full": - return ptd_checkpoint_wrapper(module, preserve_rng_state=False) - - assert ac_config.mode == "selective", f"{ac_config.mode}" - use_op_sac = ac_config.selective_ac_option == "op" - use_layer_sac = ac_config.selective_ac_option.isdigit() - if not use_op_sac and not use_layer_sac: - raise ValueError( - f"Invalid selective AC option: {ac_config.selective_ac_option}. " - f"Valid options: 'op' or a positive int representing layer frequency" - ) - if use_op_sac: - from torch.utils.checkpoint import ( - CheckpointPolicy, - create_selective_checkpoint_contexts, - ) - - def _get_custom_policy(meta): - def _custom_policy(ctx, func, *args, **kwargs): - mode = "recompute" if ctx.is_recompute else "forward" - mm_count_key = f"{mode}_mm_count" - if func == torch.ops.aten.mm.default: - meta[mm_count_key] += 1 - # Saves output of all compute ops, except every second mm - to_save = func in _save_list and not ( - func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0 - ) - return ( - CheckpointPolicy.MUST_SAVE - if to_save - else CheckpointPolicy.PREFER_RECOMPUTE - ) - - return _custom_policy - - def selective_checkpointing_context_fn(): - meta = defaultdict(int) - return create_selective_checkpoint_contexts(_get_custom_policy(meta)) - - return ptd_checkpoint_wrapper( - module, - context_fn=selective_checkpointing_context_fn, - preserve_rng_state=False, - ) - elif use_layer_sac: - # Checkpoint every `ac_freq` of the modules passed to this function - ac_freq = int(ac_config.selective_ac_option) - ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0) - ptd_checkpoint_wrapper._count += 1 - if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0: - return ptd_checkpoint_wrapper(module, preserve_rng_state=False) - else: - return module - - def apply_ac(model: nn.Module, ac_config): """Apply activation checkpointing to the model.""" for layer_id, transformer_block in model.model.layers.named_children(): @@ -361,22 +285,3 @@ def apply_fsdp( reshard_after_forward=reshard_after_forward, ) fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled) - - -def apply_ddp( - model: nn.Module, - dp_mesh: DeviceMesh, - enable_compile: bool, - enable_compiled_autograd: bool, -): - if enable_compile: - if enable_compiled_autograd: - torch._dynamo.config.optimize_ddp = ( - "python_reducer_without_compiled_forward" - ) - else: - torch._dynamo.config.optimize_ddp = "ddp_optimizer" - - replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) - - logger.info("Applied DDP to the model") From 91838de487f807bacf4a38fb211d1a8133016fca Mon Sep 17 00:00:00 2001 From: junjzhang Date: Wed, 5 Mar 2025 19:43:28 +0800 Subject: [PATCH 3/5] update readme --- torchtitan/experiments/train_llama_hf/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/torchtitan/experiments/train_llama_hf/README.md b/torchtitan/experiments/train_llama_hf/README.md index 08b0f081d5..08bbfed0d5 100644 --- a/torchtitan/experiments/train_llama_hf/README.md +++ b/torchtitan/experiments/train_llama_hf/README.md @@ -12,6 +12,7 @@ pip install -r extra_requirements.txt ### Test loading HF weights +NOTE: you need to have internet connection to download the weights. ```bash pytest test_loading_hf_weights.py ``` From accfa1f31834372323bbfd14104753ead8905a8d Mon Sep 17 00:00:00 2001 From: junjzhang Date: Thu, 6 Mar 2025 13:44:29 +0800 Subject: [PATCH 4/5] fix bugs in PP --- .../train_llama_hf/model/__init__.py | 2 +- .../train_llama_hf/model/parallelize_llama.py | 14 ++++++++--- .../train_llama_hf/model/pipeline_llama.py | 25 +++++++++---------- .../test_loading_hf_weights_helper.py | 9 ++++--- 4 files changed, 29 insertions(+), 21 deletions(-) diff --git a/torchtitan/experiments/train_llama_hf/model/__init__.py b/torchtitan/experiments/train_llama_hf/model/__init__.py index 4213dcc096..ed19d620b7 100644 --- a/torchtitan/experiments/train_llama_hf/model/__init__.py +++ b/torchtitan/experiments/train_llama_hf/model/__init__.py @@ -6,13 +6,13 @@ # # Copyright (c) Meta Platforms, Inc. All Rights Reserved. -from loss import cross_entropy_loss_hf from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from torchtitan.components.optimizer import build_lr_schedulers, build_optimizers from torchtitan.experiments.train_llama_hf.dataset import ( build_pos_included_hf_dataloader, ) +from torchtitan.experiments.train_llama_hf.loss import cross_entropy_loss_hf from torchtitan.protocols.train_spec import register_train_spec, TrainSpec from .parallelize_llama import parallelize_llama diff --git a/torchtitan/experiments/train_llama_hf/model/parallelize_llama.py b/torchtitan/experiments/train_llama_hf/model/parallelize_llama.py index 0bb0186393..2e632233cb 100644 --- a/torchtitan/experiments/train_llama_hf/model/parallelize_llama.py +++ b/torchtitan/experiments/train_llama_hf/model/parallelize_llama.py @@ -167,7 +167,11 @@ def apply_tp( # NOTE: At the cost of model code change, we can accelerate Sequence Parallel # by folding (and unfolding) the batch dimension and the sequence dimension. # Examples can be found at https://github.com/pytorch/torchtitan/pull/437 - for transformer_block in model.model.layers: + if isinstance(model.model.layers, nn.ModuleDict): + transformer_blocks = model.model.layers.values() + else: + transformer_blocks = model.model.layers + for transformer_block in transformer_blocks: layer_plan = { "input_layernorm": SequenceParallel(), "self_attn": prepare_module_input( @@ -260,8 +264,12 @@ def apply_fsdp( fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} if cpu_offload: fsdp_config["offload_policy"] = CPUOffloadPolicy() + if isinstance(model.model.layers, nn.ModuleDict): + layer_items = [(int(k), v) for (k, v) in model.model.layers.items()] + else: + layer_items = list(enumerate(model.model.layers)) - for layer_id, transformer_block in enumerate(model.model.layers): + for layer_id, transformer_block in layer_items: if reshard_after_forward_policy == "always": reshard_after_forward = True elif reshard_after_forward_policy == "never": @@ -274,7 +282,7 @@ def apply_fsdp( else: # As an optimization, do not reshard after forward for the last # transformer block since FSDP would prefetch it immediately - reshard_after_forward = int(layer_id) < len(model.model.layers) - 1 + reshard_after_forward = layer_id < len(layer_items) - 1 else: raise ValueError( f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}." diff --git a/torchtitan/experiments/train_llama_hf/model/pipeline_llama.py b/torchtitan/experiments/train_llama_hf/model/pipeline_llama.py index 6e6a815134..bb0092b859 100644 --- a/torchtitan/experiments/train_llama_hf/model/pipeline_llama.py +++ b/torchtitan/experiments/train_llama_hf/model/pipeline_llama.py @@ -29,7 +29,6 @@ ) from torchtitan.tools.logging import logger - DeviceType = Union[int, str, torch.device] @@ -87,8 +86,10 @@ def forward( # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) - # decoder layers - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + # decoder layers, ok since ModuleDict is ordered + for decoder_layer in list(self.layers.values())[ + : self.config.num_hidden_layers + ]: if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( @@ -217,6 +218,10 @@ def pipeline_llama( model_config: PretrainedConfig, loss_fn: Callable[..., torch.Tensor], ) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]: + logger.info("Changing model.model.layers to nn.ModuleDict") + model.model.layers = nn.ModuleDict( + {str(i): layer for i, layer in enumerate(model.model.layers)} + ) logger.info( "Patching Llama forward method for pipeline parallelism, it will disable some features of orignal HF model" ) @@ -277,20 +282,14 @@ def _build_stage( model.model.embed_tokens = None drop_layers = start_layer is not None - del_indexes = [] - for i in range(len(model.model.layers)): + for name in list(model.model.layers.keys()): # we keep layers in a contiguous region between start (inclusive) and stop (exclusive) - if f"layers.{i}" == start_layer: + if f"layers.{name}" == start_layer: drop_layers = False - if f"layers.{i}" == stop_layer: + if f"layers.{name}" == stop_layer: drop_layers = True if drop_layers: - del_indexes.append(i) - - # delete layers in reverse order to avoid index shifting - del_indexes.reverse() - for i in del_indexes: - del model.model.layers[i] + del model.model.layers[name] if not is_last: model.model.norm = None diff --git a/torchtitan/experiments/train_llama_hf/test_loading_hf_weights_helper.py b/torchtitan/experiments/train_llama_hf/test_loading_hf_weights_helper.py index 25aca1cf38..6f3a913ea1 100644 --- a/torchtitan/experiments/train_llama_hf/test_loading_hf_weights_helper.py +++ b/torchtitan/experiments/train_llama_hf/test_loading_hf_weights_helper.py @@ -16,13 +16,13 @@ normalize_state_dict_key, ) +from torchtitan.experiments.train_llama_hf.loss import cross_entropy_loss_hf + from torchtitan.experiments.train_llama_hf.model.parallelize_llama import ( apply_fsdp, apply_tp, ) -from torchtitan.experiments.train_llama_hf.model.pipeline_llama import ( - pipeline_llama_manual_split, -) +from torchtitan.experiments.train_llama_hf.model.pipeline_llama import pipeline_llama def main(job_config: JobConfig): @@ -52,13 +52,14 @@ def main(job_config: JobConfig): # apply parallelisms if parallel_dims.pp_enabled: # apply PT-D Pipeline Parallel - _, model_parts = pipeline_llama_manual_split( + _, model_parts, _, _ = pipeline_llama( model, world_mesh["pp"], parallel_dims, job_config, device, model_config, + loss_fn=cross_entropy_loss_hf, ) else: model_parts = [model] From 97153154d68e51feb3d082cb343bd6b1a7fecae4 Mon Sep 17 00:00:00 2001 From: junjzhang Date: Thu, 6 Mar 2025 14:32:23 +0800 Subject: [PATCH 5/5] fix bugs in CP --- torchtitan/experiments/train_llama_hf/train_llama_hf.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchtitan/experiments/train_llama_hf/train_llama_hf.py b/torchtitan/experiments/train_llama_hf/train_llama_hf.py index 5dc1176b46..0bbd603fa2 100644 --- a/torchtitan/experiments/train_llama_hf/train_llama_hf.py +++ b/torchtitan/experiments/train_llama_hf/train_llama_hf.py @@ -307,11 +307,11 @@ def main(job_config: JobConfig): # apply context parallelism if cp is enabled # ensure CP handles the separate freqs_cis buffer for each pp stage optional_context_parallel_ctx = ( - utils.create_context_parallel_ctx( + dist_utils.create_context_parallel_ctx( cp_mesh=world_mesh["cp"], - cp_buffers=[input_ids, labels, position_ids], + cp_buffers=[input_ids, position_ids, labels], cp_seq_dims=[1, 1, 1], - cp_no_restore_buffers={input_ids, labels}, + cp_no_restore_buffers={input_ids, position_ids, labels}, cp_rotate_method=job_config.experimental.context_parallel_rotate_method, ) if parallel_dims.cp_enabled @@ -344,7 +344,7 @@ def main(job_config: JobConfig): else: # Non-PP forward / backward with train_context(optional_context_parallel_ctx): - pred = model(input_ids) + pred = model(input_ids, position_ids=position_ids) loss = train_spec.loss_fn(pred, labels) # pred.shape=(bs, seq_len, vocab_size) # need to free to before bwd to avoid peaking memory