diff --git a/llmc/compression/quantization/awq.py b/llmc/compression/quantization/awq.py index 09d99be2..1ea48b33 100644 --- a/llmc/compression/quantization/awq.py +++ b/llmc/compression/quantization/awq.py @@ -23,6 +23,16 @@ def __init__(self, model, quant_config, input, padding_mask, config): self.trans = special_config.get('trans', True) self.trans_version = special_config.get('trans_version', 'v2') self.save_scale = special_config.get('save_scale', False) + self.awq_bs = special_config.get('awq_bs', None) + + @torch.no_grad() + def scaling_weight(self, w, scales, is_gqa): + if is_gqa: + scales_tmp = self.repeat_gqa_scales(scales) + else: + scales_tmp = scales + w_tmp = w.mul_(scales_tmp.view(1, -1)) + return w_tmp @torch.no_grad() def get_weight_scale(self, layers_dict): @@ -49,20 +59,82 @@ def get_weight_scale(self, layers_dict): torch.cuda.empty_cache() return scale - @torch.no_grad() def get_act_scale(self, x): - return x.abs().view(-1, x.shape[-1]).mean(0) + batch_means = [] + b_num = x.shape[0] // self._bs + for num in range(b_num): + batch_x = x[num * self._bs:(num + 1) * self._bs] + batch_mean = batch_x.abs().view(-1, batch_x.shape[-1]).mean(0) + batch_means.append(batch_mean) + final_mean = sum(batch_means) / len(batch_means) + return final_mean + + @torch.no_grad() + def get_scales(self, prev_op, x, w_max, is_gqa, ratio): + if is_gqa: + x_tmp = prev_op(x) + w_tmp = self.get_weight_scale({'prev_op': prev_op}) + else: + x_tmp = x + w_tmp = w_max + + x_tmp = self.get_act_scale(x_tmp) + + if self.trans_version == 'v1': + scales = ( + (x_tmp.pow(ratio) / w_tmp.pow(1 - ratio)) + .clamp(min=1e-4) + .view(-1) + ) + elif self.trans_version == 'v2': + scales = x_tmp.pow(ratio).clamp(min=1e-4).view(-1) + + scales = scales / (scales.max() * scales.min()).sqrt() + return scales + + def inspect_module_forward(self, x, inspect_module, kwargs): + outs = [] + b_num = x.shape[0] // self._bs + for num in range(b_num): + _x = x[num * self._bs:(num + 1) * self._bs] + out = inspect_module(_x, **kwargs) + if isinstance(out, tuple): + out = out[0] + outs.append(out) + return torch.cat(outs, dim=0) @torch.no_grad() def get_original_out(self, x, inspect_module, subset_kwargs): with torch.no_grad(): - org_out = inspect_module(x, **subset_kwargs) - if isinstance(org_out, tuple): - org_out = org_out[0] + org_out = self.inspect_module_forward(x, inspect_module, subset_kwargs) return org_out + def calculate_loss(self, org_out, out): + total_loss = 0.0 + b_num = org_out.shape[0] // self._bs + for num in range(b_num): + _org_out = org_out[num * self._bs:(num + 1) * self._bs] + _out = out[num * self._bs:(num + 1) * self._bs] + single_loss = (_org_out - _out).float().pow(2).mean().item() + total_loss += single_loss + return total_loss / b_num + @torch.no_grad() - def search_scale_subset(self, layers_dict, input, inspect_module, subset_kwargs): + def search_scale_subset( + self, + prev_op, + layers_dict, + input, + inspect_module, + is_gqa, + subset_kwargs + ): + + if self.awq_bs is None: + self._bs = input[0].shape[0] + else: + self._bs = self.awq_bs + w_max = self.get_weight_scale(layers_dict) # grid search for ratio best_error = float('inf') @@ -89,18 +161,10 @@ def search_scale_subset(self, layers_dict, input, inspect_module, subset_kwargs) x_max = self.get_act_scale(x) ratio = n * 1 / n_grid - if self.trans_version == 'v1': - scales = ( - (x_max.pow(ratio) / w_max.pow(1 - ratio)) - .clamp(min=1e-4) - .view(-1) - ) - elif self.trans_version == 'v2': - scales = x_max.pow(ratio).clamp(min=1e-4).view(-1) - scales = scales / (scales.max() * scales.min()).sqrt() + scales = self.get_scales(prev_op, x, w_max, is_gqa, ratio) for layer_name in layers_dict: fc = layers_dict[layer_name] - fc.weight.mul_(scales.view(1, -1)) + fc.weight = self.scaling_weight(fc.weight, scales, is_gqa) fc.weight.data = get_wquantizer( self.block_idx, @@ -110,7 +174,12 @@ def search_scale_subset(self, layers_dict, input, inspect_module, subset_kwargs) self.wquantizer, ).fake_quant_weight_dynamic(fc.weight.data) - x_tmp = x / scales.view(1, -1) + del x_max + gc.collect() + torch.cuda.empty_cache() + + x_tmp = self.scaling_input(x, scales, is_gqa) + if not check_w_only( self.block_idx, list(layers_dict.keys())[0], @@ -118,23 +187,26 @@ def search_scale_subset(self, layers_dict, input, inspect_module, subset_kwargs) self.quantizer_mix_bits, self.w_only, ): - x_tmp = get_aquantizer( - self.block_idx, - list(layers_dict.keys())[0], - self.mix_bits_map, - self.quantizer_mix_bits, - self.aquantizer, - ).fake_quant_act_dynamic(x_tmp) - out = inspect_module(x_tmp, **kwargs) - - if isinstance(out, tuple): - out = out[0] + outs = [] + for i in range(x_tmp.shape[0]): + _x = x_tmp[i] + _x = get_aquantizer( + self.block_idx, + list(layers_dict.keys())[0], + self.mix_bits_map, + self.quantizer_mix_bits, + self.aquantizer, + ).fake_quant_act_dynamic(_x) + outs.append(_x) + x_tmp = torch.stack(outs) + + out = self.inspect_module_forward(x_tmp, inspect_module, kwargs) if self.padding_mask and org_out.shape[1] == self.padding_mask[i].shape[-1]: org_out = org_out * self.padding_mask[i].unsqueeze(dim=-1).to(org_out.device) # noqa out = out * self.padding_mask[i].unsqueeze(dim=-1).to(out.device) - loss = (org_out - out).float().pow(2).mean().item() + loss = self.calculate_loss(org_out, out) if len(input) == 1: n_samples = x.shape[0] @@ -149,6 +221,11 @@ def search_scale_subset(self, layers_dict, input, inspect_module, subset_kwargs) best_error = loss_mean best_scales = scales_mean + del org_out + del out + gc.collect() + torch.cuda.empty_cache() + # Synchronize across ranks best_error_tensor = torch.tensor([best_error], device='cuda') dist.all_reduce(best_error_tensor, op=dist.ReduceOp.MIN) @@ -248,15 +325,28 @@ def subset_transform( and prev_op[0].out_features != layers[0].in_features * 2 and prev_op[0].out_features != layers[0].in_features ): - logger.info('Cannot apply scale. Do not transform this subset.') - return + + if self.has_gqa: + is_gqa = True + input_keys = list(input_feat.keys()) + input_name = input_keys[input_keys.index(input_name) - 1] + else: + logger.info('Cannot apply scale. Do not transform this subset.') + return + else: + is_gqa = False scale = self.search_scale_subset( - layers_dict, input_feat[input_name], inspect_module, subset_kwargs + prev_op[0], + layers_dict, + input_feat[input_name], + inspect_module, + is_gqa, + subset_kwargs ) self.apply_scale(scale, prev_op, layers) - self.update_input_feat(scale, input_feat, layers_dict) + self.update_input_feat(scale, input_feat, layers_dict, is_gqa) if self.save_scale: for n in layers_dict: diff --git a/llmc/compression/quantization/base_blockwise_quantization.py b/llmc/compression/quantization/base_blockwise_quantization.py index 3c9c56d7..bacf1c9f 100644 --- a/llmc/compression/quantization/base_blockwise_quantization.py +++ b/llmc/compression/quantization/base_blockwise_quantization.py @@ -283,15 +283,26 @@ def set_quant_config(self): assert self.config['model']['type'] in ['Opt', 'Llama'] self.hidden_size = self.model.model_config.hidden_size - if self.online_rotate: - self.num_heads = self.model.model_config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.intermediate_size = self.model.model_config.intermediate_size - self.fp32_had = special_config.get('fp32_had', False) - + self.set_model_config() self.quant_objects = self.quant_config.get('quant_objects', ['language']) logger.info(f'self.quant_objects : {self.quant_objects}') + def set_model_config(self): + self.hidden_size = self.model.model_config.hidden_size + self.num_heads = self.model.model_config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + if hasattr(self.model.model_config, 'intermediate_size'): + self.intermediate_size = self.model.model_config.intermediate_size + if hasattr(self.model.model_config, 'num_key_value_heads'): + self.num_key_value_heads = self.model.model_config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + if self.num_key_value_groups > 1: + self.has_gqa = True + else: + self.has_gqa = False + else: + self.has_gqa = False + def replace_rotate_linears(self, block): for n, m in block.named_modules(): if isinstance(m, nn.Linear) and ( @@ -581,6 +592,12 @@ def register_act_qparams(self, layers_dict, act_tensors): layer.register_buffer(f'buf_act_qmin_{i}', qmin.cuda()) layer.register_buffer(f'buf_act_qmax_{i}', qmax.cuda()) + @torch.no_grad() + def repeat_gqa_scales(self, scales): + scales = scales.view(1, self.num_key_value_heads, self.head_dim) + scales = torch.repeat_interleave(scales, dim=1, repeats=self.num_key_value_groups) + return scales + @torch.no_grad() def apply_scale(self, scales, prev_op, layers): assert ( @@ -652,6 +669,14 @@ def scale_fc_fc(self, fc1, fc2, scales): fc1.bias.div_(scales.view(-1)) fc1.weight.div_(scales.view(-1, 1)) + elif self.has_gqa: + if hasattr(fc1, 'bias') and fc1.bias is not None: + fc1.bias.div_(scales.view(-1)) + fc1.weight.div_(scales.view(-1, 1)) + + if fc1.out_features != fc2.in_features: + logger.info('GQA scale this fc-fc.') + scales = self.repeat_gqa_scales(scales) else: logger.error(f'fc1.out_features: {fc1.out_features}') logger.error(f'fc2.in_features: {fc2.in_features}') @@ -795,11 +820,26 @@ def bake_mean_into_fc(self, fc): fc.bias.data = fc.bias.data.to(fc_dtype) @torch.no_grad() - def update_input_feat(self, scale, input_feat, layers_dict): + def scaling_input(self, x, scales, is_gqa): + if is_gqa: + scales_tmp = self.repeat_gqa_scales(scales) + else: + scales_tmp = scales + + x_tmp = torch.empty_like(x) + for i, batch in enumerate(x): + batch_scale = scales_tmp.view(1, -1) + x_tmp[i] = batch / batch_scale + + return x_tmp + + @torch.no_grad() + def update_input_feat(self, scale, input_feat, layers_dict, is_gqa): for layer_name in layers_dict: for i in range(len(input_feat[layer_name])): inp = input_feat[layer_name][i] - inp.div_(scale.view(1, -1).to(inp.device)) + scale = scale.to(inp.device) + inp = self.scaling_input(inp, scale, is_gqa) @torch.no_grad() def set_non_linear_mode(self, quant_format, module, mode): diff --git a/llmc/compression/quantization/quant.py b/llmc/compression/quantization/quant.py index af3b7bba..472f1579 100644 --- a/llmc/compression/quantization/quant.py +++ b/llmc/compression/quantization/quant.py @@ -17,8 +17,6 @@ def __init__(self, bit, symmetric, granularity, **kwargs): elif self.granularity == 'per_head': self.head_num = self.kwargs['head_num'] - self.mse_b_num = self.kwargs.get('mse_b_num', 1) - if self.kwargs.get('ste', False): self.round_func = lambda x: (x.round() - x).detach() + x else: @@ -33,12 +31,17 @@ def __init__(self, bit, symmetric, granularity, **kwargs): self.sigmoid = torch.nn.Sigmoid() # mse config + self.mse_b_num = self.kwargs.get('mse_b_num', 1) self.maxshrink = self.kwargs.get('maxshrink', 0.8) self.mse_grid = self.kwargs.get('mse_grid', 100) # hist config self.bins = self.kwargs.get('bins', 2048) self.hist_threshold = self.kwargs.get('hist_threshold', 1) + self.dst_nbins = 2**bit + self.upsample_rate = ( + 16 # used to reduce quantization errors when upscaling histogram + ) # hqq config self.lp_norm = self.kwargs.get('lp_norm', 0.7) @@ -83,7 +86,7 @@ def get_tensor_range(self, tensor, args={}): elif self.calib_algo == 'learnable': return self.get_learnable_range(tensor, **args) else: - raise ValueError(f'Unsupported calibration algorithm: {self.calib_algo}') + return self.get_minmax_range(tensor) def get_hist_range(self, stats_min_max, act_stats_hist): clip_val = {} @@ -108,11 +111,12 @@ def get_hist_range(self, stats_min_max, act_stats_hist): torch.min(stats_min_max[input_idx]['min']), torch.max(stats_min_max[input_idx]['max']), ] - runing_min_vals, runing_max_vals = [], [] + + moving_min_vals, moving_max_vals = [], [] for input_idx, tensor_range in clip_val.items(): - runing_min_vals.append(tensor_range[0]) - runing_max_vals.append(tensor_range[1]) - return runing_min_vals, runing_max_vals + moving_min_vals.append(tensor_range[0]) + moving_max_vals.append(tensor_range[1]) + return moving_min_vals, moving_max_vals def get_minmax_range(self, tensor): if self.granularity == 'per_tensor': @@ -233,11 +237,9 @@ def get_minmax_stats(self, act_tensors): return stats_min_max def get_static_minmax_range(self, act_tensors): - act_tensors = self.reshape_batch_tensors(act_tensors) stats_min_max = self.get_minmax_stats(act_tensors) min_vals, max_vals = [], [] - for input_idx, tensor_range in stats_min_max.items(): min_val = tensor_range['min'].mean() max_val = tensor_range['max'].mean() @@ -246,101 +248,292 @@ def get_static_minmax_range(self, act_tensors): return min_vals, max_vals - def get_static_hist_range(self, act_tensors): - - act_tensors = self.reshape_batch_tensors(act_tensors) + def get_norm( + self, delta_begin: torch.Tensor, delta_end: torch.Tensor, density: torch.Tensor + ) -> torch.Tensor: + r""" + Compute the norm of the values uniformaly distributed between + delta_begin and delta_end. + Currently only L2 norm is supported. + + norm = density * (integral_{begin, end} x^2) + = density * (end^3 - begin^3) / 3 + """ + norm = ( + delta_end * delta_end * delta_end - delta_begin * delta_begin * delta_begin + ) / 3 + return density * norm + + def get_quantization_error(self, histogram, min_val, max_val, next_start_bin, next_end_bin): + r""" + Compute the quantization error if we use start_bin to end_bin as the + min and max to do the quantization. + """ + bin_width = (max_val.item() - min_val.item()) / self.bins + + dst_bin_width = bin_width * (next_end_bin - next_start_bin + 1) / self.dst_nbins + if dst_bin_width == 0.0: + return 0.0 + + src_bin = torch.arange(self.bins, device=histogram.device) + # distances from the beginning of first dst_bin to the beginning and + # end of src_bin + src_bin_begin = (src_bin - next_start_bin) * bin_width + src_bin_end = src_bin_begin + bin_width + + # which dst_bins the beginning and end of src_bin belong to? + dst_bin_of_begin = torch.clamp( + torch.div(src_bin_begin, dst_bin_width, rounding_mode='floor'), + 0, + self.dst_nbins - 1, + ) + dst_bin_of_begin_center = (dst_bin_of_begin + 0.5) * dst_bin_width - stats_min_max = stats_min_max = self.get_minmax_stats(act_tensors) - act_stats_hist = {} - for input_idx, tensors in enumerate(act_tensors): - for tensor in tensors: - data_max = max( - torch.max(stats_min_max[input_idx]['max']), - -torch.min(stats_min_max[input_idx]['min']), - ) - hist = torch.histc( - torch.abs(tensor), bins=int(self.bins), min=0, max=data_max - ) + dst_bin_of_end = torch.clamp( + torch.div(src_bin_end, dst_bin_width, rounding_mode='floor'), + 0, + self.dst_nbins - 1, + ) + density = histogram / bin_width - if input_idx not in act_stats_hist: - act_stats_hist[input_idx] = [hist] - else: - act_stats_hist[input_idx].append(hist) + norm = torch.zeros(self.bins, device=histogram.device) - for input_idx, hist in act_stats_hist.items(): - act_stats_hist[input_idx] = torch.stack(hist).sum(0) + delta_begin = src_bin_begin - dst_bin_of_begin_center + delta_end = dst_bin_width / 2 + norm += self.get_norm( + delta_begin, + torch.ones(self.bins, device=histogram.device) * delta_end, + density, + ) - runing_min_vals, runing_max_vals = self.get_hist_range( - stats_min_max, act_stats_hist + norm += (dst_bin_of_end - dst_bin_of_begin - 1) * self.get_norm( + torch.tensor(-dst_bin_width / 2), torch.tensor(dst_bin_width / 2), density ) - return runing_min_vals, runing_max_vals - def get_static_runing_minmax_range(self, act_tensors, alpha): - act_tensors = self.reshape_batch_tensors(act_tensors) - runing_min_vals, runing_max_vals = [], [] - for tensors in act_tensors: - runing_min_val, runing_max_val = None, None - for tensor in tensors: - tensor = self.reshape_tensor(tensor) - tensor_range = self.get_minmax_range(tensor) - min_val, max_val = tensor_range[0], tensor_range[1] + dst_bin_of_end_center = dst_bin_of_end * dst_bin_width + dst_bin_width / 2 + + delta_begin = -dst_bin_width / 2 + delta_end = src_bin_end - dst_bin_of_end_center + norm += self.get_norm(torch.tensor(delta_begin), delta_end, density) + + return norm.sum().item() + + def _upscale_histogram(self, histogram, orig_min, orig_max, update_min, update_max): + # this turns the histogram into a more fine-coarsed histogram to reduce + # bin quantization errors + histogram = histogram.repeat_interleave(self.upsample_rate) / self.upsample_rate + bin_size = (orig_max - orig_min) / (self.bins * self.upsample_rate) + mid_points_histogram = ( + torch.linspace( + orig_min, + orig_max, + self.bins * self.upsample_rate + 1, + device=orig_min.device, + )[:-1].to(histogram.device) + + 0.5 * bin_size + ) + boundaries_new_histogram = torch.linspace( + update_min, update_max, self.bins + 1, device=update_min.device + ).to(histogram.device) + # this maps the mid-poits of the histogram to the new histogram's space + bucket_assignments = ( + torch.bucketize(mid_points_histogram, boundaries_new_histogram, right=True) + - 1 + ) + # this then maps the histogram mid-points in the new space, + # weighted by the original histogram's values + # this is just the old histogram in the new histogram's space - if runing_min_val is None or runing_max_val is None: - runing_min_val = min_val - runing_max_val = max_val - else: - runing_min_val = runing_min_val + alpha * (min_val - runing_min_val) - runing_max_val = runing_max_val + alpha * (max_val - runing_max_val) - runing_min_vals.append(runing_min_val) - runing_max_vals.append(runing_max_val) + # In case due to numerical issues the values land higher/lower than the maximum/minimum + bucket_assignments[bucket_assignments >= self.bins] = self.bins - 1 + bucket_assignments[bucket_assignments < 0] = 0 - return runing_min_vals, runing_max_vals + update_histogram = torch.bincount( + bucket_assignments, weights=histogram, minlength=self.bins + ) + return update_histogram + + def _combine_histograms( + self, orig_hist, orig_min, orig_max, update_hist, update_min, update_max + ): + # If the new min and max are the same as the current min and max, + # we can just add the new histogram to the original histogram + if update_min == orig_min and update_max == orig_max: + return orig_hist + update_hist + + # If the orig hist only has one value (i.e., the min and max are the same) + # we can just add it into new histogram + if orig_min == orig_max: + bin_value = torch.sum(update_hist) + transformed_orig_hist = ( + torch.histc(orig_min, + bins=self.bins, + min=update_min, + max=update_max) # type: ignore[arg-type] + * bin_value + ) + return transformed_orig_hist + update_hist + + # We assume the update_hist is already in the target range, we will map the orig_max to it + assert update_min <= orig_min + assert update_max >= orig_max + + # Now we need to turn the old_histogram, into the range of the new histogram + transformed_orig_hist = self._upscale_histogram( + orig_hist, + orig_min, + orig_max, + update_min, + update_max, + ) - def get_static_mse_range(self, act_tensors, norm=2.4): - act_tensors = self.reshape_batch_tensors(act_tensors) - stats_min_max = self.get_minmax_stats(act_tensors) - best_min_vals, best_max_vals = [], [] + return update_hist + transformed_orig_hist + + def get_hist_threshold(self, histogram, min_val, max_val): + + assert histogram.size()[0] == self.bins, 'bins mismatch' + bin_width = (max_val - min_val) / self.bins + + # cumulative sum + total = torch.sum(histogram).item() + cSum = torch.cumsum(histogram, dim=0) + + stepsize = 1e-8 + alpha = 0.0 # lower bound + beta = 1.0 # upper bound + start_bin = 0 + end_bin = self.bins - 1 + norm_min = float('inf') + + while alpha < beta: + # Find the next step + next_alpha = alpha + stepsize + next_beta = beta - stepsize + + # find the left and right bins between the quantile bounds + left = start_bin + right = end_bin + while left < end_bin and cSum[left] < next_alpha * total: + left = left + 1 + while right > start_bin and cSum[right] > next_beta * total: + right = right - 1 + + # decide the next move + next_start_bin = start_bin + next_end_bin = end_bin + if (left - start_bin) > (end_bin - right): + # move the start bin + next_start_bin = left + alpha = next_alpha + else: + # move the end bin + next_end_bin = right + beta = next_beta - for input_idx, tensor_range in stats_min_max.items(): - _min_val = tensor_range['min'].mean() - _max_val = tensor_range['max'].mean() - _tensor = torch.stack(act_tensors[input_idx]).float() + if next_start_bin == start_bin and next_end_bin == end_bin: + continue - best = float('inf') - best_min_val, best_max_val = _min_val, _max_val - dev = _tensor.device + # calculate the quantization error using next_start_bin and next_end_bin + norm = self.get_quantization_error(histogram, + min_val, + max_val, + next_start_bin, + next_end_bin) - for i in range(int(self.maxshrink * self.mse_grid)): - p = 1 - i / self.mse_grid + if norm > norm_min: + break + norm_min = norm + start_bin = next_start_bin + end_bin = next_end_bin - xmin = p * _min_val - xmax = p * _max_val + new_min = min_val + bin_width * start_bin + new_max = min_val + bin_width * (end_bin + 1) + return new_min, new_max - if self.quant_type == 'float-quant' and not self.use_qtorch: - clip_tensor, scales = self.get_float_qparams( - _tensor, (xmin, xmax), dev - ) - zeros, qmin, qmax = 0, None, None - q_tensor = self.quant_dequant( - clip_tensor, scales, zeros, qmax, qmin + def get_static_hist_range(self, act_tensors): + act_tensors = self.reshape_batch_tensors(act_tensors) + stats_min_max = self.get_minmax_stats(act_tensors) + min_vals, max_vals = [], [] + histograms = [] + for input_idx, tensors in enumerate(act_tensors): + min_val, max_val = None, None + histogram = torch.zeros(self.bins) + tensor_range = stats_min_max[input_idx] + for idx, tensor in enumerate(tensors): + tensor = tensor.float() + x_min, x_max = tensor_range['min'][idx], tensor_range['max'][idx] + if min_val is None or max_val is None: + new_histogram = torch.histc( + tensor, self.bins, min=x_min.item(), max=x_max.item() ) + histogram.detach_().resize_(new_histogram.shape) + histogram.copy_(new_histogram) + min_val, max_val = x_min, x_max else: - scales, zeros, qmax, qmin = self.get_qparams((xmin, xmax), dev) - q_tensor = self.quant_dequant(_tensor, scales, zeros, qmax, qmin) + current_min, current_max = min_val, max_val + update_min, update_max = x_min, x_max + new_min = torch.min(current_min, update_min) + new_max = torch.max(current_max, update_max) + + update_histogram = torch.histc( + tensor, self.bins, min=new_min.item(), max=new_max.item() + ).to(histogram.device) + + if new_min == current_min and new_max == current_max: + combined_histogram = histogram + update_histogram + histogram.detach_().resize_(combined_histogram.shape) + histogram.copy_(combined_histogram) + else: + combined_histogram = self._combine_histograms( + histogram, + current_min, + current_max, + update_histogram, + new_min, + new_max, + ) + histogram.detach_().resize_(combined_histogram.shape) + histogram.copy_(combined_histogram) + + min_val, max_val = new_min, new_max - q_tensor -= _tensor - q_tensor.abs_() - q_tensor.pow_(norm) - err = torch.sum(q_tensor) + min_vals.append(min_val) + max_vals.append(max_val) + histograms.append(histogram) - if err < best: - best_min_val, best_max_val = xmin, xmax + new_min_vals, new_max_vals = [], [] + for i in range(len(histograms)): + histogram = histograms[i] + min_val, max_val = min_vals[i], max_vals[i] + new_min, new_max = self.get_hist_threshold( + histogram, min_val, max_val + ) + new_min_vals.append(new_min) + new_max_vals.append(new_max) - best_min_vals.append(best_min_val) - best_max_vals.append(best_max_val) + return new_min_vals, new_max_vals - return best_min_vals, best_max_vals + def get_static_moving_minmax_range(self, act_tensors, alpha): + act_tensors = self.reshape_batch_tensors(act_tensors) + moving_min_vals, moving_max_vals = [], [] + for tensors in act_tensors: + moving_min_val, moving_max_val = None, None + for tensor in tensors: + tensor = self.reshape_tensor(tensor) + tensor_range = self.get_minmax_range(tensor) + min_val, max_val = tensor_range[0], tensor_range[1] + + if moving_min_val is None or moving_max_val is None: + moving_min_val = min_val + moving_max_val = max_val + else: + moving_min_val = moving_min_val + alpha * (min_val - moving_min_val) + moving_max_val = moving_max_val + alpha * (max_val - moving_max_val) + moving_min_vals.append(moving_min_val) + moving_max_vals.append(moving_max_val) + + return moving_min_vals, moving_max_vals def get_qparams(self, tensor_range, device): min_val, max_val = tensor_range[0], tensor_range[1] @@ -361,14 +554,17 @@ def get_qparams(self, tensor_range, device): def get_batch_tensors_qparams(self, act_tensors, alpha=0.01, args={}): scales_list, zeros_list, qmin_list, qmax_list = [], [], [], [] - if self.calib_algo == 'hist': + if self.calib_algo == 'static_hist': + assert ( + self.sym is True and self.granularity == 'per_tensor' + ), 'Only support per tensor static symmetric.' min_vals, max_vals = self.get_static_hist_range(act_tensors) - elif self.calib_algo == 'minmax': + elif self.calib_algo == 'static_minmax': min_vals, max_vals = self.get_static_minmax_range(act_tensors) - elif self.calib_algo == 'runing_minmax': - min_vals, max_vals = self.get_static_runing_minmax_range(act_tensors, alpha) - elif self.calib_algo == 'mse': - min_vals, max_vals = self.get_static_mse_range(act_tensors) + elif self.calib_algo == 'static_moving_minmax': + min_vals, max_vals = self.get_static_moving_minmax_range(act_tensors, alpha) + else: + raise ValueError(f'Unsupported calibration algorithm: {self.calib_algo}') for i in range(len(min_vals)): min_val, max_val = min_vals[i], max_vals[i] @@ -396,8 +592,6 @@ def optimize_weights_proximal(self, tensor, scales, zeros, qmax, qmin): current_beta *= current_kappa current_error = float(torch.abs(tensor - W_r).mean()) - logger.info(f'iter : {i}, error : {current_error}') - if current_error < best_error: best_error = current_error else: diff --git a/llmc/compression/quantization/tesseraq.py b/llmc/compression/quantization/tesseraq.py index e128373f..f81f0d76 100644 --- a/llmc/compression/quantization/tesseraq.py +++ b/llmc/compression/quantization/tesseraq.py @@ -1,3 +1,4 @@ +import copy import functools import gc import math @@ -144,6 +145,36 @@ def get_original_out(self, block): if self.aug_loss: self.ori_out2 = self.block_forward(block) + @torch.no_grad() + def collect_block_qparams(self, block, input_feat): + named_linears = self.model.get_block_linears(block) + for n, m in named_linears.items(): + args = {} + if hasattr(m, 'buf_lowbound_factor'): + args['lowbound_factor'] = m.buf_lowbound_factor + if hasattr(m, 'buf_upbound_factor'): + args['upbound_factor'] = m.buf_upbound_factor + ( + tensor, + scales, + zeros, + max_int, + min_int, + ) = self.wquantizer.get_tensor_qparams(m.weight.data, args=args) + m.register_buffer('buf_scales', scales) + m.register_buffer('buf_zeros', zeros) + m.register_buffer('buf_qmax', torch.tensor(max_int).to(self.dev)) + m.register_buffer('buf_qmin', torch.tensor(min_int).to(self.dev)) + + if self.act_static: + subsets = self.model.get_subsets_in_block(block) + for index, subset in enumerate(subsets): + layers_dict = subset['layers'] + input_name = subset['input'][0] + input_tensors = copy.deepcopy(input_feat[input_name]) + self.register_act_qparams(layers_dict, input_tensors) + del input_tensors + @torch.no_grad() def block_transform(self, block, input_feat, block_kwargs): logger.info(f'Start transform the {self.block_idx+1}-th block') @@ -163,7 +194,7 @@ def block_transform(self, block, input_feat, block_kwargs): if self.weight_clip: self.tesseraq_weight_clip(block, input_feat) - self.collect_block_qparams(block) # collect quant range after transformation + self.collect_block_qparams(block, input_feat) # collect quant range after transformation self.register_tesseraq_parameters(block) self.tesseraq_train(block)