|
| 1 | +# Copyright (c) OpenMMLab. All rights reserved. |
| 2 | +import argparse |
| 3 | +import logging |
| 4 | +import os |
| 5 | +import os.path as osp |
| 6 | +from functools import partial |
| 7 | + |
| 8 | +import mmcv |
| 9 | +import torch.multiprocessing as mp |
| 10 | +from torch.multiprocessing import Process, set_start_method |
| 11 | + |
| 12 | +from mmdeploy.apis import (create_calib_input_data, extract_model, |
| 13 | + get_predefined_partition_cfg, torch2onnx, |
| 14 | + torch2torchscript, visualize_model) |
| 15 | +from mmdeploy.apis.core import PIPELINE_MANAGER |
| 16 | +from mmdeploy.apis.utils import to_backend |
| 17 | +from mmdeploy.backend.sdk.export_info import export2SDK |
| 18 | +from mmdeploy.utils import (IR, Backend, get_backend, get_calib_filename, |
| 19 | + get_ir_config, get_partition_config, |
| 20 | + get_root_logger, load_config, target_wrapper) |
| 21 | + |
| 22 | +import mmcv_custom |
| 23 | +import mmdet_custom |
| 24 | + |
| 25 | +def parse_args(): |
| 26 | + parser = argparse.ArgumentParser(description='Export model to backends.') |
| 27 | + parser.add_argument('deploy_cfg', help='deploy config path') |
| 28 | + parser.add_argument('model_cfg', help='model config path') |
| 29 | + parser.add_argument('checkpoint', help='model checkpoint path') |
| 30 | + parser.add_argument('img', help='image used to convert model model') |
| 31 | + parser.add_argument( |
| 32 | + '--test-img', |
| 33 | + default=None, |
| 34 | + type=str, |
| 35 | + nargs='+', |
| 36 | + help='image used to test model') |
| 37 | + parser.add_argument( |
| 38 | + '--work-dir', |
| 39 | + default=os.getcwd(), |
| 40 | + help='the dir to save logs and models') |
| 41 | + parser.add_argument( |
| 42 | + '--calib-dataset-cfg', |
| 43 | + help=('dataset config path used to calibrate in int8 mode. If not ' |
| 44 | + 'specified, it will use "val" dataset in model config instead.'), |
| 45 | + default=None) |
| 46 | + parser.add_argument( |
| 47 | + '--device', help='device used for conversion', default='cpu') |
| 48 | + parser.add_argument( |
| 49 | + '--log-level', |
| 50 | + help='set log level', |
| 51 | + default='INFO', |
| 52 | + choices=list(logging._nameToLevel.keys())) |
| 53 | + parser.add_argument( |
| 54 | + '--show', action='store_true', help='Show detection outputs') |
| 55 | + parser.add_argument( |
| 56 | + '--dump-info', action='store_true', help='Output information for SDK') |
| 57 | + parser.add_argument( |
| 58 | + '--quant-image-dir', |
| 59 | + default=None, |
| 60 | + help='Image directory for quantize model.') |
| 61 | + parser.add_argument( |
| 62 | + '--quant', action='store_true', help='Quantize model to low bit.') |
| 63 | + parser.add_argument( |
| 64 | + '--uri', |
| 65 | + default='192.168.1.1:60000', |
| 66 | + help='Remote ipv4:port or ipv6:port for inference on edge device.') |
| 67 | + args = parser.parse_args() |
| 68 | + return args |
| 69 | + |
| 70 | + |
| 71 | +def create_process(name, target, args, kwargs, ret_value=None): |
| 72 | + logger = get_root_logger() |
| 73 | + logger.info(f'{name} start.') |
| 74 | + log_level = logger.level |
| 75 | + |
| 76 | + wrap_func = partial(target_wrapper, target, log_level, ret_value) |
| 77 | + |
| 78 | + process = Process(target=wrap_func, args=args, kwargs=kwargs) |
| 79 | + process.start() |
| 80 | + process.join() |
| 81 | + |
| 82 | + if ret_value is not None: |
| 83 | + if ret_value.value != 0: |
| 84 | + logger.error(f'{name} failed.') |
| 85 | + exit(1) |
| 86 | + else: |
| 87 | + logger.info(f'{name} success.') |
| 88 | + |
| 89 | + |
| 90 | +def torch2ir(ir_type: IR): |
| 91 | + """Return the conversion function from torch to the intermediate |
| 92 | + representation. |
| 93 | +
|
| 94 | + Args: |
| 95 | + ir_type (IR): The type of the intermediate representation. |
| 96 | + """ |
| 97 | + if ir_type == IR.ONNX: |
| 98 | + return torch2onnx |
| 99 | + elif ir_type == IR.TORCHSCRIPT: |
| 100 | + return torch2torchscript |
| 101 | + else: |
| 102 | + raise KeyError(f'Unexpected IR type {ir_type}') |
| 103 | + |
| 104 | + |
| 105 | +def main(): |
| 106 | + args = parse_args() |
| 107 | + set_start_method('spawn', force=True) |
| 108 | + logger = get_root_logger() |
| 109 | + log_level = logging.getLevelName(args.log_level) |
| 110 | + logger.setLevel(log_level) |
| 111 | + |
| 112 | + pipeline_funcs = [ |
| 113 | + torch2onnx, torch2torchscript, extract_model, create_calib_input_data |
| 114 | + ] |
| 115 | + PIPELINE_MANAGER.enable_multiprocess(True, pipeline_funcs) |
| 116 | + PIPELINE_MANAGER.set_log_level(log_level, pipeline_funcs) |
| 117 | + |
| 118 | + deploy_cfg_path = args.deploy_cfg |
| 119 | + model_cfg_path = args.model_cfg |
| 120 | + checkpoint_path = args.checkpoint |
| 121 | + quant = args.quant |
| 122 | + quant_image_dir = args.quant_image_dir |
| 123 | + |
| 124 | + # load deploy_cfg |
| 125 | + deploy_cfg, model_cfg = load_config(deploy_cfg_path, model_cfg_path) |
| 126 | + |
| 127 | + # create work_dir if not |
| 128 | + mmcv.mkdir_or_exist(osp.abspath(args.work_dir)) |
| 129 | + |
| 130 | + if args.dump_info: |
| 131 | + export2SDK( |
| 132 | + deploy_cfg, |
| 133 | + model_cfg, |
| 134 | + args.work_dir, |
| 135 | + pth=checkpoint_path, |
| 136 | + device=args.device) |
| 137 | + |
| 138 | + ret_value = mp.Value('d', 0, lock=False) |
| 139 | + |
| 140 | + # convert to IR |
| 141 | + ir_config = get_ir_config(deploy_cfg) |
| 142 | + ir_save_file = ir_config['save_file'] |
| 143 | + ir_type = IR.get(ir_config['type']) |
| 144 | + torch2ir(ir_type)( |
| 145 | + args.img, |
| 146 | + args.work_dir, |
| 147 | + ir_save_file, |
| 148 | + deploy_cfg_path, |
| 149 | + model_cfg_path, |
| 150 | + checkpoint_path, |
| 151 | + device=args.device) |
| 152 | + |
| 153 | + # convert backend |
| 154 | + ir_files = [osp.join(args.work_dir, ir_save_file)] |
| 155 | + |
| 156 | + # partition model |
| 157 | + partition_cfgs = get_partition_config(deploy_cfg) |
| 158 | + |
| 159 | + if partition_cfgs is not None: |
| 160 | + |
| 161 | + if 'partition_cfg' in partition_cfgs: |
| 162 | + partition_cfgs = partition_cfgs.get('partition_cfg', None) |
| 163 | + else: |
| 164 | + assert 'type' in partition_cfgs |
| 165 | + partition_cfgs = get_predefined_partition_cfg( |
| 166 | + deploy_cfg, partition_cfgs['type']) |
| 167 | + |
| 168 | + origin_ir_file = ir_files[0] |
| 169 | + ir_files = [] |
| 170 | + for partition_cfg in partition_cfgs: |
| 171 | + save_file = partition_cfg['save_file'] |
| 172 | + save_path = osp.join(args.work_dir, save_file) |
| 173 | + start = partition_cfg['start'] |
| 174 | + end = partition_cfg['end'] |
| 175 | + dynamic_axes = partition_cfg.get('dynamic_axes', None) |
| 176 | + |
| 177 | + extract_model( |
| 178 | + origin_ir_file, |
| 179 | + start, |
| 180 | + end, |
| 181 | + dynamic_axes=dynamic_axes, |
| 182 | + save_file=save_path) |
| 183 | + |
| 184 | + ir_files.append(save_path) |
| 185 | + |
| 186 | + # calib data |
| 187 | + calib_filename = get_calib_filename(deploy_cfg) |
| 188 | + if calib_filename is not None: |
| 189 | + calib_path = osp.join(args.work_dir, calib_filename) |
| 190 | + create_calib_input_data( |
| 191 | + calib_path, |
| 192 | + deploy_cfg_path, |
| 193 | + model_cfg_path, |
| 194 | + checkpoint_path, |
| 195 | + dataset_cfg=args.calib_dataset_cfg, |
| 196 | + dataset_type='val', |
| 197 | + device=args.device) |
| 198 | + |
| 199 | + backend_files = ir_files |
| 200 | + # convert backend |
| 201 | + backend = get_backend(deploy_cfg) |
| 202 | + |
| 203 | + # preprocess deploy_cfg |
| 204 | + if backend == Backend.RKNN: |
| 205 | + # TODO: Add this to task_processor in the future |
| 206 | + import tempfile |
| 207 | + |
| 208 | + from mmdeploy.utils import (get_common_config, get_normalization, |
| 209 | + get_quantization_config, |
| 210 | + get_rknn_quantization) |
| 211 | + quantization_cfg = get_quantization_config(deploy_cfg) |
| 212 | + common_params = get_common_config(deploy_cfg) |
| 213 | + if get_rknn_quantization(deploy_cfg) is True: |
| 214 | + transform = get_normalization(model_cfg) |
| 215 | + common_params.update( |
| 216 | + dict( |
| 217 | + mean_values=[transform['mean']], |
| 218 | + std_values=[transform['std']])) |
| 219 | + |
| 220 | + dataset_file = tempfile.NamedTemporaryFile(suffix='.txt').name |
| 221 | + with open(dataset_file, 'w') as f: |
| 222 | + f.writelines([osp.abspath(args.img)]) |
| 223 | + quantization_cfg.setdefault('dataset', dataset_file) |
| 224 | + if backend == Backend.ASCEND: |
| 225 | + # TODO: Add this to backend manager in the future |
| 226 | + if args.dump_info: |
| 227 | + from mmdeploy.backend.ascend import update_sdk_pipeline |
| 228 | + update_sdk_pipeline(args.work_dir) |
| 229 | + |
| 230 | + # convert to backend |
| 231 | + PIPELINE_MANAGER.set_log_level(log_level, [to_backend]) |
| 232 | + if backend == Backend.TENSORRT: |
| 233 | + PIPELINE_MANAGER.enable_multiprocess(True, [to_backend]) |
| 234 | + backend_files = to_backend( |
| 235 | + backend, |
| 236 | + ir_files, |
| 237 | + work_dir=args.work_dir, |
| 238 | + deploy_cfg=deploy_cfg, |
| 239 | + log_level=log_level, |
| 240 | + device=args.device, |
| 241 | + uri=args.uri) |
| 242 | + |
| 243 | + # ncnn quantization |
| 244 | + if backend == Backend.NCNN and quant: |
| 245 | + from onnx2ncnn_quant_table import get_table |
| 246 | + |
| 247 | + from mmdeploy.apis.ncnn import get_quant_model_file, ncnn2int8 |
| 248 | + model_param_paths = backend_files[::2] |
| 249 | + model_bin_paths = backend_files[1::2] |
| 250 | + backend_files = [] |
| 251 | + for onnx_path, model_param_path, model_bin_path in zip( |
| 252 | + ir_files, model_param_paths, model_bin_paths): |
| 253 | + |
| 254 | + deploy_cfg, model_cfg = load_config(deploy_cfg_path, |
| 255 | + model_cfg_path) |
| 256 | + quant_onnx, quant_table, quant_param, quant_bin = get_quant_model_file( # noqa: E501 |
| 257 | + onnx_path, args.work_dir) |
| 258 | + |
| 259 | + create_process( |
| 260 | + 'ncnn quant table', |
| 261 | + target=get_table, |
| 262 | + args=(onnx_path, deploy_cfg, model_cfg, quant_onnx, |
| 263 | + quant_table, quant_image_dir, args.device), |
| 264 | + kwargs=dict(), |
| 265 | + ret_value=ret_value) |
| 266 | + |
| 267 | + create_process( |
| 268 | + 'ncnn_int8', |
| 269 | + target=ncnn2int8, |
| 270 | + args=(model_param_path, model_bin_path, quant_table, |
| 271 | + quant_param, quant_bin), |
| 272 | + kwargs=dict(), |
| 273 | + ret_value=ret_value) |
| 274 | + backend_files += [quant_param, quant_bin] |
| 275 | + |
| 276 | + if args.test_img is None: |
| 277 | + args.test_img = args.img |
| 278 | + |
| 279 | + extra = dict( |
| 280 | + backend=backend, |
| 281 | + output_file=osp.join(args.work_dir, f'output_{backend.value}.jpg'), |
| 282 | + show_result=args.show) |
| 283 | + if backend == Backend.SNPE: |
| 284 | + extra['uri'] = args.uri |
| 285 | + |
| 286 | + # get backend inference result, try render |
| 287 | + create_process( |
| 288 | + f'visualize {backend.value} model', |
| 289 | + target=visualize_model, |
| 290 | + args=(model_cfg_path, deploy_cfg_path, backend_files, args.test_img, |
| 291 | + args.device), |
| 292 | + kwargs=extra, |
| 293 | + ret_value=ret_value) |
| 294 | + |
| 295 | + # get pytorch model inference result, try visualize if possible |
| 296 | + create_process( |
| 297 | + 'visualize pytorch model', |
| 298 | + target=visualize_model, |
| 299 | + args=(model_cfg_path, deploy_cfg_path, [checkpoint_path], |
| 300 | + args.test_img, args.device), |
| 301 | + kwargs=dict( |
| 302 | + backend=Backend.PYTORCH, |
| 303 | + output_file=osp.join(args.work_dir, 'output_pytorch.jpg'), |
| 304 | + show_result=args.show), |
| 305 | + ret_value=ret_value) |
| 306 | + logger.info('All process success.') |
| 307 | + |
| 308 | + |
| 309 | +if __name__ == '__main__': |
| 310 | + main() |
0 commit comments