Skip to content

Commit

Permalink
Remove unnecessary constant outputs from ONNX exported graph
Browse files Browse the repository at this point in the history
`TracingAdapter` creates extra outputs (through `flatten_to_tuple`) to hold
metadata information to rebuild the original data format during
deserialization.

When exporting a PyTorch model to ONNX, the support to de-serialize the
output to the original formatThis is unnecessary during ONNX export as the original data will never
be reconstructed to its original format using Schema.__call__ API.
This PR suppresses such extra output constants during
torch.onnx.export() execution. Outside this API, the behavior is not
changed, ensuring BC.

Although not stricly necessary to achieve the same numerical results as
PyTorch, when a ONNX model schema is compared to PyTorch's, the diffrent
number of outputs (ONNX model will have more outputs than PyTorch) may
not only confuse users, but also result in false negative when coding
model comparison helpers.
  • Loading branch information
Thiago Crepaldi committed Jun 3, 2022
1 parent 523c402 commit 4ba9841
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 11 deletions.
8 changes: 7 additions & 1 deletion detectron2/export/c10.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,13 @@ def has(self, name):
return name in self.batch_extra_fields

def set(self, name, value):
data_len = len(value)
# len(tensor) leads to constants during tracing mode
if isinstance(value, Boxes):
data_len = value.tensor.shape[0]
elif isinstance(value, torch.Tensor):
data_len = value.shape[0]
else:
data_len = len(value)
if len(self.batch_extra_fields):
assert (
len(self) == data_len
Expand Down
5 changes: 0 additions & 5 deletions detectron2/export/caffe2_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,6 @@ def _check_eval(module):
)
onnx_model = onnx.load_from_string(f.getvalue())

# Apply ONNX's Optimization
all_passes = onnx.optimizer.get_available_passes()
passes = ["fuse_bn_into_conv"]
assert all(p in all_passes for p in passes)
onnx_model = onnx.optimizer.optimize(onnx_model, passes)
return onnx_model


Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/deployment.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ and then export the model into Caffe2, TorchScript or ONNX format.
The converted model is able to run in either Python or C++ without detectron2/torchvision dependency, on CPU or GPUs.
It has a runtime optimized for CPU & mobile inference, but not optimized for GPU inference.

This feature requires 1.9 > ONNX ≥ 1.6.
This feature requires ONNX ≥ 1.6.

### Coverage

Expand Down
13 changes: 10 additions & 3 deletions tests/export/test_c10.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
# Copyright (c) Facebook, Inc. and its affiliates.
import unittest

from detectron2.config import get_cfg
from detectron2.export.c10 import Caffe2RPN
from detectron2.layers import ShapeSpec
try:
# Caffe2 used to be included in PyTorch, but since PyTorch 1.10+,
# it is not included in pre-built packages. This is a safety BC check
from detectron2.config import get_cfg
from detectron2.export.c10 import Caffe2RPN
from detectron2.layers import ShapeSpec
except ImportError:
raise unittest.SkipTest(
f"PyTorch does not have Caffe2 support. Skipping all tests in {__name__}"
)


class TestCaffe2RPN(unittest.TestCase):
Expand Down
12 changes: 11 additions & 1 deletion tests/test_export_caffe2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,26 @@
import tempfile
import unittest
import torch
from torch.hub import _check_module_exists

from detectron2 import model_zoo
from detectron2.export import Caffe2Model, Caffe2Tracer
from detectron2.utils.logger import setup_logger
from detectron2.utils.testing import get_sample_coco_image

try:
# Caffe2 used to be included in PyTorch, but since PyTorch 1.10+,
# Caffe2 is not included in pre-built packages. This is a safety BC check
from detectron2.export import Caffe2Model, Caffe2Tracer
except ImportError:
raise unittest.SkipTest(
f"PyTorch does not have Caffe2 support. Skipping all tests in {__name__}"
)


# TODO: this test requires manifold access, see: T88318502
# Running it on CircleCI causes crash, not sure why.
@unittest.skipIf(os.environ.get("CIRCLECI"), "Caffe2 tests crash on CircleCI.")
@unittest.skipIf(not _check_module_exists("onnx"), "ONNX not installed.")
class TestCaffe2Export(unittest.TestCase):
def setUp(self):
setup_logger()
Expand Down

0 comments on commit 4ba9841

Please sign in to comment.