diff --git a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py index eca5fd658e8..8d457f6e704 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py @@ -533,19 +533,24 @@ def __init__(self, mod, mod_extra_config, *args, **kwargs): self.fetch_from_cache = mod.fetch_from_cache self.forward = self.forward_measure - def forward(self, input, cache, block_indices, block_offset): + def forward(self, input, cache, num_kv_cache_passes, num_slots_available, block_indices, block_offset): qinput = self.quant_input(input) - output_cache = self.forward_orig(qinput, cache, block_indices, block_offset) + output_cache = self.forward_orig(qinput, cache, num_kv_cache_passes, num_slots_available, block_indices, block_offset) return self.quant_output(output_cache) - def forward_measure(self, input, cache, block_indices, block_offset): + def forward_measure(self, input, cache, num_kv_cache_passes, num_slots_available, block_indices, block_offset): measure_input((input), self._mod_extra_config.inputs) - output_cache = self.forward_orig(input, cache, block_indices, block_offset) + output_cache = self.forward_orig(input, cache, num_kv_cache_passes, num_slots_available, block_indices, block_offset) measure_output((output_cache), self._mod_extra_config.outputs) return output_cache - def fetch_from_cache(self, cache, blocks): + def fetch_from_cache(self, cache, blocks, permutations=None): quant_cache = self.quant_input(cache) + if permutations: + output_cache = self.orig_fetch_from_cache(quant_cache, blocks, permutations) + for i in range(len(output_cache)): + output_cache[i]=self.quant_output(output_cache[i]) + return output_cache output_cache = self.orig_fetch_from_cache(quant_cache, blocks) return self.quant_output(output_cache)