-
Notifications
You must be signed in to change notification settings - Fork 936
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Length masking for batch inputs #1173
Conversation
I do wonder if we should revisit our design of model interface to include an optional mask. For cases that require flexible masks it will be a lot more .. well .. flexible. The other added benefit is it is more functional and easier to compile / export if we decide to go that route. Other than the tedious aspect of changing the model interface.. how would that work for your use case? |
That would work great too and would definitely be easier to follow in the code. The only gotcha will be making sure that the mask you pass has the same dtype as the rest of the model so you don’t accidentally upcast. |
It's a good point and is easy to do. If we do go that route (and maybe either way) we should require the mask type to be the same type as keys/queries/values in the |
+1 on "revisit our design of model interface to include an optional mask." |
I'll redraft this to add a mask input everywhere. I'll make a PR in the core repo too for the |
OK all done. I'm a little worried we're going to break anyone using |
Nice, thanks a ton for making that change!!
It's an easy fix for them to make it a kwarg. Let's land it and see what happens. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks!!
So, just a question to help my understanding of the implications of this: you no longer need to 'remove' any prefix that was cached beforehand (as said here)? |
Right now this doesn't change any behavior in mlx lm for training or fine-tuning. It's just some additional functionality that let's you specify a mask and some lengths to the mask creation function. For cases like:
You could use a mask instead but I would in general not advise it since the cache will grow and you will be doing a whole bunch of wasted computation. |
This masks out the most recent tokens (right-padding), right? print(create_causal_mask(3, 1, lengths=mx.array([1, 2, 3, 1]))) |
Allow passing an array of lengths to
make_prompt_cache
.create_attention_mask
then does the correct masking when you call an mlx-lm model with a batch of padded variable-length inputs and the cache you created above.Very open to other ways of doing this, but I went with this because it allows us to leave the model code untouched.