Skip to content

Commit

Permalink
Convert to DynamicCache and back
Browse files Browse the repository at this point in the history
  • Loading branch information
kddubey committed Oct 28, 2024
1 parent c97b761 commit aa75991
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 10 deletions.
47 changes: 43 additions & 4 deletions src/cappr/huggingface/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,18 @@
from cappr.huggingface._utils import BatchEncodingPT, ModelForCausalLM


try:
from transformers import DynamicCache
except ImportError: # pragma: no cover
# This is an "empty" type. isinstance and issubclass will always be False
DynamicCache = type("DynamicCache", (object,), {})
dynamic_cache_from_tuple = lambda past_key_values: past_key_values
_IS_DYNAMIC_CACHE_AVAILABLE = False
else:
dynamic_cache_from_tuple = DynamicCache.from_legacy_cache
_IS_DYNAMIC_CACHE_AVAILABLE = True


@classify._token_logprobs
@_batch.flatten
@_batch.batchify(batchable_arg="texts", progress_bar_desc="log-probs")
Expand Down Expand Up @@ -191,7 +203,7 @@ def _past_key_values_get(


########################################################################################
############################# KV caching + batch inference #############################
###################################### KV caching ######################################
########################################################################################


Expand Down Expand Up @@ -253,8 +265,29 @@ def forward(
encodings = {"input_ids": input_ids, "attention_mask": attention_mask}

if self._cappr.past is None:
# A later version of transformers will deprecate the omission of
# past_key_values or passing past_key_values=None during inference. We need
# to supply a Cache object
does_model_support_cache_object = (
"GPT2LMHeadModel" not in self._cappr.model.config.architectures
)
# I don't know why the gpt2 implementation still assumes that
# past_key_values is a tuple. TODO: are there more models to exclude? I
# didn't see anything in the model.config which indicates this.
if _IS_DYNAMIC_CACHE_AVAILABLE and does_model_support_cache_object:
past_key_values = DynamicCache()
else:
past_key_values = None

with hf._utils.set_up_model(self._cappr.model):
out: CausalLMOutputWithPast = self._cappr.model(**encodings)
out: CausalLMOutputWithPast = self._cappr.model(
**encodings, past_key_values=past_key_values
)

if isinstance(out.past_key_values, DynamicCache):
# Currently, the KV cache logic assumes KVs are immutable. So we'll
# always keep in the legacy/_PastKeyValues type
out.past_key_values = out.past_key_values.to_legacy_cache()
self._cappr.past = encodings, out
return out

Expand Down Expand Up @@ -311,13 +344,15 @@ def forward(
input_ids=input_ids,
attention_mask=attention_mask_past_cat_present,
position_ids=position_ids_present,
past_key_values=past_key_values,
past_key_values=dynamic_cache_from_tuple(past_key_values),
)
if isinstance(out.past_key_values, DynamicCache):
out.past_key_values = out.past_key_values.to_legacy_cache()

# past_key_values is already concatenated in out.past_key_values
if self._cappr.logits_all:
out.logits = torch.cat([out_past.logits[batch_idxs], out.logits], dim=1)
if self._cappr.update_cache:
# past_key_values is already concatenated in out.past_key_values
# Concatenate encodings for future model calls
input_ids_past = encodings_past["input_ids"][batch_idxs]
encodings = {
Expand Down Expand Up @@ -556,6 +591,10 @@ def cache(
assert torch.allclose(logits3, logits_correct(["a b c 1 2 3"]), atol=atol)
assert torch.allclose(logits4, logits_correct(["a b c d"]), atol=atol)
"""
# TODO: Alternate implementation is to use a DynamicCache which, on exit,
# resets/truncates the KV cache to the original set of tokens. Such an interface
# would mimic that of llamacpp. It'd be more memory efficient than keeping two sets
# of KVs. The common data would be shared instead of copied.
try:
past = getattr(getattr(model_and_tokenizer[0], "_cappr"), "past")
except AttributeError:
Expand Down
12 changes: 6 additions & 6 deletions tests/huggingface/test_huggingface_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,11 +376,11 @@ def test__logits_completions_given_prompts(
# Test logits for better debuggability
# For this helper function, prompts can't be a single string
if isinstance(prompts, str):
return
slow_out = classify_no_cache._logits_completions_given_prompts(
prompts = [prompts]
slow_out = self.module_correct._logits_completions_given_prompts(
model, tokenizer, prompts, completions
)
fast_out = classify._logits_completions_given_prompts(
fast_out = self.module._logits_completions_given_prompts(
model, tokenizer, prompts, completions
)
_test_encodings(*slow_out, *fast_out)
Expand Down Expand Up @@ -433,11 +433,11 @@ def test__logits_completions_given_prompts_examples(
# Test logits for better debuggability
# For this helper function, examples can't be an Example
if isinstance(examples, Example):
return
slow_out = classify_no_cache._logits_completions_given_prompts_examples(
examples = [examples]
slow_out = self.module_correct._logits_completions_given_prompts_examples(
model, tokenizer, examples
)
fast_out = classify._logits_completions_given_prompts_examples(
fast_out = self.module._logits_completions_given_prompts_examples(
model, tokenizer, examples
)
_test_encodings(*slow_out, *fast_out)
Expand Down

0 comments on commit aa75991

Please sign in to comment.