-
-
Notifications
You must be signed in to change notification settings - Fork 149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Possible memory leak? #673
Comments
From skimming the code, I suspect the issue might now lie with Equinox, since you're only using Here's what I would do:
That would be my first naive experimentation. |
Thanks @Artur-Galstyan but I already performed all those analysis before opening the issue. For A100, it runs OOM after 4 transformer layers, and for TPUv3-8 it runs OOM after 16 layers. Both the devices have enough memory to handle a 7B model. In fact, I can load the entire torch model with no issues at all. There is definitely a memory leak, and I don't know if it's because of |
I'll try to investigate this too, later. Unfortunately, I don't have an A100, but if my math is correct, I should (theoretically) be able to load 7B models on a 3090 with 16-bit precision. I'll let you know if I find something. |
Sure thing. Thanks. BTW you can test it on TPUs on Kaggle for free |
You need a |
Oh okay. Let me try that. Thanks |
@patrick-kidger this throws
|
Do the dtype conversion after model creation, not inside of it. In this case you can't delete the old buffer whilst inside a vmap. |
@patrick-kidger fair enough but the other one didn't make any difference. BTW, I figured out where the issue lies, and it is definitely with In a nutshell, inside This fails, and leads to OOMmake_tf_layers = lambda k: TransformerBlock(args, key=k)
self.layers = eqx.filter_vmap(make_tf_layers)(tf_layers_keys)
self.layers = to_dtype(self.layers, jnp.float16) This works thoughself.layers = [to_dtype(TransformerBlock(args, key=k), jnp.float16) for k in tf_layers_keys] |
Oh! That's super weird. What about when there is no (In all cases including the |
Still OOM. And I somewhat have an idea why it is happening. When we use Quick question: Any trivial reason for not supporting dtype for each layer explicitly? |
Ah, maybe I misunderstood the memory constraints of your model. Indeed,
You're asking about adding an e.g. That said, this does seem like a case where that turns out to be fairly complicated. Hmm, it might be that we should reconsider the above approach, or else find some other way to directly create the weights with the desired precision. |
Thanks for the confirmation. Now that the only solution is to create a list of TransformerBlocks in a for loop, what's your recommended way of calling these layers? I want to jit the whole thing but I am not aware of any other method than what I tried above. How can I call self.layers = [to_dtype(TransformerBlock(args, key=k), jnp.float16) for k in tf_layers_keys]
Agreed but I think |
One simple thing that will work at the moment would be to stack all the layers after creating them: layers = [to_dtype(...) ...]
self.layers = jtu.tree_map(lambda *x: jnp.stack(x) if eqx.is_array(x) else x, layers) On dtype: agreed that this seems like it could be thought of differently to the other initialisations. I'm not currently sure if we should (a) bite the bullet and add general initialiser, (b) add just a dtype argument, or (c) maybe there is a way to finesse this with the current API. Either way this seems like a problem worth solving. |
Thanks for the suggestion. I will try that out.
I can help contribute to it. Do you want to open an issue to track this? |
Sure, go ahead and open an issue to track this. |
I don't know if this is the intended behavior here but this doesn't create stacked arrays. |
Sorry, should be |
No worries. Thanks for the clarification. Also, once the weights are stacked, I guess it will be a simple vmapped call? |
Probably a |
Couple of issues here:
dynamic_tf_layers, static_tf_layers = eqx.partition(self.layers, eqx.is_array)
def f(_x, _dynamic_tf_layers):
tf_layer = eqx.combine(_dynamic_tf_layers, static_tf_layers)
return tf_layer(_x, cos_freq, sin_freq, positions, mask), None
h, _ = jax.lax.scan(f, h, dynamic_tf_layers) This throws Pardon if this sounds naive but isn't there a way where you can inspect the input/output shapes of each layer in an equinox module iteratively? |
|
I figured out what's wrong with the class FeedForward(eqx.Module):
layer1: eqx.nn.Linear
layer2: eqx.nn.Linear
def __init__(self, in_features, out_features, key):
key1, key2 = jax.random.split(key)
self.layer1 = eqx.nn.Linear(in_features, out_features, key=key1)
self.layer2 = eqx.nn.Linear(out_features, out_features, key=key2)
def __call__(self, x):
x = jax.vmap(self.layer1(x))
x = jax.vmap(self.layer2(x))
return x
class CustomModel(eqx.Module):
layers: FeedForward
def __init__(self, in_features, out_features, num_layers):
keys = jax.random.split(jax.random.PRNGKey(1), num_layers)
self.layers = [FeedForward(in_features, out_features, key=k) for k in keys]
self.layers = jtu.tree_map(lambda *x: jnp.stack(x) if eqx.is_array(x) else x , *self.layers)
def __call__(self, x):
... If I print out
You can see that the PS: If you just use |
Any suggestions? @patrick-kidger |
Ah, the typo is that def _stack(*x):
is_array = {eqx.is_array(xi) for xi in x}
if len(is_array) == 1:
if is_array.pop():
return jnp.stack(x)
else:
return x
else:
raise ValueError
self.layers = jtu.tree_map(_stack, *layers) not terribly pretty but it works. |
Thanks for the help but it OOMs out as soon as I try to stack layers, and I have no idea why! The simple for loop for creating layers work but stacking them doesn't. Though I can give it one more try, I would take a break here as I am exhausted debugging this.
Yes, it will save a lot of time, and will make the code concise |
Probably because both versions of the weights (pre-stack and post-stack) need to live in memory at the same time for a short while. In any case you shouldn't need to worry about this any more anyway, we've merged your PR over on the dev branch? |
I tried it on TPUv3-8 with the changes on dev branch. Given everything is in half-precision, I used |
Hard to say what's going on from where I'm sitting I'm afraid. I hope this is something you can debug, and then let us know where the error came from! |
I literally have no idea. If it was just a memory issue, the same code should OOM out when we are creating transformer blocks in a loop. The issue lies in replacing the for loop with |
Will pick this up again next week, and will update the findings |
This is now solved but I figured out to solve it in a much better way. @patrick-kidger Mistral-7B port is complete. Here is the source code for your reference. Thanks again for all the help. |
Oh awesome! I'm really glad that you managed to complete this. I'd be happy to link to your GitHub page from the "advanced examples" in the Equinox documentation, if you'd like? By the way, I see a comment in your code about |
That would be awesome! Thanks 🍺
Like
If I initialize RMSNorm with
I will send a PR for this |
@patrick-kidger as you are already aware of the Mistral port I have been working on. I have now put that code in a public repo. There are some
TODOs
remaining that I will complete anyway but I am dealing with a big issue. Here is the condensed code implementation for reference.Given that this model has 7B parameters, I should be able to load this model on a
A100 40G
GPU as well as onTPUv3-8
. The problem is that no matter what device, it always errors out with OOM.I suspected that it might be the case that I am initializing the transformer with full precision, and then converting it to
float16
, so I tried the other way round where in each block I convert each layer tofloat16
as soon as it is initialized. Even after that it errors out with OOM. I suspect there is a memory leak somewhere but I am not able to pinpoint it.Any suggestions/pointers would be very helpful. Thanks in advance 🙏
The text was updated successfully, but these errors were encountered: