Skip to content

Commit 6f580b6

Browse files
committed
squashed and cleaned the commits
1 parent 3a9bc3b commit 6f580b6

File tree

6 files changed

+781
-3
lines changed

6 files changed

+781
-3
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def cross_compile_for_windows(
104104
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
105105
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
106106
use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE,
107+
cpu_memory_budget: int = _defaults.CPU_MEMORY_BUDGET,
107108
**kwargs: Any,
108109
) -> torch.fx.GraphModule:
109110
"""Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows
@@ -178,6 +179,7 @@ def cross_compile_for_windows(
178179
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
179180
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
180181
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model
182+
cpu_memory_budget (int): The maximum amount of CPU memory to use for the compilation. If the compilation requires more memory than this budget, the compilation will fail. If set to -1, the compilation will use all available CPU memory.
181183
**kwargs: Any,
182184
Returns:
183185
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -333,6 +335,7 @@ def cross_compile_for_windows(
333335
"tiling_optimization_level": tiling_optimization_level,
334336
"l2_limit_for_tiling": l2_limit_for_tiling,
335337
"use_distributed_mode_trace": use_distributed_mode_trace,
338+
"cpu_memory_budget": cpu_memory_budget,
336339
}
337340

338341
# disable the following settings is not supported for cross compilation for windows feature
@@ -434,6 +437,7 @@ def compile(
434437
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
435438
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
436439
use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE,
440+
cpu_memory_budget: int = _defaults.CPU_MEMORY_BUDGET,
437441
**kwargs: Any,
438442
) -> torch.fx.GraphModule:
439443
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -680,8 +684,8 @@ def compile(
680684
"l2_limit_for_tiling": l2_limit_for_tiling,
681685
"offload_module_to_cpu": offload_module_to_cpu,
682686
"use_distributed_mode_trace": use_distributed_mode_trace,
687+
"cpu_memory_budget": cpu_memory_budget,
683688
}
684-
685689
settings = CompilationSettings(**compilation_options)
686690
logger.info("Compilation Settings: %s\n", settings)
687691
exported_program = pre_export_lowering(exported_program, settings)
@@ -850,6 +854,16 @@ def preserve_module_specs(
850854
require_full_compilation=settings.require_full_compilation,
851855
)
852856

857+
from torch_tensorrt.dynamo.partitioning._resource_partitioner import (
858+
resource_partition,
859+
)
860+
861+
partitioned_module = resource_partition(
862+
gm,
863+
partitioned_module,
864+
cpu_memory_budget=settings.cpu_memory_budget,
865+
)
866+
853867
dryrun_tracker.unsupported_ops = supported_ops.unsupported_operators
854868

855869
# The global partitioner leaves non-TRT nodes as-is
@@ -868,6 +882,16 @@ def preserve_module_specs(
868882
# Iterate over all components that can be accelerated
869883
# Generate the corresponding TRT Module for those
870884

885+
# Here we delete the frozen parameters from the graph module. Note this does not affect the submodules. We are going to delete the frozen parameters from the submodules in the convert_module function.
886+
# This is done to release CPU memory.
887+
for attr in dir(gm):
888+
if attr.startswith("_frozen_param"):
889+
delattr(gm, attr)
890+
891+
from torch_tensorrt.dynamo.conversion._ConverterRegistry import DYNAMO_CONVERTERS
892+
893+
DYNAMO_CONVERTERS.disallowed_targets = set()
894+
871895
for name, _ in partitioned_module.named_children():
872896
submodule = getattr(partitioned_module, name)
873897
# filter on the GraphModule
@@ -1243,7 +1267,7 @@ def convert_exported_program_to_serialized_trt_engine(
12431267

12441268
# Prepare torch_trt inputs
12451269
trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs)
1246-
trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs)
1270+
trt_kwarg_inputs: Optional[dict[str, Any]] = prepare_inputs(kwarg_inputs)
12471271
device = to_torch_tensorrt_device(device)
12481272
enabled_precisions = {dtype._from(p) for p in enabled_precisions}
12491273

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import platform
33
import tempfile
44

5+
import psutil
56
import torch
67
from torch_tensorrt._Device import Device
78
from torch_tensorrt._enums import EngineCapability, dtype
@@ -57,6 +58,7 @@
5758
L2_LIMIT_FOR_TILING = -1
5859
USE_DISTRIBUTED_MODE_TRACE = False
5960
OFFLOAD_MODULE_TO_CPU = False
61+
CPU_MEMORY_BUDGET = psutil.virtual_memory().available
6062

6163
if platform.system() == "Linux":
6264
import pwd

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from torch_tensorrt.dynamo._defaults import (
88
ASSUME_DYNAMIC_SHAPE_SUPPORT,
99
CACHE_BUILT_ENGINES,
10+
CPU_MEMORY_BUDGET,
1011
DISABLE_TF32,
1112
DLA_GLOBAL_DRAM_SIZE,
1213
DLA_LOCAL_DRAM_SIZE,
@@ -140,6 +141,7 @@ class CompilationSettings:
140141
l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING
141142
use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE
142143
offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU
144+
cpu_memory_budget: int = CPU_MEMORY_BUDGET
143145

144146
def __getstate__(self) -> dict[str, Any]:
145147
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (

0 commit comments

Comments
 (0)