|
31 | 31 | # variable setting is shown for each example. |
32 | 32 |
|
33 | 33 | import torch |
| 34 | +import sys |
34 | 35 |
|
35 | | -# exit cleanly if we are on a device that doesn't support torch.compile |
36 | | -if torch.cuda.get_device_capability() < (7, 0): |
37 | | - print("Skipping because torch.compile is not supported on this device.") |
38 | | -else: |
39 | | - @torch.compile() |
40 | | - def fn(x, y): |
41 | | - z = x + y |
42 | | - return z + 2 |
43 | | - |
44 | | - |
45 | | - inputs = (torch.ones(2, 2, device="cuda"), torch.zeros(2, 2, device="cuda")) |
46 | 36 |
|
| 37 | +def env_setup(): |
| 38 | + """Set up environment for running the example. Exit cleanly if CUDA is not available.""" |
| 39 | + if not torch.cuda.is_available(): |
| 40 | + print("CUDA is not available. Exiting.") |
| 41 | + sys.exit(0) |
| 42 | + |
| 43 | + if torch.cuda.get_device_capability() < (7, 0): |
| 44 | + print("Skipping because torch.compile is not supported on this device.") |
| 45 | + sys.exit(0) |
47 | 46 |
|
48 | | -# print separator and reset dynamo |
49 | | -# between each example |
50 | | - def separator(name): |
51 | | - print(f"==================={name}=========================") |
52 | | - torch._dynamo.reset() |
53 | 47 |
|
| 48 | +def separator(name): |
| 49 | + """Print separator and reset dynamo between each example""" |
| 50 | + print(f"\n{'='*20} {name} {'='*20}") |
| 51 | + torch._dynamo.reset() |
54 | 52 |
|
55 | | - separator("Dynamo Tracing") |
56 | | -# View dynamo tracing |
57 | | -# TORCH_LOGS="+dynamo" |
58 | | - torch._logging.set_logs(dynamo=logging.DEBUG) |
59 | | - fn(*inputs) |
60 | 53 |
|
61 | | - separator("Traced Graph") |
62 | | -# View traced graph |
63 | | -# TORCH_LOGS="graph" |
64 | | - torch._logging.set_logs(graph=True) |
65 | | - fn(*inputs) |
| 54 | +def run_debugging_suite(): |
| 55 | + """Run the complete debugging suite with all logging options""" |
| 56 | + env_setup() |
66 | 57 |
|
67 | | - separator("Fusion Decisions") |
68 | | -# View fusion decisions |
69 | | -# TORCH_LOGS="fusion" |
70 | | - torch._logging.set_logs(fusion=True) |
71 | | - fn(*inputs) |
72 | | - |
73 | | - separator("Output Code") |
74 | | -# View output code generated by inductor |
75 | | -# TORCH_LOGS="output_code" |
76 | | - torch._logging.set_logs(output_code=True) |
77 | | - fn(*inputs) |
| 58 | + @torch.compile() |
| 59 | + def fn(x, y): |
| 60 | + z = x + y |
| 61 | + return z + 2 |
78 | 62 |
|
79 | | - separator("") |
| 63 | + inputs = ( |
| 64 | + torch.ones(2, 2, device="cuda"), |
| 65 | + torch.zeros(2, 2, device="cuda") |
| 66 | + ) |
| 67 | + |
| 68 | + logging_scenarios = [ |
| 69 | + # View dynamo tracing; TORCH_LOGS="+dynamo" |
| 70 | + ("Dynamo Tracing", {"dynamo": logging.DEBUG}), |
| 71 | + |
| 72 | + # View traced graph; TORCH_LOGS="graph" |
| 73 | + ("Traced Graph", {"graph": True}), |
| 74 | + |
| 75 | + # View fusion decisions; TORCH_LOGS="fusion" |
| 76 | + ("Fusion Decisions", {"fusion": True}), |
| 77 | + |
| 78 | + # View output code generated by inductor; TORCH_LOGS="output_code" |
| 79 | + ("Output Code", {"output_code": True}) |
| 80 | + ] |
| 81 | + |
| 82 | + for name, log_config in logging_scenarios: |
| 83 | + separator(name) |
| 84 | + torch._logging.set_logs(**log_config) |
| 85 | + try: |
| 86 | + result = fn(*inputs) |
| 87 | + print(f"Function output shape: {result.shape}") |
| 88 | + except Exception as e: |
| 89 | + print(f"Error during {name}: {str(e)}") |
| 90 | + |
| 91 | +if __name__ == "__main__": |
| 92 | + run_debugging_suite() |
80 | 93 |
|
81 | 94 | ###################################################################### |
82 | 95 | # Conclusion |
|
0 commit comments