Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Length masking for batch inputs #1173

Merged
merged 5 commits into from
Dec 19, 2024
Merged

Length masking for batch inputs #1173

merged 5 commits into from
Dec 19, 2024

Conversation

barronalex
Copy link
Collaborator

Allow passing an array of lengths to make_prompt_cache.

create_attention_mask then does the correct masking when you call an mlx-lm model with a batch of padded variable-length inputs and the cache you created above.

Very open to other ways of doing this, but I went with this because it allows us to leave the model code untouched.

@awni
Copy link
Member

awni commented Dec 18, 2024

Very open to other ways of doing this, but I went with this because it allows us to leave the model code untouched.

I do wonder if we should revisit our design of model interface to include an optional mask. For cases that require flexible masks it will be a lot more .. well .. flexible. The other added benefit is it is more functional and easier to compile / export if we decide to go that route.

Other than the tedious aspect of changing the model interface.. how would that work for your use case?

@barronalex
Copy link
Collaborator Author

That would work great too and would definitely be easier to follow in the code.

The only gotcha will be making sure that the mask you pass has the same dtype as the rest of the model so you don’t accidentally upcast.

@awni
Copy link
Member

awni commented Dec 18, 2024

The only gotcha will be making sure that the mask you pass has the same dtype as the rest of the model so you don’t accidentally upcast.

It's a good point and is easy to do. If we do go that route (and maybe either way) we should require the mask type to be the same type as keys/queries/values in the scaled_dot_product fast implementation. Or we could require that result_type(mask.dtype(), queries.dtype()) == queries.dtype() or something like that.

@chimezie
Copy link
Contributor

+1 on "revisit our design of model interface to include an optional mask."

@barronalex
Copy link
Collaborator Author

I'll redraft this to add a mask input everywhere. I'll make a PR in the core repo too for the scaled_dot_product_attention change.

@barronalex
Copy link
Collaborator Author

OK all done.

I'm a little worried we're going to break anyone using model(x, cache) but model(x, mask, cache) is more consistent with all of the other places we use the mask.

@awni
Copy link
Member

awni commented Dec 18, 2024

Nice, thanks a ton for making that change!!

anyone using model(x, cache) but model(x, mask, cache) is more consistent with all of the other places we use the mask.

It's an easy fix for them to make it a kwarg. Let's land it and see what happens.

llms/mlx_lm/models/base.py Outdated Show resolved Hide resolved
Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks!!

@chimezie
Copy link
Contributor

So, just a question to help my understanding of the implications of this: you no longer need to 'remove' any prefix that was cached beforehand (as said here)?

@awni
Copy link
Member

awni commented Dec 19, 2024

So, just a question to help my understanding of the implications of this: you no longer need to 'remove' any prefix that was cached beforehand

Right now this doesn't change any behavior in mlx lm for training or fine-tuning. It's just some additional functionality that let's you specify a mask and some lengths to the mask creation function.

For cases like:

  • prefill prefx
  • generate with question A
  • trim cache to prefix length
  • generate with question B

You could use a mask instead but I would in general not advise it since the cache will grow and you will be doing a whole bunch of wasted computation.

@barronalex barronalex merged commit d4ef909 into main Dec 19, 2024
2 checks passed
@barronalex barronalex deleted the cache-lengths branch December 19, 2024 03:43
@llllvvuu
Copy link
Contributor

This masks out the most recent tokens (right-padding), right?

print(create_causal_mask(3, 1, lengths=mx.array([1, 2, 3, 1])))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants