@@ -47,6 +47,134 @@ def infer_module_output_dtypes(
4747 return get_output_dtypes (outputs , truncate_double ) # type: ignore[no-any-return]
4848
4949
50+ def insert_engine_to_cache (
51+ hash_val : str ,
52+ interpreter_result : TRTInterpreterResult ,
53+ engine_cache : BaseEngineCache ,
54+ settings : CompilationSettings ,
55+ inputs : Sequence [Input ],
56+ ) -> bool :
57+ if not ENABLED_FEATURES .refit :
58+ logger .info ("Refit feature is not available, so the engine is not cached" )
59+ return False
60+
61+ # Cache the weight-stripped engine regardless of the `strip_engine_weights` setting
62+ if engine_cache .check (hash_val ) is not None :
63+ logger .info (f"The engine already exists in cache for hash: { hash_val } " )
64+ return False
65+
66+ if not settings .strip_engine_weights :
67+ # set EXCLUDE_WEIGHTS flag to strip weights
68+ serialization_config = interpreter_result .engine .create_serialization_config ()
69+ serialization_config .set_flag (trt .SerializationFlag .EXCLUDE_WEIGHTS )
70+ weight_stripped_serialized_engine = (
71+ interpreter_result .engine .serialize_with_config (serialization_config )
72+ )
73+ else :
74+ weight_stripped_serialized_engine = interpreter_result .engine .serialize ()
75+
76+ # Insert weight-stripped engine to cache
77+ engine_cache .insert (
78+ hash_val ,
79+ (
80+ weight_stripped_serialized_engine ,
81+ interpreter_result .input_names ,
82+ interpreter_result .output_names ,
83+ inputs ,
84+ settings ,
85+ interpreter_result .weight_name_map ,
86+ interpreter_result .requires_output_allocator ,
87+ ),
88+ )
89+ logger .info (f"Engine was successfully inserted into cache for hash: { hash_val } " )
90+ return True
91+
92+
93+ def pull_cached_engine (
94+ hash_val : str ,
95+ module : torch .fx .GraphModule ,
96+ engine_cache : BaseEngineCache ,
97+ settings : CompilationSettings ,
98+ inputs : Sequence [Input ],
99+ ) -> Optional [SerializedInterpreterResult ]:
100+ if not ENABLED_FEATURES .refit :
101+ logger .info (
102+ "Refit feature is not available, so the engine is not loaded from cache"
103+ )
104+ return None
105+
106+ # query the cached TRT engine
107+ cached_data = engine_cache .check (hash_val )
108+ if cached_data is not None : # hit the cache
109+ (
110+ serialized_engine , # weight-stripped engine
111+ input_names ,
112+ output_names ,
113+ cached_engine_inputs ,
114+ cached_engine_compilation_settings ,
115+ weight_name_map ,
116+ requires_output_allocator ,
117+ ) = cached_data
118+
119+ setting_compatiblity , incompattible_settings = settings_are_compatible (
120+ settings , cached_engine_compilation_settings
121+ )
122+ assert (
123+ setting_compatiblity
124+ ), f"Attempted to refit a cached engine with incompatible settings: { incompattible_settings } , (old_settings: { cached_engine_compilation_settings } , new_settings: { settings } )"
125+
126+ for i , e in enumerate (
127+ [Input .equivalent_spec (c , i ) for c , i in zip (cached_engine_inputs , inputs )]
128+ ):
129+ assert (
130+ e
131+ ), f"Attempted to refit a cached engine built for a different input size (input: { i } , cached size: { cached_engine_inputs [i ]} , new size: { inputs [i ]} "
132+
133+ logger .info (
134+ "Found the cached engine that corresponds to this graph. It is directly loaded."
135+ )
136+
137+ # refit the cached engine with the new graph module
138+ if not settings .strip_engine_weights :
139+ runtime = trt .Runtime (TRT_LOGGER )
140+ engine = runtime .deserialize_cuda_engine (
141+ serialized_engine
142+ ) # weight-stripped engine
143+
144+ from torch_tensorrt .dynamo ._refit import (
145+ _refit_single_trt_engine_with_gm ,
146+ )
147+
148+ # weight-stripped engine --in place--> weight-included engine
149+ _refit_single_trt_engine_with_gm (
150+ new_gm = module ,
151+ old_engine = engine ,
152+ input_list = inputs ,
153+ settings = settings ,
154+ weight_name_map = weight_name_map ,
155+ )
156+
157+ # EXCLUDE_WEIGHTS flag must be cleared and INCLUDE_REFIT flag must be set
158+ serialization_config = engine .create_serialization_config ()
159+ serialization_config .clear_flag (trt .SerializationFlag .EXCLUDE_WEIGHTS )
160+ serialization_config .set_flag (trt .SerializationFlag .INCLUDE_REFIT )
161+ serialized_engine = engine .serialize_with_config (serialization_config )
162+ # Start from here, the engine is weight-included and refittable
163+
164+ with io .BytesIO () as engine_bytes :
165+ engine_bytes .write (serialized_engine )
166+ serialized_engine = engine_bytes .getvalue ()
167+
168+ return SerializedInterpreterResult (
169+ serialized_engine = serialized_engine ,
170+ input_names = input_names ,
171+ output_names = output_names ,
172+ weight_name_map = weight_name_map ,
173+ requires_output_allocator = requires_output_allocator ,
174+ )
175+ return None
176+
177+
50178def interpret_module_to_result (
51179 module : torch .fx .GraphModule ,
52180 inputs : Sequence [Input ],
@@ -68,127 +196,6 @@ def interpret_module_to_result(
68196 SerializedInterpreterResult
69197 """
70198
71- def _insert_engine_to_cache (
72- hash_val : str , interpreter_result : TRTInterpreterResult
73- ) -> bool :
74- if not ENABLED_FEATURES .refit :
75- logger .info ("Refit feature is not available, so the engine is not cached" )
76- return False
77-
78- # Cache the weight-stripped engine regardless of the `strip_engine_weights` setting
79- if engine_cache .check (hash_val ) is not None : # type: ignore[union-attr]
80- logger .info (f"The engine already exists in cache for hash: { hash_val } " )
81- return False
82-
83- if not settings .strip_engine_weights :
84- # set EXCLUDE_WEIGHTS flag to strip weights
85- serialization_config = (
86- interpreter_result .engine .create_serialization_config ()
87- )
88- serialization_config .set_flag (trt .SerializationFlag .EXCLUDE_WEIGHTS )
89- weight_stripped_serialized_engine = (
90- interpreter_result .engine .serialize_with_config (serialization_config )
91- )
92- else :
93- weight_stripped_serialized_engine = interpreter_result .engine .serialize ()
94-
95- # Insert weight-stripped engine to cache
96- engine_cache .insert ( # type: ignore[union-attr]
97- hash_val ,
98- (
99- weight_stripped_serialized_engine ,
100- interpreter_result .input_names ,
101- interpreter_result .output_names ,
102- inputs ,
103- settings ,
104- interpreter_result .weight_name_map ,
105- interpreter_result .requires_output_allocator ,
106- ),
107- )
108- logger .info (f"Engine was successfully inserted into cache for hash: { hash_val } " )
109- return True
110-
111- def _pull_cached_engine (hash_val : str ) -> Optional [SerializedInterpreterResult ]:
112- if not ENABLED_FEATURES .refit :
113- logger .info (
114- "Refit feature is not available, so the engine is not loaded from cache"
115- )
116- return None
117-
118- # query the cached TRT engine
119- cached_data = engine_cache .check (hash_val ) # type: ignore[union-attr]
120- if cached_data is not None : # hit the cache
121- (
122- serialized_engine , # weight-stripped engine
123- input_names ,
124- output_names ,
125- cached_engine_inputs ,
126- cached_engine_compilation_settings ,
127- weight_name_map ,
128- requires_output_allocator ,
129- ) = cached_data
130-
131- setting_compatiblity , incompattible_settings = settings_are_compatible (
132- settings , cached_engine_compilation_settings
133- )
134- assert (
135- setting_compatiblity
136- ), f"Attempted to refit a cached engine with incompatible settings: { incompattible_settings } , (old_settings: { cached_engine_compilation_settings } , new_settings: { settings } )"
137-
138- for i , e in enumerate (
139- [
140- Input .equivalent_spec (c , i )
141- for c , i in zip (cached_engine_inputs , inputs )
142- ]
143- ):
144- assert (
145- e
146- ), f"Attempted to refit a cached engine built for a different input size (input: { i } , cached size: { cached_engine_inputs [i ]} , new size: { inputs [i ]} "
147-
148- logger .info (
149- "Found the cached engine that corresponds to this graph. It is directly loaded."
150- )
151-
152- # refit the cached engine with the new graph module
153- if not settings .strip_engine_weights :
154- runtime = trt .Runtime (TRT_LOGGER )
155- engine = runtime .deserialize_cuda_engine (
156- serialized_engine
157- ) # weight-stripped engine
158-
159- from torch_tensorrt .dynamo ._refit import (
160- _refit_single_trt_engine_with_gm ,
161- )
162-
163- # weight-stripped engine --in place--> weight-included engine
164- _refit_single_trt_engine_with_gm (
165- new_gm = module ,
166- old_engine = engine ,
167- input_list = inputs ,
168- settings = settings ,
169- weight_name_map = weight_name_map ,
170- )
171-
172- # EXCLUDE_WEIGHTS flag must be cleared and INCLUDE_REFIT flag must be set
173- serialization_config = engine .create_serialization_config ()
174- serialization_config .clear_flag (trt .SerializationFlag .EXCLUDE_WEIGHTS )
175- serialization_config .set_flag (trt .SerializationFlag .INCLUDE_REFIT )
176- serialized_engine = engine .serialize_with_config (serialization_config )
177- # Start from here, the engine is weight-included and refittable
178-
179- with io .BytesIO () as engine_bytes :
180- engine_bytes .write (serialized_engine )
181- serialized_engine = engine_bytes .getvalue ()
182-
183- return SerializedInterpreterResult (
184- serialized_engine = serialized_engine ,
185- input_names = input_names ,
186- output_names = output_names ,
187- weight_name_map = weight_name_map ,
188- requires_output_allocator = requires_output_allocator ,
189- )
190- return None
191-
192199 # engine_cache could be None if:
193200 # 1) engine_cache is not passed in when calling this function like convert_exported_program_to_serialized_trt_engine etc., or
194201 # 2) both cache_built_engines and reuse_cached_engines are False
@@ -201,7 +208,9 @@ def _pull_cached_engine(hash_val: str) -> Optional[SerializedInterpreterResult]:
201208 hash_val = engine_cache .get_hash (module , inputs , settings )
202209
203210 if settings .reuse_cached_engines :
204- serialized_interpreter_result = _pull_cached_engine (hash_val )
211+ serialized_interpreter_result = pull_cached_engine (
212+ hash_val , module , engine_cache , settings , inputs
213+ )
205214 if serialized_interpreter_result is not None : # hit the cache
206215 return serialized_interpreter_result
207216
@@ -235,7 +244,9 @@ def _pull_cached_engine(hash_val: str) -> Optional[SerializedInterpreterResult]:
235244 and settings .cache_built_engines
236245 and engine_cache is not None
237246 ):
238- _ = _insert_engine_to_cache (hash_val , interpreter_result )
247+ _ = insert_engine_to_cache (
248+ hash_val , interpreter_result , engine_cache , settings , inputs
249+ )
239250
240251 serialized_engine = interpreter_result .engine .serialize ()
241252 with io .BytesIO () as engine_bytes :
0 commit comments