From df4507951cff5ee53eeadbc0fc9b30bd05ab0ec6 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Tue, 7 May 2024 09:04:21 +0200 Subject: [PATCH] ci/test: patch unstable `test_bleu_score_functional` (#2533) (cherry picked from commit d76b82e7f14cd3004d3198a1e5b031f87663e066) --- tests/unittests/_helpers/wrappers.py | 21 +++++++++++++++++---- tests/unittests/text/test_sacre_bleu.py | 2 ++ 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/tests/unittests/_helpers/wrappers.py b/tests/unittests/_helpers/wrappers.py index 4029f001b2d..157dd1db2b8 100644 --- a/tests/unittests/_helpers/wrappers.py +++ b/tests/unittests/_helpers/wrappers.py @@ -6,6 +6,12 @@ ALLOW_SKIP_IF_OUT_OF_MEMORY = os.getenv("ALLOW_SKIP_IF_OUT_OF_MEMORY", "0") == "1" ALLOW_SKIP_IF_BAD_CONNECTION = os.getenv("ALLOW_SKIP_IF_BAD_CONNECTION", "0") == "1" +_ERROR_CONNECTION_MESSAGE_PATTERNS = ( + "We couldn't connect to", + "Connection error", + "Can't load", + "`nltk` resource `punkt` is", +) def skip_on_running_out_of_memory(reason: str = "Skipping test as it ran out of memory."): @@ -33,18 +39,25 @@ def skip_on_connection_issues(reason: str = "Unable to load checkpoints from Hug The tests run normally if no connection issue arises, and they're marked as skipped otherwise. """ - _error_msg_starts = ["We couldn't connect to", "Connection error", "Can't load", "`nltk` resource `punkt` is"] def test_decorator(function: Callable, *args: Any, **kwargs: Any) -> Optional[Callable]: @wraps(function) def run_test(*args: Any, **kwargs: Any) -> Optional[Any]: + from urllib.error import URLError + try: return function(*args, **kwargs) + except URLError as ex: + if "Error 403: Forbidden" not in str(ex) or not ALLOW_SKIP_IF_BAD_CONNECTION: + raise ex + pytest.skip(reason) except (OSError, ValueError) as ex: - if all(msg_start not in str(ex) for msg_start in _error_msg_starts): + if ( + all(msg_start not in str(ex) for msg_start in _ERROR_CONNECTION_MESSAGE_PATTERNS) + or not ALLOW_SKIP_IF_BAD_CONNECTION + ): raise ex - if ALLOW_SKIP_IF_BAD_CONNECTION: - pytest.skip(reason) + pytest.skip(reason) return run_test diff --git a/tests/unittests/text/test_sacre_bleu.py b/tests/unittests/text/test_sacre_bleu.py index dbf72269a3e..e8b66012011 100644 --- a/tests/unittests/text/test_sacre_bleu.py +++ b/tests/unittests/text/test_sacre_bleu.py @@ -21,6 +21,7 @@ from torchmetrics.functional.text.sacre_bleu import AVAILABLE_TOKENIZERS, _TokenizersLiteral, sacre_bleu_score from torchmetrics.text.sacre_bleu import SacreBLEUScore +from unittests._helpers import skip_on_connection_issues from unittests.text._helpers import TextTester from unittests.text._inputs import _inputs_multiple_references @@ -69,6 +70,7 @@ def test_bleu_score_class(self, ddp, preds, targets, tokenize, lowercase): metric_args=metric_args, ) + @skip_on_connection_issues(reason="could not download model or tokenizer") def test_bleu_score_functional(self, preds, targets, tokenize, lowercase): """Test functional implementation of metric.""" if _should_skip_tokenizer(tokenize):