Skip to content

Commit cbfb8e1

Browse files
authored
[compiler toolkit] Add tests and scripts for numerics check (#2015)
This PR adds the utils to automatically check the training numerics (losses, grad norms) of two runs to verify if they have bitwise equivalence. The added script triggers two runs with user defined configs. Then it loads metrics saved during training and compare the numerics to verify bitwise equivalence. Currently we check for losses and grad norms during training steps For example, we want to compare the numerics between compiler toolkit with aot_eager backend and eager on llama3-8B. ``` python torchtitan/experiments/compiler_toolkit/scripts/check_numerics.py --ngpu 4 --config-file torchtitan/models/llama3/train_configs/llama3_8b.toml --dp-shard-degree 2 --tp-degree 2 ``` It'll run `simple_fsdp` experiment without `torch.compile` as the eager baseline, and `compile_toolkit` experiment as the compiled run. Then it compares the training numerics of these two runs to verify bitwise equivalence. When it is bitwise equivalent, we'll see the following output ``` Starting training: simple_fsdp.llama3 ✓ Training completed: simple_fsdp.llama3 Starting training: compiler_toolkit.llama3 ✓ Training completed: compiler_toolkit.llama3 ✓ PASS: All 11 steps match exactly (bitwise equivalent) ✓ PASS: All 11 steps match exactly (bitwise equivalent) ✓ SUCCESS: All metrics are bitwise equivalent ``` Also added unit-tests in `compiler_toolkit/tests/test_numerics.py` so that we can guard working parallelism combinations that already have bitwise equivalence in CI.
1 parent 55c63c1 commit cbfb8e1

File tree

3 files changed

+467
-0
lines changed

3 files changed

+467
-0
lines changed
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import argparse
8+
import sys
9+
from pathlib import Path
10+
11+
# Add parent directory to path to import numerics_utils
12+
sys.path.insert(0, str(Path(__file__).parent.parent))
13+
14+
from tests.numerics_utils import run_numerics_test
15+
16+
17+
def main():
18+
parser = argparse.ArgumentParser(
19+
description="Run two training jobs and compare their tensorboard metrics"
20+
)
21+
parser.add_argument(
22+
"--ngpu",
23+
type=int,
24+
required=True,
25+
help="Number of GPUs to use",
26+
)
27+
parser.add_argument(
28+
"--config-file",
29+
type=str,
30+
required=True,
31+
help="Path to config file",
32+
)
33+
parser.add_argument(
34+
"--dp-shard-degree",
35+
type=int,
36+
default=1,
37+
help="Data parallel shard degree",
38+
)
39+
parser.add_argument(
40+
"--tp-degree",
41+
type=int,
42+
default=1,
43+
help="Tensor parallel degree",
44+
)
45+
parser.add_argument(
46+
"--cp-degree",
47+
type=int,
48+
default=1,
49+
help="Context parallel degree",
50+
)
51+
parser.add_argument(
52+
"--ep-degree",
53+
type=int,
54+
default=1,
55+
help="Expert parallel degree",
56+
)
57+
parser.add_argument(
58+
"--ac-mode",
59+
type=str,
60+
default="selective",
61+
choices=["selective", "none", "full"],
62+
help="Activation checkpoint mode",
63+
)
64+
parser.add_argument(
65+
"--steps",
66+
type=int,
67+
default=50,
68+
help="Number of training steps",
69+
)
70+
parser.add_argument(
71+
"--seed",
72+
type=int,
73+
default=42,
74+
help="Random seed for deterministic training",
75+
)
76+
parser.add_argument(
77+
"--eager-tb-folder",
78+
type=str,
79+
default="tb/eager_run",
80+
help="Tensorboard folder for eager run",
81+
)
82+
parser.add_argument(
83+
"--compiled-tb-folder",
84+
type=str,
85+
default="tb/compiled_run",
86+
help="Tensorboard folder for compiled run",
87+
)
88+
parser.add_argument(
89+
"--metrics",
90+
nargs="+",
91+
default=["loss_metrics/global_avg_loss", "grad_norm"],
92+
help="Metrics to compare",
93+
)
94+
parser.add_argument(
95+
"--passes",
96+
type=str,
97+
default=None,
98+
help=(
99+
"Comma-separated list of compiler passes to apply "
100+
"(e.g., 'autobucketing_reordering' or 'autobucketing_reordering,regional_inductor')"
101+
),
102+
)
103+
104+
args = parser.parse_args()
105+
106+
success = run_numerics_test(
107+
ngpu=args.ngpu,
108+
config_file=args.config_file,
109+
dp_shard_degree=args.dp_shard_degree,
110+
tp_degree=args.tp_degree,
111+
cp_degree=args.cp_degree,
112+
ep_degree=args.ep_degree,
113+
ac_mode=args.ac_mode,
114+
steps=args.steps,
115+
seed=args.seed,
116+
eager_tb_folder=args.eager_tb_folder,
117+
compiled_tb_folder=args.compiled_tb_folder,
118+
metrics=args.metrics,
119+
passes=args.passes,
120+
)
121+
122+
return 0 if success else 1
123+
124+
125+
if __name__ == "__main__":
126+
exit(main())
Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""Shared utilities for numerics testing."""
8+
9+
import glob
10+
import os
11+
import subprocess
12+
13+
import torch
14+
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
15+
16+
17+
def load_metrics(event_path, metric_names):
18+
"""Load metrics from tensorboard event files."""
19+
event_acc = EventAccumulator(event_path)
20+
event_acc.Reload()
21+
22+
metrics = {}
23+
for metric_name in metric_names:
24+
try:
25+
scalars = event_acc.Scalars(metric_name)
26+
metrics[metric_name] = {scalar.step: scalar.value for scalar in scalars}
27+
except KeyError:
28+
print(f"Warning: Metric {metric_name!r} not found in event file")
29+
metrics[metric_name] = {}
30+
31+
return metrics
32+
33+
34+
def compare_metrics(metrics1, metrics2, label1="Eager", label2="Compiled"):
35+
"""Compare two sets of metrics and verify bitwise equivalence using torch.equal()."""
36+
37+
all_metrics = set(metrics1.keys()) | set(metrics2.keys())
38+
all_match = True
39+
40+
for metric_name in sorted(all_metrics):
41+
42+
steps1 = set(metrics1[metric_name].keys())
43+
steps2 = set(metrics2[metric_name].keys())
44+
45+
if steps1 != steps2:
46+
print(" ERROR: Step mismatch!")
47+
print(f" {label1} steps: {sorted(steps1)}")
48+
print(f" {label2} steps: {sorted(steps2)}")
49+
all_match = False
50+
continue
51+
52+
# Convert values to tensors for each step and compare
53+
values1 = [metrics1[metric_name][step] for step in sorted(steps1)]
54+
values2 = [metrics2[metric_name][step] for step in sorted(steps2)]
55+
56+
tensor1 = torch.tensor(values1)
57+
tensor2 = torch.tensor(values2)
58+
59+
if torch.equal(tensor1, tensor2):
60+
print(
61+
f" ✓ PASS: All {len(steps1)} steps match exactly (bitwise equivalent)"
62+
)
63+
else:
64+
# Find and report mismatches
65+
mismatches = []
66+
for idx, step in enumerate(sorted(steps1)):
67+
val1 = values1[idx]
68+
val2 = values2[idx]
69+
if val1 != val2:
70+
mismatches.append((step, val1, val2, abs(val1 - val2)))
71+
72+
print(
73+
f" ERROR: Found {len(mismatches)} mismatches out of {len(steps1)} steps"
74+
)
75+
76+
return all_match
77+
78+
79+
def find_latest_event_dir(base_path):
80+
"""Find the latest timestamped directory in the base path."""
81+
if not os.path.exists(base_path):
82+
raise ValueError(f"Path does not exist: {base_path}")
83+
84+
subdirs = [d for d in glob.glob(os.path.join(base_path, "*")) if os.path.isdir(d)]
85+
if not subdirs:
86+
return base_path
87+
88+
latest = max(subdirs, key=os.path.getmtime)
89+
return latest
90+
91+
92+
def run_training(
93+
ngpu,
94+
config_file,
95+
model_name,
96+
dp_shard_degree,
97+
tp_degree,
98+
cp_degree,
99+
ep_degree,
100+
ac_mode,
101+
steps,
102+
seed,
103+
deterministic,
104+
tb_folder,
105+
passes=None,
106+
):
107+
"""Run a training job with the specified configuration."""
108+
print(f"\nStarting training: {model_name}")
109+
110+
env = os.environ.copy()
111+
env["NGPU"] = str(ngpu)
112+
env["CONFIG_FILE"] = config_file
113+
114+
cmd = [
115+
"./run_train.sh",
116+
"--model.name",
117+
model_name,
118+
"--parallelism.data_parallel_shard_degree",
119+
str(dp_shard_degree),
120+
"--parallelism.tensor_parallel_degree",
121+
str(tp_degree),
122+
]
123+
124+
if cp_degree > 1:
125+
cmd.extend(["--parallelism.context_parallel_degree", str(cp_degree)])
126+
if ep_degree > 1:
127+
cmd.extend(["--parallelism.expert_parallel_degree", str(ep_degree)])
128+
129+
cmd.extend(
130+
[
131+
"--activation_checkpoint.mode",
132+
ac_mode,
133+
"--training.steps",
134+
str(steps),
135+
"--debug.seed",
136+
str(seed),
137+
"--debug.deterministic",
138+
"--metrics.enable_tensorboard",
139+
"--metrics.save_tb_folder",
140+
tb_folder,
141+
]
142+
)
143+
144+
if passes:
145+
cmd.extend(
146+
[
147+
"--job.custom_config_module",
148+
"torchtitan.experiments.compiler_toolkit.job_config",
149+
"--compile.passes",
150+
passes,
151+
]
152+
)
153+
154+
print(f"Environment: NGPU={env['NGPU']}, CONFIG_FILE={env['CONFIG_FILE']}")
155+
print(f"Running command: {' '.join(cmd)}")
156+
157+
try:
158+
result = subprocess.run(
159+
cmd,
160+
env=env,
161+
check=True,
162+
stdout=subprocess.PIPE,
163+
stderr=subprocess.STDOUT,
164+
text=True,
165+
)
166+
print(f"✓ Training completed: {model_name}")
167+
return True
168+
except subprocess.CalledProcessError as e:
169+
print(f"✗ Training failed: {model_name}")
170+
print(f"Error output:\n{e.stdout}")
171+
return False
172+
173+
174+
def determine_model_names(config_file):
175+
"""Determine model names based on config file."""
176+
if "deepseek" in config_file:
177+
model_name = "deepseek_v3"
178+
elif "llama3" in config_file:
179+
model_name = "llama3"
180+
else:
181+
raise ValueError(
182+
f"Unable to determine model names from config file: {config_file}"
183+
)
184+
185+
eager_model = f"simple_fsdp.{model_name}"
186+
compiled_model = f"compiler_toolkit.{model_name}"
187+
188+
return eager_model, compiled_model
189+
190+
191+
def run_numerics_test(
192+
ngpu,
193+
config_file,
194+
dp_shard_degree,
195+
tp_degree,
196+
cp_degree,
197+
ep_degree,
198+
ac_mode,
199+
steps,
200+
seed,
201+
eager_tb_folder,
202+
compiled_tb_folder,
203+
metrics,
204+
passes=None,
205+
):
206+
"""
207+
Run numerics test by training both eager and compiled models and comparing metrics.
208+
209+
Returns:
210+
bool: True if all metrics match, False otherwise.
211+
"""
212+
# Determine model names
213+
eager_model, compiled_model = determine_model_names(config_file)
214+
215+
# Run eager training
216+
eager_success = run_training(
217+
ngpu=ngpu,
218+
config_file=config_file,
219+
model_name=eager_model,
220+
dp_shard_degree=dp_shard_degree,
221+
tp_degree=tp_degree,
222+
cp_degree=cp_degree,
223+
ep_degree=ep_degree,
224+
ac_mode=ac_mode,
225+
steps=steps,
226+
seed=seed,
227+
deterministic=True,
228+
tb_folder=eager_tb_folder,
229+
)
230+
231+
if not eager_success:
232+
print("✗ Eager training failed")
233+
return False
234+
235+
# Run compiled training
236+
compiled_success = run_training(
237+
ngpu=ngpu,
238+
config_file=config_file,
239+
model_name=compiled_model,
240+
dp_shard_degree=dp_shard_degree,
241+
tp_degree=tp_degree,
242+
cp_degree=cp_degree,
243+
ep_degree=ep_degree,
244+
ac_mode=ac_mode,
245+
steps=steps,
246+
seed=seed,
247+
deterministic=True,
248+
tb_folder=compiled_tb_folder,
249+
passes=passes,
250+
)
251+
252+
if not compiled_success:
253+
print("✗ Compiled training failed")
254+
return False
255+
256+
# Compare metrics
257+
eager_path = find_latest_event_dir(f"./outputs/{eager_tb_folder}")
258+
compiled_path = find_latest_event_dir(f"./outputs/{compiled_tb_folder}")
259+
260+
eager_metrics = load_metrics(eager_path, metrics)
261+
compiled_metrics = load_metrics(compiled_path, metrics)
262+
263+
all_match = compare_metrics(eager_metrics, compiled_metrics)
264+
265+
if all_match:
266+
print("✓ SUCCESS: All metrics are bitwise equivalent")
267+
else:
268+
print("✗ FAILURE: Metrics differ between runs")
269+
270+
return all_match

0 commit comments

Comments
 (0)