Skip to content

Commit

Permalink
Update pre-commit hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Sep 2, 2024
1 parent 2c045d8 commit 85bcec4
Show file tree
Hide file tree
Showing 25 changed files with 89 additions and 74 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ repos:
- id: check-useless-excludes
# - id: identity # Prints all files passed to pre-commits. Debugging.
- repo: https://github.com/lyz-code/yamlfix
rev: 1.16.0
rev: 1.17.0
hooks:
- id: yamlfix
- repo: https://github.com/pre-commit/pre-commit-hooks
Expand Down Expand Up @@ -46,7 +46,7 @@ repos:
hooks:
- id: yamllint
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.7
rev: v0.6.3
hooks:
# Run the linter.
- id: ruff
Expand Down
1 change: 1 addition & 0 deletions examples/long_running.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Example specification for a consumption-savings model with health and exercise."""

import jax.numpy as jnp

from lcm import DiscreteGrid, LinspaceGrid, Model

# ======================================================================================
Expand Down
1 change: 1 addition & 0 deletions explanations/dispatchers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"import jax.numpy as jnp\n",
"import pytest\n",
"from jax import vmap\n",
"\n",
"from lcm.dispatchers import productmap, spacemap, vmap_1d"
]
},
Expand Down
1 change: 1 addition & 0 deletions explanations/function_representation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
"outputs": [],
"source": [
"import jax.numpy as jnp\n",
"\n",
"from lcm import DiscreteGrid, LinspaceGrid, Model\n",
"\n",
"\n",
Expand Down
4 changes: 2 additions & 2 deletions tests/test_analytical_solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@

import numpy as np
import pytest
from lcm._config import TEST_DATA
from lcm.entry_point import get_lcm_function
from numpy.testing import assert_array_almost_equal as aaae

from lcm._config import TEST_DATA
from lcm.entry_point import get_lcm_function
from tests.test_models.deterministic import get_model_config, get_params

# ======================================================================================
Expand Down
3 changes: 2 additions & 1 deletion tests/test_argmax.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import jax.numpy as jnp
from jax import jit
from lcm.argmax import _flatten_last_n_axes, _move_axes_to_back, argmax, segment_argmax
from numpy.testing import assert_array_equal

from lcm.argmax import _flatten_last_n_axes, _move_axes_to_back, argmax, segment_argmax

# Test jitted functions
# ======================================================================================
jitted_segment_argmax = jit(segment_argmax, static_argnums=2)
Expand Down
9 changes: 5 additions & 4 deletions tests/test_create_params.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import numpy as np
import pandas as pd
import pytest
from numpy.testing import assert_equal

from lcm.create_params_template import (
_create_function_params,
_create_stochastic_transition_params,
create_params_template,
)
from lcm.model import Model
from numpy.testing import assert_equal


def test_create_params_without_shocks():
Expand Down Expand Up @@ -51,7 +52,7 @@ def test_create_function_params():


def test_create_shock_params():
def next_a(a, _period): # noqa: ARG001
def next_a(a, _period):
pass

variable_info = pd.DataFrame(
Expand All @@ -74,7 +75,7 @@ def next_a(a, _period): # noqa: ARG001


def test_create_shock_params_invalid_variable():
def next_a(a): # noqa: ARG001
def next_a(a):
pass

variable_info = pd.DataFrame(
Expand All @@ -97,7 +98,7 @@ def next_a(a): # noqa: ARG001


def test_create_shock_params_invalid_dependency():
def next_a(a, b, _period): # noqa: ARG001
def next_a(a, b, _period):
pass

variable_info = pd.DataFrame(
Expand Down
21 changes: 11 additions & 10 deletions tests/test_discrete_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import pandas as pd
import pytest
from jax.ops import segment_max
from numpy.testing import assert_array_almost_equal as aaae

from lcm.discrete_problem import (
_calculate_emax_extreme_value_shocks,
_determine_dense_discrete_choice_axes,
Expand All @@ -13,18 +15,17 @@
_solve_discrete_problem_no_shocks,
get_solve_discrete_problem,
)
from numpy.testing import assert_array_almost_equal as aaae


@pytest.fixture()
@pytest.fixture
def cc_values():
"""Conditional continuation values."""
v_t = jnp.arange(20).reshape(2, 2, 5) / 2
# reuse old test case from when segment axis was last
return jnp.transpose(v_t, axes=(2, 0, 1))


@pytest.fixture()
@pytest.fixture
def segment_info():
return {
"segment_ids": jnp.array([0, 0, 1, 1, 1]),
Expand Down Expand Up @@ -141,7 +142,7 @@ def _get_reshaped_cc_values_and_variable_info(cc_values, collapse, n_extra_axes)
# ======================================================================================


@pytest.mark.illustrative()
@pytest.mark.illustrative
def test_get_solve_discrete_problem_illustrative():
variable_info = pd.DataFrame(
{
Expand Down Expand Up @@ -172,7 +173,7 @@ def test_get_solve_discrete_problem_illustrative():
aaae(got, jnp.array([1, 3, 5]))


@pytest.mark.illustrative()
@pytest.mark.illustrative
def test_solve_discrete_problem_no_shocks_illustrative():
cc_values = jnp.array(
[
Expand Down Expand Up @@ -213,7 +214,7 @@ def test_solve_discrete_problem_no_shocks_illustrative():
aaae(got, jnp.array([3, 5]))


@pytest.mark.illustrative()
@pytest.mark.illustrative
def test_calculate_emax_extreme_value_shocks_illustrative():
cc_values = jnp.array(
[
Expand Down Expand Up @@ -259,7 +260,7 @@ def test_calculate_emax_extreme_value_shocks_illustrative():
# ======================================================================================


@pytest.mark.illustrative()
@pytest.mark.illustrative
def test_segment_max_over_first_axis_illustrative():
a = jnp.arange(4)
segment_info = {
Expand All @@ -271,7 +272,7 @@ def test_segment_max_over_first_axis_illustrative():
aaae(got, expected)


@pytest.mark.illustrative()
@pytest.mark.illustrative
def test_segment_extreme_value_emax_over_first_axis_illustrative():
a = jnp.array([[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11]])

Expand All @@ -289,7 +290,7 @@ def test_segment_extreme_value_emax_over_first_axis_illustrative():
aaae(got, expected)


@pytest.mark.illustrative()
@pytest.mark.illustrative
def test_segment_logsumexp_illustrative():
a = jnp.array([[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11]])

Expand All @@ -308,7 +309,7 @@ def test_segment_logsumexp_illustrative():
# ======================================================================================


@pytest.mark.illustrative()
@pytest.mark.illustrative
def test_determine_discrete_choice_axes_illustrative():
# No discrete choice variable
# ==================================================================================
Expand Down
15 changes: 8 additions & 7 deletions tests/test_dispatchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

import jax.numpy as jnp
import pytest
from numpy.testing import assert_array_almost_equal as aaae

from lcm.dispatchers import (
productmap,
spacemap,
vmap_1d,
)
from lcm.functools import allow_args
from numpy.testing import assert_array_almost_equal as aaae


def f(a, /, *, b, c):
Expand Down Expand Up @@ -40,7 +41,7 @@ def g(a, /, b, *, c, d):
# ======================================================================================


@pytest.fixture()
@pytest.fixture
def setup_productmap_f():
return {
"a": jnp.linspace(-5, 5, 10),
Expand All @@ -49,7 +50,7 @@ def setup_productmap_f():
}


@pytest.fixture()
@pytest.fixture
def expected_productmap_f():
grids = {
"a": jnp.linspace(-5, 5, 10),
Expand All @@ -61,7 +62,7 @@ def expected_productmap_f():
return allow_args(f)(*helper).reshape(10, 7, 5)


@pytest.fixture()
@pytest.fixture
def setup_productmap_g():
return {
"a": jnp.linspace(-5, 5, 10),
Expand All @@ -71,7 +72,7 @@ def setup_productmap_g():
}


@pytest.fixture()
@pytest.fixture
def expected_productmap_g():
grids = {
"a": jnp.linspace(-5, 5, 10),
Expand Down Expand Up @@ -185,7 +186,7 @@ def test_productmap_with_some_argument_mapped_twice():
# ======================================================================================


@pytest.fixture()
@pytest.fixture
def setup_spacemap():
value_grid = {
"a": jnp.array([1.0, 2, 3]),
Expand All @@ -206,7 +207,7 @@ def setup_spacemap():
return value_grid, combination_grid


@pytest.fixture()
@pytest.fixture
def expected_spacemap():
value_grid = {
"a": jnp.array([1.0, 2, 3]),
Expand Down
4 changes: 2 additions & 2 deletions tests/test_entry_point.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import jax.numpy as jnp
import pytest
from pybaum import tree_equal, tree_map

from lcm.entry_point import (
create_compute_conditional_continuation_policy,
create_compute_conditional_continuation_value,
Expand All @@ -8,8 +10,6 @@
from lcm.model_functions import get_utility_and_feasibility_function
from lcm.process_model import process_model
from lcm.state_space import create_state_choice_space
from pybaum import tree_equal, tree_map

from tests.test_models.deterministic import get_model_config
from tests.test_models.deterministic import utility as iskhakov_et_al_2017_utility

Expand Down
11 changes: 6 additions & 5 deletions tests/test_function_representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import jax.numpy as jnp
import pytest

from lcm import LinspaceGrid
from lcm.dispatchers import productmap
from lcm.function_representation import (
Expand Down Expand Up @@ -308,7 +309,7 @@ def _utility(wealth, working):
# ======================================================================================


@pytest.mark.illustrative()
@pytest.mark.illustrative
def test_get_function_evaluator_illustrative():
a_grid = LinspaceGrid(start=0, stop=1, n_points=3)

Expand Down Expand Up @@ -342,7 +343,7 @@ def test_get_function_evaluator_illustrative():
assert jnp.allclose(got, expected)


@pytest.mark.illustrative()
@pytest.mark.illustrative
def test_get_lookup_function_illustrative():
values = jnp.array([0, 1, 4])
func = _get_lookup_function(array_name="xyz", axis_names=["a"])
Expand All @@ -351,7 +352,7 @@ def test_get_lookup_function_illustrative():
assert pure_lookup_func(a=2) == 4


@pytest.mark.illustrative()
@pytest.mark.illustrative
def test_get_coordinate_finder_illustrative():
find_coordinate = _get_coordinate_finder(
in_name="a",
Expand All @@ -365,7 +366,7 @@ def test_get_coordinate_finder_illustrative():
assert find_coordinate(a=0.25) == 0.5


@pytest.mark.illustrative()
@pytest.mark.illustrative
def test_get_interpolator_illustrative():
interpolate = _get_interpolator(
name_of_values_on_grid="test_name",
Expand All @@ -386,7 +387,7 @@ def f(a, b):
assert interpolate(test_name=values, a=0.5, b=1.5) == -1


@pytest.mark.illustrative()
@pytest.mark.illustrative
def test_fail_if_interpolation_axes_are_not_last_illustrative():
# Empty intersection of axis_names and interpolation_info
# ==================================================================================
Expand Down
11 changes: 6 additions & 5 deletions tests/test_functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import jax.numpy as jnp
import pytest
from jax import vmap
from numpy.testing import assert_array_almost_equal as aaae

from lcm.functools import (
all_as_args,
all_as_kwargs,
Expand All @@ -12,18 +14,17 @@
convert_kwargs_to_args,
get_union_of_arguments,
)
from numpy.testing import assert_array_almost_equal as aaae

# ======================================================================================
# get_union_of_arguments
# ======================================================================================


def test_get_union_of_arguments():
def f(a, b): # noqa: ARG001
def f(a, b):
pass

def g(b, c): # noqa: ARG001
def g(b, c):
pass

got = get_union_of_arguments([f, g])
Expand Down Expand Up @@ -160,7 +161,7 @@ def f(a, /, b):


def test_allow_only_kwargs_signature_change():
def f(a, /, b, *, c): # noqa: ARG001
def f(a, /, b, *, c):
pass

decorated = allow_only_kwargs(f)
Expand Down Expand Up @@ -239,7 +240,7 @@ def f(a, *, b):


def test_allow_args_signature_change():
def f(a, /, b, *, c): # noqa: ARG001
def f(a, /, b, *, c):
pass

decorated = allow_args(f)
Expand Down
Loading

0 comments on commit 85bcec4

Please sign in to comment.