From 3bed4b6cfc7cb98ec0e268bfc2257f9530a6a763 Mon Sep 17 00:00:00 2001 From: Max Jahn Date: Sat, 1 Jun 2024 13:01:36 +0200 Subject: [PATCH] Jit the returned lcm function (#77) --- src/lcm/entry_point.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/lcm/entry_point.py b/src/lcm/entry_point.py index 4b34c60..dc0c03d 100644 --- a/src/lcm/entry_point.py +++ b/src/lcm/entry_point.py @@ -22,6 +22,7 @@ def get_lcm_function( model, targets="solve", debug_mode=True, # noqa: FBT002 + jit=True, # noqa: FBT002 interpolation_options=None, ): """Entry point for users to get high level functions generated by lcm. @@ -47,6 +48,7 @@ def get_lcm_function( targets (str or iterable): The requested function types. Currently only "solve", "simulate" and "solve_and_simulate" are supported. debug_mode (bool): Whether to log debug messages. + jit (bool): Whether to jit the returned function. interpolation_options (dict): Dictionary of keyword arguments for interpolation via map_coordinates. @@ -139,13 +141,13 @@ def get_lcm_function( utility_and_feasibility=u_and_f, continuous_choice_variables=list(_choice_grids), ) - compute_ccv_functions.append(jax.jit(compute_ccv)) + compute_ccv_functions.append(compute_ccv) compute_ccv_argmax = create_compute_conditional_continuation_policy( utility_and_feasibility=u_and_f, continuous_choice_variables=list(_choice_grids), ) - compute_ccv_policy_functions.append(jax.jit(compute_ccv_argmax)) + compute_ccv_policy_functions.append(compute_ccv_argmax) # create list of emax_calculators # ============================================================================== @@ -158,7 +160,7 @@ def get_lcm_function( choice_segments=choice_segments[period], params=_mod.params, ) - emax_calculators.append(jax.jit(calculator)) + emax_calculators.append(calculator) # ================================================================================== # select requested solver and partial arguments into it @@ -172,7 +174,8 @@ def get_lcm_function( emax_calculators=emax_calculators, logger=logger, ) - solve_model = _solve_model + + solve_model = jax.jit(_solve_model) if jit else _solve_model _next_state_simulate = get_next_state_function(model=_mod, target="simulate") simulate_model = partial(