From 95a4bb696a2388f34d50c6873ff93911d6b235de Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 27 Sep 2023 11:42:57 +0200 Subject: [PATCH] Support eetq weight only quantization (#1068) # What does this PR do? Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --------- Co-authored-by: zhaosida --- Cargo.lock | 8 ++-- launcher/src/main.rs | 32 +++++++++++-- server/.gitignore | 1 + server/Makefile | 1 + server/Makefile-eetq | 13 +++++ server/text_generation_server/cli.py | 1 + server/text_generation_server/utils/layers.py | 47 ++++++++++++++++++- 7 files changed, 92 insertions(+), 11 deletions(-) create mode 100644 server/Makefile-eetq diff --git a/Cargo.lock b/Cargo.lock index 5fa01f0c070..8fa7b726665 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2896,18 +2896,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.48" +version = "1.0.49" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d6d7a740b8a666a7e828dd00da9c0dc290dff53154ea77ac109281de90589b7" +checksum = "1177e8c6d7ede7afde3585fd2513e611227efd6481bd78d2e82ba1ce16557ed4" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.48" +version = "1.0.49" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49922ecae66cc8a249b77e68d1d0623c1b2c514f0060c27cdc68bd62a1219d35" +checksum = "10712f02019e9288794769fba95cd6847df9874d49d871d062172f9dd41bc4cc" dependencies = [ "proc-macro2", "quote", diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 09e32f8944b..b4fc86b7bee 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -21,11 +21,32 @@ mod env_runtime; #[derive(Clone, Copy, Debug, ValueEnum)] enum Quantization { + /// 4 bit quantization. Requires a specific GTPQ quantized model: + /// https://hf.co/models?search=awq. + /// Should replace GPTQ models whereever possible because of the better latency + Awq, + /// 8 bit quantization, doesn't require specific model. + /// Should be a drop-in replacement to bitsandbytes with much better performance. + /// Kernels are from https://github.com/NetEase-FuXi/EETQ.git + Eetq, + /// 4 bit quantization. Requires a specific GTPQ quantized model: https://hf.co/models?search=gptq. + /// text-generation-inference will use exllama (faster) kernels whereever possible, and use + /// triton kernel (wider support) when it's not. + /// AWQ has faster kernels. + Gptq, + /// Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, + /// but it is known that the model will be much slower to run than the native f16. + #[deprecated( + since = "1.1.0", + note = "Use `eetq` instead, which provides better latencies overall and is drop-in in most cases" + )] Bitsandbytes, + /// Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, + /// but it is known that the model will be much slower to run than the native f16. BitsandbytesNF4, + /// Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better + /// perplexity performance for you model BitsandbytesFP4, - Gptq, - Awq, } impl std::fmt::Display for Quantization { @@ -47,6 +68,9 @@ impl std::fmt::Display for Quantization { Quantization::Awq => { write!(f, "awq") } + Quantization::Eetq => { + write!(f, "eetq") + } } } } @@ -127,9 +151,7 @@ struct Args { #[clap(long, env)] num_shard: Option, - /// Whether you want the model to be quantized. This will use `bitsandbytes` for - /// quantization on the fly, or `gptq`. 4bit quantization is available through - /// `bitsandbytes` by providing the `bitsandbytes-fp4` or `bitsandbytes-nf4` options. + /// Whether you want the model to be quantized. #[clap(long, env, value_enum)] quantize: Option, diff --git a/server/.gitignore b/server/.gitignore index 40c37f48b9e..dcb8fe6743b 100644 --- a/server/.gitignore +++ b/server/.gitignore @@ -160,3 +160,4 @@ flash-attention/ flash-attention-v2/ vllm/ llm-awq/ +eetq/ diff --git a/server/Makefile b/server/Makefile index b21d79d49fd..52543e3d215 100644 --- a/server/Makefile +++ b/server/Makefile @@ -2,6 +2,7 @@ include Makefile-flash-att include Makefile-flash-att-v2 include Makefile-vllm include Makefile-awq +include Makefile-eetq unit-tests: pytest -s -vv -m "not private" tests diff --git a/server/Makefile-eetq b/server/Makefile-eetq new file mode 100644 index 00000000000..5e8e9830e8a --- /dev/null +++ b/server/Makefile-eetq @@ -0,0 +1,13 @@ +eetq_commit := 323827dd471458a84e9c840f614e4592b157a4b1 + +eetq: + # Clone eetq + pip install packaging + git clone https://github.com/NetEase-FuXi/EETQ.git eetq + +build-eetq: eetq + cd eetq && git fetch && git checkout $(eetq_commit) + cd eetq && python setup.py build + +install-eetq: build-eetq + cd eetq && python setup.py install diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 7464934f6bc..cf9596c9148 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -18,6 +18,7 @@ class Quantization(str, Enum): bitsandbytes_fp4 = "bitsandbytes-fp4" gptq = "gptq" awq = "awq" + eetq = "eetq" class Dtype(str, Enum): diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index fb27764cd41..14cb55cc96c 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -5,6 +5,8 @@ from torch import nn from torch.nn import functional as F from typing import List +from loguru import logger +from functools import lru_cache HAS_BITS_AND_BYTES = True try: @@ -42,6 +44,13 @@ from typing import Optional +HAS_EETQ = False +try: + from EETQ import quant_weights, w8_a16_gemm + HAS_EETQ = True +except ImportError: + pass + # Monkey patching @classmethod def load_layer_norm(cls, prefix, weights, eps): @@ -120,6 +129,30 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: return F.linear(input, self.weight, self.bias) +class EETQLinear(nn.Module): + def __init__( + self, + weight, + bias, + ) -> None: + super().__init__() + device = weight.device + weight = torch.t(weight).contiguous().cpu() + weight, scale = quant_weights(weight, torch.int8, False) + if bias: + bias = weights.get_tensor(f"{prefix}.bias") + else: + bias = None + self.weight = weight.cuda(device) + self.scale = scale.cuda(device) + self.bias = bias.cuda(device) if bias is not None else None + + def forward(self, input: torch.Tensor) -> torch.Tensor: + output = w8_a16_gemm(input, self.weight, self.scale) + output = output + self.bias if self.bias is not None else output + return output + + class Linear8bitLt(nn.Module): def __init__( self, @@ -211,10 +244,20 @@ def forward(self, x: torch.Tensor): return out +@lru_cache(1) +def warn_deprecate_bnb(): + logger.warning("Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce") + def get_linear(weight, bias, quantize): if quantize is None: linear = FastLinear(weight, bias) + elif quantize == "eetq": + if HAS_EETQ: + linear = EETQLinear(weight, bias) + else: + raise ImportError("Please install EETQ from https://github.com/NetEase-FuXi/EETQ") elif quantize == "bitsandbytes": + warn_deprecate_bnb() linear = Linear8bitLt( weight, bias, @@ -298,8 +341,8 @@ def load(config, prefix: str, weights): weight = weights.get_tensor(f"{prefix}.weight") should_gather = False - # GPTQ and AWQ don't quantize heads (nor embeddings) - if config.quantize in ["gptq", "awq"]: + # GPTQ,AWQ,EETQ don't quantize heads (nor embeddings) + if config.quantize in ["gptq", "awq", "eetq"]: quantize = None else: quantize = config.quantize