Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions csrc/fused_moe/cutlass_backend/deepgemm_jit_setup.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <tvm/ffi/extra/module.h>

#include <filesystem>

#include "nv_internal/tensorrt_llm/deep_gemm/compiler.cuh"

namespace flashinfer {

void set_deepgemm_jit_include_dirs(tvm::ffi::Array<tvm::ffi::String> include_dirs) {
std::vector<std::filesystem::path> dirs;
for (const auto& dir : include_dirs) {
dirs.push_back(std::filesystem::path(std::string(dir)));
}
deep_gemm::jit::Compiler::setIncludeDirs(dirs);
}

} // namespace flashinfer

TVM_FFI_DLL_EXPORT_TYPED_FUNC(set_deepgemm_jit_include_dirs,
flashinfer::set_deepgemm_jit_include_dirs);
98 changes: 25 additions & 73 deletions csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
#include "nvrtc.h"
#include "runtime.cuh"
#include "scheduler.cuh"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/logger.h"

#ifdef _WIN32
#include <windows.h>
Expand All @@ -44,7 +47,7 @@
namespace deep_gemm::jit {

// Generate a unique ID for temporary directories to avoid collisions
std::string generateUniqueId() {
inline std::string generateUniqueId() {
// Use current time and random number to generate a unique ID
static std::mt19937 gen(std::random_device{}());
static std::uniform_int_distribution<> distrib(0, 999999);
Expand All @@ -59,7 +62,7 @@ std::string generateUniqueId() {
return std::to_string(value) + "_" + std::to_string(random_value);
}

std::filesystem::path getDefaultUserDir() {
inline std::filesystem::path getDefaultUserDir() {
static std::filesystem::path userDir;
if (userDir.empty()) {
char const* cacheDir = getenv("TRTLLM_DG_CACHE_DIR");
Expand Down Expand Up @@ -91,7 +94,7 @@ inline std::filesystem::path getTmpDir() { return getDefaultUserDir() / "tmp"; }

inline std::filesystem::path getCacheDir() { return getDefaultUserDir() / "cache"; }

std::string getNvccCompiler() {
inline std::string getNvccCompiler() {
static std::string compiler;
if (compiler.empty()) {
// Check environment variable
Expand Down Expand Up @@ -121,75 +124,21 @@ std::string getNvccCompiler() {
return compiler;
}

std::vector<std::filesystem::path> getJitIncludeDirs() {
inline std::vector<std::filesystem::path>& getJitIncludeDirs() {
static std::vector<std::filesystem::path> includeDirs;
if (includeDirs.empty()) {
// Command to execute - try pip first, fallback to uv pip
char const* cmd =
"pip show flashinfer-python 2>/dev/null || uv pip show flashinfer-python 2>/dev/null";

// Buffer to store the output
std::array<char, 128> buffer;
std::string result;

// Open pipe to command
#ifdef _MSC_VER
FILE* pipe = _popen(cmd, "r");
#else
FILE* pipe = popen(cmd, "r");
#endif

if (pipe) {
// Read the output
while (fgets(buffer.data(), buffer.size(), pipe) != nullptr) {
result += buffer.data();
}

// Close the pipe
#ifdef _MSC_VER
_pclose(pipe);
#else
pclose(pipe);
#endif

// Parse the location using regex
// `pip show tensorrt_llm` will output something like:
// Location: /usr/local/lib/python3.12/dist-packages
// Editable project location: /code
std::regex locationRegex("(Location|Editable project location): (.+)");

// Find all matches
auto match_begin = std::sregex_iterator(result.begin(), result.end(), locationRegex);
auto match_end = std::sregex_iterator();

// Get the number of matches
auto match_count = std::distance(match_begin, match_end);

if (match_count > 0) {
// Get the last match
auto last_match_iter = match_begin;
std::advance(last_match_iter, match_count - 1);

// Get the path from the second capture group
std::string location = last_match_iter->str(2);
location.erase(location.find_last_not_of(" \n\r\t") + 1);

// Set the include directory based on the package location
includeDirs.push_back(std::filesystem::path(location) / "flashinfer" / "data" / "csrc" /
"nv_internal" / "tensorrt_llm");
}
} else {
TLLM_LOG_WARNING("Failed to find FlashInfer installation, DeepGEMM will be disabled.");
}
}
return includeDirs;
}

std::string generateKernel(uint32_t const shape_n, uint32_t const shape_k, uint32_t const block_m,
uint32_t const block_n, uint32_t const block_k,
uint32_t const num_groups, uint32_t const num_stages,
uint32_t const num_tma_multicast, deep_gemm::GemmType const gemm_type,
bool swapAB = false) {
inline void setJitIncludeDirs(std::vector<std::filesystem::path> const& dirs) {
static std::vector<std::filesystem::path>& includeDirs = getJitIncludeDirs();
includeDirs = dirs;
}

inline std::string generateKernel(uint32_t const shape_n, uint32_t const shape_k,
uint32_t const block_m, uint32_t const block_n,
uint32_t const block_k, uint32_t const num_groups,
uint32_t const num_stages, uint32_t const num_tma_multicast,
deep_gemm::GemmType const gemm_type, bool swapAB = false) {
constexpr uint32_t kNumTMAThreads = 128;
constexpr uint32_t kNumMathThreadsPerGroup = 128;

Expand Down Expand Up @@ -289,7 +238,12 @@ class Compiler {
return instance;
}

[[nodiscard]] bool isValid() const { return !includeDirs_.empty(); }
[[nodiscard]] bool isValid() const { return !getJitIncludeDirs().empty(); }

// Set include directories before the singleton is initialized
static void setIncludeDirs(std::vector<std::filesystem::path> const& dirs) {
setJitIncludeDirs(dirs);
}

// Build function
Runtime* build(uint32_t const shape_n, uint32_t const shape_k, uint32_t const block_m,
Expand Down Expand Up @@ -362,7 +316,7 @@ class Compiler {
std::filesystem::create_directories(path);
}

for (auto const& dir : includeDirs_) {
for (auto const& dir : getJitIncludeDirs()) {
flags.push_back("-I" + dir.string());
}

Expand Down Expand Up @@ -518,10 +472,8 @@ class Compiler {
}

private:
std::vector<std::filesystem::path> includeDirs_;

// Private constructor for singleton pattern
Compiler() : includeDirs_(getJitIncludeDirs()) {
Compiler() {
// Create necessary directories
if (kJitUseNvcc || kJitDumpCubin) {
std::filesystem::create_directories(getTmpDir());
Expand Down
17 changes: 9 additions & 8 deletions csrc/nv_internal/tensorrt_llm/deep_gemm/jit_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/

#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include <nvrtc.h>

Expand Down Expand Up @@ -67,7 +68,7 @@ GemmConfig get_best_gemm_config(uint32_t shape_m, uint32_t shape_n, uint32_t sha

namespace deep_gemm::jit {

std::string gemm_type_to_string(deep_gemm::GemmType gemm_type) {
inline std::string gemm_type_to_string(deep_gemm::GemmType gemm_type) {
switch (gemm_type) {
case deep_gemm::GemmType::Normal:
return std::string("Normal");
Expand All @@ -85,10 +86,10 @@ std::string gemm_type_to_string(deep_gemm::GemmType gemm_type) {
}
}

int div_up(int a, int b) { return (a + b - 1) / b; }
inline int div_up(int a, int b) { return (a + b - 1) / b; }

int get_smem_size(int num_stages, int k, int block_m, int block_n, int block_k = 128,
bool swap_ab = false) {
inline int get_smem_size(int num_stages, int k, int block_m, int block_n, int block_k = 128,
bool swap_ab = false) {
if (!swap_ab) {
int smem_d = block_m * block_n * 2;
int smem_a_per_stage = block_m * block_k;
Expand Down Expand Up @@ -126,16 +127,16 @@ int get_smem_size(int num_stages, int k, int block_m, int block_n, int block_k =
}
}

bool is_tma_multicast_legal(int n, int block_n, int num_tma_multicast, int num_sms) {
inline bool is_tma_multicast_legal(int n, int block_n, int num_tma_multicast, int num_sms) {
if (num_tma_multicast == 1) {
return true;
}
return (n % (block_n * num_tma_multicast) == 0) && num_sms % num_tma_multicast == 0;
}

GemmConfig get_best_gemm_config(uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
int num_groups, int num_device_sms,
bool is_grouped_contiguous = false, bool swap_ab = false) {
inline GemmConfig get_best_gemm_config(uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
int num_groups, int num_device_sms,
bool is_grouped_contiguous = false, bool swap_ab = false) {
// Choose candidate block sizes
std::vector<int> block_ms;
block_ms.push_back((!is_grouped_contiguous && shape_m <= 64) ? 64 : 128);
Expand Down
2 changes: 1 addition & 1 deletion csrc/nv_internal/tensorrt_llm/deep_gemm/runtime.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,6 @@ class RuntimeCache {
};

// Global function to access the singleton
RuntimeCache& getGlobalRuntimeCache() { return RuntimeCache::getInstance(); }
inline RuntimeCache& getGlobalRuntimeCache() { return RuntimeCache::getInstance(); }

} // namespace deep_gemm::jit
8 changes: 8 additions & 0 deletions flashinfer/fused_moe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,14 @@ def get_cutlass_fused_moe_module(backend: str = "100", use_fast_build: bool = Fa
else:
raise ValueError(f"Invalid backend: {backend}")

# Set DeepGEMM JIT include directories after module is loaded
from ..jit import env as jit_env

deepgemm_include_dir = str(
jit_env.FLASHINFER_CSRC_DIR / "nv_internal" / "tensorrt_llm"
)
module.set_deepgemm_jit_include_dirs([deepgemm_include_dir])

class MoERunner(TunableRunner):
# avoid overhead of creating a new runner in forward pass
runner_dict: Dict[
Expand Down
4 changes: 3 additions & 1 deletion flashinfer/jit/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,9 @@ def gen_cutlass_fused_moe_module(
jit_env.FLASHINFER_CSRC_DIR
/ "nv_internal/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm.cu",
jit_env.FLASHINFER_CSRC_DIR
/ "fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu",
/ "fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu",
jit_env.FLASHINFER_CSRC_DIR
/ "fused_moe/cutlass_backend/deepgemm_jit_setup.cu",
jit_env.FLASHINFER_CSRC_DIR
/ "fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu",
# Add all generated kernels
Expand Down