Skip to content

Commit 9c2faf5

Browse files
committed
fix comments
1 parent d39a29e commit 9c2faf5

File tree

1 file changed

+26
-10
lines changed

1 file changed

+26
-10
lines changed

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import tensorrt as trt
88
import torch
99
from torch_tensorrt._enums import dtype
10-
from torch_tensorrt._features import ENABLED_FEATURES, needs_refit
10+
from torch_tensorrt._features import ENABLED_FEATURES
1111
from torch_tensorrt._Input import Input
1212
from torch_tensorrt.dynamo._engine_cache import BaseEngineCache
1313
from torch_tensorrt.dynamo._settings import CompilationSettings, settings_are_compatible
@@ -70,11 +70,16 @@ def interpret_module_to_result(
7070

7171
def _insert_engine_to_cache(
7272
hash_val: str, interpreter_result: TRTInterpreterResult
73-
) -> None: # type: ignore[unused-ignore]
73+
) -> bool:
74+
if not ENABLED_FEATURES.refit:
75+
logger.info("Refit feature is not available, so the engine is not cached")
76+
return False
77+
7478
# Cache the weight-stripped engine regardless of the `strip_engine_weights` setting
7579
if engine_cache.check(hash_val) is not None: # type: ignore[union-attr]
76-
logger.info(f"Engine already exists in cache for hash: {hash_val}")
77-
return
80+
logger.info(f"The engine already exists in cache for hash: {hash_val}")
81+
return False
82+
7883
if not settings.strip_engine_weights:
7984
# set EXCLUDE_WEIGHTS flag to strip weights
8085
serialization_config = (
@@ -101,9 +106,15 @@ def _insert_engine_to_cache(
101106
),
102107
)
103108
logger.info(f"Engine was successfully inserted into cache for hash: {hash_val}")
109+
return True
104110

105-
@needs_refit # type: ignore[misc]
106111
def _pull_cached_engine(hash_val: str) -> Optional[SerializedInterpreterResult]:
112+
if not ENABLED_FEATURES.refit:
113+
logger.info(
114+
"Refit feature is not available, so the engine is not loaded from cache"
115+
)
116+
return None
117+
107118
# query the cached TRT engine
108119
cached_data = engine_cache.check(hash_val) # type: ignore[union-attr]
109120
if cached_data is not None: # hit the cache
@@ -181,14 +192,18 @@ def _pull_cached_engine(hash_val: str) -> Optional[SerializedInterpreterResult]:
181192
# engine_cache could be None if:
182193
# 1) engine_cache is not passed in when calling this function like convert_exported_program_to_serialized_trt_engine etc., or
183194
# 2) both cache_built_engines and reuse_cached_engines are False
184-
if engine_cache is not None and not settings.immutable_weights:
195+
if (
196+
ENABLED_FEATURES.refit
197+
and engine_cache is not None
198+
and not settings.immutable_weights
199+
):
185200
if settings.cache_built_engines or settings.reuse_cached_engines:
186201
hash_val = engine_cache.get_hash(module, inputs, settings)
187202

188203
if settings.reuse_cached_engines:
189204
serialized_interpreter_result = _pull_cached_engine(hash_val)
190205
if serialized_interpreter_result is not None: # hit the cache
191-
return serialized_interpreter_result # type: ignore[no-any-return]
206+
return serialized_interpreter_result
192207

193208
output_dtypes = infer_module_output_dtypes(
194209
module, truncate_double=settings.truncate_double
@@ -215,11 +230,12 @@ def _pull_cached_engine(hash_val: str) -> Optional[SerializedInterpreterResult]:
215230

216231
# Engine caching only for refittable engines
217232
if (
218-
not settings.immutable_weights
233+
ENABLED_FEATURES.refit
234+
and not settings.immutable_weights
219235
and settings.cache_built_engines
220236
and engine_cache is not None
221237
):
222-
_insert_engine_to_cache(hash_val, interpreter_result)
238+
_ = _insert_engine_to_cache(hash_val, interpreter_result)
223239

224240
serialized_engine = interpreter_result.engine.serialize()
225241
with io.BytesIO() as engine_bytes:
@@ -237,7 +253,7 @@ def _pull_cached_engine(hash_val: str) -> Optional[SerializedInterpreterResult]:
237253
requires_output_allocator=interpreter_result.requires_output_allocator,
238254
)
239255

240-
return serialized_interpreter_result # type: ignore[no-any-return]
256+
return serialized_interpreter_result
241257

242258

243259
def convert_module(

0 commit comments

Comments
 (0)