Skip to content

Commit

Permalink
Adjust INC to run from vLLM with old PA
Browse files Browse the repository at this point in the history
Change-Id: Ifdea6840aaa22791f478ad10788e5d47fd4a0394
  • Loading branch information
nirda7 committed Aug 1, 2024
1 parent ff114b7 commit c19fcbd
Showing 1 changed file with 10 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -533,19 +533,24 @@ def __init__(self, mod, mod_extra_config, *args, **kwargs):
self.fetch_from_cache = mod.fetch_from_cache
self.forward = self.forward_measure

def forward(self, input, cache, block_indices, block_offset):
def forward(self, input, cache, num_kv_cache_passes, num_slots_available, block_indices, block_offset):
qinput = self.quant_input(input)
output_cache = self.forward_orig(qinput, cache, block_indices, block_offset)
output_cache = self.forward_orig(qinput, cache, num_kv_cache_passes, num_slots_available, block_indices, block_offset)
return self.quant_output(output_cache)

def forward_measure(self, input, cache, block_indices, block_offset):
def forward_measure(self, input, cache, num_kv_cache_passes, num_slots_available, block_indices, block_offset):
measure_input((input), self._mod_extra_config.inputs)
output_cache = self.forward_orig(input, cache, block_indices, block_offset)
output_cache = self.forward_orig(input, cache, num_kv_cache_passes, num_slots_available, block_indices, block_offset)
measure_output((output_cache), self._mod_extra_config.outputs)
return output_cache

def fetch_from_cache(self, cache, blocks):
def fetch_from_cache(self, cache, blocks, permutations=None):
quant_cache = self.quant_input(cache)
if permutations:
output_cache = self.orig_fetch_from_cache(quant_cache, blocks, permutations)
for i in range(len(output_cache)):
output_cache[i]=self.quant_output(output_cache[i])
return output_cache
output_cache = self.orig_fetch_from_cache(quant_cache, blocks)
return self.quant_output(output_cache)

Expand Down

0 comments on commit c19fcbd

Please sign in to comment.