Skip to content

Commit

Permalink
Better TF docstring types (#23477)
Browse files Browse the repository at this point in the history
* Rework TF type hints to use | None instead of Optional[] for tf.Tensor

* Rework TF type hints to use | None instead of Optional[] for tf.Tensor

* Don't forget the imports

* Add the imports to tests too

* make fixup

* Refactor tests that depended on get_type_hints

* Better test refactor

* Fix an old hidden bug in the test_keras_fit input creation code

* Fix for the Deit tests
  • Loading branch information
Rocketknight1 authored May 24, 2023
1 parent 767e6b5 commit f8b2574
Show file tree
Hide file tree
Showing 139 changed files with 2,907 additions and 2,621 deletions.
190 changes: 96 additions & 94 deletions src/transformers/modeling_tf_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import warnings
from dataclasses import dataclass
from typing import List, Optional, Tuple
Expand Down Expand Up @@ -43,8 +45,8 @@ class TFBaseModelOutput(ModelOutput):
"""

last_hidden_state: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand Down Expand Up @@ -96,8 +98,8 @@ class TFBaseModelOutputWithPooling(ModelOutput):

last_hidden_state: tf.Tensor = None
pooler_output: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand Down Expand Up @@ -164,10 +166,10 @@ class TFBaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):

last_hidden_state: tf.Tensor = None
pooler_output: tf.Tensor = None
past_key_values: Optional[List[tf.Tensor]] = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
cross_attentions: Optional[Tuple[tf.Tensor]] = None
past_key_values: List[tf.Tensor] | None = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None
cross_attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand Down Expand Up @@ -201,9 +203,9 @@ class TFBaseModelOutputWithPast(ModelOutput):
"""

last_hidden_state: tf.Tensor = None
past_key_values: Optional[List[tf.Tensor]] = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
past_key_values: List[tf.Tensor] | None = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand Down Expand Up @@ -234,9 +236,9 @@ class TFBaseModelOutputWithCrossAttentions(ModelOutput):
"""

last_hidden_state: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
cross_attentions: Optional[Tuple[tf.Tensor]] = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None
cross_attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand Down Expand Up @@ -276,10 +278,10 @@ class TFBaseModelOutputWithPastAndCrossAttentions(ModelOutput):
"""

last_hidden_state: tf.Tensor = None
past_key_values: Optional[List[tf.Tensor]] = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
cross_attentions: Optional[Tuple[tf.Tensor]] = None
past_key_values: List[tf.Tensor] | None = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None
cross_attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand Down Expand Up @@ -333,13 +335,13 @@ class TFSeq2SeqModelOutput(ModelOutput):
"""

last_hidden_state: tf.Tensor = None
past_key_values: Optional[List[tf.Tensor]] = None
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
decoder_attentions: Optional[Tuple[tf.Tensor]] = None
cross_attentions: Optional[Tuple[tf.Tensor]] = None
encoder_last_hidden_state: Optional[tf.Tensor] = None
encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
encoder_attentions: Optional[Tuple[tf.Tensor]] = None
past_key_values: List[tf.Tensor] | None = None
decoder_hidden_states: Tuple[tf.Tensor] | None = None
decoder_attentions: Tuple[tf.Tensor] | None = None
cross_attentions: Tuple[tf.Tensor] | None = None
encoder_last_hidden_state: tf.Tensor | None = None
encoder_hidden_states: Tuple[tf.Tensor] | None = None
encoder_attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand All @@ -365,10 +367,10 @@ class TFCausalLMOutput(ModelOutput):
heads.
"""

loss: Optional[tf.Tensor] = None
loss: tf.Tensor | None = None
logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand Down Expand Up @@ -400,11 +402,11 @@ class TFCausalLMOutputWithPast(ModelOutput):
heads.
"""

loss: Optional[tf.Tensor] = None
loss: tf.Tensor | None = None
logits: tf.Tensor = None
past_key_values: Optional[List[tf.Tensor]] = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
past_key_values: List[tf.Tensor] | None = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand Down Expand Up @@ -442,12 +444,12 @@ class TFCausalLMOutputWithCrossAttentions(ModelOutput):
`past_key_values` input) to speed up sequential decoding.
"""

loss: Optional[tf.Tensor] = None
loss: tf.Tensor | None = None
logits: tf.Tensor = None
past_key_values: Optional[List[tf.Tensor]] = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
cross_attentions: Optional[Tuple[tf.Tensor]] = None
past_key_values: List[tf.Tensor] | None = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None
cross_attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand All @@ -473,10 +475,10 @@ class TFMaskedLMOutput(ModelOutput):
heads.
"""

loss: Optional[tf.Tensor] = None
loss: tf.Tensor | None = None
logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand Down Expand Up @@ -527,15 +529,15 @@ class TFSeq2SeqLMOutput(ModelOutput):
self-attention heads.
"""

loss: Optional[tf.Tensor] = None
loss: tf.Tensor | None = None
logits: tf.Tensor = None
past_key_values: Optional[List[tf.Tensor]] = None
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
decoder_attentions: Optional[Tuple[tf.Tensor]] = None
cross_attentions: Optional[Tuple[tf.Tensor]] = None
encoder_last_hidden_state: Optional[tf.Tensor] = None
encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
encoder_attentions: Optional[Tuple[tf.Tensor]] = None
past_key_values: List[tf.Tensor] | None = None
decoder_hidden_states: Tuple[tf.Tensor] | None = None
decoder_attentions: Tuple[tf.Tensor] | None = None
cross_attentions: Tuple[tf.Tensor] | None = None
encoder_last_hidden_state: tf.Tensor | None = None
encoder_hidden_states: Tuple[tf.Tensor] | None = None
encoder_attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand All @@ -562,10 +564,10 @@ class TFNextSentencePredictorOutput(ModelOutput):
heads.
"""

loss: Optional[tf.Tensor] = None
loss: tf.Tensor | None = None
logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand All @@ -591,10 +593,10 @@ class TFSequenceClassifierOutput(ModelOutput):
heads.
"""

loss: Optional[tf.Tensor] = None
loss: tf.Tensor | None = None
logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand Down Expand Up @@ -642,15 +644,15 @@ class TFSeq2SeqSequenceClassifierOutput(ModelOutput):
self-attention heads.
"""

loss: Optional[tf.Tensor] = None
loss: tf.Tensor | None = None
logits: tf.Tensor = None
past_key_values: Optional[List[tf.Tensor]] = None
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
decoder_attentions: Optional[Tuple[tf.Tensor]] = None
cross_attentions: Optional[Tuple[tf.Tensor]] = None
encoder_last_hidden_state: Optional[tf.Tensor] = None
encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
encoder_attentions: Optional[Tuple[tf.Tensor]] = None
past_key_values: List[tf.Tensor] | None = None
decoder_hidden_states: Tuple[tf.Tensor] | None = None
decoder_attentions: Tuple[tf.Tensor] | None = None
cross_attentions: Tuple[tf.Tensor] | None = None
encoder_last_hidden_state: tf.Tensor | None = None
encoder_hidden_states: Tuple[tf.Tensor] | None = None
encoder_attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand Down Expand Up @@ -684,10 +686,10 @@ class TFSemanticSegmenterOutput(ModelOutput):
heads.
"""

loss: Optional[tf.Tensor] = None
loss: tf.Tensor | None = None
logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand Down Expand Up @@ -716,9 +718,9 @@ class TFSemanticSegmenterOutputWithNoAttention(ModelOutput):
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
"""

loss: Optional[tf.Tensor] = None
loss: tf.Tensor | None = None
logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
hidden_states: Tuple[tf.Tensor] | None = None


@dataclass
Expand All @@ -742,10 +744,10 @@ class TFImageClassifierOutput(ModelOutput):
heads.
"""

loss: Optional[tf.Tensor] = None
loss: tf.Tensor | None = None
logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand Down Expand Up @@ -773,10 +775,10 @@ class TFMultipleChoiceModelOutput(ModelOutput):
heads.
"""

loss: Optional[tf.Tensor] = None
loss: tf.Tensor | None = None
logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand All @@ -802,10 +804,10 @@ class TFTokenClassifierOutput(ModelOutput):
heads.
"""

loss: Optional[tf.Tensor] = None
loss: tf.Tensor | None = None
logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand Down Expand Up @@ -833,11 +835,11 @@ class TFQuestionAnsweringModelOutput(ModelOutput):
heads.
"""

loss: Optional[tf.Tensor] = None
loss: tf.Tensor | None = None
start_logits: tf.Tensor = None
end_logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand Down Expand Up @@ -884,15 +886,15 @@ class TFSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
self-attention heads.
"""

loss: Optional[tf.Tensor] = None
loss: tf.Tensor | None = None
start_logits: tf.Tensor = None
end_logits: tf.Tensor = None
past_key_values: Optional[List[tf.Tensor]] = None
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
decoder_attentions: Optional[Tuple[tf.Tensor]] = None
encoder_last_hidden_state: Optional[tf.Tensor] = None
encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
encoder_attentions: Optional[Tuple[tf.Tensor]] = None
past_key_values: List[tf.Tensor] | None = None
decoder_hidden_states: Tuple[tf.Tensor] | None = None
decoder_attentions: Tuple[tf.Tensor] | None = None
encoder_last_hidden_state: tf.Tensor | None = None
encoder_hidden_states: Tuple[tf.Tensor] | None = None
encoder_attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand Down Expand Up @@ -924,11 +926,11 @@ class TFSequenceClassifierOutputWithPast(ModelOutput):
heads.
"""

loss: Optional[tf.Tensor] = None
loss: tf.Tensor | None = None
logits: tf.Tensor = None
past_key_values: Optional[List[tf.Tensor]] = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
past_key_values: List[tf.Tensor] | None = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand All @@ -947,7 +949,7 @@ class TFImageClassifierOutputWithNoAttention(ModelOutput):
feature maps) of the model at the output of each stage.
"""

loss: Optional[tf.Tensor] = None
loss: tf.Tensor | None = None
logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor, ...]] = None

Expand All @@ -974,10 +976,10 @@ class TFMaskedImageModelingOutput(ModelOutput):
heads.
"""

loss: Optional[tf.Tensor] = None
loss: tf.Tensor | None = None
reconstruction: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None

@property
def logits(self):
Expand Down
Loading

0 comments on commit f8b2574

Please sign in to comment.