-
Notifications
You must be signed in to change notification settings - Fork 307
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature Request: Implement Static Cache and Quantization Techniques in CTranslate2 #1717
Comments
Hello, thank you for your information. I will take a look to know how hqq works. Otherwise, the cache's implementation in Ctranslate2 is reallocated depending on the length of sequence. I'm not sure if the static cache (as my understanding, the cache will be pre-allocated before with max size) can speed up much with Ctranslate2, it requires some benchmark to confirm it. Some changes in design is needed. |
@minhthuc2502 Thanks. This blog post should help with HQQ: https://mobiusml.github.io/hqq_blog/ |
Hello, I tried to implement the HQQ for 4 bit quantization, it works but I think it has to combine low-bit matmul kernel to speed up the inference, Do you have any reference for this? |
@minhthuc2502 we use the int4mm kernel from torchao: https://github.com/mobiusml/hqq/blob/master/hqq/backends/torchao.py |
As I understand it, to use int4mm, you have to convert HQQ quant to the format accepted by int4mm (w_q and scale,,,). Do you think it will reduce the performance? Similar to using int4mm kernel, It looks like when we dequantize the weight and make the |
The conversion is only done once, different int4 kernels require different input formats, so that's why we do it via patching so we can support many backends, not just the torchao int4 kernel. The int4mm kernel should be faster than fp16 matmul with or without torch.compile. However, it is a gemv kernel optimized for decoding only, so the prefill phase with this kernel is actually slower. That's why we don't use it in the encoder, and only use it in the decoding phase. Unfortunately, there's no way so far to efficiently dequantize() and do fp16 matmul in the prefill phase which should be faster. But decoding one-token at a time with the int4mm is def faster. There are other options by the way, like https://github.com/microsoft/BitBLAS/ , they also support A16W2 matmul which should be even faster for larger models. |
@minhthuc2502 Did you figure out how to speed up the HQQ implementation in ctranslate2? This will be a useful add-on for large E-D models. |
I tried to implement HQQ, only quantization, but I have not do the benchmark yet. The thing which prevented me to make it work correctly is patching the HQQ format. As I understand I will patch the HQQ format in the conversion time and then do the inference with the new format. BTW, I'll try to implement it. |
I tried hqq quant in this PR . When I use I tried using hqq quant first and then converting directly to torch quant format to have scales and zeros and weights. Then before the inference, I will convert the weight to int4pack and then use the int4mm for matmul. |
@minhthuc2502 you have to quantize with hqq. The link you shared is just doing RTN quantization, which will give bad quality especially at lower bits. |
I followed this to add the hqq https://github.com/mobiusml/hqq/blob/master/hqq/core/quantize.py. Did I miss something? |
@minhthuc2502 yes, it's not performing optimization: https://github.com/mobiusml/hqq/blob/master/hqq/core/quantize.py#L115-L122 , which is the actual hqq algo: https://github.com/mobiusml/hqq/blob/master/hqq/core/optimize.py#L194-L243 from hqq.core.quantize import HQQLinear, BaseQuantizeConfig
from hqq.backends.torchao import patch_hqq_to_aoint4
quant_config = BaseQuantizeConfig(nbits=4, group_size=64, quant_zero=False, quant_scale=False, axis=1)
hqq_ao_layer = patch_hqq_to_aoint4(HQQLinear(your_linear_layer, quant_config=quant_config, compute_dtype=torch.bfloat16, device='cuda'), None)
out = hqq_ao_layer.forward(x) |
Thank you for your suggestion. I just use directly the |
Glad to hear it worked better! It should work fine with 4-bit and a group-size of 64 as suggested in the code above. Which model did you try the summarization with? Do you have a code snippet I can run to investigate ? |
I tested with LLama2. You can do a simple generation with llama2 and make a prompt like this: The result i got: I applied int4mm for prefill step + generation step |
I tried with Llama2-7B and it's working fine: import torch, os
cache_path = '.'
model_id = "meta-llama/Llama-2-7b-chat-hf"
compute_dtype = torch.bfloat16 #int4 kernel only works with bfloat16
device = 'cuda:0'
##########################################################################################################################################################
from hqq.engine.hf import HQQModelForCausalLM, AutoTokenizer
from hqq.core.quantize import *
#Load
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_path)
#No quantize
# from transformers import AutoModelForCausalLM
# model = AutoModelForCausalLM.from_pretrained(model_id, cache_dir=cache_path, torch_dtype=compute_dtype, attn_implementation="sdpa", device_map = device)
#Quantize
model = HQQModelForCausalLM.from_pretrained(model_id, cache_dir=cache_path, torch_dtype=compute_dtype, attn_implementation="sdpa")
quant_config = BaseQuantizeConfig(nbits=4, group_size=64, quant_scale=False, quant_zero=False, axis=1)
model.quantize_model(quant_config=quant_config, compute_dtype=compute_dtype, device=device)
#Set default backends, to compare with int4mm
if(quant_config['weight_quant_params']['axis']==0):
HQQLinear.set_backend(HQQBackend.ATEN)
else:
HQQLinear.set_backend(HQQBackend.PYTORCH)
##########################################################################################################################################################
from hqq.utils.patching import prepare_for_inference
prepare_for_inference(model, backend="torchao_int4")
#Import custom HF generator
from hqq.utils.generation_hf import HFGenerator
#Generate
#gen = HFGenerator(model, tokenizer, max_new_tokens=1000, do_sample=True, compile="partial").warmup()
gen = HFGenerator(model, tokenizer, max_new_tokens=1000, do_sample=False, compile=None)
prompt = "Summarize this paragraph: Roger Federer (born 8 August 1981) is a Swiss former professional tennis player. Federer was ranked world No. 1 in singles by the Association of Tennis Professionals (ATP) for 310 weeks, including a record 237 consecutive weeks, and finished as the year-end No. 1 five times. He won 103 singles titles on the ATP Tour, the second most of all time, including 20 major men's singles titles (among which a record eight men's singles Wimbledon titles, and an Open Era joint-record five men's singles US Open titles) and six year-end championships."
out = gen.generate(prompt, print_tokens=True) Outputs:
|
It's weird. It seems like I used exactly the parameters for hqq quantization + int4mm. |
@minhthuc2502 Could you please let us know if you have some updates on this? |
@minhthuc2502 @alexlnkp
Description
What type of cache is currently implemented in CTranslate2? Is it static or dynamic? Could we achieve a speed-up if the cache implementation is changed for the decoder in encoder-decoder models?
Also, it would be great to implement recent popular quantization techniques such as [HQQ] (https://github.com/mobiusml/hqq) in ctranslate2 format.
Motivation
Given that a static cache (see this PR) can significantly speed up processing in PyTorch encoder-decoder models via torch compilation, can we enable this in CTranslate2? This enhancement can improve decoding speed for projects utilizing CTranslate2 models, such as Faster Whisper.
References
Speed-up achieved for PyTorch-based Whisper: Blog Post
Benefits
Implementing static caching and recent quantization techniques in CTranslate2 could lead to significant performance improvements in model decoding speeds and efficiency.
Thank you for considering this feature request!
The text was updated successfully, but these errors were encountered: