Skip to content

Commit 12d224e

Browse files
committed
test
fix test
1 parent 913f7c7 commit 12d224e

File tree

1 file changed

+124
-0
lines changed

1 file changed

+124
-0
lines changed

test/test_indexing.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1604,6 +1604,130 @@ def load_store_kernel(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
16041604
self.assertEqual(code3, code4)
16051605
self.assertExpectedJournal(code4)
16061606

1607+
def test_indirect_indexing_2d(self):
1608+
@helion.kernel()
1609+
def test(
1610+
col: torch.Tensor, # [M, K] int64
1611+
val: torch.Tensor, # [M, K] fp32
1612+
B: torch.Tensor, # [K, N] fp32
1613+
) -> torch.Tensor: # [M, N] fp32
1614+
M, K = col.shape
1615+
_, N = B.shape
1616+
out_dtype = torch.promote_types(val.dtype, B.dtype)
1617+
C = torch.empty((M, N), dtype=out_dtype, device=B.device)
1618+
B_flat = B.reshape(-1) # [K*N]
1619+
1620+
for tile_m, tile_n in hl.tile([M, N]):
1621+
# [tile_m, tile_n]
1622+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
1623+
1624+
for tile_k in hl.tile(K):
1625+
# [tile_m, tile_k]
1626+
cols_2d = col[tile_m, tile_k]
1627+
# [tile_m, tile_k, tile_n]
1628+
B_slice = hl.load(
1629+
B_flat,
1630+
[(cols_2d * N)[:, :, None] + tile_n.index[None, None, :]]
1631+
)
1632+
# [tile_m, tile_k]
1633+
vals_2d = val[tile_m, tile_k]
1634+
# [tile_m, tile_k, tile_n]
1635+
contrib = vals_2d[:, :, None] * B_slice
1636+
# [tile_m, tile_n]
1637+
contrib = contrib.sum(dim=1)
1638+
# [tile_m, tile_n]
1639+
acc = acc + contrib
1640+
1641+
C[tile_m, tile_n] = acc.to(out_dtype)
1642+
1643+
return C
1644+
1645+
M, K, N = 32, 16, 24
1646+
col = torch.randint(0, K, (M, K), device=DEVICE, dtype=torch.int64)
1647+
val = torch.rand((M, K), device=DEVICE, dtype=torch.float32)
1648+
B = torch.rand((K, N), device=DEVICE, dtype=torch.float32)
1649+
1650+
code, result = code_and_output(
1651+
test,
1652+
(col, val, B),
1653+
block_size=[8, 8, 4],
1654+
)
1655+
1656+
# For each output position (i,j), compute sum over k: val[i,k] * B[col[i,k], j]
1657+
expected = torch.zeros((M, N), device=DEVICE, dtype=torch.float32)
1658+
for i in range(M):
1659+
for j in range(N):
1660+
for k in range(K):
1661+
expected[i, j] += val[i, k] * B[col[i, k], j]
1662+
1663+
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5)
1664+
self.assertExpectedJournal(code)
1665+
1666+
def test_indirect_indexing_3d(self):
1667+
@helion.kernel()
1668+
def test(
1669+
col: torch.Tensor, # [M, N, K] int64 - indices for first dimension of B
1670+
val: torch.Tensor, # [M, N, K] fp32 - values to multiply
1671+
B: torch.Tensor, # [K, P, Q] fp32 - tensor to index into
1672+
) -> torch.Tensor: # [M, N, P, Q] fp32
1673+
M, N, K = col.shape
1674+
_, P, Q = B.shape
1675+
out_dtype = torch.promote_types(val.dtype, B.dtype)
1676+
C = torch.empty((M, N, P, Q), dtype=out_dtype, device=B.device)
1677+
1678+
for tile_m, tile_n, tile_p, tile_q in hl.tile([M, N, P, Q]):
1679+
# [tile_m, tile_n, tile_p, tile_q]
1680+
acc = hl.zeros([tile_m, tile_n, tile_p, tile_q], dtype=torch.float32)
1681+
1682+
for tile_k in hl.tile(K):
1683+
# [tile_m, tile_n, tile_k]
1684+
cols_3d = col[tile_m, tile_n, tile_k]
1685+
1686+
# [tile_m, tile_n, tile_k, tile_p, tile_q]
1687+
# Direct indexing into B using gather
1688+
B_slice = B[
1689+
cols_3d[:, :, :, None, None],
1690+
tile_p.index[None, None, :, None],
1691+
tile_q.index[None, None, None, :],
1692+
]
1693+
1694+
# [tile_m, tile_n, tile_k]
1695+
vals_3d = val[tile_m, tile_n, tile_k]
1696+
1697+
# [tile_m, tile_n, tile_k, tile_p, tile_q]
1698+
contrib = vals_3d[:, :, :, None, None] * B_slice
1699+
1700+
# [tile_m, tile_n, tile_p, tile_q] - sum over k dimension
1701+
contrib = contrib.sum(dim=2)
1702+
1703+
# [tile_m, tile_n, tile_p, tile_q]
1704+
acc = acc + contrib
1705+
1706+
C[tile_m, tile_n, tile_p, tile_q] = acc.to(out_dtype)
1707+
return C
1708+
1709+
M, N, K, P, Q = 16, 12, 8, 10, 14
1710+
col = torch.randint(0, K, (M, N, K), device=DEVICE, dtype=torch.int64)
1711+
val = torch.rand((M, N, K), device=DEVICE, dtype=torch.float32)
1712+
B = torch.rand((K, P, Q), device=DEVICE, dtype=torch.float32)
1713+
1714+
code, result = code_and_output(
1715+
test,
1716+
(col, val, B),
1717+
block_size=[4, 4, 4, 4, 4], # 5D tiling for M, N, P, Q, K
1718+
)
1719+
1720+
# For each output position (i,j,p,q), compute sum over k: val[i,j,k] * B[col[i,j,k], p, q]
1721+
expected = torch.zeros((M, N, P, Q), device=DEVICE, dtype=torch.float32)
1722+
for i in range(M):
1723+
for j in range(N):
1724+
for p in range(P):
1725+
for q in range(Q):
1726+
for k in range(K):
1727+
expected[i, j, p, q] += val[i, j, k] * B[col[i, j, k], p, q]
1728+
1729+
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5)
1730+
self.assertExpectedJournal(code)
16071731

16081732
if __name__ == "__main__":
16091733
unittest.main()

0 commit comments

Comments
 (0)