From b6c1e5ee68c53d1253834561ee09c69495f232c0 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Mart=C3=AD=20Zamora=20Casals?= <marti@amalfianalytics.com>
Date: Fri, 3 Jan 2025 10:56:46 +0100
Subject: [PATCH 1/4] Added pre-commit hooks for pyright and mypy (only API
 tests)

---
 .pre-commit-config.yaml | 25 +++++++++++++++++++++++++
 test/types/__init__.py  |  0
 test/types/decorator.py | 31 +++++++++++++++++++++++++++++++
 3 files changed, 56 insertions(+)
 create mode 100644 test/types/__init__.py
 create mode 100644 test/types/decorator.py

diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 1a396ed..1d7f42b 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -26,3 +26,28 @@ repos:
       - id: ruff  # linter
         types_or: [ python, pyi, jupyter ]
         args: [ --fix ]
+  - repo: https://github.com/RobertCraigie/pyright-python
+    rev: v1.1.391
+    hooks:
+    - id: pyright
+      files: ^test/types/
+      # must match the Python version used in CI
+      language_version: python3.11
+      additional_dependencies:
+          [
+            beartype,
+            tensorflow,
+          ]
+  - repo: https://github.com/pre-commit/mirrors-mypy
+    rev: v1.14.1  
+    hooks:
+      - id: mypy
+        files: ^test/types/
+        # must match the Python version used in CI
+        language_version: python3.11
+        additional_dependencies:
+          [
+            beartype,
+            tensorflow,
+          ]
+        args: ["--ignore-missing-imports", "--follow-imports=skip"]
\ No newline at end of file
diff --git a/test/types/__init__.py b/test/types/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/test/types/decorator.py b/test/types/decorator.py
new file mode 100644
index 0000000..c587359
--- /dev/null
+++ b/test/types/decorator.py
@@ -0,0 +1,31 @@
+from dataclasses import dataclass
+
+import tensorflow as tf
+from beartype import beartype
+
+from jaxtyping import Float, Int, jaxtyped
+
+
+@jaxtyped(typechecker=beartype)
+@dataclass
+class User:
+    name: str
+    age: int
+    items: Float[tf.Tensor, "N"]  # noqa: F821
+    timestamps: Int[tf.Tensor, "N"]  # noqa: F821
+
+
+@jaxtyped(typechecker=beartype)
+def transform_user(user: User, increment_age: int = 1) -> User:
+    user.age += increment_age
+    return user
+
+
+user = User(
+    name="John",
+    age=20,
+    items=tf.random.normal([10]),
+    timestamps=tf.random.uniform([10], minval=0, maxval=100, dtype=tf.int32),
+)
+
+new_user = transform_user(user, increment_age=2)

From 315798c63243d781001de3f557d366f22eb5d20e Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Mart=C3=AD=20Zamora=20Casals?= <marti@amalfianalytics.com>
Date: Tue, 7 Jan 2025 10:40:01 +0100
Subject: [PATCH 2/4] Removed unnecessary noqa

---
 test/types/decorator.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/test/types/decorator.py b/test/types/decorator.py
index c587359..102314e 100644
--- a/test/types/decorator.py
+++ b/test/types/decorator.py
@@ -11,8 +11,8 @@
 class User:
     name: str
     age: int
-    items: Float[tf.Tensor, "N"]  # noqa: F821
-    timestamps: Int[tf.Tensor, "N"]  # noqa: F821
+    items: Float[tf.Tensor, " N"]
+    timestamps: Int[tf.Tensor, " N"]
 
 
 @jaxtyped(typechecker=beartype)

From 02e3c19f0a2c23d9e80db1fbdcde3ed361a0d712 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Mart=C3=AD=20Zamora=20Casals?= <marti@amalfianalytics.com>
Date: Tue, 7 Jan 2025 10:47:57 +0100
Subject: [PATCH 3/4] Replaced tensorflow for numpy in type tests

---
 .pre-commit-config.yaml |  4 ++--
 test/types/decorator.py | 10 +++++-----
 2 files changed, 7 insertions(+), 7 deletions(-)

diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 1d7f42b..4587e7d 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -36,7 +36,7 @@ repos:
       additional_dependencies:
           [
             beartype,
-            tensorflow,
+            numpy<2,
           ]
   - repo: https://github.com/pre-commit/mirrors-mypy
     rev: v1.14.1  
@@ -48,6 +48,6 @@ repos:
         additional_dependencies:
           [
             beartype,
-            tensorflow,
+            numpy<2,
           ]
         args: ["--ignore-missing-imports", "--follow-imports=skip"]
\ No newline at end of file
diff --git a/test/types/decorator.py b/test/types/decorator.py
index 102314e..1d1ccda 100644
--- a/test/types/decorator.py
+++ b/test/types/decorator.py
@@ -1,6 +1,6 @@
 from dataclasses import dataclass
 
-import tensorflow as tf
+import numpy as np
 from beartype import beartype
 
 from jaxtyping import Float, Int, jaxtyped
@@ -11,8 +11,8 @@
 class User:
     name: str
     age: int
-    items: Float[tf.Tensor, " N"]
-    timestamps: Int[tf.Tensor, " N"]
+    items: Float[np.ndarray, " N"]
+    timestamps: Int[np.ndarray, " N"]
 
 
 @jaxtyped(typechecker=beartype)
@@ -24,8 +24,8 @@ def transform_user(user: User, increment_age: int = 1) -> User:
 user = User(
     name="John",
     age=20,
-    items=tf.random.normal([10]),
-    timestamps=tf.random.uniform([10], minval=0, maxval=100, dtype=tf.int32),
+    items=np.random.normal(size=10),
+    timestamps=np.random.randint(0, 100, size=10),
 )
 
 new_user = transform_user(user, increment_age=2)

From 0a0f1ae840ce84c66d61a632cba101c081a3dd50 Mon Sep 17 00:00:00 2001
From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com>
Date: Wed, 8 Jan 2025 12:32:33 +0100
Subject: [PATCH 4/4] Try tweaking pre-commit-config

---
 .pre-commit-config.yaml | 14 ++------------
 1 file changed, 2 insertions(+), 12 deletions(-)

diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 4587e7d..682df2b 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -31,23 +31,13 @@ repos:
     hooks:
     - id: pyright
       files: ^test/types/
-      # must match the Python version used in CI
-      language_version: python3.11
       additional_dependencies:
-          [
-            beartype,
-            numpy<2,
-          ]
+          [beartype, numpy<2]
   - repo: https://github.com/pre-commit/mirrors-mypy
     rev: v1.14.1  
     hooks:
       - id: mypy
         files: ^test/types/
-        # must match the Python version used in CI
-        language_version: python3.11
         additional_dependencies:
-          [
-            beartype,
-            numpy<2,
-          ]
+          [beartype, numpy<2]
         args: ["--ignore-missing-imports", "--follow-imports=skip"]
\ No newline at end of file