diff --git a/README.md b/README.md index 94eece5007..cd5c7e1e58 100644 --- a/README.md +++ b/README.md @@ -169,6 +169,20 @@ o = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=False) # prefill att Check out [documentation](https://docs.flashinfer.ai/) for usage of batch decode/append/prefill kernels and shared-prefix cascading kernels. +## API Logging + +FlashInfer provides comprehensive API logging for debugging. Enable it using environment variables: + +```bash +# Enable logging (levels: 0=off (default), 1=basic, 3=detailed, 5=statistics) +export FLASHINFER_LOGLEVEL=3 + +# Set log destination (stdout (default), stderr, or file path) +export FLASHINFER_LOGDEST=stdout +``` + +For detailed information about logging levels, configuration, and advanced features, see [LOGGING.md](LOGGING.md). + ## Custom Attention Variants Starting from FlashInfer v0.2, users can customize their own attention variants with additional parameters. For more details, refer to our [JIT examples](https://github.com/flashinfer-ai/flashinfer/blob/main/tests/utils/test_jit_example.py). diff --git a/benchmarks/bench_logging_overhead.py b/benchmarks/bench_logging_overhead.py new file mode 100644 index 0000000000..e67edcfa45 --- /dev/null +++ b/benchmarks/bench_logging_overhead.py @@ -0,0 +1,333 @@ +#!/usr/bin/env python3 +""" +Benchmark script to measure the overhead of API logging at different levels. + +This script creates decorated and undecorated versions of a test function +(torch.matmul) and compares their performance to accurately measure logging overhead. + +Usage: + # Set the logging level before running + export FLASHINFER_LOGLEVEL=3 + python bench_logging_overhead.py + + # Or run with different levels + FLASHINFER_LOGLEVEL=0 python bench_logging_overhead.py + FLASHINFER_LOGLEVEL=1 python bench_logging_overhead.py + FLASHINFER_LOGLEVEL=3 python bench_logging_overhead.py + FLASHINFER_LOGLEVEL=5 python bench_logging_overhead.py + + # Or use the helper script to run all levels + bash benchmark_all_levels.sh +""" + +import os +import sys +import time +import torch +import numpy as np +from typing import List, Tuple + +# Get logging level BEFORE importing flashinfer +LOGGING_LEVEL = int(os.environ.get("FLASHINFER_LOGLEVEL", "0")) +LOG_DEST = os.environ.get("FLASHINFER_LOGDEST", "/tmp/flashinfer_benchmark_log.txt") + +# Import the decorator +from flashinfer.api_logging import flashinfer_api + + +# Create two versions of a test function: +# 1. Undecorated (baseline) +# 2. Decorated (with logging) +def test_matmul_undecorated(A, B): + return torch.matmul(A, B) + + +@flashinfer_api +def test_matmul_decorated(A, B): + return torch.matmul(A, B) + + +class BenchmarkResults: + """Store and display benchmark results.""" + + def __init__(self): + self.undecorated_times = [] + self.decorated_times = [] + + def set_undecorated(self, times: List[float]): + """Set benchmark results for undecorated function.""" + self.undecorated_times = times + + def set_decorated(self, times: List[float]): + """Set benchmark results for decorated function.""" + self.decorated_times = times + + def print_summary(self, logging_level: int): + """Print a summary of benchmark results.""" + print("\n" + "=" * 80) + print("BENCHMARK RESULTS") + print("=" * 80) + + undecorated_mean = np.mean(self.undecorated_times) + undecorated_std = np.std(self.undecorated_times) + + decorated_mean = np.mean(self.decorated_times) + decorated_std = np.std(self.decorated_times) + + overhead_abs = (decorated_mean - undecorated_mean) * 1000 # ms + overhead_pct = ( + ((decorated_mean - undecorated_mean) / undecorated_mean * 100) + if undecorated_mean > 0 + else 0 + ) + + print( + f"\n{'Version':<20} {'Mean (ms)':<12} {'Std (ms)':<12} {'Median (ms)':<12}" + ) + print("-" * 80) + print( + f"{'Undecorated':<20} {undecorated_mean * 1000:<12.4f} {undecorated_std * 1000:<12.4f} {np.median(self.undecorated_times) * 1000:<12.4f}" + ) + print( + f"{'Decorated':<20} {decorated_mean * 1000:<12.4f} {decorated_std * 1000:<12.4f} {np.median(self.decorated_times) * 1000:<12.4f}" + ) + + print("\n" + "=" * 80) + print("OVERHEAD ANALYSIS") + print("=" * 80) + print(f"\nLogging Level: {logging_level}") + print(f"Absolute overhead: {overhead_abs:.4f} ms") + print(f"Relative overhead: {overhead_pct:.2f}%") + + print("\n" + "=" * 80) + print("DETAILED STATISTICS") + print("=" * 80) + + print("\nUndecorated (baseline):") + print(f" Mean: {undecorated_mean * 1000:.4f} ms") + print(f" Median: {np.median(self.undecorated_times) * 1000:.4f} ms") + print(f" Std: {undecorated_std * 1000:.4f} ms") + print(f" Min: {np.min(self.undecorated_times) * 1000:.4f} ms") + print(f" Max: {np.max(self.undecorated_times) * 1000:.4f} ms") + + print("\nDecorated (with logging):") + print(f" Mean: {decorated_mean * 1000:.4f} ms") + print(f" Median: {np.median(self.decorated_times) * 1000:.4f} ms") + print(f" Std: {decorated_std * 1000:.4f} ms") + print(f" Min: {np.min(self.decorated_times) * 1000:.4f} ms") + print(f" Max: {np.max(self.decorated_times) * 1000:.4f} ms") + + +def setup_test_inputs( + batch_size: int = 32, + m: int = 512, + n: int = 512, + k: int = 512, + device: str = "cuda:0", +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Set up test inputs for matmul. + + Parameters + ---------- + batch_size : int + Batch size for the matrix multiplication + m, n, k : int + Matrix dimensions + device : str + Device to use + + Returns + ------- + A, B : torch.Tensor + Input tensors for matrix multiplication + """ + # Create random tensors + A = torch.randn(batch_size, m, k, dtype=torch.float16, device=device) + B = torch.randn(batch_size, k, n, dtype=torch.float16, device=device) + + return A, B + + +def warmup(func, A, B, num_warmup: int = 10): + """Warmup the GPU and JIT compilation.""" + for _ in range(num_warmup): + _ = func(A, B) + torch.cuda.synchronize() + + +def benchmark_function( + func, func_name: str, A, B, num_iterations: int = 100 +) -> List[float]: + """ + Benchmark a specific function. + + Parameters + ---------- + func : callable + Function to benchmark + func_name : str + Name of the function (for display) + A, B : torch.Tensor + Input tensors for matrix multiplication + num_iterations : int + Number of iterations to run + + Returns + ------- + List[float] + List of execution times in seconds + """ + print(f"\nBenchmarking: {func_name}") + print(f" Running {num_iterations} iterations...") + + times = [] + + for _ in range(num_iterations): + # Synchronize before timing + torch.cuda.synchronize() + + # Time the execution + start = time.perf_counter() + _ = func(A, B) + torch.cuda.synchronize() + end = time.perf_counter() + + elapsed = end - start + times.append(elapsed) + + print(f" Complete. Mean time: {np.mean(times) * 1000:.4f} ms") + + return times + + +def main(): + """Main benchmark function.""" + print("=" * 80) + print("FlashInfer API Logging Overhead Benchmark") + print("=" * 80) + + # Display logging configuration + print("\nLogging Configuration:") + print(f" FLASHINFER_LOGLEVEL = {LOGGING_LEVEL}") + print(f" FLASHINFER_LOGDEST = {LOG_DEST}") + + # Get level name + level_names = { + 0: "No logging (zero-overhead)", + 1: "Function name only", + 3: "Name + inputs/outputs + metadata", + 5: "Name + inputs/outputs + metadata + statistics", + } + print(f" Level description: {level_names.get(LOGGING_LEVEL, 'Unknown')}") + + # Check if CUDA is available + if not torch.cuda.is_available(): + print("\nError: CUDA is not available. This benchmark requires a CUDA device.") + exit(1) + + device = "cuda:0" + print(f"\nDevice: {device}") + print(f"Device Name: {torch.cuda.get_device_name(device)}") + + # Setup test inputs + print("\nSetting up test inputs...") + batch_size = 32 + m, n, k = 128, 128, 128 + print(f" Batch size: {batch_size}") + print(f" Matrix dimensions: [{batch_size}, {m}, {k}] @ [{batch_size}, {k}, {n}]") + + A, B = setup_test_inputs(batch_size, m, n, k, device) + + # Benchmark parameters + num_iterations = 100 + print("\nBenchmark parameters:") + print(f" Iterations: {num_iterations}") + print(" Warmup iterations: 10") + + # Clear log file before starting + if os.path.exists(LOG_DEST): + os.remove(LOG_DEST) + + print("\n" + "=" * 80) + print("WARMUP PHASE") + print("=" * 80) + + # Warmup undecorated version + print("\nWarming up undecorated version...") + warmup(test_matmul_undecorated, A, B, num_warmup=10) + print(" Complete.") + + # Warmup decorated version + print("\nWarming up decorated version...") + warmup(test_matmul_decorated, A, B, num_warmup=10) + print(" Complete.") + + print("\n" + "=" * 80) + print("BENCHMARK PHASE") + print("=" * 80) + + # Store results + results = BenchmarkResults() + + # Benchmark undecorated version + undecorated_times = benchmark_function( + test_matmul_undecorated, "Undecorated (baseline)", A, B, num_iterations + ) + results.set_undecorated(undecorated_times) + + # Benchmark decorated version + decorated_times = benchmark_function( + test_matmul_decorated, + f"Decorated (logging level {LOGGING_LEVEL})", + A, + B, + num_iterations, + ) + results.set_decorated(decorated_times) + + # Print summary + results.print_summary(LOGGING_LEVEL) + + # Check log file size + if LOGGING_LEVEL > 0 and os.path.exists(LOG_DEST): + log_size = os.path.getsize(LOG_DEST) + print("\n" + "=" * 80) + print("LOG FILE INFO") + print("=" * 80) + print(f"Log file: {LOG_DEST}") + print(f"Log size: {log_size / 1024:.2f} KB ({log_size} bytes)") + print(f"Iterations logged: {num_iterations}") + print(f"Bytes per iteration: {log_size / num_iterations:.2f}") + + # Cleanup option + cleanup_log = os.environ.get("CLEANUP_LOG", "true").lower() == "true" + if cleanup_log: + os.remove(LOG_DEST) + print("\n Log file removed (set CLEANUP_LOG=false to keep it)") + else: + print(f"\n Log file preserved at {LOG_DEST}") + + print("\n" + "=" * 80) + print("RECOMMENDATIONS") + print("=" * 80) + print("\nTo benchmark other levels, run:") + for level in [0, 1, 3, 5]: + if level != LOGGING_LEVEL: + print(f" FLASHINFER_LOGLEVEL={level} python {sys.argv[0]}") + + print("\n" + "=" * 80) + print("Benchmark complete!") + print("=" * 80) + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("\n\nBenchmark interrupted by user.") + except Exception as e: + print(f"\n\nError during benchmark: {e}") + import traceback + + traceback.print_exc() diff --git a/docs/index.rst b/docs/index.rst index 6a5a9c6a19..f4e61d26c4 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -15,6 +15,7 @@ FlashInfer is a library and kernel generator for Large Language Models that prov :caption: Get Started installation + logging .. toctree:: :maxdepth: 2 diff --git a/docs/logging.rst b/docs/logging.rst new file mode 100644 index 0000000000..c3c2c83d8f --- /dev/null +++ b/docs/logging.rst @@ -0,0 +1,118 @@ +.. _logging: + +Logging +======= + +FlashInfer provides a logging feature to help debug issues and reproduce crashes. This document describes all available logging levels and their features. + +Quick Start +----------- + +Enable logging using two environment variables: + +.. code-block:: bash + + # Set logging level (0-5) + export FLASHINFER_LOGLEVEL=3 + + # Set log destination (default is stdout) + export FLASHINFER_LOGDEST=stdout # or stderr, or a file path like "flashinfer.log" + +Logging Levels +-------------- + +.. list-table:: + :header-rows: 1 + :widths: 10 20 35 25 + + * - Level + - Name + - Features + - Use Case + * - **0** + - Disabled (Default) + - No logging (zero overhead) + - Production + * - **1** + - Function Names + - Function names only + - Basic tracing + * - **3** + - Inputs/Outputs + - Function names + arguments + outputs with metadata + - Standard debugging + * - **5** + - Statistics + - Level 3 + tensor statistics (min, max, mean, NaN/Inf counts) + - Numerical analysis + +Environment Variables +--------------------- + +Main Configuration +^^^^^^^^^^^^^^^^^^ + +.. list-table:: + :header-rows: 1 + :widths: 30 15 15 40 + + * - Variable + - Type + - Default + - Description + * - ``FLASHINFER_LOGLEVEL`` + - int + - 0 + - Logging level (0, 1, 3, 5) + * - ``FLASHINFER_LOGDEST`` + - str + - ``stdout`` + - Log destination: ``stdout``, ``stderr``, or file path + +Process ID Substitution +^^^^^^^^^^^^^^^^^^^^^^^^ + +Use ``%i`` in file paths for automatic process ID substitution (useful for multi-GPU training): + +.. code-block:: bash + + export FLASHINFER_LOGDEST="flashinfer_log_%i.txt" # → flashinfer_log_12345.txt + + +Miscellaneous Notes and Examples +--------------------------------- + +CUDA Graph Compatibility +^^^^^^^^^^^^^^^^^^^^^^^^^ + +Level 5 statistics are **automatically skipped during CUDA graph capture** to avoid synchronization issues. + +.. code-block:: python + + # This works correctly - no synchronization errors + with torch.cuda.graph(cuda_graph): + result = mm_fp4(a, b, scales, ...) # Level 5 logging active + # Statistics automatically skipped during capture + +Output shows: ``[statistics skipped: CUDA graph capture in progress]`` + +Process IDs for Multi-GPU Environments +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: bash + + # Use %i for process ID substitution + export FLASHINFER_LOGLEVEL=3 + export FLASHINFER_LOGDEST="logs/flashinfer_api_%i.log" + + torchrun --nproc_per_node=8 awesome_script_that_uses_FlashInfer.py + + # Creates separate logs: + # logs/flashinfer_api_12345.log (rank 0) + # logs/flashinfer_api_12346.log (rank 1) + # ... + +Level 0 has zero overhead +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +At Level 0, the decorator returns the original function unchanged. No wrapper, no checks, no overhead. diff --git a/flashinfer/api_logging.py b/flashinfer/api_logging.py new file mode 100644 index 0000000000..734d6bae28 --- /dev/null +++ b/flashinfer/api_logging.py @@ -0,0 +1,565 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import enum +import functools +import inspect +import logging +import os +import sys +from typing import Any, Callable +import contextlib +import torch + + +# Helper function to substitute %i with process ID in file paths +def _substitute_process_id(path: str) -> str: + """ + Replace %i with the current process ID in a path. + + This is useful for multi-process/multi-GPU environments where each process + needs its own log file. + """ + if "%i" in path: + return path.replace("%i", str(os.getpid())) + return path + + +# Read environment variables once at module load time +_API_LOG_LEVEL = int(os.environ.get("FLASHINFER_LOGLEVEL", "0")) +_API_LOG_DEST = _substitute_process_id(os.environ.get("FLASHINFER_LOGDEST", "stdout")) + +# Create logger using Python's logging library +_logger = logging.getLogger("flashinfer.api") + + +def _setup_logger(): + """Set up the logger based on environment variables.""" + if _API_LOG_LEVEL == 0: + # Completely disable logging for zero overhead + _logger.addHandler(logging.NullHandler()) + _logger.setLevel(logging.CRITICAL + 1) # Higher than any level + return + + # All enabled levels use loggging.DEBUG; verbosity is controlled by FLASHINFER_LOGLEVEL instead + _logger.setLevel(logging.DEBUG) + + # Remove any existing handlers + _logger.handlers.clear() + + # Create handler based on destination + if _API_LOG_DEST == "stdout": + handler = logging.StreamHandler(sys.stdout) + elif _API_LOG_DEST == "stderr": + handler = logging.StreamHandler(sys.stderr) + else: + handler = logging.FileHandler(_API_LOG_DEST, mode="a") + + # Use a simple formatter (we'll add timestamps manually to key lines) + formatter = logging.Formatter("%(message)s") + handler.setFormatter(formatter) + + _logger.addHandler(handler) + _logger.propagate = False # Don't propagate to root logger + + +# Initialize logger at module load time +_setup_logger() + + +def _get_timestamp() -> str: + """Get current timestamp in the format [YYYY-MM-DD HH:MM:SS].""" + from datetime import datetime + + return datetime.now().strftime("[%Y-%m-%d %H:%M:%S]") + + +def _log_system_info(): + """Log system information once at module initialization.""" + if _API_LOG_LEVEL == 0: + return + + lines = [] + lines.append("=" * 80) + lines.append(f"{_get_timestamp()} FlashInfer API Logging - System Information") + lines.append("=" * 80) + + try: + # FlashInfer version + try: + from .version import __version__ as flashinfer_version + + lines.append(f"FlashInfer version: {flashinfer_version}") + except Exception: + lines.append("FlashInfer version: ") + + # CUDA toolkit version + cuda_version = torch.version.cuda + if cuda_version: + lines.append(f"CUDA toolkit version: {cuda_version}") + else: + lines.append("CUDA toolkit version: ") + + # cuDNN version + try: + if torch.backends.cudnn.is_available(): + cudnn_version = torch.backends.cudnn.version() + if cudnn_version: + lines.append(f"cuDNN version: {cudnn_version}") + else: + lines.append("cuDNN version: ") + else: + lines.append("cuDNN version: ") + except Exception as e: + lines.append(f"cuDNN version: ") + + # GPU information (if CUDA is available) + if torch.cuda.is_available(): + device_count = torch.cuda.device_count() + lines.append(f"Number of GPUs: {device_count}") + + # Log information for each GPU + for i in range(device_count): + try: + gpu_name = torch.cuda.get_device_name(i) + capability = torch.cuda.get_device_capability(i) + sm_arch = capability[0] * 10 + capability[1] + lines.append(f" GPU {i}: {gpu_name}") + lines.append( + f" Compute capability: {capability[0]}.{capability[1]} (SM{sm_arch})" + ) + except Exception as e: + lines.append(f" GPU {i}: ") + else: + lines.append("CUDA: Not available (CPU-only mode)") + + # PyTorch version + lines.append(f"PyTorch version: {torch.__version__}") + + except Exception as e: + lines.append(f"Error gathering system information: {e}") + + lines.append("=" * 80) + lines.append("") # Empty line for readability + + _logger.debug("\n".join(lines)) + + +# Log system information once at module load time (if logging is enabled) +_log_system_info() + + +def _format_value(value: Any, level: int, indent: int = 0) -> str: + """ + Format a value for logging based on the log level. + + Parameters + ---------- + value : Any + The value to format + level : int + The logging level (1, 2, or 3) + indent : int + The indentation level for nested structures + + Returns + ------- + str + Formatted string representation of the value + """ + indent_str = " " * indent + + # Handle None + if value is None: + return f"{indent_str}None" + + # Handle Enum types + if isinstance(value, enum.Enum): + # Show both the name and value of the enum + return ( + f"{indent_str}{value.__class__.__name__}.{value.name} (value={value.value})" + ) + + # Handle torch.Tensor + if isinstance(value, torch.Tensor): + if level == 1: + return f"{indent_str}Tensor(...)" + + # Level 3+: Show metadata + lines = [f"{indent_str}Tensor("] + lines.append(f"{indent_str} shape={tuple(value.shape)}") + lines.append(f"{indent_str} stride={tuple(value.stride())}") + lines.append(f"{indent_str} dtype={value.dtype}") + lines.append(f"{indent_str} device={value.device}") + lines.append(f"{indent_str} requires_grad={value.requires_grad}") + lines.append(f"{indent_str} is_contiguous={value.is_contiguous()}") + + # Level 5: Add statistics + if level >= 5: + try: + # Skip statistics if we're in CUDA graph capture mode + # (operations like .min()/.max()/.mean() cause synchronization issues) + is_capturing = False + if value.is_cuda and hasattr(torch.cuda, "is_current_stream_capturing"): + with contextlib.suppress(Exception): + is_capturing = torch.cuda.is_current_stream_capturing() + + if is_capturing: + lines.append( + f"{indent_str} [statistics skipped: CUDA graph capture in progress]" + ) + elif value.numel() > 0: + # Convert to float for statistics if possible + if value.dtype in [ + torch.float16, + torch.float32, + torch.float64, + torch.bfloat16, + torch.float8_e4m3fn, + torch.float8_e5m2, + ]: + val_float = value.float() + lines.append(f"{indent_str} min={val_float.min().item():.6f}") + lines.append(f"{indent_str} max={val_float.max().item():.6f}") + lines.append( + f"{indent_str} mean={val_float.mean().item():.6f}" + ) + nan_count = torch.isnan(val_float).sum().item() + lines.append(f"{indent_str} nan_count={nan_count}") + inf_count = torch.isinf(val_float).sum().item() + lines.append(f"{indent_str} inf_count={inf_count}") + elif value.dtype in [ + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.uint8, + ]: + lines.append(f"{indent_str} min={value.min().item()}") + lines.append(f"{indent_str} max={value.max().item()}") + lines.append( + f"{indent_str} mean={value.float().mean().item():.6f}" + ) + except Exception as e: + lines.append(f"{indent_str} [statistics error: {e}]") + + lines.append(f"{indent_str})") + return "\n".join(lines) + + # Handle FP4Tensor (custom FlashInfer type) + if hasattr(value, "__class__") and value.__class__.__name__ == "FP4Tensor": + if level == 1: + return f"{indent_str}FP4Tensor(...)" + + lines = [f"{indent_str}FP4Tensor("] + lines.append( + f"{indent_str} data={_format_value(value.data, level, indent + 1)}" + ) + lines.append( + f"{indent_str} scale={_format_value(value.scale, level, indent + 1)}" + ) + lines.append(f"{indent_str} scale_start_index={value.scale_start_index}") + if hasattr(value, "original_shape") and value.original_shape is not None: + lines.append(f"{indent_str} original_shape={value.original_shape}") + lines.append(f"{indent_str})") + return "\n".join(lines) + + # Handle lists + if isinstance(value, list): + if len(value) == 0: + return f"{indent_str}[]" + if level == 1: + return f"{indent_str}[list with {len(value)} items]" + + lines = [f"{indent_str}["] + for i, item in enumerate(value): + lines.append( + f"{indent_str} [{i}]: {_format_value(item, level, indent + 1)}" + ) + lines.append(f"{indent_str}]") + return "\n".join(lines) + + # Handle tuples + if isinstance(value, tuple): + if len(value) == 0: + return f"{indent_str}()" + if level == 1: + return f"{indent_str}(tuple with {len(value)} items)" + + lines = [f"{indent_str}("] + for i, item in enumerate(value): + lines.append( + f"{indent_str} [{i}]: {_format_value(item, level, indent + 1)}" + ) + lines.append(f"{indent_str})") + return "\n".join(lines) + + # Handle dictionaries + if isinstance(value, dict): + if len(value) == 0: + return f"{indent_str}{{}}" + if level == 1: + return f"{indent_str}{{dict with {len(value)} keys}}" + + lines = [f"{indent_str}{{"] + for key, val in value.items(): + lines.append( + f"{indent_str} {repr(key)}: {_format_value(val, level, indent + 1)}" + ) + lines.append(f"{indent_str}}}") + return "\n".join(lines) + + # Handle numeric types (int, float, bool) + if isinstance(value, (int, float, bool, complex)): + return f"{indent_str}{value}" + + # Handle strings + if isinstance(value, str): + return f"{indent_str}{repr(value)}" + + # Default: use repr + try: + return f"{indent_str}{repr(value)}" + except Exception: + return f"{indent_str}<{type(value).__name__} object>" + + +def _get_default_params(func: Callable, args: tuple, kwargs: dict) -> dict: + """ + Extract parameters that have default values but were not explicitly provided. + + Parameters + ---------- + func : Callable + The function being called + args : tuple + Positional arguments that were provided + kwargs : dict + Keyword arguments that were provided + + Returns + ------- + dict + Dictionary of parameter names to default values for parameters that were not provided + """ + try: + sig = inspect.signature(func) + default_params = {} + + # Determine which parameters were NOT provided + for i, (param_name, param) in enumerate(sig.parameters.items()): + # Skip if parameter has no default + if param.default is inspect.Parameter.empty: + continue + + # Check if this parameter was provided + provided = False + + # Check positional args and keyword args + if i < len(args) or param_name in kwargs: + provided = True + + # If not provided, record the default value + if not provided: + default_params[param_name] = param.default + + return default_params + except Exception: + # If we can't inspect the signature, return empty dict + return {} + + +def _log_function_inputs( + func: Callable, func_name: str, args: tuple, kwargs: dict, level: int +) -> None: + """ + Log function inputs BEFORE execution for crash safety. + + This ensures inputs are captured even if the function crashes with a CUDA error. + + Parameters + ---------- + func : Callable + The function being called (needed to extract default parameters) + func_name : str + Name of the function being called + args : tuple + Positional arguments + kwargs : dict + Keyword arguments + level : int + Logging level (3 or 5) + """ + lines = [] + lines.append("=" * 80) + lines.append(f"{_get_timestamp()} FlashInfer API Call: {func_name}") + lines.append("-" * 80) + + # Log explicitly provided inputs + if args or kwargs: + # Positional arguments + if args: + lines.append("Positional input arguments:") + for i, arg in enumerate(args): + lines.append(f" arg[{i}]:") + lines.append(_format_value(arg, level, indent=2)) + + # Keyword arguments + if kwargs: + lines.append("Keyword input arguments:") + for key, value in kwargs.items(): + lines.append(f" {key}=") + lines.append(_format_value(value, level, indent=2)) + else: + lines.append("(No explicit arguments)") + + # Log default parameters that were not explicitly provided + default_params = _get_default_params(func, args, kwargs) + if default_params: + lines.append("Default parameters (not explicitly provided):") + for param_name, default_value in default_params.items(): + lines.append(f" {param_name}= [DEFAULT]") + lines.append(_format_value(default_value, level, indent=2)) + + _logger.debug("\n".join(lines)) + + +def _log_function_outputs(func_name: str, result: Any, level: int) -> None: + """ + Log function outputs AFTER successful execution. + + Parameters + ---------- + func_name : str + Name of the function + result : Any + Function return value + level : int + Logging level (3 or 5) + """ + lines = [] + # Log outputs + lines.append("Output value:") + lines.append(_format_value(result, level, indent=1)) + + lines.append("=" * 80) + lines.append("") # Empty line for readability + + _logger.debug("\n".join(lines)) + + +def flashinfer_api(func: Callable = None) -> Callable: + """ + Decorator to FlashInfer's APIs. + + Currently logs input and output values of the function using Python's logging library. + This decorator integrates with Python's standard logging infrastructure while + maintaining zero overhead when disabled (FLASHINFER_LOGLEVEL=0). + + NOTE/TODO: Not all FlashInfer APIs are decorated with this decorator yet. This is a work in progress. + + Environment Variables + --------------------- + FLASHINFER_LOGLEVEL : int (default: 0) + - 0: No logging (zero overhead - decorator returns original function) + - 1: Log function name only (logged BEFORE execution - crash-safe) + - 3: Log function name + inputs/outputs with metadata (inputs logged BEFORE execution - crash-safe) + - 5: Log function name + inputs/outputs with metadata + tensor statistics (inputs logged BEFORE execution - crash-safe) + + FLASHINFER_LOGDEST : str (default: "stdout") + - "stdout": Log to standard output + - "stderr": Log to standard error + - : Log to specified file path + - Use %i in path for process ID substitution (e.g., "log_%i.txt" -> "log_12345.txt") + + Examples + -------- + Basic usage: + + >>> @flashinfer_api + ... def my_function(x, y): + ... return x + y + + Notes + ----- + - Key header lines include a timestamp in the format: [YYYY-MM-DD HH:MM:SS] + (e.g., "FlashInfer API Call: function_name", "FlashInfer API Logging - System Information") + - When FLASHINFER_LOGLEVEL=0, the decorator has truly zero overhead + as it returns the original function unchanged. + - Function names and inputs are logged BEFORE execution: + - Level 1: Function name only + - Levels 3-5: Function name + inputs with metadata + This means critical debugging information is preserved even if the function + crashes (e.g., CUDA illegal memory access, out-of-bounds, etc.). + - Outputs are logged AFTER successful execution for levels 3 and 5. + - **CUDA Graph Compatibility**: At level 5, tensor statistics (min/max/mean/nan_count) + are automatically skipped during CUDA graph capture to avoid synchronization issues. + The message "[statistics skipped: CUDA graph capture in progress]" will be logged. + - The %i pattern is automatically replaced with the process ID for multi-process environments. + - The logger does not propagate to the root logger to avoid duplicate logs. + """ + # If logging is disabled, return original function with zero overhead + if _API_LOG_LEVEL == 0: + if func is None: + return lambda f: f + return func + + def decorator(f: Callable) -> Callable: + @functools.wraps(f) + def wrapper(*args, **kwargs): + # Determine function name (with class name if applicable) + func_name = f.__name__ + if args and hasattr(args[0], "__class__"): + try: + class_name = args[0].__class__.__name__ + if "Wrapper" in class_name or class_name in [ + "BatchMLAPagedAttentionWrapper" + ]: + func_name = f"{class_name}.{func_name}" + except Exception: + pass + + # Log BEFORE execution (crash-safe for all levels!) + try: + if _API_LOG_LEVEL == 1: + # Level 1: Just log function name before execution (crash-safe) + _logger.debug( + f"{_get_timestamp()} FlashInfer API Call: {func_name}" + ) + elif _API_LOG_LEVEL >= 3: + # Level 3+: Log full inputs before execution (crash-safe) + _log_function_inputs(f, func_name, args, kwargs, _API_LOG_LEVEL) + except Exception as e: + _logger.error(f"[LOGGING ERROR in {func_name} (pre-execution)]: {e}") + + # Call the original function (may crash here with CUDA errors) + result = f(*args, **kwargs) + + # Log outputs AFTER successful execution (level 3+ only) + try: + if _API_LOG_LEVEL >= 3: + # Level 3+: Log outputs (inputs were already logged above) + _log_function_outputs(func_name, result, _API_LOG_LEVEL) + except Exception as e: + _logger.error(f"[LOGGING ERROR in {func_name} (outputs)]: {e}") + + return result + + return wrapper + + if func is None: + return decorator + return decorator(func) diff --git a/flashinfer/cudnn/decode.py b/flashinfer/cudnn/decode.py index 6ef13b997f..195ca2d49d 100644 --- a/flashinfer/cudnn/decode.py +++ b/flashinfer/cudnn/decode.py @@ -3,6 +3,7 @@ import torch +from ..api_logging import flashinfer_api from .utils import get_cudnn_fmha_gen_module try: @@ -252,6 +253,7 @@ def _batch_decode_with_kv_cache( return out +@flashinfer_api def cudnn_batch_decode_with_kv_cache( q: torch.Tensor, k_cache: torch.Tensor, diff --git a/flashinfer/cudnn/prefill.py b/flashinfer/cudnn/prefill.py index fc573cf7cb..b8c09a66ee 100644 --- a/flashinfer/cudnn/prefill.py +++ b/flashinfer/cudnn/prefill.py @@ -3,6 +3,7 @@ import torch +from ..api_logging import flashinfer_api from .utils import get_cudnn_fmha_gen_module try: @@ -383,6 +384,7 @@ def _batch_prefill_with_kv_cache( return out, None +@flashinfer_api def cudnn_batch_prefill_with_kv_cache( q: torch.Tensor, k_cache: torch.Tensor, diff --git a/flashinfer/decode.py b/flashinfer/decode.py index af8dda0345..ab34ba8857 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -21,6 +21,7 @@ import torch +from .api_logging import flashinfer_api from .xqa import xqa, xqa_mla from .cudnn import cudnn_batch_decode_with_kv_cache as cudnn_batch_decode_with_kv_cache from .jit import ( @@ -312,6 +313,7 @@ def get_trtllm_gen_fmha_module(): return op +@flashinfer_api def single_decode_with_kv_cache_with_jit_module( jit_module: Any, q: torch.Tensor, @@ -388,6 +390,7 @@ def single_decode_with_kv_cache( ) -> Tuple[torch.Tensor, torch.Tensor]: ... +@flashinfer_api def single_decode_with_kv_cache( q: torch.Tensor, k: torch.Tensor, @@ -646,6 +649,7 @@ class BatchDecodeWithPagedKVCacheWrapper: manages the lifecycle of these data structures. """ + @flashinfer_api def __init__( self, float_workspace_buffer: torch.Tensor, @@ -809,6 +813,7 @@ def reset_workspace_buffer( pin_memory=True, ) + @flashinfer_api def plan( self, indptr: torch.Tensor, @@ -1162,6 +1167,7 @@ def run( window_left: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: ... + @flashinfer_api def run( self, q: torch.Tensor, @@ -2059,6 +2065,7 @@ def _fake_paged_run( ) +@flashinfer_api def trtllm_batch_decode_with_kv_cache( query: torch.Tensor, kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], @@ -2332,6 +2339,7 @@ def trtllm_batch_decode_with_kv_cache( # xqa uses NHD layout +@flashinfer_api def xqa_batch_decode_with_kv_cache( query: torch.Tensor, kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], @@ -2516,6 +2524,7 @@ def _check_trtllm_gen_mla_shape( ) +@flashinfer_api def trtllm_batch_decode_with_kv_cache_mla( query: torch.Tensor, kv_cache: torch.Tensor, @@ -2677,6 +2686,7 @@ def trtllm_batch_decode_with_kv_cache_mla( raise ValueError(f"Backend {backend} not supported") +@flashinfer_api def xqa_batch_decode_with_kv_cache_mla( query: torch.Tensor, kv_cache: torch.Tensor, diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 3c5e7a09c5..7b53c3f82c 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -20,6 +20,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch +from ..api_logging import flashinfer_api from ..autotuner import ( AutoTuner, DynamicTensorSpec, @@ -685,6 +686,7 @@ def _fake_cutlass_fused_moe( # ref: https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py#L121 +@flashinfer_api def cutlass_fused_moe( input: torch.Tensor, token_selected_experts: torch.Tensor, @@ -1857,6 +1859,7 @@ def _fake_trtllm_fp4_block_scale_moe( ) +@flashinfer_api def trtllm_bf16_moe( routing_logits: torch.Tensor, routing_bias: Optional[torch.Tensor], @@ -1937,6 +1940,7 @@ def trtllm_bf16_moe( ) +@flashinfer_api def trtllm_fp8_per_tensor_scale_moe( routing_logits: torch.Tensor, routing_bias: Optional[torch.Tensor], @@ -2010,6 +2014,7 @@ def trtllm_fp8_per_tensor_scale_moe( ) +@flashinfer_api def trtllm_fp8_block_scale_moe( routing_logits: torch.Tensor, routing_bias: Optional[torch.Tensor], @@ -2087,6 +2092,7 @@ def trtllm_fp8_block_scale_moe( ) +@flashinfer_api def trtllm_fp4_block_scale_moe( routing_logits: torch.Tensor, routing_bias: Optional[torch.Tensor], @@ -2216,6 +2222,7 @@ def trtllm_fp4_block_scale_moe( ) +@flashinfer_api def trtllm_fp4_block_scale_routed_moe( topk_ids: torch.Tensor, routing_bias: Optional[torch.Tensor], diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 589c651aca..251e2a4682 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -22,6 +22,7 @@ from flashinfer.trtllm_low_latency_gemm import trtllm_low_latency_gemm import torch +from ..api_logging import flashinfer_api from ..autotuner import ( AutoTuner, ConstraintSpec, @@ -539,6 +540,7 @@ def forward( ) +@flashinfer_api def tgv_gemm_sm100( a: torch.Tensor, b: torch.Tensor, @@ -884,6 +886,7 @@ def reset_workspace_buffer( self._float_workspace_buffer = float_workspace_buffer self._int_workspace_buffer = int_workspace_buffer + @flashinfer_api def run( self, x: torch.Tensor, @@ -1551,6 +1554,7 @@ def _expand_block_scale_tensor_shape(block_scale_tensor, batch_size): return (tuple(block_scale_shape), tuple(block_scale_stride)) +@flashinfer_api def mm_fp8( a: torch.Tensor, b: torch.Tensor, @@ -2024,6 +2028,7 @@ def _heuristic_func_mm_fp4( common_check=_check_mm_fp4_problem_size, heuristic_func=_heuristic_func_mm_fp4, # result stored in mm_fp4.suitable_auto_backends ) +@flashinfer_api def mm_fp4( a: torch.Tensor, b: torch.Tensor, @@ -2281,6 +2286,7 @@ def _heuristic_func_bmm_fp8( common_check=_check_bmm_fp8_problem_size, heuristic_func=_heuristic_func_bmm_fp8, ) +@flashinfer_api def bmm_fp8( A: torch.Tensor, B: torch.Tensor, @@ -2372,6 +2378,7 @@ def bmm_fp8( return out +@flashinfer_api def gemm_fp8_nt_groupwise( a: torch.Tensor, b: torch.Tensor, @@ -2623,6 +2630,7 @@ def forward( ) +@flashinfer_api def gemm_fp8_nt_blockscaled( a: torch.Tensor, b: torch.Tensor, @@ -2651,6 +2659,7 @@ def gemm_fp8_nt_blockscaled( ) +@flashinfer_api def group_gemm_fp8_nt_groupwise( a: torch.Tensor, # (cum_m, k) b: torch.Tensor, # (batch_size, n, k) @@ -2813,6 +2822,7 @@ def group_gemm_fp8_nt_groupwise( return out +@flashinfer_api def group_gemm_mxfp8_mxfp4_nt_groupwise( a: torch.Tensor, # (cum_m, k) b: torch.Tensor, # (batch_size, n, k // 2) @@ -2980,6 +2990,7 @@ def get_deepgemm_sm100_module(): return module +@flashinfer_api def group_deepgemm_fp8_nt_groupwise( a: torch.Tensor, # (m, k) b: torch.Tensor, # (batch_size, n, k) @@ -3110,6 +3121,7 @@ def group_deepgemm_fp8_nt_groupwise( return out +@flashinfer_api def batch_deepgemm_fp8_nt_groupwise( a: torch.Tensor, # (batch_size, m, k) b: torch.Tensor, # (batch_size, n, k) diff --git a/flashinfer/mla.py b/flashinfer/mla.py index da57d94e6b..22cf029a2e 100644 --- a/flashinfer/mla.py +++ b/flashinfer/mla.py @@ -19,6 +19,7 @@ import torch +from .api_logging import flashinfer_api from .jit import gen_batch_mla_module from .jit.mla import gen_mla_module from .utils import MaskMode, check_shape_dtype_device, determine_mla_backend @@ -129,6 +130,7 @@ class BatchMLAPagedAttentionWrapper: torch.Size([114, 128, 512]) """ + @flashinfer_api def __init__( self, float_workspace_buffer: torch.Tensor, @@ -199,6 +201,7 @@ def __init__( else: self._backend = backend + @flashinfer_api def plan( self, qo_indptr: torch.Tensor, @@ -333,6 +336,7 @@ def run( return_lse_base_on_e: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: ... + @flashinfer_api def run( self, q_nope: torch.Tensor, diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 47d725c5d3..a2c4ceb0a8 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -22,6 +22,7 @@ import torch +from .api_logging import flashinfer_api from .jit import ( gen_batch_prefill_module, gen_customize_batch_prefill_module, @@ -873,6 +874,7 @@ def _fake_paged_run( ) +@flashinfer_api def single_prefill_with_kv_cache_with_jit_module( jit_module: Any, q: torch.Tensor, @@ -957,6 +959,7 @@ def single_prefill_with_kv_cache( ) -> Tuple[torch.Tensor, torch.Tensor]: ... +@flashinfer_api def single_prefill_with_kv_cache( q: torch.Tensor, k: torch.Tensor, @@ -1325,6 +1328,7 @@ class BatchPrefillWithPagedKVCacheWrapper: wrapper class manages the lifecycle of these data structures. """ + @flashinfer_api def __init__( self, float_workspace_buffer: torch.Tensor, @@ -1520,6 +1524,7 @@ def reset_workspace_buffer( pin_memory=True, ) + @flashinfer_api def plan( self, qo_indptr: torch.Tensor, @@ -1976,6 +1981,7 @@ def run( window_left: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: ... + @flashinfer_api def run( self, q: torch.Tensor, @@ -2350,6 +2356,7 @@ class BatchPrefillWithRaggedKVCacheWrapper: wrapper class manages the lifecycle of these data structures. """ + @flashinfer_api def __init__( self, float_workspace_buffer: torch.Tensor, @@ -2493,6 +2500,7 @@ def reset_workspace_buffer( pin_memory=True, ) + @flashinfer_api def plan( self, qo_indptr: torch.Tensor, @@ -2837,6 +2845,7 @@ def run( enable_pdl: Optional[bool] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: ... + @flashinfer_api def run( self, q: torch.Tensor, @@ -3193,6 +3202,7 @@ def get_trtllm_gen_fmha_module(): return op +@flashinfer_api def trtllm_ragged_attention_deepseek( query: torch.Tensor, key: torch.Tensor, @@ -3327,6 +3337,7 @@ def trtllm_ragged_attention_deepseek( return out +@flashinfer_api def trtllm_batch_context_with_kv_cache( query: torch.Tensor, kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], diff --git a/tests/utils/test_logging.py b/tests/utils/test_logging.py new file mode 100644 index 0000000000..6ead5e7d6b --- /dev/null +++ b/tests/utils/test_logging.py @@ -0,0 +1,588 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os +import sys +import tempfile +from enum import Enum +from pathlib import Path + +import pytest +import torch + + +# Test enum classes +class TestEnum(Enum): + """Test enum with integer values.""" + + OPTION_A = 0 + OPTION_B = 1 + OPTION_C = 2 + + +class StringEnum(Enum): + """Test enum with string values. Names are for testing purposes.""" + + MODE_STANDARD = "standard" + MODE_OPTIMIZED = "optimized" + + +class TestAPILogging: + """Test suite for FlashInfer API logging infrastructure.""" + + @pytest.fixture(autouse=True) + def setup_and_teardown(self): + """Reset environment and reimport logging module for each test.""" + # Store original environment + original_level = os.environ.get("FLASHINFER_LOGLEVEL") + original_dest = os.environ.get("FLASHINFER_LOGDEST") + + yield + + # Restore original environment + if original_level is not None: + os.environ["FLASHINFER_LOGLEVEL"] = original_level + elif "FLASHINFER_LOGLEVEL" in os.environ: + del os.environ["FLASHINFER_LOGLEVEL"] + + if original_dest is not None: + os.environ["FLASHINFER_LOGDEST"] = original_dest + elif "FLASHINFER_LOGDEST" in os.environ: + del os.environ["FLASHINFER_LOGDEST"] + + # Force reimport to pick up new environment variables + if "flashinfer.api_logging" in sys.modules: + del sys.modules["flashinfer.api_logging"] + + def setup_logging(self, level: int, dest: str = "stdout"): + """Helper to set up logging environment and reimport.""" + os.environ["FLASHINFER_LOGLEVEL"] = str(level) + os.environ["FLASHINFER_LOGDEST"] = dest + + # Force reimport + if "flashinfer.api_logging" in sys.modules: + del sys.modules["flashinfer.api_logging"] + + from flashinfer.api_logging import flashinfer_api + + return flashinfer_api + + def test_level_0_zero_overhead(self): + """Test that level 0 has truly zero overhead (returns original function).""" + decorator = self.setup_logging(level=0) + + def original_func(x, y): + return x + y + + decorated_func = decorator(original_func) + + # At level 0, decorator should return the original function unchanged + assert decorated_func is original_func + assert decorated_func(5, 3) == 8 + + def test_level_1_function_name(self): + """Test that level 1 logs function name only.""" + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=1, dest=log_file) + + @decorator + def test_function(x, y): + return x + y + + result = test_function(10, 20) + assert result == 30 + + # Check log contents + with open(log_file, "r") as f: + log_contents = f.read() + + assert "FlashInfer API Call: test_function" in log_contents + # Level 1 should not log inputs/outputs details + assert "Positional input arguments" not in log_contents + assert "Output value" not in log_contents + finally: + Path(log_file).unlink(missing_ok=True) + + def test_level_3_inputs_outputs(self): + """Test that level 3 logs inputs and outputs with metadata.""" + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=3, dest=log_file) + + @decorator + def test_function(tensor, value): + return tensor * value + + tensor = torch.tensor([1.0, 2.0, 3.0]) + test_function(tensor, 2.0) + + # Check log contents + with open(log_file, "r") as f: + log_contents = f.read() + + # Should log function name + assert "FlashInfer API Call: test_function" in log_contents + + # Should log inputs + assert "Positional input arguments" in log_contents + assert "arg[0]" in log_contents + assert "Tensor(" in log_contents + assert "shape=(3,)" in log_contents + assert "dtype=torch.float32" in log_contents + + # Should log outputs + assert "Output value:" in log_contents + + # Should NOT log statistics (level 5 only) + assert "min=" not in log_contents + assert "max=" not in log_contents + finally: + Path(log_file).unlink(missing_ok=True) + + def test_level_5_statistics(self): + """Test that level 5 logs tensor statistics.""" + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=5, dest=log_file) + + @decorator + def test_function(tensor): + return tensor + 1.0 + + tensor = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) + test_function(tensor) + + # Check log contents + with open(log_file, "r") as f: + log_contents = f.read() + + # Should log statistics + assert "min=" in log_contents + assert "max=" in log_contents + assert "mean=" in log_contents + assert "nan_count=" in log_contents + assert "inf_count=" in log_contents + finally: + Path(log_file).unlink(missing_ok=True) + + def test_enum_logging(self): + """Test that enum values are logged with name and value.""" + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=3, dest=log_file) + + @decorator + def test_function(mode: TestEnum, strategy: StringEnum): + return f"{mode.name}_{strategy.name}" + + test_function(TestEnum.OPTION_B, StringEnum.MODE_OPTIMIZED) + + # Check log contents + with open(log_file, "r") as f: + log_contents = f.read() + + # Should show enum name and value + assert "TestEnum.OPTION_B" in log_contents + assert "(value=1)" in log_contents + assert "StringEnum.MODE_OPTIMIZED" in log_contents + assert ( + "(value=optimized)" in log_contents + or "(value='optimized')" in log_contents + or '(value="optimized")' in log_contents + ) + finally: + Path(log_file).unlink(missing_ok=True) + + def test_default_parameters(self): + """Test that default parameters are logged separately.""" + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=3, dest=log_file) + + @decorator + def test_function(x, y=10, z=20, mode=TestEnum.OPTION_A): + return x + y + z + + # Call with only required argument + result = test_function(5) + assert result == 35 + + # Check log contents + with open(log_file, "r") as f: + log_contents = f.read() + + # Should show default parameters section + assert "Default parameters (not explicitly provided)" in log_contents + assert "[DEFAULT]" in log_contents + + # Should show the default values + assert "y=" in log_contents + assert "z=" in log_contents + assert "mode=" in log_contents + finally: + Path(log_file).unlink(missing_ok=True) + + def test_explicit_vs_default_parameters(self): + """Test that explicitly provided parameters are not shown in defaults.""" + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=3, dest=log_file) + + @decorator + def test_function(x, y=10, z=20): + return x + y + z + + # Call with some explicit parameters + test_function(5, y=100) + + # Check log contents + with open(log_file, "r") as f: + log_contents = f.read() + + # y should be in keyword arguments (explicit) + assert "Keyword input arguments:" in log_contents + + # Only z should be in defaults + lines = log_contents.split("\n") + default_section_started = False + defaults_found = [] + for line in lines: + if "Default parameters" in line: + default_section_started = True + if default_section_started and "=" in line and "[DEFAULT]" in line: + defaults_found.append(line) + + # Should have only one default parameter (z) + assert len(defaults_found) == 1 + assert "z=" in defaults_found[0] + finally: + Path(log_file).unlink(missing_ok=True) + + def test_class_method_logging(self): + """Test that class methods log with class name.""" + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=1, dest=log_file) + + class TestWrapper: + @decorator + def run(self, x): + return x * 2 + + wrapper = TestWrapper() + result = wrapper.run(5) + assert result == 10 + + # Check log contents + with open(log_file, "r") as f: + log_contents = f.read() + + # Should log class name for Wrapper classes + assert "TestWrapper.run" in log_contents + finally: + Path(log_file).unlink(missing_ok=True) + + def test_crash_safety_inputs_logged_before_execution(self): + """Test that inputs are logged BEFORE execution (crash-safe).""" + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=3, dest=log_file) + + @decorator + def crashing_function(x, y): + raise RuntimeError("Simulated crash") + + # Call the function and expect it to crash + with pytest.raises(RuntimeError, match="Simulated crash"): + crashing_function(42, 99) + + # Check that inputs were still logged + with open(log_file, "r") as f: + log_contents = f.read() + + # Inputs should be in the log even though function crashed + assert "FlashInfer API Call: crashing_function" in log_contents + assert "Positional input arguments" in log_contents + assert "arg[0]" in log_contents + assert "42" in log_contents + assert "arg[1]" in log_contents + assert "99" in log_contents + + # Outputs should NOT be in the log (function crashed) + assert "Output value:" not in log_contents + finally: + Path(log_file).unlink(missing_ok=True) + + def test_different_data_types(self): + """Test logging of various data types.""" + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=3, dest=log_file) + + @decorator + def test_function( + int_val, + float_val, + bool_val, + str_val, + list_val, + tuple_val, + dict_val, + none_val, + ): + return "success" + + test_function( + 42, 3.14, True, "hello", [1, 2, 3], (4, 5, 6), {"key": "value"}, None + ) + + # Check log contents + with open(log_file, "r") as f: + log_contents = f.read() + + # Should log all types correctly + assert "42" in log_contents + assert "3.14" in log_contents + assert "True" in log_contents + assert "'hello'" in log_contents + assert "None" in log_contents + finally: + Path(log_file).unlink(missing_ok=True) + + def test_tensor_metadata(self): + """Test that tensor metadata is logged correctly.""" + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=3, dest=log_file) + + @decorator + def test_function(tensor): + return tensor + + # Create a tensor with specific properties + tensor = torch.randn(2, 3, 4, dtype=torch.float32, device="cpu") + tensor = tensor.contiguous() + tensor.requires_grad = False + + test_function(tensor) + + # Check log contents + with open(log_file, "r") as f: + log_contents = f.read() + + # Should log all metadata + assert "shape=(2, 3, 4)" in log_contents + assert "dtype=torch.float32" in log_contents + assert "device=cpu" in log_contents + assert "requires_grad=False" in log_contents + assert "is_contiguous=True" in log_contents + assert "stride=" in log_contents + finally: + Path(log_file).unlink(missing_ok=True) + + def test_nested_structures(self): + """Test logging of nested data structures.""" + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=3, dest=log_file) + + @decorator + def test_function(nested): + return nested + + # Create nested structure + nested = { + "list": [1, 2, 3], + "dict": {"inner": "value"}, + "tuple": (4, 5), + } + + test_function(nested) + + # Check log contents + with open(log_file, "r") as f: + log_contents = f.read() + + # Should handle nested structures + assert "list" in log_contents + assert "dict" in log_contents + assert "tuple" in log_contents + finally: + Path(log_file).unlink(missing_ok=True) + + def test_decorator_with_and_without_parentheses(self): + """Test that decorator works both as @decorator and @decorator().""" + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=1, dest=log_file) + + # Without parentheses + @decorator + def func1(x): + return x + 1 + + # With parentheses + @decorator() + def func2(x): + return x + 2 + + result1 = func1(10) + result2 = func2(20) + + assert result1 == 11 + assert result2 == 22 + + # Check log contents + with open(log_file, "r") as f: + log_contents = f.read() + + assert "func1" in log_contents + assert "func2" in log_contents + finally: + Path(log_file).unlink(missing_ok=True) + + def test_multiple_calls_same_function(self): + """Test that multiple calls to the same function are all logged.""" + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=1, dest=log_file) + + @decorator + def test_function(x): + return x + + # Call multiple times + for i in range(3): + test_function(i) + + # Check log contents + with open(log_file, "r") as f: + log_contents = f.read() + + # Should have 3 log entries + assert log_contents.count("FlashInfer API Call: test_function") == 3 + finally: + Path(log_file).unlink(missing_ok=True) + + def test_kwargs_logging(self): + """Test that keyword arguments are logged correctly.""" + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=3, dest=log_file) + + @decorator + def test_function(a, b, c): + return a + b + c + + # Call with keyword arguments + result = test_function(a=1, b=2, c=3) + assert result == 6 + + # Check log contents + with open(log_file, "r") as f: + log_contents = f.read() + + # Should log keyword arguments + assert "Keyword input arguments:" in log_contents + assert "a=" in log_contents + assert "b=" in log_contents + assert "c=" in log_contents + finally: + Path(log_file).unlink(missing_ok=True) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_cuda_graph_compatibility(self): + """Test that level 5 logging is compatible with CUDA graph capture.""" + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=5, dest=log_file) + + @decorator + def test_cuda_function(tensor): + return tensor * 2.0 + + # Create a CUDA tensor + tensor = torch.randn(10, 10, device="cuda") + + # Test 1: Normal execution (should have statistics) + test_cuda_function(tensor) + + with open(log_file, "r") as f: + log_normal = f.read() + + # Should have statistics in normal execution + # (unless PyTorch version is too old) + if hasattr(torch.cuda, "is_current_stream_capturing"): + # Normal execution should have min/max OR statistics error + has_stats = "min=" in log_normal or "statistics error" in log_normal + assert has_stats, "Expected statistics or error in normal execution" + + # Clear log file + with open(log_file, "w") as f: + f.write("") + + # Test 2: CUDA graph capture (should skip statistics) + if hasattr(torch.cuda, "CUDAGraph"): + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + test_cuda_function(tensor) + + with open(log_file, "r") as f: + log_capture = f.read() + + # Should skip statistics during capture + assert ( + "[statistics skipped: CUDA graph capture in progress]" + in log_capture + or "statistics" not in log_capture + ), "Expected statistics to be skipped during CUDA graph capture" + finally: + Path(log_file).unlink(missing_ok=True) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])