Skip to content

Commit 8a23df1

Browse files
authored
Add distributed CI job (4xH100) and example unit tests (#1106)
1 parent 5d0cd02 commit 8a23df1

File tree

6 files changed

+408
-15
lines changed

6 files changed

+408
-15
lines changed

.github/matrix.json

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,16 @@
5252
"pytorch-version": "pytorch-nightly",
5353
"alias": "h100"
5454
},
55+
{
56+
"runner": "linux.aws.h100.4",
57+
"python-version": "3.12",
58+
"ref-eager": false,
59+
"image": "nvidia/cuda:12.8.1-devel-ubuntu24.04",
60+
"runtime-version": "cu128",
61+
"container-options": "--gpus all",
62+
"pytorch-version": "pytorch-nightly",
63+
"alias": "h100-distributed"
64+
},
5565
{
5666
"runner": "linux.dgx.b200",
5767
"python-version": "3.12",

.github/workflows/test.yml

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,21 @@ jobs:
6666
run: |
6767
set -eux
6868
apt-get update
69-
apt-get install -y libdw1
69+
apt-get install -y libdw1 curl wget git pkg-config zlib1g-dev build-essential
70+
71+
- name: Install NVSHMEM
72+
if: contains(matrix.alias, 'distributed')
73+
run: |
74+
set -euxo pipefail
75+
GPU_COUNT=$(nvidia-smi -L | wc -l)
76+
if [ "$GPU_COUNT" -ne 4 ]; then
77+
echo "Error: Expected 4 GPUs but found $GPU_COUNT"
78+
exit 1
79+
fi
80+
curl -L https://raw.githubusercontent.com/pytorch/pytorch/main/.ci/docker/common/install_cuda.sh -o install_cuda.sh
81+
chmod +x install_cuda.sh
82+
source install_cuda.sh
83+
install_nvshmem 13 3.4.5
7084
7185
- name: Install uv
7286
uses: astral-sh/setup-uv@v7
@@ -131,7 +145,7 @@ jobs:
131145
- name: Install Helion
132146
run: |
133147
source .venv/bin/activate
134-
uv pip install setuptools
148+
uv pip install setuptools ninja
135149
SETUPTOOLS_SCM_PRETEND_VERSION="0.0.0" uv pip install -e .'[dev]'
136150
python -c "import helion; print(helion.__name__)"
137151
@@ -145,7 +159,11 @@ jobs:
145159
if [[ "${{ contains(matrix.alias, 'cpu') }}" == "true" ]]; then export TRITON_CPU_BACKEND=1; fi
146160
# -rf: print failed tests
147161
# --timeout: max allowed time for each test
148-
pytest -rf --timeout=60
162+
TEST_PATH=$([[ "${{ contains(matrix.alias, 'distributed') }}" == "true" ]] && echo "test/test_examples_dist.py" || echo ".")
163+
EXTRA_FLAGS=$([[ "${{ contains(matrix.alias, 'distributed') }}" == "true" ]] && echo "-rs" || echo "--ignore=test/test_examples_dist.py")
164+
# For distributed tests, fail if any test is skipped
165+
SKIP_CHECK=$([[ "${{ contains(matrix.alias, 'distributed') }}" == "true" ]] && echo "! grep -q SKIPPED" || echo "cat")
166+
pytest -rf --timeout=60 $EXTRA_FLAGS $TEST_PATH | tee >(eval $SKIP_CHECK)
149167
150168
test-notebooks:
151169
name: test-notebooks-cu128-py3.12-pytorch-2.9-a10g

examples/all_gather_matmul.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,8 @@ def main() -> None:
235235
if __name__ == "__main__":
236236
"""
237237
Run with:
238-
torchrun \
239-
--nnodes 1 --nproc-per-node 8 \
238+
python -m torch.distributed.run --standalone \
239+
--nproc-per-node 4 \
240240
--rdzv-backend c10d --rdzv-endpoint localhost:0 \
241241
--no_python python3 examples/all_gather_matmul.py
242242
"""

examples/all_reduce.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323

2424
import helion
2525
from helion._testing import DEVICE
26-
from helion._testing import run_example
2726
import helion.language as hl
2827

2928
# %%
@@ -252,15 +251,25 @@ def test(N: int, device: torch.device, dtype: torch.dtype) -> None:
252251
assert dist_group is not None
253252

254253
world_size = dist.get_world_size()
254+
rank = dist.get_rank()
255+
256+
# Create symmetric memory tensor for Helion implementation
255257
a_shared = symm_mem.empty(N // world_size, dtype=dtype, device=device).normal_()
256258

257-
run_example(
258-
helion_one_shot_all_reduce,
259-
reference_one_shot_all_reduce,
260-
(a_shared,),
261-
rtol=1e-1,
262-
atol=1e-1,
263-
)
259+
print(f"[Rank {rank}] Running Helion all-reduce...")
260+
result_helion = helion_one_shot_all_reduce(a_shared)
261+
262+
# Create symmetric memory tensor for reference implementation
263+
a_shared_ref = symm_mem.empty(N // world_size, dtype=dtype, device=device)
264+
a_shared_ref.copy_(a_shared)
265+
266+
print(f"[Rank {rank}] Running reference all-reduce...")
267+
result_ref = reference_one_shot_all_reduce(a_shared_ref)
268+
269+
# Compare results
270+
print(f"[Rank {rank}] Comparing results...")
271+
torch.testing.assert_close(result_helion, result_ref, rtol=1e-1, atol=1e-1)
272+
print(f"[Rank {rank}] Results match! ✓")
264273

265274

266275
def main() -> None:
@@ -283,8 +292,8 @@ def main() -> None:
283292
if __name__ == "__main__":
284293
"""
285294
Run with:
286-
torchrun \
287-
--nnodes 1 --nproc-per-node 8 \
295+
python -m torch.distributed.run --standalone \
296+
--nproc-per-node 4 \
288297
--rdzv-backend c10d --rdzv-endpoint localhost:0 \
289298
--no_python python3 examples/all_reduce.py
290299
"""

test/test_examples_dist.expected

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
This file is automatically generated by assertExpectedJournal calls in test_examples_dist.py.
2+
Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set.
3+
4+
--- assertExpectedJournal(TestExamplesDist.test_all_gather_matmul)
5+
from __future__ import annotations
6+
7+
import torch
8+
import helion
9+
import triton
10+
import triton.language as tl
11+
from torch._inductor.runtime import triton_helpers
12+
from helion.runtime import default_launcher as _default_launcher
13+
14+
@triton.jit
15+
def _helion_helion_matmul_w_progress(progress, a, b, out, SPLITS_PER_RANK, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
16+
# src[all_gather_matmul.py:N]: for tile_m, tile_n in hl.tile([M, N]):
17+
num_blocks_0 = tl.cdiv(4096, _BLOCK_SIZE_0)
18+
pid_0 = tl.program_id(0) % num_blocks_0
19+
pid_1 = tl.program_id(0) // num_blocks_0
20+
offset_0 = pid_0 * _BLOCK_SIZE_0
21+
offset_1 = pid_1 * _BLOCK_SIZE_1
22+
# src[all_gather_matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
23+
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
24+
# src[all_gather_matmul.py:N]: tile_m.begin // (M_per_rank // SPLITS_PER_RANK),
25+
floordiv = triton_helpers.div_floor_integer(1024, SPLITS_PER_RANK)
26+
floordiv_1 = triton_helpers.div_floor_integer(offset_0, triton_helpers.div_floor_integer(1024, SPLITS_PER_RANK))
27+
# src[all_gather_matmul.py:N]: hl.wait(
28+
# src[all_gather_matmul.py:N]: progress,
29+
# src[all_gather_matmul.py:N]: [
30+
# src[all_gather_matmul.py:N-N]: ...
31+
helion.runtime.triton_wait_signal(addr=progress + floordiv_1 * 1, expect=1, update=0, sem='acquire', scope='gpu', op='ld', skip_sync=False)
32+
# src[all_gather_matmul.py:N]: for tile_k in hl.tile(K):
33+
# src[all_gather_matmul.py:N]: acc = torch.addmm(acc, a[tile_m, tile_k], b[tile_k, tile_n])
34+
for offset_2 in tl.range(0, 16384, _BLOCK_SIZE_2):
35+
acc_copy = acc
36+
acc_copy_0 = acc_copy
37+
# src[all_gather_matmul.py:N]: acc = torch.addmm(acc, a[tile_m, tile_k], b[tile_k, tile_n])
38+
load = tl.load(tl.make_block_ptr(a, [4096, 16384], [16384, 1], [offset_0, offset_2], [_BLOCK_SIZE_0, _BLOCK_SIZE_2], [1, 0]), boundary_check=[0, 1], padding_option='zero')
39+
load_1 = tl.load(tl.make_block_ptr(b, [16384, 6656], [1, 16384], [offset_2, offset_1], [_BLOCK_SIZE_2, _BLOCK_SIZE_1], [0, 1]), boundary_check=[0, 1], padding_option='zero')
40+
acc = tl.dot(tl.cast(load, tl.bfloat16), tl.cast(load_1, tl.bfloat16), acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32)
41+
# src[all_gather_matmul.py:N]: out[tile_m, tile_n] = acc
42+
v_0 = tl.cast(acc, tl.bfloat16)
43+
tl.store(tl.make_block_ptr(out, [4096, 6656], [6656, 1], [offset_0, offset_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), v_0, boundary_check=[0, 1])
44+
45+
def helion_matmul_w_progress(a: torch.Tensor, a_shared: torch.Tensor, b: torch.Tensor, progress: torch.Tensor, SPLITS_PER_RANK: int, RANK: int, *, _launcher=_default_launcher):
46+
"""
47+
Performs matrix multiplication with progress tracking.
48+
Args:
49+
a (torch.Tensor): First input tensor for matrix multiplication.
50+
a_shared (torch.Tensor): Shared tensor across ranks.
51+
b (torch.Tensor): Second input tensor for matrix multiplication.
52+
progress (torch.Tensor): Tensor used to track progress of the operation.
53+
SPLITS_PER_RANK (int): Number of splits per rank.
54+
RANK (int): Current process rank.
55+
Returns:
56+
torch.Tensor: The result of the matrix multiplication.
57+
"""
58+
# src[all_gather_matmul.py:N]: M, K = a.size()
59+
M, K = a.size()
60+
# src[all_gather_matmul.py:N]: K2, N = b.size()
61+
K2, N = b.size()
62+
# src[all_gather_matmul.py:N]: assert K2 == K, f"size mismatch {K2} != {K}"
63+
assert K2 == K, f'size mismatch {K2} != {K}'
64+
# src[all_gather_matmul.py:N]: out = torch.empty(
65+
# src[all_gather_matmul.py:N]: [M, N], dtype=torch.promote_types(a.dtype, b.dtype), device=a.device
66+
# src[all_gather_matmul.py:N]: )
67+
out = torch.empty([M, N], dtype=torch.promote_types(a.dtype, b.dtype), device=a.device)
68+
# src[all_gather_matmul.py:N]: M_per_rank = a_shared.size(0)
69+
M_per_rank = a_shared.size(0)
70+
# src[all_gather_matmul.py:N]: for tile_m, tile_n in hl.tile([M, N]):
71+
_BLOCK_SIZE_0 = 128
72+
_BLOCK_SIZE_1 = 256
73+
# src[all_gather_matmul.py:N]: for tile_k in hl.tile(K):
74+
# src[all_gather_matmul.py:N]: acc = torch.addmm(acc, a[tile_m, tile_k], b[tile_k, tile_n])
75+
_BLOCK_SIZE_2 = 64
76+
# src[all_gather_matmul.py:N]: for tile_m, tile_n in hl.tile([M, N]):
77+
# src[all_gather_matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
78+
# src[all_gather_matmul.py:N]: hl.wait(
79+
# src[all_gather_matmul.py:N-N]: ...
80+
_launcher(_helion_helion_matmul_w_progress, (triton.cdiv(4096, _BLOCK_SIZE_0) * triton.cdiv(6656, _BLOCK_SIZE_1),), progress, a, b, out, SPLITS_PER_RANK, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=8, num_stages=3)
81+
# src[all_gather_matmul.py:N]: return out
82+
return out
83+
84+
--- assertExpectedJournal(TestExamplesDist.test_all_reduce)
85+
from __future__ import annotations
86+
87+
import torch
88+
import helion
89+
import helion.language as hl
90+
import triton
91+
import triton.language as tl
92+
from helion.runtime import default_launcher as _default_launcher
93+
94+
@triton.jit
95+
def _helion_one_shot_all_reduce_kernel(signal_pad_addrs, local_signal_pad, a_shared_tuple_item_0, a_shared_tuple_item_1, a_shared_tuple_item_2, a_shared_tuple_item_3, out, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
96+
# src[all_reduce.py:N]: for tile_n in hl.tile(N):
97+
pid_0 = tl.program_id(0)
98+
offset_0 = pid_0 * _BLOCK_SIZE_0
99+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
100+
mask_0 = indices_0 < 4096
101+
indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
102+
# src[all_reduce.py:N]: ptr_tile = signal_pad_addrs[:]
103+
ptr_tile = tl.load(signal_pad_addrs + indices_1 * 1, None)
104+
# src[all_reduce.py:N]: [tile_n.id, my_rank],
105+
tile_id = offset_0 // _BLOCK_SIZE_0
106+
# src[all_reduce.py:N]: hl.signal(
107+
# src[all_reduce.py:N]: stack_signalpad,
108+
# src[all_reduce.py:N]: [tile_n.id, my_rank],
109+
# src[all_reduce.py:N-N]: ...
110+
helion.runtime.triton_wait_multiple_signal(addr=ptr_tile.to(tl.pointer_type(tl.int32))[:] + (tile_id * 4 + 0 * 1)[None], expect=0, update=1, sem='relaxed', scope='sys', op='atomic_cas', skip_sync=True, sync_before=not True)
111+
# src[all_reduce.py:N]: for world in hl.tile(world_size, block_size=world_size):
112+
# src[all_reduce.py:N]: hl.wait(
113+
# src[all_reduce.py:N]: local_signal_pad,
114+
# src[all_reduce.py:N-N]: ...
115+
for offset_2 in tl.range(0, 4, _BLOCK_SIZE_2):
116+
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
117+
# src[all_reduce.py:N]: [tile_n.id, world],
118+
tile_id_1 = offset_0 // _BLOCK_SIZE_0
119+
# src[all_reduce.py:N]: hl.wait(
120+
# src[all_reduce.py:N]: local_signal_pad,
121+
# src[all_reduce.py:N]: [tile_n.id, world],
122+
# src[all_reduce.py:N-N]: ...
123+
helion.runtime.triton_wait_multiple_signal(addr=local_signal_pad + (tile_id_1 * 4 + indices_2 * 1), expect=1, update=0, sem='acquire', scope='sys', op='atomic_cas', skip_sync=False)
124+
# src[all_reduce.py:N]: acc = hl.zeros(
125+
# src[all_reduce.py:N]: [tile_n], dtype=a_shared_tuple[0].dtype, device=local_signal_pad.device
126+
# src[all_reduce.py:N]: )
127+
acc = tl.full([_BLOCK_SIZE_0], 0.0, tl.bfloat16)
128+
# src[all_reduce.py:N]: acc += a[tile_n]
129+
load_1 = tl.load(a_shared_tuple_item_0 + indices_0 * 1, mask_0, other=0)
130+
v_0 = acc + load_1
131+
load_2 = tl.load(a_shared_tuple_item_1 + indices_0 * 1, mask_0, other=0)
132+
v_1 = v_0 + load_2
133+
load_3 = tl.load(a_shared_tuple_item_2 + indices_0 * 1, mask_0, other=0)
134+
v_2 = v_1 + load_3
135+
load_4 = tl.load(a_shared_tuple_item_3 + indices_0 * 1, mask_0, other=0)
136+
v_3 = v_2 + load_4
137+
# src[all_reduce.py:N]: out[tile_n] = acc
138+
tl.store(out + indices_0 * 1, v_3, mask_0)
139+
# src[all_reduce.py:N]: hl.signal(
140+
# src[all_reduce.py:N]: stack_signalpad, [tile_n.id, my_rank], signal=1, wait_for=0, scope="sys"
141+
# src[all_reduce.py:N]: )
142+
helion.runtime.triton_wait_multiple_signal(addr=ptr_tile.to(tl.pointer_type(tl.int32))[:] + (tile_id * 4 + 0 * 1)[None], expect=0, update=1, sem='release', scope='sys', op='atomic_cas', skip_sync=True, sync_before=not False)
143+
# src[all_reduce.py:N]: for world in hl.tile(world_size, block_size=world_size):
144+
# src[all_reduce.py:N]: hl.wait(
145+
# src[all_reduce.py:N]: local_signal_pad,
146+
# src[all_reduce.py:N-N]: ...
147+
for offset_3 in tl.range(0, 4, _BLOCK_SIZE_3):
148+
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
149+
# src[all_reduce.py:N]: [tile_n.id, world],
150+
tile_id_2 = offset_0 // _BLOCK_SIZE_0
151+
# src[all_reduce.py:N]: hl.wait(
152+
# src[all_reduce.py:N]: local_signal_pad,
153+
# src[all_reduce.py:N]: [tile_n.id, world],
154+
# src[all_reduce.py:N-N]: ...
155+
helion.runtime.triton_wait_multiple_signal(addr=local_signal_pad + (tile_id_2 * 4 + indices_3 * 1), expect=1, update=0, sem='relaxed', scope='sys', op='atomic_cas', skip_sync=True)
156+
157+
def one_shot_all_reduce_kernel(signal_pad_addrs: torch.Tensor, local_signal_pad: torch.Tensor, a_shared_tuple: tuple[torch.Tensor, ...], my_rank: hl.constexpr, *, _launcher=_default_launcher):
158+
"""
159+
Helion JIT-compiled kernel for one-shot all-reduce operation.
160+
161+
This kernel implements a distributed all-reduce using symmetric memory and signal pads
162+
for cross-device synchronization. It performs element-wise summation across all devices
163+
in the distributed group using tiled computation for memory efficiency.
164+
165+
Args:
166+
signal_pad_addrs: Tensor containing addresses of signal pads for all devices
167+
local_signal_pad: Local signal pad for synchronization
168+
a_shared_tuple: Tuple of shared tensors from all devices in the group
169+
my_rank: Current device's rank in the distributed group
170+
171+
Returns:
172+
Tensor containing the all-reduced result (sum across all devices)
173+
"""
174+
# src[all_reduce.py:N]: _, world_size = local_signal_pad.size()
175+
_, world_size = local_signal_pad.size()
176+
# src[all_reduce.py:N]: out = torch.empty_like(a_shared_tuple[0])
177+
out = torch.empty_like(a_shared_tuple[0])
178+
# src[all_reduce.py:N]: N = out.size(0)
179+
N = out.size(0)
180+
# src[all_reduce.py:N]: for tile_n in hl.tile(N):
181+
_BLOCK_SIZE_0 = 8192
182+
_RDIM_SIZE_1 = 4
183+
# src[all_reduce.py:N]: for world in hl.tile(world_size, block_size=world_size):
184+
# src[all_reduce.py:N]: hl.wait(
185+
# src[all_reduce.py:N]: local_signal_pad,
186+
# src[all_reduce.py:N-N]: ...
187+
_BLOCK_SIZE_2 = 4
188+
# src[all_reduce.py:N]: for world in hl.tile(world_size, block_size=world_size):
189+
# src[all_reduce.py:N]: hl.wait(
190+
# src[all_reduce.py:N]: local_signal_pad,
191+
# src[all_reduce.py:N-N]: ...
192+
_BLOCK_SIZE_3 = 4
193+
# src[all_reduce.py:N]: for tile_n in hl.tile(N):
194+
# src[all_reduce.py:N]: # Sync all devices through signal_pad to make sure
195+
# src[all_reduce.py:N]: # all previous writes to the shared tensor are visible
196+
# src[all_reduce.py:N-N]: ...
197+
_launcher(_helion_one_shot_all_reduce_kernel, (triton.cdiv(4096, _BLOCK_SIZE_0),), signal_pad_addrs, local_signal_pad, a_shared_tuple[0], a_shared_tuple[1], a_shared_tuple[2], a_shared_tuple[3], out, _BLOCK_SIZE_0, _RDIM_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=32, num_stages=1)
198+
# src[all_reduce.py:N]: return out
199+
return out

0 commit comments

Comments
 (0)