Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions docker/IREE/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
FROM ubuntu_for_dli

# Install IREE
ARG IREE_VERSION=3.8.0
RUN python3 -m pip install iree-base-compiler==${IREE_VERSION} iree-base-runtime==${IREE_VERSION} iree-turbine==${IREE_VERSION}

# Install dependencies
RUN python3 -m pip install opencv-python numpy

# Install onnx for model conversion
ARG ONNX_VERSION=1.19.1
RUN python3 -m pip install onnx==${ONNX_VERSION}

# Install torch for model conversion
ARG TORCH_VERSION=2.9.1
ARG TORCHVISION_VERSION=0.24.1
RUN python3 -m pip install torch==${TORCH_VERSION} torchvision==${TORCHVISION_VERSION}

WORKDIR /tmp/
4 changes: 4 additions & 0 deletions requirements_frameworks.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,7 @@ paddlepaddle==2.6.0

ncnn
spektral==1.3.0

iree-base-compiler
iree-base-runtime
iree-turbine
2 changes: 1 addition & 1 deletion src/accuracy_checker/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def execute(self, idx):
command_line = self.__fill_command_line()
if command_line == '':
self.__log.error('Command line is empty')
self.__log.info(f'Start accuracy check for {idx+1} test: {self._test.model.name}')
self.__log.info(f'Start accuracy check for {idx + 1} test: {self._test.model.name}')
self.__log.info(f'Command line is : {command_line}')
self._executor.set_target_framework(self._test.framework)
command_line = self._executor.prepare_command_line(self._test, command_line)
Expand Down
2 changes: 2 additions & 0 deletions src/benchmark/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ the following frameworks:
- [RKNN][rknn].
- [Spektral][spektral] (Python API).
- [PaddlePaddle][paddlepaddle] (Python API).
- [IREE][iree] (Python API).

### Implemented algorithm

Expand Down Expand Up @@ -274,3 +275,4 @@ pip install openvino_dev[mxnet,caffe,caffe2,onnx,pytorch,tensorflow2]==<your ver
[rknn]: https://github.com/rockchip-linux/rknpu2
[spektral]: https://graphneural.network
[paddlepaddle]: https://www.paddlepaddle.org.cn/en
[iree]: https://iree.dev
3 changes: 3 additions & 0 deletions src/benchmark/config_parser_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from frameworks.ncnn.ncnn_parameters_parser import NcnnParametersParser
from frameworks.spektral.spektral_parameters_parser import SpektralParametersParser
from frameworks.executorch.executorch_parameters_parser import ExecuTorchParametersParser
from frameworks.iree.iree_parameters_parser import IREEParametersParser


def get_parameters_parser(framework):
Expand Down Expand Up @@ -57,4 +58,6 @@ def get_parameters_parser(framework):
return CppParametersParser()
if framework == KnownFrameworks.executorch:
return ExecuTorchParametersParser()
if framework == KnownFrameworks.iree:
return IREEParametersParser()
raise NotImplementedError(f'Unknown framework {framework}')
4 changes: 3 additions & 1 deletion src/benchmark/frameworks/config_parser/test_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ def prepare_framework_params(self):
match_parameter_description['compile_with_backend'] = 'Pytorch compile backend'

match_parameter_description['high_level_api'] = 'TVM HighLevelAPI'
match_parameter_description['opt_level'] = 'TVM OptimizationLevel'
match_parameter_description['opt_level'] = 'Optimization level'

match_parameter_description['extra_compile_args'] = 'Extra compile args'

for parameter, description in match_parameter_description.items():
if hasattr(self.dep_parameters, parameter) and getattr(self.dep_parameters, parameter) is not None:
Expand Down
2 changes: 2 additions & 0 deletions src/benchmark/frameworks/framework_wrapper_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .rknn.rknn_wrapper import RknnWrapper
from .executorch_cpp.executorch_cpp_wrapper import ExecuTorchCppWrapper
from .executorch.executorch_wrapper import ExecuTorchWrapper
from .iree.iree_wrapper import IREEWrapper


class FrameworkWrapperRegistry(metaclass=Singleton):
Expand Down Expand Up @@ -62,3 +63,4 @@ def _get_wrappers(self):
self._framework_wrappers[RknnWrapper.framework_name] = RknnWrapper()
self._framework_wrappers[ExecuTorchCppWrapper.framework_name] = ExecuTorchCppWrapper()
self._framework_wrappers[ExecuTorchWrapper.framework_name] = ExecuTorchWrapper()
self._framework_wrappers[IREEWrapper.framework_name] = IREEWrapper()
Empty file.
84 changes: 84 additions & 0 deletions src/benchmark/frameworks/iree/iree_parameters_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from ..config_parser.dependent_parameters_parser import DependentParametersParser
from ..config_parser.framework_parameters_parser import FrameworkParameters


class IREEParametersParser(DependentParametersParser):
CONFIG_FRAMEWORK_DEPENDENT_TAG = 'FrameworkDependent'
TAG_FUNCTION_NAME = 'FunctionName'
TAG_INPUT_SHAPE = 'InputShape'
TAG_LAYOUT = 'Layout'
TAG_NORMALIZE = 'Normalize'
TAG_MEAN = 'Mean'
TAG_STD = 'Std'
TAG_CHANNEL_SWAP = 'ChannelSwap'
TAG_TARGET_BACKEND = 'TargetBackend'
TAG_OPTIMIZATION_LEVEL = 'OptimizationLevel'
TAG_ONNX_OPSET = 'OnnxOpsetVersion'
TAG_EXTRA_COMPILE_ARGS = 'ExtraCompileArgs'

def parse_parameters(self, curr_test):
dep_parameters_tag = curr_test.getElementsByTagName(self.CONFIG_FRAMEWORK_DEPENDENT_TAG)[0]

def _read_tag(tag_name):
tag_nodes = dep_parameters_tag.getElementsByTagName(tag_name)
if not tag_nodes:
return None
node = tag_nodes[0].firstChild
return node.data.strip() if node else None

return IREEParameters(
function_name=_read_tag(self.TAG_FUNCTION_NAME),
input_shape=_read_tag(self.TAG_INPUT_SHAPE),
layout=_read_tag(self.TAG_LAYOUT),
normalize=_read_tag(self.TAG_NORMALIZE),
mean=_read_tag(self.TAG_MEAN),
std=_read_tag(self.TAG_STD),
channel_swap=_read_tag(self.TAG_CHANNEL_SWAP),
target_backend=_read_tag(self.TAG_TARGET_BACKEND),
optimization_level=_read_tag(self.TAG_OPTIMIZATION_LEVEL),
onnx_opset_version=_read_tag(self.TAG_ONNX_OPSET),
extra_compile_args=_read_tag(self.TAG_EXTRA_COMPILE_ARGS),
)


class IREEParameters(FrameworkParameters):
def __init__(self, function_name, input_shape, layout, normalize, mean, std, channel_swap,
target_backend, optimization_level, onnx_opset_version, extra_compile_args):
self.function_name = None
self.input_shape = None
self.layout = 'NHWC'
self.normalize = None
self.mean = None
self.std = None
self.channel_swap = None
self.target_backend = 'llvm-cpu'
self.opt_level = '2'
self.onnx_opset_version = None
self.extra_compile_args = None

if not self._parameter_is_not_none(function_name):
raise ValueError('FunctionName is a required parameter for IREE benchmark tests.')
self.function_name = function_name

if not self._parameter_is_not_none(input_shape):
raise ValueError('InputShape is a required parameter for IREE benchmark tests.')
self.input_shape = input_shape

if self._parameter_is_not_none(layout):
self.layout = layout
if self._parameter_is_not_none(normalize):
self.normalize = normalize
if self._parameter_is_not_none(mean):
self.mean = mean
if self._parameter_is_not_none(std):
self.std = std
if self._parameter_is_not_none(channel_swap):
self.channel_swap = channel_swap
if self._parameter_is_not_none(target_backend):
self.target_backend = target_backend
if self._parameter_is_not_none(optimization_level):
self.opt_level = optimization_level
if self._parameter_is_not_none(onnx_opset_version):
self.onnx_opset_version = onnx_opset_version
if self._parameter_is_not_none(extra_compile_args):
self.extra_compile_args = extra_compile_args
124 changes: 124 additions & 0 deletions src/benchmark/frameworks/iree/iree_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from pathlib import Path

from ..processes import ProcessHandler


class IREEProcess(ProcessHandler):
benchmark_app_name = 'iree_python_benchmark'
launcher_latency_units = 'seconds'

def __init__(self, test, executor, log):
super().__init__(test, executor, log)
self.path_to_script = Path.joinpath(self.inference_script_root, 'inference_iree.py')

@staticmethod
def create_process(test, executor, log):
return IREEProcess(test, executor, log)

def get_performance_metrics(self):
return self.get_performance_metrics_from_json_report()

def _fill_command_line(self):
python = ProcessHandler.get_cmd_python_version(self._test)
arguments = self._compose_arguments()
return f'{python} {self.path_to_script} {arguments}'.strip()

def _compose_arguments(self):
model = self._test.model
dep = self._test.dep_parameters
indep = self._test.indep_parameters

dataset_path = self._normalize_optional(self._test.dataset.path if self._test.dataset else None)
model_path = self._normalize_optional(model.model)
weights_path = self._normalize_optional(model.weight)

command = (f'-fn {dep.function_name} -is {dep.input_shape} -ni {indep.iteration} '
f'--report_path {self.report_path}')

command = self._add_optional_argument_to_cmd_line(command, '-mn', model.name)

source_framework = self._get_source_framework(model.source_framework)
command = self._add_optional_argument_to_cmd_line(command, '-f', source_framework)

command = self._add_optional_argument_to_cmd_line(command, '-m', model_path)
command = self._add_optional_argument_to_cmd_line(command, '-w', weights_path)

module_path = self._normalize_optional(model.module)
command = self._add_optional_argument_to_cmd_line(command, '-tm', module_path)
command = self._add_optional_argument_to_cmd_line(command, '-i', dataset_path)
command = self._add_optional_argument_to_cmd_line(command, '-b', indep.batch_size)

task_type = self._resolve_task_type(model)
command = self._add_optional_argument_to_cmd_line(command, '--task', task_type)

time_limit = indep.test_time_limit
command = self._add_optional_argument_to_cmd_line(command, '--time', time_limit)

layout = self._normalize_optional(dep.layout)
command = self._add_optional_argument_to_cmd_line(command, '--layout', layout)
if self._parameter_is_true(dep.normalize):
command = self._add_flag_to_cmd_line(command, '--norm')

mean = self._normalize_optional(dep.mean)
std = self._normalize_optional(dep.std)
channel_swap = self._normalize_optional(dep.channel_swap)
command = self._add_optional_argument_to_cmd_line(command, '--mean', mean)
command = self._add_optional_argument_to_cmd_line(command, '--std', std)
command = self._add_optional_argument_to_cmd_line(command, '--channel_swap', channel_swap)

target_backend = self._normalize_optional(dep.target_backend) or 'llvm-cpu'
command = self._add_optional_argument_to_cmd_line(command, '-tb', target_backend)

opt_level = self._normalize_optional(dep.opt_level) or '2'
command = self._add_optional_argument_to_cmd_line(command, '--opt_level', opt_level)

onnx_opset = self._normalize_optional(dep.onnx_opset_version)
command = self._add_optional_argument_to_cmd_line(command, '--onnx_opset_version', onnx_opset)

if indep.raw_output:
command = self._add_argument_to_cmd_line(command, '--raw_output', indep.raw_output)

extra_compile_args = self._normalize_optional(dep.extra_compile_args)

if extra_compile_args:
command = f'{command} --extra_compile_args {extra_compile_args}'

return command.strip()

@staticmethod
def _normalize_optional(value):
if value is None:
return None
string_value = str(value).strip()
if not string_value or string_value.lower() == 'none':
return None
return string_value

@staticmethod
def _parameter_is_true(value):
if value is None:
return False
return str(value).strip().lower() in ['true', '1', 'yes']

@staticmethod
def _get_source_framework(value):
normalized_value = IREEProcess._normalize_optional(value)
if not normalized_value:
return None
normalized_value = normalized_value.lower()
allowed_frameworks = {'onnx', 'pytorch'}
if normalized_value in allowed_frameworks:
return normalized_value
return None

@staticmethod
def _resolve_task_type(model):
candidate = getattr(model, 'task', None)
normalized_candidate = IREEProcess._normalize_optional(candidate)
if not normalized_candidate:
return None
normalized_candidate = normalized_candidate.lower()
allowed_tasks = {'feedforward', 'classification'}
if normalized_candidate in allowed_tasks:
return normalized_candidate
return None
16 changes: 16 additions & 0 deletions src/benchmark/frameworks/iree/iree_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from ..config_parser.test_reporter import Test
from ..framework_wrapper import FrameworkWrapper
from ..known_frameworks import KnownFrameworks
from .iree_process import IREEProcess


class IREEWrapper(FrameworkWrapper):
framework_name = KnownFrameworks.iree

@staticmethod
def create_process(test, executor, log, **kwargs):
return IREEProcess.create_process(test, executor, log)

@staticmethod
def create_test(model, dataset, indep_parameters, dep_parameters):
return Test(model, dataset, indep_parameters, dep_parameters)
1 change: 1 addition & 0 deletions src/benchmark/frameworks/known_frameworks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ class KnownFrameworks:
ncnn = 'ncnn'
executorch_cpp = 'ExecuTorch Cpp'
executorch = 'ExecuTorch'
iree = 'IREE'
2 changes: 2 additions & 0 deletions src/benchmark/tests/test_processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from src.benchmark.frameworks.tensorflow.tensorflow_process import TensorFlowProcess
from src.benchmark.frameworks.tensorflow_lite.tensorflow_lite_process import TensorFlowLiteProcess
from src.benchmark.tests.test_executor import get_host_executor
from src.benchmark.frameworks.iree.iree_process import IREEProcess

log.basicConfig(
format='[ %(levelname)s ] %(message)s',
Expand Down Expand Up @@ -70,6 +71,7 @@ class DotDict(dict):
['OpenCV DNN Python', OpenCVDNNPythonProcess],
['ONNX Runtime Python', ONNXRuntimePythonProcess],
['TVM', TVMProcess],
['IREE', IREEProcess],
])
@pytest.mark.parametrize('complex_test', [['sync', 'handwritten', None, SyncOpenVINOProcess],
['async', 'handwritten', None, AsyncOpenVINOProcess],
Expand Down
Loading
Loading