77import tensorrt as trt
88import torch
99from torch_tensorrt ._enums import dtype
10- from torch_tensorrt ._features import ENABLED_FEATURES , needs_refit
10+ from torch_tensorrt ._features import ENABLED_FEATURES
1111from torch_tensorrt ._Input import Input
1212from torch_tensorrt .dynamo ._engine_cache import BaseEngineCache
1313from torch_tensorrt .dynamo ._settings import CompilationSettings , settings_are_compatible
@@ -70,11 +70,16 @@ def interpret_module_to_result(
7070
7171 def _insert_engine_to_cache (
7272 hash_val : str , interpreter_result : TRTInterpreterResult
73- ) -> None : # type: ignore[unused-ignore]
73+ ) -> bool :
74+ if not ENABLED_FEATURES .refit :
75+ logger .info ("Refit feature is not available, so the engine is not cached" )
76+ return False
77+
7478 # Cache the weight-stripped engine regardless of the `strip_engine_weights` setting
7579 if engine_cache .check (hash_val ) is not None : # type: ignore[union-attr]
76- logger .info (f"Engine already exists in cache for hash: { hash_val } " )
77- return
80+ logger .info (f"The engine already exists in cache for hash: { hash_val } " )
81+ return False
82+
7883 if not settings .strip_engine_weights :
7984 # set EXCLUDE_WEIGHTS flag to strip weights
8085 serialization_config = (
@@ -101,9 +106,15 @@ def _insert_engine_to_cache(
101106 ),
102107 )
103108 logger .info (f"Engine was successfully inserted into cache for hash: { hash_val } " )
109+ return True
104110
105- @needs_refit # type: ignore[misc]
106111 def _pull_cached_engine (hash_val : str ) -> Optional [SerializedInterpreterResult ]:
112+ if not ENABLED_FEATURES .refit :
113+ logger .info (
114+ "Refit feature is not available, so the engine is not loaded from cache"
115+ )
116+ return None
117+
107118 # query the cached TRT engine
108119 cached_data = engine_cache .check (hash_val ) # type: ignore[union-attr]
109120 if cached_data is not None : # hit the cache
@@ -181,14 +192,18 @@ def _pull_cached_engine(hash_val: str) -> Optional[SerializedInterpreterResult]:
181192 # engine_cache could be None if:
182193 # 1) engine_cache is not passed in when calling this function like convert_exported_program_to_serialized_trt_engine etc., or
183194 # 2) both cache_built_engines and reuse_cached_engines are False
184- if engine_cache is not None and not settings .immutable_weights :
195+ if (
196+ ENABLED_FEATURES .refit
197+ and engine_cache is not None
198+ and not settings .immutable_weights
199+ ):
185200 if settings .cache_built_engines or settings .reuse_cached_engines :
186201 hash_val = engine_cache .get_hash (module , inputs , settings )
187202
188203 if settings .reuse_cached_engines :
189204 serialized_interpreter_result = _pull_cached_engine (hash_val )
190205 if serialized_interpreter_result is not None : # hit the cache
191- return serialized_interpreter_result # type: ignore[no-any-return]
206+ return serialized_interpreter_result
192207
193208 output_dtypes = infer_module_output_dtypes (
194209 module , truncate_double = settings .truncate_double
@@ -215,11 +230,12 @@ def _pull_cached_engine(hash_val: str) -> Optional[SerializedInterpreterResult]:
215230
216231 # Engine caching only for refittable engines
217232 if (
218- not settings .immutable_weights
233+ ENABLED_FEATURES .refit
234+ and not settings .immutable_weights
219235 and settings .cache_built_engines
220236 and engine_cache is not None
221237 ):
222- _insert_engine_to_cache (hash_val , interpreter_result )
238+ _ = _insert_engine_to_cache (hash_val , interpreter_result )
223239
224240 serialized_engine = interpreter_result .engine .serialize ()
225241 with io .BytesIO () as engine_bytes :
@@ -237,7 +253,7 @@ def _pull_cached_engine(hash_val: str) -> Optional[SerializedInterpreterResult]:
237253 requires_output_allocator = interpreter_result .requires_output_allocator ,
238254 )
239255
240- return serialized_interpreter_result # type: ignore[no-any-return]
256+ return serialized_interpreter_result
241257
242258
243259def convert_module (
0 commit comments