Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support complex dtypes in networks #765

Merged
merged 5 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 13 additions & 39 deletions equinox/nn/_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from .._misc import default_floating_dtype
from .._module import field, Module
from ._misc import all_sequences
from ._misc import all_sequences, default_init


_T = TypeVar("_T")
Expand Down Expand Up @@ -153,8 +153,6 @@ def __init__(
case the padding is an odd number, then the extra padding is added at
the end for `'SAME'` and at the beginning for `'SAME_LOWER'`.
"""
dtype = default_floating_dtype() if dtype is None else dtype
wkey, bkey = jrandom.split(key, 2)

parse = _ntuple(num_spatial_dims)
kernel_size = parse(kernel_size)
Expand All @@ -167,25 +165,14 @@ def __init__(
f"by `groups` (={groups})."
)

dtype = default_floating_dtype() if dtype is None else dtype
wkey, bkey = jrandom.split(key, 2)
grouped_in_channels = in_channels // groups
lim = 1 / np.sqrt(grouped_in_channels * math.prod(kernel_size))
self.weight = jrandom.uniform(
wkey,
(out_channels, grouped_in_channels) + kernel_size,
minval=-lim,
maxval=lim,
dtype=dtype,
)
if use_bias:
self.bias = jrandom.uniform(
bkey,
(out_channels,) + (1,) * num_spatial_dims,
minval=-lim,
maxval=lim,
dtype=dtype,
)
else:
self.bias = None
lim = 1 / math.sqrt(grouped_in_channels * math.prod(kernel_size))
wshape = (out_channels, grouped_in_channels) + kernel_size
self.weight = default_init(wkey, wshape, dtype, lim)
bshape = (out_channels,) + (1,) * num_spatial_dims
self.bias = default_init(bkey, bshape, dtype, lim) if use_bias else None

self.num_spatial_dims = num_spatial_dims
self.in_channels = in_channels
Expand Down Expand Up @@ -520,24 +507,11 @@ def __init__(
raise ValueError("Must have `output_padding < stride` (elementwise).")

grouped_in_channels = in_channels // groups
lim = 1 / np.sqrt(grouped_in_channels * math.prod(kernel_size))
self.weight = jrandom.uniform(
wkey,
(out_channels, grouped_in_channels) + kernel_size,
minval=-lim,
maxval=lim,
dtype=dtype,
)
if use_bias:
self.bias = jrandom.uniform(
bkey,
(out_channels,) + (1,) * num_spatial_dims,
minval=-lim,
maxval=lim,
dtype=dtype,
)
else:
self.bias = None
lim = 1 / math.sqrt(grouped_in_channels * math.prod(kernel_size))
wshape = (out_channels, grouped_in_channels) + kernel_size
self.weight = default_init(wkey, wshape, dtype, lim)
bshape = (out_channels,) + (1,) * num_spatial_dims
self.bias = default_init(bkey, bshape, dtype, lim) if use_bias else None

padding = _padding_init(padding, num_spatial_dims)
padding_mode = _padding_mode_init(padding_mode)
Expand Down
19 changes: 6 additions & 13 deletions equinox/nn/_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from .._misc import default_floating_dtype
from .._module import field, Module
from ._misc import default_init


class Linear(Module, strict=True):
Expand Down Expand Up @@ -47,23 +48,15 @@ def __init__(
Likewise `out_features` can also be a string `"scalar"`, in which case the
output from the layer will have shape `()`.
"""
dtype = default_floating_dtype() if dtype is None else dtype
wkey, bkey = jrandom.split(key, 2)
in_features_ = 1 if in_features == "scalar" else in_features
out_features_ = 1 if out_features == "scalar" else out_features
lim = 1 / math.sqrt(in_features_)

if dtype is None:
dtype = default_floating_dtype()

self.weight = jrandom.uniform(
wkey, (out_features_, in_features_), minval=-lim, maxval=lim, dtype=dtype
)
if use_bias:
self.bias = jrandom.uniform(
bkey, (out_features_,), minval=-lim, maxval=lim, dtype=dtype
)
else:
self.bias = None
wshape = (out_features_, in_features_)
self.weight = default_init(wkey, wshape, dtype, lim)
bshape = (out_features_,)
self.bias = default_init(bkey, bshape, dtype, lim) if use_bias else None

self.in_features = in_features
self.out_features = out_features
Expand Down
18 changes: 18 additions & 0 deletions equinox/nn/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
from collections.abc import Sequence
from typing import Any, TYPE_CHECKING, TypeVar, Union

import jax
import jax.numpy as jnp
import jax.random as jrandom
from jaxtyping import PRNGKeyArray


_T = TypeVar("_T", bound=Sequence)

Expand All @@ -17,3 +22,16 @@ def all_sequences(
# beartype doesn't like StrictTypeGuard
def all_sequences(x: Union[Sequence[Any], Sequence[_T]]) -> bool:
return all(isinstance(xi, Sequence) for xi in x)


def default_init(
key: PRNGKeyArray, shape: tuple[int, ...], dtype: Any, lim: float
) -> jax.Array:
if jnp.issubdtype(dtype, jnp.complexfloating):
real_dtype = jnp.finfo(dtype).dtype
rkey, ikey = jrandom.split(key, 2)
real = jrandom.uniform(rkey, shape, real_dtype, minval=-lim, maxval=lim)
imag = jrandom.uniform(ikey, shape, real_dtype, minval=-lim, maxval=lim)
return real.astype(dtype) + 1j * imag.astype(dtype)
else:
return jrandom.uniform(key, shape, dtype, minval=-lim, maxval=lim)
37 changes: 13 additions & 24 deletions equinox/nn/_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from .._misc import default_floating_dtype
from .._module import field, Module
from ._misc import default_init


class GRUCell(Module, strict=True):
Expand Down Expand Up @@ -66,19 +67,13 @@ def __init__(
ihkey, hhkey, bkey, bkey2 = jrandom.split(key, 4)
lim = math.sqrt(1 / hidden_size)

self.weight_ih = jrandom.uniform(
ihkey, (3 * hidden_size, input_size), minval=-lim, maxval=lim, dtype=dtype
)
self.weight_hh = jrandom.uniform(
hhkey, (3 * hidden_size, hidden_size), minval=-lim, maxval=lim, dtype=dtype
)
ihshape = (3 * hidden_size, input_size)
self.weight_ih = default_init(ihkey, ihshape, dtype, lim)
hhshape = (3 * hidden_size, hidden_size)
self.weight_hh = default_init(hhkey, hhshape, dtype, lim)
if use_bias:
self.bias = jrandom.uniform(
bkey, (3 * hidden_size,), minval=-lim, maxval=lim, dtype=dtype
)
self.bias_n = jrandom.uniform(
bkey2, (hidden_size,), minval=-lim, maxval=lim, dtype=dtype
)
self.bias = default_init(bkey, (3 * hidden_size,), dtype, lim)
self.bias_n = default_init(bkey2, (hidden_size,), dtype, lim)
else:
self.bias = None
self.bias_n = None
Expand Down Expand Up @@ -172,18 +167,12 @@ def __init__(
ihkey, hhkey, bkey = jrandom.split(key, 3)
lim = math.sqrt(1 / hidden_size)

self.weight_ih = jrandom.uniform(
ihkey, (4 * hidden_size, input_size), minval=-lim, maxval=lim, dtype=dtype
)
self.weight_hh = jrandom.uniform(
hhkey, (4 * hidden_size, hidden_size), minval=-lim, maxval=lim, dtype=dtype
)
if use_bias:
self.bias = jrandom.uniform(
bkey, (4 * hidden_size,), minval=-lim, maxval=lim, dtype=dtype
)
else:
self.bias = None
ihshape = (4 * hidden_size, input_size)
self.weight_ih = default_init(ihkey, ihshape, dtype, lim)
hhshape = (4 * hidden_size, hidden_size)
self.weight_hh = default_init(hhkey, hhshape, dtype, lim)
bshape = (4 * hidden_size,)
self.bias = default_init(bkey, bshape, dtype, lim) if use_bias else None

self.input_size = input_size
self.hidden_size = hidden_size
Expand Down
7 changes: 3 additions & 4 deletions tests/test_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,9 @@ def f(x, terminate):
f(jnp.array(1.0), terminate=False)
jax.effects_barrier()
text, _ = capfd.readouterr()
assert (
text
== "foo:\n primals=Array(1., dtype=float32)\ncotangents=Array(nan, dtype=float32)\n" # noqa: E501
)
out_text1 = "foo:\n primals=Array(1., dtype=float32)\ncotangents=Array(nan, dtype=float32)\n" # noqa: E501
out_text2 = "foo:\n primals=array(1., dtype=float32)\ncotangents=array(nan, dtype=float32)\n" # noqa: E501
assert text in (out_text1, out_text2)

with pytest.raises(Exception):
f(jnp.array(1.0), terminate=True)
Expand Down
16 changes: 16 additions & 0 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def test_linear(getkey):
x = jrandom.normal(getkey(), (2,), dtype=jnp.float16)
assert linear(x).dtype == jnp.float16

linear = eqx.nn.Linear(2, "scalar", key=getkey(), dtype=jnp.complex64)
x = jrandom.normal(getkey(), (2,), dtype=jnp.complex64)
assert linear(x).dtype == jnp.complex64


def test_identity(getkey):
identity1 = eqx.nn.Identity()
Expand Down Expand Up @@ -344,6 +348,18 @@ def test_conv2d(getkey):
answer = jnp.array([-37, -31, -9, 25, 61, 49, 23, 41, 27]).reshape(1, 3, 3)
assert jnp.allclose(conv(data), answer)

# Test complex value matches
conv = eqx.nn.Conv2d(1, 1, 3, padding=1, dtype=jnp.complex64, key=getkey())
new_weight = jnp.arange(9, dtype=jnp.complex64).reshape(1, 1, 3, 3)
new_bias = jnp.array([1 + 1j], dtype=jnp.complex64).reshape(1, 1, 1)
data = (1 + 1j) * jnp.arange(-4, 5, dtype=jnp.complex64).reshape(1, 3, 3)
assert new_weight.shape == conv.weight.shape
assert new_bias.shape == conv.bias.shape # pyright: ignore
conv = eqx.tree_at(lambda x: (x.weight, x.bias), conv, (new_weight, new_bias))
answer = jnp.array([-37, -31, -9, 25, 61, 49, 23, 41, 27]).reshape(1, 3, 3)
answer = (1 + 1j) * answer.astype(jnp.complex64)
assert jnp.allclose(conv(data), answer)

# Test groups
conv = eqx.nn.Conv2d(2, 2, kernel_size=3, padding=1, key=getkey(), groups=2)
# we will duplicate the weights from the "value matches" case
Expand Down
3 changes: 2 additions & 1 deletion tests/test_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
import jax
import jax.random as jr
import jax.tree_util as jtu
from jax.sharding import Mesh, PartitionSpec


[cpu] = jax.local_devices(backend="cpu")
sharding = jax.sharding.PositionalSharding([cpu])
sharding = jax.sharding.NamedSharding(Mesh([cpu], "x"), PartitionSpec("x"))


def test_sharding():
Expand Down
Loading