-
Notifications
You must be signed in to change notification settings - Fork 148
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
518 additions
and
93 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# Speculative Decoding | ||
|
||
Speculative decoding describes a set of the methods for speeding up next token generation for autoregressive language models | ||
by attempting to "guess" the next N tokens of the base model. These guesses can be generated in a number of different ways | ||
including: | ||
|
||
- An addtional smaller "draft" model (e.g., Llama-70b and Llama-7b) | ||
- An adapter that extends the sequence dimension of the logits (e.g., Medusa) | ||
- A heuristic (e.g., looking for recurring sequences in the prompt) | ||
|
||
LoRAX implements some of these approaches, with a particular emphasis on supporting adapter-based methods like Medusa | ||
that can be applied per request for task-level speedups. | ||
|
||
## Process | ||
|
||
Most all of the above speculative decoding methods consist of the same two phases: a "draft" phase that generates | ||
candidate tokens and a "verification" phase that accepts some subset of the candidates to add to the response. | ||
|
||
### Draft | ||
|
||
For methods other than assisted generation via a draft model, the *draft step* happens at the end the normal next token | ||
selection phase after generating the logits. Given the logits for the next token and all the tokens that have been | ||
processed previously (input or output) a number of speculative tokens are generated and added to the batch state | ||
for verification in the next inference step. | ||
|
||
### Verification | ||
|
||
Once the speculative logits have been generated, a separate *verification step* is performed whereby the most likely next `S` tokens | ||
are passed through the model again (as part of the normal decoding process) to check for correctness. If any prefix of the `S` tokens | ||
are deemed *correct*, then they can be appended to the response directly. The remaining incorrect speculative tokens are discarded. | ||
|
||
Note that this process adds some compute overhead to the normal decoding step. As such, it will only confer benefits when: | ||
|
||
1. The decoding step is *memory bound* (generally true for most LLMs on modern GPUs). | ||
2. The speculation process is able to consistently predict future tokens correctly. | ||
|
||
## Options | ||
|
||
### Medusa | ||
|
||
See the [Medusa](../models/adapters/medusa.md) guide for details on how this method works and how to use it. | ||
|
||
### Prompt Lookup Decoding | ||
|
||
[Prompt Lookup Decoding](https://github.com/apoorvumang/prompt-lookup-decoding?tab=readme-ov-file) is a simple | ||
herustic method that uses string matching on the input + previously generated tokens to find candidate n-grams. This | ||
method is particularly useful if your generation task will reuse many similar phrases from the input (e.g., in | ||
retrieval augmented generation where citing the input is important). If there is no need to repeat anything from the | ||
input, there will be no speedup and performance may decrease. | ||
|
||
#### Usage | ||
|
||
Initialize LoRAX with the `--speculative-tokens` param. This controls the length of the sequence LoRAX will attempt | ||
to match against in the input and suggest as the continuation of the current token: | ||
|
||
```bash | ||
docker run --gpus all --shm-size 1g -p 8080:80 -v $PWD:/data \ | ||
ghcr.io/predibase/lorax:main \ | ||
--model-id mistralai/Mistral-7B-Instruct-v0.2 \ | ||
--speculative-tokens 3 | ||
``` | ||
|
||
Increasing this value will yield greater speedups when there are long common sequences, but slow things down if there | ||
is little overlap. | ||
|
||
Note that this method is not compatible with Medusa adapters per request. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
112 changes: 35 additions & 77 deletions
112
docs/models/adapters.md → docs/models/adapters/index.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
# LoRA | ||
|
||
[Low Rank Adaptation (LoRA)](https://arxiv.org/abs/2106.09685) is a popular adapter method for fine-tuning response quality. | ||
|
||
LoRAX supports LoRA adapters trained using frameworks like [PEFT](https://github.com/huggingface/peft) and [Ludwig](https://ludwig.ai/). | ||
|
||
## How it works | ||
|
||
``` mermaid | ||
graph BT | ||
I{{X}} --> W; | ||
I --> A[/LoRA A\]; | ||
A --> B[\LoRA B/]; | ||
W --> P((+)); | ||
B--> P; | ||
P --> O{{Y}} | ||
``` | ||
|
||
LoRA works by targeting specific layers of the base model and inserting a new low-rank pair of weights `LoRA A` and `LoRA B` alongside each base model | ||
param `W`. The input `X` is passed through both the original weights and the LoRA weights, and then the activations are summed together | ||
to produce the final layer output `Y`. | ||
|
||
## Usage | ||
|
||
### Supported Target Modules | ||
|
||
When training a LoRA adapter, you can specify which of these layers (or "modules") you wish to target for adaptation. Typically | ||
these are the projection layers in the attention blocks (`q` and `v`, sometimes `k` and `o` as well for LLaMA like models), but can | ||
usually be any linear layer. | ||
|
||
Here is a list of supported target modules for each architecture in LoRAX. Note that in cases where your adapter contains target | ||
modules that LoRAX does not support, LoRAX will ignore those layers and emit a warning on the backend. | ||
|
||
#### Llama | ||
|
||
- `q_proj` | ||
- `k_proj` | ||
- `v_proj` | ||
- `o_proj` | ||
- `gate_proj` | ||
- `up_proj` | ||
- `down_proj` | ||
- `lm_head` | ||
|
||
#### Mistral | ||
|
||
- `q_proj` | ||
- `k_proj` | ||
- `v_proj` | ||
- `o_proj` | ||
- `gate_proj` | ||
- `up_proj` | ||
- `down_proj` | ||
- `lm_head` | ||
|
||
#### Mixtral | ||
|
||
- `q_proj` | ||
- `k_proj` | ||
- `v_proj` | ||
- `o_proj` | ||
- `lm_head` | ||
|
||
#### Gemma | ||
|
||
- `q_proj` | ||
- `k_proj` | ||
- `v_proj` | ||
- `o_proj` | ||
- `gate_proj` | ||
- `up_proj` | ||
- `down_proj` | ||
|
||
#### Phi-3 | ||
|
||
- `qkv_proj` | ||
- `o_proj` | ||
- `gate_up_proj` | ||
- `down_proj` | ||
- `lm_head` | ||
|
||
#### Phi-2 | ||
|
||
- `q_proj` | ||
- `k_proj` | ||
- `v_proj` | ||
- `dense` | ||
- `fc1` | ||
- `fc2` | ||
- `lm_head` | ||
|
||
#### Qwen2 | ||
|
||
- `q_proj` | ||
- `k_proj` | ||
- `v_proj` | ||
- `o_proj` | ||
- `gate_proj` | ||
- `up_proj` | ||
- `down_proj` | ||
- `lm_head` | ||
|
||
#### Qwen | ||
|
||
- `c_attn` | ||
- `c_proj` | ||
- `w1` | ||
- `w2` | ||
- `lm_head` | ||
|
||
#### Command-R | ||
|
||
- `q_proj` | ||
- `k_proj` | ||
- `v_proj` | ||
- `o_proj` | ||
- `gate_proj` | ||
- `up_proj` | ||
- `down_proj` | ||
- `lm_head` | ||
|
||
#### DBRX | ||
|
||
- `Wqkv` | ||
- `out_proj` | ||
- `lm_head` | ||
|
||
#### GPT2 | ||
|
||
- `c_attn` | ||
- `c_proj` | ||
- `c_fc` | ||
|
||
#### Bloom | ||
|
||
- `query_key_value` | ||
- `dense` | ||
- `dense_h_to_4h` | ||
- `dense_4h_to_h` | ||
- `lm_head` | ||
|
||
## How to train | ||
|
||
LoRA is a very popular fine-tuning method for LLMs, and as such there are a number of options for creating them | ||
from your data, including the following (non-exhaustive) options. | ||
|
||
### Open Source | ||
|
||
- [PEFT](https://github.com/huggingface/peft) | ||
- [Ludwig](https://ludwig.ai/) | ||
|
||
### Commercial | ||
|
||
- [Predibase](https://predibase.com/) |
Oops, something went wrong.