Skip to content

Commit

Permalink
Beef up transform limitations doc (#879)
Browse files Browse the repository at this point in the history
I want to be able to point someone at this page whenever we get asked
about the limitations of vmap. Please let me know if there are things
we're still missing from here
  • Loading branch information
zou3519 committed Jun 17, 2022
1 parent 4386955 commit e7f7fd1
Showing 1 changed file with 65 additions and 0 deletions.
65 changes: 65 additions & 0 deletions docs/source/ux_limitations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,23 @@ Please rewrite ``f`` to return ``intermediate``:

grad_x, intermediate = grad(f, has_aux=True)(x)

torch.autograd APIs
-------------------

If you are trying to use a ``torch.autograd`` API like ``torch.autograd.grad``
or ``torch.autograd.backward`` inside of a function being transformed by
:func:`vmap` or one of functorch's AD transforms (:func:`vjp`, :func:`jvp`,
:func:`jacrev`, :func:`jacfwd`), the transform may not be able to transform over it.
If it is unable to do so, you'll receive an error message.

This is a fundamental design limitation in how PyTorch's AD support is implemented
and the reason why we designed the functorch library. Please instead use the functorch
equivalents of the ``torch.autograd`` APIs:
- ``torch.autograd.grad``, ``Tensor.backward`` -> ``functorch.vjp`` or ``functorch.grad``
- ``torch.autograd.functional.jvp`` -> ``functorch.jvp``
- ``torch.autograd.functional.jacobian`` -> ``functorch.jacrev`` or ``functorch.jacfwd``
- ``torch.autograd.functional.hessian`` -> ``functorch.hessian``

vmap limitations
----------------

Expand Down Expand Up @@ -144,6 +161,14 @@ elements (or more):
vmap(f, in_dims=(0, 0))(x, y)
assert torch.allclose(x, expected)

Mutation: out= PyTorch Operations
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
:func:`vmap` doesn't support the ``out=`` keyword argument in PyTorch operations.
It will error out gracefully if it encounters that in your code.

This is not a fundamental limitation; we could theoretically support this in the
future but we have chosen not to for now.

Data-dependent Python control flow
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
We don't yet support ``vmap`` over data-dependent control flow. Data-dependent
Expand Down Expand Up @@ -180,6 +205,46 @@ using special control flow operators (e.g. ``jax.lax.cond``, ``jax.lax.while_loo
We're investigating adding equivalents of those to functorch
(open an issue on `GitHub <https://github.com/pytorch/functorch>`_ to voice your support!).

Data-dependent operations (.item())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
We do not (and will not) support vmap over a user-defined function that calls
``.item()`` on a Tensor. For example, the following will raise an error message:

::

def f(x):
return x.item()

x = torch.randn(3)
vmap(f)(x)

Please try to rewrite your code to not use ``.item()`` calls.

You may also encounter an error message about using ``.item()`` but you might
not have used it. In those cases, it is possible that PyTorch internally is
calling ``.item()`` -- please file an issue on GitHub and we'll fix
PyTorch internals.

Dynamic shape operations (nonzero and friends)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
``vmap(f)`` requires that ``f`` applied to every "example" in your input
returns a Tensor with the same shape. Operations such as ``torch.nonzero``,
``torch.is_nonzero`` are not supported and will error as a result.

To see why, consider the following example:

::

xs = torch.tensor([[0, 1, 2], [0, 0, 3]])
vmap(torch.nonzero)(xs)

``torch.nonzero(xs[0])`` returns a Tensor of shape 2;
but ``torch.nonzero(xs[1])`` returns a Tensor of shape 1.
We are unable to construct a single Tensor as an output;
the output would need to be a ragged Tensor (and PyTorch does not yet have
the concept of a ragged Tensor).


Randomness
----------
The user's intention when calling a random operation can be unclear. Specifically, some users may want
Expand Down

0 comments on commit e7f7fd1

Please sign in to comment.