Skip to content

Commit 8ede969

Browse files
authored
Fix moe UT test case (#5376)
1 parent a7a4ee5 commit 8ede969

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tests/gpu/examples/moe/test_moe_scatter_gather.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
class TestTorchMethod:
1010

1111
def init(self, n_token, n_expert, n_topk):
12-
gating_logits = torch.randn(n_token, n_expert, device=dpcpp_device)
12+
gating_logits = torch.rand(n_token, n_expert, device=dpcpp_device)
1313
gating_logits = gating_logits.to(torch.float)
1414
softmax = torch.nn.functional.softmax(gating_logits, dim=-1, dtype=torch.float)
1515
topk_weights, topk_indices = torch.topk(softmax, n_topk, dim=-1)
@@ -108,7 +108,7 @@ def test_moe_scatter(self, dtype, n_expert, n_token):
108108
def test_moe_gather(self, dtype, n_expert, n_token):
109109
n_channels = 1024
110110
n_topk = 2
111-
activation = torch.randn(
111+
activation = torch.rand(
112112
(n_token * n_topk, n_channels), dtype=dtype, device=dpcpp_device
113113
)
114114
topk_weights, topk_indices, token_for_experts, token_offset = self.init(

0 commit comments

Comments
 (0)