Skip to content

Commit 3d11a48

Browse files
committed
Cleanup and streamline
1 parent e7abe89 commit 3d11a48

File tree

10 files changed

+55
-135
lines changed

10 files changed

+55
-135
lines changed

benchmarks/bench_logging_overhead.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,6 @@
55
This script creates decorated and undecorated versions of a test function
66
(torch.matmul) and compares their performance to accurately measure logging overhead.
77
8-
Why torch.matmul instead of bmm_fp8?
9-
- bmm_fp8 is already decorated in the FlashInfer source code
10-
- Using it would cause double-decoration and inaccurate results
11-
- torch.matmul gives us a clean baseline to measure pure decorator overhead
12-
138
Usage:
149
# Set the logging level before running
1510
export FLASHINFER_LOGLEVEL_DBG=3
@@ -37,30 +32,18 @@
3732
LOG_DEST = os.environ.get("FLASHINFER_LOGDEST_DBG", "/tmp/flashinfer_benchmark_log.txt")
3833

3934
# Import the decorator
40-
try:
41-
from flashinfer.api_logging import flashinfer_api_log
42-
except ImportError as e:
43-
print(f"Error: Could not import flashinfer: {e}")
44-
print("Make sure flashinfer is installed.")
45-
exit(1)
35+
from flashinfer.api_logging import flashinfer_log
4636

4737

4838
# Create two versions of a test function:
4939
# 1. Undecorated (baseline)
5040
# 2. Decorated (with logging)
51-
#
52-
# We use a simple torch.matmul instead of bmm_fp8 because bmm_fp8 is already
53-
# decorated in the source code, which would cause double-decoration.
54-
55-
5641
def test_matmul_undecorated(A, B):
57-
"""Undecorated version - baseline for comparison."""
5842
return torch.matmul(A, B)
5943

6044

61-
@flashinfer_api_log
45+
@flashinfer_log
6246
def test_matmul_decorated(A, B):
63-
"""Decorated version - with API logging."""
6447
return torch.matmul(A, B)
6548

6649

flashinfer/api_logging.py

Lines changed: 3 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@ def _substitute_process_id(path: str) -> str:
3232
3333
This is useful for multi-process/multi-GPU environments where each process
3434
needs its own log file.
35-
36-
Example: "flashinfer_log_%i.txt" -> "flashinfer_log_12345.txt"
3735
"""
3836
if "%i" in path:
3937
return path.replace("%i", str(os.getpid()))
@@ -46,38 +44,6 @@ def _substitute_process_id(path: str) -> str:
4644
os.environ.get("FLASHINFER_LOGDEST_DBG", "stdout")
4745
)
4846

49-
# Enable cuDNN, cuBLAS, and cuBLASLt API logging when FlashInfer logging level >= 5
50-
# Only override if the user hasn't already configured the logging switch
51-
# If the switch is not set, we override both the switch and destination as a bundle
52-
if _API_LOG_LEVEL >= 5:
53-
# cuBLAS logging: Check switch, set both switch and destination
54-
if "CUBLAS_LOGINFO_DBG" not in os.environ:
55-
os.environ["CUBLAS_LOGINFO_DBG"] = "1"
56-
os.environ["CUBLAS_LOGDEST_DBG"] = _substitute_process_id(
57-
"flashinfer_cublas_log_%i.txt"
58-
)
59-
60-
# cuBLASLt logging: Check switch, set both switch and destination
61-
if "CUBLASLT_LOG_LEVEL" not in os.environ:
62-
os.environ["CUBLASLT_LOG_LEVEL"] = "2"
63-
os.environ["CUBLASLT_LOG_FILE"] = _substitute_process_id(
64-
"flashinfer_cublaslt_log_%i.txt"
65-
)
66-
67-
# cuDNN backend logging: Check switch, set both switch and destination
68-
if "CUDNN_LOGLEVEL_DBG" not in os.environ:
69-
os.environ["CUDNN_LOGLEVEL_DBG"] = "2.5"
70-
os.environ["CUDNN_LOGDEST_DBG"] = _substitute_process_id(
71-
"flashinfer_cudnn_backend_log_%i.txt"
72-
)
73-
74-
# cuDNN frontend logging: Check switch, set both switch and destination
75-
if "CUDNN_FRONTEND_LOG_INFO" not in os.environ:
76-
os.environ["CUDNN_FRONTEND_LOG_INFO"] = "1"
77-
os.environ["CUDNN_FRONTEND_LOG_FILE"] = _substitute_process_id(
78-
"flashinfer_cudnn_frontend_log_%i.txt"
79-
)
80-
8147
# Create logger using Python's logging library
8248
_logger = logging.getLogger("flashinfer.api")
8349

@@ -185,28 +151,6 @@ def _log_system_info():
185151
# PyTorch version
186152
lines.append(f"PyTorch version: {torch.__version__}")
187153

188-
# cuDNN/cuBLAS/cuBLASLt logging status
189-
if _API_LOG_LEVEL >= 5:
190-
lines.append("")
191-
lines.append("cuDNN/cuBLAS/cuBLASLt Logging: Enabled (Level 5)")
192-
cublas_info = os.environ.get("CUBLAS_LOGINFO_DBG", "not set")
193-
cublas_dest = os.environ.get("CUBLAS_LOGDEST_DBG", "not set")
194-
cublaslt_level = os.environ.get("CUBLASLT_LOG_LEVEL", "not set")
195-
cublaslt_file = os.environ.get("CUBLASLT_LOG_FILE", "not set")
196-
cudnn_level = os.environ.get("CUDNN_LOGLEVEL_DBG", "not set")
197-
cudnn_dest = os.environ.get("CUDNN_LOGDEST_DBG", "not set")
198-
cudnn_fe_info = os.environ.get("CUDNN_FRONTEND_LOG_INFO", "not set")
199-
cudnn_fe_file = os.environ.get("CUDNN_FRONTEND_LOG_FILE", "not set")
200-
201-
lines.append(f" CUBLAS_LOGINFO_DBG={cublas_info}")
202-
lines.append(f" CUBLAS_LOGDEST_DBG={cublas_dest}")
203-
lines.append(f" CUBLASLT_LOG_LEVEL={cublaslt_level}")
204-
lines.append(f" CUBLASLT_LOG_FILE={cublaslt_file}")
205-
lines.append(f" CUDNN_LOGLEVEL_DBG={cudnn_level}")
206-
lines.append(f" CUDNN_LOGDEST_DBG={cudnn_dest}")
207-
lines.append(f" CUDNN_FRONTEND_LOG_INFO={cudnn_fe_info}")
208-
lines.append(f" CUDNN_FRONTEND_LOG_FILE={cudnn_fe_file}")
209-
210154
except Exception as e:
211155
lines.append(f"Error gathering system information: {e}")
212156

@@ -519,7 +463,7 @@ def _log_function_outputs(func_name: str, result: Any, level: int) -> None:
519463
_logger.debug("\n".join(lines))
520464

521465

522-
def flashinfer_api_log(func: Callable = None) -> Callable:
466+
def flashinfer_log(func: Callable = None) -> Callable:
523467
"""
524468
Decorator to log FlashInfer API calls using Python's logging library.
525469
@@ -544,7 +488,7 @@ def flashinfer_api_log(func: Callable = None) -> Callable:
544488
--------
545489
Basic usage:
546490
547-
>>> @flashinfer_api_log
491+
>>> @flashinfer_log
548492
... def my_function(x, y):
549493
... return x + y
550494
@@ -563,13 +507,7 @@ def flashinfer_api_log(func: Callable = None) -> Callable:
563507
- **CUDA Graph Compatibility**: At level 5, tensor statistics (min/max/mean/nan_count)
564508
are automatically skipped during CUDA graph capture to avoid synchronization issues.
565509
The message "[statistics skipped: CUDA graph capture in progress]" will be logged.
566-
- **cuDNN/cuBLAS/cuBLASLt Integration**: At level 5, if not already set by the user, the following
567-
environment variables are automatically configured to enable cuDNN, cuBLAS, and cuBLASLt logging:
568-
- CUBLAS_LOGINFO_DBG=1, CUBLAS_LOGDEST_DBG=flashinfer_cublas_log_%i.txt
569-
- CUBLASLT_LOG_LEVEL=2, CUBLASLT_LOG_FILE=flashinfer_cublaslt_log_%i.txt
570-
- CUDNN_LOGLEVEL_DBG=2.5, CUDNN_LOGDEST_DBG=flashinfer_cudnn_backend_log_%i.txt
571-
- CUDNN_FRONTEND_LOG_INFO=1, CUDNN_FRONTEND_LOG_FILE=flashinfer_cudnn_frontend_log_%i.txt
572-
The %i pattern is automatically replaced with the process ID for multi-process environments.
510+
- The %i pattern is automatically replaced with the process ID for multi-process environments.
573511
- The logger does not propagate to the root logger to avoid duplicate logs.
574512
"""
575513
# If logging is disabled, return original function with zero overhead
@@ -621,7 +559,6 @@ def wrapper(*args, **kwargs):
621559

622560
return wrapper
623561

624-
# Support both @flashinfer_api_log and @flashinfer_api_log()
625562
if func is None:
626563
return decorator
627564
return decorator(func)

flashinfer/cudnn/decode.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import torch
55

6-
from ..api_logging import flashinfer_api_log
6+
from ..api_logging import flashinfer_log
77
from .utils import get_cudnn_fmha_gen_module
88

99
try:
@@ -253,7 +253,7 @@ def _batch_decode_with_kv_cache(
253253
return out
254254

255255

256-
@flashinfer_api_log
256+
@flashinfer_log
257257
def cudnn_batch_decode_with_kv_cache(
258258
q: torch.Tensor,
259259
k_cache: torch.Tensor,

flashinfer/cudnn/prefill.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import torch
55

6-
from ..api_logging import flashinfer_api_log
6+
from ..api_logging import flashinfer_log
77
from .utils import get_cudnn_fmha_gen_module
88

99
try:
@@ -384,7 +384,7 @@ def _batch_prefill_with_kv_cache(
384384
return out, None
385385

386386

387-
@flashinfer_api_log
387+
@flashinfer_log
388388
def cudnn_batch_prefill_with_kv_cache(
389389
q: torch.Tensor,
390390
k_cache: torch.Tensor,

flashinfer/decode.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import torch
2323

24-
from .api_logging import flashinfer_api_log
24+
from .api_logging import flashinfer_log
2525
from .xqa import xqa, xqa_mla
2626
from .cudnn import cudnn_batch_decode_with_kv_cache as cudnn_batch_decode_with_kv_cache
2727
from .jit import (
@@ -313,7 +313,7 @@ def get_trtllm_gen_fmha_module():
313313
return op
314314

315315

316-
@flashinfer_api_log
316+
@flashinfer_log
317317
def single_decode_with_kv_cache_with_jit_module(
318318
jit_module: Any,
319319
q: torch.Tensor,
@@ -390,7 +390,7 @@ def single_decode_with_kv_cache(
390390
) -> Tuple[torch.Tensor, torch.Tensor]: ...
391391

392392

393-
@flashinfer_api_log
393+
@flashinfer_log
394394
def single_decode_with_kv_cache(
395395
q: torch.Tensor,
396396
k: torch.Tensor,
@@ -649,7 +649,7 @@ class BatchDecodeWithPagedKVCacheWrapper:
649649
manages the lifecycle of these data structures.
650650
"""
651651

652-
@flashinfer_api_log
652+
@flashinfer_log
653653
def __init__(
654654
self,
655655
float_workspace_buffer: torch.Tensor,
@@ -813,7 +813,7 @@ def reset_workspace_buffer(
813813
pin_memory=True,
814814
)
815815

816-
@flashinfer_api_log
816+
@flashinfer_log
817817
def plan(
818818
self,
819819
indptr: torch.Tensor,
@@ -1167,7 +1167,7 @@ def run(
11671167
window_left: Optional[int] = None,
11681168
) -> Tuple[torch.Tensor, torch.Tensor]: ...
11691169

1170-
@flashinfer_api_log
1170+
@flashinfer_log
11711171
def run(
11721172
self,
11731173
q: torch.Tensor,
@@ -2065,7 +2065,7 @@ def _fake_paged_run(
20652065
)
20662066

20672067

2068-
@flashinfer_api_log
2068+
@flashinfer_log
20692069
def trtllm_batch_decode_with_kv_cache(
20702070
query: torch.Tensor,
20712071
kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
@@ -2339,7 +2339,7 @@ def trtllm_batch_decode_with_kv_cache(
23392339

23402340

23412341
# xqa uses NHD layout
2342-
@flashinfer_api_log
2342+
@flashinfer_log
23432343
def xqa_batch_decode_with_kv_cache(
23442344
query: torch.Tensor,
23452345
kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
@@ -2524,7 +2524,7 @@ def _check_trtllm_gen_mla_shape(
25242524
)
25252525

25262526

2527-
@flashinfer_api_log
2527+
@flashinfer_log
25282528
def trtllm_batch_decode_with_kv_cache_mla(
25292529
query: torch.Tensor,
25302530
kv_cache: torch.Tensor,
@@ -2686,7 +2686,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
26862686
raise ValueError(f"Backend {backend} not supported")
26872687

26882688

2689-
@flashinfer_api_log
2689+
@flashinfer_log
26902690
def xqa_batch_decode_with_kv_cache_mla(
26912691
query: torch.Tensor,
26922692
kv_cache: torch.Tensor,

flashinfer/fused_moe/core.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from typing import Any, Dict, List, Optional, Tuple, Union
2121
import torch
2222

23-
from ..api_logging import flashinfer_api_log
23+
from ..api_logging import flashinfer_log
2424
from ..autotuner import (
2525
AutoTuner,
2626
DynamicTensorSpec,
@@ -686,7 +686,7 @@ def _fake_cutlass_fused_moe(
686686

687687

688688
# ref: https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py#L121
689-
@flashinfer_api_log
689+
@flashinfer_log
690690
def cutlass_fused_moe(
691691
input: torch.Tensor,
692692
token_selected_experts: torch.Tensor,
@@ -1859,7 +1859,7 @@ def _fake_trtllm_fp4_block_scale_moe(
18591859
)
18601860

18611861

1862-
@flashinfer_api_log
1862+
@flashinfer_log
18631863
def trtllm_bf16_moe(
18641864
routing_logits: torch.Tensor,
18651865
routing_bias: Optional[torch.Tensor],
@@ -1940,7 +1940,7 @@ def trtllm_bf16_moe(
19401940
)
19411941

19421942

1943-
@flashinfer_api_log
1943+
@flashinfer_log
19441944
def trtllm_fp8_per_tensor_scale_moe(
19451945
routing_logits: torch.Tensor,
19461946
routing_bias: Optional[torch.Tensor],
@@ -2014,7 +2014,7 @@ def trtllm_fp8_per_tensor_scale_moe(
20142014
)
20152015

20162016

2017-
@flashinfer_api_log
2017+
@flashinfer_log
20182018
def trtllm_fp8_block_scale_moe(
20192019
routing_logits: torch.Tensor,
20202020
routing_bias: Optional[torch.Tensor],
@@ -2092,7 +2092,7 @@ def trtllm_fp8_block_scale_moe(
20922092
)
20932093

20942094

2095-
@flashinfer_api_log
2095+
@flashinfer_log
20962096
def trtllm_fp4_block_scale_moe(
20972097
routing_logits: torch.Tensor,
20982098
routing_bias: Optional[torch.Tensor],
@@ -2222,7 +2222,7 @@ def trtllm_fp4_block_scale_moe(
22222222
)
22232223

22242224

2225-
@flashinfer_api_log
2225+
@flashinfer_log
22262226
def trtllm_fp4_block_scale_routed_moe(
22272227
topk_ids: torch.Tensor,
22282228
routing_bias: Optional[torch.Tensor],

0 commit comments

Comments
 (0)