Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve simulation output #36

Merged
merged 9 commits into from
Oct 4, 2023
119 changes: 114 additions & 5 deletions src/lcm/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from functools import partial

import jax.numpy as jnp
import pandas as pd
from dags import concatenate_functions
from pybaum import leaf_names, tree_flatten

from lcm.argmax import argmax, segment_argmax
from lcm.dispatchers import spacemap, vmap_1d
Expand All @@ -19,6 +21,7 @@ def simulate(
initial_states,
solve_model=None,
vf_arr_list=None,
additional_targets=None,
):
"""Simulate the model forward in time.

Expand All @@ -41,6 +44,8 @@ def simulate(
vf_arr_list (list): List of value function arrays of length n_periods. This is
the output of the model's `solve` function. If not provided, the model is
solved first.
additional_targets (list): List of targets to compute. If provided, the targets
are computed and added to the simulation results.

Returns:
list: List of length n_periods containing the valuations, optimal choices, and
Expand Down Expand Up @@ -80,7 +85,7 @@ def simulate(

# Forward simulation
# ==================================================================================
result = []
_simulation_results = []

for period in range(n_periods):
# Create data state choice space
Expand Down Expand Up @@ -167,19 +172,123 @@ def simulate(
# Store results
# ==============================================================================
choices = {**dense_choices, **sparse_choices, **cont_choices}
result.append(

_simulation_results.append(
{
"choices": choices,
"value": value,
"choices": choices,
"states": states,
},
)

# Update states
# ==============================================================================
states = next_state(**choices, **states)
states = {key.lstrip("next_"): val for key, val in states.items()}
states = {key.removeprefix("next_"): val for key, val in states.items()}

processed = _process_simulated_data(_simulation_results)

if additional_targets is not None:
calculated_targets = _compute_targets(
processed,
targets=additional_targets,
model_functions=model.functions,
params=params,
)
processed = {**processed, **calculated_targets}

return _as_data_frame(processed, n_periods=n_periods)


def _as_data_frame(processed, n_periods):
"""Convert processed simulation results to DataFrame.

Args:
processed (dict): Dict with processed simulation results.
n_periods (int): Number of periods.

Returns:
pd.DataFrame: DataFrame with the simulation results. The index is a multi-index
with the first level corresponding to the period and the second level
corresponding to the initial state id. The columns correspond to the value,
and the choice and state variables, and potentially auxiliary variables.

"""
n_initial_states = len(processed["value"]) // n_periods
index = pd.MultiIndex.from_product(
[range(n_periods), range(n_initial_states)],
names=["period", "initial_state_id"],
)
return pd.DataFrame(processed, index=index)


def _compute_targets(processed_results, targets, model_functions, params):
"""Compute targets.

Args:
processed_results (dict): Dict with processed simulation results. Values must be
one-dimensional arrays.
targets (list): List of targets to compute.
model_functions (dict): Dict with model functions.
params (dict): Dict with model parameters.

Returns:
dict: Dict with computed targets.

"""
target_func = concatenate_functions(
functions=model_functions,
targets=targets,
return_type="dict",
)

# get list of variables over which we want to vectorize the target function
variables = [
p for p in list(inspect.signature(target_func).parameters) if p != "params"
]

target_func = vmap_1d(target_func, variables=variables)

kwargs = {k: v for k, v in processed_results.items() if k in variables}
return target_func(params=params, **kwargs)


def _process_simulated_data(results):
"""Process and flatten the simulation results.

Args:
results (list): List of dicts with simulation results. Each dict contains the
value, choices, and states for one period.

Returns:
dict: Dict with processed simulation results. The keys are the variable names
and the values are the flattened arrays, with dimension (n_periods *
n_initial_states, ).

"""
column_names = [
# remove prefixes 'choice_' and 'states_' from variable names, which are added
# by the leaf_names function
name.removeprefix("choices_").removeprefix("states_")
for name in leaf_names(results[0])
]

return result
# ==================================================================================
# Get dict of arrays for each var with dimension (n_periods * n_initial_states, )
# ----------------------------------------------------------------------------------
# The arrays are flattened, so that the resulting dictionary has a one-dimensional
# array for each variable. The length of this array is the number of periods times
# the number of initial states. The order of array elements is given by an outer
# level of periods and an inner level of initial states ids.
# ==================================================================================

# flatten the nested dictionary structure to get a list of dicts, where each dict
# has only array values
processed = [
dict(zip(column_names, tree_flatten(vals)[0], strict=True)) for vals in results
]
processed = {key: [d[key] for d in processed] for key in column_names}
return {key: jnp.concatenate(values) for key, values in processed.items()}


@partial(vmap_1d, variables=["ccv_policy", "dense_argmax"])
Expand Down
45 changes: 29 additions & 16 deletions tests/test_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,8 @@ def test_simulate_using_raw_inputs(simulate_inputs):
**simulate_inputs,
)

choices = got[0]["choices"]

assert_array_equal(choices["retirement"], 1)
assert_array_almost_equal(choices["consumption"], jnp.array([1.0, 50.400803]))
assert_array_equal(got.loc[0, :]["retirement"], 1)
assert_array_almost_equal(got.loc[0, :]["consumption"], jnp.array([1.0, 50.400803]))


# ======================================================================================
Expand Down Expand Up @@ -119,7 +117,7 @@ def _model_solution(n_periods):
return _model_solution


@pytest.mark.parametrize("n_periods", range(1, PHELPS_DEATON["n_periods"] + 1))
@pytest.mark.parametrize("n_periods", range(3, PHELPS_DEATON["n_periods"] + 1))
def test_simulate_using_get_lcm_function(phelps_deaton_model_solution, n_periods):
vf_arr_list, params, model = phelps_deaton_model_solution(n_periods)

Expand All @@ -131,14 +129,25 @@ def test_simulate_using_get_lcm_function(phelps_deaton_model_solution, n_periods
initial_states={
"wealth": jnp.array([1.0, 20, 40, 70]),
},
additional_targets=["utility", "consumption_constraint"],
)

assert {
"value",
"retirement",
"consumption",
"wealth",
"utility",
"consumption_constraint",
} == set(res.columns)

# assert that everyone retires in the last period
assert_array_equal(res[-1]["choices"]["retirement"], 1)
last_period_index = n_periods - 1
assert_array_equal(res.loc[last_period_index, :]["retirement"], 1)

# assert that higher wealth leads to higher consumption
for period in range(n_periods):
assert jnp.all(jnp.diff(res[period]["choices"]["consumption"]) >= 0)
assert (res.loc[period, :]["consumption"].diff()[1:] >= 0).all()

# The following does not work. I.e. the continuation value in each period is not
# weakly increasing in wealth. It is unclear if this needs to hold.
Expand Down Expand Up @@ -199,7 +208,11 @@ def test_effect_of_beta_on_last_period():

# Asserting
# ==================================================================================
assert jnp.all(res_low[-1]["value"] <= res_high[-1]["value"])
last_period_index = 4
assert (
res_low.loc[last_period_index, :]["value"]
<= res_high.loc[last_period_index, :]["value"]
).all()


def test_effect_of_delta():
Expand Down Expand Up @@ -251,14 +264,14 @@ def test_effect_of_delta():
# Asserting
# ==================================================================================
for period in range(5):
assert jnp.all(
res_low[period]["choices"]["consumption"]
<= res_high[period]["choices"]["consumption"],
)
assert jnp.all(
res_low[period]["choices"]["retirement"]
>= res_high[period]["choices"]["retirement"],
)
assert (
res_low.loc[period, :]["consumption"]
<= res_high.loc[period, :]["consumption"]
).all()
assert (
res_low.loc[period, :]["retirement"]
>= res_high.loc[period, :]["retirement"]
).all()


# ======================================================================================
Expand Down