Skip to content

Commit 4d09e21

Browse files
IREE converter support in inference script (without full validation cli)
1 parent 2ddb872 commit 4d09e21

File tree

1 file changed

+82
-6
lines changed

1 file changed

+82
-6
lines changed

src/inference/inference_iree.py

Lines changed: 82 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import sys
44
import traceback
5+
import tempfile
56
from pathlib import Path
67

78
import postprocessing_data as pp
@@ -17,6 +18,7 @@
1718
'iree_converter',
1819
'iree_auxiliary')))
1920
from compiler import IREECompiler # noqa: E402
21+
from converter import IREEConverter # noqa: E402
2022

2123
sys.path.append(str(Path(__file__).resolve().parents[1].joinpath('utils')))
2224
from logger_conf import configure_logger # noqa: E402
@@ -30,14 +32,45 @@
3032
sys.exit(1)
3133

3234

35+
def validate_cli_args(args):
36+
if args.model:
37+
pass
38+
else:
39+
pass
40+
41+
3342
def cli_argument_parser():
3443
parser = argparse.ArgumentParser()
35-
44+
parser.add_argument('-f', '--source_framework',
45+
help='Source model framework (required for automatic conversion to MLIR)',
46+
type=str,
47+
choices=['onnx', 'pytorch'],
48+
dest='source_framework')
3649
parser.add_argument('-m', '--model',
37-
help='Path to .vmfb file with compiled model or .mlir.',
38-
required=True,
50+
help='Path to source framework model (.onnx, .pt),'
51+
'to file with compiled model (.vmfb)'
52+
'or MLIR (.mlir).',
3953
type=str,
4054
dest='model')
55+
parser.add_argument('-w', '--weights',
56+
help='Path to an .pth file with a trained weights.'
57+
'Availiable when source_framework=pytorch ',
58+
type=str,
59+
dest='model_weights')
60+
parser.add_argument('-tm', '--torch_module',
61+
help='Torch module with model architecture.'
62+
'Availiable when source_framework=pytorch',
63+
type=str,
64+
dest='torch_module')
65+
parser.add_argument('-mn', '--model_name',
66+
help='Model name.',
67+
type=str,
68+
dest='model_name')
69+
parser.add_argument('--onnx_opset_version',
70+
help='Path to an .onnx with a trained model.'
71+
'Availiable when source_framework=onnx',
72+
type=int,
73+
dest='onnx_opset_version')
4174
parser.add_argument('-fn', '--function_name',
4275
help='IREE module function name to execute.',
4376
required=True,
@@ -143,8 +176,25 @@ def cli_argument_parser():
143176
type=str,
144177
nargs=argparse.REMAINDER,
145178
default=[])
146-
147-
return parser.parse_args()
179+
args = parser.parse_args()
180+
validate_cli_args(args)
181+
return args
182+
183+
184+
def convert_model_to_mlir(model_path, model_weights, torch_module, model_name, onnx_opset_version, source_framework, input_shape, output_mlir):
185+
dictionary = {
186+
'source_framework': source_framework,
187+
'model_name': model_name,
188+
'model_path': model_path,
189+
'model_weights': model_weights,
190+
'torch_module': torch_module,
191+
'onnx_opset_version': onnx_opset_version,
192+
'input_shape': input_shape,
193+
'output_mlir': output_mlir
194+
}
195+
converter = IREEConverter.get_converter(dictionary)
196+
converter.convert_to_mlir()
197+
return
148198

149199

150200
def compile_mlir(mlir_path, target_backend, opt_level, extra_compile_args):
@@ -191,13 +241,33 @@ def create_iree_context_from_buffer(vmfb_buffer):
191241
raise
192242

193243

194-
def load_model(model_path, target_backend, opt_level, extra_compile_args):
244+
def load_model(model_path, model_weights, torch_module, model_name, onnx_opset_version,
245+
source_framework, input_shape, target_backend, opt_level, extra_compile_args):
246+
is_tmp_mlir = False
247+
if model_path is None or model_path.split('.')[-1] not in ['vmfb', 'mlir']:
248+
with tempfile.NamedTemporaryFile(mode='w+t', delete=False, suffix='.mlir') as temp:
249+
output_mlir = temp.name
250+
convert_model_to_mlir(model_path,
251+
model_weights,
252+
torch_module,
253+
model_name,
254+
onnx_opset_version,
255+
source_framework,
256+
input_shape,
257+
output_mlir)
258+
model_path = output_mlir
259+
is_tmp_mlir = True
260+
195261
vmfb_buffer = load_model_buffer(
196262
model_path,
197263
target_backend=target_backend,
198264
opt_level=opt_level,
199265
extra_compile_args=extra_compile_args
200266
)
267+
268+
if is_tmp_mlir:
269+
os.remove(model_path)
270+
201271
return create_iree_context_from_buffer(vmfb_buffer)
202272

203273

@@ -316,6 +386,12 @@ def main():
316386
log.info('Loading model')
317387
model_context = load_model(
318388
model_path=args.model,
389+
model_weights=args.model_weights,
390+
torch_module=args.torch_module,
391+
model_name=args.model_name,
392+
onnx_opset_version=args.onnx_opset_version,
393+
source_framework=args.source_framework,
394+
input_shape=args.input_shape,
319395
target_backend=args.target_backend,
320396
opt_level=args.opt_level,
321397
extra_compile_args=args.extra_compile_args

0 commit comments

Comments
 (0)