Skip to content

Commit

Permalink
Fix bug in choice_axes calculation (#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens authored Aug 24, 2023
1 parent af6661e commit d33a1e5
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 3 deletions.
9 changes: 6 additions & 3 deletions src/lcm/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,12 +470,15 @@ def determine_discrete_dense_choice_axes(variable_info):
discrete and dense choices.
"""
dense_vars = variable_info.query(
"is_dense & ~(is_choice & is_continuous)",
discrete_dense_choice_vars = variable_info.query(
"~is_continuous & is_dense & is_choice",
).index.tolist()

choice_vars = set(variable_info.query("is_choice").index.tolist())

choice_indices = [i for i, ax in enumerate(dense_vars) if ax in choice_vars]
# We add 1 because the first dimension corresponds to the sparse state variables
choice_indices = [
i + 1 for i, ax in enumerate(discrete_dense_choice_vars) if ax in choice_vars
]

return None if not choice_indices else tuple(choice_indices)
15 changes: 15 additions & 0 deletions tests/test_simulate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import jax.numpy as jnp
import pandas as pd
import pytest
from jax import random
from lcm.entry_point import (
Expand All @@ -17,6 +18,7 @@
_retrieve_non_sparse_choices,
create_choice_segments,
create_data_scs,
determine_discrete_dense_choice_axes,
dict_product,
filter_ccv_policy,
simulate,
Expand Down Expand Up @@ -333,3 +335,16 @@ def test_dict_product():
assert got_length == 4
for key, val in exp.items():
assert_array_equal(got_dict[key], val)


def test_determine_discrete_dense_choice_axes():
variable_info = pd.DataFrame(
{
"is_state": [True, True, False, True, False, False],
"is_dense": [False, True, True, False, True, True],
"is_choice": [False, False, True, True, True, True],
"is_continuous": [False, True, False, False, False, True],
},
)
got = determine_discrete_dense_choice_axes(variable_info)
assert got == (1, 2)

0 comments on commit d33a1e5

Please sign in to comment.