Skip to content

Commit

Permalink
Add ONNX support to KRCNNConvDeconvUpsampleHead
Browse files Browse the repository at this point in the history
The fix involves changes to both PyTorch and detectron2:
	* Pytorch had a bug which prevented some tensors to be identified as float
	  (refer to pytorch/pytorch#81386)

	* `detectron2/structures/keypoints.py::heatmaps_to_keypoints` internally
	  does advanced indexing on a `squeeze`d `roi_map`.
	  The aforentioned `squeeze` fails rank inference due to the
	  presence of `onnx::If` on its composition to support dynamic
	  dims. By replacing `squeeze` by `replace` on
	  detectron's `heatmaps_to_keypoints`, shape inference succeeds,
	  allowing ONNX export to succeed with dynamic axes support.
  • Loading branch information
Thiago Crepaldi committed Feb 8, 2023
1 parent 86fc8c6 commit 1a56083
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 12 deletions.
14 changes: 5 additions & 9 deletions detectron2/structures/keypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,6 @@ def heatmaps_to_keypoints(maps: torch.Tensor, rois: torch.Tensor) -> torch.Tenso
we maintain consistency with :meth:`Keypoints.to_heatmap` by using the conversion from
Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a continuous coordinate.
"""
# The decorator use of torch.no_grad() was not supported by torchscript.
# https://github.com/pytorch/pytorch/issues/44768
maps = maps.detach()
rois = rois.detach()

offset_x = rois[:, 0]
offset_y = rois[:, 1]
Expand All @@ -202,11 +198,11 @@ def heatmaps_to_keypoints(maps: torch.Tensor, rois: torch.Tensor) -> torch.Tenso

for i in range(num_rois):
outsize = (int(heights_ceil[i]), int(widths_ceil[i]))
roi_map = F.interpolate(
maps[[i]], size=outsize, mode="bicubic", align_corners=False
).squeeze(
0
) # #keypoints x H x W
roi_map = F.interpolate(maps[[i]], size=outsize, mode="bicubic", align_corners=False)

# Although semantically equivalent, `reshape` is used instead of `squeeze` due
# to limitation during ONNX export of `squeeze` in scripting mode
roi_map = roi_map.reshape(roi_map.shape[1:]) # keypoints x H x W

# softmax over the spatial region
max_score, _ = roi_map.view(num_keypoints, -1).max(1)
Expand Down
15 changes: 15 additions & 0 deletions detectron2/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,21 @@ def min_torch_version(min_version: str) -> bool:
return installed_version >= min_version


def has_dynamic_axes(onnx_model):
"""
Return True when all ONNX input/output have only dynamic axes for all ranks
"""
return all(
not dim.dim_param.isnumeric()
for inp in onnx_model.graph.input
for dim in inp.type.tensor_type.shape.dim
) and all(
not dim.dim_param.isnumeric()
for out in onnx_model.graph.output
for dim in out.type.tensor_type.shape.dim
)


def register_custom_op_onnx_export(
opname: str, symbolic_fn: Callable, opset_version: int, min_version: str
) -> None:
Expand Down
62 changes: 59 additions & 3 deletions tests/test_export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,17 @@
from detectron2.config import get_cfg
from detectron2.export import STABLE_ONNX_OPSET_VERSION
from detectron2.export.flatten import TracingAdapter
from detectron2.export.torchscript_patch import patch_builtin_len
from detectron2.layers import ShapeSpec
from detectron2.modeling import build_model
from detectron2.modeling.roi_heads import KRCNNConvDeconvUpsampleHead
from detectron2.structures import Boxes, Instances
from detectron2.utils.testing import (
_pytorch1111_symbolic_opset9_repeat_interleave,
_pytorch1111_symbolic_opset9_to,
get_sample_coco_image,
has_dynamic_axes,
random_boxes,
register_custom_op_onnx_export,
skipIfOnCPUCI,
skipIfUnsupportedMinOpsetVersion,
Expand All @@ -26,6 +32,8 @@
@unittest.skipIf(not _check_module_exists("onnx"), "ONNX not installed.")
@skipIfUnsupportedMinTorchVersion("1.10")
class TestONNXTracingExport(unittest.TestCase):
opset_version = STABLE_ONNX_OPSET_VERSION

def testMaskRCNNFPN(self):
def inference_func(model, images):
with warnings.catch_warnings(record=True):
Expand Down Expand Up @@ -85,9 +93,55 @@ def inference_func(model, image):
self._test_model_zoo_from_config_path(
"COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml",
inference_func,
opset_version=STABLE_ONNX_OPSET_VERSION,
)

def testKeypointHead(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = KRCNNConvDeconvUpsampleHead(
ShapeSpec(channels=4, height=14, width=14), num_keypoints=17, conv_dims=(4,)
)

def forward(self, x, predbox1, predbox2):
inst = [
Instances((100, 100), pred_boxes=Boxes(predbox1)),
Instances((100, 100), pred_boxes=Boxes(predbox2)),
]
ret = self.model(x, inst)
return tuple(x.pred_keypoints for x in ret)

model = M()
model.eval()

def gen_input(num1, num2):
feat = torch.randn((num1 + num2, 4, 14, 14))
box1 = random_boxes(num1)
box2 = random_boxes(num2)
return feat, box1, box2

with patch_builtin_len():
onnx_model = self._test_model(
model,
gen_input(1, 2),
input_names=["features", "pred_boxes", "pred_classes"],
output_names=["box1", "box2"],
dynamic_axes={
"features": {0: "batch", 1: "static_four", 2: "height", 3: "width"},
"pred_boxes": {0: "batch", 1: "static_four"},
"pred_classes": {0: "batch", 1: "static_four"},
"box1": {0: "num_instance", 1: "K", 2: "static_three"},
"box2": {0: "num_instance", 1: "K", 2: "static_three"},
},
)

# Although ONNX models are not executable by PyTorch to verify
# support of batches with different sizes, we can verify model's IR
# does not hard-code input and/or output shapes.
# TODO: Add tests with different batch sizes when detectron2's CI
# support ONNX Runtime backend.
assert has_dynamic_axes(onnx_model)

################################################################################
# Testcase internals - DO NOT add tests below this point
################################################################################
Expand All @@ -114,6 +168,9 @@ def _test_model(
save_onnx_graph_path=None,
**export_kwargs,
):
# Not imported in the beginning of file to prevent runtime errors
# for environments without ONNX.
# This testcase checks dependencies before running
import onnx # isort:skip

f = io.BytesIO()
Expand All @@ -138,6 +195,7 @@ def _test_model(
assert onnx_model is not None
if save_onnx_graph_path:
onnx.save(onnx_model, save_onnx_graph_path)
return onnx_model

def _test_model_zoo_from_config_path(
self,
Expand Down Expand Up @@ -171,9 +229,7 @@ def _test_model_from_config_path(
point_rend.add_pointrend_config(cfg)
cfg.merge_from_file(config_path)
cfg.freeze()

model = build_model(cfg)

image = get_sample_coco_image()
inputs = tuple(image.clone() for _ in range(batch))
return self._test_model(
Expand Down

0 comments on commit 1a56083

Please sign in to comment.