diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index b23d2bae661053..73ba70806ca6ca 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -1171,12 +1171,8 @@ def __init__(self, config, *inputs, **kwargs): self.config = config self.name_or_path = config.name_or_path self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None - if not hasattr(self, "serving"): # Don't overwrite existing serving signatures - self.serving = tf.function( - self.eager_serving, input_signature=[self._prune_signature(self.input_signature)] - ) # Set the serving spec quickly to ensure that Keras doesn't use the specific dummy input shapes as the spec - self._set_save_spec(self.serving.input_signature[0]) + self._set_save_spec(self._prune_signature(self.input_signature)) def get_config(self): return self.config.to_dict() @@ -1226,15 +1222,31 @@ def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers): head_mask = tf.cast(head_mask, tf.float32) # switch to float if need + fp16 compatibility return head_mask + @tf.function + def serving(self, inputs): + """ + Args: + Method used for serving the model. Does not have a specific signature, but will be specialized as concrete + functions when saving with `save_pretrained`. + inputs (`Dict[str, tf.Tensor]`): + The input of the saved model as a dictionary of tensors. + """ + output = self.call(inputs) + + return self.serving_output(output) + def eager_serving(self, inputs): """ - Method used for serving the model. Intended not to be compiled with a tf.function decorator so that we can use - it to generate multiple signatures later. + Method used for serving the model. This method is deprecated, and will be removed. Args: inputs (`Dict[str, tf.Tensor]`): The input of the saved model as a dictionary of tensors. """ + warnings.warn( + "The function `eager_serving` is deprecated and will be removed in version 4.32.0 of Transformers", + FutureWarning, + ) output = self.call(inputs) return self.serving_output(output) @@ -2409,17 +2421,19 @@ def save_pretrained( if getattr(self.config, "torch_dtype", None) is not None and not isinstance(self.config.torch_dtype, str): self.config.torch_dtype = str(self.config.torch_dtype).split(".")[1] if signatures is None: - if any(spec.dtype == tf.int32 for spec in self.serving.input_signature[0].values()): + sig = self._prune_signature(self.input_signature) + serving_default = self.serving.get_concrete_function(sig) + if any(spec.dtype == tf.int32 for spec in sig.values()): int64_spec = { key: tf.TensorSpec( shape=spec.shape, dtype=tf.int64 if spec.dtype == tf.int32 else spec.dtype, name=spec.name ) - for key, spec in self.serving.input_signature[0].items() + for key, spec in sig.items() } - int64_serving = tf.function(self.eager_serving, input_signature=[int64_spec]) - signatures = {"serving_default": self.serving, "int64_serving": int64_serving} + int64_serving = self.serving.get_concrete_function(int64_spec) + signatures = {"serving_default": serving_default, "int64_serving": int64_serving} else: - signatures = self.serving + signatures = serving_default saved_model_dir = os.path.join(save_directory, "saved_model", str(version)) self.save(saved_model_dir, include_optimizer=False, signatures=signatures) logger.info(f"Saved model created in {saved_model_dir}") diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index f7035751901daa..9f0bdb89b577f0 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -1875,13 +1875,6 @@ def preprocess_string(string, skip_cuda_tests): if not is_cuda_found: modified_string = "".join(codeblocks) - if ">>>" in modified_string: - lines = modified_string.split("\n") - indent = len(lines[-1]) - len(lines[-1].lstrip()) - - cleanup = ">>> import gc; gc.collect() # doctest: +IGNORE_RESULT" - modified_string += "\n" + " " * indent + cleanup - return modified_string diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py index 982a5807d20621..c7c69aa18fd46c 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py @@ -2676,7 +2676,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) self.model = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="model") - self.model._set_save_spec(inputs=self.serving.input_signature) + self.model._set_save_spec(self._prune_signature(self.input_signature)) self.use_cache = config.use_cache # final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency. self.bias_layer = BiasLayer( diff --git a/tests/models/rag/test_modeling_tf_rag.py b/tests/models/rag/test_modeling_tf_rag.py index a7edf6e0f1b7f6..b4720f7c7f0dde 100644 --- a/tests/models/rag/test_modeling_tf_rag.py +++ b/tests/models/rag/test_modeling_tf_rag.py @@ -1,6 +1,5 @@ from __future__ import annotations -import gc import json import os import shutil @@ -551,11 +550,6 @@ def config_and_inputs(self): @require_sentencepiece @require_tokenizers class TFRagModelIntegrationTests(unittest.TestCase): - def tearDown(self): - super().tearDown() - # clean-up as much as possible GPU memory occupied by PyTorch - gc.collect() - @cached_property def token_model(self): return TFRagTokenForGeneration.from_pretrained_question_encoder_generator( diff --git a/tests/models/sam/test_modeling_tf_sam.py b/tests/models/sam/test_modeling_tf_sam.py index 863d4bd842b9db..a14b9912867126 100644 --- a/tests/models/sam/test_modeling_tf_sam.py +++ b/tests/models/sam/test_modeling_tf_sam.py @@ -17,7 +17,6 @@ from __future__ import annotations -import gc import inspect import unittest @@ -431,11 +430,6 @@ def prepare_dog_img(): @require_tf @slow class TFSamModelIntegrationTest(unittest.TestCase): - def tearDown(self): - super().tearDown() - # clean-up as much as possible GPU memory occupied by PyTorch - gc.collect() - def test_inference_mask_generation_no_point(self): model = TFSamModel.from_pretrained("facebook/sam-vit-base") processor = SamProcessor.from_pretrained("facebook/sam-vit-base") diff --git a/tests/models/xglm/test_modeling_tf_xglm.py b/tests/models/xglm/test_modeling_tf_xglm.py index 3582209cc7d136..e2b8cc2e6cbcfd 100644 --- a/tests/models/xglm/test_modeling_tf_xglm.py +++ b/tests/models/xglm/test_modeling_tf_xglm.py @@ -15,7 +15,6 @@ from __future__ import annotations -import gc import unittest from transformers import XGLMConfig, XGLMTokenizer, is_tf_available @@ -191,11 +190,6 @@ def test_resize_token_embeddings(self): @require_tf class TFXGLMModelLanguageGenerationTest(unittest.TestCase): - def tearDown(self): - super().tearDown() - # clean-up as much as possible GPU memory occupied by PyTorch - gc.collect() - @slow def test_lm_generate_xglm(self, verify_outputs=True): model = TFXGLMForCausalLM.from_pretrained("facebook/xglm-564M") diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 586a2a761dc1c2..c411e7cdc42291 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -1684,14 +1684,10 @@ def test_int_support(self): if tensor.dtype.is_integer: self.assertTrue(tensor.dtype == tf.int32, "Integer dummy inputs should be tf.int32!") - # Also confirm that the serving sig uses int32 - if hasattr(model, "serving"): - serving_sig = model.serving.input_signature - for key, tensor_spec in serving_sig[0].items(): - if tensor_spec.dtype.is_integer: - self.assertTrue( - tensor_spec.dtype == tf.int32, "Serving signatures should use tf.int32 for ints!" - ) + # Also confirm that the input_signature uses int32 + for key, tensor_spec in model.input_signature.items(): + if tensor_spec.dtype.is_integer: + self.assertTrue(tensor_spec.dtype == tf.int32, "Input signatures should use tf.int32 for ints!") def test_generate_with_headmasking(self): attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"] diff --git a/tests/utils/test_modeling_tf_core.py b/tests/utils/test_modeling_tf_core.py index ea5bc26986ae8b..17d68a4de59a55 100644 --- a/tests/utils/test_modeling_tf_core.py +++ b/tests/utils/test_modeling_tf_core.py @@ -217,17 +217,18 @@ def test_saved_model_creation_extended(self): for model_class in self.all_model_classes: class_inputs_dict = self._prepare_for_class(inputs_dict, model_class) model = model_class(config) + class_sig = model._prune_signature(model.input_signature) num_out = len(model(class_inputs_dict)) for key in list(class_inputs_dict.keys()): # Remove keys not in the serving signature, as the SavedModel will not be compiled to deal with them - if key not in model.serving.input_signature[0]: + if key not in class_sig: del class_inputs_dict[key] # Check it's a tensor, in case the inputs dict has some bools in it too elif isinstance(class_inputs_dict[key], tf.Tensor) and class_inputs_dict[key].dtype.is_integer: class_inputs_dict[key] = tf.cast(class_inputs_dict[key], tf.int32) - if set(class_inputs_dict.keys()) != set(model.serving.input_signature[0].keys()): + if set(class_inputs_dict.keys()) != set(class_sig.keys()): continue # Some models have inputs that the preparation functions don't create, we skip those with tempfile.TemporaryDirectory() as tmpdirname: