Skip to content

Commit 00d38cc

Browse files
committed
Add test for payload not in the workspace and fix coderabbit comments
1 parent 94df845 commit 00d38cc

File tree

3 files changed

+72
-26
lines changed

3 files changed

+72
-26
lines changed

csrc/trtllm_moe_alltoall.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@ Tuple<Array<int64_t>, Array<int64_t>, int64_t> moeA2ADispatchOp(
127127
TensorView metainfo, int64_t runtimeMaxTokensPerRank, int64_t epRank, int64_t epSize,
128128
int64_t topK, int64_t numExperts) {
129129
using tl_throughput::PayloadDescriptor;
130-
fflush(stdout);
131130

132131
CHECK_INPUT(tokenSelectedExperts);
133132
CHECK_INPUT_TYPE(tokenSelectedExperts, dl_int32);
@@ -388,6 +387,10 @@ void moeA2ASanitizeExpertIdsOp(TensorView expertIds, TensorView workspace, Tenso
388387
static_cast<int32_t*>(expertIds.data_ptr()), recvCounters,
389388
static_cast<int32_t>(invalidExpertId), static_cast<int>(epSize),
390389
static_cast<int>(runtimeMaxTokensPerRank), static_cast<int>(topK), get_current_stream());
390+
391+
auto err = cudaGetLastError();
392+
TVM_FFI_ICHECK(err == cudaSuccess)
393+
<< "moe_a2a_sanitize_expert_ids launch failed: " << cudaGetErrorString(err);
391394
}
392395

393396
// Expose metainfo index constants for Python access

flashinfer/comm/trtllm_moe_alltoall.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,18 @@ def moe_a2a_get_workspace_size_per_rank(
178178
total_dispatch_payload_size_per_token: int,
179179
combine_payload_size_per_token: int,
180180
):
181+
"""
182+
Get the workspace size per rank for the MoeAlltoAll operation.
183+
184+
Args:
185+
ep_size: Total expert parallel size
186+
max_num_tokens: Maximum number of tokens across all ranks
187+
total_dispatch_payload_size_per_token: The size of the payload per token in the dispatch phase. This should be the sum of all payloads tensors.
188+
combine_payload_size_per_token: The size of the payload per token in the combine phase.
189+
190+
Returns:
191+
workspace_size_per_rank: Size of the workspace per rank in bytes
192+
"""
181193
return module.moe_a2a_get_workspace_size_per_rank(
182194
ep_size,
183195
max_num_tokens,
@@ -218,15 +230,13 @@ def moe_a2a_wrap_payload_tensor_in_workspace(
218230
219231
Args:
220232
workspace: [ep_size, size_per_rank] workspace tensor
221-
ep_rank: Current expert parallel rank
222-
ep_size: Total expert parallel size
223-
runtime_max_tokens_per_rank: Max tokens per rank in this batch
224-
total_size: Total size of the payload
225-
offset: Offset from dispatch
226-
dtype: Data type for the tensor
233+
leading_shape: The leading shape to wrap the tensor with
234+
slice_start: The start of the slice in the workspace
235+
slice_end: The end of the slice in the workspace
236+
dtype: Data type for the output tensor
227237
228238
Returns:
229-
tensor: [ep_size * max_tokens, hidden_size] workspace-backed tensor
239+
tensor: [leading_shape, *] workspace-backed tensor
230240
"""
231241
workspace_base = workspace.view(-1).view(dtype=torch.uint8)
232242
assert slice_end <= workspace.numel(), (
@@ -249,6 +259,24 @@ def moe_a2a_dispatch(
249259
top_k: int,
250260
num_experts: int,
251261
):
262+
"""
263+
Dispatch tokens and payloads to expert ranks.
264+
265+
Args:
266+
token_selected_experts: [local_num_tokens, top_k] int32 tensor
267+
input_payloads: List of [local_num_tokens, *] tensors to dispatch
268+
workspace: [ep_size, size_per_rank] workspace tensor
269+
metainfo: Metadata tensor from initialize
270+
runtime_max_tokens_per_rank: Max tokens per rank in this batch
271+
ep_rank: Current expert parallel rank
272+
ep_size: Total expert parallel size
273+
top_k: Number of experts per token
274+
num_experts: Total number of experts
275+
276+
Returns:
277+
output_payloads: List of payloads for this rank, backed by data in the workspace
278+
combine_payload_offset: The offset to place the combine payload in the workspace
279+
"""
252280
recv_offsets, recv_sizes, combine_payload_offset = (
253281
get_mnnvl_moe_alltoall_module().moe_a2a_dispatch(
254282
token_selected_experts,

tests/comm/test_trtllm_moe_alltoall.py

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,14 @@ def setup_test_environment():
4949
]
5050

5151
COMBINE_PARAMS = [
52-
(2, 64, 8, 2, torch.bfloat16), # Small input, 2 ranks
53-
(4, 32, 32768, 4, torch.bfloat16), # Large input, 4 ranks
54-
(8, 16, 2048, 8, torch.bfloat16), # Medium input, 8 ranks
55-
(2, 64, 8, 2, torch.float16), # Small input, 2 ranks
56-
(4, 32, 32768, 4, torch.float16), # Large input, 4 ranks
57-
(8, 16, 2048, 8, torch.float16), # Medium input, 8 ranks
52+
(2, 64, 8, 2, torch.bfloat16, True), # Small input, 2 ranks
53+
(4, 32, 32768, 4, torch.bfloat16, True), # Large input, 4 ranks
54+
(8, 16, 2048, 8, torch.bfloat16, True), # Medium input, 8 ranks
55+
(8, 16, 2048, 8, torch.bfloat16, False), # Medium input, 8 ranks
56+
(2, 64, 8, 2, torch.float16, True), # Small input, 2 ranks
57+
(4, 32, 32768, 4, torch.float16, True), # Large input, 4 ranks
58+
(8, 16, 2048, 8, torch.float16, True), # Medium input, 8 ranks
59+
(8, 16, 2048, 8, torch.float16, False), # Medium input, 8 ranks
5860
]
5961

6062

@@ -429,9 +431,11 @@ def fake_moe(
429431
return processed_states.view(target_shape)
430432

431433

432-
@pytest.mark.parametrize("world_size,num_tokens,vector_dim,top_k,dtype", COMBINE_PARAMS)
434+
@pytest.mark.parametrize(
435+
"world_size,num_tokens,vector_dim,top_k,dtype,payload_in_workspace", COMBINE_PARAMS
436+
)
433437
def test_moe_combine_multi_rank_single_gpu(
434-
world_size, num_tokens, vector_dim, top_k, dtype
438+
world_size, num_tokens, vector_dim, top_k, dtype, payload_in_workspace
435439
):
436440
torch.cuda.set_device(0)
437441
check_sufficient_sm_count(num_tokens, world_size)
@@ -489,16 +493,27 @@ def test_moe_combine_multi_rank_single_gpu(
489493

490494
inplace_combine_tensors = []
491495
for rank in range(world_size):
492-
inplace_combine_tensors.append(
493-
trtllm_moe_alltoall.moe_a2a_wrap_payload_tensor_in_workspace(
494-
all_workspaces[rank, :],
495-
[world_size, num_tokens],
496-
combine_payload_offsets[rank],
497-
combine_payload_offsets[rank]
498-
+ world_size * num_tokens * vector_dim * dtype.itemsize,
499-
dtype,
496+
if payload_in_workspace:
497+
inplace_combine_tensors.append(
498+
trtllm_moe_alltoall.moe_a2a_wrap_payload_tensor_in_workspace(
499+
all_workspaces[rank, :],
500+
[world_size, num_tokens],
501+
combine_payload_offsets[rank],
502+
combine_payload_offsets[rank]
503+
+ world_size * num_tokens * vector_dim * dtype.itemsize,
504+
dtype,
505+
)
506+
)
507+
else:
508+
inplace_combine_tensors.append(
509+
torch.empty(
510+
world_size,
511+
num_tokens,
512+
vector_dim,
513+
dtype=dtype,
514+
device=torch.device("cuda"),
515+
)
500516
)
501-
)
502517

503518
for rank in range(world_size):
504519
inplace_combine_tensors[rank].copy_(
@@ -520,7 +535,7 @@ def test_moe_combine_multi_rank_single_gpu(
520535
metainfo,
521536
world_size,
522537
combine_payload_offsets,
523-
payload_in_workspace=True,
538+
payload_in_workspace=payload_in_workspace,
524539
)
525540

526541
reference_result = fake_moe(

0 commit comments

Comments
 (0)