diff --git a/src/lcm/entry_point.py b/src/lcm/entry_point.py index 1e560fb..37ec5df 100644 --- a/src/lcm/entry_point.py +++ b/src/lcm/entry_point.py @@ -1,4 +1,5 @@ import functools +import inspect from collections.abc import Callable from functools import partial from typing import Literal, cast @@ -9,13 +10,14 @@ from lcm.argmax import argmax from lcm.discrete_problem import get_solve_discrete_problem from lcm.dispatchers import productmap +from lcm.functools import all_as_kwargs from lcm.input_processing import process_model from lcm.logging import get_logger from lcm.model_functions import ( get_utility_and_feasibility_function, ) from lcm.next_state import get_next_state_function -from lcm.simulate import simulate +from lcm.simulate import _as_data_frame, _compute_targets, simulate from lcm.solve_brute import solve from lcm.state_space import create_state_choice_space from lcm.typing import ParamsDict @@ -170,7 +172,7 @@ def get_lcm_function( solve_model = jax.jit(_solve_model) if jit else _solve_model _next_state_simulate = get_next_state_function(model=_mod, target="simulate") - simulate_model = partial( + _simulate_model = partial( simulate, state_indexers=state_indexers, continuous_choice_grids=continuous_choice_grids, @@ -179,13 +181,43 @@ def get_lcm_function( next_state=jax.jit(_next_state_simulate), logger=logger, ) + simulate_model = jax.jit(_simulate_model) if jit else _simulate_model if targets == "solve": - _target = solve_model + + def _target(*args, **kwargs): + return solve_model(*args, **kwargs) elif targets == "simulate": - _target = simulate_model + + def _target(*args, **kwargs): + kwargs = all_as_kwargs( + args, kwargs, list(inspect.signature(simulate).parameters) + ) + additional_targets = kwargs.get("additional_targets") + kwargs.pop("additional_targets", None) + _simulated = simulate_model(**kwargs) + return _as_data_frame( + _compute_targets( + _simulated, additional_targets, _mod.functions, kwargs["params"] + ), + _mod.n_periods, + ) elif targets == "solve_and_simulate": - _target = partial(simulate_model, solve_model=solve_model) + + def _target(*args, **kwargs): + kwargs = all_as_kwargs( + args, kwargs, list(inspect.signature(simulate).parameters) + ) + additional_targets = kwargs.get("additional_targets") + kwargs.pop("additional_targets", None) + _solved = solve_model(kwargs["params"]) + _simulated = simulate_model(**kwargs, vf_arr_list=_solved) + return _as_data_frame( + _compute_targets( + _simulated, additional_targets, _mod.functions, kwargs["params"] + ), + _mod.n_periods, + ) return cast(Callable, _target), _mod.params diff --git a/src/lcm/simulate.py b/src/lcm/simulate.py index 561cfce..cc2283e 100644 --- a/src/lcm/simulate.py +++ b/src/lcm/simulate.py @@ -21,9 +21,7 @@ def simulate( model: InternalModel, next_state, logger, - solve_model=None, vf_arr_list=None, - additional_targets=None, seed=12345, ): """Simulate the model forward in time. @@ -59,13 +57,9 @@ def simulate( """ if vf_arr_list is None: - if solve_model is None: - raise ValueError( - "You need to provide either vf_arr_list or solve_model.", - ) - # We do not need to convert the params here, because the solve_model function - # will do it. - vf_arr_list = solve_model(params) + raise ValueError( + "You need to provide either vf_arr_list or solve_model.", + ) logger.info("Starting simulation") @@ -207,18 +201,7 @@ def simulate( logger.info("Period: %s", period) - processed = _process_simulated_data(_simulation_results) - - if additional_targets is not None: - calculated_targets = _compute_targets( - processed, - targets=additional_targets, - model_functions=model.functions, - params=params, - ) - processed = {**processed, **calculated_targets} - - return _as_data_frame(processed, n_periods=n_periods) + return _process_simulated_data(_simulation_results) def solve_continuous_problem( @@ -316,21 +299,24 @@ def _compute_targets(processed_results, targets, model_functions, params): dict: Dict with computed targets. """ - target_func = concatenate_functions( - functions=model_functions, - targets=targets, - return_type="dict", - ) + if targets is not None: + target_func = concatenate_functions( + functions=model_functions, + targets=targets, + return_type="dict", + ) - # get list of variables over which we want to vectorize the target function - variables = [ - p for p in list(inspect.signature(target_func).parameters) if p != "params" - ] + # get list of variables over which we want to vectorize the target function + variables = [ + p for p in list(inspect.signature(target_func).parameters) if p != "params" + ] + + target_func = vmap_1d(target_func, variables=variables) - target_func = vmap_1d(target_func, variables=variables) + kwargs = {k: v for k, v in processed_results.items() if k in variables} - kwargs = {k: v for k, v in processed_results.items() if k in variables} - return target_func(params=params, **kwargs) + return {**processed_results, **target_func(params=params, **kwargs)} + return processed_results def _process_simulated_data(results): diff --git a/tests/test_regression_test.py b/tests/test_regression_test.py index 52c360a..3f4ab3e 100644 --- a/tests/test_regression_test.py +++ b/tests/test_regression_test.py @@ -48,4 +48,4 @@ def test_regression_test(): # Compare # ================================================================================== aaae(expected_solve, got_solve, decimal=5) - assert_frame_equal(expected_simulate, got_simulate) + assert_frame_equal(expected_simulate, got_simulate, check_like=True) diff --git a/tests/test_simulate.py b/tests/test_simulate.py index 859aa99..6d0c1b1 100644 --- a/tests/test_simulate.py +++ b/tests/test_simulate.py @@ -94,8 +94,8 @@ def test_simulate_using_raw_inputs(simulate_inputs): **simulate_inputs, ) - assert_array_equal(got.loc[0, :]["retirement"], 1) - assert_array_almost_equal(got.loc[0, :]["consumption"], jnp.array([1.0, 50.400803])) + assert_array_equal(got["retirement"], 1) + assert_array_almost_equal(got["consumption"], jnp.array([1.0, 50.400803])) # ====================================================================================== @@ -336,6 +336,7 @@ def f_b(b, params): # noqa: ARG001 params={"disutility_of_work": -1.0}, ) expected = { + **processed_results, "fa": jnp.arange(3) - 1.0, "fb": 1 + jnp.arange(3), }