Skip to content

Commit

Permalink
enable awq ipex linear in transformers (casper-hansen#610)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiqing-feng authored Sep 13, 2024
1 parent eab1a4a commit 7954766
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 37 deletions.
16 changes: 4 additions & 12 deletions awq/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,18 +663,10 @@ def _load_quantized_modules(
elif version == "gemv_fast":
q_linear_module = WQLinear_GEMVFast

if use_ipex:
q_linear = q_linear_module.from_linear(
module,
quant_config.w_bit,
quant_config.q_group_size,
True,
has_zero_points=quant_config.zero_point,
)
else:
q_linear = q_linear_module.from_linear(
module, quant_config.w_bit, quant_config.q_group_size, True
)

q_linear = q_linear_module.from_linear(
module, quant_config.w_bit, quant_config.q_group_size, True
)
q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear)

Expand Down
10 changes: 4 additions & 6 deletions awq/modules/linear/gemm_ipex.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

class WQLinear_IPEX(nn.Module):

def __init__(self, w_bit, group_size, in_features, out_features, bias, zero_point, dev):
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
super().__init__()
assert IPEX_INSTALLED, \
"Please install IPEX package with `pip install intel_extension_for_pytorch`."
Expand All @@ -23,7 +23,6 @@ def __init__(self, w_bit, group_size, in_features, out_features, bias, zero_poin
self.out_features = out_features
self.w_bit = w_bit
self.group_size = group_size if group_size != -1 else in_features
self.zero_point = zero_point
self.scale_dtype = torch.float32

# quick sanity check (make sure aligment)
Expand All @@ -35,9 +34,9 @@ def __init__(self, w_bit, group_size, in_features, out_features, bias, zero_poin
"qzeros",
torch.zeros(
(in_features // self.group_size, out_features // self.pack_num),
dtype=torch.int8,
dtype=torch.int32,
device=dev,
) if self.zero_point else None,
),
)
self.register_buffer(
"scales",
Expand Down Expand Up @@ -66,14 +65,13 @@ def post_init(self):
self.group_size, None, 0, 1)

@classmethod
def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None, has_zero_points=False):
def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None):
awq_linear = cls(
w_bit,
group_size,
linear.in_features,
linear.out_features,
linear.bias is not None,
has_zero_points,
linear.weight.device,
)
if init_only: # just prepare for loading sd
Expand Down
27 changes: 8 additions & 19 deletions awq/utils/fused_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,25 +82,14 @@ def fuse_qkv(module, q_proj, k_proj, v_proj):
elif isinstance(q_proj, WQLinear_IPEX):
q_linear = WQLinear_IPEX

if isinstance(q_proj, WQLinear_IPEX):
qkv_layer = q_linear(
q_proj.w_bit,
q_proj.group_size,
q_proj.in_features,
q_proj.out_features + k_proj.out_features + v_proj.out_features,
q_proj.bias is not None,
q_proj.zero_point,
next(iter(module.state_dict().values())).device,
)
else:
qkv_layer = q_linear(
q_proj.w_bit,
q_proj.group_size,
q_proj.in_features,
q_proj.out_features + k_proj.out_features + v_proj.out_features,
q_proj.bias is not None,
next(iter(module.state_dict().values())).device,
)
qkv_layer = q_linear(
q_proj.w_bit,
q_proj.group_size,
q_proj.in_features,
q_proj.out_features + k_proj.out_features + v_proj.out_features,
q_proj.bias is not None,
next(iter(module.state_dict().values())).device,
)

if isinstance(q_proj, WQLinear_GEMV):
qkv_layer.qweight = torch.cat(
Expand Down

0 comments on commit 7954766

Please sign in to comment.