From 842e99f1b9ee2a0fa239997ef695c5ed0bd77195 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 6 Sep 2023 13:37:27 +0100 Subject: [PATCH] TF-OPT attention mask fixes (#25238) * stash commit * More OPT updates * Update src/transformers/models/opt/modeling_tf_opt.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- .../models/opt/modeling_tf_opt.py | 41 ++++++++++--------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/src/transformers/models/opt/modeling_tf_opt.py b/src/transformers/models/opt/modeling_tf_opt.py index cabed90e918778..6c48d6e629273c 100644 --- a/src/transformers/models/opt/modeling_tf_opt.py +++ b/src/transformers/models/opt/modeling_tf_opt.py @@ -61,17 +61,15 @@ LARGE_NEGATIVE = -1e8 -# Copied from transformers.models.bart.modeling_tf_bart._make_causal_mask def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0): """ Make causal mask used for bi-directional self-attention. """ bsz = input_ids_shape[0] tgt_len = input_ids_shape[1] - mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE - mask_cond = tf.range(shape_list(mask)[-1]) - - mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask) + # We need triu with k = 1 but TF expects known compile-time dims for that, so we hack around it + mask = tf.fill((tgt_len, tgt_len), tf.cast(LARGE_NEGATIVE, tf.float32)) + mask = tf.linalg.band_part(mask, 0, -1) - tf.linalg.band_part(mask, 0, 0) if past_key_values_length > 0: mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1) @@ -93,7 +91,7 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None): return (one_cst - expanded_mask) * LARGE_NEGATIVE -class TFOPTLearnedPositionalEmbedding(TFSharedEmbeddings): +class TFOPTLearnedPositionalEmbedding(tf.keras.layers.Embedding): """ This module learns positional embeddings up to a fixed maximum size. """ @@ -516,16 +514,22 @@ def get_input_embeddings(self): def _prepare_decoder_attention_mask(self, attention_mask, input_shape, past_key_values_length): # create causal mask # # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) - else: - combined_attention_mask = _expand_mask( - tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1] - ) + _, seq_length = input_shape + tf.debugging.assert_equal( + seq_length + past_key_values_length, + shape_list(attention_mask)[1], + message="Attention mask shape should be (batch_size, seq_length + past_key_values_length)" + f" but is {shape_list(attention_mask)[1]} with input_ids shape {input_shape} and past length" + f" {past_key_values_length}.", + ) - if attention_mask is not None: - combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1]) + expanded_attn_mask = _expand_mask(attention_mask, tgt_len=input_shape[-1]) + if seq_length > 1: + combined_attention_mask = ( + _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) + expanded_attn_mask + ) + else: + combined_attention_mask = expanded_attn_mask return combined_attention_mask @@ -615,17 +619,16 @@ def call( inputs_embeds = self.embed_tokens(input_ids) if attention_mask is None: - attention_mask = tf.ones(inputs_embeds.shape[:2], dtype=tf.bool) + attention_mask = tf.ones((input_shape[0], input_shape[1] + past_key_values_length), dtype=tf.bool) else: tf.debugging.assert_equal( - tf.shape(attention_mask)[1], + shape_list(attention_mask)[1], past_key_values_length + input_shape[1], message=( - f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " + f"The provided attention mask has length {tf.shape(attention_mask)[1]}, but its length should be " f"{past_key_values_length + input_shape[1]} (sum of the lengths of current and past inputs)" ), ) - pos_embeds = self.embed_positions(attention_mask, past_key_values_length) attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length)