diff --git a/problems/amd_distributed/all2all/submission.py b/problems/amd_distributed/all2all/submission.py index 5eddcf68..eb8d451f 100644 --- a/problems/amd_distributed/all2all/submission.py +++ b/problems/amd_distributed/all2all/submission.py @@ -38,7 +38,7 @@ def dispatch(self, dp_x: torch.Tensor, indices: torch.Tensor): ) # srcGobalExpert, srcRank, srcIndex, expert index send_counts_t = torch.tensor(send_counts, dtype=torch.long, device=device) - # 1.3 token nums to recv from each rank + # 1.4 token nums to recv from each rank recv_counts_t = torch.empty(self.world_size, dtype=torch.long, device=device) dist.all_to_all_single(recv_counts_t, send_counts_t) # ---------2. send and recv buffer, order by tokens on each rank ----------