From b92ca14c85bf0b64605a58b58eedce1a64992d44 Mon Sep 17 00:00:00 2001 From: Thiago Crepaldi Date: Wed, 13 Jul 2022 18:16:19 -0400 Subject: [PATCH] Add ONNX support to KRCNNConvDeconvUpsampleHead The fix involves changes to both PyTorch and detectron2: * Pytorch had a bug which prevented some tensors to be identified as float (refer to https://github.com/pytorch/pytorch/pull/81386) * `detectron2/structures/keypoints.py::heatmaps_to_keypoints` internally does advanced indexing on a `squeeze`d `roi_map`. The aforentioned `squeeze` fails rank inference due to the presence of `onnx::If` on its composition to support dynamic dims. By replacing `squeeze` by `replace` on detectron's `heatmaps_to_keypoints`, shape inference succeeds, allowing ONNX export to succeed with dynamic axes support. --- detectron2/structures/instances.py | 2 - detectron2/structures/keypoints.py | 14 +++---- detectron2/utils/testing.py | 15 ++++++++ tests/test_export_onnx.py | 62 ++++++++++++++++++++++++++++-- 4 files changed, 79 insertions(+), 14 deletions(-) diff --git a/detectron2/structures/instances.py b/detectron2/structures/instances.py index 519d5de7e9..c9579bce27 100644 --- a/detectron2/structures/instances.py +++ b/detectron2/structures/instances.py @@ -4,8 +4,6 @@ from typing import Any, Dict, List, Tuple, Union import torch -from detectron2.structures import Boxes - class Instances: """ diff --git a/detectron2/structures/keypoints.py b/detectron2/structures/keypoints.py index d0ee8724ac..b93ebed4f6 100644 --- a/detectron2/structures/keypoints.py +++ b/detectron2/structures/keypoints.py @@ -179,10 +179,6 @@ def heatmaps_to_keypoints(maps: torch.Tensor, rois: torch.Tensor) -> torch.Tenso we maintain consistency with :meth:`Keypoints.to_heatmap` by using the conversion from Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a continuous coordinate. """ - # The decorator use of torch.no_grad() was not supported by torchscript. - # https://github.com/pytorch/pytorch/issues/44768 - maps = maps.detach() - rois = rois.detach() offset_x = rois[:, 0] offset_y = rois[:, 1] @@ -202,11 +198,11 @@ def heatmaps_to_keypoints(maps: torch.Tensor, rois: torch.Tensor) -> torch.Tenso for i in range(num_rois): outsize = (int(heights_ceil[i]), int(widths_ceil[i])) - roi_map = F.interpolate( - maps[[i]], size=outsize, mode="bicubic", align_corners=False - ).squeeze( - 0 - ) # #keypoints x H x W + roi_map = F.interpolate(maps[[i]], size=outsize, mode="bicubic", align_corners=False) + + # Although semantically equivalent, `reshape` is used instead of `squeeze` due + # to limitation during ONNX export of `squeeze` in scripting mode + roi_map = roi_map.reshape(roi_map.shape[1:]) # keypoints x H x W # softmax over the spatial region max_score, _ = roi_map.view(num_keypoints, -1).max(1) diff --git a/detectron2/utils/testing.py b/detectron2/utils/testing.py index b597ed92fd..3f5b9dbe44 100644 --- a/detectron2/utils/testing.py +++ b/detectron2/utils/testing.py @@ -176,6 +176,21 @@ def min_torch_version(min_version: str) -> bool: return installed_version >= min_version +def has_dynamic_axes(onnx_model): + """ + Return True when all ONNX input/output have only dynamic axes for all ranks + """ + return all( + not dim.dim_param.isnumeric() + for inp in onnx_model.graph.input + for dim in inp.type.tensor_type.shape.dim + ) and all( + not dim.dim_param.isnumeric() + for out in onnx_model.graph.output + for dim in out.type.tensor_type.shape.dim + ) + + def register_custom_op_onnx_export( opname: str, symbolic_fn: Callable, opset_version: int, min_version: str ) -> None: diff --git a/tests/test_export_onnx.py b/tests/test_export_onnx.py index ffab02ae48..aa15e1a406 100644 --- a/tests/test_export_onnx.py +++ b/tests/test_export_onnx.py @@ -10,11 +10,17 @@ from detectron2.config import get_cfg from detectron2.export import STABLE_ONNX_OPSET_VERSION from detectron2.export.flatten import TracingAdapter +from detectron2.export.torchscript_patch import patch_builtin_len +from detectron2.layers import ShapeSpec from detectron2.modeling import build_model +from detectron2.modeling.roi_heads import KRCNNConvDeconvUpsampleHead +from detectron2.structures import Boxes, Instances from detectron2.utils.testing import ( _pytorch1111_symbolic_opset9_repeat_interleave, _pytorch1111_symbolic_opset9_to, get_sample_coco_image, + has_dynamic_axes, + random_boxes, register_custom_op_onnx_export, skipIfOnCPUCI, skipIfUnsupportedMinOpsetVersion, @@ -26,6 +32,8 @@ @unittest.skipIf(not _check_module_exists("onnx"), "ONNX not installed.") @skipIfUnsupportedMinTorchVersion("1.10") class TestONNXTracingExport(unittest.TestCase): + opset_version = STABLE_ONNX_OPSET_VERSION + def testMaskRCNNFPN(self): def inference_func(model, images): with warnings.catch_warnings(record=True): @@ -85,9 +93,55 @@ def inference_func(model, image): self._test_model_zoo_from_config_path( "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml", inference_func, - opset_version=STABLE_ONNX_OPSET_VERSION, ) + def testKeypointHead(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.model = KRCNNConvDeconvUpsampleHead( + ShapeSpec(channels=4, height=14, width=14), num_keypoints=17, conv_dims=(4,) + ) + + def forward(self, x, predbox1, predbox2): + inst = [ + Instances((100, 100), pred_boxes=Boxes(predbox1)), + Instances((100, 100), pred_boxes=Boxes(predbox2)), + ] + ret = self.model(x, inst) + return tuple(x.pred_keypoints for x in ret) + + model = M() + model.eval() + + def gen_input(num1, num2): + feat = torch.randn((num1 + num2, 4, 14, 14)) + box1 = random_boxes(num1) + box2 = random_boxes(num2) + return feat, box1, box2 + + with patch_builtin_len(): + onnx_model = self._test_model( + model, + gen_input(1, 2), + input_names=["features", "pred_boxes", "pred_classes"], + output_names=["box1", "box2"], + dynamic_axes={ + "features": {0: "batch", 1: "static_four", 2: "height", 3: "width"}, + "pred_boxes": {0: "batch", 1: "static_four"}, + "pred_classes": {0: "batch", 1: "static_four"}, + "box1": {0: "num_instance", 1: "K", 2: "static_three"}, + "box2": {0: "num_instance", 1: "K", 2: "static_three"}, + }, + ) + + # Although ONNX models are not executable by PyTorch to verify + # support of batches with different sizes, we can verify model's IR + # does not hard-code input and/or output shapes. + # TODO: Add tests with different batch sizes when detectron2's CI + # support ONNX Runtime backend. + assert has_dynamic_axes(onnx_model) + ################################################################################ # Testcase internals - DO NOT add tests below this point ################################################################################ @@ -114,6 +168,9 @@ def _test_model( save_onnx_graph_path=None, **export_kwargs, ): + # Not imported in the beginning of file to prevent runtime errors + # for environments without ONNX. + # This testcase checks dependencies before running import onnx # isort:skip f = io.BytesIO() @@ -138,6 +195,7 @@ def _test_model( assert onnx_model is not None if save_onnx_graph_path: onnx.save(onnx_model, save_onnx_graph_path) + return onnx_model def _test_model_zoo_from_config_path( self, @@ -171,9 +229,7 @@ def _test_model_from_config_path( point_rend.add_pointrend_config(cfg) cfg.merge_from_file(config_path) cfg.freeze() - model = build_model(cfg) - image = get_sample_coco_image() inputs = tuple(image.clone() for _ in range(batch)) return self._test_model(