diff --git a/auto_round/auto_quantizer.py b/auto_round/auto_quantizer.py index e0a6b03d..8a4952fc 100644 --- a/auto_round/auto_quantizer.py +++ b/auto_round/auto_quantizer.py @@ -317,6 +317,13 @@ def convert_model(self, model: nn.Module): return model def _dynamic_import_inference_linear(self, bits, backend): + if (not torch.cuda.is_available()) or "qbits" in backend or "cpu" in backend: + try: + from intel_extension_for_transformers import qbits # pylint: disable=E0401 + except Exception as e: + raise ImportError("Please install Intel Extension for Transformers via 'pip install " + "intel-extension-for-transformers' to inference on X86 CPU") + return qlinear_qbits.QuantLinear if bits == 4 and self.exllama2_available and "exllamav2" in backend: from auto_round_extension.cuda.qliner_exllamav2 import QuantLinear else: @@ -341,9 +348,10 @@ def _replace_by_quant_layers(self, module: nn.Module, layer_configs, backend): data_type = config["data_type"] if not (bits <= 8 and data_type == "int"): continue - QuantLinear = self._dynamic_import_inference_linear(bits, backend) + layer = get_module(module, layer_name) device = get_device(layer) + QuantLinear = self._dynamic_import_inference_linear(bits, backend) if isinstance(layer, nn.Linear): in_features = layer.in_features out_features = layer.out_features @@ -363,24 +371,13 @@ def _replace_by_quant_layers(self, module: nn.Module, layer_configs, backend): weight_dtype=layer.weight.dtype, ) - if new_layer.qweight.device.type == "cpu": # fallback to qbits linear when qweight on cpu device - QuantLinear = qlinear_qbits.QuantLinear - new_layer = QuantLinear( # pylint: disable=E1123 - bits, - group_size, - in_features, - out_features, - bias, - weight_dtype=layer.weight.dtype, - ) - new_layer.device = device set_module(module, layer_name, new_layer) def qbits_post_init(self, model): dep_check = True for layer in model.modules(): - if isinstance(layer,qlinear_qbits.QuantLinear): + if isinstance(layer, qlinear_qbits.QuantLinear): if dep_check: layer.req_check() layer.post_init() @@ -408,7 +405,7 @@ class StoreAttr(object): model = autoround_post_init(model) # there are no side-effects after call qbits_post_init when model quant-type not equal to qbits. model = self.qbits_post_init(model) - + return model def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs): diff --git a/auto_round_extension/qbits/qlinear_qbits.py b/auto_round_extension/qbits/qlinear_qbits.py index 758e615a..1c567946 100644 --- a/auto_round_extension/qbits/qlinear_qbits.py +++ b/auto_round_extension/qbits/qlinear_qbits.py @@ -20,12 +20,6 @@ import torch.nn as nn from auto_round.utils import convert_dtype_torch2str, logger QBITS_AVAILABLE = True -try: - from intel_extension_for_transformers import qbits # noqa: F401 -except Exception as e: - QBITS_AVAILABLE = False - logger.warning( - "qlinear_qbits should be used with Intel Extension for Transformers.") BITS_DTYPE_MAPPING = { 2: "int2_clip", @@ -62,6 +56,7 @@ def __init__( self.maxq = 2**self.bits - 1 self.weight_dtype = weight_dtype self.asym = True + self.qbits = None self.register_buffer( "qweight", @@ -98,6 +93,7 @@ def __init__( def req_check(self): torch_version = str(torch.__version__) if QBITS_AVAILABLE: + import intel_extension_for_transformers itrex_version = str(intel_extension_for_transformers.__version__) version_match_map = {"1.4": "2.2.0+cpu", "1.4.1": "2.2.0+cpu", "1.4.2": "2.3.0+cpu"} @@ -111,6 +107,8 @@ def req_check(self): exit(1) def post_init(self): + import intel_extension_for_transformers + self.qbits = intel_extension_for_transformers.qbits assert self.qweight.device.type == "cpu" if self.bias is not None: self.bias = self.bias.to(dtype=torch.float32) @@ -142,7 +140,7 @@ def post_init(self): logger.info( f"QBits repack quantized weight: K:{intweight.shape[0]}, N:{intweight.shape[1]}, weight_dtype:{BITS_DTYPE_MAPPING[self.bits]}, scale_dtype:fp32, compute_dtype:fp32, group_size:{self.group_size}") - self.qweight = qbits.repack_quantized_weight(intweight.contiguous(), scales.float().contiguous(), zeros.contiguous(), torch.empty(0), + self.qweight = self.qbits.repack_quantized_weight(intweight.contiguous(), scales.float().contiguous(), zeros.contiguous(), torch.empty(0), # weight_dtype BITS_DTYPE_MAPPING[self.bits], # scale_dtype @@ -167,7 +165,7 @@ def forward(self, x: torch.Tensor): bias = self.bias if self.bias is not None else torch.empty( 0, dtype=torch.float) - qbits.woq_linear(x, self.qweight, bias, outputs, + self.qbits.woq_linear(x, self.qweight, bias, outputs, convert_dtype_torch2str(torch.float), # compute_dtype BITS_DTYPE_MAPPING[self.bits], # weight_dtype "fp32", # scale_dtype