@@ -484,9 +484,6 @@ def kernel(
484484 torch .testing .assert_close (src_result , expected_src )
485485 torch .testing .assert_close (dst_result , expected_dst )
486486
487- @skipIfNormalMode (
488- "AssertionError in roll_reduction.py:104 - stored_node is not a torch.fx.Node"
489- )
490487 def test_2d_full_slice (self ):
491488 """Test both setter from scalar and getter for [:,:]"""
492489
@@ -537,33 +534,79 @@ def kernel(
537534 torch .testing .assert_close (src_result , expected_src )
538535 torch .testing .assert_close (dst_result , expected_dst )
539536
540- @skipIfNormalMode (
541- "AssertionError in roll_reduction.py:104 - stored_node is not a torch.fx.Node"
542- )
543537 def test_1d_full_slice (self ):
544- """Test both setter from scalar and getter for [:]"""
538+ """Test both setter from scalar and getter for [:] with multiple scalar types """
545539
546- @helion .kernel (use_default_config = True )
540+ @helion .kernel (config = { "block_size" : 128 } )
547541 def kernel (
548- src : torch .Tensor , dst : torch .Tensor
549- ) -> tuple [torch .Tensor , torch .Tensor ]:
550- N = src .shape [0 ]
551- for _ in hl .grid (N ):
552- dst [:] = 1.0 # Test setter with scalar
553- src [:] = dst [:] # Test getter from dst and setter to src
554- return src , dst
542+ src_float : torch .Tensor ,
543+ dst_float : torch .Tensor ,
544+ src_int : torch .Tensor ,
545+ dst_int : torch .Tensor ,
546+ src_symint : torch .Tensor ,
547+ dst_symint : torch .Tensor ,
548+ ) -> tuple [
549+ torch .Tensor ,
550+ torch .Tensor ,
551+ torch .Tensor ,
552+ torch .Tensor ,
553+ torch .Tensor ,
554+ torch .Tensor ,
555+ ]:
556+ N = src_float .shape [0 ]
557+ for tile in hl .tile (N ):
558+ # Test float scalar
559+ dst_float [:] = 1.0
560+ src_float [:] = dst_float [:]
561+
562+ # Test int scalar
563+ dst_int [:] = 99
564+ src_int [:] = dst_int [:]
565+
566+ # Test SymInt scalar
567+ dst_symint [:] = tile .block_size
568+ src_symint [:] = dst_symint [:]
569+
570+ return (
571+ src_float ,
572+ dst_float ,
573+ src_int ,
574+ dst_int ,
575+ src_symint ,
576+ dst_symint ,
577+ )
555578
556579 N = 128
557- src = torch .zeros ([N ], device = DEVICE )
558- dst = torch .zeros ([N ], device = DEVICE )
580+ src_float = torch .zeros ([N ], device = DEVICE )
581+ dst_float = torch .zeros ([N ], device = DEVICE )
582+ src_int = torch .zeros ([N ], device = DEVICE )
583+ dst_int = torch .zeros ([N ], device = DEVICE )
584+ src_symint = torch .zeros ([N ], device = DEVICE )
585+ dst_symint = torch .zeros ([N ], device = DEVICE )
586+
587+ results = kernel (
588+ src_float ,
589+ dst_float ,
590+ src_int ,
591+ dst_int ,
592+ src_symint ,
593+ dst_symint ,
594+ )
559595
560- src_result , dst_result = kernel (src , dst )
596+ # Check float results
597+ expected_float = torch .ones ([N ], device = DEVICE )
598+ torch .testing .assert_close (results [0 ], expected_float )
599+ torch .testing .assert_close (results [1 ], expected_float )
561600
562- # Both should be ones after the kernel
563- expected_src = torch .ones ([N ], device = DEVICE )
564- expected_dst = torch .ones ([N ], device = DEVICE )
565- torch .testing .assert_close (src_result , expected_src )
566- torch .testing .assert_close (dst_result , expected_dst )
601+ # Check int results
602+ expected_int = torch .full ([N ], 99.0 , device = DEVICE )
603+ torch .testing .assert_close (results [2 ], expected_int )
604+ torch .testing .assert_close (results [3 ], expected_int )
605+
606+ # Check SymInt results
607+ expected_symint = torch .full ([N ], 128.0 , device = DEVICE )
608+ torch .testing .assert_close (results [4 ], expected_symint )
609+ torch .testing .assert_close (results [5 ], expected_symint )
567610
568611 @skipIfNormalMode (
569612 "RankMismatch: Expected ndim=1, but got ndim=0 - LHS/RHS shape mismatch in type_propagation.py"
@@ -624,9 +667,6 @@ def kernel(buf: torch.Tensor, zeros: torch.Tensor) -> torch.Tensor:
624667 expected = torch .zeros ([N ], device = DEVICE )
625668 torch .testing .assert_close (result , expected )
626669
627- @skipIfNormalMode (
628- "AssertionError in roll_reduction.py:104 - stored_node is not a torch.fx.Node"
629- )
630670 def test_mixed_slice_index (self ):
631671 """Test both setter from scalar and getter for [i,:]"""
632672
0 commit comments