Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def cross_compile_for_windows(
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE,
cpu_memory_budget: int = _defaults.CPU_MEMORY_BUDGET,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows
Expand Down Expand Up @@ -178,6 +179,7 @@ def cross_compile_for_windows(
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model
cpu_memory_budget (int): The maximum amount of CPU memory to use for the compilation. If the compilation requires more memory than this budget, the compilation will fail. If set to -1, the compilation will use all available CPU memory.
**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
Expand Down Expand Up @@ -333,6 +335,7 @@ def cross_compile_for_windows(
"tiling_optimization_level": tiling_optimization_level,
"l2_limit_for_tiling": l2_limit_for_tiling,
"use_distributed_mode_trace": use_distributed_mode_trace,
"cpu_memory_budget": cpu_memory_budget,
}

# disable the following settings is not supported for cross compilation for windows feature
Expand Down Expand Up @@ -434,6 +437,7 @@ def compile(
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE,
cpu_memory_budget: int = _defaults.CPU_MEMORY_BUDGET,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
Expand Down Expand Up @@ -614,6 +618,10 @@ def compile(
"'arg_inputs' and 'inputs' should not be used at the same time."
)

assert (
cpu_memory_budget >= 2 * 1024 * 1024 * 1024
), "CPU memory budget must be greater than 10GB"

arg_inputs = inputs or arg_inputs

if kwarg_inputs is None:
Expand Down Expand Up @@ -680,8 +688,8 @@ def compile(
"l2_limit_for_tiling": l2_limit_for_tiling,
"offload_module_to_cpu": offload_module_to_cpu,
"use_distributed_mode_trace": use_distributed_mode_trace,
"cpu_memory_budget": cpu_memory_budget,
}

settings = CompilationSettings(**compilation_options)
logger.info("Compilation Settings: %s\n", settings)
exported_program = pre_export_lowering(exported_program, settings)
Expand Down Expand Up @@ -850,6 +858,16 @@ def preserve_module_specs(
require_full_compilation=settings.require_full_compilation,
)

from torch_tensorrt.dynamo.partitioning._resource_partitioner import (
resource_partition,
)

partitioned_module = resource_partition(
gm,
partitioned_module,
cpu_memory_budget=settings.cpu_memory_budget,
)

dryrun_tracker.unsupported_ops = supported_ops.unsupported_operators

# The global partitioner leaves non-TRT nodes as-is
Expand All @@ -868,6 +886,16 @@ def preserve_module_specs(
# Iterate over all components that can be accelerated
# Generate the corresponding TRT Module for those

# Here we delete the frozen parameters from the graph module. Note this does not affect the submodules. We are going to delete the frozen parameters from the submodules in the convert_module function.
# This is done to release CPU memory.
for attr in dir(gm):
if attr.startswith("_frozen_param"):
delattr(gm, attr)

from torch_tensorrt.dynamo.conversion._ConverterRegistry import DYNAMO_CONVERTERS

DYNAMO_CONVERTERS.disallowed_targets = set()

for name, _ in partitioned_module.named_children():
submodule = getattr(partitioned_module, name)
# filter on the GraphModule
Expand Down Expand Up @@ -1243,7 +1271,7 @@ def convert_exported_program_to_serialized_trt_engine(

# Prepare torch_trt inputs
trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs)
trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs)
trt_kwarg_inputs: Optional[dict[str, Any]] = prepare_inputs(kwarg_inputs)
device = to_torch_tensorrt_device(device)
enabled_precisions = {dtype._from(p) for p in enabled_precisions}

Expand Down
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import platform
import tempfile

import psutil
import torch
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import EngineCapability, dtype
Expand Down Expand Up @@ -57,6 +58,7 @@
L2_LIMIT_FOR_TILING = -1
USE_DISTRIBUTED_MODE_TRACE = False
OFFLOAD_MODULE_TO_CPU = False
CPU_MEMORY_BUDGET = psutil.virtual_memory().available

if platform.system() == "Linux":
import pwd
Expand Down
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch_tensorrt.dynamo._defaults import (
ASSUME_DYNAMIC_SHAPE_SUPPORT,
CACHE_BUILT_ENGINES,
CPU_MEMORY_BUDGET,
DISABLE_TF32,
DLA_GLOBAL_DRAM_SIZE,
DLA_LOCAL_DRAM_SIZE,
Expand Down Expand Up @@ -140,6 +141,7 @@ class CompilationSettings:
l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING
use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE
offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU
cpu_memory_budget: int = CPU_MEMORY_BUDGET

def __getstate__(self) -> dict[str, Any]:
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
Expand Down
Loading
Loading