Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support complex dtypes in networks #765

Merged
merged 5 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion equinox/_misc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Any

import jax
import jax.core
import jax.numpy as jnp
from jaxtyping import Array
import jax.random as jrandom
from jaxtyping import Array, PRNGKeyArray


def left_broadcast_to(arr: Array, shape: tuple[int, ...]) -> Array:
Expand All @@ -18,3 +21,17 @@ def default_floating_dtype():
return jnp.float64
else:
return jnp.float32


def default_init(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, this should go in equinox.nn._misc!

key: PRNGKeyArray, shape: tuple[int, ...], dtype: Any, lim: float
) -> jax.Array:
if jnp.issubdtype(dtype, jnp.complexfloating):
# only two possible complex dtypes, jnp.complex64 or jnp.complex128
real_dtype = jnp.float32 if dtype == jnp.complex64 else jnp.float64
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rkey, ikey = jrandom.split(key, 2)
real = jrandom.uniform(rkey, shape, real_dtype, minval=-lim, maxval=lim)
imag = jrandom.uniform(ikey, shape, real_dtype, minval=-lim, maxval=lim)
return real.astype(dtype) + 1j * imag.astype(dtype)
else:
return jrandom.uniform(key, shape, dtype, minval=-lim, maxval=lim)
52 changes: 13 additions & 39 deletions equinox/nn/_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np
from jaxtyping import Array, PRNGKeyArray

from .._misc import default_floating_dtype
from .._misc import default_floating_dtype, default_init
from .._module import field, Module
from ._misc import all_sequences

Expand Down Expand Up @@ -153,8 +153,6 @@ def __init__(
case the padding is an odd number, then the extra padding is added at
the end for `'SAME'` and at the beginning for `'SAME_LOWER'`.
"""
dtype = default_floating_dtype() if dtype is None else dtype
wkey, bkey = jrandom.split(key, 2)

parse = _ntuple(num_spatial_dims)
kernel_size = parse(kernel_size)
Expand All @@ -167,25 +165,14 @@ def __init__(
f"by `groups` (={groups})."
)

dtype = default_floating_dtype() if dtype is None else dtype
wkey, bkey = jrandom.split(key, 2)
grouped_in_channels = in_channels // groups
lim = 1 / np.sqrt(grouped_in_channels * math.prod(kernel_size))
self.weight = jrandom.uniform(
wkey,
(out_channels, grouped_in_channels) + kernel_size,
minval=-lim,
maxval=lim,
dtype=dtype,
)
if use_bias:
self.bias = jrandom.uniform(
bkey,
(out_channels,) + (1,) * num_spatial_dims,
minval=-lim,
maxval=lim,
dtype=dtype,
)
else:
self.bias = None
lim = 1 / math.sqrt(grouped_in_channels * math.prod(kernel_size))
wshape = (out_channels, grouped_in_channels) + kernel_size
self.weight = default_init(wkey, wshape, dtype, lim)
bshape = (out_channels,) + (1,) * num_spatial_dims
self.bias = default_init(bkey, bshape, dtype, lim) if use_bias else None

self.num_spatial_dims = num_spatial_dims
self.in_channels = in_channels
Expand Down Expand Up @@ -520,24 +507,11 @@ def __init__(
raise ValueError("Must have `output_padding < stride` (elementwise).")

grouped_in_channels = in_channels // groups
lim = 1 / np.sqrt(grouped_in_channels * math.prod(kernel_size))
self.weight = jrandom.uniform(
wkey,
(out_channels, grouped_in_channels) + kernel_size,
minval=-lim,
maxval=lim,
dtype=dtype,
)
if use_bias:
self.bias = jrandom.uniform(
bkey,
(out_channels,) + (1,) * num_spatial_dims,
minval=-lim,
maxval=lim,
dtype=dtype,
)
else:
self.bias = None
lim = 1 / math.sqrt(grouped_in_channels * math.prod(kernel_size))
wshape = (out_channels, grouped_in_channels) + kernel_size
self.weight = default_init(wkey, wshape, dtype, lim)
bshape = (out_channels,) + (1,) * num_spatial_dims
self.bias = default_init(bkey, bshape, dtype, lim) if use_bias else None

padding = _padding_init(padding, num_spatial_dims)
padding_mode = _padding_mode_init(padding_mode)
Expand Down
20 changes: 6 additions & 14 deletions equinox/nn/_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import jax.random as jrandom
from jaxtyping import Array, PRNGKeyArray

from .._misc import default_floating_dtype
from .._misc import default_floating_dtype, default_init
from .._module import field, Module


Expand Down Expand Up @@ -47,23 +47,15 @@ def __init__(
Likewise `out_features` can also be a string `"scalar"`, in which case the
output from the layer will have shape `()`.
"""
dtype = default_floating_dtype() if dtype is None else dtype
wkey, bkey = jrandom.split(key, 2)
in_features_ = 1 if in_features == "scalar" else in_features
out_features_ = 1 if out_features == "scalar" else out_features
lim = 1 / math.sqrt(in_features_)

if dtype is None:
dtype = default_floating_dtype()

self.weight = jrandom.uniform(
wkey, (out_features_, in_features_), minval=-lim, maxval=lim, dtype=dtype
)
if use_bias:
self.bias = jrandom.uniform(
bkey, (out_features_,), minval=-lim, maxval=lim, dtype=dtype
)
else:
self.bias = None
wshape = (out_features_, in_features_)
self.weight = default_init(wkey, wshape, dtype, lim)
bshape = (out_features_,)
self.bias = default_init(bkey, bshape, dtype, lim) if use_bias else None

self.in_features = in_features
self.out_features = out_features
Expand Down
38 changes: 13 additions & 25 deletions equinox/nn/_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import jax.random as jrandom
from jaxtyping import Array, PRNGKeyArray

from .._misc import default_floating_dtype
from .._misc import default_floating_dtype, default_init
from .._module import field, Module


Expand Down Expand Up @@ -66,19 +66,13 @@ def __init__(
ihkey, hhkey, bkey, bkey2 = jrandom.split(key, 4)
lim = math.sqrt(1 / hidden_size)

self.weight_ih = jrandom.uniform(
ihkey, (3 * hidden_size, input_size), minval=-lim, maxval=lim, dtype=dtype
)
self.weight_hh = jrandom.uniform(
hhkey, (3 * hidden_size, hidden_size), minval=-lim, maxval=lim, dtype=dtype
)
ihshape = (3 * hidden_size, input_size)
self.weight_ih = default_init(ihkey, ihshape, dtype, lim)
hhshape = (3 * hidden_size, hidden_size)
self.weight_hh = default_init(hhkey, hhshape, dtype, lim)
if use_bias:
self.bias = jrandom.uniform(
bkey, (3 * hidden_size,), minval=-lim, maxval=lim, dtype=dtype
)
self.bias_n = jrandom.uniform(
bkey2, (hidden_size,), minval=-lim, maxval=lim, dtype=dtype
)
self.bias = default_init(bkey, (3 * hidden_size,), dtype, lim)
self.bias_n = default_init(bkey2, (hidden_size,), dtype, lim)
else:
self.bias = None
self.bias_n = None
Expand Down Expand Up @@ -172,18 +166,12 @@ def __init__(
ihkey, hhkey, bkey = jrandom.split(key, 3)
lim = math.sqrt(1 / hidden_size)

self.weight_ih = jrandom.uniform(
ihkey, (4 * hidden_size, input_size), minval=-lim, maxval=lim, dtype=dtype
)
self.weight_hh = jrandom.uniform(
hhkey, (4 * hidden_size, hidden_size), minval=-lim, maxval=lim, dtype=dtype
)
if use_bias:
self.bias = jrandom.uniform(
bkey, (4 * hidden_size,), minval=-lim, maxval=lim, dtype=dtype
)
else:
self.bias = None
ihshape = (4 * hidden_size, input_size)
self.weight_ih = default_init(ihkey, ihshape, dtype, lim)
hhshape = (4 * hidden_size, hidden_size)
self.weight_hh = default_init(hhkey, hhshape, dtype, lim)
bshape = (4 * hidden_size,)
self.bias = default_init(bkey, bshape, dtype, lim) if use_bias else None

self.input_size = input_size
self.hidden_size = hidden_size
Expand Down
7 changes: 3 additions & 4 deletions tests/test_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,9 @@ def f(x, terminate):
f(jnp.array(1.0), terminate=False)
jax.effects_barrier()
text, _ = capfd.readouterr()
assert (
text
== "foo:\n primals=Array(1., dtype=float32)\ncotangents=Array(nan, dtype=float32)\n" # noqa: E501
)
out_text1 = "foo:\n primals=Array(1., dtype=float32)\ncotangents=Array(nan, dtype=float32)\n" # noqa: E501
out_text2 = "foo:\n primals=array(1., dtype=float32)\ncotangents=array(nan, dtype=float32)\n" # noqa: E501
assert text in (out_text1, out_text2)

with pytest.raises(Exception):
f(jnp.array(1.0), terminate=True)
Expand Down
16 changes: 16 additions & 0 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def test_linear(getkey):
x = jrandom.normal(getkey(), (2,), dtype=jnp.float16)
assert linear(x).dtype == jnp.float16

linear = eqx.nn.Linear(2, "scalar", key=getkey(), dtype=jnp.complex64)
x = jrandom.normal(getkey(), (2,), dtype=jnp.complex64)
assert linear(x).dtype == jnp.complex64


def test_identity(getkey):
identity1 = eqx.nn.Identity()
Expand Down Expand Up @@ -344,6 +348,18 @@ def test_conv2d(getkey):
answer = jnp.array([-37, -31, -9, 25, 61, 49, 23, 41, 27]).reshape(1, 3, 3)
assert jnp.allclose(conv(data), answer)

# Test complex value matches
conv = eqx.nn.Conv2d(1, 1, 3, padding=1, dtype=jnp.complex64, key=getkey())
new_weight = jnp.arange(9, dtype=jnp.complex64).reshape(1, 1, 3, 3)
new_bias = jnp.array([1 + 1j], dtype=jnp.complex64).reshape(1, 1, 1)
data = (1 + 1j) * jnp.arange(-4, 5, dtype=jnp.complex64).reshape(1, 3, 3)
assert new_weight.shape == conv.weight.shape
assert new_bias.shape == conv.bias.shape # pyright: ignore
conv = eqx.tree_at(lambda x: (x.weight, x.bias), conv, (new_weight, new_bias))
answer = jnp.array([-37, -31, -9, 25, 61, 49, 23, 41, 27]).reshape(1, 3, 3)
answer = (1 + 1j) * answer.astype(jnp.complex64)
assert jnp.allclose(conv(data), answer)

# Test groups
conv = eqx.nn.Conv2d(2, 2, kernel_size=3, padding=1, key=getkey(), groups=2)
# we will duplicate the weights from the "value matches" case
Expand Down
3 changes: 2 additions & 1 deletion tests/test_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
import jax
import jax.random as jr
import jax.tree_util as jtu
from jax.sharding import Mesh, PartitionSpec


[cpu] = jax.local_devices(backend="cpu")
sharding = jax.sharding.PositionalSharding([cpu])
sharding = jax.sharding.NamedSharding(Mesh([cpu], "x"), PartitionSpec("x"))


def test_sharding():
Expand Down
Loading