Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batched_next_token() and batched_sample() #1693

Merged
merged 9 commits into from
Aug 28, 2024
Merged

Conversation

apaz-cli
Copy link
Contributor

This is part of a v0 implementation of batching. Looking for feedback. See the comments and talk to me for caveats. There are many.

This only works where input_pos can be shared. That is to say, no continual batching, and both prompts need to be the same length. The second one is a rather onerous requirement. It's not that we can't pad to combine the prompts. We acn do that. But we cannot yet pad to combine input_pos. Still figuring that out. It will require rewriting the rope cache and kvcache to work with an input_pos with a batch dimension. The index_copy_ and index_select operations cannot be used, because they expect vectors and don't work with batches of data.

So after talking about it, this broken implementation was the v0 that @lantiga and I decided to go with.

Next steps, roughly in order:

  1. Write batched_generate_fn(), mask off the garbage tokens after EOS.
    • For the V0, may just implement checking tokenizer eos id for each item in the batch. Checking for stop sequences greatly complicates things. As evidenced by how long it took to write generate_fn() the first time.
  2. Hook up to litserve, build a demo.
  3. Make batched_generate_fn() work for models with stop sequences.
  4. Make batched_next_token() accept input_pos as a list. This means we'll be able to continuously batch non-prefill tokens. This requires the fix to the kvcache and the mask cache. See the comment for inspiration on how to implement that.
  5. Hook batched_generate_fn() up to the LLM api somehow, then switch the container over to using that frontend function.
  6. Figure out a plan for continuously batched prefill. This gives us a reasonable time to first token.

After all of this, we'll be able to serve inference to real users, at reasonable latencies. Let me know if anything in the roadmap seems out of place or out of order.

@rasbt @lantiga @Andrei-Aksionov

litgpt/generate/base.py Outdated Show resolved Hide resolved
Copy link
Contributor

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer if the V0 did

  • 2d tensor as input and 2d tensor as input pos.
  • apply the above for contiguous tensors and fall back to for loop over batch dim in index_copy_ and index_select when it is not contiguous.

litgpt/generate/base.py Show resolved Hide resolved
litgpt/generate/base.py Outdated Show resolved Hide resolved
litgpt/generate/base.py Outdated Show resolved Hide resolved
litgpt/generate/base.py Show resolved Hide resolved
@apaz-cli
Copy link
Contributor Author

@t-vi The reasoning behind the API being lists instead of a 2d tensor was that in any given batch not all the slots are going to be filled. We don't get to pass a batch of a size other than the size that the kvcache was initialized with, so we need to mask off the garbage. Taking the input as a list was going to make it easier to figure out which outputs need to be ignored. At least, that was the original thought. Then I decided to implement it in batched_generate(), and never switched back the API.

I'll do that. It works better with thunder. We're already going to have to compile for a specific batch dimension due to the sizing of the kvcache, so we can bake in the control flow for sampling. But we can't do that inside of the batch_next_token() function, because we can't use control flow to check for None. And so it shall return unmasked garbage for now.

Yeah, that index_select_2d looks right. We also need index_copy_2d_. But as you say, keeping it out of this PR for now. Useful for later though.

@apaz-cli
Copy link
Contributor Author

The thunder failure appears to be real, but unrelated to this PR. Nightly was broken, and so we merged without the thunder tests, but it looks like something regressed during that time. Will look into it in a bit.

@t-vi
Copy link
Contributor

t-vi commented Aug 27, 2024

Yeah, that index_select_2d looks right. We also need index_copy_2d_. But as you say, keeping it out of this PR for now. Useful for later though.

No, I'm saying we need this really now-ish and the index_copy_2d should work exactly like the the index_select_2d.

@apaz-cli
Copy link
Contributor Author

It now takes tensors.

I've been monkeying with it for a while. The example you provided works for index dim=0 and t.ndim=2, but we also need dim=2 and t.ndim=4 for the kvcache. It would also be useful to just have a version that works, no matter the arguments. Probably easier than maintaining four new separate functions.

I tried a bunch of stuff, but haven't been able to get everything hacked together in a working state yet. Partially because it took me an hour or so to realize a subtle misunderstanding in how the indexes are taken. I thought index_copy_ could be implemented in terms of index_select, but the index dimensions don't match up and don't work that way. Will step through it in a more principled manner in the morning. Worst case scenario, for loops can get us there. But I'd really rather not implement it by hand that way. Hence all the monkeying.

@apaz-cli
Copy link
Contributor Author

apaz-cli commented Aug 28, 2024

I got something mostly working for index_select. In your example code, t.numel() should be idx.numel().

Turns out that applying the rope cache also needs to change, because every selected index needs its own positional embedding, so the dim of model.sin and model.cos have to change.

The mask cache is less complicated than I thought though, it turns out that it's always of shape (1, 1, n, n). I'm going to have to write another function to handle it, but it's not that bad.

It's been a while, and the scope of this PR has crept significantly, so I'm going to put in another once I have a solution for input_pos.ndim > 1.

@apaz-cli apaz-cli merged commit 3c0c479 into main Aug 28, 2024
8 of 9 checks passed
@apaz-cli apaz-cli deleted the ap/batched_next_token_2 branch August 28, 2024 15:05
@t-vi
Copy link
Contributor

t-vi commented Aug 28, 2024

I got something mostly working for index_select. In your example code, t.numel() should be idx.numel().

nope, should be t.numel() because the indexing is of the tensor.

@t-vi
Copy link
Contributor

t-vi commented Aug 28, 2024

Turns out that applying the rope cache also needs to change, because every selected index needs its own positional embedding, so the dim of model.sin and model.cos have to change.

Why? You can just flatten the index and then view the output of index_select with the shape of the index.

@apaz-cli
Copy link
Contributor Author

@t-vi

Why

It's a positional embedding. So, previously we had one position (index_pos). Now we have many different positions that we're calculating the embedding for. So, after we flatten and view (exactly as you describe) model.sin and model.cos we have an extra batch dimension. It turns out that apply_rope_cache is not set up to deal with that. It does its own incompatible index trickery. So I'm fixing it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants