Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stop storing references to bound methods via tf.function #24146

Merged
merged 12 commits into from
Jun 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 26 additions & 12 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,12 +1171,8 @@ def __init__(self, config, *inputs, **kwargs):
self.config = config
self.name_or_path = config.name_or_path
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
if not hasattr(self, "serving"): # Don't overwrite existing serving signatures
self.serving = tf.function(
self.eager_serving, input_signature=[self._prune_signature(self.input_signature)]
)
# Set the serving spec quickly to ensure that Keras doesn't use the specific dummy input shapes as the spec
self._set_save_spec(self.serving.input_signature[0])
self._set_save_spec(self._prune_signature(self.input_signature))

def get_config(self):
return self.config.to_dict()
Expand Down Expand Up @@ -1226,15 +1222,31 @@ def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
head_mask = tf.cast(head_mask, tf.float32) # switch to float if need + fp16 compatibility
return head_mask

@tf.function
def serving(self, inputs):
"""
Args:
Method used for serving the model. Does not have a specific signature, but will be specialized as concrete
functions when saving with `save_pretrained`.
inputs (`Dict[str, tf.Tensor]`):
The input of the saved model as a dictionary of tensors.
"""
output = self.call(inputs)

return self.serving_output(output)

def eager_serving(self, inputs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't just kill a public method like this, unfortunately. We'll need to have a deprecation cycle and warning that it's going to be removed in 2 versions time

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood! I'll put it back and mark it as deprecated.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree in general, but is eager_serving intended to be used by user ..?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added it back with a warning!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ydshieh Nope, not intended, but it's a public method on all of our TF models, so removing it is a breaking change.

"""
Method used for serving the model. Intended not to be compiled with a tf.function decorator so that we can use
it to generate multiple signatures later.
Method used for serving the model. This method is deprecated, and will be removed.

Args:
inputs (`Dict[str, tf.Tensor]`):
The input of the saved model as a dictionary of tensors.
"""
warnings.warn(
"The function `eager_serving` is deprecated and will be removed in version 4.32.0 of Transformers",
FutureWarning,
)
output = self.call(inputs)

return self.serving_output(output)
Expand Down Expand Up @@ -2409,17 +2421,19 @@ def save_pretrained(
if getattr(self.config, "torch_dtype", None) is not None and not isinstance(self.config.torch_dtype, str):
self.config.torch_dtype = str(self.config.torch_dtype).split(".")[1]
if signatures is None:
if any(spec.dtype == tf.int32 for spec in self.serving.input_signature[0].values()):
sig = self._prune_signature(self.input_signature)
serving_default = self.serving.get_concrete_function(sig)
if any(spec.dtype == tf.int32 for spec in sig.values()):
int64_spec = {
key: tf.TensorSpec(
shape=spec.shape, dtype=tf.int64 if spec.dtype == tf.int32 else spec.dtype, name=spec.name
)
for key, spec in self.serving.input_signature[0].items()
for key, spec in sig.items()
}
int64_serving = tf.function(self.eager_serving, input_signature=[int64_spec])
signatures = {"serving_default": self.serving, "int64_serving": int64_serving}
int64_serving = self.serving.get_concrete_function(int64_spec)
signatures = {"serving_default": serving_default, "int64_serving": int64_serving}
else:
signatures = self.serving
signatures = serving_default
saved_model_dir = os.path.join(save_directory, "saved_model", str(version))
self.save(saved_model_dir, include_optimizer=False, signatures=signatures)
logger.info(f"Saved model created in {saved_model_dir}")
Expand Down
7 changes: 0 additions & 7 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1875,13 +1875,6 @@ def preprocess_string(string, skip_cuda_tests):
if not is_cuda_found:
modified_string = "".join(codeblocks)

if ">>>" in modified_string:
lines = modified_string.split("\n")
indent = len(lines[-1]) - len(lines[-1].lstrip())

cleanup = ">>> import gc; gc.collect() # doctest: +IGNORE_RESULT"
modified_string += "\n" + " " * indent + cleanup

return modified_string


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2676,7 +2676,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.model = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="model")
self.model._set_save_spec(inputs=self.serving.input_signature)
self.model._set_save_spec(self._prune_signature(self.input_signature))
self.use_cache = config.use_cache
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency.
self.bias_layer = BiasLayer(
Expand Down
6 changes: 0 additions & 6 deletions tests/models/rag/test_modeling_tf_rag.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import gc
import json
import os
import shutil
Expand Down Expand Up @@ -551,11 +550,6 @@ def config_and_inputs(self):
@require_sentencepiece
@require_tokenizers
class TFRagModelIntegrationTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()

@cached_property
def token_model(self):
return TFRagTokenForGeneration.from_pretrained_question_encoder_generator(
Expand Down
6 changes: 0 additions & 6 deletions tests/models/sam/test_modeling_tf_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from __future__ import annotations

import gc
import inspect
import unittest

Expand Down Expand Up @@ -431,11 +430,6 @@ def prepare_dog_img():
@require_tf
@slow
class TFSamModelIntegrationTest(unittest.TestCase):
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()

def test_inference_mask_generation_no_point(self):
model = TFSamModel.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
Expand Down
6 changes: 0 additions & 6 deletions tests/models/xglm/test_modeling_tf_xglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

from __future__ import annotations

import gc
import unittest

from transformers import XGLMConfig, XGLMTokenizer, is_tf_available
Expand Down Expand Up @@ -191,11 +190,6 @@ def test_resize_token_embeddings(self):

@require_tf
class TFXGLMModelLanguageGenerationTest(unittest.TestCase):
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()

@slow
def test_lm_generate_xglm(self, verify_outputs=True):
model = TFXGLMForCausalLM.from_pretrained("facebook/xglm-564M")
Expand Down
12 changes: 4 additions & 8 deletions tests/test_modeling_tf_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1684,14 +1684,10 @@ def test_int_support(self):
if tensor.dtype.is_integer:
self.assertTrue(tensor.dtype == tf.int32, "Integer dummy inputs should be tf.int32!")

# Also confirm that the serving sig uses int32
if hasattr(model, "serving"):
serving_sig = model.serving.input_signature
for key, tensor_spec in serving_sig[0].items():
if tensor_spec.dtype.is_integer:
self.assertTrue(
tensor_spec.dtype == tf.int32, "Serving signatures should use tf.int32 for ints!"
)
# Also confirm that the input_signature uses int32
for key, tensor_spec in model.input_signature.items():
if tensor_spec.dtype.is_integer:
self.assertTrue(tensor_spec.dtype == tf.int32, "Input signatures should use tf.int32 for ints!")

def test_generate_with_headmasking(self):
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
Expand Down
5 changes: 3 additions & 2 deletions tests/utils/test_modeling_tf_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,17 +217,18 @@ def test_saved_model_creation_extended(self):
for model_class in self.all_model_classes:
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
class_sig = model._prune_signature(model.input_signature)
num_out = len(model(class_inputs_dict))

for key in list(class_inputs_dict.keys()):
# Remove keys not in the serving signature, as the SavedModel will not be compiled to deal with them
if key not in model.serving.input_signature[0]:
if key not in class_sig:
del class_inputs_dict[key]
# Check it's a tensor, in case the inputs dict has some bools in it too
elif isinstance(class_inputs_dict[key], tf.Tensor) and class_inputs_dict[key].dtype.is_integer:
class_inputs_dict[key] = tf.cast(class_inputs_dict[key], tf.int32)

if set(class_inputs_dict.keys()) != set(model.serving.input_signature[0].keys()):
if set(class_inputs_dict.keys()) != set(class_sig.keys()):
continue # Some models have inputs that the preparation functions don't create, we skip those

with tempfile.TemporaryDirectory() as tmpdirname:
Expand Down