diff --git a/README.md b/README.md index 7e16de7d..838521de 100644 --- a/README.md +++ b/README.md @@ -47,6 +47,8 @@ AutoAWQ is an easy-to-use package for 4-bit quantized models. AutoAWQ speeds up - Your CUDA version must be CUDA 11.8 or later. - AMD: - Your ROCm version must be compatible with Triton. +- Intel CPU and Intel GPU: + - Your torch and intel_extension_for_pytorch package version should at least 2.4 for optimized performance. ### Install from PyPi @@ -60,6 +62,10 @@ There are a few ways to install AutoAWQ: - `INSTALL_KERNELS=1 pip install git+https://github.com/casper-hansen/AutoAWQ.git` - NOTE: This installs https://github.com/casper-hansen/AutoAWQ_kernels +3. From main branch for Intel CPU and Intel XPU optimized performance: + - `pip install intel_extension_for_pytorch` + - `pip install git+https://github.com/casper-hansen/AutoAWQ.git` + ## Usage Under examples, you can find examples of how to quantize, run inference, and benchmark AutoAWQ models. @@ -132,6 +138,9 @@ print(f'Model is quantized and saved at "{quant_path}"') ```python from awq import AutoAWQForCausalLM from transformers import AutoTokenizer, TextStreamer +from awq.utils.utils import get_best_device + +device = get_best_device() quant_path = "TheBloke/zephyr-7B-beta-AWQ" @@ -155,7 +164,7 @@ prompt = "You're standing on the surface of the Earth. "\ tokens = tokenizer( prompt_template.format(prompt=prompt), return_tensors='pt' -).input_ids.cuda() +).input_ids.to(device) # Generate output generation_output = model.generate( @@ -229,18 +238,18 @@ GPU: 2x NVIDIA GeForce RTX 4090 ### CPU - CPU: 48 cores SPR (Intel 4th Gen Xeon CPU) -- Command: `python examples/benchmark.py --model_path --batch_size 1` +- Command: `python examples/benchmark.py --model_path --batch_size 1 --generator hf` | Model | Version | Batch Size | Prefill Length | Decode Length | Prefill tokens/s | Decode tokens/s | Memory | |-------|---------|------------|----------------|---------------|-------------------|------------------|---------------| -| Llama 2 7B | gemm | 1 | 32 | 32 | 817.86 | 70.93 | 1.94 GB (0.00%) | -| Llama 2 7B | gemm | 1 | 2048 | 2048 | 5279.15 | 36.83 | 2.31 GB (0.00%) | -| Falcon | gemm | 1 | 32 | 32 | 337.51 | 26.41 | 9.57 GB (0.01%) | -| Falcon | gemm | 1 | 2048 | 2048 | 546.71 | 18.8 | 13.46 GB (0.01%) | -| Mistral | gemm | 1 | 32 | 32 | 343.08 | 28.46 | 9.74 GB (0.01%) | -| Mistral | gemm | 1 | 2048 | 2048 | 1135.23 | 13.23 | 10.35 GB (0.01%) | -| Vicuna | gemm | 1 | 32 | 32 | 340.73 | 28.86 | 9.59 GB (0.01%) | -| Vicuna | gemm | 1 | 2048 | 2048 | 1143.19 | 11.14 | 10.98 GB (0.01%) | +| TinyLlama 1B | gemm | 1 | 32 | 32 | 817.86 | 70.93 | 1.94 GB (0.00%) | +| TinyLlama 1B | gemm | 1 | 2048 | 2048 | 5279.15 | 36.83 | 2.31 GB (0.00%) | +| Falcon 7B | gemm | 1 | 32 | 32 | 337.51 | 26.41 | 9.57 GB (0.01%) | +| Falcon 7B | gemm | 1 | 2048 | 2048 | 546.71 | 18.8 | 13.46 GB (0.01%) | +| Mistral 7B | gemm | 1 | 32 | 32 | 343.08 | 28.46 | 9.74 GB (0.01%) | +| Mistral 7B | gemm | 1 | 2048 | 2048 | 1135.23 | 13.23 | 10.35 GB (0.01%) | +| Vicuna 7B | gemm | 1 | 32 | 32 | 340.73 | 28.86 | 9.59 GB (0.01%) | +| Vicuna 7B | gemm | 1 | 2048 | 2048 | 1143.19 | 11.14 | 10.98 GB (0.01%) | | Llama 2 13B | gemm | 1 | 32 | 32 | 220.79 | 18.14 | 17.46 GB (0.02%) | | Llama 2 13B | gemm | 1 | 2048 | 2048 | 650.94 | 6.54 | 19.84 GB (0.02%) | | DeepSeek Coder 33B | gemm | 1 | 32 | 32 | 101.61 | 8.58 | 40.80 GB (0.04%) | diff --git a/awq/models/base.py b/awq/models/base.py index a5fbf4c3..c70b0363 100644 --- a/awq/models/base.py +++ b/awq/models/base.py @@ -447,7 +447,7 @@ def from_quantized( bool, Doc("Whether to map the weights to ExLlamaV2 kernels.") ] = False, use_ipex: Annotated[ - bool, Doc("Whether to map the weights to ipex kernels for CPU device.") + bool, Doc("Whether to map the weights to ipex kernels for CPU and XPU device.") ] = False, device_map: Annotated[ Union[str, Dict], @@ -500,8 +500,9 @@ def from_quantized( trust_remote_code=trust_remote_code, ) - use_cpu_ipex = use_ipex or get_best_device() == "cpu" - if use_cpu_ipex and not ipex_available: + best_device = get_best_device() + use_ipex = use_ipex or best_device in ["cpu", "xpu:0"] + if use_ipex and not ipex_available: raise ImportError( "Please install intel_extension_for_pytorch with " "`pip install intel_extension_for_pytorch` for 'ipex' kernel!" @@ -514,7 +515,7 @@ def from_quantized( quant_config.version, use_exllama=use_exllama, use_exllama_v2=use_exllama_v2, - use_ipex=use_cpu_ipex, + use_ipex=use_ipex, ) model.tie_weights() @@ -534,14 +535,12 @@ def from_quantized( # Dispath to devices awq_ext, msg = try_import("awq_ext") if fuse_layers: - if awq_ext is None: + if best_device in ["mps", "cuda:0"] and awq_ext is None: warnings.warn("Skipping fusing modules because AWQ extension is not installed." + msg) else: self.fuse_layers(model) - if use_cpu_ipex: - dtype = torch.bfloat16 - model.to(dtype=dtype, device="cpu") + if use_ipex: # repack qweight to match the ipex kernel. model = ipex_post_init(model) elif quant_config.version == "marlin": diff --git a/awq/modules/fused/attn.py b/awq/modules/fused/attn.py index d4ce6c4e..9e348b51 100644 --- a/awq/modules/fused/attn.py +++ b/awq/modules/fused/attn.py @@ -29,7 +29,7 @@ def __init__(self, head_dim, max_seq_len, device, rope_theta): super(RoPE, self).__init__() self.freqs_cis = nn.Parameter( - self.precompute_freqs_cis(head_dim, max_seq_len * 2, rope_theta).to(device), + self.precompute_freqs_cis(head_dim, max_seq_len, rope_theta).to(device), requires_grad=False, ) @@ -137,8 +137,8 @@ def __init__( self.use_alibi = use_alibi self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1")) - if kwargs.get("max_new_tokens") is not None: - max_seq_len = kwargs["max_new_tokens"] + if kwargs.get("max_length") is not None: + max_seq_len = kwargs["max_length"] self.max_seq_len = max_seq_len self.is_hf_transformers = False diff --git a/awq/modules/linear/gemm_ipex.py b/awq/modules/linear/gemm_ipex.py index dd4a996e..68fafa5d 100644 --- a/awq/modules/linear/gemm_ipex.py +++ b/awq/modules/linear/gemm_ipex.py @@ -1,18 +1,20 @@ import torch import torch.nn as nn +from .gemm import WQLinear_GEMM +from awq.utils.packing_utils import dequantize_gemm try: - from intel_extension_for_pytorch.nn.modules.weight_only_quantization import WeightOnlyQuantizedLinear - assert hasattr(WeightOnlyQuantizedLinear, "from_weight"), "The minimum version for ipex is at least 2.4" + from intel_extension_for_pytorch.llm.quantization import IPEXWeightOnlyQuantizedLinear + assert hasattr(IPEXWeightOnlyQuantizedLinear, "from_weight"), "The minimum version for ipex is at least 2.4" IPEX_INSTALLED = True except: IPEX_INSTALLED = False -class WQLinear_IPEX(nn.Module): +class WQLinear_IPEX(WQLinear_GEMM): - def __init__(self, w_bit, group_size, in_features, out_features, bias, dev): - super().__init__() + def __init__(self, w_bit, group_size, in_features, out_features, bias, dev, training=False): + nn.Module.__init__(self) assert IPEX_INSTALLED, \ "Please install IPEX package with `pip install intel_extension_for_pytorch`." assert w_bit == 4, "Only 4 bit are supported for now." @@ -24,12 +26,15 @@ def __init__(self, w_bit, group_size, in_features, out_features, bias, dev): self.w_bit = w_bit self.group_size = group_size if group_size != -1 else in_features self.scale_dtype = torch.float32 + self.training = training # quick sanity check (make sure aligment) assert self.in_features % self.group_size == 0 assert out_features % (32 // self.w_bit) == 0 self.pack_num = 32 // self.w_bit + self.init_ipex = False + self.register_buffer( "qzeros", torch.zeros( @@ -59,10 +64,13 @@ def __init__(self, w_bit, group_size, in_features, out_features, bias, dev): self.register_buffer("qweight", qweight) def post_init(self): - assert self.qweight.device.type == "cpu" - self.ipex_linear = WeightOnlyQuantizedLinear.from_weight(self.qweight, self.scales, self.qzeros, \ - self.in_features, self.out_features, None, self.bias, \ - self.group_size, None, 0, 1) + assert self.qweight.device.type in ("cpu", "xpu") + + def init_ipex_linear(self): + if not self.training: + self.ipex_linear = IPEXWeightOnlyQuantizedLinear.from_weight(self.qweight, self.scales, self.qzeros, \ + self.in_features, self.out_features, None, self.bias, \ + self.group_size, None, quant_method=1, dtype=0) @classmethod def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None): @@ -79,16 +87,31 @@ def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None): raise NotImplementedError("Only inference is supported for IPEX kernels") - @torch.no_grad() def forward(self, x): assert IPEX_INSTALLED, ( "IPEX kernels could not be loaded. " "Please install with `pip install intel_extension_for_pytorch` and " "refer to the detial https://github.com/intel/intel-extension-for-pytorch/tree/main") - outputs = self.ipex_linear(x) + if not self.init_ipex: + self.init_ipex_linear() + self.init_ipex = True + + if hasattr(self, "ipex_linear"): + with torch.no_grad(): + outputs = self.ipex_linear(x) + else: + outputs = dequantize_gemm(self.qweight, self.qzeros, self.scales, self.w_bit, self.group_size).to(x.dtype) + outputs = torch.matmul(x, outputs) return outputs + + def backward(self, grad_output): + weights = dequantize_gemm(self.qweight, self.qzeros, self.scales, self.w_bit, self.group_size).to(grad_output.dtype) + batch_size = grad_output.shape[0] + grad_input = grad_output.bmm(weights.transpose(0, 1).unsqueeze(0).repeat(batch_size, 1, 1)) + + return grad_input, None, None, None, None, None, None, None def extra_repr(self) -> str: return ("in_features={}, out_features={}, bias={}, w_bit={}, group_size={}".format( diff --git a/awq/utils/fused_utils.py b/awq/utils/fused_utils.py index dcc90ae6..577fc94d 100644 --- a/awq/utils/fused_utils.py +++ b/awq/utils/fused_utils.py @@ -67,7 +67,9 @@ def fuse_qkv(module, q_proj, k_proj, v_proj): else None ) - if isinstance(q_proj, WQLinear_GEMV): + if isinstance(q_proj, WQLinear_IPEX): + q_linear = WQLinear_IPEX + elif isinstance(q_proj, WQLinear_GEMV): q_linear = WQLinear_GEMV elif isinstance(q_proj, WQLinear_GEMM): q_linear = WQLinear_GEMM @@ -79,8 +81,6 @@ def fuse_qkv(module, q_proj, k_proj, v_proj): q_linear = WQLinear_Marlin elif isinstance(q_proj, WQLinear_GEMVFast): q_linear = WQLinear_GEMVFast - elif isinstance(q_proj, WQLinear_IPEX): - q_linear = WQLinear_IPEX qkv_layer = q_linear( q_proj.w_bit, diff --git a/awq/utils/utils.py b/awq/utils/utils.py index 7553c5df..3eb8faca 100644 --- a/awq/utils/utils.py +++ b/awq/utils/utils.py @@ -91,6 +91,8 @@ def get_best_device(): return "mps" elif torch.cuda.is_available(): return "cuda:0" + elif torch.xpu.is_available(): + return "xpu:0" else: return "cpu" diff --git a/examples/benchmark.py b/examples/benchmark.py index f3ff44ec..370de9ab 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -11,9 +11,9 @@ from transformers import AutoTokenizer, GenerationConfig, LogitsProcessor, LogitsProcessorList DEVICE = get_best_device() -if DEVICE == "cpu": +if DEVICE in ["cpu", "xpu:0"]: if ipex_available: - torch_dtype = torch.bfloat16 + torch_dtype = torch.bfloat16 if DEVICE == "cpu" else torch.float16 else: raise ImportError("Please import intel_extension_for_pytorch " "by `pip install intel_extension_for_pytorch`") @@ -29,8 +29,10 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): """The logit processor is called after the model forward.""" # cuda runs async operates, so we synchronize for accurate time measurement - if DEVICE != "cpu": + if DEVICE == "cuda:0": torch.cuda.synchronize() + elif DEVICE == "xpu:0": + torch.xpu.synchronize() # measure time start_time = time.time() @@ -56,8 +58,10 @@ def generate_torch(model, input_ids, n_generate): with torch.inference_mode(): for i in range(n_generate): - if DEVICE != "cpu": + if DEVICE == "cuda:0": torch.cuda.synchronize() + elif DEVICE == "xpu:0": + torch.xpu.synchronize() start = time.time() if i == 0: @@ -69,8 +73,10 @@ def generate_torch(model, input_ids, n_generate): out = model(inputs, use_cache=True) - if DEVICE != "cpu": + if DEVICE == "cuda:0": torch.cuda.synchronize() + elif DEVICE == "xpu:0": + torch.xpu.synchronize() token = out[0][:, -1].max(1)[1].unsqueeze(1) if i == 0: @@ -102,7 +108,7 @@ def generate_hf(model: BaseAWQForCausalLM, input_ids, n_generate): return context_time, generate_time -def run_round(generator, model_path, quant_file, n_generate, input_ids, batch_size, no_safetensors, pretrained): +def run_round(generator, model_path, quant_file, n_generate, context, input_ids, batch_size, no_safetensors, pretrained): print(f" -- Loading model...") if pretrained: @@ -114,7 +120,7 @@ def run_round(generator, model_path, quant_file, n_generate, input_ids, batch_si ) else: model = AutoAWQForCausalLM.from_quantized( - model_path, quant_file, max_seq_len=n_generate, batch_size=batch_size, safetensors=not no_safetensors + model_path, quant_file, max_seq_len=n_generate+context, batch_size=batch_size, safetensors=not no_safetensors ) print(f" -- Warming up...") @@ -149,6 +155,12 @@ def run_round(generator, model_path, quant_file, n_generate, input_ids, batch_si memory_pct = mem_info.rss / memory_info.total total_memory_used = float(mem_info.rss) / (1024 ** 3) print(f" ** Max Memory (device: {DEVICE}): {total_memory_used:.2f} GB ({memory_pct:.2f}%)") + elif DEVICE == "xpu:0": + for device in range(torch.xpu.device_count()): + memory_used = torch.xpu.max_memory_allocated(device) / (1024 ** 3) + total_memory_used += memory_used + memory_pct = memory_used / (torch.xpu.get_device_properties(device).total_memory / (1024 ** 3)) * 100 + print(f" ** Max Memory (device: {device}): {memory_used:.2f} GB ({memory_pct:.2f}%)") else: for device in range(torch.cuda.device_count()): memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 3) @@ -197,14 +209,17 @@ def main(args): for settings in rounds: input_ids = torch.randint(0, tokenizer.vocab_size, (args.batch_size, settings["context"])) - if DEVICE != "cpu": + if DEVICE == "cuda:0": input_ids = input_ids.cuda() + elif DEVICE == "xpu:0": + input_ids = input_ids.to("xpu:0") stats, model_version = run_round( generator, args.model_path, args.quant_file, settings["n_generate"], + settings["context"], input_ids, args.batch_size, args.no_safetensors, @@ -218,8 +233,10 @@ def main(args): df = pd.DataFrame(all_stats) print('Device:', DEVICE) - if DEVICE != "cpu": + if DEVICE == "cuda:0": print('GPU:', torch.cuda.get_device_name()) + elif DEVICE == "xpu:0": + print('XPU:', torch.xpu.get_device_name()) print('Model:', args.model_path) print('Version:', model_version) print(df.to_markdown(index=False)) diff --git a/examples/generate.py b/examples/generate.py index 803e5b9d..23935dd4 100644 --- a/examples/generate.py +++ b/examples/generate.py @@ -1,7 +1,9 @@ import torch from awq import AutoAWQForCausalLM from transformers import AutoTokenizer, TextStreamer +from awq.utils.utils import get_best_device +device = get_best_device() model_id = "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4" tokenizer = AutoTokenizer.from_pretrained(model_id) streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) @@ -26,7 +28,7 @@ add_generation_prompt=True, return_tensors="pt", return_dict=True, -).to("cuda") +).to(device) outputs = model.generate( **inputs,