diff --git a/src/lcm/dispatchers.py b/src/lcm/dispatchers.py index 93b5f04b..79c60680 100644 --- a/src/lcm/dispatchers.py +++ b/src/lcm/dispatchers.py @@ -39,6 +39,8 @@ def spacemap(func, dense_vars, sparse_vars, *, dense_first): above but there might be additional dimensions. """ + func = allow_args(func) # vmap cannot deal with keyword-only arguments + if not set(dense_vars).isdisjoint(sparse_vars): raise ValueError("dense_vars and sparse_vars overlap.") @@ -96,6 +98,8 @@ def productmap(func, variables): might be additional dimensions. """ + func = allow_args(func) # vmap cannot deal with keyword-only arguments + if len(variables) != len(set(variables)): raise ValueError("Same argument provided more than once.") @@ -188,24 +192,102 @@ def allow_kwargs(func): possibility to call it with keyword arguments). """ + signature = inspect.signature(func) + parameters = signature.parameters + + # Get names of keyword-only arguments + kw_only_parameters = [ + p.name for p in parameters.values() if p.kind == inspect.Parameter.KEYWORD_ONLY + ] + + # Create new signature without positional-only arguments + new_parameters = [ + p.replace(kind=inspect.Parameter.POSITIONAL_OR_KEYWORD) + if p.kind == inspect.Parameter.POSITIONAL_ONLY + else p + for p in parameters.values() + ] + new_signature = signature.replace(parameters=new_parameters) @functools.wraps(func) def allow_kwargs_wrapper(*args, **kwargs): - parameters = list(inspect.signature(func).parameters) + # Retrieve keyword-only arguments + kw_only_kwargs = {k: kwargs[k] for k in kw_only_parameters} - positional = list(args) if args is not None else [] + # Get kwargs that will be converted to positional arguments + pos_kwargs = {k: v for k, v in kwargs.items() if k not in kw_only_parameters} - kwargs = {} if args is None else kwargs - if len(args) + len(kwargs) != len(parameters): + # Check if the total number of arguments matches the function signature + if len(args) + len(pos_kwargs) + len(kw_only_kwargs) != len(parameters): raise ValueError("Not enough or too many arguments provided.") - positional += convert_kwargs_to_args(kwargs, parameters) + # Separate positional arguments and convert keyword arguments to positional + positional = list(args) + positional += convert_kwargs_to_args(pos_kwargs, list(parameters)) - return func(*positional) + return func(*positional, **kw_only_kwargs) + allow_kwargs_wrapper.__signature__ = new_signature return allow_kwargs_wrapper +def allow_args(func): + """Allow a function to be called with positional arguments. + + Args: + func (callable): The function to be wrapped. + + Returns: + callable: A callable with the same arguments as func (but with the additional + possibility to call it with positional arguments). + + """ + signature = inspect.signature(func) + parameters = signature.parameters + + # Count the number of positional-only arguments + n_positional_only_parameters = len( + [p for p in parameters.values() if p.kind == inspect.Parameter.POSITIONAL_ONLY], + ) + + # Create new signature without keyword-only arguments + new_parameters = [ + p.replace(kind=inspect.Parameter.POSITIONAL_OR_KEYWORD) + if p.kind == inspect.Parameter.KEYWORD_ONLY + else p + for p in parameters.values() + ] + new_signature = signature.replace(parameters=new_parameters) + + @functools.wraps(func) + def allow_args_wrapper(*args, **kwargs): + # Check if the total number of arguments matches the function signature + if len(args) + len(kwargs) != len(parameters): + raise ValueError("Not enough or too many arguments provided.") + + # Convert all arguments to positional arguments in correct order + positional = list(args) + positional += convert_kwargs_to_args(kwargs, list(parameters)) + + # Extract positional-only arguments + positional_only = positional[:n_positional_only_parameters] + + # Create kwargs dictionary with remaining arguments + kwargs_names = list(parameters)[n_positional_only_parameters:] + kwargs = dict( + zip( + kwargs_names, + positional[n_positional_only_parameters:], + strict=True, + ), + ) + + return func(*positional_only, **kwargs) + + allow_args_wrapper.__signature__ = new_signature + return allow_args_wrapper + + def convert_kwargs_to_args(kwargs, parameters): """Convert kwargs to args in the order of parameters. diff --git a/tests/test_dispatchers.py b/tests/test_dispatchers.py index cd0e02e5..28937788 100644 --- a/tests/test_dispatchers.py +++ b/tests/test_dispatchers.py @@ -1,8 +1,11 @@ +import inspect import itertools import jax.numpy as jnp import pytest +from jax import vmap from lcm.dispatchers import ( + allow_args, allow_kwargs, convert_kwargs_to_args, productmap, @@ -12,16 +15,28 @@ from numpy.testing import assert_array_almost_equal as aaae -def f(a, b, c): +def f(a, /, *, b, c): + """Tests that dispatchers can handle positional-only and keyword-only arguments. + + a is positional-only, b and c are keyword-only + """ return jnp.sin(a) + jnp.cos(b) + jnp.tan(c) -def f2(b, a, c): +def f2(b, a, /, *, c): + """Tests that dispatchers can handle positional-only and keyword-only arguments. + + b and a are positional-only, c is keyword-only + """ return jnp.sin(a) + jnp.cos(b) + jnp.tan(c) -def g(a, b, c, d): - return f(a, b, c) + jnp.log(d) +def g(a, /, b, *, c, d): + """Tests that dispatchers can handle positional-only and keyword-only arguments. + + a is positional-only, b is positional-or-keyword, c and d are keyword-only + """ + return f(a, b=b, c=c) + jnp.log(d) # ====================================================================================== @@ -47,7 +62,7 @@ def expected_productmap_f(): } helper = jnp.array(list(itertools.product(*grids.values()))).T - return f(*helper).reshape(10, 7, 5) + return allow_kwargs(allow_args(f))(*helper).reshape(10, 7, 5) @pytest.fixture() @@ -70,7 +85,7 @@ def expected_productmap_g(): } helper = jnp.array(list(itertools.product(*grids.values()))).T - return g(*helper).reshape(10, 7, 5, 4) + return allow_kwargs(allow_args(g))(*helper).reshape(10, 7, 5, 4) @pytest.mark.parametrize( @@ -120,7 +135,7 @@ def test_productmap_with_all_arguments_mapped_some_len_one(): helper = jnp.array(list(itertools.product(*grids.values()))).T - expected = f(*helper).reshape(1, 1, 5) + expected = allow_kwargs(allow_args(f))(*helper).reshape(1, 1, 5) decorated = productmap(f, ["a", "b", "c"]) calculated = decorated(*grids.values()) @@ -148,7 +163,7 @@ def test_productmap_with_some_arguments_mapped(): helper = jnp.array(list(itertools.product(grids["a"], [grids["b"]], grids["c"]))).T - expected = f(*helper).reshape(10, 5) + expected = allow_kwargs(allow_args(f))(*helper).reshape(10, 5) decorated = productmap(f, ["a", "c"]) calculated = decorated(*grids.values()) @@ -202,7 +217,7 @@ def expected_spacemap(): all_grids = {**value_grid, **combination_grid} helper = jnp.array(list(itertools.product(*all_grids.values()))).T - return g(*helper).reshape(3, 2, 4 * 5) + return allow_kwargs(allow_args(g))(*helper).reshape(3, 2, 4 * 5) @pytest.mark.parametrize("dense_first", [True, False]) @@ -272,6 +287,113 @@ def f(a, /, b): assert allow_kwargs(f)(a=1, b=2) == 3 +def test_allow_kwargs_with_keyword_only_args(): + def f(a, /, *, b): + return a + b + + with pytest.raises(TypeError): + f(a=1, b=2) + + assert allow_kwargs(f)(a=1, b=2) == 3 + + +def test_allow_kwargs_incorrect_number_of_args(): + def f(a, /, b): + return a + b + + with pytest.raises(ValueError, match="Not enough or too many arguments"): + allow_kwargs(f)(a=1, b=2, c=3) + + with pytest.raises(ValueError, match="Not enough or too many arguments"): + allow_kwargs(f)(a=1) + + +def test_allow_kwargs_signature_change(): + def f(a, /, b, *, c): # noqa: ARG001 + pass + + decorated = allow_kwargs(f) + parameters = inspect.signature(decorated).parameters + + assert parameters["a"].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + assert parameters["b"].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + assert parameters["c"].kind == inspect.Parameter.KEYWORD_ONLY + + +# ====================================================================================== +# allow args +# ====================================================================================== + + +def test_allow_args(): + def f(a, *, b): + # b is keyword-only + return a + b + + with pytest.raises(TypeError): + f(1, 2) + + assert allow_args(f)(1, 2) == 3 + assert allow_args(f)(1, b=2) == 3 + assert allow_args(f)(b=2, a=1) == 3 + + +def test_allow_args_different_kwargs_order(): + def f(a, b, c, *, d): + return a + b + c + d + + with pytest.raises(TypeError): + f(1, 2, 3, 4) + + assert allow_args(f)(1, 2, 3, 4) == 10 + assert allow_args(f)(1, 2, d=4, c=3) == 10 + + +def test_allow_args_incorrect_number_of_args(): + def f(a, *, b): + return a + b + + with pytest.raises(ValueError, match="Not enough or too many arguments"): + allow_args(f)(1, 2, b=3) + + with pytest.raises(ValueError, match="Not enough or too many arguments"): + allow_args(f)(1) + + +def test_allow_args_with_vmap(): + def f(a, *, b): + # b is keyword-only + return a + b + + f_vmapped = vmap(f, in_axes=(0, 0)) + f_allow_args_vmapped = vmap(allow_args(f), in_axes=(0, 0)) + + a = jnp.arange(2) + b = jnp.arange(2) + + with pytest.raises(TypeError): + # TypeError since b is keyword-only + f_vmapped(a, b) + + with pytest.raises(ValueError, match="vmap in_axes specification"): + # ValueError since vmap doesn't support keyword arguments + f_vmapped(a, b=b) + + aaae(f_allow_args_vmapped(a, b), jnp.array([0, 2])) + + +def test_allow_args_signature_change(): + def f(a, /, b, *, c): # noqa: ARG001 + pass + + decorated = allow_args(f) + parameters = inspect.signature(decorated).parameters + + assert parameters["a"].kind == inspect.Parameter.POSITIONAL_ONLY + assert parameters["b"].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + assert parameters["c"].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + + # ====================================================================================== # vmap_1d # ======================================================================================