diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 88b5520..e2b33d7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,16 +9,6 @@ repos: rev: 1.16.0 hooks: - id: yamlfix - - repo: local - hooks: - - id: check-nbqa-version-mismatch - name: Check for version mismatch between black, ruff, and nbQA - entry: python scripts/check_nbqa_version_mismatch.py - language: python - always_run: true - require_serial: true - additional_dependencies: - - pyyaml - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.6.0 hooks: @@ -55,30 +45,27 @@ repos: rev: v1.35.1 hooks: - id: yamllint - - repo: https://github.com/psf/black - rev: 24.4.2 - hooks: - - id: black - language_version: python3.12 - repo: https://github.com/asottile/blacken-docs rev: 1.18.0 hooks: - id: blacken-docs - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.5.0 + rev: v0.5.1 hooks: + # Run the linter. - id: ruff - # args: - # - --verbose - - repo: https://github.com/nbQA-dev/nbQA - rev: 1.8.5 - hooks: - - id: nbqa-black - additional_dependencies: - - black==24.4.2 - - id: nbqa-ruff - additional_dependencies: - - ruff==v0.5.0 + types_or: + - python + - pyi + - jupyter + args: + - --fix + # Run the formatter. + - id: ruff-format + types_or: + - python + - pyi + - jupyter - repo: https://github.com/executablebooks/mdformat rev: 0.7.17 hooks: diff --git a/explanations/dispatchers.ipynb b/explanations/dispatchers.ipynb index b6cea60..5a73dd4 100644 --- a/explanations/dispatchers.ipynb +++ b/explanations/dispatchers.ipynb @@ -682,10 +682,8 @@ " strict=False,\n", " ),\n", "):\n", - "\n", " # loop over product of dense variables\n", " for j, wealth in enumerate(sc_space.dense_vars[\"wealth\"]):\n", - "\n", " u = utility(\n", " wealth=wealth,\n", " retirement=retirement,\n", diff --git a/pixi.lock b/pixi.lock index f5eaf12..f330525 100644 --- a/pixi.lock +++ b/pixi.lock @@ -6735,9 +6735,9 @@ packages: timestamp: 1664996250081 - kind: pypi name: lcm - version: 0.1.dev185+g33959a4.d20240618 + version: 0.1.dev178+gf51ffae.d20240711 path: . - sha256: 31807f257930d650403d90b08f8f0448c091d216cc368d21e1f903d437bb983f + sha256: 5a26861b4465172ada3aa21512b70fb9fda05861c66d221b3aabd003aad8a1b6 requires_dist: - dags - numpy @@ -10795,8 +10795,8 @@ packages: - kind: pypi name: pandas version: 2.2.2 - url: https://files.pythonhosted.org/packages/40/10/79e52ef01dfeb1c1ca47a109a01a248754ebe990e159a844ece12914de83/pandas-2.2.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - sha256: eee3a87076c0756de40b05c5e9a6069c035ba43e8dd71c379e68cab2c20f16ad + url: https://files.pythonhosted.org/packages/db/7c/9a60add21b96140e22465d9adf09832feade45235cd22f4cb1668a25e443/pandas-2.2.2-cp312-cp312-macosx_11_0_arm64.whl + sha256: e9b79011ff7a0f4b1d6da6a61aa1aa604fb312d6647de5bad20013682d1429ce requires_dist: - numpy>=1.22.4 ; python_version < '3.11' - numpy>=1.23.2 ; python_version == '3.11' @@ -10979,8 +10979,8 @@ packages: - kind: pypi name: pandas version: 2.2.2 - url: https://files.pythonhosted.org/packages/db/7c/9a60add21b96140e22465d9adf09832feade45235cd22f4cb1668a25e443/pandas-2.2.2-cp312-cp312-macosx_11_0_arm64.whl - sha256: e9b79011ff7a0f4b1d6da6a61aa1aa604fb312d6647de5bad20013682d1429ce + url: https://files.pythonhosted.org/packages/dd/49/de869130028fb8d90e25da3b7d8fb13e40f5afa4c4af1781583eb1ff3839/pandas-2.2.2-cp312-cp312-macosx_10_9_x86_64.whl + sha256: 9dfde2a0ddef507a631dc9dc4af6a9489d5e2e740e226ad426a05cabfbd7c8ef requires_dist: - numpy>=1.22.4 ; python_version < '3.11' - numpy>=1.23.2 ; python_version == '3.11' @@ -11071,8 +11071,8 @@ packages: - kind: pypi name: pandas version: 2.2.2 - url: https://files.pythonhosted.org/packages/dd/49/de869130028fb8d90e25da3b7d8fb13e40f5afa4c4af1781583eb1ff3839/pandas-2.2.2-cp312-cp312-macosx_10_9_x86_64.whl - sha256: 9dfde2a0ddef507a631dc9dc4af6a9489d5e2e740e226ad426a05cabfbd7c8ef + url: https://files.pythonhosted.org/packages/40/10/79e52ef01dfeb1c1ca47a109a01a248754ebe990e159a844ece12914de83/pandas-2.2.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + sha256: eee3a87076c0756de40b05c5e9a6069c035ba43e8dd71c379e68cab2c20f16ad requires_dist: - numpy>=1.22.4 ; python_version < '3.11' - numpy>=1.23.2 ; python_version == '3.11' diff --git a/pyproject.toml b/pyproject.toml index 8abb86d..8df5d76 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -164,6 +164,7 @@ test-gpu = {features = ["test", "cuda"], solve-group = "cuda"} [tool.ruff] target-version = "py312" fix = true +exclude = ["src/lcm/sandbox"] [tool.ruff.lint] select = ["ALL"] @@ -215,6 +216,9 @@ extend-ignore = [ # long messages outside the exception class "TRY003", + + # Missing docstring in magic method + "D105", ] [tool.ruff.lint.per-file-ignores] @@ -223,6 +227,7 @@ extend-ignore = [ "examples/*" = ["INP001"] "explanations/*" = ["INP001", "B018", "T201", "E402", "PD008"] "scripts/*" = ["INP001", "D101", "RET503"] +"**/*.ipynb" = ["FBT003", "E402"] [tool.ruff.lint.pydocstyle] convention = "google" @@ -238,7 +243,6 @@ black = "pyproject.toml" black = 1 [tool.nbqa.exclude] -ruff = "src/lcm/sandbox" black = "src/lcm/sandbox" # ====================================================================================== diff --git a/scripts/check_nbqa_version_mismatch.py b/scripts/check_nbqa_version_mismatch.py deleted file mode 100644 index 2f0a044..0000000 --- a/scripts/check_nbqa_version_mismatch.py +++ /dev/null @@ -1,118 +0,0 @@ -"""Make sure that black and ruff versions used with nbQA match the primary versions. - -For example, we could use ruff==v0.4.6 in the primary pre-commit config, but nbQA could -be using ruff==v0.3.2. In this case, we raise an error. - -""" - -from pathlib import Path -from typing import NotRequired, TypedDict - -import yaml - - -class PreCommitRepo(TypedDict): - repo: str - hooks: list[dict] - rev: NotRequired[str] - - -class PreCommitConfig(TypedDict): - repos: list[PreCommitRepo] - ci: dict - - -class NbQAHook(TypedDict): - id: str - additional_dependencies: list[str] - - -class NbQARepo(PreCommitRepo): - repo: str - hooks: list[NbQAHook] - rev: str - - -TOOL_TO_REPO = { - "black": "https://github.com/psf/black", - "ruff": "https://github.com/astral-sh/ruff-pre-commit", - "nbQA": "https://github.com/nbQA-dev/nbQA", -} - - -def read_pre_commit_config() -> PreCommitConfig: - """Read a YAML file.""" - with Path(".pre-commit-config.yaml").open() as stream: - try: - return yaml.safe_load(stream) - except yaml.YAMLError as error: - raise ValueError("Failed to parse .pre-commit-config.yaml file") from error - - -def get_nbqa_repo(pre_commit_config: PreCommitConfig) -> NbQARepo: - """Get the nbQA repo from the pre-commit config. - - Args: - pre_commit_config: The pre-commit config. - - Returns: - The nbQA repo if found, otherwise None. - - """ - for repo in pre_commit_config["repos"]: - if repo["repo"] == TOOL_TO_REPO["nbQA"]: - return repo - - -def get_primary_version(pre_commit_config: PreCommitConfig, tool: str) -> str: - """Get the primary version of the tool used in the pre-commit config. - - Args: - pre_commit_config: The pre-commit config. - tool: The tool to get the primary version for. For example, ruff or black. - - Returns: - The primary version of the tool, otherwise None. - - """ - for repo in pre_commit_config["repos"]: - if repo["repo"] == TOOL_TO_REPO[tool]: - return repo["rev"] - - -def check_for_version_mismatch( - pre_commit_config: PreCommitConfig, - nbqa_repo: NbQARepo, -) -> None: - """Check for version mismatch between primary versions and versions used by nbQA. - - Args: - pre_commit_config: The pre-commit config. - nbqa_repo: The nbQA repo. - - Raises: - ValueError: If there is a version mismatch. - - """ - version_mismatch = {} - - for hook in nbqa_repo["hooks"]: - tool = hook["id"].removeprefix("nbqa-") - primary_version = get_primary_version(pre_commit_config, tool) - tool_dependency = hook["additional_dependencies"][0] - version_used_by_nbqa = tool_dependency.removeprefix(f"{tool}==") - - if primary_version != version_used_by_nbqa: - version_mismatch[tool] = { - "primary": primary_version, - "nbQA": version_used_by_nbqa, - } - - if version_mismatch: - raise ValueError(f"Versions mismatch in nbQA repo: {version_mismatch}") - - -if __name__ == "__main__": - pre_commit_config = read_pre_commit_config() - nbqa_repo = get_nbqa_repo(pre_commit_config) - check_for_version_mismatch(pre_commit_config, nbqa_repo) diff --git a/src/lcm/__init__.py b/src/lcm/__init__.py index 948d15a..ef81605 100644 --- a/src/lcm/__init__.py +++ b/src/lcm/__init__.py @@ -1,5 +1,4 @@ from lcm import mark -from lcm.user_grids import DiscreteGrid, LinspaceGrid, LogspaceGrid -from lcm.user_model import Model +from lcm.user_input import DiscreteGrid, LinspaceGrid, LogspaceGrid, Model __all__ = ["mark", "Model", "LinspaceGrid", "LogspaceGrid", "DiscreteGrid"] diff --git a/src/lcm/create_params_template.py b/src/lcm/create_params_template.py index f7f0b14..92304cd 100644 --- a/src/lcm/create_params_template.py +++ b/src/lcm/create_params_template.py @@ -7,7 +7,7 @@ from jax import Array from lcm.typing import Params, ScalarUserInput -from lcm.user_model import Model +from lcm.user_input import Model def create_params_template( @@ -129,7 +129,6 @@ def _create_stochastic_transition_params( invalid_dependencies = {} for var in stochastic_variables: - # Retrieve corresponding next function and its arguments next_var = user_model.functions[f"next_{var}"] dependencies = list(inspect.signature(next_var).parameters) diff --git a/src/lcm/entry_point.py b/src/lcm/entry_point.py index b25285e..4f8ac3c 100644 --- a/src/lcm/entry_point.py +++ b/src/lcm/entry_point.py @@ -18,7 +18,7 @@ from lcm.simulate import simulate from lcm.solve_brute import solve from lcm.state_space import create_state_choice_space -from lcm.user_model import Model +from lcm.user_input import Model def get_lcm_function( diff --git a/src/lcm/process_model.py b/src/lcm/process_model.py index fb58c86..42f26bc 100644 --- a/src/lcm/process_model.py +++ b/src/lcm/process_model.py @@ -15,13 +15,19 @@ from lcm.interfaces import ( ContinuousGridInfo, ContinuousGridSpec, + ContinuousGridType, DiscreteGridSpec, GridSpec, InternalModel, ) from lcm.typing import Params -from lcm.user_grids import ContinuousGrid, DiscreteGrid, LinspaceGrid, LogspaceGrid -from lcm.user_model import Model +from lcm.user_input import ( + ContinuousGrid, + DiscreteGrid, + LinspaceGrid, + LogspaceGrid, + Model, +) def process_model(user_model: Model) -> InternalModel: @@ -225,7 +231,7 @@ def _get_gridspecs( for name, spec in raw_variables.items(): if isinstance(spec, ContinuousGrid): if isinstance(spec, LinspaceGrid): - kind = "linspace" + kind: ContinuousGridType = "linspace" elif isinstance(spec, LogspaceGrid): kind = "logspace" else: diff --git a/src/lcm/sandbox/state_space_jax_versus_numba.ipynb b/src/lcm/sandbox/state_space_jax_versus_numba.ipynb index eed6277..784cf9c 100644 --- a/src/lcm/sandbox/state_space_jax_versus_numba.ipynb +++ b/src/lcm/sandbox/state_space_jax_versus_numba.ipynb @@ -7,22 +7,20 @@ "metadata": {}, "outputs": [], "source": [ - "import math\n", "import itertools\n", - "import numpy as np\n", - "from numba import njit\n", - "\n", - "import matplotlib.pyplot as plt\n", + "import math\n", "\n", "import jax.numpy as jnp\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", "from jax import jit\n", "from jax.config import config\n", + "from numba import njit\n", "\n", "config.update(\"jax_enable_x64\", True)\n", "\n", "from dags import concatenate_functions\n", "from lcm.dispatchers import gridmap, productmap\n", - "\n", "from numpy.testing import assert_array_almost_equal as aaae" ] }, @@ -136,7 +134,8 @@ "}\n", "\n", "utility_concat = concatenate_functions(\n", - " functions=[_utility, _leisure, _consumption], targets=\"_utility\"\n", + " functions=[_utility, _leisure, _consumption],\n", + " targets=\"_utility\",\n", ")\n", "_decorated_func = productmap(utility_concat, [\"wage\", \"working\"])\n", "decorated_func_jit = jit(_decorated_func)\n", @@ -184,7 +183,8 @@ "\n", "def get_jax_runtime_2d():\n", " utility_concat = concatenate_functions(\n", - " functions=[_utility, _leisure, _consumption], targets=\"_utility\"\n", + " functions=[_utility, _leisure, _consumption],\n", + " targets=\"_utility\",\n", " )\n", " _decorated_func = productmap(utility_concat, [\"wage\", \"working\"])\n", " decorated_func_jit = jit(_decorated_func)\n", @@ -394,7 +394,7 @@ " np.linspace(5, 25, 5),\n", " np.linspace(1, 5, 5),\n", " np.linspace(2, 10, 5),\n", - " )\n", + " ),\n", ")" ] }, @@ -513,7 +513,9 @@ ")\n", "\n", "_decorated_func = gridmap(\n", - " utility_concat_4d, list(dense_variables), list(contingent_variables)\n", + " utility_concat_4d,\n", + " list(dense_variables),\n", + " list(contingent_variables),\n", ")\n", "decorated_func_jit = jit(_decorated_func)\n", "rslt_jax = decorated_func_jit(**dense_variables, **contingent_variables)\n", @@ -561,7 +563,8 @@ "\n", "def get_jax_runtime(len_grids_arr):\n", " utility_concat_4d = concatenate_functions(\n", - " functions=[_utility_4d, _leisure_4d, _consumption_4d], targets=\"_utility_4d\"\n", + " functions=[_utility_4d, _leisure_4d, _consumption_4d],\n", + " targets=\"_utility_4d\",\n", " )\n", "\n", " _decorated_func = gridmap(utility_concat_4d, [\"a\", \"b\"], [\"c\", \"d\"])\n", @@ -741,7 +744,8 @@ "\n", "def get_jax_runtime_4d(len_grids_arr):\n", " utility_concat_4d = concatenate_functions(\n", - " functions=[_utility_4d, _leisure_4d, _consumption_4d], targets=\"_utility_4d\"\n", + " functions=[_utility_4d, _leisure_4d, _consumption_4d],\n", + " targets=\"_utility_4d\",\n", " )\n", "\n", " _decorated_func = productmap(utility_concat_4d, [\"a\", \"b\", \"c\", \"d\"])\n", @@ -775,7 +779,7 @@ " np.linspace(5, 25, 5),\n", " np.linspace(1, 5, 5),\n", " np.linspace(2, 10, 5),\n", - " )\n", + " ),\n", ")" ] }, diff --git a/src/lcm/user_grids.py b/src/lcm/user_grids.py deleted file mode 100644 index 6bb3dc6..0000000 --- a/src/lcm/user_grids.py +++ /dev/null @@ -1,192 +0,0 @@ -import dataclasses as dc -from collections.abc import Iterable -from typing import Self, get_args - -from lcm.typing import ScalarUserInput - - -class Grid: - """LCM Grid base class.""" - - -@dc.dataclass(frozen=True) -class DiscreteGrid(Grid): - """A grid of discrete values. - - Attributes: - options: The options in the grid. Must be an iterable of scalar int or float - values. - - """ - - options: Iterable[ScalarUserInput] - - def __post_init__(self) -> None: - if not isinstance(self.options, Iterable): - raise LcmGridInitializationError( - "options must be an iterable of scalar int or float values", - ) - - errors = _validate_discrete_grid(self.options) - if errors: - raise LcmGridInitializationError(_format_errors(errors)) - - def replace(self, options: Iterable[ScalarUserInput]) -> "DiscreteGrid": - """Replace the grid with new values. - - Args: - options: The new options in the grid. - - Returns: - The updated grid. - - """ - return dc.replace(self, options=options) - - -@dc.dataclass(frozen=True) -class ContinuousGrid(Grid): - """LCM Continuous Grid base class.""" - - start: ScalarUserInput - stop: ScalarUserInput - n_points: int - - def __post_init__(self) -> None: - errors = _validate_continuous_grid( - start=self.start, - stop=self.stop, - n_points=self.n_points, - ) - if errors: - raise LcmGridInitializationError(_format_errors(errors)) - - def replace(self, **kwargs) -> Self: - """Replace the grid with new values. - - Args: - **kwargs: - - start: The new start value of the grid. - - stop: The new stop value of the grid. - - n_points: The new number of points in the grid. - - Returns: - The updated grid. - - - """ - return dc.replace(self, **kwargs) - - -class LinspaceGrid(ContinuousGrid): - """A linear grid of continuous values. - - Attributes: - start: The start value of the grid. Must be a scalar int or float value. - stop: The stop value of the grid. Must be a scalar int or float value. - n_points: The number of points in the grid. Must be an int greater than 0. - - """ - - -class LogspaceGrid(ContinuousGrid): - """A logarithmic grid of continuous values. - - Attributes: - start: The start value of the grid. Must be a scalar int or float value. - stop: The stop value of the grid. Must be a scalar int or float value. - n_points: The number of points in the grid. Must be an int greater than 0. - - """ - - -# ====================================================================================== -# Validate user input -# ====================================================================================== - - -class LcmGridInitializationError(Exception): - """Raised when there is an error in the grid initialization.""" - - -def _format_errors(errors: list[str]) -> str: - """Convert list of error messages into a single string. - - If list is empty, returns the empty string. - - """ - if len(errors) == 0: - formatted = "" - elif len(errors) == 1: - formatted = errors[0] - else: - enumerated = "\n\n".join([f"{i}. {error}" for i, error in enumerate(errors, 1)]) - formatted = f"The following errors occurred:\n\n{enumerated}" - return formatted - - -# Discrete grid -# ====================================================================================== - - -def _validate_discrete_grid(options: list[ScalarUserInput]) -> list[str]: - """Validate the discrete grid options. - - Args: - options: The user options to validate. - - Returns: - list[str]: A list of error messages. - - """ - error_messages = [] - - if not len(options) > 0: - error_messages.append("options must contain at least one element") - - if not all(isinstance(option, get_args(ScalarUserInput)) for option in options): - error_messages.append("options must contain only scalar int or float values") - - if len(options) != len(set(options)): - error_messages.append("options must contain unique values") - - return error_messages - - -# Continuous grid -# ====================================================================================== - - -def _validate_continuous_grid( - start: ScalarUserInput, - stop: ScalarUserInput, - n_points: int, -) -> list[str]: - """Validate the continuous grid parameters. - - Args: - start: The start value of the grid. - stop: The stop value of the grid. - n_points: The number of points in the grid. - - Returns: - list[str]: A list of error messages. - - """ - error_messages = [] - - if not (valid_start_type := isinstance(start, get_args(ScalarUserInput))): - error_messages.append("start must be a scalar int or float value") - - if not (valid_stop_type := isinstance(stop, get_args(ScalarUserInput))): - error_messages.append("stop must be a scalar int or float value") - - if not isinstance(n_points, int) or n_points < 1: - error_messages.append( - f"n_points must be an int greater than 0 but is {n_points}", - ) - - if valid_start_type and valid_stop_type and start >= stop: - error_messages.append("start must be less than stop") - - return error_messages diff --git a/src/lcm/user_input.py b/src/lcm/user_input.py new file mode 100644 index 0000000..5bf4c43 --- /dev/null +++ b/src/lcm/user_input.py @@ -0,0 +1,322 @@ +"""Collection of classes that are used by the user to define the model and grids.""" + +import dataclasses as dc +from collections.abc import Callable, Collection +from dataclasses import KW_ONLY, InitVar, dataclass, field +from typing import Self, get_args + +from lcm.typing import ScalarUserInput + + +class Grid: + """LCM Grid base class.""" + + +@dataclass(frozen=True) +class Model: + """A user model which can be processed into an internal model. + + Attributes: + description: Description of the model. + n_periods: Number of periods in the model. + functions: Dictionary of user provided functions that define the functional + relationships between model variables. It must include at least a function + called 'utility'. + choices: Dictionary of user provided choices. + states: Dictionary of user provided states. + + """ + + description: str | None = None + _: KW_ONLY + n_periods: int + functions: dict[str, Callable] = field(default_factory=dict) + choices: dict[str, Grid] = field(default_factory=dict) + states: dict[str, Grid] = field(default_factory=dict) + _skip_checks: InitVar[bool] = False + + def __post_init__(self, _skip_checks: bool) -> None: + if _skip_checks: + return + + type_errors = _validate_model_attribute_types(self) + if type_errors: + raise LcmModelInitializationError(_format_errors(type_errors)) + + logical_errors = _validate_logical_consistency_model(self) + if logical_errors: + raise LcmModelInitializationError(_format_errors(logical_errors)) + + def replace(self, **kwargs) -> "Model": + """Replace the attributes of the model. + + Args: + **kwargs: Keyword arguments to replace the attributes of the model. + + Returns: + A new model with the replaced attributes. + + """ + return dc.replace(self, **kwargs) + + +@dataclass(frozen=True) +class DiscreteGrid(Grid): + """A grid of discrete values. + + Attributes: + options: The options in the grid. Must be an iterable of scalar int or float + values. + + """ + + options: Collection[ScalarUserInput] + + def __post_init__(self) -> None: + if not isinstance(self.options, Collection): + raise LcmGridInitializationError( + "options must be a collection of scalar int or float values, e.g., a ", + "list or tuple", + ) + + errors = _validate_discrete_grid(self.options) + if errors: + raise LcmGridInitializationError(_format_errors(errors)) + + def replace(self, options: Collection[ScalarUserInput]) -> "DiscreteGrid": + """Replace the grid with new values. + + Args: + options: The new options in the grid. + + Returns: + The updated grid. + + """ + return dc.replace(self, options=options) + + +@dataclass(frozen=True) +class ContinuousGrid(Grid): + """LCM Continuous Grid base class.""" + + start: ScalarUserInput + stop: ScalarUserInput + n_points: int + + def __post_init__(self) -> None: + errors = _validate_continuous_grid( + start=self.start, + stop=self.stop, + n_points=self.n_points, + ) + if errors: + raise LcmGridInitializationError(_format_errors(errors)) + + def replace(self, **kwargs) -> Self: + """Replace the grid with new values. + + Args: + **kwargs: + - start: The new start value of the grid. + - stop: The new stop value of the grid. + - n_points: The new number of points in the grid. + + Returns: + The updated grid. + + + """ + return dc.replace(self, **kwargs) + + +class LinspaceGrid(ContinuousGrid): + """A linear grid of continuous values. + + Example: + -------- + Let `start = 1`, `stop = 100`, and `n_points = 3`. The grid is `[1, 50.5, 100]`. + + Attributes: + start: The start value of the grid. Must be a scalar int or float value. + stop: The stop value of the grid. Must be a scalar int or float value. + n_points: The number of points in the grid. Must be an int greater than 0. + + """ + + +class LogspaceGrid(ContinuousGrid): + """A logarithmic grid of continuous values. + + Example: + -------- + Let `start = 1`, `stop = 100`, and `n_points = 3`. The grid is `[1, 10, 100]`. + + Attributes: + start: The start value of the grid. Must be a scalar int or float value. + stop: The stop value of the grid. Must be a scalar int or float value. + n_points: The number of points in the grid. Must be an int greater than 0. + + """ + + +# ====================================================================================== +# Validate user input +# ====================================================================================== + + +class LcmModelInitializationError(Exception): + """Raised when there is an error in the model initialization.""" + + +class LcmGridInitializationError(Exception): + """Raised when there is an error in the grid initialization.""" + + +def _format_errors(errors: list[str]) -> str: + """Convert list of error messages into a single string. + + If list is empty, returns the empty string. + + """ + if len(errors) == 0: + formatted = "" + elif len(errors) == 1: + formatted = errors[0] + else: + enumerated = "\n\n".join([f"{i}. {error}" for i, error in enumerate(errors, 1)]) + formatted = f"The following errors occurred:\n\n{enumerated}" + return formatted + + +# Model +# ====================================================================================== + + +def _validate_model_attribute_types(model: Model) -> list[str]: + """Validate the types of the model attributes.""" + error_messages = [] + + # Validate types of states and choices + # ---------------------------------------------------------------------------------- + for attr_name in ("choices", "states"): + attr = getattr(model, attr_name) + if not isinstance(attr, dict): + error_messages.append(f"{attr_name} must be a dictionary.") + else: + for k, v in attr.items(): + if not isinstance(k, str): + error_messages.append(f"{attr_name} key {k} must be a string.") + if not isinstance(v, Grid): + error_messages.append(f"{attr_name} value {v} must be a LCM grid.") + + # Validate types of functions + # ---------------------------------------------------------------------------------- + if not isinstance(model.functions, dict): + error_messages.append("functions must be a dictionary.") + else: + for k, v in model.functions.items(): + if not isinstance(k, str): + error_messages.append(f"functions key {k} must be a string.") + if not callable(v): + error_messages.append(f"functions value {v} must be a callable.") + + return error_messages + + +def _validate_logical_consistency_model(model: Model) -> list[str]: + """Validate the logical consistency of the model.""" + error_messages = [] + + if model.n_periods < 1: + error_messages.append("Number of periods must be a positive integer.") + + if "utility" not in model.functions: + error_messages.append( + "Utility function is not defined. LCM expects a function called 'utility'" + "in the functions dictionary.", + ) + + if states_without_next_func := [ + state for state in model.states if f"next_{state}" not in model.functions + ]: + error_messages.append( + "Each state must have a corresponding next state function. For the " + "following states, no next state function was found: " + f"{states_without_next_func}.", + ) + + if states_and_choices_overlap := set(model.states) & set(model.choices): + error_messages.append( + "States and choices cannot have overlapping names. The following names " + f"are used in both states and choices: {states_and_choices_overlap}.", + ) + + return error_messages + + +# Discrete grid +# ====================================================================================== + + +def _validate_discrete_grid(options: Collection[ScalarUserInput]) -> list[str]: + """Validate the discrete grid options. + + Args: + options: The user options to validate. + + Returns: + list[str]: A list of error messages. + + """ + error_messages = [] + + if not len(options) > 0: + error_messages.append("options must contain at least one element") + + if not all(isinstance(option, get_args(ScalarUserInput)) for option in options): + error_messages.append("options must contain only scalar int or float values") + + if len(options) != len(set(options)): + error_messages.append("options must contain unique values") + + return error_messages + + +# Continuous grid +# ====================================================================================== + + +def _validate_continuous_grid( + start: ScalarUserInput, + stop: ScalarUserInput, + n_points: int, +) -> list[str]: + """Validate the continuous grid parameters. + + Args: + start: The start value of the grid. + stop: The stop value of the grid. + n_points: The number of points in the grid. + + Returns: + list[str]: A list of error messages. + + """ + error_messages = [] + + if not (valid_start_type := isinstance(start, get_args(ScalarUserInput))): + error_messages.append("start must be a scalar int or float value") + + if not (valid_stop_type := isinstance(stop, get_args(ScalarUserInput))): + error_messages.append("stop must be a scalar int or float value") + + if not isinstance(n_points, int) or n_points < 1: + error_messages.append( + f"n_points must be an int greater than 0 but is {n_points}", + ) + + if valid_start_type and valid_stop_type and start >= stop: + error_messages.append("start must be less than stop") + + return error_messages diff --git a/src/lcm/user_model.py b/src/lcm/user_model.py deleted file mode 100644 index 4aa9549..0000000 --- a/src/lcm/user_model.py +++ /dev/null @@ -1,140 +0,0 @@ -import dataclasses as dc -from collections.abc import Callable -from dataclasses import KW_ONLY, InitVar, dataclass, field - -from lcm.user_grids import Grid - - -@dataclass(frozen=True) -class Model: - """A user model which can be processed into an internal model. - - Attributes: - description: Description of the model. - n_periods: Number of periods in the model. - functions: Dictionary of user provided functions that define the functional - relationships between model variables. It must include at least a function - called 'utility'. - choices: Dictionary of user provided choices. - states: Dictionary of user provided states. - - """ - - description: str | None = None - _: KW_ONLY - n_periods: int - functions: dict[str, Callable] = field(default_factory=dict) - choices: dict[str, Grid] = field(default_factory=dict) - states: dict[str, Grid] = field(default_factory=dict) - _skip_checks: InitVar[bool] = False - - def __post_init__(self, _skip_checks: bool) -> None: - if _skip_checks: - return - - type_errors = _validate_model_attribute_types(self) - if type_errors: - raise LcmModelInitializationError(_format_errors(type_errors)) - - logical_errors = _validate_logical_consistency_model(self) - if logical_errors: - raise LcmModelInitializationError(_format_errors(logical_errors)) - - def replace(self, **kwargs) -> "Model": - """Replace the attributes of the model. - - Args: - **kwargs: Keyword arguments to replace the attributes of the model. - - Returns: - A new model with the replaced attributes. - - """ - return dc.replace(self, **kwargs) - - -# ====================================================================================== -# Validate user input -# ====================================================================================== - - -class LcmModelInitializationError(Exception): - """Raised when there is an error in the model initialization.""" - - -def _format_errors(errors: list[str]) -> str: - """Convert list of error messages into a single string. - - If list is empty, returns the empty string. - - """ - if len(errors) == 0: - formatted = "" - elif len(errors) == 1: - formatted = errors[0] - else: - enumerated = "\n\n".join([f"{i}. {error}" for i, error in enumerate(errors, 1)]) - formatted = f"The following errors occurred:\n\n{enumerated}" - return formatted - - -def _validate_model_attribute_types(model: Model) -> list[str]: - """Validate the types of the model attributes.""" - error_messages = [] - - # Validate types of states and choices - # ---------------------------------------------------------------------------------- - for attr_name in ("choices", "states"): - attr = getattr(model, attr_name) - if not isinstance(attr, dict): - error_messages.append(f"{attr_name} must be a dictionary.") - else: - for k, v in attr.items(): - if not isinstance(k, str): - error_messages.append(f"{attr_name} key {k} must be a string.") - if not isinstance(v, Grid): - error_messages.append(f"{attr_name} value {v} must be a LCM grid.") - - # Validate types of functions - # ---------------------------------------------------------------------------------- - if not isinstance(model.functions, dict): - error_messages.append("functions must be a dictionary.") - else: - for k, v in model.functions.items(): - if not isinstance(k, str): - error_messages.append(f"functions key {k} must be a string.") - if not callable(v): - error_messages.append(f"functions value {v} must be a callable.") - - return error_messages - - -def _validate_logical_consistency_model(model: Model) -> list[str]: - """Validate the logical consistency of the model.""" - error_messages = [] - - if model.n_periods < 1: - error_messages.append("Number of periods must be a positive integer.") - - if "utility" not in model.functions: - error_messages.append( - "Utility function is not defined. LCM expects a function called 'utility'" - "in the functions dictionary.", - ) - - if states_without_next_func := [ - state for state in model.states if f"next_{state}" not in model.functions - ]: - error_messages.append( - "Each state must have a corresponding next state function. For the " - "following states, no next state function was found: " - f"{states_without_next_func}.", - ) - - if states_and_choices_overlap := set(model.states) & set(model.choices): - error_messages.append( - "States and choices cannot have overlapping names. The following names " - f"are used in both states and choices: {states_and_choices_overlap}.", - ) - - return error_messages diff --git a/tests/test_create_params.py b/tests/test_create_params.py index a3dc96f..5f04b13 100644 --- a/tests/test_create_params.py +++ b/tests/test_create_params.py @@ -6,7 +6,7 @@ _create_stochastic_transition_params, create_params_template, ) -from lcm.user_model import Model +from lcm.user_input import Model from numpy.testing import assert_equal diff --git a/tests/test_simulate.py b/tests/test_simulate.py index 32b54c7..6d77ac8 100644 --- a/tests/test_simulate.py +++ b/tests/test_simulate.py @@ -160,7 +160,6 @@ def test_simulate_using_get_lcm_function( assert_array_equal(res.loc[last_period_index, :]["retirement"], 1) for period in range(n_periods): - # assert that higher wealth leads to higher consumption in each period assert (res.loc[period]["consumption"].diff()[1:] >= 0).all() @@ -255,7 +254,6 @@ def test_effect_of_disutility_of_work(): # Asserting # ================================================================================== for period in range(5): - # We expect that individuals with lower disutility of work, work (weakly) more # and thus consume (weakly) more assert (