Skip to content

Commit 2773276

Browse files
committed
fix comments
1 parent ea81677 commit 2773276

File tree

1 file changed

+134
-123
lines changed

1 file changed

+134
-123
lines changed

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 134 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,134 @@ def infer_module_output_dtypes(
4747
return get_output_dtypes(outputs, truncate_double) # type: ignore[no-any-return]
4848

4949

50+
def insert_engine_to_cache(
51+
hash_val: str,
52+
interpreter_result: TRTInterpreterResult,
53+
engine_cache: BaseEngineCache,
54+
settings: CompilationSettings,
55+
inputs: Sequence[Input],
56+
) -> bool:
57+
if not ENABLED_FEATURES.refit:
58+
logger.info("Refit feature is not available, so the engine is not cached")
59+
return False
60+
61+
# Cache the weight-stripped engine regardless of the `strip_engine_weights` setting
62+
if engine_cache.check(hash_val) is not None:
63+
logger.info(f"The engine already exists in cache for hash: {hash_val}")
64+
return False
65+
66+
if not settings.strip_engine_weights:
67+
# set EXCLUDE_WEIGHTS flag to strip weights
68+
serialization_config = interpreter_result.engine.create_serialization_config()
69+
serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
70+
weight_stripped_serialized_engine = (
71+
interpreter_result.engine.serialize_with_config(serialization_config)
72+
)
73+
else:
74+
weight_stripped_serialized_engine = interpreter_result.engine.serialize()
75+
76+
# Insert weight-stripped engine to cache
77+
engine_cache.insert(
78+
hash_val,
79+
(
80+
weight_stripped_serialized_engine,
81+
interpreter_result.input_names,
82+
interpreter_result.output_names,
83+
inputs,
84+
settings,
85+
interpreter_result.weight_name_map,
86+
interpreter_result.requires_output_allocator,
87+
),
88+
)
89+
logger.info(f"Engine was successfully inserted into cache for hash: {hash_val}")
90+
return True
91+
92+
93+
def pull_cached_engine(
94+
hash_val: str,
95+
module: torch.fx.GraphModule,
96+
engine_cache: BaseEngineCache,
97+
settings: CompilationSettings,
98+
inputs: Sequence[Input],
99+
) -> Optional[SerializedInterpreterResult]:
100+
if not ENABLED_FEATURES.refit:
101+
logger.info(
102+
"Refit feature is not available, so the engine is not loaded from cache"
103+
)
104+
return None
105+
106+
# query the cached TRT engine
107+
cached_data = engine_cache.check(hash_val)
108+
if cached_data is not None: # hit the cache
109+
(
110+
serialized_engine, # weight-stripped engine
111+
input_names,
112+
output_names,
113+
cached_engine_inputs,
114+
cached_engine_compilation_settings,
115+
weight_name_map,
116+
requires_output_allocator,
117+
) = cached_data
118+
119+
setting_compatiblity, incompattible_settings = settings_are_compatible(
120+
settings, cached_engine_compilation_settings
121+
)
122+
assert (
123+
setting_compatiblity
124+
), f"Attempted to refit a cached engine with incompatible settings: {incompattible_settings}, (old_settings: {cached_engine_compilation_settings}, new_settings: {settings})"
125+
126+
for i, e in enumerate(
127+
[Input.equivalent_spec(c, i) for c, i in zip(cached_engine_inputs, inputs)]
128+
):
129+
assert (
130+
e
131+
), f"Attempted to refit a cached engine built for a different input size (input: {i}, cached size: {cached_engine_inputs[i]}, new size: {inputs[i]}"
132+
133+
logger.info(
134+
"Found the cached engine that corresponds to this graph. It is directly loaded."
135+
)
136+
137+
# refit the cached engine with the new graph module
138+
if not settings.strip_engine_weights:
139+
runtime = trt.Runtime(TRT_LOGGER)
140+
engine = runtime.deserialize_cuda_engine(
141+
serialized_engine
142+
) # weight-stripped engine
143+
144+
from torch_tensorrt.dynamo._refit import (
145+
_refit_single_trt_engine_with_gm,
146+
)
147+
148+
# weight-stripped engine --in place--> weight-included engine
149+
_refit_single_trt_engine_with_gm(
150+
new_gm=module,
151+
old_engine=engine,
152+
input_list=inputs,
153+
settings=settings,
154+
weight_name_map=weight_name_map,
155+
)
156+
157+
# EXCLUDE_WEIGHTS flag must be cleared and INCLUDE_REFIT flag must be set
158+
serialization_config = engine.create_serialization_config()
159+
serialization_config.clear_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
160+
serialization_config.set_flag(trt.SerializationFlag.INCLUDE_REFIT)
161+
serialized_engine = engine.serialize_with_config(serialization_config)
162+
# Start from here, the engine is weight-included and refittable
163+
164+
with io.BytesIO() as engine_bytes:
165+
engine_bytes.write(serialized_engine)
166+
serialized_engine = engine_bytes.getvalue()
167+
168+
return SerializedInterpreterResult(
169+
serialized_engine=serialized_engine,
170+
input_names=input_names,
171+
output_names=output_names,
172+
weight_name_map=weight_name_map,
173+
requires_output_allocator=requires_output_allocator,
174+
)
175+
return None
176+
177+
50178
def interpret_module_to_result(
51179
module: torch.fx.GraphModule,
52180
inputs: Sequence[Input],
@@ -68,127 +196,6 @@ def interpret_module_to_result(
68196
SerializedInterpreterResult
69197
"""
70198

71-
def _insert_engine_to_cache(
72-
hash_val: str, interpreter_result: TRTInterpreterResult
73-
) -> bool:
74-
if not ENABLED_FEATURES.refit:
75-
logger.info("Refit feature is not available, so the engine is not cached")
76-
return False
77-
78-
# Cache the weight-stripped engine regardless of the `strip_engine_weights` setting
79-
if engine_cache.check(hash_val) is not None: # type: ignore[union-attr]
80-
logger.info(f"The engine already exists in cache for hash: {hash_val}")
81-
return False
82-
83-
if not settings.strip_engine_weights:
84-
# set EXCLUDE_WEIGHTS flag to strip weights
85-
serialization_config = (
86-
interpreter_result.engine.create_serialization_config()
87-
)
88-
serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
89-
weight_stripped_serialized_engine = (
90-
interpreter_result.engine.serialize_with_config(serialization_config)
91-
)
92-
else:
93-
weight_stripped_serialized_engine = interpreter_result.engine.serialize()
94-
95-
# Insert weight-stripped engine to cache
96-
engine_cache.insert( # type: ignore[union-attr]
97-
hash_val,
98-
(
99-
weight_stripped_serialized_engine,
100-
interpreter_result.input_names,
101-
interpreter_result.output_names,
102-
inputs,
103-
settings,
104-
interpreter_result.weight_name_map,
105-
interpreter_result.requires_output_allocator,
106-
),
107-
)
108-
logger.info(f"Engine was successfully inserted into cache for hash: {hash_val}")
109-
return True
110-
111-
def _pull_cached_engine(hash_val: str) -> Optional[SerializedInterpreterResult]:
112-
if not ENABLED_FEATURES.refit:
113-
logger.info(
114-
"Refit feature is not available, so the engine is not loaded from cache"
115-
)
116-
return None
117-
118-
# query the cached TRT engine
119-
cached_data = engine_cache.check(hash_val) # type: ignore[union-attr]
120-
if cached_data is not None: # hit the cache
121-
(
122-
serialized_engine, # weight-stripped engine
123-
input_names,
124-
output_names,
125-
cached_engine_inputs,
126-
cached_engine_compilation_settings,
127-
weight_name_map,
128-
requires_output_allocator,
129-
) = cached_data
130-
131-
setting_compatiblity, incompattible_settings = settings_are_compatible(
132-
settings, cached_engine_compilation_settings
133-
)
134-
assert (
135-
setting_compatiblity
136-
), f"Attempted to refit a cached engine with incompatible settings: {incompattible_settings}, (old_settings: {cached_engine_compilation_settings}, new_settings: {settings})"
137-
138-
for i, e in enumerate(
139-
[
140-
Input.equivalent_spec(c, i)
141-
for c, i in zip(cached_engine_inputs, inputs)
142-
]
143-
):
144-
assert (
145-
e
146-
), f"Attempted to refit a cached engine built for a different input size (input: {i}, cached size: {cached_engine_inputs[i]}, new size: {inputs[i]}"
147-
148-
logger.info(
149-
"Found the cached engine that corresponds to this graph. It is directly loaded."
150-
)
151-
152-
# refit the cached engine with the new graph module
153-
if not settings.strip_engine_weights:
154-
runtime = trt.Runtime(TRT_LOGGER)
155-
engine = runtime.deserialize_cuda_engine(
156-
serialized_engine
157-
) # weight-stripped engine
158-
159-
from torch_tensorrt.dynamo._refit import (
160-
_refit_single_trt_engine_with_gm,
161-
)
162-
163-
# weight-stripped engine --in place--> weight-included engine
164-
_refit_single_trt_engine_with_gm(
165-
new_gm=module,
166-
old_engine=engine,
167-
input_list=inputs,
168-
settings=settings,
169-
weight_name_map=weight_name_map,
170-
)
171-
172-
# EXCLUDE_WEIGHTS flag must be cleared and INCLUDE_REFIT flag must be set
173-
serialization_config = engine.create_serialization_config()
174-
serialization_config.clear_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
175-
serialization_config.set_flag(trt.SerializationFlag.INCLUDE_REFIT)
176-
serialized_engine = engine.serialize_with_config(serialization_config)
177-
# Start from here, the engine is weight-included and refittable
178-
179-
with io.BytesIO() as engine_bytes:
180-
engine_bytes.write(serialized_engine)
181-
serialized_engine = engine_bytes.getvalue()
182-
183-
return SerializedInterpreterResult(
184-
serialized_engine=serialized_engine,
185-
input_names=input_names,
186-
output_names=output_names,
187-
weight_name_map=weight_name_map,
188-
requires_output_allocator=requires_output_allocator,
189-
)
190-
return None
191-
192199
# engine_cache could be None if:
193200
# 1) engine_cache is not passed in when calling this function like convert_exported_program_to_serialized_trt_engine etc., or
194201
# 2) both cache_built_engines and reuse_cached_engines are False
@@ -201,7 +208,9 @@ def _pull_cached_engine(hash_val: str) -> Optional[SerializedInterpreterResult]:
201208
hash_val = engine_cache.get_hash(module, inputs, settings)
202209

203210
if settings.reuse_cached_engines:
204-
serialized_interpreter_result = _pull_cached_engine(hash_val)
211+
serialized_interpreter_result = pull_cached_engine(
212+
hash_val, module, engine_cache, settings, inputs
213+
)
205214
if serialized_interpreter_result is not None: # hit the cache
206215
return serialized_interpreter_result
207216

@@ -235,7 +244,9 @@ def _pull_cached_engine(hash_val: str) -> Optional[SerializedInterpreterResult]:
235244
and settings.cache_built_engines
236245
and engine_cache is not None
237246
):
238-
_ = _insert_engine_to_cache(hash_val, interpreter_result)
247+
_ = insert_engine_to_cache(
248+
hash_val, interpreter_result, engine_cache, settings, inputs
249+
)
239250

240251
serialized_engine = interpreter_result.engine.serialize()
241252
with io.BytesIO() as engine_bytes:

0 commit comments

Comments
 (0)