From 09cc52399de35ce7b469b7bf97d74d425ee2b4c5 Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Mon, 13 Jun 2022 17:39:01 -0400 Subject: [PATCH] "set_inplace_requires_grad_allowed" should be a context manager (#870) Test Plan: - run existing tests; code reading --- functorch/_src/eager_transforms.py | 17 ++++++++++++----- functorch/csrc/DynamicLayer.cpp | 5 +++++ functorch/csrc/DynamicLayer.h | 1 + functorch/csrc/init.cpp | 1 + 4 files changed, 19 insertions(+), 5 deletions(-) diff --git a/functorch/_src/eager_transforms.py b/functorch/_src/eager_transforms.py index cd3959086..4675f3639 100644 --- a/functorch/_src/eager_transforms.py +++ b/functorch/_src/eager_transforms.py @@ -33,20 +33,27 @@ _assert_wrapped_functional, _propagate_functional_input_mutation, set_inplace_requires_grad_allowed, + get_inplace_requires_grad_allowed, ) argnums_t = Union[int, Tuple[int, ...]] +@contextlib.contextmanager +def enable_inplace_requires_grad(enabled=True): + prev_state = get_inplace_requires_grad_allowed() + set_inplace_requires_grad_allowed(enabled) + try: + yield + finally: + set_inplace_requires_grad_allowed(prev_state) + + def _create_differentiable(inps, level=None): def create_differentiable(x): if isinstance(x, torch.Tensor): - try: - set_inplace_requires_grad_allowed(True) + with enable_inplace_requires_grad(): return x.requires_grad_() - finally: - set_inplace_requires_grad_allowed(False) - raise ValueError(f'Thing passed to transform API must be Tensor, ' f'got {type(x)}') return tree_map(create_differentiable, inps) diff --git a/functorch/csrc/DynamicLayer.cpp b/functorch/csrc/DynamicLayer.cpp index 2812c85f6..d1d24b3fb 100644 --- a/functorch/csrc/DynamicLayer.cpp +++ b/functorch/csrc/DynamicLayer.cpp @@ -140,6 +140,11 @@ void setInplaceRequiresGradAllowed(bool allowed) { functorch_tls->allow_inplace_requires_grad_ = allowed; } +bool getInplaceRequiresGradAllowed() { + auto* functorch_tls = getRawFunctorchTLS(); + return functorch_tls->allow_inplace_requires_grad_; +} + static std::vector& dynamicLayerStackAccessor() { return getRawFunctorchTLS()->dynamicLayerStack; diff --git a/functorch/csrc/DynamicLayer.h b/functorch/csrc/DynamicLayer.h index fe912980c..7d5b5f4a9 100644 --- a/functorch/csrc/DynamicLayer.h +++ b/functorch/csrc/DynamicLayer.h @@ -86,6 +86,7 @@ std::ostream& operator<<(std::ostream& os, const DynamicLayer& layer); std::ostream& operator<<(std::ostream& os, const std::vector& dynamicLayerStack); void setInplaceRequiresGradAllowed(bool allowed); +bool getInplaceRequiresGradAllowed(); } diff --git a/functorch/csrc/init.cpp b/functorch/csrc/init.cpp index b0699ce7c..3c60db9a0 100644 --- a/functorch/csrc/init.cpp +++ b/functorch/csrc/init.cpp @@ -380,6 +380,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("_set_vmap_fallback_enabled", &at::functorch::setVmapFallbackEnabled); m.def("_is_vmap_fallback_enabled", &at::functorch::isVmapFallbackEnabled); m.def("set_inplace_requires_grad_allowed", &at::functorch::setInplaceRequiresGradAllowed); + m.def("get_inplace_requires_grad_allowed", &at::functorch::getInplaceRequiresGradAllowed); m.def("dlevel", &at::functorch::dlevel, "dlevel"); m.def("dump_tensor", &at::functorch::dump_tensor, "dump_tensor"); m.def("reshape_dim_into", &at::functorch::reshape_dim_into);