Skip to content

Commit

Permalink
Allow general discrete grids (#90)
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens authored Sep 23, 2024
1 parent ff66bad commit ee48e30
Show file tree
Hide file tree
Showing 32 changed files with 1,729 additions and 981 deletions.
15 changes: 9 additions & 6 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,16 @@ jobs:
- '3.12'
steps:
- uses: actions/checkout@v4
- uses: prefix-dev/[email protected].0
- uses: prefix-dev/[email protected].1
with:
pixi-version: v0.29.0
cache: true
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
environments: test-cpu
activate-environment: true
frozen: true
- name: Run pytest
shell: bash -l {0}
shell: bash {0}
run: pixi run -e test-cpu tests
- name: Upload coverage report
if: runner.os == 'Linux' && matrix.python-version == '3.12'
Expand All @@ -46,26 +47,28 @@ jobs:
fail-fast: false
steps:
- uses: actions/checkout@v4
- uses: prefix-dev/[email protected].0
- uses: prefix-dev/[email protected].1
with:
pixi-version: v0.29.0
cache: true
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
environments: mypy
frozen: true
- name: Run mypy
shell: bash -l {0}
shell: bash {0}
run: pixi run mypy
run-explanation-notebooks:
name: Run explanation notebooks on Python 3.12
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: prefix-dev/[email protected].0
- uses: prefix-dev/[email protected].1
with:
pixi-version: v0.29.0
cache: true
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
environments: test-cpu
frozen: true
- name: Run explanation notebooks
shell: bash -l {0}
shell: bash {0}
run: pixi run -e test-cpu explanation-notebooks
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ repos:
hooks:
- id: yamllint
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.4
rev: v0.6.5
hooks:
# Run the linter.
- id: ruff
Expand Down
13 changes: 12 additions & 1 deletion examples/long_running.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Example specification for a consumption-savings model with health and exercise."""

from dataclasses import dataclass

import jax.numpy as jnp

from lcm import DiscreteGrid, LinspaceGrid, Model
Expand All @@ -9,6 +11,15 @@
# ======================================================================================


# --------------------------------------------------------------------------------------
# Categorical variables
# --------------------------------------------------------------------------------------
@dataclass
class WorkingStatus:
retired: int = 0
working: int = 1


# --------------------------------------------------------------------------------------
# Utility function
# --------------------------------------------------------------------------------------
Expand Down Expand Up @@ -67,7 +78,7 @@ def consumption_constraint(consumption, wealth, labor_income):
"age": age,
},
choices={
"working": DiscreteGrid([0, 1]),
"working": DiscreteGrid(WorkingStatus),
"consumption": LinspaceGrid(
start=1,
stop=100,
Expand Down
12 changes: 10 additions & 2 deletions explanations/dispatchers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
"metadata": {},
"outputs": [],
"source": [
"from dataclasses import dataclass\n",
"\n",
"import jax.numpy as jnp\n",
"import pytest\n",
"from jax import vmap\n",
Expand Down Expand Up @@ -277,6 +279,12 @@
"from lcm import DiscreteGrid, LinspaceGrid, Model\n",
"\n",
"\n",
"@dataclass\n",
"class RetirementStatus:\n",
" working: int = 0\n",
" retired: int = 1\n",
"\n",
"\n",
"def utility(consumption, retirement, lagged_retirement, wealth):\n",
" working = 1 - retirement\n",
" retirement_habit = lagged_retirement * wealth\n",
Expand All @@ -296,11 +304,11 @@
" },\n",
" n_periods=1,\n",
" choices={\n",
" \"retirement\": DiscreteGrid([0, 1]),\n",
" \"retirement\": DiscreteGrid(RetirementStatus),\n",
" \"consumption\": LinspaceGrid(start=1, stop=2, n_points=2),\n",
" },\n",
" states={\n",
" \"lagged_retirement\": DiscreteGrid([0, 1]),\n",
" \"lagged_retirement\": DiscreteGrid(RetirementStatus),\n",
" \"wealth\": LinspaceGrid(start=1, stop=4, n_points=4),\n",
" },\n",
")"
Expand Down
10 changes: 9 additions & 1 deletion explanations/function_representation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,19 @@
"metadata": {},
"outputs": [],
"source": [
"from dataclasses import dataclass\n",
"\n",
"import jax.numpy as jnp\n",
"\n",
"from lcm import DiscreteGrid, LinspaceGrid, Model\n",
"\n",
"\n",
"@dataclass\n",
"class RetirementStatus:\n",
" working: int = 0\n",
" retired: int = 1\n",
"\n",
"\n",
"def utility(consumption, working, disutility_of_work):\n",
" return jnp.log(consumption) - disutility_of_work * working\n",
"\n",
Expand Down Expand Up @@ -125,7 +133,7 @@
" \"age\": age,\n",
" },\n",
" choices={\n",
" \"retirement\": DiscreteGrid([0, 1]),\n",
" \"retirement\": DiscreteGrid(RetirementStatus),\n",
" \"consumption\": LinspaceGrid(start=1, stop=400, n_points=20),\n",
" },\n",
" states={\n",
Expand Down
Loading

0 comments on commit ee48e30

Please sign in to comment.