Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stop storing references to bound methods via tf.function #24146

Merged
merged 12 commits into from
Jun 13, 2023

Conversation

Rocketknight1
Copy link
Member

@Rocketknight1 Rocketknight1 commented Jun 9, 2023

This is (hopefully!) the end of a long saga this week.

@ydshieh noticed that our tests runners were going OOM, after a couple of PRs I made to dummy inputs. I thought the problem was just that the new dummy inputs were too large, but eventually we figured out that the problem was actually quite complicated!

tl;dr A circular reference exists, which is caused by us calling tf.function() on a model method and then storing the result as a model attribute. Because this reference exists, our TF models are not cleaned up immediately when they are deleted, but only after the next Python garbage collection.

I believe the PRs triggered the issue by eliminating unneccessary calls and making TF model building much faster. This left less time for garbage collection to happen, and as a result our test suites started a second test before the first test had been cleaned up, which caused the test runner to go OOM.

We tried resolving this problem by manually calling gc.collect() before each test, but this made some of the test suites much slower! Obviously the real solution had to be to resolve the circular reference that was slowing down model cleanup.

The solution is to replace model.eager_serving with a method model._get_eager_serving_fn(). This returns a function that TensorFlow can compile, but which doesn't create a hard reference to a model method in the returned tf.function. I confirmed through manual inspection with gc.get_referrers that the reference is removed and models are cleaned up immediately once they go out of scope now.

See the update below for a full description of the solution I finally went with!

@Rocketknight1 Rocketknight1 requested a review from amyeroberts June 9, 2023 16:55
@Rocketknight1 Rocketknight1 force-pushed the fix_tf_model_garbage_collection branch from 50f6831 to b138eea Compare June 9, 2023 16:57
@Rocketknight1 Rocketknight1 requested review from ydshieh and gante June 9, 2023 17:00
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 9, 2023

The documentation is not available anymore as the PR was closed or merged.

@ydshieh
Copy link
Collaborator

ydshieh commented Jun 9, 2023

no OOM with this PR (for the models involved), but we have some errors regarding TypeError: Binding inputs to tf.function eager_servingfailed due tomissing a required argument: 'inputs'`` popping up for several model/tokenizer tests.

One example is

self = <tests.models.gpt2.test_modeling_tf_gpt2.TFGPT2ModelTest testMethod=test_saved_model_creation>

    @slow
    def test_saved_model_creation(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        config.output_hidden_states = False
        config.output_attentions = False
    
        if hasattr(config, "use_cache"):
            config.use_cache = False
    
        model_class = self.all_model_classes[0]
    
        class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
        model = model_class(config)
    
        model(class_inputs_dict)
    
        with tempfile.TemporaryDirectory() as tmpdirname:
>           model.save_pretrained(tmpdirname, saved_model=True)

tests/test_modeling_tf_common.py:268: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
src/transformers/modeling_tf_utils.py:2427: in save_pretrained
    self.save(saved_model_dir, include_optimizer=False, signatures=signatures)
/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py:70: in error_handler
    raise e.with_traceback(filtered_tb) from None
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <tensorflow.python.eager.polymorphic_function.function_spec.FunctionSpec object at 0x7f6e1c4db7f0>
args = ({'attention_mask': TensorSpec(shape=(None, None), dtype=tf.int32, name='attention_mask'), 'input_ids': TensorSpec(sha...tf.int32, name='input_ids'), 'token_type_ids': TensorSpec(shape=(None, None), dtype=tf.int32, name='token_type_ids')},), kwargs = {}

    def bind_function_inputs(self, args, kwargs):
      """Bind `args` and `kwargs` into a canonicalized signature args, kwargs."""
      sanitized_kwargs = {
          function_type_lib.sanitize_arg_name(k): v for k, v in kwargs.items()
      }
      if len(kwargs) != len(sanitized_kwargs):
        raise ValueError(f"Name collision after sanitization. Please rename "
                         f"tf.function input parameters. Original: "
                         f"{sorted(kwargs.keys())}, Sanitized: "
                         f"{sorted(sanitized_kwargs.keys())}")
    
      try:
        bound_arguments = self.function_type.bind_with_defaults(
            args, sanitized_kwargs, self.default_values)
      except Exception as e:
>       raise TypeError(
            f"Binding inputs to tf.function `{self._name}` failed due to `{e}`."
            f"Received args: {args} and kwargs: {sanitized_kwargs} for signature:"
            f" {self.function_type}."
        ) from e
E       TypeError: Binding inputs to tf.function `eager_serving` failed due to `missing a required argument: 'inputs'`.Received args: ({'input_ids': TensorSpec(shape=(None, None), dtype=tf.int32, name='input_ids'), 'attention_mask': TensorSpec(shape=(None, None), dtype=tf.int32, name='attention_mask'), 'token_type_ids': TensorSpec(shape=(None, None), dtype=tf.int32, name='token_type_ids')},) and kwargs: {} for signature: (self, inputs: Dict(mapping={'input_ids': TensorSpec(shape=(None, None), dtype=tf.int32, name='input_ids'), 'attention_mask': TensorSpec(shape=(None, None), dtype=tf.int32, name='attention_mask'), 'token_type_ids': TensorSpec(shape=(None, None), dtype=tf.int32, name='token_type_ids')})).

@Rocketknight1
Copy link
Member Author

Looks like the way I'm handling the methods fails when we try to save the model with those signatures. I'll figure it out on Monday!

@ydshieh ydshieh mentioned this pull request Jun 9, 2023
@Rocketknight1
Copy link
Member Author

This should be ready for review now! The changes are pretty small, but it took me a while to figure out the details.

It turns out anything that looks like self.serving = tf.function(self.eager_serving) will create a circular reference between self.serving and self and inhibit cleanup. This does not apply to methods defined at the class (rather than instance) level. Something like this is fine and does not block cleanup:

@tf.function(input_signature=...)
def serving(self, inputs):
    ...

The problem with the construction above, though, is that the tf.function decorator has to be called with all of its arguments at the class level, before the model has been initialized with a config. This means it can't read any shapes or details from the config, which means its signature has to be very very general. This is why we transitioned to self.serving = ... in the first place.

The solution I found is the following:

  • Get rid of all helper methods like self.eager_serving. These were only used internally anyway, to allow us to compile multiple serving signatures.
  • Decorate the base serving method with tf.function and no signature at all.
  • Rely on our control of self.save_spec to ensure that base TF methods like model.save() will save with the right signature even when we aren't manually defining it (I checked this and it works!)
  • When we want to manually specify signatures, we just call self.serving.get_concrete_signature with different signatures. No need to keep eager_serving around anymore!

This should totally preserve functionality and backward compatibility, while resolving the memory cleanup issue and keeping the specific save signatures. The only potentially noticeable change is that self.serving.input_signature is no longer defined. We read that value in a couple of tests as a shortcut to find the model input names, so I just replaced it with self.input_signature instead. I don't think anyone outside of Hugging Face was using it, and it certainly wasn't part of our public API, so I don't expect any issues!

@Rocketknight1
Copy link
Member Author

Thanks to @ydshieh for his patience with the tests and to @gante for digging out the old PRs that let me finally understand why a lot of this stuff was ever here in the first place!

@ydshieh
Copy link
Collaborator

ydshieh commented Jun 13, 2023

OK, I will run a few tests and let you know @Rocketknight1
Thank you for trying trying!

Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM. Thank you.

Instead of remove _prune_signature and use self.input_signature, shouldn't we use a modified version like input_signature_to_use = self._prune_signature(self.input_signature)?

But you are HF's version of François C., I beleive whatever you said!

@Rocketknight1
Copy link
Member Author

@ydshieh actually, you're right - I thought it wasn't doing anything anymore, but it's still useful in some cases when we define a broad signature that gets inherited. Let me rework that so we keep it!

@Rocketknight1 Rocketknight1 force-pushed the fix_tf_model_garbage_collection branch from 3d12414 to e371f62 Compare June 13, 2023 12:44
@ydshieh
Copy link
Collaborator

ydshieh commented Jun 13, 2023

No warning sign after running tests for 4 involved models. You are very good at TF!

@Rocketknight1
Copy link
Member Author

@ydshieh I finished rebasing and I removed your gc.collect() change to the doctests. Are you okay for me to merge now, or do you want to run any further tests?

Either way, I think we've finally resolved this one!

@ydshieh
Copy link
Collaborator

ydshieh commented Jun 13, 2023

It's ok, go ahead. If doctest starts to fail, I will call you.

@Rocketknight1
Copy link
Member Author

Also pinging @amyeroberts for core maintainer review

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 🔥

To be safe, can you trigger the slow CI on this branch? Most of the TF serialization tests are slow tests :D

@ydshieh
Copy link
Collaborator

ydshieh commented Jun 13, 2023

LGTM 🔥

To be safe, can you trigger the slow CI on this branch? Most of the TF serialization tests are slow tests :D

Hello @gante !

Do you mean enable slow tests but for all the models ..? Or anything else?
Can't run slow tests on CircleCI however, so need to run on a specific VM.

@Rocketknight1
Copy link
Member Author

Good point, actually - this PR isn't specific to any one model, so we'd need to run slow tests for all models. Since it's a long time to the next release, let's just merge this PR (after review) and see if anything fails overnight?

@amyeroberts
Copy link
Collaborator

@Rocketknight1 @ydshieh Could we run slow tests on just a handful of models ~5 popular ones from different modalities to make sure any obvious issues have been caught?

@Rocketknight1
Copy link
Member Author

@amyeroberts I've been running them locally on the core models - BERT and GPT-2 look good! Are you okay if I try a few more and then merge if there are no issues?

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this!

Like all good refactors, the code is so much cleaner ❤️ Overall, looks good. Only thing needed before merge is handling the deprecation of eager_serving

@@ -1226,12 +1222,12 @@ def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
head_mask = tf.cast(head_mask, tf.float32) # switch to float if need + fp16 compatibility
return head_mask

def eager_serving(self, inputs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't just kill a public method like this, unfortunately. We'll need to have a deprecation cycle and warning that it's going to be removed in 2 versions time

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood! I'll put it back and mark it as deprecated.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree in general, but is eager_serving intended to be used by user ..?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added it back with a warning!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ydshieh Nope, not intended, but it's a public method on all of our TF models, so removing it is a breaking change.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating!

@Rocketknight1 With just a few extra slow runs on different models e.g. CLIP, Wav2Vec2, I think we're good to go 👍

@Rocketknight1
Copy link
Member Author

Tested BERT, GPT-2, BART, ViT, CLIP and Wav2Vec2 without issues. Merging!

@Rocketknight1 Rocketknight1 merged commit 3bd1fe4 into main Jun 13, 2023
@Rocketknight1 Rocketknight1 deleted the fix_tf_model_garbage_collection branch June 13, 2023 18:04
novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
…#24146)

* Stop storing references to bound methods in tf.functions

* Remove the gc.collect calls now that we resolved the underlying problem

* Remove the default signature from model.serving entirely, big cleanup

* Remove _prune_signature as self.input_signature can prune itself

* Restore serving docstring

* Update int support test to check the input signature

* Make sure other tests also use model.input_signature and not serving.input_signature

* Restore _prune_signature

* Remove the doctest GC now it's no longer needed

* Correct core tests to use the pruned sig

* order lines correctly in core tests

* Add eager_serving back with a deprecation warning
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants