Skip to content

Commit

Permalink
Implement label translator
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Oct 2, 2023
1 parent ba76985 commit 33aa499
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 7 deletions.
54 changes: 48 additions & 6 deletions src/lcm/process_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import lcm.grids as grids_module
from lcm.create_params import create_params
from lcm.function_evaluator import _get_label_translator
from lcm.functools import all_as_kwargs
from lcm.interfaces import GridSpec, Model

Expand Down Expand Up @@ -271,6 +272,8 @@ def _get_functions(user_model, function_info, variable_info, grids, params):
raw_functions[f"weight_next_{var}"] = _get_stochastic_weight_function(
raw_func=raw_functions[f"next_{var}"],
name=var,
variable_info=variable_info,
grids=grids,
)

# ==================================================================================
Expand Down Expand Up @@ -351,7 +354,7 @@ def next_func(*args, **kwargs): # noqa: ARG001
return next_func


def _get_stochastic_weight_function(raw_func, name):
def _get_stochastic_weight_function(raw_func, name, variable_info, grids):
"""Get a function that returns the transition weights of a stochastic variable.
Example:
Expand All @@ -364,29 +367,68 @@ def _get_stochastic_weight_function(raw_func, name):
>>> pass
>>> next_health._stochastic_info = StochasticInfo()
>>> params = {"shocks": {"health": np.arange(4).reshape(2, 2)}}
>>> weight_func = _get_stochastic_weight_function(next_health, "health")
>>> weight_func = _get_stochastic_weight_function(
>>> raw_func=next_health,
>>> name="health"
>>> variable_info=variable_info,
>>> grids=grids,
>>> )
>>> weight_func(health=0, params=params)
>>> array([0, 1])
Args:
raw_func (callable): The raw next function of the stochastic variable.
name (str): The name of the stochastic variable.
variable_info (pd.DataFrame): A table with information about model variables.
grids (dict): Dictionary containing all variables of the model. The keys are
the names of the variables. The values are the grids.
Returns:
callable: A function that returns the transition weights of the stochastic
variable.
"""
signature = list(inspect.signature(raw_func).parameters)
function_parameters = list(inspect.signature(raw_func).parameters)

# Assert that stochastic next function only depends on state variables
for arg in function_parameters:
if not variable_info.loc[arg, "is_state"]:
raise ValueError(
f"Stochastic variables can only depend on state variables, but {name} "
f"depends on {arg}.",
)

new_kwargs = [*signature, "params"]
label_translators = _get_label_translators(
variables=function_parameters,
grids=grids,
)

new_kwargs = [*function_parameters, "params"]

@with_signature(args=new_kwargs)
def weight_func(*args, **kwargs):
kwargs = all_as_kwargs(args, kwargs, arg_names=new_kwargs)
indices = [kwargs[arg] for arg in signature] # TODO(@timmens): this should be
# an indexer lookup
indices = [
label_translators[arg](**{arg: kwargs[arg]}) for arg in function_parameters
]
return kwargs["params"]["shocks"][name][*indices]

return weight_func


def _get_label_translators(variables, grids):
"""Get a dictionary of label translators.
Args:
variables (list): List of variable names.
grids (dict): Dictionary containing all variables of the model. The keys are
the names of the variables. The values are the grids.
Returns:
dict: Dictionary that maps variable names to label translators.
"""
return {
var: _get_label_translator(labels=grids[var], in_name=var) for var in variables
}
34 changes: 33 additions & 1 deletion tests/test_process_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,19 @@ def raw_func(health, wealth): # noqa: ARG001

raw_func._stochastic_info = StochasticInfo()

got_function = _get_stochastic_weight_function(raw_func, name="health")
variable_info = pd.DataFrame({"is_state": [True, True]}, index=["health", "wealth"])

grids = {
"health": jnp.arange(2),
"wealth": jnp.arange(2),
}

got_function = _get_stochastic_weight_function(
raw_func,
name="health",
variable_info=variable_info,
grids=grids,
)

params = {"shocks": {"health": np.arange(24).reshape(2, 3, 4)}}

Expand All @@ -228,5 +240,25 @@ def raw_func(health, wealth): # noqa: ARG001
assert_array_equal(got, expected)


def test_get_stochastic_weight_function_non_state_dependency():
def raw_func(health, wealth): # noqa: ARG001
pass

raw_func._stochastic_info = StochasticInfo()

variable_info = pd.DataFrame(
{"is_state": [False, True]},
index=["health", "wealth"],
)

with pytest.raises(ValueError, match="Stochastic variables"):
_get_stochastic_weight_function(
raw_func,
name="health",
variable_info=variable_info,
grids=None,
)


def test_create_shock_params():
pass

0 comments on commit 33aa499

Please sign in to comment.