@@ -5,32 +5,27 @@ Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environmen
55from __future__ import annotations
66
77import torch
8- import helion
98import triton
109import triton.language as tl
1110from torch._inductor.runtime import triton_helpers
1211from torch._inductor.runtime.triton_compat import libdevice
1312from helion.runtime import default_launcher as _default_launcher
1413
15- helion.runtime.set_triton_allocator()
16-
1714@triton.jit
18- def _helion_attention(q_view, k_view, v_view, out, k_view_size_0, k_view_size_2, out_size_0, out_size_1, q_in_size_1, q_view_size_0, q_view_size_1, v_view_size_0, v_view_size_1, k_view_stride_0, k_view_stride_1, k_view_stride_2, out_stride_0, out_stride_1, out_stride_2, q_view_stride_0, q_view_stride_1, q_view_stride_2, v_view_stride_0, v_view_stride_1, v_view_stride_2, m_dim, n_dim, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
19- q_view_desc = tl.make_tensor_descriptor(q_view, [q_view_size_0, q_view_size_1, 64], [q_view_stride_0, q_view_stride_1, q_view_stride_2], [_BLOCK_SIZE_0, _BLOCK_SIZE_1, 64])
20- k_view_desc = tl.make_tensor_descriptor(k_view, [k_view_size_0, k_view_size_2, 64], [k_view_stride_0, k_view_stride_2, k_view_stride_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_3, 64])
21- v_view_desc = tl.make_tensor_descriptor(v_view, [v_view_size_0, v_view_size_1, 64], [v_view_stride_0, v_view_stride_1, v_view_stride_2], [_BLOCK_SIZE_0, _BLOCK_SIZE_3, 64])
22- out_desc = tl.make_tensor_descriptor(out, [out_size_0, out_size_1, 64], [out_stride_0, out_stride_1, out_stride_2], [_BLOCK_SIZE_0, _BLOCK_SIZE_1, 64])
15+ def _helion_attention(q_view, k_view, v_view, out, q_in_size_1, k_view_stride_0, k_view_stride_1, k_view_stride_2, out_stride_0, out_stride_1, out_stride_2, q_view_stride_0, q_view_stride_1, q_view_stride_2, v_view_stride_0, v_view_stride_1, v_view_stride_2, m_dim, n_dim, _BLOCK_SIZE_1: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
2316 num_blocks_0 = q_in_size_1
2417 pid_0 = tl.program_id(0) % num_blocks_0
2518 pid_1 = tl.program_id(0) // num_blocks_0
2619 offset_0 = pid_0
20+ indices_0 = offset_0 + tl.zeros([1], tl.int32)
2721 offset_1 = pid_1 * _BLOCK_SIZE_1
2822 indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
2923 mask_1 = indices_1 < m_dim
24+ indices_4 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32)
3025 m_i = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], float('-inf'), tl.float32)
3126 l_i = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 1.0, tl.float32)
3227 acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1, 64], 0.0, tl.float32)
33- q = q_view_desc .load([offset_0, offset_1, 0] )
28+ q = tl .load(q_view + (indices_0[:, None, None] * q_view_stride_0 + indices_1[None, :, None] * q_view_stride_1 + indices_4[None, None, :] * q_view_stride_2), mask_1[None, :, None], other=0 )
3429 for offset_2 in tl.range(0, n_dim.to(tl.int32), _BLOCK_SIZE_3):
3530 indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
3631 mask_3 = indices_2 < n_dim
@@ -42,7 +37,7 @@ def _helion_attention(q_view, k_view, v_view, out, k_view_size_0, k_view_size_2,
4237 m_i_copy_0 = m_i_copy
4338 l_i_copy_0 = l_i_copy
4439 acc_copy_0 = acc_copy
45- k = tl.permute(k_view_desc. load([offset_0, offset_2, 0] ), [0, 2, 1] )
40+ k = tl.load(k_view + (indices_0[:, None, None] * k_view_stride_0 + indices_4[None, :, None] * k_view_stride_1 + indices_2[None, None, :] * k_view_stride_2 ), mask_3[None, None, :], other=0 )
4641 qk = tl.reshape(tl.dot(tl.reshape(tl.cast(q_copy_0, tl.float32), [_BLOCK_SIZE_1, 64]), tl.reshape(tl.cast(k, tl.float32), [64, _BLOCK_SIZE_3]), input_precision='tf32', out_dtype=tl.float32), [_BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_3])
4742 _mask_to_2 = tl.where(tl.broadcast_to(mask_1[None, :, None] & mask_3[None, None, :], [_BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_3]), qk, tl.full([], float('-inf'), tl.float32))
4843 amax = tl.cast(tl.max(_mask_to_2, 2), tl.float32)
@@ -62,12 +57,12 @@ def _helion_attention(q_view, k_view, v_view, out, k_view_size_0, k_view_size_2,
6257 l_i = v_9 + l_ij
6358 subscript_1 = v_8[:, :, None]
6459 v_11 = acc_copy_0 * subscript_1
65- v = v_view_desc .load([offset_0, offset_2, 0] )
60+ v = tl .load(v_view + (indices_0[:, None, None] * v_view_stride_0 + indices_2[None, :, None] * v_view_stride_1 + indices_4[None, None, :] * v_view_stride_2), mask_3[None, :, None], other=0 )
6661 acc = tl.reshape(tl.dot(tl.reshape(tl.cast(_mask_to_3, tl.float32), [_BLOCK_SIZE_1, _BLOCK_SIZE_3]), tl.reshape(tl.cast(v, tl.float32), [_BLOCK_SIZE_3, 64]), acc=tl.reshape(v_11, [_BLOCK_SIZE_1, 64]), input_precision='tf32', out_dtype=tl.float32), [_BLOCK_SIZE_0, _BLOCK_SIZE_1, 64])
6762 m_i = v_2
6863 subscript_2 = l_i[:, :, None]
6964 v_12 = acc / subscript_2
70- out_desc .store([offset_0, offset_1, 0], v_12)
65+ tl .store(out + (indices_0[:, None, None] * out_stride_0 + indices_1[None, :, None] * out_stride_1 + indices_4[None, None, :] * out_stride_2), v_12, mask_1[None, :, None] )
7166
7267def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, *, _launcher=_default_launcher):
7368 """
@@ -93,39 +88,37 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, *, _la
9388 k_view = k_in.reshape([-1, n_dim, head_dim]).transpose(1, 2)
9489 out = torch.empty_like(q_view)
9590 _BLOCK_SIZE_1 = 16
91+ _RDIM_SIZE_2 = 64
9692 _BLOCK_SIZE_3 = 16
97- _launcher(_helion_attention, (q_in.size(1) * triton.cdiv(m_dim, _BLOCK_SIZE_1),), q_view, k_view, v_view, out, k_view.size(0), k_view.size(2), out.size(0), out.size(1), q_in.size(1), q_view.size(0), q_view.size(1), v_view.size(0), v_view.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, 1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
93+ _launcher(_helion_attention, (q_in.size(1) * triton.cdiv(m_dim, _BLOCK_SIZE_1),), q_view, k_view, v_view, out, q_in.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, _RDIM_SIZE_2 , 1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
9894 return out.view(q_in.size())
9995
10096--- assertExpectedJournal(TestTensorDescriptor.test_attention_tensor_descriptor)
10197from __future__ import annotations
10298
10399import torch
104- import helion
105100import triton
106101import triton.language as tl
107102from torch._inductor.runtime import triton_helpers
108103from torch._inductor.runtime.triton_compat import libdevice
109104from helion.runtime import default_launcher as _default_launcher
110105
111- helion.runtime.set_triton_allocator()
112-
113106@triton.jit
114- def _helion_attention(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
115- q_view_desc = tl.make_tensor_descriptor(q_view, [64, 1024, 64], [65536, 64, 1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1, 64])
116- k_view_desc = tl.make_tensor_descriptor(k_view, [64, 512, 64], [32768, 64, 1], [_BLOCK_SIZE_0, _BLOCK_SIZE_3, 64])
117- v_view_desc = tl.make_tensor_descriptor(v_view, [64, 512, 64], [32768, 64, 1], [_BLOCK_SIZE_0, _BLOCK_SIZE_3, 64])
118- out_desc = tl.make_tensor_descriptor(out, [64, 1024, 64], [65536, 64, 1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1, 64])
107+ def _helion_attention(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
119108 num_blocks_0 = 64
120109 pid_0 = tl.program_id(0) % num_blocks_0
121110 pid_1 = tl.program_id(0) // num_blocks_0
122111 offset_0 = pid_0
112+ indices_0 = offset_0 + tl.zeros([1], tl.int32)
123113 offset_1 = pid_1 * _BLOCK_SIZE_1
114+ indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
115+ indices_4 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32)
124116 m_i = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], float('-inf'), tl.float32)
125117 l_i = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 1.0, tl.float32)
126118 acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1, 64], 0.0, tl.float32)
127- q = q_view_desc .load([offset_0, offset_1, 0] )
119+ q = tl .load(q_view + (indices_0[:, None, None] * 65536 + indices_1[None, :, None] * 64 + indices_4[None, None, :] * 1), None )
128120 for offset_2 in tl.range(0, 512, _BLOCK_SIZE_3):
121+ indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
129122 q_copy = q
130123 m_i_copy = m_i
131124 l_i_copy = l_i
@@ -134,7 +127,7 @@ def _helion_attention(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr,
134127 m_i_copy_0 = m_i_copy
135128 l_i_copy_0 = l_i_copy
136129 acc_copy_0 = acc_copy
137- k = tl.permute(k_view_desc. load([offset_0, offset_2, 0]), [0, 2, 1] )
130+ k = tl.load(k_view + (indices_0[:, None, None] * 32768 + indices_4[None, :, None] * 1 + indices_2[None, None, :] * 64), None )
138131 qk = tl.reshape(tl.dot(tl.reshape(tl.cast(q_copy_0, tl.float16), [_BLOCK_SIZE_1, 64]), tl.reshape(tl.cast(k, tl.float16), [64, _BLOCK_SIZE_3]), input_precision='tf32', out_dtype=tl.float32), [_BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_3])
139132 amax = tl.cast(tl.max(qk, 2), tl.float16)
140133 v_0 = 0.18033688
@@ -154,14 +147,14 @@ def _helion_attention(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr,
154147 l_i = v_11 + l_ij
155148 subscript_1 = v_10[:, :, None]
156149 v_13 = acc_copy_0 * subscript_1
157- v = v_view_desc .load([offset_0, offset_2, 0] )
150+ v = tl .load(v_view + (indices_0[:, None, None] * 32768 + indices_2[None, :, None] * 64 + indices_4[None, None, :] * 1), None )
158151 v_14 = tl.cast(v_8, tl.float16)
159152 acc = tl.reshape(tl.dot(tl.reshape(tl.cast(v_14, tl.float16), [_BLOCK_SIZE_1, _BLOCK_SIZE_3]), tl.reshape(tl.cast(v, tl.float16), [_BLOCK_SIZE_3, 64]), acc=tl.reshape(v_13, [_BLOCK_SIZE_1, 64]), input_precision='tf32', out_dtype=tl.float32), [_BLOCK_SIZE_0, _BLOCK_SIZE_1, 64])
160153 m_i = v_3
161154 subscript_2 = l_i[:, :, None]
162155 v_15 = acc / subscript_2
163156 v_16 = tl.cast(v_15, tl.float16)
164- out_desc .store([offset_0, offset_1, 0], v_16)
157+ tl .store(out + (indices_0[:, None, None] * 65536 + indices_1[None, :, None] * 64 + indices_4[None, None, :] * 1), v_16, None )
165158
166159def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, *, _launcher=_default_launcher):
167160 """
@@ -187,6 +180,7 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, *, _la
187180 k_view = k_in.reshape([-1, n_dim, head_dim]).transpose(1, 2)
188181 out = torch.empty_like(q_view)
189182 _BLOCK_SIZE_1 = 128
183+ _RDIM_SIZE_2 = 64
190184 _BLOCK_SIZE_3 = 64
191- _launcher(_helion_attention, (64 * triton.cdiv(1024, _BLOCK_SIZE_1),), q_view, k_view, v_view, out, _BLOCK_SIZE_1, 1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
185+ _launcher(_helion_attention, (64 * triton.cdiv(1024, _BLOCK_SIZE_1),), q_view, k_view, v_view, out, _BLOCK_SIZE_1, _RDIM_SIZE_2, 1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
192186 return out.view(q_in.size())
0 commit comments