From 55dad829821b2071ae01b018e2181a554bc16ba0 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Mon, 23 Sep 2024 16:41:50 +0200 Subject: [PATCH] Integrate comments from review --- tests/test_entry_point.py | 52 +++++++++++++++++++----------- tests/test_models/deterministic.py | 40 +++++++++++++++-------- tests/test_simulate.py | 2 +- 3 files changed, 60 insertions(+), 34 deletions(-) diff --git a/tests/test_entry_point.py b/tests/test_entry_point.py index 7ab9f3d..2e0e307 100644 --- a/tests/test_entry_point.py +++ b/tests/test_entry_point.py @@ -10,7 +10,11 @@ from lcm.input_processing import process_model from lcm.model_functions import get_utility_and_feasibility_function from lcm.state_space import create_state_choice_space -from tests.test_models.deterministic import get_model_config +from tests.test_models.deterministic import ( + DiscreteConsumptionChoice, + RetirementStatus, + get_model_config, +) from tests.test_models.deterministic import utility as iskhakov_et_al_2017_utility # ====================================================================================== @@ -18,7 +22,7 @@ # ====================================================================================== -STRIPPED_DOWN_AND_FULLY_DISCRETE_MODELS = [ +STRIPPED_DOWN_AND_DISCRETE_MODELS = [ "iskhakov_et_al_2017_stripped_down", "iskhakov_et_al_2017_discrete", ] @@ -90,11 +94,8 @@ def test_get_lcm_function_with_simulation_target_simple_fully_discrete(): @pytest.mark.parametrize( "model", - [ - get_model_config(name, n_periods=3) - for name in STRIPPED_DOWN_AND_FULLY_DISCRETE_MODELS - ], - ids=STRIPPED_DOWN_AND_FULLY_DISCRETE_MODELS, + [get_model_config(name, n_periods=3) for name in STRIPPED_DOWN_AND_DISCRETE_MODELS], + ids=STRIPPED_DOWN_AND_DISCRETE_MODELS, ) def test_get_lcm_function_with_simulation_is_coherent(model): """Test that solve_and_simulate creates same output as solve then simulate.""" @@ -153,7 +154,13 @@ def test_get_lcm_function_with_simulation_target_iskhakov_et_al_2017(model): vf_arr_list=vf_arr_list, initial_states={ "wealth": jnp.array([10.0, 10.0, 20.0]), - "lagged_retirement": jnp.array([0, 1, 1]), + "lagged_retirement": jnp.array( + [ + RetirementStatus.working, + RetirementStatus.retired, + RetirementStatus.retired, + ] + ), }, ) @@ -200,14 +207,14 @@ def test_create_compute_conditional_continuation_value(): val = compute_ccv( consumption=jnp.array([10, 20, 30.0]), - retirement=1, + retirement=RetirementStatus.retired, wealth=30, params=params, vf_arr=None, ) assert val == iskhakov_et_al_2017_utility( consumption=30.0, - working=0, + working=RetirementStatus.working, disutility_of_work=1.0, ) @@ -248,15 +255,17 @@ def test_create_compute_conditional_continuation_value_with_discrete_model(): ) val = compute_ccv( - consumption=jnp.array([0, 1]), - retirement=1, + consumption=jnp.array( + [DiscreteConsumptionChoice.low, DiscreteConsumptionChoice.high] + ), + retirement=RetirementStatus.retired, wealth=2, params=params, vf_arr=None, ) assert val == iskhakov_et_al_2017_utility( consumption=2, - working=0, + working=RetirementStatus.working, disutility_of_work=1.0, ) @@ -303,7 +312,7 @@ def test_create_compute_conditional_continuation_policy(): policy, val = compute_ccv_policy( consumption=jnp.array([10, 20, 30.0]), - retirement=1, + retirement=RetirementStatus.retired, wealth=30, params=params, vf_arr=None, @@ -311,7 +320,7 @@ def test_create_compute_conditional_continuation_policy(): assert policy == 2 assert val == iskhakov_et_al_2017_utility( consumption=30.0, - working=0, + working=RetirementStatus.working, disutility_of_work=1.0, ) @@ -352,8 +361,10 @@ def test_create_compute_conditional_continuation_policy_with_discrete_model(): ) policy, val = compute_ccv_policy( - consumption=jnp.array([0, 1]), - retirement=1, + consumption=jnp.array( + [DiscreteConsumptionChoice.low, DiscreteConsumptionChoice.high] + ), + retirement=RetirementStatus.retired, wealth=2, params=params, vf_arr=None, @@ -361,7 +372,7 @@ def test_create_compute_conditional_continuation_policy_with_discrete_model(): assert policy == 1 assert val == iskhakov_et_al_2017_utility( consumption=2, - working=0, + working=RetirementStatus.working, disutility_of_work=1.0, ) @@ -375,7 +386,10 @@ def test_get_lcm_function_with_period_argument_in_filter(): model = get_model_config("iskhakov_et_al_2017", n_periods=3) def absorbing_retirement_filter(retirement, lagged_retirement, _period): - return jnp.logical_or(retirement == 1, lagged_retirement == 0) + return jnp.logical_or( + retirement == RetirementStatus.retired, + lagged_retirement == RetirementStatus.working, + ) model.functions["absorbing_retirement_filter"] = absorbing_retirement_filter diff --git a/tests/test_models/deterministic.py b/tests/test_models/deterministic.py index d88cc52..8ff625c 100644 --- a/tests/test_models/deterministic.py +++ b/tests/test_models/deterministic.py @@ -28,6 +28,19 @@ class RetirementStatus: retired: int = 1 +@dataclass +class DiscreteConsumptionChoice: + low: int = 0 + high: int = 1 + + +@dataclass +class DiscreteWealthLevels: + low: int = 0 + medium: int = 1 + high: int = 2 + + # -------------------------------------------------------------------------------------- # Utility functions # -------------------------------------------------------------------------------------- @@ -51,7 +64,7 @@ def utility_with_filter( def utility_discrete(consumption, working, disutility_of_work): # In the discrete model, consumption is defined as "low" or "high". This can be # translated to the levels 1 and 2. - consumption_level = 1 + (consumption == ConsumptionStatus.high) + consumption_level = 1 + (consumption == DiscreteConsumptionChoice.high) return utility(consumption_level, working, disutility_of_work) @@ -81,6 +94,15 @@ def next_wealth(wealth, consumption, labor_income, interest_rate): return (1 + interest_rate) * (wealth - consumption) + labor_income +def next_wealth_discrete(wealth, consumption, labor_income, interest_rate): + # For discrete state variables, we need to assure that the next state is also a + # valid state, i.e., it is a member of the discrete grid. + continuous = next_wealth(wealth, consumption, labor_income, interest_rate) + return jnp.clip( + jnp.rint(continuous), DiscreteWealthLevels.low, DiscreteWealthLevels.high + ).astype(jnp.int32) + + # -------------------------------------------------------------------------------------- # Constraints # -------------------------------------------------------------------------------------- @@ -170,12 +192,6 @@ def absorbing_retirement_filter(retirement, lagged_retirement): ) -@dataclass -class ConsumptionStatus: - low: int = 0 - high: int = 1 - - ISKHAKOV_ET_AL_2017_DISCRETE = Model( description=( "Starts from Iskhakov et al. (2017), removes filters and the lagged_retirement " @@ -184,21 +200,17 @@ class ConsumptionStatus: n_periods=3, functions={ "utility": utility_discrete, - "next_wealth": next_wealth, + "next_wealth": next_wealth_discrete, "consumption_constraint": consumption_constraint, "labor_income": labor_income, "working": working, }, choices={ "retirement": DiscreteGrid(RetirementStatus), - "consumption": DiscreteGrid(ConsumptionStatus), + "consumption": DiscreteGrid(DiscreteConsumptionChoice), }, states={ - "wealth": LinspaceGrid( - start=0, - stop=400, - n_points=100, - ), + "wealth": DiscreteGrid(DiscreteWealthLevels), }, ) diff --git a/tests/test_simulate.py b/tests/test_simulate.py index f9fd516..3d7b9cb 100644 --- a/tests/test_simulate.py +++ b/tests/test_simulate.py @@ -182,7 +182,7 @@ def test_simulate_with_only_discrete_choices(): assert_array_equal(res["retirement"], jnp.array([0, 1, 1, 1])) assert_array_equal(res["consumption"], jnp.array([0, 1, 1, 1])) - assert_array_equal(res["wealth"], jnp.array([0, 4, 1.5, 3])) + assert_array_equal(res["wealth"], jnp.array([0, 4, 2, 2])) # ======================================================================================