diff --git a/awq/models/base.py b/awq/models/base.py index f2396bb5..2da5095d 100644 --- a/awq/models/base.py +++ b/awq/models/base.py @@ -663,18 +663,10 @@ def _load_quantized_modules( elif version == "gemv_fast": q_linear_module = WQLinear_GEMVFast - if use_ipex: - q_linear = q_linear_module.from_linear( - module, - quant_config.w_bit, - quant_config.q_group_size, - True, - has_zero_points=quant_config.zero_point, - ) - else: - q_linear = q_linear_module.from_linear( - module, quant_config.w_bit, quant_config.q_group_size, True - ) + + q_linear = q_linear_module.from_linear( + module, quant_config.w_bit, quant_config.q_group_size, True + ) q_linear.to(next(layer.parameters()).device) set_op_by_name(layer, name, q_linear) diff --git a/awq/modules/linear/gemm_ipex.py b/awq/modules/linear/gemm_ipex.py index 399b98d3..dd4a996e 100644 --- a/awq/modules/linear/gemm_ipex.py +++ b/awq/modules/linear/gemm_ipex.py @@ -11,7 +11,7 @@ class WQLinear_IPEX(nn.Module): - def __init__(self, w_bit, group_size, in_features, out_features, bias, zero_point, dev): + def __init__(self, w_bit, group_size, in_features, out_features, bias, dev): super().__init__() assert IPEX_INSTALLED, \ "Please install IPEX package with `pip install intel_extension_for_pytorch`." @@ -23,7 +23,6 @@ def __init__(self, w_bit, group_size, in_features, out_features, bias, zero_poin self.out_features = out_features self.w_bit = w_bit self.group_size = group_size if group_size != -1 else in_features - self.zero_point = zero_point self.scale_dtype = torch.float32 # quick sanity check (make sure aligment) @@ -35,9 +34,9 @@ def __init__(self, w_bit, group_size, in_features, out_features, bias, zero_poin "qzeros", torch.zeros( (in_features // self.group_size, out_features // self.pack_num), - dtype=torch.int8, + dtype=torch.int32, device=dev, - ) if self.zero_point else None, + ), ) self.register_buffer( "scales", @@ -66,14 +65,13 @@ def post_init(self): self.group_size, None, 0, 1) @classmethod - def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None, has_zero_points=False): + def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None): awq_linear = cls( w_bit, group_size, linear.in_features, linear.out_features, linear.bias is not None, - has_zero_points, linear.weight.device, ) if init_only: # just prepare for loading sd diff --git a/awq/utils/fused_utils.py b/awq/utils/fused_utils.py index a78fead1..dcc90ae6 100644 --- a/awq/utils/fused_utils.py +++ b/awq/utils/fused_utils.py @@ -82,25 +82,14 @@ def fuse_qkv(module, q_proj, k_proj, v_proj): elif isinstance(q_proj, WQLinear_IPEX): q_linear = WQLinear_IPEX - if isinstance(q_proj, WQLinear_IPEX): - qkv_layer = q_linear( - q_proj.w_bit, - q_proj.group_size, - q_proj.in_features, - q_proj.out_features + k_proj.out_features + v_proj.out_features, - q_proj.bias is not None, - q_proj.zero_point, - next(iter(module.state_dict().values())).device, - ) - else: - qkv_layer = q_linear( - q_proj.w_bit, - q_proj.group_size, - q_proj.in_features, - q_proj.out_features + k_proj.out_features + v_proj.out_features, - q_proj.bias is not None, - next(iter(module.state_dict().values())).device, - ) + qkv_layer = q_linear( + q_proj.w_bit, + q_proj.group_size, + q_proj.in_features, + q_proj.out_features + k_proj.out_features + v_proj.out_features, + q_proj.bias is not None, + next(iter(module.state_dict().values())).device, + ) if isinstance(q_proj, WQLinear_GEMV): qkv_layer.qweight = torch.cat(