Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable awq ipex linear in transformers #610

Merged
merged 1 commit into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 4 additions & 12 deletions awq/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,18 +663,10 @@ def _load_quantized_modules(
elif version == "gemv_fast":
q_linear_module = WQLinear_GEMVFast

if use_ipex:
q_linear = q_linear_module.from_linear(
module,
quant_config.w_bit,
quant_config.q_group_size,
True,
has_zero_points=quant_config.zero_point,
)
else:
q_linear = q_linear_module.from_linear(
module, quant_config.w_bit, quant_config.q_group_size, True
)

q_linear = q_linear_module.from_linear(
module, quant_config.w_bit, quant_config.q_group_size, True
)
q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear)

Expand Down
10 changes: 4 additions & 6 deletions awq/modules/linear/gemm_ipex.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

class WQLinear_IPEX(nn.Module):

def __init__(self, w_bit, group_size, in_features, out_features, bias, zero_point, dev):
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
super().__init__()
assert IPEX_INSTALLED, \
"Please install IPEX package with `pip install intel_extension_for_pytorch`."
Expand All @@ -23,7 +23,6 @@ def __init__(self, w_bit, group_size, in_features, out_features, bias, zero_poin
self.out_features = out_features
self.w_bit = w_bit
self.group_size = group_size if group_size != -1 else in_features
self.zero_point = zero_point
self.scale_dtype = torch.float32

# quick sanity check (make sure aligment)
Expand All @@ -35,9 +34,9 @@ def __init__(self, w_bit, group_size, in_features, out_features, bias, zero_poin
"qzeros",
torch.zeros(
(in_features // self.group_size, out_features // self.pack_num),
dtype=torch.int8,
dtype=torch.int32,
device=dev,
) if self.zero_point else None,
),
)
self.register_buffer(
"scales",
Expand Down Expand Up @@ -66,14 +65,13 @@ def post_init(self):
self.group_size, None, 0, 1)

@classmethod
def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None, has_zero_points=False):
def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None):
awq_linear = cls(
w_bit,
group_size,
linear.in_features,
linear.out_features,
linear.bias is not None,
has_zero_points,
linear.weight.device,
)
if init_only: # just prepare for loading sd
Expand Down
27 changes: 8 additions & 19 deletions awq/utils/fused_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,25 +82,14 @@ def fuse_qkv(module, q_proj, k_proj, v_proj):
elif isinstance(q_proj, WQLinear_IPEX):
q_linear = WQLinear_IPEX

if isinstance(q_proj, WQLinear_IPEX):
qkv_layer = q_linear(
q_proj.w_bit,
q_proj.group_size,
q_proj.in_features,
q_proj.out_features + k_proj.out_features + v_proj.out_features,
q_proj.bias is not None,
q_proj.zero_point,
next(iter(module.state_dict().values())).device,
)
else:
qkv_layer = q_linear(
q_proj.w_bit,
q_proj.group_size,
q_proj.in_features,
q_proj.out_features + k_proj.out_features + v_proj.out_features,
q_proj.bias is not None,
next(iter(module.state_dict().values())).device,
)
qkv_layer = q_linear(
q_proj.w_bit,
q_proj.group_size,
q_proj.in_features,
q_proj.out_features + k_proj.out_features + v_proj.out_features,
q_proj.bias is not None,
next(iter(module.state_dict().values())).device,
)

if isinstance(q_proj, WQLinear_GEMV):
qkv_layer.qweight = torch.cat(
Expand Down