Skip to content

Commit

Permalink
Simple dtype argument addition (#680)
Browse files Browse the repository at this point in the history
* add dtype and format code

* add a simple test for checking dtype other than float32

* fix default dtype and format code

* refine documentation for the dtype argument
  • Loading branch information
AakashKumarNain authored Mar 16, 2024
1 parent 3337800 commit 3061c18
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
15 changes: 13 additions & 2 deletions equinox/nn/_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import jax.random as jrandom
from jaxtyping import Array, PRNGKeyArray

from .._misc import default_floating_dtype
from .._module import field, Module


Expand All @@ -23,6 +24,7 @@ def __init__(
in_features: Union[int, Literal["scalar"]],
out_features: Union[int, Literal["scalar"]],
use_bias: bool = True,
dtype=None,
*,
key: PRNGKeyArray,
):
Expand All @@ -33,6 +35,9 @@ def __init__(
- `out_features`: The output size. The output from the layer will be a vector
of shape `(out_features,)`.
- `use_bias`: Whether to add on a bias as well.
- `dtype`: The dtype to use for the weight and the bias in this layer.
Defaults to either `jax.numpy.float32` or `jax.numpy.float64` depending
on whether JAX is in 64-bit mode.
- `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
initialisation. (Keyword only argument.)
Expand All @@ -46,11 +51,17 @@ def __init__(
in_features_ = 1 if in_features == "scalar" else in_features
out_features_ = 1 if out_features == "scalar" else out_features
lim = 1 / math.sqrt(in_features_)

if dtype is None:
dtype = default_floating_dtype()

self.weight = jrandom.uniform(
wkey, (out_features_, in_features_), minval=-lim, maxval=lim
wkey, (out_features_, in_features_), minval=-lim, maxval=lim, dtype=dtype
)
if use_bias:
self.bias = jrandom.uniform(bkey, (out_features_,), minval=-lim, maxval=lim)
self.bias = jrandom.uniform(
bkey, (out_features_,), minval=-lim, maxval=lim, dtype=dtype
)
else:
self.bias = None

Expand Down
4 changes: 4 additions & 0 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ def test_linear(getkey):
x = jrandom.normal(getkey(), (2,))
assert linear(x).shape == ()

linear = eqx.nn.Linear(2, "scalar", key=getkey(), dtype=jnp.float16)
x = jrandom.normal(getkey(), (2,), dtype=jnp.float16)
assert linear(x).dtype == jnp.float16


def test_identity(getkey):
identity1 = eqx.nn.Identity()
Expand Down

0 comments on commit 3061c18

Please sign in to comment.