Skip to content

Conversation

@djns99
Copy link

@djns99 djns99 commented Nov 18, 2025

📌 Description

This ports the latest MNNVL A2A communication implementation from TRT-LLM

🔍 Related Issues

#2094

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added a throughput‑optimized Mixture‑of‑Experts all‑to‑all backend with Python API, workspace management, dispatch/combine flows, and expert‑ID sanitization.
  • Configuration

    • Added runtime tuning controls for MoE A2A (one‑block behavior, dispatch/combine block sizes, force GDR copy) and renamed a KV cache timing config.
  • Documentation

    • Added API docs for the MNNVL A2A throughput backend.
  • Tests

    • Added comprehensive single‑GPU and multi‑rank end‑to‑end tests and updated test scripts to run them.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 18, 2025

Walkthrough

Adds a throughput‑optimized Mixture‑of‑Experts all‑to‑all backend: new CUDA kernels and header, C++ FFI entry points and env helpers, Python/JIT bindings and workspace management, metadata helpers, and comprehensive single‑GPU and MPI multi‑rank tests and docs.

Changes

Cohort / File(s) Summary
CUDA Kernels & Header
csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu, csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h
New MoE A2A CUDA implementation and public header: dispatch/combine/sanitize kernels, Warp/Block policies, vectorized copy/dispatch/combine helpers, compute_target_rank_id, Dispatch/Combine pointer & params structs, TOP_K/dtype/policy macros, and host launch entrypoints.
FFI & Runtime C++
csrc/trtllm_moe_alltoall.cu
New TVM/FFI layer exposing workspace-size computation, initialize, dispatch, combine, sanitize_expert_ids, and metainfo index retrieval; workspace layout, payload descriptors, alignment, validation, and FFI exports.
Env utils
csrc/nv_internal/cpp/common/envUtils.cpp, csrc/nv_internal/tensorrt_llm/common/envUtils.h
Removed getEnvParallelCacheSend; renamed getEnvKVCacheTransferOutputPathgetEnvKVCacheTimeOutputPath (return type std::string const&); added MOE A2A env helpers (getEnvMoeA2AOneBlockPerToken, getEnvMoeA2ADispatchBlockSize, getEnvMoeA2ACombineBlockSize, getEnvEplbForceGdrcopy) and sanitizeBlockSize.
Meta info
csrc/nv_internal/tensorrt_llm/thop/moeAlltoAllMeta.h
New enum MoeA2AMetaInfoIndex, MoeA2ADataOffsets alias, and helper returning name→index pairs for metainfo fields.
Python module & public API
flashinfer/comm/trtllm_moe_alltoall.py, flashinfer/comm/__init__.py
New MoeAlltoAll class and top-level wrappers (initialize, dispatch, combine, sanitize, get_metainfo_index_pairs, get_workspace_size_per_rank); lazy JIT module builder, workspace caching/state, zero-copy payload handling, and package exports.
JIT & build integration
flashinfer/jit/comm.py, flashinfer/jit/__init__.py, flashinfer/aot.py
Added gen_mnnvl_moe_alltoall_module JitSpec (sources including kernels and envUtils), re-exported it publicly, and integrated into AOT generator when applicable.
Tests
tests/comm/test_trtllm_moe_alltoall.py, tests/comm/test_mnnvl_moe_alltoall.py, tests/comm/test_mnnvl_memory.py
New end‑to‑end tests (single‑GPU and MPI multi‑rank) for dispatch/combine/sanitize with helpers and fake‑MoE reference; device selection tweak (self.rankself.local_rank) in memory test.
Docs & Scripts
docs/api/comm.rst, scripts/task_test_multi_node_comm_kernels.sh, scripts/task_test_single_node_comm_kernels.sh
Added API docs for MNNVL A2A throughput backend; new test invocations in CI scripts; commented out pycache cleanup in multi-node test script.
Packaging / imports
flashinfer/jit/__init__.py, flashinfer/aot.py
Re-exported gen_mnnvl_moe_alltoall_module in public JIT API and wired the generator into AOT flow.

Sequence Diagram(s)

sequenceDiagram
    participant Py as Python API (MoeAlltoAll)
    participant JIT as JIT Module
    participant FFI as C++ FFI
    participant GPU as GPU Kernels
    participant Meta as Metainfo / Flags

    Note over Py,JIT: Ensure JIT module built/loaded
    Py->>JIT: ensure module
    Py->>FFI: moe_a2a_initialize(workspace, ep_rank, ep_size, max_tokens)
    FFI->>Meta: init metainfo, counters, flags

    Py->>FFI: moe_a2a_dispatch(routing, payloads, workspace, metainfo, ...)
    FFI->>GPU: launch moeA2ADispatchKernel / prepare
    GPU->>Meta: write recv buffers, update send/recv counters & completion flags
    FFI-->>Py: return recv tensors + combine_payload_offset

    Py->>FFI: moe_a2a_combine(payload, local_num_tokens, workspace, metainfo, ..., offset)
    FFI->>GPU: launch moeA2ACombineKernel / prepare
    GPU->>Meta: read recv buffers, accumulate TOP_K contributions, write outputs
    FFI-->>Py: return combined outputs

    Py->>FFI: moe_a2a_sanitize_expert_ids(expert_ids, workspace, metainfo, ep_rank, invalid_id)
    FFI->>GPU: launch sanitize kernel
    FFI-->>Py: sanitization complete
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Areas needing extra attention:

  • csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu: complex templated CUDA, atomics, vectorized memory ops, TOP_K and policy specializations.
  • csrc/trtllm_moe_alltoall.cu: workspace layout, offsets/alignment, FFI marshaling and error handling.
  • flashinfer/comm/trtllm_moe_alltoall.py: JIT integration, workspace caching/lifecycle, state transitions and zero-copy workspace tensor handling.
  • Tests: MPI multi‑rank orchestration and numerical tolerance for bf16.
  • csrc/nv_internal/cpp/common/envUtils.*: API change (removed function, renamed function with return-type change) may affect callers.

Suggested reviewers

  • aleozlx
  • joker-eph
  • djmmoss
  • cyx-6
  • yzh119
  • nvmbreughe
  • wenscarl

Poem

🐰 I hop through kernels, swift and small,
I route each token, heed the call,
I gather payloads rank by rank,
I stitch and sum—no step is blank,
A2A hops—carrots for all! 🥕

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 33.04% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The PR title clearly and concisely summarizes the main change: porting TRT-LLM communication kernels to flashinfer, which matches the core objective of this changeset.
Description check ✅ Passed The PR description includes the required sections from the template (Description, Related Issues, Pre-commit Checks, Tests) with all checklist items marked as completed, and provides context about porting MNNVL A2A from TRT-LLM.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@djns99 djns99 force-pushed the djns99/update-trtllm-kernels branch 5 times, most recently from 710a388 to bd82a2b Compare November 20, 2025 04:18
#define check_timeout(s) false
#else
// 300 * 2000 MHz - should be high enough on any GPU but will prevent a hang
#define check_timeout(s) ((clock64() - (s)) > (300ll * 2000ll * 1000ll * 1000ll))
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have manually added this, can I get someone to sanity check my logic here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me

# echo ""

pip install -e . -v
# pip install -e . -v
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im pretty sure this file is intended to be run with mpirun, so doing this logic on every rank would be incorrect behaviour. I don't see where this task is used though, so I could be misunderstanding something

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dierksen @nvmbreughe will this cause any issues that you know?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know of any issues. I don't think we actually use this in CI at the moment.

@djns99 djns99 force-pushed the djns99/update-trtllm-kernels branch from f766bfe to 8cdf8d8 Compare November 27, 2025 02:00
)


def moe_a2a_get_workspace_size_per_rank(
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This API is also an addition of mine, some extra eyes on this would be helpful

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like TRT-LLM just added (in the last 24 hours) something similar which looks slightly more user friendly: NVIDIA/TensorRT-LLM@8b5eded

Maybe in a future PR that could be integrated to flashinfer

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I am in touch with the author, I was going to port across as part of this, it should be a small change

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@trevor-m @bobboli I ported those changes into this MR. Can you take a look and see if it looks sensible. I made some tweaks to make it slightly more flexible, since @bobboli's version was specific to TRT-LLM's usage


# Single shared workspace across the process
# _WORKSPACE: Optional[dict] = None
_WORKSPACE_CACHE: dict[tuple[int, int, int, int], dict] = {}
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated this to use a cache so that we support different shaped communicators existing at the same time. Previously this assumed all communicators had the same shape which made testing tricky.
I would appreciate if people could double check my working here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me. I don't think sglang will need different shaped communicators right now but may be nice for other cases or in the future.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah mainly useful for the tests, in trt-llm each test is launched under MPI so has a new state, but we dont do that in flashinfer.

@djns99 djns99 marked this pull request as ready for review November 28, 2025 02:03
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🧹 Nitpick comments (18)
csrc/nv_internal/cpp/common/envUtils.cpp (1)

357-357: Consider caching getEnvEplbForceGdrcopy like other bool env helpers

getEnvEplbForceGdrcopy calls getBoolEnv (and thus std::getenv) on every invocation, while most other helpers in this file cache the value in a static local. Functionally this is fine, but for consistency and to avoid repeated env lookups in hot paths you might want to align it:

-bool getEnvEplbForceGdrcopy() { return getBoolEnv("TRTLLM_EPLB_FORCE_GDRCOPY"); }
+bool getEnvEplbForceGdrcopy() {
+  static bool const forceGdrcopy = getBoolEnv("TRTLLM_EPLB_FORCE_GDRCOPY");
+  return forceGdrcopy;
+}

Not critical, but it would match the rest of the env-utils style.

csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu (2)

428-431: Consider documenting why the acquire fence is commented out.

The fence.acquire.sys at line 430 is commented out after the dispatch wait loop. While the combine kernel (line 735) does have an acquire fence, having it commented here without explanation could cause confusion for future maintainers. If this is intentional (relying on the combine kernel's fence), a brief comment explaining the design decision would help.

       }
-      // asm volatile("fence.acquire.sys;");
+      // NOTE: Acquire fence intentionally omitted here; combine kernel provides
+      // the acquire semantics before reading peer data.
 #endif

596-609: Generic fallback is unreachable code.

The generic fallback reduction loop (lines 599-608) can never be reached because the SWITCH_TOP_K macro (lines 53-78) only allows TOP_K values of 1, 2, 4, or 8, and all these cases have explicit handling above. Consider removing this dead code or adding a static_assert to document the constraint.

     } else if constexpr (TOP_K == 1) {
       // nothing to do
-    } else {
-      // Generic fallback: accumulate all into acc[0]
-      T* a0 = reinterpret_cast<T*>(&acc[0]);
-#pragma unroll
-      for (int k = 1; k < TOP_K; ++k) {
-        T* ak = reinterpret_cast<T*>(&acc[k]);
-#pragma unroll
-        for (int j = 0; j < elems_per_vec; ++j) {
-          a0[j] += ak[j];
-        }
-      }
+    } else {
+      static_assert(TOP_K == 1 || TOP_K == 2 || TOP_K == 4 || TOP_K == 8,
+                    "Only TOP_K values 1, 2, 4, 8 are supported");
     }
scripts/task_test_multi_node_comm_kernels.sh (1)

9-13: Disabling cache cleanup may cause stale import issues.

The cache cleanup commands are commented out. If module refactoring occurs between test runs, stale .pyc files could cause import errors or unexpected behavior. Consider re-enabling these commands or documenting why they're disabled.

tests/comm/test_trtllm_moe_alltoall.py (4)

1-2: Copyright year should be updated to 2025.

The license header shows 2024 but this is a new file created in 2025.

-Copyright (c) 2024 by FlashInfer team.
+Copyright (c) 2025 by FlashInfer team.

112-112: Potential issue with payload size calculation.

x[0].numel() gets the number of elements in the first row, but if input_tensors is a list of 2D tensors, this calculates size per token correctly. However, the variable name payload_size_per_token and the indexing x[0] could be clearer.

-    payload_size_per_token = sum([x[0].numel() * x.itemsize for x in input_tensors])
+    payload_size_per_token = sum([x.shape[-1] * x.element_size() for x in input_tensors])

207-236: CUDA streams created but not explicitly cleaned up.

The cuda_streams_all_ranks list creates CUDA streams that are not explicitly destroyed. While Python's garbage collector will eventually clean them up, for test reliability consider using a context manager or explicit cleanup.


411-411: Minor typo in comment.

Extra slash at end of comment.

-        # For each expert selected for this token/
+        # For each expert selected for this token
tests/comm/test_mnnvl_moe_alltoall.py (4)

37-46: Consider using raise without exception name per Python best practices.

The explicit raise e is redundant; bare raise preserves the traceback better.

 def safe_run(func, *args, **kwargs):
     comm = MPI.COMM_WORLD
     try:
         func(*args, **kwargs)
     except MPIExit as e:
-        raise e
+        raise
     except Exception as e:
         traceback.print_exc()
         comm.allgather(True)
-        raise e
+        raise

49-51: Test fixture should yield for proper cleanup semantics.

Even though no cleanup is needed, the fixture pattern should include yield for consistency.

 @pytest.fixture(autouse=True)
 def setup_test():
     torch.manual_seed(0x1234)
+    yield

571-576: Blind exception catch may mask real initialization errors.

Catching bare Exception when checking MNNVL support could hide legitimate configuration issues. Consider catching specific exception types or at least logging the exception.

     try:
         MnnvlMemory.initialize()
         if not MnnvlMemory.supports_mnnvl():
             pytest.skip("MNNVL not supported on this system")
-    except Exception:
+    except (RuntimeError, pynvml.NVMLError) as e:
+        # Log exception for debugging if needed
         pytest.skip("MNNVL not supported on this system")

709-712: Unused variable expert_id_payload_index as flagged by static analysis.

The unpacked variable is never used. Either prefix with underscore or remove from unpacking.

-    payloads, expert_id_payload_index = make_bfloat16_payloads(
+    payloads, _expert_id_payload_index = make_bfloat16_payloads(
         local_num_tokens, hidden_size, top_k, rank, token_selected_experts
     )
flashinfer/comm/trtllm_moe_alltoall.py (5)

8-8: TODO comment should be addressed or tracked.

The # TODO Review comment at the top suggests this module needs review. Consider removing after review or converting to a tracked issue.

Would you like me to open an issue to track any remaining review items?


351-351: Mutable class attribute should use ClassVar annotation.

Per static analysis and Python best practices, mutable class attributes should be annotated with typing.ClassVar.

+from typing import ClassVar
+
 class MoeAlltoAll:
     ...
     # Single shared workspace across the process
-    _WORKSPACE_CACHE: dict[tuple[int, int, int, int], dict] = {}
+    _WORKSPACE_CACHE: ClassVar[dict[tuple[int, int, int, int], dict]] = {}

456-463: Assertions for validation could use proper exceptions in production.

Using assert for validation is acceptable for debug builds but these checks may be skipped in optimized Python (python -O). Consider using explicit if/raise for critical invariants.


610-610: Inefficient way to get element size.

Creating an empty tensor just to get element size is wasteful. Use torch.finfo or torch.iinfo or a lookup table instead.

-        element_size = torch.tensor([], dtype=dtype).element_size()
+        # More efficient: use dtype itemsize directly
+        element_size = torch.empty(0, dtype=dtype).element_size()

Or better, consider caching element sizes or using:

element_size = torch.finfo(dtype).bits // 8 if dtype.is_floating_point else torch.iinfo(dtype).bits // 8

621-628: __all__ is not sorted as noted by static analysis.

Consider sorting for consistency, though this is a minor issue.

 __all__ = [
     "MoeAlltoAll",
     "moe_a2a_initialize",
+    "moe_a2a_combine",
     "moe_a2a_dispatch",
-    "moe_a2a_combine",
+    "moe_a2a_get_workspace_size_per_rank",
     "moe_a2a_sanitize_expert_ids",
-    "moe_a2a_get_workspace_size_per_rank",
 ]
csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h (1)

78-120: Well-documented struct with clear field descriptions.

The MoeA2ADispatchParams struct has excellent inline documentation explaining each field's purpose and dimensions. The TODO on line 90-91 about renaming max_tokens_per_rank to runtime_max_tokens_per_rank should be tracked.

Would you like me to open an issue to track the TODO about renaming max_tokens_per_rank?

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 18004a8 and 71bb8fb.

📒 Files selected for processing (16)
  • csrc/nv_internal/cpp/common/envUtils.cpp (2 hunks)
  • csrc/nv_internal/tensorrt_llm/common/envUtils.h (2 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu (1 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h (1 hunks)
  • csrc/nv_internal/tensorrt_llm/thop/moeAlltoAllMeta.h (1 hunks)
  • csrc/trtllm_moe_a2a.cu (1 hunks)
  • docs/api/comm.rst (1 hunks)
  • flashinfer/aot.py (1 hunks)
  • flashinfer/comm/__init__.py (1 hunks)
  • flashinfer/comm/trtllm_moe_alltoall.py (1 hunks)
  • flashinfer/jit/__init__.py (1 hunks)
  • flashinfer/jit/comm.py (1 hunks)
  • scripts/task_test_multi_node_comm_kernels.sh (1 hunks)
  • tests/comm/test_mnnvl_memory.py (1 hunks)
  • tests/comm/test_mnnvl_moe_alltoall.py (1 hunks)
  • tests/comm/test_trtllm_moe_alltoall.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (11)
flashinfer/jit/__init__.py (1)
flashinfer/jit/comm.py (1)
  • gen_mnnvl_a2a_module (83-109)
csrc/nv_internal/tensorrt_llm/common/envUtils.h (1)
csrc/nv_internal/cpp/common/envUtils.cpp (8)
  • getEnvKVCacheTimeOutputPath (275-278)
  • getEnvKVCacheTimeOutputPath (275-275)
  • getEnvMoeA2AOneBlockPerToken (326-333)
  • getEnvMoeA2AOneBlockPerToken (326-326)
  • getEnvMoeA2ADispatchBlockSize (347-350)
  • getEnvMoeA2ADispatchBlockSize (347-347)
  • getEnvMoeA2ACombineBlockSize (352-355)
  • getEnvMoeA2ACombineBlockSize (352-352)
tests/comm/test_mnnvl_memory.py (1)
flashinfer/comm/mapping.py (1)
  • local_rank (391-392)
csrc/nv_internal/tensorrt_llm/thop/moeAlltoAllMeta.h (1)
csrc/trtllm_moe_a2a.cu (2)
  • getMoeA2AMetaInfoIndexPairs (395-407)
  • getMoeA2AMetaInfoIndexPairs (395-395)
csrc/trtllm_moe_a2a.cu (1)
csrc/nv_internal/cpp/common/envUtils.cpp (2)
  • getEnvMoeA2AOneBlockPerToken (326-333)
  • getEnvMoeA2AOneBlockPerToken (326-326)
flashinfer/aot.py (1)
flashinfer/jit/comm.py (1)
  • gen_mnnvl_a2a_module (83-109)
flashinfer/jit/comm.py (1)
flashinfer/jit/core.py (2)
  • JitSpec (213-312)
  • gen_jit_spec (315-381)
csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu (1)
csrc/nv_internal/cpp/common/envUtils.cpp (4)
  • getEnvMoeA2ADispatchBlockSize (347-350)
  • getEnvMoeA2ADispatchBlockSize (347-347)
  • getEnvMoeA2ACombineBlockSize (352-355)
  • getEnvMoeA2ACombineBlockSize (352-352)
flashinfer/comm/trtllm_moe_alltoall.py (4)
flashinfer/comm/mnnvl.py (5)
  • MnnvlMemory (232-551)
  • MnnvlConfig (224-229)
  • as_torch_strided_tensor (264-273)
  • initialize (276-285)
  • set_comm_from_config (288-293)
flashinfer/comm/mapping.py (2)
  • Mapping (21-475)
  • moe_ep_rank (349-350)
flashinfer/jit/comm.py (1)
  • gen_mnnvl_a2a_module (83-109)
include/flashinfer/trtllm/fused_moe/runner.h (1)
  • num_experts (263-263)
tests/comm/test_mnnvl_moe_alltoall.py (2)
flashinfer/comm/trtllm_moe_alltoall.py (4)
  • MoeAlltoAll (336-618)
  • dispatch (484-541)
  • get_combine_payload_tensor_in_workspace (585-618)
  • combine (543-583)
flashinfer/comm/mnnvl.py (3)
  • MnnvlMemory (232-551)
  • initialize (276-285)
  • supports_mnnvl (545-551)
csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h (3)
csrc/nv_internal/tensorrt_llm/common/envUtils.h (1)
  • tensorrt_llm (23-104)
csrc/nv_internal/tensorrt_llm/thop/moeAlltoAllMeta.h (1)
  • mnnvl_throughput (25-58)
csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu (8)
  • moe_a2a_dispatch_launch (445-506)
  • moe_a2a_dispatch_launch (445-445)
  • moe_a2a_prepare_dispatch_launch (436-439)
  • moe_a2a_prepare_dispatch_launch (436-436)
  • moe_a2a_combine_launch (792-842)
  • moe_a2a_combine_launch (792-792)
  • moe_a2a_sanitize_expert_ids_launch (864-872)
  • moe_a2a_sanitize_expert_ids_launch (864-866)
🪛 Clang (14.0.6)
csrc/nv_internal/tensorrt_llm/thop/moeAlltoAllMeta.h

[error] 19-19: 'array' file not found

(clang-diagnostic-error)

csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h

[error] 18-18: 'cuda_bf16.h' file not found

(clang-diagnostic-error)

🪛 Ruff (0.14.6)
flashinfer/comm/trtllm_moe_alltoall.py

351-351: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


443-443: Avoid specifying long messages outside the exception class

(TRY003)


445-445: Avoid specifying long messages outside the exception class

(TRY003)


606-608: Avoid specifying long messages outside the exception class

(TRY003)


621-628: __all__ is not sorted

Apply an isort-style sorting to __all__

(RUF022)

tests/comm/test_mnnvl_moe_alltoall.py

34-34: Avoid specifying long messages outside the exception class

(TRY003)


42-42: Use raise without specifying exception name

Remove exception name

(TRY201)


46-46: Use raise without specifying exception name

Remove exception name

(TRY201)


575-575: Do not catch blind exception: Exception

(BLE001)


676-676: Do not catch blind exception: Exception

(BLE001)


710-710: Unpacked variable expert_id_payload_index is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🔇 Additional comments (38)
csrc/nv_internal/cpp/common/envUtils.cpp (2)

275-277: KV cache time output path now cached by reference – behavior looks consistent

Returning std::string const& to a function-local static is safe here and matches the pattern used by other helpers in this file (single evaluation of the env var, cheap repeated access). The only behavior change is that changes to TRTLLM_KVCACHE_TIME_OUTPUT_PATH after the first call are no longer observed, which is usually fine for env-based config.

If you don’t rely on mutating the env mid-process (including in tests), this change looks good to me.


326-355: Based on my analysis with the available information, I can now rewrite the review comment:

Verify MOE A2A environment variable prefixes (TLLM_ vs TRTLLM_) against project documentation and call sites; fix misleading comment in sanitizeBlockSize to reflect "round up" instead of "nearest"

The new helpers follow the file's consistent patterns (single-read statics, sane defaults), but need attention in two areas:

  1. Env var prefixes for MoE A2A knobs

    • These functions read TLLM_MOE_A2A_ONE_BLOCK_PER_TOKEN, TLLM_MOE_A2A_DISPATCH_BLOCK_SIZE, and TLLM_MOE_A2A_COMBINE_BLOCK_SIZE
    • Review the project documentation, call sites, and surrounding code to confirm whether the TLLM_ prefix matches intended usage; most other envs in this file appear to use TRTLLM_ prefix and silently using the wrong names would be problematic.
  2. sanitizeBlockSize comment accuracy

    • The implementation rounds up to a multiple of 32: block = (block + 31) / 32 * 32; (e.g., 33 → 64, not 32)
    • The comment currently states "Round to nearest multiple of 32 (warp size)", which is misleading.

    Fix the comment:

  • // Round to nearest multiple of 32 (warp size)
  • // Round up to the next multiple of 32 (warp size)
    
    The extra `if (block == 0) block = 256;` check after clamping is redundant but harmless.
    
    
csrc/nv_internal/tensorrt_llm/common/envUtils.h (1)

95-102: LGTM! New MoE A2A environment variable accessors are properly declared.

The three new accessor functions (getEnvMoeA2AOneBlockPerToken, getEnvMoeA2ADispatchBlockSize, getEnvMoeA2ACombineBlockSize) are well-documented with default behaviors and align with their implementations in envUtils.cpp.

csrc/nv_internal/tensorrt_llm/thop/moeAlltoAllMeta.h (3)

19-22: Static analysis false positive - standard headers are valid.

The Clang error about <array> not being found is a false positive. This is a standard C++11 header that should be available in any modern C++ environment. The includes are correct.


28-43: LGTM! Well-structured metadata index enum.

The MoeA2AMetaInfoIndex enum provides clear, sequential indexing for metadata fields with NUM_METAINFO_FIELDS = 9 correctly representing the count of actual data fields (0-8). The MoeA2ADataOffsets type alias correctly uses this count for the array size.


45-58: LGTM! Useful name-to-index mapping function.

The inline getMoeA2AMetaInfoIndexPairs() function provides a clean way to expose metadata field names and their corresponding indices, which is consumed by the TVM FFI interface in csrc/trtllm_moe_a2a.cu.

csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu (4)

114-116: Timeout calculation looks reasonable.

The timeout of 300ll * 2000ll * 1000ll * 1000ll cycles (~600 billion) translates to approximately 300 seconds at 2 GHz, which provides a generous upper bound to prevent infinite hangs while allowing ample time for legitimate synchronization delays. The calculation avoids overflow by using long long literals.


272-276: LGTM! Single-threaded flag increment is safe.

The flag_val increment occurs only when idx == 0, ensuring single-threaded access. Since this kernel runs sequentially in the stream before the dispatch kernel, there's no race condition.


844-872: LGTM! Sanitize kernel implementation is correct.

The kernel correctly identifies invalid tokens (where token_idx >= recv_counters[source_rank]) and sets their expert IDs to invalid_id. Each thread operates on disjoint memory locations, avoiding any race conditions.


315-344: Based on my investigation, I cannot access the repository or find public documentation for the kMaxRanks constant definition. The repository clone failed, and web searches returned no results for this internal NVIDIA TensorRT-LLM code.

However, the core concern raised in the review comment remains valid and cannot be conclusively verified without access to:

  1. The header file containing kMaxRanks definition
  2. The actual value of kMaxRanks
  3. Runtime validation constraints on ep_size (ensemble parallel size)

The potential undefined behavior is legitimate: if target_rank can be 64 or greater, the bit shift 1ULL << target_rank on a uint64_t would indeed cause undefined behavior in C++.


Verify kMaxRanks does not exceed 64 to avoid undefined behavior.

The already_copied bitmask uses uint64_t with bit operations 1ULL << target_rank. If target_rank can be 64 or greater, this causes undefined behavior (shifting by >= width of type). The code validates params.ep_size <= kMaxRanks at line 448, so ensure kMaxRanks is defined as ≤ 64 in the header.

flashinfer/aot.py (1)

515-522: LGTM! MNNVL A2A module integration follows existing patterns.

The new gen_mnnvl_a2a_module is correctly imported within the add_comm block and added under the has_sm100 condition, consistent with the existing gen_trtllm_comm_module and gen_trtllm_mnnvl_comm_module placement.

tests/comm/test_mnnvl_memory.py (1)

125-125: LGTM! Correct device selection for multi-node scenarios.

Using self.local_rank instead of self.rank is the correct fix for multi-node setups where the global rank may exceed the local GPU count. This aligns with the setup fixture (line 51) and the Mapping.local_rank property shown in the relevant snippet.

scripts/task_test_multi_node_comm_kernels.sh (1)

17-19: LGTM - new MoE A2A test coverage added.

The addition of test_mnnvl_moe_alltoall.py aligns with the new MoE A2A functionality introduced in this PR.

flashinfer/jit/__init__.py (1)

80-80: LGTM!

The re-export follows the established pattern for module generators in this file.

flashinfer/comm/__init__.py (1)

42-52: LGTM!

The new MoE A2A exports follow the established import patterns and properly expose the public API surface for the throughput backend.

csrc/trtllm_moe_a2a.cu (6)

53-99: LGTM - offset calculations and workspace sizing are well-structured.

The alignment logic using cache-line boundaries (128 bytes) is appropriate for GPU memory access patterns. The offset calculations correctly account for metadata fields and payload regions.


101-123: LGTM - initialization logic is correct.

Input validation is comprehensive. The cudaStreamSynchronize at line 119 ensures the memset completes before returning, which is necessary for correctness.


132-262: LGTM - dispatch implementation is solid.

The function has comprehensive input validation, proper payload descriptor setup, workspace size verification, and appropriate CUDA error checking after kernel launch.


264-277: LGTM!

The dtype conversion correctly handles the common floating-point types used in MoE computations. The unreachable return after the throw silences compiler warnings.


279-361: LGTM - combine implementation is well-validated.

The function properly validates all inputs including the optional payloadInWorkspace optimization for zero-copy scenarios. Error handling is comprehensive.


393-416: LGTM!

The metainfo accessor correctly wraps the C++ pairs for Python accessibility, and all FFI exports are properly registered.

flashinfer/jit/comm.py (1)

83-109: I'll verify whether architecture-specific NVCC flags are required for the gen_mnnvl_a2a_module function. Let me search the codebase for relevant information.
<function_calls>


#!/bin/bash

Find the moeAlltoAllKernels.cu file

find . -name "moeAlltoAllKernels.cu" -type f 2>/dev/null


</function_calls>
<function_calls>


#!/bin/bash

Check if the file exists and search for architecture-specific patterns

if [ -f "csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu" ]; then
echo "=== Checking for SM guards and architecture macros ==="
rg -n 'sm|SM[0-9]|CUDA_ARCH|CUDA_SM' "csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu" || echo "No explicit SM guards found"

echo ""
echo "=== Checking for tensor core/MMA operations ==="
rg -n 'mma|wmma|tensorop|__mma|__hmma' "csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu" || echo "No tensor core intrinsics found"
else
echo "File not found at expected path"
fi


</function_calls>

<function_calls>


#!/bin/bash

Look at the comm.py file to understand nvcc_flags definition and usage

echo "=== Current comm.py around nvcc_flags definition ==="
rg -n 'nvcc_flags' flashinfer/jit/comm.py -B3 -A3

echo ""
echo "=== All functions in comm.py to see pattern ==="
rg -n '^def gen_' flashinfer/jit/comm.py


</function_calls>

<function_calls>


#!/bin/bash

Let's see the full context of comm.py to understand nvcc_flags and module patterns

wc -l flashinfer/jit/comm.py


</function_calls>

<function_calls>


#!/bin/bash

Read the entire comm.py file to understand structure and nvcc_flags

cat -n flashinfer/jit/comm.py


</function_calls>

<function_calls>


#!/bin/bash

Check what nvcc_flags is and where it comes from

rg -n 'nvcc_flags\s*=' flashinfer/jit/comm.py -B2 -A2

echo ""
echo "=== Check for imports in comm.py ==="
head -30 flashinfer/jit/comm.py


</function_calls>

<function_calls>


#!/bin/bash

Check the jit module structure

ls -la flashinfer/jit/


</function_calls>

<function_calls>


#!/bin/bash

Check jit_env for nvcc_flags

rg -n 'nvcc_flags' flashinfer/jit/ -A2 -B2


</function_calls>

tests/comm/test_trtllm_moe_alltoall.py (6)

25-29: Session-scoped fixture with torch.manual_seed may not reset seed between tests.

The fixture sets the seed once at session start but individual tests may want reproducible seeds. Consider if this is intentional or if test-level seeding is needed.


62-69: Good defensive check for SM resources.

This helper appropriately skips tests when insufficient SMs are available for parallel kernel execution. The check prevents hangs on systems with limited GPU resources.


72-84: LGTM!

The make_payload helper correctly distinguishes between integer and floating-point types for random tensor generation.


138-141: Sorting approach for validation is reasonable but fragile.

Sorting both input and output tensors to compare them works for this test case but relies on unique values. If there are duplicate values, the sort order could differ. Consider documenting this assumption or using a more robust comparison.


388-429: Reference implementation for fake_moe looks correct.

The fake_moe function properly handles expert parallelism filtering and accumulation. The tree reduction comment on line 423 correctly explains why results are summed after collection.


530-536: Relatively loose tolerance for numerical comparison.

Using atol=1.5e-2 and rtol=1.5e-2 is quite loose for bf16/fp16. This may mask precision issues. Verify this tolerance is intentional given the accumulation order differences mentioned elsewhere.

tests/comm/test_mnnvl_moe_alltoall.py (3)

293-293: Direct modification of class variable _WORKSPACE is concerning.

Setting MoeAlltoAll._WORKSPACE = None directly before instantiation suggests test isolation concerns. This should be documented or handled via a proper reset method.

Consider whether _reset_workspace() method from MoeAlltoAll should be used instead, or if this pattern is intentional for test setup.


800-813: Good documentation of tolerance rationale.

The comment on line 809 clearly explains why a 99% match threshold is used instead of exact comparison due to bf16 accumulation order differences. This is helpful for future maintainers.


836-838: Helpful run instructions in docstring.

The comment showing how to run with mpirun is useful for developers unfamiliar with MPI testing.

flashinfer/comm/trtllm_moe_alltoall.py (3)

353-383: Workspace caching strategy looks correct.

The caching by (workspace_size_per_rank, ep_rank, ep_size, max_num_tokens) tuple allows reusing workspaces across instances with compatible configurations. This addresses the past review comment about supporting different shaped communicators.


470-482: _reset_workspace method deletes from class cache without thread safety.

If multiple threads could access this class simultaneously, the del operation on _WORKSPACE_CACHE could cause issues. Document that this method is not thread-safe.

     def _reset_workspace(self):
-        """Reset the workspace to free up its state. This is mainly used for testing. Use this with caution. This object is no longer usable after this."""
+        """Reset the workspace to free up its state.
+        
+        Warning: This method is not thread-safe and is mainly used for testing.
+        This object is no longer usable after calling this method.
+        """

505-508: Good use of state machine pattern for dispatch/combine sequencing.

The phase checking prevents calling dispatch twice without combine and ensures proper operation ordering. This is a clean design.

csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h (4)

17-19: Static analysis reports missing cuda_bf16.h - this is a false positive.

The cuda_bf16.h and cuda_fp16.h headers are provided by the CUDA toolkit and will be available during compilation with nvcc. Static analysis tools without CUDA environment cannot find these headers.


23-27: Configuration constants are well-documented and reasonable.

The limits (256 experts, 8 top-k, 8 payloads, 64 ranks) provide good flexibility while keeping fixed-size arrays manageable. Consider whether these should be configurable at runtime if larger deployments are anticipated.


173-179: Function declarations are clean and match the implementation.

The kernel launch function declarations align with the implementations shown in the relevant code snippets from moeAlltoAllKernels.cu.


148-148: Unable to verify include configuration due to repository access failure.

The repository clone failed, preventing me from examining the file's include structure, verifying whether nvinfer1::DataType is actually used, or confirming if the necessary headers are already present. Manual verification is required to confirm:

  1. Whether csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h currently includes NvInfer headers
  2. Whether nvinfer1::DataType is actually declared in the file or included transitively
  3. Whether the code compiles successfully without the suggested include

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (3)
flashinfer/comm/trtllm_moe_alltoall.py (3)

379-379: Annotate mutable class attribute with ClassVar.

Per Python best practices, mutable class attributes should be annotated with ClassVar to make clear they are shared across instances.

+from typing import ClassVar
+
 class MoeAlltoAll:
     ...
-    _WORKSPACE_CACHE: dict[tuple[int, int, int, int], dict] = {}
+    _WORKSPACE_CACHE: ClassVar[dict[tuple[int, int, int, int], dict]] = {}

638-638: Consider using torch.finfo or torch.iinfo for element size.

Creating an empty tensor just to get element size has minor overhead. Consider using dtype introspection directly.

-        element_size = torch.tensor([], dtype=dtype).element_size()
+        element_size = torch._utils._element_size(dtype)

Alternatively, keep the current approach if you prefer avoiding private APIs.


649-656: Consider adding moe_a2a_wrap_payload_tensor_in_workspace to __all__.

This function is used in tests and appears to be part of the public API. Also consider sorting __all__ for consistency.

 __all__ = [
     "MoeAlltoAll",
+    "moe_a2a_combine",
     "moe_a2a_initialize",
     "moe_a2a_dispatch",
-    "moe_a2a_combine",
+    "moe_a2a_get_workspace_size_per_rank",
     "moe_a2a_sanitize_expert_ids",
-    "moe_a2a_get_workspace_size_per_rank",
+    "moe_a2a_wrap_payload_tensor_in_workspace",
 ]
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 71bb8fb and 222a2e8.

📒 Files selected for processing (6)
  • csrc/trtllm_moe_alltoall.cu (1 hunks)
  • flashinfer/aot.py (1 hunks)
  • flashinfer/comm/trtllm_moe_alltoall.py (1 hunks)
  • flashinfer/jit/__init__.py (1 hunks)
  • flashinfer/jit/comm.py (1 hunks)
  • tests/comm/test_trtllm_moe_alltoall.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • flashinfer/aot.py
  • flashinfer/jit/comm.py
🧰 Additional context used
🧬 Code graph analysis (2)
csrc/trtllm_moe_alltoall.cu (2)
csrc/tvm_ffi_utils.h (3)
  • Tensor (282-284)
  • get_current_stream (266-270)
  • encode_dlpack_dtype (29-31)
flashinfer/comm/trtllm_moe_alltoall.py (6)
  • moe_a2a_get_workspace_size_per_rank (175-198)
  • moe_a2a_get_workspace_size_per_rank (350-361)
  • moe_a2a_initialize (41-47)
  • moe_a2a_initialize (210-218)
  • moe_a2a_dispatch (53-93)
  • moe_a2a_dispatch (251-309)
tests/comm/test_trtllm_moe_alltoall.py (7)
flashinfer/comm/mapping.py (1)
  • Mapping (21-475)
tests/test_helpers/test_helpers.py (1)
  • get_device_properties (10-11)
include/flashinfer/trtllm/fused_moe/runner.h (1)
  • num_experts (263-263)
flashinfer/comm/trtllm_moe_alltoall.py (13)
  • moe_a2a_get_workspace_size_per_rank (175-198)
  • moe_a2a_get_workspace_size_per_rank (350-361)
  • MoeAlltoAll (364-646)
  • dispatch (512-569)
  • get_combine_payload_tensor_in_workspace (613-646)
  • combine (571-611)
  • moe_a2a_initialize (41-47)
  • moe_a2a_initialize (210-218)
  • moe_a2a_dispatch (53-93)
  • moe_a2a_dispatch (251-309)
  • moe_a2a_sanitize_expert_ids (146-155)
  • moe_a2a_sanitize_expert_ids (338-347)
  • moe_a2a_wrap_payload_tensor_in_workspace (221-248)
flashinfer/fused_moe/utils.py (1)
  • _ (157-163)
csrc/xqa/mha.cu (1)
  • any (157-157)
tests/comm/test_mnnvl_moe_alltoall.py (1)
  • fake_moe (121-181)
🪛 Ruff (0.14.6)
flashinfer/comm/trtllm_moe_alltoall.py

379-379: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


471-471: Avoid specifying long messages outside the exception class

(TRY003)


473-473: Avoid specifying long messages outside the exception class

(TRY003)


634-636: Avoid specifying long messages outside the exception class

(TRY003)


649-656: __all__ is not sorted

Apply an isort-style sorting to __all__

(RUF022)

🔇 Additional comments (16)
flashinfer/jit/__init__.py (1)

80-80: LGTM!

The new import follows the established pattern for re-exporting JIT module generators from the comm submodule.

csrc/trtllm_moe_alltoall.cu (5)

53-88: LGTM!

The offset calculation logic correctly aligns data structures to cache-line boundaries where needed, following a clear sequential layout pattern.


101-123: LGTM!

The initialization correctly validates inputs, zeros the workspace region, and returns metadata offsets. The stream synchronization ensures the workspace is properly initialized before returning.


278-360: LGTM!

The combine operation has thorough input validation, properly handles the workspace-backed payload case, and includes appropriate error checking after kernel launch.


362-394: LGTM!

The sanitization operation correctly validates inputs and launches the kernel with proper error checking.


396-419: LGTM!

The metainfo index pairs helper provides a clean mechanism to expose C++ constants to Python, and all required functions are properly exported.

tests/comm/test_trtllm_moe_alltoall.py (6)

74-86: LGTM!

The payload generator correctly handles both integer and floating-point dtypes for test data generation.


93-161: LGTM!

Comprehensive single-GPU test covering multiple payload dtypes, dispatch/combine workflow, and workspace-backed tensor operations.


164-240: LGTM!

The helper correctly simulates multi-rank dispatch on a single GPU using separate CUDA streams, with proper synchronization.


302-344: LGTM!

The multi-rank test correctly validates token routing across simulated ranks with proper verification of payload delivery.


390-431: LGTM!

The reference MoE implementation provides a deterministic baseline for verifying combine correctness, with appropriate handling of expert-parallel scenarios.


434-551: LGTM!

Comprehensive combine test covering multiple dtypes, workspace configurations, and ranks with appropriate numerical tolerances for reduced-precision arithmetic.

flashinfer/comm/trtllm_moe_alltoall.py (4)

32-207: LGTM!

The JIT module getter follows the established pattern with proper caching and custom op registration.


221-248: LGTM!

The function correctly creates a workspace-backed tensor view with properly documented parameters.


470-473: LGTM!

The validation logic is appropriate and the exception messages are concise.


498-510: LGTM!

The reset method appropriately handles workspace cleanup for testing scenarios, with clear documentation about post-call state.

@djns99 djns99 force-pushed the djns99/update-trtllm-kernels branch from 6e9bed5 to a51b1ea Compare November 28, 2025 02:42
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (6)
csrc/trtllm_moe_alltoall.cu (1)

263-276: Consider extending dtype support for future flexibility.

The toNvDataType function currently supports half, bfloat16, and float32. Consider documenting supported types or adding int8/fp8 support if those are common in MoE workloads.

tests/comm/test_mnnvl_moe_alltoall.py (3)

711-712: Unused variable is intentional; consider underscore prefix per Ruff hint.

The expert_id_payload_index is returned by the helper but not used in this test. Consider renaming to _expert_id_payload_index to signal intentional discard.

-    payloads, expert_id_payload_index = make_bfloat16_payloads(
+    payloads, _expert_id_payload_index = make_bfloat16_payloads(

293-294: Setting class attribute _WORKSPACE = None may conflict with class-level cache.

Assigning MoeAlltoAll._WORKSPACE = None resets a non-existent instance attribute. The class uses _WORKSPACE_CACHE for caching. This assignment has no effect but is misleading.

Consider removing this line or using MoeAlltoAll._WORKSPACE_CACHE.clear() if the intent is to reset the cache:

-    MoeAlltoAll._WORKSPACE = None
+    MoeAlltoAll._WORKSPACE_CACHE.clear()

742-742: Same issue: _WORKSPACE = None assignment is ineffective.

This line also sets a non-existent attribute. Consider removing or using _WORKSPACE_CACHE.clear().

-    MoeAlltoAll._WORKSPACE = None
flashinfer/comm/trtllm_moe_alltoall.py (2)

375-377: Annotate mutable class attribute with ClassVar per Ruff hint.

The _WORKSPACE_CACHE is a mutable class-level attribute that should be annotated with ClassVar to make the intent clear.

+from typing import ClassVar
+
 class MoeAlltoAll:
     ...
-    _WORKSPACE_CACHE: dict[tuple[int, int, int, int], dict] = {}
+    _WORKSPACE_CACHE: ClassVar[dict[tuple[int, int, int, int], dict]] = {}

496-508: Document that _reset_workspace invalidates the instance.

The docstring mentions this but it's critical: after calling _reset_workspace, the object is unusable. Consider adding a stronger warning or raising an exception on subsequent method calls.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d375afe and db22fce.

📒 Files selected for processing (4)
  • csrc/trtllm_moe_alltoall.cu (1 hunks)
  • flashinfer/comm/trtllm_moe_alltoall.py (1 hunks)
  • tests/comm/test_mnnvl_moe_alltoall.py (1 hunks)
  • tests/comm/test_trtllm_moe_alltoall.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
csrc/trtllm_moe_alltoall.cu (2)
csrc/tvm_ffi_utils.h (3)
  • Tensor (282-284)
  • get_current_stream (266-270)
  • encode_dlpack_dtype (29-31)
flashinfer/comm/trtllm_moe_alltoall.py (11)
  • moe_a2a_get_workspace_size_per_rank (173-196)
  • moe_a2a_get_workspace_size_per_rank (348-359)
  • moe_a2a_initialize (39-45)
  • moe_a2a_initialize (208-216)
  • moe_a2a_dispatch (51-91)
  • moe_a2a_dispatch (249-307)
  • moe_a2a_combine (97-138)
  • moe_a2a_combine (310-333)
  • moe_a2a_sanitize_expert_ids (144-153)
  • moe_a2a_sanitize_expert_ids (336-345)
  • moe_a2a_get_metainfo_index_pairs (159-167)
flashinfer/comm/trtllm_moe_alltoall.py (3)
flashinfer/comm/mnnvl.py (5)
  • MnnvlMemory (232-551)
  • MnnvlConfig (224-229)
  • as_torch_strided_tensor (264-273)
  • initialize (276-285)
  • set_comm_from_config (288-293)
flashinfer/comm/mapping.py (2)
  • Mapping (21-475)
  • moe_ep_rank (349-350)
flashinfer/jit/comm.py (1)
  • gen_mnnvl_moe_alltoall_module (83-109)
tests/comm/test_mnnvl_moe_alltoall.py (3)
flashinfer/comm/trtllm_moe_alltoall.py (4)
  • MoeAlltoAll (362-644)
  • dispatch (510-567)
  • get_combine_payload_tensor_in_workspace (611-644)
  • combine (569-609)
flashinfer/comm/mapping.py (1)
  • Mapping (21-475)
flashinfer/comm/mnnvl.py (3)
  • MnnvlMemory (232-551)
  • initialize (276-285)
  • supports_mnnvl (545-551)
🪛 Ruff (0.14.6)
flashinfer/comm/trtllm_moe_alltoall.py

377-377: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


469-469: Avoid specifying long messages outside the exception class

(TRY003)


471-471: Avoid specifying long messages outside the exception class

(TRY003)


632-634: Avoid specifying long messages outside the exception class

(TRY003)

tests/comm/test_mnnvl_moe_alltoall.py

34-34: Avoid specifying long messages outside the exception class

(TRY003)


576-576: Do not catch blind exception: Exception

(BLE001)


677-677: Do not catch blind exception: Exception

(BLE001)


711-711: Unpacked variable expert_id_payload_index is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🔇 Additional comments (44)
csrc/trtllm_moe_alltoall.cu (8)

1-51: LGTM! Well-structured header and utility functions.

The license header, includes, and utility functions are appropriately organized. The alignOffset function correctly implements cache-line alignment using bitwise operations.


53-88: LGTM! Offset calculation is well-structured.

The calculateOffsets function properly calculates memory offsets for various MoE A2A data structures with appropriate alignment for cache-line boundaries.


90-99: LGTM! Workspace size calculation.

The workspace size calculation correctly accounts for metadata, payload, and combine regions with proper alignment.


101-123: LGTM! Initialize operation with proper validation.

Good input validation for workspace dimensions, rank bounds, and proper error checking for CUDA operations. The synchronization before returning metainfo is appropriate.


125-261: LGTM! Dispatch operation is well-implemented.

The dispatch function has comprehensive input validation, proper payload descriptor handling, and correct workspace pointer arithmetic. Error checking after kernel launch is appropriate.


278-360: LGTM! Combine operation with proper validation.

The combine function correctly validates payload dimensions, workspace pointer alignment, and handles the payloadInWorkspace flag appropriately. Error checking after kernel launch is proper.


362-394: LGTM! Sanitize operation is correctly implemented.

Proper input validation and error checking for the sanitize expert IDs kernel.


396-419: LGTM! Metainfo export and FFI registration.

The metainfo index pairs function and TVM FFI exports are correctly implemented, providing clean Python interoperability.

tests/comm/test_trtllm_moe_alltoall.py (12)

25-29: LGTM! Docstring has been corrected.

The fixture docstring now accurately describes that it sets the torch seed for deterministic tests.


32-60: Good test parameter coverage.

The test parameters cover a good range of configurations (small, medium, large) for both single-GPU and multi-rank scenarios, with various dtypes and payload configurations.


63-72: Good resource-aware skip logic.

The SM count check appropriately skips tests when hardware resources are insufficient, preventing false failures on less capable GPUs.


74-86: LGTM! Payload generation helper.

The make_payload function correctly handles both integer and floating-point dtypes with appropriate random value generation.


89-162: Comprehensive single-GPU test with proper verification.

The test covers dispatch and combine flows with multiple dtypes, validates output via sorting and exact comparison, and tests the workspace-backed combine path.


164-240: LGTM! Multi-rank dispatch helper is well-structured.

The helper properly manages workspaces, initializes per-rank metadata, uses separate CUDA streams for parallel execution, and synchronizes appropriately.


243-259: LGTM! Sanitize helper function.

Simple and correct delegation to the underlying sanitize function for each rank.


262-299: LGTM! Combine helper with parallel execution.

The combine helper correctly uses separate streams per rank and synchronizes before returning results.


302-345: LGTM! Multi-rank test with proper verification.

Good verification logic that filters non-zero tensors and compares sorted outputs against the reference filtered by expert assignment.


347-388: LGTM! Sanitize test with comprehensive verification.

The test properly clones tensors before sanitization to enable before/after comparison and correctly verifies the sanitization logic.


390-431: LGTM! Reference MoE implementation for verification.

The fake_moe function provides a clear reference implementation for verifying the distributed MoE behavior, with proper EP-rank filtering logic.


434-555: Good end-to-end combine test with tolerance handling.

The test covers the full dispatch-process-combine cycle with both in-workspace and external payload paths. The tolerance values for bf16 are reasonable.

tests/comm/test_mnnvl_moe_alltoall.py (11)

27-46: MPI error handling utilities are well-designed.

The MPIExit exception, check_any_rank_failed, and safe_run pattern provide robust MPI coordination for test failures across ranks, ensuring clean error propagation.


49-52: LGTM! Test fixture for deterministic seeding.


55-88: LGTM! Helper functions for expert routing and token generation.

compute_target_rank_id correctly implements contiguous expert partitioning, and generate_token_selected_experts properly generates random expert assignments.


91-119: LGTM! Expert weight creation with reproducible seeding.

Using ep_rank * 1000 + i as a seed ensures reproducibility across runs while differentiating experts per rank.


122-182: LGTM! Comprehensive fake MoE reference implementation.

The function correctly handles both EP-rank and global modes with proper local expert ID conversion.


185-258: LGTM! Payload creation helpers.

Both NV FP4 and BFloat16 payload creators are well-structured with appropriate rank-specific patterns for verification.


261-383: LGTM! Single-rank dispatch worker function.

Comprehensive workspace setup, dispatch execution, and metadata extraction for MPI-based testing.


386-556: LGTM! Thorough dispatch verification.

The verify_dispatch function provides exhaustive validation of shapes, dtypes, counters, routing, and payload content. This is excellent for catching regressions.


572-577: Bare Exception catch is intentional for MNNVL availability check.

The broad exception catch here is acceptable since it's used to detect MNNVL support availability across various failure modes (driver issues, missing hardware, etc.).


673-678: Bare Exception catch is acceptable for capability detection.

Same as above - this is intentional for gracefully skipping tests on systems without MNNVL support.


657-818: LGTM! Full dispatch+combine cycle test.

The test properly verifies the complete MoE A2A workflow with appropriate tolerance for bf16 accumulation order differences. The 99% match threshold is reasonable given the expected numerical variations.

flashinfer/comm/trtllm_moe_alltoall.py (13)

1-19: LGTM! Module header and imports are well-organized.

Clean module docstring and appropriate imports for the MoE A2A functionality.


21-28: LGTM! State dataclass is appropriate.

The _A2AState dataclass cleanly tracks the dispatch/combine lifecycle with appropriate fields.


30-205: LGTM! JIT module initialization with custom op registration.

The lazy module loading with @functools.cache and custom op registration is well-structured. The returned SimpleNamespace provides a clean API surface.


208-246: LGTM! Public wrapper functions are clean delegations.

The top-level moe_a2a_* functions provide clean interfaces to the JIT module, with appropriate docstrings where needed.


249-307: LGTM! Dispatch wrapper with tensor wrapping.

The dispatch function correctly wraps the raw offsets into workspace-backed tensors for each payload.


310-359: LGTM! Combine, sanitize, and workspace size wrappers.

Clean delegation to the underlying JIT module.


379-409: LGTM! Workspace caching with proper key management.

The get_workspace classmethod correctly caches workspaces by configuration tuple, preventing redundant allocations.


411-432: LGTM! Lazy metainfo constant initialization.

The _init_constants method properly strips prefixes for a cleaner Python API.


434-494: LGTM! Constructor with proper validation and MNNVL configuration.

Good input validation for top_k and num_experts, with optional MnnvlConfig support as discussed in past reviews.


510-567: LGTM! Dispatch method with proper state management.

Good state assertions, lifecycle tracking, and optional sanitization flow.


569-609: LGTM! Combine method with state reset.

Proper state validation and reset after combine completes, enabling the next dispatch/combine cycle.


611-645: LGTM! Workspace-backed tensor accessor.

The get_combine_payload_tensor_in_workspace method correctly computes slice bounds and validates state.


647-654: LGTM! Clean __all__ export list.

Explicitly defines the public API surface.

@aleozlx
Copy link
Collaborator

aleozlx commented Dec 3, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !174 has been created, and the CI pipeline #39564968 is currently running. I'll report back once the pipeline job completes.

@djns99 djns99 requested a review from kahyunnam as a code owner December 3, 2025 22:22
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu (1)

748-786: Block size inconsistency between prepare and combine.

Line 749 hardcodes kBlockSize = 256 in moe_a2a_prepare_combine_launch, while moe_a2a_combine_launch (line 800) uses the configurable getEnvMoeA2ACombineBlockSize(). This creates an inconsistency if the environment variable TLLM_MOE_A2A_COMBINE_BLOCK_SIZE is set to a different value.

Apply this diff to use the same configurable block size:

 void moe_a2a_prepare_combine_launch(MoeA2ACombineParams const& params) {
-  constexpr int kBlockSize = 256;
-  constexpr int kWarpsPerBlock = kBlockSize / 32;  // 8 warps per block
+  int const kBlockSize = tensorrt_llm::common::getEnvMoeA2ACombineBlockSize();
+  int const kWarpsPerBlock = kBlockSize / 32;
🧹 Nitpick comments (5)
csrc/nv_internal/tensorrt_llm/common/envUtils.h (1)

95-102: Consider enhancing API documentation.

The documentation for the new MoE A2A configuration functions could be more complete:

  • Environment variable names are not specified (e.g., TLLM_MOE_A2A_ONE_BLOCK_PER_TOKEN, TLLM_MOE_A2A_DISPATCH_BLOCK_SIZE, TLLM_MOE_A2A_COMBINE_BLOCK_SIZE)
  • Valid range (256-1024) and alignment (32) for block size functions are not documented
  • The sanitization behavior (clamping and rounding) is not mentioned

Adding these details would help users understand the expected behavior without consulting the implementation.

csrc/nv_internal/cpp/common/envUtils.cpp (1)

335-345: Unreachable code at line 343.

After the rounding operation on line 342, block can never be 0. The minimum input to the rounding is 1 (when val is negative or 0, it gets set to 256; otherwise minimum positive after clamping is 1). The formula (block + 31) / 32 * 32 yields at least 32 for any block >= 1.

Consider removing the dead code:

 static int sanitizeBlockSize(std::optional<int32_t> const& val) {
   // Default 256 when not set or invalid
   int block = val.value_or(256);
   // Clamp to sane CUDA bounds and warp multiples
   if (block <= 0) block = 256;
   if (block > 1024) block = 1024;
   // Round to nearest multiple of 32 (warp size)
   block = (block + 31) / 32 * 32;
-  if (block == 0) block = 256;
   return block;
 }
csrc/trtllm_moe_alltoall.cu (1)

101-123: Consider removing synchronous stream wait in initialization.

The cudaStreamSynchronize on line 119 blocks the host until the memset completes. This may be intentional to ensure the workspace is zeroed before returning, but it could impact performance if called frequently. If the caller can guarantee proper stream ordering, this sync could be deferred.

Is the synchronous wait required for correctness, or can the caller rely on stream ordering for subsequent operations?

flashinfer/comm/trtllm_moe_alltoall.py (2)

377-377: Annotate mutable class attribute with ClassVar.

The static analysis tool correctly identifies that _WORKSPACE_CACHE should be annotated with typing.ClassVar to indicate it's a class-level variable.

Apply this diff:

+from typing import ClassVar
+
 class MoeAlltoAll:
     # Single shared workspace across the process
     # _WORKSPACE: Optional[dict] = None
-    _WORKSPACE_CACHE: dict[tuple[int, int, int, int], dict] = {}
+    _WORKSPACE_CACHE: ClassVar[dict[tuple[int, int, int, int], dict]] = {}

468-471: Consider extracting validation logic to a helper method.

The exception messages for top_k and num_experts validation are specified inline. While not a critical issue, extracting this to a validation helper would improve maintainability.

This is a minor code organization suggestion. The current implementation is acceptable.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between db22fce and 9509058.

📒 Files selected for processing (17)
  • csrc/nv_internal/cpp/common/envUtils.cpp (2 hunks)
  • csrc/nv_internal/tensorrt_llm/common/envUtils.h (2 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu (1 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h (1 hunks)
  • csrc/nv_internal/tensorrt_llm/thop/moeAlltoAllMeta.h (1 hunks)
  • csrc/trtllm_moe_alltoall.cu (1 hunks)
  • docs/api/comm.rst (1 hunks)
  • flashinfer/aot.py (1 hunks)
  • flashinfer/comm/__init__.py (1 hunks)
  • flashinfer/comm/trtllm_moe_alltoall.py (1 hunks)
  • flashinfer/jit/__init__.py (1 hunks)
  • flashinfer/jit/comm.py (1 hunks)
  • scripts/task_test_multi_node_comm_kernels.sh (1 hunks)
  • scripts/task_test_single_node_comm_kernels.sh (1 hunks)
  • tests/comm/test_mnnvl_memory.py (1 hunks)
  • tests/comm/test_mnnvl_moe_alltoall.py (1 hunks)
  • tests/comm/test_trtllm_moe_alltoall.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
  • flashinfer/aot.py
  • flashinfer/jit/comm.py
  • scripts/task_test_multi_node_comm_kernels.sh
🧰 Additional context used
🧬 Code graph analysis (8)
tests/comm/test_mnnvl_memory.py (1)
flashinfer/comm/mapping.py (1)
  • local_rank (391-392)
csrc/nv_internal/cpp/common/envUtils.cpp (1)
include/flashinfer/trtllm/common.h (1)
  • getBoolEnv (195-198)
csrc/nv_internal/tensorrt_llm/common/envUtils.h (1)
csrc/nv_internal/cpp/common/envUtils.cpp (8)
  • getEnvKVCacheTimeOutputPath (275-278)
  • getEnvKVCacheTimeOutputPath (275-275)
  • getEnvMoeA2AOneBlockPerToken (326-333)
  • getEnvMoeA2AOneBlockPerToken (326-326)
  • getEnvMoeA2ADispatchBlockSize (347-350)
  • getEnvMoeA2ADispatchBlockSize (347-347)
  • getEnvMoeA2ACombineBlockSize (352-355)
  • getEnvMoeA2ACombineBlockSize (352-352)
csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu (1)
csrc/nv_internal/cpp/common/envUtils.cpp (4)
  • getEnvMoeA2ADispatchBlockSize (347-350)
  • getEnvMoeA2ADispatchBlockSize (347-347)
  • getEnvMoeA2ACombineBlockSize (352-355)
  • getEnvMoeA2ACombineBlockSize (352-352)
flashinfer/comm/trtllm_moe_alltoall.py (4)
flashinfer/comm/mnnvl.py (5)
  • MnnvlMemory (232-551)
  • MnnvlConfig (224-229)
  • as_torch_strided_tensor (264-273)
  • initialize (276-285)
  • set_comm_from_config (288-293)
flashinfer/comm/mapping.py (2)
  • Mapping (21-475)
  • moe_ep_rank (349-350)
flashinfer/jit/comm.py (1)
  • gen_mnnvl_moe_alltoall_module (83-109)
include/flashinfer/trtllm/fused_moe/runner.h (2)
  • num_experts (263-263)
  • hidden_size (265-265)
tests/comm/test_mnnvl_moe_alltoall.py (3)
flashinfer/comm/trtllm_moe_alltoall.py (4)
  • MoeAlltoAll (362-644)
  • dispatch (510-567)
  • get_combine_payload_tensor_in_workspace (611-644)
  • combine (569-609)
flashinfer/comm/mapping.py (1)
  • Mapping (21-475)
flashinfer/comm/mnnvl.py (3)
  • MnnvlMemory (232-551)
  • initialize (276-285)
  • supports_mnnvl (545-551)
csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h (3)
csrc/nv_internal/tensorrt_llm/common/envUtils.h (1)
  • tensorrt_llm (23-104)
csrc/nv_internal/tensorrt_llm/thop/moeAlltoAllMeta.h (1)
  • mnnvl_throughput (25-58)
csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu (10)
  • moe_a2a_dispatch_launch (445-506)
  • moe_a2a_dispatch_launch (445-445)
  • moe_a2a_prepare_dispatch_launch (436-439)
  • moe_a2a_prepare_dispatch_launch (436-436)
  • moe_a2a_combine_launch (792-842)
  • moe_a2a_combine_launch (792-792)
  • moe_a2a_prepare_combine_launch (748-786)
  • moe_a2a_prepare_combine_launch (748-748)
  • moe_a2a_sanitize_expert_ids_launch (864-872)
  • moe_a2a_sanitize_expert_ids_launch (864-866)
flashinfer/comm/__init__.py (1)
flashinfer/comm/trtllm_moe_alltoall.py (12)
  • MoeAlltoAll (362-644)
  • moe_a2a_combine (97-138)
  • moe_a2a_combine (310-333)
  • moe_a2a_dispatch (51-91)
  • moe_a2a_dispatch (249-307)
  • moe_a2a_initialize (39-45)
  • moe_a2a_initialize (208-216)
  • moe_a2a_get_workspace_size_per_rank (173-196)
  • moe_a2a_get_workspace_size_per_rank (348-359)
  • moe_a2a_sanitize_expert_ids (144-153)
  • moe_a2a_sanitize_expert_ids (336-345)
  • moe_a2a_wrap_payload_tensor_in_workspace (219-246)
🪛 Clang (14.0.6)
csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h

[error] 18-18: 'cuda_bf16.h' file not found

(clang-diagnostic-error)

csrc/nv_internal/tensorrt_llm/thop/moeAlltoAllMeta.h

[error] 19-19: 'array' file not found

(clang-diagnostic-error)

🪛 Ruff (0.14.7)
flashinfer/comm/trtllm_moe_alltoall.py

377-377: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


469-469: Avoid specifying long messages outside the exception class

(TRY003)


471-471: Avoid specifying long messages outside the exception class

(TRY003)


632-634: Avoid specifying long messages outside the exception class

(TRY003)

tests/comm/test_mnnvl_moe_alltoall.py

34-34: Avoid specifying long messages outside the exception class

(TRY003)


576-576: Do not catch blind exception: Exception

(BLE001)


677-677: Do not catch blind exception: Exception

(BLE001)


711-711: Unpacked variable expert_id_payload_index is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🔇 Additional comments (30)
csrc/nv_internal/tensorrt_llm/common/envUtils.h (2)

98-98: Clarify the intent of the TODO comment before merging to main.

The TODO comment indicates these APIs are "For DEV purpose temporarily." If these are indeed temporary development features, consider whether they should be merged to the main branch or if the comment should be updated to reflect their production readiness.


67-67: Verify that all call sites have been updated for the function rename and return type change.

The function name changed from getEnvKVCacheTransferOutputPath to getEnvKVCacheTimeOutputPath, and the return type changed from std::string to std::string const&. Ensure no stale references to the old function name remain in the codebase and that all callers have been updated accordingly, as this is a breaking API change.

scripts/task_test_single_node_comm_kernels.sh (1)

25-25: LGTM!

The new test invocation follows the existing pattern and is appropriately placed with other TRT-LLM MOE-related tests.

csrc/nv_internal/cpp/common/envUtils.cpp (2)

326-333: LGTM!

The default-true behavior with explicit opt-out via TLLM_MOE_A2A_ONE_BLOCK_PER_TOKEN=0 is correctly implemented using getIntEnv to distinguish "not set" from "set to 0".


347-358: LGTM!

The block size getters and EPLB force GDRcopy helper are correctly implemented with thread-safe static initialization.

csrc/trtllm_moe_alltoall.cu (4)

125-261: LGTM!

The dispatch operation has thorough input validation, proper error checking after kernel launch, and correctly populates the params structure. The alignment error message typo noted in past review has been fixed.


278-360: LGTM!

The combine operation correctly validates inputs, handles the in-workspace vs external payload distinction, and properly checks for kernel launch errors.


362-410: LGTM!

The sanitize operation and metainfo index pair helper are correctly implemented with proper validation and error handling.


263-276: LGTM!

The dtype conversion correctly handles the supported floating-point types. The return after throw is unreachable but satisfies the compiler's control flow analysis.

csrc/nv_internal/tensorrt_llm/thop/moeAlltoAllMeta.h (2)

17-43: LGTM!

The header correctly defines the metainfo index enum with sequential values, the type alias for the offsets array, and includes all necessary headers. The static analysis hint about <array> not being found is a false positive—the include is present on line 19 and the codebase compiles with proper include paths.


45-58: LGTM!

The index pairs function correctly exposes all metainfo field indices with descriptive names for Python interoperability.

csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h (4)

17-27: LGTM!

The header guards and configuration constants are appropriately defined. The static analysis hint about cuda_bf16.h not being found is a false positive—the include is present and the file compiles correctly with proper CUDA include paths.


30-76: LGTM!

The PayloadDescriptor, DispatchKernelPointers, and CombineKernelPointers structs are well-designed with appropriate const-correctness (dispatch writes to recv buffers, combine reads from them) and clear documentation.


122-179: LGTM!

The function declarations correctly match the implementations (as seen in the relevant code snippets from moeAlltoAllKernels.cu). Parameters are passed by const reference appropriately.


148-148: Verify nvinfer1::DataType availability.

The header uses nvinfer1::DataType but doesn't include the TensorRT header that defines it. This relies on transitive includes from consumers, which can be fragile and may cause compilation issues if the header is included in isolation.

tests/comm/test_mnnvl_memory.py (1)

125-125: LGTM! Correct device selection for multi-node scenarios.

The change from self.rank to self.local_rank ensures proper CUDA device selection in multi-node environments where the global rank may exceed the local GPU count. This aligns with the local_rank calculation at line 46.

flashinfer/jit/__init__.py (1)

80-80: LGTM! Proper API surface exposure.

The new import follows the established pattern and correctly exposes the MNNVL MoE A2A module generator as part of the public JIT API.

docs/api/comm.rst (1)

132-147: LGTM! Comprehensive API documentation.

The new MNNVL A2A section properly documents all eight public API symbols. The previously flagged missing function moe_a2a_get_workspace_size_per_rank is now included (line 146), addressing past review feedback.

csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu (1)

111-116: Verify timeout calculation is appropriate.

The timeout check multiplies 300 × 2000 MHz to derive a cycle threshold. The comment mentions "should be high enough on any GPU," but this assumes a specific clock frequency that may not hold across all GPU architectures. Consider using a time-based timeout API if available, or document the assumption more clearly.

Based on learnings, the developer has flagged this for manual review (line 115 comment).

tests/comm/test_trtllm_moe_alltoall.py (3)

25-29: LGTM! Docstring accurately describes the fixture.

The docstring has been corrected from the previous "warm up JIT compilation" to "Set up torch seed for deterministic tests," which accurately reflects what the fixture does.


63-72: Helper function enhances test reliability.

The check_sufficient_sm_count function is a good practice for avoiding resource contention when running multiple kernels on a single GPU. This prevents hangs and improves test reliability.


390-432: Comprehensive reference implementation for verification.

The fake_moe function provides a clear reference implementation that processes tokens through experts. This is essential for validating the correctness of the A2A dispatch and combine operations. The EP-aware logic (lines 404-420) correctly handles expert parallelism scenarios.

flashinfer/comm/trtllm_moe_alltoall.py (4)

30-33: Good use of caching for JIT module.

Using @functools.cache ensures the JIT module is built only once and reused, improving performance.


219-246: LGTM! Docstring has been updated.

The docstring now correctly describes the actual parameters (workspace, leading_shape, slice_start, slice_end, dtype) instead of the stale parameters flagged in the previous review.


379-409: Workspace caching mechanism appears sound.

The caching strategy uses a composite key of (workspace_size_per_rank, ep_rank, ep_size, max_num_tokens) to support different communicator shapes simultaneously. This addresses the developer's concern at line 377 about testing multiple shapes concurrently. The implementation correctly initializes MNNVL memory and metainfo once per unique configuration.

Based on learnings, the developer requested validation of the cache logic.


611-644: Zero-copy workspace tensor allocation is well-designed.

The get_combine_payload_tensor_in_workspace method provides a workspace-backed tensor for combine payloads, enabling zero-copy writes by expert processing. The state check at line 631 ensures proper phase ordering (dispatch before workspace allocation).

Based on learnings, the developer flagged this API for review (line 348).

tests/comm/test_mnnvl_moe_alltoall.py (4)

27-34: Good MPI safety mechanism.

The MPIExit exception and check_any_rank_failed function provide a clean way to handle failures across MPI ranks, ensuring that all ranks terminate gracefully if any rank fails.


55-68: Well-documented helper for expert-to-rank mapping.

The compute_target_rank_id function includes clear documentation with examples, making the contiguous partitioning strategy easy to understand. This mirrors the CUDA implementation.


386-556: Comprehensive verification of dispatch correctness.

The verify_dispatch function thoroughly validates:

  • Tensor dimensions and dtypes
  • Send/receive counter consistency across ranks
  • Payload content integrity using routing indices
  • Invalid token sanitization

This provides strong confidence in the dispatch implementation's correctness.


657-814: Full end-to-end validation with appropriate tolerance.

The moe_a2a_dispatch_moe_combine_test_impl function validates the complete dispatch → expert processing → combine pipeline. The match threshold approach (lines 802-814) is pragmatic for handling accumulation order differences in bfloat16, allowing up to 1% mismatches with a clear explanation.

@djns99
Copy link
Author

djns99 commented Dec 3, 2025

/bot cancel

@djns99
Copy link
Author

djns99 commented Dec 3, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

Unknown Command

Command /bot cancel is not recognized.

Use /bot help for available commands.

@flashinfer-bot
Copy link
Collaborator

GitLab MR !174 has been updated with latest changes, and the CI pipeline #39568335 is currently running. I'll report back once the pipeline job completes.

@djns99
Copy link
Author

djns99 commented Dec 4, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !174 has been updated with latest changes, and the CI pipeline #39578645 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
flashinfer/comm/trtllm_moe_alltoall.py (2)

386-388: Mutable class attributes should use ClassVar annotation.

The static analysis correctly identifies that _WORKSPACE_CACHE is a mutable class attribute that should be annotated with ClassVar to satisfy PEP 526.

Apply this diff:

+from typing import ClassVar, Optional
+
 @dataclass
 class _A2AState:
     ...

 class MoeAlltoAll:
-    _WORKSPACE_CACHE: dict[tuple[int, int, int, int], dict] = {}
+    _WORKSPACE_CACHE: ClassVar[dict[tuple[int, int, int, int], dict]] = {}

Also apply the same fix to _METAINFO_INDEX at line 466:

-    _METAINFO_INDEX: Optional[dict] = None
+    _METAINFO_INDEX: ClassVar[Optional[dict]] = None

487-495: Use explicit Optional type hints for PEP 484 compliance.

Per PEP 484, = None default should use Optional[T] or T | None for the type hint rather than relying on implicit Optional.

     def __init__(
         self,
         mapping: Mapping,
         max_num_tokens: int,
         top_k: int,
         num_experts: int,
-        workspace_size_per_rank: int = None,
-        hidden_size: int = None,
+        workspace_size_per_rank: Optional[int] = None,
+        hidden_size: Optional[int] = None,
         mnnvl_config: Optional[MnnvlConfig] = None,
     ):
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b7da0b7 and 13f86a2.

📒 Files selected for processing (3)
  • csrc/trtllm_moe_alltoall.cu (1 hunks)
  • flashinfer/comm/trtllm_moe_alltoall.py (1 hunks)
  • tests/comm/test_trtllm_moe_alltoall.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
flashinfer/comm/trtllm_moe_alltoall.py (3)
flashinfer/comm/mnnvl.py (5)
  • MnnvlMemory (232-551)
  • MnnvlConfig (224-229)
  • as_torch_strided_tensor (264-273)
  • initialize (276-285)
  • set_comm_from_config (288-293)
flashinfer/comm/mapping.py (2)
  • Mapping (21-475)
  • moe_ep_rank (349-350)
flashinfer/jit/comm.py (1)
  • gen_mnnvl_moe_alltoall_module (83-109)
csrc/trtllm_moe_alltoall.cu (2)
csrc/tvm_ffi_utils.h (3)
  • Tensor (282-284)
  • get_current_stream (266-270)
  • encode_dlpack_dtype (29-31)
flashinfer/comm/cuda_ipc.py (1)
  • cudaGetErrorString (146-147)
🪛 Ruff (0.14.7)
flashinfer/comm/trtllm_moe_alltoall.py

388-388: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


493-493: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)


494-494: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)


533-533: Avoid specifying long messages outside the exception class

(TRY003)


535-535: Avoid specifying long messages outside the exception class

(TRY003)


696-698: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (29)
csrc/trtllm_moe_alltoall.cu (9)

1-51: LGTM - License header and includes are appropriate.

The file has proper Apache 2.0 license headers, includes necessary TVM FFI headers, and sets up appropriate namespace aliases for the TRT-LLM throughput kernels.


53-88: Offset calculation logic is sound.

The sequential offset calculation with cache-line alignment is correct. The alignment formula (offset + alignment - 1) & ~(alignment - 1) is standard.

Minor consideration: the multiplications like static_cast<size_t>(maxNumTokens) * tl_throughput::kMaxTopK * kInt32Bytes could theoretically overflow for extremely large values, but given typical MoE configurations (maxNumTokens in thousands, kMaxTopK small), this is unlikely to be a practical concern.


95-117: Initialization logic is correct with proper synchronization.

The function properly validates inputs, zeros the workspace for the specified rank, and returns a CPU tensor containing the offset metadata. The cudaStreamSynchronize ensures the memset completes before the metainfo is used.


119-186: Dispatch input validation is thorough.

The function properly validates:

  • Expert tensor shape and type
  • Payload count limits against kMaxPayloads
  • Individual payload shapes matching local token count
  • Metainfo format and workspace dimensions
  • Alignment requirements for payload buffers

The error message at line 185 has been corrected per the past review.


188-255: Dispatch parameter setup and kernel launch are well-structured.

The workspace pointer arithmetic is correct, using stride(0) for the rank dimension. The kernel launch is properly followed by error checking via cudaGetLastError(). The return tuple provides absolute offsets that can be used directly with the workspace tensor.


257-270: Data type conversion handles expected types.

The function correctly maps DLPack dtypes to nvinfer1 types for the three supported floating-point formats. The unreachable return after the throw is a reasonable way to satisfy compiler warnings.


272-354: Combine operation has proper validation and error handling.

The function correctly:

  • Validates 3D payload shape expectations
  • Ensures the combine payload region fits within workspace bounds
  • Verifies pointer alignment when payload_in_workspace is true
  • Checks kernel launch errors

The prepare_payload being null when payload_in_workspace is true correctly signals to the kernel that data is already in place.


356-388: Sanitize expert IDs operation is straightforward.

The function correctly infers dimensions from the expert_ids tensor shape and launches the sanitization kernel with proper error checking.


390-413: Metainfo index pairs and FFI exports are well-organized.

The helper function cleanly converts C++ pairs to TVM-compatible arrays, and the FFI exports follow a consistent naming convention matching the Python API surface.

tests/comm/test_trtllm_moe_alltoall.py (10)

17-32: Setup and fixture are correct.

The docstring has been updated per the past review to accurately describe setting the torch seed for deterministic tests. The pynvml.nvmlInit() call at module level is appropriate for GPU resource checks.


35-63: Test parameters provide good coverage.

The parameterization covers various scales (small to large), different rank configurations (2, 4, 8), both supported floating-point types (bfloat16, float16), and both workspace modes. This should catch most configuration-related issues.


66-89: Helper functions are well-designed.

The check_sufficient_sm_count function is a smart safeguard to prevent potential deadlocks when simulating multiple ranks on a single GPU. The make_payload function correctly handles both integer and floating-point types.


92-168: Single GPU test is comprehensive and correct.

The test validates the complete dispatch-combine cycle:

  • Creates payloads with multiple dtypes (bfloat16, float16, int32, uint8)
  • Verifies that dispatch output matches input after accounting for shuffling
  • Tests the in-workspace combine path

The x[0].numel() idiom at line 121 correctly computes the number of elements per row.


171-247: Multi-rank dispatch simulation is well-structured.

The helper correctly:

  • Calculates workspace size accounting for all ranks
  • Initializes metainfo for each rank before synchronization
  • Uses separate CUDA streams to simulate parallel rank execution
  • Properly synchronizes all streams before returning

This is a clever approach to test multi-rank behavior on a single GPU.


250-306: Sanitize and combine helpers follow same pattern as dispatch.

Both helpers correctly use parallel CUDA streams with proper synchronization to simulate multi-rank execution.


309-394: Multi-rank dispatch and sanitize tests are thorough.

The tests validate:

  • Tokens are correctly routed to their target ranks
  • Non-matching tokens are properly filtered (via any() check)
  • Sanitization correctly marks invalid expert IDs

397-438: Fake MoE reference implementation is correct.

The function properly:

  • Scales hidden states by a deterministic factor based on expert ID
  • Filters by EP rank when in expert-parallel mode
  • Uses float32 for accumulation before casting back to the original dtype

This provides a reliable reference for validating the combine operation.


441-558: Combine test validates end-to-end correctness with appropriate tolerance.

The test exercises the full dispatch → process → combine pipeline with the fake_moe reference. The tolerances (atol=1.5e-2, rtol=1.5e-2) are appropriate for 16-bit floating point accumulation across multiple experts.


561-614: Workspace size test validates consistency between APIs.

The test confirms that:

  • get_moe_workspace_size_per_rank returns consistent values across calling methods
  • The aux data size calculation is correct
  • The data region size scales linearly with token count

Good coverage of the sizing API surface.

flashinfer/comm/trtllm_moe_alltoall.py (10)

21-27: State tracking dataclass is appropriate.

The _A2AState dataclass correctly tracks the operation phase and stores state needed between dispatch and combine calls.


30-196: JIT module initialization with custom op registration is well-structured.

The use of @functools.cache ensures single initialization, and each operation is properly registered as a custom op for graph compilation support. The docstrings accurately describe the parameters and return values.


210-237: Workspace tensor wrapping is correct with updated docstring.

The function correctly:

  • Views the workspace as uint8 bytes
  • Validates the slice bounds
  • Creates a properly typed and shaped view

The docstring has been updated per the past review.


240-298: Dispatch wrapper correctly wraps output payloads.

The function calls the C++ dispatch and then wraps each output payload as a workspace-backed tensor with the correct shape and dtype. The use of strict=True in zip ensures payload count consistency.


339-370: Workspace size calculation correctly mirrors C++ implementation.

The 128-byte alignment padding matches the kCachelineAlignment used in the C++ implementation, ensuring consistent sizing between Python and C++.


560-572: Reset workspace method is appropriately guarded.

The _reset_workspace method is clearly documented as for testing only with appropriate warnings. The state transition to "deleted" prevents accidental reuse.


574-631: Dispatch method has proper phase validation and state management.

The method correctly:

  • Validates the phase is "idle" before dispatch
  • Validates runtime tokens don't exceed configured max
  • Updates state after successful dispatch
  • Optionally sanitizes expert IDs with required parameter validation

633-673: Combine method correctly resets state for next cycle.

The method validates the phase, calls the C++ combine operation, and resets _state to a fresh _A2AState() instance, ready for the next dispatch-combine cycle.


675-708: Zero-copy workspace tensor accessor is well-implemented.

The get_combine_payload_tensor_in_workspace method correctly validates the phase and uses element_size() to calculate the proper byte range for the slice.


711-718: Module exports are complete and correctly ordered.

All public API functions are included in __all__, providing a clean public interface.

@flashinfer-bot
Copy link
Collaborator

[CANCELING] Pipeline #39578645: canceled

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants