Skip to content

Commit

Permalink
[Fix] fix coreml for branch 1.x (open-mmlab#1669)
Browse files Browse the repository at this point in the history
* fix coreml for branch 1.x

* fix docstring

* update docsting1
  • Loading branch information
grimoire authored Jan 31, 2023
1 parent 0c1adba commit dc1d9df
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 20 deletions.
44 changes: 39 additions & 5 deletions mmdeploy/backend/coreml/backend_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def is_available(cls, with_custom_ops: bool = False) -> bool:
bool: True if backend package is installed.
"""
import importlib
return importlib.util.find_spec('coreml') is not None
return importlib.util.find_spec('coremltools') is not None

@classmethod
def get_version(cls) -> str:
Expand All @@ -52,7 +52,7 @@ def get_version(cls) -> str:
else:
import pkg_resources
try:
return pkg_resources.get_distribution('coreml').version
return pkg_resources.get_distribution('coremltools').version
except Exception:
return 'None'

Expand All @@ -76,14 +76,48 @@ def to_backend(cls,
Returns:
Sequence[str]: Backend files.
"""
from .torchscript2coreml import from_torchscript
from mmdeploy.utils import (get_common_config, get_ir_config,
get_model_inputs, load_config)
from .torchscript2coreml import from_torchscript, get_model_suffix

coreml_files = []
for model_id, torchscript_path in enumerate(ir_files):
torchscript_name = osp.splitext(osp.split(torchscript_path)[1])[0]
output_file_prefix = osp.join(work_dir, torchscript_name)

from_torchscript(model_id, torchscript_path, output_file_prefix,
deploy_cfg, coreml_files)
deploy_cfg = load_config(deploy_cfg)[0]

common_params = get_common_config(deploy_cfg)
model_params = get_model_inputs(deploy_cfg)[model_id]

final_params = common_params
final_params.update(model_params)

ir_config = get_ir_config(deploy_cfg)
input_names = ir_config.get('input_names', [])
output_names = ir_config.get('output_names', [])
input_shapes = final_params['input_shapes']
compute_precision = final_params.get('compute_precision',
'FLOAT32')
convert_to = deploy_cfg.backend_config.convert_to

minimum_deployment_target = final_params.get(
'minimum_deployment_target', None)
skip_model_load = final_params.get('skip_model_load', False)

from_torchscript(
torchscript_path,
output_file_prefix,
input_names=input_names,
output_names=output_names,
input_shapes=input_shapes,
compute_precision=compute_precision,
convert_to=convert_to,
minimum_deployment_target=minimum_deployment_target,
skip_model_load=skip_model_load)

suffix = get_model_suffix(convert_to)
output_path = output_file_prefix + suffix
coreml_files.append(output_path)

return coreml_files
24 changes: 14 additions & 10 deletions mmdeploy/backend/coreml/torchscript2coreml.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Dict, Sequence, Union
from typing import Dict, Optional, Sequence, Union

import coremltools as ct
import torch
Expand Down Expand Up @@ -52,9 +52,10 @@ def from_torchscript(torchscript_model: Union[str,
output_file_prefix: str,
input_names: Sequence[str],
output_names: Sequence[str],
input_shapes: Dict,
input_shapes: Dict[str, Dict],
compute_precision: str = 'FLOAT32',
convert_to: str = 'neuralnetwork',
fp16_mode: bool = False,
minimum_deployment_target: Optional[str] = None,
skip_model_load: bool = True,
**kwargs):
"""Create a coreml engine from torchscript.
Expand All @@ -67,9 +68,12 @@ def from_torchscript(torchscript_model: Union[str,
output_names (Sequence[str]): The output names of the model.
input_shapes (Dict): The input shapes include max_shape, min_shape and
default_shape
convert_to (str, optional): The converted model type, can be
compute_precision (str): The model precision,
FLOAT16 or FLOAT32, see coremltools.precision, default `FLOAT32`.
convert_to (str): The converted model type, can be
'neuralnetwork' or 'mlprogram'. Defaults to 'neuralnetwork'.
fp16_mode (bool, optional): Convert to fp16 model. Defaults to False.
minimum_deployment_target (str, optional): minimum deploy target.
iOS15, iOS16, etc., see coremltools.target
skip_model_load (bool, optional): Skip model load. Defaults to True.
"""

Expand Down Expand Up @@ -98,19 +102,19 @@ def from_torchscript(torchscript_model: Union[str,
if convert_to == 'neuralnetwork':
compute_precision = None
else:
if fp16_mode:
compute_precision = ct.precision.FLOAT16
else:
compute_precision = ct.precision.FLOAT32
compute_precision = ct.precision[compute_precision]

mlmodel = ct.convert(
model=torchscript_model,
inputs=inputs,
outputs=outputs,
compute_precision=compute_precision,
convert_to=convert_to,
skip_model_load=False)
minimum_deployment_target=ct.target[minimum_deployment_target]
if minimum_deployment_target else None,
skip_model_load=skip_model_load)

suffix = get_model_suffix(convert_to)
output_path = output_file_prefix + suffix

mlmodel.save(output_path)
24 changes: 24 additions & 0 deletions mmdeploy/pytorch/functions/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,27 @@ def topk__tensorrt(input: torch.Tensor,
k = TENSORRT_MAX_TOPK

return ctx.origin_func(input, k, dim=dim, largest=largest, sorted=sorted)


@FUNCTION_REWRITER.register_rewriter(func_name='torch.topk', backend='coreml')
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.Tensor.topk', backend='coreml')
def topk__coreml(input: torch.Tensor,
k: int,
dim: Optional[int] = None,
largest: bool = True,
sorted: bool = True):
"""Rewrite `topk` for coreml.
Cast k to tensor and make sure k is smaller than input.shape[dim].
"""
ctx = FUNCTION_REWRITER.get_context()

if dim is None:
dim = int(input.ndim - 1)
size = input.shape[dim]
if not isinstance(k, torch.Tensor):
k = torch.tensor(k, device=input.device, dtype=torch.long)
# Always keep topk op for dynamic input
k = torch.where(k < size, k, size)
return ctx.origin_func(input, k, dim=dim, largest=largest, sorted=sorted)
34 changes: 29 additions & 5 deletions tests/test_backend/test_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def generate_torchscript_file():
context_info=context_info)


def onnx2backend(backend, onnx_file):
def ir2backend(backend, onnx_file, ts_file):
if backend == Backend.TENSORRT:
from mmdeploy.backend.tensorrt import from_onnx
backend_file = tempfile.NamedTemporaryFile(suffix='.engine').name
Expand Down Expand Up @@ -142,6 +142,33 @@ def onnx2backend(backend, onnx_file):
onnx_file, lib_file, shape=shape, dtype=dtype, tuner=tuner_dict)
assert osp.exists(lib_file)
return lib_file
elif backend == Backend.TORCHSCRIPT:
return ts_file
elif backend == Backend.COREML:
output_names = ['output']
from mmdeploy.backend.coreml.torchscript2coreml import (
from_torchscript, get_model_suffix)
backend_dir = tempfile.TemporaryDirectory().name
work_dir = backend_dir
torchscript_name = osp.splitext(osp.split(ts_file)[1])[0]
output_file_prefix = osp.join(work_dir, torchscript_name)
convert_to = 'mlprogram'
from_torchscript(
ts_file,
output_file_prefix,
input_names=input_names,
output_names=output_names,
input_shapes=dict(
input=dict(
min_shape=[1, 3, 8, 8],
default_shape=[1, 3, 8, 8],
max_shape=[1, 3, 8, 8])),
convert_to=convert_to)
suffix = get_model_suffix(convert_to)
return output_file_prefix + suffix
else:
raise NotImplementedError(
f'Convert for {backend.value} has not been implemented.')


def create_wrapper(backend, model_files):
Expand Down Expand Up @@ -184,10 +211,7 @@ def run_wrapper(backend, wrapper, input):
@pytest.mark.parametrize('backend', ALL_BACKEND)
def test_wrapper(backend):
check_backend(backend, True)
if backend == Backend.TORCHSCRIPT:
model_files = ts_file
else:
model_files = onnx2backend(backend, onnx_file)
model_files = ir2backend(backend, onnx_file, ts_file)
assert model_files is not None
wrapper = create_wrapper(backend, model_files)
assert wrapper is not None
Expand Down

0 comments on commit dc1d9df

Please sign in to comment.