You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.
In order to achieve state-of-the-art serving performance on OPT/GPT, we need to develop the following features, sorted with priority.
Task 1: Align single-GPU decoding performance with FasterTransformer.
Task 1.1: generate kernel performance table
In terms of single-GPU decoding performance, on an OPT model with 2.7B parameters with bs=1, JAX achieves ~65 tokens/s on an A100, while FT achieves ~78 tokens/s on a V100.
Many of our components when running inference is under-optimized. Several TODOs:
Compare our kernel time vs. FT kernel time and identify the gap in kernel implementations.
Port several of the handcrafted kernels used in FT into XLA, e.g., the masked attention kernel hand-crafted in FT.
Task 1.2: Optimizing the masked attention computation
Our masked attention is performing unnecessary computation; At each decoding step, we attend to the full seq_len instead of the length of the previous tokens. This is due to a static compiler design. We need to:
Benchmark how much this full attention cost is in real inference
If it is substantial (>20% of the overall), we think of ways to reduce the extra computation as much as possible in a static compiler architecture.
Task 1.3: Use training-like computation on the prompt (prefix)
We should compile at least two executables: one to perform training-like forward computation for the prompts, the other for decoding. Now we are decoding even for each token in the prompt, which is slow.
Task 2: Align intra-op parallelism performance in decoding with FasterTransformer.
Assuming Task#1 is done, we then try to align the performance with FT when autosharding is turned on.
As a reference, on the 2.7B model with bs = 1:
FT achieves 2.8x latency reduction when using 8 V100, compared to 1 V100
We achieve 1.3x latency increase when using 2 A100, compared to 1 A100.
This might be a simple fallback in the auto sharding solver at batchsize = 1, or some more serious problems. We need to fix it and match or outperform FT in terms of the latency reduction when adding more GPUs.
Task 3: Enable full inter-op parallelism and #mb > 1
3.1 Adding some basic features
Currently, we can do device-placement-like inter-op parallelism at #microbatches = 1, we need to do some additional engineering development to support full inter-op parallelism and #mb > 1.
This development may not improve latency but will boost throughput significantly.
3.2 Reduce ray scheduling overheads as much as possible
We have identified many ray scheduling overheads; They are not substantial at training but become critical at inference. We shall find ways to reduce these scheduling overheads as much as possible.
I have enumerated all sources of Ray overheads (as I can think of) below. We need to benchmark the severity of these overheads and see which ones are substantial in our targeted use cases (e.g., 175B, long decoding steps, long prompts, etc.), because not all of them are easy to hide.
repeatedly call shard_args at each decoding step. Note in training, this shard_args is called only once per iteration
grab decoded tokens from the last stage that owns the decoding layer to the stage which owns the input layer (when inter-op parallelism is on)
huggingface-based tokenization, top-p sampling, and beam search -- are they sufficiently efficient?
Task 4: support beam search, output scores/hidden states, and other related text generation utilities
Currently, we only support top-p sampling with one single output sentence. We need to support beam_size > 1
Related near-complete PR for beam search: #491
Task 5: Study and improve the current batching mechanism
The batching in serving is complicated because it requires at least two levels of batching:
Batching incoming requests
Currently, we set two thresholds:
TIMEOUT: a time window
MAX_TOKENS: a maximal number of tokens allowed to be processed in a "job launch".
The server will listen to requests until either TIMEOUT is achieved or we have collected tokens reaching MAX_TOKENS, and then send them as one job to the backend for computation.
There are potentially other and better mechanism on how to batch these requests from user requests coming in stochastically; I'll post several related papers later.
Dynamically batched computation
Suppose we are given a batch of sentences to decode. Each sentence in the batch has different lengths, and different user-requested parameters (top_p, beam_size, max_decoding_length, etc.).
How should we enable batched computation of these sentences?
Our current inference system status:
compile one executable which runs with batch_size = 1
given the sentence batch, we perform inference one sentence at a time in the batch.
Some rough ideas for improvement:
Compile executables for different batch size batch_size = 1, 2, 4 etc.
Design a batch mechanism that: (1) groups sentences of similar lengths into multiple running batches, (2) dispatches these batches to executables compiled with the closest batch size (need to do some padding)
The text was updated successfully, but these errors were encountered:
A small trick to boost performance: In beam search, all beam branches actually share the same prompt context, and they don't need to be shuffled at every timestep. This can reduce both the shuffle overhead and save some memory.
In order to achieve state-of-the-art serving performance on OPT/GPT, we need to develop the following features, sorted with priority.
Task 1: Align single-GPU decoding performance with FasterTransformer.
Task 1.1: generate kernel performance table
In terms of single-GPU decoding performance, on an OPT model with 2.7B parameters with
bs=1
, JAX achieves ~65 tokens/s on an A100, while FT achieves ~78 tokens/s on a V100.Many of our components when running inference is under-optimized. Several TODOs:
Task 1.2: Optimizing the masked attention computation
Our masked attention is performing unnecessary computation; At each decoding step, we attend to the full seq_len instead of the length of the previous tokens. This is due to a static compiler design. We need to:
Task 1.3: Use training-like computation on the prompt (prefix)
We should compile at least two executables: one to perform training-like forward computation for the prompts, the other for decoding. Now we are decoding even for each token in the prompt, which is slow.
Task 2: Align intra-op parallelism performance in decoding with FasterTransformer.
Assuming Task#1 is done, we then try to align the performance with FT when autosharding is turned on.
As a reference, on the 2.7B model with
bs = 1
:This might be a simple fallback in the auto sharding solver at batchsize = 1, or some more serious problems. We need to fix it and match or outperform FT in terms of the latency reduction when adding more GPUs.
Task 3: Enable full inter-op parallelism and #mb > 1
3.1 Adding some basic features
Currently, we can do device-placement-like inter-op parallelism at
#microbatches = 1
, we need to do some additional engineering development to support full inter-op parallelism and #mb > 1.This development may not improve latency but will boost throughput significantly.
3.2 Reduce ray scheduling overheads as much as possible
We have identified many ray scheduling overheads; They are not substantial at training but become critical at inference. We shall find ways to reduce these scheduling overheads as much as possible.
I have enumerated all sources of Ray overheads (as I can think of) below. We need to benchmark the severity of these overheads and see which ones are substantial in our targeted use cases (e.g., 175B, long decoding steps, long prompts, etc.), because not all of them are easy to hide.
shard_args
at each decoding step. Note in training, thisshard_args
is called only once per iterationTask 4: support beam search, output scores/hidden states, and other related text generation utilities
Currently, we only support top-p sampling with one single output sentence. We need to support
beam_size > 1
Related near-complete PR for beam search: #491
One good reference is openAI's text generation interface and features
Task 5: Study and improve the current batching mechanism
The batching in serving is complicated because it requires at least two levels of batching:
Batching incoming requests
Currently, we set two thresholds:
TIMEOUT
: a time windowMAX_TOKENS
: a maximal number of tokens allowed to be processed in a "job launch".The server will listen to requests until either
TIMEOUT
is achieved or we have collected tokens reachingMAX_TOKENS
, and then send them as one job to the backend for computation.There are potentially other and better mechanism on how to batch these requests from user requests coming in stochastically; I'll post several related papers later.
Dynamically batched computation
Suppose we are given a batch of sentences to decode. Each sentence in the batch has different lengths, and different user-requested parameters (
top_p
,beam_size
,max_decoding_length
, etc.).How should we enable batched computation of these sentences?
Our current inference system status:
batch_size = 1
Some rough ideas for improvement:
batch_size = 1, 2, 4
etc.The text was updated successfully, but these errors were encountered: