Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bugs #295

Merged
merged 1 commit into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 78 additions & 53 deletions llmc/compression/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def scaling_weight(self, w, scales, is_gqa):
scales_tmp = self.repeat_gqa_scales(scales)
else:
scales_tmp = scales
w_tmp = w.mul_(scales_tmp.view(1, -1))
return w_tmp
w.mul_(scales_tmp.view(1, -1))
return w

@torch.no_grad()
def get_weight_scale(self, layers_dict):
Expand Down Expand Up @@ -60,14 +60,17 @@ def get_weight_scale(self, layers_dict):
return scale

def get_act_scale(self, x):
batch_means = []
b_num = x.shape[0] // self._bs
for num in range(b_num):
batch_x = x[num * self._bs:(num + 1) * self._bs]
batch_mean = batch_x.abs().view(-1, batch_x.shape[-1]).mean(0)
batch_means.append(batch_mean)
final_mean = sum(batch_means) / len(batch_means)
return final_mean
if x.shape[0] == self._bs:
return x.abs().view(-1, x.shape[-1]).mean(0)
else:
batch_means = []
b_num = x.shape[0] // self._bs
for num in range(b_num):
batch_x = x[num * self._bs:(num + 1) * self._bs]
batch_mean = batch_x.abs().view(-1, batch_x.shape[-1]).mean(0)
batch_means.append(batch_mean)
final_mean = sum(batch_means) / len(batch_means)
return final_mean

@torch.no_grad()
def get_scales(self, prev_op, x, w_max, is_gqa, ratio):
Expand All @@ -93,15 +96,22 @@ def get_scales(self, prev_op, x, w_max, is_gqa, ratio):
return scales

def inspect_module_forward(self, x, inspect_module, kwargs):
outs = []
b_num = x.shape[0] // self._bs
for num in range(b_num):
_x = x[num * self._bs:(num + 1) * self._bs]
out = inspect_module(_x, **kwargs)
if isinstance(out, tuple):
out = out[0]
outs.append(out)
return torch.cat(outs, dim=0)
if self._bs == x.shape[0]:
with torch.no_grad():
out = inspect_module(x, **kwargs)
if isinstance(out, tuple):
out = out[0]
return out
else:
outs = []
b_num = x.shape[0] // self._bs
for num in range(b_num):
_x = x[num * self._bs:(num + 1) * self._bs]
out = inspect_module(_x, **kwargs)
if isinstance(out, tuple):
out = out[0]
outs.append(out)
return torch.cat(outs, dim=0)

@torch.no_grad()
def get_original_out(self, x, inspect_module, subset_kwargs):
Expand All @@ -110,14 +120,53 @@ def get_original_out(self, x, inspect_module, subset_kwargs):
return org_out

def calculate_loss(self, org_out, out):
total_loss = 0.0
b_num = org_out.shape[0] // self._bs
for num in range(b_num):
_org_out = org_out[num * self._bs:(num + 1) * self._bs]
_out = out[num * self._bs:(num + 1) * self._bs]
single_loss = (_org_out - _out).float().pow(2).mean().item()
total_loss += single_loss
return total_loss / b_num
if out.shape[0] == self._bs:
return (org_out - out).float().pow(2).mean().item()
else:
total_loss = 0.0
b_num = org_out.shape[0] // self._bs
for num in range(b_num):
_org_out = org_out[num * self._bs:(num + 1) * self._bs]
_out = out[num * self._bs:(num + 1) * self._bs]
single_loss = (_org_out - _out).float().pow(2).mean().item()
total_loss += single_loss
return total_loss / b_num

def fake_quantize_weight(self, weight, scales, is_gqa, layer_name):
weight = self.scaling_weight(weight, scales, is_gqa)
weight.data = get_wquantizer(
self.block_idx,
layer_name,
self.mix_bits_map,
self.quantizer_mix_bits,
self.wquantizer,
).fake_quant_weight_dynamic(weight.data)

return weight

def fake_quantize_input(self, x_tmp, layers_dict):
if self._bs == x_tmp.shape[0]:
x_tmp = get_aquantizer(
self.block_idx,
list(layers_dict.keys())[0],
self.mix_bits_map,
self.quantizer_mix_bits,
self.aquantizer,
).fake_quant_act_dynamic(x_tmp)
else:
outs = []
for i in range(x_tmp.shape[0]):
_x = x_tmp[i]
_x = get_aquantizer(
self.block_idx,
list(layers_dict.keys())[0],
self.mix_bits_map,
self.quantizer_mix_bits,
self.aquantizer,
).fake_quant_act_dynamic(_x)
outs.append(_x)
x_tmp = torch.stack(outs)
return x_tmp

@torch.no_grad()
def search_scale_subset(
Expand Down Expand Up @@ -158,25 +207,12 @@ def search_scale_subset(
else:
org_out = self.get_original_out(x, inspect_module, kwargs)
org_out_dict[i] = org_out
x_max = self.get_act_scale(x)

ratio = n * 1 / n_grid
scales = self.get_scales(prev_op, x, w_max, is_gqa, ratio)
for layer_name in layers_dict:
fc = layers_dict[layer_name]
fc.weight = self.scaling_weight(fc.weight, scales, is_gqa)

fc.weight.data = get_wquantizer(
self.block_idx,
layer_name,
self.mix_bits_map,
self.quantizer_mix_bits,
self.wquantizer,
).fake_quant_weight_dynamic(fc.weight.data)

del x_max
gc.collect()
torch.cuda.empty_cache()
fc.weight = self.fake_quantize_weight(fc.weight, scales, is_gqa, layer_name)

x_tmp = self.scaling_input(x, scales, is_gqa)

Expand All @@ -187,18 +223,7 @@ def search_scale_subset(
self.quantizer_mix_bits,
self.w_only,
):
outs = []
for i in range(x_tmp.shape[0]):
_x = x_tmp[i]
_x = get_aquantizer(
self.block_idx,
list(layers_dict.keys())[0],
self.mix_bits_map,
self.quantizer_mix_bits,
self.aquantizer,
).fake_quant_act_dynamic(_x)
outs.append(_x)
x_tmp = torch.stack(outs)
x_tmp = self.fake_quantize_input(x_tmp, layers_dict)

out = self.inspect_module_forward(x_tmp, inspect_module, kwargs)

Expand Down
23 changes: 13 additions & 10 deletions llmc/compression/quantization/base_blockwise_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,9 @@ def set_quant_config(self):
if self.act_static:
act_static_cfg.update(self.config.calib.n_sample)
act_static_cfg.update(self.config.calib.bs)
kv_quant_type = self.quant_config['kvcache'].get('quant_type', 'int-quant')
self.kv_module = KV_REGISTRY[self.quant_config['kvcache']['method']](
self.quant_type, self.quant_config['kvcache'],
kv_quant_type, self.quant_config['kvcache'],
self.model.model_config.num_hidden_layers, **kv_special_cfg, **act_static_cfg
)
self.quant_kvcache = True
Expand Down Expand Up @@ -287,8 +288,9 @@ def set_quant_config(self):
# set online-rotation config
self.online_rotate = special_config.get('online_rotate', False)
if self.online_rotate:
assert self.config['model']['type'] in ['Opt', 'Llama']

assert (
self.config['model']['type'] in ['Opt', 'Llama']
), 'Please set online_rotate=False'
self.hidden_size = self.model.model_config.hidden_size
self.set_model_config()
self.modality = self.quant_config.modality
Expand Down Expand Up @@ -832,12 +834,13 @@ def scaling_input(self, x, scales, is_gqa):
scales_tmp = self.repeat_gqa_scales(scales)
else:
scales_tmp = scales

x_tmp = torch.empty_like(x)
for i, batch in enumerate(x):
batch_scale = scales_tmp.view(1, -1)
x_tmp[i] = batch / batch_scale

if hasattr(self, '_bs') and self._bs < x.shape[0]:
x_tmp = torch.empty_like(x)
for i, batch in enumerate(x):
batch_scale = scales_tmp.view(1, -1)
x_tmp[i] = batch / batch_scale
else:
x_tmp = x / scales.view(1, -1)
return x_tmp

@torch.no_grad()
Expand All @@ -846,7 +849,7 @@ def update_input_feat(self, scale, input_feat, layers_dict, is_gqa):
for i in range(len(input_feat[layer_name])):
inp = input_feat[layer_name][i]
scale = scale.to(inp.device)
inp = self.scaling_input(inp, scale, is_gqa)
input_feat[layer_name][i] = self.scaling_input(inp, scale, is_gqa)

@torch.no_grad()
def set_non_linear_mode(self, quant_format, module, mode):
Expand Down
56 changes: 13 additions & 43 deletions llmc/compression/quantization/quant.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
import torch
from loguru import logger

try:
from qtorch.quant import float_quantize
except Exception:
logger.warning(
'qtorch not found, please install qtorch.'
'Please install qtorch (pip install qtorch).'
)
float_quantize = None

class BaseQuantizer(object):
def __init__(self, bit, symmetric, granularity, **kwargs):
Expand Down Expand Up @@ -36,8 +44,6 @@ def __init__(self, bit, symmetric, granularity, **kwargs):

# hist config
self.bins = self.kwargs.get('bins', 2048)
self.hist_threshold = self.kwargs.get('hist_threshold', 1)
self.dst_nbins = 2**bit if isinstance(bit, int) else None
self.upsample_rate = (
16 # used to reduce quantization errors when upscaling histogram
)
Expand Down Expand Up @@ -87,36 +93,6 @@ def get_tensor_range(self, tensor, args={}):
else:
return self.get_minmax_range(tensor)

def get_hist_range(self, stats_min_max, act_stats_hist):
clip_val = {}
for input_idx, hist in act_stats_hist.items():
hist = hist.float() / hist.sum()
data_max = max(
-torch.min(stats_min_max[input_idx]['min']),
torch.max(stats_min_max[input_idx]['max']),
)
accum = 0
for i in range(len(hist)):
accum += hist[i]
if accum >= self.hist_threshold:
clip_value = (i + 0.5) * (data_max / self.bins)
clip_val[input_idx] = [
max(-clip_value, torch.min(stats_min_max[input_idx]['min'])),
min(clip_value, torch.max(stats_min_max[input_idx]['max'])),
]
break
if input_idx not in clip_val:
clip_val[input_idx] = [
torch.min(stats_min_max[input_idx]['min']),
torch.max(stats_min_max[input_idx]['max']),
]

moving_min_vals, moving_max_vals = [], []
for input_idx, tensor_range in clip_val.items():
moving_min_vals.append(tensor_range[0])
moving_max_vals.append(tensor_range[1])
return moving_min_vals, moving_max_vals

def get_minmax_range(self, tensor):
if self.granularity == 'per_tensor':
max_val = torch.max(tensor)
Expand Down Expand Up @@ -556,7 +532,7 @@ def get_batch_tensors_qparams(self, act_tensors, alpha=0.01, args={}):
if self.calib_algo == 'static_hist':
assert (
self.sym is True and self.granularity == 'per_tensor'
), 'Only support per tensor static symmetric.'
), 'Only support per tensor static symmetric int quantize.'
min_vals, max_vals = self.get_static_hist_range(act_tensors)
elif self.calib_algo == 'static_minmax':
min_vals, max_vals = self.get_static_minmax_range(act_tensors)
Expand Down Expand Up @@ -657,6 +633,7 @@ def __init__(self, bit, symmetric, granularity, **kwargs):

self.qmin = torch.tensor(self.qmin)
self.qmax = torch.tensor(self.qmax)
self.dst_nbins = 2**bit

def get_hqq_qparams(self, tensor, args):
tensor = tensor.float()
Expand Down Expand Up @@ -947,17 +924,10 @@ def __init__(self, bit, symmetric, granularity, **kwargs):
self.sign_bits = 1
self.num_bits = self.e_bits + self.m_bits + self.sign_bits
self.default_bias = 2 ** (self.e_bits - 1)

self.dst_nbins = 2**self.num_bits
self.use_qtorch = self.kwargs.get('use_qtorch')
if self.use_qtorch:
try:
from qtorch.quant import float_quantize
except ImportError:
logger.error('qtorch not found, please install qtorch.')
raise ImportError('Please install qtorch (pip install qtorch).')

self.float_quantize = float_quantize

assert float_quantize is not None, 'Please install qtorch (pip install qtorch). Or set use_qtorch=False'
if 'float_range' in self.kwargs:
self.qmin, self.qmax = self.kwargs['float_range']
else:
Expand Down Expand Up @@ -1045,7 +1015,7 @@ def quant(self, tensor, scales, zeros, qmax, qmin):
scaled_tensor = tensor / scales + zeros
if self.use_qtorch:
org_dtype = scaled_tensor.dtype
q_tensor = self.float_quantize(
q_tensor = float_quantize(
scaled_tensor.float(), self.e_bits, self.m_bits, rounding='nearest'
)
q_tensor.to(org_dtype)
Expand Down
Loading