Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF-OPT attention mask fixes #25238

Merged
merged 3 commits into from
Sep 6, 2023
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 26 additions & 19 deletions src/transformers/models/opt/modeling_tf_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,19 @@
LARGE_NEGATIVE = -1e8


# Copied from transformers.models.bart.modeling_tf_bart._make_causal_mask
def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz = input_ids_shape[0]
tgt_len = input_ids_shape[1]
mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
mask_cond = tf.range(shape_list(mask)[-1])

mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask)
# mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
# mask_cond = tf.range(shape_list(mask)[-1])
#
# mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask)
Rocketknight1 marked this conversation as resolved.
Show resolved Hide resolved
# We need triu with k = 1 but TF expects known compile-time dims for that, so we hack around it
mask = tf.fill((tgt_len, tgt_len), tf.cast(LARGE_NEGATIVE, tf.float32))
mask = tf.linalg.band_part(mask, 0, -1) - tf.linalg.band_part(mask, 0, 0)

if past_key_values_length > 0:
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1)
Expand All @@ -93,7 +95,7 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
return (one_cst - expanded_mask) * LARGE_NEGATIVE


class TFOPTLearnedPositionalEmbedding(TFSharedEmbeddings):
class TFOPTLearnedPositionalEmbedding(tf.keras.layers.Embedding):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
Expand Down Expand Up @@ -516,16 +518,22 @@ def get_input_embeddings(self):
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, past_key_values_length):
# create causal mask
# # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
else:
combined_attention_mask = _expand_mask(
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
)
_, seq_length = input_shape
tf.debugging.assert_equal(
seq_length + past_key_values_length,
shape_list(attention_mask)[1],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this check robust? From the diff it looks like attention_mask can be None

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! The TFOPTDecoder layer checks for None attention masks and replaces them with tf.ones. That happens before _prepare_decoder_attention_mask is called. The earlier code had an if attention_mask is not None branch that was just always taken as a result.

message="Attention mask shape should be (batch_size, seq_length + past_key_values_length)"
f" but is {shape_list(attention_mask)[1]} with input_ids shape {input_shape} and past length"
f" {past_key_values_length}.",
)

if attention_mask is not None:
combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1])
expanded_attn_mask = _expand_mask(attention_mask, tgt_len=input_shape[-1])
if seq_length > 1:
combined_attention_mask = (
_make_causal_mask(input_shape, past_key_values_length=past_key_values_length) + expanded_attn_mask
)
else:
combined_attention_mask = expanded_attn_mask

return combined_attention_mask

Expand Down Expand Up @@ -615,17 +623,16 @@ def call(
inputs_embeds = self.embed_tokens(input_ids)

if attention_mask is None:
attention_mask = tf.ones(inputs_embeds.shape[:2], dtype=tf.bool)
attention_mask = tf.ones((input_shape[0], input_shape[1] + past_key_values_length), dtype=tf.bool)
else:
tf.debugging.assert_equal(
tf.shape(attention_mask)[1],
shape_list(attention_mask)[1],
past_key_values_length + input_shape[1],
message=(
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
f"The provided attention mask has length {tf.shape(attention_mask)[1]}, but its length should be "
f"{past_key_values_length + input_shape[1]} (sum of the lengths of current and past inputs)"
),
)

pos_embeds = self.embed_positions(attention_mask, past_key_values_length)

attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length)
Expand Down