Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][quantize] add gptq post-quantization #58

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from parallel_utils import map_layers_to_multi_gpus, get_lowest_occupied_gpu
import torch.nn as nn
from quantize.omniquant import omniquant
from quantize.gptq import gptq
from tqdm import tqdm
import utils
from pathlib import Path
Expand Down Expand Up @@ -218,6 +219,7 @@ def main():
parser.add_argument("--aug_loss", default=False, action="store_true", help="calculate additional loss with same input")
parser.add_argument("--symmetric",default=False, action="store_true", help="symmetric quantization")
parser.add_argument("--disable_zero_point",default=False, action="store_true", help="quantization without zero_point")
parser.add_argument("--gptq", default=False, action="store_true", help="use gptq for further compensation")
parser.add_argument("--a_dynamic_method", type=str, default="per_token", choices=["per_token"])
parser.add_argument("--w_dynamic_method", type=str, default="per_channel", choices=["per_channel"])
parser.add_argument("--limit", type=int, default=-1)
Expand Down Expand Up @@ -351,6 +353,11 @@ def main():
logger,
)
logger.info(time.time() - tick)
if args.gptq:
tick = time.time()
with torch.no_grad():
gptq(lm, args, dataloader, logger)
logger.info(time.time() - tick)
if args.save_dir:
# delete omni parameters
for name, module in lm.model.named_modules():
Expand Down
204 changes: 204 additions & 0 deletions quantize/gptq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright [2024-01-23] <[email protected], Xingchen Song>

import torch
import math

from auto_gptq.quantization.gptq import GPTQ
from auto_gptq.modeling._utils import find_layers

from quantize.quantizer import FixedScaleQuantizer
from quantize.int_linear import QuantLinear
from quantize.utils import set_quant_state, smooth_and_quant_inplace


def gptq(lm, args, dataloader, logger):
logger.info("Starting GPTQ...")

# 1. move embedding layer and first layer to target device
model = lm.model
dev = lm.device
logger.info("model.device: {}".format(dev))
use_cache = model.config.use_cache
model.config.use_cache = False
is_llama = False
# TODO(xcsong): support other archs
if "llama" in args.net.lower():
is_llama = True
layers = model.model.layers
model.model.embed_tokens = model.model.embed_tokens.to(dev)
model.model.norm = model.model.norm.to(dev)
layers_block_name = "model.layers"
inside_layer_modules = [
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"],
["self_attn.o_proj"],
["mlp.up_proj", "mlp.gate_proj"],
["mlp.down_proj"]
]
else:
raise ValueError("Only support for Llama-2 now")
layers[0] = layers[0].to(dev)

# 2. catch the first layer input
dtype = next(iter(model.parameters())).dtype
logger.info("model.dtype: {}".format(dtype))
inps = torch.zeros(
(args.nsamples, lm.seqlen, model.config.hidden_size),
dtype=dtype, device=dev
)
outs = torch.zeros_like(inps)
cache = {"i": 0, 'attention_mask': None}

class Catcher(torch.nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
self.is_llama = False

def forward(self, inp, **kwargs):
inps[cache["i"]] = inp
cache["i"] += 1
cache["attention_mask"] = kwargs["attention_mask"]
if self.is_llama:
cache["position_ids"] = kwargs["position_ids"]
raise ValueError

layers[0] = Catcher(layers[0])
layers[0].is_llama = is_llama

for batch in dataloader:
if cache["i"] >= args.nsamples:
break
try:
model(batch[0].to(dev))
except ValueError:
pass

# 3. move embedding layer and first layer to cpu
layers[0] = layers[0].module
layers[0] = layers[0].cpu()
if "llama" in args.net.lower():
model.model.embed_tokens = model.model.embed_tokens.cpu()
model.model.norm = model.model.norm.cpu()
else:
raise ValueError("Only support for Llama-2 now")
torch.cuda.empty_cache()

# 4. get additional inputs (mask, pos, ..., etc)
attention_mask = cache["attention_mask"]
if is_llama:
position_ids = cache["position_ids"]
else:
position_ids = None

# 5. start gptq quantization
quantizers = {}
for i in range(len(layers)):
logger.info(f"=== Start quantize layer {i} with GPTQ ===")
layer = layers[i].to(dev)

# 5.1 get layers which should be quantized
full = find_layers(layer, layers=[QuantLinear])
for names in inside_layer_modules:
# NOTE(xcsong): type(subset[name]) == QuantLinear
# i.e. subset["self_attn.k_proj"] = QuantLinear(**someargs)
# type(gptq[name]) == GPTQ
# i.e. gptq["self_attn.k_proj"] = GPTQ(subset["self_attn.k_proj"]) # noqa
subset = {n: full[n] for n in names if n in full}
gptq = {}
# 5.1.1 init gptq
for name in subset:
gptq[name] = GPTQ(subset[name])
# NOTE(xcsong): Overwrite GPTQ().quantizer, use fixed scale
# and zero obtained from omniquant's quantizer
_ = subset[name].weight_quantizer(subset[name].weight)
subset[name].weight_quantizer.register_scales_and_zeros()
scale = subset[name].weight_quantizer.scales
zero = subset[name].weight_quantizer.zeros \
if subset[name].weight_quantizer.zeros is not None \
else torch.zeros_like(scale)
gptq[name].quantizer = FixedScaleQuantizer(
scale=scale, zero=zero,
**args.weight_quant_params,
shape=subset[name].weight.shape
)

# 5.1.2 init gptq.H
# NOTE(xcsong): Overwrite GPTQ().add_batch(), since
# 1. type(gptq[name].layer) is QuantLinear, not nn.Linear,
# making it incompatible with the original implementation.
# 2. We might consider utilizing fake quantized activations
# for the calculation of H.
def add_batch(name):
def tmp(_, inp, out):
# apply fake_quant to actiavtion
inp = subset[name].act_quantizer(inp[0].data)
if len(inp.shape) == 2:
inp = inp.unsqueeze(0)
batch = inp.shape[0]
if isinstance(gptq[name].layer, QuantLinear):
if len(inp.shape) == 3:
inp = inp.reshape((-1, inp.shape[-1]))
inp = inp.t()
else:
raise NotImplementedError()
gptq[name].H *= gptq[name].nsamples / (gptq[name].nsamples + batch)
gptq[name].nsamples += batch
inp = math.sqrt(2 / gptq[name].nsamples) * inp.float()
gptq[name].H += inp.matmul(inp.t())
return tmp

handles = []
for name in subset:
set_quant_state(subset[name], weight_quant=False,
act_quant=False)
subset[name].use_temporary_parameter = False
handles.append(subset[name].register_forward_hook(
add_batch(name)))
for j in range(args.nsamples):
layer(inps[j].unsqueeze(0),
attention_mask=attention_mask,
position_ids=position_ids)[0]
for h in handles:
h.remove()

# 5.1.3 do gptq-algorithm and update weight in-place
for name in subset:
logger.info(f'Quantize {name} in layer {i + 1}/{len(layers)}')
scale, zero, g_idx = gptq[name].fasterquant(
blocksize=128, # same as gptq
percdamp=0.01, # same as gptq, always choose 1% of the average diagonal value # noqa
group_size=-1 if args.group_size is None else args.group_size, # group = None means per-channel # noqa
actorder=False,
static_groups=False
)
quantizers[f'{layers_block_name}.{i}.{name}'] = (
gptq[name].quantizer.cpu(), scale.cpu(),
zero.cpu(), g_idx.cpu()
)
gptq[name].free()
torch.cuda.empty_cache()

# 5.2 quantize weight optimized by gptq
# NOTE(xcsong): After GPTQ quantization, we do
# online fake_quantize for activation (via set_quant_state)
# and
# offline in-place fake_quantize for weights (via smooth_and_quant_inplace)
layer.half()
set_quant_state(layer, weight_quant=False, act_quant=True)
prev_let, prev_gptq = args.let, args.gptq
args.let, args.gptq = False, False
smooth_and_quant_inplace(layer, args)
args.let, args.gptq = prev_let, prev_gptq

# 5.3 get output of current layer, treat it as input for next layer
for j in range(args.nsamples):
outs[j] = layer(inps[j].unsqueeze(0),
attention_mask=attention_mask,
position_ids=position_ids)[0]
inps, outs = outs, inps
layers[i] = layer.cpu()
torch.cuda.empty_cache()

model.config.use_cache = use_cache
57 changes: 55 additions & 2 deletions quantize/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
CLIPMIN = 1e-5




def round_ste(x: torch.Tensor):
"""
Implement Straight-Through Estimator for rounding operation.
Expand Down Expand Up @@ -163,3 +161,58 @@ def register_scales_and_zeros(self):
self.register_buffer('zeros', self.round_zero_point)
del self.scale
del self.round_zero_point


class FixedScaleQuantizer(UniformAffineQuantizer):
def __init__(
self,
scale,
zero,
n_bits: int = 8,
symmetric: bool = False,
per_channel_axes=[],
metric="minmax",
dynamic=False,
dynamic_method="per_cluster",
group_size=None,
shape=None,
lwc=False,
disable_zero_point=False
):
UniformAffineQuantizer.__init__(
self,
n_bits,
symmetric,
per_channel_axes,
metric,
dynamic,
dynamic_method,
group_size,
shape,
lwc,
disable_zero_point
)
# Init scale & zero
self.scale = scale
self.zero = zero

# NOTE(xcsong): Overwrite AutoGptqQuantizer.find_params() since there is
# no need to re-compute scale and zero
def find_params(self, x, weight=False):
pass

# NOTE(xcsong): Overwrite AutoGptqQuantizer.ready() since there is
# no need to re-compute scale and zero
def ready(self):
return True

# NOTE(xcsong): Overwrite AutoGptqQuantizer.quantize() since we have a
# slightly different quantization process
def quantize(self, x):
if self.n_bits >= 16:
return x
if self.metric == "fix0to1":
return x.mul_(2**self.n_bits-1).round_().div_(2**self.n_bits-1)

x_dequant = self.fake_quant(x, self.scale, self.zero)
return x_dequant
11 changes: 7 additions & 4 deletions quantize/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,13 @@ def smooth_and_quant_inplace(model, args):
model.out_smooth_scale, model.out_smooth_shift)
smooth_q_k_inplace(model.self_attn.q_proj, model.self_attn.k_proj,
model.qkt_smooth_scale)
for name, module in model.named_modules():
if isinstance(module, QuantLinear):
module.weight = module.weight_quantizer(module.weight)
module.use_temporary_parameter=False
if not args.gptq:
# NOTE(xcsong): Instead of fake_quantized weights,
# We need original weights to calculate errors for GPTQ
for _, module in model.named_modules():
if isinstance(module, QuantLinear):
module.weight = module.weight_quantizer(module.weight)
module.use_temporary_parameter = False

def set_quant_state(self, weight_quant: bool = False, act_quant: bool = False):
# setting weight quantization here does not affect actual forward pass
Expand Down