Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .github/matrix.json
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,16 @@
"pytorch-version": "pytorch-nightly",
"alias": "h100"
},
{
"runner": "linux.aws.h100.4",
"python-version": "3.12",
"ref-eager": false,
"image": "nvidia/cuda:12.8.1-devel-ubuntu24.04",
"runtime-version": "cu128",
"container-options": "--gpus all",
"pytorch-version": "pytorch-nightly",
"alias": "h100-distributed"
},
{
"runner": "linux.dgx.b200",
"python-version": "3.12",
Expand Down
24 changes: 21 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,21 @@ jobs:
run: |
set -eux
apt-get update
apt-get install -y libdw1
apt-get install -y libdw1 curl wget git pkg-config zlib1g-dev build-essential

- name: Install NVSHMEM
if: contains(matrix.alias, 'distributed')
run: |
set -euxo pipefail
GPU_COUNT=$(nvidia-smi -L | wc -l)
if [ "$GPU_COUNT" -ne 4 ]; then
echo "Error: Expected 4 GPUs but found $GPU_COUNT"
exit 1
fi
curl -L https://raw.githubusercontent.com/pytorch/pytorch/main/.ci/docker/common/install_cuda.sh -o install_cuda.sh
chmod +x install_cuda.sh
source install_cuda.sh
install_nvshmem 13 3.4.5

- name: Install uv
uses: astral-sh/setup-uv@v7
Expand Down Expand Up @@ -131,7 +145,7 @@ jobs:
- name: Install Helion
run: |
source .venv/bin/activate
uv pip install setuptools
uv pip install setuptools ninja
SETUPTOOLS_SCM_PRETEND_VERSION="0.0.0" uv pip install -e .'[dev]'
python -c "import helion; print(helion.__name__)"

Expand All @@ -145,7 +159,11 @@ jobs:
if [[ "${{ contains(matrix.alias, 'cpu') }}" == "true" ]]; then export TRITON_CPU_BACKEND=1; fi
# -rf: print failed tests
# --timeout: max allowed time for each test
pytest -rf --timeout=60
TEST_PATH=$([[ "${{ contains(matrix.alias, 'distributed') }}" == "true" ]] && echo "test/test_examples_dist.py" || echo ".")
EXTRA_FLAGS=$([[ "${{ contains(matrix.alias, 'distributed') }}" == "true" ]] && echo "-rs" || echo "--ignore=test/test_examples_dist.py")
# For distributed tests, fail if any test is skipped
SKIP_CHECK=$([[ "${{ contains(matrix.alias, 'distributed') }}" == "true" ]] && echo "! grep -q SKIPPED" || echo "cat")
pytest -rf --timeout=60 $EXTRA_FLAGS $TEST_PATH | tee >(eval $SKIP_CHECK)

test-notebooks:
name: test-notebooks-cu128-py3.12-pytorch-2.9-a10g
Expand Down
4 changes: 2 additions & 2 deletions examples/all_gather_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,8 @@ def main() -> None:
if __name__ == "__main__":
"""
Run with:
torchrun \
--nnodes 1 --nproc-per-node 8 \
python -m torch.distributed.run --standalone \
--nproc-per-node 4 \
--rdzv-backend c10d --rdzv-endpoint localhost:0 \
--no_python python3 examples/all_gather_matmul.py
"""
Expand Down
29 changes: 19 additions & 10 deletions examples/all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

import helion
from helion._testing import DEVICE
from helion._testing import run_example
import helion.language as hl

# %%
Expand Down Expand Up @@ -252,15 +251,25 @@ def test(N: int, device: torch.device, dtype: torch.dtype) -> None:
assert dist_group is not None

world_size = dist.get_world_size()
rank = dist.get_rank()

# Create symmetric memory tensor for Helion implementation
a_shared = symm_mem.empty(N // world_size, dtype=dtype, device=device).normal_()

run_example(
helion_one_shot_all_reduce,
reference_one_shot_all_reduce,
(a_shared,),
rtol=1e-1,
atol=1e-1,
)
print(f"[Rank {rank}] Running Helion all-reduce...")
result_helion = helion_one_shot_all_reduce(a_shared)

# Create symmetric memory tensor for reference implementation
a_shared_ref = symm_mem.empty(N // world_size, dtype=dtype, device=device)
a_shared_ref.copy_(a_shared)

print(f"[Rank {rank}] Running reference all-reduce...")
result_ref = reference_one_shot_all_reduce(a_shared_ref)

# Compare results
print(f"[Rank {rank}] Comparing results...")
torch.testing.assert_close(result_helion, result_ref, rtol=1e-1, atol=1e-1)
print(f"[Rank {rank}] Results match! ✓")


def main() -> None:
Expand All @@ -283,8 +292,8 @@ def main() -> None:
if __name__ == "__main__":
"""
Run with:
torchrun \
--nnodes 1 --nproc-per-node 8 \
python -m torch.distributed.run --standalone \
--nproc-per-node 4 \
--rdzv-backend c10d --rdzv-endpoint localhost:0 \
--no_python python3 examples/all_reduce.py
"""
Expand Down
199 changes: 199 additions & 0 deletions test/test_examples_dist.expected
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
This file is automatically generated by assertExpectedJournal calls in test_examples_dist.py.
Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set.

--- assertExpectedJournal(TestExamplesDist.test_all_gather_matmul)
from __future__ import annotations

import torch
import helion
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_helion_matmul_w_progress(progress, a, b, out, SPLITS_PER_RANK, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
# src[all_gather_matmul.py:N]: for tile_m, tile_n in hl.tile([M, N]):
num_blocks_0 = tl.cdiv(4096, _BLOCK_SIZE_0)
pid_0 = tl.program_id(0) % num_blocks_0
pid_1 = tl.program_id(0) // num_blocks_0
offset_0 = pid_0 * _BLOCK_SIZE_0
offset_1 = pid_1 * _BLOCK_SIZE_1
# src[all_gather_matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
# src[all_gather_matmul.py:N]: tile_m.begin // (M_per_rank // SPLITS_PER_RANK),
floordiv = triton_helpers.div_floor_integer(1024, SPLITS_PER_RANK)
floordiv_1 = triton_helpers.div_floor_integer(offset_0, triton_helpers.div_floor_integer(1024, SPLITS_PER_RANK))
# src[all_gather_matmul.py:N]: hl.wait(
# src[all_gather_matmul.py:N]: progress,
# src[all_gather_matmul.py:N]: [
# src[all_gather_matmul.py:N-N]: ...
helion.runtime.triton_wait_signal(addr=progress + floordiv_1 * 1, expect=1, update=0, sem='acquire', scope='gpu', op='ld', skip_sync=False)
# src[all_gather_matmul.py:N]: for tile_k in hl.tile(K):
# src[all_gather_matmul.py:N]: acc = torch.addmm(acc, a[tile_m, tile_k], b[tile_k, tile_n])
for offset_2 in tl.range(0, 16384, _BLOCK_SIZE_2):
acc_copy = acc
acc_copy_0 = acc_copy
# src[all_gather_matmul.py:N]: acc = torch.addmm(acc, a[tile_m, tile_k], b[tile_k, tile_n])
load = tl.load(tl.make_block_ptr(a, [4096, 16384], [16384, 1], [offset_0, offset_2], [_BLOCK_SIZE_0, _BLOCK_SIZE_2], [1, 0]), boundary_check=[0, 1], padding_option='zero')
load_1 = tl.load(tl.make_block_ptr(b, [16384, 6656], [1, 16384], [offset_2, offset_1], [_BLOCK_SIZE_2, _BLOCK_SIZE_1], [0, 1]), boundary_check=[0, 1], padding_option='zero')
acc = tl.dot(tl.cast(load, tl.bfloat16), tl.cast(load_1, tl.bfloat16), acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32)
# src[all_gather_matmul.py:N]: out[tile_m, tile_n] = acc
v_0 = tl.cast(acc, tl.bfloat16)
tl.store(tl.make_block_ptr(out, [4096, 6656], [6656, 1], [offset_0, offset_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), v_0, boundary_check=[0, 1])

def helion_matmul_w_progress(a: torch.Tensor, a_shared: torch.Tensor, b: torch.Tensor, progress: torch.Tensor, SPLITS_PER_RANK: int, RANK: int, *, _launcher=_default_launcher):
"""
Performs matrix multiplication with progress tracking.
Args:
a (torch.Tensor): First input tensor for matrix multiplication.
a_shared (torch.Tensor): Shared tensor across ranks.
b (torch.Tensor): Second input tensor for matrix multiplication.
progress (torch.Tensor): Tensor used to track progress of the operation.
SPLITS_PER_RANK (int): Number of splits per rank.
RANK (int): Current process rank.
Returns:
torch.Tensor: The result of the matrix multiplication.
"""
# src[all_gather_matmul.py:N]: M, K = a.size()
M, K = a.size()
# src[all_gather_matmul.py:N]: K2, N = b.size()
K2, N = b.size()
# src[all_gather_matmul.py:N]: assert K2 == K, f"size mismatch {K2} != {K}"
assert K2 == K, f'size mismatch {K2} != {K}'
# src[all_gather_matmul.py:N]: out = torch.empty(
# src[all_gather_matmul.py:N]: [M, N], dtype=torch.promote_types(a.dtype, b.dtype), device=a.device
# src[all_gather_matmul.py:N]: )
out = torch.empty([M, N], dtype=torch.promote_types(a.dtype, b.dtype), device=a.device)
# src[all_gather_matmul.py:N]: M_per_rank = a_shared.size(0)
M_per_rank = a_shared.size(0)
# src[all_gather_matmul.py:N]: for tile_m, tile_n in hl.tile([M, N]):
_BLOCK_SIZE_0 = 128
_BLOCK_SIZE_1 = 256
# src[all_gather_matmul.py:N]: for tile_k in hl.tile(K):
# src[all_gather_matmul.py:N]: acc = torch.addmm(acc, a[tile_m, tile_k], b[tile_k, tile_n])
_BLOCK_SIZE_2 = 64
# src[all_gather_matmul.py:N]: for tile_m, tile_n in hl.tile([M, N]):
# src[all_gather_matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
# src[all_gather_matmul.py:N]: hl.wait(
# src[all_gather_matmul.py:N-N]: ...
_launcher(_helion_helion_matmul_w_progress, (triton.cdiv(4096, _BLOCK_SIZE_0) * triton.cdiv(6656, _BLOCK_SIZE_1),), progress, a, b, out, SPLITS_PER_RANK, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=8, num_stages=3)
# src[all_gather_matmul.py:N]: return out
return out

--- assertExpectedJournal(TestExamplesDist.test_all_reduce)
from __future__ import annotations

import torch
import helion
import helion.language as hl
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_one_shot_all_reduce_kernel(signal_pad_addrs, local_signal_pad, a_shared_tuple_item_0, a_shared_tuple_item_1, a_shared_tuple_item_2, a_shared_tuple_item_3, out, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
# src[all_reduce.py:N]: for tile_n in hl.tile(N):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < 4096
indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
# src[all_reduce.py:N]: ptr_tile = signal_pad_addrs[:]
ptr_tile = tl.load(signal_pad_addrs + indices_1 * 1, None)
# src[all_reduce.py:N]: [tile_n.id, my_rank],
tile_id = offset_0 // _BLOCK_SIZE_0
# src[all_reduce.py:N]: hl.signal(
# src[all_reduce.py:N]: stack_signalpad,
# src[all_reduce.py:N]: [tile_n.id, my_rank],
# src[all_reduce.py:N-N]: ...
helion.runtime.triton_wait_multiple_signal(addr=ptr_tile.to(tl.pointer_type(tl.int32))[:] + (tile_id * 4 + 0 * 1)[None], expect=0, update=1, sem='relaxed', scope='sys', op='atomic_cas', skip_sync=True, sync_before=not True)
# src[all_reduce.py:N]: for world in hl.tile(world_size, block_size=world_size):
# src[all_reduce.py:N]: hl.wait(
# src[all_reduce.py:N]: local_signal_pad,
# src[all_reduce.py:N-N]: ...
for offset_2 in tl.range(0, 4, _BLOCK_SIZE_2):
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
# src[all_reduce.py:N]: [tile_n.id, world],
tile_id_1 = offset_0 // _BLOCK_SIZE_0
# src[all_reduce.py:N]: hl.wait(
# src[all_reduce.py:N]: local_signal_pad,
# src[all_reduce.py:N]: [tile_n.id, world],
# src[all_reduce.py:N-N]: ...
helion.runtime.triton_wait_multiple_signal(addr=local_signal_pad + (tile_id_1 * 4 + indices_2 * 1), expect=1, update=0, sem='acquire', scope='sys', op='atomic_cas', skip_sync=False)
# src[all_reduce.py:N]: acc = hl.zeros(
# src[all_reduce.py:N]: [tile_n], dtype=a_shared_tuple[0].dtype, device=local_signal_pad.device
# src[all_reduce.py:N]: )
acc = tl.full([_BLOCK_SIZE_0], 0.0, tl.bfloat16)
# src[all_reduce.py:N]: acc += a[tile_n]
load_1 = tl.load(a_shared_tuple_item_0 + indices_0 * 1, mask_0, other=0)
v_0 = acc + load_1
load_2 = tl.load(a_shared_tuple_item_1 + indices_0 * 1, mask_0, other=0)
v_1 = v_0 + load_2
load_3 = tl.load(a_shared_tuple_item_2 + indices_0 * 1, mask_0, other=0)
v_2 = v_1 + load_3
load_4 = tl.load(a_shared_tuple_item_3 + indices_0 * 1, mask_0, other=0)
v_3 = v_2 + load_4
# src[all_reduce.py:N]: out[tile_n] = acc
tl.store(out + indices_0 * 1, v_3, mask_0)
# src[all_reduce.py:N]: hl.signal(
# src[all_reduce.py:N]: stack_signalpad, [tile_n.id, my_rank], signal=1, wait_for=0, scope="sys"
# src[all_reduce.py:N]: )
helion.runtime.triton_wait_multiple_signal(addr=ptr_tile.to(tl.pointer_type(tl.int32))[:] + (tile_id * 4 + 0 * 1)[None], expect=0, update=1, sem='release', scope='sys', op='atomic_cas', skip_sync=True, sync_before=not False)
# src[all_reduce.py:N]: for world in hl.tile(world_size, block_size=world_size):
# src[all_reduce.py:N]: hl.wait(
# src[all_reduce.py:N]: local_signal_pad,
# src[all_reduce.py:N-N]: ...
for offset_3 in tl.range(0, 4, _BLOCK_SIZE_3):
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
# src[all_reduce.py:N]: [tile_n.id, world],
tile_id_2 = offset_0 // _BLOCK_SIZE_0
# src[all_reduce.py:N]: hl.wait(
# src[all_reduce.py:N]: local_signal_pad,
# src[all_reduce.py:N]: [tile_n.id, world],
# src[all_reduce.py:N-N]: ...
helion.runtime.triton_wait_multiple_signal(addr=local_signal_pad + (tile_id_2 * 4 + indices_3 * 1), expect=1, update=0, sem='relaxed', scope='sys', op='atomic_cas', skip_sync=True)

def one_shot_all_reduce_kernel(signal_pad_addrs: torch.Tensor, local_signal_pad: torch.Tensor, a_shared_tuple: tuple[torch.Tensor, ...], my_rank: hl.constexpr, *, _launcher=_default_launcher):
"""
Helion JIT-compiled kernel for one-shot all-reduce operation.

This kernel implements a distributed all-reduce using symmetric memory and signal pads
for cross-device synchronization. It performs element-wise summation across all devices
in the distributed group using tiled computation for memory efficiency.

Args:
signal_pad_addrs: Tensor containing addresses of signal pads for all devices
local_signal_pad: Local signal pad for synchronization
a_shared_tuple: Tuple of shared tensors from all devices in the group
my_rank: Current device's rank in the distributed group

Returns:
Tensor containing the all-reduced result (sum across all devices)
"""
# src[all_reduce.py:N]: _, world_size = local_signal_pad.size()
_, world_size = local_signal_pad.size()
# src[all_reduce.py:N]: out = torch.empty_like(a_shared_tuple[0])
out = torch.empty_like(a_shared_tuple[0])
# src[all_reduce.py:N]: N = out.size(0)
N = out.size(0)
# src[all_reduce.py:N]: for tile_n in hl.tile(N):
_BLOCK_SIZE_0 = 8192
_RDIM_SIZE_1 = 4
# src[all_reduce.py:N]: for world in hl.tile(world_size, block_size=world_size):
# src[all_reduce.py:N]: hl.wait(
# src[all_reduce.py:N]: local_signal_pad,
# src[all_reduce.py:N-N]: ...
_BLOCK_SIZE_2 = 4
# src[all_reduce.py:N]: for world in hl.tile(world_size, block_size=world_size):
# src[all_reduce.py:N]: hl.wait(
# src[all_reduce.py:N]: local_signal_pad,
# src[all_reduce.py:N-N]: ...
_BLOCK_SIZE_3 = 4
# src[all_reduce.py:N]: for tile_n in hl.tile(N):
# src[all_reduce.py:N]: # Sync all devices through signal_pad to make sure
# src[all_reduce.py:N]: # all previous writes to the shared tensor are visible
# src[all_reduce.py:N-N]: ...
_launcher(_helion_one_shot_all_reduce_kernel, (triton.cdiv(4096, _BLOCK_SIZE_0),), signal_pad_addrs, local_signal_pad, a_shared_tuple[0], a_shared_tuple[1], a_shared_tuple[2], a_shared_tuple[3], out, _BLOCK_SIZE_0, _RDIM_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=32, num_stages=1)
# src[all_reduce.py:N]: return out
return out
Loading
Loading