Skip to content

Commit e9e4957

Browse files
authored
Add extra line before repro log; update repro log tests (#1102)
1 parent 8452d52 commit e9e4957

File tree

3 files changed

+132
-42
lines changed

3 files changed

+132
-42
lines changed

helion/runtime/kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,7 @@ def _render_input_arg_assignment(name: str, value: object) -> list[str]:
737737
output_lines.extend(["", "helion_repro_caller()"])
738738

739739
output_lines.append("# === END HELION KERNEL REPRO ===")
740-
repro_text = "\n".join(output_lines)
740+
repro_text = "\n" + "\n".join(output_lines)
741741
log_func(repro_text)
742742

743743

test/test_debug_utils.expected

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,75 @@ def helion_repro_caller():
2323

2424
helion_repro_caller()
2525
# === END HELION KERNEL REPRO ===
26+
27+
--- assertExpectedJournal(TestDebugUtils.test_print_repro_on_autotune_error)
28+
# === HELION KERNEL REPRO ===
29+
import helion
30+
import helion.language as hl
31+
import torch
32+
from torch._dynamo.testing import rand_strided
33+
34+
@helion.kernel(config=helion.Config(block_sizes=[64], indexing=['pointer', 'pointer'], load_eviction_policies=[''], num_stages=1, num_warps=8, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[]), static_shapes=True)
35+
def kernel(x: torch.Tensor) -> torch.Tensor:
36+
out = torch.empty_like(x)
37+
n = x.shape[0]
38+
for tile_n in hl.tile([n]):
39+
out[tile_n] = x[tile_n] + 1
40+
return out
41+
42+
def helion_repro_caller():
43+
torch.manual_seed(0)
44+
x = rand_strided((128,), (1,), dtype=torch.float32, device=DEVICE)
45+
return kernel(x)
46+
47+
helion_repro_caller()
48+
# === END HELION KERNEL REPRO ===
49+
50+
--- assertExpectedJournal(TestDebugUtils.test_print_repro_on_device_ir_lowering_error)
51+
# === HELION KERNEL REPRO ===
52+
import helion
53+
import helion.language as hl
54+
import torch
55+
from torch._dynamo.testing import rand_strided
56+
57+
@helion.kernel(config=helion.Config(block_sizes=[32], indexing=[], load_eviction_policies=[], num_stages=1, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[]), static_shapes=True)
58+
def kernel_with_compile_error(x: torch.Tensor) -> torch.Tensor:
59+
out = torch.empty_like(x)
60+
n = x.shape[0]
61+
for tile_n in hl.tile([n]):
62+
# Using torch.nonzero inside device loop causes compilation error
63+
# because it produces data-dependent output shape
64+
torch.nonzero(x[tile_n])
65+
out[tile_n] = x[tile_n]
66+
return out
67+
68+
def helion_repro_caller():
69+
torch.manual_seed(0)
70+
x = rand_strided((128,), (1,), dtype=torch.float32, device=DEVICE)
71+
return kernel_with_compile_error(x)
72+
73+
helion_repro_caller()
74+
# === END HELION KERNEL REPRO ===
75+
76+
--- assertExpectedJournal(TestDebugUtils.test_print_repro_on_triton_codegen_error)
77+
# === HELION KERNEL REPRO ===
78+
import helion
79+
import helion.language as hl
80+
import torch
81+
from torch._dynamo.testing import rand_strided
82+
83+
@helion.kernel(config=helion.Config(block_sizes=[32], indexing=['pointer', 'pointer'], load_eviction_policies=[''], num_stages=1, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[]), static_shapes=True)
84+
def kernel_with_triton_error(x: torch.Tensor) -> torch.Tensor:
85+
out = torch.empty_like(x)
86+
n = x.shape[0]
87+
for tile_n in hl.tile([n]):
88+
out[tile_n] = x[tile_n] + 1
89+
return out
90+
91+
def helion_repro_caller():
92+
torch.manual_seed(0)
93+
x = rand_strided((128,), (1,), dtype=torch.float32, device=DEVICE)
94+
return kernel_with_triton_error(x)
95+
96+
helion_repro_caller()
97+
# === END HELION KERNEL REPRO ===

test/test_debug_utils.py

Lines changed: 59 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,41 @@ def kernel(x: torch.Tensor) -> torch.Tensor:
6868

6969
return kernel
7070

71+
def _extract_repro_script(self, text: str) -> str:
72+
"""Extract the repro code block between markers (including markers).
73+
74+
Args:
75+
text: The text containing the repro block. Can be a full string or log_capture object.
76+
77+
Returns:
78+
The extracted repro block including both markers.
79+
"""
80+
# If it's a log capture object, extract the repro script from logs first
81+
if hasattr(text, "records"):
82+
log_capture = text
83+
repro_script = None
84+
for record in log_capture.records:
85+
if "# === HELION KERNEL REPRO ===" in record.message:
86+
repro_script = record.message
87+
break
88+
if repro_script is None:
89+
self.fail("No repro script found in logs")
90+
text = repro_script
91+
92+
# Extract code block between markers
93+
start_marker = "# === HELION KERNEL REPRO ==="
94+
end_marker = "# === END HELION KERNEL REPRO ==="
95+
start_idx = text.find(start_marker)
96+
end_idx = text.find(end_marker)
97+
98+
if start_idx == -1:
99+
self.fail("Start marker not found")
100+
if end_idx == -1:
101+
self.fail("End marker not found")
102+
103+
# Extract content including both markers
104+
return text[start_idx : end_idx + len(end_marker)].strip()
105+
71106
def test_print_repro_env_var(self):
72107
"""Ensure HELION_PRINT_REPRO=1 emits an executable repro script."""
73108
with self._with_print_repro_enabled():
@@ -83,15 +118,8 @@ def test_print_repro_env_var(self):
83118
result = kernel(x)
84119
torch.testing.assert_close(result, x + 1)
85120

86-
# Extract repro script from logs (use records to get the raw message without formatting)
87-
repro_script = None
88-
for record in log_capture.records:
89-
if "# === HELION KERNEL REPRO ===" in record.message:
90-
repro_script = record.message
91-
break
92-
93-
if repro_script is None:
94-
self.fail("No repro script found in logs")
121+
# Extract repro script from logs
122+
repro_script = self._extract_repro_script(log_capture)
95123

96124
# Normalize range_warp_specializes=[None] to [] for comparison
97125
normalized_script = repro_script.replace(
@@ -163,10 +191,14 @@ def mock_do_bench(*args, **kwargs):
163191
captured = "".join(output_capture.readouterr())
164192

165193
# Verify that a repro script was printed for the failing config
166-
self.assertIn("# === HELION KERNEL REPRO ===", captured)
167-
self.assertIn("# === END HELION KERNEL REPRO ===", captured)
168-
self.assertIn("kernel", captured)
169-
self.assertIn("helion_repro_caller()", captured)
194+
repro_script = self._extract_repro_script(captured)
195+
196+
# Normalize range_warp_specializes=[None] to [] for comparison
197+
normalized_script = repro_script.replace(
198+
"range_warp_specializes=[None]", "range_warp_specializes=[]"
199+
)
200+
201+
self.assertExpectedJournal(normalized_script)
170202

171203
def test_print_repro_on_device_ir_lowering_error(self):
172204
"""Ensure HELION_PRINT_REPRO=1 prints repro when compilation fails during device IR lowering."""
@@ -192,21 +224,14 @@ def kernel_with_compile_error(x: torch.Tensor) -> torch.Tensor:
192224
kernel_with_compile_error(x)
193225

194226
# Extract repro script from logs
195-
repro_script = None
196-
for record in log_capture.records:
197-
if "# === HELION KERNEL REPRO ===" in record.message:
198-
repro_script = record.message
199-
break
200-
201-
# Verify that a repro script was printed when compilation failed
202-
self.assertIsNotNone(
203-
repro_script,
204-
"Expected repro script to be printed when device IR lowering fails",
227+
repro_script = self._extract_repro_script(log_capture)
228+
229+
# Normalize range_warp_specializes=[None] to [] for comparison
230+
normalized_script = repro_script.replace(
231+
"range_warp_specializes=[None]", "range_warp_specializes=[]"
205232
)
206-
self.assertIn("# === HELION KERNEL REPRO ===", repro_script)
207-
self.assertIn("# === END HELION KERNEL REPRO ===", repro_script)
208-
self.assertIn("kernel_with_compile_error", repro_script)
209-
self.assertIn("helion_repro_caller()", repro_script)
233+
234+
self.assertExpectedJournal(normalized_script)
210235

211236
def test_print_repro_on_triton_codegen_error(self):
212237
"""Ensure HELION_PRINT_REPRO=1 prints repro when Triton codegen fails."""
@@ -242,21 +267,14 @@ def mock_load(code, *args, **kwargs):
242267
kernel_with_triton_error(x)
243268

244269
# Extract repro script from logs
245-
repro_script = None
246-
for record in log_capture.records:
247-
if "# === HELION KERNEL REPRO ===" in record.message:
248-
repro_script = record.message
249-
break
250-
251-
# Verify that a repro script was printed when Triton codegen failed
252-
self.assertIsNotNone(
253-
repro_script,
254-
"Expected repro script to be printed when Triton codegen fails",
270+
repro_script = self._extract_repro_script(log_capture)
271+
272+
# Normalize range_warp_specializes=[None] to [] for comparison
273+
normalized_script = repro_script.replace(
274+
"range_warp_specializes=[None]", "range_warp_specializes=[]"
255275
)
256-
self.assertIn("# === HELION KERNEL REPRO ===", repro_script)
257-
self.assertIn("# === END HELION KERNEL REPRO ===", repro_script)
258-
self.assertIn("kernel_with_triton_error", repro_script)
259-
self.assertIn("helion_repro_caller()", repro_script)
276+
277+
self.assertExpectedJournal(normalized_script)
260278

261279

262280
if __name__ == "__main__":

0 commit comments

Comments
 (0)