Skip to content

Commit

Permalink
"set_inplace_requires_grad_allowed" should be a context manager (#870)
Browse files Browse the repository at this point in the history
Test Plan:
- run existing tests; code reading
  • Loading branch information
zou3519 authored Jun 13, 2022
1 parent 5d9e50b commit 056ff1f
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 5 deletions.
17 changes: 12 additions & 5 deletions functorch/_src/eager_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,27 @@
_assert_wrapped_functional,
_propagate_functional_input_mutation,
set_inplace_requires_grad_allowed,
get_inplace_requires_grad_allowed,
)

argnums_t = Union[int, Tuple[int, ...]]


@contextlib.contextmanager
def enable_inplace_requires_grad(enabled=True):
prev_state = get_inplace_requires_grad_allowed()
set_inplace_requires_grad_allowed(enabled)
try:
yield
finally:
set_inplace_requires_grad_allowed(prev_state)


def _create_differentiable(inps, level=None):
def create_differentiable(x):
if isinstance(x, torch.Tensor):
try:
set_inplace_requires_grad_allowed(True)
with enable_inplace_requires_grad():
return x.requires_grad_()
finally:
set_inplace_requires_grad_allowed(False)

raise ValueError(f'Thing passed to transform API must be Tensor, '
f'got {type(x)}')
return tree_map(create_differentiable, inps)
Expand Down
5 changes: 5 additions & 0 deletions functorch/csrc/DynamicLayer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,11 @@ void setInplaceRequiresGradAllowed(bool allowed) {
functorch_tls->allow_inplace_requires_grad_ = allowed;
}

bool getInplaceRequiresGradAllowed() {
auto* functorch_tls = getRawFunctorchTLS();
return functorch_tls->allow_inplace_requires_grad_;
}


static std::vector<DynamicLayer>& dynamicLayerStackAccessor() {
return getRawFunctorchTLS()->dynamicLayerStack;
Expand Down
1 change: 1 addition & 0 deletions functorch/csrc/DynamicLayer.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ std::ostream& operator<<(std::ostream& os, const DynamicLayer& layer);
std::ostream& operator<<(std::ostream& os, const std::vector<DynamicLayer>& dynamicLayerStack);

void setInplaceRequiresGradAllowed(bool allowed);
bool getInplaceRequiresGradAllowed();


}
Expand Down
1 change: 1 addition & 0 deletions functorch/csrc/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("_set_vmap_fallback_enabled", &at::functorch::setVmapFallbackEnabled);
m.def("_is_vmap_fallback_enabled", &at::functorch::isVmapFallbackEnabled);
m.def("set_inplace_requires_grad_allowed", &at::functorch::setInplaceRequiresGradAllowed);
m.def("get_inplace_requires_grad_allowed", &at::functorch::getInplaceRequiresGradAllowed);
m.def("dlevel", &at::functorch::dlevel, "dlevel");
m.def("dump_tensor", &at::functorch::dump_tensor, "dump_tensor");
m.def("reshape_dim_into", &at::functorch::reshape_dim_into);
Expand Down

0 comments on commit 056ff1f

Please sign in to comment.