@@ -68,6 +68,41 @@ def kernel(x: torch.Tensor) -> torch.Tensor:
6868
6969 return kernel
7070
71+ def _extract_repro_script (self , text : str ) -> str :
72+ """Extract the repro code block between markers (including markers).
73+
74+ Args:
75+ text: The text containing the repro block. Can be a full string or log_capture object.
76+
77+ Returns:
78+ The extracted repro block including both markers.
79+ """
80+ # If it's a log capture object, extract the repro script from logs first
81+ if hasattr (text , "records" ):
82+ log_capture = text
83+ repro_script = None
84+ for record in log_capture .records :
85+ if "# === HELION KERNEL REPRO ===" in record .message :
86+ repro_script = record .message
87+ break
88+ if repro_script is None :
89+ self .fail ("No repro script found in logs" )
90+ text = repro_script
91+
92+ # Extract code block between markers
93+ start_marker = "# === HELION KERNEL REPRO ==="
94+ end_marker = "# === END HELION KERNEL REPRO ==="
95+ start_idx = text .find (start_marker )
96+ end_idx = text .find (end_marker )
97+
98+ if start_idx == - 1 :
99+ self .fail ("Start marker not found" )
100+ if end_idx == - 1 :
101+ self .fail ("End marker not found" )
102+
103+ # Extract content including both markers
104+ return text [start_idx : end_idx + len (end_marker )].strip ()
105+
71106 def test_print_repro_env_var (self ):
72107 """Ensure HELION_PRINT_REPRO=1 emits an executable repro script."""
73108 with self ._with_print_repro_enabled ():
@@ -83,15 +118,8 @@ def test_print_repro_env_var(self):
83118 result = kernel (x )
84119 torch .testing .assert_close (result , x + 1 )
85120
86- # Extract repro script from logs (use records to get the raw message without formatting)
87- repro_script = None
88- for record in log_capture .records :
89- if "# === HELION KERNEL REPRO ===" in record .message :
90- repro_script = record .message
91- break
92-
93- if repro_script is None :
94- self .fail ("No repro script found in logs" )
121+ # Extract repro script from logs
122+ repro_script = self ._extract_repro_script (log_capture )
95123
96124 # Normalize range_warp_specializes=[None] to [] for comparison
97125 normalized_script = repro_script .replace (
@@ -163,10 +191,14 @@ def mock_do_bench(*args, **kwargs):
163191 captured = "" .join (output_capture .readouterr ())
164192
165193 # Verify that a repro script was printed for the failing config
166- self .assertIn ("# === HELION KERNEL REPRO ===" , captured )
167- self .assertIn ("# === END HELION KERNEL REPRO ===" , captured )
168- self .assertIn ("kernel" , captured )
169- self .assertIn ("helion_repro_caller()" , captured )
194+ repro_script = self ._extract_repro_script (captured )
195+
196+ # Normalize range_warp_specializes=[None] to [] for comparison
197+ normalized_script = repro_script .replace (
198+ "range_warp_specializes=[None]" , "range_warp_specializes=[]"
199+ )
200+
201+ self .assertExpectedJournal (normalized_script )
170202
171203 def test_print_repro_on_device_ir_lowering_error (self ):
172204 """Ensure HELION_PRINT_REPRO=1 prints repro when compilation fails during device IR lowering."""
@@ -192,21 +224,14 @@ def kernel_with_compile_error(x: torch.Tensor) -> torch.Tensor:
192224 kernel_with_compile_error (x )
193225
194226 # Extract repro script from logs
195- repro_script = None
196- for record in log_capture .records :
197- if "# === HELION KERNEL REPRO ===" in record .message :
198- repro_script = record .message
199- break
200-
201- # Verify that a repro script was printed when compilation failed
202- self .assertIsNotNone (
203- repro_script ,
204- "Expected repro script to be printed when device IR lowering fails" ,
227+ repro_script = self ._extract_repro_script (log_capture )
228+
229+ # Normalize range_warp_specializes=[None] to [] for comparison
230+ normalized_script = repro_script .replace (
231+ "range_warp_specializes=[None]" , "range_warp_specializes=[]"
205232 )
206- self .assertIn ("# === HELION KERNEL REPRO ===" , repro_script )
207- self .assertIn ("# === END HELION KERNEL REPRO ===" , repro_script )
208- self .assertIn ("kernel_with_compile_error" , repro_script )
209- self .assertIn ("helion_repro_caller()" , repro_script )
233+
234+ self .assertExpectedJournal (normalized_script )
210235
211236 def test_print_repro_on_triton_codegen_error (self ):
212237 """Ensure HELION_PRINT_REPRO=1 prints repro when Triton codegen fails."""
@@ -242,21 +267,14 @@ def mock_load(code, *args, **kwargs):
242267 kernel_with_triton_error (x )
243268
244269 # Extract repro script from logs
245- repro_script = None
246- for record in log_capture .records :
247- if "# === HELION KERNEL REPRO ===" in record .message :
248- repro_script = record .message
249- break
250-
251- # Verify that a repro script was printed when Triton codegen failed
252- self .assertIsNotNone (
253- repro_script ,
254- "Expected repro script to be printed when Triton codegen fails" ,
270+ repro_script = self ._extract_repro_script (log_capture )
271+
272+ # Normalize range_warp_specializes=[None] to [] for comparison
273+ normalized_script = repro_script .replace (
274+ "range_warp_specializes=[None]" , "range_warp_specializes=[]"
255275 )
256- self .assertIn ("# === HELION KERNEL REPRO ===" , repro_script )
257- self .assertIn ("# === END HELION KERNEL REPRO ===" , repro_script )
258- self .assertIn ("kernel_with_triton_error" , repro_script )
259- self .assertIn ("helion_repro_caller()" , repro_script )
276+
277+ self .assertExpectedJournal (normalized_script )
260278
261279
262280if __name__ == "__main__" :
0 commit comments