From 357261786548c64e16225acfe2c275733588d3fc Mon Sep 17 00:00:00 2001 From: Nir David Date: Thu, 25 Jul 2024 10:30:10 +0300 Subject: [PATCH] [SW-194177] - Integrate new vllm-PA algo with HQT Change-Id: I94c9679f0aff7c2f9a86a802da825bfd6d0772ad --- .../algorithms/fp8_quant/_quant_common/helper_modules.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py index 435d4389199..79399c90815 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py @@ -530,12 +530,10 @@ def forward_measure(self, input, cache, block_indices, block_offset): measure_output((output_cache), self._mod_extra_config.outputs) return output_cache - def fetch_from_cache(self, cache, blocks, permutations): + def fetch_from_cache(self, cache, blocks): quant_cache = self.quant_input(cache) - output_cache = self.orig_fetch_from_cache(quant_cache, blocks, permutations) - for i in range(len(output_cache)): - output_cache[i]=self.quant_output(output_cache[i]) - return output_cache + output_cache = self.orig_fetch_from_cache(quant_cache, blocks) + return self.quant_output(output_cache) class PatchedConv2d(nn.Conv2d):