Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions torchtitan/experiments/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
24 changes: 24 additions & 0 deletions torchtitan/experiments/train_llama_hf/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Training LLAMA with HF weights

This directory contains scripts and configs for training LLAMA with HF weights using TorchTitan.

## Usage

### Install extra dependencies

```bash
pip install -r extra_requirements.txt
```

### Test loading HF weights

NOTE: you need to have internet connection to download the weights.
```bash
pytest test_loading_hf_weights.py
```

### Run training

```bash
LOG_RANK=7 bash run_train.sh
```
10 changes: 10 additions & 0 deletions torchtitan/experiments/train_llama_hf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Llama 3 is licensed under the LLAMA 3 Community License,
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

import torchtitan.experiments.train_llama_hf.model # noqa: F401
98 changes: 98 additions & 0 deletions torchtitan/experiments/train_llama_hf/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional

import torch

from transformers import PreTrainedTokenizerBase

from torchtitan.components.dataloader import ParallelAwareDataloader
from torchtitan.config_manager import JobConfig
from torchtitan.datasets.hf_datasets import HuggingFaceDataset
from torchtitan.tools.logging import logger


class HuggingFaceDatasetWithPos(HuggingFaceDataset):
def __init__(
self,
dataset_name: str,
dataset_path: Optional[str],
tokenizer: PreTrainedTokenizerBase,
seq_len: int = 2048,
dp_rank: int = 0,
dp_world_size: int = 1,
infinite: bool = False,
) -> None:
super().__init__(
dataset_name,
dataset_path,
tokenizer,
seq_len,
dp_rank,
dp_world_size,
infinite,
)

def __iter__(self):
max_buffer_token_len = 1 + self.seq_len

while True:
for sample in self._get_data_iter():
# Use the dataset-specific text processor
sample_text = self._text_processor(sample)
sample_tokens = self._tokenizer.encode(sample_text)
self._all_tokens.extend(sample_tokens)
self._sample_idx += 1

while len(self._all_tokens) >= max_buffer_token_len:
x = torch.LongTensor(self._all_tokens[:max_buffer_token_len])
# update tokens to the remaining tokens
self._all_tokens = self._all_tokens[max_buffer_token_len:]
input = x[:-1]
label = x[1:]
# Add position IDs (0 to seq_len-1)
position_ids = torch.arange(len(input), dtype=torch.long)
yield input, label, position_ids

if not self.infinite:
logger.warning(f"Dataset {self.dataset_name} has run out of data")
break
else:
# Reset offset for the next iteration
self._sample_idx = 0
logger.warning(f"Dataset {self.dataset_name} is being re-looped")


def build_pos_included_hf_dataloader(
dp_world_size: int,
dp_rank: int,
tokenizer,
job_config: JobConfig,
infinite: bool = True,
) -> ParallelAwareDataloader:
"""Build a data loader for HuggingFace datasets."""
dataset_name = job_config.training.dataset
dataset_path = job_config.training.dataset_path
batch_size = job_config.training.batch_size
seq_len = job_config.training.seq_len

hf_ds = HuggingFaceDatasetWithPos(
dataset_name=dataset_name,
dataset_path=dataset_path,
tokenizer=tokenizer,
seq_len=seq_len,
dp_rank=dp_rank,
dp_world_size=dp_world_size,
infinite=infinite,
)

return ParallelAwareDataloader(
dataset=hf_ds,
dp_rank=dp_rank,
dp_world_size=dp_world_size,
batch_size=batch_size,
)
2 changes: 2 additions & 0 deletions torchtitan/experiments/train_llama_hf/extra_requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
transformers >=4.49.0
sentencepiece >=0.2.0
188 changes: 188 additions & 0 deletions torchtitan/experiments/train_llama_hf/hf_weights_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import gc
import json
from collections import defaultdict
from pathlib import Path

import torch.nn as nn
from huggingface_hub import repo_exists, snapshot_download
from safetensors import safe_open
from torch.distributed.tensor import distribute_tensor, DTensor
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME

from torchtitan.tools.logging import logger

INDEX_NAME_MAPPING = {
"safetensors": SAFE_WEIGHTS_INDEX_NAME,
}

PATTERNS_TO_REMOVE = [
"._orig_mod", # Some optimizers add suffixes
"._fsdp_wrapped_module", # FSDP wrapper
"._checkpoint_wrapped_module", # checkpoint wrapper
".module", # DataParallel/DistributedDataParallel
"_module.", # Some wrappers add prefix
]


def normalize_state_dict_key(
key: str, patterns_to_remove: list[str] = PATTERNS_TO_REMOVE
) -> str:
"""
Normalize the state dict key, remove the prefix or suffix added by various wrappers.
Args:
key: The original state dict key
Returns:
The normalized key
"""
normalized_key = key
for pattern in patterns_to_remove:
normalized_key = normalized_key.replace(pattern, "")

return normalized_key


def get_weight_map(pretrained_model_path: Path) -> dict[str, str]:
"""
Get the weight map from the pretrained model.
Args:
pretrained_model_path: The path to the pretrained model.
Returns:
weight_map: A dictionary mapping from the path to the weight map to the list of state dict keys.
"""
index_file = pretrained_model_path / INDEX_NAME_MAPPING["safetensors"]
if not index_file.exists():
return None
with open(index_file, "r") as f:
metadata = json.load(f)
return metadata["weight_map"]


def group_state_dict_keys_and_st_partition_paths(
pretrained_model_path: Path,
state_dict_keys,
weight_map,
state_dict_map: dict[str, str] = None,
):
"""
Group state dict keys and save them to a file.
Args:
pretrained_model_path: The path to the pretrained model.
state_dict_keys: The state dict keys to group.
weight_map: The weight map.
state_dict_map: A dictionary mapping from the state dict key to the weight path.
Returns:
st_partition_map: A dictionary mapping from the weight path to the list of state dict keys.
"""
st_partition_map = defaultdict(list)
for state_dict_key in state_dict_keys:
ckpt_state_dict_key = (
state_dict_map[state_dict_key]
if state_dict_map is not None
else state_dict_key
)
if weight_map is None:
partition_path = pretrained_model_path / "model.safetensors"
else:
partition_path = pretrained_model_path / weight_map[ckpt_state_dict_key]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A KeyError will be raised if a checkpoint key is missing from weight_map. For example, when tie_word_embeddings = True, model.lm_head.weight is not present in weight_map. So, how should we initialize or load the weight in this case? Currently, a KeyError occurs. If we ignore it, the loaded model will be incorrect—specifically, model.lm_head.weight will be set to zero.

Llama-3.2-3B can be used as a model for verification.

st_partition_map[partition_path].append(state_dict_key)
return st_partition_map


def load_sharded_state_dict_for_model_from_path(
pretrained_model_path: Path,
model: nn.Module,
mapping_dict: dict[str, str] = None,
**kwargs,
):
"""
Load the state dict sharded (depends on DTensor) from the pretrained model path. It only load the weights for current rank.
Args:
pretrained_model_path: The path to the pretrained model, it could be a local path or an s3 path.
model: The model to load the state dict into.
**kwargs: other arguments for torch.nn.Module.load_state_dict
"""
# check exceptions
if not pretrained_model_path.exists():
raise ValueError(
f"The pretrained model path {pretrained_model_path} does not exist."
)
if not pretrained_model_path.is_dir():
raise ValueError(
f"The pretrained model path {pretrained_model_path} is not a directory."
)
# get the weight map
weight_map = get_weight_map(pretrained_model_path)
model_state_dict = model.state_dict()
model_state_dict_keys = list(model_state_dict.keys())

# create a mapping_dict between the original state_dict_key and the weight_map_key if not provided
mapping_dict = (
mapping_dict
if mapping_dict is not None
else {key: normalize_state_dict_key(key) for key in model_state_dict_keys}
)
st_partition_map = group_state_dict_keys_and_st_partition_paths(
pretrained_model_path, model_state_dict_keys, weight_map, mapping_dict
)

# get the sharded state dict
state_dict = {}
for safetensor_partition_path, state_dict_keys in st_partition_map.items():
with safe_open(safetensor_partition_path, framework="pt", device="cpu") as f:
for state_dict_key in state_dict_keys:
model_tensor = model_state_dict[state_dict_key]
ckpt_state_dict_key = mapping_dict[state_dict_key]
if isinstance(model_tensor, DTensor):
local_tensor = f.get_tensor(ckpt_state_dict_key)
state_dict[state_dict_key] = distribute_tensor(
local_tensor,
model_tensor.device_mesh,
model_tensor.placements,
)
else:
state_dict[state_dict_key] = f.get_tensor(ckpt_state_dict_key)
model.load_state_dict(state_dict, **kwargs)
del state_dict
gc.collect()


def load_sharded_state_dict_for_model_from_hf(
pretrained_model_id_or_path: str,
model: nn.Module,
**kwargs,
):
"""
Load the state dict sharded (depends on DTensor) from the pretrained model path. It only load the weights for current rank.
Args:
pretrained_model_id_or_path: The id or path to the pretrained model, it could be a repo id in huggingface,
or a local path
model: The model to load the state dict into.
**kwargs: other arguments for torch.nn.Module.load_state_dict
"""
logger.info(f"Loading the state dict from {pretrained_model_id_or_path}")
pretrained_model_id_or_path = Path(pretrained_model_id_or_path)
if not pretrained_model_id_or_path.exists():
if not repo_exists(str(pretrained_model_id_or_path)):
raise ValueError(
f"The pretrained model {pretrained_model_id_or_path} does not exist"
)
logger.info(
f"Try to download the model from huggingface: {pretrained_model_id_or_path}"
)
pretrained_model_path = Path(
snapshot_download(str(pretrained_model_id_or_path))
)
elif not pretrained_model_id_or_path.is_dir():
raise ValueError(
f"The pretrained model path {pretrained_model_id_or_path} is not a directory."
)
else:
pretrained_model_path = pretrained_model_id_or_path

load_sharded_state_dict_for_model_from_path(pretrained_model_path, model, **kwargs)
14 changes: 14 additions & 0 deletions torchtitan/experiments/train_llama_hf/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch


def cross_entropy_loss_hf(preds, labels):
loss = torch.nn.functional.cross_entropy(
preds[0].flatten(0, 1).float(), labels.flatten(0, 1)
)
return loss
34 changes: 34 additions & 0 deletions torchtitan/experiments/train_llama_hf/model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

from torchtitan.components.optimizer import build_lr_schedulers, build_optimizers
from torchtitan.experiments.train_llama_hf.dataset import (
build_pos_included_hf_dataloader,
)
from torchtitan.experiments.train_llama_hf.loss import cross_entropy_loss_hf
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec

from .parallelize_llama import parallelize_llama
from .pipeline_llama import pipeline_llama

register_train_spec(
TrainSpec(
name="llama3_hf",
cls=AutoModelForCausalLM,
config=AutoConfig,
parallelize_fn=parallelize_llama,
pipelining_fn=pipeline_llama,
build_optimizers_fn=build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_pos_included_hf_dataloader,
tokenizer_cls=AutoTokenizer,
loss_fn=cross_entropy_loss_hf,
)
)
Loading