Skip to content

Commit

Permalink
TF-OPT attention mask fixes (huggingface#25238)
Browse files Browse the repository at this point in the history
* stash commit

* More OPT updates

* Update src/transformers/models/opt/modeling_tf_opt.py

Co-authored-by: amyeroberts <[email protected]>

---------

Co-authored-by: amyeroberts <[email protected]>
  • Loading branch information
2 people authored and EduardoPach committed Nov 18, 2023
1 parent 6bccf4f commit 260c5a3
Showing 1 changed file with 22 additions and 19 deletions.
41 changes: 22 additions & 19 deletions src/transformers/models/opt/modeling_tf_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,15 @@
LARGE_NEGATIVE = -1e8


# Copied from transformers.models.bart.modeling_tf_bart._make_causal_mask
def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz = input_ids_shape[0]
tgt_len = input_ids_shape[1]
mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
mask_cond = tf.range(shape_list(mask)[-1])

mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask)
# We need triu with k = 1 but TF expects known compile-time dims for that, so we hack around it
mask = tf.fill((tgt_len, tgt_len), tf.cast(LARGE_NEGATIVE, tf.float32))
mask = tf.linalg.band_part(mask, 0, -1) - tf.linalg.band_part(mask, 0, 0)

if past_key_values_length > 0:
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1)
Expand All @@ -93,7 +91,7 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
return (one_cst - expanded_mask) * LARGE_NEGATIVE


class TFOPTLearnedPositionalEmbedding(TFSharedEmbeddings):
class TFOPTLearnedPositionalEmbedding(tf.keras.layers.Embedding):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
Expand Down Expand Up @@ -516,16 +514,22 @@ def get_input_embeddings(self):
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, past_key_values_length):
# create causal mask
# # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
else:
combined_attention_mask = _expand_mask(
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
)
_, seq_length = input_shape
tf.debugging.assert_equal(
seq_length + past_key_values_length,
shape_list(attention_mask)[1],
message="Attention mask shape should be (batch_size, seq_length + past_key_values_length)"
f" but is {shape_list(attention_mask)[1]} with input_ids shape {input_shape} and past length"
f" {past_key_values_length}.",
)

if attention_mask is not None:
combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1])
expanded_attn_mask = _expand_mask(attention_mask, tgt_len=input_shape[-1])
if seq_length > 1:
combined_attention_mask = (
_make_causal_mask(input_shape, past_key_values_length=past_key_values_length) + expanded_attn_mask
)
else:
combined_attention_mask = expanded_attn_mask

return combined_attention_mask

Expand Down Expand Up @@ -615,17 +619,16 @@ def call(
inputs_embeds = self.embed_tokens(input_ids)

if attention_mask is None:
attention_mask = tf.ones(inputs_embeds.shape[:2], dtype=tf.bool)
attention_mask = tf.ones((input_shape[0], input_shape[1] + past_key_values_length), dtype=tf.bool)
else:
tf.debugging.assert_equal(
tf.shape(attention_mask)[1],
shape_list(attention_mask)[1],
past_key_values_length + input_shape[1],
message=(
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
f"The provided attention mask has length {tf.shape(attention_mask)[1]}, but its length should be "
f"{past_key_values_length + input_shape[1]} (sum of the lengths of current and past inputs)"
),
)

pos_embeds = self.embed_positions(attention_mask, past_key_values_length)

attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length)
Expand Down

0 comments on commit 260c5a3

Please sign in to comment.