Skip to content

Commit

Permalink
Improve poptorch cond checks
Browse files Browse the repository at this point in the history
  • Loading branch information
katalinic-gc committed Apr 27, 2023
1 parent 77244c5 commit 284a930
Showing 1 changed file with 22 additions and 2 deletions.
24 changes: 22 additions & 2 deletions optimum/graphcore/generation/attention_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,23 @@

import torch

import poptorch
from transformers.utils.versions import require_version


FLOAT16_LIMIT = 1e4


def assert_poptorch_supports_cond(context: Optional[str] = None):
context = context or ""
require_version("poptorch>=3.3", "Require poptorch>=3.3 for `poptorch.cond`. " + context)
if not hasattr(poptorch, "cond"):
raise AttributeError(
"`poptorch.cond` appears to be missing, perhaps you are using a candidate release "
"which does not support it yet? " + context
)


class IPUAttentionMixin:
"""
The aim of this class is to provide common, model-agnostic functionality such as KV caching and attention
Expand All @@ -43,7 +56,7 @@ class IPUAttentionMixin:
@property
def kv_cache_initialised(self) -> bool:
return self._kv_cache_initialised

@property
def cross_kv_cache_initialised(self) -> bool:
return self._cross_kv_cache_initialised
Expand Down Expand Up @@ -110,6 +123,9 @@ def from_model(
uses_beams=num_beams > 1,
)
if use_cross_cache:
assert_poptorch_supports_cond(
context="Cross-attention KV caching has been enabled with `use_cross_cache=True`."
)
clone._create_cross_kv_cache(
(batch_size * num_beams, clone.num_heads, encoder_max_length, clone.head_dim),
dtype=dtype,
Expand Down Expand Up @@ -193,7 +209,7 @@ def update_attention_mask(self, attention_mask: Optional[torch.Tensor] = None):
mask = mask + attention_mask

return mask

def add_to_cross_kv_cache(
self,
cross_input: torch.Tensor,
Expand All @@ -213,6 +229,10 @@ def add_to_cross_kv_cache(
if not hasattr(poptorch, "cond"):
raise AttributeError("Cross KV caching requires `poptorch.cond` which appears to be missing.")

assert_poptorch_supports_cond(
context="Cross-attention KV caching has been enabled with `use_cross_cache=True`."
)

# For now assume that generation will always start from step 0.
reset_kv_cache = self._generation_step == 0
self._cross_k_cache *= 1 - reset_kv_cache.to(self._cross_k_cache.dtype)
Expand Down

0 comments on commit 284a930

Please sign in to comment.