Skip to content

Commit 22e959a

Browse files
authored
Add export_dtype parameter to convert_to_hf function (#2041)
The current `convert_to_hf.py` does not support `export_dtype`, which makes it `float32` by default. This PR adds support for export dtypes of `["float16", "bfloat16", "float32"]`.
1 parent d9bdfbb commit 22e959a

File tree

1 file changed

+23
-3
lines changed

1 file changed

+23
-3
lines changed

scripts/checkpoint_conversion/convert_to_hf.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,18 @@
1212
import torchtitan.protocols.train_spec as train_spec_module
1313
from torch.distributed.checkpoint import HuggingFaceStorageWriter
1414
from torchtitan.components.checkpoint import ModelWrapper
15+
from torchtitan.config import TORCH_DTYPE_MAP
1516

1617

1718
@torch.inference_mode()
18-
def convert_to_hf(input_dir, output_dir, model_name, model_flavor, hf_assets_path):
19-
if model_name == "flux":
20-
import torchtitan.experiments.flux # noqa: F401
19+
def convert_to_hf(
20+
input_dir,
21+
output_dir,
22+
model_name,
23+
model_flavor,
24+
hf_assets_path,
25+
export_dtype,
26+
):
2127
# load model and model args so that we can get the state dict shape
2228
train_spec = train_spec_module.get_train_spec(model_name)
2329
model_args = train_spec.model_args[model_flavor]
@@ -49,6 +55,11 @@ def convert_to_hf(input_dir, output_dir, model_name, model_flavor, hf_assets_pat
4955
thread_count_consolidation=5,
5056
)
5157

58+
# map and apply export dtype if needed
59+
target_dtype = TORCH_DTYPE_MAP[export_dtype]
60+
if target_dtype != torch.float32:
61+
hf_state_dict = {k: v.to(target_dtype) for k, v in hf_state_dict.items()}
62+
5263
dcp.save(
5364
hf_state_dict,
5465
storage_writer=storage_writer,
@@ -71,6 +82,14 @@ def convert_to_hf(input_dir, output_dir, model_name, model_flavor, hf_assets_pat
7182
)
7283
parser.add_argument("--model_name", type=str, nargs="?", default="llama3")
7384
parser.add_argument("--model_flavor", type=str, nargs="?", default="8B")
85+
parser.add_argument(
86+
"--export_dtype",
87+
type=str,
88+
nargs="?",
89+
choices=["float16", "bfloat16", "float32"],
90+
default="float32",
91+
help="Export dtype for HF checkpoint (default: float32)",
92+
)
7493
args = parser.parse_args()
7594

7695
convert_to_hf(
@@ -79,4 +98,5 @@ def convert_to_hf(input_dir, output_dir, model_name, model_flavor, hf_assets_pat
7998
args.model_name,
8099
args.model_flavor,
81100
args.hf_assets_path,
101+
args.export_dtype,
82102
)

0 commit comments

Comments
 (0)