Skip to content

Commit

Permalink
finish implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Sebastien Ehrhardt committed May 14, 2024
1 parent e20b7b7 commit a3ff33e
Showing 1 changed file with 3 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,9 @@ def check_tf_to_pt_equivalence(self, config, decoder_config, tf_inputs_dict):

with tempfile.TemporaryDirectory() as tmpdirname:
tf_model.save_pretrained(tmpdirname, safe_serialization=False)
pt_model = VisionEncoderDecoderModel.from_pretrained(tmpdirname, from_tf=True, attn_implementation="eager")
pt_model = VisionEncoderDecoderModel.from_pretrained(
tmpdirname, from_tf=True, attn_implementation=tf_model.config._attn_implementation
)

self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict)

Expand Down

0 comments on commit a3ff33e

Please sign in to comment.