Skip to content

Commit

Permalink
Add tests for DiscreteStateConverter class
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Sep 19, 2024
1 parent 227b859 commit efb91ae
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 22 deletions.
44 changes: 22 additions & 22 deletions src/lcm/input_processing/discrete_state_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,16 @@ class DiscreteStateConverter:
Attributes:
converted_states: The names of the states that have been converted.
index_to_label: A dictionary of functions mapping from the internal index to the
label for each converted state.
label_to_index: A dictionary of functions mapping from the label to the internal
index_to_code: A dictionary of functions mapping from the internal index to the
code for each converted state.
code_to_index: A dictionary of functions mapping from the code to the internal
index for each converted state.
"""

converted_states: list[str] = field(default_factory=list)
index_to_label: dict[str, Callable[[Array], Array]] = field(default_factory=dict)
label_to_index: dict[str, Callable[[Array], Array]] = field(default_factory=dict)
index_to_code: dict[str, Callable[[Array], Array]] = field(default_factory=dict)
code_to_index: dict[str, Callable[[Array], Array]] = field(default_factory=dict)

def internal_to_params(self, params: ParamsDict) -> ParamsDict:
"""Convert parameters from internal to external representation.
Expand Down Expand Up @@ -66,27 +66,27 @@ def internal_to_states(self, states: dict[str, Array]) -> dict[str, Array]:
If a state has been converted, the name of its corresponding index function must
be changed from `___{var}_index__` to `{var}`, and the values of the state must
be converted from indices to labels.
be converted from indices to codes.
"""
out = states.copy()
for var in self.converted_states:
out.pop(f"__{var}_index__")
out[var] = self.index_to_label[var](states[f"__{var}_index__"])
out[var] = self.index_to_code[var](states[f"__{var}_index__"])
return out

def states_to_internal(self, states: dict[str, Array]) -> dict[str, Array]:
"""Convert states from external to internal representation.
If a state has been converted, the name of its corresponding index function must
be changed from `{var}` to `___{var}_index__`, and the values of the state must
be converted from labels to indices.
be converted from codes to indices.
"""
out = states.copy()
for var in self.converted_states:
out.pop(var)
out[f"__{var}_index__"] = self.label_to_index[var](states[var])
out[f"__{var}_index__"] = self.code_to_index[var](states[var])
return out


Expand Down Expand Up @@ -147,32 +147,32 @@ def convert_discrete_codes_to_indices(
for var in non_index_states:
functions[f"next___{var}_index__"] = functions.pop(f"next_{var}")

# Add index to label functions
# Add index to code functions
# ----------------------------------------------------------------------------------
index_to_label_funcs = {
index_to_code_funcs = {
var: _get_index_to_code_func(gridspecs[var].to_jax(), name=var)
for var in non_index_discrete_vars
}
functions = functions | index_to_label_funcs
functions = functions | index_to_code_funcs

# Construct label to index functions for states
# Construct code to index functions for states
# ----------------------------------------------------------------------------------
converted_states = [s for s in non_index_discrete_vars if s in model.states]

label_to_index_funcs_for_states = {
code_to_index_funcs_for_states = {
var: _get_code_to_index_func(gridspecs[var].to_jax(), name=var)
for var in converted_states
}

# Subset index to label functions to only include states for converter
index_to_label_funcs_for_states = {
k: v for k, v in index_to_label_funcs.items() if k in model.states
# Subset index to code functions to only include states for converter
index_to_code_funcs_for_states = {
k: v for k, v in index_to_code_funcs.items() if k in model.states
}

converter = DiscreteStateConverter(
converted_states=converted_states,
index_to_label=index_to_label_funcs_for_states,
label_to_index=label_to_index_funcs_for_states,
index_to_code=index_to_code_funcs_for_states,
code_to_index=code_to_index_funcs_for_states,
)

new_model = model.replace(
Expand Down Expand Up @@ -223,7 +223,7 @@ def func(*args, **kwargs):


def _get_code_to_index_func(codes_array: Array, name: str) -> Callable[[Array], Array]:
"""Get function mapping from label to index.
"""Get function mapping from code to index.
Args:
codes_array: An array of codes.
Expand All @@ -236,9 +236,9 @@ def _get_code_to_index_func(codes_array: Array, name: str) -> Callable[[Array],
"""

@with_signature(args=[name])
def label_to_index(*args, **kwargs):
def code_to_index(*args, **kwargs):
kwargs = all_as_kwargs(args, kwargs, arg_names=[name])
data = kwargs[name]
return jnp.argmax(data[:, None] == codes_array[None, :], axis=1)

return label_to_index
return code_to_index
50 changes: 50 additions & 0 deletions tests/input_processing/test_discrete_state_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from lcm import DiscreteGrid
from lcm.input_processing.discrete_state_conversion import (
DiscreteStateConverter,
_get_code_to_index_func,
_get_discrete_vars_with_non_index_codes,
_get_index_to_code_func,
Expand Down Expand Up @@ -49,6 +50,55 @@ def next_c(a, b):
)


@pytest.fixture
def discrete_state_converter_kwargs():
return {
"converted_states": ["c"],
"index_to_code": {"c": _get_index_to_code_func(jnp.array([1, 0]), name="c")},
"code_to_index": {"c": _get_code_to_index_func(jnp.array([1, 0]), name="c")},
}


def test_discrete_state_converter_internal_to_params(discrete_state_converter_kwargs):
expected = {
"next_c": 1,
}
internal_params = {
"next___c_index__": 1,
}
converter = DiscreteStateConverter(**discrete_state_converter_kwargs)
assert converter.internal_to_params(internal_params) == expected


def test_discrete_state_converter_params_to_internal(discrete_state_converter_kwargs):
expected = {
"next___c_index__": 1,
}
params = {
"next_c": 1,
}
converter = DiscreteStateConverter(**discrete_state_converter_kwargs)
assert converter.params_to_internal(params) == expected


def test_discrete_state_converter_internal_to_states(discrete_state_converter_kwargs):
expected = jnp.array([1, 0])
internal_states = {
"__c_index__": jnp.array([0, 1]),
}
converter = DiscreteStateConverter(**discrete_state_converter_kwargs)
assert_array_equal(converter.internal_to_states(internal_states)["c"], expected)


def test_discrete_state_converter_states_to_internal(discrete_state_converter_kwargs):
expected = jnp.array([0, 1])
states = {
"c": jnp.array([1, 0]),
}
converter = DiscreteStateConverter(**discrete_state_converter_kwargs)
assert_array_equal(converter.states_to_internal(states)["__c_index__"], expected)


def test_get_index_to_label_func():
codes_array = jnp.array([1, 0])
got = _get_index_to_code_func(codes_array, name="foo")
Expand Down

0 comments on commit efb91ae

Please sign in to comment.