Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support complex dtypes in networks #765

Merged
merged 5 commits into from
Jun 24, 2024

Conversation

ChenAo-Phys
Copy link
Contributor

As discussed in #763, this PR adds support of complex dtypes for Linear, Conv, GRUCell, and LSTMCell.

A default initializer is defined in eqx._misc and called in initializations. Is it better to define it in eqx.nn._misc?

Some tests are added in tests/test_nn.py for complex dtypes and passed.

A minor problem that makes the test fail in my jax==0.4.25 is also fixed in tests/test_debug.py.

@ChenAo-Phys
Copy link
Contributor Author

It seems that the eqx.filter_shard test case in tests/test_sharding.py is no longer compatible with the newest jax==0.4.30, which is also seen in #342.

Now PositionalSharding requires all arrays in the pytree to have the same number of dimensions. But NamedSharding doesn't need it, so I modify the test and let it pass. The change of PositionalSharding could be a bug in the newest jax?

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks excellent. Minor comments aside I'd be very happy to merge this as-is.

Thank you in particular for fixing the sharding test! We were discussing this over in #755. I think I like your fix!

equinox/_misc.py Outdated
) -> jax.Array:
if jnp.issubdtype(dtype, jnp.complexfloating):
# only two possible complex dtypes, jnp.complex64 or jnp.complex128
real_dtype = jnp.float32 if dtype == jnp.complex64 else jnp.float64
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

equinox/_misc.py Outdated
@@ -18,3 +21,17 @@ def default_floating_dtype():
return jnp.float64
else:
return jnp.float32


def default_init(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, this should go in equinox.nn._misc!

@ChenAo-Phys
Copy link
Contributor Author

The revision is done. It looks good to merge now.

@patrick-kidger patrick-kidger merged commit b962ca7 into patrick-kidger:main Jun 24, 2024
2 checks passed
@patrick-kidger
Copy link
Owner

Yup, this all looks great to me!

Thank you for the excellent contribution, I'm really happy to have this in :)

Artur-Galstyan pushed a commit to Artur-Galstyan/equinox that referenced this pull request Jun 27, 2024
* add support for complex networks

* add tests for complex networks

* support for older version of jax

* replace PositionalSharding by NamedSharding

* move default_init to nn._misc

---------

Co-authored-by: ChenAo-Phys <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants