From 29fdecbbb44ceb8d19c12809af90dc23063becfc Mon Sep 17 00:00:00 2001 From: Kaihui-intel Date: Wed, 5 Jun 2024 09:24:49 +0800 Subject: [PATCH] Fix dtype of unpacked tensor (#1840) Signed-off-by: Kaihui-intel --- neural_compressor/torch/algorithms/weight_only/modules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/neural_compressor/torch/algorithms/weight_only/modules.py b/neural_compressor/torch/algorithms/weight_only/modules.py index fe2c0ac80af..a69243a8a24 100644 --- a/neural_compressor/torch/algorithms/weight_only/modules.py +++ b/neural_compressor/torch/algorithms/weight_only/modules.py @@ -289,7 +289,7 @@ def pack_tensor_with_torch(self, raw_tensor): def unpack_tensor_with_torch(self, packed_tensor): target_dtype = torch.int8 if not hasattr(self, "qzeros") or "int" not in self.dtype else torch.uint8 target_len = packed_tensor.shape[1] * self.n_pack - unpacked_tensor = torch.zeros(packed_tensor.shape[0], target_len, dtype=self.compression_dtype).to(self.device) + unpacked_tensor = torch.zeros(packed_tensor.shape[0], target_len, dtype=target_dtype).to(self.device) mask = torch.tensor(2**self.bits - 1, dtype=self.compression_dtype).to(self.device) for j in range(packed_tensor.shape[1]): for e in range(self.n_pack):