@@ -99,6 +99,9 @@ def jit_add_combine_fn(x, y):
9999
100100
101101class TestAssociativeScan (RefEagerTestBase , TestCase ):
102+ @skipIfRefEager (
103+ "torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
104+ )
102105 def test_associative_scan_basic_addition (self ):
103106 """Test basic associative_scan functionality with prefix sum."""
104107
@@ -132,6 +135,9 @@ def test_scan_kernel(x: torch.Tensor) -> torch.Tensor:
132135 self .assertIn ("param_0 + param_1" , code )
133136 self .assertIn ("tl.associative_scan" , code )
134137
138+ @skipIfRefEager (
139+ "torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
140+ )
135141 def test_associative_scan_maximum (self ):
136142 """Test associative_scan with maximum combine function."""
137143
@@ -164,6 +170,9 @@ def test_max_kernel(x: torch.Tensor) -> torch.Tensor:
164170 "tl.maximum" in code or "triton_helpers.maximum" in code
165171 )
166172
173+ @skipIfRefEager (
174+ "torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
175+ )
167176 def test_associative_scan_multiplication (self ):
168177 """Test associative_scan with multiplication combine function."""
169178
@@ -194,6 +203,9 @@ def test_mul_kernel(x: torch.Tensor) -> torch.Tensor:
194203 # Verify the generated code contains multiplication
195204 self .assertIn ("param_0 * param_1" , code )
196205
206+ @skipIfRefEager (
207+ "torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
208+ )
197209 def test_associative_scan_minimum (self ):
198210 """Test associative_scan with minimum combine function."""
199211
@@ -226,6 +238,9 @@ def test_min_kernel(x: torch.Tensor) -> torch.Tensor:
226238 "tl.minimum" in code or "triton_helpers.minimum" in code
227239 )
228240
241+ @skipIfRefEager (
242+ "torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
243+ )
229244 def test_associative_scan_multiple_functions (self ):
230245 """Test using multiple different combine functions in one kernel."""
231246
@@ -262,6 +277,9 @@ def test_multi_kernel(x: torch.Tensor) -> torch.Tensor:
262277 "tl.maximum" in code or "triton_helpers.maximum" in code
263278 )
264279
280+ @skipIfRefEager (
281+ "torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
282+ )
265283 def test_associative_scan_type_propagation (self ):
266284 """Test that associative_scan type propagation works correctly."""
267285
@@ -286,6 +304,9 @@ def test_type_kernel(x: torch.Tensor) -> torch.Tensor:
286304 # Use relaxed tolerance for large tensors due to accumulated floating-point errors
287305 torch .testing .assert_close (result , expected , rtol = 1e-4 , atol = 1e-4 )
288306
307+ @skipIfRefEager (
308+ "torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
309+ )
289310 def test_associative_scan_different_dtypes (self ):
290311 """Test associative_scan with different data types."""
291312
@@ -320,6 +341,9 @@ def test_dtype_kernel(x: torch.Tensor) -> torch.Tensor:
320341 expected = expected .to (result .dtype )
321342 torch .testing .assert_close (result , expected , rtol = 1e-4 , atol = 1e-4 )
322343
344+ @skipIfRefEager (
345+ "torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
346+ )
323347 def test_associative_scan_different_sizes (self ):
324348 """Test associative_scan with different tensor sizes."""
325349
@@ -356,6 +380,9 @@ def test_size_kernel(x: torch.Tensor) -> torch.Tensor:
356380 expected = torch .cumsum (x , dim = 1 )
357381 torch .testing .assert_close (result , expected , rtol = 1e-4 , atol = 1e-4 )
358382
383+ @skipIfRefEager (
384+ "torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
385+ )
359386 def test_associative_scan_reverse (self ):
360387 """Test associative_scan with reverse=True parameter."""
361388
@@ -381,6 +408,9 @@ def test_reverse_kernel(x: torch.Tensor) -> torch.Tensor:
381408 # Verify reverse parameter is in generated code
382409 self .assertIn ("reverse=True" , code )
383410
411+ @skipIfRefEager (
412+ "torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
413+ )
384414 def test_associative_scan_edge_cases (self ):
385415 """Test associative_scan edge cases."""
386416
@@ -406,6 +436,9 @@ def test_single_element(x: torch.Tensor) -> torch.Tensor:
406436 expected = torch .tensor ([[3.0 , 10.0 ]], device = DEVICE )
407437 torch .testing .assert_close (result , expected , rtol = 1e-4 , atol = 1e-4 )
408438
439+ @skipIfRefEager (
440+ "torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
441+ )
409442 def test_associative_scan_large_scale (self ):
410443 """Test associative_scan with large tensors for performance validation."""
411444
@@ -431,6 +464,9 @@ def test_large_kernel(x: torch.Tensor) -> torch.Tensor:
431464 self .assertEqual (result .shape , x .shape )
432465 self .assertEqual (result .dtype , x .dtype )
433466
467+ @skipIfRefEager (
468+ "torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
469+ )
434470 def test_associative_scan_torch_hops_mapping (self ):
435471 """Test that torch._higher_order_ops.associative_scan automatically maps to hl.associative_scan."""
436472
@@ -466,6 +502,9 @@ def test_torch_hops_kernel(x: torch.Tensor) -> torch.Tensor:
466502 self .assertIn ("tl.associative_scan" , code )
467503 self .assertIn ("param_0 + param_1" , code )
468504
505+ @skipIfRefEager (
506+ "torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
507+ )
469508 def test_associative_scan_code_generation (self ):
470509 """Test that the generated code structure is correct."""
471510
@@ -705,6 +744,9 @@ def cumulative_argmax_kernel(
705744 self .assertIn ("def argmax_combine_fn_" , code )
706745 self .assertIn ("tl.associative_scan" , code )
707746
747+ @skipIfRefEager (
748+ "torch._higher_order_ops.associative_scan is not supported by ref eager mode yet"
749+ )
708750 def test_associative_scan_in_helper_function (self ):
709751 """Test calling a function that internally uses hl.associative_scan."""
710752
@@ -766,6 +808,7 @@ def test_cumsum_kernel(x: torch.Tensor) -> torch.Tensor:
766808 self .assertIn ("param_0 + param_1" , code )
767809 self .assertIn ("tl.associative_scan" , code )
768810
811+ @skipIfRefEager ("hl.cumsum is not supported by ref eager mode yet" )
769812 def test_cumsum_reverse (self ):
770813 """Test cumsum with reverse=True."""
771814
@@ -847,6 +890,7 @@ def test_cumprod_kernel(x: torch.Tensor) -> torch.Tensor:
847890 self .assertIn ("param_0 * param_1" , code )
848891 self .assertIn ("tl.associative_scan" , code )
849892
893+ @skipIfRefEager ("hl.cumprod is not supported by ref eager mode yet" )
850894 def test_cumprod_reverse (self ):
851895 """Test cumprod with reverse=True."""
852896
@@ -870,6 +914,7 @@ def test_cumprod_reverse_kernel(x: torch.Tensor) -> torch.Tensor:
870914 # Verify reverse parameter is used
871915 self .assertIn ("reverse=True" , code )
872916
917+ @skipIfRefEager ("torch.cumprod is not supported by ref eager mode yet" )
873918 def test_cumprod_different_dtypes (self ):
874919 """Test cumprod with different data types."""
875920
@@ -988,6 +1033,9 @@ def test_segmented_tuple_kernel(
9881033 self .assertIn ("def helion_combine_tuple_fn_" , code )
9891034 self .assertIn ("tl.associative_scan" , code )
9901035
1036+ @skipIfRefEager (
1037+ "torch._higher_order_ops.associative_scan with tuple arg is not supported by ref eager mode yet"
1038+ )
9911039 def test_associative_scan_argmax_tuple_format (self ):
9921040 """Test cumulative argmax using tuple format combine function."""
9931041
0 commit comments