Skip to content

Commit

Permalink
Revert "enable awq ipex linear in transformers (casper-hansen#610)"
Browse files Browse the repository at this point in the history
This reverts commit 7954766.
  • Loading branch information
vanGraan committed Oct 1, 2024
1 parent 7954766 commit 62414b5
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 16 deletions.
16 changes: 12 additions & 4 deletions awq/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,10 +663,18 @@ def _load_quantized_modules(
elif version == "gemv_fast":
q_linear_module = WQLinear_GEMVFast


q_linear = q_linear_module.from_linear(
module, quant_config.w_bit, quant_config.q_group_size, True
)
if use_ipex:
q_linear = q_linear_module.from_linear(
module,
quant_config.w_bit,
quant_config.q_group_size,
True,
has_zero_points=quant_config.zero_point,
)
else:
q_linear = q_linear_module.from_linear(
module, quant_config.w_bit, quant_config.q_group_size, True
)
q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear)

Expand Down
10 changes: 6 additions & 4 deletions awq/modules/linear/gemm_ipex.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

class WQLinear_IPEX(nn.Module):

def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
def __init__(self, w_bit, group_size, in_features, out_features, bias, zero_point, dev):
super().__init__()
assert IPEX_INSTALLED, \
"Please install IPEX package with `pip install intel_extension_for_pytorch`."
Expand All @@ -23,6 +23,7 @@ def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
self.out_features = out_features
self.w_bit = w_bit
self.group_size = group_size if group_size != -1 else in_features
self.zero_point = zero_point
self.scale_dtype = torch.float32

# quick sanity check (make sure aligment)
Expand All @@ -34,9 +35,9 @@ def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
"qzeros",
torch.zeros(
(in_features // self.group_size, out_features // self.pack_num),
dtype=torch.int32,
dtype=torch.int8,
device=dev,
),
) if self.zero_point else None,
)
self.register_buffer(
"scales",
Expand Down Expand Up @@ -65,13 +66,14 @@ def post_init(self):
self.group_size, None, 0, 1)

@classmethod
def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None):
def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None, has_zero_points=False):
awq_linear = cls(
w_bit,
group_size,
linear.in_features,
linear.out_features,
linear.bias is not None,
has_zero_points,
linear.weight.device,
)
if init_only: # just prepare for loading sd
Expand Down
27 changes: 19 additions & 8 deletions awq/utils/fused_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,25 @@ def fuse_qkv(module, q_proj, k_proj, v_proj):
elif isinstance(q_proj, WQLinear_IPEX):
q_linear = WQLinear_IPEX

qkv_layer = q_linear(
q_proj.w_bit,
q_proj.group_size,
q_proj.in_features,
q_proj.out_features + k_proj.out_features + v_proj.out_features,
q_proj.bias is not None,
next(iter(module.state_dict().values())).device,
)
if isinstance(q_proj, WQLinear_IPEX):
qkv_layer = q_linear(
q_proj.w_bit,
q_proj.group_size,
q_proj.in_features,
q_proj.out_features + k_proj.out_features + v_proj.out_features,
q_proj.bias is not None,
q_proj.zero_point,
next(iter(module.state_dict().values())).device,
)
else:
qkv_layer = q_linear(
q_proj.w_bit,
q_proj.group_size,
q_proj.in_features,
q_proj.out_features + k_proj.out_features + v_proj.out_features,
q_proj.bias is not None,
next(iter(module.state_dict().values())).device,
)

if isinstance(q_proj, WQLinear_GEMV):
qkv_layer.qweight = torch.cat(
Expand Down

0 comments on commit 62414b5

Please sign in to comment.