-
-
Notifications
You must be signed in to change notification settings - Fork 150
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support complex dtypes in networks #765
Conversation
It seems that the Now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks excellent. Minor comments aside I'd be very happy to merge this as-is.
Thank you in particular for fixing the sharding test! We were discussing this over in #755. I think I like your fix!
equinox/_misc.py
Outdated
) -> jax.Array: | ||
if jnp.issubdtype(dtype, jnp.complexfloating): | ||
# only two possible complex dtypes, jnp.complex64 or jnp.complex128 | ||
real_dtype = jnp.float32 if dtype == jnp.complex64 else jnp.float64 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a neater trick is this: https://github.com/patrick-kidger/lineax/blob/1909d190c1963d5f2d991508c1b2714f2266048b/lineax/_misc.py#L92-L93
equinox/_misc.py
Outdated
@@ -18,3 +21,17 @@ def default_floating_dtype(): | |||
return jnp.float64 | |||
else: | |||
return jnp.float32 | |||
|
|||
|
|||
def default_init( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, this should go in equinox.nn._misc
!
The revision is done. It looks good to merge now. |
Yup, this all looks great to me! Thank you for the excellent contribution, I'm really happy to have this in :) |
* add support for complex networks * add tests for complex networks * support for older version of jax * replace PositionalSharding by NamedSharding * move default_init to nn._misc --------- Co-authored-by: ChenAo-Phys <[email protected]>
As discussed in #763, this PR adds support of complex dtypes for
Linear
,Conv
,GRUCell
, andLSTMCell
.A default initializer is defined in
eqx._misc
and called in initializations. Is it better to define it ineqx.nn._misc
?Some tests are added in
tests/test_nn.py
for complex dtypes and passed.A minor problem that makes the test fail in my
jax==0.4.25
is also fixed intests/test_debug.py
.