Skip to content

Commit

Permalink
Add options module to store LCM's default options
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Jun 27, 2024
1 parent e0c8b10 commit 0616b3e
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 25 deletions.
21 changes: 11 additions & 10 deletions src/lcm/entry_point.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
from functools import partial
from typing import Literal

import jax
import jax.numpy as jnp
Expand All @@ -12,18 +13,20 @@
get_utility_and_feasibility_function,
)
from lcm.next_state import get_next_state_function
from lcm.options import DefaultMapCoordinatesOptions
from lcm.process_model import process_model
from lcm.simulate import simulate
from lcm.solve_brute import solve
from lcm.state_space import create_state_choice_space


def get_lcm_function(
model,
targets="solve",
debug_mode=True, # noqa: FBT002
jit=True, # noqa: FBT002
interpolation_options=None,
model: dict,
targets: Literal["solve", "simulate", "solve_and_simulate"] = "solve",
*,
debug_mode: bool = True,
jit: bool = True,
interpolation_options: dict | None = None,
):
"""Entry point for users to get high level functions generated by lcm.
Expand All @@ -50,7 +53,7 @@ def get_lcm_function(
debug_mode (bool): Whether to log debug messages.
jit (bool): Whether to jit the returned function.
interpolation_options (dict): Dictionary of keyword arguments for interpolation
via map_coordinates.
via map_coordinates. If None, the default options are used.
Returns:
callable: A function that takes params and returns the requested targets.
Expand All @@ -65,9 +68,7 @@ def get_lcm_function(

_mod = process_model(user_model=model)
last_period = _mod.n_periods - 1
interpolation_options = (
{} if interpolation_options is None else interpolation_options
)
interpolation_options = DefaultMapCoordinatesOptions | (interpolation_options or {})

logger = get_logger(debug_mode)

Expand All @@ -83,7 +84,7 @@ def get_lcm_function(
# Initialize other argument lists
# ==================================================================================
state_choice_spaces = []
state_indexers = []
state_indexers = [] # type:ignore[var-annotated]
space_infos = []
compute_ccv_functions = []
compute_ccv_policy_functions = []
Expand Down
17 changes: 4 additions & 13 deletions src/lcm/function_representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def get_function_representation(
space_info: SpaceInfo,
name_of_values_on_grid: str,
*,
interpolation_options: MapCoordinatesOptions | None = None,
interpolation_options: MapCoordinatesOptions,
input_prefix: str = "",
) -> Callable[..., Array]:
"""Create a function representation of pre-calculated values on a grid.
Expand Down Expand Up @@ -230,17 +230,10 @@ def find_coordinate(*args, **kwargs):
return find_coordinate


DefaultMapCoordinatesOptions: MapCoordinatesOptions = {
"order": 1,
"mode": "nearest",
"cval": 0.0,
}


def _get_interpolator(
name_of_values_on_grid: str,
axis_names: list[str],
map_coordinates_options: MapCoordinatesOptions | None,
map_coordinates_options: MapCoordinatesOptions,
) -> Callable[..., Array]:
"""Create a function interpolator via named axes.
Expand All @@ -250,15 +243,13 @@ def _get_interpolator(
resulting function.
axis_names: Names of the axes in the data array.
map_coordinates_options: Dictionary of interpolation options that will be passed
to jax.scipy.ndimage.map_coordinates. If None, DefaultMapCoordinatesOptions
will be used.
to jax.scipy.ndimage.map_coordinates.
Returns:
callable: A callable that interpolates a function via named axes.
"""
kwargs = DefaultMapCoordinatesOptions | (map_coordinates_options or {})
partialled_map_coordinates = partial(map_coordinates, **kwargs)
partialled_map_coordinates = partial(map_coordinates, **map_coordinates_options)

arg_names = [name_of_values_on_grid, *axis_names]

Expand Down
7 changes: 7 additions & 0 deletions src/lcm/options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from lcm.typing import MapCoordinatesOptions

DefaultMapCoordinatesOptions: MapCoordinatesOptions = {
"order": 1,
"mode": "nearest",
"cval": 0.0,
}
10 changes: 8 additions & 2 deletions tests/test_function_representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
IndexerInfo,
SpaceInfo,
)
from lcm.options import DefaultMapCoordinatesOptions


def test_function_evaluator_with_one_continuous_variable():
Expand All @@ -41,6 +42,7 @@ def test_function_evaluator_with_one_continuous_variable():
space_info=space_info,
name_of_values_on_grid="vf_arr",
input_prefix="next_",
interpolation_options=DefaultMapCoordinatesOptions,
)

# partial the function values into the evaluator
Expand All @@ -67,6 +69,7 @@ def test_function_evaluator_with_one_discrete_variable():
space_info=space_info,
name_of_values_on_grid="vf_arr",
input_prefix="next_",
interpolation_options=DefaultMapCoordinatesOptions,
)

# partial the function values into the evaluator
Expand Down Expand Up @@ -144,6 +147,7 @@ def test_function_evaluator():
evaluator = get_function_representation(
space_info=space_info,
name_of_values_on_grid="vf_arr",
interpolation_options=DefaultMapCoordinatesOptions,
)

# test the evaluator
Expand Down Expand Up @@ -227,6 +231,7 @@ def test_function_evaluator_longer_indexer():
evaluator = get_function_representation(
space_info=space_info,
name_of_values_on_grid="vf_arr",
interpolation_options=DefaultMapCoordinatesOptions,
)

# test the evaluator
Expand Down Expand Up @@ -292,7 +297,7 @@ def test_get_interpolator():
interpolate = _get_interpolator(
name_of_values_on_grid="vf",
axis_names=["wealth", "working"],
map_coordinates_options=None,
map_coordinates_options=DefaultMapCoordinatesOptions,
)

def _utility(wealth, working):
Expand Down Expand Up @@ -337,6 +342,7 @@ def test_get_function_evaluator_illustrative():
space_info=space_info,
name_of_values_on_grid="values_name",
input_prefix="prefix_",
interpolation_options=DefaultMapCoordinatesOptions,
)

# partial the function values into the evaluator
Expand Down Expand Up @@ -376,7 +382,7 @@ def test_get_interpolator_illustrative():
interpolate = _get_interpolator(
name_of_values_on_grid="test_name",
axis_names=["a", "b"],
map_coordinates_options=None,
map_coordinates_options=DefaultMapCoordinatesOptions,
)

def f(a, b):
Expand Down

0 comments on commit 0616b3e

Please sign in to comment.