Skip to content

Commit

Permalink
Add autograd_inlining flag to torch.onnx.export (Fix pytorch#88286)
Browse files Browse the repository at this point in the history
  • Loading branch information
Thiago Crepaldi committed Jun 29, 2023
1 parent 67babf7 commit 3d3da10
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 5 deletions.
7 changes: 6 additions & 1 deletion torch/csrc/autograd/python_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -747,9 +747,14 @@ static void _trace_post_record(
}
}
}
py::object onnx_globals = py::module::import("torch.onnx._globals");
py::bool_ is_in_onnx_export =
py::module::import("torch.onnx.__init__").attr("is_in_onnx_export");
if (py::cast<bool>(is_in_onnx_export)) {
py::bool_ is_autograd_inlining_enabled =
py::cast<bool>(onnx_globals.attr("GLOBALS").attr("autograd_inlining"));

if (py::cast<bool>(is_in_onnx_export) &&
py::cast<bool>(is_autograd_inlining_enabled)) {
_append_subgraph(old_node, graph, std::move(trace_outputs), unpack_output);
}

Expand Down
9 changes: 6 additions & 3 deletions torch/csrc/jit/passes/onnx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -478,15 +478,18 @@ void NodeToONNX(
onnx_registration.attr("registry")
.attr("is_registered_op")("prim::PythonOp", opset_version)
.cast<bool>();
py::bool_ is_autograd_inlining_enabled =
py::cast<bool>(onnx_globals.attr("GLOBALS").attr("autograd_inlining"));
if (!py::hasattr(pyobj, "symbolic") && !is_registered_op) {
// Inline the subgraph within the prim::PythonOp unless
// either of these conditions are satisfied
// 1. The torch.autograd.Function class of this node object has `symbolic`
// method defined.
// 2. Custom export symbolic is registered for prim::PythonOp.
if (operator_export_type == ::torch::onnx::OperatorExportTypes::ONNX ||
operator_export_type ==
::torch::onnx::OperatorExportTypes::ONNX_ATEN_FALLBACK) {
if ((operator_export_type == ::torch::onnx::OperatorExportTypes::ONNX ||
operator_export_type ==
::torch::onnx::OperatorExportTypes::ONNX_ATEN_FALLBACK) &&
(py::cast<bool>(is_autograd_inlining_enabled))) {
try {
inlineAutograd(op);
} catch (const std::exception& ex) {
Expand Down
12 changes: 12 additions & 0 deletions torch/onnx/_globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(self):
_C_onnx.OperatorExportTypes.ONNX
)
self.onnx_shape_inference: bool = True
self._autograd_inlining: bool = True

@property
def training_mode(self):
Expand Down Expand Up @@ -69,5 +70,16 @@ def in_onnx_export(self, value: bool):
raise TypeError("in_onnx_export must be a boolean")
self._in_onnx_export = value

@property
def autograd_inlining(self) -> bool:
"""Whether Autograd must be inlined."""
return self._autograd_inlining

@autograd_inlining.setter
def autograd_inlining(self, value: bool):
if type(value) is not bool:
raise TypeError("autograd_inlining must be a boolean")
self._autograd_inlining = value


GLOBALS = _InternalGlobals()
12 changes: 11 additions & 1 deletion torch/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def export(
keep_initializers_as_inputs: Optional[bool] = None,
custom_opsets: Optional[Mapping[str, int]] = None,
export_modules_as_functions: Union[bool, Collection[Type[torch.nn.Module]]] = False,
autograd_inlining: Optional[bool] = True,
) -> None:
r"""Exports a model into ONNX format.
Expand Down Expand Up @@ -496,6 +497,9 @@ def forward(self, x):
* Set of type of nn.Module: export ``nn.Module`` forward calls as local function nodes,
only if the type of the ``nn.Module`` is found in the set.
autograd_inlining (bool, default True): Flag used to control whether to inline autograd functions.
Refer to https://github.com/pytorch/pytorch/pull/74765 for more details.
Raises:
:class:`torch.onnx.errors.CheckerError`: If the ONNX checker detects an invalid ONNX graph.
:class:`torch.onnx.errors.UnsupportedOperatorError`: If the ONNX graph cannot be exported because it
Expand All @@ -520,6 +524,7 @@ def forward(self, x):
keep_initializers_as_inputs=keep_initializers_as_inputs,
custom_opsets=custom_opsets,
export_modules_as_functions=export_modules_as_functions,
autograd_inlining=autograd_inlining,
)


Expand Down Expand Up @@ -581,7 +586,8 @@ def _optimize_graph(
# Remove fork/wait nodes
_C._jit_pass_inline_fork_wait(graph)
_C._jit_pass_lint(graph)
_C._jit_pass_onnx_autograd_function_process(graph)
if GLOBALS.autograd_inlining:
_C._jit_pass_onnx_autograd_function_process(graph)
_C._jit_pass_lower_all_tuples(graph)

# we now record some ops like ones/zeros
Expand Down Expand Up @@ -1483,6 +1489,7 @@ def _export(
add_node_names=True,
onnx_shape_inference=True,
export_modules_as_functions=False,
autograd_inlining=True,
):
assert GLOBALS.in_onnx_export is False

Expand Down Expand Up @@ -1534,6 +1541,8 @@ def _export(

try:
GLOBALS.in_onnx_export = True
_autograd_inlining_previous = GLOBALS.autograd_inlining
GLOBALS.autograd_inlining = autograd_inlining

module_typenames_to_export_as_functions: Set[str] = set()
if isinstance(model, (torch.nn.Module, torch.jit.ScriptModule)):
Expand Down Expand Up @@ -1658,6 +1667,7 @@ def _export(
finally:
assert GLOBALS.in_onnx_export
GLOBALS.in_onnx_export = False
GLOBALS.autograd_inlining = _autograd_inlining_previous
_reset_trace_module_map()

return torch_out
Expand Down

0 comments on commit 3d3da10

Please sign in to comment.