Skip to content

Commit

Permalink
Add ONNX support to KRCNNConvDeconvUpsampleHead
Browse files Browse the repository at this point in the history
The fix involves changes to both PyTorch and detectron2:
	* Pytorch had a bug which prevented some tensors to be identified as float
	  (refer to pytorch/pytorch#81386)

	* `detectron2/structures/keypoints.py::heatmaps_to_keypoints` internally
	  does advanced indexing on a `squeeze`d `roi_map`.
	  The aforentioned `squeeze` fails rank inference due to the
	  presence of `onnx::If` on its composition to support dynamic
	  dims. By replacing `squeeze` by `replace` on
	  detectron's `heatmaps_to_keypoints`, shape inference succeeds,
	  allowing ONNX export to succeed with dynamic axes support.
  • Loading branch information
Thiago Crepaldi committed Jul 13, 2022
1 parent 453d795 commit d4d3ec2
Show file tree
Hide file tree
Showing 2 changed files with 213 additions and 7 deletions.
12 changes: 5 additions & 7 deletions detectron2/structures/keypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,6 @@ def heatmaps_to_keypoints(maps: torch.Tensor, rois: torch.Tensor) -> torch.Tenso
we maintain consistency with :meth:`Keypoints.to_heatmap` by using the conversion from
Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a continuous coordinate.
"""
# The decorator use of torch.no_grad() was not supported by torchscript.
# https://github.com/pytorch/pytorch/issues/44768
maps = maps.detach()
rois = rois.detach()

offset_x = rois[:, 0]
offset_y = rois[:, 1]
Expand All @@ -204,9 +200,11 @@ def heatmaps_to_keypoints(maps: torch.Tensor, rois: torch.Tensor) -> torch.Tenso
outsize = (int(heights_ceil[i]), int(widths_ceil[i]))
roi_map = F.interpolate(
maps[[i]], size=outsize, mode="bicubic", align_corners=False
).squeeze(
0
) # #keypoints x H x W
)

# Although semantically equivalent, `reshape` is used instead of `squeeze` due
# to limitation during ONNX export of `squeeze` in scripting mode
roi_map = roi_map.reshape(roi_map.shape[1:]) # keypoints x H x W

# softmax over the spatial region
max_score, _ = roi_map.view(num_keypoints, -1).max(1)
Expand Down
208 changes: 208 additions & 0 deletions tests/test_export_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
# Copyright (c) Facebook, Inc. and its affiliates.
import io
import unittest
import torch
from torch.hub import _check_module_exists

from detectron2 import model_zoo
from detectron2.config import get_cfg
from detectron2.export.flatten import TracingAdapter
from detectron2.export.torchscript_patch import patch_builtin_len
from detectron2.layers import ShapeSpec
from detectron2.modeling import build_model
from detectron2.modeling.roi_heads import KRCNNConvDeconvUpsampleHead
from detectron2.structures import Boxes, Instances
from detectron2.utils.testing import get_sample_coco_image, random_boxes

# TODO: replace with `from detectron2.export import STABLE_ONNX_OPSET_VERSION`
# after https://github.com/facebookresearch/detectron2/pull/4291 is merged
STABLE_ONNX_OPSET_VERSION = 11


def has_dynamic_axes(onnx_model):
"""
Return True when all input/output only have dynamic axes in all ranks
"""
return all(
not dim.dim_param.isnumeric()
for inp in onnx_model.graph.input
for dim in inp.type.tensor_type.shape.dim
) and all(
not dim.dim_param.isnumeric()
for out in onnx_model.graph.output
for dim in out.type.tensor_type.shape.dim
)


# TODO: Remove after https://github.com/facebookresearch/detectron2/pull/4291 is merged
# Import it from detectron2.utils.testing instead
def min_torch_version(min_version: str) -> bool:
"""
Returns True when torch's version is at least `min_version`.
"""
from packaging import version

try:
import torch
except ImportError:
return False

installed_version = version.parse(torch.__version__.split("+")[0])
min_version = version.parse(min_version)
return installed_version >= min_version


# TODO: Remove after https://github.com/facebookresearch/detectron2/pull/4291 is merged
# Import it from detectron2.utils.testing instead
def skipIfUnsupportedMinTorchVersion(min_version):
"""
Skips tests for PyTorch versions older than min_version.
"""

def skip_dec(func):
def wrapper(self):
if not min_torch_version(min_version):
raise unittest.SkipTest(
f"module 'torch' has __version__ {torch.__version__}"
f", required is: {min_version}"
)
return func(self)

return wrapper

return skip_dec


@unittest.skipIf(not _check_module_exists("onnx"), "ONNX not installed.")
class TestONNXTracingExport(unittest.TestCase):
opset_version = STABLE_ONNX_OPSET_VERSION

def testKeypointHead(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = KRCNNConvDeconvUpsampleHead(
ShapeSpec(channels=4, height=14, width=14), num_keypoints=17, conv_dims=(4,)
)

def forward(self, x, predbox1, predbox2):
inst = [
Instances((100, 100), pred_boxes=Boxes(predbox1)),
Instances((100, 100), pred_boxes=Boxes(predbox2)),
]
ret = self.model(x, inst)
return tuple(x.pred_keypoints for x in ret)

model = M()
model.eval()

def gen_input(num1, num2):
feat = torch.randn((num1 + num2, 4, 14, 14))
box1 = random_boxes(num1)
box2 = random_boxes(num2)
return feat, box1, box2

with patch_builtin_len():
onnx_model = self._test_model(
model,
gen_input(1, 2),
opset_version=STABLE_ONNX_OPSET_VERSION,
input_names=["features", "pred_boxes", "pred_classes"],
output_names=["box1", "box2"],
dynamic_axes={
"features": {0: "batch", 1: "static_four", 2: "height", 3: "width"},
"pred_boxes": {0: "batch", 1: "static_four"},
"pred_classes": {0: "batch", 1: "static_four"},
"box1": {0: "num_instance", 1: "K", 2: "static_three"},
"box2": {0: "num_instance", 1: "K", 2: "static_three"},
},
save_onnx_graph_path='/home/thiagofc/_d2.onnx',
)

# Although ONNX models are not executable by PyTorch to verify
# support of batches with different sizes, we can verify model's IR
# does not hard-code input and/or output shapes.
# TODO: Add tests with different batch sizes when detectron2's CI
# support ONNX Runtime backend.
assert has_dynamic_axes(onnx_model)

################################################################################
# Testcase internals - DO NOT add tests below this point
################################################################################
def _test_model(
self,
model,
inputs,
inference_func=None,
opset_version=STABLE_ONNX_OPSET_VERSION,
save_onnx_graph_path=None,
**export_kwargs,
):
# Not imported in the beginning of file to prevent runtime errors
# for environments without ONNX.
# This testcase checks dependencies before running
import onnx # isort:skip

f = io.BytesIO()
adapter_model = TracingAdapter(model, inputs, inference_func)
adapter_model.eval()
with torch.no_grad():
try:
torch.onnx.enable_log()
except AttributeError:
# Older ONNX versions does not have this API
pass
torch.onnx.export(
adapter_model,
adapter_model.flattened_inputs,
f,
training=torch.onnx.TrainingMode.EVAL,
opset_version=opset_version,
verbose=True,
**export_kwargs,
)
onnx_model = onnx.load_from_string(f.getvalue())
assert onnx_model is not None
if save_onnx_graph_path:
onnx.save(onnx_model, save_onnx_graph_path)

return onnx_model

def _test_model_zoo_from_config_path(
self,
config_path,
inference_func,
batch=1,
opset_version=STABLE_ONNX_OPSET_VERSION,
save_onnx_graph_path=None,
**export_kwargs,
):
model = model_zoo.get(config_path, trained=True)
image = get_sample_coco_image()
inputs = tuple(image.clone() for _ in range(batch))
return self._test_model(
model, inputs, inference_func, opset_version, save_onnx_graph_path, **export_kwargs
)

def _test_model_from_config_path(
self,
config_path,
inference_func,
batch=1,
opset_version=STABLE_ONNX_OPSET_VERSION,
save_onnx_graph_path=None,
**export_kwargs,
):
from projects.PointRend import point_rend # isort:skip

cfg = get_cfg()
cfg.DATALOADER.NUM_WORKERS = 0
point_rend.add_pointrend_config(cfg)
cfg.merge_from_file(config_path)
cfg.freeze()
model = build_model(cfg)
image = get_sample_coco_image()
inputs = tuple(image.clone() for _ in range(batch))
return self._test_model(
model, inputs, inference_func, opset_version, save_onnx_graph_path, **export_kwargs
)

0 comments on commit d4d3ec2

Please sign in to comment.