Skip to content

Commit

Permalink
feat: add linspace func rewrite
Browse files Browse the repository at this point in the history
  • Loading branch information
DDGRCF committed Dec 9, 2022
1 parent 1a2b94e commit fbe6376
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 24 deletions.
4 changes: 2 additions & 2 deletions docs/en/03-benchmark/benchmark.md
Original file line number Diff line number Diff line change
Expand Up @@ -911,15 +911,15 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](../
<td align="center" rowspan="2">Instance Segmentation</td>
<td align="center" rowspan="2">COCO2017</td>
<td align="center">mask AP</td>
<td align="center">33.1</td>
<td align="center">-</td>
<td align="center">-</td>
<td align="center">-</td>
<td align="center">-</td>
<td align="center">-</td>
<td align="center">-</td>
<td align="center">-</td>
<td align="center">-</td>
<td align="center">35.4</td>
<td align="center">32.7</td>
</tr>
</tbody>
</table>
Expand Down
4 changes: 2 additions & 2 deletions docs/zh_cn/03-benchmark/benchmark.md
Original file line number Diff line number Diff line change
Expand Up @@ -906,15 +906,15 @@ GPU: ncnn, TensorRT, PPLNN
<td align="center" rowspan="2">Instance Segmentation</td>
<td align="center" rowspan="2">COCO2017</td>
<td align="center">mask AP</td>
<td align="center">33.1</td>
<td align="center">-</td>
<td align="center">-</td>
<td align="center">-</td>
<td align="center">-</td>
<td align="center">-</td>
<td align="center">-</td>
<td align="center">-</td>
<td align="center">-</td>
<td align="center">35.4</td>
<td align="center">32.7</td>
</tr>
</tbody>
</table>
Expand Down
1 change: 0 additions & 1 deletion mmdeploy/codebase/mmdet/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from . import dense_heads # noqa: F401,F403
from . import detectors # noqa: F401,F403
from . import layers # noqa: F401,F403
from . import misc # noqa: F401,F403
from . import necks # noqa: F401,F403
from . import roi_heads # noqa: F401,F403
from . import task_modules # noqa: F401,F403
Expand Down
19 changes: 0 additions & 19 deletions mmdeploy/codebase/mmdet/models/misc.py

This file was deleted.

1 change: 1 addition & 0 deletions mmdeploy/pytorch/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from . import group_norm # noqa: F401,F403
from . import interpolate # noqa: F401,F403
from . import linear # noqa: F401,F403
from . import linspace # noqa: F401,F403
from . import masked_fill # noqa: F401,F403
from . import mod # noqa: F401,F403
from . import multi_head_attention_forward # noqa: F401,F403
Expand Down
24 changes: 24 additions & 0 deletions mmdeploy/pytorch/functions/linspace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

import torch
from torch.types import Number

from mmdeploy.core import FUNCTION_REWRITER


@FUNCTION_REWRITER.register_rewriter(func_name='torch.linspace')
@FUNCTION_REWRITER.register_rewriter(func_name='torch.Tensor.linspace')
def linspace__onnx(ctx,
start: Number,
end: Number,
steps: Optional[int] = None,
**kwargs):
"""Rewrite `linspace` for onnxruntime."""
steps = 100 if steps is None else steps
if steps == 1:
output = torch.arange(start, end + 1, **kwargs)[:steps]
else:
output = torch.arange(start, end + 1, (end - start) * 1. / (steps - 1),
**kwargs)[:steps]
return output
32 changes: 32 additions & 0 deletions tests/test_pytorch/test_pytorch_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,3 +555,35 @@ def _pad_(x):
run_with_backend=True)
assert torch.allclose(
pytorch_output, rewrite_output[0], rtol=1e-3, atol=1e-5)


@backend_checker(Backend.ONNXRUNTIME)
def test_linspace_onnx():
import random

deploy_cfg_ort = Config(
dict(
onnx_config=dict(input_shape=None),
backend_config=dict(type='onnxruntime')))

def linspace_caller(*arg, **kwargs):
return torch.linspace(*arg, **kwargs)

steps_list = [None, 1, random.randint(1, 1000)]
for steps in steps_list:
start = random.random() * 100
end = random.random() * 100 + start

model_output = torch.linspace(start, end, steps)

wrapped_func = WrapFunction(
linspace_caller, start=start, end=end, steps=steps)

rewrite_outputs, _ = get_rewrite_outputs(
wrapped_func,
model_inputs={},
deploy_cfg=deploy_cfg_ort,
run_with_backend=True)

assert np.allclose(
model_output, rewrite_outputs, rtol=1e-03, atol=1e-05)

0 comments on commit fbe6376

Please sign in to comment.