From 02e3c19f0a2c23d9e80db1fbdcde3ed361a0d712 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mart=C3=AD=20Zamora=20Casals?= Date: Tue, 7 Jan 2025 10:47:57 +0100 Subject: [PATCH] Replaced tensorflow for numpy in type tests --- .pre-commit-config.yaml | 4 ++-- test/types/decorator.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1d7f42b..4587e7d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -36,7 +36,7 @@ repos: additional_dependencies: [ beartype, - tensorflow, + numpy<2, ] - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.14.1 @@ -48,6 +48,6 @@ repos: additional_dependencies: [ beartype, - tensorflow, + numpy<2, ] args: ["--ignore-missing-imports", "--follow-imports=skip"] \ No newline at end of file diff --git a/test/types/decorator.py b/test/types/decorator.py index 102314e..1d1ccda 100644 --- a/test/types/decorator.py +++ b/test/types/decorator.py @@ -1,6 +1,6 @@ from dataclasses import dataclass -import tensorflow as tf +import numpy as np from beartype import beartype from jaxtyping import Float, Int, jaxtyped @@ -11,8 +11,8 @@ class User: name: str age: int - items: Float[tf.Tensor, " N"] - timestamps: Int[tf.Tensor, " N"] + items: Float[np.ndarray, " N"] + timestamps: Int[np.ndarray, " N"] @jaxtyped(typechecker=beartype) @@ -24,8 +24,8 @@ def transform_user(user: User, increment_age: int = 1) -> User: user = User( name="John", age=20, - items=tf.random.normal([10]), - timestamps=tf.random.uniform([10], minval=0, maxval=100, dtype=tf.int32), + items=np.random.normal(size=10), + timestamps=np.random.randint(0, 100, size=10), ) new_user = transform_user(user, increment_age=2)