Skip to content

Commit

Permalink
ci/test: patch unstable test_bleu_score_functional (#2533)
Browse files Browse the repository at this point in the history
(cherry picked from commit d76b82e)
  • Loading branch information
Borda committed May 15, 2024
1 parent d44b729 commit df45079
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 4 deletions.
21 changes: 17 additions & 4 deletions tests/unittests/_helpers/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@

ALLOW_SKIP_IF_OUT_OF_MEMORY = os.getenv("ALLOW_SKIP_IF_OUT_OF_MEMORY", "0") == "1"
ALLOW_SKIP_IF_BAD_CONNECTION = os.getenv("ALLOW_SKIP_IF_BAD_CONNECTION", "0") == "1"
_ERROR_CONNECTION_MESSAGE_PATTERNS = (
"We couldn't connect to",
"Connection error",
"Can't load",
"`nltk` resource `punkt` is",
)


def skip_on_running_out_of_memory(reason: str = "Skipping test as it ran out of memory."):
Expand Down Expand Up @@ -33,18 +39,25 @@ def skip_on_connection_issues(reason: str = "Unable to load checkpoints from Hug
The tests run normally if no connection issue arises, and they're marked as skipped otherwise.
"""
_error_msg_starts = ["We couldn't connect to", "Connection error", "Can't load", "`nltk` resource `punkt` is"]

def test_decorator(function: Callable, *args: Any, **kwargs: Any) -> Optional[Callable]:
@wraps(function)
def run_test(*args: Any, **kwargs: Any) -> Optional[Any]:
from urllib.error import URLError

try:
return function(*args, **kwargs)
except URLError as ex:
if "Error 403: Forbidden" not in str(ex) or not ALLOW_SKIP_IF_BAD_CONNECTION:
raise ex
pytest.skip(reason)
except (OSError, ValueError) as ex:
if all(msg_start not in str(ex) for msg_start in _error_msg_starts):
if (
all(msg_start not in str(ex) for msg_start in _ERROR_CONNECTION_MESSAGE_PATTERNS)
or not ALLOW_SKIP_IF_BAD_CONNECTION
):
raise ex
if ALLOW_SKIP_IF_BAD_CONNECTION:
pytest.skip(reason)
pytest.skip(reason)

return run_test

Expand Down
2 changes: 2 additions & 0 deletions tests/unittests/text/test_sacre_bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torchmetrics.functional.text.sacre_bleu import AVAILABLE_TOKENIZERS, _TokenizersLiteral, sacre_bleu_score
from torchmetrics.text.sacre_bleu import SacreBLEUScore

from unittests._helpers import skip_on_connection_issues
from unittests.text._helpers import TextTester
from unittests.text._inputs import _inputs_multiple_references

Expand Down Expand Up @@ -69,6 +70,7 @@ def test_bleu_score_class(self, ddp, preds, targets, tokenize, lowercase):
metric_args=metric_args,
)

@skip_on_connection_issues(reason="could not download model or tokenizer")
def test_bleu_score_functional(self, preds, targets, tokenize, lowercase):
"""Test functional implementation of metric."""
if _should_skip_tokenizer(tokenize):
Expand Down

0 comments on commit df45079

Please sign in to comment.