From 18de6318361fb1f302f338b98f03501c410a2370 Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Tue, 18 Nov 2025 01:22:38 +0000 Subject: [PATCH 01/13] First commit --- flashinfer/api_logging.py | 435 ++++++++++++++++++++++++++ flashinfer/cudnn/decode.py | 2 + flashinfer/cudnn/prefill.py | 2 + flashinfer/decode.py | 10 + flashinfer/fused_moe/core.py | 7 + flashinfer/gemm/gemm_base.py | 12 + flashinfer/mla.py | 4 + flashinfer/prefill.py | 11 + tests/utils/test_logging.py | 583 +++++++++++++++++++++++++++++++++++ 9 files changed, 1066 insertions(+) create mode 100644 flashinfer/api_logging.py create mode 100644 tests/utils/test_logging.py diff --git a/flashinfer/api_logging.py b/flashinfer/api_logging.py new file mode 100644 index 0000000000..c98d0f87df --- /dev/null +++ b/flashinfer/api_logging.py @@ -0,0 +1,435 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import enum +import functools +import inspect +import logging +import os +import sys +from typing import Any, Callable + +import torch + + +# Read environment variables once at module load time +_API_LOG_LEVEL = int(os.environ.get("FLASHINFER_APILOG_LEVEL", "0")) +_API_LOG_DEST = os.environ.get("FLASHINFER_APILOG_DEST", "./flashinfer_log.txt") + +# Create logger using Python's logging library +_logger = logging.getLogger("flashinfer.api") + +def _setup_logger(): + """Set up the logger based on environment variables.""" + if _API_LOG_LEVEL == 0: + # Completely disable logging for zero overhead + _logger.addHandler(logging.NullHandler()) + _logger.setLevel(logging.CRITICAL + 1) # Higher than any level + return + + # All enabled levels use loggging.DEBUG; verbosity is controlled by FLASHINFER_APILOG_LEVEL instead + _logger.setLevel(logging.DEBUG) + + # Remove any existing handlers + _logger.handlers.clear() + + # Create handler based on destination + if _API_LOG_DEST == "stdout": + handler = logging.StreamHandler(sys.stdout) + elif _API_LOG_DEST == "stderr": + handler = logging.StreamHandler(sys.stderr) + else: + handler = logging.FileHandler(_API_LOG_DEST, mode='a') + + # Use a simple formatter (we'll format the detailed content ourselves) + formatter = logging.Formatter('%(message)s') + handler.setFormatter(formatter) + + _logger.addHandler(handler) + _logger.propagate = False # Don't propagate to root logger + +# Initialize logger at module load time +_setup_logger() + +def _format_value(value: Any, level: int, indent: int = 0) -> str: + """ + Format a value for logging based on the log level. + + Parameters + ---------- + value : Any + The value to format + level : int + The logging level (1, 2, or 3) + indent : int + The indentation level for nested structures + + Returns + ------- + str + Formatted string representation of the value + """ + indent_str = " " * indent + + # Handle None + if value is None: + return f"{indent_str}None" + + # Handle Enum types + if isinstance(value, enum.Enum): + # Show both the name and value of the enum + return f"{indent_str}{value.__class__.__name__}.{value.name} (value={value.value})" + + # Handle torch.Tensor + if isinstance(value, torch.Tensor): + if level == 1: + return f"{indent_str}Tensor(...)" + + # Level 2+: Show metadata + lines = [f"{indent_str}Tensor("] + lines.append(f"{indent_str} shape={tuple(value.shape)}") + lines.append(f"{indent_str} stride={tuple(value.stride())}") + lines.append(f"{indent_str} dtype={value.dtype}") + lines.append(f"{indent_str} device={value.device}") + lines.append(f"{indent_str} requires_grad={value.requires_grad}") + lines.append(f"{indent_str} is_contiguous={value.is_contiguous()}") + + # Level 3: Add statistics + if level >= 3: + try: + # Skip statistics if we're in CUDA graph capture mode + # (operations like .min()/.max()/.mean() cause synchronization issues) + is_capturing = False + if value.is_cuda and hasattr(torch.cuda, 'is_current_stream_capturing'): + try: + is_capturing = torch.cuda.is_current_stream_capturing() + except Exception: + pass # Fallback if detection fails + + if is_capturing: + lines.append(f"{indent_str} [statistics skipped: CUDA graph capture in progress]") + elif value.numel() > 0: + # Convert to float for statistics if possible + if value.dtype in [torch.float16, torch.float32, torch.float64, + torch.bfloat16, torch.float8_e4m3fn, torch.float8_e5m2]: + val_float = value.float() + lines.append(f"{indent_str} min={val_float.min().item():.6f}") + lines.append(f"{indent_str} max={val_float.max().item():.6f}") + lines.append(f"{indent_str} mean={val_float.mean().item():.6f}") + nan_count = torch.isnan(val_float).sum().item() + lines.append(f"{indent_str} nan_count={nan_count}") + inf_count = torch.isinf(val_float).sum().item() + lines.append(f"{indent_str} inf_count={inf_count}") + elif value.dtype in [torch.int8, torch.int16, torch.int32, torch.int64, + torch.uint8]: + lines.append(f"{indent_str} min={value.min().item()}") + lines.append(f"{indent_str} max={value.max().item()}") + lines.append(f"{indent_str} mean={value.float().mean().item():.6f}") + except Exception as e: + lines.append(f"{indent_str} [statistics error: {e}]") + + lines.append(f"{indent_str})") + return "\n".join(lines) + + # Handle FP4Tensor (custom FlashInfer type) + if hasattr(value, '__class__') and value.__class__.__name__ == 'FP4Tensor': + if level == 1: + return f"{indent_str}FP4Tensor(...)" + + lines = [f"{indent_str}FP4Tensor("] + lines.append(f"{indent_str} data={_format_value(value.data, level, indent + 1)}") + lines.append(f"{indent_str} scale={_format_value(value.scale, level, indent + 1)}") + lines.append(f"{indent_str} scale_start_index={value.scale_start_index}") + if hasattr(value, 'original_shape') and value.original_shape is not None: + lines.append(f"{indent_str} original_shape={value.original_shape}") + lines.append(f"{indent_str})") + return "\n".join(lines) + + # Handle lists + if isinstance(value, list): + if len(value) == 0: + return f"{indent_str}[]" + if level == 1: + return f"{indent_str}[list with {len(value)} items]" + + lines = [f"{indent_str}["] + for i, item in enumerate(value): + lines.append(f"{indent_str} [{i}]: {_format_value(item, level, indent + 1)}") + lines.append(f"{indent_str}]") + return "\n".join(lines) + + # Handle tuples + if isinstance(value, tuple): + if len(value) == 0: + return f"{indent_str}()" + if level == 1: + return f"{indent_str}(tuple with {len(value)} items)" + + lines = [f"{indent_str}("] + for i, item in enumerate(value): + lines.append(f"{indent_str} [{i}]: {_format_value(item, level, indent + 1)}") + lines.append(f"{indent_str})") + return "\n".join(lines) + + # Handle dictionaries + if isinstance(value, dict): + if len(value) == 0: + return f"{indent_str}{{}}" + if level == 1: + return f"{indent_str}{{dict with {len(value)} keys}}" + + lines = [f"{indent_str}{{"] + for key, val in value.items(): + lines.append(f"{indent_str} {repr(key)}: {_format_value(val, level, indent + 1)}") + lines.append(f"{indent_str}}}") + return "\n".join(lines) + + # Handle numeric types (int, float, bool) + if isinstance(value, (int, float, bool, complex)): + return f"{indent_str}{value}" + + # Handle strings + if isinstance(value, str): + return f"{indent_str}{repr(value)}" + + # Default: use repr + try: + return f"{indent_str}{repr(value)}" + except Exception: + return f"{indent_str}<{type(value).__name__} object>" + + +def _get_default_params(func: Callable, args: tuple, kwargs: dict) -> dict: + """ + Extract parameters that have default values but were not explicitly provided. + + Parameters + ---------- + func : Callable + The function being called + args : tuple + Positional arguments that were provided + kwargs : dict + Keyword arguments that were provided + + Returns + ------- + dict + Dictionary of parameter names to default values for parameters that were not provided + """ + try: + sig = inspect.signature(func) + default_params = {} + + # Get parameter names in order + param_names = list(sig.parameters.keys()) + + # Determine which parameters were NOT provided + for i, (param_name, param) in enumerate(sig.parameters.items()): + # Skip if parameter has no default + if param.default is inspect.Parameter.empty: + continue + + # Check if this parameter was provided + provided = False + + # Check positional args (accounting for 'self' in methods) + if i < len(args): + provided = True + # Check keyword args + elif param_name in kwargs: + provided = True + + # If not provided, record the default value + if not provided: + default_params[param_name] = param.default + + return default_params + except Exception: + # If we can't inspect the signature, return empty dict + return {} + + +def _log_function_inputs(func: Callable, func_name: str, args: tuple, kwargs: dict, level: int) -> None: + """ + Log function inputs BEFORE execution for crash safety. + + This ensures inputs are captured even if the function crashes with a CUDA error. + + Parameters + ---------- + func : Callable + The function being called (needed to extract default parameters) + func_name : str + Name of the function being called + args : tuple + Positional arguments + kwargs : dict + Keyword arguments + level : int + Logging level (2 or 3) + """ + lines = [] + lines.append("=" * 80) + lines.append(f"FlashInfer API Call: {func_name}") + lines.append("-" * 80) + + # Log explicitly provided inputs + if args or kwargs: + # Positional arguments + if args: + lines.append("Positional input arguments:") + for i, arg in enumerate(args): + lines.append(f" arg[{i}]:") + lines.append(_format_value(arg, level, indent=2)) + + # Keyword arguments + if kwargs: + lines.append("Keyword input arguments:") + for key, value in kwargs.items(): + lines.append(f" {key}=") + lines.append(_format_value(value, level, indent=2)) + else: + lines.append("(No explicit arguments)") + + # Log default parameters that were not explicitly provided + default_params = _get_default_params(func, args, kwargs) + if default_params: + lines.append("Default parameters (not explicitly provided):") + for param_name, default_value in default_params.items(): + lines.append(f" {param_name}= [DEFAULT]") + lines.append(_format_value(default_value, level, indent=2)) + + _logger.debug("\n".join(lines)) + + +def _log_function_outputs(func_name: str, result: Any, level: int) -> None: + """ + Log function outputs AFTER successful execution. + + Parameters + ---------- + func_name : str + Name of the function + result : Any + Function return value + level : int + Logging level (2 or 3) + """ + lines = [] + # Log outputs + lines.append("Output value:") + lines.append(_format_value(result, level, indent=1)) + + lines.append("=" * 80) + lines.append("") # Empty line for readability + + _logger.debug("\n".join(lines)) + + +def flashinfer_api_log(func: Callable = None) -> Callable: + """ + Decorator to log FlashInfer API calls using Python's logging library. + + This decorator integrates with Python's standard logging infrastructure while + maintaining zero overhead when disabled (FLASHINFER_APILOG_LEVEL=0). + + Environment Variables + --------------------- + FLASHINFER_APILOG_LEVEL : int (default: 0) + - 0: No logging (zero overhead - decorator returns original function) + - 1: Log function name only (logged BEFORE execution - crash-safe) + - 2: Log function name + inputs/outputs with metadata (inputs logged BEFORE execution - crash-safe) + - 3: Log function name + inputs/outputs with metadata + tensor statistics (inputs logged BEFORE execution - crash-safe) + + FLASHINFER_APILOG_DEST : str (default: "./flashinfer_log.txt") + - "stdout": Log to standard output + - "stderr": Log to standard error + - : Log to specified file path + + Examples + -------- + Basic usage: + + >>> @flashinfer_api_log + ... def my_function(x, y): + ... return x + y + + Notes + ----- + - When FLASHINFER_APILOG_LEVEL=0, the decorator has truly zero overhead + as it returns the original function unchanged. + - Function names and inputs are logged BEFORE execution: + - Level 1: Function name only + - Levels 2-3: Function name + inputs with metadata + This means critical debugging information is preserved even if the function + crashes (e.g., CUDA illegal memory access, out-of-bounds, etc.). + - Outputs are logged AFTER successful execution for levels 2 and 3. + - **CUDA Graph Compatibility**: At level 3, tensor statistics (min/max/mean/nan_count) + are automatically skipped during CUDA graph capture to avoid synchronization issues. + The message "[statistics skipped: CUDA graph capture in progress]" will be logged. + - The logger does not propagate to the root logger to avoid duplicate logs. + """ + # If logging is disabled, return original function with zero overhead + if _API_LOG_LEVEL == 0: + if func is None: + return lambda f: f + return func + + def decorator(f: Callable) -> Callable: + @functools.wraps(f) + def wrapper(*args, **kwargs): + # Determine function name (with class name if applicable) + func_name = f.__name__ + if args and hasattr(args[0], '__class__'): + try: + class_name = args[0].__class__.__name__ + if 'Wrapper' in class_name or class_name in ['BatchMLAPagedAttentionWrapper']: + func_name = f"{class_name}.{func_name}" + except: + pass + + # Log BEFORE execution (crash-safe for all levels!) + try: + if _API_LOG_LEVEL == 1: + # Level 1: Just log function name before execution (crash-safe) + _logger.debug(f"FlashInfer API Call: {func_name}") + elif _API_LOG_LEVEL >= 2: + # Level 2+: Log full inputs before execution (crash-safe) + _log_function_inputs(f, func_name, args, kwargs, _API_LOG_LEVEL) + except Exception as e: + _logger.error(f"[LOGGING ERROR in {func_name} (pre-execution)]: {e}") + + # Call the original function (may crash here with CUDA errors) + result = f(*args, **kwargs) + + # Log outputs AFTER successful execution (level 2+ only) + try: + if _API_LOG_LEVEL >= 2: + # Level 2+: Log outputs (inputs were already logged above) + _log_function_outputs(func_name, result, _API_LOG_LEVEL) + except Exception as e: + _logger.error(f"[LOGGING ERROR in {func_name} (outputs)]: {e}") + + return result + + return wrapper + + # Support both @flashinfer_api_log and @flashinfer_api_log() + if func is None: + return decorator + return decorator(func) + diff --git a/flashinfer/cudnn/decode.py b/flashinfer/cudnn/decode.py index 6ef13b997f..39f4cf67c1 100644 --- a/flashinfer/cudnn/decode.py +++ b/flashinfer/cudnn/decode.py @@ -3,6 +3,7 @@ import torch +from ..api_logging import flashinfer_api_log from .utils import get_cudnn_fmha_gen_module try: @@ -252,6 +253,7 @@ def _batch_decode_with_kv_cache( return out +@flashinfer_api_log def cudnn_batch_decode_with_kv_cache( q: torch.Tensor, k_cache: torch.Tensor, diff --git a/flashinfer/cudnn/prefill.py b/flashinfer/cudnn/prefill.py index fc573cf7cb..9ca5cea66a 100644 --- a/flashinfer/cudnn/prefill.py +++ b/flashinfer/cudnn/prefill.py @@ -3,6 +3,7 @@ import torch +from ..api_logging import flashinfer_api_log from .utils import get_cudnn_fmha_gen_module try: @@ -383,6 +384,7 @@ def _batch_prefill_with_kv_cache( return out, None +@flashinfer_api_log def cudnn_batch_prefill_with_kv_cache( q: torch.Tensor, k_cache: torch.Tensor, diff --git a/flashinfer/decode.py b/flashinfer/decode.py index af8dda0345..2f2072039a 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -21,6 +21,7 @@ import torch +from .api_logging import flashinfer_api_log from .xqa import xqa, xqa_mla from .cudnn import cudnn_batch_decode_with_kv_cache as cudnn_batch_decode_with_kv_cache from .jit import ( @@ -312,6 +313,7 @@ def get_trtllm_gen_fmha_module(): return op +@flashinfer_api_log def single_decode_with_kv_cache_with_jit_module( jit_module: Any, q: torch.Tensor, @@ -388,6 +390,7 @@ def single_decode_with_kv_cache( ) -> Tuple[torch.Tensor, torch.Tensor]: ... +@flashinfer_api_log def single_decode_with_kv_cache( q: torch.Tensor, k: torch.Tensor, @@ -646,6 +649,7 @@ class BatchDecodeWithPagedKVCacheWrapper: manages the lifecycle of these data structures. """ + @flashinfer_api_log def __init__( self, float_workspace_buffer: torch.Tensor, @@ -809,6 +813,7 @@ def reset_workspace_buffer( pin_memory=True, ) + @flashinfer_api_log def plan( self, indptr: torch.Tensor, @@ -1162,6 +1167,7 @@ def run( window_left: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: ... + @flashinfer_api_log def run( self, q: torch.Tensor, @@ -2059,6 +2065,7 @@ def _fake_paged_run( ) +@flashinfer_api_log def trtllm_batch_decode_with_kv_cache( query: torch.Tensor, kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], @@ -2332,6 +2339,7 @@ def trtllm_batch_decode_with_kv_cache( # xqa uses NHD layout +@flashinfer_api_log def xqa_batch_decode_with_kv_cache( query: torch.Tensor, kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], @@ -2516,6 +2524,7 @@ def _check_trtllm_gen_mla_shape( ) +@flashinfer_api_log def trtllm_batch_decode_with_kv_cache_mla( query: torch.Tensor, kv_cache: torch.Tensor, @@ -2677,6 +2686,7 @@ def trtllm_batch_decode_with_kv_cache_mla( raise ValueError(f"Backend {backend} not supported") +@flashinfer_api_log def xqa_batch_decode_with_kv_cache_mla( query: torch.Tensor, kv_cache: torch.Tensor, diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 3c5e7a09c5..9ab621453a 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -20,6 +20,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch +from ..api_logging import flashinfer_api_log from ..autotuner import ( AutoTuner, DynamicTensorSpec, @@ -685,6 +686,7 @@ def _fake_cutlass_fused_moe( # ref: https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py#L121 +@flashinfer_api_log def cutlass_fused_moe( input: torch.Tensor, token_selected_experts: torch.Tensor, @@ -1857,6 +1859,7 @@ def _fake_trtllm_fp4_block_scale_moe( ) +@flashinfer_api_log def trtllm_bf16_moe( routing_logits: torch.Tensor, routing_bias: Optional[torch.Tensor], @@ -1937,6 +1940,7 @@ def trtllm_bf16_moe( ) +@flashinfer_api_log def trtllm_fp8_per_tensor_scale_moe( routing_logits: torch.Tensor, routing_bias: Optional[torch.Tensor], @@ -2010,6 +2014,7 @@ def trtllm_fp8_per_tensor_scale_moe( ) +@flashinfer_api_log def trtllm_fp8_block_scale_moe( routing_logits: torch.Tensor, routing_bias: Optional[torch.Tensor], @@ -2087,6 +2092,7 @@ def trtllm_fp8_block_scale_moe( ) +@flashinfer_api_log def trtllm_fp4_block_scale_moe( routing_logits: torch.Tensor, routing_bias: Optional[torch.Tensor], @@ -2216,6 +2222,7 @@ def trtllm_fp4_block_scale_moe( ) +@flashinfer_api_log def trtllm_fp4_block_scale_routed_moe( topk_ids: torch.Tensor, routing_bias: Optional[torch.Tensor], diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 589c651aca..6306336876 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -22,6 +22,7 @@ from flashinfer.trtllm_low_latency_gemm import trtllm_low_latency_gemm import torch +from ..api_logging import flashinfer_api_log from ..autotuner import ( AutoTuner, ConstraintSpec, @@ -539,6 +540,7 @@ def forward( ) +@flashinfer_api_log def tgv_gemm_sm100( a: torch.Tensor, b: torch.Tensor, @@ -884,6 +886,7 @@ def reset_workspace_buffer( self._float_workspace_buffer = float_workspace_buffer self._int_workspace_buffer = int_workspace_buffer + @flashinfer_api_log def run( self, x: torch.Tensor, @@ -1551,6 +1554,7 @@ def _expand_block_scale_tensor_shape(block_scale_tensor, batch_size): return (tuple(block_scale_shape), tuple(block_scale_stride)) +@flashinfer_api_log def mm_fp8( a: torch.Tensor, b: torch.Tensor, @@ -2015,6 +2019,7 @@ def _heuristic_func_mm_fp4( return [c for c in candidate_backends if c in suitable_backends] +@flashinfer_api_log @backend_requirement( { "cudnn": _cudnn_gemm_fp4_requirement, @@ -2272,6 +2277,7 @@ def _heuristic_func_bmm_fp8( return heuristic_backends +@flashinfer_api_log @backend_requirement( { "cudnn": _cudnn_bmm_fp8_requirement, @@ -2372,6 +2378,7 @@ def bmm_fp8( return out +@flashinfer_api_log def gemm_fp8_nt_groupwise( a: torch.Tensor, b: torch.Tensor, @@ -2623,6 +2630,7 @@ def forward( ) +@flashinfer_api_log def gemm_fp8_nt_blockscaled( a: torch.Tensor, b: torch.Tensor, @@ -2651,6 +2659,7 @@ def gemm_fp8_nt_blockscaled( ) +@flashinfer_api_log def group_gemm_fp8_nt_groupwise( a: torch.Tensor, # (cum_m, k) b: torch.Tensor, # (batch_size, n, k) @@ -2813,6 +2822,7 @@ def group_gemm_fp8_nt_groupwise( return out +@flashinfer_api_log def group_gemm_mxfp8_mxfp4_nt_groupwise( a: torch.Tensor, # (cum_m, k) b: torch.Tensor, # (batch_size, n, k // 2) @@ -2980,6 +2990,7 @@ def get_deepgemm_sm100_module(): return module +@flashinfer_api_log def group_deepgemm_fp8_nt_groupwise( a: torch.Tensor, # (m, k) b: torch.Tensor, # (batch_size, n, k) @@ -3110,6 +3121,7 @@ def group_deepgemm_fp8_nt_groupwise( return out +@flashinfer_api_log def batch_deepgemm_fp8_nt_groupwise( a: torch.Tensor, # (batch_size, m, k) b: torch.Tensor, # (batch_size, n, k) diff --git a/flashinfer/mla.py b/flashinfer/mla.py index da57d94e6b..c94db531f9 100644 --- a/flashinfer/mla.py +++ b/flashinfer/mla.py @@ -19,6 +19,7 @@ import torch +from .api_logging import flashinfer_api_log from .jit import gen_batch_mla_module from .jit.mla import gen_mla_module from .utils import MaskMode, check_shape_dtype_device, determine_mla_backend @@ -129,6 +130,7 @@ class BatchMLAPagedAttentionWrapper: torch.Size([114, 128, 512]) """ + @flashinfer_api_log def __init__( self, float_workspace_buffer: torch.Tensor, @@ -199,6 +201,7 @@ def __init__( else: self._backend = backend + @flashinfer_api_log def plan( self, qo_indptr: torch.Tensor, @@ -333,6 +336,7 @@ def run( return_lse_base_on_e: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: ... + @flashinfer_api_log def run( self, q_nope: torch.Tensor, diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 47d725c5d3..6fec42cbff 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -22,6 +22,7 @@ import torch +from .api_logging import flashinfer_api_log from .jit import ( gen_batch_prefill_module, gen_customize_batch_prefill_module, @@ -873,6 +874,7 @@ def _fake_paged_run( ) +@flashinfer_api_log def single_prefill_with_kv_cache_with_jit_module( jit_module: Any, q: torch.Tensor, @@ -957,6 +959,7 @@ def single_prefill_with_kv_cache( ) -> Tuple[torch.Tensor, torch.Tensor]: ... +@flashinfer_api_log def single_prefill_with_kv_cache( q: torch.Tensor, k: torch.Tensor, @@ -1325,6 +1328,7 @@ class BatchPrefillWithPagedKVCacheWrapper: wrapper class manages the lifecycle of these data structures. """ + @flashinfer_api_log def __init__( self, float_workspace_buffer: torch.Tensor, @@ -1520,6 +1524,7 @@ def reset_workspace_buffer( pin_memory=True, ) + @flashinfer_api_log def plan( self, qo_indptr: torch.Tensor, @@ -1976,6 +1981,7 @@ def run( window_left: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: ... + @flashinfer_api_log def run( self, q: torch.Tensor, @@ -2350,6 +2356,7 @@ class BatchPrefillWithRaggedKVCacheWrapper: wrapper class manages the lifecycle of these data structures. """ + @flashinfer_api_log def __init__( self, float_workspace_buffer: torch.Tensor, @@ -2493,6 +2500,7 @@ def reset_workspace_buffer( pin_memory=True, ) + @flashinfer_api_log def plan( self, qo_indptr: torch.Tensor, @@ -2837,6 +2845,7 @@ def run( enable_pdl: Optional[bool] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: ... + @flashinfer_api_log def run( self, q: torch.Tensor, @@ -3193,6 +3202,7 @@ def get_trtllm_gen_fmha_module(): return op +@flashinfer_api_log def trtllm_ragged_attention_deepseek( query: torch.Tensor, key: torch.Tensor, @@ -3327,6 +3337,7 @@ def trtllm_ragged_attention_deepseek( return out +@flashinfer_api_log def trtllm_batch_context_with_kv_cache( query: torch.Tensor, kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], diff --git a/tests/utils/test_logging.py b/tests/utils/test_logging.py new file mode 100644 index 0000000000..5d6865c743 --- /dev/null +++ b/tests/utils/test_logging.py @@ -0,0 +1,583 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os +import sys +import tempfile +from enum import Enum +from io import StringIO +from pathlib import Path + +import pytest +import torch + + +# Test enum classes +class TestEnum(Enum): + """Test enum with integer values.""" + OPTION_A = 0 + OPTION_B = 1 + OPTION_C = 2 + + +class StringEnum(Enum): + """Test enum with string values. Names are for testing purposes.""" + MODE_STANDARD = "standard" + MODE_OPTIMIZED = "optimized" + + +class TestAPILogging: + """Test suite for FlashInfer API logging infrastructure.""" + + @pytest.fixture(autouse=True) + def setup_and_teardown(self): + """Reset environment and reimport logging module for each test.""" + # Store original environment + original_level = os.environ.get("FLASHINFER_APILOG_LEVEL") + original_dest = os.environ.get("FLASHINFER_APILOG_DEST") + + yield + + # Restore original environment + if original_level is not None: + os.environ["FLASHINFER_APILOG_LEVEL"] = original_level + elif "FLASHINFER_APILOG_LEVEL" in os.environ: + del os.environ["FLASHINFER_APILOG_LEVEL"] + + if original_dest is not None: + os.environ["FLASHINFER_APILOG_DEST"] = original_dest + elif "FLASHINFER_APILOG_DEST" in os.environ: + del os.environ["FLASHINFER_APILOG_DEST"] + + # Force reimport to pick up new environment variables + if "flashinfer.api_logging" in sys.modules: + del sys.modules["flashinfer.api_logging"] + + def setup_logging(self, level: int, dest: str = "stdout"): + """Helper to set up logging environment and reimport.""" + os.environ["FLASHINFER_APILOG_LEVEL"] = str(level) + os.environ["FLASHINFER_APILOG_DEST"] = dest + + # Force reimport + if "flashinfer.api_logging" in sys.modules: + del sys.modules["flashinfer.api_logging"] + + from flashinfer.api_logging import flashinfer_api_log + return flashinfer_api_log + + def test_level_0_zero_overhead(self): + """Test that level 0 has truly zero overhead (returns original function).""" + decorator = self.setup_logging(level=0) + + def original_func(x, y): + return x + y + + decorated_func = decorator(original_func) + + # At level 0, decorator should return the original function unchanged + assert decorated_func is original_func + assert decorated_func(5, 3) == 8 + + def test_level_1_function_name(self): + """Test that level 1 logs function name only.""" + with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.txt') as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=1, dest=log_file) + + @decorator + def test_function(x, y): + return x + y + + result = test_function(10, 20) + assert result == 30 + + # Check log contents + with open(log_file, 'r') as f: + log_contents = f.read() + + assert "FlashInfer API Call: test_function" in log_contents + # Level 1 should not log inputs/outputs details + assert "Positional input arguments" not in log_contents + assert "Output value" not in log_contents + finally: + Path(log_file).unlink(missing_ok=True) + + def test_level_2_inputs_outputs(self): + """Test that level 2 logs inputs and outputs with metadata.""" + with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.txt') as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=2, dest=log_file) + + @decorator + def test_function(tensor, value): + return tensor * value + + tensor = torch.tensor([1.0, 2.0, 3.0]) + result = test_function(tensor, 2.0) + + # Check log contents + with open(log_file, 'r') as f: + log_contents = f.read() + + # Should log function name + assert "FlashInfer API Call: test_function" in log_contents + + # Should log inputs + assert "Positional input arguments" in log_contents + assert "arg[0]" in log_contents + assert "Tensor(" in log_contents + assert "shape=(3,)" in log_contents + assert "dtype=torch.float32" in log_contents + + # Should log outputs + assert "Output value:" in log_contents + + # Should NOT log statistics (level 3 only) + assert "min=" not in log_contents + assert "max=" not in log_contents + finally: + Path(log_file).unlink(missing_ok=True) + + def test_level_3_statistics(self): + """Test that level 3 logs tensor statistics.""" + with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.txt') as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=3, dest=log_file) + + @decorator + def test_function(tensor): + return tensor + 1.0 + + tensor = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) + result = test_function(tensor) + + # Check log contents + with open(log_file, 'r') as f: + log_contents = f.read() + + # Should log statistics + assert "min=" in log_contents + assert "max=" in log_contents + assert "mean=" in log_contents + assert "nan_count=" in log_contents + assert "inf_count=" in log_contents + finally: + Path(log_file).unlink(missing_ok=True) + + def test_enum_logging(self): + """Test that enum values are logged with name and value.""" + with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.txt') as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=2, dest=log_file) + + @decorator + def test_function(mode: TestEnum, strategy: StringEnum): + return f"{mode.name}_{strategy.name}" + + result = test_function(TestEnum.OPTION_B, StringEnum.MODE_OPTIMIZED) + + # Check log contents + with open(log_file, 'r') as f: + log_contents = f.read() + + # Should show enum name and value + assert "TestEnum.OPTION_B" in log_contents + assert "(value=1)" in log_contents + assert "StringEnum.MODE_OPTIMIZED" in log_contents + assert "(value=optimized)" in log_contents or "(value='optimized')" in log_contents or '(value="optimized")' in log_contents + finally: + Path(log_file).unlink(missing_ok=True) + + def test_default_parameters(self): + """Test that default parameters are logged separately.""" + with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.txt') as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=2, dest=log_file) + + @decorator + def test_function(x, y=10, z=20, mode=TestEnum.OPTION_A): + return x + y + z + + # Call with only required argument + result = test_function(5) + assert result == 35 + + # Check log contents + with open(log_file, 'r') as f: + log_contents = f.read() + + # Should show default parameters section + assert "Default parameters (not explicitly provided)" in log_contents + assert "[DEFAULT]" in log_contents + + # Should show the default values + assert "y=" in log_contents + assert "z=" in log_contents + assert "mode=" in log_contents + finally: + Path(log_file).unlink(missing_ok=True) + + def test_explicit_vs_default_parameters(self): + """Test that explicitly provided parameters are not shown in defaults.""" + with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.txt') as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=2, dest=log_file) + + @decorator + def test_function(x, y=10, z=20): + return x + y + z + + # Call with some explicit parameters + result = test_function(5, y=100) + + # Check log contents + with open(log_file, 'r') as f: + log_contents = f.read() + + # y should be in keyword arguments (explicit) + assert "Keyword input arguments:" in log_contents + + # Only z should be in defaults + lines = log_contents.split('\n') + default_section_started = False + defaults_found = [] + for line in lines: + if "Default parameters" in line: + default_section_started = True + if default_section_started and "=" in line and "[DEFAULT]" in line: + defaults_found.append(line) + + # Should have only one default parameter (z) + assert len(defaults_found) == 1 + assert "z=" in defaults_found[0] + finally: + Path(log_file).unlink(missing_ok=True) + + def test_class_method_logging(self): + """Test that class methods log with class name.""" + with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.txt') as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=1, dest=log_file) + + class TestWrapper: + @decorator + def run(self, x): + return x * 2 + + wrapper = TestWrapper() + result = wrapper.run(5) + assert result == 10 + + # Check log contents + with open(log_file, 'r') as f: + log_contents = f.read() + + # Should log class name for Wrapper classes + assert "TestWrapper.run" in log_contents + finally: + Path(log_file).unlink(missing_ok=True) + + def test_crash_safety_inputs_logged_before_execution(self): + """Test that inputs are logged BEFORE execution (crash-safe).""" + with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.txt') as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=2, dest=log_file) + + @decorator + def crashing_function(x, y): + raise RuntimeError("Simulated crash") + + # Call the function and expect it to crash + with pytest.raises(RuntimeError, match="Simulated crash"): + crashing_function(42, 99) + + # Check that inputs were still logged + with open(log_file, 'r') as f: + log_contents = f.read() + + # Inputs should be in the log even though function crashed + assert "FlashInfer API Call: crashing_function" in log_contents + assert "Positional input arguments" in log_contents + assert "arg[0]" in log_contents + assert "42" in log_contents + assert "arg[1]" in log_contents + assert "99" in log_contents + + # Outputs should NOT be in the log (function crashed) + assert "Output value:" not in log_contents + finally: + Path(log_file).unlink(missing_ok=True) + + def test_different_data_types(self): + """Test logging of various data types.""" + with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.txt') as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=2, dest=log_file) + + @decorator + def test_function( + int_val, float_val, bool_val, str_val, + list_val, tuple_val, dict_val, none_val + ): + return "success" + + result = test_function( + 42, + 3.14, + True, + "hello", + [1, 2, 3], + (4, 5, 6), + {"key": "value"}, + None + ) + + # Check log contents + with open(log_file, 'r') as f: + log_contents = f.read() + + # Should log all types correctly + assert "42" in log_contents + assert "3.14" in log_contents + assert "True" in log_contents + assert "'hello'" in log_contents + assert "None" in log_contents + finally: + Path(log_file).unlink(missing_ok=True) + + def test_tensor_metadata(self): + """Test that tensor metadata is logged correctly.""" + with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.txt') as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=2, dest=log_file) + + @decorator + def test_function(tensor): + return tensor + + # Create a tensor with specific properties + tensor = torch.randn(2, 3, 4, dtype=torch.float32, device='cpu') + tensor = tensor.contiguous() + tensor.requires_grad = False + + result = test_function(tensor) + + # Check log contents + with open(log_file, 'r') as f: + log_contents = f.read() + + # Should log all metadata + assert "shape=(2, 3, 4)" in log_contents + assert "dtype=torch.float32" in log_contents + assert "device=cpu" in log_contents + assert "requires_grad=False" in log_contents + assert "is_contiguous=True" in log_contents + assert "stride=" in log_contents + finally: + Path(log_file).unlink(missing_ok=True) + + def test_nested_structures(self): + """Test logging of nested data structures.""" + with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.txt') as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=2, dest=log_file) + + @decorator + def test_function(nested): + return nested + + # Create nested structure + nested = { + "list": [1, 2, 3], + "dict": {"inner": "value"}, + "tuple": (4, 5), + } + + result = test_function(nested) + + # Check log contents + with open(log_file, 'r') as f: + log_contents = f.read() + + # Should handle nested structures + assert "list" in log_contents + assert "dict" in log_contents + assert "tuple" in log_contents + finally: + Path(log_file).unlink(missing_ok=True) + + def test_decorator_with_and_without_parentheses(self): + """Test that decorator works both as @decorator and @decorator().""" + with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.txt') as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=1, dest=log_file) + + # Without parentheses + @decorator + def func1(x): + return x + 1 + + # With parentheses + @decorator() + def func2(x): + return x + 2 + + result1 = func1(10) + result2 = func2(20) + + assert result1 == 11 + assert result2 == 22 + + # Check log contents + with open(log_file, 'r') as f: + log_contents = f.read() + + assert "func1" in log_contents + assert "func2" in log_contents + finally: + Path(log_file).unlink(missing_ok=True) + + def test_multiple_calls_same_function(self): + """Test that multiple calls to the same function are all logged.""" + with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.txt') as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=1, dest=log_file) + + @decorator + def test_function(x): + return x + + # Call multiple times + for i in range(3): + test_function(i) + + # Check log contents + with open(log_file, 'r') as f: + log_contents = f.read() + + # Should have 3 log entries + assert log_contents.count("FlashInfer API Call: test_function") == 3 + finally: + Path(log_file).unlink(missing_ok=True) + + def test_kwargs_logging(self): + """Test that keyword arguments are logged correctly.""" + with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.txt') as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=2, dest=log_file) + + @decorator + def test_function(a, b, c): + return a + b + c + + # Call with keyword arguments + result = test_function(a=1, b=2, c=3) + assert result == 6 + + # Check log contents + with open(log_file, 'r') as f: + log_contents = f.read() + + # Should log keyword arguments + assert "Keyword input arguments:" in log_contents + assert "a=" in log_contents + assert "b=" in log_contents + assert "c=" in log_contents + finally: + Path(log_file).unlink(missing_ok=True) + + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_cuda_graph_compatibility(self): + """Test that level 3 logging is compatible with CUDA graph capture.""" + with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.txt') as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=3, dest=log_file) + + @decorator + def test_cuda_function(tensor): + return tensor * 2.0 + + # Create a CUDA tensor + tensor = torch.randn(10, 10, device='cuda') + + # Test 1: Normal execution (should have statistics) + result1 = test_cuda_function(tensor) + + with open(log_file, 'r') as f: + log_normal = f.read() + + # Should have statistics in normal execution + # (unless PyTorch version is too old) + if hasattr(torch.cuda, 'is_current_stream_capturing'): + # Normal execution should have min/max OR statistics error + has_stats = "min=" in log_normal or "statistics error" in log_normal + assert has_stats, "Expected statistics or error in normal execution" + + # Clear log file + with open(log_file, 'w') as f: + f.write('') + + # Test 2: CUDA graph capture (should skip statistics) + if hasattr(torch.cuda, 'CUDAGraph'): + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + result2 = test_cuda_function(tensor) + + with open(log_file, 'r') as f: + log_capture = f.read() + + # Should skip statistics during capture + assert "[statistics skipped: CUDA graph capture in progress]" in log_capture or \ + "statistics" not in log_capture, \ + "Expected statistics to be skipped during CUDA graph capture" + finally: + Path(log_file).unlink(missing_ok=True) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) + From ae44417d8905d442e13ede019469547624dbfd44 Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Tue, 18 Nov 2025 19:16:31 +0000 Subject: [PATCH 02/13] Adding benchmark. Applying pre-commit --- benchmarks/bench_logging_overhead.py | 350 +++++++++++++++++++++++++ flashinfer/api_logging.py | 203 ++++++++------- flashinfer/gemm/gemm_base.py | 4 +- tests/utils/test_logging.py | 367 ++++++++++++++------------- 4 files changed, 653 insertions(+), 271 deletions(-) create mode 100644 benchmarks/bench_logging_overhead.py diff --git a/benchmarks/bench_logging_overhead.py b/benchmarks/bench_logging_overhead.py new file mode 100644 index 0000000000..de2781feb5 --- /dev/null +++ b/benchmarks/bench_logging_overhead.py @@ -0,0 +1,350 @@ +#!/usr/bin/env python3 +""" +Benchmark script to measure the overhead of API logging at different levels. + +This script creates decorated and undecorated versions of a test function +(torch.matmul) and compares their performance to accurately measure logging overhead. + +Why torch.matmul instead of bmm_fp8? + - bmm_fp8 is already decorated in the FlashInfer source code + - Using it would cause double-decoration and inaccurate results + - torch.matmul gives us a clean baseline to measure pure decorator overhead + +Usage: + # Set the logging level before running + export FLASHINFER_APILOG_LEVEL=2 + python bench_logging_overhead.py + + # Or run with different levels + FLASHINFER_APILOG_LEVEL=0 python bench_logging_overhead.py + FLASHINFER_APILOG_LEVEL=1 python bench_logging_overhead.py + FLASHINFER_APILOG_LEVEL=2 python bench_logging_overhead.py + FLASHINFER_APILOG_LEVEL=3 python bench_logging_overhead.py + + # Or use the helper script to run all levels + bash benchmark_all_levels.sh +""" + +import os +import sys +import time +import torch +import numpy as np +from typing import List, Tuple + +# Get logging level BEFORE importing flashinfer +LOGGING_LEVEL = int(os.environ.get("FLASHINFER_APILOG_LEVEL", "0")) +LOG_DEST = os.environ.get("FLASHINFER_APILOG_DEST", "/tmp/flashinfer_benchmark_log.txt") + +# Import the decorator +try: + from flashinfer.api_logging import flashinfer_api_log +except ImportError as e: + print(f"Error: Could not import flashinfer: {e}") + print("Make sure flashinfer is installed.") + exit(1) + + +# Create two versions of a test function: +# 1. Undecorated (baseline) +# 2. Decorated (with logging) +# +# We use a simple torch.matmul instead of bmm_fp8 because bmm_fp8 is already +# decorated in the source code, which would cause double-decoration. + + +def test_matmul_undecorated(A, B): + """Undecorated version - baseline for comparison.""" + return torch.matmul(A, B) + + +@flashinfer_api_log +def test_matmul_decorated(A, B): + """Decorated version - with API logging.""" + return torch.matmul(A, B) + + +class BenchmarkResults: + """Store and display benchmark results.""" + + def __init__(self): + self.undecorated_times = [] + self.decorated_times = [] + + def set_undecorated(self, times: List[float]): + """Set benchmark results for undecorated function.""" + self.undecorated_times = times + + def set_decorated(self, times: List[float]): + """Set benchmark results for decorated function.""" + self.decorated_times = times + + def print_summary(self, logging_level: int): + """Print a summary of benchmark results.""" + print("\n" + "=" * 80) + print("BENCHMARK RESULTS") + print("=" * 80) + + undecorated_mean = np.mean(self.undecorated_times) + undecorated_std = np.std(self.undecorated_times) + + decorated_mean = np.mean(self.decorated_times) + decorated_std = np.std(self.decorated_times) + + overhead_abs = (decorated_mean - undecorated_mean) * 1000 # ms + overhead_pct = ( + ((decorated_mean - undecorated_mean) / undecorated_mean * 100) + if undecorated_mean > 0 + else 0 + ) + + print( + f"\n{'Version':<20} {'Mean (ms)':<12} {'Std (ms)':<12} {'Median (ms)':<12}" + ) + print("-" * 80) + print( + f"{'Undecorated':<20} {undecorated_mean * 1000:<12.4f} {undecorated_std * 1000:<12.4f} {np.median(self.undecorated_times) * 1000:<12.4f}" + ) + print( + f"{'Decorated':<20} {decorated_mean * 1000:<12.4f} {decorated_std * 1000:<12.4f} {np.median(self.decorated_times) * 1000:<12.4f}" + ) + + print("\n" + "=" * 80) + print("OVERHEAD ANALYSIS") + print("=" * 80) + print(f"\nLogging Level: {logging_level}") + print(f"Absolute overhead: {overhead_abs:.4f} ms") + print(f"Relative overhead: {overhead_pct:.2f}%") + + print("\n" + "=" * 80) + print("DETAILED STATISTICS") + print("=" * 80) + + print("\nUndecorated (baseline):") + print(f" Mean: {undecorated_mean * 1000:.4f} ms") + print(f" Median: {np.median(self.undecorated_times) * 1000:.4f} ms") + print(f" Std: {undecorated_std * 1000:.4f} ms") + print(f" Min: {np.min(self.undecorated_times) * 1000:.4f} ms") + print(f" Max: {np.max(self.undecorated_times) * 1000:.4f} ms") + + print("\nDecorated (with logging):") + print(f" Mean: {decorated_mean * 1000:.4f} ms") + print(f" Median: {np.median(self.decorated_times) * 1000:.4f} ms") + print(f" Std: {decorated_std * 1000:.4f} ms") + print(f" Min: {np.min(self.decorated_times) * 1000:.4f} ms") + print(f" Max: {np.max(self.decorated_times) * 1000:.4f} ms") + + +def setup_test_inputs( + batch_size: int = 32, + m: int = 512, + n: int = 512, + k: int = 512, + device: str = "cuda:0", +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Set up test inputs for matmul. + + Parameters + ---------- + batch_size : int + Batch size for the matrix multiplication + m, n, k : int + Matrix dimensions + device : str + Device to use + + Returns + ------- + A, B : torch.Tensor + Input tensors for matrix multiplication + """ + # Create random tensors + A = torch.randn(batch_size, m, k, dtype=torch.float16, device=device) + B = torch.randn(batch_size, k, n, dtype=torch.float16, device=device) + + return A, B + + +def warmup(func, A, B, num_warmup: int = 10): + """Warmup the GPU and JIT compilation.""" + for _ in range(num_warmup): + _ = func(A, B) + torch.cuda.synchronize() + + +def benchmark_function( + func, func_name: str, A, B, num_iterations: int = 100 +) -> List[float]: + """ + Benchmark a specific function. + + Parameters + ---------- + func : callable + Function to benchmark + func_name : str + Name of the function (for display) + A, B : torch.Tensor + Input tensors for matrix multiplication + num_iterations : int + Number of iterations to run + + Returns + ------- + List[float] + List of execution times in seconds + """ + print(f"\nBenchmarking: {func_name}") + print(f" Running {num_iterations} iterations...") + + times = [] + + for _ in range(num_iterations): + # Synchronize before timing + torch.cuda.synchronize() + + # Time the execution + start = time.perf_counter() + _ = func(A, B) + torch.cuda.synchronize() + end = time.perf_counter() + + elapsed = end - start + times.append(elapsed) + + print(f" Complete. Mean time: {np.mean(times) * 1000:.4f} ms") + + return times + + +def main(): + """Main benchmark function.""" + print("=" * 80) + print("FlashInfer API Logging Overhead Benchmark") + print("=" * 80) + + # Display logging configuration + print("\nLogging Configuration:") + print(f" FLASHINFER_APILOG_LEVEL = {LOGGING_LEVEL}") + print(f" FLASHINFER_APILOG_DEST = {LOG_DEST}") + + # Get level name + level_names = { + 0: "No logging (zero-overhead)", + 1: "Function name only", + 2: "Name + inputs/outputs + metadata", + 3: "Name + inputs/outputs + metadata + statistics", + } + print(f" Level description: {level_names.get(LOGGING_LEVEL, 'Unknown')}") + + # Check if CUDA is available + if not torch.cuda.is_available(): + print("\nError: CUDA is not available. This benchmark requires a CUDA device.") + exit(1) + + device = "cuda:0" + print(f"\nDevice: {device}") + print(f"Device Name: {torch.cuda.get_device_name(device)}") + + # Setup test inputs + print("\nSetting up test inputs...") + batch_size = 32 + m, n, k = 128, 128, 128 + print(f" Batch size: {batch_size}") + print(f" Matrix dimensions: [{batch_size}, {m}, {k}] @ [{batch_size}, {k}, {n}]") + + A, B = setup_test_inputs(batch_size, m, n, k, device) + + # Benchmark parameters + num_iterations = 100 + print("\nBenchmark parameters:") + print(f" Iterations: {num_iterations}") + print(" Warmup iterations: 10") + + # Clear log file before starting + if os.path.exists(LOG_DEST): + os.remove(LOG_DEST) + + print("\n" + "=" * 80) + print("WARMUP PHASE") + print("=" * 80) + + # Warmup undecorated version + print("\nWarming up undecorated version...") + warmup(test_matmul_undecorated, A, B, num_warmup=10) + print(" Complete.") + + # Warmup decorated version + print("\nWarming up decorated version...") + warmup(test_matmul_decorated, A, B, num_warmup=10) + print(" Complete.") + + print("\n" + "=" * 80) + print("BENCHMARK PHASE") + print("=" * 80) + + # Store results + results = BenchmarkResults() + + # Benchmark undecorated version + undecorated_times = benchmark_function( + test_matmul_undecorated, "Undecorated (baseline)", A, B, num_iterations + ) + results.set_undecorated(undecorated_times) + + # Benchmark decorated version + decorated_times = benchmark_function( + test_matmul_decorated, + f"Decorated (logging level {LOGGING_LEVEL})", + A, + B, + num_iterations, + ) + results.set_decorated(decorated_times) + + # Print summary + results.print_summary(LOGGING_LEVEL) + + # Check log file size + if LOGGING_LEVEL > 0 and os.path.exists(LOG_DEST): + log_size = os.path.getsize(LOG_DEST) + print("\n" + "=" * 80) + print("LOG FILE INFO") + print("=" * 80) + print(f"Log file: {LOG_DEST}") + print(f"Log size: {log_size / 1024:.2f} KB ({log_size} bytes)") + print(f"Iterations logged: {num_iterations}") + print(f"Bytes per iteration: {log_size / num_iterations:.2f}") + + # Cleanup option + cleanup_log = os.environ.get("CLEANUP_LOG", "true").lower() == "true" + if cleanup_log: + os.remove(LOG_DEST) + print("\n Log file removed (set CLEANUP_LOG=false to keep it)") + else: + print(f"\n Log file preserved at {LOG_DEST}") + + print("\n" + "=" * 80) + print("RECOMMENDATIONS") + print("=" * 80) + print("\nTo benchmark other levels, run:") + for level in [0, 1, 2, 3]: + if level != LOGGING_LEVEL: + print(f" FLASHINFER_APILOG_LEVEL={level} python {sys.argv[0]}") + + print("\n" + "=" * 80) + print("Benchmark complete!") + print("=" * 80) + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("\n\nBenchmark interrupted by user.") + except Exception as e: + print(f"\n\nError during benchmark: {e}") + import traceback + + traceback.print_exc() diff --git a/flashinfer/api_logging.py b/flashinfer/api_logging.py index c98d0f87df..2912d2c7da 100644 --- a/flashinfer/api_logging.py +++ b/flashinfer/api_logging.py @@ -21,7 +21,7 @@ import os import sys from typing import Any, Callable - +import contextlib import torch @@ -32,6 +32,7 @@ # Create logger using Python's logging library _logger = logging.getLogger("flashinfer.api") + def _setup_logger(): """Set up the logger based on environment variables.""" if _API_LOG_LEVEL == 0: @@ -39,35 +40,37 @@ def _setup_logger(): _logger.addHandler(logging.NullHandler()) _logger.setLevel(logging.CRITICAL + 1) # Higher than any level return - + # All enabled levels use loggging.DEBUG; verbosity is controlled by FLASHINFER_APILOG_LEVEL instead _logger.setLevel(logging.DEBUG) - + # Remove any existing handlers _logger.handlers.clear() - + # Create handler based on destination if _API_LOG_DEST == "stdout": handler = logging.StreamHandler(sys.stdout) elif _API_LOG_DEST == "stderr": handler = logging.StreamHandler(sys.stderr) else: - handler = logging.FileHandler(_API_LOG_DEST, mode='a') - + handler = logging.FileHandler(_API_LOG_DEST, mode="a") + # Use a simple formatter (we'll format the detailed content ourselves) - formatter = logging.Formatter('%(message)s') + formatter = logging.Formatter("%(message)s") handler.setFormatter(formatter) - + _logger.addHandler(handler) _logger.propagate = False # Don't propagate to root logger + # Initialize logger at module load time _setup_logger() + def _format_value(value: Any, level: int, indent: int = 0) -> str: """ Format a value for logging based on the log level. - + Parameters ---------- value : Any @@ -76,28 +79,30 @@ def _format_value(value: Any, level: int, indent: int = 0) -> str: The logging level (1, 2, or 3) indent : int The indentation level for nested structures - + Returns ------- str Formatted string representation of the value """ indent_str = " " * indent - + # Handle None if value is None: return f"{indent_str}None" - + # Handle Enum types if isinstance(value, enum.Enum): # Show both the name and value of the enum - return f"{indent_str}{value.__class__.__name__}.{value.name} (value={value.value})" - + return ( + f"{indent_str}{value.__class__.__name__}.{value.name} (value={value.value})" + ) + # Handle torch.Tensor if isinstance(value, torch.Tensor): if level == 1: return f"{indent_str}Tensor(...)" - + # Level 2+: Show metadata lines = [f"{indent_str}Tensor("] lines.append(f"{indent_str} shape={tuple(value.shape)}") @@ -106,105 +111,130 @@ def _format_value(value: Any, level: int, indent: int = 0) -> str: lines.append(f"{indent_str} device={value.device}") lines.append(f"{indent_str} requires_grad={value.requires_grad}") lines.append(f"{indent_str} is_contiguous={value.is_contiguous()}") - + # Level 3: Add statistics if level >= 3: try: # Skip statistics if we're in CUDA graph capture mode # (operations like .min()/.max()/.mean() cause synchronization issues) is_capturing = False - if value.is_cuda and hasattr(torch.cuda, 'is_current_stream_capturing'): - try: + if value.is_cuda and hasattr(torch.cuda, "is_current_stream_capturing"): + with contextlib.suppress(Exception): is_capturing = torch.cuda.is_current_stream_capturing() - except Exception: - pass # Fallback if detection fails - + if is_capturing: - lines.append(f"{indent_str} [statistics skipped: CUDA graph capture in progress]") + lines.append( + f"{indent_str} [statistics skipped: CUDA graph capture in progress]" + ) elif value.numel() > 0: # Convert to float for statistics if possible - if value.dtype in [torch.float16, torch.float32, torch.float64, - torch.bfloat16, torch.float8_e4m3fn, torch.float8_e5m2]: + if value.dtype in [ + torch.float16, + torch.float32, + torch.float64, + torch.bfloat16, + torch.float8_e4m3fn, + torch.float8_e5m2, + ]: val_float = value.float() lines.append(f"{indent_str} min={val_float.min().item():.6f}") lines.append(f"{indent_str} max={val_float.max().item():.6f}") - lines.append(f"{indent_str} mean={val_float.mean().item():.6f}") + lines.append( + f"{indent_str} mean={val_float.mean().item():.6f}" + ) nan_count = torch.isnan(val_float).sum().item() lines.append(f"{indent_str} nan_count={nan_count}") inf_count = torch.isinf(val_float).sum().item() lines.append(f"{indent_str} inf_count={inf_count}") - elif value.dtype in [torch.int8, torch.int16, torch.int32, torch.int64, - torch.uint8]: + elif value.dtype in [ + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.uint8, + ]: lines.append(f"{indent_str} min={value.min().item()}") lines.append(f"{indent_str} max={value.max().item()}") - lines.append(f"{indent_str} mean={value.float().mean().item():.6f}") + lines.append( + f"{indent_str} mean={value.float().mean().item():.6f}" + ) except Exception as e: lines.append(f"{indent_str} [statistics error: {e}]") - + lines.append(f"{indent_str})") return "\n".join(lines) - + # Handle FP4Tensor (custom FlashInfer type) - if hasattr(value, '__class__') and value.__class__.__name__ == 'FP4Tensor': + if hasattr(value, "__class__") and value.__class__.__name__ == "FP4Tensor": if level == 1: return f"{indent_str}FP4Tensor(...)" - + lines = [f"{indent_str}FP4Tensor("] - lines.append(f"{indent_str} data={_format_value(value.data, level, indent + 1)}") - lines.append(f"{indent_str} scale={_format_value(value.scale, level, indent + 1)}") + lines.append( + f"{indent_str} data={_format_value(value.data, level, indent + 1)}" + ) + lines.append( + f"{indent_str} scale={_format_value(value.scale, level, indent + 1)}" + ) lines.append(f"{indent_str} scale_start_index={value.scale_start_index}") - if hasattr(value, 'original_shape') and value.original_shape is not None: + if hasattr(value, "original_shape") and value.original_shape is not None: lines.append(f"{indent_str} original_shape={value.original_shape}") lines.append(f"{indent_str})") return "\n".join(lines) - + # Handle lists if isinstance(value, list): if len(value) == 0: return f"{indent_str}[]" if level == 1: return f"{indent_str}[list with {len(value)} items]" - + lines = [f"{indent_str}["] for i, item in enumerate(value): - lines.append(f"{indent_str} [{i}]: {_format_value(item, level, indent + 1)}") + lines.append( + f"{indent_str} [{i}]: {_format_value(item, level, indent + 1)}" + ) lines.append(f"{indent_str}]") return "\n".join(lines) - + # Handle tuples if isinstance(value, tuple): if len(value) == 0: return f"{indent_str}()" if level == 1: return f"{indent_str}(tuple with {len(value)} items)" - + lines = [f"{indent_str}("] for i, item in enumerate(value): - lines.append(f"{indent_str} [{i}]: {_format_value(item, level, indent + 1)}") + lines.append( + f"{indent_str} [{i}]: {_format_value(item, level, indent + 1)}" + ) lines.append(f"{indent_str})") return "\n".join(lines) - + # Handle dictionaries if isinstance(value, dict): if len(value) == 0: return f"{indent_str}{{}}" if level == 1: return f"{indent_str}{{dict with {len(value)} keys}}" - + lines = [f"{indent_str}{{"] for key, val in value.items(): - lines.append(f"{indent_str} {repr(key)}: {_format_value(val, level, indent + 1)}") + lines.append( + f"{indent_str} {repr(key)}: {_format_value(val, level, indent + 1)}" + ) lines.append(f"{indent_str}}}") return "\n".join(lines) - + # Handle numeric types (int, float, bool) if isinstance(value, (int, float, bool, complex)): return f"{indent_str}{value}" - + # Handle strings if isinstance(value, str): return f"{indent_str}{repr(value)}" - + # Default: use repr try: return f"{indent_str}{repr(value)}" @@ -215,7 +245,7 @@ def _format_value(value: Any, level: int, indent: int = 0) -> str: def _get_default_params(func: Callable, args: tuple, kwargs: dict) -> dict: """ Extract parameters that have default values but were not explicitly provided. - + Parameters ---------- func : Callable @@ -224,7 +254,7 @@ def _get_default_params(func: Callable, args: tuple, kwargs: dict) -> dict: Positional arguments that were provided kwargs : dict Keyword arguments that were provided - + Returns ------- dict @@ -233,42 +263,38 @@ def _get_default_params(func: Callable, args: tuple, kwargs: dict) -> dict: try: sig = inspect.signature(func) default_params = {} - - # Get parameter names in order - param_names = list(sig.parameters.keys()) - + # Determine which parameters were NOT provided for i, (param_name, param) in enumerate(sig.parameters.items()): # Skip if parameter has no default if param.default is inspect.Parameter.empty: continue - + # Check if this parameter was provided provided = False - - # Check positional args (accounting for 'self' in methods) - if i < len(args): - provided = True - # Check keyword args - elif param_name in kwargs: + + # Check positional args and keyword args + if i < len(args) or param_name in kwargs: provided = True - + # If not provided, record the default value if not provided: default_params[param_name] = param.default - + return default_params except Exception: # If we can't inspect the signature, return empty dict return {} -def _log_function_inputs(func: Callable, func_name: str, args: tuple, kwargs: dict, level: int) -> None: +def _log_function_inputs( + func: Callable, func_name: str, args: tuple, kwargs: dict, level: int +) -> None: """ Log function inputs BEFORE execution for crash safety. - + This ensures inputs are captured even if the function crashes with a CUDA error. - + Parameters ---------- func : Callable @@ -286,16 +312,16 @@ def _log_function_inputs(func: Callable, func_name: str, args: tuple, kwargs: di lines.append("=" * 80) lines.append(f"FlashInfer API Call: {func_name}") lines.append("-" * 80) - + # Log explicitly provided inputs - if args or kwargs: + if args or kwargs: # Positional arguments if args: lines.append("Positional input arguments:") for i, arg in enumerate(args): lines.append(f" arg[{i}]:") lines.append(_format_value(arg, level, indent=2)) - + # Keyword arguments if kwargs: lines.append("Keyword input arguments:") @@ -304,7 +330,7 @@ def _log_function_inputs(func: Callable, func_name: str, args: tuple, kwargs: di lines.append(_format_value(value, level, indent=2)) else: lines.append("(No explicit arguments)") - + # Log default parameters that were not explicitly provided default_params = _get_default_params(func, args, kwargs) if default_params: @@ -312,14 +338,14 @@ def _log_function_inputs(func: Callable, func_name: str, args: tuple, kwargs: di for param_name, default_value in default_params.items(): lines.append(f" {param_name}= [DEFAULT]") lines.append(_format_value(default_value, level, indent=2)) - + _logger.debug("\n".join(lines)) def _log_function_outputs(func_name: str, result: Any, level: int) -> None: """ Log function outputs AFTER successful execution. - + Parameters ---------- func_name : str @@ -329,24 +355,24 @@ def _log_function_outputs(func_name: str, result: Any, level: int) -> None: level : int Logging level (2 or 3) """ - lines = [] + lines = [] # Log outputs lines.append("Output value:") lines.append(_format_value(result, level, indent=1)) - + lines.append("=" * 80) lines.append("") # Empty line for readability - + _logger.debug("\n".join(lines)) def flashinfer_api_log(func: Callable = None) -> Callable: """ Decorator to log FlashInfer API calls using Python's logging library. - + This decorator integrates with Python's standard logging infrastructure while maintaining zero overhead when disabled (FLASHINFER_APILOG_LEVEL=0). - + Environment Variables --------------------- FLASHINFER_APILOG_LEVEL : int (default: 0) @@ -354,20 +380,20 @@ def flashinfer_api_log(func: Callable = None) -> Callable: - 1: Log function name only (logged BEFORE execution - crash-safe) - 2: Log function name + inputs/outputs with metadata (inputs logged BEFORE execution - crash-safe) - 3: Log function name + inputs/outputs with metadata + tensor statistics (inputs logged BEFORE execution - crash-safe) - + FLASHINFER_APILOG_DEST : str (default: "./flashinfer_log.txt") - "stdout": Log to standard output - "stderr": Log to standard error - : Log to specified file path - + Examples -------- Basic usage: - + >>> @flashinfer_api_log ... def my_function(x, y): ... return x + y - + Notes ----- - When FLASHINFER_APILOG_LEVEL=0, the decorator has truly zero overhead @@ -388,20 +414,22 @@ def flashinfer_api_log(func: Callable = None) -> Callable: if func is None: return lambda f: f return func - + def decorator(f: Callable) -> Callable: @functools.wraps(f) def wrapper(*args, **kwargs): # Determine function name (with class name if applicable) func_name = f.__name__ - if args and hasattr(args[0], '__class__'): + if args and hasattr(args[0], "__class__"): try: class_name = args[0].__class__.__name__ - if 'Wrapper' in class_name or class_name in ['BatchMLAPagedAttentionWrapper']: + if "Wrapper" in class_name or class_name in [ + "BatchMLAPagedAttentionWrapper" + ]: func_name = f"{class_name}.{func_name}" - except: + except Exception: pass - + # Log BEFORE execution (crash-safe for all levels!) try: if _API_LOG_LEVEL == 1: @@ -412,7 +440,7 @@ def wrapper(*args, **kwargs): _log_function_inputs(f, func_name, args, kwargs, _API_LOG_LEVEL) except Exception as e: _logger.error(f"[LOGGING ERROR in {func_name} (pre-execution)]: {e}") - + # Call the original function (may crash here with CUDA errors) result = f(*args, **kwargs) @@ -427,9 +455,8 @@ def wrapper(*args, **kwargs): return result return wrapper - + # Support both @flashinfer_api_log and @flashinfer_api_log() if func is None: return decorator return decorator(func) - diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 6306336876..90900b2775 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -2019,7 +2019,6 @@ def _heuristic_func_mm_fp4( return [c for c in candidate_backends if c in suitable_backends] -@flashinfer_api_log @backend_requirement( { "cudnn": _cudnn_gemm_fp4_requirement, @@ -2029,6 +2028,7 @@ def _heuristic_func_mm_fp4( common_check=_check_mm_fp4_problem_size, heuristic_func=_heuristic_func_mm_fp4, # result stored in mm_fp4.suitable_auto_backends ) +@flashinfer_api_log def mm_fp4( a: torch.Tensor, b: torch.Tensor, @@ -2277,7 +2277,6 @@ def _heuristic_func_bmm_fp8( return heuristic_backends -@flashinfer_api_log @backend_requirement( { "cudnn": _cudnn_bmm_fp8_requirement, @@ -2287,6 +2286,7 @@ def _heuristic_func_bmm_fp8( common_check=_check_bmm_fp8_problem_size, heuristic_func=_heuristic_func_bmm_fp8, ) +@flashinfer_api_log def bmm_fp8( A: torch.Tensor, B: torch.Tensor, diff --git a/tests/utils/test_logging.py b/tests/utils/test_logging.py index 5d6865c743..3559d1d11b 100644 --- a/tests/utils/test_logging.py +++ b/tests/utils/test_logging.py @@ -18,7 +18,6 @@ import sys import tempfile from enum import Enum -from io import StringIO from pathlib import Path import pytest @@ -28,6 +27,7 @@ # Test enum classes class TestEnum(Enum): """Test enum with integer values.""" + OPTION_A = 0 OPTION_B = 1 OPTION_C = 2 @@ -35,145 +35,147 @@ class TestEnum(Enum): class StringEnum(Enum): """Test enum with string values. Names are for testing purposes.""" + MODE_STANDARD = "standard" MODE_OPTIMIZED = "optimized" class TestAPILogging: """Test suite for FlashInfer API logging infrastructure.""" - + @pytest.fixture(autouse=True) def setup_and_teardown(self): """Reset environment and reimport logging module for each test.""" # Store original environment original_level = os.environ.get("FLASHINFER_APILOG_LEVEL") original_dest = os.environ.get("FLASHINFER_APILOG_DEST") - + yield - + # Restore original environment if original_level is not None: os.environ["FLASHINFER_APILOG_LEVEL"] = original_level elif "FLASHINFER_APILOG_LEVEL" in os.environ: del os.environ["FLASHINFER_APILOG_LEVEL"] - + if original_dest is not None: os.environ["FLASHINFER_APILOG_DEST"] = original_dest elif "FLASHINFER_APILOG_DEST" in os.environ: del os.environ["FLASHINFER_APILOG_DEST"] - + # Force reimport to pick up new environment variables if "flashinfer.api_logging" in sys.modules: del sys.modules["flashinfer.api_logging"] - + def setup_logging(self, level: int, dest: str = "stdout"): """Helper to set up logging environment and reimport.""" os.environ["FLASHINFER_APILOG_LEVEL"] = str(level) os.environ["FLASHINFER_APILOG_DEST"] = dest - + # Force reimport if "flashinfer.api_logging" in sys.modules: del sys.modules["flashinfer.api_logging"] - + from flashinfer.api_logging import flashinfer_api_log + return flashinfer_api_log - + def test_level_0_zero_overhead(self): """Test that level 0 has truly zero overhead (returns original function).""" decorator = self.setup_logging(level=0) - + def original_func(x, y): return x + y - + decorated_func = decorator(original_func) - + # At level 0, decorator should return the original function unchanged assert decorated_func is original_func assert decorated_func(5, 3) == 8 - + def test_level_1_function_name(self): """Test that level 1 logs function name only.""" - with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.txt') as f: + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: log_file = f.name - + try: decorator = self.setup_logging(level=1, dest=log_file) - + @decorator def test_function(x, y): return x + y - + result = test_function(10, 20) assert result == 30 - + # Check log contents - with open(log_file, 'r') as f: + with open(log_file, "r") as f: log_contents = f.read() - + assert "FlashInfer API Call: test_function" in log_contents # Level 1 should not log inputs/outputs details assert "Positional input arguments" not in log_contents assert "Output value" not in log_contents finally: Path(log_file).unlink(missing_ok=True) - + def test_level_2_inputs_outputs(self): """Test that level 2 logs inputs and outputs with metadata.""" - with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.txt') as f: + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: log_file = f.name - + try: decorator = self.setup_logging(level=2, dest=log_file) - + @decorator def test_function(tensor, value): return tensor * value - + tensor = torch.tensor([1.0, 2.0, 3.0]) - result = test_function(tensor, 2.0) - + test_function(tensor, 2.0) + # Check log contents - with open(log_file, 'r') as f: + with open(log_file, "r") as f: log_contents = f.read() - + # Should log function name assert "FlashInfer API Call: test_function" in log_contents - + # Should log inputs assert "Positional input arguments" in log_contents assert "arg[0]" in log_contents assert "Tensor(" in log_contents assert "shape=(3,)" in log_contents assert "dtype=torch.float32" in log_contents - + # Should log outputs assert "Output value:" in log_contents - + # Should NOT log statistics (level 3 only) assert "min=" not in log_contents assert "max=" not in log_contents finally: Path(log_file).unlink(missing_ok=True) - + def test_level_3_statistics(self): """Test that level 3 logs tensor statistics.""" - with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.txt') as f: + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: log_file = f.name - + try: decorator = self.setup_logging(level=3, dest=log_file) - + @decorator def test_function(tensor): return tensor + 1.0 - + tensor = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) - result = test_function(tensor) - + test_function(tensor) + # Check log contents - with open(log_file, 'r') as f: + with open(log_file, "r") as f: log_contents = f.read() - + # Should log statistics assert "min=" in log_contents assert "max=" in log_contents @@ -182,88 +184,92 @@ def test_function(tensor): assert "inf_count=" in log_contents finally: Path(log_file).unlink(missing_ok=True) - + def test_enum_logging(self): """Test that enum values are logged with name and value.""" - with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.txt') as f: + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: log_file = f.name - + try: decorator = self.setup_logging(level=2, dest=log_file) - + @decorator def test_function(mode: TestEnum, strategy: StringEnum): return f"{mode.name}_{strategy.name}" - - result = test_function(TestEnum.OPTION_B, StringEnum.MODE_OPTIMIZED) - + + test_function(TestEnum.OPTION_B, StringEnum.MODE_OPTIMIZED) + # Check log contents - with open(log_file, 'r') as f: + with open(log_file, "r") as f: log_contents = f.read() - + # Should show enum name and value assert "TestEnum.OPTION_B" in log_contents assert "(value=1)" in log_contents assert "StringEnum.MODE_OPTIMIZED" in log_contents - assert "(value=optimized)" in log_contents or "(value='optimized')" in log_contents or '(value="optimized")' in log_contents + assert ( + "(value=optimized)" in log_contents + or "(value='optimized')" in log_contents + or '(value="optimized")' in log_contents + ) finally: Path(log_file).unlink(missing_ok=True) - + def test_default_parameters(self): """Test that default parameters are logged separately.""" - with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.txt') as f: + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: log_file = f.name - + try: decorator = self.setup_logging(level=2, dest=log_file) - + @decorator def test_function(x, y=10, z=20, mode=TestEnum.OPTION_A): return x + y + z - + # Call with only required argument result = test_function(5) assert result == 35 - + # Check log contents - with open(log_file, 'r') as f: + with open(log_file, "r") as f: log_contents = f.read() - + # Should show default parameters section assert "Default parameters (not explicitly provided)" in log_contents assert "[DEFAULT]" in log_contents - + # Should show the default values assert "y=" in log_contents assert "z=" in log_contents assert "mode=" in log_contents finally: Path(log_file).unlink(missing_ok=True) - + def test_explicit_vs_default_parameters(self): """Test that explicitly provided parameters are not shown in defaults.""" - with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.txt') as f: + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: log_file = f.name - + try: decorator = self.setup_logging(level=2, dest=log_file) - + @decorator def test_function(x, y=10, z=20): return x + y + z - + # Call with some explicit parameters - result = test_function(5, y=100) - + test_function(5, y=100) + # Check log contents - with open(log_file, 'r') as f: + with open(log_file, "r") as f: log_contents = f.read() - + # y should be in keyword arguments (explicit) assert "Keyword input arguments:" in log_contents - + # Only z should be in defaults - lines = log_contents.split('\n') + lines = log_contents.split("\n") default_section_started = False defaults_found = [] for line in lines: @@ -271,59 +277,59 @@ def test_function(x, y=10, z=20): default_section_started = True if default_section_started and "=" in line and "[DEFAULT]" in line: defaults_found.append(line) - + # Should have only one default parameter (z) assert len(defaults_found) == 1 assert "z=" in defaults_found[0] finally: Path(log_file).unlink(missing_ok=True) - + def test_class_method_logging(self): """Test that class methods log with class name.""" - with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.txt') as f: + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: log_file = f.name - + try: decorator = self.setup_logging(level=1, dest=log_file) - + class TestWrapper: @decorator def run(self, x): return x * 2 - + wrapper = TestWrapper() result = wrapper.run(5) assert result == 10 - + # Check log contents - with open(log_file, 'r') as f: + with open(log_file, "r") as f: log_contents = f.read() - + # Should log class name for Wrapper classes assert "TestWrapper.run" in log_contents finally: Path(log_file).unlink(missing_ok=True) - + def test_crash_safety_inputs_logged_before_execution(self): """Test that inputs are logged BEFORE execution (crash-safe).""" - with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.txt') as f: + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: log_file = f.name - + try: decorator = self.setup_logging(level=2, dest=log_file) - + @decorator def crashing_function(x, y): raise RuntimeError("Simulated crash") - + # Call the function and expect it to crash with pytest.raises(RuntimeError, match="Simulated crash"): crashing_function(42, 99) - + # Check that inputs were still logged - with open(log_file, 'r') as f: + with open(log_file, "r") as f: log_contents = f.read() - + # Inputs should be in the log even though function crashed assert "FlashInfer API Call: crashing_function" in log_contents assert "Positional input arguments" in log_contents @@ -331,42 +337,41 @@ def crashing_function(x, y): assert "42" in log_contents assert "arg[1]" in log_contents assert "99" in log_contents - + # Outputs should NOT be in the log (function crashed) assert "Output value:" not in log_contents finally: Path(log_file).unlink(missing_ok=True) - + def test_different_data_types(self): """Test logging of various data types.""" - with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.txt') as f: + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: log_file = f.name - + try: decorator = self.setup_logging(level=2, dest=log_file) - + @decorator def test_function( - int_val, float_val, bool_val, str_val, - list_val, tuple_val, dict_val, none_val + int_val, + float_val, + bool_val, + str_val, + list_val, + tuple_val, + dict_val, + none_val, ): return "success" - - result = test_function( - 42, - 3.14, - True, - "hello", - [1, 2, 3], - (4, 5, 6), - {"key": "value"}, - None + + test_function( + 42, 3.14, True, "hello", [1, 2, 3], (4, 5, 6), {"key": "value"}, None ) - + # Check log contents - with open(log_file, 'r') as f: + with open(log_file, "r") as f: log_contents = f.read() - + # Should log all types correctly assert "42" in log_contents assert "3.14" in log_contents @@ -375,30 +380,30 @@ def test_function( assert "None" in log_contents finally: Path(log_file).unlink(missing_ok=True) - + def test_tensor_metadata(self): """Test that tensor metadata is logged correctly.""" - with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.txt') as f: + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: log_file = f.name - + try: decorator = self.setup_logging(level=2, dest=log_file) - + @decorator def test_function(tensor): return tensor - + # Create a tensor with specific properties - tensor = torch.randn(2, 3, 4, dtype=torch.float32, device='cpu') + tensor = torch.randn(2, 3, 4, dtype=torch.float32, device="cpu") tensor = tensor.contiguous() tensor.requires_grad = False - - result = test_function(tensor) - + + test_function(tensor) + # Check log contents - with open(log_file, 'r') as f: + with open(log_file, "r") as f: log_contents = f.read() - + # Should log all metadata assert "shape=(2, 3, 4)" in log_contents assert "dtype=torch.float32" in log_contents @@ -408,117 +413,117 @@ def test_function(tensor): assert "stride=" in log_contents finally: Path(log_file).unlink(missing_ok=True) - + def test_nested_structures(self): """Test logging of nested data structures.""" - with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.txt') as f: + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: log_file = f.name - + try: decorator = self.setup_logging(level=2, dest=log_file) - + @decorator def test_function(nested): return nested - + # Create nested structure nested = { "list": [1, 2, 3], "dict": {"inner": "value"}, "tuple": (4, 5), } - - result = test_function(nested) - + + test_function(nested) + # Check log contents - with open(log_file, 'r') as f: + with open(log_file, "r") as f: log_contents = f.read() - + # Should handle nested structures assert "list" in log_contents assert "dict" in log_contents assert "tuple" in log_contents finally: Path(log_file).unlink(missing_ok=True) - + def test_decorator_with_and_without_parentheses(self): """Test that decorator works both as @decorator and @decorator().""" - with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.txt') as f: + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: log_file = f.name - + try: decorator = self.setup_logging(level=1, dest=log_file) - + # Without parentheses @decorator def func1(x): return x + 1 - + # With parentheses @decorator() def func2(x): return x + 2 - + result1 = func1(10) result2 = func2(20) - + assert result1 == 11 assert result2 == 22 - + # Check log contents - with open(log_file, 'r') as f: + with open(log_file, "r") as f: log_contents = f.read() - + assert "func1" in log_contents assert "func2" in log_contents finally: Path(log_file).unlink(missing_ok=True) - + def test_multiple_calls_same_function(self): """Test that multiple calls to the same function are all logged.""" - with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.txt') as f: + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: log_file = f.name - + try: decorator = self.setup_logging(level=1, dest=log_file) - + @decorator def test_function(x): return x - + # Call multiple times for i in range(3): test_function(i) - + # Check log contents - with open(log_file, 'r') as f: + with open(log_file, "r") as f: log_contents = f.read() - + # Should have 3 log entries assert log_contents.count("FlashInfer API Call: test_function") == 3 finally: Path(log_file).unlink(missing_ok=True) - + def test_kwargs_logging(self): """Test that keyword arguments are logged correctly.""" - with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.txt') as f: + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: log_file = f.name - + try: decorator = self.setup_logging(level=2, dest=log_file) - + @decorator def test_function(a, b, c): return a + b + c - + # Call with keyword arguments result = test_function(a=1, b=2, c=3) assert result == 6 - + # Check log contents - with open(log_file, 'r') as f: + with open(log_file, "r") as f: log_contents = f.read() - + # Should log keyword arguments assert "Keyword input arguments:" in log_contents assert "a=" in log_contents @@ -527,57 +532,57 @@ def test_function(a, b, c): finally: Path(log_file).unlink(missing_ok=True) - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_cuda_graph_compatibility(self): """Test that level 3 logging is compatible with CUDA graph capture.""" - with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.txt') as f: + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: log_file = f.name - + try: decorator = self.setup_logging(level=3, dest=log_file) - + @decorator def test_cuda_function(tensor): return tensor * 2.0 - + # Create a CUDA tensor - tensor = torch.randn(10, 10, device='cuda') - + tensor = torch.randn(10, 10, device="cuda") + # Test 1: Normal execution (should have statistics) - result1 = test_cuda_function(tensor) - - with open(log_file, 'r') as f: + test_cuda_function(tensor) + + with open(log_file, "r") as f: log_normal = f.read() - + # Should have statistics in normal execution # (unless PyTorch version is too old) - if hasattr(torch.cuda, 'is_current_stream_capturing'): + if hasattr(torch.cuda, "is_current_stream_capturing"): # Normal execution should have min/max OR statistics error has_stats = "min=" in log_normal or "statistics error" in log_normal assert has_stats, "Expected statistics or error in normal execution" - + # Clear log file - with open(log_file, 'w') as f: - f.write('') - + with open(log_file, "w") as f: + f.write("") + # Test 2: CUDA graph capture (should skip statistics) - if hasattr(torch.cuda, 'CUDAGraph'): + if hasattr(torch.cuda, "CUDAGraph"): graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph): - result2 = test_cuda_function(tensor) - - with open(log_file, 'r') as f: + test_cuda_function(tensor) + + with open(log_file, "r") as f: log_capture = f.read() - + # Should skip statistics during capture - assert "[statistics skipped: CUDA graph capture in progress]" in log_capture or \ - "statistics" not in log_capture, \ - "Expected statistics to be skipped during CUDA graph capture" + assert ( + "[statistics skipped: CUDA graph capture in progress]" + in log_capture + or "statistics" not in log_capture + ), "Expected statistics to be skipped during CUDA graph capture" finally: Path(log_file).unlink(missing_ok=True) if __name__ == "__main__": pytest.main([__file__, "-v"]) - From 753e60038b10d1b40cbb99be5e3118d3a94967fe Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Tue, 18 Nov 2025 22:15:04 +0000 Subject: [PATCH 03/13] Log System Info --- flashinfer/api_logging.py | 75 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/flashinfer/api_logging.py b/flashinfer/api_logging.py index 2912d2c7da..9626a39202 100644 --- a/flashinfer/api_logging.py +++ b/flashinfer/api_logging.py @@ -67,6 +67,81 @@ def _setup_logger(): _setup_logger() +def _log_system_info(): + """Log system information once at module initialization.""" + if _API_LOG_LEVEL == 0: + return + + lines = [] + lines.append("=" * 80) + lines.append("FlashInfer API Logging - System Information") + lines.append("=" * 80) + + try: + # FlashInfer version + try: + from .version import __version__ as flashinfer_version + + lines.append(f"FlashInfer version: {flashinfer_version}") + except Exception: + lines.append("FlashInfer version: ") + + # CUDA toolkit version + cuda_version = torch.version.cuda + if cuda_version: + lines.append(f"CUDA toolkit version: {cuda_version}") + else: + lines.append("CUDA toolkit version: ") + + # cuDNN version + try: + if torch.backends.cudnn.is_available(): + cudnn_version = torch.backends.cudnn.version() + if cudnn_version: + lines.append(f"cuDNN version: {cudnn_version}") + else: + lines.append("cuDNN version: ") + else: + lines.append("cuDNN version: ") + except Exception as e: + lines.append(f"cuDNN version: ") + + # GPU information (if CUDA is available) + if torch.cuda.is_available(): + device_count = torch.cuda.device_count() + lines.append(f"Number of GPUs: {device_count}") + + # Log information for each GPU + for i in range(device_count): + try: + gpu_name = torch.cuda.get_device_name(i) + capability = torch.cuda.get_device_capability(i) + sm_arch = capability[0] * 10 + capability[1] + lines.append(f" GPU {i}: {gpu_name}") + lines.append( + f" Compute capability: {capability[0]}.{capability[1]} (SM{sm_arch})" + ) + except Exception as e: + lines.append(f" GPU {i}: ") + else: + lines.append("CUDA: Not available (CPU-only mode)") + + # PyTorch version + lines.append(f"PyTorch version: {torch.__version__}") + + except Exception as e: + lines.append(f"Error gathering system information: {e}") + + lines.append("=" * 80) + lines.append("") # Empty line for readability + + _logger.debug("\n".join(lines)) + + +# Log system information once at module load time (if logging is enabled) +_log_system_info() + + def _format_value(value: Any, level: int, indent: int = 0) -> str: """ Format a value for logging based on the log level. From 6d0406b2c7261bb7a1dd937626ab13366ba55780 Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Wed, 19 Nov 2025 22:02:03 +0000 Subject: [PATCH 04/13] Rename logging env vars. Set default dest to stdout --- benchmarks/bench_logging_overhead.py | 20 ++++++++++---------- flashinfer/api_logging.py | 14 +++++++------- tests/utils/test_logging.py | 20 ++++++++++---------- 3 files changed, 27 insertions(+), 27 deletions(-) diff --git a/benchmarks/bench_logging_overhead.py b/benchmarks/bench_logging_overhead.py index de2781feb5..391e588f66 100644 --- a/benchmarks/bench_logging_overhead.py +++ b/benchmarks/bench_logging_overhead.py @@ -12,14 +12,14 @@ Usage: # Set the logging level before running - export FLASHINFER_APILOG_LEVEL=2 + export FLASHINFER_LOGLEVEL_DBG=2 python bench_logging_overhead.py # Or run with different levels - FLASHINFER_APILOG_LEVEL=0 python bench_logging_overhead.py - FLASHINFER_APILOG_LEVEL=1 python bench_logging_overhead.py - FLASHINFER_APILOG_LEVEL=2 python bench_logging_overhead.py - FLASHINFER_APILOG_LEVEL=3 python bench_logging_overhead.py + FLASHINFER_LOGLEVEL_DBG=0 python bench_logging_overhead.py + FLASHINFER_LOGLEVEL_DBG=1 python bench_logging_overhead.py + FLASHINFER_LOGLEVEL_DBG=2 python bench_logging_overhead.py + FLASHINFER_LOGLEVEL_DBG=3 python bench_logging_overhead.py # Or use the helper script to run all levels bash benchmark_all_levels.sh @@ -33,8 +33,8 @@ from typing import List, Tuple # Get logging level BEFORE importing flashinfer -LOGGING_LEVEL = int(os.environ.get("FLASHINFER_APILOG_LEVEL", "0")) -LOG_DEST = os.environ.get("FLASHINFER_APILOG_DEST", "/tmp/flashinfer_benchmark_log.txt") +LOGGING_LEVEL = int(os.environ.get("FLASHINFER_LOGLEVEL_DBG", "0")) +LOG_DEST = os.environ.get("FLASHINFER_LOGDEST_DBG", "/tmp/flashinfer_benchmark_log.txt") # Import the decorator try: @@ -226,8 +226,8 @@ def main(): # Display logging configuration print("\nLogging Configuration:") - print(f" FLASHINFER_APILOG_LEVEL = {LOGGING_LEVEL}") - print(f" FLASHINFER_APILOG_DEST = {LOG_DEST}") + print(f" FLASHINFER_LOGLEVEL_DBG = {LOGGING_LEVEL}") + print(f" FLASHINFER_LOGDEST_DBG = {LOG_DEST}") # Get level name level_names = { @@ -331,7 +331,7 @@ def main(): print("\nTo benchmark other levels, run:") for level in [0, 1, 2, 3]: if level != LOGGING_LEVEL: - print(f" FLASHINFER_APILOG_LEVEL={level} python {sys.argv[0]}") + print(f" FLASHINFER_LOGLEVEL_DBG={level} python {sys.argv[0]}") print("\n" + "=" * 80) print("Benchmark complete!") diff --git a/flashinfer/api_logging.py b/flashinfer/api_logging.py index 9626a39202..12566dc8fe 100644 --- a/flashinfer/api_logging.py +++ b/flashinfer/api_logging.py @@ -26,8 +26,8 @@ # Read environment variables once at module load time -_API_LOG_LEVEL = int(os.environ.get("FLASHINFER_APILOG_LEVEL", "0")) -_API_LOG_DEST = os.environ.get("FLASHINFER_APILOG_DEST", "./flashinfer_log.txt") +_API_LOG_LEVEL = int(os.environ.get("FLASHINFER_LOGLEVEL_DBG", "0")) +_API_LOG_DEST = os.environ.get("FLASHINFER_LOGDEST_DBG", "stdout") # Create logger using Python's logging library _logger = logging.getLogger("flashinfer.api") @@ -41,7 +41,7 @@ def _setup_logger(): _logger.setLevel(logging.CRITICAL + 1) # Higher than any level return - # All enabled levels use loggging.DEBUG; verbosity is controlled by FLASHINFER_APILOG_LEVEL instead + # All enabled levels use loggging.DEBUG; verbosity is controlled by FLASHINFER_LOGLEVEL_DBG instead _logger.setLevel(logging.DEBUG) # Remove any existing handlers @@ -446,17 +446,17 @@ def flashinfer_api_log(func: Callable = None) -> Callable: Decorator to log FlashInfer API calls using Python's logging library. This decorator integrates with Python's standard logging infrastructure while - maintaining zero overhead when disabled (FLASHINFER_APILOG_LEVEL=0). + maintaining zero overhead when disabled (FLASHINFER_LOGLEVEL_DBG=0). Environment Variables --------------------- - FLASHINFER_APILOG_LEVEL : int (default: 0) + FLASHINFER_LOGLEVEL_DBG : int (default: 0) - 0: No logging (zero overhead - decorator returns original function) - 1: Log function name only (logged BEFORE execution - crash-safe) - 2: Log function name + inputs/outputs with metadata (inputs logged BEFORE execution - crash-safe) - 3: Log function name + inputs/outputs with metadata + tensor statistics (inputs logged BEFORE execution - crash-safe) - FLASHINFER_APILOG_DEST : str (default: "./flashinfer_log.txt") + FLASHINFER_LOGDEST_DBG : str (default: "stdout") - "stdout": Log to standard output - "stderr": Log to standard error - : Log to specified file path @@ -471,7 +471,7 @@ def flashinfer_api_log(func: Callable = None) -> Callable: Notes ----- - - When FLASHINFER_APILOG_LEVEL=0, the decorator has truly zero overhead + - When FLASHINFER_LOGLEVEL_DBG=0, the decorator has truly zero overhead as it returns the original function unchanged. - Function names and inputs are logged BEFORE execution: - Level 1: Function name only diff --git a/tests/utils/test_logging.py b/tests/utils/test_logging.py index 3559d1d11b..43c44ddbd1 100644 --- a/tests/utils/test_logging.py +++ b/tests/utils/test_logging.py @@ -47,21 +47,21 @@ class TestAPILogging: def setup_and_teardown(self): """Reset environment and reimport logging module for each test.""" # Store original environment - original_level = os.environ.get("FLASHINFER_APILOG_LEVEL") - original_dest = os.environ.get("FLASHINFER_APILOG_DEST") + original_level = os.environ.get("FLASHINFER_LOGLEVEL_DBG") + original_dest = os.environ.get("FLASHINFER_LOGDEST_DBG") yield # Restore original environment if original_level is not None: - os.environ["FLASHINFER_APILOG_LEVEL"] = original_level - elif "FLASHINFER_APILOG_LEVEL" in os.environ: - del os.environ["FLASHINFER_APILOG_LEVEL"] + os.environ["FLASHINFER_LOGLEVEL_DBG"] = original_level + elif "FLASHINFER_LOGLEVEL_DBG" in os.environ: + del os.environ["FLASHINFER_LOGLEVEL_DBG"] if original_dest is not None: - os.environ["FLASHINFER_APILOG_DEST"] = original_dest - elif "FLASHINFER_APILOG_DEST" in os.environ: - del os.environ["FLASHINFER_APILOG_DEST"] + os.environ["FLASHINFER_LOGDEST_DBG"] = original_dest + elif "FLASHINFER_LOGDEST_DBG" in os.environ: + del os.environ["FLASHINFER_LOGDEST_DBG"] # Force reimport to pick up new environment variables if "flashinfer.api_logging" in sys.modules: @@ -69,8 +69,8 @@ def setup_and_teardown(self): def setup_logging(self, level: int, dest: str = "stdout"): """Helper to set up logging environment and reimport.""" - os.environ["FLASHINFER_APILOG_LEVEL"] = str(level) - os.environ["FLASHINFER_APILOG_DEST"] = dest + os.environ["FLASHINFER_LOGLEVEL_DBG"] = str(level) + os.environ["FLASHINFER_LOGDEST_DBG"] = dest # Force reimport if "flashinfer.api_logging" in sys.modules: From e984e1b68d0e37329aef061e43050a42ef8c01ea Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Wed, 19 Nov 2025 23:03:37 +0000 Subject: [PATCH 05/13] Allow %i substitution for process ID for multi-GPU environments. Logging level above 3 now automatically include cudnn and cublas API logging --- flashinfer/api_logging.py | 81 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 80 insertions(+), 1 deletion(-) diff --git a/flashinfer/api_logging.py b/flashinfer/api_logging.py index 12566dc8fe..37fdb63e97 100644 --- a/flashinfer/api_logging.py +++ b/flashinfer/api_logging.py @@ -25,9 +25,58 @@ import torch +# Helper function to substitute %i with process ID in file paths +def _substitute_process_id(path: str) -> str: + """ + Replace %i with the current process ID in a path. + + This is useful for multi-process/multi-GPU environments where each process + needs its own log file. + + Example: "flashinfer_log_%i.txt" -> "flashinfer_log_12345.txt" + """ + if "%i" in path: + return path.replace("%i", str(os.getpid())) + return path + + # Read environment variables once at module load time _API_LOG_LEVEL = int(os.environ.get("FLASHINFER_LOGLEVEL_DBG", "0")) -_API_LOG_DEST = os.environ.get("FLASHINFER_LOGDEST_DBG", "stdout") +_API_LOG_DEST = _substitute_process_id( + os.environ.get("FLASHINFER_LOGDEST_DBG", "stdout") +) + +# Enable cuDNN, cuBLAS, and cuBLASLt API logging when FlashInfer logging level >= 3 +# Only override if the user hasn't already configured the logging switch +# If the switch is not set, we override both the switch and destination as a bundle +if _API_LOG_LEVEL >= 3: + # cuBLAS logging: Check switch, set both switch and destination + if "CUBLAS_LOGINFO_DBG" not in os.environ: + os.environ["CUBLAS_LOGINFO_DBG"] = "1" + os.environ["CUBLAS_LOGDEST_DBG"] = _substitute_process_id( + "flashinfer_cublas_log_%i.txt" + ) + + # cuBLASLt logging: Check switch, set both switch and destination + if "CUBLASLT_LOG_LEVEL" not in os.environ: + os.environ["CUBLASLT_LOG_LEVEL"] = "2" + os.environ["CUBLASLT_LOG_FILE"] = _substitute_process_id( + "flashinfer_cublaslt_log_%i.txt" + ) + + # cuDNN backend logging: Check switch, set both switch and destination + if "CUDNN_LOGLEVEL_DBG" not in os.environ: + os.environ["CUDNN_LOGLEVEL_DBG"] = "2.5" + os.environ["CUDNN_LOGDEST_DBG"] = _substitute_process_id( + "flashinfer_cudnn_backend_log_%i.txt" + ) + + # cuDNN frontend logging: Check switch, set both switch and destination + if "CUDNN_FRONTEND_LOG_INFO" not in os.environ: + os.environ["CUDNN_FRONTEND_LOG_INFO"] = "1" + os.environ["CUDNN_FRONTEND_LOG_FILE"] = _substitute_process_id( + "flashinfer_cudnn_frontend_log_%i.txt" + ) # Create logger using Python's logging library _logger = logging.getLogger("flashinfer.api") @@ -129,6 +178,28 @@ def _log_system_info(): # PyTorch version lines.append(f"PyTorch version: {torch.__version__}") + # cuDNN/cuBLAS/cuBLASLt logging status + if _API_LOG_LEVEL >= 3: + lines.append("") + lines.append("cuDNN/cuBLAS/cuBLASLt Logging: Enabled (Level 3)") + cublas_info = os.environ.get("CUBLAS_LOGINFO_DBG", "not set") + cublas_dest = os.environ.get("CUBLAS_LOGDEST_DBG", "not set") + cublaslt_level = os.environ.get("CUBLASLT_LOG_LEVEL", "not set") + cublaslt_file = os.environ.get("CUBLASLT_LOG_FILE", "not set") + cudnn_level = os.environ.get("CUDNN_LOGLEVEL_DBG", "not set") + cudnn_dest = os.environ.get("CUDNN_LOGDEST_DBG", "not set") + cudnn_fe_info = os.environ.get("CUDNN_FRONTEND_LOG_INFO", "not set") + cudnn_fe_file = os.environ.get("CUDNN_FRONTEND_LOG_FILE", "not set") + + lines.append(f" CUBLAS_LOGINFO_DBG={cublas_info}") + lines.append(f" CUBLAS_LOGDEST_DBG={cublas_dest}") + lines.append(f" CUBLASLT_LOG_LEVEL={cublaslt_level}") + lines.append(f" CUBLASLT_LOG_FILE={cublaslt_file}") + lines.append(f" CUDNN_LOGLEVEL_DBG={cudnn_level}") + lines.append(f" CUDNN_LOGDEST_DBG={cudnn_dest}") + lines.append(f" CUDNN_FRONTEND_LOG_INFO={cudnn_fe_info}") + lines.append(f" CUDNN_FRONTEND_LOG_FILE={cudnn_fe_file}") + except Exception as e: lines.append(f"Error gathering system information: {e}") @@ -460,6 +531,7 @@ def flashinfer_api_log(func: Callable = None) -> Callable: - "stdout": Log to standard output - "stderr": Log to standard error - : Log to specified file path + - Use %i in path for process ID substitution (e.g., "log_%i.txt" -> "log_12345.txt") Examples -------- @@ -482,6 +554,13 @@ def flashinfer_api_log(func: Callable = None) -> Callable: - **CUDA Graph Compatibility**: At level 3, tensor statistics (min/max/mean/nan_count) are automatically skipped during CUDA graph capture to avoid synchronization issues. The message "[statistics skipped: CUDA graph capture in progress]" will be logged. + - **cuDNN/cuBLAS/cuBLASLt Integration**: At level 3, if not already set by the user, the following + environment variables are automatically configured to enable cuDNN, cuBLAS, and cuBLASLt logging: + - CUBLAS_LOGINFO_DBG=1, CUBLAS_LOGDEST_DBG=flashinfer_cublas_log_%i.txt + - CUBLASLT_LOG_LEVEL=2, CUBLASLT_LOG_FILE=flashinfer_cublaslt_log_%i.txt + - CUDNN_LOGLEVEL_DBG=2.5, CUDNN_LOGDEST_DBG=flashinfer_cudnn_backend_log_%i.txt + - CUDNN_FRONTEND_LOG_INFO=1, CUDNN_FRONTEND_LOG_FILE=flashinfer_cudnn_frontend_log_%i.txt + The %i pattern is automatically replaced with the process ID for multi-process environments. - The logger does not propagate to the root logger to avoid duplicate logs. """ # If logging is disabled, return original function with zero overhead From 62b6436413e52ea828e03720c679e9d5114bde5b Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Wed, 19 Nov 2025 23:35:08 +0000 Subject: [PATCH 06/13] Add time stamps and space out levels --- benchmarks/bench_logging_overhead.py | 10 ++--- flashinfer/api_logging.py | 59 +++++++++++++++++----------- tests/utils/test_logging.py | 34 ++++++++-------- 3 files changed, 57 insertions(+), 46 deletions(-) diff --git a/benchmarks/bench_logging_overhead.py b/benchmarks/bench_logging_overhead.py index 391e588f66..5a1b6d72c8 100644 --- a/benchmarks/bench_logging_overhead.py +++ b/benchmarks/bench_logging_overhead.py @@ -12,14 +12,14 @@ Usage: # Set the logging level before running - export FLASHINFER_LOGLEVEL_DBG=2 + export FLASHINFER_LOGLEVEL_DBG=3 python bench_logging_overhead.py # Or run with different levels FLASHINFER_LOGLEVEL_DBG=0 python bench_logging_overhead.py FLASHINFER_LOGLEVEL_DBG=1 python bench_logging_overhead.py - FLASHINFER_LOGLEVEL_DBG=2 python bench_logging_overhead.py FLASHINFER_LOGLEVEL_DBG=3 python bench_logging_overhead.py + FLASHINFER_LOGLEVEL_DBG=5 python bench_logging_overhead.py # Or use the helper script to run all levels bash benchmark_all_levels.sh @@ -233,8 +233,8 @@ def main(): level_names = { 0: "No logging (zero-overhead)", 1: "Function name only", - 2: "Name + inputs/outputs + metadata", - 3: "Name + inputs/outputs + metadata + statistics", + 3: "Name + inputs/outputs + metadata", + 5: "Name + inputs/outputs + metadata + statistics", } print(f" Level description: {level_names.get(LOGGING_LEVEL, 'Unknown')}") @@ -329,7 +329,7 @@ def main(): print("RECOMMENDATIONS") print("=" * 80) print("\nTo benchmark other levels, run:") - for level in [0, 1, 2, 3]: + for level in [0, 1, 3, 5]: if level != LOGGING_LEVEL: print(f" FLASHINFER_LOGLEVEL_DBG={level} python {sys.argv[0]}") diff --git a/flashinfer/api_logging.py b/flashinfer/api_logging.py index 37fdb63e97..583b6ec23b 100644 --- a/flashinfer/api_logging.py +++ b/flashinfer/api_logging.py @@ -46,10 +46,10 @@ def _substitute_process_id(path: str) -> str: os.environ.get("FLASHINFER_LOGDEST_DBG", "stdout") ) -# Enable cuDNN, cuBLAS, and cuBLASLt API logging when FlashInfer logging level >= 3 +# Enable cuDNN, cuBLAS, and cuBLASLt API logging when FlashInfer logging level >= 5 # Only override if the user hasn't already configured the logging switch # If the switch is not set, we override both the switch and destination as a bundle -if _API_LOG_LEVEL >= 3: +if _API_LOG_LEVEL >= 5: # cuBLAS logging: Check switch, set both switch and destination if "CUBLAS_LOGINFO_DBG" not in os.environ: os.environ["CUBLAS_LOGINFO_DBG"] = "1" @@ -104,7 +104,7 @@ def _setup_logger(): else: handler = logging.FileHandler(_API_LOG_DEST, mode="a") - # Use a simple formatter (we'll format the detailed content ourselves) + # Use a simple formatter (we'll add timestamps manually to key lines) formatter = logging.Formatter("%(message)s") handler.setFormatter(formatter) @@ -116,6 +116,13 @@ def _setup_logger(): _setup_logger() +def _get_timestamp() -> str: + """Get current timestamp in the format [YYYY-MM-DD HH:MM:SS].""" + from datetime import datetime + + return datetime.now().strftime("[%Y-%m-%d %H:%M:%S]") + + def _log_system_info(): """Log system information once at module initialization.""" if _API_LOG_LEVEL == 0: @@ -123,7 +130,7 @@ def _log_system_info(): lines = [] lines.append("=" * 80) - lines.append("FlashInfer API Logging - System Information") + lines.append(f"{_get_timestamp()} FlashInfer API Logging - System Information") lines.append("=" * 80) try: @@ -179,9 +186,9 @@ def _log_system_info(): lines.append(f"PyTorch version: {torch.__version__}") # cuDNN/cuBLAS/cuBLASLt logging status - if _API_LOG_LEVEL >= 3: + if _API_LOG_LEVEL >= 5: lines.append("") - lines.append("cuDNN/cuBLAS/cuBLASLt Logging: Enabled (Level 3)") + lines.append("cuDNN/cuBLAS/cuBLASLt Logging: Enabled (Level 5)") cublas_info = os.environ.get("CUBLAS_LOGINFO_DBG", "not set") cublas_dest = os.environ.get("CUBLAS_LOGDEST_DBG", "not set") cublaslt_level = os.environ.get("CUBLASLT_LOG_LEVEL", "not set") @@ -249,7 +256,7 @@ def _format_value(value: Any, level: int, indent: int = 0) -> str: if level == 1: return f"{indent_str}Tensor(...)" - # Level 2+: Show metadata + # Level 3+: Show metadata lines = [f"{indent_str}Tensor("] lines.append(f"{indent_str} shape={tuple(value.shape)}") lines.append(f"{indent_str} stride={tuple(value.stride())}") @@ -258,8 +265,8 @@ def _format_value(value: Any, level: int, indent: int = 0) -> str: lines.append(f"{indent_str} requires_grad={value.requires_grad}") lines.append(f"{indent_str} is_contiguous={value.is_contiguous()}") - # Level 3: Add statistics - if level >= 3: + # Level 5: Add statistics + if level >= 5: try: # Skip statistics if we're in CUDA graph capture mode # (operations like .min()/.max()/.mean() cause synchronization issues) @@ -452,11 +459,11 @@ def _log_function_inputs( kwargs : dict Keyword arguments level : int - Logging level (2 or 3) + Logging level (3 or 5) """ lines = [] lines.append("=" * 80) - lines.append(f"FlashInfer API Call: {func_name}") + lines.append(f"{_get_timestamp()} FlashInfer API Call: {func_name}") lines.append("-" * 80) # Log explicitly provided inputs @@ -499,7 +506,7 @@ def _log_function_outputs(func_name: str, result: Any, level: int) -> None: result : Any Function return value level : int - Logging level (2 or 3) + Logging level (3 or 5) """ lines = [] # Log outputs @@ -524,8 +531,8 @@ def flashinfer_api_log(func: Callable = None) -> Callable: FLASHINFER_LOGLEVEL_DBG : int (default: 0) - 0: No logging (zero overhead - decorator returns original function) - 1: Log function name only (logged BEFORE execution - crash-safe) - - 2: Log function name + inputs/outputs with metadata (inputs logged BEFORE execution - crash-safe) - - 3: Log function name + inputs/outputs with metadata + tensor statistics (inputs logged BEFORE execution - crash-safe) + - 3: Log function name + inputs/outputs with metadata (inputs logged BEFORE execution - crash-safe) + - 5: Log function name + inputs/outputs with metadata + tensor statistics (inputs logged BEFORE execution - crash-safe) FLASHINFER_LOGDEST_DBG : str (default: "stdout") - "stdout": Log to standard output @@ -543,18 +550,20 @@ def flashinfer_api_log(func: Callable = None) -> Callable: Notes ----- + - Key header lines include a timestamp in the format: [YYYY-MM-DD HH:MM:SS] + (e.g., "FlashInfer API Call: function_name", "FlashInfer API Logging - System Information") - When FLASHINFER_LOGLEVEL_DBG=0, the decorator has truly zero overhead as it returns the original function unchanged. - Function names and inputs are logged BEFORE execution: - Level 1: Function name only - - Levels 2-3: Function name + inputs with metadata + - Levels 3-5: Function name + inputs with metadata This means critical debugging information is preserved even if the function crashes (e.g., CUDA illegal memory access, out-of-bounds, etc.). - - Outputs are logged AFTER successful execution for levels 2 and 3. - - **CUDA Graph Compatibility**: At level 3, tensor statistics (min/max/mean/nan_count) + - Outputs are logged AFTER successful execution for levels 3 and 5. + - **CUDA Graph Compatibility**: At level 5, tensor statistics (min/max/mean/nan_count) are automatically skipped during CUDA graph capture to avoid synchronization issues. The message "[statistics skipped: CUDA graph capture in progress]" will be logged. - - **cuDNN/cuBLAS/cuBLASLt Integration**: At level 3, if not already set by the user, the following + - **cuDNN/cuBLAS/cuBLASLt Integration**: At level 5, if not already set by the user, the following environment variables are automatically configured to enable cuDNN, cuBLAS, and cuBLASLt logging: - CUBLAS_LOGINFO_DBG=1, CUBLAS_LOGDEST_DBG=flashinfer_cublas_log_%i.txt - CUBLASLT_LOG_LEVEL=2, CUBLASLT_LOG_FILE=flashinfer_cublaslt_log_%i.txt @@ -588,9 +597,11 @@ def wrapper(*args, **kwargs): try: if _API_LOG_LEVEL == 1: # Level 1: Just log function name before execution (crash-safe) - _logger.debug(f"FlashInfer API Call: {func_name}") - elif _API_LOG_LEVEL >= 2: - # Level 2+: Log full inputs before execution (crash-safe) + _logger.debug( + f"{_get_timestamp()} FlashInfer API Call: {func_name}" + ) + elif _API_LOG_LEVEL >= 3: + # Level 3+: Log full inputs before execution (crash-safe) _log_function_inputs(f, func_name, args, kwargs, _API_LOG_LEVEL) except Exception as e: _logger.error(f"[LOGGING ERROR in {func_name} (pre-execution)]: {e}") @@ -598,10 +609,10 @@ def wrapper(*args, **kwargs): # Call the original function (may crash here with CUDA errors) result = f(*args, **kwargs) - # Log outputs AFTER successful execution (level 2+ only) + # Log outputs AFTER successful execution (level 3+ only) try: - if _API_LOG_LEVEL >= 2: - # Level 2+: Log outputs (inputs were already logged above) + if _API_LOG_LEVEL >= 3: + # Level 3+: Log outputs (inputs were already logged above) _log_function_outputs(func_name, result, _API_LOG_LEVEL) except Exception as e: _logger.error(f"[LOGGING ERROR in {func_name} (outputs)]: {e}") diff --git a/tests/utils/test_logging.py b/tests/utils/test_logging.py index 43c44ddbd1..58efc52ac0 100644 --- a/tests/utils/test_logging.py +++ b/tests/utils/test_logging.py @@ -119,13 +119,13 @@ def test_function(x, y): finally: Path(log_file).unlink(missing_ok=True) - def test_level_2_inputs_outputs(self): - """Test that level 2 logs inputs and outputs with metadata.""" + def test_level_3_inputs_outputs(self): + """Test that level 3 logs inputs and outputs with metadata.""" with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: log_file = f.name try: - decorator = self.setup_logging(level=2, dest=log_file) + decorator = self.setup_logging(level=3, dest=log_file) @decorator def test_function(tensor, value): @@ -151,19 +151,19 @@ def test_function(tensor, value): # Should log outputs assert "Output value:" in log_contents - # Should NOT log statistics (level 3 only) + # Should NOT log statistics (level 5 only) assert "min=" not in log_contents assert "max=" not in log_contents finally: Path(log_file).unlink(missing_ok=True) - def test_level_3_statistics(self): - """Test that level 3 logs tensor statistics.""" + def test_level_5_statistics(self): + """Test that level 5 logs tensor statistics.""" with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: log_file = f.name try: - decorator = self.setup_logging(level=3, dest=log_file) + decorator = self.setup_logging(level=5, dest=log_file) @decorator def test_function(tensor): @@ -191,7 +191,7 @@ def test_enum_logging(self): log_file = f.name try: - decorator = self.setup_logging(level=2, dest=log_file) + decorator = self.setup_logging(level=3, dest=log_file) @decorator def test_function(mode: TestEnum, strategy: StringEnum): @@ -221,7 +221,7 @@ def test_default_parameters(self): log_file = f.name try: - decorator = self.setup_logging(level=2, dest=log_file) + decorator = self.setup_logging(level=3, dest=log_file) @decorator def test_function(x, y=10, z=20, mode=TestEnum.OPTION_A): @@ -252,7 +252,7 @@ def test_explicit_vs_default_parameters(self): log_file = f.name try: - decorator = self.setup_logging(level=2, dest=log_file) + decorator = self.setup_logging(level=3, dest=log_file) @decorator def test_function(x, y=10, z=20): @@ -316,7 +316,7 @@ def test_crash_safety_inputs_logged_before_execution(self): log_file = f.name try: - decorator = self.setup_logging(level=2, dest=log_file) + decorator = self.setup_logging(level=3, dest=log_file) @decorator def crashing_function(x, y): @@ -349,7 +349,7 @@ def test_different_data_types(self): log_file = f.name try: - decorator = self.setup_logging(level=2, dest=log_file) + decorator = self.setup_logging(level=3, dest=log_file) @decorator def test_function( @@ -387,7 +387,7 @@ def test_tensor_metadata(self): log_file = f.name try: - decorator = self.setup_logging(level=2, dest=log_file) + decorator = self.setup_logging(level=3, dest=log_file) @decorator def test_function(tensor): @@ -420,7 +420,7 @@ def test_nested_structures(self): log_file = f.name try: - decorator = self.setup_logging(level=2, dest=log_file) + decorator = self.setup_logging(level=3, dest=log_file) @decorator def test_function(nested): @@ -510,7 +510,7 @@ def test_kwargs_logging(self): log_file = f.name try: - decorator = self.setup_logging(level=2, dest=log_file) + decorator = self.setup_logging(level=3, dest=log_file) @decorator def test_function(a, b, c): @@ -534,12 +534,12 @@ def test_function(a, b, c): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_cuda_graph_compatibility(self): - """Test that level 3 logging is compatible with CUDA graph capture.""" + """Test that level 5 logging is compatible with CUDA graph capture.""" with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: log_file = f.name try: - decorator = self.setup_logging(level=3, dest=log_file) + decorator = self.setup_logging(level=5, dest=log_file) @decorator def test_cuda_function(tensor): From aed27cfd4ec97445c0fe286dedbc951addce6c00 Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Fri, 21 Nov 2025 01:20:19 +0000 Subject: [PATCH 07/13] Cleanup and streamline --- benchmarks/bench_logging_overhead.py | 21 +-------- flashinfer/api_logging.py | 69 ++-------------------------- flashinfer/cudnn/decode.py | 4 +- flashinfer/cudnn/prefill.py | 4 +- flashinfer/decode.py | 20 ++++---- flashinfer/fused_moe/core.py | 14 +++--- flashinfer/gemm/gemm_base.py | 24 +++++----- flashinfer/mla.py | 8 ++-- flashinfer/prefill.py | 22 ++++----- tests/utils/test_logging.py | 4 +- 10 files changed, 55 insertions(+), 135 deletions(-) diff --git a/benchmarks/bench_logging_overhead.py b/benchmarks/bench_logging_overhead.py index 5a1b6d72c8..0fd6724a4f 100644 --- a/benchmarks/bench_logging_overhead.py +++ b/benchmarks/bench_logging_overhead.py @@ -5,11 +5,6 @@ This script creates decorated and undecorated versions of a test function (torch.matmul) and compares their performance to accurately measure logging overhead. -Why torch.matmul instead of bmm_fp8? - - bmm_fp8 is already decorated in the FlashInfer source code - - Using it would cause double-decoration and inaccurate results - - torch.matmul gives us a clean baseline to measure pure decorator overhead - Usage: # Set the logging level before running export FLASHINFER_LOGLEVEL_DBG=3 @@ -37,30 +32,18 @@ LOG_DEST = os.environ.get("FLASHINFER_LOGDEST_DBG", "/tmp/flashinfer_benchmark_log.txt") # Import the decorator -try: - from flashinfer.api_logging import flashinfer_api_log -except ImportError as e: - print(f"Error: Could not import flashinfer: {e}") - print("Make sure flashinfer is installed.") - exit(1) +from flashinfer.api_logging import flashinfer_log # Create two versions of a test function: # 1. Undecorated (baseline) # 2. Decorated (with logging) -# -# We use a simple torch.matmul instead of bmm_fp8 because bmm_fp8 is already -# decorated in the source code, which would cause double-decoration. - - def test_matmul_undecorated(A, B): - """Undecorated version - baseline for comparison.""" return torch.matmul(A, B) -@flashinfer_api_log +@flashinfer_log def test_matmul_decorated(A, B): - """Decorated version - with API logging.""" return torch.matmul(A, B) diff --git a/flashinfer/api_logging.py b/flashinfer/api_logging.py index 583b6ec23b..6948e728d6 100644 --- a/flashinfer/api_logging.py +++ b/flashinfer/api_logging.py @@ -32,8 +32,6 @@ def _substitute_process_id(path: str) -> str: This is useful for multi-process/multi-GPU environments where each process needs its own log file. - - Example: "flashinfer_log_%i.txt" -> "flashinfer_log_12345.txt" """ if "%i" in path: return path.replace("%i", str(os.getpid())) @@ -46,38 +44,6 @@ def _substitute_process_id(path: str) -> str: os.environ.get("FLASHINFER_LOGDEST_DBG", "stdout") ) -# Enable cuDNN, cuBLAS, and cuBLASLt API logging when FlashInfer logging level >= 5 -# Only override if the user hasn't already configured the logging switch -# If the switch is not set, we override both the switch and destination as a bundle -if _API_LOG_LEVEL >= 5: - # cuBLAS logging: Check switch, set both switch and destination - if "CUBLAS_LOGINFO_DBG" not in os.environ: - os.environ["CUBLAS_LOGINFO_DBG"] = "1" - os.environ["CUBLAS_LOGDEST_DBG"] = _substitute_process_id( - "flashinfer_cublas_log_%i.txt" - ) - - # cuBLASLt logging: Check switch, set both switch and destination - if "CUBLASLT_LOG_LEVEL" not in os.environ: - os.environ["CUBLASLT_LOG_LEVEL"] = "2" - os.environ["CUBLASLT_LOG_FILE"] = _substitute_process_id( - "flashinfer_cublaslt_log_%i.txt" - ) - - # cuDNN backend logging: Check switch, set both switch and destination - if "CUDNN_LOGLEVEL_DBG" not in os.environ: - os.environ["CUDNN_LOGLEVEL_DBG"] = "2.5" - os.environ["CUDNN_LOGDEST_DBG"] = _substitute_process_id( - "flashinfer_cudnn_backend_log_%i.txt" - ) - - # cuDNN frontend logging: Check switch, set both switch and destination - if "CUDNN_FRONTEND_LOG_INFO" not in os.environ: - os.environ["CUDNN_FRONTEND_LOG_INFO"] = "1" - os.environ["CUDNN_FRONTEND_LOG_FILE"] = _substitute_process_id( - "flashinfer_cudnn_frontend_log_%i.txt" - ) - # Create logger using Python's logging library _logger = logging.getLogger("flashinfer.api") @@ -185,28 +151,6 @@ def _log_system_info(): # PyTorch version lines.append(f"PyTorch version: {torch.__version__}") - # cuDNN/cuBLAS/cuBLASLt logging status - if _API_LOG_LEVEL >= 5: - lines.append("") - lines.append("cuDNN/cuBLAS/cuBLASLt Logging: Enabled (Level 5)") - cublas_info = os.environ.get("CUBLAS_LOGINFO_DBG", "not set") - cublas_dest = os.environ.get("CUBLAS_LOGDEST_DBG", "not set") - cublaslt_level = os.environ.get("CUBLASLT_LOG_LEVEL", "not set") - cublaslt_file = os.environ.get("CUBLASLT_LOG_FILE", "not set") - cudnn_level = os.environ.get("CUDNN_LOGLEVEL_DBG", "not set") - cudnn_dest = os.environ.get("CUDNN_LOGDEST_DBG", "not set") - cudnn_fe_info = os.environ.get("CUDNN_FRONTEND_LOG_INFO", "not set") - cudnn_fe_file = os.environ.get("CUDNN_FRONTEND_LOG_FILE", "not set") - - lines.append(f" CUBLAS_LOGINFO_DBG={cublas_info}") - lines.append(f" CUBLAS_LOGDEST_DBG={cublas_dest}") - lines.append(f" CUBLASLT_LOG_LEVEL={cublaslt_level}") - lines.append(f" CUBLASLT_LOG_FILE={cublaslt_file}") - lines.append(f" CUDNN_LOGLEVEL_DBG={cudnn_level}") - lines.append(f" CUDNN_LOGDEST_DBG={cudnn_dest}") - lines.append(f" CUDNN_FRONTEND_LOG_INFO={cudnn_fe_info}") - lines.append(f" CUDNN_FRONTEND_LOG_FILE={cudnn_fe_file}") - except Exception as e: lines.append(f"Error gathering system information: {e}") @@ -519,7 +463,7 @@ def _log_function_outputs(func_name: str, result: Any, level: int) -> None: _logger.debug("\n".join(lines)) -def flashinfer_api_log(func: Callable = None) -> Callable: +def flashinfer_log(func: Callable = None) -> Callable: """ Decorator to log FlashInfer API calls using Python's logging library. @@ -544,7 +488,7 @@ def flashinfer_api_log(func: Callable = None) -> Callable: -------- Basic usage: - >>> @flashinfer_api_log + >>> @flashinfer_log ... def my_function(x, y): ... return x + y @@ -563,13 +507,7 @@ def flashinfer_api_log(func: Callable = None) -> Callable: - **CUDA Graph Compatibility**: At level 5, tensor statistics (min/max/mean/nan_count) are automatically skipped during CUDA graph capture to avoid synchronization issues. The message "[statistics skipped: CUDA graph capture in progress]" will be logged. - - **cuDNN/cuBLAS/cuBLASLt Integration**: At level 5, if not already set by the user, the following - environment variables are automatically configured to enable cuDNN, cuBLAS, and cuBLASLt logging: - - CUBLAS_LOGINFO_DBG=1, CUBLAS_LOGDEST_DBG=flashinfer_cublas_log_%i.txt - - CUBLASLT_LOG_LEVEL=2, CUBLASLT_LOG_FILE=flashinfer_cublaslt_log_%i.txt - - CUDNN_LOGLEVEL_DBG=2.5, CUDNN_LOGDEST_DBG=flashinfer_cudnn_backend_log_%i.txt - - CUDNN_FRONTEND_LOG_INFO=1, CUDNN_FRONTEND_LOG_FILE=flashinfer_cudnn_frontend_log_%i.txt - The %i pattern is automatically replaced with the process ID for multi-process environments. + - The %i pattern is automatically replaced with the process ID for multi-process environments. - The logger does not propagate to the root logger to avoid duplicate logs. """ # If logging is disabled, return original function with zero overhead @@ -621,7 +559,6 @@ def wrapper(*args, **kwargs): return wrapper - # Support both @flashinfer_api_log and @flashinfer_api_log() if func is None: return decorator return decorator(func) diff --git a/flashinfer/cudnn/decode.py b/flashinfer/cudnn/decode.py index 39f4cf67c1..0635623d36 100644 --- a/flashinfer/cudnn/decode.py +++ b/flashinfer/cudnn/decode.py @@ -3,7 +3,7 @@ import torch -from ..api_logging import flashinfer_api_log +from ..api_logging import flashinfer_log from .utils import get_cudnn_fmha_gen_module try: @@ -253,7 +253,7 @@ def _batch_decode_with_kv_cache( return out -@flashinfer_api_log +@flashinfer_log def cudnn_batch_decode_with_kv_cache( q: torch.Tensor, k_cache: torch.Tensor, diff --git a/flashinfer/cudnn/prefill.py b/flashinfer/cudnn/prefill.py index 9ca5cea66a..8d299355b6 100644 --- a/flashinfer/cudnn/prefill.py +++ b/flashinfer/cudnn/prefill.py @@ -3,7 +3,7 @@ import torch -from ..api_logging import flashinfer_api_log +from ..api_logging import flashinfer_log from .utils import get_cudnn_fmha_gen_module try: @@ -384,7 +384,7 @@ def _batch_prefill_with_kv_cache( return out, None -@flashinfer_api_log +@flashinfer_log def cudnn_batch_prefill_with_kv_cache( q: torch.Tensor, k_cache: torch.Tensor, diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 2f2072039a..91077cb814 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -21,7 +21,7 @@ import torch -from .api_logging import flashinfer_api_log +from .api_logging import flashinfer_log from .xqa import xqa, xqa_mla from .cudnn import cudnn_batch_decode_with_kv_cache as cudnn_batch_decode_with_kv_cache from .jit import ( @@ -313,7 +313,7 @@ def get_trtllm_gen_fmha_module(): return op -@flashinfer_api_log +@flashinfer_log def single_decode_with_kv_cache_with_jit_module( jit_module: Any, q: torch.Tensor, @@ -390,7 +390,7 @@ def single_decode_with_kv_cache( ) -> Tuple[torch.Tensor, torch.Tensor]: ... -@flashinfer_api_log +@flashinfer_log def single_decode_with_kv_cache( q: torch.Tensor, k: torch.Tensor, @@ -649,7 +649,7 @@ class BatchDecodeWithPagedKVCacheWrapper: manages the lifecycle of these data structures. """ - @flashinfer_api_log + @flashinfer_log def __init__( self, float_workspace_buffer: torch.Tensor, @@ -813,7 +813,7 @@ def reset_workspace_buffer( pin_memory=True, ) - @flashinfer_api_log + @flashinfer_log def plan( self, indptr: torch.Tensor, @@ -1167,7 +1167,7 @@ def run( window_left: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: ... - @flashinfer_api_log + @flashinfer_log def run( self, q: torch.Tensor, @@ -2065,7 +2065,7 @@ def _fake_paged_run( ) -@flashinfer_api_log +@flashinfer_log def trtllm_batch_decode_with_kv_cache( query: torch.Tensor, kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], @@ -2339,7 +2339,7 @@ def trtllm_batch_decode_with_kv_cache( # xqa uses NHD layout -@flashinfer_api_log +@flashinfer_log def xqa_batch_decode_with_kv_cache( query: torch.Tensor, kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], @@ -2524,7 +2524,7 @@ def _check_trtllm_gen_mla_shape( ) -@flashinfer_api_log +@flashinfer_log def trtllm_batch_decode_with_kv_cache_mla( query: torch.Tensor, kv_cache: torch.Tensor, @@ -2686,7 +2686,7 @@ def trtllm_batch_decode_with_kv_cache_mla( raise ValueError(f"Backend {backend} not supported") -@flashinfer_api_log +@flashinfer_log def xqa_batch_decode_with_kv_cache_mla( query: torch.Tensor, kv_cache: torch.Tensor, diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 9ab621453a..b66c69dc6a 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -20,7 +20,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch -from ..api_logging import flashinfer_api_log +from ..api_logging import flashinfer_log from ..autotuner import ( AutoTuner, DynamicTensorSpec, @@ -686,7 +686,7 @@ def _fake_cutlass_fused_moe( # ref: https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py#L121 -@flashinfer_api_log +@flashinfer_log def cutlass_fused_moe( input: torch.Tensor, token_selected_experts: torch.Tensor, @@ -1859,7 +1859,7 @@ def _fake_trtllm_fp4_block_scale_moe( ) -@flashinfer_api_log +@flashinfer_log def trtllm_bf16_moe( routing_logits: torch.Tensor, routing_bias: Optional[torch.Tensor], @@ -1940,7 +1940,7 @@ def trtllm_bf16_moe( ) -@flashinfer_api_log +@flashinfer_log def trtllm_fp8_per_tensor_scale_moe( routing_logits: torch.Tensor, routing_bias: Optional[torch.Tensor], @@ -2014,7 +2014,7 @@ def trtllm_fp8_per_tensor_scale_moe( ) -@flashinfer_api_log +@flashinfer_log def trtllm_fp8_block_scale_moe( routing_logits: torch.Tensor, routing_bias: Optional[torch.Tensor], @@ -2092,7 +2092,7 @@ def trtllm_fp8_block_scale_moe( ) -@flashinfer_api_log +@flashinfer_log def trtllm_fp4_block_scale_moe( routing_logits: torch.Tensor, routing_bias: Optional[torch.Tensor], @@ -2222,7 +2222,7 @@ def trtllm_fp4_block_scale_moe( ) -@flashinfer_api_log +@flashinfer_log def trtllm_fp4_block_scale_routed_moe( topk_ids: torch.Tensor, routing_bias: Optional[torch.Tensor], diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 90900b2775..ec88b09620 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -22,7 +22,7 @@ from flashinfer.trtllm_low_latency_gemm import trtllm_low_latency_gemm import torch -from ..api_logging import flashinfer_api_log +from ..api_logging import flashinfer_log from ..autotuner import ( AutoTuner, ConstraintSpec, @@ -540,7 +540,7 @@ def forward( ) -@flashinfer_api_log +@flashinfer_log def tgv_gemm_sm100( a: torch.Tensor, b: torch.Tensor, @@ -886,7 +886,7 @@ def reset_workspace_buffer( self._float_workspace_buffer = float_workspace_buffer self._int_workspace_buffer = int_workspace_buffer - @flashinfer_api_log + @flashinfer_log def run( self, x: torch.Tensor, @@ -1554,7 +1554,7 @@ def _expand_block_scale_tensor_shape(block_scale_tensor, batch_size): return (tuple(block_scale_shape), tuple(block_scale_stride)) -@flashinfer_api_log +@flashinfer_log def mm_fp8( a: torch.Tensor, b: torch.Tensor, @@ -2028,7 +2028,7 @@ def _heuristic_func_mm_fp4( common_check=_check_mm_fp4_problem_size, heuristic_func=_heuristic_func_mm_fp4, # result stored in mm_fp4.suitable_auto_backends ) -@flashinfer_api_log +@flashinfer_log def mm_fp4( a: torch.Tensor, b: torch.Tensor, @@ -2286,7 +2286,7 @@ def _heuristic_func_bmm_fp8( common_check=_check_bmm_fp8_problem_size, heuristic_func=_heuristic_func_bmm_fp8, ) -@flashinfer_api_log +@flashinfer_log def bmm_fp8( A: torch.Tensor, B: torch.Tensor, @@ -2378,7 +2378,7 @@ def bmm_fp8( return out -@flashinfer_api_log +@flashinfer_log def gemm_fp8_nt_groupwise( a: torch.Tensor, b: torch.Tensor, @@ -2630,7 +2630,7 @@ def forward( ) -@flashinfer_api_log +@flashinfer_log def gemm_fp8_nt_blockscaled( a: torch.Tensor, b: torch.Tensor, @@ -2659,7 +2659,7 @@ def gemm_fp8_nt_blockscaled( ) -@flashinfer_api_log +@flashinfer_log def group_gemm_fp8_nt_groupwise( a: torch.Tensor, # (cum_m, k) b: torch.Tensor, # (batch_size, n, k) @@ -2822,7 +2822,7 @@ def group_gemm_fp8_nt_groupwise( return out -@flashinfer_api_log +@flashinfer_log def group_gemm_mxfp8_mxfp4_nt_groupwise( a: torch.Tensor, # (cum_m, k) b: torch.Tensor, # (batch_size, n, k // 2) @@ -2990,7 +2990,7 @@ def get_deepgemm_sm100_module(): return module -@flashinfer_api_log +@flashinfer_log def group_deepgemm_fp8_nt_groupwise( a: torch.Tensor, # (m, k) b: torch.Tensor, # (batch_size, n, k) @@ -3121,7 +3121,7 @@ def group_deepgemm_fp8_nt_groupwise( return out -@flashinfer_api_log +@flashinfer_log def batch_deepgemm_fp8_nt_groupwise( a: torch.Tensor, # (batch_size, m, k) b: torch.Tensor, # (batch_size, n, k) diff --git a/flashinfer/mla.py b/flashinfer/mla.py index c94db531f9..4ae2fdcd5e 100644 --- a/flashinfer/mla.py +++ b/flashinfer/mla.py @@ -19,7 +19,7 @@ import torch -from .api_logging import flashinfer_api_log +from .api_logging import flashinfer_log from .jit import gen_batch_mla_module from .jit.mla import gen_mla_module from .utils import MaskMode, check_shape_dtype_device, determine_mla_backend @@ -130,7 +130,7 @@ class BatchMLAPagedAttentionWrapper: torch.Size([114, 128, 512]) """ - @flashinfer_api_log + @flashinfer_log def __init__( self, float_workspace_buffer: torch.Tensor, @@ -201,7 +201,7 @@ def __init__( else: self._backend = backend - @flashinfer_api_log + @flashinfer_log def plan( self, qo_indptr: torch.Tensor, @@ -336,7 +336,7 @@ def run( return_lse_base_on_e: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: ... - @flashinfer_api_log + @flashinfer_log def run( self, q_nope: torch.Tensor, diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 6fec42cbff..c266b5e33c 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -22,7 +22,7 @@ import torch -from .api_logging import flashinfer_api_log +from .api_logging import flashinfer_log from .jit import ( gen_batch_prefill_module, gen_customize_batch_prefill_module, @@ -874,7 +874,7 @@ def _fake_paged_run( ) -@flashinfer_api_log +@flashinfer_log def single_prefill_with_kv_cache_with_jit_module( jit_module: Any, q: torch.Tensor, @@ -959,7 +959,7 @@ def single_prefill_with_kv_cache( ) -> Tuple[torch.Tensor, torch.Tensor]: ... -@flashinfer_api_log +@flashinfer_log def single_prefill_with_kv_cache( q: torch.Tensor, k: torch.Tensor, @@ -1328,7 +1328,7 @@ class BatchPrefillWithPagedKVCacheWrapper: wrapper class manages the lifecycle of these data structures. """ - @flashinfer_api_log + @flashinfer_log def __init__( self, float_workspace_buffer: torch.Tensor, @@ -1524,7 +1524,7 @@ def reset_workspace_buffer( pin_memory=True, ) - @flashinfer_api_log + @flashinfer_log def plan( self, qo_indptr: torch.Tensor, @@ -1981,7 +1981,7 @@ def run( window_left: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: ... - @flashinfer_api_log + @flashinfer_log def run( self, q: torch.Tensor, @@ -2356,7 +2356,7 @@ class BatchPrefillWithRaggedKVCacheWrapper: wrapper class manages the lifecycle of these data structures. """ - @flashinfer_api_log + @flashinfer_log def __init__( self, float_workspace_buffer: torch.Tensor, @@ -2500,7 +2500,7 @@ def reset_workspace_buffer( pin_memory=True, ) - @flashinfer_api_log + @flashinfer_log def plan( self, qo_indptr: torch.Tensor, @@ -2845,7 +2845,7 @@ def run( enable_pdl: Optional[bool] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: ... - @flashinfer_api_log + @flashinfer_log def run( self, q: torch.Tensor, @@ -3202,7 +3202,7 @@ def get_trtllm_gen_fmha_module(): return op -@flashinfer_api_log +@flashinfer_log def trtllm_ragged_attention_deepseek( query: torch.Tensor, key: torch.Tensor, @@ -3337,7 +3337,7 @@ def trtllm_ragged_attention_deepseek( return out -@flashinfer_api_log +@flashinfer_log def trtllm_batch_context_with_kv_cache( query: torch.Tensor, kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], diff --git a/tests/utils/test_logging.py b/tests/utils/test_logging.py index 58efc52ac0..92cff6fc7e 100644 --- a/tests/utils/test_logging.py +++ b/tests/utils/test_logging.py @@ -76,9 +76,9 @@ def setup_logging(self, level: int, dest: str = "stdout"): if "flashinfer.api_logging" in sys.modules: del sys.modules["flashinfer.api_logging"] - from flashinfer.api_logging import flashinfer_api_log + from flashinfer.api_logging import flashinfer_log - return flashinfer_api_log + return flashinfer_log def test_level_0_zero_overhead(self): """Test that level 0 has truly zero overhead (returns original function).""" From 5ec246b2e2427d707cee342c5478e14fb1ad6833 Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Fri, 21 Nov 2025 01:36:05 +0000 Subject: [PATCH 08/13] Adding documentation updates --- LOGGING.md | 83 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ README.md | 14 +++++++++ 2 files changed, 97 insertions(+) create mode 100644 LOGGING.md diff --git a/LOGGING.md b/LOGGING.md new file mode 100644 index 0000000000..1e5c22db27 --- /dev/null +++ b/LOGGING.md @@ -0,0 +1,83 @@ +# FlashInfer Logging + +FlashInfer provides a logging feature to help debug issues, and reproduce crashes. This document describes all available logging levels and their features. + +## Quick Start + +Enable logging using two environment variables: + +```bash +# Set logging level (0-5) +export FLASHINFER_LOGLEVEL_DBG=3 + +# Set log destination (default is stdout) +export FLASHINFER_LOGDEST_DBG=stdout # or stderr, or a file path like "flashinfer.log" + +# Run your code +python train.py +``` + +## Logging Levels + +| Level | Name | Features | Use Case | +|-------|------|----------|----------| +| **0** | Disabled (Default) | No logging (zero overhad) | Production | +| **1** | Function Names | Function names only | Basic tracing | +| **3** | Inputs/Outputs | Function names + arguments + outputs with metadata | Standard debugging | +| **5** | Statistics | Level 3 + tensor statistics (min, max, mean, NaN/Inf counts) | Numerical analysis | + + +## Environment Variables + +### Main Configuration + +| Variable | Type | Default | Description | +|----------|------|---------|-------------| +| `FLASHINFER_LOGLEVEL_DBG` | int | 0 | Logging level (0, 1, 3, 5) | +| `FLASHINFER_LOGDEST_DBG` | str | `stdout` | Log destination: `stdout`, `stderr`, or file path | + +### Process ID Substitution + +Use `%i` in file paths for automatic process ID substitution (useful for multi-GPU training): + +```bash +export FLASHINFER_LOGDEST_DBG="flashinfer_log_%i.txt" # → flashinfer_log_12345.txt +``` + +This works for: +- `FLASHINFER_LOGDEST_DBG` + +## Miscellaneous Notes and Examples +### CUDA Graph Compatibility + +Level 5 statistics are **automatically skipped during CUDA graph capture** to avoid synchronization issues. + +```python +# This works correctly - no synchronization errors +with torch.cuda.graph(cuda_graph): + result = mm_fp4(a, b, scales) # Level 5 logging active + # Statistics automatically skipped during capture +``` + +Output shows: `[statistics skipped: CUDA graph capture in progress]` + +### Process IDs for Multi-GPU Environments + +```bash +# Use %i for process ID substitution +export FLASHINFER_LOGLEVEL_DBG=3 +export FLASHINFER_LOGDEST_DBG="logs/flashinfer_api_%i.log" + +torchrun --nproc_per_node=8 awesome_script_that_uses_FlashInfer.py + +# Creates separate logs: +# logs/flashinfer_api_12345.log (rank 0) +# logs/flashinfer_api_12346.log (rank 1) +# ... +``` + +## Frequently Asked Questions + +### Q: Does Level 0 really have zero overhead? + +**A: Yes.** At Level 0, the decorator returns the original function unchanged. No wrapper, no checks, no overhead. \ No newline at end of file diff --git a/README.md b/README.md index 94eece5007..d928e9e053 100644 --- a/README.md +++ b/README.md @@ -169,6 +169,20 @@ o = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=False) # prefill att Check out [documentation](https://docs.flashinfer.ai/) for usage of batch decode/append/prefill kernels and shared-prefix cascading kernels. +## API Logging + +FlashInfer provides comprehensive API logging for debugging. Enable it using environment variables: + +```bash +# Enable logging (levels: 0=off (default), 1=basic, 3=detailed, 5=statistics) +export FLASHINFER_LOGLEVEL_DBG=3 + +# Set log destination (stdout (default), stderr, or file path) +export FLASHINFER_LOGDEST_DBG=stdout +``` + +For detailed information about logging levels, configuration, and advanced features, see [LOGGING.md](LOGGING.md). + ## Custom Attention Variants Starting from FlashInfer v0.2, users can customize their own attention variants with additional parameters. For more details, refer to our [JIT examples](https://github.com/flashinfer-ai/flashinfer/blob/main/tests/utils/test_jit_example.py). From 8459eb16ca1dbf3260e82015b355d65b1665e28b Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Fri, 21 Nov 2025 01:42:06 +0000 Subject: [PATCH 09/13] Fix typo and apply pre-commit --- LOGGING.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/LOGGING.md b/LOGGING.md index 1e5c22db27..4cf4825c36 100644 --- a/LOGGING.md +++ b/LOGGING.md @@ -21,7 +21,7 @@ python train.py | Level | Name | Features | Use Case | |-------|------|----------|----------| -| **0** | Disabled (Default) | No logging (zero overhad) | Production | +| **0** | Disabled (Default) | No logging (zero overhead) | Production | | **1** | Function Names | Function names only | Basic tracing | | **3** | Inputs/Outputs | Function names + arguments + outputs with metadata | Standard debugging | | **5** | Statistics | Level 3 + tensor statistics (min, max, mean, NaN/Inf counts) | Numerical analysis | @@ -80,4 +80,4 @@ torchrun --nproc_per_node=8 awesome_script_that_uses_FlashInfer.py ### Q: Does Level 0 really have zero overhead? -**A: Yes.** At Level 0, the decorator returns the original function unchanged. No wrapper, no checks, no overhead. \ No newline at end of file +**A: Yes.** At Level 0, the decorator returns the original function unchanged. No wrapper, no checks, no overhead. From 41ad5581f3573bc177c99ea1840098283fa3d9d0 Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Fri, 21 Nov 2025 21:59:09 +0000 Subject: [PATCH 10/13] Rename decorator and environment flags --- LOGGING.md | 16 ++++++++-------- README.md | 4 ++-- benchmarks/bench_logging_overhead.py | 24 ++++++++++++------------ flashinfer/api_logging.py | 20 +++++++++----------- flashinfer/cudnn/decode.py | 4 ++-- flashinfer/cudnn/prefill.py | 4 ++-- flashinfer/decode.py | 20 ++++++++++---------- flashinfer/fused_moe/core.py | 14 +++++++------- flashinfer/gemm/gemm_base.py | 24 ++++++++++++------------ flashinfer/mla.py | 8 ++++---- flashinfer/prefill.py | 22 +++++++++++----------- tests/utils/test_logging.py | 24 ++++++++++++------------ 12 files changed, 91 insertions(+), 93 deletions(-) diff --git a/LOGGING.md b/LOGGING.md index 4cf4825c36..b44a0036d7 100644 --- a/LOGGING.md +++ b/LOGGING.md @@ -8,10 +8,10 @@ Enable logging using two environment variables: ```bash # Set logging level (0-5) -export FLASHINFER_LOGLEVEL_DBG=3 +export FLASHINFER_LOGLEVEL=3 # Set log destination (default is stdout) -export FLASHINFER_LOGDEST_DBG=stdout # or stderr, or a file path like "flashinfer.log" +export FLASHINFER_LOGDEST=stdout # or stderr, or a file path like "flashinfer.log" # Run your code python train.py @@ -33,19 +33,19 @@ python train.py | Variable | Type | Default | Description | |----------|------|---------|-------------| -| `FLASHINFER_LOGLEVEL_DBG` | int | 0 | Logging level (0, 1, 3, 5) | -| `FLASHINFER_LOGDEST_DBG` | str | `stdout` | Log destination: `stdout`, `stderr`, or file path | +| `FLASHINFER_LOGLEVEL` | int | 0 | Logging level (0, 1, 3, 5) | +| `FLASHINFER_LOGDEST` | str | `stdout` | Log destination: `stdout`, `stderr`, or file path | ### Process ID Substitution Use `%i` in file paths for automatic process ID substitution (useful for multi-GPU training): ```bash -export FLASHINFER_LOGDEST_DBG="flashinfer_log_%i.txt" # → flashinfer_log_12345.txt +export FLASHINFER_LOGDEST="flashinfer_log_%i.txt" # → flashinfer_log_12345.txt ``` This works for: -- `FLASHINFER_LOGDEST_DBG` +- `FLASHINFER_LOGDEST` ## Miscellaneous Notes and Examples ### CUDA Graph Compatibility @@ -65,8 +65,8 @@ Output shows: `[statistics skipped: CUDA graph capture in progress]` ```bash # Use %i for process ID substitution -export FLASHINFER_LOGLEVEL_DBG=3 -export FLASHINFER_LOGDEST_DBG="logs/flashinfer_api_%i.log" +export FLASHINFER_LOGLEVEL=3 +export FLASHINFER_LOGDEST="logs/flashinfer_api_%i.log" torchrun --nproc_per_node=8 awesome_script_that_uses_FlashInfer.py diff --git a/README.md b/README.md index d928e9e053..cd5c7e1e58 100644 --- a/README.md +++ b/README.md @@ -175,10 +175,10 @@ FlashInfer provides comprehensive API logging for debugging. Enable it using env ```bash # Enable logging (levels: 0=off (default), 1=basic, 3=detailed, 5=statistics) -export FLASHINFER_LOGLEVEL_DBG=3 +export FLASHINFER_LOGLEVEL=3 # Set log destination (stdout (default), stderr, or file path) -export FLASHINFER_LOGDEST_DBG=stdout +export FLASHINFER_LOGDEST=stdout ``` For detailed information about logging levels, configuration, and advanced features, see [LOGGING.md](LOGGING.md). diff --git a/benchmarks/bench_logging_overhead.py b/benchmarks/bench_logging_overhead.py index 0fd6724a4f..db916a0ce2 100644 --- a/benchmarks/bench_logging_overhead.py +++ b/benchmarks/bench_logging_overhead.py @@ -7,14 +7,14 @@ Usage: # Set the logging level before running - export FLASHINFER_LOGLEVEL_DBG=3 + export FLASHINFER_APILEVEL=3 python bench_logging_overhead.py # Or run with different levels - FLASHINFER_LOGLEVEL_DBG=0 python bench_logging_overhead.py - FLASHINFER_LOGLEVEL_DBG=1 python bench_logging_overhead.py - FLASHINFER_LOGLEVEL_DBG=3 python bench_logging_overhead.py - FLASHINFER_LOGLEVEL_DBG=5 python bench_logging_overhead.py + FLASHINFER_APILEVEL=0 python bench_logging_overhead.py + FLASHINFER_APILEVEL=1 python bench_logging_overhead.py + FLASHINFER_APILEVEL=3 python bench_logging_overhead.py + FLASHINFER_APILEVEL=5 python bench_logging_overhead.py # Or use the helper script to run all levels bash benchmark_all_levels.sh @@ -28,11 +28,11 @@ from typing import List, Tuple # Get logging level BEFORE importing flashinfer -LOGGING_LEVEL = int(os.environ.get("FLASHINFER_LOGLEVEL_DBG", "0")) -LOG_DEST = os.environ.get("FLASHINFER_LOGDEST_DBG", "/tmp/flashinfer_benchmark_log.txt") +LOGGING_LEVEL = int(os.environ.get("FLASHINFER_APILEVEL", "0")) +LOG_DEST = os.environ.get("FLASHINFER_APIDEST", "/tmp/flashinfer_benchmark_log.txt") # Import the decorator -from flashinfer.api_logging import flashinfer_log +from flashinfer.api_logging import flashinfer_api # Create two versions of a test function: @@ -42,7 +42,7 @@ def test_matmul_undecorated(A, B): return torch.matmul(A, B) -@flashinfer_log +@flashinfer_api def test_matmul_decorated(A, B): return torch.matmul(A, B) @@ -209,8 +209,8 @@ def main(): # Display logging configuration print("\nLogging Configuration:") - print(f" FLASHINFER_LOGLEVEL_DBG = {LOGGING_LEVEL}") - print(f" FLASHINFER_LOGDEST_DBG = {LOG_DEST}") + print(f" FLASHINFER_APILEVEL = {LOGGING_LEVEL}") + print(f" FLASHINFER_APIDEST = {LOG_DEST}") # Get level name level_names = { @@ -314,7 +314,7 @@ def main(): print("\nTo benchmark other levels, run:") for level in [0, 1, 3, 5]: if level != LOGGING_LEVEL: - print(f" FLASHINFER_LOGLEVEL_DBG={level} python {sys.argv[0]}") + print(f" FLASHINFER_APILEVEL={level} python {sys.argv[0]}") print("\n" + "=" * 80) print("Benchmark complete!") diff --git a/flashinfer/api_logging.py b/flashinfer/api_logging.py index 6948e728d6..65414725f0 100644 --- a/flashinfer/api_logging.py +++ b/flashinfer/api_logging.py @@ -39,10 +39,8 @@ def _substitute_process_id(path: str) -> str: # Read environment variables once at module load time -_API_LOG_LEVEL = int(os.environ.get("FLASHINFER_LOGLEVEL_DBG", "0")) -_API_LOG_DEST = _substitute_process_id( - os.environ.get("FLASHINFER_LOGDEST_DBG", "stdout") -) +_API_LOG_LEVEL = int(os.environ.get("FLASHINFER_APILEVEL", "0")) +_API_LOG_DEST = _substitute_process_id(os.environ.get("FLASHINFER_APIDEST", "stdout")) # Create logger using Python's logging library _logger = logging.getLogger("flashinfer.api") @@ -56,7 +54,7 @@ def _setup_logger(): _logger.setLevel(logging.CRITICAL + 1) # Higher than any level return - # All enabled levels use loggging.DEBUG; verbosity is controlled by FLASHINFER_LOGLEVEL_DBG instead + # All enabled levels use loggging.DEBUG; verbosity is controlled by FLASHINFER_APILEVEL instead _logger.setLevel(logging.DEBUG) # Remove any existing handlers @@ -463,22 +461,22 @@ def _log_function_outputs(func_name: str, result: Any, level: int) -> None: _logger.debug("\n".join(lines)) -def flashinfer_log(func: Callable = None) -> Callable: +def flashinfer_api(func: Callable = None) -> Callable: """ Decorator to log FlashInfer API calls using Python's logging library. This decorator integrates with Python's standard logging infrastructure while - maintaining zero overhead when disabled (FLASHINFER_LOGLEVEL_DBG=0). + maintaining zero overhead when disabled (FLASHINFER_APILEVEL=0). Environment Variables --------------------- - FLASHINFER_LOGLEVEL_DBG : int (default: 0) + FLASHINFER_APILEVEL : int (default: 0) - 0: No logging (zero overhead - decorator returns original function) - 1: Log function name only (logged BEFORE execution - crash-safe) - 3: Log function name + inputs/outputs with metadata (inputs logged BEFORE execution - crash-safe) - 5: Log function name + inputs/outputs with metadata + tensor statistics (inputs logged BEFORE execution - crash-safe) - FLASHINFER_LOGDEST_DBG : str (default: "stdout") + FLASHINFER_APIDEST : str (default: "stdout") - "stdout": Log to standard output - "stderr": Log to standard error - : Log to specified file path @@ -488,7 +486,7 @@ def flashinfer_log(func: Callable = None) -> Callable: -------- Basic usage: - >>> @flashinfer_log + >>> @flashinfer_api ... def my_function(x, y): ... return x + y @@ -496,7 +494,7 @@ def flashinfer_log(func: Callable = None) -> Callable: ----- - Key header lines include a timestamp in the format: [YYYY-MM-DD HH:MM:SS] (e.g., "FlashInfer API Call: function_name", "FlashInfer API Logging - System Information") - - When FLASHINFER_LOGLEVEL_DBG=0, the decorator has truly zero overhead + - When FLASHINFER_APILEVEL=0, the decorator has truly zero overhead as it returns the original function unchanged. - Function names and inputs are logged BEFORE execution: - Level 1: Function name only diff --git a/flashinfer/cudnn/decode.py b/flashinfer/cudnn/decode.py index 0635623d36..195ca2d49d 100644 --- a/flashinfer/cudnn/decode.py +++ b/flashinfer/cudnn/decode.py @@ -3,7 +3,7 @@ import torch -from ..api_logging import flashinfer_log +from ..api_logging import flashinfer_api from .utils import get_cudnn_fmha_gen_module try: @@ -253,7 +253,7 @@ def _batch_decode_with_kv_cache( return out -@flashinfer_log +@flashinfer_api def cudnn_batch_decode_with_kv_cache( q: torch.Tensor, k_cache: torch.Tensor, diff --git a/flashinfer/cudnn/prefill.py b/flashinfer/cudnn/prefill.py index 8d299355b6..b8c09a66ee 100644 --- a/flashinfer/cudnn/prefill.py +++ b/flashinfer/cudnn/prefill.py @@ -3,7 +3,7 @@ import torch -from ..api_logging import flashinfer_log +from ..api_logging import flashinfer_api from .utils import get_cudnn_fmha_gen_module try: @@ -384,7 +384,7 @@ def _batch_prefill_with_kv_cache( return out, None -@flashinfer_log +@flashinfer_api def cudnn_batch_prefill_with_kv_cache( q: torch.Tensor, k_cache: torch.Tensor, diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 91077cb814..ab34ba8857 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -21,7 +21,7 @@ import torch -from .api_logging import flashinfer_log +from .api_logging import flashinfer_api from .xqa import xqa, xqa_mla from .cudnn import cudnn_batch_decode_with_kv_cache as cudnn_batch_decode_with_kv_cache from .jit import ( @@ -313,7 +313,7 @@ def get_trtllm_gen_fmha_module(): return op -@flashinfer_log +@flashinfer_api def single_decode_with_kv_cache_with_jit_module( jit_module: Any, q: torch.Tensor, @@ -390,7 +390,7 @@ def single_decode_with_kv_cache( ) -> Tuple[torch.Tensor, torch.Tensor]: ... -@flashinfer_log +@flashinfer_api def single_decode_with_kv_cache( q: torch.Tensor, k: torch.Tensor, @@ -649,7 +649,7 @@ class BatchDecodeWithPagedKVCacheWrapper: manages the lifecycle of these data structures. """ - @flashinfer_log + @flashinfer_api def __init__( self, float_workspace_buffer: torch.Tensor, @@ -813,7 +813,7 @@ def reset_workspace_buffer( pin_memory=True, ) - @flashinfer_log + @flashinfer_api def plan( self, indptr: torch.Tensor, @@ -1167,7 +1167,7 @@ def run( window_left: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: ... - @flashinfer_log + @flashinfer_api def run( self, q: torch.Tensor, @@ -2065,7 +2065,7 @@ def _fake_paged_run( ) -@flashinfer_log +@flashinfer_api def trtllm_batch_decode_with_kv_cache( query: torch.Tensor, kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], @@ -2339,7 +2339,7 @@ def trtllm_batch_decode_with_kv_cache( # xqa uses NHD layout -@flashinfer_log +@flashinfer_api def xqa_batch_decode_with_kv_cache( query: torch.Tensor, kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], @@ -2524,7 +2524,7 @@ def _check_trtllm_gen_mla_shape( ) -@flashinfer_log +@flashinfer_api def trtllm_batch_decode_with_kv_cache_mla( query: torch.Tensor, kv_cache: torch.Tensor, @@ -2686,7 +2686,7 @@ def trtllm_batch_decode_with_kv_cache_mla( raise ValueError(f"Backend {backend} not supported") -@flashinfer_log +@flashinfer_api def xqa_batch_decode_with_kv_cache_mla( query: torch.Tensor, kv_cache: torch.Tensor, diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index b66c69dc6a..7b53c3f82c 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -20,7 +20,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch -from ..api_logging import flashinfer_log +from ..api_logging import flashinfer_api from ..autotuner import ( AutoTuner, DynamicTensorSpec, @@ -686,7 +686,7 @@ def _fake_cutlass_fused_moe( # ref: https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py#L121 -@flashinfer_log +@flashinfer_api def cutlass_fused_moe( input: torch.Tensor, token_selected_experts: torch.Tensor, @@ -1859,7 +1859,7 @@ def _fake_trtllm_fp4_block_scale_moe( ) -@flashinfer_log +@flashinfer_api def trtllm_bf16_moe( routing_logits: torch.Tensor, routing_bias: Optional[torch.Tensor], @@ -1940,7 +1940,7 @@ def trtllm_bf16_moe( ) -@flashinfer_log +@flashinfer_api def trtllm_fp8_per_tensor_scale_moe( routing_logits: torch.Tensor, routing_bias: Optional[torch.Tensor], @@ -2014,7 +2014,7 @@ def trtllm_fp8_per_tensor_scale_moe( ) -@flashinfer_log +@flashinfer_api def trtllm_fp8_block_scale_moe( routing_logits: torch.Tensor, routing_bias: Optional[torch.Tensor], @@ -2092,7 +2092,7 @@ def trtllm_fp8_block_scale_moe( ) -@flashinfer_log +@flashinfer_api def trtllm_fp4_block_scale_moe( routing_logits: torch.Tensor, routing_bias: Optional[torch.Tensor], @@ -2222,7 +2222,7 @@ def trtllm_fp4_block_scale_moe( ) -@flashinfer_log +@flashinfer_api def trtllm_fp4_block_scale_routed_moe( topk_ids: torch.Tensor, routing_bias: Optional[torch.Tensor], diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index ec88b09620..251e2a4682 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -22,7 +22,7 @@ from flashinfer.trtllm_low_latency_gemm import trtllm_low_latency_gemm import torch -from ..api_logging import flashinfer_log +from ..api_logging import flashinfer_api from ..autotuner import ( AutoTuner, ConstraintSpec, @@ -540,7 +540,7 @@ def forward( ) -@flashinfer_log +@flashinfer_api def tgv_gemm_sm100( a: torch.Tensor, b: torch.Tensor, @@ -886,7 +886,7 @@ def reset_workspace_buffer( self._float_workspace_buffer = float_workspace_buffer self._int_workspace_buffer = int_workspace_buffer - @flashinfer_log + @flashinfer_api def run( self, x: torch.Tensor, @@ -1554,7 +1554,7 @@ def _expand_block_scale_tensor_shape(block_scale_tensor, batch_size): return (tuple(block_scale_shape), tuple(block_scale_stride)) -@flashinfer_log +@flashinfer_api def mm_fp8( a: torch.Tensor, b: torch.Tensor, @@ -2028,7 +2028,7 @@ def _heuristic_func_mm_fp4( common_check=_check_mm_fp4_problem_size, heuristic_func=_heuristic_func_mm_fp4, # result stored in mm_fp4.suitable_auto_backends ) -@flashinfer_log +@flashinfer_api def mm_fp4( a: torch.Tensor, b: torch.Tensor, @@ -2286,7 +2286,7 @@ def _heuristic_func_bmm_fp8( common_check=_check_bmm_fp8_problem_size, heuristic_func=_heuristic_func_bmm_fp8, ) -@flashinfer_log +@flashinfer_api def bmm_fp8( A: torch.Tensor, B: torch.Tensor, @@ -2378,7 +2378,7 @@ def bmm_fp8( return out -@flashinfer_log +@flashinfer_api def gemm_fp8_nt_groupwise( a: torch.Tensor, b: torch.Tensor, @@ -2630,7 +2630,7 @@ def forward( ) -@flashinfer_log +@flashinfer_api def gemm_fp8_nt_blockscaled( a: torch.Tensor, b: torch.Tensor, @@ -2659,7 +2659,7 @@ def gemm_fp8_nt_blockscaled( ) -@flashinfer_log +@flashinfer_api def group_gemm_fp8_nt_groupwise( a: torch.Tensor, # (cum_m, k) b: torch.Tensor, # (batch_size, n, k) @@ -2822,7 +2822,7 @@ def group_gemm_fp8_nt_groupwise( return out -@flashinfer_log +@flashinfer_api def group_gemm_mxfp8_mxfp4_nt_groupwise( a: torch.Tensor, # (cum_m, k) b: torch.Tensor, # (batch_size, n, k // 2) @@ -2990,7 +2990,7 @@ def get_deepgemm_sm100_module(): return module -@flashinfer_log +@flashinfer_api def group_deepgemm_fp8_nt_groupwise( a: torch.Tensor, # (m, k) b: torch.Tensor, # (batch_size, n, k) @@ -3121,7 +3121,7 @@ def group_deepgemm_fp8_nt_groupwise( return out -@flashinfer_log +@flashinfer_api def batch_deepgemm_fp8_nt_groupwise( a: torch.Tensor, # (batch_size, m, k) b: torch.Tensor, # (batch_size, n, k) diff --git a/flashinfer/mla.py b/flashinfer/mla.py index 4ae2fdcd5e..22cf029a2e 100644 --- a/flashinfer/mla.py +++ b/flashinfer/mla.py @@ -19,7 +19,7 @@ import torch -from .api_logging import flashinfer_log +from .api_logging import flashinfer_api from .jit import gen_batch_mla_module from .jit.mla import gen_mla_module from .utils import MaskMode, check_shape_dtype_device, determine_mla_backend @@ -130,7 +130,7 @@ class BatchMLAPagedAttentionWrapper: torch.Size([114, 128, 512]) """ - @flashinfer_log + @flashinfer_api def __init__( self, float_workspace_buffer: torch.Tensor, @@ -201,7 +201,7 @@ def __init__( else: self._backend = backend - @flashinfer_log + @flashinfer_api def plan( self, qo_indptr: torch.Tensor, @@ -336,7 +336,7 @@ def run( return_lse_base_on_e: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: ... - @flashinfer_log + @flashinfer_api def run( self, q_nope: torch.Tensor, diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index c266b5e33c..a2c4ceb0a8 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -22,7 +22,7 @@ import torch -from .api_logging import flashinfer_log +from .api_logging import flashinfer_api from .jit import ( gen_batch_prefill_module, gen_customize_batch_prefill_module, @@ -874,7 +874,7 @@ def _fake_paged_run( ) -@flashinfer_log +@flashinfer_api def single_prefill_with_kv_cache_with_jit_module( jit_module: Any, q: torch.Tensor, @@ -959,7 +959,7 @@ def single_prefill_with_kv_cache( ) -> Tuple[torch.Tensor, torch.Tensor]: ... -@flashinfer_log +@flashinfer_api def single_prefill_with_kv_cache( q: torch.Tensor, k: torch.Tensor, @@ -1328,7 +1328,7 @@ class BatchPrefillWithPagedKVCacheWrapper: wrapper class manages the lifecycle of these data structures. """ - @flashinfer_log + @flashinfer_api def __init__( self, float_workspace_buffer: torch.Tensor, @@ -1524,7 +1524,7 @@ def reset_workspace_buffer( pin_memory=True, ) - @flashinfer_log + @flashinfer_api def plan( self, qo_indptr: torch.Tensor, @@ -1981,7 +1981,7 @@ def run( window_left: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: ... - @flashinfer_log + @flashinfer_api def run( self, q: torch.Tensor, @@ -2356,7 +2356,7 @@ class BatchPrefillWithRaggedKVCacheWrapper: wrapper class manages the lifecycle of these data structures. """ - @flashinfer_log + @flashinfer_api def __init__( self, float_workspace_buffer: torch.Tensor, @@ -2500,7 +2500,7 @@ def reset_workspace_buffer( pin_memory=True, ) - @flashinfer_log + @flashinfer_api def plan( self, qo_indptr: torch.Tensor, @@ -2845,7 +2845,7 @@ def run( enable_pdl: Optional[bool] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: ... - @flashinfer_log + @flashinfer_api def run( self, q: torch.Tensor, @@ -3202,7 +3202,7 @@ def get_trtllm_gen_fmha_module(): return op -@flashinfer_log +@flashinfer_api def trtllm_ragged_attention_deepseek( query: torch.Tensor, key: torch.Tensor, @@ -3337,7 +3337,7 @@ def trtllm_ragged_attention_deepseek( return out -@flashinfer_log +@flashinfer_api def trtllm_batch_context_with_kv_cache( query: torch.Tensor, kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], diff --git a/tests/utils/test_logging.py b/tests/utils/test_logging.py index 92cff6fc7e..f6888ed4ed 100644 --- a/tests/utils/test_logging.py +++ b/tests/utils/test_logging.py @@ -47,21 +47,21 @@ class TestAPILogging: def setup_and_teardown(self): """Reset environment and reimport logging module for each test.""" # Store original environment - original_level = os.environ.get("FLASHINFER_LOGLEVEL_DBG") - original_dest = os.environ.get("FLASHINFER_LOGDEST_DBG") + original_level = os.environ.get("FLASHINFER_APILEVEL") + original_dest = os.environ.get("FLASHINFER_APIDEST") yield # Restore original environment if original_level is not None: - os.environ["FLASHINFER_LOGLEVEL_DBG"] = original_level - elif "FLASHINFER_LOGLEVEL_DBG" in os.environ: - del os.environ["FLASHINFER_LOGLEVEL_DBG"] + os.environ["FLASHINFER_APILEVEL"] = original_level + elif "FLASHINFER_APILEVEL" in os.environ: + del os.environ["FLASHINFER_APILEVEL"] if original_dest is not None: - os.environ["FLASHINFER_LOGDEST_DBG"] = original_dest - elif "FLASHINFER_LOGDEST_DBG" in os.environ: - del os.environ["FLASHINFER_LOGDEST_DBG"] + os.environ["FLASHINFER_APIDEST"] = original_dest + elif "FLASHINFER_APIDEST" in os.environ: + del os.environ["FLASHINFER_APIDEST"] # Force reimport to pick up new environment variables if "flashinfer.api_logging" in sys.modules: @@ -69,16 +69,16 @@ def setup_and_teardown(self): def setup_logging(self, level: int, dest: str = "stdout"): """Helper to set up logging environment and reimport.""" - os.environ["FLASHINFER_LOGLEVEL_DBG"] = str(level) - os.environ["FLASHINFER_LOGDEST_DBG"] = dest + os.environ["FLASHINFER_APILEVEL"] = str(level) + os.environ["FLASHINFER_APIDEST"] = dest # Force reimport if "flashinfer.api_logging" in sys.modules: del sys.modules["flashinfer.api_logging"] - from flashinfer.api_logging import flashinfer_log + from flashinfer.api_logging import flashinfer_api - return flashinfer_log + return flashinfer_api def test_level_0_zero_overhead(self): """Test that level 0 has truly zero overhead (returns original function).""" From 4a80fc42509d95b4a40d1f11ef66dbca0ba4568c Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Fri, 21 Nov 2025 22:02:47 +0000 Subject: [PATCH 11/13] Fix typo --- benchmarks/bench_logging_overhead.py | 20 ++++++++++---------- flashinfer/api_logging.py | 14 +++++++------- tests/utils/test_logging.py | 20 ++++++++++---------- 3 files changed, 27 insertions(+), 27 deletions(-) diff --git a/benchmarks/bench_logging_overhead.py b/benchmarks/bench_logging_overhead.py index db916a0ce2..e67edcfa45 100644 --- a/benchmarks/bench_logging_overhead.py +++ b/benchmarks/bench_logging_overhead.py @@ -7,14 +7,14 @@ Usage: # Set the logging level before running - export FLASHINFER_APILEVEL=3 + export FLASHINFER_LOGLEVEL=3 python bench_logging_overhead.py # Or run with different levels - FLASHINFER_APILEVEL=0 python bench_logging_overhead.py - FLASHINFER_APILEVEL=1 python bench_logging_overhead.py - FLASHINFER_APILEVEL=3 python bench_logging_overhead.py - FLASHINFER_APILEVEL=5 python bench_logging_overhead.py + FLASHINFER_LOGLEVEL=0 python bench_logging_overhead.py + FLASHINFER_LOGLEVEL=1 python bench_logging_overhead.py + FLASHINFER_LOGLEVEL=3 python bench_logging_overhead.py + FLASHINFER_LOGLEVEL=5 python bench_logging_overhead.py # Or use the helper script to run all levels bash benchmark_all_levels.sh @@ -28,8 +28,8 @@ from typing import List, Tuple # Get logging level BEFORE importing flashinfer -LOGGING_LEVEL = int(os.environ.get("FLASHINFER_APILEVEL", "0")) -LOG_DEST = os.environ.get("FLASHINFER_APIDEST", "/tmp/flashinfer_benchmark_log.txt") +LOGGING_LEVEL = int(os.environ.get("FLASHINFER_LOGLEVEL", "0")) +LOG_DEST = os.environ.get("FLASHINFER_LOGDEST", "/tmp/flashinfer_benchmark_log.txt") # Import the decorator from flashinfer.api_logging import flashinfer_api @@ -209,8 +209,8 @@ def main(): # Display logging configuration print("\nLogging Configuration:") - print(f" FLASHINFER_APILEVEL = {LOGGING_LEVEL}") - print(f" FLASHINFER_APIDEST = {LOG_DEST}") + print(f" FLASHINFER_LOGLEVEL = {LOGGING_LEVEL}") + print(f" FLASHINFER_LOGDEST = {LOG_DEST}") # Get level name level_names = { @@ -314,7 +314,7 @@ def main(): print("\nTo benchmark other levels, run:") for level in [0, 1, 3, 5]: if level != LOGGING_LEVEL: - print(f" FLASHINFER_APILEVEL={level} python {sys.argv[0]}") + print(f" FLASHINFER_LOGLEVEL={level} python {sys.argv[0]}") print("\n" + "=" * 80) print("Benchmark complete!") diff --git a/flashinfer/api_logging.py b/flashinfer/api_logging.py index 65414725f0..3946c3818b 100644 --- a/flashinfer/api_logging.py +++ b/flashinfer/api_logging.py @@ -39,8 +39,8 @@ def _substitute_process_id(path: str) -> str: # Read environment variables once at module load time -_API_LOG_LEVEL = int(os.environ.get("FLASHINFER_APILEVEL", "0")) -_API_LOG_DEST = _substitute_process_id(os.environ.get("FLASHINFER_APIDEST", "stdout")) +_API_LOG_LEVEL = int(os.environ.get("FLASHINFER_LOGLEVEL", "0")) +_API_LOG_DEST = _substitute_process_id(os.environ.get("FLASHINFER_LOGDEST", "stdout")) # Create logger using Python's logging library _logger = logging.getLogger("flashinfer.api") @@ -54,7 +54,7 @@ def _setup_logger(): _logger.setLevel(logging.CRITICAL + 1) # Higher than any level return - # All enabled levels use loggging.DEBUG; verbosity is controlled by FLASHINFER_APILEVEL instead + # All enabled levels use loggging.DEBUG; verbosity is controlled by FLASHINFER_LOGLEVEL instead _logger.setLevel(logging.DEBUG) # Remove any existing handlers @@ -466,17 +466,17 @@ def flashinfer_api(func: Callable = None) -> Callable: Decorator to log FlashInfer API calls using Python's logging library. This decorator integrates with Python's standard logging infrastructure while - maintaining zero overhead when disabled (FLASHINFER_APILEVEL=0). + maintaining zero overhead when disabled (FLASHINFER_LOGLEVEL=0). Environment Variables --------------------- - FLASHINFER_APILEVEL : int (default: 0) + FLASHINFER_LOGLEVEL : int (default: 0) - 0: No logging (zero overhead - decorator returns original function) - 1: Log function name only (logged BEFORE execution - crash-safe) - 3: Log function name + inputs/outputs with metadata (inputs logged BEFORE execution - crash-safe) - 5: Log function name + inputs/outputs with metadata + tensor statistics (inputs logged BEFORE execution - crash-safe) - FLASHINFER_APIDEST : str (default: "stdout") + FLASHINFER_LOGDEST : str (default: "stdout") - "stdout": Log to standard output - "stderr": Log to standard error - : Log to specified file path @@ -494,7 +494,7 @@ def flashinfer_api(func: Callable = None) -> Callable: ----- - Key header lines include a timestamp in the format: [YYYY-MM-DD HH:MM:SS] (e.g., "FlashInfer API Call: function_name", "FlashInfer API Logging - System Information") - - When FLASHINFER_APILEVEL=0, the decorator has truly zero overhead + - When FLASHINFER_LOGLEVEL=0, the decorator has truly zero overhead as it returns the original function unchanged. - Function names and inputs are logged BEFORE execution: - Level 1: Function name only diff --git a/tests/utils/test_logging.py b/tests/utils/test_logging.py index f6888ed4ed..6ead5e7d6b 100644 --- a/tests/utils/test_logging.py +++ b/tests/utils/test_logging.py @@ -47,21 +47,21 @@ class TestAPILogging: def setup_and_teardown(self): """Reset environment and reimport logging module for each test.""" # Store original environment - original_level = os.environ.get("FLASHINFER_APILEVEL") - original_dest = os.environ.get("FLASHINFER_APIDEST") + original_level = os.environ.get("FLASHINFER_LOGLEVEL") + original_dest = os.environ.get("FLASHINFER_LOGDEST") yield # Restore original environment if original_level is not None: - os.environ["FLASHINFER_APILEVEL"] = original_level - elif "FLASHINFER_APILEVEL" in os.environ: - del os.environ["FLASHINFER_APILEVEL"] + os.environ["FLASHINFER_LOGLEVEL"] = original_level + elif "FLASHINFER_LOGLEVEL" in os.environ: + del os.environ["FLASHINFER_LOGLEVEL"] if original_dest is not None: - os.environ["FLASHINFER_APIDEST"] = original_dest - elif "FLASHINFER_APIDEST" in os.environ: - del os.environ["FLASHINFER_APIDEST"] + os.environ["FLASHINFER_LOGDEST"] = original_dest + elif "FLASHINFER_LOGDEST" in os.environ: + del os.environ["FLASHINFER_LOGDEST"] # Force reimport to pick up new environment variables if "flashinfer.api_logging" in sys.modules: @@ -69,8 +69,8 @@ def setup_and_teardown(self): def setup_logging(self, level: int, dest: str = "stdout"): """Helper to set up logging environment and reimport.""" - os.environ["FLASHINFER_APILEVEL"] = str(level) - os.environ["FLASHINFER_APIDEST"] = dest + os.environ["FLASHINFER_LOGLEVEL"] = str(level) + os.environ["FLASHINFER_LOGDEST"] = dest # Force reimport if "flashinfer.api_logging" in sys.modules: From a75c35949538af265353e0617feac4ba956d1039 Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Fri, 21 Nov 2025 22:29:27 +0000 Subject: [PATCH 12/13] Add disclaimer to decorator docstring --- flashinfer/api_logging.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/flashinfer/api_logging.py b/flashinfer/api_logging.py index 3946c3818b..734d6bae28 100644 --- a/flashinfer/api_logging.py +++ b/flashinfer/api_logging.py @@ -463,11 +463,14 @@ def _log_function_outputs(func_name: str, result: Any, level: int) -> None: def flashinfer_api(func: Callable = None) -> Callable: """ - Decorator to log FlashInfer API calls using Python's logging library. + Decorator to FlashInfer's APIs. + Currently logs input and output values of the function using Python's logging library. This decorator integrates with Python's standard logging infrastructure while maintaining zero overhead when disabled (FLASHINFER_LOGLEVEL=0). + NOTE/TODO: Not all FlashInfer APIs are decorated with this decorator yet. This is a work in progress. + Environment Variables --------------------- FLASHINFER_LOGLEVEL : int (default: 0) From 3bb0b73ee698c1c523595e6550326931c3ce8ef5 Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Sat, 22 Nov 2025 01:17:02 +0000 Subject: [PATCH 13/13] Move logging.md to documentation --- LOGGING.md | 83 --------------------------------- docs/index.rst | 1 + docs/logging.rst | 118 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 119 insertions(+), 83 deletions(-) delete mode 100644 LOGGING.md create mode 100644 docs/logging.rst diff --git a/LOGGING.md b/LOGGING.md deleted file mode 100644 index b44a0036d7..0000000000 --- a/LOGGING.md +++ /dev/null @@ -1,83 +0,0 @@ -# FlashInfer Logging - -FlashInfer provides a logging feature to help debug issues, and reproduce crashes. This document describes all available logging levels and their features. - -## Quick Start - -Enable logging using two environment variables: - -```bash -# Set logging level (0-5) -export FLASHINFER_LOGLEVEL=3 - -# Set log destination (default is stdout) -export FLASHINFER_LOGDEST=stdout # or stderr, or a file path like "flashinfer.log" - -# Run your code -python train.py -``` - -## Logging Levels - -| Level | Name | Features | Use Case | -|-------|------|----------|----------| -| **0** | Disabled (Default) | No logging (zero overhead) | Production | -| **1** | Function Names | Function names only | Basic tracing | -| **3** | Inputs/Outputs | Function names + arguments + outputs with metadata | Standard debugging | -| **5** | Statistics | Level 3 + tensor statistics (min, max, mean, NaN/Inf counts) | Numerical analysis | - - -## Environment Variables - -### Main Configuration - -| Variable | Type | Default | Description | -|----------|------|---------|-------------| -| `FLASHINFER_LOGLEVEL` | int | 0 | Logging level (0, 1, 3, 5) | -| `FLASHINFER_LOGDEST` | str | `stdout` | Log destination: `stdout`, `stderr`, or file path | - -### Process ID Substitution - -Use `%i` in file paths for automatic process ID substitution (useful for multi-GPU training): - -```bash -export FLASHINFER_LOGDEST="flashinfer_log_%i.txt" # → flashinfer_log_12345.txt -``` - -This works for: -- `FLASHINFER_LOGDEST` - -## Miscellaneous Notes and Examples -### CUDA Graph Compatibility - -Level 5 statistics are **automatically skipped during CUDA graph capture** to avoid synchronization issues. - -```python -# This works correctly - no synchronization errors -with torch.cuda.graph(cuda_graph): - result = mm_fp4(a, b, scales) # Level 5 logging active - # Statistics automatically skipped during capture -``` - -Output shows: `[statistics skipped: CUDA graph capture in progress]` - -### Process IDs for Multi-GPU Environments - -```bash -# Use %i for process ID substitution -export FLASHINFER_LOGLEVEL=3 -export FLASHINFER_LOGDEST="logs/flashinfer_api_%i.log" - -torchrun --nproc_per_node=8 awesome_script_that_uses_FlashInfer.py - -# Creates separate logs: -# logs/flashinfer_api_12345.log (rank 0) -# logs/flashinfer_api_12346.log (rank 1) -# ... -``` - -## Frequently Asked Questions - -### Q: Does Level 0 really have zero overhead? - -**A: Yes.** At Level 0, the decorator returns the original function unchanged. No wrapper, no checks, no overhead. diff --git a/docs/index.rst b/docs/index.rst index 6a5a9c6a19..f4e61d26c4 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -15,6 +15,7 @@ FlashInfer is a library and kernel generator for Large Language Models that prov :caption: Get Started installation + logging .. toctree:: :maxdepth: 2 diff --git a/docs/logging.rst b/docs/logging.rst new file mode 100644 index 0000000000..c3c2c83d8f --- /dev/null +++ b/docs/logging.rst @@ -0,0 +1,118 @@ +.. _logging: + +Logging +======= + +FlashInfer provides a logging feature to help debug issues and reproduce crashes. This document describes all available logging levels and their features. + +Quick Start +----------- + +Enable logging using two environment variables: + +.. code-block:: bash + + # Set logging level (0-5) + export FLASHINFER_LOGLEVEL=3 + + # Set log destination (default is stdout) + export FLASHINFER_LOGDEST=stdout # or stderr, or a file path like "flashinfer.log" + +Logging Levels +-------------- + +.. list-table:: + :header-rows: 1 + :widths: 10 20 35 25 + + * - Level + - Name + - Features + - Use Case + * - **0** + - Disabled (Default) + - No logging (zero overhead) + - Production + * - **1** + - Function Names + - Function names only + - Basic tracing + * - **3** + - Inputs/Outputs + - Function names + arguments + outputs with metadata + - Standard debugging + * - **5** + - Statistics + - Level 3 + tensor statistics (min, max, mean, NaN/Inf counts) + - Numerical analysis + +Environment Variables +--------------------- + +Main Configuration +^^^^^^^^^^^^^^^^^^ + +.. list-table:: + :header-rows: 1 + :widths: 30 15 15 40 + + * - Variable + - Type + - Default + - Description + * - ``FLASHINFER_LOGLEVEL`` + - int + - 0 + - Logging level (0, 1, 3, 5) + * - ``FLASHINFER_LOGDEST`` + - str + - ``stdout`` + - Log destination: ``stdout``, ``stderr``, or file path + +Process ID Substitution +^^^^^^^^^^^^^^^^^^^^^^^^ + +Use ``%i`` in file paths for automatic process ID substitution (useful for multi-GPU training): + +.. code-block:: bash + + export FLASHINFER_LOGDEST="flashinfer_log_%i.txt" # → flashinfer_log_12345.txt + + +Miscellaneous Notes and Examples +--------------------------------- + +CUDA Graph Compatibility +^^^^^^^^^^^^^^^^^^^^^^^^^ + +Level 5 statistics are **automatically skipped during CUDA graph capture** to avoid synchronization issues. + +.. code-block:: python + + # This works correctly - no synchronization errors + with torch.cuda.graph(cuda_graph): + result = mm_fp4(a, b, scales, ...) # Level 5 logging active + # Statistics automatically skipped during capture + +Output shows: ``[statistics skipped: CUDA graph capture in progress]`` + +Process IDs for Multi-GPU Environments +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: bash + + # Use %i for process ID substitution + export FLASHINFER_LOGLEVEL=3 + export FLASHINFER_LOGDEST="logs/flashinfer_api_%i.log" + + torchrun --nproc_per_node=8 awesome_script_that_uses_FlashInfer.py + + # Creates separate logs: + # logs/flashinfer_api_12345.log (rank 0) + # logs/flashinfer_api_12346.log (rank 1) + # ... + +Level 0 has zero overhead +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +At Level 0, the decorator returns the original function unchanged. No wrapper, no checks, no overhead.