Skip to content

Commit

Permalink
Move Compute Targets
Browse files Browse the repository at this point in the history
  • Loading branch information
mj023 committed Oct 31, 2024
1 parent 0055821 commit ff093d9
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 37 deletions.
27 changes: 20 additions & 7 deletions src/lcm/entry_point.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import functools
from collections.abc import Callable
from dags import concatenate_functions
from functools import partial
import inspect
from typing import Literal, cast

import jax
Expand All @@ -10,13 +10,14 @@
from lcm.argmax import argmax
from lcm.discrete_problem import get_solve_discrete_problem
from lcm.dispatchers import productmap
from lcm.functools import all_as_kwargs
from lcm.input_processing import process_model
from lcm.logging import get_logger
from lcm.model_functions import (
get_utility_and_feasibility_function,
)
from lcm.next_state import get_next_state_function
from lcm.simulate import simulate, _as_data_frame
from lcm.simulate import simulate, _as_data_frame, _compute_targets
from lcm.solve_brute import solve
from lcm.state_space import create_state_choice_space
from lcm.typing import ParamsDict
Expand All @@ -28,7 +29,7 @@ def get_lcm_function(
targets: Literal["solve", "simulate", "solve_and_simulate"] = "solve",
*,
debug_mode: bool = True,
jit: bool = True,
jit: bool = False,
) -> tuple[Callable, ParamsDict]:
"""Entry point for users to get high level functions generated by lcm.
Expand Down Expand Up @@ -180,14 +181,26 @@ def get_lcm_function(
next_state=jax.jit(_next_state_simulate),
logger=logger,
)
simulate_model = jax.jit(_simulate_model, static_argnames="solve_model") if jit else _simulate_model
simulate_model = jax.jit(_simulate_model) if jit else _simulate_model

if targets == "solve":
_target = solve_model
def _target(*args,**kwargs):
return solve_model(*args,**kwargs)
elif targets == "simulate":
_target = lambda *args,**kwargs: _as_data_frame(simulate_model(*args,**kwargs),_mod.n_periods)
def _target(*args,**kwargs):
kwargs = all_as_kwargs(args,kwargs,list(inspect.signature(simulate).parameters,))
additional_targets = kwargs.get('additional_targets')
kwargs.pop('additional_targets',None)
_simulated = simulate_model(**kwargs)
return _as_data_frame(_compute_targets(_simulated, additional_targets, _mod.functions, kwargs['params']), _mod.n_periods)
elif targets == "solve_and_simulate":
_target = lambda *args,**kwargs: _as_data_frame(partial(simulate_model, solve_model=solve_model)(*args,**kwargs),_mod.n_periods)
def _target(*args,**kwargs):
kwargs = all_as_kwargs(args,kwargs,list(inspect.signature(simulate).parameters,))
additional_targets = kwargs.get('additional_targets')
kwargs.pop('additional_targets',None)
_solved = solve_model(kwargs['params'])
_simulated = simulate_model(**kwargs, vf_arr_list=_solved)
return _as_data_frame(_compute_targets(_simulated, additional_targets, _mod.functions, kwargs['params']), _mod.n_periods)
return cast(Callable, _target), _mod.params


Expand Down
49 changes: 19 additions & 30 deletions src/lcm/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ def simulate(
model: InternalModel,
next_state,
logger,
solve_model=None,
vf_arr_list=None,
additional_targets=None,
seed=12345,
):
"""Simulate the model forward in time.
Expand Down Expand Up @@ -59,13 +57,9 @@ def simulate(
"""
if vf_arr_list is None:
if solve_model is None:
raise ValueError(
"You need to provide either vf_arr_list or solve_model.",
)
# We do not need to convert the params here, because the solve_model function
# will do it.
vf_arr_list = solve_model(params)
raise ValueError(
"You need to provide either vf_arr_list or solve_model.",
)

logger.info("Starting simulation")

Expand Down Expand Up @@ -209,15 +203,6 @@ def simulate(

processed = _process_simulated_data(_simulation_results)

if additional_targets is not None:
calculated_targets = _compute_targets(
processed,
targets=additional_targets,
model_functions=model.functions,
params=params,
)
processed = {**processed, **calculated_targets}

return processed


Expand Down Expand Up @@ -316,21 +301,25 @@ def _compute_targets(processed_results, targets, model_functions, params):
dict: Dict with computed targets.
"""
target_func = concatenate_functions(
functions=model_functions,
targets=targets,
return_type="dict",
)
if targets is not None:
target_func = concatenate_functions(
functions=model_functions,
targets=targets,
return_type="dict",
)

# get list of variables over which we want to vectorize the target function
variables = [
p for p in list(inspect.signature(target_func).parameters) if p != "params"
]
# get list of variables over which we want to vectorize the target function
variables = [
p for p in list(inspect.signature(target_func).parameters) if p != "params"
]

target_func = vmap_1d(target_func, variables=variables)
target_func = vmap_1d(target_func, variables=variables)

kwargs = {k: v for k, v in processed_results.items() if k in variables}
return target_func(params=params, **kwargs)
kwargs = {k: v for k, v in processed_results.items() if k in variables}

return {**processed_results, **target_func(params=params, **kwargs)}
else:
return processed_results


def _process_simulated_data(results):
Expand Down

0 comments on commit ff093d9

Please sign in to comment.