@@ -485,6 +485,134 @@ def test_fused_moe(use_ep, mesh, num_tokens, intermediate_size, hidden_size,
485485 )
486486
487487
488+ @pytest .mark .parametrize ("mesh" , [
489+ test_utils .get_spmd_mesh (1 ),
490+ test_utils .get_spmd_mesh (jax .local_device_count ())
491+ ])
492+ @pytest .mark .parametrize ("num_tokens" , [8 ])
493+ @pytest .mark .parametrize ("intermediate_size" , [1024 ])
494+ @pytest .mark .parametrize ("hidden_size" , [128 ])
495+ @pytest .mark .parametrize ("num_experts" , [8 ])
496+ @pytest .mark .parametrize ("topk" , [2 ])
497+ def test_fused_moe_bias (mesh , num_tokens , intermediate_size , hidden_size ,
498+ num_experts , topk ):
499+ os .environ ['VLLM_DISABLE_SHARED_EXPERTS_STREAM' ] = '1'
500+ torch .manual_seed (42 )
501+ dtype = torch .bfloat16
502+
503+ a = torch .randn ((num_tokens , hidden_size ), dtype = dtype ) / 10
504+ w1 = torch .randn (
505+ (num_experts , 2 * intermediate_size , hidden_size ), dtype = dtype ) / 10
506+ w2 = torch .randn (
507+ (num_experts , hidden_size , intermediate_size ), dtype = dtype ) / 10
508+ w1_bias = torch .randn (
509+ (num_experts , 2 * intermediate_size ), dtype = dtype ) / 10
510+ w2_bias = torch .randn ((num_experts , hidden_size ), dtype = dtype ) / 10
511+ score = torch .randn ((num_tokens , num_experts ), dtype = dtype )
512+
513+ engine_args = EngineArgs (
514+ model = "Qwen/Qwen2-1.5B-Instruct" ,
515+ max_model_len = 64 ,
516+ max_num_batched_tokens = 64 ,
517+ max_num_seqs = 4 ,
518+ )
519+ vllm_config = engine_args .create_engine_config ()
520+ vllm_config .model_config .dtype = dtype
521+
522+ quant_config = get_tpu_quantization_config (vllm_config , mesh )
523+ with set_current_vllm_config (vllm_config ):
524+ vllm_fused_moe = FusedMoE (
525+ num_experts = num_experts ,
526+ top_k = topk ,
527+ hidden_size = hidden_size ,
528+ intermediate_size = intermediate_size ,
529+ reduce_results = False ,
530+ renormalize = False ,
531+ tp_size = 1 ,
532+ dp_size = 1 ,
533+ quant_config = quant_config ,
534+ has_bias = True ,
535+ )
536+ vllm_fused_moe .w13_weight .data = w1
537+ vllm_fused_moe .w2_weight .data = w2
538+ vllm_fused_moe .w13_bias .data = w1_bias
539+ vllm_fused_moe .w2_bias .data = w2_bias
540+
541+ jax_a = torch_view (t2j (a , use_dlpack = False ))
542+ jax_a .apply_jax_ (jax .device_put , NamedSharding (mesh , P (None , None )))
543+ score = torch_view (t2j (score ))
544+ score .apply_jax_ (jax .device_put , NamedSharding (mesh , P (None , None )))
545+
546+ with torchax .default_env (), set_forward_context (None , vllm_config ):
547+ assert isinstance (vllm_fused_moe .quant_method ,
548+ VllmUnquantizedFusedMoEMethod )
549+ vllm_fused_moe .quant_method .process_weights_after_loading (
550+ vllm_fused_moe )
551+ vllm_fused_moe (jax_a , score )
552+
553+
554+ @pytest .mark .parametrize ("mesh" , [
555+ test_utils .get_spmd_mesh (1 ),
556+ test_utils .get_spmd_mesh (jax .local_device_count ())
557+ ])
558+ @pytest .mark .parametrize ("num_tokens" , [8 ])
559+ @pytest .mark .parametrize ("intermediate_size" , [1024 ])
560+ @pytest .mark .parametrize ("hidden_size" , [128 ])
561+ @pytest .mark .parametrize ("num_experts" , [8 ])
562+ @pytest .mark .parametrize ("topk" , [2 ])
563+ @pytest .mark .parametrize ("activation" , ["silu" , "swigluoai" ])
564+ def test_fused_moe_activation (mesh , num_tokens , intermediate_size , hidden_size ,
565+ num_experts , topk , activation ):
566+ os .environ ['VLLM_DISABLE_SHARED_EXPERTS_STREAM' ] = '1'
567+ torch .manual_seed (42 )
568+ dtype = torch .bfloat16
569+
570+ a = torch .randn ((num_tokens , hidden_size ), dtype = dtype ) / 10
571+ w1 = torch .randn (
572+ (num_experts , 2 * intermediate_size , hidden_size ), dtype = dtype ) / 10
573+ w2 = torch .randn (
574+ (num_experts , hidden_size , intermediate_size ), dtype = dtype ) / 10
575+ score = torch .randn ((num_tokens , num_experts ), dtype = dtype )
576+
577+ engine_args = EngineArgs (
578+ model = "Qwen/Qwen2-1.5B-Instruct" ,
579+ max_model_len = 64 ,
580+ max_num_batched_tokens = 64 ,
581+ max_num_seqs = 4 ,
582+ )
583+ vllm_config = engine_args .create_engine_config ()
584+ vllm_config .model_config .dtype = dtype
585+
586+ quant_config = get_tpu_quantization_config (vllm_config , mesh )
587+ with set_current_vllm_config (vllm_config ):
588+ vllm_fused_moe = FusedMoE (
589+ num_experts = num_experts ,
590+ top_k = topk ,
591+ hidden_size = hidden_size ,
592+ intermediate_size = intermediate_size ,
593+ reduce_results = False ,
594+ renormalize = False ,
595+ tp_size = 1 ,
596+ dp_size = 1 ,
597+ quant_config = quant_config ,
598+ activation = activation ,
599+ )
600+ vllm_fused_moe .w13_weight .data = w1
601+ vllm_fused_moe .w2_weight .data = w2
602+
603+ jax_a = torch_view (t2j (a , use_dlpack = False ))
604+ jax_a .apply_jax_ (jax .device_put , NamedSharding (mesh , P (None , None )))
605+ score = torch_view (t2j (score ))
606+ score .apply_jax_ (jax .device_put , NamedSharding (mesh , P (None , None )))
607+
608+ with torchax .default_env (), set_forward_context (None , vllm_config ):
609+ assert isinstance (vllm_fused_moe .quant_method ,
610+ VllmUnquantizedFusedMoEMethod )
611+ vllm_fused_moe .quant_method .process_weights_after_loading (
612+ vllm_fused_moe )
613+ vllm_fused_moe (jax_a , score )
614+
615+
488616@pytest .mark .parametrize ("use_ep" , [True ])
489617@pytest .mark .parametrize ("mesh" ,
490618 [test_utils .get_spmd_mesh (jax .local_device_count ())])
0 commit comments