Skip to content

Commit

Permalink
encoder - decoder tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Sebastien Ehrhardt committed Apr 29, 2024
1 parent 941552b commit 40b12eb
Showing 1 changed file with 47 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from transformers import DonutProcessor, NougatProcessor, TrOCRProcessor
from transformers.testing_utils import (
_run_slow_tests,
require_levenshtein,
require_nltk,
require_sentencepiece,
Expand Down Expand Up @@ -561,6 +562,29 @@ def prepare_config_and_inputs(self):
"labels": decoder_token_labels,
}

def check_encoder_decoder_model_output_attentions(
self,
config,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
labels=None,
pixel_values=None,
**kwargs,
):
if not _run_slow_tests:
return

super().check_encoder_decoder_model_output_attentions(
config,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
labels=None,
pixel_values=None,
**kwargs,
)


@require_torch
class Swin2BartModelTest(EncoderDecoderMixin, unittest.TestCase):
Expand Down Expand Up @@ -677,6 +701,29 @@ def prepare_config_and_inputs(self):
"labels": decoder_input_ids,
}

def check_encoder_decoder_model_output_attentions(
self,
config,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
labels=None,
pixel_values=None,
**kwargs,
):
if not _run_slow_tests:
return

super().check_encoder_decoder_model_output_attentions(
config,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
labels=None,
pixel_values=None,
**kwargs,
)

# there are no published pretrained TrOCR checkpoints for now
def test_real_model_save_load_from_pretrained(self):
pass
Expand Down

0 comments on commit 40b12eb

Please sign in to comment.