-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add batched_next_token()
and batched_sample()
#1693
Conversation
Co-authored-by: Sebastian Raschka <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer if the V0 did
- 2d tensor as input and 2d tensor as input pos.
- apply the above for contiguous tensors and fall back to for loop over batch dim in index_copy_ and index_select when it is not contiguous.
@t-vi The reasoning behind the API being lists instead of a 2d tensor was that in any given batch not all the slots are going to be filled. We don't get to pass a batch of a size other than the size that the kvcache was initialized with, so we need to mask off the garbage. Taking the input as a list was going to make it easier to figure out which outputs need to be ignored. At least, that was the original thought. Then I decided to implement it in I'll do that. It works better with thunder. We're already going to have to compile for a specific batch dimension due to the sizing of the kvcache, so we can bake in the control flow for sampling. But we can't do that inside of the Yeah, that |
The thunder failure appears to be real, but unrelated to this PR. Nightly was broken, and so we merged without the thunder tests, but it looks like something regressed during that time. Will look into it in a bit. |
No, I'm saying we need this really now-ish and the |
It now takes tensors. I've been monkeying with it for a while. The example you provided works for index dim=0 and t.ndim=2, but we also need dim=2 and t.ndim=4 for the kvcache. It would also be useful to just have a version that works, no matter the arguments. Probably easier than maintaining four new separate functions. I tried a bunch of stuff, but haven't been able to get everything hacked together in a working state yet. Partially because it took me an hour or so to realize a subtle misunderstanding in how the indexes are taken. I thought |
I got something mostly working for Turns out that applying the rope cache also needs to change, because every selected index needs its own positional embedding, so the dim of The mask cache is less complicated than I thought though, it turns out that it's always of shape (1, 1, n, n). I'm going to have to write another function to handle it, but it's not that bad. It's been a while, and the scope of this PR has crept significantly, so I'm going to put in another once I have a solution for |
nope, should be t.numel() because the indexing is of the tensor. |
Why? You can just flatten the index and then view the output of |
It's a positional embedding. So, previously we had one position ( |
This is part of a v0 implementation of batching. Looking for feedback. See the comments and talk to me for caveats. There are many.
This only works where
input_pos
can be shared. That is to say, no continual batching, and both prompts need to be the same length. The second one is a rather onerous requirement. It's not that we can't pad to combine the prompts. We acn do that. But we cannot yet pad to combine input_pos. Still figuring that out. It will require rewriting the rope cache and kvcache to work with an input_pos with a batch dimension. Theindex_copy_
andindex_select
operations cannot be used, because they expect vectors and don't work with batches of data.So after talking about it, this broken implementation was the v0 that @lantiga and I decided to go with.
Next steps, roughly in order:
batched_generate_fn()
, mask off the garbage tokens after EOS.generate_fn()
the first time.batched_generate_fn()
work for models with stop sequences.batched_next_token()
acceptinput_pos
as a list. This means we'll be able to continuously batch non-prefill tokens. This requires the fix to the kvcache and the mask cache. See the comment for inspiration on how to implement that.batched_generate_fn()
up to theLLM
api somehow, then switch the container over to using that frontend function.After all of this, we'll be able to serve inference to real users, at reasonable latencies. Let me know if anything in the roadmap seems out of place or out of order.
@rasbt @lantiga @Andrei-Aksionov