From 3d3da109e3afa617c513e78aa999f5a1f44ffbce Mon Sep 17 00:00:00 2001 From: Thiago Crepaldi Date: Thu, 29 Jun 2023 19:21:31 +0000 Subject: [PATCH] Add autograd_inlining flag to torch.onnx.export (Fix #88286) --- torch/csrc/autograd/python_function.cpp | 7 ++++++- torch/csrc/jit/passes/onnx.cpp | 9 ++++++--- torch/onnx/_globals.py | 12 ++++++++++++ torch/onnx/utils.py | 12 +++++++++++- 4 files changed, 35 insertions(+), 5 deletions(-) diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp index a9da6d328cb0c..2c1ccf7f31528 100644 --- a/torch/csrc/autograd/python_function.cpp +++ b/torch/csrc/autograd/python_function.cpp @@ -747,9 +747,14 @@ static void _trace_post_record( } } } + py::object onnx_globals = py::module::import("torch.onnx._globals"); py::bool_ is_in_onnx_export = py::module::import("torch.onnx.__init__").attr("is_in_onnx_export"); - if (py::cast(is_in_onnx_export)) { + py::bool_ is_autograd_inlining_enabled = + py::cast(onnx_globals.attr("GLOBALS").attr("autograd_inlining")); + + if (py::cast(is_in_onnx_export) && + py::cast(is_autograd_inlining_enabled)) { _append_subgraph(old_node, graph, std::move(trace_outputs), unpack_output); } diff --git a/torch/csrc/jit/passes/onnx.cpp b/torch/csrc/jit/passes/onnx.cpp index 0caf734b0a6fa..630a20a4d265e 100644 --- a/torch/csrc/jit/passes/onnx.cpp +++ b/torch/csrc/jit/passes/onnx.cpp @@ -478,15 +478,18 @@ void NodeToONNX( onnx_registration.attr("registry") .attr("is_registered_op")("prim::PythonOp", opset_version) .cast(); + py::bool_ is_autograd_inlining_enabled = + py::cast(onnx_globals.attr("GLOBALS").attr("autograd_inlining")); if (!py::hasattr(pyobj, "symbolic") && !is_registered_op) { // Inline the subgraph within the prim::PythonOp unless // either of these conditions are satisfied // 1. The torch.autograd.Function class of this node object has `symbolic` // method defined. // 2. Custom export symbolic is registered for prim::PythonOp. - if (operator_export_type == ::torch::onnx::OperatorExportTypes::ONNX || - operator_export_type == - ::torch::onnx::OperatorExportTypes::ONNX_ATEN_FALLBACK) { + if ((operator_export_type == ::torch::onnx::OperatorExportTypes::ONNX || + operator_export_type == + ::torch::onnx::OperatorExportTypes::ONNX_ATEN_FALLBACK) && + (py::cast(is_autograd_inlining_enabled))) { try { inlineAutograd(op); } catch (const std::exception& ex) { diff --git a/torch/onnx/_globals.py b/torch/onnx/_globals.py index 4f2e54cf6962e..f827d12be7fbf 100644 --- a/torch/onnx/_globals.py +++ b/torch/onnx/_globals.py @@ -29,6 +29,7 @@ def __init__(self): _C_onnx.OperatorExportTypes.ONNX ) self.onnx_shape_inference: bool = True + self._autograd_inlining: bool = True @property def training_mode(self): @@ -69,5 +70,16 @@ def in_onnx_export(self, value: bool): raise TypeError("in_onnx_export must be a boolean") self._in_onnx_export = value + @property + def autograd_inlining(self) -> bool: + """Whether Autograd must be inlined.""" + return self._autograd_inlining + + @autograd_inlining.setter + def autograd_inlining(self, value: bool): + if type(value) is not bool: + raise TypeError("autograd_inlining must be a boolean") + self._autograd_inlining = value + GLOBALS = _InternalGlobals() diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 9863fa6613115..c88aaf437ddc1 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -205,6 +205,7 @@ def export( keep_initializers_as_inputs: Optional[bool] = None, custom_opsets: Optional[Mapping[str, int]] = None, export_modules_as_functions: Union[bool, Collection[Type[torch.nn.Module]]] = False, + autograd_inlining: Optional[bool] = True, ) -> None: r"""Exports a model into ONNX format. @@ -496,6 +497,9 @@ def forward(self, x): * Set of type of nn.Module: export ``nn.Module`` forward calls as local function nodes, only if the type of the ``nn.Module`` is found in the set. + autograd_inlining (bool, default True): Flag used to control whether to inline autograd functions. + Refer to https://github.com/pytorch/pytorch/pull/74765 for more details. + Raises: :class:`torch.onnx.errors.CheckerError`: If the ONNX checker detects an invalid ONNX graph. :class:`torch.onnx.errors.UnsupportedOperatorError`: If the ONNX graph cannot be exported because it @@ -520,6 +524,7 @@ def forward(self, x): keep_initializers_as_inputs=keep_initializers_as_inputs, custom_opsets=custom_opsets, export_modules_as_functions=export_modules_as_functions, + autograd_inlining=autograd_inlining, ) @@ -581,7 +586,8 @@ def _optimize_graph( # Remove fork/wait nodes _C._jit_pass_inline_fork_wait(graph) _C._jit_pass_lint(graph) - _C._jit_pass_onnx_autograd_function_process(graph) + if GLOBALS.autograd_inlining: + _C._jit_pass_onnx_autograd_function_process(graph) _C._jit_pass_lower_all_tuples(graph) # we now record some ops like ones/zeros @@ -1483,6 +1489,7 @@ def _export( add_node_names=True, onnx_shape_inference=True, export_modules_as_functions=False, + autograd_inlining=True, ): assert GLOBALS.in_onnx_export is False @@ -1534,6 +1541,8 @@ def _export( try: GLOBALS.in_onnx_export = True + _autograd_inlining_previous = GLOBALS.autograd_inlining + GLOBALS.autograd_inlining = autograd_inlining module_typenames_to_export_as_functions: Set[str] = set() if isinstance(model, (torch.nn.Module, torch.jit.ScriptModule)): @@ -1658,6 +1667,7 @@ def _export( finally: assert GLOBALS.in_onnx_export GLOBALS.in_onnx_export = False + GLOBALS.autograd_inlining = _autograd_inlining_previous _reset_trace_module_map() return torch_out