Skip to content

Commit 830fbfb

Browse files
committed
expected
1 parent 12d224e commit 830fbfb

File tree

1 file changed

+110
-0
lines changed

1 file changed

+110
-0
lines changed

test/test_indexing.expected

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,116 @@ def broadcast_add_3d(x: torch.Tensor, bias1: torch.Tensor, bias2: torch.Tensor,
285285
# src[test_indexing.py:N]: return out
286286
return out
287287

288+
--- assertExpectedJournal(TestIndexing.test_indirect_indexing_2d)
289+
from __future__ import annotations
290+
291+
import torch
292+
import triton
293+
import triton.language as tl
294+
from helion.runtime import default_launcher as _default_launcher
295+
296+
@triton.jit
297+
def _helion_test(col, B_flat, val, C, B_flat_stride_0, C_stride_0, C_stride_1, col_stride_0, col_stride_1, val_stride_0, val_stride_1, M, N, K, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
298+
num_blocks_0 = tl.cdiv(M, _BLOCK_SIZE_0)
299+
pid_0 = tl.program_id(0) % num_blocks_0
300+
pid_1 = tl.program_id(0) // num_blocks_0
301+
offset_0 = pid_0 * _BLOCK_SIZE_0
302+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
303+
mask_0 = indices_0 < M
304+
offset_1 = pid_1 * _BLOCK_SIZE_1
305+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
306+
mask_1 = indices_1 < N
307+
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
308+
for offset_3 in tl.range(0, K.to(tl.int32), _BLOCK_SIZE_2):
309+
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
310+
mask_2 = indices_3 < K
311+
acc_copy = acc
312+
acc_copy_0 = acc_copy
313+
cols_2d = tl.load(col + (indices_0[:, None] * col_stride_0 + indices_3[None, :] * col_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
314+
v_0 = cols_2d * N
315+
subscript = v_0[:, :, None]
316+
v_1 = tl.cast(indices_1, tl.int64)
317+
v_2 = subscript + v_1
318+
B_slice = tl.load(B_flat + v_2 * B_flat_stride_0, mask_0[:, None, None] & mask_2[None, :, None] & mask_1[None, None, :], other=0)
319+
vals_2d = tl.load(val + (indices_0[:, None] * val_stride_0 + indices_3[None, :] * val_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
320+
subscript_1 = vals_2d[:, :, None]
321+
v_3 = subscript_1 * B_slice
322+
contrib_1 = tl.cast(tl.sum(v_3, 1), tl.float32)
323+
acc = acc_copy_0 + contrib_1
324+
tl.store(C + (indices_0[:, None] * C_stride_0 + indices_1[None, :] * C_stride_1), acc, mask_0[:, None] & mask_1[None, :])
325+
326+
def test(col: torch.Tensor, val: torch.Tensor, B: torch.Tensor, *, _launcher=_default_launcher):
327+
M, K = col.shape
328+
_, N = B.shape
329+
out_dtype = torch.promote_types(val.dtype, B.dtype)
330+
C = torch.empty((M, N), dtype=out_dtype, device=B.device)
331+
B_flat = B.reshape(-1)
332+
_BLOCK_SIZE_0 = 8
333+
_BLOCK_SIZE_1 = 8
334+
_BLOCK_SIZE_2 = 4
335+
_RDIM_SIZE_3 = triton.next_power_of_2(_BLOCK_SIZE_1)
336+
_launcher(_helion_test, (triton.cdiv(M, _BLOCK_SIZE_0) * triton.cdiv(N, _BLOCK_SIZE_1),), col, B_flat, val, C, B_flat.stride(0), C.stride(0), C.stride(1), col.stride(0), col.stride(1), val.stride(0), val.stride(1), M, N, K, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
337+
return C
338+
339+
--- assertExpectedJournal(TestIndexing.test_indirect_indexing_3d)
340+
from __future__ import annotations
341+
342+
import torch
343+
import triton
344+
import triton.language as tl
345+
from helion.runtime import default_launcher as _default_launcher
346+
347+
@triton.jit
348+
def _helion_test(col, B, val, C, B_stride_0, B_stride_1, B_stride_2, C_stride_0, C_stride_1, C_stride_2, C_stride_3, col_stride_0, col_stride_1, col_stride_2, val_stride_0, val_stride_1, val_stride_2, M, N, P, Q, K, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SIZE_4: tl.constexpr):
349+
num_blocks_0 = tl.cdiv(M, _BLOCK_SIZE_0)
350+
num_blocks_1 = tl.cdiv(N, _BLOCK_SIZE_1)
351+
num_blocks_2 = tl.cdiv(P, _BLOCK_SIZE_2)
352+
pid_0 = tl.program_id(0) % num_blocks_0
353+
pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1
354+
pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1) % num_blocks_2
355+
pid_3 = tl.program_id(0) // (num_blocks_0 * num_blocks_1 * num_blocks_2)
356+
offset_0 = pid_0 * _BLOCK_SIZE_0
357+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
358+
mask_0 = indices_0 < M
359+
offset_1 = pid_1 * _BLOCK_SIZE_1
360+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
361+
mask_1 = indices_1 < N
362+
offset_2 = pid_2 * _BLOCK_SIZE_2
363+
indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32)
364+
mask_2 = indices_2 < P
365+
offset_3 = pid_3 * _BLOCK_SIZE_3
366+
indices_3 = (offset_3 + tl.arange(0, _BLOCK_SIZE_3)).to(tl.int32)
367+
mask_3 = indices_3 < Q
368+
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3], 0.0, tl.float32)
369+
for offset_5 in tl.range(0, K.to(tl.int32), _BLOCK_SIZE_4):
370+
indices_5 = offset_5 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32)
371+
mask_4 = indices_5 < K
372+
acc_copy = acc
373+
acc_copy_0 = acc_copy
374+
cols_3d = tl.load(col + (indices_0[:, None, None] * col_stride_0 + indices_1[None, :, None] * col_stride_1 + indices_5[None, None, :] * col_stride_2), mask_0[:, None, None] & mask_1[None, :, None] & mask_4[None, None, :], other=0)
375+
subscript = cols_3d[:, :, :, None, None]
376+
B_slice = tl.load(B + (subscript * B_stride_0 + indices_2[None, None, None, :, None] * B_stride_1 + indices_3[None, None, None, None, :] * B_stride_2), mask_0[:, None, None, None, None] & mask_1[None, :, None, None, None] & mask_4[None, None, :, None, None] & mask_2[None, None, None, :, None] & mask_3[None, None, None, None, :], other=0)
377+
vals_3d = tl.load(val + (indices_0[:, None, None] * val_stride_0 + indices_1[None, :, None] * val_stride_1 + indices_5[None, None, :] * val_stride_2), mask_0[:, None, None] & mask_1[None, :, None] & mask_4[None, None, :], other=0)
378+
subscript_1 = vals_3d[:, :, :, None, None]
379+
v_0 = subscript_1 * B_slice
380+
contrib_1 = tl.cast(tl.sum(v_0, 2), tl.float32)
381+
acc = acc_copy_0 + contrib_1
382+
tl.store(C + (indices_0[:, None, None, None] * C_stride_0 + indices_1[None, :, None, None] * C_stride_1 + indices_2[None, None, :, None] * C_stride_2 + indices_3[None, None, None, :] * C_stride_3), acc, mask_0[:, None, None, None] & mask_1[None, :, None, None] & mask_2[None, None, :, None] & mask_3[None, None, None, :])
383+
384+
def test(col: torch.Tensor, val: torch.Tensor, B: torch.Tensor, *, _launcher=_default_launcher):
385+
M, N, K = col.shape
386+
_, P, Q = B.shape
387+
out_dtype = torch.promote_types(val.dtype, B.dtype)
388+
C = torch.empty((M, N, P, Q), dtype=out_dtype, device=B.device)
389+
_BLOCK_SIZE_0 = 4
390+
_BLOCK_SIZE_1 = 4
391+
_BLOCK_SIZE_2 = 4
392+
_BLOCK_SIZE_3 = 4
393+
_BLOCK_SIZE_4 = 4
394+
_RDIM_SIZE_5 = triton.next_power_of_2(_BLOCK_SIZE_2)
395+
_launcher(_helion_test, (triton.cdiv(M, _BLOCK_SIZE_0) * triton.cdiv(N, _BLOCK_SIZE_1) * triton.cdiv(P, _BLOCK_SIZE_2) * triton.cdiv(Q, _BLOCK_SIZE_3),), col, B, val, C, B.stride(0), B.stride(1), B.stride(2), C.stride(0), C.stride(1), C.stride(2), C.stride(3), col.stride(0), col.stride(1), col.stride(2), val.stride(0), val.stride(1), val.stride(2), M, N, P, Q, K, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, num_warps=4, num_stages=3)
396+
return C
397+
288398
--- assertExpectedJournal(TestIndexing.test_mask_load)
289399
from __future__ import annotations
290400

0 commit comments

Comments
 (0)