From 7d25a6b48bf2ee9b736e05d19790927a45bcd626 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mart=C3=AD=20Zamora?= Date: Wed, 8 Jan 2025 12:39:52 +0100 Subject: [PATCH] Pre-commit hooks for pyright and mypy (only API tests) (#283) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Added pre-commit hooks for pyright and mypy (only API tests) * Removed unnecessary noqa * Replaced tensorflow for numpy in type tests * Try tweaking pre-commit-config --------- Co-authored-by: Martí Zamora Casals Co-authored-by: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> --- .pre-commit-config.yaml | 15 +++++++++++++++ test/types/__init__.py | 0 test/types/decorator.py | 31 +++++++++++++++++++++++++++++++ 3 files changed, 46 insertions(+) create mode 100644 test/types/__init__.py create mode 100644 test/types/decorator.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1a396ed..682df2b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,3 +26,18 @@ repos: - id: ruff # linter types_or: [ python, pyi, jupyter ] args: [ --fix ] + - repo: https://github.com/RobertCraigie/pyright-python + rev: v1.1.391 + hooks: + - id: pyright + files: ^test/types/ + additional_dependencies: + [beartype, numpy<2] + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.14.1 + hooks: + - id: mypy + files: ^test/types/ + additional_dependencies: + [beartype, numpy<2] + args: ["--ignore-missing-imports", "--follow-imports=skip"] \ No newline at end of file diff --git a/test/types/__init__.py b/test/types/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/types/decorator.py b/test/types/decorator.py new file mode 100644 index 0000000..1d1ccda --- /dev/null +++ b/test/types/decorator.py @@ -0,0 +1,31 @@ +from dataclasses import dataclass + +import numpy as np +from beartype import beartype + +from jaxtyping import Float, Int, jaxtyped + + +@jaxtyped(typechecker=beartype) +@dataclass +class User: + name: str + age: int + items: Float[np.ndarray, " N"] + timestamps: Int[np.ndarray, " N"] + + +@jaxtyped(typechecker=beartype) +def transform_user(user: User, increment_age: int = 1) -> User: + user.age += increment_age + return user + + +user = User( + name="John", + age=20, + items=np.random.normal(size=10), + timestamps=np.random.randint(0, 100, size=10), +) + +new_user = transform_user(user, increment_age=2)