From b6c1e5ee68c53d1253834561ee09c69495f232c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mart=C3=AD=20Zamora=20Casals?= <marti@amalfianalytics.com> Date: Fri, 3 Jan 2025 10:56:46 +0100 Subject: [PATCH 1/4] Added pre-commit hooks for pyright and mypy (only API tests) --- .pre-commit-config.yaml | 25 +++++++++++++++++++++++++ test/types/__init__.py | 0 test/types/decorator.py | 31 +++++++++++++++++++++++++++++++ 3 files changed, 56 insertions(+) create mode 100644 test/types/__init__.py create mode 100644 test/types/decorator.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1a396ed..1d7f42b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,3 +26,28 @@ repos: - id: ruff # linter types_or: [ python, pyi, jupyter ] args: [ --fix ] + - repo: https://github.com/RobertCraigie/pyright-python + rev: v1.1.391 + hooks: + - id: pyright + files: ^test/types/ + # must match the Python version used in CI + language_version: python3.11 + additional_dependencies: + [ + beartype, + tensorflow, + ] + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.14.1 + hooks: + - id: mypy + files: ^test/types/ + # must match the Python version used in CI + language_version: python3.11 + additional_dependencies: + [ + beartype, + tensorflow, + ] + args: ["--ignore-missing-imports", "--follow-imports=skip"] \ No newline at end of file diff --git a/test/types/__init__.py b/test/types/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/types/decorator.py b/test/types/decorator.py new file mode 100644 index 0000000..c587359 --- /dev/null +++ b/test/types/decorator.py @@ -0,0 +1,31 @@ +from dataclasses import dataclass + +import tensorflow as tf +from beartype import beartype + +from jaxtyping import Float, Int, jaxtyped + + +@jaxtyped(typechecker=beartype) +@dataclass +class User: + name: str + age: int + items: Float[tf.Tensor, "N"] # noqa: F821 + timestamps: Int[tf.Tensor, "N"] # noqa: F821 + + +@jaxtyped(typechecker=beartype) +def transform_user(user: User, increment_age: int = 1) -> User: + user.age += increment_age + return user + + +user = User( + name="John", + age=20, + items=tf.random.normal([10]), + timestamps=tf.random.uniform([10], minval=0, maxval=100, dtype=tf.int32), +) + +new_user = transform_user(user, increment_age=2) From 315798c63243d781001de3f557d366f22eb5d20e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mart=C3=AD=20Zamora=20Casals?= <marti@amalfianalytics.com> Date: Tue, 7 Jan 2025 10:40:01 +0100 Subject: [PATCH 2/4] Removed unnecessary noqa --- test/types/decorator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/types/decorator.py b/test/types/decorator.py index c587359..102314e 100644 --- a/test/types/decorator.py +++ b/test/types/decorator.py @@ -11,8 +11,8 @@ class User: name: str age: int - items: Float[tf.Tensor, "N"] # noqa: F821 - timestamps: Int[tf.Tensor, "N"] # noqa: F821 + items: Float[tf.Tensor, " N"] + timestamps: Int[tf.Tensor, " N"] @jaxtyped(typechecker=beartype) From 02e3c19f0a2c23d9e80db1fbdcde3ed361a0d712 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mart=C3=AD=20Zamora=20Casals?= <marti@amalfianalytics.com> Date: Tue, 7 Jan 2025 10:47:57 +0100 Subject: [PATCH 3/4] Replaced tensorflow for numpy in type tests --- .pre-commit-config.yaml | 4 ++-- test/types/decorator.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1d7f42b..4587e7d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -36,7 +36,7 @@ repos: additional_dependencies: [ beartype, - tensorflow, + numpy<2, ] - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.14.1 @@ -48,6 +48,6 @@ repos: additional_dependencies: [ beartype, - tensorflow, + numpy<2, ] args: ["--ignore-missing-imports", "--follow-imports=skip"] \ No newline at end of file diff --git a/test/types/decorator.py b/test/types/decorator.py index 102314e..1d1ccda 100644 --- a/test/types/decorator.py +++ b/test/types/decorator.py @@ -1,6 +1,6 @@ from dataclasses import dataclass -import tensorflow as tf +import numpy as np from beartype import beartype from jaxtyping import Float, Int, jaxtyped @@ -11,8 +11,8 @@ class User: name: str age: int - items: Float[tf.Tensor, " N"] - timestamps: Int[tf.Tensor, " N"] + items: Float[np.ndarray, " N"] + timestamps: Int[np.ndarray, " N"] @jaxtyped(typechecker=beartype) @@ -24,8 +24,8 @@ def transform_user(user: User, increment_age: int = 1) -> User: user = User( name="John", age=20, - items=tf.random.normal([10]), - timestamps=tf.random.uniform([10], minval=0, maxval=100, dtype=tf.int32), + items=np.random.normal(size=10), + timestamps=np.random.randint(0, 100, size=10), ) new_user = transform_user(user, increment_age=2) From 0a0f1ae840ce84c66d61a632cba101c081a3dd50 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Wed, 8 Jan 2025 12:32:33 +0100 Subject: [PATCH 4/4] Try tweaking pre-commit-config --- .pre-commit-config.yaml | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4587e7d..682df2b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,23 +31,13 @@ repos: hooks: - id: pyright files: ^test/types/ - # must match the Python version used in CI - language_version: python3.11 additional_dependencies: - [ - beartype, - numpy<2, - ] + [beartype, numpy<2] - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.14.1 hooks: - id: mypy files: ^test/types/ - # must match the Python version used in CI - language_version: python3.11 additional_dependencies: - [ - beartype, - numpy<2, - ] + [beartype, numpy<2] args: ["--ignore-missing-imports", "--follow-imports=skip"] \ No newline at end of file