-
-
Notifications
You must be signed in to change notification settings - Fork 150
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sharding - shard eqx.Module
as well as inputs?
#688
Comments
The reason is just that I should really get around to updating the Equinox doc :) I do actually have an update written for it in this PR but I wasn't completely satisfied with the new APIs it offers, so this has been languishing a bit. (I'm wondering whether we can unify If you haven't seen it yet then you may also like Levanter, which is built on top of Equinox to perform training of large-scale models. I believe it has a number of far-more-carefully-thought-out APIs for parallelism! |
Ah cool, sounds good :) I had a quick look. I see one should use the I guess one way to combine the Putting a [device] = jax.devices()
# Error: shardings leaf specifications are expected to be PartitionSpec instances or None
lax.with_sharding_constraint(jnp.ones((2,)), device) You could correct this for any sharding or device someone might try to filter with [device] = jax.devices()
# If someone wants to put the arrays (that would be filtered from a PyTree) onto a device, convert it to a sharding first
device = jax.sharding.SingleDeviceSharding(device)
# No error
lax.with_sharding_constraint(jnp.ones((2,)), device) which negates the use of Here is a small example: from typing import Any
import jax
import jaxlib
from jax import lax
import jax.experimental.mesh_utils as mesh_utils
import jax.numpy as jnp
import jax.random as jr
import jax.sharding as sharding
import equinox as eqx
from jaxtyping import PyTree
data_dim = 2
hidden_size = 8
depth = 1
key = jr.PRNGKey(0)
key_u, key_model = jr.split(key)
x = jr.uniform(key_u, (data_dim,)) * 2. * jnp.pi
y = jnp.sin(x)
model = eqx.nn.MLP(
data_dim, data_dim, hidden_size, depth, key=key
)
def filter_with_sharding_constraint(
x: PyTree[Any],
device_or_shardings: jaxlib.xla_extension.Device | jax.sharding.Sharding
):
if isinstance(device_or_shardings, jaxlib.xla_extension.Device):
shardings = jax.sharding.SingleDeviceSharding(device_or_shardings)
else:
shardings = device_or_shardings
dynamic, static = eqx.partition(x, eqx.is_array)
dynamic = lax.with_sharding_constraint(dynamic, shardings)
return eqx.combine(dynamic, static)
num_devices = len(jax.devices())
devices = mesh_utils.create_device_mesh((num_devices, 1))
sharding = jax.sharding.PositionalSharding(devices)
replicated = sharding.replicate()
model = filter_with_sharding_constraint(model, sharding)
@eqx.filter_jit(donate="all-except-first")
def evaluate(model, x, y, sharding):
replicated = sharding.replicate()
model = filter_with_sharding_constraint(model, replicated)
x, y = filter_with_sharding_constraint((x, y), sharding)
return jnp.mean((y - model(x)) ** 2.)
evaluate(model, x, y, sharding)
# Separately...
model = eqx.nn.MLP(
data_dim, data_dim, hidden_size, depth, key=key
)
[device] = jax.devices() # Single device case
filter_with_sharding_constraint(model, device) In this example, replicating the sharding would throw an error if it is a device. The user should know that which seems fair since they have entered the domain of sharding etc. at this point (e.g. they would use The function may need a different name, something a bit less clunky than |
Oh, that's exceptionally elegant. I really like that! So I've just checked the equivalent Levanter (Haliax) function, which is implemented here: (Note that the Notably they do seem to try to do something more complicated, although I'm not sure Do we miss anything by using As for the name, agreed on something less clunky! Maybe just |
Great - it felt almost too simple - which it may be, barring anything I might be missing with using I think I see what Levanter are doing, it seems to be a kind of filtering except for named arrays. They check if So I think the They're also assessing whether the equivalent I've looked into the Yes, def filter_shard(
x: PyTree[Any],
device_or_shard: jaxlib.xla_extension.Device | jax.sharding.Sharding
):
if isinstance(device_or_shard, jaxlib.xla_extension.Device):
shardings = jax.sharding.SingleDeviceSharding(device_or_shardings)
else:
shardings = device_or_shardings
dynamic, static = eqx.partition(x, eqx.is_array)
dynamic = lax.with_sharding_constraint(dynamic, shardings)
return eqx.combine(dynamic, static) |
Great! As for the rest of it -- if you feel so inclined then I'd be very happy to take a PR (against the dev branch) adding this in! |
At some point I'm pretty sure wsc didn't work outside of jit and I missed the memo on when it started working! Or I could have just totally misunderstood. Looks like I can simplify things on my end! Partitioning is indeed more elegant than what I was doing. I will switch to that. (I'll probably just call out to eqx.filter_shard once it lands!) Aside from that, the name mapping stuff is Haliax specific, but the rest probably needn't be. All hax.shard does is:
I think in the bad-old-days of pre-jax.Array, wsc didn't work outside of jit, and I never bothered to look at the change, and a lot of the logic around shardings was trying to make that function work when everything was still in flux. Or, again, could have been operating under a misapprehension all this time. Re address-ability: You should test to make sure your function works in multi-process mode, e.g. on a multihost TPU, or multi GPU (when using one process per gpu), or multihost GPU. It's very easy to get bitten by something that looks right on a single machine but suddenly isn't once you go multi-machine. the is_equivalent_to was to get around array_from_callback when the sharding was a no-op. |
just tested it out, using wsc instead of my complexity works just fine. thanks for finding that! |
Forgot to link the commit here... add |
FYI I forgot one of the reasons for why I did it this way. wsc doesn't work if the source array doesn't share a backend with the the target sharding. Arguably those should be different functions, but I kind of like them being the same. |
Hmmm, bother, that might be annoying. IIUC this is an issue specifically when outside of JIT, since within JIT then everything is necessarily on the same backend. Maybe we should just do a |
I think jit is actually going to start support at least cross-device device_put (there's work on cpu offloading). I'm gonna bother the jax team and see what I can get |
I still need to test distirbuted/multihost with this patch, but this fixes cpu->gpu |
I think they don't particularly like the way I'm doing it (they don't like me sniffing out jit vs not jit), but I get the sense that ^^ should work? It seems like they're thinking about merging the two functions jax-ml/jax#19670 (comment) |
So my main concern with performing Whilst I've not tried benchmarking the effect here, as a rule of thumb I try to only perform such checks outside of performance-critical contexts. (If this is indeed a performance concern, then FWIW at least one improvement would be to move the |
Yeah i think sniffing the trace out is probably the cheapest thing. It's almost enough to check if the array is a tracer or not via instanceof, but I think you maybe still want wsc and not device_put if the arg is a captured/literal array? I don't usually worry about perf outside of jit (or the wrapper code to get to inside jit), since except for like dataloading it's never inner-loop, at least for my use cases... I think really only device_put behaves differently inside jit vs outside. It's just there's no general way to do a cross-backend transfer inside jit right now. (device_put seems to be growing these capabilities for cpu offload, but it's in flux) |
Sure, but device-put'ing/shading is something that happens as part of dataloading? for data in dataloader:
data = device_put_or_shard_etc(data, device) # Your equivalent of this line calls `is_in_jit()` many times.
grads = jitted_jax_computation(data) etc. Or am I missing a trick here? FWIW at least in terms of #691, I'm thinking that maybe we just shouldn't worry too much about the cross-backend thing, so that the current |
Moving a import jax
import equinox as eqx
import numpy as np
x = np.ones((64, 2))
[device] = jax.devices() # device = cuda:0
y = eqx.filter_shard(x, device)
y.devices() # {cuda(id=0)} assuming that is what you meant. |
Okay, #691 is merged! I'm closing this. Thank you everyone for helping to make this a reality. |
Hi all,
In the JAX example for 8-way batch data parallelism sharding the author shards the parameters of their model as well as their batch of inputs.
These parameters I am guessing allude to the "init, apply fn" style libraries.
I see that in your example of auto-parallelism you don't do this. Is there a reason for this? I ask because I have had some issues with sharding in JAX before, and it is not obvious how to replicate an equinox model across the devices (maybe an
eqx.filter(...)
, due to the non-arrays?).Training an
equinox.Module
without explicitly sharding copies of the model to the devices seems to work - could this just be that the parameters are naturally sharded by their operations with sharded arrays?Cheers,
Jed
The text was updated successfully, but these errors were encountered: