Skip to content

Commit f111334

Browse files
committed
finished training with hf models
1 parent b291ad6 commit f111334

17 files changed

+10209
-5
lines changed

pixi.lock

Lines changed: 8277 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,17 @@ authors = [
1010
]
1111
keywords = ["pytorch", "training", "llm"]
1212
dependencies = [
13-
# Stateful Dataloader
14-
"torchdata>=0.8.0",
15-
1613
# Hugging Face integrations
17-
"datasets>=2.21.0",
14+
"datasets>2.21",
1815

1916
# Tokenization
2017
"blobfile",
2118
"tiktoken",
2219

2320
# Miscellaneous
24-
"tomli>=1.1.0"
21+
"tomli>=1.1.0",
22+
23+
"torchdata>=0.10.1,<0.11",
2524
]
2625
dynamic = ["version"]
2726

@@ -51,5 +50,58 @@ build-backend = "setuptools.build_meta"
5150
where = [""]
5251
include = ["torchtitan*"]
5352

53+
[tool.setuptools.package-data]
54+
recipes = ["train_configs/*.toml"]
55+
5456
[tool.pytest.ini_options]
5557
addopts = ["--showlocals"] # show local variables in tracebacks
58+
59+
[tool.pixi.project]
60+
channels = ["https://prefix.dev/meta-forge", "conda-forge", "https://prefix.dev/meta-pty"]
61+
platforms = ["linux-64"]
62+
63+
[tool.pixi.pypi-dependencies]
64+
torchtitan = { path = ".", editable = true }
65+
66+
[tool.pixi.environments]
67+
default = { solve-group = "default" }
68+
dev = { features = ["dev"], solve-group = "default" }
69+
70+
[tool.pixi.tasks]
71+
72+
[tool.pixi.dependencies]
73+
pytorch = { version = "==2.6.0", build = "*cuda*" }
74+
wandb = ">=0.19.6,<0.20"
75+
tensorboard = ">=2.18.0,<3"
76+
blobfile = ">=3.0.0,<4"
77+
tabulate = ">=0.9.0,<0.10"
78+
fsspec = "<=2024.12.0"
79+
torchao = ">=0.8.0,<0.9"
80+
gcc = "<=13.2.0"
81+
gxx = "<=13.2.0"
82+
cuda = "12.6.*"
83+
transformers = ">=4.49.0,<5"
84+
universal_pathlib = ">=0.2.6,<0.3"
85+
s3fs = ">=2024.12.0,<2025"
86+
sentencepiece = ">=0.2.0,<0.3"
87+
88+
[tool.pixi.feature.dev.dependencies]
89+
pytest = ">=8.3.4,<9"
90+
pytest-cov = ">=6.0.0,<7"
91+
pre-commit = "*"
92+
expecttest = ">=0.3.0,<0.4"
93+
ipykernel = ">=6.29.5,<7"
94+
pytz = ">=2025.1,<2026"
95+
parakernel = ">=0.1.3,<0.2"
96+
97+
[tool.pixi.system-requirements]
98+
cuda = "12.7"
99+
100+
[tool.ruff]
101+
select = ["E", "W", "F", "B"]
102+
103+
[tool.ruff.lint]
104+
extend-select = ["E", "W", "F", "B"]
105+
106+
[tool.ruff.format]
107+
select = ["E", "W", "F", "B"]

torchtitan/experiments/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Training LLAMA with HF weights
2+
3+
This directory contains scripts and configs for training LLAMA with HF weights using TorchTitan.
4+
5+
## Usage
6+
7+
### Install extra dependencies
8+
9+
```bash
10+
pip install -r extra_requirements.txt
11+
```
12+
13+
### Test loading HF weights
14+
15+
```bash
16+
pytest test_loading_hf_weights.py
17+
```
18+
19+
### Run training
20+
21+
```bash
22+
LOG_RANK=7 bash run_train.sh
23+
```
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
# Llama 3 is licensed under the LLAMA 3 Community License,
8+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
9+
10+
import torchtitan.experiments.train_llama_hf.model # noqa: F401
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from dataclasses import dataclass
8+
from typing import Any, Callable, Optional
9+
10+
import torch
11+
12+
from datasets import Dataset, load_dataset
13+
from datasets.distributed import split_dataset_by_node
14+
from torch.distributed.checkpoint.stateful import Stateful
15+
from torch.utils.data import IterableDataset
16+
from transformers import PreTrainedTokenizerBase
17+
18+
from torchtitan.components.dataloader import ParallelAwareDataloader
19+
20+
from torchtitan.config_manager import JobConfig
21+
from torchtitan.tools.logging import logger
22+
23+
24+
def _load_c4_dataset(dataset_path: str):
25+
"""Load C4 dataset with default configuration."""
26+
return load_dataset(dataset_path, name="en", split="train", streaming=True)
27+
28+
29+
def _process_c4_text(sample: dict[str, Any]) -> str:
30+
"""Process C4 dataset sample text."""
31+
return sample["text"]
32+
33+
34+
@dataclass
35+
class DatasetConfig:
36+
path: str
37+
loader: Callable
38+
text_processor: Callable
39+
40+
41+
# Add your dataset here here - more information at docs/datasets.md
42+
DATASETS = {
43+
"c4": DatasetConfig(
44+
path="allenai/c4",
45+
loader=_load_c4_dataset,
46+
text_processor=_process_c4_text,
47+
),
48+
"c4_test": DatasetConfig(
49+
path="tests/assets/c4_test",
50+
loader=lambda path: load_dataset(path, split="train"),
51+
text_processor=_process_c4_text,
52+
),
53+
}
54+
55+
56+
def _validate_dataset(
57+
dataset_name: str, dataset_path: str = None
58+
) -> tuple[str, Callable, Callable]:
59+
"""Validate dataset name and path."""
60+
if dataset_name not in DATASETS:
61+
raise ValueError(
62+
f"Dataset {dataset_name} is not supported. "
63+
f"Supported datasets are: {list(DATASETS.keys())}"
64+
)
65+
66+
config = DATASETS[dataset_name]
67+
path = dataset_path or config.path
68+
logger.info(f"Preparing {dataset_name} dataset from {path}")
69+
return path, config.loader, config.text_processor
70+
71+
72+
class HuggingFaceDataset(IterableDataset, Stateful):
73+
def __init__(
74+
self,
75+
dataset_name: str,
76+
dataset_path: Optional[str],
77+
tokenizer: PreTrainedTokenizerBase,
78+
seq_len: int = 2048,
79+
dp_rank: int = 0,
80+
dp_world_size: int = 1,
81+
infinite: bool = False,
82+
) -> None:
83+
# Force lowercase for consistent comparison
84+
dataset_name = dataset_name.lower()
85+
86+
path, dataset_loader, text_processor = _validate_dataset(
87+
dataset_name, dataset_path
88+
)
89+
ds = dataset_loader(path)
90+
91+
self.dataset_name = dataset_name
92+
self._data = split_dataset_by_node(ds, dp_rank, dp_world_size)
93+
self._tokenizer = tokenizer
94+
self.seq_len = seq_len
95+
self.infinite = infinite
96+
self._text_processor = text_processor
97+
98+
# Variables for checkpointing
99+
self._sample_idx = 0
100+
self._all_tokens: list[int] = []
101+
102+
def _get_data_iter(self):
103+
if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
104+
return iter([])
105+
106+
it = iter(self._data)
107+
for _ in range(self._sample_idx):
108+
next(it)
109+
return it
110+
111+
def __iter__(self):
112+
max_buffer_token_len = 1 + self.seq_len
113+
114+
while True:
115+
for sample in self._get_data_iter():
116+
# Use the dataset-specific text processor
117+
sample_text = self._text_processor(sample)
118+
sample_tokens = self._tokenizer.encode(sample_text)
119+
self._all_tokens.extend(sample_tokens)
120+
self._sample_idx += 1
121+
122+
while len(self._all_tokens) >= max_buffer_token_len:
123+
x = torch.LongTensor(self._all_tokens[:max_buffer_token_len])
124+
# update tokens to the remaining tokens
125+
self._all_tokens = self._all_tokens[max_buffer_token_len:]
126+
input = x[:-1]
127+
label = x[1:]
128+
# Add position IDs (0 to seq_len-1)
129+
position_ids = torch.arange(len(input), dtype=torch.long)
130+
yield input, label, position_ids
131+
132+
if not self.infinite:
133+
logger.warning(f"Dataset {self.dataset_name} has run out of data")
134+
break
135+
else:
136+
# Reset offset for the next iteration
137+
self._sample_idx = 0
138+
logger.warning(f"Dataset {self.dataset_name} is being re-looped")
139+
140+
def load_state_dict(self, state_dict):
141+
self._sample_idx = state_dict["sample_idx"]
142+
self._all_tokens = state_dict["token_buffer"]
143+
144+
def state_dict(self):
145+
return {"token_buffer": self._all_tokens, "sample_idx": self._sample_idx}
146+
147+
148+
def build_hf_dataloader(
149+
dp_world_size: int,
150+
dp_rank: int,
151+
tokenizer,
152+
job_config: JobConfig,
153+
infinite: bool = True,
154+
) -> ParallelAwareDataloader:
155+
"""Build a data loader for HuggingFace datasets."""
156+
dataset_name = job_config.training.dataset
157+
dataset_path = job_config.training.dataset_path
158+
batch_size = job_config.training.batch_size
159+
seq_len = job_config.training.seq_len
160+
161+
hf_ds = HuggingFaceDataset(
162+
dataset_name=dataset_name,
163+
dataset_path=dataset_path,
164+
tokenizer=tokenizer,
165+
seq_len=seq_len,
166+
dp_rank=dp_rank,
167+
dp_world_size=dp_world_size,
168+
infinite=infinite,
169+
)
170+
171+
return ParallelAwareDataloader(
172+
dataset=hf_ds,
173+
dp_rank=dp_rank,
174+
dp_world_size=dp_world_size,
175+
batch_size=batch_size,
176+
)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
transformers >=4.49.0
2+
sentencepiece >=0.2.0

0 commit comments

Comments
 (0)