Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 89 additions & 4 deletions examples/distributed_inference/tensor_parallel_initialize_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Tensor Parallel Initialize Distributed Environment
==================================================

This module provides functions to initialize and clean up the distributed environment for tensor parallel distributed inference.
This module provides functions to initialize and clean up the distributed environment for tensor parallel distributed inference. These utilities are useful for tensor parallel distributed inference examples using torch.distributed.
"""

import logging
Expand All @@ -19,8 +19,65 @@
logger = logging.getLogger(__name__)


# this is kept at the application level, when mpirun is used to run the application
def initialize_distributed_env(rank=0, world_size=1, port=29500):
def initialize_logger(
rank, logger_file_name, file_level=logging.DEBUG, console_level=logging.INFO
):
"""Initialize rank-specific Torch-TensorRT logger with configurable handler levels.

Logger level is set to DEBUG (pass-through), handlers control filtering for files and stream buffers

Args:
rank: Process rank for multi-GPU
logger_file_name: Base name for log file (will add _rank.log)
file_level: What goes to file - default DEBUG (everything)
console_level: What prints to console - default INFO (clean output)
"""
logger = logging.getLogger("torch_tensorrt")
logger.setLevel(logging.DEBUG)
logger.handlers.clear()

# File handler
fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w")
fh.setLevel(file_level)
fh.setFormatter(
logging.Formatter(
f"[Rank {rank}] %(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
)
logger.addHandler(fh)

# console handler
ch = logging.StreamHandler()
ch.setLevel(
console_level
) # Console handler controls what's printed in console output
ch.setFormatter(logging.Formatter(f"[Rank {rank}] %(levelname)s: %(message)s"))
logger.addHandler(ch)

# safegauard though not reqd
logger.propagate = False
return logger


# This is required for env initialization since we use mpirun
def initialize_distributed_env(
logger_file_name,
rank=0,
world_size=1,
port=29500,
file_level="debug",
console_level="info",
):
"""Initialize distributed environment with handler-based logging.

Args:
logger_file_name: Base name for log files
rank: Initial rank (overridden by OMPI env vars)
world_size: Initial world size (overridden by OMPI env vars)
port: Master port for distributed communication
file_level: File handler level - "debug", "info", "warning" (default: "debug")
console_level: Console handler level - "debug", "info", "warning" (default: "info")
"""
local_rank = int(
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count())
)
Expand All @@ -44,12 +101,40 @@ def initialize_distributed_env(rank=0, world_size=1, port=29500):
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,))
rank = device_mesh.get_rank()
assert rank == local_rank
# Convert string handler levels to logging constants
level_map = {
"debug": logging.DEBUG,
"info": logging.INFO,
"warning": logging.WARNING,
"error": logging.ERROR,
}
file_level_int = level_map.get(file_level.lower(), logging.DEBUG)
console_level_int = level_map.get(console_level.lower(), logging.INFO)

# Initialize logger with handler-specific levels
# Logger itself is always DEBUG - handlers do the filtering
logger = initialize_logger(
rank,
logger_file_name,
file_level=file_level_int,
console_level=console_level_int,
)
device_id = (
rank % torch.cuda.device_count()
) # Ensure each rank gets a unique device
torch.cuda.set_device(device_id)

return device_mesh, world_size, rank
# Set C++ TensorRT runtime log level based on most verbose handler
# Use the most verbose level to ensure all important logs are captured
cpp_level = min(file_level_int, console_level_int)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dont we have an API that abstracts needing to detect if the C++ runtime is available? If not we should add one

Copy link
Collaborator Author

@apbose apbose Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added a function in _features.py for the above. And also moved all this to logging.py. Let me know if that function placment works

try:
import torch_tensorrt.logging as torchtrt_logging

torchtrt_logging.set_level(cpp_level)
except Exception as e:
logger.warning(f"Could not set C++ TensorRT log level: {e}")

return device_mesh, world_size, rank, logger


def cleanup_distributed_env():
Expand Down
36 changes: 34 additions & 2 deletions py/torch_tensorrt/_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,29 @@

def _enabled_features_str() -> str:
enabled = lambda x: "ENABLED" if x else "DISABLED"
out_str: str = f"Enabled Features:\n - Dynamo Frontend: {enabled(_DYNAMO_FE_AVAIL)}\n - Torch-TensorRT Runtime: {enabled(_TORCHTRT_RT_AVAIL)}\n - FX Frontend: {enabled(_FX_FE_AVAIL)}\n - TorchScript Frontend: {enabled(_TS_FE_AVAIL)}\n - Refit: {enabled(_REFIT_AVAIL)}\n - QDP Plugin: {enabled(_QDP_PLUGIN_AVAIL)} \n - TensorRT-RTX: {enabled(_TENSORRT_RTX)}\n" # type: ignore[no-untyped-call]
out_str: str = f"Enabled Features:\n - Dynamo Frontend: {enabled(_DYNAMO_FE_AVAIL)}\n - Torch-TensorRT Runtime: {enabled(_TORCHTRT_RT_AVAIL)}\n - FX Frontend: {enabled(_FX_FE_AVAIL)}\n - TorchScript Frontend: {enabled(_TS_FE_AVAIL)}\n - Refit: {enabled(_REFIT_AVAIL)}\n - QDP Plugin: {enabled(_QDP_PLUGIN_AVAIL)} \n - TensorRT-RTX: {enabled(_TENSORRT_RTX)}\n - TensorRT-LLM for NCCL: {enabled(_TRTLLM_AVAIL)}\n" # type: ignore[no-untyped-call]
return out_str


# Inline helper functions for checking feature availability
def has_torch_tensorrt_runtime() -> bool:
"""Check if Torch-TensorRT C++ runtime is available.

Returns:
bool: True if libtorchtrt_runtime.so or libtorchtrt.so is available
"""
return bool(ENABLED_FEATURES.torch_tensorrt_runtime)


def has_torchscript_frontend() -> bool:
"""Check if TorchScript frontend is available.

Returns:
bool: True if libtorchtrt.so is available
"""
return bool(ENABLED_FEATURES.torchscript_frontend)


def needs_tensorrt_rtx(f: Callable[..., Any]) -> Callable[..., Any]:
def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
if ENABLED_FEATURES.tensorrt_rtx:
Expand Down Expand Up @@ -163,14 +182,27 @@ def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:


def needs_trtllm_for_nccl(f: Callable[..., Any]) -> Callable[..., Any]:
"""
Runtime check decorator for TensorRT-LLM NCCL plugin availability.

WARNING: This decorator CANNOT prevent registration of converters at import time.
When used with @dynamo_tensorrt_converter, the converter is always registered
regardless of decorator order, because registration happens at import time before
the wrapper is called.

This decorator is kept for potential non-registration use cases where
runtime checks are appropriate.
@apbose: to discuss if this is required
"""

def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
if ENABLED_FEATURES.trtllm_for_nccl:
return f(*args, **kwargs)
else:

def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
raise NotImplementedError(
"Refit feature is currently not available in Python 3.13 or higher"
"TensorRT-LLM plugin for NCCL is not available"
)

return not_implemented(*args, **kwargs)
Expand Down
4 changes: 2 additions & 2 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch
import torch._dynamo as td
import torch_tensorrt.logging as torchtrt_logging
from torch._dynamo.backends.common import aot_autograd
from torch._dynamo.utils import detect_fake_mode
from torch._functorch.aot_autograd import aot_export_joint_simple
Expand All @@ -23,7 +24,6 @@
from torch_tensorrt.dynamo.utils import (
parse_dynamo_kwargs,
prepare_inputs,
set_log_level,
)

logger = logging.getLogger(__name__)
Expand All @@ -40,7 +40,7 @@ def torch_tensorrt_backend(
and "debug" in kwargs["options"]
and kwargs["options"]["debug"]
) or ("debug" in kwargs and kwargs["debug"]):
set_log_level(logger.parent, logging.DEBUG)
torchtrt_logging.set_level(logging.DEBUG)

DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend

Expand Down
78 changes: 47 additions & 31 deletions py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import tensorrt as trt
from torch.fx.node import Argument, Target
from torch_tensorrt._features import needs_trtllm_for_nccl
from torch_tensorrt._features import ENABLED_FEATURES
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
Expand All @@ -20,37 +20,53 @@
_LOGGER: logging.Logger = logging.getLogger(__name__)


@needs_trtllm_for_nccl
@dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op)
def fused_nccl_gather(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
return impl.nccl_ops.nccl_gather(
ctx,
target,
SourceIR.ATEN,
name,
[args[0]],
# Conditionally register NCCL converters only if TensorRT-LLM plugin is available.
# We use an `if` statement instead of @needs_trtllm_for_nccl decorator because
# @dynamo_tensorrt_converter ALWAYS registers at import time regardless of decorator
# order. Conditional registration prevents registration when TRTLLM is unavailable,
# allowing fallback to PyTorch execution for NCCL ops.

# Order 1: @needs_trtllm_for_nccl followed by registering the converter leads to plugin registry not finding nccl ops plugins since we register the bare converter, without the decorator
# Order 2: registering the converter first followed by @needs_trtllm_for_nccl leads to "NotImplementedError: TensorRT-LLM plugin for NCCL is not available :TensorRT-LLM plugin for NCCL is not available" and no fall back to pytorch
if ENABLED_FEATURES.trtllm_for_nccl:
_LOGGER.debug(
"TensorRT-LLM plugin for NCCL is available. Registering NCCL converters."
)

@dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op)
def fused_nccl_gather(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
return impl.nccl_ops.nccl_gather(
ctx,
target,
SourceIR.ATEN,
name,
[args[0]],
)

@dynamo_tensorrt_converter(tensorrt_fused_nccl_reduce_scatter_op)
def fused_nccl_reduce_scatter(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
return impl.nccl_ops.nccl_reduce_scatter(
ctx,
target,
SourceIR.ATEN,
name,
[args[0]],
)

@needs_trtllm_for_nccl
@dynamo_tensorrt_converter(tensorrt_fused_nccl_reduce_scatter_op)
def fused_nccl_reduce_scatter(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
return impl.nccl_ops.nccl_reduce_scatter(
ctx,
target,
SourceIR.ATEN,
name,
[args[0]],
else:
_LOGGER.info(
"TensorRT-LLM plugin for NCCL is not available. "
"NCCL operations will fall back to PyTorch execution."
)
28 changes: 0 additions & 28 deletions py/torch_tensorrt/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import dtype
from torch_tensorrt._features import ENABLED_FEATURES
from torch_tensorrt._Input import Input
from torch_tensorrt._utils import is_tensorrt_version_supported
from torch_tensorrt.dynamo import _defaults
Expand Down Expand Up @@ -270,33 +269,6 @@ def get_model_device(module: torch.fx.GraphModule) -> torch.device:
return device


def set_log_level(parent_logger: Any, level: Any) -> None:
"""
Sets the log level to the user provided level.
This is used to set debug logging at a global level
at entry points of tracing, dynamo and torch_compile compilation.
And set log level for c++ torch trt logger if runtime is available.
"""
if parent_logger:
parent_logger.setLevel(level)

if ENABLED_FEATURES.torch_tensorrt_runtime:
if level == logging.DEBUG:
log_level = trt.ILogger.Severity.VERBOSE
elif level == logging.INFO:
log_level = trt.ILogger.Severity.INFO
elif level == logging.WARNING:
log_level = trt.ILogger.Severity.WARNING
elif level == logging.ERROR:
log_level = trt.ILogger.Severity.ERROR
elif level == logging.CRITICAL:
log_level = trt.ILogger.Severity.INTERNAL_ERROR
else:
raise AssertionError(f"{level} is not valid log level")

torch.ops.tensorrt.set_logging_level(int(log_level))


def prepare_inputs(
inputs: Input | torch.Tensor | Sequence[Any] | Dict[Any, Any],
disable_memory_format_check: bool = False,
Expand Down
Loading
Loading