From 0da68884547c8e7ca0811d339b37059bb1a308b6 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Fri, 14 Jun 2024 15:21:52 +0200 Subject: [PATCH] Add callable_with option to vmap_1d --- src/lcm/dispatchers.py | 31 +++++++++++++++++++++------- tests/test_dispatchers.py | 43 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 7 deletions(-) diff --git a/src/lcm/dispatchers.py b/src/lcm/dispatchers.py index 8838249..3c96a55 100644 --- a/src/lcm/dispatchers.py +++ b/src/lcm/dispatchers.py @@ -1,6 +1,6 @@ import inspect from collections.abc import Callable -from typing import TypeVar +from typing import Literal, TypeVar from jax import Array, vmap @@ -72,11 +72,11 @@ def spacemap( if not sparse_vars: vmapped = _base_productmap(func, dense_vars) elif put_dense_first: - vmapped = vmap_1d(func, variables=sparse_vars, apply_allow_kwargs=False) + vmapped = vmap_1d(func, variables=sparse_vars, callable_with="only_args") vmapped = _base_productmap(vmapped, dense_vars) else: vmapped = _base_productmap(func, dense_vars) - vmapped = vmap_1d(vmapped, variables=sparse_vars, apply_allow_kwargs=False) + vmapped = vmap_1d(vmapped, variables=sparse_vars, callable_with="only_args") # This raises a mypy error but is perfectly fine to do. See # https://github.com/python/mypy/issues/12472 @@ -85,7 +85,12 @@ def spacemap( return allow_only_kwargs(vmapped) -def vmap_1d(func: F, variables: list[str], *, apply_allow_kwargs: bool = True) -> F: +def vmap_1d( + func: F, + variables: list[str], + *, + callable_with: Literal["only_args", "only_kwargs"] = "only_kwargs", +) -> F: """Apply vmap such that func is mapped over the specified variables. In contrast to a general vmap call, vmap_1d vectorizes along the leading axis of all @@ -95,8 +100,10 @@ def vmap_1d(func: F, variables: list[str], *, apply_allow_kwargs: bool = True) - Args: func (callable): The function to be dispatched. variables (list): List with names of arguments that over which we map. - apply_allow_kwargs (bool): Whether to apply the allow_kwargs decorator to the - dispatched function. + callable_with (str): Whether to apply the allow_kwargs decorator to the + dispatched function. If "only_args", the returned function can only be + called with positional arguments. If "only_kwargs", the returned function + can only be called with keyword arguments. Returns: callable: A callable with the same arguments as func (but with an additional @@ -134,7 +141,17 @@ def vmap_1d(func: F, variables: list[str], *, apply_allow_kwargs: bool = True) - # https://github.com/python/mypy/issues/12472 vmapped.__signature__ = signature # type: ignore[attr-defined] - return allow_only_kwargs(vmapped) if apply_allow_kwargs else vmapped + if callable_with == "only_kwargs": + out = allow_only_kwargs(vmapped) + elif callable_with == "only_args": + out = vmapped + else: + raise ValueError( + f"Invalid callable_with option: {callable_with}. Possible options are " + "('only_args', 'only_kwargs')", + ) + + return out def productmap(func: F, variables: list[str]) -> F: diff --git a/tests/test_dispatchers.py b/tests/test_dispatchers.py index 286f7f6..8b0a4bc 100644 --- a/tests/test_dispatchers.py +++ b/tests/test_dispatchers.py @@ -286,3 +286,46 @@ def func(a, b, c): def test_vmap_1d_error(): with pytest.raises(ValueError, match="Same argument provided more than once."): vmap_1d(None, variables=["a", "a"]) + + +def test_vmap_1d_callable_with_only_args(): + def func(a): + return a + + vmapped = vmap_1d(func, variables=["a"], callable_with="only_args") + a = jnp.array([1, 2]) + # check that the function works with positional arguments + aaae(vmapped(a), a) + # check that the function fails with keyword arguments + with pytest.raises( + ValueError, + match="vmap in_axes must be an int, None, or a tuple of entries corresponding", + ): + vmapped(a=1) + + +def test_vmap_1d_callable_with_only_kwargs(): + def func(a): + return a + + vmapped = vmap_1d(func, variables=["a"], callable_with="only_kwargs") + a = jnp.array([1, 2]) + # check that the function works with keyword arguments + aaae(vmapped(a=a), a) + # check that the function fails with positional arguments + with pytest.raises( + ValueError, + match="This function has been decorated so that it allows only kwargs, but was", + ): + vmapped(a) + + +def test_vmap_1d_callable_with_invalid(): + def func(a): + return a + + with pytest.raises( + ValueError, + match="Invalid callable_with option: invalid. Possible options are", + ): + vmap_1d(func, variables=["a"], callable_with="invalid")