diff --git a/.github/matrix.json b/.github/matrix.json index 733fddaf2..663f8ad4b 100644 --- a/.github/matrix.json +++ b/.github/matrix.json @@ -52,6 +52,16 @@ "pytorch-version": "pytorch-nightly", "alias": "h100" }, + { + "runner": "linux.aws.h100.4", + "python-version": "3.12", + "ref-eager": false, + "image": "nvidia/cuda:12.8.1-devel-ubuntu24.04", + "runtime-version": "cu128", + "container-options": "--gpus all", + "pytorch-version": "pytorch-nightly", + "alias": "h100-distributed" + }, { "runner": "linux.dgx.b200", "python-version": "3.12", diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8a568f259..d3652a265 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -66,7 +66,21 @@ jobs: run: | set -eux apt-get update - apt-get install -y libdw1 + apt-get install -y libdw1 curl wget git pkg-config zlib1g-dev build-essential + + - name: Install NVSHMEM + if: contains(matrix.alias, 'distributed') + run: | + set -euxo pipefail + GPU_COUNT=$(nvidia-smi -L | wc -l) + if [ "$GPU_COUNT" -ne 4 ]; then + echo "Error: Expected 4 GPUs but found $GPU_COUNT" + exit 1 + fi + curl -L https://raw.githubusercontent.com/pytorch/pytorch/main/.ci/docker/common/install_cuda.sh -o install_cuda.sh + chmod +x install_cuda.sh + source install_cuda.sh + install_nvshmem 13 3.4.5 - name: Install uv uses: astral-sh/setup-uv@v7 @@ -131,7 +145,7 @@ jobs: - name: Install Helion run: | source .venv/bin/activate - uv pip install setuptools + uv pip install setuptools ninja SETUPTOOLS_SCM_PRETEND_VERSION="0.0.0" uv pip install -e .'[dev]' python -c "import helion; print(helion.__name__)" @@ -145,7 +159,11 @@ jobs: if [[ "${{ contains(matrix.alias, 'cpu') }}" == "true" ]]; then export TRITON_CPU_BACKEND=1; fi # -rf: print failed tests # --timeout: max allowed time for each test - pytest -rf --timeout=60 + TEST_PATH=$([[ "${{ contains(matrix.alias, 'distributed') }}" == "true" ]] && echo "test/test_examples_dist.py" || echo ".") + EXTRA_FLAGS=$([[ "${{ contains(matrix.alias, 'distributed') }}" == "true" ]] && echo "-rs" || echo "--ignore=test/test_examples_dist.py") + # For distributed tests, fail if any test is skipped + SKIP_CHECK=$([[ "${{ contains(matrix.alias, 'distributed') }}" == "true" ]] && echo "! grep -q SKIPPED" || echo "cat") + pytest -rf --timeout=60 $EXTRA_FLAGS $TEST_PATH | tee >(eval $SKIP_CHECK) test-notebooks: name: test-notebooks-cu128-py3.12-pytorch-2.9-a10g diff --git a/examples/all_gather_matmul.py b/examples/all_gather_matmul.py index 1b9d11ee9..001d67a3c 100644 --- a/examples/all_gather_matmul.py +++ b/examples/all_gather_matmul.py @@ -235,8 +235,8 @@ def main() -> None: if __name__ == "__main__": """ Run with: - torchrun \ - --nnodes 1 --nproc-per-node 8 \ + python -m torch.distributed.run --standalone \ + --nproc-per-node 4 \ --rdzv-backend c10d --rdzv-endpoint localhost:0 \ --no_python python3 examples/all_gather_matmul.py """ diff --git a/examples/all_reduce.py b/examples/all_reduce.py index e4399440c..17616c275 100644 --- a/examples/all_reduce.py +++ b/examples/all_reduce.py @@ -23,7 +23,6 @@ import helion from helion._testing import DEVICE -from helion._testing import run_example import helion.language as hl # %% @@ -252,15 +251,25 @@ def test(N: int, device: torch.device, dtype: torch.dtype) -> None: assert dist_group is not None world_size = dist.get_world_size() + rank = dist.get_rank() + + # Create symmetric memory tensor for Helion implementation a_shared = symm_mem.empty(N // world_size, dtype=dtype, device=device).normal_() - run_example( - helion_one_shot_all_reduce, - reference_one_shot_all_reduce, - (a_shared,), - rtol=1e-1, - atol=1e-1, - ) + print(f"[Rank {rank}] Running Helion all-reduce...") + result_helion = helion_one_shot_all_reduce(a_shared) + + # Create symmetric memory tensor for reference implementation + a_shared_ref = symm_mem.empty(N // world_size, dtype=dtype, device=device) + a_shared_ref.copy_(a_shared) + + print(f"[Rank {rank}] Running reference all-reduce...") + result_ref = reference_one_shot_all_reduce(a_shared_ref) + + # Compare results + print(f"[Rank {rank}] Comparing results...") + torch.testing.assert_close(result_helion, result_ref, rtol=1e-1, atol=1e-1) + print(f"[Rank {rank}] Results match! ✓") def main() -> None: @@ -283,8 +292,8 @@ def main() -> None: if __name__ == "__main__": """ Run with: - torchrun \ - --nnodes 1 --nproc-per-node 8 \ + python -m torch.distributed.run --standalone \ + --nproc-per-node 4 \ --rdzv-backend c10d --rdzv-endpoint localhost:0 \ --no_python python3 examples/all_reduce.py """ diff --git a/test/test_examples_dist.expected b/test/test_examples_dist.expected new file mode 100644 index 000000000..dadbf473a --- /dev/null +++ b/test/test_examples_dist.expected @@ -0,0 +1,199 @@ +This file is automatically generated by assertExpectedJournal calls in test_examples_dist.py. +Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set. + +--- assertExpectedJournal(TestExamplesDist.test_all_gather_matmul) +from __future__ import annotations + +import torch +import helion +import triton +import triton.language as tl +from torch._inductor.runtime import triton_helpers +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_helion_matmul_w_progress(progress, a, b, out, SPLITS_PER_RANK, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + # src[all_gather_matmul.py:N]: for tile_m, tile_n in hl.tile([M, N]): + num_blocks_0 = tl.cdiv(4096, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + offset_1 = pid_1 * _BLOCK_SIZE_1 + # src[all_gather_matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + # src[all_gather_matmul.py:N]: tile_m.begin // (M_per_rank // SPLITS_PER_RANK), + floordiv = triton_helpers.div_floor_integer(1024, SPLITS_PER_RANK) + floordiv_1 = triton_helpers.div_floor_integer(offset_0, triton_helpers.div_floor_integer(1024, SPLITS_PER_RANK)) + # src[all_gather_matmul.py:N]: hl.wait( + # src[all_gather_matmul.py:N]: progress, + # src[all_gather_matmul.py:N]: [ + # src[all_gather_matmul.py:N-N]: ... + helion.runtime.triton_wait_signal(addr=progress + floordiv_1 * 1, expect=1, update=0, sem='acquire', scope='gpu', op='ld', skip_sync=False) + # src[all_gather_matmul.py:N]: for tile_k in hl.tile(K): + # src[all_gather_matmul.py:N]: acc = torch.addmm(acc, a[tile_m, tile_k], b[tile_k, tile_n]) + for offset_2 in tl.range(0, 16384, _BLOCK_SIZE_2): + acc_copy = acc + acc_copy_0 = acc_copy + # src[all_gather_matmul.py:N]: acc = torch.addmm(acc, a[tile_m, tile_k], b[tile_k, tile_n]) + load = tl.load(tl.make_block_ptr(a, [4096, 16384], [16384, 1], [offset_0, offset_2], [_BLOCK_SIZE_0, _BLOCK_SIZE_2], [1, 0]), boundary_check=[0, 1], padding_option='zero') + load_1 = tl.load(tl.make_block_ptr(b, [16384, 6656], [1, 16384], [offset_2, offset_1], [_BLOCK_SIZE_2, _BLOCK_SIZE_1], [0, 1]), boundary_check=[0, 1], padding_option='zero') + acc = tl.dot(tl.cast(load, tl.bfloat16), tl.cast(load_1, tl.bfloat16), acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32) + # src[all_gather_matmul.py:N]: out[tile_m, tile_n] = acc + v_0 = tl.cast(acc, tl.bfloat16) + tl.store(tl.make_block_ptr(out, [4096, 6656], [6656, 1], [offset_0, offset_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), v_0, boundary_check=[0, 1]) + +def helion_matmul_w_progress(a: torch.Tensor, a_shared: torch.Tensor, b: torch.Tensor, progress: torch.Tensor, SPLITS_PER_RANK: int, RANK: int, *, _launcher=_default_launcher): + """ + Performs matrix multiplication with progress tracking. + Args: + a (torch.Tensor): First input tensor for matrix multiplication. + a_shared (torch.Tensor): Shared tensor across ranks. + b (torch.Tensor): Second input tensor for matrix multiplication. + progress (torch.Tensor): Tensor used to track progress of the operation. + SPLITS_PER_RANK (int): Number of splits per rank. + RANK (int): Current process rank. + Returns: + torch.Tensor: The result of the matrix multiplication. + """ + # src[all_gather_matmul.py:N]: M, K = a.size() + M, K = a.size() + # src[all_gather_matmul.py:N]: K2, N = b.size() + K2, N = b.size() + # src[all_gather_matmul.py:N]: assert K2 == K, f"size mismatch {K2} != {K}" + assert K2 == K, f'size mismatch {K2} != {K}' + # src[all_gather_matmul.py:N]: out = torch.empty( + # src[all_gather_matmul.py:N]: [M, N], dtype=torch.promote_types(a.dtype, b.dtype), device=a.device + # src[all_gather_matmul.py:N]: ) + out = torch.empty([M, N], dtype=torch.promote_types(a.dtype, b.dtype), device=a.device) + # src[all_gather_matmul.py:N]: M_per_rank = a_shared.size(0) + M_per_rank = a_shared.size(0) + # src[all_gather_matmul.py:N]: for tile_m, tile_n in hl.tile([M, N]): + _BLOCK_SIZE_0 = 128 + _BLOCK_SIZE_1 = 256 + # src[all_gather_matmul.py:N]: for tile_k in hl.tile(K): + # src[all_gather_matmul.py:N]: acc = torch.addmm(acc, a[tile_m, tile_k], b[tile_k, tile_n]) + _BLOCK_SIZE_2 = 64 + # src[all_gather_matmul.py:N]: for tile_m, tile_n in hl.tile([M, N]): + # src[all_gather_matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + # src[all_gather_matmul.py:N]: hl.wait( + # src[all_gather_matmul.py:N-N]: ... + _launcher(_helion_helion_matmul_w_progress, (triton.cdiv(4096, _BLOCK_SIZE_0) * triton.cdiv(6656, _BLOCK_SIZE_1),), progress, a, b, out, SPLITS_PER_RANK, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=8, num_stages=3) + # src[all_gather_matmul.py:N]: return out + return out + +--- assertExpectedJournal(TestExamplesDist.test_all_reduce) +from __future__ import annotations + +import torch +import helion +import helion.language as hl +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_one_shot_all_reduce_kernel(signal_pad_addrs, local_signal_pad, a_shared_tuple_item_0, a_shared_tuple_item_1, a_shared_tuple_item_2, a_shared_tuple_item_3, out, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr): + # src[all_reduce.py:N]: for tile_n in hl.tile(N): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < 4096 + indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) + # src[all_reduce.py:N]: ptr_tile = signal_pad_addrs[:] + ptr_tile = tl.load(signal_pad_addrs + indices_1 * 1, None) + # src[all_reduce.py:N]: [tile_n.id, my_rank], + tile_id = offset_0 // _BLOCK_SIZE_0 + # src[all_reduce.py:N]: hl.signal( + # src[all_reduce.py:N]: stack_signalpad, + # src[all_reduce.py:N]: [tile_n.id, my_rank], + # src[all_reduce.py:N-N]: ... + helion.runtime.triton_wait_multiple_signal(addr=ptr_tile.to(tl.pointer_type(tl.int32))[:] + (tile_id * 4 + 0 * 1)[None], expect=0, update=1, sem='relaxed', scope='sys', op='atomic_cas', skip_sync=True, sync_before=not True) + # src[all_reduce.py:N]: for world in hl.tile(world_size, block_size=world_size): + # src[all_reduce.py:N]: hl.wait( + # src[all_reduce.py:N]: local_signal_pad, + # src[all_reduce.py:N-N]: ... + for offset_2 in tl.range(0, 4, _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + # src[all_reduce.py:N]: [tile_n.id, world], + tile_id_1 = offset_0 // _BLOCK_SIZE_0 + # src[all_reduce.py:N]: hl.wait( + # src[all_reduce.py:N]: local_signal_pad, + # src[all_reduce.py:N]: [tile_n.id, world], + # src[all_reduce.py:N-N]: ... + helion.runtime.triton_wait_multiple_signal(addr=local_signal_pad + (tile_id_1 * 4 + indices_2 * 1), expect=1, update=0, sem='acquire', scope='sys', op='atomic_cas', skip_sync=False) + # src[all_reduce.py:N]: acc = hl.zeros( + # src[all_reduce.py:N]: [tile_n], dtype=a_shared_tuple[0].dtype, device=local_signal_pad.device + # src[all_reduce.py:N]: ) + acc = tl.full([_BLOCK_SIZE_0], 0.0, tl.bfloat16) + # src[all_reduce.py:N]: acc += a[tile_n] + load_1 = tl.load(a_shared_tuple_item_0 + indices_0 * 1, mask_0, other=0) + v_0 = acc + load_1 + load_2 = tl.load(a_shared_tuple_item_1 + indices_0 * 1, mask_0, other=0) + v_1 = v_0 + load_2 + load_3 = tl.load(a_shared_tuple_item_2 + indices_0 * 1, mask_0, other=0) + v_2 = v_1 + load_3 + load_4 = tl.load(a_shared_tuple_item_3 + indices_0 * 1, mask_0, other=0) + v_3 = v_2 + load_4 + # src[all_reduce.py:N]: out[tile_n] = acc + tl.store(out + indices_0 * 1, v_3, mask_0) + # src[all_reduce.py:N]: hl.signal( + # src[all_reduce.py:N]: stack_signalpad, [tile_n.id, my_rank], signal=1, wait_for=0, scope="sys" + # src[all_reduce.py:N]: ) + helion.runtime.triton_wait_multiple_signal(addr=ptr_tile.to(tl.pointer_type(tl.int32))[:] + (tile_id * 4 + 0 * 1)[None], expect=0, update=1, sem='release', scope='sys', op='atomic_cas', skip_sync=True, sync_before=not False) + # src[all_reduce.py:N]: for world in hl.tile(world_size, block_size=world_size): + # src[all_reduce.py:N]: hl.wait( + # src[all_reduce.py:N]: local_signal_pad, + # src[all_reduce.py:N-N]: ... + for offset_3 in tl.range(0, 4, _BLOCK_SIZE_3): + indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32) + # src[all_reduce.py:N]: [tile_n.id, world], + tile_id_2 = offset_0 // _BLOCK_SIZE_0 + # src[all_reduce.py:N]: hl.wait( + # src[all_reduce.py:N]: local_signal_pad, + # src[all_reduce.py:N]: [tile_n.id, world], + # src[all_reduce.py:N-N]: ... + helion.runtime.triton_wait_multiple_signal(addr=local_signal_pad + (tile_id_2 * 4 + indices_3 * 1), expect=1, update=0, sem='relaxed', scope='sys', op='atomic_cas', skip_sync=True) + +def one_shot_all_reduce_kernel(signal_pad_addrs: torch.Tensor, local_signal_pad: torch.Tensor, a_shared_tuple: tuple[torch.Tensor, ...], my_rank: hl.constexpr, *, _launcher=_default_launcher): + """ + Helion JIT-compiled kernel for one-shot all-reduce operation. + + This kernel implements a distributed all-reduce using symmetric memory and signal pads + for cross-device synchronization. It performs element-wise summation across all devices + in the distributed group using tiled computation for memory efficiency. + + Args: + signal_pad_addrs: Tensor containing addresses of signal pads for all devices + local_signal_pad: Local signal pad for synchronization + a_shared_tuple: Tuple of shared tensors from all devices in the group + my_rank: Current device's rank in the distributed group + + Returns: + Tensor containing the all-reduced result (sum across all devices) + """ + # src[all_reduce.py:N]: _, world_size = local_signal_pad.size() + _, world_size = local_signal_pad.size() + # src[all_reduce.py:N]: out = torch.empty_like(a_shared_tuple[0]) + out = torch.empty_like(a_shared_tuple[0]) + # src[all_reduce.py:N]: N = out.size(0) + N = out.size(0) + # src[all_reduce.py:N]: for tile_n in hl.tile(N): + _BLOCK_SIZE_0 = 8192 + _RDIM_SIZE_1 = 4 + # src[all_reduce.py:N]: for world in hl.tile(world_size, block_size=world_size): + # src[all_reduce.py:N]: hl.wait( + # src[all_reduce.py:N]: local_signal_pad, + # src[all_reduce.py:N-N]: ... + _BLOCK_SIZE_2 = 4 + # src[all_reduce.py:N]: for world in hl.tile(world_size, block_size=world_size): + # src[all_reduce.py:N]: hl.wait( + # src[all_reduce.py:N]: local_signal_pad, + # src[all_reduce.py:N-N]: ... + _BLOCK_SIZE_3 = 4 + # src[all_reduce.py:N]: for tile_n in hl.tile(N): + # src[all_reduce.py:N]: # Sync all devices through signal_pad to make sure + # src[all_reduce.py:N]: # all previous writes to the shared tensor are visible + # src[all_reduce.py:N-N]: ... + _launcher(_helion_one_shot_all_reduce_kernel, (triton.cdiv(4096, _BLOCK_SIZE_0),), signal_pad_addrs, local_signal_pad, a_shared_tuple[0], a_shared_tuple[1], a_shared_tuple[2], a_shared_tuple[3], out, _BLOCK_SIZE_0, _RDIM_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=32, num_stages=1) + # src[all_reduce.py:N]: return out + return out diff --git a/test/test_examples_dist.py b/test/test_examples_dist.py new file mode 100644 index 000000000..86acc5194 --- /dev/null +++ b/test/test_examples_dist.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +import torch +import torch.distributed as dist +import torch.distributed._symmetric_memory as symm_mem +from torch.testing._internal.common_distributed import MultiProcessTestCase +from torch.testing._internal.common_distributed import skip_if_lt_x_gpu +from torch.testing._internal.common_utils import instantiate_parametrized_tests +from torch.testing._internal.common_utils import run_tests + +from helion._testing import EXAMPLES_DIR +from helion._testing import TestCase +from helion._testing import code_and_output +from helion._testing import import_path + + +@instantiate_parametrized_tests +class TestExamplesDist(TestCase, MultiProcessTestCase): + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + + def setUp(self) -> None: + super().setUp() + self._spawn_processes() + + @property + def world_size(self) -> int: + return 4 + + @property + def device(self) -> torch.device: + return torch.device(f"cuda:{self.rank}") + + def _init_process(self): + torch.cuda.set_device(self.device) + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + backend="nccl", + world_size=self.world_size, + rank=self.rank, + store=store, + ) + torch.manual_seed(42 + self.rank) + + @skip_if_lt_x_gpu(4) + def test_all_gather_matmul(self): + self._init_process() + + mod = import_path(EXAMPLES_DIR / "all_gather_matmul.py") + + M, N, K = 4096, 6656, 16384 + + a_shared = symm_mem.empty( + M // self.world_size, K, dtype=torch.bfloat16, device=self.device + ).normal_() + + b = ( + torch.randn((K, N), device=self.device, dtype=torch.bfloat16) + .T.contiguous() + .T + ) + + symm_mem_group = dist.group.WORLD + if symm_mem_group is None: + raise RuntimeError("No symmetric memory group available") + symm_mem_hdl = symm_mem.rendezvous(a_shared, group=symm_mem_group) + a_shape = list(a_shared.shape) + a_shape[0] *= symm_mem_hdl.world_size + a_out = torch.empty(a_shape, dtype=a_shared.dtype, device=a_shared.device) + progress = torch.zeros( + symm_mem_hdl.world_size, + dtype=torch.uint32, + device=a_shared.device, + ) + backend_stream = mod.copy_engine_all_gather_w_progress( + a_out, a_shared, progress, 1 + ) + + code, result = code_and_output( + mod.helion_matmul_w_progress, + (a_out, a_shared, b, progress, 1, symm_mem_hdl.rank), + ) + + if self.rank == 0: + if not hasattr(self.__class__, "_expected_journal"): + from helion._testing import AssertExpectedJournal + + self.__class__._expected_journal = AssertExpectedJournal(self.__class__) + self.assertExpectedJournal(code) + + golden_a = a_shared.clone() + ag_golden, mm_golden = torch.ops.symm_mem.fused_all_gather_matmul( + golden_a, [b], gather_dim=0, group_name=symm_mem_group.group_name + ) + + torch.testing.assert_close(result, mm_golden[0], rtol=1e-1, atol=1e-1) + torch.testing.assert_close(a_out, ag_golden) + + torch.cuda.current_stream().wait_stream(backend_stream) + dist.destroy_process_group() + + @skip_if_lt_x_gpu(4) + def test_all_reduce(self): + self._init_process() + + mod = import_path(EXAMPLES_DIR / "all_reduce.py") + + N = 16384 + dtype = torch.bfloat16 + + a_shared = symm_mem.empty( + N // self.world_size, dtype=dtype, device=self.device + ).normal_() + + symm_mem_hdl = symm_mem.rendezvous(a_shared, group=dist.group.WORLD) + a_shared_tuple = tuple( + [ + symm_mem_hdl.get_buffer(i, tuple(a_shared.shape), a_shared.dtype) + for i in range(symm_mem_hdl.world_size) + ] + ) + local_signal_pad = symm_mem_hdl.get_signal_pad( + symm_mem_hdl.rank, dtype=torch.int32 + ).view(-1, symm_mem_hdl.world_size) + signal_pad_addrs = mod.dev_array_to_tensor_short( + symm_mem_hdl.signal_pad_ptrs_dev, + (symm_mem_hdl.world_size,), + dtype=torch.uint64, + device=a_shared.device, + ) + + code, result = code_and_output( + mod.one_shot_all_reduce_kernel, + (signal_pad_addrs, local_signal_pad, a_shared_tuple, symm_mem_hdl.rank), + ) + + if self.rank == 0: + if not hasattr(self.__class__, "_expected_journal"): + from helion._testing import AssertExpectedJournal + + self.__class__._expected_journal = AssertExpectedJournal(self.__class__) + self.assertExpectedJournal(code) + + a_shared_ref = symm_mem.empty( + N // self.world_size, dtype=dtype, device=self.device + ) + a_shared_ref.copy_(a_shared) + expected = mod.reference_one_shot_all_reduce(a_shared_ref) + + torch.testing.assert_close(result, expected, rtol=1e-1, atol=1e-1) + + dist.destroy_process_group() + + +if __name__ == "__main__": + run_tests()