From c19fcbdbcc5640ea424fd11622fc7d79410dada9 Mon Sep 17 00:00:00 2001 From: Nir David Date: Thu, 1 Aug 2024 19:48:57 +0300 Subject: [PATCH] Adjust INC to run from vLLM with old PA Change-Id: Ifdea6840aaa22791f478ad10788e5d47fd4a0394 --- .../fp8_quant/_quant_common/helper_modules.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py index eca5fd658e8..8d457f6e704 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py @@ -533,19 +533,24 @@ def __init__(self, mod, mod_extra_config, *args, **kwargs): self.fetch_from_cache = mod.fetch_from_cache self.forward = self.forward_measure - def forward(self, input, cache, block_indices, block_offset): + def forward(self, input, cache, num_kv_cache_passes, num_slots_available, block_indices, block_offset): qinput = self.quant_input(input) - output_cache = self.forward_orig(qinput, cache, block_indices, block_offset) + output_cache = self.forward_orig(qinput, cache, num_kv_cache_passes, num_slots_available, block_indices, block_offset) return self.quant_output(output_cache) - def forward_measure(self, input, cache, block_indices, block_offset): + def forward_measure(self, input, cache, num_kv_cache_passes, num_slots_available, block_indices, block_offset): measure_input((input), self._mod_extra_config.inputs) - output_cache = self.forward_orig(input, cache, block_indices, block_offset) + output_cache = self.forward_orig(input, cache, num_kv_cache_passes, num_slots_available, block_indices, block_offset) measure_output((output_cache), self._mod_extra_config.outputs) return output_cache - def fetch_from_cache(self, cache, blocks): + def fetch_from_cache(self, cache, blocks, permutations=None): quant_cache = self.quant_input(cache) + if permutations: + output_cache = self.orig_fetch_from_cache(quant_cache, blocks, permutations) + for i in range(len(output_cache)): + output_cache[i]=self.quant_output(output_cache[i]) + return output_cache output_cache = self.orig_fetch_from_cache(quant_cache, blocks) return self.quant_output(output_cache)