Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simple dtype argument addition #680

Merged
merged 4 commits into from
Mar 16, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions equinox/nn/_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import jax.random as jrandom
from jaxtyping import Array, PRNGKeyArray

from .._misc import default_floating_dtype
from .._module import field, Module


Expand All @@ -23,6 +24,7 @@ def __init__(
in_features: Union[int, Literal["scalar"]],
out_features: Union[int, Literal["scalar"]],
use_bias: bool = True,
dtype=None,
Copy link
Contributor

@Artur-Galstyan Artur-Galstyan Mar 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type of dtype is missing here. I don't have my laptop for a couple of days but I think the type of dtype is jax.numpy.dtype. And it should also be something like "dtype: Optional[jax.numpy.dtype] = None"

*,
key: PRNGKeyArray,
):
Expand All @@ -33,6 +35,8 @@ def __init__(
- `out_features`: The output size. The output from the layer will be a vector
of shape `(out_features,)`.
- `use_bias`: Whether to add on a bias as well.
- `dtype`: The dtype to use. Defaults to either `jax.numpy.float32` or
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also tell the user that it applies to all trainable parameters (not just the weights but also the bias) IMO

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated this

`jax.numpy.float64` depending on whether JAX is in 64-bit mode.
- `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
initialisation. (Keyword only argument.)

Expand All @@ -46,11 +50,17 @@ def __init__(
in_features_ = 1 if in_features == "scalar" else in_features
out_features_ = 1 if out_features == "scalar" else out_features
lim = 1 / math.sqrt(in_features_)

if dtype is None:
dtype = default_floating_dtype()

self.weight = jrandom.uniform(
wkey, (out_features_, in_features_), minval=-lim, maxval=lim
wkey, (out_features_, in_features_), minval=-lim, maxval=lim, dtype=dtype
)
if use_bias:
self.bias = jrandom.uniform(bkey, (out_features_,), minval=-lim, maxval=lim)
self.bias = jrandom.uniform(
bkey, (out_features_,), minval=-lim, maxval=lim, dtype=dtype
)
else:
self.bias = None

Expand Down
4 changes: 4 additions & 0 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ def test_linear(getkey):
x = jrandom.normal(getkey(), (2,))
assert linear(x).shape == ()

linear = eqx.nn.Linear(2, "scalar", key=getkey(), dtype=jnp.float16)
x = jrandom.normal(getkey(), (2,), dtype=jnp.float16)
assert linear(x).dtype == jnp.float16


def test_identity(getkey):
identity1 = eqx.nn.Identity()
Expand Down