From da1ada236eb867b69c663c58904e0a21ad9bcb88 Mon Sep 17 00:00:00 2001 From: Kaihui-intel Date: Mon, 3 Jun 2024 13:14:06 +0800 Subject: [PATCH] Fix WOQ Linear pack slow issue (#1828) Signed-off-by: Kaihui-intel --- .../torch/algorithms/weight_only/modules.py | 55 ++++++++++++++++++- 1 file changed, 53 insertions(+), 2 deletions(-) diff --git a/neural_compressor/torch/algorithms/weight_only/modules.py b/neural_compressor/torch/algorithms/weight_only/modules.py index 768aa0c2fdc..fe2c0ac80af 100644 --- a/neural_compressor/torch/algorithms/weight_only/modules.py +++ b/neural_compressor/torch/algorithms/weight_only/modules.py @@ -19,6 +19,7 @@ # since the model classes inherit torch.nn.Module. import math +import numpy as np import torch from torch.autograd import Function from torch.nn import functional as F @@ -270,7 +271,7 @@ def recover(self): fp32_weight[:, idx] = weight[:, idx] * scales[:, self.g_idx[idx]] return fp32_weight - def pack_tensor(self, raw_tensor): + def pack_tensor_with_torch(self, raw_tensor): target_len = math.ceil(raw_tensor.shape[1] / self.n_pack) packed_tensor = torch.zeros(raw_tensor.shape[0], target_len, dtype=self.compression_dtype).to(self.device) mask = torch.tensor(2**self.bits - 1, dtype=self.compression_dtype).to(self.device) @@ -285,7 +286,7 @@ def pack_tensor(self, raw_tensor): accelerator.synchronize() return packed_tensor - def unpack_tensor(self, packed_tensor): + def unpack_tensor_with_torch(self, packed_tensor): target_dtype = torch.int8 if not hasattr(self, "qzeros") or "int" not in self.dtype else torch.uint8 target_len = packed_tensor.shape[1] * self.n_pack unpacked_tensor = torch.zeros(packed_tensor.shape[0], target_len, dtype=self.compression_dtype).to(self.device) @@ -302,6 +303,56 @@ def unpack_tensor(self, packed_tensor): accelerator.synchronize() return unpacked_tensor + def pack_tensor_with_numpy(self, raw_tensor): + raw_array = raw_tensor.cpu().numpy() + target_len = np.ceil(raw_array.shape[1] / self.n_pack).astype(int) + torch.int32 + target_dtype = torch.tensor(0, dtype=self.compression_dtype).numpy().dtype + packed_array = np.zeros((raw_array.shape[0], target_len), dtype=target_dtype) + mask = np.uint8(2**self.bits - 1) + for j in range(packed_array.shape[1]): + start = self.n_pack * j + end = self.n_pack * (j + 1) + tmp = raw_array[:, start:end].astype(target_dtype) + tmp &= mask + for e in range(tmp.shape[1]): + tmp[:, e] = np.left_shift(tmp[:, e], self.bits * e) + packed_array[:, j] |= tmp[:, e] + accelerator.synchronize() + packed_tensor = torch.from_numpy(packed_array).to(device=raw_tensor.device) + return packed_tensor + + def unpack_tensor_with_numpy(self, packed_tensor): + packed_array = packed_tensor.cpu().numpy() + target_dtype = np.int8 if not hasattr(self, "qzeros") or "int" not in self.dtype else np.uint8 + target_len = packed_array.shape[1] * self.n_pack + unpacked_array = np.zeros((packed_array.shape[0], target_len), dtype=target_dtype) + mask = np.uint8(2**self.bits - 1) + for j in range(packed_array.shape[1]): + for e in range(self.n_pack): + index = j * self.n_pack + e + tmp = packed_array[:, j] + tmp = np.left_shift(tmp, self.compress_bits - self.bits * (e + 1)) + tmp = np.right_shift(tmp, self.compress_bits - self.bits) + if target_dtype == np.uint8: + tmp &= mask + unpacked_array[:, index] = tmp.astype(target_dtype) + accelerator.synchronize() + unpacked_tensor = torch.from_numpy(unpacked_array).to(device=packed_tensor.device) + return unpacked_tensor + + def pack_tensor(self, raw_tensor): + if "cuda" in self.device: + return self.pack_tensor_with_torch(raw_tensor) + else: + return self.pack_tensor_with_numpy(raw_tensor) + + def unpack_tensor(self, packed_tensor): + if "cuda" in self.device: + return self.unpack_tensor_with_torch(packed_tensor) + else: + return self.unpack_tensor_with_numpy(packed_tensor) + def forward(self, input): if not hasattr(self, "weight"): weight = self.recover()