From ce8be2148ca71ffd77bd516cd841a2e5d60af3b6 Mon Sep 17 00:00:00 2001 From: gushiqiao Date: Tue, 17 Dec 2024 16:44:33 +0800 Subject: [PATCH] Add gptq chunk support --- .../combination/quarot_comb_gptq/w8a8/step_2_gptq.yml | 1 + llmc/compression/quantization/gptq.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/configs/quantization/combination/quarot_comb_gptq/w8a8/step_2_gptq.yml b/configs/quantization/combination/quarot_comb_gptq/w8a8/step_2_gptq.yml index a43305ab..2f12cbfe 100644 --- a/configs/quantization/combination/quarot_comb_gptq/w8a8/step_2_gptq.yml +++ b/configs/quantization/combination/quarot_comb_gptq/w8a8/step_2_gptq.yml @@ -42,6 +42,7 @@ quant: static_groups: True percdamp: 0.01 blocksize: 128 + chunk_num: 4 true_sequential: True online_rotate: True fp32_had: True diff --git a/llmc/compression/quantization/gptq.py b/llmc/compression/quantization/gptq.py index 18d54822..263d9db0 100644 --- a/llmc/compression/quantization/gptq.py +++ b/llmc/compression/quantization/gptq.py @@ -42,6 +42,7 @@ def add_quant_config(self): self.blocksize = special_config['blocksize'] self.owq = special_config.get('owq', False) + self.chunk_num = special_config.get('chunk_num', 1) if self.owq: self.n_outs = special_config['n_outs'] @@ -275,12 +276,18 @@ def add_batch(self, layer, name, inp, out): inp = inp.permute([1, 0, 2]) inp = inp.flatten(1) + assert inp.shape[1] % self.chunk_num == 0, \ + f'Error: inp.shape[1] ({inp.shape[1]}) cannot be evenly divided by chunk_num.' + chunks = torch.chunk(inp, self.chunk_num, dim=1) + self.layers_cache[name]['H'] *= self.layers_cache[name]['nsamples'] / ( self.layers_cache[name]['nsamples'] + tmp ) self.layers_cache[name]['nsamples'] += tmp - inp = math.sqrt(2 / self.layers_cache[name]['nsamples']) * inp.float() - self.layers_cache[name]['H'] += inp.matmul(inp.t()) + + for chunk in chunks: + chunk = math.sqrt(2 / self.layers_cache[name]['nsamples']) * chunk.float() + self.layers_cache[name]['H'] += chunk.matmul(chunk.t()) dist.all_reduce(self.layers_cache[name]['H'], op=dist.ReduceOp.SUM) dist.all_reduce(torch.tensor(self.layers_cache[name]['nsamples']).cuda(),