-
Notifications
You must be signed in to change notification settings - Fork 617
[Experimental Feature] Huggingface model training #919
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
c412663
3c196f9
f620f3b
92f32f9
44b7395
ce37765
91838de
accfa1f
9715315
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| # Training LLAMA with HF weights | ||
|
|
||
| This directory contains scripts and configs for training LLAMA with HF weights using TorchTitan. | ||
|
|
||
| ## Usage | ||
|
|
||
| ### Install extra dependencies | ||
|
|
||
| ```bash | ||
| pip install -r extra_requirements.txt | ||
| ``` | ||
|
|
||
| ### Test loading HF weights | ||
|
|
||
| NOTE: you need to have internet connection to download the weights. | ||
| ```bash | ||
| pytest test_loading_hf_weights.py | ||
| ``` | ||
|
|
||
| ### Run training | ||
|
|
||
| ```bash | ||
| LOG_RANK=7 bash run_train.sh | ||
| ``` |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| # | ||
| # Llama 3 is licensed under the LLAMA 3 Community License, | ||
| # Copyright (c) Meta Platforms, Inc. All Rights Reserved. | ||
|
|
||
| import torchtitan.experiments.train_llama_hf.model # noqa: F401 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,98 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| from typing import Optional | ||
|
|
||
| import torch | ||
|
|
||
| from transformers import PreTrainedTokenizerBase | ||
|
|
||
| from torchtitan.components.dataloader import ParallelAwareDataloader | ||
| from torchtitan.config_manager import JobConfig | ||
| from torchtitan.datasets.hf_datasets import HuggingFaceDataset | ||
| from torchtitan.tools.logging import logger | ||
|
|
||
|
|
||
| class HuggingFaceDatasetWithPos(HuggingFaceDataset): | ||
| def __init__( | ||
| self, | ||
| dataset_name: str, | ||
| dataset_path: Optional[str], | ||
| tokenizer: PreTrainedTokenizerBase, | ||
| seq_len: int = 2048, | ||
| dp_rank: int = 0, | ||
| dp_world_size: int = 1, | ||
| infinite: bool = False, | ||
| ) -> None: | ||
| super().__init__( | ||
| dataset_name, | ||
| dataset_path, | ||
| tokenizer, | ||
| seq_len, | ||
| dp_rank, | ||
| dp_world_size, | ||
| infinite, | ||
| ) | ||
|
|
||
| def __iter__(self): | ||
| max_buffer_token_len = 1 + self.seq_len | ||
|
|
||
| while True: | ||
| for sample in self._get_data_iter(): | ||
| # Use the dataset-specific text processor | ||
| sample_text = self._text_processor(sample) | ||
| sample_tokens = self._tokenizer.encode(sample_text) | ||
| self._all_tokens.extend(sample_tokens) | ||
| self._sample_idx += 1 | ||
|
|
||
| while len(self._all_tokens) >= max_buffer_token_len: | ||
| x = torch.LongTensor(self._all_tokens[:max_buffer_token_len]) | ||
| # update tokens to the remaining tokens | ||
| self._all_tokens = self._all_tokens[max_buffer_token_len:] | ||
| input = x[:-1] | ||
| label = x[1:] | ||
| # Add position IDs (0 to seq_len-1) | ||
| position_ids = torch.arange(len(input), dtype=torch.long) | ||
junjzhang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| yield input, label, position_ids | ||
|
|
||
| if not self.infinite: | ||
| logger.warning(f"Dataset {self.dataset_name} has run out of data") | ||
| break | ||
| else: | ||
| # Reset offset for the next iteration | ||
| self._sample_idx = 0 | ||
| logger.warning(f"Dataset {self.dataset_name} is being re-looped") | ||
|
|
||
|
|
||
| def build_pos_included_hf_dataloader( | ||
junjzhang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| dp_world_size: int, | ||
| dp_rank: int, | ||
| tokenizer, | ||
| job_config: JobConfig, | ||
| infinite: bool = True, | ||
| ) -> ParallelAwareDataloader: | ||
| """Build a data loader for HuggingFace datasets.""" | ||
| dataset_name = job_config.training.dataset | ||
| dataset_path = job_config.training.dataset_path | ||
| batch_size = job_config.training.batch_size | ||
| seq_len = job_config.training.seq_len | ||
|
|
||
| hf_ds = HuggingFaceDatasetWithPos( | ||
| dataset_name=dataset_name, | ||
| dataset_path=dataset_path, | ||
| tokenizer=tokenizer, | ||
| seq_len=seq_len, | ||
| dp_rank=dp_rank, | ||
| dp_world_size=dp_world_size, | ||
| infinite=infinite, | ||
| ) | ||
|
|
||
| return ParallelAwareDataloader( | ||
| dataset=hf_ds, | ||
| dp_rank=dp_rank, | ||
| dp_world_size=dp_world_size, | ||
| batch_size=batch_size, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| transformers >=4.49.0 | ||
junjzhang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| sentencepiece >=0.2.0 | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,188 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import gc | ||
| import json | ||
| from collections import defaultdict | ||
| from pathlib import Path | ||
|
|
||
| import torch.nn as nn | ||
| from huggingface_hub import repo_exists, snapshot_download | ||
| from safetensors import safe_open | ||
| from torch.distributed.tensor import distribute_tensor, DTensor | ||
| from transformers.utils import SAFE_WEIGHTS_INDEX_NAME | ||
|
|
||
| from torchtitan.tools.logging import logger | ||
|
|
||
| INDEX_NAME_MAPPING = { | ||
| "safetensors": SAFE_WEIGHTS_INDEX_NAME, | ||
| } | ||
|
|
||
| PATTERNS_TO_REMOVE = [ | ||
| "._orig_mod", # Some optimizers add suffixes | ||
| "._fsdp_wrapped_module", # FSDP wrapper | ||
| "._checkpoint_wrapped_module", # checkpoint wrapper | ||
| ".module", # DataParallel/DistributedDataParallel | ||
junjzhang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| "_module.", # Some wrappers add prefix | ||
| ] | ||
|
|
||
|
|
||
| def normalize_state_dict_key( | ||
| key: str, patterns_to_remove: list[str] = PATTERNS_TO_REMOVE | ||
| ) -> str: | ||
| """ | ||
| Normalize the state dict key, remove the prefix or suffix added by various wrappers. | ||
| Args: | ||
| key: The original state dict key | ||
| Returns: | ||
| The normalized key | ||
| """ | ||
| normalized_key = key | ||
| for pattern in patterns_to_remove: | ||
| normalized_key = normalized_key.replace(pattern, "") | ||
|
|
||
| return normalized_key | ||
|
|
||
|
|
||
| def get_weight_map(pretrained_model_path: Path) -> dict[str, str]: | ||
| """ | ||
| Get the weight map from the pretrained model. | ||
| Args: | ||
| pretrained_model_path: The path to the pretrained model. | ||
| Returns: | ||
| weight_map: A dictionary mapping from the path to the weight map to the list of state dict keys. | ||
| """ | ||
| index_file = pretrained_model_path / INDEX_NAME_MAPPING["safetensors"] | ||
| if not index_file.exists(): | ||
| return None | ||
| with open(index_file, "r") as f: | ||
| metadata = json.load(f) | ||
| return metadata["weight_map"] | ||
|
|
||
|
|
||
| def group_state_dict_keys_and_st_partition_paths( | ||
| pretrained_model_path: Path, | ||
| state_dict_keys, | ||
| weight_map, | ||
| state_dict_map: dict[str, str] = None, | ||
| ): | ||
| """ | ||
| Group state dict keys and save them to a file. | ||
| Args: | ||
| pretrained_model_path: The path to the pretrained model. | ||
| state_dict_keys: The state dict keys to group. | ||
| weight_map: The weight map. | ||
| state_dict_map: A dictionary mapping from the state dict key to the weight path. | ||
| Returns: | ||
| st_partition_map: A dictionary mapping from the weight path to the list of state dict keys. | ||
| """ | ||
| st_partition_map = defaultdict(list) | ||
| for state_dict_key in state_dict_keys: | ||
| ckpt_state_dict_key = ( | ||
| state_dict_map[state_dict_key] | ||
| if state_dict_map is not None | ||
| else state_dict_key | ||
| ) | ||
| if weight_map is None: | ||
| partition_path = pretrained_model_path / "model.safetensors" | ||
| else: | ||
| partition_path = pretrained_model_path / weight_map[ckpt_state_dict_key] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A Llama-3.2-3B can be used as a model for verification. |
||
| st_partition_map[partition_path].append(state_dict_key) | ||
| return st_partition_map | ||
|
|
||
|
|
||
| def load_sharded_state_dict_for_model_from_path( | ||
| pretrained_model_path: Path, | ||
| model: nn.Module, | ||
| mapping_dict: dict[str, str] = None, | ||
| **kwargs, | ||
| ): | ||
| """ | ||
| Load the state dict sharded (depends on DTensor) from the pretrained model path. It only load the weights for current rank. | ||
junjzhang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| Args: | ||
| pretrained_model_path: The path to the pretrained model, it could be a local path or an s3 path. | ||
| model: The model to load the state dict into. | ||
| **kwargs: other arguments for torch.nn.Module.load_state_dict | ||
| """ | ||
| # check exceptions | ||
| if not pretrained_model_path.exists(): | ||
| raise ValueError( | ||
| f"The pretrained model path {pretrained_model_path} does not exist." | ||
| ) | ||
| if not pretrained_model_path.is_dir(): | ||
| raise ValueError( | ||
| f"The pretrained model path {pretrained_model_path} is not a directory." | ||
| ) | ||
| # get the weight map | ||
| weight_map = get_weight_map(pretrained_model_path) | ||
| model_state_dict = model.state_dict() | ||
| model_state_dict_keys = list(model_state_dict.keys()) | ||
|
|
||
| # create a mapping_dict between the original state_dict_key and the weight_map_key if not provided | ||
| mapping_dict = ( | ||
| mapping_dict | ||
| if mapping_dict is not None | ||
| else {key: normalize_state_dict_key(key) for key in model_state_dict_keys} | ||
| ) | ||
| st_partition_map = group_state_dict_keys_and_st_partition_paths( | ||
| pretrained_model_path, model_state_dict_keys, weight_map, mapping_dict | ||
| ) | ||
|
|
||
| # get the sharded state dict | ||
| state_dict = {} | ||
| for safetensor_partition_path, state_dict_keys in st_partition_map.items(): | ||
| with safe_open(safetensor_partition_path, framework="pt", device="cpu") as f: | ||
| for state_dict_key in state_dict_keys: | ||
| model_tensor = model_state_dict[state_dict_key] | ||
| ckpt_state_dict_key = mapping_dict[state_dict_key] | ||
| if isinstance(model_tensor, DTensor): | ||
| local_tensor = f.get_tensor(ckpt_state_dict_key) | ||
| state_dict[state_dict_key] = distribute_tensor( | ||
| local_tensor, | ||
| model_tensor.device_mesh, | ||
| model_tensor.placements, | ||
| ) | ||
| else: | ||
| state_dict[state_dict_key] = f.get_tensor(ckpt_state_dict_key) | ||
| model.load_state_dict(state_dict, **kwargs) | ||
| del state_dict | ||
| gc.collect() | ||
|
|
||
|
|
||
| def load_sharded_state_dict_for_model_from_hf( | ||
| pretrained_model_id_or_path: str, | ||
| model: nn.Module, | ||
| **kwargs, | ||
| ): | ||
| """ | ||
| Load the state dict sharded (depends on DTensor) from the pretrained model path. It only load the weights for current rank. | ||
| Args: | ||
| pretrained_model_id_or_path: The id or path to the pretrained model, it could be a repo id in huggingface, | ||
| or a local path | ||
| model: The model to load the state dict into. | ||
| **kwargs: other arguments for torch.nn.Module.load_state_dict | ||
| """ | ||
| logger.info(f"Loading the state dict from {pretrained_model_id_or_path}") | ||
| pretrained_model_id_or_path = Path(pretrained_model_id_or_path) | ||
| if not pretrained_model_id_or_path.exists(): | ||
| if not repo_exists(str(pretrained_model_id_or_path)): | ||
| raise ValueError( | ||
| f"The pretrained model {pretrained_model_id_or_path} does not exist" | ||
| ) | ||
| logger.info( | ||
| f"Try to download the model from huggingface: {pretrained_model_id_or_path}" | ||
| ) | ||
| pretrained_model_path = Path( | ||
| snapshot_download(str(pretrained_model_id_or_path)) | ||
| ) | ||
| elif not pretrained_model_id_or_path.is_dir(): | ||
| raise ValueError( | ||
| f"The pretrained model path {pretrained_model_id_or_path} is not a directory." | ||
| ) | ||
| else: | ||
| pretrained_model_path = pretrained_model_id_or_path | ||
|
|
||
| load_sharded_state_dict_for_model_from_path(pretrained_model_path, model, **kwargs) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import torch | ||
|
|
||
|
|
||
| def cross_entropy_loss_hf(preds, labels): | ||
| loss = torch.nn.functional.cross_entropy( | ||
| preds[0].flatten(0, 1).float(), labels.flatten(0, 1) | ||
junjzhang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ) | ||
| return loss | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| # | ||
| # Copyright (c) Meta Platforms, Inc. All Rights Reserved. | ||
|
|
||
| from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer | ||
|
|
||
| from torchtitan.components.optimizer import build_lr_schedulers, build_optimizers | ||
| from torchtitan.experiments.train_llama_hf.dataset import ( | ||
| build_pos_included_hf_dataloader, | ||
| ) | ||
| from torchtitan.experiments.train_llama_hf.loss import cross_entropy_loss_hf | ||
| from torchtitan.protocols.train_spec import register_train_spec, TrainSpec | ||
|
|
||
| from .parallelize_llama import parallelize_llama | ||
| from .pipeline_llama import pipeline_llama | ||
|
|
||
| register_train_spec( | ||
| TrainSpec( | ||
| name="llama3_hf", | ||
| cls=AutoModelForCausalLM, | ||
| config=AutoConfig, | ||
| parallelize_fn=parallelize_llama, | ||
| pipelining_fn=pipeline_llama, | ||
| build_optimizers_fn=build_optimizers, | ||
| build_lr_schedulers_fn=build_lr_schedulers, | ||
| build_dataloader_fn=build_pos_included_hf_dataloader, | ||
| tokenizer_cls=AutoTokenizer, | ||
| loss_fn=cross_entropy_loss_hf, | ||
| ) | ||
| ) |
Uh oh!
There was an error while loading. Please reload this page.