Skip to content

Commit

Permalink
Extend and fix filters (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens authored May 24, 2024
1 parent 20db0ec commit bbdb2a1
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 5 deletions.
4 changes: 2 additions & 2 deletions src/lcm/process_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ def _get_variable_info(user_model, function_info):
order += info.query("is_sparse & is_choice").index.tolist()
order += info.query("is_dense & is_discrete & is_state").index.tolist()
order += info.query("is_dense & is_discrete & is_choice").index.tolist()
order += info.query("is_continuous & is_state").index.tolist()
order += info.query("is_continuous & is_choice").index.tolist()
order += info.query("is_dense & is_continuous & is_state").index.tolist()
order += info.query("is_dense & is_continuous & is_choice").index.tolist()

if set(order) != set(info.index):
raise ValueError("Order and index do not match.")
Expand Down
13 changes: 11 additions & 2 deletions src/lcm/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def simulate(
data_scs, data_choice_segments = create_data_scs(
states=states,
model=model,
period=period,
)

# Compute objects dependent on data-state-choice-space
Expand Down Expand Up @@ -472,12 +473,14 @@ def _retrieve_non_sparse_choices(index, grids, grid_shape):
def create_data_scs(
states,
model,
period,
):
"""Create data state choice space.
Args:
states (dict): Dict with initial states.
model (Model): Model instance.
period (int): Period.
Returns:
- Space: Data state choice space.
Expand Down Expand Up @@ -548,10 +551,16 @@ def create_data_scs(
aggregator=jnp.logical_and,
)

fixed_inputs = {"_period": period}
potential_kwargs = _combination_grid | fixed_inputs

parameters = list(inspect.signature(scalar_filter).parameters)
kwargs = {k: v for k, v in _combination_grid.items() if k in parameters}
kwargs = {k: v for k, v in potential_kwargs.items() if k in parameters}

# we do not vmap over the period variable
vmapped_parameters = [p for p in parameters if p != "_period"]

_filter = vmap_1d(scalar_filter, variables=parameters)
_filter = vmap_1d(scalar_filter, variables=vmapped_parameters)
mask = _filter(**kwargs)

# filter infeasible combinations
Expand Down
2 changes: 1 addition & 1 deletion src/lcm/state_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def create_state_choice_space(model, period, *, is_last_period, jit_filter):
_filter_mask = create_filter_mask(
model=model,
subset=vi.query("is_sparse").index.tolist(),
fixed_inputs={"period": period},
fixed_inputs={"_period": period},
jit_filter=jit_filter,
)

Expand Down
18 changes: 18 additions & 0 deletions tests/test_entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,3 +350,21 @@ def test_create_compute_conditional_continuation_policy_with_discrete_model():
working=0,
disutility_of_work=1.0,
)


# ======================================================================================
# Test filter with _period argument
# ======================================================================================


def test_get_lcm_function_with_period_argument_in_filter():
user_model = get_model_config("iskhakov_et_al_2017", n_periods=3)

def absorbing_retirement_filter(retirement, lagged_retirement, _period):
return jnp.logical_or(retirement == 1, lagged_retirement == 0)

user_model["functions"]["absorbing_retirement_filter"] = absorbing_retirement_filter

solve_model, params_template = get_lcm_function(model=user_model)
params = tree_map(lambda _: 0.2, params_template)
solve_model(params)
16 changes: 16 additions & 0 deletions tests/test_process_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,3 +273,19 @@ def raw_func(health, wealth): # noqa: ARG001
variable_info=variable_info,
grids=None,
)


def test_variable_info_with_continuous_filter_has_unique_index():
user_model = get_model_config("iskhakov_et_al_2017", n_periods=3)

def wealth_filter(wealth):
return wealth > 200

user_model["functions"]["wealth_filter"] = wealth_filter

function_info = _get_function_info(user_model)
got = _get_variable_info(
user_model,
function_info,
)
assert got.index.is_unique
1 change: 1 addition & 0 deletions tests/test_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ def test_create_data_state_choice_space():
"lagged_retirement": jnp.array([0, 1]),
},
model=model,
period=0,
)
assert got_space.dense_vars == {}
assert_array_equal(got_space.sparse_vars["wealth"], jnp.array([10.0, 10.0, 20.0]))
Expand Down

0 comments on commit bbdb2a1

Please sign in to comment.