@@ -5,10 +5,13 @@ Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environmen
55from __future__ import annotations
66
77import torch
8+ import helion
89import triton
910import triton.language as tl
1011from helion.runtime import default_launcher as _default_launcher
1112
13+ helion.runtime.set_triton_allocator()
14+
1215@triton.jit
1316def _stack_load_kernel_2d_kernel(dev_ptrs, out, dev_ptrs_stride_0, dev_ptrs_stride_1, example_tensor_stride_0, out_stride_0, out_stride_1, out_stride_2, N, M2, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr):
1417 pid_0 = tl.program_id(0)
@@ -31,10 +34,13 @@ def stack_load_kernel_2d(dev_ptrs: torch.Tensor, example_tensor: torch.Tensor, *
3134 return outfrom __future__ import annotations
3235
3336import torch
37+ import helion
3438import triton
3539import triton.language as tl
3640from helion.runtime import default_launcher as _default_launcher
3741
42+ helion.runtime.set_triton_allocator()
43+
3844@triton.jit
3945def _stack_load_2d_looped_kernel(dev_ptrs, out, dev_ptrs_stride_0, dev_ptrs_stride_1, example_tensor_stride_0, out_stride_0, out_stride_1, out_stride_2, N, M2, M1, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr):
4046 pid_0 = tl.program_id(0)
@@ -61,10 +67,13 @@ def stack_load_2d_looped(dev_ptrs: torch.Tensor, example_tensor: torch.Tensor, *
6167from __future__ import annotations
6268
6369import torch
70+ import helion
6471import triton
6572import triton.language as tl
6673from helion.runtime import default_launcher as _default_launcher
6774
75+ helion.runtime.set_triton_allocator()
76+
6877@triton.jit
6978def _stack_load_kernel_kernel(dev_ptrs, out, dev_ptrs_stride_0, example_tensor_stride_0, example_tensor_stride_1, out_stride_0, out_stride_1, out_stride_2, N1, N2, M, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _RDIM_SIZE_2: tl.constexpr):
7079 num_blocks_0 = tl.cdiv(N1, _BLOCK_SIZE_0)
@@ -96,10 +105,13 @@ def stack_load_kernel(dev_ptrs: torch.Tensor, example_tensor: torch.Tensor, *, _
96105from __future__ import annotations
97106
98107import torch
108+ import helion
99109import triton
100110import triton.language as tl
101111from helion.runtime import default_launcher as _default_launcher
102112
113+ helion.runtime.set_triton_allocator()
114+
103115@triton.jit
104116def _stack_load_kernel_kernel(dev_ptrs, out, dev_ptrs_stride_0, example_tensor_stride_0, out_stride_0, out_stride_1, _RDIM_SIZE_1: tl.constexpr):
105117 pid_0 = tl.program_id(0)
@@ -121,10 +133,13 @@ def stack_load_kernel(dev_ptrs: torch.Tensor, example_tensor: torch.Tensor, *, _
121133from __future__ import annotations
122134
123135import torch
136+ import helion
124137import triton
125138import triton.language as tl
126139from helion.runtime import default_launcher as _default_launcher
127140
141+ helion.runtime.set_triton_allocator()
142+
128143@triton.jit
129144def _stack_load_w_mask_kernel(dev_ptrs, out, dev_ptrs_stride_0, example_tensor_stride_0, out_stride_0, out_stride_1, N, M, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
130145 pid_0 = tl.program_id(0)
@@ -154,10 +169,13 @@ def stack_load_w_mask(dev_ptrs: torch.Tensor, example_tensor: torch.Tensor, *, _
154169from __future__ import annotations
155170
156171import torch
172+ import helion
157173import triton
158174import triton.language as tl
159175from helion.runtime import default_launcher as _default_launcher
160176
177+ helion.runtime.set_triton_allocator()
178+
161179@triton.jit
162180def _stack_store_kernel_kernel(dev_ptrs, x, dev_ptrs_stride_0, example_tensor_stride_0, x_stride_0, N, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr):
163181 pid_0 = tl.program_id(0)
@@ -181,10 +199,13 @@ def stack_store_kernel(x: torch.Tensor, dev_ptrs: torch.Tensor, example_tensor:
181199from __future__ import annotations
182200
183201import torch
202+ import helion
184203import triton
185204import triton.language as tl
186205from helion.runtime import default_launcher as _default_launcher
187206
207+ helion.runtime.set_triton_allocator()
208+
188209@triton.jit
189210def _stack_store_kernel_kernel(dev_ptrs, x, dev_ptrs_stride_0, example_tensor_stride_0, x_stride_0, _RDIM_SIZE_1: tl.constexpr):
190211 pid_0 = tl.program_id(0)
@@ -203,10 +224,13 @@ def stack_store_kernel(x: torch.Tensor, dev_ptrs: torch.Tensor, example_tensor:
203224from __future__ import annotations
204225
205226import torch
227+ import helion
206228import triton
207229import triton.language as tl
208230from helion.runtime import default_launcher as _default_launcher
209231
232+ helion.runtime.set_triton_allocator()
233+
210234@triton.jit
211235def _stack_store_arange_kernel_kernel(dev_ptrs, dev_ptrs_stride_0, example_tensor_stride_0, _RDIM_SIZE_1: tl.constexpr):
212236 pid_0 = tl.program_id(0)
0 commit comments