-
-
Notifications
You must be signed in to change notification settings - Fork 150
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
eqx.filter_shard
; test + update examples/parallelism.ipynb
#691
Conversation
eqx.filter_shard
; test + example
eqx.filter_shard
; test + exampleeqx.filter_shard
; test + update examples/parallelism.ipynb
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nits aside this LGTM! Thank you for putting this together, it looks awesome :)
equinox/_sharding.py
Outdated
**Arguments:** | ||
- `x`: A PyTree, with potentially a mix of arrays and non-arrays on the leaves. | ||
They will have their shardings constrained. | ||
- `device_or_shardings`: Either a singular device (e.g. CPU or GPU) or PyTree of | ||
sharding specifications. The structure should be a prefix of `x`. | ||
**Returns:** | ||
A copy of `x` with the specified sharding constraints. | ||
!!! Example | ||
See also the [autoparallelism example](../../../examples/parallelism). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: this will need many newlines in between to format correctly in the docs:
**Arguments:**
- `x`: ...
- `device_or_shardings`: ...
**Returns:**
A copy ...
!!! Example
See also ...
(check with mkdocs serve
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I missed this, I've shortened the docstring a bit and formatted it correctly, looks less horrific on mkdocs serve
.
|
||
@eqx.filter_jit | ||
def f(x): | ||
return eqx.filter_shard(x, sharding) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that no-op computations are special-cased by XLA, so this might not actually test anything. I think this should just do +1
or something simple.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed the test here to actually do something, basically just filters the params out and adds one to them, then filter_shard
s them in the JIT'd function.
Alright, this LGTM! That is, modulo David's concern back in #688. Let's double check with things over there before we merge this. |
Okay, looks like the discussion has settled, so merging! Thank you for putting this together, and I'm excited to have this in the next release of Equinox! :) |
Cheers Patrick! Hope to contribute again :) |
* eqx.filter_shard; test + example * fixed line lengths * fixed? * double checking..
Hi Patrick,
Please pardon the terrible
git
practice if I've made some mistakes. I tend to work alone in my daily projects :)I've
equinox/_sharding.py
witheqx.filter_shard
as per the issue thread from the other day,tests/test_sharding.py
that tests the function on sharding aneqx.nn.MLP
in and outside ofeqx.filter_jit
,examples/parallelism.ipynb
example to have the same content as the same file in your PR you weren't sure about, with the corresponding changes made.Let me know what you think, again, sorry if the
git
stuff is wrong!Cheers,
Jed