Skip to content

Commit 94b0650

Browse files
authored
Fix register_block_size codegen (#659)
1 parent 0fc6dc7 commit 94b0650

File tree

4 files changed

+98
-2
lines changed

4 files changed

+98
-2
lines changed

helion/_compiler/type_propagation.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,6 +1064,19 @@ def propagate_attribute(self, attr: str, origin: AttributeOrigin) -> TypeInfo:
10641064
return super().propagate_attribute(attr, origin)
10651065

10661066

1067+
class BlockSizeType(SymIntType):
1068+
"""Type for block sizes registered via register_block_size"""
1069+
1070+
block_id: int
1071+
1072+
def __init__(self, origin: Origin, value: torch.SymInt, block_id: int) -> None:
1073+
super().__init__(origin, value)
1074+
self.block_id = block_id
1075+
1076+
def __str__(self) -> str:
1077+
return f"{type(self).__name__}({self.block_id})"
1078+
1079+
10671080
class GridIndexType(SymIntType):
10681081
block_id: int
10691082

helion/language/tunable_ops.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def _(min_or_max: int, max_or_none: int | None = None, /) -> int:
6666
def _(
6767
min_or_max: TypeInfo, max_or_none: TypeInfo | None = None, /, *, origin: Origin
6868
) -> TypeInfo:
69-
from .._compiler.type_propagation import SymIntType
69+
from .._compiler.type_propagation import BlockSizeType
7070

7171
min_type, max_type = _normalize_begin_end(min_or_max, max_or_none, origin=origin)
7272
min_proxy = _to_proxy(min_type)
@@ -85,22 +85,27 @@ def _(
8585
loop_spec.min_size = assert_integer_power_of_two(max(1, min_proxy))
8686
loop_spec.max_size = next_power_of_2(env.size_hint(max_proxy))
8787
block_id = result.block_id
88-
return SymIntType(origin, env.block_sizes[block_id].var)
88+
return BlockSizeType(origin, env.block_sizes[block_id].var, block_id)
8989

9090

9191
def _block_id_from_state(state: CodegenState) -> int:
9292
"""Extract the block_id from the current state for nodes hl.register_block_size."""
93+
from .._compiler.type_propagation import BlockSizeType
9394
from .._compiler.type_propagation import SymIntType
9495

9596
env = CompileEnvironment.current()
9697
if state.fx_node is not None:
9798
val = state.fx_node.meta["val"]
99+
if isinstance(val, BlockSizeType):
100+
return val.block_id
98101
assert isinstance(val, SymIntType)
99102
block_id = env.get_block_id(val.value)
100103
assert block_id is not None
101104
return block_id
102105
current_node = ExtendedAST.current()[-1]
103106
type_info = current_node._type_info
107+
if isinstance(type_info, BlockSizeType):
108+
return type_info.block_id
104109
assert isinstance(type_info, SymIntType)
105110
block_id = env.get_block_id(type_info.value)
106111
assert block_id is not None

test/test_loops.expected

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -971,6 +971,45 @@ def nested_loop_kernel(x: torch.Tensor, *, _launcher=_default_launcher):
971971
_launcher(_helion_nested_loop_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, out, x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
972972
return out
973973

974+
--- assertExpectedJournal(TestLoops.test_register_block_size_codegen_size_hint)
975+
from __future__ import annotations
976+
977+
import torch
978+
import triton
979+
import triton.language as tl
980+
from helion.runtime import default_launcher as _default_launcher
981+
982+
@triton.jit
983+
def _helion_kernel_fixed_block_size(loss_sum, y_true, kl_loss, loss, loss_sum_stride_0, _BLOCK_SIZE_1: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
984+
pid_0 = tl.program_id(0)
985+
offset_1 = pid_0 * _BLOCK_SIZE_1
986+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
987+
indices_4 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32)
988+
full = tl.full([64, 64], 0.0, tl.float32)
989+
tl.store(loss_sum + (indices_4[:, None] * loss_sum_stride_0 + indices_4[None, :] * 1), full, None)
990+
for offset_2 in tl.range(0, 128, _BLOCK_SIZE_3):
991+
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
992+
y_true_val = tl.load(y_true + (indices_1[:, None] * 128 + indices_2[None, :] * 1), None)
993+
tl.store(kl_loss + (indices_1[:, None] * 128 + indices_2[None, :] * 1), y_true_val, None)
994+
load_1 = tl.load(kl_loss + (indices_1[:, None] * 128 + indices_2[None, :] * 1), None)
995+
tl.atomic_add(loss_sum + (indices_1[:, None] * loss_sum_stride_0 + indices_2[None, :] * 1), load_1, mask=None, sem='relaxed')
996+
load = tl.load(loss_sum + (indices_4[:, None] * loss_sum_stride_0 + indices_4[None, :] * 1), None)
997+
sum_1 = tl.cast(tl.sum(load, 1), tl.float32)
998+
tl.store(loss + indices_1 * 1, sum_1, None)
999+
1000+
def kernel_fixed_block_size(y_pred: torch.Tensor, y_true: torch.Tensor, *, _launcher=_default_launcher):
1001+
BT, V_local = y_pred.shape
1002+
loss = torch.zeros((BT,), dtype=torch.float32, device=y_pred.device)
1003+
kl_loss = torch.zeros_like(y_pred)
1004+
block_size_n = 128
1005+
BT_SIZE = 64
1006+
loss_sum = torch.zeros([BT_SIZE, block_size_n], dtype=torch.float32, device=y_pred.device)
1007+
_BLOCK_SIZE_1 = 64
1008+
_RDIM_SIZE_2 = 64
1009+
_BLOCK_SIZE_3 = 64
1010+
_launcher(_helion_kernel_fixed_block_size, (triton.cdiv(64, _BLOCK_SIZE_1),), loss_sum, y_true, kl_loss, loss, loss_sum.stride(0), _BLOCK_SIZE_1, _RDIM_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
1011+
return torch.sum(loss) / BT
1012+
9741013
--- assertExpectedJournal(TestLoops.test_reorder_with_register_block_size)
9751014
from __future__ import annotations
9761015

test/test_loops.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,45 @@ def fn(x: torch.Tensor) -> torch.Tensor:
324324
self.assertEqual(spec.min_size, 32)
325325
self.assertEqual(spec.max_size, 256)
326326

327+
@skipIfRefEager("Triton codegen is disabled in ref eager mode")
328+
def test_register_block_size_codegen_size_hint(self):
329+
@helion.kernel(static_shapes=True)
330+
def kernel_fixed_block_size(
331+
y_pred: torch.Tensor,
332+
y_true: torch.Tensor,
333+
) -> torch.Tensor:
334+
BT, V_local = y_pred.shape
335+
336+
loss = torch.zeros((BT,), dtype=torch.float32, device=y_pred.device)
337+
kl_loss = torch.zeros_like(y_pred)
338+
339+
block_size_n = hl.register_block_size(V_local)
340+
BT_SIZE = 64
341+
loss_sum = torch.zeros(
342+
[BT_SIZE, block_size_n], dtype=torch.float32, device=y_pred.device
343+
)
344+
345+
for tile_bt in hl.tile(BT, block_size=BT_SIZE):
346+
loss_sum[:, :] = hl.zeros([BT_SIZE, block_size_n], dtype=torch.float32)
347+
for tile_v in hl.tile(V_local, block_size=block_size_n):
348+
y_true_val = y_true[tile_bt, tile_v]
349+
kl_loss[tile_bt, tile_v] = y_true_val
350+
hl.atomic_add(loss_sum, [tile_bt, tile_v], kl_loss[tile_bt, tile_v])
351+
352+
loss[tile_bt] = loss_sum[:, :].sum(dim=-1)
353+
354+
return torch.sum(loss) / BT
355+
356+
y_pred = torch.randn(64, 128, device=DEVICE, dtype=torch.float32)
357+
y_true = torch.randn(64, 128, device=DEVICE, dtype=torch.float32)
358+
args = (y_pred, y_true)
359+
360+
code, result = code_and_output(kernel_fixed_block_size, args, block_sizes=[128])
361+
self.assertExpectedJournal(code)
362+
363+
expected = y_true[:, : y_pred.size(0)].sum() / y_pred.size(0)
364+
torch.testing.assert_close(result, expected)
365+
327366
def test_reorder_with_register_block_size(self):
328367
@helion.kernel(
329368
config={

0 commit comments

Comments
 (0)