Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions torchtitan/experiments/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
23 changes: 23 additions & 0 deletions torchtitan/experiments/train_llama_hf/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Training LLAMA with HF weights

This directory contains scripts and configs for training LLAMA with HF weights using TorchTitan.

## Usage

### Install extra dependencies

```bash
pip install -r extra_requirements.txt
```

### Test loading HF weights

```bash
pytest test_loading_hf_weights.py
```

### Run training

```bash
LOG_RANK=7 bash run_train.sh
```
10 changes: 10 additions & 0 deletions torchtitan/experiments/train_llama_hf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Llama 3 is licensed under the LLAMA 3 Community License,
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

import torchtitan.experiments.train_llama_hf.model # noqa: F401
127 changes: 127 additions & 0 deletions torchtitan/experiments/train_llama_hf/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional

import torch

from datasets import Dataset
from datasets.distributed import split_dataset_by_node
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import IterableDataset
from transformers import PreTrainedTokenizerBase

from torchtitan.components.dataloader import ParallelAwareDataloader
from torchtitan.config_manager import JobConfig
from torchtitan.datasets.hf_datasets import _validate_dataset
from torchtitan.tools.logging import logger


class HuggingFaceDataset(IterableDataset, Stateful):
def __init__(
self,
dataset_name: str,
dataset_path: Optional[str],
tokenizer: PreTrainedTokenizerBase,
seq_len: int = 2048,
dp_rank: int = 0,
dp_world_size: int = 1,
infinite: bool = False,
) -> None:
# Force lowercase for consistent comparison
dataset_name = dataset_name.lower()

path, dataset_loader, text_processor = _validate_dataset(
dataset_name, dataset_path
)
ds = dataset_loader(path)

self.dataset_name = dataset_name
self._data = split_dataset_by_node(ds, dp_rank, dp_world_size)
self._tokenizer = tokenizer
self.seq_len = seq_len
self.infinite = infinite
self._text_processor = text_processor

# Variables for checkpointing
self._sample_idx = 0
self._all_tokens: list[int] = []

def _get_data_iter(self):
if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
return iter([])

it = iter(self._data)
for _ in range(self._sample_idx):
next(it)
return it

def __iter__(self):
max_buffer_token_len = 1 + self.seq_len

while True:
for sample in self._get_data_iter():
# Use the dataset-specific text processor
sample_text = self._text_processor(sample)
sample_tokens = self._tokenizer.encode(sample_text)
self._all_tokens.extend(sample_tokens)
self._sample_idx += 1

while len(self._all_tokens) >= max_buffer_token_len:
x = torch.LongTensor(self._all_tokens[:max_buffer_token_len])
# update tokens to the remaining tokens
self._all_tokens = self._all_tokens[max_buffer_token_len:]
input = x[:-1]
label = x[1:]
# Add position IDs (0 to seq_len-1)
position_ids = torch.arange(len(input), dtype=torch.long)
yield input, label, position_ids

if not self.infinite:
logger.warning(f"Dataset {self.dataset_name} has run out of data")
break
else:
# Reset offset for the next iteration
self._sample_idx = 0
logger.warning(f"Dataset {self.dataset_name} is being re-looped")

def load_state_dict(self, state_dict):
self._sample_idx = state_dict["sample_idx"]
self._all_tokens = state_dict["token_buffer"]

def state_dict(self):
return {"token_buffer": self._all_tokens, "sample_idx": self._sample_idx}


def build_hf_dataloader(
dp_world_size: int,
dp_rank: int,
tokenizer,
job_config: JobConfig,
infinite: bool = True,
) -> ParallelAwareDataloader:
"""Build a data loader for HuggingFace datasets."""
dataset_name = job_config.training.dataset
dataset_path = job_config.training.dataset_path
batch_size = job_config.training.batch_size
seq_len = job_config.training.seq_len

hf_ds = HuggingFaceDataset(
dataset_name=dataset_name,
dataset_path=dataset_path,
tokenizer=tokenizer,
seq_len=seq_len,
dp_rank=dp_rank,
dp_world_size=dp_world_size,
infinite=infinite,
)

return ParallelAwareDataloader(
dataset=hf_ds,
dp_rank=dp_rank,
dp_world_size=dp_world_size,
batch_size=batch_size,
)
2 changes: 2 additions & 0 deletions torchtitan/experiments/train_llama_hf/extra_requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
transformers >=4.49.0
sentencepiece >=0.2.0
188 changes: 188 additions & 0 deletions torchtitan/experiments/train_llama_hf/hf_weights_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import gc
import json
from collections import defaultdict
from pathlib import Path

import torch.nn as nn
from huggingface_hub import repo_exists, snapshot_download
from safetensors import safe_open
from torch.distributed.tensor import distribute_tensor, DTensor
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME

from torchtitan.tools.logging import logger

INDEX_NAME_MAPPING = {
"safetensors": SAFE_WEIGHTS_INDEX_NAME,
}

PATTERNS_TO_REMOVE = [
"._orig_mod", # Some optimizers add suffixes
"._fsdp_wrapped_module", # FSDP wrapper
"._checkpoint_wrapped_module", # checkpoint wrapper
".module", # DataParallel/DistributedDataParallel
"_module.", # Some wrappers add prefix
]


def normalize_state_dict_key(
key: str, patterns_to_remove: list[str] = PATTERNS_TO_REMOVE
) -> str:
"""
Normalize the state dict key, remove the prefix or suffix added by various wrappers.
Args:
key: The original state dict key
Returns:
The normalized key
"""
normalized_key = key
for pattern in patterns_to_remove:
normalized_key = normalized_key.replace(pattern, "")

return normalized_key


def get_weight_map(pretrained_model_path: Path) -> dict[str, str]:
"""
Get the weight map from the pretrained model.
Args:
pretrained_model_path: The path to the pretrained model.
Returns:
weight_map: A dictionary mapping from the path to the weight map to the list of state dict keys.
"""
index_file = pretrained_model_path / INDEX_NAME_MAPPING["safetensors"]
if not index_file.exists():
return None
with open(index_file, "r") as f:
metadata = json.load(f)
return metadata["weight_map"]


def group_state_dict_keys_and_st_partition_paths(
pretrained_model_path: Path,
state_dict_keys,
weight_map,
state_dict_map: dict[str, str] = None,
):
"""
Group state dict keys and save them to a file.
Args:
pretrained_model_path: The path to the pretrained model.
state_dict_keys: The state dict keys to group.
weight_map: The weight map.
state_dict_map: A dictionary mapping from the state dict key to the weight path.
Returns:
st_partition_map: A dictionary mapping from the weight path to the list of state dict keys.
"""
st_partition_map = defaultdict(list)
for state_dict_key in state_dict_keys:
ckpt_state_dict_key = (
state_dict_map[state_dict_key]
if state_dict_map is not None
else state_dict_key
)
if weight_map is None:
partition_path = pretrained_model_path / "model.safetensors"
else:
partition_path = pretrained_model_path / weight_map[ckpt_state_dict_key]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A KeyError will be raised if a checkpoint key is missing from weight_map. For example, when tie_word_embeddings = True, model.lm_head.weight is not present in weight_map. So, how should we initialize or load the weight in this case? Currently, a KeyError occurs. If we ignore it, the loaded model will be incorrect—specifically, model.lm_head.weight will be set to zero.

Llama-3.2-3B can be used as a model for verification.

st_partition_map[partition_path].append(state_dict_key)
return st_partition_map


def load_sharded_state_dict_for_model_from_path(
pretrained_model_path: Path,
model: nn.Module,
mapping_dict: dict[str, str] = None,
**kwargs,
):
"""
Load the state dict sharded (depends on DTensor) from the pretrained model path. It only load the weights for current rank.
Args:
pretrained_model_path: The path to the pretrained model, it could be a local path or an s3 path.
model: The model to load the state dict into.
**kwargs: other arguments for torch.nn.Module.load_state_dict
"""
# check exceptions
if not pretrained_model_path.exists():
raise ValueError(
f"The pretrained model path {pretrained_model_path} does not exist."
)
if not pretrained_model_path.is_dir():
raise ValueError(
f"The pretrained model path {pretrained_model_path} is not a directory."
)
# get the weight map
weight_map = get_weight_map(pretrained_model_path)
model_state_dict = model.state_dict()
model_state_dict_keys = list(model_state_dict.keys())

# create a mapping_dict between the original state_dict_key and the weight_map_key if not provided
mapping_dict = (
mapping_dict
if mapping_dict is not None
else {key: normalize_state_dict_key(key) for key in model_state_dict_keys}
)
st_partition_map = group_state_dict_keys_and_st_partition_paths(
pretrained_model_path, model_state_dict_keys, weight_map, mapping_dict
)

# get the sharded state dict
state_dict = {}
for safetensor_partition_path, state_dict_keys in st_partition_map.items():
with safe_open(safetensor_partition_path, framework="pt", device="cpu") as f:
for state_dict_key in state_dict_keys:
model_tensor = model_state_dict[state_dict_key]
ckpt_state_dict_key = mapping_dict[state_dict_key]
if isinstance(model_tensor, DTensor):
local_tensor = f.get_tensor(ckpt_state_dict_key)
state_dict[state_dict_key] = distribute_tensor(
local_tensor,
model_tensor.device_mesh,
model_tensor.placements,
)
else:
state_dict[state_dict_key] = f.get_tensor(ckpt_state_dict_key)
model.load_state_dict(state_dict, **kwargs)
del state_dict
gc.collect()


def load_sharded_state_dict_for_model_from_hf(
pretrained_model_id_or_path: str,
model: nn.Module,
**kwargs,
):
"""
Load the state dict sharded (depends on DTensor) from the pretrained model path. It only load the weights for current rank.
Args:
pretrained_model_id_or_path: The id or path to the pretrained model, it could be a repo id in huggingface,
or a local path
model: The model to load the state dict into.
**kwargs: other arguments for torch.nn.Module.load_state_dict
"""
logger.info(f"Loading the state dict from {pretrained_model_id_or_path}")
pretrained_model_id_or_path = Path(pretrained_model_id_or_path)
if not pretrained_model_id_or_path.exists():
if not repo_exists(str(pretrained_model_id_or_path)):
raise ValueError(
f"The pretrained model {pretrained_model_id_or_path} does not exist"
)
logger.info(
f"Try to download the model from huggingface: {pretrained_model_id_or_path}"
)
pretrained_model_path = Path(
snapshot_download(str(pretrained_model_id_or_path))
)
elif not pretrained_model_id_or_path.is_dir():
raise ValueError(
f"The pretrained model path {pretrained_model_id_or_path} is not a directory."
)
else:
pretrained_model_path = pretrained_model_id_or_path

load_sharded_state_dict_for_model_from_path(pretrained_model_path, model, **kwargs)
14 changes: 14 additions & 0 deletions torchtitan/experiments/train_llama_hf/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch


def cross_entropy_loss_hf(preds, labels):
loss = torch.nn.functional.cross_entropy(
preds[0].flatten(0, 1).float(), labels.flatten(0, 1)
)
return loss
Loading