Skip to content

Commit 028a455

Browse files
authored
Print device and stride when print module (#2045)
Before: <img width="978" height="93" alt="image" src="https://github.com/user-attachments/assets/48dc39d9-e897-4396-ac62-025574303403" /> After: <img width="1318" height="82" alt="image" src="https://github.com/user-attachments/assets/47b4771a-aaf9-4f61-80bc-757f3a08c1d2" />
1 parent 23c993c commit 028a455

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

torchtitan/experiments/compiler_toolkit/graph_utils.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ def _dump_gm(dump_folder: str | None, gm: torch.fx.GraphModule, name: str) -> No
2929

3030
output_path = Path(dump_folder) / "compiler" / f"{name}.txt"
3131
output_path.parent.mkdir(parents=True, exist_ok=True)
32-
output_path.write_text(gm.print_readable(print_output=False))
32+
output_path.write_text(
33+
gm.print_readable(print_output=False, include_stride=True, include_device=True)
34+
)
3335

3436

3537
def export_joint(
@@ -47,7 +49,11 @@ def export_joint(
4749
):
4850
gm = dynamo_graph_capture_for_export(model)(*args, **kwargs)
4951
logger.debug("Dynamo gm:")
50-
logger.debug(gm.print_readable(print_output=False))
52+
logger.debug(
53+
gm.print_readable(
54+
print_output=False, include_stride=True, include_device=True
55+
)
56+
)
5157
_dump_gm(dump_folder, gm, "dynamo_gm")
5258

5359
tracing_context = gm.meta["tracing_context"]
@@ -224,15 +230,19 @@ def compiler(
224230
passes = DEFAULT_COMPILER_PASSES
225231

226232
logger.debug(f"{name} before compiler:")
227-
logger.debug(gm.print_readable(print_output=False))
233+
logger.debug(
234+
gm.print_readable(print_output=False, include_stride=True, include_device=True)
235+
)
228236
_dump_gm(dump_folder, gm, f"{name}_before_compiler")
229237

230238
for pass_fn in passes:
231239
logger.info(f"Applying pass: {pass_fn.__name__}")
232240
gm = pass_fn(gm, example_inputs)
233241

234242
logger.debug(f"{name} after compiler:")
235-
logger.debug(gm.print_readable(print_output=False))
243+
logger.debug(
244+
gm.print_readable(print_output=False, include_stride=True, include_device=True)
245+
)
236246
_dump_gm(dump_folder, gm, f"{name}_after_compiler")
237247
return gm
238248

0 commit comments

Comments
 (0)