Skip to content

Commit

Permalink
Support gqa smooth (#279)
Browse files Browse the repository at this point in the history
* Support gqa smooth

* Support gqa smooth

* Support gqa smooth

* FIx bug

* fix bug

* fix bug

* fix bug

* fix bug

* fix bug

* fix bug

---------

Co-authored-by: gushiqiao <[email protected]>
  • Loading branch information
gushiqiao and gushiqiao authored Dec 30, 2024
1 parent 75f8d1d commit e320149
Show file tree
Hide file tree
Showing 4 changed files with 488 additions and 133 deletions.
156 changes: 123 additions & 33 deletions llmc/compression/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@ def __init__(self, model, quant_config, input, padding_mask, config):
self.trans = special_config.get('trans', True)
self.trans_version = special_config.get('trans_version', 'v2')
self.save_scale = special_config.get('save_scale', False)
self.awq_bs = special_config.get('awq_bs', None)

@torch.no_grad()
def scaling_weight(self, w, scales, is_gqa):
if is_gqa:
scales_tmp = self.repeat_gqa_scales(scales)
else:
scales_tmp = scales
w_tmp = w.mul_(scales_tmp.view(1, -1))
return w_tmp

@torch.no_grad()
def get_weight_scale(self, layers_dict):
Expand All @@ -49,20 +59,82 @@ def get_weight_scale(self, layers_dict):
torch.cuda.empty_cache()
return scale

@torch.no_grad()
def get_act_scale(self, x):
return x.abs().view(-1, x.shape[-1]).mean(0)
batch_means = []
b_num = x.shape[0] // self._bs
for num in range(b_num):
batch_x = x[num * self._bs:(num + 1) * self._bs]
batch_mean = batch_x.abs().view(-1, batch_x.shape[-1]).mean(0)
batch_means.append(batch_mean)
final_mean = sum(batch_means) / len(batch_means)
return final_mean

@torch.no_grad()
def get_scales(self, prev_op, x, w_max, is_gqa, ratio):
if is_gqa:
x_tmp = prev_op(x)
w_tmp = self.get_weight_scale({'prev_op': prev_op})
else:
x_tmp = x
w_tmp = w_max

x_tmp = self.get_act_scale(x_tmp)

if self.trans_version == 'v1':
scales = (
(x_tmp.pow(ratio) / w_tmp.pow(1 - ratio))
.clamp(min=1e-4)
.view(-1)
)
elif self.trans_version == 'v2':
scales = x_tmp.pow(ratio).clamp(min=1e-4).view(-1)

scales = scales / (scales.max() * scales.min()).sqrt()
return scales

def inspect_module_forward(self, x, inspect_module, kwargs):
outs = []
b_num = x.shape[0] // self._bs
for num in range(b_num):
_x = x[num * self._bs:(num + 1) * self._bs]
out = inspect_module(_x, **kwargs)
if isinstance(out, tuple):
out = out[0]
outs.append(out)
return torch.cat(outs, dim=0)

@torch.no_grad()
def get_original_out(self, x, inspect_module, subset_kwargs):
with torch.no_grad():
org_out = inspect_module(x, **subset_kwargs)
if isinstance(org_out, tuple):
org_out = org_out[0]
org_out = self.inspect_module_forward(x, inspect_module, subset_kwargs)
return org_out

def calculate_loss(self, org_out, out):
total_loss = 0.0
b_num = org_out.shape[0] // self._bs
for num in range(b_num):
_org_out = org_out[num * self._bs:(num + 1) * self._bs]
_out = out[num * self._bs:(num + 1) * self._bs]
single_loss = (_org_out - _out).float().pow(2).mean().item()
total_loss += single_loss
return total_loss / b_num

@torch.no_grad()
def search_scale_subset(self, layers_dict, input, inspect_module, subset_kwargs):
def search_scale_subset(
self,
prev_op,
layers_dict,
input,
inspect_module,
is_gqa,
subset_kwargs
):

if self.awq_bs is None:
self._bs = input[0].shape[0]
else:
self._bs = self.awq_bs

w_max = self.get_weight_scale(layers_dict)
# grid search for ratio
best_error = float('inf')
Expand All @@ -89,18 +161,10 @@ def search_scale_subset(self, layers_dict, input, inspect_module, subset_kwargs)
x_max = self.get_act_scale(x)

ratio = n * 1 / n_grid
if self.trans_version == 'v1':
scales = (
(x_max.pow(ratio) / w_max.pow(1 - ratio))
.clamp(min=1e-4)
.view(-1)
)
elif self.trans_version == 'v2':
scales = x_max.pow(ratio).clamp(min=1e-4).view(-1)
scales = scales / (scales.max() * scales.min()).sqrt()
scales = self.get_scales(prev_op, x, w_max, is_gqa, ratio)
for layer_name in layers_dict:
fc = layers_dict[layer_name]
fc.weight.mul_(scales.view(1, -1))
fc.weight = self.scaling_weight(fc.weight, scales, is_gqa)

fc.weight.data = get_wquantizer(
self.block_idx,
Expand All @@ -110,31 +174,39 @@ def search_scale_subset(self, layers_dict, input, inspect_module, subset_kwargs)
self.wquantizer,
).fake_quant_weight_dynamic(fc.weight.data)

x_tmp = x / scales.view(1, -1)
del x_max
gc.collect()
torch.cuda.empty_cache()

x_tmp = self.scaling_input(x, scales, is_gqa)

if not check_w_only(
self.block_idx,
list(layers_dict.keys())[0],
self.mix_bits_map,
self.quantizer_mix_bits,
self.w_only,
):
x_tmp = get_aquantizer(
self.block_idx,
list(layers_dict.keys())[0],
self.mix_bits_map,
self.quantizer_mix_bits,
self.aquantizer,
).fake_quant_act_dynamic(x_tmp)
out = inspect_module(x_tmp, **kwargs)

if isinstance(out, tuple):
out = out[0]
outs = []
for i in range(x_tmp.shape[0]):
_x = x_tmp[i]
_x = get_aquantizer(
self.block_idx,
list(layers_dict.keys())[0],
self.mix_bits_map,
self.quantizer_mix_bits,
self.aquantizer,
).fake_quant_act_dynamic(_x)
outs.append(_x)
x_tmp = torch.stack(outs)

out = self.inspect_module_forward(x_tmp, inspect_module, kwargs)

if self.padding_mask and org_out.shape[1] == self.padding_mask[i].shape[-1]:
org_out = org_out * self.padding_mask[i].unsqueeze(dim=-1).to(org_out.device) # noqa
out = out * self.padding_mask[i].unsqueeze(dim=-1).to(out.device)

loss = (org_out - out).float().pow(2).mean().item()
loss = self.calculate_loss(org_out, out)

if len(input) == 1:
n_samples = x.shape[0]
Expand All @@ -149,6 +221,11 @@ def search_scale_subset(self, layers_dict, input, inspect_module, subset_kwargs)
best_error = loss_mean
best_scales = scales_mean

del org_out
del out
gc.collect()
torch.cuda.empty_cache()

# Synchronize across ranks
best_error_tensor = torch.tensor([best_error], device='cuda')
dist.all_reduce(best_error_tensor, op=dist.ReduceOp.MIN)
Expand Down Expand Up @@ -248,15 +325,28 @@ def subset_transform(
and prev_op[0].out_features != layers[0].in_features * 2
and prev_op[0].out_features != layers[0].in_features
):
logger.info('Cannot apply scale. Do not transform this subset.')
return

if self.has_gqa:
is_gqa = True
input_keys = list(input_feat.keys())
input_name = input_keys[input_keys.index(input_name) - 1]
else:
logger.info('Cannot apply scale. Do not transform this subset.')
return
else:
is_gqa = False

scale = self.search_scale_subset(
layers_dict, input_feat[input_name], inspect_module, subset_kwargs
prev_op[0],
layers_dict,
input_feat[input_name],
inspect_module,
is_gqa,
subset_kwargs
)

self.apply_scale(scale, prev_op, layers)
self.update_input_feat(scale, input_feat, layers_dict)
self.update_input_feat(scale, input_feat, layers_dict, is_gqa)

if self.save_scale:
for n in layers_dict:
Expand Down
56 changes: 48 additions & 8 deletions llmc/compression/quantization/base_blockwise_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,15 +283,26 @@ def set_quant_config(self):
assert self.config['model']['type'] in ['Opt', 'Llama']

self.hidden_size = self.model.model_config.hidden_size
if self.online_rotate:
self.num_heads = self.model.model_config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.intermediate_size = self.model.model_config.intermediate_size
self.fp32_had = special_config.get('fp32_had', False)

self.set_model_config()
self.quant_objects = self.quant_config.get('quant_objects', ['language'])
logger.info(f'self.quant_objects : {self.quant_objects}')

def set_model_config(self):
self.hidden_size = self.model.model_config.hidden_size
self.num_heads = self.model.model_config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
if hasattr(self.model.model_config, 'intermediate_size'):
self.intermediate_size = self.model.model_config.intermediate_size
if hasattr(self.model.model_config, 'num_key_value_heads'):
self.num_key_value_heads = self.model.model_config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
if self.num_key_value_groups > 1:
self.has_gqa = True
else:
self.has_gqa = False
else:
self.has_gqa = False

def replace_rotate_linears(self, block):
for n, m in block.named_modules():
if isinstance(m, nn.Linear) and (
Expand Down Expand Up @@ -581,6 +592,12 @@ def register_act_qparams(self, layers_dict, act_tensors):
layer.register_buffer(f'buf_act_qmin_{i}', qmin.cuda())
layer.register_buffer(f'buf_act_qmax_{i}', qmax.cuda())

@torch.no_grad()
def repeat_gqa_scales(self, scales):
scales = scales.view(1, self.num_key_value_heads, self.head_dim)
scales = torch.repeat_interleave(scales, dim=1, repeats=self.num_key_value_groups)
return scales

@torch.no_grad()
def apply_scale(self, scales, prev_op, layers):
assert (
Expand Down Expand Up @@ -652,6 +669,14 @@ def scale_fc_fc(self, fc1, fc2, scales):
fc1.bias.div_(scales.view(-1))

fc1.weight.div_(scales.view(-1, 1))
elif self.has_gqa:
if hasattr(fc1, 'bias') and fc1.bias is not None:
fc1.bias.div_(scales.view(-1))
fc1.weight.div_(scales.view(-1, 1))

if fc1.out_features != fc2.in_features:
logger.info('GQA scale this fc-fc.')
scales = self.repeat_gqa_scales(scales)
else:
logger.error(f'fc1.out_features: {fc1.out_features}')
logger.error(f'fc2.in_features: {fc2.in_features}')
Expand Down Expand Up @@ -795,11 +820,26 @@ def bake_mean_into_fc(self, fc):
fc.bias.data = fc.bias.data.to(fc_dtype)

@torch.no_grad()
def update_input_feat(self, scale, input_feat, layers_dict):
def scaling_input(self, x, scales, is_gqa):
if is_gqa:
scales_tmp = self.repeat_gqa_scales(scales)
else:
scales_tmp = scales

x_tmp = torch.empty_like(x)
for i, batch in enumerate(x):
batch_scale = scales_tmp.view(1, -1)
x_tmp[i] = batch / batch_scale

return x_tmp

@torch.no_grad()
def update_input_feat(self, scale, input_feat, layers_dict, is_gqa):
for layer_name in layers_dict:
for i in range(len(input_feat[layer_name])):
inp = input_feat[layer_name][i]
inp.div_(scale.view(1, -1).to(inp.device))
scale = scale.to(inp.device)
inp = self.scaling_input(inp, scale, is_gqa)

@torch.no_grad()
def set_non_linear_mode(self, quant_format, module, mode):
Expand Down
Loading

0 comments on commit e320149

Please sign in to comment.