Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sharding - shard eqx.Module as well as inputs? #688

Closed
homerjed opened this issue Mar 21, 2024 · 19 comments
Closed

Sharding - shard eqx.Module as well as inputs? #688

homerjed opened this issue Mar 21, 2024 · 19 comments
Labels
documentation Improvements or additions to documentation

Comments

@homerjed
Copy link
Contributor

Hi all,

In the JAX example for 8-way batch data parallelism sharding the author shards the parameters of their model as well as their batch of inputs.

These parameters I am guessing allude to the "init, apply fn" style libraries.

I see that in your example of auto-parallelism you don't do this. Is there a reason for this? I ask because I have had some issues with sharding in JAX before, and it is not obvious how to replicate an equinox model across the devices (maybe an eqx.filter(...), due to the non-arrays?).

Training an equinox.Module without explicitly sharding copies of the model to the devices seems to work - could this just be that the parameters are naturally sharded by their operations with sharded arrays?

Cheers,
Jed

@patrick-kidger
Copy link
Owner

The reason is just that I should really get around to updating the Equinox doc :)

I do actually have an update written for it in this PR but I wasn't completely satisfied with the new APIs it offers, so this has been languishing a bit. (I'm wondering whether we can unify device_put and with_sharding_constraint into a single API; they both basically do the same thing, just from outside/inside JIT.)

If you haven't seen it yet then you may also like Levanter, which is built on top of Equinox to perform training of large-scale models. I believe it has a number of far-more-carefully-thought-out APIs for parallelism!

@patrick-kidger patrick-kidger added the documentation Improvements or additions to documentation label Mar 21, 2024
@homerjed
Copy link
Contributor Author

homerjed commented Mar 22, 2024

Ah cool, sounds good :)

I had a quick look. I see one should use the jax.lax.with_sharding_constraint over device_put for inside JIT.

I guess one way to combine the equinox implementations of device_put and with_sharding_constraint in your PR would be to consider whether the inputs device and shardings (respectively) are a device or sharding.

Putting a device into a jax.lax.with_sharding_constraint call throws an error but a sharding into a jax.device_put call doesn't.

[device] = jax.devices()
# Error: shardings leaf specifications are expected to be PartitionSpec instances or None
lax.with_sharding_constraint(jnp.ones((2,)), device) 

You could correct this for any sharding or device someone might try to filter with

[device] = jax.devices()
# If someone wants to put the arrays (that would be filtered from a PyTree) onto a device, convert it to a sharding first
device = jax.sharding.SingleDeviceSharding(device)
# No error
lax.with_sharding_constraint(jnp.ones((2,)), device)

which negates the use of jax.device_put, since you can do this second operation outside of JIT anyway.

Here is a small example:

from typing import Any
import jax
import jaxlib
from jax import lax
import jax.experimental.mesh_utils as mesh_utils
import jax.numpy as jnp
import jax.random as jr
import jax.sharding as sharding
import equinox as eqx
from jaxtyping import PyTree

data_dim = 2
hidden_size = 8
depth = 1

key = jr.PRNGKey(0)
key_u, key_model = jr.split(key)

x = jr.uniform(key_u, (data_dim,)) * 2. * jnp.pi
y = jnp.sin(x)

model = eqx.nn.MLP(
  data_dim, data_dim, hidden_size, depth, key=key
)

def filter_with_sharding_constraint(
    x: PyTree[Any], 
    device_or_shardings: jaxlib.xla_extension.Device | jax.sharding.Sharding
):
    if isinstance(device_or_shardings, jaxlib.xla_extension.Device):
        shardings = jax.sharding.SingleDeviceSharding(device_or_shardings)
    else:
        shardings = device_or_shardings
    dynamic, static = eqx.partition(x, eqx.is_array)
    dynamic = lax.with_sharding_constraint(dynamic, shardings)
    return eqx.combine(dynamic, static)

num_devices = len(jax.devices())
devices = mesh_utils.create_device_mesh((num_devices, 1))
sharding = jax.sharding.PositionalSharding(devices)
replicated = sharding.replicate()

model = filter_with_sharding_constraint(model, sharding)

@eqx.filter_jit(donate="all-except-first")
def evaluate(model, x, y, sharding):
    replicated = sharding.replicate() 
    model = filter_with_sharding_constraint(model, replicated)
    x, y = filter_with_sharding_constraint((x, y), sharding)
    return jnp.mean((y - model(x)) ** 2.)

evaluate(model, x, y, sharding)

# Separately... 
model = eqx.nn.MLP(
  data_dim, data_dim, hidden_size, depth, key=key
)
[device] = jax.devices() # Single device case
filter_with_sharding_constraint(model, device)

In this example, replicating the sharding would throw an error if it is a device. The user should know that which seems fair since they have entered the domain of sharding etc. at this point (e.g. they would use filter_with_sharding_constraint defined in this example simply to put the model onto a specific device, the replication corresponding to a specific use case of the user). I think the use cases are basically only for "put the arrays in this pytree on this device" or "shard the pytree arrays across this sharding pattern" which the function covers.

The function may need a different name, something a bit less clunky than filter_with_device_or_sharding?

@patrick-kidger
Copy link
Owner

Oh, that's exceptionally elegant. I really like that!

So I've just checked the equivalent Levanter (Haliax) function, which is implemented here:

https://github.com/stanford-crfm/haliax/blob/b6ecf159f4912ba89e714c552ee1a476e2248e6e/src/haliax/partitioning.py#L101

(Note that the mapping stuff is a Haliax abstraction that I believe we can ignore.)
The outside-of-JIT part is these lines:

https://github.com/stanford-crfm/haliax/blob/b6ecf159f4912ba89e714c552ee1a476e2248e6e/src/haliax/partitioning.py#L147-L164

Notably they do seem to try to do something more complicated, although I'm not sure Do we miss anything by using with_sharding_constraint when outside of JIT?

As for the name, agreed on something less clunky! Maybe just filter_shard?

@homerjed
Copy link
Contributor Author

Great - it felt almost too simple - which it may be, barring anything I might be missing with using with_sharding_constraint outside of JIT.

I think I see what Levanter are doing, it seems to be a kind of filtering except for named arrays. They check if x is a tracer and the operations after this check are all Haliax-specific, so no worries for us there. They simply return x if it isn't an array, which I guess our dynamic, static = ... deals with, so again, no worries there. Then everything else is Haliax-specific.

So I think the _do_device_put() is perhaps solely for the Haliax / Levanter way of doing things (named everything). I assume we don't need that functionality (a user would just use Levanter) and whether its complexity would balance, given the above filter_with_device_or_sharding is so simple.

They're also assessing whether the equivalent device_or_sharding (from my function above) is addressable, this may be something to think about? For me, this feels like another threshold that a user knowingly crosses and for a typical N-device sharding, this 'address-ability' will always be the trivial case, so I think only advanced users need to worry about that.

I've looked into the jax.lax.with_sharding_constraint source and can't see anything obvious that would trip up a user doing the two cases I mentioned outside of JIT. I could communicate with the jax people about the use of the function outside of JIT, but I've tested it inside and out of JIT and it seems to be nothing too crazy. I've checked out all the jax GitHub issues related to jax.lax.with_sharding_constraint and nothing jumps out that would be relevant to this function either.

Yes, filter_shard sounds better, of course a device is simply a single-device sharding anyway.

def filter_shard(
    x: PyTree[Any], 
    device_or_shard: jaxlib.xla_extension.Device | jax.sharding.Sharding
):
    if isinstance(device_or_shard, jaxlib.xla_extension.Device):
        shardings = jax.sharding.SingleDeviceSharding(device_or_shardings)
    else:
        shardings = device_or_shardings
    dynamic, static = eqx.partition(x, eqx.is_array)
    dynamic = lax.with_sharding_constraint(dynamic, shardings)
    return eqx.combine(dynamic, static)

@patrick-kidger
Copy link
Owner

patrick-kidger commented Mar 23, 2024

Great!
I'll tag @dlwh here just in case he has any thoughts on the addressability point.

As for the rest of it -- if you feel so inclined then I'd be very happy to take a PR (against the dev branch) adding this in!

@dlwh
Copy link
Contributor

dlwh commented Mar 26, 2024

At some point I'm pretty sure wsc didn't work outside of jit and I missed the memo on when it started working! Or I could have just totally misunderstood. Looks like I can simplify things on my end!

Partitioning is indeed more elegant than what I was doing. I will switch to that. (I'll probably just call out to eqx.filter_shard once it lands!)

Aside from that, the name mapping stuff is Haliax specific, but the rest probably needn't be. All hax.shard does is:

  • compute logical to physical name mapping (similar to what's done in Flax), giving PartitionSpecs/NamedShardings
  • apply a sharding primitive depending on the context (jit or not) and the kind of arg (I didn't think to just use partition, dumb)

I think in the bad-old-days of pre-jax.Array, wsc didn't work outside of jit, and I never bothered to look at the change, and a lot of the logic around shardings was trying to make that function work when everything was still in flux.

Or, again, could have been operating under a misapprehension all this time.

Re address-ability: You should test to make sure your function works in multi-process mode, e.g. on a multihost TPU, or multi GPU (when using one process per gpu), or multihost GPU. It's very easy to get bitten by something that looks right on a single machine but suddenly isn't once you go multi-machine.

the is_equivalent_to was to get around array_from_callback when the sharding was a no-op.

@dlwh
Copy link
Contributor

dlwh commented Mar 26, 2024

just tested it out, using wsc instead of my complexity works just fine. thanks for finding that!

@homerjed
Copy link
Contributor Author

Forgot to link the commit here...

add eqx.filter_shard; test + examples/parallelism.ipynb updated #691

@patrick-kidger
Copy link
Owner

Awesome stuff! Thanks for your help @dlwh and for the PR @homerjed, I'm excited to see this updated :)

@dlwh
Copy link
Contributor

dlwh commented Mar 27, 2024

FYI I forgot one of the reasons for why I did it this way. wsc doesn't work if the source array doesn't share a backend with the the target sharding. Arguably those should be different functions, but I kind of like them being the same.

@patrick-kidger
Copy link
Owner

Hmmm, bother, that might be annoying. IIUC this is an issue specifically when outside of JIT, since within JIT then everything is necessarily on the same backend. Maybe we should just do a device_put when outside of JIT? Or check this explicitly with a call to array.devices()? Or call both device_put and wsc? Honestly I'm spitballing here, it's not super clear to me how best to accomplish this.

@dlwh
Copy link
Contributor

dlwh commented Mar 29, 2024

I think jit is actually going to start support at least cross-device device_put (there's work on cpu offloading). I'm gonna bother the jax team and see what I can get

@dlwh
Copy link
Contributor

dlwh commented Mar 29, 2024

I still need to test distirbuted/multihost with this patch, but this fixes cpu->gpu

stanford-crfm/haliax@958fe28

@dlwh
Copy link
Contributor

dlwh commented Mar 29, 2024

I think they don't particularly like the way I'm doing it (they don't like me sniffing out jit vs not jit), but I get the sense that ^^ should work?

It seems like they're thinking about merging the two functions jax-ml/jax#19670 (comment)

@patrick-kidger
Copy link
Owner

So my main concern with performing is_in_jit() checks is how much that will probably slow things down when you're not in JIT. I'm expecting that the the jnp.zeros call it wraps will probably take a long time when outside of JIT, given how poorly JAX optimises this case.

Whilst I've not tried benchmarking the effect here, as a rule of thumb I try to only perform such checks outside of performance-critical contexts.

(If this is indeed a performance concern, then FWIW at least one improvement would be to move the is_in_jit() check to happen just once outside of the tree-map. Another approach might be to look inside JAX's private APIs and see what is sitting in the trace stack. If they didn't want us sniffing JIT-vs-non-JIT then they shouldn't have given us operations that do different things in each context... !)

@dlwh
Copy link
Contributor

dlwh commented Mar 30, 2024

Yeah i think sniffing the trace out is probably the cheapest thing. It's almost enough to check if the array is a tracer or not via instanceof, but I think you maybe still want wsc and not device_put if the arg is a captured/literal array?

I don't usually worry about perf outside of jit (or the wrapper code to get to inside jit), since except for like dataloading it's never inner-loop, at least for my use cases...

I think really only device_put behaves differently inside jit vs outside. It's just there's no general way to do a cross-backend transfer inside jit right now. (device_put seems to be growing these capabilities for cpu offload, but it's in flux)

@patrick-kidger
Copy link
Owner

I don't usually worry about perf outside of jit (or the wrapper code to get to inside jit), since except for like dataloading it's never inner-loop, at least for my use cases...

Sure, but device-put'ing/shading is something that happens as part of dataloading?

for data in dataloader:
    data = device_put_or_shard_etc(data, device)  # Your equivalent of this line calls `is_in_jit()` many times.
    grads = jitted_jax_computation(data)

etc.

Or am I missing a trick here?


FWIW at least in terms of #691, I'm thinking that maybe we just shouldn't worry too much about the cross-backend thing, so that the current filter_shard implementation is fine. As long as it supports moving NumPy arrays to the GPU (?), then I think we should already support the most common use cases?

@homerjed
Copy link
Contributor Author

Moving a numpy array straight to the GPU works fine:

import jax
import equinox as eqx
import numpy as np

x = np.ones((64, 2))

[device] = jax.devices() # device = cuda:0

y = eqx.filter_shard(x, device)

y.devices() # {cuda(id=0)}

assuming that is what you meant.

@patrick-kidger
Copy link
Owner

Okay, #691 is merged! I'm closing this. Thank you everyone for helping to make this a reality.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

No branches or pull requests

3 participants