@@ -29,7 +29,9 @@ def _dump_gm(dump_folder: str | None, gm: torch.fx.GraphModule, name: str) -> No
2929
3030 output_path = Path (dump_folder ) / "compiler" / f"{ name } .txt"
3131 output_path .parent .mkdir (parents = True , exist_ok = True )
32- output_path .write_text (gm .print_readable (print_output = False ))
32+ output_path .write_text (
33+ gm .print_readable (print_output = False , include_stride = True , include_device = True )
34+ )
3335
3436
3537def export_joint (
@@ -47,7 +49,11 @@ def export_joint(
4749 ):
4850 gm = dynamo_graph_capture_for_export (model )(* args , ** kwargs )
4951 logger .debug ("Dynamo gm:" )
50- logger .debug (gm .print_readable (print_output = False ))
52+ logger .debug (
53+ gm .print_readable (
54+ print_output = False , include_stride = True , include_device = True
55+ )
56+ )
5157 _dump_gm (dump_folder , gm , "dynamo_gm" )
5258
5359 tracing_context = gm .meta ["tracing_context" ]
@@ -224,15 +230,19 @@ def compiler(
224230 passes = DEFAULT_COMPILER_PASSES
225231
226232 logger .debug (f"{ name } before compiler:" )
227- logger .debug (gm .print_readable (print_output = False ))
233+ logger .debug (
234+ gm .print_readable (print_output = False , include_stride = True , include_device = True )
235+ )
228236 _dump_gm (dump_folder , gm , f"{ name } _before_compiler" )
229237
230238 for pass_fn in passes :
231239 logger .info (f"Applying pass: { pass_fn .__name__ } " )
232240 gm = pass_fn (gm , example_inputs )
233241
234242 logger .debug (f"{ name } after compiler:" )
235- logger .debug (gm .print_readable (print_output = False ))
243+ logger .debug (
244+ gm .print_readable (print_output = False , include_stride = True , include_device = True )
245+ )
236246 _dump_gm (dump_folder , gm , f"{ name } _after_compiler" )
237247 return gm
238248
0 commit comments