Skip to content

Commit

Permalink
Add callable_with option to vmap_1d
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Jun 14, 2024
1 parent 55dbef3 commit 0da6888
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 7 deletions.
31 changes: 24 additions & 7 deletions src/lcm/dispatchers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import inspect
from collections.abc import Callable
from typing import TypeVar
from typing import Literal, TypeVar

from jax import Array, vmap

Expand Down Expand Up @@ -72,11 +72,11 @@ def spacemap(
if not sparse_vars:
vmapped = _base_productmap(func, dense_vars)
elif put_dense_first:
vmapped = vmap_1d(func, variables=sparse_vars, apply_allow_kwargs=False)
vmapped = vmap_1d(func, variables=sparse_vars, callable_with="only_args")
vmapped = _base_productmap(vmapped, dense_vars)
else:
vmapped = _base_productmap(func, dense_vars)
vmapped = vmap_1d(vmapped, variables=sparse_vars, apply_allow_kwargs=False)
vmapped = vmap_1d(vmapped, variables=sparse_vars, callable_with="only_args")

# This raises a mypy error but is perfectly fine to do. See
# https://github.com/python/mypy/issues/12472
Expand All @@ -85,7 +85,12 @@ def spacemap(
return allow_only_kwargs(vmapped)


def vmap_1d(func: F, variables: list[str], *, apply_allow_kwargs: bool = True) -> F:
def vmap_1d(
func: F,
variables: list[str],
*,
callable_with: Literal["only_args", "only_kwargs"] = "only_kwargs",
) -> F:
"""Apply vmap such that func is mapped over the specified variables.
In contrast to a general vmap call, vmap_1d vectorizes along the leading axis of all
Expand All @@ -95,8 +100,10 @@ def vmap_1d(func: F, variables: list[str], *, apply_allow_kwargs: bool = True) -
Args:
func (callable): The function to be dispatched.
variables (list): List with names of arguments that over which we map.
apply_allow_kwargs (bool): Whether to apply the allow_kwargs decorator to the
dispatched function.
callable_with (str): Whether to apply the allow_kwargs decorator to the
dispatched function. If "only_args", the returned function can only be
called with positional arguments. If "only_kwargs", the returned function
can only be called with keyword arguments.
Returns:
callable: A callable with the same arguments as func (but with an additional
Expand Down Expand Up @@ -134,7 +141,17 @@ def vmap_1d(func: F, variables: list[str], *, apply_allow_kwargs: bool = True) -
# https://github.com/python/mypy/issues/12472
vmapped.__signature__ = signature # type: ignore[attr-defined]

return allow_only_kwargs(vmapped) if apply_allow_kwargs else vmapped
if callable_with == "only_kwargs":
out = allow_only_kwargs(vmapped)
elif callable_with == "only_args":
out = vmapped
else:
raise ValueError(
f"Invalid callable_with option: {callable_with}. Possible options are "
"('only_args', 'only_kwargs')",
)

return out


def productmap(func: F, variables: list[str]) -> F:
Expand Down
43 changes: 43 additions & 0 deletions tests/test_dispatchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,3 +286,46 @@ def func(a, b, c):
def test_vmap_1d_error():
with pytest.raises(ValueError, match="Same argument provided more than once."):
vmap_1d(None, variables=["a", "a"])


def test_vmap_1d_callable_with_only_args():
def func(a):
return a

vmapped = vmap_1d(func, variables=["a"], callable_with="only_args")
a = jnp.array([1, 2])
# check that the function works with positional arguments
aaae(vmapped(a), a)
# check that the function fails with keyword arguments
with pytest.raises(
ValueError,
match="vmap in_axes must be an int, None, or a tuple of entries corresponding",
):
vmapped(a=1)


def test_vmap_1d_callable_with_only_kwargs():
def func(a):
return a

vmapped = vmap_1d(func, variables=["a"], callable_with="only_kwargs")
a = jnp.array([1, 2])
# check that the function works with keyword arguments
aaae(vmapped(a=a), a)
# check that the function fails with positional arguments
with pytest.raises(
ValueError,
match="This function has been decorated so that it allows only kwargs, but was",
):
vmapped(a)


def test_vmap_1d_callable_with_invalid():
def func(a):
return a

with pytest.raises(
ValueError,
match="Invalid callable_with option: invalid. Possible options are",
):
vmap_1d(func, variables=["a"], callable_with="invalid")

0 comments on commit 0da6888

Please sign in to comment.