Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batched inference on a single node with multiple GPUs #1473

Open
antareson opened this issue Jun 9, 2024 · 9 comments
Open

Batched inference on a single node with multiple GPUs #1473

antareson opened this issue Jun 9, 2024 · 9 comments
Labels
enhancement New feature or request

Comments

@antareson
Copy link

antareson commented Jun 9, 2024

How to infer a batch of encoded tensors (shape = (B, T)) on 4 GPUs, getting 3~4x tokens/s through put compared to on single GPU? (it's for a small model which can be fit into a GPU's mem)

I've tried launching fabric with strategy='dp', 'ddp', 'fsdp' as in commit 7130a36 (2023/12/14 #818). But failed for various reasons.

Meanwhile, generate/sequentially.py is slower than single GPU and tp.py doesn't work for bached inputs out of the box.

@antareson antareson changed the title How to run batched inference on a single node with multiple GPUs? Batched inference on a single node with multiple GPUs Jun 9, 2024
@rasbt
Copy link
Collaborator

rasbt commented Jun 9, 2024

Hi there,

thanks for the suggestions. Currently, batched inference is not supported in LitGPT yet (batching is only supported in training currently). But it's one of the things on our list that we want to add.

@rasbt rasbt added the enhancement New feature or request label Jun 9, 2024
@antareson
Copy link
Author

Thank you for the quick response.

I'd like to try implementing batched inferecne. Could you provide some guidance and suggest a starting point?

e.g. Should I start from one commit before 13fa12c (which dropped FSDP support #813) or should I copy and trim the batched training code?

Thanks

@rasbt
Copy link
Collaborator

rasbt commented Jun 9, 2024

Thanks for your interest and offering help to contribute. I would honestly start with the most recent code because it's changed quite a bit over time. I would probably start with the generate base function in https://github.com/Lightning-AI/litgpt/tree/main/litgpt/generate (and then maybe later the same for the chat function: https://github.com/Lightning-AI/litgpt/blob/main/litgpt/chat/base.py)

@Andrei-Aksionov
Copy link
Collaborator

There is a PR #886 that started work on implementing batched inference.
If you want, you can proceed that work.

@antareson
Copy link
Author

Got it. To make batched inference work on multiple GPUs, would it be recommended to begin with DDP instead of FSDP?

It will be of great help if you could point me to any relevant documentation or code examples. Thanks.

@rasbt
Copy link
Collaborator

rasbt commented Jun 10, 2024

I would even start with single GPU, and then we could think about implementing data or model parallelism later.

@Andrei-Aksionov
Copy link
Collaborator

Single device --> DDP --> FSDP.

Unfortunately I'm not familiar with the problem, so I cannot provide any docs.
But, I've planned to do this anyway.
Hopefully in a couple of weeks I'll be back at my computer, so be able to assist.
In the meantime, try to do as much as you can on your own. The task should be interesting.
Have fun 😊

@FlimFlamm
Copy link

FlimFlamm commented Jun 14, 2024

Just wanted to chime in with some support! (sadly I have been absolutely swamped busy and haven't had the time to return to my original PR, which is probably only useful for starting hints at this point due to how much has changed).

Helpful tip: Ignore my build_mask_cache function

The implementation I used originally was mostly correct, but something is fishy about the mask cache. I was using my edited version to mask out padding tokens by building a custom (B,1,T,T) sized cache, as opposed to the default (1, 1, B, B) sized triangular mask cache that just gets broadcasted onto all the inputs. Apparently this screws with something in the scaled_dot_attention function (still sorting that out). At present i'm just using the original build_mask_cache function and letting the unmasked padding tokens do their thing (which llama 3 very much prefers it seems)

Batched inference is more and more desirable as synthetic data generation becomes more commonplace; looking forward to official support!

Edit: Realized my mistake while taking a shower:

I was inserting the padding mask into the triangular mask in a way that overwrote False values in the upper right triangle (which should all be False) so the attention values coming out for really messed up. What needs to happen is inserting False values into the lower part of the triangle appropriately, while not inserting True's into the upper... Doh!

It's not pretty, but it works:

def build_mask_cache(
        max_seq_length: int, 
        device: Optional[torch.device] = None,
        padding_mask : Optional[torch.Tensor] = None
        ) -> torch.Tensor:

    
    # (B, max_seq_length, max_seq_length) sized tensor of True's
    ones = torch.ones(
        (max_seq_length, max_seq_length), 
        device=device, 
        dtype=torch.bool).unsqueeze(0).repeat(padding_mask.size(0), 1, 1)
    
    # insert the padding mask into the ones tensor
    ones[:, :, :padding_mask.size(1)] = padding_mask[:, :].unsqueeze(1)
    
    # insert False/0 into the upper triangle of the tensor, and add a dimension for the head
    mask = torch.tril(ones).unsqueeze(1)

    return mask

@Andrei-Aksionov
Copy link
Collaborator

Thanks @FlimFlamm for the info.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants