Skip to content

Commit

Permalink
Use correct arg_names in model_functions.py
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Sep 28, 2023
1 parent 66babea commit 8587548
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 7 deletions.
47 changes: 41 additions & 6 deletions src/lcm/model_functions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import inspect

import jax.numpy as jnp
from dags import concatenate_functions
from dags.signature import with_signature
Expand All @@ -23,6 +25,8 @@ def get_utility_and_feasibility_function(
# ==================================================================================
# Generate dynamic functions
# ==================================================================================
current_u_and_f = get_current_u_and_f(model)

if not is_last_period:
next_states = get_next_states_function(model)
next_weights = get_next_weights_function(model)
Expand All @@ -37,9 +41,27 @@ def get_utility_and_feasibility_function(

multiply_weights = get_multiply_weights(stochastic_variables)

current_u_and_f = get_current_u_and_f(model)
relevant_functions = [
current_u_and_f,
next_states,
next_weights,
scalar_function_evaluator,
]
else:
relevant_functions = [current_u_and_f]

arg_names = [*state_variables, *choice_variables, "params", "vf_arr"]
# ==================================================================================
# Update this section

arg_names = set()
for func in relevant_functions:
parameters = inspect.signature(func).parameters
arg_names.update(parameters.keys())
arg_names = list({"vf_arr", *arg_names})
arg_names = [arg for arg in arg_names if "next_" not in arg]

# Update this section
# ==================================================================================

# ==================================================================================
# Create the utility and feasability function
Expand All @@ -66,10 +88,23 @@ def u_and_f(*args, **kwargs):
variables=[f"next_{var}" for var in stochastic_variables],
)

ccvs_at_nodes = function_evaluator(
**_next_states,
vf_arr=kwargs["vf_arr"],
)
# ==========================================================================
# Update this section

if "state_indexer" in kwargs:
ccvs_at_nodes = function_evaluator(
**_next_states,
vf_arr=kwargs["vf_arr"],
state_indexer=kwargs["state_indexer"],
)
else:
ccvs_at_nodes = function_evaluator(
**_next_states,
vf_arr=kwargs["vf_arr"],
)

# Update this section
# ==========================================================================

node_weights = multiply_weights(**weights)

Expand Down
1 change: 0 additions & 1 deletion tests/test_entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from pybaum import tree_equal, tree_map

MODELS = {
"simple": PHELPS_DEATON,
"with_filters": PHELPS_DEATON_WITH_FILTERS,
}

Expand Down

0 comments on commit 8587548

Please sign in to comment.