From 058e205f535a1c2ed046170c0edb749be89092c0 Mon Sep 17 00:00:00 2001 From: yzh119 Date: Thu, 13 Nov 2025 22:07:29 -0500 Subject: [PATCH 1/9] upd --- ...> flashinfer_cutlass_fused_moe_binding.cu} | 16 +++++ .../tensorrt_llm/deep_gemm/compiler.cuh | 71 +++---------------- flashinfer/fused_moe/core.py | 8 +++ flashinfer/jit/fused_moe.py | 2 +- 4 files changed, 36 insertions(+), 61 deletions(-) rename csrc/fused_moe/cutlass_backend/{flashinfer_cutlass_fused_moe_sm100_binding.cu => flashinfer_cutlass_fused_moe_binding.cu} (99%) diff --git a/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu b/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu similarity index 99% rename from csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu rename to csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu index 8d996da98e..df6cc16752 100644 --- a/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu +++ b/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu @@ -27,6 +27,7 @@ #include "../../tvm_ffi_utils.h" #include "cutlass_kernel_selector.h" #include "moe_gemm_kernels.h" +#include "nv_internal/tensorrt_llm/deep_gemm/compiler.cuh" #include "tensorrt_llm/common/workspace.h" #include "tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm.h" @@ -1199,3 +1200,18 @@ tvm::ffi::Module init(DLDataType activation_dtype, DLDataType weight_dtype, DLDa } TVM_FFI_DLL_EXPORT_TYPED_FUNC(init, init); + +namespace flashinfer { + +void set_deepgemm_jit_include_dirs(tvm::ffi::Array include_dirs) { + std::vector dirs; + for (const auto& dir : include_dirs) { + dirs.push_back(std::filesystem::path(std::string(dir))); + } + deep_gemm::jit::Compiler::setIncludeDirs(dirs); +} + +} // namespace flashinfer + +TVM_FFI_DLL_EXPORT_TYPED_FUNC(set_deepgemm_jit_include_dirs, + flashinfer::set_deepgemm_jit_include_dirs); diff --git a/csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh b/csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh index 25ca90927d..90363b3503 100644 --- a/csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh +++ b/csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh @@ -121,67 +121,13 @@ std::string getNvccCompiler() { return compiler; } -std::vector getJitIncludeDirs() { - static std::vector includeDirs; - if (includeDirs.empty()) { - // Command to execute - try pip first, fallback to uv pip - char const* cmd = - "pip show flashinfer-python 2>/dev/null || uv pip show flashinfer-python 2>/dev/null"; - - // Buffer to store the output - std::array buffer; - std::string result; - -// Open pipe to command -#ifdef _MSC_VER - FILE* pipe = _popen(cmd, "r"); -#else - FILE* pipe = popen(cmd, "r"); -#endif - - if (pipe) { - // Read the output - while (fgets(buffer.data(), buffer.size(), pipe) != nullptr) { - result += buffer.data(); - } - -// Close the pipe -#ifdef _MSC_VER - _pclose(pipe); -#else - pclose(pipe); -#endif - - // Parse the location using regex - // `pip show tensorrt_llm` will output something like: - // Location: /usr/local/lib/python3.12/dist-packages - // Editable project location: /code - std::regex locationRegex("(Location|Editable project location): (.+)"); - - // Find all matches - auto match_begin = std::sregex_iterator(result.begin(), result.end(), locationRegex); - auto match_end = std::sregex_iterator(); - - // Get the number of matches - auto match_count = std::distance(match_begin, match_end); - - if (match_count > 0) { - // Get the last match - auto last_match_iter = match_begin; - std::advance(last_match_iter, match_count - 1); - - // Get the path from the second capture group - std::string location = last_match_iter->str(2); - location.erase(location.find_last_not_of(" \n\r\t") + 1); +inline void setJitIncludeDirs(std::vector const& dirs) { + static std::vector& includeDirs = getJitIncludeDirs(); + includeDirs = dirs; +} - // Set the include directory based on the package location - includeDirs.push_back(std::filesystem::path(location) / "flashinfer" / "data" / "csrc" / - "nv_internal" / "tensorrt_llm"); - } - } else { - TLLM_LOG_WARNING("Failed to find FlashInfer installation, DeepGEMM will be disabled."); - } - } +inline std::vector& getJitIncludeDirs() { + static std::vector includeDirs; return includeDirs; } @@ -291,6 +237,11 @@ class Compiler { [[nodiscard]] bool isValid() const { return !includeDirs_.empty(); } + // Set include directories before the singleton is initialized + static void setIncludeDirs(std::vector const& dirs) { + setJitIncludeDirs(dirs); + } + // Build function Runtime* build(uint32_t const shape_n, uint32_t const shape_k, uint32_t const block_m, uint32_t const block_n, uint32_t const block_k, uint32_t const num_groups, diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 83f186673b..e9f746a281 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -324,6 +324,14 @@ def get_cutlass_fused_moe_module(backend: str = "100", use_fast_build: bool = Fa else: raise ValueError(f"Invalid backend: {backend}") + # Set DeepGEMM JIT include directories after module is loaded + from ..jit import env as jit_env + + deepgemm_include_dir = str( + jit_env.FLASHINFER_CSRC_DIR / "nv_internal" / "tensorrt_llm" + ) + module.set_deepgemm_jit_include_dirs([deepgemm_include_dir]) + class MoERunner(TunableRunner): # avoid overhead of creating a new runner in forward pass runner_dict: Dict[ diff --git a/flashinfer/jit/fused_moe.py b/flashinfer/jit/fused_moe.py index 78c19e98ac..037e7f10f5 100644 --- a/flashinfer/jit/fused_moe.py +++ b/flashinfer/jit/fused_moe.py @@ -146,7 +146,7 @@ def gen_cutlass_fused_moe_module( jit_env.FLASHINFER_CSRC_DIR / "nv_internal/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm.cu", jit_env.FLASHINFER_CSRC_DIR - / "fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu", + / "fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu", jit_env.FLASHINFER_CSRC_DIR / "fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu", # Add all generated kernels From b36e1fbf23ebd4bdee57034f324a546f47e667c2 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Thu, 13 Nov 2025 22:28:19 -0500 Subject: [PATCH 2/9] upd --- csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh | 10 +++++----- csrc/nv_internal/tensorrt_llm/deep_gemm/jit_utils.cuh | 5 ++++- csrc/nv_internal/tensorrt_llm/deep_gemm/runtime.cuh | 11 ++++++----- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh b/csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh index 90363b3503..03a5dbf4a1 100644 --- a/csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh +++ b/csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh @@ -121,16 +121,16 @@ std::string getNvccCompiler() { return compiler; } -inline void setJitIncludeDirs(std::vector const& dirs) { - static std::vector& includeDirs = getJitIncludeDirs(); - includeDirs = dirs; -} - inline std::vector& getJitIncludeDirs() { static std::vector includeDirs; return includeDirs; } +inline void setJitIncludeDirs(std::vector const& dirs) { + static std::vector& includeDirs = getJitIncludeDirs(); + includeDirs = dirs; +} + std::string generateKernel(uint32_t const shape_n, uint32_t const shape_k, uint32_t const block_m, uint32_t const block_n, uint32_t const block_k, uint32_t const num_groups, uint32_t const num_stages, diff --git a/csrc/nv_internal/tensorrt_llm/deep_gemm/jit_utils.cuh b/csrc/nv_internal/tensorrt_llm/deep_gemm/jit_utils.cuh index 25c47eb8f6..a2f5dd3701 100644 --- a/csrc/nv_internal/tensorrt_llm/deep_gemm/jit_utils.cuh +++ b/csrc/nv_internal/tensorrt_llm/deep_gemm/jit_utils.cuh @@ -39,7 +39,9 @@ } while (0) // Helper function to check CUDA driver errors -#define CHECK_CUDA(call) \ +// Only define if not already defined (avoid conflicts with tvm_ffi_utils.h) +#ifndef CHECK_CUDA_DRIVER +#define CHECK_CUDA_DRIVER(call) \ do { \ CUresult result = call; \ if (result != CUDA_SUCCESS) { \ @@ -49,6 +51,7 @@ exit(1); \ } \ } while (0) +#endif namespace deep_gemm::jit { diff --git a/csrc/nv_internal/tensorrt_llm/deep_gemm/runtime.cuh b/csrc/nv_internal/tensorrt_llm/deep_gemm/runtime.cuh index f4e6ab124e..3a86c6c7b5 100644 --- a/csrc/nv_internal/tensorrt_llm/deep_gemm/runtime.cuh +++ b/csrc/nv_internal/tensorrt_llm/deep_gemm/runtime.cuh @@ -65,7 +65,7 @@ class Runtime { ~Runtime() { if (lib_ != nullptr) { - CHECK_CUDA(cuLibraryUnload(lib_)); + CHECK_CUDA_DRIVER(cuLibraryUnload(lib_)); } } @@ -89,17 +89,18 @@ class Runtime { cubin_ = std::vector(std::istreambuf_iterator(cubinFile), {}); } - CHECK_CUDA(cuLibraryLoadData(&lib_, cubin_.data(), nullptr, nullptr, 0, nullptr, nullptr, 0)); + CHECK_CUDA_DRIVER( + cuLibraryLoadData(&lib_, cubin_.data(), nullptr, nullptr, 0, nullptr, nullptr, 0)); unsigned int numKernels = 0; - CHECK_CUDA(cuLibraryGetKernelCount(&numKernels, lib_)); + CHECK_CUDA_DRIVER(cuLibraryGetKernelCount(&numKernels, lib_)); std::vector kernels(numKernels); - CHECK_CUDA(cuLibraryEnumerateKernels(kernels.data(), numKernels, lib_)); + CHECK_CUDA_DRIVER(cuLibraryEnumerateKernels(kernels.data(), numKernels, lib_)); for (auto kernel : kernels) { char const* kernelName; - CHECK_CUDA(cuKernelGetName(&kernelName, kernel)); + CHECK_CUDA_DRIVER(cuKernelGetName(&kernelName, kernel)); std::string kernelNameStr(kernelName); if (kernelNameStr.find("fp8_gemm_kernel") != std::string::npos) { kernel_ = kernel; From 45f680c098b96e30625d2e96a32a3e8b2a819346 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sun, 16 Nov 2025 19:02:45 -0500 Subject: [PATCH 3/9] fix --- csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh b/csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh index 03a5dbf4a1..efe977e801 100644 --- a/csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh +++ b/csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh @@ -235,7 +235,7 @@ class Compiler { return instance; } - [[nodiscard]] bool isValid() const { return !includeDirs_.empty(); } + [[nodiscard]] bool isValid() const { return !getJitIncludeDirs().empty(); } // Set include directories before the singleton is initialized static void setIncludeDirs(std::vector const& dirs) { @@ -313,7 +313,7 @@ class Compiler { std::filesystem::create_directories(path); } - for (auto const& dir : includeDirs_) { + for (auto const& dir : getJitIncludeDirs()) { flags.push_back("-I" + dir.string()); } @@ -469,10 +469,8 @@ class Compiler { } private: - std::vector includeDirs_; - // Private constructor for singleton pattern - Compiler() : includeDirs_(getJitIncludeDirs()) { + Compiler() { // Create necessary directories if (kJitUseNvcc || kJitDumpCubin) { std::filesystem::create_directories(getTmpDir()); From 482b3b8fc72214a23646268b8d27912030902af7 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sun, 16 Nov 2025 19:27:46 -0500 Subject: [PATCH 4/9] upd --- .../cutlass_backend/deepgemm_jit_setup.cu | 36 +++++++++++++++++++ .../flashinfer_cutlass_fused_moe_binding.cu | 16 --------- .../tensorrt_llm/deep_gemm/jit_utils.cuh | 5 +-- .../tensorrt_llm/deep_gemm/runtime.cuh | 11 +++--- flashinfer/jit/fused_moe.py | 2 ++ 5 files changed, 44 insertions(+), 26 deletions(-) create mode 100644 csrc/fused_moe/cutlass_backend/deepgemm_jit_setup.cu diff --git a/csrc/fused_moe/cutlass_backend/deepgemm_jit_setup.cu b/csrc/fused_moe/cutlass_backend/deepgemm_jit_setup.cu new file mode 100644 index 0000000000..0d2bd4d2cb --- /dev/null +++ b/csrc/fused_moe/cutlass_backend/deepgemm_jit_setup.cu @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +#include "nv_internal/tensorrt_llm/deep_gemm/compiler.cuh" + +namespace flashinfer { + +void set_deepgemm_jit_include_dirs(tvm::ffi::Array include_dirs) { + std::vector dirs; + for (const auto& dir : include_dirs) { + dirs.push_back(std::filesystem::path(std::string(dir))); + } + deep_gemm::jit::Compiler::setIncludeDirs(dirs); +} + +} // namespace flashinfer + +TVM_FFI_DLL_EXPORT_TYPED_FUNC(set_deepgemm_jit_include_dirs, + flashinfer::set_deepgemm_jit_include_dirs); diff --git a/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu b/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu index df6cc16752..8d996da98e 100644 --- a/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu +++ b/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu @@ -27,7 +27,6 @@ #include "../../tvm_ffi_utils.h" #include "cutlass_kernel_selector.h" #include "moe_gemm_kernels.h" -#include "nv_internal/tensorrt_llm/deep_gemm/compiler.cuh" #include "tensorrt_llm/common/workspace.h" #include "tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm.h" @@ -1200,18 +1199,3 @@ tvm::ffi::Module init(DLDataType activation_dtype, DLDataType weight_dtype, DLDa } TVM_FFI_DLL_EXPORT_TYPED_FUNC(init, init); - -namespace flashinfer { - -void set_deepgemm_jit_include_dirs(tvm::ffi::Array include_dirs) { - std::vector dirs; - for (const auto& dir : include_dirs) { - dirs.push_back(std::filesystem::path(std::string(dir))); - } - deep_gemm::jit::Compiler::setIncludeDirs(dirs); -} - -} // namespace flashinfer - -TVM_FFI_DLL_EXPORT_TYPED_FUNC(set_deepgemm_jit_include_dirs, - flashinfer::set_deepgemm_jit_include_dirs); diff --git a/csrc/nv_internal/tensorrt_llm/deep_gemm/jit_utils.cuh b/csrc/nv_internal/tensorrt_llm/deep_gemm/jit_utils.cuh index a2f5dd3701..25c47eb8f6 100644 --- a/csrc/nv_internal/tensorrt_llm/deep_gemm/jit_utils.cuh +++ b/csrc/nv_internal/tensorrt_llm/deep_gemm/jit_utils.cuh @@ -39,9 +39,7 @@ } while (0) // Helper function to check CUDA driver errors -// Only define if not already defined (avoid conflicts with tvm_ffi_utils.h) -#ifndef CHECK_CUDA_DRIVER -#define CHECK_CUDA_DRIVER(call) \ +#define CHECK_CUDA(call) \ do { \ CUresult result = call; \ if (result != CUDA_SUCCESS) { \ @@ -51,7 +49,6 @@ exit(1); \ } \ } while (0) -#endif namespace deep_gemm::jit { diff --git a/csrc/nv_internal/tensorrt_llm/deep_gemm/runtime.cuh b/csrc/nv_internal/tensorrt_llm/deep_gemm/runtime.cuh index 3a86c6c7b5..f4e6ab124e 100644 --- a/csrc/nv_internal/tensorrt_llm/deep_gemm/runtime.cuh +++ b/csrc/nv_internal/tensorrt_llm/deep_gemm/runtime.cuh @@ -65,7 +65,7 @@ class Runtime { ~Runtime() { if (lib_ != nullptr) { - CHECK_CUDA_DRIVER(cuLibraryUnload(lib_)); + CHECK_CUDA(cuLibraryUnload(lib_)); } } @@ -89,18 +89,17 @@ class Runtime { cubin_ = std::vector(std::istreambuf_iterator(cubinFile), {}); } - CHECK_CUDA_DRIVER( - cuLibraryLoadData(&lib_, cubin_.data(), nullptr, nullptr, 0, nullptr, nullptr, 0)); + CHECK_CUDA(cuLibraryLoadData(&lib_, cubin_.data(), nullptr, nullptr, 0, nullptr, nullptr, 0)); unsigned int numKernels = 0; - CHECK_CUDA_DRIVER(cuLibraryGetKernelCount(&numKernels, lib_)); + CHECK_CUDA(cuLibraryGetKernelCount(&numKernels, lib_)); std::vector kernels(numKernels); - CHECK_CUDA_DRIVER(cuLibraryEnumerateKernels(kernels.data(), numKernels, lib_)); + CHECK_CUDA(cuLibraryEnumerateKernels(kernels.data(), numKernels, lib_)); for (auto kernel : kernels) { char const* kernelName; - CHECK_CUDA_DRIVER(cuKernelGetName(&kernelName, kernel)); + CHECK_CUDA(cuKernelGetName(&kernelName, kernel)); std::string kernelNameStr(kernelName); if (kernelNameStr.find("fp8_gemm_kernel") != std::string::npos) { kernel_ = kernel; diff --git a/flashinfer/jit/fused_moe.py b/flashinfer/jit/fused_moe.py index 037e7f10f5..add4335022 100644 --- a/flashinfer/jit/fused_moe.py +++ b/flashinfer/jit/fused_moe.py @@ -148,6 +148,8 @@ def gen_cutlass_fused_moe_module( jit_env.FLASHINFER_CSRC_DIR / "fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu", jit_env.FLASHINFER_CSRC_DIR + / "fused_moe/cutlass_backend/deepgemm_jit_setup.cu", + jit_env.FLASHINFER_CSRC_DIR / "fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu", # Add all generated kernels *(output_dir / kernel for kernel in output_dir.rglob("*.generated.cu")), From d0fb9ae836229a477ccb7659ea749a2d9af6095c Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 17 Nov 2025 08:20:04 -0500 Subject: [PATCH 5/9] fix header --- csrc/nv_internal/tensorrt_llm/deep_gemm/jit_utils.cuh | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/nv_internal/tensorrt_llm/deep_gemm/jit_utils.cuh b/csrc/nv_internal/tensorrt_llm/deep_gemm/jit_utils.cuh index 25c47eb8f6..018cb22a11 100644 --- a/csrc/nv_internal/tensorrt_llm/deep_gemm/jit_utils.cuh +++ b/csrc/nv_internal/tensorrt_llm/deep_gemm/jit_utils.cuh @@ -16,6 +16,7 @@ */ #pragma once +#include #include #include From 90a0149fecb03569b0116ef9b94eaf07d2e0f271 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 17 Nov 2025 10:53:55 -0500 Subject: [PATCH 6/9] add missing headers --- csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh b/csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh index efe977e801..2dd3ae6a54 100644 --- a/csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh +++ b/csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh @@ -36,6 +36,9 @@ #include "nvrtc.h" #include "runtime.cuh" #include "scheduler.cuh" +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/logger.h" #ifdef _WIN32 #include From 3b3a25653cc70bbef23b95c0e7f8ec89eb23f5d8 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sun, 30 Nov 2025 11:13:50 -0800 Subject: [PATCH 7/9] fix-inline --- csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh | 8 ++++---- csrc/nv_internal/tensorrt_llm/deep_gemm/jit_utils.cuh | 10 +++++----- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh b/csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh index 2dd3ae6a54..f8fa6bd5ef 100644 --- a/csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh +++ b/csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh @@ -47,7 +47,7 @@ namespace deep_gemm::jit { // Generate a unique ID for temporary directories to avoid collisions -std::string generateUniqueId() { +inline std::string generateUniqueId() { // Use current time and random number to generate a unique ID static std::mt19937 gen(std::random_device{}()); static std::uniform_int_distribution<> distrib(0, 999999); @@ -62,7 +62,7 @@ std::string generateUniqueId() { return std::to_string(value) + "_" + std::to_string(random_value); } -std::filesystem::path getDefaultUserDir() { +inline std::filesystem::path getDefaultUserDir() { static std::filesystem::path userDir; if (userDir.empty()) { char const* cacheDir = getenv("TRTLLM_DG_CACHE_DIR"); @@ -94,7 +94,7 @@ inline std::filesystem::path getTmpDir() { return getDefaultUserDir() / "tmp"; } inline std::filesystem::path getCacheDir() { return getDefaultUserDir() / "cache"; } -std::string getNvccCompiler() { +inline std::string getNvccCompiler() { static std::string compiler; if (compiler.empty()) { // Check environment variable @@ -134,7 +134,7 @@ inline void setJitIncludeDirs(std::vector const& dirs) { includeDirs = dirs; } -std::string generateKernel(uint32_t const shape_n, uint32_t const shape_k, uint32_t const block_m, +inline std::string generateKernel(uint32_t const shape_n, uint32_t const shape_k, uint32_t const block_m, uint32_t const block_n, uint32_t const block_k, uint32_t const num_groups, uint32_t const num_stages, uint32_t const num_tma_multicast, deep_gemm::GemmType const gemm_type, diff --git a/csrc/nv_internal/tensorrt_llm/deep_gemm/jit_utils.cuh b/csrc/nv_internal/tensorrt_llm/deep_gemm/jit_utils.cuh index 018cb22a11..ebbfda0219 100644 --- a/csrc/nv_internal/tensorrt_llm/deep_gemm/jit_utils.cuh +++ b/csrc/nv_internal/tensorrt_llm/deep_gemm/jit_utils.cuh @@ -68,7 +68,7 @@ GemmConfig get_best_gemm_config(uint32_t shape_m, uint32_t shape_n, uint32_t sha namespace deep_gemm::jit { -std::string gemm_type_to_string(deep_gemm::GemmType gemm_type) { +inline std::string gemm_type_to_string(deep_gemm::GemmType gemm_type) { switch (gemm_type) { case deep_gemm::GemmType::Normal: return std::string("Normal"); @@ -86,9 +86,9 @@ std::string gemm_type_to_string(deep_gemm::GemmType gemm_type) { } } -int div_up(int a, int b) { return (a + b - 1) / b; } +inline int div_up(int a, int b) { return (a + b - 1) / b; } -int get_smem_size(int num_stages, int k, int block_m, int block_n, int block_k = 128, +inline int get_smem_size(int num_stages, int k, int block_m, int block_n, int block_k = 128, bool swap_ab = false) { if (!swap_ab) { int smem_d = block_m * block_n * 2; @@ -127,14 +127,14 @@ int get_smem_size(int num_stages, int k, int block_m, int block_n, int block_k = } } -bool is_tma_multicast_legal(int n, int block_n, int num_tma_multicast, int num_sms) { +inline bool is_tma_multicast_legal(int n, int block_n, int num_tma_multicast, int num_sms) { if (num_tma_multicast == 1) { return true; } return (n % (block_n * num_tma_multicast) == 0) && num_sms % num_tma_multicast == 0; } -GemmConfig get_best_gemm_config(uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, +inline GemmConfig get_best_gemm_config(uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, int num_groups, int num_device_sms, bool is_grouped_contiguous = false, bool swap_ab = false) { // Choose candidate block sizes From 061f47c7ebd369bd4e42c0e846b22e2012dd81c5 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sun, 30 Nov 2025 11:14:01 -0800 Subject: [PATCH 8/9] fix-inline --- csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh | 10 +++++----- csrc/nv_internal/tensorrt_llm/deep_gemm/jit_utils.cuh | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh b/csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh index f8fa6bd5ef..fbdc902972 100644 --- a/csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh +++ b/csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh @@ -134,11 +134,11 @@ inline void setJitIncludeDirs(std::vector const& dirs) { includeDirs = dirs; } -inline std::string generateKernel(uint32_t const shape_n, uint32_t const shape_k, uint32_t const block_m, - uint32_t const block_n, uint32_t const block_k, - uint32_t const num_groups, uint32_t const num_stages, - uint32_t const num_tma_multicast, deep_gemm::GemmType const gemm_type, - bool swapAB = false) { +inline std::string generateKernel(uint32_t const shape_n, uint32_t const shape_k, + uint32_t const block_m, uint32_t const block_n, + uint32_t const block_k, uint32_t const num_groups, + uint32_t const num_stages, uint32_t const num_tma_multicast, + deep_gemm::GemmType const gemm_type, bool swapAB = false) { constexpr uint32_t kNumTMAThreads = 128; constexpr uint32_t kNumMathThreadsPerGroup = 128; diff --git a/csrc/nv_internal/tensorrt_llm/deep_gemm/jit_utils.cuh b/csrc/nv_internal/tensorrt_llm/deep_gemm/jit_utils.cuh index ebbfda0219..8e26486d4d 100644 --- a/csrc/nv_internal/tensorrt_llm/deep_gemm/jit_utils.cuh +++ b/csrc/nv_internal/tensorrt_llm/deep_gemm/jit_utils.cuh @@ -89,7 +89,7 @@ inline std::string gemm_type_to_string(deep_gemm::GemmType gemm_type) { inline int div_up(int a, int b) { return (a + b - 1) / b; } inline int get_smem_size(int num_stages, int k, int block_m, int block_n, int block_k = 128, - bool swap_ab = false) { + bool swap_ab = false) { if (!swap_ab) { int smem_d = block_m * block_n * 2; int smem_a_per_stage = block_m * block_k; @@ -135,8 +135,8 @@ inline bool is_tma_multicast_legal(int n, int block_n, int num_tma_multicast, in } inline GemmConfig get_best_gemm_config(uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, - int num_groups, int num_device_sms, - bool is_grouped_contiguous = false, bool swap_ab = false) { + int num_groups, int num_device_sms, + bool is_grouped_contiguous = false, bool swap_ab = false) { // Choose candidate block sizes std::vector block_ms; block_ms.push_back((!is_grouped_contiguous && shape_m <= 64) ? 64 : 128); From a61855cdcfc0b506e4a1cf81f5c0a975c7f95491 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sun, 30 Nov 2025 13:14:10 -0800 Subject: [PATCH 9/9] fix-inline --- csrc/nv_internal/tensorrt_llm/deep_gemm/runtime.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/nv_internal/tensorrt_llm/deep_gemm/runtime.cuh b/csrc/nv_internal/tensorrt_llm/deep_gemm/runtime.cuh index f4e6ab124e..f960af632d 100644 --- a/csrc/nv_internal/tensorrt_llm/deep_gemm/runtime.cuh +++ b/csrc/nv_internal/tensorrt_llm/deep_gemm/runtime.cuh @@ -181,6 +181,6 @@ class RuntimeCache { }; // Global function to access the singleton -RuntimeCache& getGlobalRuntimeCache() { return RuntimeCache::getInstance(); } +inline RuntimeCache& getGlobalRuntimeCache() { return RuntimeCache::getInstance(); } } // namespace deep_gemm::jit