From 357ec8ec213beca6715e88a1857e28ace13ff4bf Mon Sep 17 00:00:00 2001 From: derrick chambers Date: Sat, 17 Aug 2024 15:57:51 -0700 Subject: [PATCH] update linter --- .pre-commit-config.yaml | 45 ++++++++------------- pyproject.toml | 73 ++++++++++++++++++++++++++++++++++ scripts/profile_dbscan1d.ipynb | 41 +++++++++++-------- src/dbscan1d/core.py | 9 +++-- src/dbscan1d/version.py | 5 ++- tests/test_dbscan1d.py | 14 ++++--- 6 files changed, 130 insertions(+), 57 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 01a7efb..a12d5c6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,38 +1,27 @@ +exclude: scripts/ repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v2.3.0 hooks: - id: check-yaml - id: end-of-file-fixer - - id: trailing-whitespace + - id: check-merge-conflict - id: mixed-line-ending args: ['--fix=lf'] -- repo: https://github.com/psf/black - rev: 22.6.0 + + # Ruff is a replacement for flake8 and many other linters (much faster too) +- repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.6.1 hooks: - - id: black -- repo: https://github.com/PyCQA/flake8 - rev: 3.8.3 + - id: ruff + args: ["--fix"] + # Run the formatter. + - id: ruff-format + + # ensures __future__ import annotations at top of files which require it + # for the typing features they are using. +- repo: https://github.com/frostming/fix-future-annotations + rev: 0.5.0 hooks: - - id: flake8 - additional_dependencies: - - flake8-black - - flake8-breakpoint - - flake8-docstrings -- repo: https://github.com/pycqa/isort - rev: 5.12.0 - hooks: - - id: isort - name: isort (python) - args: ["--profile", "black"] - - id: isort - name: isort (cython) - types: [cython] - - id: isort - name: isort (pyi) - types: [pyi] -- repo: https://github.com/kynan/nbstripout - rev: 0.3.9 - hooks: - - id: nbstripout - files: ".ipynb" + - id: fix-future-annotations diff --git a/pyproject.toml b/pyproject.toml index b8af7b2..b88af4b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,3 +64,76 @@ dev = ["dbscan1d[test]"] "Bug Tracker" = "https://github.com/d-chambers/dbscan1d/issues" "Documentation" = "https://github.com/d-chambers/dbscan1d" "Homepage" = "https://github.com/d-chambers/dbscan1d" + +# --- formatting + +[tool.ruff] + +line-length = 88 + +# enable certain types of linting +lint.select = [ + "E", + "F", + "UP", + "RUF", + "I001", + "D", + "FA", + "T", + "N", + "NPY", + "NPY201", +] + +# Exclude a variety of commonly ignored directories. +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".mypy_cache", + ".nox", + ".pants.d", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "venv", + "__init__.py" +] + +# lowest python version supported +target-version = "py310" + +lint.fixable = ["ALL"] + +# List of codes to ignore +lint.ignore = ["D105", "D107", "D401", "D205", "D200", "D400", "N803", "N806"] + +[tool.ruff.lint.mccabe] +# Unlike Flake8, default to a complexity level of 10. +max-complexity = 10 + +# config for docstring parsing +[tool.ruff.lint.pydocstyle] +convention = "numpy" + +[tool.pytest.ini_options] +filterwarnings = [ + # Ignore hdf5 warnings from pytables, See pytables #1035 + 'ignore::Warning:tables:' +] + +[tool.ruff.format] +# Use `\n` line endings for all files +line-ending = "lf" diff --git a/scripts/profile_dbscan1d.ipynb b/scripts/profile_dbscan1d.ipynb index ba04bd2..97d614d 100644 --- a/scripts/profile_dbscan1d.ipynb +++ b/scripts/profile_dbscan1d.ipynb @@ -18,13 +18,20 @@ "\n", "import matplotlib.pyplot as plt\n", "import pandas as pd\n", - "\n", - "from dbscan1d import DBSCAN1D\n", "from sklearn.cluster import DBSCAN\n", "from sklearn.datasets import make_blobs\n", "\n", + "from dbscan1d import DBSCAN1D\n", "\n", - "n_points = [10, 100, 1_000, 10_000, 20_000, 30_000, 40_000,]\n", + "n_points = [\n", + " 10,\n", + " 100,\n", + " 1_000,\n", + " 10_000,\n", + " 20_000,\n", + " 30_000,\n", + " 40_000,\n", + "]\n", "centers = 2" ] }, @@ -46,8 +53,8 @@ "outputs": [], "source": [ "# Profile\n", - "db1 = DBSCAN1D(.5, 4)\n", - "db2 = DBSCAN(.5, 4)" + "db1 = DBSCAN1D(0.5, 4)\n", + "db2 = DBSCAN(0.5, 4)" ] }, { @@ -57,16 +64,16 @@ "outputs": [], "source": [ "# profile each stream type with each function\n", - "df = pd.DataFrame(columns=['dbscan', 'dbscan1d'], index=n_points)\n", + "df = pd.DataFrame(columns=[\"dbscan\", \"dbscan1d\"], index=n_points)\n", "for n_point in n_points:\n", - " print(f'on {n_point}')\n", + " print(f\"on {n_point}\")\n", " X = create_blobs(n_point, centers)\n", - " print('starting dbscan1d')\n", + " print(\"starting dbscan1d\")\n", " ti1 = %timeit -o db1.fit_predict(X)\n", - " df.loc[n_point, 'dbscan1d'] = ti1.best\n", - " print('starting dbscan')\n", + " df.loc[n_point, \"dbscan1d\"] = ti1.best\n", + " print(\"starting dbscan\")\n", " ti2 = %timeit -o db2.fit_predict(X)\n", - " df.loc[n_point, 'dbscan'] = ti2.best\n", + " df.loc[n_point, \"dbscan\"] = ti2.best\n", " print()\n", " print()" ] @@ -89,20 +96,20 @@ }, "outputs": [], "source": [ - "out_path = Path(__file__).parent / 'profile_results.png'\n", + "out_path = Path(__file__).parent / \"profile_results.png\"\n", "\n", "x = df.index.values\n", - "plt.loglog(x, df['dbscan'].values, label='dbscan', color='r')\n", - "plt.loglog(x, df['dbscan1d'].values, label='dbscan1d', color='b')\n", + "plt.loglog(x, df[\"dbscan\"].values, label=\"dbscan\", color=\"r\")\n", + "plt.loglog(x, df[\"dbscan1d\"].values, label=\"dbscan1d\", color=\"b\")\n", "\n", - "plt.xlabel('number of points')\n", - "plt.ylabel('run time (s)')\n", + "plt.xlabel(\"number of points\")\n", + "plt.ylabel(\"run time (s)\")\n", "\n", "plt.legend()\n", "\n", "plt.savefig(out_path)\n", "\n", - "plt.show()\n" + "plt.show()" ] } ], diff --git a/src/dbscan1d/core.py b/src/dbscan1d/core.py index 8bfccd3..c33ad29 100644 --- a/src/dbscan1d/core.py +++ b/src/dbscan1d/core.py @@ -3,7 +3,8 @@ It should be *much* more efficient for large datasets. """ -from typing import Optional + +from __future__ import annotations import numpy as np @@ -17,9 +18,9 @@ class DBSCAN1D: """ # params that change upon fit/training - core_sample_indices_: Optional[np.ndarray] = None - components_: Optional[np.ndarray] = None - labels_: Optional[np.ndarray] = None + core_sample_indices_: np.ndarray | None = None + components_: np.ndarray | None = None + labels_: np.ndarray | None = None def __init__(self, eps: float = 0.5, min_samples: int = 5, metric="euclidean"): self.eps = eps diff --git a/src/dbscan1d/version.py b/src/dbscan1d/version.py index 9826595..b1023cd 100644 --- a/src/dbscan1d/version.py +++ b/src/dbscan1d/version.py @@ -1,8 +1,9 @@ """Module for reporting the version of dbscan1d.""" + from importlib.metadata import PackageNotFoundError, version try: __version__ = version("dbscan1d") # package is not installed -except PackageNotFoundError: # NOQA - __version__ = "0.0.0" # NOQA +except PackageNotFoundError: + __version__ = "0.0.0" diff --git a/tests/test_dbscan1d.py b/tests/test_dbscan1d.py index cb37fb0..eb2ff5f 100644 --- a/tests/test_dbscan1d.py +++ b/tests/test_dbscan1d.py @@ -3,6 +3,7 @@ Requires sklearn. """ + import copy from itertools import product from pathlib import Path @@ -122,7 +123,8 @@ def generate_test_data(num_points, centers=None): num_points, n_features=1, centers=centers, random_state=13 ) X = blobs.flatten() - np.random.shuffle(X) + rng = np.random.default_rng() + rng.shuffle(X) return X, blob_labels @@ -139,18 +141,18 @@ class TestSKleanEquivalent: # define a small range of dbscan input params over which tests will # be parametrized - eps_values = [0.0001, 0.1, 0.5, 1, 2] - min_samples_values = [1, 2, 5, 15] - db_params = list(product(eps_values, min_samples_values)) + eps_values = (0.0001, 0.1, 0.5, 1, 2) + min_samples_values = (1, 2, 5, 15) + db_params = tuple(product(eps_values, min_samples_values)) - centers = [ + centers = ( np.array([0, 5, 10]), np.arange(10), np.array([1, 2, 3, 4, 5, 10]), np.array([1, 1.1, 1.2, 1.3, 1.4, 1.5]), 2, 7, - ] + ) @pytest.fixture(scope="class", params=centers) def blobs(self, request):