Skip to content

Commit b15b026

Browse files
authored
Fix misaligned address error for matmul (#662)
1 parent 07b1182 commit b15b026

File tree

3 files changed

+89
-28
lines changed

3 files changed

+89
-28
lines changed

helion/_compiler/indexing_strategy.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,10 @@ def valid_block_size(
216216
if fake_tensor.ndim == 2 and block_size < threshold:
217217
return False
218218

219-
# was getting some IMAs with small block sizes even in non-stride 1 dims
220-
return block_size * element_size >= 16 or (block_size == 1 and stride != 1)
219+
# Tensor-descriptor path (TMA + WGMMA / stmatrix writes)
220+
# moves data in 16-byte chunks. Enforce a 16-byte minimum so the
221+
# generated stores stay aligned and avoid misaligned-address errors.
222+
return block_size * element_size >= 16
221223

222224
# 4) Check minimum 16 bytes in each dimension
223225
sizes = fake_tensor.size()

test/test_tensor_descriptor.expected

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,32 +5,27 @@ Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environmen
55
from __future__ import annotations
66

77
import torch
8-
import helion
98
import triton
109
import triton.language as tl
1110
from torch._inductor.runtime import triton_helpers
1211
from torch._inductor.runtime.triton_compat import libdevice
1312
from helion.runtime import default_launcher as _default_launcher
1413

15-
helion.runtime.set_triton_allocator()
16-
1714
@triton.jit
18-
def _helion_attention(q_view, k_view, v_view, out, k_view_size_0, k_view_size_2, out_size_0, out_size_1, q_in_size_1, q_view_size_0, q_view_size_1, v_view_size_0, v_view_size_1, k_view_stride_0, k_view_stride_1, k_view_stride_2, out_stride_0, out_stride_1, out_stride_2, q_view_stride_0, q_view_stride_1, q_view_stride_2, v_view_stride_0, v_view_stride_1, v_view_stride_2, m_dim, n_dim, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
19-
q_view_desc = tl.make_tensor_descriptor(q_view, [q_view_size_0, q_view_size_1, 64], [q_view_stride_0, q_view_stride_1, q_view_stride_2], [_BLOCK_SIZE_0, _BLOCK_SIZE_1, 64])
20-
k_view_desc = tl.make_tensor_descriptor(k_view, [k_view_size_0, k_view_size_2, 64], [k_view_stride_0, k_view_stride_2, k_view_stride_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_3, 64])
21-
v_view_desc = tl.make_tensor_descriptor(v_view, [v_view_size_0, v_view_size_1, 64], [v_view_stride_0, v_view_stride_1, v_view_stride_2], [_BLOCK_SIZE_0, _BLOCK_SIZE_3, 64])
22-
out_desc = tl.make_tensor_descriptor(out, [out_size_0, out_size_1, 64], [out_stride_0, out_stride_1, out_stride_2], [_BLOCK_SIZE_0, _BLOCK_SIZE_1, 64])
15+
def _helion_attention(q_view, k_view, v_view, out, q_in_size_1, k_view_stride_0, k_view_stride_1, k_view_stride_2, out_stride_0, out_stride_1, out_stride_2, q_view_stride_0, q_view_stride_1, q_view_stride_2, v_view_stride_0, v_view_stride_1, v_view_stride_2, m_dim, n_dim, _BLOCK_SIZE_1: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
2316
num_blocks_0 = q_in_size_1
2417
pid_0 = tl.program_id(0) % num_blocks_0
2518
pid_1 = tl.program_id(0) // num_blocks_0
2619
offset_0 = pid_0
20+
indices_0 = offset_0 + tl.zeros([1], tl.int32)
2721
offset_1 = pid_1 * _BLOCK_SIZE_1
2822
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
2923
mask_1 = indices_1 < m_dim
24+
indices_4 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32)
3025
m_i = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], float('-inf'), tl.float32)
3126
l_i = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 1.0, tl.float32)
3227
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1, 64], 0.0, tl.float32)
33-
q = q_view_desc.load([offset_0, offset_1, 0])
28+
q = tl.load(q_view + (indices_0[:, None, None] * q_view_stride_0 + indices_1[None, :, None] * q_view_stride_1 + indices_4[None, None, :] * q_view_stride_2), mask_1[None, :, None], other=0)
3429
for offset_2 in tl.range(0, n_dim.to(tl.int32), _BLOCK_SIZE_3):
3530
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
3631
mask_3 = indices_2 < n_dim
@@ -42,7 +37,7 @@ def _helion_attention(q_view, k_view, v_view, out, k_view_size_0, k_view_size_2,
4237
m_i_copy_0 = m_i_copy
4338
l_i_copy_0 = l_i_copy
4439
acc_copy_0 = acc_copy
45-
k = tl.permute(k_view_desc.load([offset_0, offset_2, 0]), [0, 2, 1])
40+
k = tl.load(k_view + (indices_0[:, None, None] * k_view_stride_0 + indices_4[None, :, None] * k_view_stride_1 + indices_2[None, None, :] * k_view_stride_2), mask_3[None, None, :], other=0)
4641
qk = tl.reshape(tl.dot(tl.reshape(tl.cast(q_copy_0, tl.float32), [_BLOCK_SIZE_1, 64]), tl.reshape(tl.cast(k, tl.float32), [64, _BLOCK_SIZE_3]), input_precision='tf32', out_dtype=tl.float32), [_BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_3])
4742
_mask_to_2 = tl.where(tl.broadcast_to(mask_1[None, :, None] & mask_3[None, None, :], [_BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_3]), qk, tl.full([], float('-inf'), tl.float32))
4843
amax = tl.cast(tl.max(_mask_to_2, 2), tl.float32)
@@ -62,12 +57,12 @@ def _helion_attention(q_view, k_view, v_view, out, k_view_size_0, k_view_size_2,
6257
l_i = v_9 + l_ij
6358
subscript_1 = v_8[:, :, None]
6459
v_11 = acc_copy_0 * subscript_1
65-
v = v_view_desc.load([offset_0, offset_2, 0])
60+
v = tl.load(v_view + (indices_0[:, None, None] * v_view_stride_0 + indices_2[None, :, None] * v_view_stride_1 + indices_4[None, None, :] * v_view_stride_2), mask_3[None, :, None], other=0)
6661
acc = tl.reshape(tl.dot(tl.reshape(tl.cast(_mask_to_3, tl.float32), [_BLOCK_SIZE_1, _BLOCK_SIZE_3]), tl.reshape(tl.cast(v, tl.float32), [_BLOCK_SIZE_3, 64]), acc=tl.reshape(v_11, [_BLOCK_SIZE_1, 64]), input_precision='tf32', out_dtype=tl.float32), [_BLOCK_SIZE_0, _BLOCK_SIZE_1, 64])
6762
m_i = v_2
6863
subscript_2 = l_i[:, :, None]
6964
v_12 = acc / subscript_2
70-
out_desc.store([offset_0, offset_1, 0], v_12)
65+
tl.store(out + (indices_0[:, None, None] * out_stride_0 + indices_1[None, :, None] * out_stride_1 + indices_4[None, None, :] * out_stride_2), v_12, mask_1[None, :, None])
7166

7267
def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, *, _launcher=_default_launcher):
7368
"""
@@ -93,39 +88,37 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, *, _la
9388
k_view = k_in.reshape([-1, n_dim, head_dim]).transpose(1, 2)
9489
out = torch.empty_like(q_view)
9590
_BLOCK_SIZE_1 = 16
91+
_RDIM_SIZE_2 = 64
9692
_BLOCK_SIZE_3 = 16
97-
_launcher(_helion_attention, (q_in.size(1) * triton.cdiv(m_dim, _BLOCK_SIZE_1),), q_view, k_view, v_view, out, k_view.size(0), k_view.size(2), out.size(0), out.size(1), q_in.size(1), q_view.size(0), q_view.size(1), v_view.size(0), v_view.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, 1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
93+
_launcher(_helion_attention, (q_in.size(1) * triton.cdiv(m_dim, _BLOCK_SIZE_1),), q_view, k_view, v_view, out, q_in.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, _RDIM_SIZE_2, 1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
9894
return out.view(q_in.size())
9995

10096
--- assertExpectedJournal(TestTensorDescriptor.test_attention_tensor_descriptor)
10197
from __future__ import annotations
10298

10399
import torch
104-
import helion
105100
import triton
106101
import triton.language as tl
107102
from torch._inductor.runtime import triton_helpers
108103
from torch._inductor.runtime.triton_compat import libdevice
109104
from helion.runtime import default_launcher as _default_launcher
110105

111-
helion.runtime.set_triton_allocator()
112-
113106
@triton.jit
114-
def _helion_attention(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
115-
q_view_desc = tl.make_tensor_descriptor(q_view, [64, 1024, 64], [65536, 64, 1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1, 64])
116-
k_view_desc = tl.make_tensor_descriptor(k_view, [64, 512, 64], [32768, 64, 1], [_BLOCK_SIZE_0, _BLOCK_SIZE_3, 64])
117-
v_view_desc = tl.make_tensor_descriptor(v_view, [64, 512, 64], [32768, 64, 1], [_BLOCK_SIZE_0, _BLOCK_SIZE_3, 64])
118-
out_desc = tl.make_tensor_descriptor(out, [64, 1024, 64], [65536, 64, 1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1, 64])
107+
def _helion_attention(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
119108
num_blocks_0 = 64
120109
pid_0 = tl.program_id(0) % num_blocks_0
121110
pid_1 = tl.program_id(0) // num_blocks_0
122111
offset_0 = pid_0
112+
indices_0 = offset_0 + tl.zeros([1], tl.int32)
123113
offset_1 = pid_1 * _BLOCK_SIZE_1
114+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
115+
indices_4 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32)
124116
m_i = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], float('-inf'), tl.float32)
125117
l_i = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 1.0, tl.float32)
126118
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1, 64], 0.0, tl.float32)
127-
q = q_view_desc.load([offset_0, offset_1, 0])
119+
q = tl.load(q_view + (indices_0[:, None, None] * 65536 + indices_1[None, :, None] * 64 + indices_4[None, None, :] * 1), None)
128120
for offset_2 in tl.range(0, 512, _BLOCK_SIZE_3):
121+
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
129122
q_copy = q
130123
m_i_copy = m_i
131124
l_i_copy = l_i
@@ -134,7 +127,7 @@ def _helion_attention(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr,
134127
m_i_copy_0 = m_i_copy
135128
l_i_copy_0 = l_i_copy
136129
acc_copy_0 = acc_copy
137-
k = tl.permute(k_view_desc.load([offset_0, offset_2, 0]), [0, 2, 1])
130+
k = tl.load(k_view + (indices_0[:, None, None] * 32768 + indices_4[None, :, None] * 1 + indices_2[None, None, :] * 64), None)
138131
qk = tl.reshape(tl.dot(tl.reshape(tl.cast(q_copy_0, tl.float16), [_BLOCK_SIZE_1, 64]), tl.reshape(tl.cast(k, tl.float16), [64, _BLOCK_SIZE_3]), input_precision='tf32', out_dtype=tl.float32), [_BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_3])
139132
amax = tl.cast(tl.max(qk, 2), tl.float16)
140133
v_0 = 0.18033688
@@ -154,14 +147,14 @@ def _helion_attention(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr,
154147
l_i = v_11 + l_ij
155148
subscript_1 = v_10[:, :, None]
156149
v_13 = acc_copy_0 * subscript_1
157-
v = v_view_desc.load([offset_0, offset_2, 0])
150+
v = tl.load(v_view + (indices_0[:, None, None] * 32768 + indices_2[None, :, None] * 64 + indices_4[None, None, :] * 1), None)
158151
v_14 = tl.cast(v_8, tl.float16)
159152
acc = tl.reshape(tl.dot(tl.reshape(tl.cast(v_14, tl.float16), [_BLOCK_SIZE_1, _BLOCK_SIZE_3]), tl.reshape(tl.cast(v, tl.float16), [_BLOCK_SIZE_3, 64]), acc=tl.reshape(v_13, [_BLOCK_SIZE_1, 64]), input_precision='tf32', out_dtype=tl.float32), [_BLOCK_SIZE_0, _BLOCK_SIZE_1, 64])
160153
m_i = v_3
161154
subscript_2 = l_i[:, :, None]
162155
v_15 = acc / subscript_2
163156
v_16 = tl.cast(v_15, tl.float16)
164-
out_desc.store([offset_0, offset_1, 0], v_16)
157+
tl.store(out + (indices_0[:, None, None] * 65536 + indices_1[None, :, None] * 64 + indices_4[None, None, :] * 1), v_16, None)
165158

166159
def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, *, _launcher=_default_launcher):
167160
"""
@@ -187,6 +180,7 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, *, _la
187180
k_view = k_in.reshape([-1, n_dim, head_dim]).transpose(1, 2)
188181
out = torch.empty_like(q_view)
189182
_BLOCK_SIZE_1 = 128
183+
_RDIM_SIZE_2 = 64
190184
_BLOCK_SIZE_3 = 64
191-
_launcher(_helion_attention, (64 * triton.cdiv(1024, _BLOCK_SIZE_1),), q_view, k_view, v_view, out, _BLOCK_SIZE_1, 1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
185+
_launcher(_helion_attention, (64 * triton.cdiv(1024, _BLOCK_SIZE_1),), q_view, k_view, v_view, out, _BLOCK_SIZE_1, _RDIM_SIZE_2, 1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
192186
return out.view(q_in.size())

test/test_tensor_descriptor.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,71 @@ def kernel_different_blocks(x: torch.Tensor) -> torch.Tensor:
198198
# The block sizes should also be permuted in the tensor descriptor
199199
# This is important for correctness
200200

201+
@unittest.skipUnless(
202+
supports_tensor_descriptor(), "Tensor descriptor support is required"
203+
)
204+
def test_tiny_matmul_tile_fallback(self) -> None:
205+
"""Tensor descriptor indexing should be rejected when the tile is too small."""
206+
207+
@helion.kernel(
208+
config=helion.Config(
209+
block_sizes=[1, 16, 16],
210+
indexing="tensor_descriptor",
211+
l2_groupings=[2],
212+
loop_orders=[[0, 1]],
213+
num_stages=4,
214+
num_warps=1,
215+
pid_type="persistent_blocked",
216+
range_flattens=[True, True],
217+
range_multi_buffers=[False, True],
218+
range_num_stages=[0, 1],
219+
range_unroll_factors=[0, 4],
220+
),
221+
static_shapes=True,
222+
)
223+
def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
224+
m, k = x.size()
225+
k2, n = y.size()
226+
assert k == k2
227+
out = torch.empty(
228+
[m, n],
229+
dtype=torch.promote_types(x.dtype, y.dtype),
230+
device=x.device,
231+
)
232+
for tile_m, tile_n in hl.tile([m, n]):
233+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
234+
for tile_k in hl.tile(k):
235+
acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
236+
out[tile_m, tile_n] = acc.to(out.dtype)
237+
return out
238+
239+
x = torch.randn((64, 64), device=DEVICE, dtype=torch.float16)
240+
y = torch.randn((64, 64), device=DEVICE, dtype=torch.float16)
241+
242+
code, result = code_and_output(matmul, (x, y))
243+
torch.cuda.synchronize()
244+
expected = torch.matmul(x, y)
245+
torch.testing.assert_close(result, expected, atol=1e-2, rtol=1e-2)
246+
247+
# Ensure we fall back to pointer indexing for accesses that would use the
248+
# 1x16 tile - there should be no tensor descriptor for the x or out tensors.
249+
self.assertNotIn("x_desc = tl.make_tensor_descriptor", code)
250+
self.assertNotIn("out_desc = tl.make_tensor_descriptor", code)
251+
# The K dimension still has a valid tile size, so the column operand can
252+
# keep using tensor descriptors.
253+
self.assertIn("y_desc = tl.make_tensor_descriptor", code)
254+
255+
# A larger tile should still be able to use tensor descriptors
256+
code_large, result_large = code_and_output(
257+
matmul,
258+
(x, y),
259+
block_sizes=[16, 16, 16],
260+
indexing="tensor_descriptor",
261+
)
262+
torch.cuda.synchronize()
263+
torch.testing.assert_close(result_large, expected, atol=1e-2, rtol=1e-2)
264+
self.assertIn(get_tensor_descriptor_fn_name(), code_large)
265+
201266
@unittest.skipUnless(
202267
supports_tensor_descriptor(), "Tensor descriptor support is required"
203268
)

0 commit comments

Comments
 (0)