Skip to content

Commit

Permalink
Use correct coordinates
Browse files Browse the repository at this point in the history
  • Loading branch information
mj023 committed Oct 4, 2024
1 parent 505f653 commit 805727b
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions examples/example_dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,16 @@ def u_and_f_last(consumption, working, health, exercise,wealth, period, vf_arr,p
def u_and_f(consumption, working, health, exercise,wealth, period, vf_arr,params, last_period):
age = period + 18
income = labor_income(wage(age),working)
next_state = jnp.array([next_health(health, exercise, working), next_wealth(wealth, consumption, income, params['interest_rate'])])
step_length_wealth = (100 - 1) / (100 - 1)
next_wealth_pos = (next_wealth(wealth, consumption, income, params['interest_rate']) - 1) / step_length_wealth
step_length_health = (1 - 0) / (100 - 1)
next_health_pos = (next_health(health, exercise, working) - 0) / step_length_health
next_state = jnp.array([next_health_pos,next_wealth_pos])

ccv = map_coordinates(vf_arr, list(next_state), order = 1, mode= 'nearest')
big_u = utility(consumption, working, health, exercise, params['disutility_of_work']) + params['beta'] * ccv
return big_u, consumption_constraint(consumption, wealth, income)

def _base_productmap(func, product_axes: list[str]):
signature = inspect.signature(func)
parameters = list(signature.parameters)
Expand Down Expand Up @@ -105,7 +111,7 @@ def compute_ccv(*args, **kwargs):
cont_mapped = _base_productmap(compute_ccv, product_axes=['health','wealth','working'])
jit_cont_mapped = jax.jit(cont_mapped)
ccvs = jit_cont_mapped(consumption, working, health, exercise,wealth, period, vf_arr,params, last_period)
vf_arr = ccvs.max(axis = 2)
vf_arr = ccvs.max(axis = 2,initial=-jnp.inf)
reversed_solution.append(vf_arr)
last_period = False
return reversed_solution
Expand Down

0 comments on commit 805727b

Please sign in to comment.