Skip to content

Commit 41ad558

Browse files
committed
Rename decorator and environment flags
1 parent 8459eb1 commit 41ad558

File tree

12 files changed

+91
-93
lines changed

12 files changed

+91
-93
lines changed

LOGGING.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@ Enable logging using two environment variables:
88

99
```bash
1010
# Set logging level (0-5)
11-
export FLASHINFER_LOGLEVEL_DBG=3
11+
export FLASHINFER_LOGLEVEL=3
1212

1313
# Set log destination (default is stdout)
14-
export FLASHINFER_LOGDEST_DBG=stdout # or stderr, or a file path like "flashinfer.log"
14+
export FLASHINFER_LOGDEST=stdout # or stderr, or a file path like "flashinfer.log"
1515

1616
# Run your code
1717
python train.py
@@ -33,19 +33,19 @@ python train.py
3333

3434
| Variable | Type | Default | Description |
3535
|----------|------|---------|-------------|
36-
| `FLASHINFER_LOGLEVEL_DBG` | int | 0 | Logging level (0, 1, 3, 5) |
37-
| `FLASHINFER_LOGDEST_DBG` | str | `stdout` | Log destination: `stdout`, `stderr`, or file path |
36+
| `FLASHINFER_LOGLEVEL` | int | 0 | Logging level (0, 1, 3, 5) |
37+
| `FLASHINFER_LOGDEST` | str | `stdout` | Log destination: `stdout`, `stderr`, or file path |
3838

3939
### Process ID Substitution
4040

4141
Use `%i` in file paths for automatic process ID substitution (useful for multi-GPU training):
4242

4343
```bash
44-
export FLASHINFER_LOGDEST_DBG="flashinfer_log_%i.txt" # → flashinfer_log_12345.txt
44+
export FLASHINFER_LOGDEST="flashinfer_log_%i.txt" # → flashinfer_log_12345.txt
4545
```
4646

4747
This works for:
48-
- `FLASHINFER_LOGDEST_DBG`
48+
- `FLASHINFER_LOGDEST`
4949

5050
## Miscellaneous Notes and Examples
5151
### CUDA Graph Compatibility
@@ -65,8 +65,8 @@ Output shows: `[statistics skipped: CUDA graph capture in progress]`
6565

6666
```bash
6767
# Use %i for process ID substitution
68-
export FLASHINFER_LOGLEVEL_DBG=3
69-
export FLASHINFER_LOGDEST_DBG="logs/flashinfer_api_%i.log"
68+
export FLASHINFER_LOGLEVEL=3
69+
export FLASHINFER_LOGDEST="logs/flashinfer_api_%i.log"
7070

7171
torchrun --nproc_per_node=8 awesome_script_that_uses_FlashInfer.py
7272

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,10 +175,10 @@ FlashInfer provides comprehensive API logging for debugging. Enable it using env
175175

176176
```bash
177177
# Enable logging (levels: 0=off (default), 1=basic, 3=detailed, 5=statistics)
178-
export FLASHINFER_LOGLEVEL_DBG=3
178+
export FLASHINFER_LOGLEVEL=3
179179

180180
# Set log destination (stdout (default), stderr, or file path)
181-
export FLASHINFER_LOGDEST_DBG=stdout
181+
export FLASHINFER_LOGDEST=stdout
182182
```
183183

184184
For detailed information about logging levels, configuration, and advanced features, see [LOGGING.md](LOGGING.md).

benchmarks/bench_logging_overhead.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77
88
Usage:
99
# Set the logging level before running
10-
export FLASHINFER_LOGLEVEL_DBG=3
10+
export FLASHINFER_APILEVEL=3
1111
python bench_logging_overhead.py
1212
1313
# Or run with different levels
14-
FLASHINFER_LOGLEVEL_DBG=0 python bench_logging_overhead.py
15-
FLASHINFER_LOGLEVEL_DBG=1 python bench_logging_overhead.py
16-
FLASHINFER_LOGLEVEL_DBG=3 python bench_logging_overhead.py
17-
FLASHINFER_LOGLEVEL_DBG=5 python bench_logging_overhead.py
14+
FLASHINFER_APILEVEL=0 python bench_logging_overhead.py
15+
FLASHINFER_APILEVEL=1 python bench_logging_overhead.py
16+
FLASHINFER_APILEVEL=3 python bench_logging_overhead.py
17+
FLASHINFER_APILEVEL=5 python bench_logging_overhead.py
1818
1919
# Or use the helper script to run all levels
2020
bash benchmark_all_levels.sh
@@ -28,11 +28,11 @@
2828
from typing import List, Tuple
2929

3030
# Get logging level BEFORE importing flashinfer
31-
LOGGING_LEVEL = int(os.environ.get("FLASHINFER_LOGLEVEL_DBG", "0"))
32-
LOG_DEST = os.environ.get("FLASHINFER_LOGDEST_DBG", "/tmp/flashinfer_benchmark_log.txt")
31+
LOGGING_LEVEL = int(os.environ.get("FLASHINFER_APILEVEL", "0"))
32+
LOG_DEST = os.environ.get("FLASHINFER_APIDEST", "/tmp/flashinfer_benchmark_log.txt")
3333

3434
# Import the decorator
35-
from flashinfer.api_logging import flashinfer_log
35+
from flashinfer.api_logging import flashinfer_api
3636

3737

3838
# Create two versions of a test function:
@@ -42,7 +42,7 @@ def test_matmul_undecorated(A, B):
4242
return torch.matmul(A, B)
4343

4444

45-
@flashinfer_log
45+
@flashinfer_api
4646
def test_matmul_decorated(A, B):
4747
return torch.matmul(A, B)
4848

@@ -209,8 +209,8 @@ def main():
209209

210210
# Display logging configuration
211211
print("\nLogging Configuration:")
212-
print(f" FLASHINFER_LOGLEVEL_DBG = {LOGGING_LEVEL}")
213-
print(f" FLASHINFER_LOGDEST_DBG = {LOG_DEST}")
212+
print(f" FLASHINFER_APILEVEL = {LOGGING_LEVEL}")
213+
print(f" FLASHINFER_APIDEST = {LOG_DEST}")
214214

215215
# Get level name
216216
level_names = {
@@ -314,7 +314,7 @@ def main():
314314
print("\nTo benchmark other levels, run:")
315315
for level in [0, 1, 3, 5]:
316316
if level != LOGGING_LEVEL:
317-
print(f" FLASHINFER_LOGLEVEL_DBG={level} python {sys.argv[0]}")
317+
print(f" FLASHINFER_APILEVEL={level} python {sys.argv[0]}")
318318

319319
print("\n" + "=" * 80)
320320
print("Benchmark complete!")

flashinfer/api_logging.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,8 @@ def _substitute_process_id(path: str) -> str:
3939

4040

4141
# Read environment variables once at module load time
42-
_API_LOG_LEVEL = int(os.environ.get("FLASHINFER_LOGLEVEL_DBG", "0"))
43-
_API_LOG_DEST = _substitute_process_id(
44-
os.environ.get("FLASHINFER_LOGDEST_DBG", "stdout")
45-
)
42+
_API_LOG_LEVEL = int(os.environ.get("FLASHINFER_APILEVEL", "0"))
43+
_API_LOG_DEST = _substitute_process_id(os.environ.get("FLASHINFER_APIDEST", "stdout"))
4644

4745
# Create logger using Python's logging library
4846
_logger = logging.getLogger("flashinfer.api")
@@ -56,7 +54,7 @@ def _setup_logger():
5654
_logger.setLevel(logging.CRITICAL + 1) # Higher than any level
5755
return
5856

59-
# All enabled levels use loggging.DEBUG; verbosity is controlled by FLASHINFER_LOGLEVEL_DBG instead
57+
# All enabled levels use loggging.DEBUG; verbosity is controlled by FLASHINFER_APILEVEL instead
6058
_logger.setLevel(logging.DEBUG)
6159

6260
# Remove any existing handlers
@@ -463,22 +461,22 @@ def _log_function_outputs(func_name: str, result: Any, level: int) -> None:
463461
_logger.debug("\n".join(lines))
464462

465463

466-
def flashinfer_log(func: Callable = None) -> Callable:
464+
def flashinfer_api(func: Callable = None) -> Callable:
467465
"""
468466
Decorator to log FlashInfer API calls using Python's logging library.
469467
470468
This decorator integrates with Python's standard logging infrastructure while
471-
maintaining zero overhead when disabled (FLASHINFER_LOGLEVEL_DBG=0).
469+
maintaining zero overhead when disabled (FLASHINFER_APILEVEL=0).
472470
473471
Environment Variables
474472
---------------------
475-
FLASHINFER_LOGLEVEL_DBG : int (default: 0)
473+
FLASHINFER_APILEVEL : int (default: 0)
476474
- 0: No logging (zero overhead - decorator returns original function)
477475
- 1: Log function name only (logged BEFORE execution - crash-safe)
478476
- 3: Log function name + inputs/outputs with metadata (inputs logged BEFORE execution - crash-safe)
479477
- 5: Log function name + inputs/outputs with metadata + tensor statistics (inputs logged BEFORE execution - crash-safe)
480478
481-
FLASHINFER_LOGDEST_DBG : str (default: "stdout")
479+
FLASHINFER_APIDEST : str (default: "stdout")
482480
- "stdout": Log to standard output
483481
- "stderr": Log to standard error
484482
- <path>: Log to specified file path
@@ -488,15 +486,15 @@ def flashinfer_log(func: Callable = None) -> Callable:
488486
--------
489487
Basic usage:
490488
491-
>>> @flashinfer_log
489+
>>> @flashinfer_api
492490
... def my_function(x, y):
493491
... return x + y
494492
495493
Notes
496494
-----
497495
- Key header lines include a timestamp in the format: [YYYY-MM-DD HH:MM:SS]
498496
(e.g., "FlashInfer API Call: function_name", "FlashInfer API Logging - System Information")
499-
- When FLASHINFER_LOGLEVEL_DBG=0, the decorator has truly zero overhead
497+
- When FLASHINFER_APILEVEL=0, the decorator has truly zero overhead
500498
as it returns the original function unchanged.
501499
- Function names and inputs are logged BEFORE execution:
502500
- Level 1: Function name only

flashinfer/cudnn/decode.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import torch
55

6-
from ..api_logging import flashinfer_log
6+
from ..api_logging import flashinfer_api
77
from .utils import get_cudnn_fmha_gen_module
88

99
try:
@@ -253,7 +253,7 @@ def _batch_decode_with_kv_cache(
253253
return out
254254

255255

256-
@flashinfer_log
256+
@flashinfer_api
257257
def cudnn_batch_decode_with_kv_cache(
258258
q: torch.Tensor,
259259
k_cache: torch.Tensor,

flashinfer/cudnn/prefill.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import torch
55

6-
from ..api_logging import flashinfer_log
6+
from ..api_logging import flashinfer_api
77
from .utils import get_cudnn_fmha_gen_module
88

99
try:
@@ -384,7 +384,7 @@ def _batch_prefill_with_kv_cache(
384384
return out, None
385385

386386

387-
@flashinfer_log
387+
@flashinfer_api
388388
def cudnn_batch_prefill_with_kv_cache(
389389
q: torch.Tensor,
390390
k_cache: torch.Tensor,

flashinfer/decode.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import torch
2323

24-
from .api_logging import flashinfer_log
24+
from .api_logging import flashinfer_api
2525
from .xqa import xqa, xqa_mla
2626
from .cudnn import cudnn_batch_decode_with_kv_cache as cudnn_batch_decode_with_kv_cache
2727
from .jit import (
@@ -313,7 +313,7 @@ def get_trtllm_gen_fmha_module():
313313
return op
314314

315315

316-
@flashinfer_log
316+
@flashinfer_api
317317
def single_decode_with_kv_cache_with_jit_module(
318318
jit_module: Any,
319319
q: torch.Tensor,
@@ -390,7 +390,7 @@ def single_decode_with_kv_cache(
390390
) -> Tuple[torch.Tensor, torch.Tensor]: ...
391391

392392

393-
@flashinfer_log
393+
@flashinfer_api
394394
def single_decode_with_kv_cache(
395395
q: torch.Tensor,
396396
k: torch.Tensor,
@@ -649,7 +649,7 @@ class BatchDecodeWithPagedKVCacheWrapper:
649649
manages the lifecycle of these data structures.
650650
"""
651651

652-
@flashinfer_log
652+
@flashinfer_api
653653
def __init__(
654654
self,
655655
float_workspace_buffer: torch.Tensor,
@@ -813,7 +813,7 @@ def reset_workspace_buffer(
813813
pin_memory=True,
814814
)
815815

816-
@flashinfer_log
816+
@flashinfer_api
817817
def plan(
818818
self,
819819
indptr: torch.Tensor,
@@ -1167,7 +1167,7 @@ def run(
11671167
window_left: Optional[int] = None,
11681168
) -> Tuple[torch.Tensor, torch.Tensor]: ...
11691169

1170-
@flashinfer_log
1170+
@flashinfer_api
11711171
def run(
11721172
self,
11731173
q: torch.Tensor,
@@ -2065,7 +2065,7 @@ def _fake_paged_run(
20652065
)
20662066

20672067

2068-
@flashinfer_log
2068+
@flashinfer_api
20692069
def trtllm_batch_decode_with_kv_cache(
20702070
query: torch.Tensor,
20712071
kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
@@ -2339,7 +2339,7 @@ def trtllm_batch_decode_with_kv_cache(
23392339

23402340

23412341
# xqa uses NHD layout
2342-
@flashinfer_log
2342+
@flashinfer_api
23432343
def xqa_batch_decode_with_kv_cache(
23442344
query: torch.Tensor,
23452345
kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
@@ -2524,7 +2524,7 @@ def _check_trtllm_gen_mla_shape(
25242524
)
25252525

25262526

2527-
@flashinfer_log
2527+
@flashinfer_api
25282528
def trtllm_batch_decode_with_kv_cache_mla(
25292529
query: torch.Tensor,
25302530
kv_cache: torch.Tensor,
@@ -2686,7 +2686,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
26862686
raise ValueError(f"Backend {backend} not supported")
26872687

26882688

2689-
@flashinfer_log
2689+
@flashinfer_api
26902690
def xqa_batch_decode_with_kv_cache_mla(
26912691
query: torch.Tensor,
26922692
kv_cache: torch.Tensor,

flashinfer/fused_moe/core.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from typing import Any, Dict, List, Optional, Tuple, Union
2121
import torch
2222

23-
from ..api_logging import flashinfer_log
23+
from ..api_logging import flashinfer_api
2424
from ..autotuner import (
2525
AutoTuner,
2626
DynamicTensorSpec,
@@ -686,7 +686,7 @@ def _fake_cutlass_fused_moe(
686686

687687

688688
# ref: https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py#L121
689-
@flashinfer_log
689+
@flashinfer_api
690690
def cutlass_fused_moe(
691691
input: torch.Tensor,
692692
token_selected_experts: torch.Tensor,
@@ -1859,7 +1859,7 @@ def _fake_trtllm_fp4_block_scale_moe(
18591859
)
18601860

18611861

1862-
@flashinfer_log
1862+
@flashinfer_api
18631863
def trtllm_bf16_moe(
18641864
routing_logits: torch.Tensor,
18651865
routing_bias: Optional[torch.Tensor],
@@ -1940,7 +1940,7 @@ def trtllm_bf16_moe(
19401940
)
19411941

19421942

1943-
@flashinfer_log
1943+
@flashinfer_api
19441944
def trtllm_fp8_per_tensor_scale_moe(
19451945
routing_logits: torch.Tensor,
19461946
routing_bias: Optional[torch.Tensor],
@@ -2014,7 +2014,7 @@ def trtllm_fp8_per_tensor_scale_moe(
20142014
)
20152015

20162016

2017-
@flashinfer_log
2017+
@flashinfer_api
20182018
def trtllm_fp8_block_scale_moe(
20192019
routing_logits: torch.Tensor,
20202020
routing_bias: Optional[torch.Tensor],
@@ -2092,7 +2092,7 @@ def trtllm_fp8_block_scale_moe(
20922092
)
20932093

20942094

2095-
@flashinfer_log
2095+
@flashinfer_api
20962096
def trtllm_fp4_block_scale_moe(
20972097
routing_logits: torch.Tensor,
20982098
routing_bias: Optional[torch.Tensor],
@@ -2222,7 +2222,7 @@ def trtllm_fp4_block_scale_moe(
22222222
)
22232223

22242224

2225-
@flashinfer_log
2225+
@flashinfer_api
22262226
def trtllm_fp4_block_scale_routed_moe(
22272227
topk_ids: torch.Tensor,
22282228
routing_bias: Optional[torch.Tensor],

0 commit comments

Comments
 (0)