Skip to content

Commit

Permalink
Add description to _get_functions function
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Sep 29, 2023
1 parent 77ff58e commit a9f5641
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 19 deletions.
36 changes: 22 additions & 14 deletions src/lcm/model_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,26 +56,35 @@ def get_utility_and_feasibility_function(
inspect.signature(scalar_value_function).parameters,
)

arg_names = {"vf_arr"} | get_union_of_arguments(relevant_functions)
arg_names = [arg for arg in arg_names if "next_" not in arg]

# ==================================================================================
# Create the utility and feasability function
# ==================================================================================

@with_signature(args=arg_names)
def u_and_f(*args, **kwargs):
kwargs = all_as_kwargs(args, kwargs, arg_names=arg_names)
arg_names = {"vf_arr"} | get_union_of_arguments(relevant_functions)
arg_names = [arg for arg in arg_names if "next_" not in arg]

if is_last_period:

@with_signature(args=arg_names)
def u_and_f(*args, **kwargs):
kwargs = all_as_kwargs(args, kwargs, arg_names=arg_names)

states = {k: v for k, v in kwargs.items() if k in state_variables}
choices = {k: v for k, v in kwargs.items() if k in choice_variables}
states = {k: v for k, v in kwargs.items() if k in state_variables}
choices = {k: v for k, v in kwargs.items() if k in choice_variables}

u, f = current_u_and_f(**states, **choices, params=kwargs["params"])
return current_u_and_f(**states, **choices, params=kwargs["params"])

else:

if is_last_period:
big_u = u
@with_signature(args=arg_names)
def u_and_f(*args, **kwargs):
kwargs = all_as_kwargs(args, kwargs, arg_names=arg_names)

states = {k: v for k, v in kwargs.items() if k in state_variables}
choices = {k: v for k, v in kwargs.items() if k in choice_variables}

u, f = current_u_and_f(**states, **choices, params=kwargs["params"])

else:
_next_states = next_states(**states, **choices, params=kwargs["params"])
weights = next_weights(**states, **choices, params=kwargs["params"])

Expand All @@ -94,8 +103,7 @@ def u_and_f(*args, **kwargs):
ccv = (ccvs_at_nodes * node_weights).sum()

big_u = u + kwargs["params"]["beta"] * ccv

return big_u, f
return big_u, f

return u_and_f

Expand Down
18 changes: 13 additions & 5 deletions src/lcm/process_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,19 +273,27 @@ def _get_functions(user_model, function_info, variable_info, grids, params):
name=var,
)

# ==================================================================================
# Add 'params' argument to functions
# ==================================================================================
# We wrap the user functions such that they can be called with the 'params' argument
# instead of the individual parameters. This is done for all functions except for
# filter functions, because they cannot depend on model parameters. And dynamically
# generated weighting functions for stochastic next functions, since they are
# constructed to accept the 'params' argument.

functions = {}
for name, func in raw_functions.items():
# if the raw function is a weighting function for a stochastic variable, skip
is_weight_next_function = name.startswith("weight_next_")

if is_weight_next_function:
functions[name] = func
continue
processed_func = func

is_filter = function_info.loc[name, "is_filter"]
if is_filter:
elif function_info.loc[name, "is_filter"]:
if params.get(name, {}):
raise ValueError("filters cannot depend on model parameters.")
processed_func = func

elif params[name]:
processed_func = _get_extracting_function(
func=func,
Expand Down

0 comments on commit a9f5641

Please sign in to comment.