From ed2168efa27579de07bdca0cdedf5bf6ab347d14 Mon Sep 17 00:00:00 2001 From: Aakash Kumar Nain Date: Sat, 16 Mar 2024 16:38:00 +0530 Subject: [PATCH] Simple `dtype` argument addition (#680) * add dtype and format code * add a simple test for checking dtype other than float32 * fix default dtype and format code * refine documentation for the dtype argument --- equinox/nn/_linear.py | 15 +++++++++++++-- tests/test_nn.py | 4 ++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/equinox/nn/_linear.py b/equinox/nn/_linear.py index fdf524ff..450ee663 100644 --- a/equinox/nn/_linear.py +++ b/equinox/nn/_linear.py @@ -6,6 +6,7 @@ import jax.random as jrandom from jaxtyping import Array, PRNGKeyArray +from .._misc import default_floating_dtype from .._module import field, Module @@ -23,6 +24,7 @@ def __init__( in_features: Union[int, Literal["scalar"]], out_features: Union[int, Literal["scalar"]], use_bias: bool = True, + dtype=None, *, key: PRNGKeyArray, ): @@ -33,6 +35,9 @@ def __init__( - `out_features`: The output size. The output from the layer will be a vector of shape `(out_features,)`. - `use_bias`: Whether to add on a bias as well. + - `dtype`: The dtype to use for the weight and the bias in this layer. + Defaults to either `jax.numpy.float32` or `jax.numpy.float64` depending + on whether JAX is in 64-bit mode. - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter initialisation. (Keyword only argument.) @@ -46,11 +51,17 @@ def __init__( in_features_ = 1 if in_features == "scalar" else in_features out_features_ = 1 if out_features == "scalar" else out_features lim = 1 / math.sqrt(in_features_) + + if dtype is None: + dtype = default_floating_dtype() + self.weight = jrandom.uniform( - wkey, (out_features_, in_features_), minval=-lim, maxval=lim + wkey, (out_features_, in_features_), minval=-lim, maxval=lim, dtype=dtype ) if use_bias: - self.bias = jrandom.uniform(bkey, (out_features_,), minval=-lim, maxval=lim) + self.bias = jrandom.uniform( + bkey, (out_features_,), minval=-lim, maxval=lim, dtype=dtype + ) else: self.bias = None diff --git a/tests/test_nn.py b/tests/test_nn.py index 95c795df..4afe5ee6 100644 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -44,6 +44,10 @@ def test_linear(getkey): x = jrandom.normal(getkey(), (2,)) assert linear(x).shape == () + linear = eqx.nn.Linear(2, "scalar", key=getkey(), dtype=jnp.float16) + x = jrandom.normal(getkey(), (2,), dtype=jnp.float16) + assert linear(x).dtype == jnp.float16 + def test_identity(getkey): identity1 = eqx.nn.Identity()