diff --git a/docs/source/experimental.rst b/docs/source/experimental.rst new file mode 100644 index 000000000..05f82f08f --- /dev/null +++ b/docs/source/experimental.rst @@ -0,0 +1,12 @@ +functorch.experimental +====================== + +.. currentmodule:: functorch.experimental + +Experimental Function Transforms +-------------------------------- +.. autosummary:: + :toctree: generated + :nosignatures: + + functionalize diff --git a/docs/source/index.rst b/docs/source/index.rst index 007f56ebf..6c66a86e2 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -63,6 +63,7 @@ Check out our `whirlwind tour `_ or some of our tutorials mentio :caption: API Reference and Notes functorch + experimental aot_autograd .. toctree:: diff --git a/functorch/_src/eager_transforms.py b/functorch/_src/eager_transforms.py index 4675f3639..69a61a9af 100644 --- a/functorch/_src/eager_transforms.py +++ b/functorch/_src/eager_transforms.py @@ -1239,6 +1239,158 @@ def _unwrap_all_tensors_from_functional(tensor_pytree, *, reapply_views: bool): def functionalize(func: Callable, *, remove: str = 'mutations') -> Callable: + """ + functionalize is a transform that can be used to remove (intermediate) + mutations and aliasing from a function, while preserving the function's + semantics. + + ``functionalize(func)`` returns a new function with the same semantics + as ``func``, but with all intermediate mutations removed. + Every inplace operation performed on an intermediate tensor: + ``intermediate.foo_()`` + gets replaced by its out-of-place equivalent: + ``intermediate_updated = intermediate.foo()``. + + functionalize is useful for shipping a pytorch program off to + backends or compilers that aren't able to easily represent + mutations or aliasing operators. + + Args: + func (Callable): A Python function that takes one or more arguments. + remove (str): An optional string argument, that takes on either + the value 'mutations' or 'mutations_and_views'. + If 'mutations' is passed in then all mutating operators + will be replaced with their non-mutating equivalents. + If 'mutations_and_views' is passed in, then additionally, all aliasing + operators will be replaced with their non-aliasing equivalents. + Default: 'mutations'. + + Returns: + Returns a new "functionalized" function. It takes the same inputs as + :attr:`func`, and has the same behavior, but any mutations + (and optionally aliasing) performed on intermeidate tensors + in the function will be removed. + + functionalize will also remove mutations (and views) that were performed on function inputs. + However to preserve semantics, functionalize will "fix up" the mutations after + the transform has finished running, by detecting if any tensor inputs "should have" + been mutated, and copying the new data back to the inputs if necessary. + + + Example:: + + >>> import torch + >>> from functorch import make_fx + >>> from functorch.experimental import functionalize + >>> + >>> A function that uses mutations and views, but only on intermediate tensors. + >>> def f(a): + ... b = a + 1 + ... c = b.view(-1) + ... c.add_(1) + ... return b + ... + >>> inpt = torch.randn(2) + >>> + >>> out1 = f(inpt) + >>> out2 = functionalize(f)(inpt) + >>> + >>> # semantics are the same (outputs are equivalent) + >>> print(torch.allclose(out1, out2)) + True + >>> + >>> f_traced = make_fx(f)(inpt) + >>> f_no_mutations_traced = make_fx(functionalize(f))(inpt) + >>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt) + >>> + >>> print(f_traced.code) + + + + def forward(self, a_1): + add = torch.ops.aten.add(a_1, 1); a_1 = None + view = torch.ops.aten.view(add, [-1]) + add_ = torch.ops.aten.add_(view, 1); view = None + return add + + >>> print(f_no_mutations_traced.code) + + + + def forward(self, a_1): + add = torch.ops.aten.add(a_1, 1); a_1 = None + view = torch.ops.aten.view(add, [-1]); add = None + add_1 = torch.ops.aten.add(view, 1); view = None + view_1 = torch.ops.aten.view(add_1, [2]); add_1 = None + return view_1 + + >>> print(f_no_mutations_and_views_traced.code) + + + + def forward(self, a_1): + add = torch.ops.aten.add(a_1, 1); a_1 = None + view_copy = torch.ops.aten.view_copy(add, [-1]); add = None + add_1 = torch.ops.aten.add(view_copy, 1); view_copy = None + view_copy_1 = torch.ops.aten.view_copy(add_1, [2]); add_1 = None + return view_copy_1 + + + >>> A function that mutates its input tensor + >>> def f(a): + ... b = a.view(-1) + ... b.add_(1) + ... return a + ... + >>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt) + >>> + >>> All mutations and views have been removed, + >>> but there is an extra copy_ in the graph to correctly apply the mutation to the input + >>> after the function has completed. + >>> print(f_no_mutations_and_views_traced.code) + + + + def forward(self, a_1): + view_copy = torch.ops.aten.view_copy(a_1, [-1]) + add = torch.ops.aten.add(view_copy, 1); view_copy = None + view_copy_1 = torch.ops.aten.view_copy(add, [2]); add = None + copy_ = torch.ops.aten.copy_(a_1, view_copy_1); a_1 = None + return view_copy_1 + + + There are a few "failure modes" for functionalize that are worth calling out: + (1) Like other functorch transforms, `functionalize()` doesn't work with functions + that directly use `.backward()`. The same is true for torch.autograd.grad. + If you want to use autograd, you can compute gradients directly + with `functionalize(grad(f))`. + (2) Like other functorch transforms, `functionalize()` doesn't work with global state. + If you call `functionalize(f)` on a function that takes views / mutations of + non-local state, functionalization will simply no-op and pass the view/mutation + calls directly to the backend. + One way to work around this is is to ensure that any non-local state creation + is wrapped into a larger function, which you then call functionalize on. + (3) `resize_()` has some limitations: functionalize will only work on programs + that use resize_()` as long as the tensor being resized is not a view. + (4) `as_strided()` has some limitations: functionalize will not work on + `as_strided()` calls that result in tensors with overlapping memory. + + + Finally, a helpful mental model for understanding functionalization is that + most user pytorch programs are writting with the public torch API. + When executed, torch operators are generally decomposed into + our internal C++ "ATen" API. + The logic for functionalization happens entirely at the level of ATen. + Functionalization knows how to take every aliasing operator in ATen, + and map it to its non-aliasing equivalent + (e.g. ``tensor.view({-1})`` -> ``at::view_copy(tensor, {-1})``), + and how to take every mutating operator in ATen, + and map it to its non-mutating equivalent + (e.g. ``tensor.add_(1)`` -> ``at::add(tensor, -1)``), + while tracking aliases and mutations out-of-line to know when to fix things up. + Information about which ATen operators are aliasing or mutating all comes from + https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml. + """ if remove == 'mutations': reapply_views = True elif remove == 'mutations_and_views':