Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

[Roadmap] Inference performance (for OPT/GPT) #534

Open
2 of 8 tasks
zhisbug opened this issue Jun 21, 2022 · 2 comments
Open
2 of 8 tasks

[Roadmap] Inference performance (for OPT/GPT) #534

zhisbug opened this issue Jun 21, 2022 · 2 comments
Assignees
Labels
enhancement New feature

Comments

@zhisbug
Copy link
Member

zhisbug commented Jun 21, 2022

In order to achieve state-of-the-art serving performance on OPT/GPT, we need to develop the following features, sorted with priority.

Task 1: Align single-GPU decoding performance with FasterTransformer.

Task 1.1: generate kernel performance table

In terms of single-GPU decoding performance, on an OPT model with 2.7B parameters with bs=1, JAX achieves ~65 tokens/s on an A100, while FT achieves ~78 tokens/s on a V100.

Many of our components when running inference is under-optimized. Several TODOs:

  • Compare our kernel time vs. FT kernel time and identify the gap in kernel implementations.
  • Port several of the handcrafted kernels used in FT into XLA, e.g., the masked attention kernel hand-crafted in FT.

Task 1.2: Optimizing the masked attention computation

Our masked attention is performing unnecessary computation; At each decoding step, we attend to the full seq_len instead of the length of the previous tokens. This is due to a static compiler design. We need to:

  • Benchmark how much this full attention cost is in real inference
  • If it is substantial (>20% of the overall), we think of ways to reduce the extra computation as much as possible in a static compiler architecture.

Task 1.3: Use training-like computation on the prompt (prefix)

We should compile at least two executables: one to perform training-like forward computation for the prompts, the other for decoding. Now we are decoding even for each token in the prompt, which is slow.

Task 2: Align intra-op parallelism performance in decoding with FasterTransformer.

Assuming Task#1 is done, we then try to align the performance with FT when autosharding is turned on.

As a reference, on the 2.7B model with bs = 1:

  • FT achieves 2.8x latency reduction when using 8 V100, compared to 1 V100
  • We achieve 1.3x latency increase when using 2 A100, compared to 1 A100.

This might be a simple fallback in the auto sharding solver at batchsize = 1, or some more serious problems. We need to fix it and match or outperform FT in terms of the latency reduction when adding more GPUs.

Task 3: Enable full inter-op parallelism and #mb > 1

3.1 Adding some basic features

Currently, we can do device-placement-like inter-op parallelism at #microbatches = 1, we need to do some additional engineering development to support full inter-op parallelism and #mb > 1.

This development may not improve latency but will boost throughput significantly.

3.2 Reduce ray scheduling overheads as much as possible

We have identified many ray scheduling overheads; They are not substantial at training but become critical at inference. We shall find ways to reduce these scheduling overheads as much as possible.

I have enumerated all sources of Ray overheads (as I can think of) below. We need to benchmark the severity of these overheads and see which ones are substantial in our targeted use cases (e.g., 175B, long decoding steps, long prompts, etc.), because not all of them are easy to hide.

  • repeatedly call shard_args at each decoding step. Note in training, this shard_args is called only once per iteration
  • grab decoded tokens from the last stage that owns the decoding layer to the stage which owns the input layer (when inter-op parallelism is on)
  • Destruct remotebuffs too frequently. See here
  • huggingface-based tokenization, top-p sampling, and beam search -- are they sufficiently efficient?

Task 4: support beam search, output scores/hidden states, and other related text generation utilities

Currently, we only support top-p sampling with one single output sentence. We need to support beam_size > 1
Related near-complete PR for beam search: #491

One good reference is openAI's text generation interface and features

Task 5: Study and improve the current batching mechanism

The batching in serving is complicated because it requires at least two levels of batching:

Batching incoming requests

Currently, we set two thresholds:

  1. TIMEOUT: a time window
  2. MAX_TOKENS: a maximal number of tokens allowed to be processed in a "job launch".

The server will listen to requests until either TIMEOUT is achieved or we have collected tokens reaching MAX_TOKENS, and then send them as one job to the backend for computation.

There are potentially other and better mechanism on how to batch these requests from user requests coming in stochastically; I'll post several related papers later.

Dynamically batched computation

Suppose we are given a batch of sentences to decode. Each sentence in the batch has different lengths, and different user-requested parameters (top_p, beam_size, max_decoding_length, etc.).

How should we enable batched computation of these sentences?

Our current inference system status:

  • compile one executable which runs with batch_size = 1
  • given the sentence batch, we perform inference one sentence at a time in the batch.

Some rough ideas for improvement:

  • Compile executables for different batch size batch_size = 1, 2, 4 etc.
  • Design a batch mechanism that: (1) groups sentences of similar lengths into multiple running batches, (2) dispatches these batches to executables compiled with the closest batch size (need to do some padding)
@merrymercy
Copy link
Member

Notes on possible improvements:

  • cache should be donated

@zhuohan123
Copy link
Member

A small trick to boost performance: In beam search, all beam branches actually share the same prompt context, and they don't need to be shuffled at every timestep. This can reduce both the shuffle overhead and save some memory.

@merrymercy merrymercy changed the title Inference performance (for OPT/GPT) roadmap [Roadmap] Inference performance (for OPT/GPT) Aug 6, 2022
@merrymercy merrymercy added the enhancement New feature label Dec 20, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
enhancement New feature
Projects
None yet
Development

No branches or pull requests

3 participants