Skip to content

Commit

Permalink
Update functools.py (#71)
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens authored Jun 14, 2024
1 parent 7ec7e4b commit 133ed92
Show file tree
Hide file tree
Showing 7 changed files with 314 additions and 192 deletions.
127 changes: 70 additions & 57 deletions src/lcm/dispatchers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import inspect
from collections.abc import Callable
from typing import TypeVar
from typing import Literal, TypeVar

from jax import vmap
from jax import Array, vmap

from lcm.functools import allow_args, allow_kwargs
from lcm.functools import allow_args, allow_only_kwargs

F = TypeVar("F", bound=Callable)
F = TypeVar("F", bound=Callable[..., Array])


def spacemap(
Expand All @@ -25,26 +25,25 @@ def spacemap(
keyword arguments.
Args:
func (callable): The function to be dispatched.
dense_vars (list): Names of the dense variables, i.e. those that are stored
as arrays of possible values in the grid.
sparse_vars (list): Names of the sparse variables, i.e. those that are stored
as arrays of possible combinations of variables in the grid.
put_dense_first (bool): Whether the dense or sparse dimensions should come first
in the output of the dispatched function.
func: The function to be dispatched.
dense_vars: Names of the dense variables, i.e. those that are stored as arrays
of possible values in the grid.
sparse_vars: Names of the sparse variables, i.e. those that are stored as arrays
of possible combinations of variables in the grid.
put_dense_first: Whether the dense or sparse dimensions should come first in the
output of the dispatched function.
Returns:
callable: A callable with the same arguments as func (but with an additional
leading dimension) that returns a jax.numpy.ndarray or pytree of arrays.
If ``func`` returns a scalar, the dispatched function returns a
jax.numpy.ndarray with k + 1 dimensions, where k is the length of
``dense_vars`` and the additional dimension corresponds to the
``sparse_vars``. The order of the dimensions is determined by the order of
``dense_vars`` as well as the ``put_dense_first`` argument.
If the output of ``func`` is a jax pytree, the usual jax behavior applies,
i.e. the leading dimensions of all arrays in the pytree are as described
above but there might be additional dimensions.
A callable with the same arguments as func (but with an additional leading
dimension) that returns a jax.numpy.ndarray or pytree of arrays. If ``func``
returns a scalar, the dispatched function returns a jax.numpy.ndarray with 1
jax.numpy.ndarray with k + 1 dimensions, where k is the length of ``dense_vars``
and the additional dimension corresponds to the ``sparse_vars``. The order of
the dimensions is determined by the order of ``dense_vars`` as well as the
``put_dense_first`` argument. If the output of ``func`` is a jax pytree, the
usual jax behavior applies, i.e. the leading dimensions of all arrays in the
pytree are as described above but there might be additional dimensions.
"""
# Check inputs and prepare function
Expand Down Expand Up @@ -72,42 +71,49 @@ def spacemap(
if not sparse_vars:
vmapped = _base_productmap(func, dense_vars)
elif put_dense_first:
vmapped = vmap_1d(func, variables=sparse_vars, apply_allow_kwargs=False)
vmapped = vmap_1d(func, variables=sparse_vars, callable_with="only_args")
vmapped = _base_productmap(vmapped, dense_vars)
else:
vmapped = _base_productmap(func, dense_vars)
vmapped = vmap_1d(vmapped, variables=sparse_vars, apply_allow_kwargs=False)
vmapped = vmap_1d(vmapped, variables=sparse_vars, callable_with="only_args")

# This raises a mypy error but is perfectly fine to do. See
# https://github.com/python/mypy/issues/12472
vmapped.__signature__ = inspect.signature(func) # type: ignore[attr-defined]

return allow_kwargs(vmapped)
return allow_only_kwargs(vmapped)


def vmap_1d(func: F, variables: list[str], *, apply_allow_kwargs: bool = True) -> F:
def vmap_1d(
func: F,
variables: list[str],
*,
callable_with: Literal["only_args", "only_kwargs"] = "only_kwargs",
) -> F:
"""Apply vmap such that func is mapped over the specified variables.
In contrast to a general vmap call, vmap_1d vectorizes along the leading axis of all
of the requested variables simultaneously. Moreover, it preserves the function
signature and allows the function to be called with keyword arguments.
Args:
func (callable): The function to be dispatched.
variables (list): List with names of arguments that over which we map.
apply_allow_kwargs (bool): Whether to apply the allow_kwargs decorator to the
dispatched function.
func: The function to be dispatched.
variables: List with names of arguments that over which we map.
callable_with: Whether to apply the allow_kwargs decorator to the dispatched
function. If "only_args", the returned function can only be called with
positional arguments. If "only_kwargs", the returned function can only be
called with keyword arguments.
Returns:
callable: A callable with the same arguments as func (but with an additional
leading dimension) that returns a jax.numpy.ndarray or pytree of arrays.
If ``func`` returns a scalar, the dispatched function returns a
jax.numpy.ndarray with 1 dimension and length k, where k is the length of
one of the mapped inputs in ``variables``. The order of the dimensions is
determined by the order of ``variables`` which can be different to the order
of ``funcs`` arguments. If the output of ``func`` is a jax pytree, the usual
jax behavior applies, i.e. the leading dimensions of all arrays in the
pytree are as described above but there might be additional dimensions.
A callable with the same arguments as func (but with an additional leading
dimension) that returns a jax.numpy.ndarray or pytree of arrays. If ``func``
returns a scalar, the dispatched function returns a jax.numpy.ndarray with 1
jax.numpy.ndarray with 1 dimension and length k, where k is the length of one of
the mapped inputs in ``variables``. The order of the dimensions is determined by
the order of ``variables`` which can be different to the order of ``funcs``
arguments. If the output of ``func`` is a jax pytree, the usual jax behavior
applies, i.e. the leading dimensions of all arrays in the pytree are as
described above but there might be additional dimensions.
"""
if duplicates := {v for v in variables if variables.count(v) > 1}:
Expand All @@ -134,7 +140,17 @@ def vmap_1d(func: F, variables: list[str], *, apply_allow_kwargs: bool = True) -
# https://github.com/python/mypy/issues/12472
vmapped.__signature__ = signature # type: ignore[attr-defined]

return allow_kwargs(vmapped) if apply_allow_kwargs else vmapped
if callable_with == "only_kwargs":
out = allow_only_kwargs(vmapped)
elif callable_with == "only_args":
out = vmapped
else:
raise ValueError(
f"Invalid callable_with option: {callable_with}. Possible options are "
"('only_args', 'only_kwargs')",
)

return out


def productmap(func: F, variables: list[str]) -> F:
Expand All @@ -146,20 +162,19 @@ def productmap(func: F, variables: list[str]) -> F:
allows the function to be called with keyword arguments.
Args:
func (callable): The function to be dispatched.
variables (list): List with names of arguments that over which the Cartesian
product should be formed.
func: The function to be dispatched.
variables: List with names of arguments that over which the Cartesian product
should be formed.
Returns:
callable: A callable with the same arguments as func (but with an additional
leading dimension) that returns a jax.numpy.ndarray or pytree of arrays.
If ``func`` returns a scalar, the dispatched function returns a
jax.numpy.ndarray with k dimensions, where k is the length of ``variables``.
The order of the dimensions is determined by the order of ``variables``
which can be different to the order of ``funcs`` arguments. If the output of
``func`` is a jax pytree, the usual jax behavior applies, i.e. the leading
dimensions of all arrays in the pytree are as described above but there
might be additional dimensions.
A callable with the same arguments as func (but with an additional leading
dimension) that returns a jax.numpy.ndarray or pytree of arrays. If ``func``
returns a scalar, the dispatched function returns a jax.numpy.ndarray with k
dimensions, where k is the length of ``variables``. The order of the dimensions
is determined by the order of ``variables`` which can be different to the order
of ``funcs`` arguments. If the output of ``func`` is a jax pytree, the usual jax
behavior applies, i.e. the leading dimensions of all arrays in the pytree are as
described above but there might be additional dimensions.
"""
func = allow_args(func) # jax.vmap cannot deal with keyword-only arguments
Expand All @@ -176,7 +191,7 @@ def productmap(func: F, variables: list[str]) -> F:
# https://github.com/python/mypy/issues/12472
vmapped.__signature__ = signature # type: ignore[attr-defined]

return allow_kwargs(vmapped)
return allow_only_kwargs(vmapped)


def _base_productmap(func: F, product_axes: list[str]) -> F:
Expand All @@ -186,13 +201,11 @@ def _base_productmap(func: F, product_axes: list[str]) -> F:
the function to be called with keyword arguments.
Args:
func (callable): The function to be dispatched. Cannot have keyword-only
arguments.
product_axes (list): List with names of arguments over which we apply vmap.
func: The function to be dispatched. Cannot have keyword-only arguments.
product_axes: List with names of arguments over which we apply vmap.
Returns:
callable: A callable with the same arguments as func. See ``product_map`` for
details.
A callable with the same arguments as func. See ``product_map`` for details.
"""
signature = inspect.signature(func)
Expand Down
Loading

0 comments on commit 133ed92

Please sign in to comment.