diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index cefba1577ab3bf..32f6abcbe3aad1 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -64,6 +64,7 @@ ) from transformers.testing_utils import ( CaptureLogger, + is_flaky, is_pt_flax_cross_test, is_pt_tf_cross_test, require_accelerate, @@ -381,6 +382,7 @@ def test_gradient_checkpointing_enable_disable(self): m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to False" ) + @is_flaky(description="low likelihood of failure, reason not yet discovered") def test_save_load_fast_init_from_base(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() if config.__class__ not in MODEL_MAPPING: