Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Sparsity #1

Merged
merged 12 commits into from
Feb 1, 2024
133 changes: 31 additions & 102 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,112 +1,41 @@
<p align="center">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/logos/vllm-logo-text-dark.png">
<img alt="vLLM" src="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/logos/vllm-logo-text-light.png" width=55%>
</picture>
</p>
## Neural Magic vLLM

<h3 align="center">
Easy, fast, and cheap LLM serving for everyone
</h3>
Fork of vLLM with sparsity.

<p align="center">
| <a href="https://docs.vllm.ai"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://discord.gg/jz7wjKhh6g"><b>Discord</b></a> |
### To Run

</p>

---

**The Second vLLM Bay Area Meetup (Jan 31st 5pm-7:30pm PT)**

We are thrilled to announce our second vLLM Meetup!
The vLLM team will share recent updates and roadmap.
We will also have vLLM collaborators from IBM coming up to the stage to discuss their insights on LLM optimizations.
Please register [here](https://lu.ma/ygxbpzhl) and join us!

---

*Latest News* 🔥
- [2023/12] Added ROCm support to vLLM.
- [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing).
- [2023/09] We created our [Discord server](https://discord.gg/jz7wjKhh6g)! Join us to discuss vLLM and LLM serving! We will also post the latest announcements and updates there.
- [2023/09] We released our [PagedAttention paper](https://arxiv.org/abs/2309.06180) on arXiv!
- [2023/08] We would like to express our sincere gratitude to [Andreessen Horowitz](https://a16z.com/2023/08/30/supporting-the-open-source-ai-community/) (a16z) for providing a generous grant to support the open-source development and research of vLLM.
- [2023/07] Added support for LLaMA-2! You can run and serve 7B/13B/70B LLaMA-2s on vLLM with a single command!
- [2023/06] Serving vLLM On any Cloud with SkyPilot. Check out a 1-click [example](https://github.com/skypilot-org/skypilot/blob/master/llm/vllm) to start the vLLM demo, and the [blog post](https://blog.skypilot.co/serving-llm-24x-faster-on-the-cloud-with-vllm-and-skypilot/) for the story behind vLLM development on the clouds.
- [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai).

---
## About
vLLM is a fast and easy-to-use library for LLM inference and serving.

vLLM is fast with:

- State-of-the-art serving throughput
- Efficient management of attention key and value memory with **PagedAttention**
- Continuous batching of incoming requests
- Fast model execution with CUDA/HIP graph
- Quantization: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [SqueezeLLM](https://arxiv.org/abs/2306.07629)
- Optimized CUDA kernels

vLLM is flexible and easy to use with:

- Seamless integration with popular Hugging Face models
- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
- Tensor parallelism support for distributed inference
- Streaming outputs
- OpenAI-compatible API server
- Support NVIDIA GPUs and AMD GPUs

vLLM seamlessly supports many Hugging Face models, including the following architectures:

- Aquila & Aquila2 (`BAAI/AquilaChat2-7B`, `BAAI/AquilaChat2-34B`, `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.)
- Baichuan & Baichuan2 (`baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.)
- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
- ChatGLM (`THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, etc.)
- DeciLM (`Deci/DeciLM-7B`, `Deci/DeciLM-7B-instruct`, etc.)
- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.)
- GPT-2 (`gpt2`, `gpt2-xl`, etc.)
- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.)
- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.)
- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.)
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.)
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
- Qwen2 (`Qwen/Qwen2-7B-beta`, `Qwen/Qwen-7B-Chat-beta`, etc.)
- StableLM(`stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc.)
- Yi (`01-ai/Yi-6B`, `01-ai/Yi-34B`, etc.)

Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
Clone and install magic_wand:

```bash
pip install vllm
git clone https://github.com/neuralmagic/magic_wand.git
cd magic_wand
export TORCH_CUDA_ARCH_LIST=8.6
pip install -e .
```

## Getting Started

Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to get started.
- [Installation](https://vllm.readthedocs.io/en/latest/getting_started/installation.html)
- [Quickstart](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html)
- [Supported Models](https://vllm.readthedocs.io/en/latest/models/supported_models.html)

## Contributing
Install:
```bash
cd ../
pip install -e .
```

We welcome and value any contributions and collaborations.
Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved.
### Run Sample

## Citation
Run a 50% sparse model:

If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180):
```bibtex
@inproceedings{kwon2023efficient,
title={Efficient Memory Management for Large Language Model Serving with PagedAttention},
author={Woosuk Kwon and Zhuohan Li and Siyuan Zhuang and Ying Sheng and Lianmin Zheng and Cody Hao Yu and Joseph E. Gonzalez and Hao Zhang and Ion Stoica},
booktitle={Proceedings of the ACM SIGOPS 29th Symposium on Operating Systems Principles},
year={2023}
}
```
```bash
from vllm import LLM, SamplingParams

model = LLM(
"nm-testing/Llama-2-7b-pruned50-retrained",
sparsity="sparse_w16a16", # If left off, model will be loaded as dense
enforce_eager=True, # Does not work with cudagraphs yet
dtype="float16",
tensor_parallel_size=1,
max_model_len=1024
)

sampling_params = SamplingParams(max_tokens=100, temperature=0)
outputs = model.generate("Hello my name is", sampling_params=sampling_params)
outputs[0].outputs[0].text
```
111 changes: 111 additions & 0 deletions examples/offline_bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import random
import time
import argparse

from vllm import LLM, SamplingParams

NUM_REQUESTS_DEFAULT = 256
MAX_SEQ_LEN_DEFAULT = 1024
MAX_TOKENS_DEFAULT = 128
SAMPLE_PROMPTS = [
# "Hello, my name is",
# "The president of the United States is",
# "The capital of France is",
"The future of AI is",
]


def run_bench(model_name,
model_revision,
is_sparse,
quant_method,
max_seq_len,
max_tokens,
num_requests,
num_gpus,
num_warmup_iters=1,
num_bench_iters=5,
possible_prompts=SAMPLE_PROMPTS,
enforce_eager=True):
print("Run bench with:")
print(f" model_name = {model_name}")
print(f" model_revision = {model_revision}")
print(f" is_sparse = {is_sparse}")
print(f" quant_method = {quant_method}")
print(f" max_seq_len = {max_seq_len}")
print(f" max_tokens = {max_tokens}")
print(f" num_requests = {num_requests}")
print(f" num_gpus = {num_gpus}")
print(f" num_warmup_iters = {num_warmup_iters}")
print(f" num_bench_iters = {num_bench_iters}")

prompts = []
for _ in range(num_requests):
index = random.randint(0, len(possible_prompts) - 1)
prompts.append(possible_prompts[index])

# Create sampling params
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
max_tokens=max_tokens)

# Create LLM
llm = LLM(
model=model_name,
revision=model_revision,
sparsity="sparse_w16a16" if is_sparse else None,
enforce_eager=enforce_eager,
# dtype=torch.bfloat16,
tensor_parallel_size=num_gpus,
gpu_memory_utilization=0.9,
max_model_len=max_seq_len,
quantization=quant_method,
)

for i in range(num_warmup_iters):
start_time = time.time()
outputs = llm.generate(prompts, sampling_params)
elapsed_time = time.time() - start_time
print(f"Warmup iter {i} time: {elapsed_time} [secs]")

iter_times = []
for i in range(num_bench_iters):
start_time = time.time()
outputs = llm.generate(prompts, sampling_params)
iter_times.append(time.time() - start_time)
print(f"Bench iter {i} time: {iter_times[-1]} [secs]")

average_iter_time = sum(iter_times) / num_bench_iters
print(f"Average per iter time: {average_iter_time} [secs]")

# Print outputs of the last iter
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

return average_iter_time


if __name__ == "__main__":
parser = argparse.ArgumentParser()

parser.add_argument("--model_name", type=str, required=True)
parser.add_argument("--model_revision", type=str, default=None)
parser.add_argument('--is_sparse', action='store_true')
parser.add_argument("--quant_method", type=str, default=None)
parser.add_argument("--max_seq_len", type=int, default=MAX_SEQ_LEN_DEFAULT)
parser.add_argument("--max_tokens", type=int, default=MAX_TOKENS_DEFAULT)
parser.add_argument("--num_requests",
type=int,
default=NUM_REQUESTS_DEFAULT)
parser.add_argument("--num_gpus", type=int, default=1)
parser.add_argument("--num_warmup_iters", type=int, default=1)
parser.add_argument("--num_bench_iters", type=int, default=5)

args = parser.parse_args()

run_bench(args.model_name, args.model_revision, args.is_sparse,
args.quant_method, args.max_seq_len, args.max_tokens,
args.num_requests, args.num_gpus, args.num_warmup_iters,
args.num_bench_iters)
27 changes: 27 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(
tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
sparsity: Optional[str] = None,
enforce_eager: bool = False,
max_context_len_to_capture: Optional[int] = None,
) -> None:
Expand All @@ -85,6 +86,7 @@ def __init__(
self.revision = revision
self.tokenizer_revision = tokenizer_revision
self.quantization = quantization
self.sparsity = sparsity
self.enforce_eager = enforce_eager
self.max_context_len_to_capture = max_context_len_to_capture

Expand All @@ -106,6 +108,7 @@ def __init__(
self._verify_load_format()
self._verify_tokenizer_mode()
self._verify_quantization()
self._verify_sparsity()
self._verify_cuda_graph()

def _verify_load_format(self) -> None:
Expand Down Expand Up @@ -144,6 +147,30 @@ def _verify_tokenizer_mode(self) -> None:
"either 'auto' or 'slow'.")
self.tokenizer_mode = tokenizer_mode

def _verify_sparsity(self) -> None:
supported_sparsity = ["sparse_w16a16"]

if self.quantization is not None:
raise ValueError("Both sparsity and quantization detected. Only "
"one or the other is supported at a time.")

if self.sparsity is not None and self.sparsity not in supported_sparsity:
raise ValueError(f"Unknown sparse method: {self.sparsity}. Must "
f"be one of {supported_sparsity}.")

hf_sparsity_config = getattr(self.hf_config, "sparsity_config", None)
if hf_sparsity_config is not None:
hf_sparsity_method = str(
hf_sparsity_config["sparse_method"]).lower()
if self.sparsity is None:
self.sparsity = hf_sparsity_method
elif self.sparsity != hf_sparsity_method:
raise ValueError(
"Sparsity method specified in the model config "
f"({hf_sparsity_method}) does not match the sparsity "
f"method specified in the `sparsity` argument "
f"({self.sparsity}).")

def _verify_quantization(self) -> None:
supported_quantization = ["awq", "gptq", "squeezellm"]
rocm_not_supported_quantization = ["awq"]
Expand Down
24 changes: 17 additions & 7 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class EngineArgs:
revision: Optional[str] = None
tokenizer_revision: Optional[str] = None
quantization: Optional[str] = None
sparsity: Optional[str] = None
enforce_eager: bool = False
max_context_len_to_capture: int = 8192
enable_lora: bool = False
Expand Down Expand Up @@ -197,6 +198,16 @@ def add_cli_args(
'None, we assume the model weights are not '
'quantized and use `dtype` to determine the data '
'type of the weights.')
parser.add_argument(
'--sparsity',
'-s',
type=str,
choices=['sparse_w16a16', None],
default=None,
help='Method used to compress sparse weights. If '
'None, we first check the `sparsity_config` attribute '
'in the model config file. If that is None we assume '
'the model weights are dense')
parser.add_argument('--enforce-eager',
action='store_true',
help='Always use eager-mode PyTorch. If False, '
Expand Down Expand Up @@ -255,13 +266,12 @@ def create_engine_configs(
self,
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig,
Optional[LoRAConfig]]:
model_config = ModelConfig(self.model, self.tokenizer,
self.tokenizer_mode, self.trust_remote_code,
self.download_dir, self.load_format,
self.dtype, self.seed, self.revision,
self.tokenizer_revision, self.max_model_len,
self.quantization, self.enforce_eager,
self.max_context_len_to_capture)
model_config = ModelConfig(
self.model, self.tokenizer, self.tokenizer_mode,
self.trust_remote_code, self.download_dir, self.load_format,
self.dtype, self.seed, self.revision, self.tokenizer_revision,
self.max_model_len, self.quantization, self.sparsity,
self.enforce_eager, self.max_context_len_to_capture)
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space,
Expand Down
1 change: 1 addition & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(
f"load_format={model_config.load_format}, "
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
f"quantization={model_config.quantization}, "
f"sparsity={model_config.sparsity}, "
f"enforce_eager={model_config.enforce_eager}, "
f"seed={model_config.seed})")
# TODO(woosuk): Print more configs in debug mode.
Expand Down
7 changes: 7 additions & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ class LLM:
the `quantization_config` attribute in the model config file. If
that is None, we assume the model weights are not quantized and use
`dtype` to determine the data type of the weights.
sparsity: The format of the sparse model weights. Currently,
we support "sparse_w16a16". If None, we first check the `sparsity`
attribute in the model config file. If that is None, we assume the
model weights are dense and use `dtype` to determine the data
type of the weights.
revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id.
tokenizer_revision: The specific tokenizer version to use. It can be a
Expand Down Expand Up @@ -75,6 +80,7 @@ def __init__(
tensor_parallel_size: int = 1,
dtype: str = "auto",
quantization: Optional[str] = None,
sparsity: Optional[str] = None,
revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None,
seed: int = 0,
Expand All @@ -94,6 +100,7 @@ def __init__(
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
quantization=quantization,
sparsity=sparsity,
revision=revision,
tokenizer_revision=tokenizer_revision,
seed=seed,
Expand Down
Loading