@@ -49,12 +49,14 @@ def setup_test_environment():
4949]
5050
5151COMBINE_PARAMS = [
52- (2 , 64 , 8 , 2 , torch .bfloat16 ), # Small input, 2 ranks
53- (4 , 32 , 32768 , 4 , torch .bfloat16 ), # Large input, 4 ranks
54- (8 , 16 , 2048 , 8 , torch .bfloat16 ), # Medium input, 8 ranks
55- (2 , 64 , 8 , 2 , torch .float16 ), # Small input, 2 ranks
56- (4 , 32 , 32768 , 4 , torch .float16 ), # Large input, 4 ranks
57- (8 , 16 , 2048 , 8 , torch .float16 ), # Medium input, 8 ranks
52+ (2 , 64 , 8 , 2 , torch .bfloat16 , True ), # Small input, 2 ranks
53+ (4 , 32 , 32768 , 4 , torch .bfloat16 , True ), # Large input, 4 ranks
54+ (8 , 16 , 2048 , 8 , torch .bfloat16 , True ), # Medium input, 8 ranks
55+ (8 , 16 , 2048 , 8 , torch .bfloat16 , False ), # Medium input, 8 ranks
56+ (2 , 64 , 8 , 2 , torch .float16 , True ), # Small input, 2 ranks
57+ (4 , 32 , 32768 , 4 , torch .float16 , True ), # Large input, 4 ranks
58+ (8 , 16 , 2048 , 8 , torch .float16 , True ), # Medium input, 8 ranks
59+ (8 , 16 , 2048 , 8 , torch .float16 , False ), # Medium input, 8 ranks
5860]
5961
6062
@@ -429,9 +431,11 @@ def fake_moe(
429431 return processed_states .view (target_shape )
430432
431433
432- @pytest .mark .parametrize ("world_size,num_tokens,vector_dim,top_k,dtype" , COMBINE_PARAMS )
434+ @pytest .mark .parametrize (
435+ "world_size,num_tokens,vector_dim,top_k,dtype,payload_in_workspace" , COMBINE_PARAMS
436+ )
433437def test_moe_combine_multi_rank_single_gpu (
434- world_size , num_tokens , vector_dim , top_k , dtype
438+ world_size , num_tokens , vector_dim , top_k , dtype , payload_in_workspace
435439):
436440 torch .cuda .set_device (0 )
437441 check_sufficient_sm_count (num_tokens , world_size )
@@ -489,16 +493,27 @@ def test_moe_combine_multi_rank_single_gpu(
489493
490494 inplace_combine_tensors = []
491495 for rank in range (world_size ):
492- inplace_combine_tensors .append (
493- trtllm_moe_alltoall .moe_a2a_wrap_payload_tensor_in_workspace (
494- all_workspaces [rank , :],
495- [world_size , num_tokens ],
496- combine_payload_offsets [rank ],
497- combine_payload_offsets [rank ]
498- + world_size * num_tokens * vector_dim * dtype .itemsize ,
499- dtype ,
496+ if payload_in_workspace :
497+ inplace_combine_tensors .append (
498+ trtllm_moe_alltoall .moe_a2a_wrap_payload_tensor_in_workspace (
499+ all_workspaces [rank , :],
500+ [world_size , num_tokens ],
501+ combine_payload_offsets [rank ],
502+ combine_payload_offsets [rank ]
503+ + world_size * num_tokens * vector_dim * dtype .itemsize ,
504+ dtype ,
505+ )
506+ )
507+ else :
508+ inplace_combine_tensors .append (
509+ torch .empty (
510+ world_size ,
511+ num_tokens ,
512+ vector_dim ,
513+ dtype = dtype ,
514+ device = torch .device ("cuda" ),
515+ )
500516 )
501- )
502517
503518 for rank in range (world_size ):
504519 inplace_combine_tensors [rank ].copy_ (
@@ -520,7 +535,7 @@ def test_moe_combine_multi_rank_single_gpu(
520535 metainfo ,
521536 world_size ,
522537 combine_payload_offsets ,
523- payload_in_workspace = True ,
538+ payload_in_workspace = payload_in_workspace ,
524539 )
525540
526541 reference_result = fake_moe (
0 commit comments