Skip to content

Commit

Permalink
Add pixi-kernel and update some text in notebooks.
Browse files Browse the repository at this point in the history
  • Loading branch information
hmgaudecker committed Jun 17, 2024
1 parent 913b3b1 commit a999b55
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 50 deletions.
9 changes: 4 additions & 5 deletions explanations/dispatchers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"outputs": [],
"source": [
"import jax.numpy as jnp\n",
"import pytest\n",
"from jax import vmap\n",
"from lcm.dispatchers import productmap, spacemap, vmap_1d"
]
Expand All @@ -27,7 +28,7 @@
"source": [
"# `vmap_1d`\n",
"\n",
"Let's vectorizing the function `f` over axis `a`."
"Let's start by vectorizing the function `f` over axis `a` using Jax' `vmap` function."
]
},
{
Expand Down Expand Up @@ -100,8 +101,6 @@
"metadata": {},
"outputs": [],
"source": [
"import pytest\n",
"\n",
"with pytest.raises(\n",
" ValueError,\n",
" match=\"vmap in_axes must be an int, None, or a tuple of entries corresponding to\",\n",
Expand Down Expand Up @@ -247,7 +246,7 @@
"The `spacemap` function combines `productmap` and `vmap_1d` in a way that is often\n",
"needed in `lcm`.\n",
"\n",
"If the valid values of a variable in a state-choice space depend on another variable, that variable is termed a _sparse_ variable; otherwise, it is a _dense_ variable. To dispatch a function across an entire state-choice space, we must vectorize over both dense and sparse variables. Since, by definition, all values of dense variables are valid, we can simply perform a `productmap` over the Cartesian grid of values. The valid combinations of sparse variables are stored as a collection of 1D arrays (see below for an example). For these, we can perform a call to `vmap_1d`.\n",
"If the valid values of a variable in a state-choice space depend on another variable, that variable is termed a _sparse_ variable; otherwise, it is a _dense_ variable. To dispatch a function across an entire state-choice space, we must vectorize over both dense and sparse variables. Since, by definition, all values of dense variables are valid, we can simply perform a `productmap` over the Cartesian grid of their values. The valid combinations of sparse variables are stored as a collection of 1D arrays (see below for an example). For these, we can perform a call to `vmap_1d`.\n",
"\n",
"Consider a simplified version of our deterministic test model. Curly brackets {} denote discrete variables; square brackets [] represent continuous variables.\n",
"\n",
Expand Down Expand Up @@ -810,7 +809,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
"version": "3.12.0"
}
},
"nbformat": 4,
Expand Down
68 changes: 31 additions & 37 deletions explanations/function_evaluator.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
"# Explanation of the Function Evaluator\n",
"\n",
"In this notebook, we showcase how the function evaluator is used in `lcm`, and how it\n",
"works. Before we dive into the details, let us consider what is does on a high level.\n",
"works. Before we dive into the details, let us consider what it does on a high level.\n",
"\n",
"## Motivation\n",
"\n",
"Consider the last period of a finite dynamic programming problem. The value function\n",
"array for this period corresponds to the maximal utility in each state. If the\n",
"state-space is discretized into states $(x_1, \\ldots, x_p)$, the value function array\n",
"(in the last period) $V_T$ is a $p$-dimensional array, where the $i$-th entry\n",
"array for this period corresponds to the maximum of concurrent utility in each state,\n",
"where the maximum is taken over choices.\n",
"\n",
"If the state-space is discretized into states $(x_1, \\ldots, x_p)$, the value function\n",
"array (in the last period) $V_T$ is a $p$-dimensional array, where the $i$-th entry\n",
"$V_{T, i} = V_T(x_i)$ is the maximal utility the agent can achieve in state $x_i$.\n",
"\n",
"Consider now the Bellman equation for the second-to last period:\n",
Expand All @@ -31,9 +33,9 @@
"may need to evaluate the function $V_T$ at a different set of points than the\n",
"pre-calculated grid points $(x_1, \\ldots, x_p)$.\n",
"\n",
"Ideally, we would like to treat $V_T$ as an analytical function that can be evaluated\n",
"at any valid state $x$, ignoring the discretization. This is what the function evaluator\n",
"does.\n",
"Ideally, we would like to treat $V_T$ as an analytical function that can be evaluated at\n",
"any valid state $x$, ignoring the discretization. This is precisely what the function\n",
"evaluator does.\n",
" \n",
"\n",
"### Example\n",
Expand Down Expand Up @@ -85,7 +87,7 @@
"model = {\n",
" \"description\": (\n",
" \"Starts from Iskhakov et al. (2017), removes filters and the lagged_retirement \"\n",
" \"state, and adds wage function that depends on age.\"\n",
" \"state, and adds a wage function that depends on age.\"\n",
" ),\n",
" \"functions\": {\n",
" \"utility\": utility,\n",
Expand Down Expand Up @@ -138,15 +140,7 @@
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
]
}
],
"outputs": [],
"source": [
"from lcm.process_model import process_model\n",
"\n",
Expand Down Expand Up @@ -175,7 +169,7 @@
"text": [
"\u001b[0;31mSignature:\u001b[0m \u001b[0mu_and_f\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconsumption\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretirement\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwealth\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mDocstring:\u001b[0m <no docstring>\n",
"\u001b[0;31mFile:\u001b[0m ~/sciebo-thinky/lcm/.pixi/envs/default/lib/python3.12/site-packages/dags/signature.py\n",
"\u001b[0;31mFile:\u001b[0m /mnt/econ/lcm/lcm/.pixi/envs/test-gpu/lib/python3.12/site-packages/dags/signature.py\n",
"\u001b[0;31mType:\u001b[0m function"
]
}
Expand All @@ -192,8 +186,8 @@
"metadata": {},
"source": [
"We can then evaluate `u_and_f` on scalar values. Notice that in the below example, the\n",
"action is not feasible since the consumption constraint restricts a consumption level\n",
"that is higher than the wealth level."
"action is not feasible since the consumption constraint forbids a consumption level that\n",
"is larger than wealth."
]
},
{
Expand All @@ -205,15 +199,15 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Utility: 4.3601579666137695, feasible: False\n"
"Utility: 4.355170249938965, feasible: False\n"
]
}
],
"source": [
"_u, _f = u_and_f(\n",
" consumption=100.5,\n",
" consumption=100,\n",
" retirement=0,\n",
" wealth=50.25,\n",
" wealth=50,\n",
" params=params,\n",
")\n",
"\n",
Expand Down Expand Up @@ -257,7 +251,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Length of (retirement, consumption, wealth) grids: (10, 20, 2)\n"
"Length of (wealth, consumption, retirement) grids: (10, 20, 2)\n"
]
}
],
Expand All @@ -268,7 +262,7 @@
"\n",
"u, f = u_and_f_mapped(**processed_model.grids, params=params)\n",
"\n",
"print(f\"Length of (retirement, consumption, wealth) grids: {u.shape}\")"
"print(f\"Length of (wealth, consumption, retirement) grids: {u.shape}\")"
]
},
{
Expand Down Expand Up @@ -340,8 +334,8 @@
"returns pre-calculated values if evaluated on a grid point, and linearly interpolated\n",
"values otherwise.\n",
"\n",
"To optimally utilize the structure of the grids when interpolating, the function evaluator\n",
"requires information on the state space."
"To optimally utilize the structure of the grids when interpolating, the function\n",
"evaluator requires information on the state space."
]
},
{
Expand All @@ -353,19 +347,19 @@
"from lcm.state_space import create_state_choice_space\n",
"\n",
"# the space info object contains information on the grid structure etc.\n",
"_, space_info, *_ = create_state_choice_space(\n",
"space_info = create_state_choice_space(\n",
" model=processed_model,\n",
" period=1,\n",
" is_last_period=True,\n",
" jit_filter=False,\n",
")"
")[1]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Generating the function evaluator"
"#### Setting up the function evaluator"
]
},
{
Expand Down Expand Up @@ -489,7 +483,7 @@
"text": [
"\u001b[0;31mSignature:\u001b[0m \u001b[0mtranslator\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhealth\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mDocstring:\u001b[0m <no docstring>\n",
"\u001b[0;31mFile:\u001b[0m ~/sciebo-thinky/lcm/src/lcm/function_evaluator.py\n",
"\u001b[0;31mFile:\u001b[0m /mnt/econ/lcm/lcm/src/lcm/function_evaluator.py\n",
"\u001b[0;31mType:\u001b[0m function"
]
}
Expand Down Expand Up @@ -543,7 +537,7 @@
"text": [
"\u001b[0;31mSignature:\u001b[0m \u001b[0mlookup\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mwealth_index\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvf_arr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mDocstring:\u001b[0m <no docstring>\n",
"\u001b[0;31mFile:\u001b[0m ~/sciebo-thinky/lcm/src/lcm/function_evaluator.py\n",
"\u001b[0;31mFile:\u001b[0m /mnt/econ/lcm/lcm/src/lcm/function_evaluator.py\n",
"\u001b[0;31mType:\u001b[0m function"
]
}
Expand Down Expand Up @@ -667,7 +661,7 @@
"text": [
"\u001b[0;31mSignature:\u001b[0m \u001b[0mwealth_coordinate_finder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mwealth\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mDocstring:\u001b[0m <no docstring>\n",
"\u001b[0;31mFile:\u001b[0m ~/sciebo-thinky/lcm/src/lcm/function_evaluator.py\n",
"\u001b[0;31mFile:\u001b[0m /mnt/econ/lcm/lcm/src/lcm/function_evaluator.py\n",
"\u001b[0;31mType:\u001b[0m function"
]
}
Expand All @@ -691,7 +685,7 @@
{
"data": {
"text/plain": [
"Array([0. , 0.50000006, 4.488722 , 8.999998 ], dtype=float32)"
"Array([0. , 0.50000006, 4.4887223 , 8.999998 ], dtype=float32)"
]
},
"execution_count": 21,
Expand Down Expand Up @@ -723,7 +717,7 @@
"text": [
"\u001b[0;31mSignature:\u001b[0m \u001b[0mvalue_function_interpolator\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvf_arr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwealth_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mDocstring:\u001b[0m <no docstring>\n",
"\u001b[0;31mFile:\u001b[0m ~/sciebo-thinky/lcm/src/lcm/function_evaluator.py\n",
"\u001b[0;31mFile:\u001b[0m /mnt/econ/lcm/lcm/src/lcm/function_evaluator.py\n",
"\u001b[0;31mType:\u001b[0m function"
]
}
Expand All @@ -747,7 +741,7 @@
{
"data": {
"text/plain": [
"Array([0. , 1.8806003, 5.238375 , 5.991464 ], dtype=float32)"
"Array([0. , 1.8806003, 5.2383747, 5.991464 ], dtype=float32)"
]
},
"execution_count": 23,
Expand Down Expand Up @@ -963,7 +957,7 @@
"text": [
"\u001b[0;31mSignature:\u001b[0m \u001b[0mvalue_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvf_arr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwealth\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mDocstring:\u001b[0m <no docstring>\n",
"\u001b[0;31mFile:\u001b[0m ~/sciebo-thinky/lcm/.pixi/envs/default/lib/python3.12/site-packages/dags/dag.py\n",
"\u001b[0;31mFile:\u001b[0m /mnt/econ/lcm/lcm/.pixi/envs/test-gpu/lib/python3.12/site-packages/dags/dag.py\n",
"\u001b[0;31mType:\u001b[0m function"
]
}
Expand Down
Loading

0 comments on commit a999b55

Please sign in to comment.