-
Notifications
You must be signed in to change notification settings - Fork 7.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add ONNX support to KRCNNConvDeconvUpsampleHead
The fix involves changes to both PyTorch and detectron2: * Pytorch had a bug which prevented some tensors to be identified as float (refer to pytorch/pytorch#81386) * `detectron2/structures/keypoints.py::heatmaps_to_keypoints` internally does advanced indexing on a `squeeze`d `roi_map`. The aforentioned `squeeze` fails rank inference due to the presence of `onnx::If` on its composition to support dynamic dims. By replacing `squeeze` by `replace` on detectron's `heatmaps_to_keypoints`, shape inference succeeds, allowing ONNX export to succeed with dynamic axes support.
- Loading branch information
Thiago Crepaldi
committed
Jul 13, 2022
1 parent
453d795
commit d4d3ec2
Showing
2 changed files
with
213 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,208 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
import io | ||
import unittest | ||
import torch | ||
from torch.hub import _check_module_exists | ||
|
||
from detectron2 import model_zoo | ||
from detectron2.config import get_cfg | ||
from detectron2.export.flatten import TracingAdapter | ||
from detectron2.export.torchscript_patch import patch_builtin_len | ||
from detectron2.layers import ShapeSpec | ||
from detectron2.modeling import build_model | ||
from detectron2.modeling.roi_heads import KRCNNConvDeconvUpsampleHead | ||
from detectron2.structures import Boxes, Instances | ||
from detectron2.utils.testing import get_sample_coco_image, random_boxes | ||
|
||
# TODO: replace with `from detectron2.export import STABLE_ONNX_OPSET_VERSION` | ||
# after https://github.com/facebookresearch/detectron2/pull/4291 is merged | ||
STABLE_ONNX_OPSET_VERSION = 11 | ||
|
||
|
||
def has_dynamic_axes(onnx_model): | ||
""" | ||
Return True when all input/output only have dynamic axes in all ranks | ||
""" | ||
return all( | ||
not dim.dim_param.isnumeric() | ||
for inp in onnx_model.graph.input | ||
for dim in inp.type.tensor_type.shape.dim | ||
) and all( | ||
not dim.dim_param.isnumeric() | ||
for out in onnx_model.graph.output | ||
for dim in out.type.tensor_type.shape.dim | ||
) | ||
|
||
|
||
# TODO: Remove after https://github.com/facebookresearch/detectron2/pull/4291 is merged | ||
# Import it from detectron2.utils.testing instead | ||
def min_torch_version(min_version: str) -> bool: | ||
""" | ||
Returns True when torch's version is at least `min_version`. | ||
""" | ||
from packaging import version | ||
|
||
try: | ||
import torch | ||
except ImportError: | ||
return False | ||
|
||
installed_version = version.parse(torch.__version__.split("+")[0]) | ||
min_version = version.parse(min_version) | ||
return installed_version >= min_version | ||
|
||
|
||
# TODO: Remove after https://github.com/facebookresearch/detectron2/pull/4291 is merged | ||
# Import it from detectron2.utils.testing instead | ||
def skipIfUnsupportedMinTorchVersion(min_version): | ||
""" | ||
Skips tests for PyTorch versions older than min_version. | ||
""" | ||
|
||
def skip_dec(func): | ||
def wrapper(self): | ||
if not min_torch_version(min_version): | ||
raise unittest.SkipTest( | ||
f"module 'torch' has __version__ {torch.__version__}" | ||
f", required is: {min_version}" | ||
) | ||
return func(self) | ||
|
||
return wrapper | ||
|
||
return skip_dec | ||
|
||
|
||
@unittest.skipIf(not _check_module_exists("onnx"), "ONNX not installed.") | ||
class TestONNXTracingExport(unittest.TestCase): | ||
opset_version = STABLE_ONNX_OPSET_VERSION | ||
|
||
def testKeypointHead(self): | ||
class M(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.model = KRCNNConvDeconvUpsampleHead( | ||
ShapeSpec(channels=4, height=14, width=14), num_keypoints=17, conv_dims=(4,) | ||
) | ||
|
||
def forward(self, x, predbox1, predbox2): | ||
inst = [ | ||
Instances((100, 100), pred_boxes=Boxes(predbox1)), | ||
Instances((100, 100), pred_boxes=Boxes(predbox2)), | ||
] | ||
ret = self.model(x, inst) | ||
return tuple(x.pred_keypoints for x in ret) | ||
|
||
model = M() | ||
model.eval() | ||
|
||
def gen_input(num1, num2): | ||
feat = torch.randn((num1 + num2, 4, 14, 14)) | ||
box1 = random_boxes(num1) | ||
box2 = random_boxes(num2) | ||
return feat, box1, box2 | ||
|
||
with patch_builtin_len(): | ||
onnx_model = self._test_model( | ||
model, | ||
gen_input(1, 2), | ||
opset_version=STABLE_ONNX_OPSET_VERSION, | ||
input_names=["features", "pred_boxes", "pred_classes"], | ||
output_names=["box1", "box2"], | ||
dynamic_axes={ | ||
"features": {0: "batch", 1: "static_four", 2: "height", 3: "width"}, | ||
"pred_boxes": {0: "batch", 1: "static_four"}, | ||
"pred_classes": {0: "batch", 1: "static_four"}, | ||
"box1": {0: "num_instance", 1: "K", 2: "static_three"}, | ||
"box2": {0: "num_instance", 1: "K", 2: "static_three"}, | ||
}, | ||
save_onnx_graph_path='/home/thiagofc/_d2.onnx', | ||
) | ||
|
||
# Although ONNX models are not executable by PyTorch to verify | ||
# support of batches with different sizes, we can verify model's IR | ||
# does not hard-code input and/or output shapes. | ||
# TODO: Add tests with different batch sizes when detectron2's CI | ||
# support ONNX Runtime backend. | ||
assert has_dynamic_axes(onnx_model) | ||
|
||
################################################################################ | ||
# Testcase internals - DO NOT add tests below this point | ||
################################################################################ | ||
def _test_model( | ||
self, | ||
model, | ||
inputs, | ||
inference_func=None, | ||
opset_version=STABLE_ONNX_OPSET_VERSION, | ||
save_onnx_graph_path=None, | ||
**export_kwargs, | ||
): | ||
# Not imported in the beginning of file to prevent runtime errors | ||
# for environments without ONNX. | ||
# This testcase checks dependencies before running | ||
import onnx # isort:skip | ||
|
||
f = io.BytesIO() | ||
adapter_model = TracingAdapter(model, inputs, inference_func) | ||
adapter_model.eval() | ||
with torch.no_grad(): | ||
try: | ||
torch.onnx.enable_log() | ||
except AttributeError: | ||
# Older ONNX versions does not have this API | ||
pass | ||
torch.onnx.export( | ||
adapter_model, | ||
adapter_model.flattened_inputs, | ||
f, | ||
training=torch.onnx.TrainingMode.EVAL, | ||
opset_version=opset_version, | ||
verbose=True, | ||
**export_kwargs, | ||
) | ||
onnx_model = onnx.load_from_string(f.getvalue()) | ||
assert onnx_model is not None | ||
if save_onnx_graph_path: | ||
onnx.save(onnx_model, save_onnx_graph_path) | ||
|
||
return onnx_model | ||
|
||
def _test_model_zoo_from_config_path( | ||
self, | ||
config_path, | ||
inference_func, | ||
batch=1, | ||
opset_version=STABLE_ONNX_OPSET_VERSION, | ||
save_onnx_graph_path=None, | ||
**export_kwargs, | ||
): | ||
model = model_zoo.get(config_path, trained=True) | ||
image = get_sample_coco_image() | ||
inputs = tuple(image.clone() for _ in range(batch)) | ||
return self._test_model( | ||
model, inputs, inference_func, opset_version, save_onnx_graph_path, **export_kwargs | ||
) | ||
|
||
def _test_model_from_config_path( | ||
self, | ||
config_path, | ||
inference_func, | ||
batch=1, | ||
opset_version=STABLE_ONNX_OPSET_VERSION, | ||
save_onnx_graph_path=None, | ||
**export_kwargs, | ||
): | ||
from projects.PointRend import point_rend # isort:skip | ||
|
||
cfg = get_cfg() | ||
cfg.DATALOADER.NUM_WORKERS = 0 | ||
point_rend.add_pointrend_config(cfg) | ||
cfg.merge_from_file(config_path) | ||
cfg.freeze() | ||
model = build_model(cfg) | ||
image = get_sample_coco_image() | ||
inputs = tuple(image.clone() for _ in range(batch)) | ||
return self._test_model( | ||
model, inputs, inference_func, opset_version, save_onnx_graph_path, **export_kwargs | ||
) |