Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix qbits issue #153

Merged
merged 6 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 11 additions & 14 deletions auto_round/auto_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,13 @@ def convert_model(self, model: nn.Module):
return model

def _dynamic_import_inference_linear(self, bits, backend):
if (not torch.cuda.is_available()) or "qbits" in backend or "cpu" in backend:
try:
from intel_extension_for_transformers import qbits # pylint: disable=E0401
except Exception as e:
raise ImportError("Please install Intel Extension for Transformers via 'pip install "
"intel-extension-for-transformers' to inference on X86 CPU")
return qlinear_qbits.QuantLinear
if bits == 4 and self.exllama2_available and "exllamav2" in backend:
from auto_round_extension.cuda.qliner_exllamav2 import QuantLinear
else:
Expand All @@ -341,9 +348,10 @@ def _replace_by_quant_layers(self, module: nn.Module, layer_configs, backend):
data_type = config["data_type"]
if not (bits <= 8 and data_type == "int"):
continue
QuantLinear = self._dynamic_import_inference_linear(bits, backend)

layer = get_module(module, layer_name)
device = get_device(layer)
QuantLinear = self._dynamic_import_inference_linear(bits, backend)
if isinstance(layer, nn.Linear):
in_features = layer.in_features
out_features = layer.out_features
Expand All @@ -363,24 +371,13 @@ def _replace_by_quant_layers(self, module: nn.Module, layer_configs, backend):
weight_dtype=layer.weight.dtype,
)

if new_layer.qweight.device.type == "cpu": # fallback to qbits linear when qweight on cpu device
QuantLinear = qlinear_qbits.QuantLinear
new_layer = QuantLinear( # pylint: disable=E1123
bits,
group_size,
in_features,
out_features,
bias,
weight_dtype=layer.weight.dtype,
)

new_layer.device = device
set_module(module, layer_name, new_layer)

def qbits_post_init(self, model):
dep_check = True
for layer in model.modules():
if isinstance(layer,qlinear_qbits.QuantLinear):
if isinstance(layer, qlinear_qbits.QuantLinear):
if dep_check:
layer.req_check()
layer.post_init()
Expand Down Expand Up @@ -408,7 +405,7 @@ class StoreAttr(object):
model = autoround_post_init(model)
# there are no side-effects after call qbits_post_init when model quant-type not equal to qbits.
model = self.qbits_post_init(model)

return model

def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
Expand Down
14 changes: 6 additions & 8 deletions auto_round_extension/qbits/qlinear_qbits.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,6 @@
import torch.nn as nn
from auto_round.utils import convert_dtype_torch2str, logger
QBITS_AVAILABLE = True
try:
wenhuach21 marked this conversation as resolved.
Show resolved Hide resolved
from intel_extension_for_transformers import qbits # noqa: F401
except Exception as e:
QBITS_AVAILABLE = False
logger.warning(
"qlinear_qbits should be used with Intel Extension for Transformers.")

BITS_DTYPE_MAPPING = {
2: "int2_clip",
Expand Down Expand Up @@ -62,6 +56,7 @@ def __init__(
self.maxq = 2**self.bits - 1
self.weight_dtype = weight_dtype
self.asym = True
self.qbits = None

self.register_buffer(
"qweight",
Expand Down Expand Up @@ -98,6 +93,7 @@ def __init__(
def req_check(self):
torch_version = str(torch.__version__)
if QBITS_AVAILABLE:
import intel_extension_for_transformers
itrex_version = str(intel_extension_for_transformers.__version__)
version_match_map = {"1.4": "2.2.0+cpu",
"1.4.1": "2.2.0+cpu", "1.4.2": "2.3.0+cpu"}
Expand All @@ -111,6 +107,8 @@ def req_check(self):
exit(1)

def post_init(self):
import intel_extension_for_transformers
self.qbits = intel_extension_for_transformers.qbits
assert self.qweight.device.type == "cpu"
if self.bias is not None:
self.bias = self.bias.to(dtype=torch.float32)
Expand Down Expand Up @@ -142,7 +140,7 @@ def post_init(self):

logger.info(
f"QBits repack quantized weight: K:{intweight.shape[0]}, N:{intweight.shape[1]}, weight_dtype:{BITS_DTYPE_MAPPING[self.bits]}, scale_dtype:fp32, compute_dtype:fp32, group_size:{self.group_size}")
self.qweight = qbits.repack_quantized_weight(intweight.contiguous(), scales.float().contiguous(), zeros.contiguous(), torch.empty(0),
self.qweight = self.qbits.repack_quantized_weight(intweight.contiguous(), scales.float().contiguous(), zeros.contiguous(), torch.empty(0),
# weight_dtype
BITS_DTYPE_MAPPING[self.bits],
# scale_dtype
Expand All @@ -167,7 +165,7 @@ def forward(self, x: torch.Tensor):
bias = self.bias if self.bias is not None else torch.empty(
0, dtype=torch.float)

qbits.woq_linear(x, self.qweight, bias, outputs,
self.qbits.woq_linear(x, self.qweight, bias, outputs,
convert_dtype_torch2str(torch.float), # compute_dtype
BITS_DTYPE_MAPPING[self.bits], # weight_dtype
"fp32", # scale_dtype
Expand Down