Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

eqx.filter_shard; test + update examples/parallelism.ipynb #691

Merged
merged 4 commits into from
Apr 7, 2024

Conversation

homerjed
Copy link
Contributor

Hi Patrick,

Please pardon the terrible git practice if I've made some mistakes. I tend to work alone in my daily projects :)

I've

  • added equinox/_sharding.py with eqx.filter_shard as per the issue thread from the other day,
  • added a test tests/test_sharding.py that tests the function on sharding an eqx.nn.MLP in and outside of eqx.filter_jit,
  • changed the examples/parallelism.ipynb example to have the same content as the same file in your PR you weren't sure about, with the corresponding changes made.

Let me know what you think, again, sorry if the git stuff is wrong!

Cheers,
Jed

@homerjed homerjed changed the title eqx.filter_shard; test + example #688 eqx.filter_shard; test + example Mar 25, 2024
@homerjed homerjed changed the title eqx.filter_shard; test + example eqx.filter_shard; test + update examples/parallelism.ipynb Mar 25, 2024
Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nits aside this LGTM! Thank you for putting this together, it looks awesome :)

Comment on lines 19 to 27
**Arguments:**
- `x`: A PyTree, with potentially a mix of arrays and non-arrays on the leaves.
They will have their shardings constrained.
- `device_or_shardings`: Either a singular device (e.g. CPU or GPU) or PyTree of
sharding specifications. The structure should be a prefix of `x`.
**Returns:**
A copy of `x` with the specified sharding constraints.
!!! Example
See also the [autoparallelism example](../../../examples/parallelism).
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: this will need many newlines in between to format correctly in the docs:

**Arguments:**

- `x`: ...
- `device_or_shardings`: ...

**Returns:**

A copy ...

!!! Example

    See also ...

(check with mkdocs serve)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed this, I've shortened the docstring a bit and formatted it correctly, looks less horrific on mkdocs serve.


@eqx.filter_jit
def f(x):
return eqx.filter_shard(x, sharding)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that no-op computations are special-cased by XLA, so this might not actually test anything. I think this should just do +1 or something simple.

Copy link
Contributor Author

@homerjed homerjed Mar 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed the test here to actually do something, basically just filters the params out and adds one to them, then filter_shards them in the JIT'd function.

@patrick-kidger
Copy link
Owner

Alright, this LGTM! That is, modulo David's concern back in #688. Let's double check with things over there before we merge this.

@patrick-kidger patrick-kidger merged commit be7e36a into patrick-kidger:dev Apr 7, 2024
2 checks passed
@patrick-kidger
Copy link
Owner

Okay, looks like the discussion has settled, so merging! Thank you for putting this together, and I'm excited to have this in the next release of Equinox! :)

@homerjed
Copy link
Contributor Author

homerjed commented Apr 7, 2024

Cheers Patrick! Hope to contribute again :)

@patrick-kidger patrick-kidger mentioned this pull request Apr 14, 2024
patrick-kidger pushed a commit that referenced this pull request Apr 14, 2024
* eqx.filter_shard; test + example

* fixed line lengths

* fixed?

* double checking..
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants