diff --git a/csrc/fused_moe/cutlass_backend/deepgemm_jit_setup.cu b/csrc/fused_moe/cutlass_backend/deepgemm_jit_setup.cu new file mode 100644 index 0000000000..0d2bd4d2cb --- /dev/null +++ b/csrc/fused_moe/cutlass_backend/deepgemm_jit_setup.cu @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +#include "nv_internal/tensorrt_llm/deep_gemm/compiler.cuh" + +namespace flashinfer { + +void set_deepgemm_jit_include_dirs(tvm::ffi::Array include_dirs) { + std::vector dirs; + for (const auto& dir : include_dirs) { + dirs.push_back(std::filesystem::path(std::string(dir))); + } + deep_gemm::jit::Compiler::setIncludeDirs(dirs); +} + +} // namespace flashinfer + +TVM_FFI_DLL_EXPORT_TYPED_FUNC(set_deepgemm_jit_include_dirs, + flashinfer::set_deepgemm_jit_include_dirs); diff --git a/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu b/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu similarity index 100% rename from csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu rename to csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu diff --git a/csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh b/csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh index 25ca90927d..fbdc902972 100644 --- a/csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh +++ b/csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh @@ -36,6 +36,9 @@ #include "nvrtc.h" #include "runtime.cuh" #include "scheduler.cuh" +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/logger.h" #ifdef _WIN32 #include @@ -44,7 +47,7 @@ namespace deep_gemm::jit { // Generate a unique ID for temporary directories to avoid collisions -std::string generateUniqueId() { +inline std::string generateUniqueId() { // Use current time and random number to generate a unique ID static std::mt19937 gen(std::random_device{}()); static std::uniform_int_distribution<> distrib(0, 999999); @@ -59,7 +62,7 @@ std::string generateUniqueId() { return std::to_string(value) + "_" + std::to_string(random_value); } -std::filesystem::path getDefaultUserDir() { +inline std::filesystem::path getDefaultUserDir() { static std::filesystem::path userDir; if (userDir.empty()) { char const* cacheDir = getenv("TRTLLM_DG_CACHE_DIR"); @@ -91,7 +94,7 @@ inline std::filesystem::path getTmpDir() { return getDefaultUserDir() / "tmp"; } inline std::filesystem::path getCacheDir() { return getDefaultUserDir() / "cache"; } -std::string getNvccCompiler() { +inline std::string getNvccCompiler() { static std::string compiler; if (compiler.empty()) { // Check environment variable @@ -121,75 +124,21 @@ std::string getNvccCompiler() { return compiler; } -std::vector getJitIncludeDirs() { +inline std::vector& getJitIncludeDirs() { static std::vector includeDirs; - if (includeDirs.empty()) { - // Command to execute - try pip first, fallback to uv pip - char const* cmd = - "pip show flashinfer-python 2>/dev/null || uv pip show flashinfer-python 2>/dev/null"; - - // Buffer to store the output - std::array buffer; - std::string result; - -// Open pipe to command -#ifdef _MSC_VER - FILE* pipe = _popen(cmd, "r"); -#else - FILE* pipe = popen(cmd, "r"); -#endif - - if (pipe) { - // Read the output - while (fgets(buffer.data(), buffer.size(), pipe) != nullptr) { - result += buffer.data(); - } - -// Close the pipe -#ifdef _MSC_VER - _pclose(pipe); -#else - pclose(pipe); -#endif - - // Parse the location using regex - // `pip show tensorrt_llm` will output something like: - // Location: /usr/local/lib/python3.12/dist-packages - // Editable project location: /code - std::regex locationRegex("(Location|Editable project location): (.+)"); - - // Find all matches - auto match_begin = std::sregex_iterator(result.begin(), result.end(), locationRegex); - auto match_end = std::sregex_iterator(); - - // Get the number of matches - auto match_count = std::distance(match_begin, match_end); - - if (match_count > 0) { - // Get the last match - auto last_match_iter = match_begin; - std::advance(last_match_iter, match_count - 1); - - // Get the path from the second capture group - std::string location = last_match_iter->str(2); - location.erase(location.find_last_not_of(" \n\r\t") + 1); - - // Set the include directory based on the package location - includeDirs.push_back(std::filesystem::path(location) / "flashinfer" / "data" / "csrc" / - "nv_internal" / "tensorrt_llm"); - } - } else { - TLLM_LOG_WARNING("Failed to find FlashInfer installation, DeepGEMM will be disabled."); - } - } return includeDirs; } -std::string generateKernel(uint32_t const shape_n, uint32_t const shape_k, uint32_t const block_m, - uint32_t const block_n, uint32_t const block_k, - uint32_t const num_groups, uint32_t const num_stages, - uint32_t const num_tma_multicast, deep_gemm::GemmType const gemm_type, - bool swapAB = false) { +inline void setJitIncludeDirs(std::vector const& dirs) { + static std::vector& includeDirs = getJitIncludeDirs(); + includeDirs = dirs; +} + +inline std::string generateKernel(uint32_t const shape_n, uint32_t const shape_k, + uint32_t const block_m, uint32_t const block_n, + uint32_t const block_k, uint32_t const num_groups, + uint32_t const num_stages, uint32_t const num_tma_multicast, + deep_gemm::GemmType const gemm_type, bool swapAB = false) { constexpr uint32_t kNumTMAThreads = 128; constexpr uint32_t kNumMathThreadsPerGroup = 128; @@ -289,7 +238,12 @@ class Compiler { return instance; } - [[nodiscard]] bool isValid() const { return !includeDirs_.empty(); } + [[nodiscard]] bool isValid() const { return !getJitIncludeDirs().empty(); } + + // Set include directories before the singleton is initialized + static void setIncludeDirs(std::vector const& dirs) { + setJitIncludeDirs(dirs); + } // Build function Runtime* build(uint32_t const shape_n, uint32_t const shape_k, uint32_t const block_m, @@ -362,7 +316,7 @@ class Compiler { std::filesystem::create_directories(path); } - for (auto const& dir : includeDirs_) { + for (auto const& dir : getJitIncludeDirs()) { flags.push_back("-I" + dir.string()); } @@ -518,10 +472,8 @@ class Compiler { } private: - std::vector includeDirs_; - // Private constructor for singleton pattern - Compiler() : includeDirs_(getJitIncludeDirs()) { + Compiler() { // Create necessary directories if (kJitUseNvcc || kJitDumpCubin) { std::filesystem::create_directories(getTmpDir()); diff --git a/csrc/nv_internal/tensorrt_llm/deep_gemm/jit_utils.cuh b/csrc/nv_internal/tensorrt_llm/deep_gemm/jit_utils.cuh index 25c47eb8f6..8e26486d4d 100644 --- a/csrc/nv_internal/tensorrt_llm/deep_gemm/jit_utils.cuh +++ b/csrc/nv_internal/tensorrt_llm/deep_gemm/jit_utils.cuh @@ -16,6 +16,7 @@ */ #pragma once +#include #include #include @@ -67,7 +68,7 @@ GemmConfig get_best_gemm_config(uint32_t shape_m, uint32_t shape_n, uint32_t sha namespace deep_gemm::jit { -std::string gemm_type_to_string(deep_gemm::GemmType gemm_type) { +inline std::string gemm_type_to_string(deep_gemm::GemmType gemm_type) { switch (gemm_type) { case deep_gemm::GemmType::Normal: return std::string("Normal"); @@ -85,10 +86,10 @@ std::string gemm_type_to_string(deep_gemm::GemmType gemm_type) { } } -int div_up(int a, int b) { return (a + b - 1) / b; } +inline int div_up(int a, int b) { return (a + b - 1) / b; } -int get_smem_size(int num_stages, int k, int block_m, int block_n, int block_k = 128, - bool swap_ab = false) { +inline int get_smem_size(int num_stages, int k, int block_m, int block_n, int block_k = 128, + bool swap_ab = false) { if (!swap_ab) { int smem_d = block_m * block_n * 2; int smem_a_per_stage = block_m * block_k; @@ -126,16 +127,16 @@ int get_smem_size(int num_stages, int k, int block_m, int block_n, int block_k = } } -bool is_tma_multicast_legal(int n, int block_n, int num_tma_multicast, int num_sms) { +inline bool is_tma_multicast_legal(int n, int block_n, int num_tma_multicast, int num_sms) { if (num_tma_multicast == 1) { return true; } return (n % (block_n * num_tma_multicast) == 0) && num_sms % num_tma_multicast == 0; } -GemmConfig get_best_gemm_config(uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, - int num_groups, int num_device_sms, - bool is_grouped_contiguous = false, bool swap_ab = false) { +inline GemmConfig get_best_gemm_config(uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + int num_groups, int num_device_sms, + bool is_grouped_contiguous = false, bool swap_ab = false) { // Choose candidate block sizes std::vector block_ms; block_ms.push_back((!is_grouped_contiguous && shape_m <= 64) ? 64 : 128); diff --git a/csrc/nv_internal/tensorrt_llm/deep_gemm/runtime.cuh b/csrc/nv_internal/tensorrt_llm/deep_gemm/runtime.cuh index f4e6ab124e..f960af632d 100644 --- a/csrc/nv_internal/tensorrt_llm/deep_gemm/runtime.cuh +++ b/csrc/nv_internal/tensorrt_llm/deep_gemm/runtime.cuh @@ -181,6 +181,6 @@ class RuntimeCache { }; // Global function to access the singleton -RuntimeCache& getGlobalRuntimeCache() { return RuntimeCache::getInstance(); } +inline RuntimeCache& getGlobalRuntimeCache() { return RuntimeCache::getInstance(); } } // namespace deep_gemm::jit diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 7b53c3f82c..07728eee39 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -328,6 +328,14 @@ def get_cutlass_fused_moe_module(backend: str = "100", use_fast_build: bool = Fa else: raise ValueError(f"Invalid backend: {backend}") + # Set DeepGEMM JIT include directories after module is loaded + from ..jit import env as jit_env + + deepgemm_include_dir = str( + jit_env.FLASHINFER_CSRC_DIR / "nv_internal" / "tensorrt_llm" + ) + module.set_deepgemm_jit_include_dirs([deepgemm_include_dir]) + class MoERunner(TunableRunner): # avoid overhead of creating a new runner in forward pass runner_dict: Dict[ diff --git a/flashinfer/jit/fused_moe.py b/flashinfer/jit/fused_moe.py index 152d92f161..e890a76681 100644 --- a/flashinfer/jit/fused_moe.py +++ b/flashinfer/jit/fused_moe.py @@ -164,7 +164,9 @@ def gen_cutlass_fused_moe_module( jit_env.FLASHINFER_CSRC_DIR / "nv_internal/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm.cu", jit_env.FLASHINFER_CSRC_DIR - / "fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu", + / "fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu", + jit_env.FLASHINFER_CSRC_DIR + / "fused_moe/cutlass_backend/deepgemm_jit_setup.cu", jit_env.FLASHINFER_CSRC_DIR / "fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu", # Add all generated kernels