Skip to content

Commit

Permalink
Add type hints to process model
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Sep 2, 2024
1 parent eb798ed commit 6135338
Showing 1 changed file with 35 additions and 19 deletions.
54 changes: 35 additions & 19 deletions src/lcm/process_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,32 +300,41 @@ def _get_functions(

else:
is_filter_function = function_info.loc[name, "is_filter"]
# params[name] contains the dictionary of parameters for the function
# params[name] contains the dictionary of parameters for the function, which
# is empty if the function does not depend on any model parameters.
depends_on_params = bool(params[name])

if is_filter_function:
if params.get(name, {}):
raise ValueError("filters cannot depend on model parameters.")
if name in params:
raise ValueError(
f"filters cannot depend on model parameters, but {name} does."
)
processed_func = func

elif depends_on_params:
processed_func = _get_extracting_function(
processed_func = _replace_func_parameters_by_params(
func=func,
params=params,
name=name,
)

else:
processed_func = _get_function_with_dummy_params(func=func)
processed_func = _add_dummy_params_argument(func)

functions[name] = processed_func

return functions


def _get_extracting_function(func, params, name):
def _replace_func_parameters_by_params(
func: Callable, params: Params, name: str
) -> Callable:
old_signature = list(inspect.signature(func).parameters)
new_kwargs = [p for p in old_signature if p not in params[name]] + ["params"]
new_kwargs = [
p
for p in old_signature
if p not in params[name] # type: ignore[operator]
] + ["params"]

@with_signature(args=new_kwargs)
@functools.wraps(func)
Expand All @@ -337,7 +346,7 @@ def processed_func(*args, **kwargs):
return processed_func


def _get_function_with_dummy_params(func):
def _add_dummy_params_argument(func: Callable) -> Callable:
old_signature = list(inspect.signature(func).parameters)

new_kwargs = [*old_signature, "params"]
Expand All @@ -352,15 +361,17 @@ def processed_func(*args, **kwargs):
return processed_func


def _get_stochastic_next_function(raw_func, grid):
def _get_stochastic_next_function(raw_func: Callable, grid: Array):
@functools.wraps(raw_func)
def next_func(*args, **kwargs): # noqa: ARG001
return grid

return next_func


def _get_stochastic_weight_function(raw_func, name, variable_info):
def _get_stochastic_weight_function(
raw_func: Callable, name: str, variable_info: pd.DataFrame
):
"""Get a function that returns the transition weights of a stochastic variable.
Example:
Expand All @@ -384,9 +395,9 @@ def _get_stochastic_weight_function(raw_func, name, variable_info):
Args:
raw_func (callable): The raw next function of the stochastic variable.
name (str): The name of the stochastic variable.
variable_info (pd.DataFrame): A table with information about model variables.
raw_func: The raw next function of the stochastic variable.
name: The name of the stochastic variable.
variable_info: A table with information about model variables.
Returns:
callable: A function that returns the transition weights of the stochastic
Expand All @@ -396,12 +407,17 @@ def _get_stochastic_weight_function(raw_func, name, variable_info):
function_parameters = list(inspect.signature(raw_func).parameters)

# Assert that stochastic next function only depends on discrete variables or period
for arg in function_parameters:
if arg != "_period" and not variable_info.loc[arg, "is_discrete"]:
raise ValueError(
f"Stochastic variables can only depend on discrete variables and "
f"'_period', but {name} depends on {arg}.",
)
invalid = {
arg
for arg in function_parameters
if arg != "_period" and not variable_info.loc[arg, "is_discrete"]
}

if invalid:
raise ValueError(
"Stochastic variables can only depend on discrete variables and '_period', "
f"but {name} depends on {invalid}.",
)

new_kwargs = [*function_parameters, "params"]

Expand Down

0 comments on commit 6135338

Please sign in to comment.