From 14868c0900a1f91fe39f138c67156ad66c16b20f Mon Sep 17 00:00:00 2001 From: n1ck-guo <110074967+n1ck-guo@users.noreply.github.com> Date: Tue, 2 Apr 2024 14:18:09 +0800 Subject: [PATCH] Refact gptq to support runing on gaudi (#1700) * gptq support for gaudi Signed-off-by: n1ck-guo --- .../torch/algorithms/weight_only/gptq.py | 21 +++++++++++++++---- .../torch/utils/auto_accelerator.py | 16 ++++++++++++++ 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/neural_compressor/torch/algorithms/weight_only/gptq.py b/neural_compressor/torch/algorithms/weight_only/gptq.py index 5c5d68a4f72..53bee017076 100644 --- a/neural_compressor/torch/algorithms/weight_only/gptq.py +++ b/neural_compressor/torch/algorithms/weight_only/gptq.py @@ -34,6 +34,7 @@ from .modules import WeightOnlyLinear DEBUG = False +accelerator = auto_detect_accelerator() # ================ device related =================== @@ -542,8 +543,10 @@ def forward(layer, *args, **kwargs): if self.run_fn: if self.run_args: self.run_fn(self.model, *self.run_args) + accelerator.mark_step() else: self.run_fn(self.model) + accelerator.mark_step() else: for batch in tqdm(self.dataloader): if not self.use_layer_wise: @@ -663,6 +666,7 @@ def tmp(_, inp, out): for j in range(batch_num): cache_keyword_batch = self.gather_single_batch_from_dict(self.cache_key_arguments, j) cache_positional_batch = self.gather_single_batch_from_list(self.cache_positional_arguments, j) + accelerator.mark_step() out = transformer_block(*cache_positional_batch, **cache_keyword_batch) out = self.track_hidden_states(out) self.cache_key_arguments["batch_num"] = batch_num @@ -682,6 +686,9 @@ def tmp(_, inp, out): W = load_value(self.model, full_layer_name + ".weight", model_path) else: W = sub_layers[layer_name].weight.data.clone() + accelerator.mark_step() + if "hpu" in self.device: + W = W.to("cpu") scale, zp, Q = gptq_for_this_block[layer_name].fasterquant( W, blocksize=weight_config_this_layer["block_size"], @@ -854,6 +861,8 @@ def fasterquant(self, W, blocksize=128, percdamp=0.01, groupsize=-1, act_order=F self.quantizer.find_params(W, weight=True) H = self.H + if "hpu" in self.device: + H = H.to("cpu") del self.H dead = torch.diag(H) == 0 H[dead, dead] = 1 @@ -958,6 +967,10 @@ def fasterquant(self, W, blocksize=128, percdamp=0.01, groupsize=-1, act_order=F zero.append(self.quantizer.zero) scale = torch.cat(scale, dim=1) zero = torch.cat(zero, dim=1) + if "hpu" in self.device: + scale = scale.to(self.device) + zero = zero.to(self.device) + Q = Q.to(self.device) return scale, zero, Q def free(self): @@ -973,25 +986,25 @@ def free(self): class Quantizer(nn.Module): def __init__(self, shape=1): super(Quantizer, self).__init__() - self.register_buffer("maxq", torch.tensor(0)) + self.maxq = 0 self.register_buffer("scale", torch.zeros(shape)) self.register_buffer("zero", torch.zeros(shape)) def configure(self, weight_config_this_layer, norm=2.4, grid=100, maxshrink=0.8, trits=False): for k, v in weight_config_this_layer.items(): setattr(self, k, v) - self.maxq = torch.tensor(2**self.bits - 1) + # self.maxq = torch.tensor(2**self.bits - 1) + self.maxq = 2**self.bits - 1 self.scheme = "sym" if self.sym else "asym" self.double_quant_scheme = "sym" if self.double_quant_sym else "asym" self.norm = norm self.grid = grid self.maxshrink = maxshrink if trits: - self.maxq = torch.tensor(-1) + self.maxq = -1 def find_params(self, x, weight=False): dev = x.device - self.maxq = self.maxq.to(dev) # NF4 FP4 if self.dtype != "int": from .utility import quant_tensor diff --git a/neural_compressor/torch/utils/auto_accelerator.py b/neural_compressor/torch/utils/auto_accelerator.py index 2887f9166ec..7e59f00e180 100644 --- a/neural_compressor/torch/utils/auto_accelerator.py +++ b/neural_compressor/torch/utils/auto_accelerator.py @@ -29,8 +29,11 @@ import torch +from neural_compressor.common.utils import LazyImport from neural_compressor.torch.utils import logger +htcore = LazyImport("habana_frameworks.torch.core") + PRIORITY_HPU = 100 PRIORITY_CUDA = 95 PRIORITY_CPU = 90 @@ -133,6 +136,10 @@ def empty_cache(self): def synchronize(self): pass + @abstractmethod + def mark_step(self): + pass + @register_accelerator(name="cpu", priority=PRIORITY_CPU) class CPU_Accelerator(Auto_Accelerator): @@ -167,6 +174,9 @@ def empty_cache(self): def synchronize(self): pass + def mark_step(self): + pass + @register_accelerator(name="cuda", priority=PRIORITY_CUDA) class CUDA_Accelerator(Auto_Accelerator): @@ -203,6 +213,9 @@ def device(self, device_index=None): def empty_cache(self): return torch.cuda.empty_cache() + def mark_step(self): + pass + @register_accelerator(name="hpu", priority=PRIORITY_HPU) class HPU_Accelerator(Auto_Accelerator): @@ -244,6 +257,9 @@ def device(self, device_index=None): def empty_cache(self): return torch.hpu.empty_cache() + def mark_step(self): + return htcore.mark_step() + def auto_detect_accelerator(device_name="auto") -> Auto_Accelerator: # Force use the cpu on node has both cpu and gpu: `FORCE_DEVICE=cpu` python main.py ...