Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions csrc/fused_moe/cutlass_backend/deepgemm_jit_setup.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <tvm/ffi/extra/module.h>

#include <filesystem>

#include "nv_internal/tensorrt_llm/deep_gemm/compiler.cuh"

namespace flashinfer {

void set_deepgemm_jit_include_dirs(tvm::ffi::Array<tvm::ffi::String> include_dirs) {
std::vector<std::filesystem::path> dirs;
for (const auto& dir : include_dirs) {
dirs.push_back(std::filesystem::path(std::string(dir)));
}
deep_gemm::jit::Compiler::setIncludeDirs(dirs);
}

} // namespace flashinfer

TVM_FFI_DLL_EXPORT_TYPED_FUNC(set_deepgemm_jit_include_dirs,
flashinfer::set_deepgemm_jit_include_dirs);
82 changes: 17 additions & 65 deletions csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
#include "nvrtc.h"
#include "runtime.cuh"
#include "scheduler.cuh"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/logger.h"

#ifdef _WIN32
#include <windows.h>
Expand Down Expand Up @@ -121,70 +124,16 @@ std::string getNvccCompiler() {
return compiler;
}

std::vector<std::filesystem::path> getJitIncludeDirs() {
inline std::vector<std::filesystem::path>& getJitIncludeDirs() {
static std::vector<std::filesystem::path> includeDirs;
if (includeDirs.empty()) {
// Command to execute - try pip first, fallback to uv pip
char const* cmd =
"pip show flashinfer-python 2>/dev/null || uv pip show flashinfer-python 2>/dev/null";

// Buffer to store the output
std::array<char, 128> buffer;
std::string result;

// Open pipe to command
#ifdef _MSC_VER
FILE* pipe = _popen(cmd, "r");
#else
FILE* pipe = popen(cmd, "r");
#endif

if (pipe) {
// Read the output
while (fgets(buffer.data(), buffer.size(), pipe) != nullptr) {
result += buffer.data();
}

// Close the pipe
#ifdef _MSC_VER
_pclose(pipe);
#else
pclose(pipe);
#endif

// Parse the location using regex
// `pip show tensorrt_llm` will output something like:
// Location: /usr/local/lib/python3.12/dist-packages
// Editable project location: /code
std::regex locationRegex("(Location|Editable project location): (.+)");

// Find all matches
auto match_begin = std::sregex_iterator(result.begin(), result.end(), locationRegex);
auto match_end = std::sregex_iterator();

// Get the number of matches
auto match_count = std::distance(match_begin, match_end);

if (match_count > 0) {
// Get the last match
auto last_match_iter = match_begin;
std::advance(last_match_iter, match_count - 1);

// Get the path from the second capture group
std::string location = last_match_iter->str(2);
location.erase(location.find_last_not_of(" \n\r\t") + 1);

// Set the include directory based on the package location
includeDirs.push_back(std::filesystem::path(location) / "flashinfer" / "data" / "csrc" /
"nv_internal" / "tensorrt_llm");
}
} else {
TLLM_LOG_WARNING("Failed to find FlashInfer installation, DeepGEMM will be disabled.");
}
}
return includeDirs;
}

inline void setJitIncludeDirs(std::vector<std::filesystem::path> const& dirs) {
static std::vector<std::filesystem::path>& includeDirs = getJitIncludeDirs();
includeDirs = dirs;
}

std::string generateKernel(uint32_t const shape_n, uint32_t const shape_k, uint32_t const block_m,
uint32_t const block_n, uint32_t const block_k,
uint32_t const num_groups, uint32_t const num_stages,
Expand Down Expand Up @@ -289,7 +238,12 @@ class Compiler {
return instance;
}

[[nodiscard]] bool isValid() const { return !includeDirs_.empty(); }
[[nodiscard]] bool isValid() const { return !getJitIncludeDirs().empty(); }

// Set include directories before the singleton is initialized
static void setIncludeDirs(std::vector<std::filesystem::path> const& dirs) {
setJitIncludeDirs(dirs);
}

// Build function
Runtime* build(uint32_t const shape_n, uint32_t const shape_k, uint32_t const block_m,
Expand Down Expand Up @@ -362,7 +316,7 @@ class Compiler {
std::filesystem::create_directories(path);
}

for (auto const& dir : includeDirs_) {
for (auto const& dir : getJitIncludeDirs()) {
flags.push_back("-I" + dir.string());
}

Expand Down Expand Up @@ -518,10 +472,8 @@ class Compiler {
}

private:
std::vector<std::filesystem::path> includeDirs_;

// Private constructor for singleton pattern
Compiler() : includeDirs_(getJitIncludeDirs()) {
Compiler() {
// Create necessary directories
if (kJitUseNvcc || kJitDumpCubin) {
std::filesystem::create_directories(getTmpDir());
Expand Down
1 change: 1 addition & 0 deletions csrc/nv_internal/tensorrt_llm/deep_gemm/jit_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/

#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include <nvrtc.h>

Expand Down
8 changes: 8 additions & 0 deletions flashinfer/fused_moe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,14 @@ def get_cutlass_fused_moe_module(backend: str = "100", use_fast_build: bool = Fa
else:
raise ValueError(f"Invalid backend: {backend}")

# Set DeepGEMM JIT include directories after module is loaded
from ..jit import env as jit_env

deepgemm_include_dir = str(
jit_env.FLASHINFER_CSRC_DIR / "nv_internal" / "tensorrt_llm"
)
module.set_deepgemm_jit_include_dirs([deepgemm_include_dir])

class MoERunner(TunableRunner):
# avoid overhead of creating a new runner in forward pass
runner_dict: Dict[
Expand Down
4 changes: 3 additions & 1 deletion flashinfer/jit/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,9 @@ def gen_cutlass_fused_moe_module(
jit_env.FLASHINFER_CSRC_DIR
/ "nv_internal/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm.cu",
jit_env.FLASHINFER_CSRC_DIR
/ "fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu",
/ "fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu",
jit_env.FLASHINFER_CSRC_DIR
/ "fused_moe/cutlass_backend/deepgemm_jit_setup.cu",
jit_env.FLASHINFER_CSRC_DIR
/ "fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu",
# Add all generated kernels
Expand Down