@@ -1604,6 +1604,130 @@ def load_store_kernel(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
16041604 self .assertEqual (code3 , code4 )
16051605 self .assertExpectedJournal (code4 )
16061606
1607+ def test_indirect_indexing_2d (self ):
1608+ @helion .kernel ()
1609+ def test (
1610+ col : torch .Tensor , # [M, K] int64
1611+ val : torch .Tensor , # [M, K] fp32
1612+ B : torch .Tensor , # [K, N] fp32
1613+ ) -> torch .Tensor : # [M, N] fp32
1614+ M , K = col .shape
1615+ _ , N = B .shape
1616+ out_dtype = torch .promote_types (val .dtype , B .dtype )
1617+ C = torch .empty ((M , N ), dtype = out_dtype , device = B .device )
1618+ B_flat = B .reshape (- 1 ) # [K*N]
1619+
1620+ for tile_m , tile_n in hl .tile ([M , N ]):
1621+ # [tile_m, tile_n]
1622+ acc = hl .zeros ([tile_m , tile_n ], dtype = torch .float32 )
1623+
1624+ for tile_k in hl .tile (K ):
1625+ # [tile_m, tile_k]
1626+ cols_2d = col [tile_m , tile_k ]
1627+ # [tile_m, tile_k, tile_n]
1628+ B_slice = hl .load (
1629+ B_flat ,
1630+ [(cols_2d * N )[:, :, None ] + tile_n .index [None , None , :]]
1631+ )
1632+ # [tile_m, tile_k]
1633+ vals_2d = val [tile_m , tile_k ]
1634+ # [tile_m, tile_k, tile_n]
1635+ contrib = vals_2d [:, :, None ] * B_slice
1636+ # [tile_m, tile_n]
1637+ contrib = contrib .sum (dim = 1 )
1638+ # [tile_m, tile_n]
1639+ acc = acc + contrib
1640+
1641+ C [tile_m , tile_n ] = acc .to (out_dtype )
1642+
1643+ return C
1644+
1645+ M , K , N = 32 , 16 , 24
1646+ col = torch .randint (0 , K , (M , K ), device = DEVICE , dtype = torch .int64 )
1647+ val = torch .rand ((M , K ), device = DEVICE , dtype = torch .float32 )
1648+ B = torch .rand ((K , N ), device = DEVICE , dtype = torch .float32 )
1649+
1650+ code , result = code_and_output (
1651+ test ,
1652+ (col , val , B ),
1653+ block_size = [8 , 8 , 4 ],
1654+ )
1655+
1656+ # For each output position (i,j), compute sum over k: val[i,k] * B[col[i,k], j]
1657+ expected = torch .zeros ((M , N ), device = DEVICE , dtype = torch .float32 )
1658+ for i in range (M ):
1659+ for j in range (N ):
1660+ for k in range (K ):
1661+ expected [i , j ] += val [i , k ] * B [col [i , k ], j ]
1662+
1663+ torch .testing .assert_close (result , expected , rtol = 1e-5 , atol = 1e-5 )
1664+ self .assertExpectedJournal (code )
1665+
1666+ def test_indirect_indexing_3d (self ):
1667+ @helion .kernel ()
1668+ def test (
1669+ col : torch .Tensor , # [M, N, K] int64 - indices for first dimension of B
1670+ val : torch .Tensor , # [M, N, K] fp32 - values to multiply
1671+ B : torch .Tensor , # [K, P, Q] fp32 - tensor to index into
1672+ ) -> torch .Tensor : # [M, N, P, Q] fp32
1673+ M , N , K = col .shape
1674+ _ , P , Q = B .shape
1675+ out_dtype = torch .promote_types (val .dtype , B .dtype )
1676+ C = torch .empty ((M , N , P , Q ), dtype = out_dtype , device = B .device )
1677+
1678+ for tile_m , tile_n , tile_p , tile_q in hl .tile ([M , N , P , Q ]):
1679+ # [tile_m, tile_n, tile_p, tile_q]
1680+ acc = hl .zeros ([tile_m , tile_n , tile_p , tile_q ], dtype = torch .float32 )
1681+
1682+ for tile_k in hl .tile (K ):
1683+ # [tile_m, tile_n, tile_k]
1684+ cols_3d = col [tile_m , tile_n , tile_k ]
1685+
1686+ # [tile_m, tile_n, tile_k, tile_p, tile_q]
1687+ # Direct indexing into B using gather
1688+ B_slice = B [
1689+ cols_3d [:, :, :, None , None ],
1690+ tile_p .index [None , None , :, None ],
1691+ tile_q .index [None , None , None , :],
1692+ ]
1693+
1694+ # [tile_m, tile_n, tile_k]
1695+ vals_3d = val [tile_m , tile_n , tile_k ]
1696+
1697+ # [tile_m, tile_n, tile_k, tile_p, tile_q]
1698+ contrib = vals_3d [:, :, :, None , None ] * B_slice
1699+
1700+ # [tile_m, tile_n, tile_p, tile_q] - sum over k dimension
1701+ contrib = contrib .sum (dim = 2 )
1702+
1703+ # [tile_m, tile_n, tile_p, tile_q]
1704+ acc = acc + contrib
1705+
1706+ C [tile_m , tile_n , tile_p , tile_q ] = acc .to (out_dtype )
1707+ return C
1708+
1709+ M , N , K , P , Q = 16 , 12 , 8 , 10 , 14
1710+ col = torch .randint (0 , K , (M , N , K ), device = DEVICE , dtype = torch .int64 )
1711+ val = torch .rand ((M , N , K ), device = DEVICE , dtype = torch .float32 )
1712+ B = torch .rand ((K , P , Q ), device = DEVICE , dtype = torch .float32 )
1713+
1714+ code , result = code_and_output (
1715+ test ,
1716+ (col , val , B ),
1717+ block_size = [4 , 4 , 4 , 4 , 4 ], # 5D tiling for M, N, P, Q, K
1718+ )
1719+
1720+ # For each output position (i,j,p,q), compute sum over k: val[i,j,k] * B[col[i,j,k], p, q]
1721+ expected = torch .zeros ((M , N , P , Q ), device = DEVICE , dtype = torch .float32 )
1722+ for i in range (M ):
1723+ for j in range (N ):
1724+ for p in range (P ):
1725+ for q in range (Q ):
1726+ for k in range (K ):
1727+ expected [i , j , p , q ] += val [i , j , k ] * B [col [i , j , k ], p , q ]
1728+
1729+ torch .testing .assert_close (result , expected , rtol = 1e-5 , atol = 1e-5 )
1730+ self .assertExpectedJournal (code )
16071731
16081732if __name__ == "__main__" :
16091733 unittest .main ()
0 commit comments