Skip to content

Commit a283d6b

Browse files
IREE fix problems and comments
1 parent 71f8774 commit a283d6b

File tree

7 files changed

+97
-68
lines changed

7 files changed

+97
-68
lines changed

src/inference/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1569,4 +1569,4 @@ python3 inference_iree.py \
15691569
[dgl]: https://www.dgl.ai/pages/start.html
15701570
[ogb]: https://ogb.stanford.edu/
15711571
[tensorflow-gpu]: https://www.tensorflow.org/install/pip
1572-
[iree]: https://iree.dev/
1572+
[iree]: https://iree.dev

src/inference/inference_iree.py

Lines changed: 57 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import argparse
2+
import os
23
import sys
34
import traceback
45
from pathlib import Path
@@ -25,14 +26,13 @@
2526
try:
2627
import iree.runtime as ireert # noqa: E402
2728
except ImportError as e:
28-
log.error(f"IREE import error: {e}")
29+
log.error(f'IREE import error: {e}')
2930
sys.exit(1)
3031

3132

3233
def cli_argument_parser():
3334
parser = argparse.ArgumentParser()
3435

35-
3636
parser.add_argument('-m', '--model',
3737
help='Path to .vmfb file with compiled model or .mlir.',
3838
required=True,
@@ -129,12 +129,12 @@ def cli_argument_parser():
129129
nargs=3,
130130
dest='channel_swap')
131131
parser.add_argument('-tb', '--target_backend',
132-
help='Target backend, for example "llvm-cpu" for CPU.',
132+
help='Target backend, for example `llvm-cpu` for CPU.',
133133
default='llvm-cpu',
134134
type=str,
135135
dest='target_backend')
136136
parser.add_argument('--opt_level',
137-
help='The optimization level of the task extractions.',
137+
help='The optimization level of the compilation.',
138138
type=int,
139139
choices=[0, 1, 2, 3],
140140
default=2)
@@ -149,38 +149,67 @@ def cli_argument_parser():
149149

150150
def compile_mlir(mlir_path, target_backend, opt_level, extra_compile_args):
151151
try:
152-
log.info(f'Starting model compilation')
152+
log.info('Starting model compilation')
153153
return IREECompiler.compile(mlir_path, target_backend, opt_level, extra_compile_args)
154154
except Exception as e:
155-
log.error(f"Failed to compile MLIR: {e}")
155+
log.error(f'Failed to compile MLIR: {e}')
156156
raise
157157

158158

159-
def load_iree_model(vmfb_buffer):
159+
def load_model_buffer(model_path, target_backend, opt_level, extra_compile_args):
160+
if not os.path.exists(model_path):
161+
raise FileNotFoundError(f'Model file not found: {model_path}')
162+
163+
file_type = model_path.split('.')[-1]
164+
165+
if file_type == 'mlir':
166+
if target_backend is None:
167+
raise ValueError('target_backend is required for MLIR compilation')
168+
vmfb_buffer = compile_mlir(model_path, target_backend, opt_level, extra_compile_args)
169+
elif file_type == 'vmfb':
170+
with open(model_path, 'rb') as f:
171+
vmfb_buffer = f.read()
172+
else:
173+
raise ValueError(f'The file type {file_type} is not supported. Supported types: .mlir, .vmfb')
174+
175+
log.info(f'Successfully loaded model buffer from {model_path}')
176+
return vmfb_buffer
177+
178+
179+
def create_iree_context_from_buffer(vmfb_buffer):
160180
try:
161181
config = ireert.Config('local-task')
162-
163182
vm_module = ireert.VmModule.from_flatbuffer(config.vm_instance, vmfb_buffer)
164183
context = ireert.SystemContext(config=config)
165184
context.add_vm_module(vm_module)
166185

167-
log.info(f"Successfully loaded IREE model")
186+
log.info('Successfully created IREE context from buffer')
168187
return context
169188

170189
except Exception as e:
171-
log.error(f"Failed to load IREE model: {e}")
190+
log.error(f'Failed to create IREE context: {e}')
172191
raise
173192

174193

194+
def load_model(model_path, target_backend, opt_level, extra_compile_args):
195+
vmfb_buffer = load_model_buffer(
196+
model_path,
197+
target_backend=target_backend,
198+
opt_level=opt_level,
199+
extra_compile_args=extra_compile_args
200+
)
201+
return create_iree_context_from_buffer(vmfb_buffer)
202+
203+
175204
def get_inference_function(model_context, function_name):
176205
try:
177206
main_module = model_context.modules.module
178207
inference_func = main_module[function_name]
179-
log.info(f"Using function '{function_name}' for inference")
208+
log.info(f'Using function {function_name} for inference')
180209
return inference_func
181210

182211
except Exception as e:
183-
log.error(f"Failed to get inference function: {e}")
212+
log.error(f'Failed to get inference function: {e}')
184213
raise
185214

186215

@@ -196,7 +225,7 @@ def inference_iree(inference_func, number_iter, get_slice, test_duration):
196225
time_infer = loop_inference(number_iter, test_duration)(
197226
inference_iteration
198227
)(inference_func, get_slice)['time_infer']
199-
228+
200229
log.info('Inference completed')
201230
return result, time_infer
202231

@@ -215,7 +244,7 @@ def infer_slice(inference_func, slice_input):
215244
input_buffers = list()
216245
for input_ in slice_input:
217246
input_buffers.append(ireert.asdevicearray(device, input_))
218-
247+
219248
result = inference_func(*input_buffers)
220249

221250
if hasattr(result, 'to_host'):
@@ -230,7 +259,7 @@ def prepare_output(result, task):
230259
elif task == 'classification':
231260
if hasattr(result, 'to_host'):
232261
result = result.to_host()
233-
262+
234263
# Extract tensor from dict if needed
235264
if isinstance(result, dict):
236265
result_key = next(iter(result))
@@ -239,18 +268,18 @@ def prepare_output(result, task):
239268
else:
240269
logits = np.array(result)
241270
output_key = 'output'
242-
271+
243272
# Ensure correct shape (batch_size, num_classes)
244273
if logits.ndim == 1:
245274
logits = logits.reshape(1, -1)
246275
elif logits.ndim > 2:
247276
logits = logits.reshape(logits.shape[0], -1)
248-
277+
249278
# Apply softmax
250279
max_logits = np.max(logits, axis=-1, keepdims=True)
251280
exp_logits = np.exp(logits - max_logits)
252281
probabilities = exp_logits / np.sum(exp_logits, axis=-1, keepdims=True)
253-
282+
254283
return {output_key: probabilities}
255284
else:
256285
raise ValueError(f'Unsupported task {task}')
@@ -270,7 +299,7 @@ def create_dict_for_transformer(args):
270299

271300
def main():
272301
args = cli_argument_parser()
273-
302+
274303
try:
275304
model_wrapper = IREEModelWrapper(args)
276305
data_transformer = IREETransformer(create_dict_for_transformer(args))
@@ -284,16 +313,13 @@ def main():
284313
target_device=args.target_backend
285314
)
286315

287-
file_type = args.model.split('.')[-1]
288-
if file_type == 'mlir':
289-
vmfb_buffer = compile_mlir(args.model, args.target_backend, args.opt_level, args.extra_compile_args)
290-
elif file_type == 'vmfb':
291-
with open(args.model, 'rb') as f:
292-
vmfb_buffer = f.read()
293-
else:
294-
raise ValueError(f'The file type {file_type} is not supported')
295-
296-
model_context = load_iree_model(vmfb_buffer)
316+
log.info('Loading model')
317+
model_context = load_model(
318+
model_path=args.model,
319+
target_backend=args.target_backend,
320+
opt_level=args.opt_level,
321+
extra_compile_args=args.extra_compile_args
322+
)
297323
inference_func = get_inference_function(model_context, args.function_name)
298324

299325
log.info(f'Preparing input data: {args.input}')
@@ -309,10 +335,10 @@ def main():
309335

310336
log.info('Computing performance metrics')
311337
inference_result = pp.calculate_performance_metrics_sync_mode(
312-
args.batch_size,
338+
args.batch_size,
313339
inference_time
314340
)
315-
341+
316342
report_writer.update_execution_results(**inference_result)
317343
report_writer.write_report(args.report_path)
318344

src/inference/io_model_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,7 @@ class IREEModelWrapper(IOModelWrapper):
415415
def __init__(self, args):
416416
self._input_shapes = [args.input_shape]
417417
self._model_path = args.model
418-
418+
419419
def get_input_layer_names(self, model):
420420
return ['input']
421421

src/model_converters/iree_converter/README.md

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ This script converts model from `<source_framework>` to the IREE MLIR format.
2727
- `-w / --weights` is a path to an `.pth` file with trained weights for PyTorch models.
2828
- `-tm / --torch_module` is a module with the model architecture for PyTorch models. Default: `torchvision.models`.
2929
- `-is / --input_shape` is an input shape in the format BxWxHxC, where B is a batch size, W is an input tensor width, H is an input tensor height, C is an input tensor number of channels. Required for PyTorch models.
30-
- `--onnx_opset_version` is the ONNX opset version for ONNX models. Default: `18`.
31-
- `-o / --output_mlir` is path to save the MLIR file. Required.
30+
- `--onnx_opset_version` is an ONNX opset version for ONNX models. Default: `18`.
31+
- `-o / --output_mlir` is a path to save the MLIR file. Required.
3232

3333
### Parameter combinations
3434
#### For ONNX models:
@@ -44,21 +44,21 @@ Two loading methods are supported (mutually exclusive):
4444
- Optional: `--weights <path/to/weights.pth>`
4545

4646
### Examples of usage
47-
ONNX model conversion:
47+
ONNX model conversion ([source of the model efficientnet-b0.onnx](https://github.com/onnx/models/blob/main/Computer_Vision/efficientnet_b0_Opset17_timm/efficientnet_b0_Opset17.onnx)):
4848
```sh
4949
python3 iree_converter.py -f onnx -m efficientnet-b0.onnx \
5050
--onnx_opset_version 18 \
5151
-o ./output/efficientnet-b0.mlir
5252
```
5353

54-
PyTorch model from file:
54+
PyTorch model from file (`.pt` can be created using [tutorial](https://docs.pytorch.org/docs/main/notes/serialization.html#saving-and-loading-torch-nn-modules)):
5555
```sh
5656
python3 iree_converter.py -f pytorch -m resnet50.pt \
5757
-is 1 224 224 3 \
5858
-o ./output/resnet50.mlir
5959
```
6060

61-
PyTorch model from torchvision with pretrained weights:
61+
PyTorch model from [torchvision](https://docs.pytorch.org/vision/main/models.html) with pretrained weights:
6262
```sh
6363
python3 iree_converter.py -f pytorch -mn resnet50 \
6464
-tm torchvision.models \
@@ -88,19 +88,19 @@ iree_compiler.py --mlir <input.mlir> \
8888
This script compiles model from `.mlir` format to the deployable binary format for the specified target backend.
8989

9090
### IREE compiler parameters
91-
- `-m / --mlir` - Path to an .mlir file with a model. Required.
92-
- `-tb / --target_backend` - Target backend for compilation. Required. Examples: `llvm-cpu`, `cuda`, `vulkan`, `vmvx`.
93-
- `--opt_level` - The optimization level of the compilation. Choices: `0`, `1`, `2`, `3`. Default: `2`.
94-
- `-o / --output_file` - Path to save the compiled model. Required.
95-
- `--extra_args` - Extra arguments for compilation. Optional.
91+
- `-m / --mlir` is a path to an .mlir file with a model. Required.
92+
- `-tb / --target_backend` is a target backend for compilation. Required. Examples: `llvm-cpu`, `cuda`, `vulkan`, `vmvx`.
93+
- `--opt_level` is an optimization level of the compilation. Choices: `0`, `1`, `2`, `3`. Default: `2`.
94+
- `-o / --output_file` is a path to save the compiled model. Required.
95+
- `--extra_args` - is an extra arguments for compilation. Optional.
9696

9797
### Supported target backends
98-
- `llvm-cpu` - CPU execution using LLVM
99-
- `cuda` - NVIDIA GPU execution using CUDA
100-
- `vulkan` - GPU execution using Vulkan API
101-
- `vmvx` - Portable VM bytecode execution
102-
- `metal` - Apple GPU execution using Metal
103-
- `rocm` - AMD GPU execution using ROCm
98+
- `llvm-cpu` - CPU execution using LLVM.
99+
- `cuda` - NVIDIA GPU execution using CUDA.
100+
- `vulkan` - GPU execution using Vulkan API.
101+
- `vmvx` - Portable VM bytecode execution.
102+
- `metal` - Apple GPU execution using Metal.
103+
- `rocm` - AMD GPU execution using ROCm.
104104

105105
### Examples of usage
106106
```sh

src/model_converters/iree_converter/iree_auxiliary/onnx_format.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@ def source_framework(self):
1616

1717
def _validate_arguments(self):
1818
if self.model_path is None or self.model_path == '':
19-
raise ValueError("The model_path parameter is required for ONNX conversion.")
19+
raise ValueError('The model_path parameter is required for ONNX conversion.')
2020

2121
if not os.path.exists(self.model_path):
22-
raise FileNotFoundError(f"Model file not found: {self.model_path}")
22+
raise FileNotFoundError(f'Model file not found: {self.model_path}')
2323

2424
if self.onnx_opset_version is None:
25-
raise ValueError("The onnx_opset_version parameter is required for ONNX conversion.")
25+
raise ValueError('The onnx_opset_version parameter is required for ONNX conversion.')
2626

2727
def _convert_model_from_framework(self):
2828
if not os.path.exists(self.output_mlir):
@@ -36,5 +36,5 @@ def _convert_model_from_framework(self):
3636
self.output_mlir,
3737
]
3838
import_cmd = subprocess.list2cmdline(import_args)
39-
ret = subprocess.run(import_cmd, shell=True, capture_output=True)
39+
subprocess.run(import_cmd, shell=True, capture_output=True)
4040
return

src/model_converters/iree_converter/iree_auxiliary/pytorch_format.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,36 +17,38 @@ def __init__(self, args):
1717
@property
1818
def source_framework(self):
1919
return 'PyTorch'
20-
20+
2121
def _validate_arguments(self):
2222
if self.input_shape is None:
23-
raise ValueError("The input_shape parameter is required for PyTorch conversion.")
24-
23+
raise ValueError('The input_shape parameter is required for PyTorch conversion.')
24+
2525
# Check load methods:
2626
# 1. model_path (load from file)
2727
# 2. module + model_name (load from torch module)
2828
has_model_path = self.model_path is not None and self.model_path != ''
29-
has_module_model = (self.module is not None and self.module != '' and
30-
self.model_name is not None and self.model_name != '')
31-
29+
has_module_model = (self.module is not None
30+
and self.module != ''
31+
and self.model_name is not None
32+
and self.model_name != '')
33+
3234
if not has_model_path and not has_module_model:
3335
raise ValueError(
34-
"For PyTorch conversion, you must specify either model_path, "
35-
"or torch_module and model_name"
36+
'For PyTorch conversion, you must specify either model_path, \
37+
or torch_module and model_name'
3638
)
37-
39+
3840
if has_model_path and has_module_model:
3941
raise ValueError(
40-
"Provided incompatible parameters for PyTorch conversion (model_path and torch_module+model_name). "
41-
"Please choose only one method of this."
42+
'Provided incompatible parameters for PyTorch conversion (model_path and torch_module+model_name). \
43+
Please choose only one method of this.'
4244
)
4345

4446
if has_model_path and not os.path.exists(self.model_path):
45-
raise FileNotFoundError(f"Model file not found: {self.model_path}")
47+
raise FileNotFoundError(f'Model file not found: {self.model_path}')
4648

47-
if (self.model_weights is not None and self.model_weights != '' and
48-
not os.path.exists(self.model_weights)):
49-
raise FileNotFoundError(f"Model weights not found: {self.model_weights}")
49+
if (self.model_weights is not None and self.model_weights != ''
50+
and not os.path.exists(self.model_weights)):
51+
raise FileNotFoundError(f'Model weights not found: {self.model_weights}')
5052

5153
def __get_model_from_path(self):
5254
self.log.info(f'Loading model from path {self.model_path}')

src/model_converters/iree_converter/iree_compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
log = configure_logger()
1212

13+
1314
def cli_argument_parser():
1415
parser = argparse.ArgumentParser()
1516
parser.add_argument('-m', '--mlir',

0 commit comments

Comments
 (0)