Skip to content

Commit

Permalink
Add Jit to simulation
Browse files Browse the repository at this point in the history
  • Loading branch information
mj023 committed Oct 21, 2024
1 parent 7e8e3f8 commit 0055821
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
11 changes: 6 additions & 5 deletions src/lcm/entry_point.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
from collections.abc import Callable
from dags import concatenate_functions
from functools import partial
from typing import Literal, cast

Expand All @@ -15,7 +16,7 @@
get_utility_and_feasibility_function,
)
from lcm.next_state import get_next_state_function
from lcm.simulate import simulate
from lcm.simulate import simulate, _as_data_frame
from lcm.solve_brute import solve
from lcm.state_space import create_state_choice_space
from lcm.typing import ParamsDict
Expand Down Expand Up @@ -170,7 +171,7 @@ def get_lcm_function(
solve_model = jax.jit(_solve_model) if jit else _solve_model

_next_state_simulate = get_next_state_function(model=_mod, target="simulate")
simulate_model = partial(
_simulate_model = partial(
simulate,
state_indexers=state_indexers,
continuous_choice_grids=continuous_choice_grids,
Expand All @@ -179,14 +180,14 @@ def get_lcm_function(
next_state=jax.jit(_next_state_simulate),
logger=logger,
)
simulate_model = jax.jit(_simulate_model, static_argnames="solve_model") if jit else _simulate_model

if targets == "solve":
_target = solve_model
elif targets == "simulate":
_target = simulate_model
_target = lambda *args,**kwargs: _as_data_frame(simulate_model(*args,**kwargs),_mod.n_periods)
elif targets == "solve_and_simulate":
_target = partial(simulate_model, solve_model=solve_model)

_target = lambda *args,**kwargs: _as_data_frame(partial(simulate_model, solve_model=solve_model)(*args,**kwargs),_mod.n_periods)
return cast(Callable, _target), _mod.params


Expand Down
2 changes: 1 addition & 1 deletion src/lcm/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def simulate(
)
processed = {**processed, **calculated_targets}

return _as_data_frame(processed, n_periods=n_periods)
return processed


def solve_continuous_problem(
Expand Down

0 comments on commit 0055821

Please sign in to comment.