Skip to content

Commit

Permalink
Split up test model modules and create common get config / params fun…
Browse files Browse the repository at this point in the history
…ction
  • Loading branch information
timmens committed Sep 23, 2024
1 parent 240c617 commit 9b7133f
Show file tree
Hide file tree
Showing 14 changed files with 246 additions and 210 deletions.
2 changes: 1 addition & 1 deletion tests/input_processing/test_process_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
process_model,
)
from lcm.mark import StochasticInfo
from tests.test_models.deterministic import (
from tests.test_models import (
get_model_config,
)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_analytical_solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from lcm._config import TEST_DATA
from lcm.entry_point import get_lcm_function
from tests.test_models.deterministic import get_model_config, get_params
from tests.test_models import get_model_config, get_params

# ======================================================================================
# Model specifications
Expand Down
12 changes: 4 additions & 8 deletions tests/test_entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
from lcm.input_processing import process_model
from lcm.model_functions import get_utility_and_feasibility_function
from lcm.state_space import create_state_choice_space
from tests.test_models import get_model_config
from tests.test_models.deterministic import (
DiscreteConsumptionChoice,
RetirementStatus,
get_model_config,
)
from tests.test_models.deterministic import utility as iskhakov_et_al_2017_utility
from tests.test_models.discrete_deterministic import ConsumptionChoice

# ======================================================================================
# Test cases
Expand Down Expand Up @@ -255,9 +255,7 @@ def test_create_compute_conditional_continuation_value_with_discrete_model():
)

val = compute_ccv(
consumption=jnp.array(
[DiscreteConsumptionChoice.low, DiscreteConsumptionChoice.high]
),
consumption=jnp.array([ConsumptionChoice.low, ConsumptionChoice.high]),
retirement=RetirementStatus.retired,
wealth=2,
params=params,
Expand Down Expand Up @@ -361,9 +359,7 @@ def test_create_compute_conditional_continuation_policy_with_discrete_model():
)

policy, val = compute_ccv_policy(
consumption=jnp.array(
[DiscreteConsumptionChoice.low, DiscreteConsumptionChoice.high]
),
consumption=jnp.array([ConsumptionChoice.low, ConsumptionChoice.high]),
retirement=RetirementStatus.retired,
wealth=2,
params=params,
Expand Down
3 changes: 2 additions & 1 deletion tests/test_model_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
get_utility_and_feasibility_function,
)
from lcm.state_space import create_state_choice_space
from tests.test_models.deterministic import get_model_config, utility
from tests.test_models import get_model_config
from tests.test_models.deterministic import utility


def test_get_combined_constraint():
Expand Down
3 changes: 3 additions & 0 deletions tests/test_models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .get_model import get_model_config, get_params

__all__ = ["get_model_config", "get_params"]
80 changes: 0 additions & 80 deletions tests/test_models/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
"""

from copy import deepcopy
from dataclasses import dataclass

import jax.numpy as jnp
Expand All @@ -28,19 +27,6 @@ class RetirementStatus:
retired: int = 1


@dataclass
class DiscreteConsumptionChoice:
low: int = 0
high: int = 1


@dataclass
class DiscreteWealthLevels:
low: int = 0
medium: int = 1
high: int = 2


# --------------------------------------------------------------------------------------
# Utility functions
# --------------------------------------------------------------------------------------
Expand All @@ -61,13 +47,6 @@ def utility_with_filter(
return utility(consumption, working, disutility_of_work)


def utility_discrete(consumption, working, disutility_of_work):
# In the discrete model, consumption is defined as "low" or "high". This can be
# translated to the levels 1 and 2.
consumption_level = 1 + (consumption == DiscreteConsumptionChoice.high)
return utility(consumption_level, working, disutility_of_work)


# --------------------------------------------------------------------------------------
# Auxiliary variables
# --------------------------------------------------------------------------------------
Expand All @@ -94,15 +73,6 @@ def next_wealth(wealth, consumption, labor_income, interest_rate):
return (1 + interest_rate) * (wealth - consumption) + labor_income


def next_wealth_discrete(wealth, consumption, labor_income, interest_rate):
# For discrete state variables, we need to assure that the next state is also a
# valid state, i.e., it is a member of the discrete grid.
continuous = next_wealth(wealth, consumption, labor_income, interest_rate)
return jnp.clip(
jnp.rint(continuous), DiscreteWealthLevels.low, DiscreteWealthLevels.high
).astype(jnp.int32)


# --------------------------------------------------------------------------------------
# Constraints
# --------------------------------------------------------------------------------------
Expand Down Expand Up @@ -190,53 +160,3 @@ def absorbing_retirement_filter(retirement, lagged_retirement):
),
},
)


ISKHAKOV_ET_AL_2017_DISCRETE = Model(
description=(
"Starts from Iskhakov et al. (2017), removes filters and the lagged_retirement "
"state, and makes the consumption decision and the wealth state discrete."
),
n_periods=3,
functions={
"utility": utility_discrete,
"next_wealth": next_wealth_discrete,
"consumption_constraint": consumption_constraint,
"labor_income": labor_income,
"working": working,
},
choices={
"retirement": DiscreteGrid(RetirementStatus),
"consumption": DiscreteGrid(DiscreteConsumptionChoice),
},
states={
"wealth": DiscreteGrid(DiscreteWealthLevels),
},
)


# ======================================================================================
# Get models and params
# ======================================================================================

IMPLEMENTED_MODELS = {
"iskhakov_et_al_2017": ISKHAKOV_ET_AL_2017,
"iskhakov_et_al_2017_stripped_down": ISKHAKOV_ET_AL_2017_STRIPPED_DOWN,
"iskhakov_et_al_2017_discrete": ISKHAKOV_ET_AL_2017_DISCRETE,
}


def get_model_config(model_name: str, n_periods: int):
model_config = deepcopy(IMPLEMENTED_MODELS[model_name])
return model_config.replace(n_periods=n_periods)


def get_params(beta=0.95, disutility_of_work=0.25, interest_rate=0.05, wage=5.0):
return {
"beta": beta,
"utility": {"disutility_of_work": disutility_of_work},
"next_wealth": {
"interest_rate": interest_rate,
},
"labor_income": {"wage": wage},
}
107 changes: 107 additions & 0 deletions tests/test_models/discrete_deterministic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""Example specifications of fully discrete deterministic consumption-saving model.
The specification builds on the example model presented in the paper: "The endogenous
grid method for discrete-continuous dynamic choice models with (or without) taste
shocks" by Fedor Iskhakov, Thomas H. Jørgensen, John Rust and Bertel Schjerning (2017,
https://doi.org/10.3982/QE643). See module `tests.test_models.deterministic` for the
continuous version.
"""

from dataclasses import dataclass

import jax.numpy as jnp

from lcm import DiscreteGrid, Model
from tests.test_models.deterministic import (
RetirementStatus,
labor_income,
next_wealth,
utility,
working,
)

# ======================================================================================
# Model functions
# ======================================================================================


# --------------------------------------------------------------------------------------
# Categorical variables
# --------------------------------------------------------------------------------------
@dataclass
class ConsumptionChoice:
low: int = 0
high: int = 1


@dataclass
class WealthStatus:
low: int = 0
medium: int = 1
high: int = 2


# --------------------------------------------------------------------------------------
# Utility functions
# --------------------------------------------------------------------------------------
def utility_discrete(consumption, working, disutility_of_work):
# In the discrete model, consumption is defined as "low" or "high". This can be
# translated to the levels 1 and 2.
consumption_level = 1 + (consumption == ConsumptionChoice.high)
return utility(consumption_level, working, disutility_of_work)


# --------------------------------------------------------------------------------------
# State transitions
# --------------------------------------------------------------------------------------
def next_wealth_discrete(wealth, consumption, labor_income, interest_rate):
# For discrete state variables, we need to assure that the next state is also a
# valid state, i.e., it is a member of the discrete grid.
continuous = next_wealth(wealth, consumption, labor_income, interest_rate)
return jnp.clip(jnp.rint(continuous), WealthStatus.low, WealthStatus.high).astype(
jnp.int32
)


# --------------------------------------------------------------------------------------
# Constraints
# --------------------------------------------------------------------------------------
def consumption_constraint(consumption, wealth):
return consumption <= wealth


# --------------------------------------------------------------------------------------
# Filters
# --------------------------------------------------------------------------------------
def absorbing_retirement_filter(retirement, lagged_retirement):
return jnp.logical_or(
retirement == RetirementStatus.retired,
lagged_retirement == RetirementStatus.working,
)


# ======================================================================================
# Model specifications
# ======================================================================================
ISKHAKOV_ET_AL_2017_DISCRETE = Model(
description=(
"Starts from Iskhakov et al. (2017), removes filters and the lagged_retirement "
"state, and makes the consumption decision and the wealth state discrete."
),
n_periods=3,
functions={
"utility": utility_discrete,
"next_wealth": next_wealth_discrete,
"consumption_constraint": consumption_constraint,
"labor_income": labor_income,
"working": working,
},
choices={
"retirement": DiscreteGrid(RetirementStatus),
"consumption": DiscreteGrid(ConsumptionChoice),
},
states={
"wealth": DiscreteGrid(WealthStatus),
},
)
Loading

0 comments on commit 9b7133f

Please sign in to comment.