-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stop storing references to bound methods via tf.function #24146
Conversation
50f6831
to
b138eea
Compare
The documentation is not available anymore as the PR was closed or merged. |
no OOM with this PR (for the models involved), but we have some errors regarding One example is self = <tests.models.gpt2.test_modeling_tf_gpt2.TFGPT2ModelTest testMethod=test_saved_model_creation>
@slow
def test_saved_model_creation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_hidden_states = False
config.output_attentions = False
if hasattr(config, "use_cache"):
config.use_cache = False
model_class = self.all_model_classes[0]
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
model(class_inputs_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
> model.save_pretrained(tmpdirname, saved_model=True)
tests/test_modeling_tf_common.py:268:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
src/transformers/modeling_tf_utils.py:2427: in save_pretrained
self.save(saved_model_dir, include_optimizer=False, signatures=signatures)
/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py:70: in error_handler
raise e.with_traceback(filtered_tb) from None
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <tensorflow.python.eager.polymorphic_function.function_spec.FunctionSpec object at 0x7f6e1c4db7f0>
args = ({'attention_mask': TensorSpec(shape=(None, None), dtype=tf.int32, name='attention_mask'), 'input_ids': TensorSpec(sha...tf.int32, name='input_ids'), 'token_type_ids': TensorSpec(shape=(None, None), dtype=tf.int32, name='token_type_ids')},), kwargs = {}
def bind_function_inputs(self, args, kwargs):
"""Bind `args` and `kwargs` into a canonicalized signature args, kwargs."""
sanitized_kwargs = {
function_type_lib.sanitize_arg_name(k): v for k, v in kwargs.items()
}
if len(kwargs) != len(sanitized_kwargs):
raise ValueError(f"Name collision after sanitization. Please rename "
f"tf.function input parameters. Original: "
f"{sorted(kwargs.keys())}, Sanitized: "
f"{sorted(sanitized_kwargs.keys())}")
try:
bound_arguments = self.function_type.bind_with_defaults(
args, sanitized_kwargs, self.default_values)
except Exception as e:
> raise TypeError(
f"Binding inputs to tf.function `{self._name}` failed due to `{e}`."
f"Received args: {args} and kwargs: {sanitized_kwargs} for signature:"
f" {self.function_type}."
) from e
E TypeError: Binding inputs to tf.function `eager_serving` failed due to `missing a required argument: 'inputs'`.Received args: ({'input_ids': TensorSpec(shape=(None, None), dtype=tf.int32, name='input_ids'), 'attention_mask': TensorSpec(shape=(None, None), dtype=tf.int32, name='attention_mask'), 'token_type_ids': TensorSpec(shape=(None, None), dtype=tf.int32, name='token_type_ids')},) and kwargs: {} for signature: (self, inputs: Dict(mapping={'input_ids': TensorSpec(shape=(None, None), dtype=tf.int32, name='input_ids'), 'attention_mask': TensorSpec(shape=(None, None), dtype=tf.int32, name='attention_mask'), 'token_type_ids': TensorSpec(shape=(None, None), dtype=tf.int32, name='token_type_ids')})). |
Looks like the way I'm handling the methods fails when we try to save the model with those signatures. I'll figure it out on Monday! |
This should be ready for review now! The changes are pretty small, but it took me a while to figure out the details. It turns out anything that looks like
The problem with the construction above, though, is that the The solution I found is the following:
This should totally preserve functionality and backward compatibility, while resolving the memory cleanup issue and keeping the specific save signatures. The only potentially noticeable change is that |
OK, I will run a few tests and let you know @Rocketknight1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM. Thank you.
Instead of remove _prune_signature
and use self.input_signature
, shouldn't we use a modified version like input_signature_to_use = self._prune_signature(self.input_signature)
?
But you are HF's version of François C., I beleive whatever you said!
@ydshieh actually, you're right - I thought it wasn't doing anything anymore, but it's still useful in some cases when we define a broad signature that gets inherited. Let me rework that so we keep it! |
3d12414
to
e371f62
Compare
No warning sign after running tests for 4 involved models. You are very good at TF! |
@ydshieh I finished rebasing and I removed your Either way, I think we've finally resolved this one! |
It's ok, go ahead. If doctest starts to fail, I will |
Also pinging @amyeroberts for core maintainer review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM 🔥
To be safe, can you trigger the slow CI on this branch? Most of the TF serialization tests are slow tests :D
Hello @gante ! Do you mean enable slow tests but for all the models ..? Or anything else? |
Good point, actually - this PR isn't specific to any one model, so we'd need to run slow tests for all models. Since it's a long time to the next release, let's just merge this PR (after review) and see if anything fails overnight? |
@Rocketknight1 @ydshieh Could we run slow tests on just a handful of models ~5 popular ones from different modalities to make sure any obvious issues have been caught? |
@amyeroberts I've been running them locally on the core models - BERT and GPT-2 look good! Are you okay if I try a few more and then merge if there are no issues? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this!
Like all good refactors, the code is so much cleaner ❤️ Overall, looks good. Only thing needed before merge is handling the deprecation of eager_serving
@@ -1226,12 +1222,12 @@ def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers): | |||
head_mask = tf.cast(head_mask, tf.float32) # switch to float if need + fp16 compatibility | |||
return head_mask | |||
|
|||
def eager_serving(self, inputs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can't just kill a public method like this, unfortunately. We'll need to have a deprecation cycle and warning that it's going to be removed in 2 versions time
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Understood! I'll put it back and mark it as deprecated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree in general, but is eager_serving
intended to be used by user ..?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added it back with a warning!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ydshieh Nope, not intended, but it's a public method on all of our TF models, so removing it is a breaking change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating!
@Rocketknight1 With just a few extra slow runs on different models e.g. CLIP, Wav2Vec2, I think we're good to go 👍
Tested BERT, GPT-2, BART, ViT, CLIP and Wav2Vec2 without issues. Merging! |
…#24146) * Stop storing references to bound methods in tf.functions * Remove the gc.collect calls now that we resolved the underlying problem * Remove the default signature from model.serving entirely, big cleanup * Remove _prune_signature as self.input_signature can prune itself * Restore serving docstring * Update int support test to check the input signature * Make sure other tests also use model.input_signature and not serving.input_signature * Restore _prune_signature * Remove the doctest GC now it's no longer needed * Correct core tests to use the pruned sig * order lines correctly in core tests * Add eager_serving back with a deprecation warning
This is (hopefully!) the end of a long saga this week.
@ydshieh noticed that our tests runners were going OOM, after a couple of PRs I made to dummy inputs. I thought the problem was just that the new dummy inputs were too large, but eventually we figured out that the problem was actually quite complicated!
tl;dr A circular reference exists, which is caused by us calling tf.function() on a model method and then storing the result as a model attribute. Because this reference exists, our TF models are not cleaned up immediately when they are deleted, but only after the next Python garbage collection.
I believe the PRs triggered the issue by eliminating unneccessary calls and making TF model building much faster. This left less time for garbage collection to happen, and as a result our test suites started a second test before the first test had been cleaned up, which caused the test runner to go OOM.
We tried resolving this problem by manually calling
gc.collect()
before each test, but this made some of the test suites much slower! Obviously the real solution had to be to resolve the circular reference that was slowing down model cleanup.The solution is to replacemodel.eager_serving
with a methodmodel._get_eager_serving_fn()
. This returns a function that TensorFlow can compile, but which doesn't create a hard reference to a model method in the returnedtf.function
. I confirmed through manual inspection withgc.get_referrers
that the reference is removed and models are cleaned up immediately once they go out of scope now.See the update below for a full description of the solution I finally went with!