From 0055821b4f37201be2fd4b3561341fc43766b77b Mon Sep 17 00:00:00 2001 From: mj023 Date: Mon, 21 Oct 2024 22:49:25 +0200 Subject: [PATCH 1/4] Add Jit to simulation --- src/lcm/entry_point.py | 11 ++++++----- src/lcm/simulate.py | 2 +- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/lcm/entry_point.py b/src/lcm/entry_point.py index 1e560fb..e049af6 100644 --- a/src/lcm/entry_point.py +++ b/src/lcm/entry_point.py @@ -1,5 +1,6 @@ import functools from collections.abc import Callable +from dags import concatenate_functions from functools import partial from typing import Literal, cast @@ -15,7 +16,7 @@ get_utility_and_feasibility_function, ) from lcm.next_state import get_next_state_function -from lcm.simulate import simulate +from lcm.simulate import simulate, _as_data_frame from lcm.solve_brute import solve from lcm.state_space import create_state_choice_space from lcm.typing import ParamsDict @@ -170,7 +171,7 @@ def get_lcm_function( solve_model = jax.jit(_solve_model) if jit else _solve_model _next_state_simulate = get_next_state_function(model=_mod, target="simulate") - simulate_model = partial( + _simulate_model = partial( simulate, state_indexers=state_indexers, continuous_choice_grids=continuous_choice_grids, @@ -179,14 +180,14 @@ def get_lcm_function( next_state=jax.jit(_next_state_simulate), logger=logger, ) + simulate_model = jax.jit(_simulate_model, static_argnames="solve_model") if jit else _simulate_model if targets == "solve": _target = solve_model elif targets == "simulate": - _target = simulate_model + _target = lambda *args,**kwargs: _as_data_frame(simulate_model(*args,**kwargs),_mod.n_periods) elif targets == "solve_and_simulate": - _target = partial(simulate_model, solve_model=solve_model) - + _target = lambda *args,**kwargs: _as_data_frame(partial(simulate_model, solve_model=solve_model)(*args,**kwargs),_mod.n_periods) return cast(Callable, _target), _mod.params diff --git a/src/lcm/simulate.py b/src/lcm/simulate.py index 561cfce..686c98c 100644 --- a/src/lcm/simulate.py +++ b/src/lcm/simulate.py @@ -218,7 +218,7 @@ def simulate( ) processed = {**processed, **calculated_targets} - return _as_data_frame(processed, n_periods=n_periods) + return processed def solve_continuous_problem( From ff093d99eab14896ef5f920fea10562e299b99e9 Mon Sep 17 00:00:00 2001 From: mj023 Date: Thu, 31 Oct 2024 03:59:13 +0100 Subject: [PATCH 2/4] Move Compute Targets --- src/lcm/entry_point.py | 27 +++++++++++++++++------ src/lcm/simulate.py | 49 ++++++++++++++++-------------------------- 2 files changed, 39 insertions(+), 37 deletions(-) diff --git a/src/lcm/entry_point.py b/src/lcm/entry_point.py index e049af6..da98e44 100644 --- a/src/lcm/entry_point.py +++ b/src/lcm/entry_point.py @@ -1,7 +1,7 @@ import functools from collections.abc import Callable -from dags import concatenate_functions from functools import partial +import inspect from typing import Literal, cast import jax @@ -10,13 +10,14 @@ from lcm.argmax import argmax from lcm.discrete_problem import get_solve_discrete_problem from lcm.dispatchers import productmap +from lcm.functools import all_as_kwargs from lcm.input_processing import process_model from lcm.logging import get_logger from lcm.model_functions import ( get_utility_and_feasibility_function, ) from lcm.next_state import get_next_state_function -from lcm.simulate import simulate, _as_data_frame +from lcm.simulate import simulate, _as_data_frame, _compute_targets from lcm.solve_brute import solve from lcm.state_space import create_state_choice_space from lcm.typing import ParamsDict @@ -28,7 +29,7 @@ def get_lcm_function( targets: Literal["solve", "simulate", "solve_and_simulate"] = "solve", *, debug_mode: bool = True, - jit: bool = True, + jit: bool = False, ) -> tuple[Callable, ParamsDict]: """Entry point for users to get high level functions generated by lcm. @@ -180,14 +181,26 @@ def get_lcm_function( next_state=jax.jit(_next_state_simulate), logger=logger, ) - simulate_model = jax.jit(_simulate_model, static_argnames="solve_model") if jit else _simulate_model + simulate_model = jax.jit(_simulate_model) if jit else _simulate_model if targets == "solve": - _target = solve_model + def _target(*args,**kwargs): + return solve_model(*args,**kwargs) elif targets == "simulate": - _target = lambda *args,**kwargs: _as_data_frame(simulate_model(*args,**kwargs),_mod.n_periods) + def _target(*args,**kwargs): + kwargs = all_as_kwargs(args,kwargs,list(inspect.signature(simulate).parameters,)) + additional_targets = kwargs.get('additional_targets') + kwargs.pop('additional_targets',None) + _simulated = simulate_model(**kwargs) + return _as_data_frame(_compute_targets(_simulated, additional_targets, _mod.functions, kwargs['params']), _mod.n_periods) elif targets == "solve_and_simulate": - _target = lambda *args,**kwargs: _as_data_frame(partial(simulate_model, solve_model=solve_model)(*args,**kwargs),_mod.n_periods) + def _target(*args,**kwargs): + kwargs = all_as_kwargs(args,kwargs,list(inspect.signature(simulate).parameters,)) + additional_targets = kwargs.get('additional_targets') + kwargs.pop('additional_targets',None) + _solved = solve_model(kwargs['params']) + _simulated = simulate_model(**kwargs, vf_arr_list=_solved) + return _as_data_frame(_compute_targets(_simulated, additional_targets, _mod.functions, kwargs['params']), _mod.n_periods) return cast(Callable, _target), _mod.params diff --git a/src/lcm/simulate.py b/src/lcm/simulate.py index 686c98c..468542e 100644 --- a/src/lcm/simulate.py +++ b/src/lcm/simulate.py @@ -21,9 +21,7 @@ def simulate( model: InternalModel, next_state, logger, - solve_model=None, vf_arr_list=None, - additional_targets=None, seed=12345, ): """Simulate the model forward in time. @@ -59,13 +57,9 @@ def simulate( """ if vf_arr_list is None: - if solve_model is None: - raise ValueError( - "You need to provide either vf_arr_list or solve_model.", - ) - # We do not need to convert the params here, because the solve_model function - # will do it. - vf_arr_list = solve_model(params) + raise ValueError( + "You need to provide either vf_arr_list or solve_model.", + ) logger.info("Starting simulation") @@ -209,15 +203,6 @@ def simulate( processed = _process_simulated_data(_simulation_results) - if additional_targets is not None: - calculated_targets = _compute_targets( - processed, - targets=additional_targets, - model_functions=model.functions, - params=params, - ) - processed = {**processed, **calculated_targets} - return processed @@ -316,21 +301,25 @@ def _compute_targets(processed_results, targets, model_functions, params): dict: Dict with computed targets. """ - target_func = concatenate_functions( - functions=model_functions, - targets=targets, - return_type="dict", - ) + if targets is not None: + target_func = concatenate_functions( + functions=model_functions, + targets=targets, + return_type="dict", + ) - # get list of variables over which we want to vectorize the target function - variables = [ - p for p in list(inspect.signature(target_func).parameters) if p != "params" - ] + # get list of variables over which we want to vectorize the target function + variables = [ + p for p in list(inspect.signature(target_func).parameters) if p != "params" + ] - target_func = vmap_1d(target_func, variables=variables) + target_func = vmap_1d(target_func, variables=variables) - kwargs = {k: v for k, v in processed_results.items() if k in variables} - return target_func(params=params, **kwargs) + kwargs = {k: v for k, v in processed_results.items() if k in variables} + + return {**processed_results, **target_func(params=params, **kwargs)} + else: + return processed_results def _process_simulated_data(results): From 8f419d3928dc3be588be0d2293b6b6b19cf41801 Mon Sep 17 00:00:00 2001 From: mj023 Date: Wed, 6 Nov 2024 16:54:02 +0100 Subject: [PATCH 3/4] Adapt Tests --- src/lcm/entry_point.py | 2 +- tests/test_regression_test.py | 2 +- tests/test_simulate.py | 5 +++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/lcm/entry_point.py b/src/lcm/entry_point.py index da98e44..fe03e22 100644 --- a/src/lcm/entry_point.py +++ b/src/lcm/entry_point.py @@ -29,7 +29,7 @@ def get_lcm_function( targets: Literal["solve", "simulate", "solve_and_simulate"] = "solve", *, debug_mode: bool = True, - jit: bool = False, + jit: bool = True, ) -> tuple[Callable, ParamsDict]: """Entry point for users to get high level functions generated by lcm. diff --git a/tests/test_regression_test.py b/tests/test_regression_test.py index 52c360a..3f4ab3e 100644 --- a/tests/test_regression_test.py +++ b/tests/test_regression_test.py @@ -48,4 +48,4 @@ def test_regression_test(): # Compare # ================================================================================== aaae(expected_solve, got_solve, decimal=5) - assert_frame_equal(expected_simulate, got_simulate) + assert_frame_equal(expected_simulate, got_simulate, check_like=True) diff --git a/tests/test_simulate.py b/tests/test_simulate.py index 859aa99..6d0c1b1 100644 --- a/tests/test_simulate.py +++ b/tests/test_simulate.py @@ -94,8 +94,8 @@ def test_simulate_using_raw_inputs(simulate_inputs): **simulate_inputs, ) - assert_array_equal(got.loc[0, :]["retirement"], 1) - assert_array_almost_equal(got.loc[0, :]["consumption"], jnp.array([1.0, 50.400803])) + assert_array_equal(got["retirement"], 1) + assert_array_almost_equal(got["consumption"], jnp.array([1.0, 50.400803])) # ====================================================================================== @@ -336,6 +336,7 @@ def f_b(b, params): # noqa: ARG001 params={"disutility_of_work": -1.0}, ) expected = { + **processed_results, "fa": jnp.arange(3) - 1.0, "fb": 1 + jnp.arange(3), } From ec123cd23ed8dea7011e57c83f574578f6855882 Mon Sep 17 00:00:00 2001 From: mj023 Date: Thu, 7 Nov 2024 00:46:57 +0100 Subject: [PATCH 4/4] Fix Formatting --- src/lcm/entry_point.py | 48 +++++++++++++++++++++++++++++------------- src/lcm/simulate.py | 9 +++----- 2 files changed, 36 insertions(+), 21 deletions(-) diff --git a/src/lcm/entry_point.py b/src/lcm/entry_point.py index fe03e22..37ec5df 100644 --- a/src/lcm/entry_point.py +++ b/src/lcm/entry_point.py @@ -1,7 +1,7 @@ import functools +import inspect from collections.abc import Callable from functools import partial -import inspect from typing import Literal, cast import jax @@ -17,7 +17,7 @@ get_utility_and_feasibility_function, ) from lcm.next_state import get_next_state_function -from lcm.simulate import simulate, _as_data_frame, _compute_targets +from lcm.simulate import _as_data_frame, _compute_targets, simulate from lcm.solve_brute import solve from lcm.state_space import create_state_choice_space from lcm.typing import ParamsDict @@ -184,23 +184,41 @@ def get_lcm_function( simulate_model = jax.jit(_simulate_model) if jit else _simulate_model if targets == "solve": - def _target(*args,**kwargs): - return solve_model(*args,**kwargs) + + def _target(*args, **kwargs): + return solve_model(*args, **kwargs) elif targets == "simulate": - def _target(*args,**kwargs): - kwargs = all_as_kwargs(args,kwargs,list(inspect.signature(simulate).parameters,)) - additional_targets = kwargs.get('additional_targets') - kwargs.pop('additional_targets',None) + + def _target(*args, **kwargs): + kwargs = all_as_kwargs( + args, kwargs, list(inspect.signature(simulate).parameters) + ) + additional_targets = kwargs.get("additional_targets") + kwargs.pop("additional_targets", None) _simulated = simulate_model(**kwargs) - return _as_data_frame(_compute_targets(_simulated, additional_targets, _mod.functions, kwargs['params']), _mod.n_periods) + return _as_data_frame( + _compute_targets( + _simulated, additional_targets, _mod.functions, kwargs["params"] + ), + _mod.n_periods, + ) elif targets == "solve_and_simulate": - def _target(*args,**kwargs): - kwargs = all_as_kwargs(args,kwargs,list(inspect.signature(simulate).parameters,)) - additional_targets = kwargs.get('additional_targets') - kwargs.pop('additional_targets',None) - _solved = solve_model(kwargs['params']) + + def _target(*args, **kwargs): + kwargs = all_as_kwargs( + args, kwargs, list(inspect.signature(simulate).parameters) + ) + additional_targets = kwargs.get("additional_targets") + kwargs.pop("additional_targets", None) + _solved = solve_model(kwargs["params"]) _simulated = simulate_model(**kwargs, vf_arr_list=_solved) - return _as_data_frame(_compute_targets(_simulated, additional_targets, _mod.functions, kwargs['params']), _mod.n_periods) + return _as_data_frame( + _compute_targets( + _simulated, additional_targets, _mod.functions, kwargs["params"] + ), + _mod.n_periods, + ) + return cast(Callable, _target), _mod.params diff --git a/src/lcm/simulate.py b/src/lcm/simulate.py index 468542e..cc2283e 100644 --- a/src/lcm/simulate.py +++ b/src/lcm/simulate.py @@ -201,9 +201,7 @@ def simulate( logger.info("Period: %s", period) - processed = _process_simulated_data(_simulation_results) - - return processed + return _process_simulated_data(_simulation_results) def solve_continuous_problem( @@ -316,10 +314,9 @@ def _compute_targets(processed_results, targets, model_functions, params): target_func = vmap_1d(target_func, variables=variables) kwargs = {k: v for k, v in processed_results.items() if k in variables} - + return {**processed_results, **target_func(params=params, **kwargs)} - else: - return processed_results + return processed_results def _process_simulated_data(results):