From a3ff33e4af7888ab1759a1cb33e3a7cffbe6e61d Mon Sep 17 00:00:00 2001 From: Sebastien Ehrhardt Date: Thu, 9 May 2024 20:05:38 +0100 Subject: [PATCH] finish implementation --- .../test_modeling_tf_vision_encoder_decoder.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/models/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py b/tests/models/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py index 38b3019e1985fc..171f33d6802a4a 100644 --- a/tests/models/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py +++ b/tests/models/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py @@ -492,7 +492,9 @@ def check_tf_to_pt_equivalence(self, config, decoder_config, tf_inputs_dict): with tempfile.TemporaryDirectory() as tmpdirname: tf_model.save_pretrained(tmpdirname, safe_serialization=False) - pt_model = VisionEncoderDecoderModel.from_pretrained(tmpdirname, from_tf=True, attn_implementation="eager") + pt_model = VisionEncoderDecoderModel.from_pretrained( + tmpdirname, from_tf=True, attn_implementation=tf_model.config._attn_implementation + ) self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict)