Skip to content

Commit

Permalink
🧹 Fix pre-commit issues outside of result saving
Browse files Browse the repository at this point in the history
  • Loading branch information
s-weigand committed Oct 18, 2024
1 parent 8c52c70 commit fdec279
Show file tree
Hide file tree
Showing 16 changed files with 58 additions and 54 deletions.
3 changes: 3 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,15 @@ repos:
- "--fix"
- "--select=I"
alias: isort
types_or: [python, pyi]
- id: ruff-format
types_or: [python, pyi]
# Commands above are both formatters an not linters
# See also: https://github.com/astral-sh/ruff/discussions/7310#discussioncomment-7102010
- id: ruff
name: "ruff lint"
exclude: "docs/source/conf.py"
types_or: [python, pyi]

# - repo: https://github.com/PyCQA/flake8
# rev: 6.0.0
Expand Down
5 changes: 4 additions & 1 deletion .ruff-notebooks.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@

extend = ".ruff.toml"

# Enable using ruff with notebooks
include = [ "*.ipynb", "**/pyproject.toml" ]

[lint]
extend-ignore = [ "D", "E402", "F404", "N811", "E703", "T201" ]
extend-ignore = [ "D", "E402", "N811", "E703", "T201" ]

[lint.isort]
required-imports = [ ]
Expand Down
5 changes: 2 additions & 3 deletions .ruff.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
extend-exclude = [ "venv", "docs/conf.py" ]
include = [ "*.py", "*.pyi", "**/pyproject.toml" ]

line-length = 99

# Assume Python 3.10.
target-version = "py310"

# Enable using ruff with notebooks
extend-include = [ "*.ipynb" ]

[lint]
select = [
Expand Down Expand Up @@ -68,7 +67,7 @@ unfixable = [ "F401" ]
"*/test_*.py" = [ "ARG001", "RUF012", "N811", "T20", "PIE804" ]
"*/__init__.py" = [ "F401" ]
"glotaran/builtin/io/netCDF/netCDF.py" = [ "N999" ]
"docs/*" = [ "INP001" ]
"docs/*" = [ "INP001", "E501", "BLE001" ]
# Needs a full rewrite anyway
"glotaran/builtin/io/ascii/wavelength_time_explicit_file.py" = [ "PTH" ]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
"\n",
"First you need to import all needed libraries and functions.\n",
"\n",
"- `from __future__ import annotations`: needed to write python 3.10 typing syntax (`|`), even with a lower python version\n",
"- `json`,`xarray`: Needed for reading and writing itself\n",
"- `DataIoInterface`: needed to subclass from, this way you get the proper type and especially signature checking\n",
"- `register_data_io`: registers the DataIo plugin under the given `format_name`s"
Expand All @@ -58,8 +57,6 @@
"metadata": {},
"outputs": [],
"source": [
"from __future__ import annotations\n",
"\n",
"import json\n",
"from pathlib import Path\n",
"\n",
Expand Down Expand Up @@ -168,8 +165,7 @@
"metadata": {},
"outputs": [],
"source": [
"from glotaran.io import load_dataset\n",
"from glotaran.io import save_dataset\n",
"from glotaran.io import load_dataset, save_dataset\n",
"from glotaran.testing.simulated_data.sequential_spectral_decay import DATASET as dataset"
]
},
Expand Down Expand Up @@ -326,7 +322,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": ".venv",
"language": "python",
"name": "python3"
},
Expand Down
7 changes: 5 additions & 2 deletions glotaran/builtin/items/activation/data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import TYPE_CHECKING
from typing import cast

import attr
import numpy as np
import xarray as xr

Expand Down Expand Up @@ -77,7 +76,11 @@ def create_result(
if activation.dispersion_center is not None
else activation.center * global_axis.size
)
props = [asdict(p) for p in activation.parameters()]
# Since we don't pass the ``global_axis`` the type ambiguity is resolved
props = [
asdict(p)
for p in cast(list[GaussianActivationParameters], activation.parameters())
]
result[key] = xr.Dataset(
{
"trace": xr.DataArray(
Expand Down
8 changes: 6 additions & 2 deletions glotaran/model/data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import TYPE_CHECKING
from typing import Any
from typing import Literal
from typing import cast
from uuid import uuid4

import xarray as xr # noqa: TCH002
Expand Down Expand Up @@ -201,8 +202,11 @@ class DataModel(Item):
@staticmethod
def create_class_for_elements(elements: set[type[Element]]) -> type[DataModel]:
data_model_cls_name = f"GlotaranDataModel_{str(uuid4()).replace('-','_')}"
data_models = (
*tuple({e.data_model_type for e in elements if e.data_model_type is not None}),
data_models: tuple[type[DataModel], ...] = (
*cast(
tuple[type[DataModel], ...],
tuple({e.data_model_type for e in elements if e.data_model_type is not None}),
),
DataModel,
)
return create_model(data_model_cls_name, __base__=data_models)
Expand Down
7 changes: 3 additions & 4 deletions glotaran/model/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,10 @@
from __future__ import annotations

import abc
from dataclasses import dataclass
from dataclasses import field
from typing import TYPE_CHECKING
from typing import Any
from typing import ClassVar

import xarray as xr
from pydantic import ConfigDict
from pydantic import Field

Expand All @@ -18,6 +15,8 @@
from glotaran.plugin_system.element_registration import register_element

if TYPE_CHECKING:
import xarray as xr

from glotaran.model.data_model import DataModel
from glotaran.typing.types import ArrayLike

Expand All @@ -37,7 +36,7 @@ def _sanitize_json_schema(json_schema: dict[str, Any]) -> None:
class Element(TypedItem, abc.ABC):
"""Subclasses must overwrite :method:`glotaran.model.Element.calculate_matrix`."""

data_model_type: ClassVar[type | None] = None
data_model_type: ClassVar[type[DataModel] | None] = None # type: ignore[valid-type]
is_exclusive: ClassVar[bool] = False
is_unique: ClassVar[bool] = False
register_as: ClassVar[str | None] = None
Expand Down
26 changes: 17 additions & 9 deletions glotaran/optimization/objective.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from collections.abc import Iterable
from dataclasses import dataclass
from typing import TYPE_CHECKING
from typing import Any
Expand Down Expand Up @@ -76,7 +77,7 @@ def __init__(self, model: ExperimentModel):
def calculate_matrices(self) -> list[OptimizationMatrix]:
if isinstance(self._data, OptimizationData):
return OptimizationMatrix.from_data(self._data).as_global_list(self._data.global_axis)
return OptimizationMatrix.from_linked_data(self._data) # type:ignore[arg-type]
return OptimizationMatrix.from_linked_data(self._data)

def calculate_reduced_matrices(
self, matrices: list[OptimizationMatrix]
Expand Down Expand Up @@ -106,10 +107,12 @@ def resolve_estimations(
]

def calculate_global_penalty(self) -> ArrayLike:
_, _, matrix = OptimizationMatrix.from_global_data(self._data) # type:ignore[arg-type]
assert isinstance(self._data, OptimizationData)
assert self._data.flat_data is not None
_, _, matrix = OptimizationMatrix.from_global_data(self._data)
return OptimizationEstimation.calculate(
matrix.array,
self._data.flat_data, # type:ignore[attr-defined]
self._data.flat_data,
self._model.residual_function,
).residual

Expand Down Expand Up @@ -157,6 +160,7 @@ def create_result_dataset(self, label: str, data: OptimizationData) -> xr.Datase

def create_global_result(self) -> OptimizationObjectiveResult:
label = next(iter(self._model.datasets.keys()))
assert isinstance(self._data, OptimizationData)
result_dataset = self.create_result_dataset(label, self._data)

global_dim = result_dataset.attrs["global_dimension"]
Expand All @@ -165,12 +169,14 @@ def create_global_result(self) -> OptimizationObjectiveResult:
model_axis = result_dataset.coords[model_dim]

matrix = OptimizationMatrix.from_data(self._data).to_data_array(
global_dim, global_axis, model_dim, model_axis
global_dim, global_axis.to_numpy(), model_dim, model_axis.to_numpy()
)
global_matrix = OptimizationMatrix.from_data(self._data, global_matrix=True).to_data_array(
model_dim, model_axis, global_dim, global_axis
model_dim, model_axis.to_numpy(), global_dim, global_axis.to_numpy()
)
_, _, full_matrix = OptimizationMatrix.from_global_data(self._data)

assert self._data.flat_data is not None
estimation = OptimizationEstimation.calculate(
full_matrix.array,
self._data.flat_data,
Expand Down Expand Up @@ -458,16 +464,18 @@ def create_data_model_results(
model_dim: str,
amplitudes: xr.DataArray,
concentrations: xr.DataArray,
) -> xr.Dataset:
result = {}
) -> dict[str, xr.Dataset]:
result: dict[str, xr.Dataset] = {}
data_model = self._model.datasets[label]
assert any(isinstance(e, str) for _, e in iterate_data_model_elements(data_model)) is False
for data_model_cls in {
e.__class__.data_model_type
for _, e in cast(tuple[Any, Element], iterate_data_model_elements(data_model))
for _, e in cast(
Iterable[tuple[Any, Element]], iterate_data_model_elements(data_model)
)
if e.__class__.data_model_type is not None
}:
result = result | data_model_cls.create_result(
result = result | cast(type[DataModel], data_model_cls).create_result(
data_model,
global_dim,
model_dim,
Expand Down
2 changes: 1 addition & 1 deletion glotaran/optimization/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def run(self) -> tuple[Parameters, dict[str, xr.Dataset], OptimizationInfo]:
)
termination_reason = ls_result.message
# No matter the error we want to behave gracefully
except Exception as e: # noqa: BLE001
except Exception as e:
if self._raise:
raise e
warn(f"Optimization failed:\n\n{e}")
Expand Down
2 changes: 1 addition & 1 deletion glotaran/optimization/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def from_least_squares_result(
def calculate_parameter_errors(
optimization_info: OptimizationInfo, parameters: Parameters
) -> None:
"""Calculate and assign standard errors to parameters in place based on optimization information.
"""Calculate and assign standard errors to parameters in place based on ``optimization_info``.
This function calculates the standard errors for the free parameters
based on the provided optimization information and assigns these errors
Expand Down
2 changes: 1 addition & 1 deletion glotaran/project/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from glotaran.model.errors import GlotaranUserError
from glotaran.model.experiment_model import ExperimentModel # noqa: TCH001
from glotaran.optimization import OptimizationInfo # noqa: TCH001
from glotaran.optimization.objective import OptimizationResult
from glotaran.optimization.objective import OptimizationResult # noqa: TCH001
from glotaran.parameter import Parameters # noqa: TCH001

if TYPE_CHECKING:
Expand Down
10 changes: 5 additions & 5 deletions glotaran/utils/sanitize.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,11 @@ def sanity_scientific_notation_conversion(d: dict[str, Any] | list[Any]):
"""
if not isinstance(d, (dict, list)):
return
for k, v in d.items() if isinstance(d, dict) else enumerate(d):
for k, v in d.items() if isinstance(d, dict) else enumerate(d): # type: ignore[union-attr]
if isinstance(v, (list, dict)):
sanity_scientific_notation_conversion(v)
if isinstance(v, str):
d[k] = convert_scientific_to_float(v) # type: ignore[index,call-overload]
d[k] = convert_scientific_to_float(v) # type: ignore[index]


def sanitize_dict_values(d: dict[str, Any] | list[Any]):
Expand All @@ -125,20 +125,20 @@ def sanitize_dict_values(d: dict[str, Any] | list[Any]):
"""
if not isinstance(d, (dict, list)):
return
for k, v in d.items() if isinstance(d, dict) else enumerate(d):
for k, v in d.items() if isinstance(d, dict) else enumerate(d): # type: ignore[union-attr]
if isinstance(v, list):
leaf = all(isinstance(el, (str, tuple, float)) for el in v)
if leaf:
if "(" in str(v):
d[k] = list_string_to_tuple( # type: ignore[index,call-overload]
d[k] = list_string_to_tuple( # type: ignore[index]
sanitize_list_with_broken_tuples(v)
)
else:
sanitize_dict_values(v)
if isinstance(v, dict):
sanitize_dict_values(v)
if isinstance(v, str):
d[k] = string_to_tuple(v) # type: ignore[index,call-overload]
d[k] = string_to_tuple(v) # type: ignore[index]


def string_to_tuple(
Expand Down
2 changes: 1 addition & 1 deletion glotaran/utils/tee.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __enter__(self) -> TeeContext:
TeeContext
Instance that can be read from.
"""
sys.stdout = self # type:ignore[assignment]
sys.stdout = self
return self

def __exit__(
Expand Down
21 changes: 6 additions & 15 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
[build-system]
build-backend = "hatchling.build"
requires = [
"hatchling",
]
requires = [ "hatchling" ]

[project]
name = "pyglotaran"
Expand Down Expand Up @@ -30,9 +28,7 @@ classifiers = [
"Topic :: Scientific/Engineering :: Chemistry",
"Topic :: Scientific/Engineering :: Physics",
]
dynamic = [
"version",
]
dynamic = [ "version" ]
dependencies = [
"asteval!=0.9.28,>=0.9.22",
"netcdf4>=1.5.7",
Expand All @@ -48,9 +44,7 @@ dependencies = [
"tabulate>=0.8.9",
"xarray>=2022.3",
]
optional-dependencies.dev = [
"pyglotaran[docs,test]",
]
optional-dependencies.dev = [ "pyglotaran[docs,test]" ]
optional-dependencies.docs = [
# documentation dependencies
"jinja2<3.2",
Expand All @@ -67,12 +61,8 @@ optional-dependencies.docs = [
"sphinx-rtd-theme>=1.2",
"sphinxcontrib-jquery>=4.1", # Needed for the search to work Ref.: https://github.com/readthedocs/sphinx_rtd_theme/issues/1434
]
optional-dependencies.extras = [
"pyglotaran-extras>=0.5",
]
optional-dependencies.full = [
"pyglotaran[extras]",
]
optional-dependencies.extras = [ "pyglotaran-extras>=0.5" ]
optional-dependencies.full = [ "pyglotaran[extras]" ]
optional-dependencies.test = [
"coverage[toml]",
"ipython>=7.2",
Expand Down Expand Up @@ -172,6 +162,7 @@ show_error_codes = true
warn_unused_configs = true
warn_unused_ignores = true
check_untyped_defs = true
python_version = "3.10"

[[tool.mypy.overrides]]
module = "glotaran.builtin.megacomplexes.*"
Expand Down
2 changes: 1 addition & 1 deletion tests/deprecation/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def deprecation_warning_on_call_test_helper(
except OverDueDeprecationError as error:
raise error

except Exception as error: # noqa: BLE001
except Exception as error:
if raise_exception:
raise error
return record, None
2 changes: 0 additions & 2 deletions tests/model/test_experiment_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

from glotaran.model.clp_constraint import OnlyConstraint
from glotaran.model.clp_constraint import ZeroConstraint
from glotaran.model.data_model import DataModel
from glotaran.model.experiment_model import ExperimentModel
from tests.model.test_data_model import MockDataModel
Expand Down

0 comments on commit fdec279

Please sign in to comment.