@@ -285,6 +285,116 @@ def broadcast_add_3d(x: torch.Tensor, bias1: torch.Tensor, bias2: torch.Tensor,
285285 # src[test_indexing.py:N]: return out
286286 return out
287287
288+ --- assertExpectedJournal(TestIndexing.test_indirect_indexing_2d)
289+ from __future__ import annotations
290+
291+ import torch
292+ import triton
293+ import triton.language as tl
294+ from helion.runtime import default_launcher as _default_launcher
295+
296+ @triton.jit
297+ def _helion_test(col, B_flat, val, C, B_flat_stride_0, C_stride_0, C_stride_1, col_stride_0, col_stride_1, val_stride_0, val_stride_1, M, N, K, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
298+ num_blocks_0 = tl.cdiv(M, _BLOCK_SIZE_0)
299+ pid_0 = tl.program_id(0) % num_blocks_0
300+ pid_1 = tl.program_id(0) // num_blocks_0
301+ offset_0 = pid_0 * _BLOCK_SIZE_0
302+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
303+ mask_0 = indices_0 < M
304+ offset_1 = pid_1 * _BLOCK_SIZE_1
305+ indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
306+ mask_1 = indices_1 < N
307+ acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
308+ for offset_3 in tl.range(0, K.to(tl.int32), _BLOCK_SIZE_2):
309+ indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
310+ mask_2 = indices_3 < K
311+ acc_copy = acc
312+ acc_copy_0 = acc_copy
313+ cols_2d = tl.load(col + (indices_0[:, None] * col_stride_0 + indices_3[None, :] * col_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
314+ v_0 = cols_2d * N
315+ subscript = v_0[:, :, None]
316+ v_1 = tl.cast(indices_1, tl.int64)
317+ v_2 = subscript + v_1
318+ B_slice = tl.load(B_flat + v_2 * B_flat_stride_0, mask_0[:, None, None] & mask_2[None, :, None] & mask_1[None, None, :], other=0)
319+ vals_2d = tl.load(val + (indices_0[:, None] * val_stride_0 + indices_3[None, :] * val_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
320+ subscript_1 = vals_2d[:, :, None]
321+ v_3 = subscript_1 * B_slice
322+ contrib_1 = tl.cast(tl.sum(v_3, 1), tl.float32)
323+ acc = acc_copy_0 + contrib_1
324+ tl.store(C + (indices_0[:, None] * C_stride_0 + indices_1[None, :] * C_stride_1), acc, mask_0[:, None] & mask_1[None, :])
325+
326+ def test(col: torch.Tensor, val: torch.Tensor, B: torch.Tensor, *, _launcher=_default_launcher):
327+ M, K = col.shape
328+ _, N = B.shape
329+ out_dtype = torch.promote_types(val.dtype, B.dtype)
330+ C = torch.empty((M, N), dtype=out_dtype, device=B.device)
331+ B_flat = B.reshape(-1)
332+ _BLOCK_SIZE_0 = 8
333+ _BLOCK_SIZE_1 = 8
334+ _BLOCK_SIZE_2 = 4
335+ _RDIM_SIZE_3 = triton.next_power_of_2(_BLOCK_SIZE_1)
336+ _launcher(_helion_test, (triton.cdiv(M, _BLOCK_SIZE_0) * triton.cdiv(N, _BLOCK_SIZE_1),), col, B_flat, val, C, B_flat.stride(0), C.stride(0), C.stride(1), col.stride(0), col.stride(1), val.stride(0), val.stride(1), M, N, K, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
337+ return C
338+
339+ --- assertExpectedJournal(TestIndexing.test_indirect_indexing_3d)
340+ from __future__ import annotations
341+
342+ import torch
343+ import triton
344+ import triton.language as tl
345+ from helion.runtime import default_launcher as _default_launcher
346+
347+ @triton.jit
348+ def _helion_test(col, B, val, C, B_stride_0, B_stride_1, B_stride_2, C_stride_0, C_stride_1, C_stride_2, C_stride_3, col_stride_0, col_stride_1, col_stride_2, val_stride_0, val_stride_1, val_stride_2, M, N, P, Q, K, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SIZE_4: tl.constexpr):
349+ num_blocks_0 = tl.cdiv(M, _BLOCK_SIZE_0)
350+ num_blocks_1 = tl.cdiv(N, _BLOCK_SIZE_1)
351+ num_blocks_2 = tl.cdiv(P, _BLOCK_SIZE_2)
352+ pid_0 = tl.program_id(0) % num_blocks_0
353+ pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1
354+ pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1) % num_blocks_2
355+ pid_3 = tl.program_id(0) // (num_blocks_0 * num_blocks_1 * num_blocks_2)
356+ offset_0 = pid_0 * _BLOCK_SIZE_0
357+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
358+ mask_0 = indices_0 < M
359+ offset_1 = pid_1 * _BLOCK_SIZE_1
360+ indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
361+ mask_1 = indices_1 < N
362+ offset_2 = pid_2 * _BLOCK_SIZE_2
363+ indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32)
364+ mask_2 = indices_2 < P
365+ offset_3 = pid_3 * _BLOCK_SIZE_3
366+ indices_3 = (offset_3 + tl.arange(0, _BLOCK_SIZE_3)).to(tl.int32)
367+ mask_3 = indices_3 < Q
368+ acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3], 0.0, tl.float32)
369+ for offset_5 in tl.range(0, K.to(tl.int32), _BLOCK_SIZE_4):
370+ indices_5 = offset_5 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32)
371+ mask_4 = indices_5 < K
372+ acc_copy = acc
373+ acc_copy_0 = acc_copy
374+ cols_3d = tl.load(col + (indices_0[:, None, None] * col_stride_0 + indices_1[None, :, None] * col_stride_1 + indices_5[None, None, :] * col_stride_2), mask_0[:, None, None] & mask_1[None, :, None] & mask_4[None, None, :], other=0)
375+ subscript = cols_3d[:, :, :, None, None]
376+ B_slice = tl.load(B + (subscript * B_stride_0 + indices_2[None, None, None, :, None] * B_stride_1 + indices_3[None, None, None, None, :] * B_stride_2), mask_0[:, None, None, None, None] & mask_1[None, :, None, None, None] & mask_4[None, None, :, None, None] & mask_2[None, None, None, :, None] & mask_3[None, None, None, None, :], other=0)
377+ vals_3d = tl.load(val + (indices_0[:, None, None] * val_stride_0 + indices_1[None, :, None] * val_stride_1 + indices_5[None, None, :] * val_stride_2), mask_0[:, None, None] & mask_1[None, :, None] & mask_4[None, None, :], other=0)
378+ subscript_1 = vals_3d[:, :, :, None, None]
379+ v_0 = subscript_1 * B_slice
380+ contrib_1 = tl.cast(tl.sum(v_0, 2), tl.float32)
381+ acc = acc_copy_0 + contrib_1
382+ tl.store(C + (indices_0[:, None, None, None] * C_stride_0 + indices_1[None, :, None, None] * C_stride_1 + indices_2[None, None, :, None] * C_stride_2 + indices_3[None, None, None, :] * C_stride_3), acc, mask_0[:, None, None, None] & mask_1[None, :, None, None] & mask_2[None, None, :, None] & mask_3[None, None, None, :])
383+
384+ def test(col: torch.Tensor, val: torch.Tensor, B: torch.Tensor, *, _launcher=_default_launcher):
385+ M, N, K = col.shape
386+ _, P, Q = B.shape
387+ out_dtype = torch.promote_types(val.dtype, B.dtype)
388+ C = torch.empty((M, N, P, Q), dtype=out_dtype, device=B.device)
389+ _BLOCK_SIZE_0 = 4
390+ _BLOCK_SIZE_1 = 4
391+ _BLOCK_SIZE_2 = 4
392+ _BLOCK_SIZE_3 = 4
393+ _BLOCK_SIZE_4 = 4
394+ _RDIM_SIZE_5 = triton.next_power_of_2(_BLOCK_SIZE_2)
395+ _launcher(_helion_test, (triton.cdiv(M, _BLOCK_SIZE_0) * triton.cdiv(N, _BLOCK_SIZE_1) * triton.cdiv(P, _BLOCK_SIZE_2) * triton.cdiv(Q, _BLOCK_SIZE_3),), col, B, val, C, B.stride(0), B.stride(1), B.stride(2), C.stride(0), C.stride(1), C.stride(2), C.stride(3), col.stride(0), col.stride(1), col.stride(2), val.stride(0), val.stride(1), val.stride(2), M, N, P, Q, K, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, num_warps=4, num_stages=3)
396+ return C
397+
288398--- assertExpectedJournal(TestIndexing.test_mask_load)
289399from __future__ import annotations
290400
0 commit comments