Skip to content

Commit

Permalink
Integrate comments from review
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Jun 14, 2024
1 parent 401580c commit 55dbef3
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 26 deletions.
4 changes: 2 additions & 2 deletions src/lcm/functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def func_with_only_kwargs(*args, **kwargs):
if args:
raise ValueError(
(
"This function was decorated with allow_only_kwargs, but was "
"called with positional arguments."
"This function has been decorated so that it allows only kwargs, "
"but was called with positional arguments."
),
)

Expand Down
28 changes: 9 additions & 19 deletions src/lcm/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import jax.numpy as jnp
import pandas as pd
from dags import concatenate_functions
from jax import vmap

from lcm.argmax import argmax, segment_argmax
from lcm.dispatchers import spacemap, vmap_1d
Expand Down Expand Up @@ -441,28 +442,17 @@ def retrieve_non_sparse_choices(indices, grids, grid_shape):
if indices is None:
out = {}
else:
out = _retrieve_non_sparse_choices(indices, grids, grid_shape)
indices = vmapped_unravel_index(indices, grid_shape)
out = {
name: grid[index]
for (name, grid), index in zip(grids.items(), indices, strict=True)
}
return out


@partial(vmap_1d, variables=["index"], apply_allow_kwargs=False)
def _retrieve_non_sparse_choices(index, grids, grid_shape):
"""Retrieve dense or continuous choices given index.
Args:
index (int): General index. Represents the index of the flattened grid.
grids (dict): Dictionary of grids.
grid_shape (tuple): Shape of the grids. Is used to unravel the index.
Returns:
dict: Dictionary of choices.
"""
indices = jnp.unravel_index(index, shape=grid_shape)
return {
name: grid[index]
for (name, grid), index in zip(grids.items(), indices, strict=True)
}
# vmap jnp.unravel_index over the first axis of the `indices` argument, while holding
# the `shape` argument constant (in_axes = (0, None)).
vmapped_unravel_index = vmap(jnp.unravel_index, in_axes=(0, None))


# ======================================================================================
Expand Down
10 changes: 7 additions & 3 deletions tests/test_dispatchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,19 @@ def test_productmap_with_all_arguments_mapped(func, args, grids, expected, reque
expected = request.getfixturevalue(expected)

decorated = productmap(func, args)

calculated = decorated(**grids)
aaae(calculated, expected)


def test_productmap_with_positional_args(setup_productmap_f):
decorated = productmap(f, ["a", "b", "c"])
match = (
"This function was decorated with allow_only_kwargs, but was called with "
"positional arguments."
"This function has been decorated so that it allows only kwargs, but was "
"called with positional arguments."
)
with pytest.raises(ValueError, match=match):
decorated(*grids.values())
decorated(*setup_productmap_f.values())


def test_productmap_different_func_order(setup_productmap_f):
Expand Down
14 changes: 12 additions & 2 deletions tests/test_functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,19 @@ def f(a, /, *, b):
assert allow_only_kwargs(f)(a=1, b=2) == 3


def test_allow_only_kwargs_incorrect_number_of_args():
def test_allow_only_kwargs_too_many_args():
def f(a, /, b):
return a + b

too_many_match = re.escape("Expected arguments: ['a', 'b'], got extra: {'c'}")
with pytest.raises(ValueError, match=too_many_match):
allow_only_kwargs(f)(a=1, b=2, c=3)


def test_allow_only_kwargs_too_few_args():
def f(a, /, b):
return a + b

too_few_match = re.escape("Expected arguments: ['a', 'b'], missing: {'b'}")
with pytest.raises(ValueError, match=too_few_match):
allow_only_kwargs(f)(a=1)
Expand Down Expand Up @@ -195,13 +200,18 @@ def f(a, b, c, *, d):
assert allow_args(f)(1, 2, d=4, c=3) == 10


def test_allow_args_incorrect_number_of_args():
def test_allow_args_too_many_args():
def f(a, *, b):
return a + b

with pytest.raises(ValueError, match="Too many arguments provided."):
allow_args(f)(1, 2, b=3)


def test_allow_args_too_few_args():
def f(a, *, b):
return a + b

with pytest.raises(ValueError, match="Not all arguments provided."):
allow_args(f)(1)

Expand Down
9 changes: 9 additions & 0 deletions tests/test_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,15 @@ def test_retrieve_non_sparse_choices():
assert_array_equal(got["b"], jnp.array([10, 16, 12]))


def test_retrieve_non_sparse_choices_no_indices():
got = retrieve_non_sparse_choices(
indices=None,
grids={"a": jnp.linspace(0, 1, 5), "b": jnp.linspace(10, 20, 6)},
grid_shape=(5, 6),
)
assert got == {}


def test_filter_ccv_policy():
ccc_policy = jnp.array(
[
Expand Down

0 comments on commit 55dbef3

Please sign in to comment.