Skip to content

Commit 5acb57b

Browse files
authored
feat: Enable API Logging for Better Debugging POC (#2108)
<!-- .github/pull_request_template.md --> ## 📌 Description tl; dr: Current PR adds a logging system for input/output tracking to aid debugging FlashInfer APIs via a `@flashinfer_api` decorator. **This PR does not label `@flashinfer_api` to every FlashInfer API -- many operations are missing labels. Further labeling is left for subsequent work.** This PR introduces a production-ready API logging infrastructure that tracks function calls, arguments, and return values via a simple one-line decorator. Any function can be decorated with the decorator to track the input/output values in the API logger. Key Features: * Logging level controlled by `FLASHINFER_LOGLEVEL` * Log destination set by `FLASHINFER_LOGDEST`; defaults to `stdout` * Zero overhead when disabled (level 0 returns original function) as seen from `benchmarks/bench_logging_overhead.py` Example usage ``` export FLASHINFER_LOGLEVEL=1 export FLASHINFER_LOGDEST="./flashinfer_api.log" python3 benchmarks/flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 fa2_tc cudnn trtllm-gen trtllm-gen-native --page_size 16 --batch_size 1 --s_qo 1 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --q_dtype bfloat16 --kv_dtype bfloat16 ``` produces log ``` ================================================================================ [2025-11-20 17:51:18] FlashInfer API Logging - System Information ================================================================================ FlashInfer version: 0.5.2 CUDA toolkit version: 13.0 cuDNN version: 91600 Number of GPUs: 1 GPU 0: NVIDIA B200 Compute capability: 10.0 (SM100) PyTorch version: 2.9.0+cu130 ================================================================================ [2025-11-20 17:51:19] FlashInfer API Call: BatchDecodeWithPagedKVCacheWrapper.__init__ [2025-11-20 17:51:19] FlashInfer API Call: BatchDecodeWithPagedKVCacheWrapper.plan [2025-11-20 17:51:19] FlashInfer API Call: BatchDecodeWithPagedKVCacheWrapper.__init__ [2025-11-20 17:51:19] FlashInfer API Call: BatchDecodeWithPagedKVCacheWrapper.plan [2025-11-20 17:51:19] FlashInfer API Call: BatchDecodeWithPagedKVCacheWrapper.__init__ [2025-11-20 17:51:19] FlashInfer API Call: BatchDecodeWithPagedKVCacheWrapper.plan [2025-11-20 17:51:19] FlashInfer API Call: BatchDecodeWithPagedKVCacheWrapper.run [2025-11-20 17:51:19] FlashInfer API Call: BatchDecodeWithPagedKVCacheWrapper.run ... ``` `export FLASHINFER_LOGLEVEL=3` produces: ``` (System Info same as above) ================================================================================ [2025-11-20 17:51:58] FlashInfer API Call: BatchDecodeWithPagedKVCacheWrapper.__init__ -------------------------------------------------------------------------------- Positional input arguments: arg[0]: <flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper object at 0x1234399e3410> arg[1]: Tensor( shape=(134217728,) stride=(1,) dtype=torch.int8 device=cuda:0 requires_grad=False is_contiguous=True ) arg[2]: 'HND' Keyword input arguments: use_cuda_graph= True use_tensor_cores= False paged_kv_indptr_buffer= Tensor( shape=(2,) stride=(1,) dtype=torch.int32 device=cuda:0 requires_grad=False is_contiguous=True ) paged_kv_indices_buffer= Tensor( shape=(6,) stride=(1,) dtype=torch.int32 device=cuda:0 requires_grad=False is_contiguous=True ) paged_kv_last_page_len_buffer= Tensor( shape=(1,) stride=(1,) dtype=torch.int32 device=cuda:0 requires_grad=False is_contiguous=True ) backend= 'fa2' Default parameters (not explicitly provided): jit_args= [DEFAULT] None Output value: None ================================================================================ ... ``` `export FLASHINFER_LOGLEVEL=5` produces: ``` (System Info same as above) ================================================================================ [2025-11-20 17:52:23] FlashInfer API Call: BatchDecodeWithPagedKVCacheWrapper.__init__ -------------------------------------------------------------------------------- Positional input arguments: arg[0]: <flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper object at 0x7a9fd9a88c0> arg[1]: Tensor( shape=(134217728,) stride=(1,) dtype=torch.int8 device=cuda:0 requires_grad=False is_contiguous=True min=0 max=0 mean=0.000000 ) arg[2]: 'HND' Keyword input arguments: use_cuda_graph= True use_tensor_cores= False paged_kv_indptr_buffer= Tensor( shape=(2,) stride=(1,) dtype=torch.int32 device=cuda:0 requires_grad=False is_contiguous=True min=0 max=6 mean=3.000000 ) paged_kv_indices_buffer= Tensor( shape=(6,) stride=(1,) dtype=torch.int32 device=cuda:0 requires_grad=False is_contiguous=True min=0 max=5 mean=2.500000 ) paged_kv_last_page_len_buffer= Tensor( shape=(1,) stride=(1,) dtype=torch.int32 device=cuda:0 requires_grad=False is_contiguous=True min=4 max=4 mean=4.000000 ) backend= 'fa2' Default parameters (not explicitly provided): jit_args= [DEFAULT] None Output value: None ================================================================================ ... ``` <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * Added API logging feature configurable via environment variables (FLASHINFER_LOGLEVEL for level control, FLASHINFER_LOGDEST for destination) * Supports five verbosity levels with function names, inputs, outputs, metadata, and tensor statistics * Zero-overhead operation when disabled * **Tests** * Added comprehensive logging test suite * **Documentation** * Added logging configuration and usage documentation <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 7128c7b commit 5acb57b

File tree

13 files changed

+1667
-0
lines changed

13 files changed

+1667
-0
lines changed

README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,20 @@ o = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=False) # prefill att
169169

170170
Check out [documentation](https://docs.flashinfer.ai/) for usage of batch decode/append/prefill kernels and shared-prefix cascading kernels.
171171

172+
## API Logging
173+
174+
FlashInfer provides comprehensive API logging for debugging. Enable it using environment variables:
175+
176+
```bash
177+
# Enable logging (levels: 0=off (default), 1=basic, 3=detailed, 5=statistics)
178+
export FLASHINFER_LOGLEVEL=3
179+
180+
# Set log destination (stdout (default), stderr, or file path)
181+
export FLASHINFER_LOGDEST=stdout
182+
```
183+
184+
For detailed information about logging levels, configuration, and advanced features, see [LOGGING.md](LOGGING.md).
185+
172186
## Custom Attention Variants
173187

174188
Starting from FlashInfer v0.2, users can customize their own attention variants with additional parameters. For more details, refer to our [JIT examples](https://github.com/flashinfer-ai/flashinfer/blob/main/tests/utils/test_jit_example.py).
Lines changed: 333 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,333 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Benchmark script to measure the overhead of API logging at different levels.
4+
5+
This script creates decorated and undecorated versions of a test function
6+
(torch.matmul) and compares their performance to accurately measure logging overhead.
7+
8+
Usage:
9+
# Set the logging level before running
10+
export FLASHINFER_LOGLEVEL=3
11+
python bench_logging_overhead.py
12+
13+
# Or run with different levels
14+
FLASHINFER_LOGLEVEL=0 python bench_logging_overhead.py
15+
FLASHINFER_LOGLEVEL=1 python bench_logging_overhead.py
16+
FLASHINFER_LOGLEVEL=3 python bench_logging_overhead.py
17+
FLASHINFER_LOGLEVEL=5 python bench_logging_overhead.py
18+
19+
# Or use the helper script to run all levels
20+
bash benchmark_all_levels.sh
21+
"""
22+
23+
import os
24+
import sys
25+
import time
26+
import torch
27+
import numpy as np
28+
from typing import List, Tuple
29+
30+
# Get logging level BEFORE importing flashinfer
31+
LOGGING_LEVEL = int(os.environ.get("FLASHINFER_LOGLEVEL", "0"))
32+
LOG_DEST = os.environ.get("FLASHINFER_LOGDEST", "/tmp/flashinfer_benchmark_log.txt")
33+
34+
# Import the decorator
35+
from flashinfer.api_logging import flashinfer_api
36+
37+
38+
# Create two versions of a test function:
39+
# 1. Undecorated (baseline)
40+
# 2. Decorated (with logging)
41+
def test_matmul_undecorated(A, B):
42+
return torch.matmul(A, B)
43+
44+
45+
@flashinfer_api
46+
def test_matmul_decorated(A, B):
47+
return torch.matmul(A, B)
48+
49+
50+
class BenchmarkResults:
51+
"""Store and display benchmark results."""
52+
53+
def __init__(self):
54+
self.undecorated_times = []
55+
self.decorated_times = []
56+
57+
def set_undecorated(self, times: List[float]):
58+
"""Set benchmark results for undecorated function."""
59+
self.undecorated_times = times
60+
61+
def set_decorated(self, times: List[float]):
62+
"""Set benchmark results for decorated function."""
63+
self.decorated_times = times
64+
65+
def print_summary(self, logging_level: int):
66+
"""Print a summary of benchmark results."""
67+
print("\n" + "=" * 80)
68+
print("BENCHMARK RESULTS")
69+
print("=" * 80)
70+
71+
undecorated_mean = np.mean(self.undecorated_times)
72+
undecorated_std = np.std(self.undecorated_times)
73+
74+
decorated_mean = np.mean(self.decorated_times)
75+
decorated_std = np.std(self.decorated_times)
76+
77+
overhead_abs = (decorated_mean - undecorated_mean) * 1000 # ms
78+
overhead_pct = (
79+
((decorated_mean - undecorated_mean) / undecorated_mean * 100)
80+
if undecorated_mean > 0
81+
else 0
82+
)
83+
84+
print(
85+
f"\n{'Version':<20} {'Mean (ms)':<12} {'Std (ms)':<12} {'Median (ms)':<12}"
86+
)
87+
print("-" * 80)
88+
print(
89+
f"{'Undecorated':<20} {undecorated_mean * 1000:<12.4f} {undecorated_std * 1000:<12.4f} {np.median(self.undecorated_times) * 1000:<12.4f}"
90+
)
91+
print(
92+
f"{'Decorated':<20} {decorated_mean * 1000:<12.4f} {decorated_std * 1000:<12.4f} {np.median(self.decorated_times) * 1000:<12.4f}"
93+
)
94+
95+
print("\n" + "=" * 80)
96+
print("OVERHEAD ANALYSIS")
97+
print("=" * 80)
98+
print(f"\nLogging Level: {logging_level}")
99+
print(f"Absolute overhead: {overhead_abs:.4f} ms")
100+
print(f"Relative overhead: {overhead_pct:.2f}%")
101+
102+
print("\n" + "=" * 80)
103+
print("DETAILED STATISTICS")
104+
print("=" * 80)
105+
106+
print("\nUndecorated (baseline):")
107+
print(f" Mean: {undecorated_mean * 1000:.4f} ms")
108+
print(f" Median: {np.median(self.undecorated_times) * 1000:.4f} ms")
109+
print(f" Std: {undecorated_std * 1000:.4f} ms")
110+
print(f" Min: {np.min(self.undecorated_times) * 1000:.4f} ms")
111+
print(f" Max: {np.max(self.undecorated_times) * 1000:.4f} ms")
112+
113+
print("\nDecorated (with logging):")
114+
print(f" Mean: {decorated_mean * 1000:.4f} ms")
115+
print(f" Median: {np.median(self.decorated_times) * 1000:.4f} ms")
116+
print(f" Std: {decorated_std * 1000:.4f} ms")
117+
print(f" Min: {np.min(self.decorated_times) * 1000:.4f} ms")
118+
print(f" Max: {np.max(self.decorated_times) * 1000:.4f} ms")
119+
120+
121+
def setup_test_inputs(
122+
batch_size: int = 32,
123+
m: int = 512,
124+
n: int = 512,
125+
k: int = 512,
126+
device: str = "cuda:0",
127+
) -> Tuple[torch.Tensor, torch.Tensor]:
128+
"""
129+
Set up test inputs for matmul.
130+
131+
Parameters
132+
----------
133+
batch_size : int
134+
Batch size for the matrix multiplication
135+
m, n, k : int
136+
Matrix dimensions
137+
device : str
138+
Device to use
139+
140+
Returns
141+
-------
142+
A, B : torch.Tensor
143+
Input tensors for matrix multiplication
144+
"""
145+
# Create random tensors
146+
A = torch.randn(batch_size, m, k, dtype=torch.float16, device=device)
147+
B = torch.randn(batch_size, k, n, dtype=torch.float16, device=device)
148+
149+
return A, B
150+
151+
152+
def warmup(func, A, B, num_warmup: int = 10):
153+
"""Warmup the GPU and JIT compilation."""
154+
for _ in range(num_warmup):
155+
_ = func(A, B)
156+
torch.cuda.synchronize()
157+
158+
159+
def benchmark_function(
160+
func, func_name: str, A, B, num_iterations: int = 100
161+
) -> List[float]:
162+
"""
163+
Benchmark a specific function.
164+
165+
Parameters
166+
----------
167+
func : callable
168+
Function to benchmark
169+
func_name : str
170+
Name of the function (for display)
171+
A, B : torch.Tensor
172+
Input tensors for matrix multiplication
173+
num_iterations : int
174+
Number of iterations to run
175+
176+
Returns
177+
-------
178+
List[float]
179+
List of execution times in seconds
180+
"""
181+
print(f"\nBenchmarking: {func_name}")
182+
print(f" Running {num_iterations} iterations...")
183+
184+
times = []
185+
186+
for _ in range(num_iterations):
187+
# Synchronize before timing
188+
torch.cuda.synchronize()
189+
190+
# Time the execution
191+
start = time.perf_counter()
192+
_ = func(A, B)
193+
torch.cuda.synchronize()
194+
end = time.perf_counter()
195+
196+
elapsed = end - start
197+
times.append(elapsed)
198+
199+
print(f" Complete. Mean time: {np.mean(times) * 1000:.4f} ms")
200+
201+
return times
202+
203+
204+
def main():
205+
"""Main benchmark function."""
206+
print("=" * 80)
207+
print("FlashInfer API Logging Overhead Benchmark")
208+
print("=" * 80)
209+
210+
# Display logging configuration
211+
print("\nLogging Configuration:")
212+
print(f" FLASHINFER_LOGLEVEL = {LOGGING_LEVEL}")
213+
print(f" FLASHINFER_LOGDEST = {LOG_DEST}")
214+
215+
# Get level name
216+
level_names = {
217+
0: "No logging (zero-overhead)",
218+
1: "Function name only",
219+
3: "Name + inputs/outputs + metadata",
220+
5: "Name + inputs/outputs + metadata + statistics",
221+
}
222+
print(f" Level description: {level_names.get(LOGGING_LEVEL, 'Unknown')}")
223+
224+
# Check if CUDA is available
225+
if not torch.cuda.is_available():
226+
print("\nError: CUDA is not available. This benchmark requires a CUDA device.")
227+
exit(1)
228+
229+
device = "cuda:0"
230+
print(f"\nDevice: {device}")
231+
print(f"Device Name: {torch.cuda.get_device_name(device)}")
232+
233+
# Setup test inputs
234+
print("\nSetting up test inputs...")
235+
batch_size = 32
236+
m, n, k = 128, 128, 128
237+
print(f" Batch size: {batch_size}")
238+
print(f" Matrix dimensions: [{batch_size}, {m}, {k}] @ [{batch_size}, {k}, {n}]")
239+
240+
A, B = setup_test_inputs(batch_size, m, n, k, device)
241+
242+
# Benchmark parameters
243+
num_iterations = 100
244+
print("\nBenchmark parameters:")
245+
print(f" Iterations: {num_iterations}")
246+
print(" Warmup iterations: 10")
247+
248+
# Clear log file before starting
249+
if os.path.exists(LOG_DEST):
250+
os.remove(LOG_DEST)
251+
252+
print("\n" + "=" * 80)
253+
print("WARMUP PHASE")
254+
print("=" * 80)
255+
256+
# Warmup undecorated version
257+
print("\nWarming up undecorated version...")
258+
warmup(test_matmul_undecorated, A, B, num_warmup=10)
259+
print(" Complete.")
260+
261+
# Warmup decorated version
262+
print("\nWarming up decorated version...")
263+
warmup(test_matmul_decorated, A, B, num_warmup=10)
264+
print(" Complete.")
265+
266+
print("\n" + "=" * 80)
267+
print("BENCHMARK PHASE")
268+
print("=" * 80)
269+
270+
# Store results
271+
results = BenchmarkResults()
272+
273+
# Benchmark undecorated version
274+
undecorated_times = benchmark_function(
275+
test_matmul_undecorated, "Undecorated (baseline)", A, B, num_iterations
276+
)
277+
results.set_undecorated(undecorated_times)
278+
279+
# Benchmark decorated version
280+
decorated_times = benchmark_function(
281+
test_matmul_decorated,
282+
f"Decorated (logging level {LOGGING_LEVEL})",
283+
A,
284+
B,
285+
num_iterations,
286+
)
287+
results.set_decorated(decorated_times)
288+
289+
# Print summary
290+
results.print_summary(LOGGING_LEVEL)
291+
292+
# Check log file size
293+
if LOGGING_LEVEL > 0 and os.path.exists(LOG_DEST):
294+
log_size = os.path.getsize(LOG_DEST)
295+
print("\n" + "=" * 80)
296+
print("LOG FILE INFO")
297+
print("=" * 80)
298+
print(f"Log file: {LOG_DEST}")
299+
print(f"Log size: {log_size / 1024:.2f} KB ({log_size} bytes)")
300+
print(f"Iterations logged: {num_iterations}")
301+
print(f"Bytes per iteration: {log_size / num_iterations:.2f}")
302+
303+
# Cleanup option
304+
cleanup_log = os.environ.get("CLEANUP_LOG", "true").lower() == "true"
305+
if cleanup_log:
306+
os.remove(LOG_DEST)
307+
print("\n Log file removed (set CLEANUP_LOG=false to keep it)")
308+
else:
309+
print(f"\n Log file preserved at {LOG_DEST}")
310+
311+
print("\n" + "=" * 80)
312+
print("RECOMMENDATIONS")
313+
print("=" * 80)
314+
print("\nTo benchmark other levels, run:")
315+
for level in [0, 1, 3, 5]:
316+
if level != LOGGING_LEVEL:
317+
print(f" FLASHINFER_LOGLEVEL={level} python {sys.argv[0]}")
318+
319+
print("\n" + "=" * 80)
320+
print("Benchmark complete!")
321+
print("=" * 80)
322+
323+
324+
if __name__ == "__main__":
325+
try:
326+
main()
327+
except KeyboardInterrupt:
328+
print("\n\nBenchmark interrupted by user.")
329+
except Exception as e:
330+
print(f"\n\nError during benchmark: {e}")
331+
import traceback
332+
333+
traceback.print_exc()

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ FlashInfer is a library and kernel generator for Large Language Models that prov
1515
:caption: Get Started
1616

1717
installation
18+
logging
1819

1920
.. toctree::
2021
:maxdepth: 2

0 commit comments

Comments
 (0)