Skip to content

Commit

Permalink
Restore _prune_signature
Browse files Browse the repository at this point in the history
  • Loading branch information
Rocketknight1 committed Jun 13, 2023
1 parent 664b2ae commit 3d12414
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
17 changes: 12 additions & 5 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,7 +1122,8 @@ def dummy_inputs(self) -> Dict[str, tf.Tensor]:
`Dict[str, tf.Tensor]`: The dummy inputs.
"""
dummies = {}
for key, spec in self.input_signature.items():
sig = self._prune_signature(self.input_signature)
for key, spec in sig.items():
# 2 is the most correct arbitrary size. I will not be taking questions
dummy_shape = [dim if dim is not None else 2 for dim in spec.shape]
if spec.shape[0] is None:
Expand Down Expand Up @@ -1171,7 +1172,7 @@ def __init__(self, config, *inputs, **kwargs):
self.name_or_path = config.name_or_path
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
# Set the serving spec quickly to ensure that Keras doesn't use the specific dummy input shapes as the spec
self._set_save_spec(self.input_signature)
self._set_save_spec(self._prune_signature(self.input_signature))

def get_config(self):
return self.config.to_dict()
Expand Down Expand Up @@ -1282,6 +1283,11 @@ def input_signature(self) -> Dict[str, tf.TensorSpec]:
raise NotImplementedError("Audio models need a manually defined input_signature")
return sig

def _prune_signature(self, signature):
"""Keeps only the keys of a given input signature that are valid for this model."""
model_inputs = list(inspect.signature(self.call).parameters)
return {key: val for key, val in signature.items() if key in model_inputs}

def serving_output(self, output):
"""
Prepare the output of the saved model. Can be overridden if specific serving modifications are required.
Expand Down Expand Up @@ -2399,13 +2405,14 @@ def save_pretrained(
if getattr(self.config, "torch_dtype", None) is not None and not isinstance(self.config.torch_dtype, str):
self.config.torch_dtype = str(self.config.torch_dtype).split(".")[1]
if signatures is None:
serving_default = self.serving.get_concrete_function(self.input_signature)
if any(spec.dtype == tf.int32 for spec in self.input_signature.values()):
sig = self._prune_signature(self.input_signature)
serving_default = self.serving.get_concrete_function(sig)
if any(spec.dtype == tf.int32 for spec in sig.values()):
int64_spec = {
key: tf.TensorSpec(
shape=spec.shape, dtype=tf.int64 if spec.dtype == tf.int32 else spec.dtype, name=spec.name
)
for key, spec in self.input_signature.items()
for key, spec in sig.items()
}
int64_serving = self.serving.get_concrete_function(int64_spec)
signatures = {"serving_default": serving_default, "int64_serving": int64_serving}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2676,7 +2676,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.model = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="model")
self.model._set_save_spec(inputs=self.input_signature)
self.model._set_save_spec(self._prune_signature(self.input_signature))
self.use_cache = config.use_cache
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency.
self.bias_layer = BiasLayer(
Expand Down

0 comments on commit 3d12414

Please sign in to comment.