Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Big TF test cleanup #24282

Merged
merged 19 commits into from
Jun 16, 2023
Merged

Big TF test cleanup #24282

merged 19 commits into from
Jun 16, 2023

Conversation

Rocketknight1
Copy link
Member

Now we've done a big overhaul of the TF model internals, a lot of tests can be fixed. Several tests were disabled for being buggy or too slow - these are almost all performant now, so I re-enabled them. Runtime for the re-enabled tests was 15-20 seconds on my local machine.

Also, we had a number of TF test failures in the daily CI. I think this PR should fix all of them, except for two cases:

Firstly, some models have issues with resize_token_embeddings. These failures are caused by the transition to TFSharedEmbedding that @gante is currently working on, and I didn't want to interfere! The usual cause is that resize_token_embeddings replaces the new-style TFSharedEmbedding with an old tf.Variable.

Secondly, there are a couple of failures in generate tests. I'm also leaving this to @gante because he knows much more about that code than me 😅

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 14, 2023

The documentation is not available anymore as the PR was closed or merged.

decoder_input_ids: tf.Tensor | None = None,
decoder_attention_mask: tf.Tensor | None = None,
attention_mask: tf.Tensor | None = None,
output_attentions: Optional[bool] = None,
foutput_attentions: Optional[bool] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Screenshot 2023-06-14 at 18 34 30

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahahaha

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 👍

@gante
Copy link
Member

gante commented Jun 14, 2023

(@Rocketknight1 ping me if the gen tests are not sorted after the latest push)

Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for all this TF things!

Other than making sure (all) the re-enabled tests will pass now (I guess you already checked them.), I have just 2 nit comments.

Comment on lines 1160 to 1162
# Set the serving spec quickly to ensure that Keras doesn't use the specific dummy input shapes as the spec
self._set_save_spec(self._prune_signature(self.input_signature))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe explain a bit why we don't do this in init method?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit of a long story! _set_save_spec is normally called internally by Keras, and for subclassed models (i.e. all models in transformers) it uses the first input shapes the model sees. This was a huge problem for us, because we'd pass in some tiny dummy inputs and it would just lock in that useless specific shape as the model's save spec. This made exporting/serving a real nightmare!

We avoid that by setting a correct, general save spec before the model has seen any inputs. It doesn't really matter in most cases whether we put that in the __init__ or the build method, as long as it happens before we pass dummy inputs in. However, there is one edge case where it makes a small difference: If the user builds a model from a config, and then passes inputs of a specific shape in. In this case, putting it in build() allows the user to set the save spec with their own inputs, which can be useful in a couple of cases.

I'm not convinced this is a perfect solution, but it resolves an edge case in our in-graph tokenizer test, so it seems a little better than the alternative!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So a short version could be

# put this in build() allows the user to set the save spec with their own inputs.

😄 ...?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! I'll comment something like that

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

@@ -463,19 +463,12 @@ def _prepare_decoder_attention_mask(
) -> tf.Tensor:
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask: tf.Tensor | None = None
if input_shape[-1] > 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't need to check this condition anymore ..?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so! I couldn't see a case where input_shape[-1] == 0 was possible.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean > 1 vs == 1 not == 0

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, you're totally right, I don't know how I blanked on that! Let me fix it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, and sorry for the extremely embarrassing oversight where my eyes kept reading > as >=!

Comment on lines -228 to -231
@tooslow
def test_saved_model_creation(self):
pass

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still a slow test, and I would like to know if this re-enabled test pass now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It passes!

Comment on lines -695 to -698
@unittest.skip(reason="Currently `saved_model` doesn't work with nested outputs.")
@slow
def test_saved_model_creation_extended(self):
pass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would like to know if this re-enabled test pass now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

saved_model_creation_extended is now a core test that is only run on a few models because it's very expensive, so this skip is no longer needed.

Comment on lines -295 to -297
def test_xla_mode(self):
# TODO JP: Make LED XLA compliant
pass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would like to know if this re-enabled test pass now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It passes!

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for cleaning up!

Overall, changes look good to me.

  • Big +1 to all of @ydshieh's comments
  • For just the affected models, could you run the slow tests this changes? In particular test_saved_model_creation_extended?
  • Could you run a generation test with speech to text to make sure the embeddings reshaping is working?

@@ -490,6 +490,7 @@ def test_model_without_retriever(self):
inputs_dict = self.config_and_inputs
self.check_model_without_retriever(**inputs_dict)

@slow
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there related to failures in the generate tests - or did they just become slow?

If failures - could you add a unittest.skip decorator instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These aren't failures! They're just extremely slow generation tests (with retrieval!), and were sometimes triggering the 120s timeout in the live CI.

Comment on lines 1160 to 1162
# Set the serving spec quickly to ensure that Keras doesn't use the specific dummy input shapes as the spec
self._set_save_spec(self._prune_signature(self.input_signature))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

# idempotent. TF doesn't need that caching anyway, since it can just store constants during compilation,
# so we just remove all of that code.
embeddings = self._get_embedding(
self.padding_idx + 1 + seq_len + self.offset + past_key_values_length, self.embedding_dim, self.padding_idx
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't exactly the same as past_key_values_length wasn't added before. This seems more correct but can we run some tests on generation to make sure this works as expected?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @amyeroberts, you're right! It's actually okay to sometimes generate too many embeddings, though, because the embeddings tensor is only transiently created in this function, gathered from and then discarded again. I ran the slow tests for this model and all passed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...though as I write this, I realize that all of this code is just a speed hack because eager Torch code can't optimize or compute things out of order, so really I should just directly transform the position IDs into the embeddings and skip the whole gather!

)
dec_attn_mask = upper_mask + lower_mask
else:
dec_attn_mask = upper_mask
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a lot easier to understand :)

combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length)
if attention_mask is None:
return combined_attention_mask
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ultra nit: if we return in an if statement, we don't need the else

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed!

@@ -463,19 +463,12 @@ def _prepare_decoder_attention_mask(
) -> tf.Tensor:
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask: tf.Tensor | None = None
if input_shape[-1] > 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

pass

@slow
def test_saved_model_creation_extended(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this run now with the default test_saved_model_creation_extended test?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_saved_model_creation_extended is now a core test that is only run on a few models because it's very expensive, so this skip is no longer needed.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating!

I just have one question to address re the refactoring of _prepare_decoder_attention_mask before merging

Comment on lines -314 to -315
@slow
def test_keras_fit(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just double checking that this is now fast for this model?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, good question. On my local machine, this test takes from 20-40 seconds depending on model. MobileBERT is one of the slower ones, but it's still inside that range.

However, 20-40 seconds is probably in the range that the whole test should be marked as slow to keep it out of the quick CI, right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have some tests (not decorated as slow) run more than 40 seconds, but let's not add more such tests. Decorate it as slow and everyone's life is easier 🍺

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🍻 cheers to that!

return combined_attention_mask
combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length)
combined_attention_mask = tf.cond(
input_shape[-1] > 1, lambda: combined_attention_mask, lambda: tf.ones_like(combined_attention_mask)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this completely equivalent?

If I've understood before and after correctly, if attention_mask is not None and input_shape[-1] == 1, then in the old case:

combined_attention_mask = expand_attention_mask

and in the new:

combinded_attention_mask = expand_attention_mask + tf.ones_like(combined_attention_mask)

i.e. an additional matrix of 1s is added

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I think you're right! Hang on, let me see what I can do.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(This issue, like most of our issues, is caused by me assuming that tests passing = no problems)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Investigation complete! So basically, I can't actually make this function reproduce the old behaviour when we're compiling with flexible shapes, because it's just totally forbidden in TF to have a conditional where one branch returns None and the other branch returns a tf.Tensor.

However, when I actually followed the code through to self-attention where the attention mask is used, an attention mask of None is just treated as an all-ones mask (i.e. neither affects the attention logits at all). Therefore, returning all-ones instead of None yields the same model outputs, while obeying TF's requirements for compiling conditionals.

@Rocketknight1
Copy link
Member Author

I think everything has been addressed now, but I'm not going to merge this one today because there's another PR affecting our tests (#24301) and ideally I'd like to be able to separately view their impact on the CI!

@Rocketknight1 Rocketknight1 merged commit 3403712 into main Jun 16, 2023
@Rocketknight1 Rocketknight1 deleted the big_tf_test_cleanup branch June 16, 2023 14:40
@ydshieh
Copy link
Collaborator

ydshieh commented Jun 16, 2023

I think everything has been addressed now, but I'm not going to merge this one today

Nice 👍 .

I never merge PRs on Firday evening or early afternoon. I don't want to get a ☎️ ⚡ !

@ydshieh
Copy link
Collaborator

ydshieh commented Jun 16, 2023

Wait, you merged ...!? (but you said you are not going to merge 🤔 )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants