@@ -754,14 +754,6 @@ def setUp(self):
754754
755755 self .hidden_states = torch .randn (self .num_tokens , self .hidden_size )
756756 self .router_logits = torch .randn (self .num_tokens , self .num_experts )
757- """Mock custom routing"""
758- self .mock_custom_routing = MagicMock ()
759- self .mock_custom_routing .return_value = (torch .ones (
760- self .num_tokens , self .top_k ),
761- torch .zeros (
762- self .num_tokens ,
763- self .top_k ,
764- dtype = torch .int32 ))
765757
766758 self .mock_ctx = MagicMock ()
767759 self .mock_ctx .weight_prefetch_method = MagicMock ()
@@ -771,7 +763,7 @@ def setUp(self):
771763 self .addCleanup (patcher .stop )
772764 patcher .start ()
773765
774- @patch ('torch_npu.npu_moe_gating_top_k ' )
766+ @patch ('torch_npu.npu_moe_gating_top_k_softmax ' )
775767 def test_softmax_scoring (self , mock_topk ):
776768 """Test softmax scoring function"""
777769 mock_topk .return_value = (torch .ones (self .num_tokens , self .top_k ),
@@ -798,14 +790,12 @@ def test_softmax_scoring(self, mock_topk):
798790 def test_sigmoid_scoring (self ):
799791 """Test sigmoid scoring function"""
800792
801- weights , ids = select_experts (
802- hidden_states = self .hidden_states ,
803- router_logits = self .router_logits ,
804- top_k = self .top_k ,
805- use_grouped_topk = False ,
806- renormalize = False ,
807- scoring_func = "sigmoid" ,
808- custom_routing_function = self .mock_custom_routing )
793+ weights , ids = select_experts (hidden_states = self .hidden_states ,
794+ router_logits = self .router_logits ,
795+ top_k = self .top_k ,
796+ use_grouped_topk = False ,
797+ renormalize = False ,
798+ scoring_func = "sigmoid" )
809799
810800 self .assertEqual (weights .shape , (self .num_tokens , self .top_k ))
811801 self .assertEqual (ids .shape , (self .num_tokens , self .top_k ))
@@ -818,8 +808,7 @@ def test_invalid_scoring_func(self):
818808 top_k = self .top_k ,
819809 use_grouped_topk = False ,
820810 renormalize = False ,
821- scoring_func = "invalid_func" ,
822- custom_routing_function = self .mock_custom_routing )
811+ scoring_func = "invalid_func" )
823812
824813 @patch ('torch.topk' )
825814 def test_grouped_topk (self , mock_topk ):
@@ -829,15 +818,13 @@ def test_grouped_topk(self, mock_topk):
829818 self .top_k ,
830819 dtype = torch .long ))
831820
832- weights , ids = select_experts (
833- hidden_states = self .hidden_states ,
834- router_logits = self .router_logits ,
835- top_k = self .top_k ,
836- use_grouped_topk = True ,
837- renormalize = False ,
838- topk_group = 4 ,
839- num_expert_group = 2 ,
840- custom_routing_function = self .mock_custom_routing )
821+ weights , ids = select_experts (hidden_states = self .hidden_states ,
822+ router_logits = self .router_logits ,
823+ top_k = self .top_k ,
824+ use_grouped_topk = True ,
825+ renormalize = False ,
826+ topk_group = 4 ,
827+ num_expert_group = 2 )
841828
842829 mock_topk .assert_called ()
843830 self .assertEqual (weights .shape , (self .num_tokens , self .top_k ))
@@ -859,29 +846,35 @@ def test_grouped_topk_with_correction_bias(self, mock_grouped_topk):
859846 renormalize = False ,
860847 topk_group = 4 ,
861848 num_expert_group = 2 ,
862- e_score_correction_bias = e_score_correction_bias ,
863- custom_routing_function = self .mock_custom_routing )
849+ e_score_correction_bias = e_score_correction_bias )
864850
865851 mock_grouped_topk .assert_called_once ()
866852 self .assertEqual (weights .shape , (self .num_tokens , self .top_k ))
867853 self .assertEqual (ids .shape , (self .num_tokens , self .top_k ))
868854
869855 def test_custom_routing_function (self ):
870856 """Test custom routing function"""
857+ mock_custom_routing = MagicMock ()
858+ mock_custom_routing .return_value = (torch .ones (self .num_tokens ,
859+ self .top_k ),
860+ torch .zeros (self .num_tokens ,
861+ self .top_k ,
862+ dtype = torch .int32 ))
863+
871864 weights , ids = select_experts (
872865 hidden_states = self .hidden_states ,
873866 router_logits = self .router_logits ,
874867 top_k = self .top_k ,
875868 use_grouped_topk = False ,
876869 renormalize = False ,
877- custom_routing_function = self . mock_custom_routing )
870+ custom_routing_function = mock_custom_routing )
878871
879- self . mock_custom_routing .assert_called_once ()
872+ mock_custom_routing .assert_called_once ()
880873 self .assertEqual (weights .shape , (self .num_tokens , self .top_k ))
881874 self .assertEqual (ids .shape , (self .num_tokens , self .top_k ))
882875 self .assertEqual (ids .dtype , torch .int32 )
883876
884- @patch ('torch_npu.npu_moe_gating_top_k ' )
877+ @patch ('torch_npu.npu_moe_gating_top_k_softmax ' )
885878 def test_renormalize (self , mock_topk ):
886879 """Test renormalization"""
887880 mock_topk .return_value = (torch .ones (self .num_tokens , self .top_k ),
@@ -907,13 +900,13 @@ def test_renormalize(self, mock_topk):
907900 sums = weights .sum (dim = - 1 )
908901 self .assertTrue (torch .allclose (sums , torch .ones_like (sums )))
909902
910- @patch ('torch_npu.npu_moe_gating_top_k ' )
903+ @patch ('torch_npu.npu_moe_gating_top_k_softmax ' )
911904 def test_output_dtypes (self , mock_topk ):
912905 """Test output dtypes"""
913906 mock_topk .return_value = (torch .ones (self .num_tokens , self .top_k ),
914907 torch .zeros (self .num_tokens ,
915908 self .top_k ,
916- dtype = torch .int32 ),
909+ dtype = torch .long ),
917910 torch .arange (0 ,
918911 self .num_tokens * self .top_k ,
919912 dtype = torch .int32 ).view (
0 commit comments