|
| 1 | +This file is automatically generated by assertExpectedJournal calls in test_examples_dist.py. |
| 2 | +Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set. |
| 3 | + |
| 4 | +--- assertExpectedJournal(TestExamplesDist.test_all_gather_matmul) |
| 5 | +from __future__ import annotations |
| 6 | + |
| 7 | +import torch |
| 8 | +import helion |
| 9 | +import triton |
| 10 | +import triton.language as tl |
| 11 | +from torch._inductor.runtime import triton_helpers |
| 12 | +from helion.runtime import default_launcher as _default_launcher |
| 13 | + |
| 14 | +@triton.jit |
| 15 | +def _helion_helion_matmul_w_progress(progress, a, b, out, SPLITS_PER_RANK, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): |
| 16 | + # src[all_gather_matmul.py:N]: for tile_m, tile_n in hl.tile([M, N]): |
| 17 | + num_blocks_0 = tl.cdiv(4096, _BLOCK_SIZE_0) |
| 18 | + pid_0 = tl.program_id(0) % num_blocks_0 |
| 19 | + pid_1 = tl.program_id(0) // num_blocks_0 |
| 20 | + offset_0 = pid_0 * _BLOCK_SIZE_0 |
| 21 | + offset_1 = pid_1 * _BLOCK_SIZE_1 |
| 22 | + # src[all_gather_matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) |
| 23 | + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) |
| 24 | + # src[all_gather_matmul.py:N]: tile_m.begin // (M_per_rank // SPLITS_PER_RANK), |
| 25 | + floordiv = triton_helpers.div_floor_integer(1024, SPLITS_PER_RANK) |
| 26 | + floordiv_1 = triton_helpers.div_floor_integer(offset_0, triton_helpers.div_floor_integer(1024, SPLITS_PER_RANK)) |
| 27 | + # src[all_gather_matmul.py:N]: hl.wait( |
| 28 | + # src[all_gather_matmul.py:N]: progress, |
| 29 | + # src[all_gather_matmul.py:N]: [ |
| 30 | + # src[all_gather_matmul.py:N-N]: ... |
| 31 | + helion.runtime.triton_wait_signal(addr=progress + floordiv_1 * 1, expect=1, update=0, sem='acquire', scope='gpu', op='ld', skip_sync=False) |
| 32 | + # src[all_gather_matmul.py:N]: for tile_k in hl.tile(K): |
| 33 | + # src[all_gather_matmul.py:N]: acc = torch.addmm(acc, a[tile_m, tile_k], b[tile_k, tile_n]) |
| 34 | + for offset_2 in tl.range(0, 16384, _BLOCK_SIZE_2): |
| 35 | + acc_copy = acc |
| 36 | + acc_copy_0 = acc_copy |
| 37 | + # src[all_gather_matmul.py:N]: acc = torch.addmm(acc, a[tile_m, tile_k], b[tile_k, tile_n]) |
| 38 | + load = tl.load(tl.make_block_ptr(a, [4096, 16384], [16384, 1], [offset_0, offset_2], [_BLOCK_SIZE_0, _BLOCK_SIZE_2], [1, 0]), boundary_check=[0, 1], padding_option='zero') |
| 39 | + load_1 = tl.load(tl.make_block_ptr(b, [16384, 6656], [1, 16384], [offset_2, offset_1], [_BLOCK_SIZE_2, _BLOCK_SIZE_1], [0, 1]), boundary_check=[0, 1], padding_option='zero') |
| 40 | + acc = tl.dot(tl.cast(load, tl.bfloat16), tl.cast(load_1, tl.bfloat16), acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32) |
| 41 | + # src[all_gather_matmul.py:N]: out[tile_m, tile_n] = acc |
| 42 | + v_0 = tl.cast(acc, tl.bfloat16) |
| 43 | + tl.store(tl.make_block_ptr(out, [4096, 6656], [6656, 1], [offset_0, offset_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), v_0, boundary_check=[0, 1]) |
| 44 | + |
| 45 | +def helion_matmul_w_progress(a: torch.Tensor, a_shared: torch.Tensor, b: torch.Tensor, progress: torch.Tensor, SPLITS_PER_RANK: int, RANK: int, *, _launcher=_default_launcher): |
| 46 | + """ |
| 47 | + Performs matrix multiplication with progress tracking. |
| 48 | + Args: |
| 49 | + a (torch.Tensor): First input tensor for matrix multiplication. |
| 50 | + a_shared (torch.Tensor): Shared tensor across ranks. |
| 51 | + b (torch.Tensor): Second input tensor for matrix multiplication. |
| 52 | + progress (torch.Tensor): Tensor used to track progress of the operation. |
| 53 | + SPLITS_PER_RANK (int): Number of splits per rank. |
| 54 | + RANK (int): Current process rank. |
| 55 | + Returns: |
| 56 | + torch.Tensor: The result of the matrix multiplication. |
| 57 | + """ |
| 58 | + # src[all_gather_matmul.py:N]: M, K = a.size() |
| 59 | + M, K = a.size() |
| 60 | + # src[all_gather_matmul.py:N]: K2, N = b.size() |
| 61 | + K2, N = b.size() |
| 62 | + # src[all_gather_matmul.py:N]: assert K2 == K, f"size mismatch {K2} != {K}" |
| 63 | + assert K2 == K, f'size mismatch {K2} != {K}' |
| 64 | + # src[all_gather_matmul.py:N]: out = torch.empty( |
| 65 | + # src[all_gather_matmul.py:N]: [M, N], dtype=torch.promote_types(a.dtype, b.dtype), device=a.device |
| 66 | + # src[all_gather_matmul.py:N]: ) |
| 67 | + out = torch.empty([M, N], dtype=torch.promote_types(a.dtype, b.dtype), device=a.device) |
| 68 | + # src[all_gather_matmul.py:N]: M_per_rank = a_shared.size(0) |
| 69 | + M_per_rank = a_shared.size(0) |
| 70 | + # src[all_gather_matmul.py:N]: for tile_m, tile_n in hl.tile([M, N]): |
| 71 | + _BLOCK_SIZE_0 = 128 |
| 72 | + _BLOCK_SIZE_1 = 256 |
| 73 | + # src[all_gather_matmul.py:N]: for tile_k in hl.tile(K): |
| 74 | + # src[all_gather_matmul.py:N]: acc = torch.addmm(acc, a[tile_m, tile_k], b[tile_k, tile_n]) |
| 75 | + _BLOCK_SIZE_2 = 64 |
| 76 | + # src[all_gather_matmul.py:N]: for tile_m, tile_n in hl.tile([M, N]): |
| 77 | + # src[all_gather_matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) |
| 78 | + # src[all_gather_matmul.py:N]: hl.wait( |
| 79 | + # src[all_gather_matmul.py:N-N]: ... |
| 80 | + _launcher(_helion_helion_matmul_w_progress, (triton.cdiv(4096, _BLOCK_SIZE_0) * triton.cdiv(6656, _BLOCK_SIZE_1),), progress, a, b, out, SPLITS_PER_RANK, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=8, num_stages=3) |
| 81 | + # src[all_gather_matmul.py:N]: return out |
| 82 | + return out |
| 83 | + |
| 84 | +--- assertExpectedJournal(TestExamplesDist.test_all_reduce) |
| 85 | +from __future__ import annotations |
| 86 | + |
| 87 | +import torch |
| 88 | +import helion |
| 89 | +import helion.language as hl |
| 90 | +import triton |
| 91 | +import triton.language as tl |
| 92 | +from helion.runtime import default_launcher as _default_launcher |
| 93 | + |
| 94 | +@triton.jit |
| 95 | +def _helion_one_shot_all_reduce_kernel(signal_pad_addrs, local_signal_pad, a_shared_tuple_item_0, a_shared_tuple_item_1, a_shared_tuple_item_2, a_shared_tuple_item_3, out, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr): |
| 96 | + # src[all_reduce.py:N]: for tile_n in hl.tile(N): |
| 97 | + pid_0 = tl.program_id(0) |
| 98 | + offset_0 = pid_0 * _BLOCK_SIZE_0 |
| 99 | + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) |
| 100 | + mask_0 = indices_0 < 4096 |
| 101 | + indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) |
| 102 | + # src[all_reduce.py:N]: ptr_tile = signal_pad_addrs[:] |
| 103 | + ptr_tile = tl.load(signal_pad_addrs + indices_1 * 1, None) |
| 104 | + # src[all_reduce.py:N]: [tile_n.id, my_rank], |
| 105 | + tile_id = offset_0 // _BLOCK_SIZE_0 |
| 106 | + # src[all_reduce.py:N]: hl.signal( |
| 107 | + # src[all_reduce.py:N]: stack_signalpad, |
| 108 | + # src[all_reduce.py:N]: [tile_n.id, my_rank], |
| 109 | + # src[all_reduce.py:N-N]: ... |
| 110 | + helion.runtime.triton_wait_multiple_signal(addr=ptr_tile.to(tl.pointer_type(tl.int32))[:] + (tile_id * 4 + 0 * 1)[None], expect=0, update=1, sem='relaxed', scope='sys', op='atomic_cas', skip_sync=True, sync_before=not True) |
| 111 | + # src[all_reduce.py:N]: for world in hl.tile(world_size, block_size=world_size): |
| 112 | + # src[all_reduce.py:N]: hl.wait( |
| 113 | + # src[all_reduce.py:N]: local_signal_pad, |
| 114 | + # src[all_reduce.py:N-N]: ... |
| 115 | + for offset_2 in tl.range(0, 4, _BLOCK_SIZE_2): |
| 116 | + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) |
| 117 | + # src[all_reduce.py:N]: [tile_n.id, world], |
| 118 | + tile_id_1 = offset_0 // _BLOCK_SIZE_0 |
| 119 | + # src[all_reduce.py:N]: hl.wait( |
| 120 | + # src[all_reduce.py:N]: local_signal_pad, |
| 121 | + # src[all_reduce.py:N]: [tile_n.id, world], |
| 122 | + # src[all_reduce.py:N-N]: ... |
| 123 | + helion.runtime.triton_wait_multiple_signal(addr=local_signal_pad + (tile_id_1 * 4 + indices_2 * 1), expect=1, update=0, sem='acquire', scope='sys', op='atomic_cas', skip_sync=False) |
| 124 | + # src[all_reduce.py:N]: acc = hl.zeros( |
| 125 | + # src[all_reduce.py:N]: [tile_n], dtype=a_shared_tuple[0].dtype, device=local_signal_pad.device |
| 126 | + # src[all_reduce.py:N]: ) |
| 127 | + acc = tl.full([_BLOCK_SIZE_0], 0.0, tl.bfloat16) |
| 128 | + # src[all_reduce.py:N]: acc += a[tile_n] |
| 129 | + load_1 = tl.load(a_shared_tuple_item_0 + indices_0 * 1, mask_0, other=0) |
| 130 | + v_0 = acc + load_1 |
| 131 | + load_2 = tl.load(a_shared_tuple_item_1 + indices_0 * 1, mask_0, other=0) |
| 132 | + v_1 = v_0 + load_2 |
| 133 | + load_3 = tl.load(a_shared_tuple_item_2 + indices_0 * 1, mask_0, other=0) |
| 134 | + v_2 = v_1 + load_3 |
| 135 | + load_4 = tl.load(a_shared_tuple_item_3 + indices_0 * 1, mask_0, other=0) |
| 136 | + v_3 = v_2 + load_4 |
| 137 | + # src[all_reduce.py:N]: out[tile_n] = acc |
| 138 | + tl.store(out + indices_0 * 1, v_3, mask_0) |
| 139 | + # src[all_reduce.py:N]: hl.signal( |
| 140 | + # src[all_reduce.py:N]: stack_signalpad, [tile_n.id, my_rank], signal=1, wait_for=0, scope="sys" |
| 141 | + # src[all_reduce.py:N]: ) |
| 142 | + helion.runtime.triton_wait_multiple_signal(addr=ptr_tile.to(tl.pointer_type(tl.int32))[:] + (tile_id * 4 + 0 * 1)[None], expect=0, update=1, sem='release', scope='sys', op='atomic_cas', skip_sync=True, sync_before=not False) |
| 143 | + # src[all_reduce.py:N]: for world in hl.tile(world_size, block_size=world_size): |
| 144 | + # src[all_reduce.py:N]: hl.wait( |
| 145 | + # src[all_reduce.py:N]: local_signal_pad, |
| 146 | + # src[all_reduce.py:N-N]: ... |
| 147 | + for offset_3 in tl.range(0, 4, _BLOCK_SIZE_3): |
| 148 | + indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32) |
| 149 | + # src[all_reduce.py:N]: [tile_n.id, world], |
| 150 | + tile_id_2 = offset_0 // _BLOCK_SIZE_0 |
| 151 | + # src[all_reduce.py:N]: hl.wait( |
| 152 | + # src[all_reduce.py:N]: local_signal_pad, |
| 153 | + # src[all_reduce.py:N]: [tile_n.id, world], |
| 154 | + # src[all_reduce.py:N-N]: ... |
| 155 | + helion.runtime.triton_wait_multiple_signal(addr=local_signal_pad + (tile_id_2 * 4 + indices_3 * 1), expect=1, update=0, sem='relaxed', scope='sys', op='atomic_cas', skip_sync=True) |
| 156 | + |
| 157 | +def one_shot_all_reduce_kernel(signal_pad_addrs: torch.Tensor, local_signal_pad: torch.Tensor, a_shared_tuple: tuple[torch.Tensor, ...], my_rank: hl.constexpr, *, _launcher=_default_launcher): |
| 158 | + """ |
| 159 | + Helion JIT-compiled kernel for one-shot all-reduce operation. |
| 160 | + |
| 161 | + This kernel implements a distributed all-reduce using symmetric memory and signal pads |
| 162 | + for cross-device synchronization. It performs element-wise summation across all devices |
| 163 | + in the distributed group using tiled computation for memory efficiency. |
| 164 | + |
| 165 | + Args: |
| 166 | + signal_pad_addrs: Tensor containing addresses of signal pads for all devices |
| 167 | + local_signal_pad: Local signal pad for synchronization |
| 168 | + a_shared_tuple: Tuple of shared tensors from all devices in the group |
| 169 | + my_rank: Current device's rank in the distributed group |
| 170 | + |
| 171 | + Returns: |
| 172 | + Tensor containing the all-reduced result (sum across all devices) |
| 173 | + """ |
| 174 | + # src[all_reduce.py:N]: _, world_size = local_signal_pad.size() |
| 175 | + _, world_size = local_signal_pad.size() |
| 176 | + # src[all_reduce.py:N]: out = torch.empty_like(a_shared_tuple[0]) |
| 177 | + out = torch.empty_like(a_shared_tuple[0]) |
| 178 | + # src[all_reduce.py:N]: N = out.size(0) |
| 179 | + N = out.size(0) |
| 180 | + # src[all_reduce.py:N]: for tile_n in hl.tile(N): |
| 181 | + _BLOCK_SIZE_0 = 8192 |
| 182 | + _RDIM_SIZE_1 = 4 |
| 183 | + # src[all_reduce.py:N]: for world in hl.tile(world_size, block_size=world_size): |
| 184 | + # src[all_reduce.py:N]: hl.wait( |
| 185 | + # src[all_reduce.py:N]: local_signal_pad, |
| 186 | + # src[all_reduce.py:N-N]: ... |
| 187 | + _BLOCK_SIZE_2 = 4 |
| 188 | + # src[all_reduce.py:N]: for world in hl.tile(world_size, block_size=world_size): |
| 189 | + # src[all_reduce.py:N]: hl.wait( |
| 190 | + # src[all_reduce.py:N]: local_signal_pad, |
| 191 | + # src[all_reduce.py:N-N]: ... |
| 192 | + _BLOCK_SIZE_3 = 4 |
| 193 | + # src[all_reduce.py:N]: for tile_n in hl.tile(N): |
| 194 | + # src[all_reduce.py:N]: # Sync all devices through signal_pad to make sure |
| 195 | + # src[all_reduce.py:N]: # all previous writes to the shared tensor are visible |
| 196 | + # src[all_reduce.py:N-N]: ... |
| 197 | + _launcher(_helion_one_shot_all_reduce_kernel, (triton.cdiv(4096, _BLOCK_SIZE_0),), signal_pad_addrs, local_signal_pad, a_shared_tuple[0], a_shared_tuple[1], a_shared_tuple[2], a_shared_tuple[3], out, _BLOCK_SIZE_0, _RDIM_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=32, num_stages=1) |
| 198 | + # src[all_reduce.py:N]: return out |
| 199 | + return out |
0 commit comments