|
2 | 2 | import os |
3 | 3 | import sys |
4 | 4 | import traceback |
| 5 | +import tempfile |
5 | 6 | from pathlib import Path |
6 | 7 |
|
7 | 8 | import postprocessing_data as pp |
|
17 | 18 | 'iree_converter', |
18 | 19 | 'iree_auxiliary'))) |
19 | 20 | from compiler import IREECompiler # noqa: E402 |
| 21 | +from converter import IREEConverter # noqa: E402 |
20 | 22 |
|
21 | 23 | sys.path.append(str(Path(__file__).resolve().parents[1].joinpath('utils'))) |
22 | 24 | from logger_conf import configure_logger # noqa: E402 |
|
30 | 32 | sys.exit(1) |
31 | 33 |
|
32 | 34 |
|
| 35 | +def validate_cli_args(args): |
| 36 | + if args.model: |
| 37 | + pass |
| 38 | + else: |
| 39 | + pass |
| 40 | + |
| 41 | + |
33 | 42 | def cli_argument_parser(): |
34 | 43 | parser = argparse.ArgumentParser() |
35 | | - |
| 44 | + parser.add_argument('-f', '--source_framework', |
| 45 | + help='Source model framework (required for automatic conversion to MLIR)', |
| 46 | + type=str, |
| 47 | + choices=['onnx', 'pytorch'], |
| 48 | + dest='source_framework') |
36 | 49 | parser.add_argument('-m', '--model', |
37 | | - help='Path to .vmfb file with compiled model or .mlir.', |
38 | | - required=True, |
| 50 | + help='Path to source framework model (.onnx, .pt),' |
| 51 | + 'to file with compiled model (.vmfb)' |
| 52 | + 'or MLIR (.mlir).', |
39 | 53 | type=str, |
40 | 54 | dest='model') |
| 55 | + parser.add_argument('-w', '--weights', |
| 56 | + help='Path to an .pth file with a trained weights.' |
| 57 | + 'Availiable when source_framework=pytorch ', |
| 58 | + type=str, |
| 59 | + dest='model_weights') |
| 60 | + parser.add_argument('-tm', '--torch_module', |
| 61 | + help='Torch module with model architecture.' |
| 62 | + 'Availiable when source_framework=pytorch', |
| 63 | + type=str, |
| 64 | + dest='torch_module') |
| 65 | + parser.add_argument('-mn', '--model_name', |
| 66 | + help='Model name.', |
| 67 | + type=str, |
| 68 | + dest='model_name') |
| 69 | + parser.add_argument('--onnx_opset_version', |
| 70 | + help='Path to an .onnx with a trained model.' |
| 71 | + 'Availiable when source_framework=onnx', |
| 72 | + type=int, |
| 73 | + dest='onnx_opset_version') |
41 | 74 | parser.add_argument('-fn', '--function_name', |
42 | 75 | help='IREE module function name to execute.', |
43 | 76 | required=True, |
@@ -143,8 +176,25 @@ def cli_argument_parser(): |
143 | 176 | type=str, |
144 | 177 | nargs=argparse.REMAINDER, |
145 | 178 | default=[]) |
146 | | - |
147 | | - return parser.parse_args() |
| 179 | + args = parser.parse_args() |
| 180 | + validate_cli_args(args) |
| 181 | + return args |
| 182 | + |
| 183 | + |
| 184 | +def convert_model_to_mlir(model_path, model_weights, torch_module, model_name, onnx_opset_version, source_framework, input_shape, output_mlir): |
| 185 | + dictionary = { |
| 186 | + 'source_framework': source_framework, |
| 187 | + 'model_name': model_name, |
| 188 | + 'model_path': model_path, |
| 189 | + 'model_weights': model_weights, |
| 190 | + 'torch_module': torch_module, |
| 191 | + 'onnx_opset_version': onnx_opset_version, |
| 192 | + 'input_shape': input_shape, |
| 193 | + 'output_mlir': output_mlir |
| 194 | + } |
| 195 | + converter = IREEConverter.get_converter(dictionary) |
| 196 | + converter.convert_to_mlir() |
| 197 | + return |
148 | 198 |
|
149 | 199 |
|
150 | 200 | def compile_mlir(mlir_path, target_backend, opt_level, extra_compile_args): |
@@ -191,13 +241,33 @@ def create_iree_context_from_buffer(vmfb_buffer): |
191 | 241 | raise |
192 | 242 |
|
193 | 243 |
|
194 | | -def load_model(model_path, target_backend, opt_level, extra_compile_args): |
| 244 | +def load_model(model_path, model_weights, torch_module, model_name, onnx_opset_version, |
| 245 | + source_framework, input_shape, target_backend, opt_level, extra_compile_args): |
| 246 | + is_tmp_mlir = False |
| 247 | + if model_path is None or model_path.split('.')[-1] not in ['vmfb', 'mlir']: |
| 248 | + with tempfile.NamedTemporaryFile(mode='w+t', delete=False, suffix='.mlir') as temp: |
| 249 | + output_mlir = temp.name |
| 250 | + convert_model_to_mlir(model_path, |
| 251 | + model_weights, |
| 252 | + torch_module, |
| 253 | + model_name, |
| 254 | + onnx_opset_version, |
| 255 | + source_framework, |
| 256 | + input_shape, |
| 257 | + output_mlir) |
| 258 | + model_path = output_mlir |
| 259 | + is_tmp_mlir = True |
| 260 | + |
195 | 261 | vmfb_buffer = load_model_buffer( |
196 | 262 | model_path, |
197 | 263 | target_backend=target_backend, |
198 | 264 | opt_level=opt_level, |
199 | 265 | extra_compile_args=extra_compile_args |
200 | 266 | ) |
| 267 | + |
| 268 | + if is_tmp_mlir: |
| 269 | + os.remove(model_path) |
| 270 | + |
201 | 271 | return create_iree_context_from_buffer(vmfb_buffer) |
202 | 272 |
|
203 | 273 |
|
@@ -316,6 +386,12 @@ def main(): |
316 | 386 | log.info('Loading model') |
317 | 387 | model_context = load_model( |
318 | 388 | model_path=args.model, |
| 389 | + model_weights=args.model_weights, |
| 390 | + torch_module=args.torch_module, |
| 391 | + model_name=args.model_name, |
| 392 | + onnx_opset_version=args.onnx_opset_version, |
| 393 | + source_framework=args.source_framework, |
| 394 | + input_shape=args.input_shape, |
319 | 395 | target_backend=args.target_backend, |
320 | 396 | opt_level=args.opt_level, |
321 | 397 | extra_compile_args=args.extra_compile_args |
|
0 commit comments