Skip to content

Commit 1501944

Browse files
committed
Add distributed unit tests for existing examples
1 parent d7e007e commit 1501944

File tree

4 files changed

+377
-12
lines changed

4 files changed

+377
-12
lines changed

examples/all_gather_matmul.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,8 @@ def main() -> None:
235235
if __name__ == "__main__":
236236
"""
237237
Run with:
238-
torchrun \
239-
--nnodes 1 --nproc-per-node 8 \
238+
python -m torch.distributed.run --standalone \
239+
--nproc-per-node 4 \
240240
--rdzv-backend c10d --rdzv-endpoint localhost:0 \
241241
--no_python python3 examples/all_gather_matmul.py
242242
"""

examples/all_reduce.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323

2424
import helion
2525
from helion._testing import DEVICE
26-
from helion._testing import run_example
2726
import helion.language as hl
2827

2928
# %%
@@ -252,15 +251,25 @@ def test(N: int, device: torch.device, dtype: torch.dtype) -> None:
252251
assert dist_group is not None
253252

254253
world_size = dist.get_world_size()
254+
rank = dist.get_rank()
255+
256+
# Create symmetric memory tensor for Helion implementation
255257
a_shared = symm_mem.empty(N // world_size, dtype=dtype, device=device).normal_()
256258

257-
run_example(
258-
helion_one_shot_all_reduce,
259-
reference_one_shot_all_reduce,
260-
(a_shared,),
261-
rtol=1e-1,
262-
atol=1e-1,
263-
)
259+
print(f"[Rank {rank}] Running Helion all-reduce...")
260+
result_helion = helion_one_shot_all_reduce(a_shared)
261+
262+
# Create symmetric memory tensor for reference implementation
263+
a_shared_ref = symm_mem.empty(N // world_size, dtype=dtype, device=device)
264+
a_shared_ref.copy_(a_shared)
265+
266+
print(f"[Rank {rank}] Running reference all-reduce...")
267+
result_ref = reference_one_shot_all_reduce(a_shared_ref)
268+
269+
# Compare results
270+
print(f"[Rank {rank}] Comparing results...")
271+
torch.testing.assert_close(result_helion, result_ref, rtol=1e-1, atol=1e-1)
272+
print(f"[Rank {rank}] Results match! ✓")
264273

265274

266275
def main() -> None:
@@ -283,8 +292,8 @@ def main() -> None:
283292
if __name__ == "__main__":
284293
"""
285294
Run with:
286-
torchrun \
287-
--nnodes 1 --nproc-per-node 8 \
295+
python -m torch.distributed.run --standalone \
296+
--nproc-per-node 4 \
288297
--rdzv-backend c10d --rdzv-endpoint localhost:0 \
289298
--no_python python3 examples/all_reduce.py
290299
"""

test/test_examples_dist.expected

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
This file is automatically generated by assertExpectedJournal calls in test_examples_dist.py.
2+
Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set.
3+
4+
--- assertExpectedJournal(TestExamplesDist.test_all_gather_matmul)
5+
from __future__ import annotations
6+
7+
import torch
8+
import helion
9+
import triton
10+
import triton.language as tl
11+
from torch._inductor.runtime import triton_helpers
12+
from helion.runtime import default_launcher as _default_launcher
13+
14+
@triton.jit
15+
def _helion_helion_matmul_w_progress(progress, a, b, out, SPLITS_PER_RANK, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
16+
# src[all_gather_matmul.py:N]: for tile_m, tile_n in hl.tile([M, N]):
17+
num_blocks_0 = tl.cdiv(4096, _BLOCK_SIZE_0)
18+
pid_0 = tl.program_id(0) % num_blocks_0
19+
pid_1 = tl.program_id(0) // num_blocks_0
20+
offset_0 = pid_0 * _BLOCK_SIZE_0
21+
offset_1 = pid_1 * _BLOCK_SIZE_1
22+
# src[all_gather_matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
23+
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
24+
# src[all_gather_matmul.py:N]: tile_m.begin // (M_per_rank // SPLITS_PER_RANK),
25+
floordiv = triton_helpers.div_floor_integer(1024, SPLITS_PER_RANK)
26+
floordiv_1 = triton_helpers.div_floor_integer(offset_0, triton_helpers.div_floor_integer(1024, SPLITS_PER_RANK))
27+
# src[all_gather_matmul.py:N]: hl.wait(
28+
# src[all_gather_matmul.py:N]: progress,
29+
# src[all_gather_matmul.py:N]: [
30+
# src[all_gather_matmul.py:N-N]: ...
31+
helion.runtime.triton_wait_signal(addr=progress + floordiv_1 * 1, expect=1, update=0, sem='acquire', scope='gpu', op='ld', skip_sync=False)
32+
# src[all_gather_matmul.py:N]: for tile_k in hl.tile(K):
33+
# src[all_gather_matmul.py:N]: acc = torch.addmm(acc, a[tile_m, tile_k], b[tile_k, tile_n])
34+
for offset_2 in tl.range(0, 16384, _BLOCK_SIZE_2):
35+
acc_copy = acc
36+
acc_copy_0 = acc_copy
37+
# src[all_gather_matmul.py:N]: acc = torch.addmm(acc, a[tile_m, tile_k], b[tile_k, tile_n])
38+
load = tl.load(tl.make_block_ptr(a, [4096, 16384], [16384, 1], [offset_0, offset_2], [_BLOCK_SIZE_0, _BLOCK_SIZE_2], [1, 0]), boundary_check=[0, 1], padding_option='zero')
39+
load_1 = tl.load(tl.make_block_ptr(b, [16384, 6656], [1, 16384], [offset_2, offset_1], [_BLOCK_SIZE_2, _BLOCK_SIZE_1], [0, 1]), boundary_check=[0, 1], padding_option='zero')
40+
acc = tl.dot(tl.cast(load, tl.bfloat16), tl.cast(load_1, tl.bfloat16), acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32)
41+
# src[all_gather_matmul.py:N]: out[tile_m, tile_n] = acc
42+
v_0 = tl.cast(acc, tl.bfloat16)
43+
tl.store(tl.make_block_ptr(out, [4096, 6656], [6656, 1], [offset_0, offset_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), v_0, boundary_check=[0, 1])
44+
45+
def helion_matmul_w_progress(a: torch.Tensor, a_shared: torch.Tensor, b: torch.Tensor, progress: torch.Tensor, SPLITS_PER_RANK: int, RANK: int, *, _launcher=_default_launcher):
46+
"""
47+
Performs matrix multiplication with progress tracking.
48+
Args:
49+
a (torch.Tensor): First input tensor for matrix multiplication.
50+
a_shared (torch.Tensor): Shared tensor across ranks.
51+
b (torch.Tensor): Second input tensor for matrix multiplication.
52+
progress (torch.Tensor): Tensor used to track progress of the operation.
53+
SPLITS_PER_RANK (int): Number of splits per rank.
54+
RANK (int): Current process rank.
55+
Returns:
56+
torch.Tensor: The result of the matrix multiplication.
57+
"""
58+
# src[all_gather_matmul.py:N]: M, K = a.size()
59+
M, K = a.size()
60+
# src[all_gather_matmul.py:N]: K2, N = b.size()
61+
K2, N = b.size()
62+
# src[all_gather_matmul.py:N]: assert K2 == K, f"size mismatch {K2} != {K}"
63+
assert K2 == K, f'size mismatch {K2} != {K}'
64+
# src[all_gather_matmul.py:N]: out = torch.empty(
65+
# src[all_gather_matmul.py:N]: [M, N], dtype=torch.promote_types(a.dtype, b.dtype), device=a.device
66+
# src[all_gather_matmul.py:N]: )
67+
out = torch.empty([M, N], dtype=torch.promote_types(a.dtype, b.dtype), device=a.device)
68+
# src[all_gather_matmul.py:N]: M_per_rank = a_shared.size(0)
69+
M_per_rank = a_shared.size(0)
70+
# src[all_gather_matmul.py:N]: for tile_m, tile_n in hl.tile([M, N]):
71+
_BLOCK_SIZE_0 = 128
72+
_BLOCK_SIZE_1 = 256
73+
# src[all_gather_matmul.py:N]: for tile_k in hl.tile(K):
74+
# src[all_gather_matmul.py:N]: acc = torch.addmm(acc, a[tile_m, tile_k], b[tile_k, tile_n])
75+
_BLOCK_SIZE_2 = 64
76+
# src[all_gather_matmul.py:N]: for tile_m, tile_n in hl.tile([M, N]):
77+
# src[all_gather_matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
78+
# src[all_gather_matmul.py:N]: hl.wait(
79+
# src[all_gather_matmul.py:N-N]: ...
80+
_launcher(_helion_helion_matmul_w_progress, (triton.cdiv(4096, _BLOCK_SIZE_0) * triton.cdiv(6656, _BLOCK_SIZE_1),), progress, a, b, out, SPLITS_PER_RANK, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=8, num_stages=3)
81+
# src[all_gather_matmul.py:N]: return out
82+
return out
83+
84+
--- assertExpectedJournal(TestExamplesDist.test_all_reduce)
85+
from __future__ import annotations
86+
87+
import torch
88+
import helion
89+
import helion.language as hl
90+
import triton
91+
import triton.language as tl
92+
from helion.runtime import default_launcher as _default_launcher
93+
94+
@triton.jit
95+
def _helion_one_shot_all_reduce_kernel(signal_pad_addrs, local_signal_pad, a_shared_tuple_item_0, a_shared_tuple_item_1, a_shared_tuple_item_2, a_shared_tuple_item_3, out, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
96+
# src[all_reduce.py:N]: for tile_n in hl.tile(N):
97+
pid_0 = tl.program_id(0)
98+
offset_0 = pid_0 * _BLOCK_SIZE_0
99+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
100+
mask_0 = indices_0 < 4096
101+
indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
102+
# src[all_reduce.py:N]: ptr_tile = signal_pad_addrs[:]
103+
ptr_tile = tl.load(signal_pad_addrs + indices_1 * 1, None)
104+
# src[all_reduce.py:N]: [tile_n.id, my_rank],
105+
tile_id = offset_0 // _BLOCK_SIZE_0
106+
# src[all_reduce.py:N]: hl.signal(
107+
# src[all_reduce.py:N]: stack_signalpad,
108+
# src[all_reduce.py:N]: [tile_n.id, my_rank],
109+
# src[all_reduce.py:N-N]: ...
110+
helion.runtime.triton_wait_multiple_signal(addr=ptr_tile.to(tl.pointer_type(tl.int32))[:] + (tile_id * 4 + 0 * 1)[None], expect=0, update=1, sem='relaxed', scope='sys', op='atomic_cas', skip_sync=True, sync_before=not True)
111+
# src[all_reduce.py:N]: for world in hl.tile(world_size, block_size=world_size):
112+
# src[all_reduce.py:N]: hl.wait(
113+
# src[all_reduce.py:N]: local_signal_pad,
114+
# src[all_reduce.py:N-N]: ...
115+
for offset_2 in tl.range(0, 4, _BLOCK_SIZE_2):
116+
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
117+
# src[all_reduce.py:N]: [tile_n.id, world],
118+
tile_id_1 = offset_0 // _BLOCK_SIZE_0
119+
# src[all_reduce.py:N]: hl.wait(
120+
# src[all_reduce.py:N]: local_signal_pad,
121+
# src[all_reduce.py:N]: [tile_n.id, world],
122+
# src[all_reduce.py:N-N]: ...
123+
helion.runtime.triton_wait_multiple_signal(addr=local_signal_pad + (tile_id_1 * 4 + indices_2 * 1), expect=1, update=0, sem='acquire', scope='sys', op='atomic_cas', skip_sync=False)
124+
# src[all_reduce.py:N]: acc = hl.zeros(
125+
# src[all_reduce.py:N]: [tile_n], dtype=a_shared_tuple[0].dtype, device=local_signal_pad.device
126+
# src[all_reduce.py:N]: )
127+
acc = tl.full([_BLOCK_SIZE_0], 0.0, tl.bfloat16)
128+
# src[all_reduce.py:N]: acc += a[tile_n]
129+
load_1 = tl.load(a_shared_tuple_item_0 + indices_0 * 1, mask_0, other=0)
130+
v_0 = acc + load_1
131+
load_2 = tl.load(a_shared_tuple_item_1 + indices_0 * 1, mask_0, other=0)
132+
v_1 = v_0 + load_2
133+
load_3 = tl.load(a_shared_tuple_item_2 + indices_0 * 1, mask_0, other=0)
134+
v_2 = v_1 + load_3
135+
load_4 = tl.load(a_shared_tuple_item_3 + indices_0 * 1, mask_0, other=0)
136+
v_3 = v_2 + load_4
137+
# src[all_reduce.py:N]: out[tile_n] = acc
138+
tl.store(out + indices_0 * 1, v_3, mask_0)
139+
# src[all_reduce.py:N]: hl.signal(
140+
# src[all_reduce.py:N]: stack_signalpad, [tile_n.id, my_rank], signal=1, wait_for=0, scope="sys"
141+
# src[all_reduce.py:N]: )
142+
helion.runtime.triton_wait_multiple_signal(addr=ptr_tile.to(tl.pointer_type(tl.int32))[:] + (tile_id * 4 + 0 * 1)[None], expect=0, update=1, sem='release', scope='sys', op='atomic_cas', skip_sync=True, sync_before=not False)
143+
# src[all_reduce.py:N]: for world in hl.tile(world_size, block_size=world_size):
144+
# src[all_reduce.py:N]: hl.wait(
145+
# src[all_reduce.py:N]: local_signal_pad,
146+
# src[all_reduce.py:N-N]: ...
147+
for offset_3 in tl.range(0, 4, _BLOCK_SIZE_3):
148+
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
149+
# src[all_reduce.py:N]: [tile_n.id, world],
150+
tile_id_2 = offset_0 // _BLOCK_SIZE_0
151+
# src[all_reduce.py:N]: hl.wait(
152+
# src[all_reduce.py:N]: local_signal_pad,
153+
# src[all_reduce.py:N]: [tile_n.id, world],
154+
# src[all_reduce.py:N-N]: ...
155+
helion.runtime.triton_wait_multiple_signal(addr=local_signal_pad + (tile_id_2 * 4 + indices_3 * 1), expect=1, update=0, sem='relaxed', scope='sys', op='atomic_cas', skip_sync=True)
156+
157+
def one_shot_all_reduce_kernel(signal_pad_addrs: torch.Tensor, local_signal_pad: torch.Tensor, a_shared_tuple: tuple[torch.Tensor, ...], my_rank: hl.constexpr, *, _launcher=_default_launcher):
158+
"""
159+
Helion JIT-compiled kernel for one-shot all-reduce operation.
160+
161+
This kernel implements a distributed all-reduce using symmetric memory and signal pads
162+
for cross-device synchronization. It performs element-wise summation across all devices
163+
in the distributed group using tiled computation for memory efficiency.
164+
165+
Args:
166+
signal_pad_addrs: Tensor containing addresses of signal pads for all devices
167+
local_signal_pad: Local signal pad for synchronization
168+
a_shared_tuple: Tuple of shared tensors from all devices in the group
169+
my_rank: Current device's rank in the distributed group
170+
171+
Returns:
172+
Tensor containing the all-reduced result (sum across all devices)
173+
"""
174+
# src[all_reduce.py:N]: _, world_size = local_signal_pad.size()
175+
_, world_size = local_signal_pad.size()
176+
# src[all_reduce.py:N]: out = torch.empty_like(a_shared_tuple[0])
177+
out = torch.empty_like(a_shared_tuple[0])
178+
# src[all_reduce.py:N]: N = out.size(0)
179+
N = out.size(0)
180+
# src[all_reduce.py:N]: for tile_n in hl.tile(N):
181+
_BLOCK_SIZE_0 = 8192
182+
_RDIM_SIZE_1 = 4
183+
# src[all_reduce.py:N]: for world in hl.tile(world_size, block_size=world_size):
184+
# src[all_reduce.py:N]: hl.wait(
185+
# src[all_reduce.py:N]: local_signal_pad,
186+
# src[all_reduce.py:N-N]: ...
187+
_BLOCK_SIZE_2 = 4
188+
# src[all_reduce.py:N]: for world in hl.tile(world_size, block_size=world_size):
189+
# src[all_reduce.py:N]: hl.wait(
190+
# src[all_reduce.py:N]: local_signal_pad,
191+
# src[all_reduce.py:N-N]: ...
192+
_BLOCK_SIZE_3 = 4
193+
# src[all_reduce.py:N]: for tile_n in hl.tile(N):
194+
# src[all_reduce.py:N]: # Sync all devices through signal_pad to make sure
195+
# src[all_reduce.py:N]: # all previous writes to the shared tensor are visible
196+
# src[all_reduce.py:N-N]: ...
197+
_launcher(_helion_one_shot_all_reduce_kernel, (triton.cdiv(4096, _BLOCK_SIZE_0),), signal_pad_addrs, local_signal_pad, a_shared_tuple[0], a_shared_tuple[1], a_shared_tuple[2], a_shared_tuple[3], out, _BLOCK_SIZE_0, _RDIM_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=32, num_stages=1)
198+
# src[all_reduce.py:N]: return out
199+
return out

0 commit comments

Comments
 (0)