@@ -971,6 +971,45 @@ def nested_loop_kernel(x: torch.Tensor, *, _launcher=_default_launcher):
971971 _launcher(_helion_nested_loop_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, out, x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
972972 return out
973973
974+ --- assertExpectedJournal(TestLoops.test_register_block_size_codegen_size_hint)
975+ from __future__ import annotations
976+
977+ import torch
978+ import triton
979+ import triton.language as tl
980+ from helion.runtime import default_launcher as _default_launcher
981+
982+ @triton.jit
983+ def _helion_kernel_fixed_block_size(loss_sum, y_true, kl_loss, loss, loss_sum_stride_0, _BLOCK_SIZE_1: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
984+ pid_0 = tl.program_id(0)
985+ offset_1 = pid_0 * _BLOCK_SIZE_1
986+ indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
987+ indices_4 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32)
988+ full = tl.full([64, 64], 0.0, tl.float32)
989+ tl.store(loss_sum + (indices_4[:, None] * loss_sum_stride_0 + indices_4[None, :] * 1), full, None)
990+ for offset_2 in tl.range(0, 128, _BLOCK_SIZE_3):
991+ indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
992+ y_true_val = tl.load(y_true + (indices_1[:, None] * 128 + indices_2[None, :] * 1), None)
993+ tl.store(kl_loss + (indices_1[:, None] * 128 + indices_2[None, :] * 1), y_true_val, None)
994+ load_1 = tl.load(kl_loss + (indices_1[:, None] * 128 + indices_2[None, :] * 1), None)
995+ tl.atomic_add(loss_sum + (indices_1[:, None] * loss_sum_stride_0 + indices_2[None, :] * 1), load_1, mask=None, sem='relaxed')
996+ load = tl.load(loss_sum + (indices_4[:, None] * loss_sum_stride_0 + indices_4[None, :] * 1), None)
997+ sum_1 = tl.cast(tl.sum(load, 1), tl.float32)
998+ tl.store(loss + indices_1 * 1, sum_1, None)
999+
1000+ def kernel_fixed_block_size(y_pred: torch.Tensor, y_true: torch.Tensor, *, _launcher=_default_launcher):
1001+ BT, V_local = y_pred.shape
1002+ loss = torch.zeros((BT,), dtype=torch.float32, device=y_pred.device)
1003+ kl_loss = torch.zeros_like(y_pred)
1004+ block_size_n = 128
1005+ BT_SIZE = 64
1006+ loss_sum = torch.zeros([BT_SIZE, block_size_n], dtype=torch.float32, device=y_pred.device)
1007+ _BLOCK_SIZE_1 = 64
1008+ _RDIM_SIZE_2 = 64
1009+ _BLOCK_SIZE_3 = 64
1010+ _launcher(_helion_kernel_fixed_block_size, (triton.cdiv(64, _BLOCK_SIZE_1),), loss_sum, y_true, kl_loss, loss, loss_sum.stride(0), _BLOCK_SIZE_1, _RDIM_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
1011+ return torch.sum(loss) / BT
1012+
9741013--- assertExpectedJournal(TestLoops.test_reorder_with_register_block_size)
9751014from __future__ import annotations
9761015
0 commit comments