From 8da21872beecdb6e1c170f22f9c79c8d58f4357e Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 9 Jun 2023 17:40:07 +0100 Subject: [PATCH 01/12] Stop storing references to bound methods in tf.functions --- src/transformers/modeling_tf_utils.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index b23d2bae6610..0af109596529 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -1173,7 +1173,7 @@ def __init__(self, config, *inputs, **kwargs): self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None if not hasattr(self, "serving"): # Don't overwrite existing serving signatures self.serving = tf.function( - self.eager_serving, input_signature=[self._prune_signature(self.input_signature)] + self._get_eager_serving_fn(), input_signature=[self._prune_signature(self.input_signature)] ) # Set the serving spec quickly to ensure that Keras doesn't use the specific dummy input shapes as the spec self._set_save_spec(self.serving.input_signature[0]) @@ -1226,18 +1226,21 @@ def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers): head_mask = tf.cast(head_mask, tf.float32) # switch to float if need + fp16 compatibility return head_mask - def eager_serving(self, inputs): + def _get_eager_serving_fn(self): """ - Method used for serving the model. Intended not to be compiled with a tf.function decorator so that we can use - it to generate multiple signatures later. + Returns a function used for serving the model. Note that we do not make this function a method of the class + with a tf.function decorator because we won't know what its signature should be until we have the model config. - Args: - inputs (`Dict[str, tf.Tensor]`): - The input of the saved model as a dictionary of tensors. + We also don't directly make this an eager method because compiling a method bound to a class greatly + complicates cleanup of that class. """ - output = self.call(inputs) - return self.serving_output(output) + def eager_serving(self, inputs): + output = self.call(inputs) + + return self.serving_output(output) + + return eager_serving @property def input_signature(self) -> Dict[str, tf.TensorSpec]: @@ -2416,7 +2419,7 @@ def save_pretrained( ) for key, spec in self.serving.input_signature[0].items() } - int64_serving = tf.function(self.eager_serving, input_signature=[int64_spec]) + int64_serving = tf.function(self._get_eager_serving_fn(), input_signature=[int64_spec]) signatures = {"serving_default": self.serving, "int64_serving": int64_serving} else: signatures = self.serving From 5bb70c12268c25e8d29c3f47c28676aef261b474 Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 9 Jun 2023 17:57:55 +0100 Subject: [PATCH 02/12] Remove the gc.collect calls now that we resolved the underlying problem --- tests/models/rag/test_modeling_tf_rag.py | 6 ------ tests/models/sam/test_modeling_tf_sam.py | 5 ----- tests/models/xglm/test_modeling_tf_xglm.py | 6 ------ 3 files changed, 17 deletions(-) diff --git a/tests/models/rag/test_modeling_tf_rag.py b/tests/models/rag/test_modeling_tf_rag.py index a7edf6e0f1b7..b4720f7c7f0d 100644 --- a/tests/models/rag/test_modeling_tf_rag.py +++ b/tests/models/rag/test_modeling_tf_rag.py @@ -1,6 +1,5 @@ from __future__ import annotations -import gc import json import os import shutil @@ -551,11 +550,6 @@ def config_and_inputs(self): @require_sentencepiece @require_tokenizers class TFRagModelIntegrationTests(unittest.TestCase): - def tearDown(self): - super().tearDown() - # clean-up as much as possible GPU memory occupied by PyTorch - gc.collect() - @cached_property def token_model(self): return TFRagTokenForGeneration.from_pretrained_question_encoder_generator( diff --git a/tests/models/sam/test_modeling_tf_sam.py b/tests/models/sam/test_modeling_tf_sam.py index 863d4bd842b9..b4f934a4b738 100644 --- a/tests/models/sam/test_modeling_tf_sam.py +++ b/tests/models/sam/test_modeling_tf_sam.py @@ -431,11 +431,6 @@ def prepare_dog_img(): @require_tf @slow class TFSamModelIntegrationTest(unittest.TestCase): - def tearDown(self): - super().tearDown() - # clean-up as much as possible GPU memory occupied by PyTorch - gc.collect() - def test_inference_mask_generation_no_point(self): model = TFSamModel.from_pretrained("facebook/sam-vit-base") processor = SamProcessor.from_pretrained("facebook/sam-vit-base") diff --git a/tests/models/xglm/test_modeling_tf_xglm.py b/tests/models/xglm/test_modeling_tf_xglm.py index 3582209cc7d1..e2b8cc2e6cbc 100644 --- a/tests/models/xglm/test_modeling_tf_xglm.py +++ b/tests/models/xglm/test_modeling_tf_xglm.py @@ -15,7 +15,6 @@ from __future__ import annotations -import gc import unittest from transformers import XGLMConfig, XGLMTokenizer, is_tf_available @@ -191,11 +190,6 @@ def test_resize_token_embeddings(self): @require_tf class TFXGLMModelLanguageGenerationTest(unittest.TestCase): - def tearDown(self): - super().tearDown() - # clean-up as much as possible GPU memory occupied by PyTorch - gc.collect() - @slow def test_lm_generate_xglm(self, verify_outputs=True): model = TFXGLMForCausalLM.from_pretrained("facebook/xglm-564M") From c5466d6ae772591b68629977d41cff84b32f49ab Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 12 Jun 2023 19:38:15 +0100 Subject: [PATCH 03/12] Remove the default signature from model.serving entirely, big cleanup --- src/transformers/modeling_tf_utils.py | 35 +++++++++------------------ 1 file changed, 11 insertions(+), 24 deletions(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 0af109596529..4c973cb3e41a 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -1171,12 +1171,8 @@ def __init__(self, config, *inputs, **kwargs): self.config = config self.name_or_path = config.name_or_path self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None - if not hasattr(self, "serving"): # Don't overwrite existing serving signatures - self.serving = tf.function( - self._get_eager_serving_fn(), input_signature=[self._prune_signature(self.input_signature)] - ) # Set the serving spec quickly to ensure that Keras doesn't use the specific dummy input shapes as the spec - self._set_save_spec(self.serving.input_signature[0]) + self._set_save_spec(self.input_signature) def get_config(self): return self.config.to_dict() @@ -1226,21 +1222,11 @@ def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers): head_mask = tf.cast(head_mask, tf.float32) # switch to float if need + fp16 compatibility return head_mask - def _get_eager_serving_fn(self): - """ - Returns a function used for serving the model. Note that we do not make this function a method of the class - with a tf.function decorator because we won't know what its signature should be until we have the model config. - - We also don't directly make this an eager method because compiling a method bound to a class greatly - complicates cleanup of that class. - """ - - def eager_serving(self, inputs): - output = self.call(inputs) - - return self.serving_output(output) + @tf.function + def serving(self, inputs): + output = self.call(inputs) - return eager_serving + return self.serving_output(output) @property def input_signature(self) -> Dict[str, tf.TensorSpec]: @@ -2412,17 +2398,18 @@ def save_pretrained( if getattr(self.config, "torch_dtype", None) is not None and not isinstance(self.config.torch_dtype, str): self.config.torch_dtype = str(self.config.torch_dtype).split(".")[1] if signatures is None: - if any(spec.dtype == tf.int32 for spec in self.serving.input_signature[0].values()): + serving_default = self.serving.get_concrete_function(self.input_signature) + if any(spec.dtype == tf.int32 for spec in self.input_signature.values()): int64_spec = { key: tf.TensorSpec( shape=spec.shape, dtype=tf.int64 if spec.dtype == tf.int32 else spec.dtype, name=spec.name ) - for key, spec in self.serving.input_signature[0].items() + for key, spec in self.input_signature.items() } - int64_serving = tf.function(self._get_eager_serving_fn(), input_signature=[int64_spec]) - signatures = {"serving_default": self.serving, "int64_serving": int64_serving} + int64_serving = self.serving.get_concrete_function(int64_spec) + signatures = {"serving_default": serving_default, "int64_serving": int64_serving} else: - signatures = self.serving + signatures = serving_default saved_model_dir = os.path.join(save_directory, "saved_model", str(version)) self.save(saved_model_dir, include_optimizer=False, signatures=signatures) logger.info(f"Saved model created in {saved_model_dir}") From 03c9061172b3d815d331c1db1a739e94097a8d80 Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 12 Jun 2023 19:41:49 +0100 Subject: [PATCH 04/12] Remove _prune_signature as self.input_signature can prune itself --- src/transformers/modeling_tf_utils.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 4c973cb3e41a..b5a109a02f54 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -1122,8 +1122,7 @@ def dummy_inputs(self) -> Dict[str, tf.Tensor]: `Dict[str, tf.Tensor]`: The dummy inputs. """ dummies = {} - sig = self._prune_signature(self.input_signature) - for key, spec in sig.items(): + for key, spec in self.input_signature.items(): # 2 is the most correct arbitrary size. I will not be taking questions dummy_shape = [dim if dim is not None else 2 for dim in spec.shape] if spec.shape[0] is None: @@ -1276,11 +1275,6 @@ def input_signature(self) -> Dict[str, tf.TensorSpec]: raise NotImplementedError("Audio models need a manually defined input_signature") return sig - def _prune_signature(self, signature): - """Keeps only the keys of a given input signature that are valid for this model.""" - model_inputs = list(inspect.signature(self.call).parameters) - return {key: val for key, val in signature.items() if key in model_inputs} - def serving_output(self, output): """ Prepare the output of the saved model. Can be overridden if specific serving modifications are required. From d8c32e4b8d3a40e7052e2b320190715051f50179 Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 12 Jun 2023 19:50:53 +0100 Subject: [PATCH 05/12] Restore serving docstring --- src/transformers/modeling_tf_utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index b5a109a02f54..cbd6d3dd8ba0 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -1223,6 +1223,13 @@ def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers): @tf.function def serving(self, inputs): + """ + Method used for serving the model. Does not have a specific signature, but will be specialized + as concrete functions when saving with `save_pretrained`. + Args: + inputs (`Dict[str, tf.Tensor]`): + The input of the saved model as a dictionary of tensors. + """ output = self.call(inputs) return self.serving_output(output) From 0bc46e06f567866e137a5b48e86af7056104159d Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 12 Jun 2023 19:51:39 +0100 Subject: [PATCH 06/12] Update int support test to check the input signature --- tests/test_modeling_tf_common.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 586a2a761dc1..7ecee69c0f01 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -1684,14 +1684,12 @@ def test_int_support(self): if tensor.dtype.is_integer: self.assertTrue(tensor.dtype == tf.int32, "Integer dummy inputs should be tf.int32!") - # Also confirm that the serving sig uses int32 - if hasattr(model, "serving"): - serving_sig = model.serving.input_signature - for key, tensor_spec in serving_sig[0].items(): - if tensor_spec.dtype.is_integer: - self.assertTrue( - tensor_spec.dtype == tf.int32, "Serving signatures should use tf.int32 for ints!" - ) + # Also confirm that the input_signature uses int32 + for key, tensor_spec in model.input_signature.items(): + if tensor_spec.dtype.is_integer: + self.assertTrue( + tensor_spec.dtype == tf.int32, "Input signatures should use tf.int32 for ints!" + ) def test_generate_with_headmasking(self): attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"] From a68b2ac99cd8335bb3469afdb2c9da726b265c7b Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 12 Jun 2023 19:59:44 +0100 Subject: [PATCH 07/12] Make sure other tests also use model.input_signature and not serving.input_signature --- src/transformers/modeling_tf_utils.py | 4 ++-- .../modeling_tf_{{cookiecutter.lowercase_modelname}}.py | 2 +- tests/test_modeling_tf_common.py | 4 +--- tests/utils/test_modeling_tf_core.py | 4 ++-- 4 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index cbd6d3dd8ba0..46c7556df6ce 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -1224,9 +1224,9 @@ def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers): @tf.function def serving(self, inputs): """ - Method used for serving the model. Does not have a specific signature, but will be specialized - as concrete functions when saving with `save_pretrained`. Args: + Method used for serving the model. Does not have a specific signature, but will be specialized as concrete + functions when saving with `save_pretrained`. inputs (`Dict[str, tf.Tensor]`): The input of the saved model as a dictionary of tensors. """ diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py index 982a5807d206..b158881d96bd 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py @@ -2676,7 +2676,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) self.model = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="model") - self.model._set_save_spec(inputs=self.serving.input_signature) + self.model._set_save_spec(inputs=self.input_signature) self.use_cache = config.use_cache # final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency. self.bias_layer = BiasLayer( diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 7ecee69c0f01..c411e7cdc422 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -1687,9 +1687,7 @@ def test_int_support(self): # Also confirm that the input_signature uses int32 for key, tensor_spec in model.input_signature.items(): if tensor_spec.dtype.is_integer: - self.assertTrue( - tensor_spec.dtype == tf.int32, "Input signatures should use tf.int32 for ints!" - ) + self.assertTrue(tensor_spec.dtype == tf.int32, "Input signatures should use tf.int32 for ints!") def test_generate_with_headmasking(self): attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"] diff --git a/tests/utils/test_modeling_tf_core.py b/tests/utils/test_modeling_tf_core.py index ea5bc26986ae..135db86d4d55 100644 --- a/tests/utils/test_modeling_tf_core.py +++ b/tests/utils/test_modeling_tf_core.py @@ -221,13 +221,13 @@ def test_saved_model_creation_extended(self): for key in list(class_inputs_dict.keys()): # Remove keys not in the serving signature, as the SavedModel will not be compiled to deal with them - if key not in model.serving.input_signature[0]: + if key not in model.input_signature: del class_inputs_dict[key] # Check it's a tensor, in case the inputs dict has some bools in it too elif isinstance(class_inputs_dict[key], tf.Tensor) and class_inputs_dict[key].dtype.is_integer: class_inputs_dict[key] = tf.cast(class_inputs_dict[key], tf.int32) - if set(class_inputs_dict.keys()) != set(model.serving.input_signature[0].keys()): + if set(class_inputs_dict.keys()) != set(model.input_signature.keys()): continue # Some models have inputs that the preparation functions don't create, we skip those with tempfile.TemporaryDirectory() as tmpdirname: From e371f62e0f108c7dd449c32440cf23a39475c2c8 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 13 Jun 2023 13:40:47 +0100 Subject: [PATCH 08/12] Restore _prune_signature --- src/transformers/modeling_tf_utils.py | 17 ++++++++++++----- ...g_tf_{{cookiecutter.lowercase_modelname}}.py | 2 +- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 46c7556df6ce..a4521fb8a63c 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -1122,7 +1122,8 @@ def dummy_inputs(self) -> Dict[str, tf.Tensor]: `Dict[str, tf.Tensor]`: The dummy inputs. """ dummies = {} - for key, spec in self.input_signature.items(): + sig = self._prune_signature(self.input_signature) + for key, spec in sig.items(): # 2 is the most correct arbitrary size. I will not be taking questions dummy_shape = [dim if dim is not None else 2 for dim in spec.shape] if spec.shape[0] is None: @@ -1171,7 +1172,7 @@ def __init__(self, config, *inputs, **kwargs): self.name_or_path = config.name_or_path self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None # Set the serving spec quickly to ensure that Keras doesn't use the specific dummy input shapes as the spec - self._set_save_spec(self.input_signature) + self._set_save_spec(self._prune_signature(self.input_signature)) def get_config(self): return self.config.to_dict() @@ -1282,6 +1283,11 @@ def input_signature(self) -> Dict[str, tf.TensorSpec]: raise NotImplementedError("Audio models need a manually defined input_signature") return sig + def _prune_signature(self, signature): + """Keeps only the keys of a given input signature that are valid for this model.""" + model_inputs = list(inspect.signature(self.call).parameters) + return {key: val for key, val in signature.items() if key in model_inputs} + def serving_output(self, output): """ Prepare the output of the saved model. Can be overridden if specific serving modifications are required. @@ -2399,13 +2405,14 @@ def save_pretrained( if getattr(self.config, "torch_dtype", None) is not None and not isinstance(self.config.torch_dtype, str): self.config.torch_dtype = str(self.config.torch_dtype).split(".")[1] if signatures is None: - serving_default = self.serving.get_concrete_function(self.input_signature) - if any(spec.dtype == tf.int32 for spec in self.input_signature.values()): + sig = self._prune_signature(self.input_signature) + serving_default = self.serving.get_concrete_function(sig) + if any(spec.dtype == tf.int32 for spec in sig.values()): int64_spec = { key: tf.TensorSpec( shape=spec.shape, dtype=tf.int64 if spec.dtype == tf.int32 else spec.dtype, name=spec.name ) - for key, spec in self.input_signature.items() + for key, spec in sig.items() } int64_serving = self.serving.get_concrete_function(int64_spec) signatures = {"serving_default": serving_default, "int64_serving": int64_serving} diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py index b158881d96bd..c7c69aa18fd4 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py @@ -2676,7 +2676,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) self.model = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="model") - self.model._set_save_spec(inputs=self.input_signature) + self.model._set_save_spec(self._prune_signature(self.input_signature)) self.use_cache = config.use_cache # final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency. self.bias_layer = BiasLayer( From f98b7b7415f53ceae992568a259f15a627d3f522 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 13 Jun 2023 13:46:03 +0100 Subject: [PATCH 09/12] Remove the doctest GC now it's no longer needed --- src/transformers/testing_utils.py | 7 ------- tests/models/sam/test_modeling_tf_sam.py | 1 - 2 files changed, 8 deletions(-) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index f7035751901d..9f0bdb89b577 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -1875,13 +1875,6 @@ def preprocess_string(string, skip_cuda_tests): if not is_cuda_found: modified_string = "".join(codeblocks) - if ">>>" in modified_string: - lines = modified_string.split("\n") - indent = len(lines[-1]) - len(lines[-1].lstrip()) - - cleanup = ">>> import gc; gc.collect() # doctest: +IGNORE_RESULT" - modified_string += "\n" + " " * indent + cleanup - return modified_string diff --git a/tests/models/sam/test_modeling_tf_sam.py b/tests/models/sam/test_modeling_tf_sam.py index b4f934a4b738..a14b99128671 100644 --- a/tests/models/sam/test_modeling_tf_sam.py +++ b/tests/models/sam/test_modeling_tf_sam.py @@ -17,7 +17,6 @@ from __future__ import annotations -import gc import inspect import unittest From 8af7857c4a386f9d83fe47475a8d6ef73cefd88d Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 13 Jun 2023 14:37:11 +0100 Subject: [PATCH 10/12] Correct core tests to use the pruned sig --- tests/utils/test_modeling_tf_core.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/utils/test_modeling_tf_core.py b/tests/utils/test_modeling_tf_core.py index 135db86d4d55..ebbd63f4c742 100644 --- a/tests/utils/test_modeling_tf_core.py +++ b/tests/utils/test_modeling_tf_core.py @@ -216,18 +216,19 @@ def test_saved_model_creation_extended(self): for model_class in self.all_model_classes: class_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + class_sig = model._prune_signature(model.input_signature) model = model_class(config) num_out = len(model(class_inputs_dict)) for key in list(class_inputs_dict.keys()): # Remove keys not in the serving signature, as the SavedModel will not be compiled to deal with them - if key not in model.input_signature: + if key not in class_sig: del class_inputs_dict[key] # Check it's a tensor, in case the inputs dict has some bools in it too elif isinstance(class_inputs_dict[key], tf.Tensor) and class_inputs_dict[key].dtype.is_integer: class_inputs_dict[key] = tf.cast(class_inputs_dict[key], tf.int32) - if set(class_inputs_dict.keys()) != set(model.input_signature.keys()): + if set(class_inputs_dict.keys()) != set(class_sig.keys()): continue # Some models have inputs that the preparation functions don't create, we skip those with tempfile.TemporaryDirectory() as tmpdirname: From 5939c126861d2780719494f8a18aa2fee10b41f8 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 13 Jun 2023 15:06:22 +0100 Subject: [PATCH 11/12] order lines correctly in core tests --- tests/utils/test_modeling_tf_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils/test_modeling_tf_core.py b/tests/utils/test_modeling_tf_core.py index ebbd63f4c742..17d68a4de59a 100644 --- a/tests/utils/test_modeling_tf_core.py +++ b/tests/utils/test_modeling_tf_core.py @@ -216,8 +216,8 @@ def test_saved_model_creation_extended(self): for model_class in self.all_model_classes: class_inputs_dict = self._prepare_for_class(inputs_dict, model_class) - class_sig = model._prune_signature(model.input_signature) model = model_class(config) + class_sig = model._prune_signature(model.input_signature) num_out = len(model(class_inputs_dict)) for key in list(class_inputs_dict.keys()): From 8764c3713a5db5d921bb444f5163dfaed423e270 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 13 Jun 2023 17:48:18 +0100 Subject: [PATCH 12/12] Add eager_serving back with a deprecation warning --- src/transformers/modeling_tf_utils.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index a4521fb8a63c..73ba70806ca6 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -1235,6 +1235,22 @@ def serving(self, inputs): return self.serving_output(output) + def eager_serving(self, inputs): + """ + Method used for serving the model. This method is deprecated, and will be removed. + + Args: + inputs (`Dict[str, tf.Tensor]`): + The input of the saved model as a dictionary of tensors. + """ + warnings.warn( + "The function `eager_serving` is deprecated and will be removed in version 4.32.0 of Transformers", + FutureWarning, + ) + output = self.call(inputs) + + return self.serving_output(output) + @property def input_signature(self) -> Dict[str, tf.TensorSpec]: """