Skip to content

Commit

Permalink
ENH: Add all warnings check to the assert_produces_warnings, and sepa…
Browse files Browse the repository at this point in the history
…rate messages for each warning. (pandas-dev#57222)

* ENH: Add all warnings check to the `assert_produces_warnings`, and separate messages for each warning.

* Fix typing errors

* Fix typing errors

* Remove unnecessary documentation

* Change `assert_produces_warning behavior` to check for all warnings by default

* Refactor typing

* Fix tests expecting a Warning that is not raised

* Adjust `raises_chained_assignment_error` and its dependencies to the new API of `assert_produces_warning`

* Fix `_assert_caught_expected_warning` typing not including tuple of warnings

* fixup! Refactor typing

* fixup! Fix `_assert_caught_expected_warning` typing not including tuple of warnings

* fixup! Fix `_assert_caught_expected_warning` typing not including tuple of warnings

* Add tests

---------

Co-authored-by: Richard Shadrach <[email protected]>
  • Loading branch information
Jorewin and rhshadrach authored Apr 14, 2024
1 parent 72fa623 commit d8c7e85
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 31 deletions.
66 changes: 50 additions & 16 deletions pandas/_testing/_warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import (
TYPE_CHECKING,
Literal,
Union,
cast,
)
import warnings
Expand All @@ -32,7 +33,8 @@ def assert_produces_warning(
] = "always",
check_stacklevel: bool = True,
raise_on_extra_warnings: bool = True,
match: str | None = None,
match: str | tuple[str | None, ...] | None = None,
must_find_all_warnings: bool = True,
) -> Generator[list[warnings.WarningMessage], None, None]:
"""
Context manager for running code expected to either raise a specific warning,
Expand Down Expand Up @@ -68,8 +70,15 @@ class for all warnings. To raise multiple types of exceptions,
raise_on_extra_warnings : bool, default True
Whether extra warnings not of the type `expected_warning` should
cause the test to fail.
match : str, optional
Match warning message.
match : {str, tuple[str, ...]}, optional
Match warning message. If it's a tuple, it has to be the size of
`expected_warning`. If additionally `must_find_all_warnings` is
True, each expected warning's message gets matched with a respective
match. Otherwise, multiple values get treated as an alternative.
must_find_all_warnings : bool, default True
If True and `expected_warning` is a tuple, each expected warning
type must get encountered. Otherwise, even one expected warning
results in success.
Examples
--------
Expand Down Expand Up @@ -97,13 +106,35 @@ class for all warnings. To raise multiple types of exceptions,
yield w
finally:
if expected_warning:
expected_warning = cast(type[Warning], expected_warning)
_assert_caught_expected_warning(
caught_warnings=w,
expected_warning=expected_warning,
match=match,
check_stacklevel=check_stacklevel,
)
if isinstance(expected_warning, tuple) and must_find_all_warnings:
match = (
match
if isinstance(match, tuple)
else (match,) * len(expected_warning)
)
for warning_type, warning_match in zip(expected_warning, match):
_assert_caught_expected_warnings(
caught_warnings=w,
expected_warning=warning_type,
match=warning_match,
check_stacklevel=check_stacklevel,
)
else:
expected_warning = cast(
Union[type[Warning], tuple[type[Warning], ...]],
expected_warning,
)
match = (
"|".join(m for m in match if m)
if isinstance(match, tuple)
else match
)
_assert_caught_expected_warnings(
caught_warnings=w,
expected_warning=expected_warning,
match=match,
check_stacklevel=check_stacklevel,
)
if raise_on_extra_warnings:
_assert_caught_no_extra_warnings(
caught_warnings=w,
Expand All @@ -123,17 +154,22 @@ def maybe_produces_warning(
return nullcontext()


def _assert_caught_expected_warning(
def _assert_caught_expected_warnings(
*,
caught_warnings: Sequence[warnings.WarningMessage],
expected_warning: type[Warning],
expected_warning: type[Warning] | tuple[type[Warning], ...],
match: str | None,
check_stacklevel: bool,
) -> None:
"""Assert that there was the expected warning among the caught warnings."""
saw_warning = False
matched_message = False
unmatched_messages = []
warning_name = (
tuple(x.__name__ for x in expected_warning)
if isinstance(expected_warning, tuple)
else expected_warning.__name__
)

for actual_warning in caught_warnings:
if issubclass(actual_warning.category, expected_warning):
Expand All @@ -149,13 +185,11 @@ def _assert_caught_expected_warning(
unmatched_messages.append(actual_warning.message)

if not saw_warning:
raise AssertionError(
f"Did not see expected warning of class {expected_warning.__name__!r}"
)
raise AssertionError(f"Did not see expected warning of class {warning_name!r}")

if match and not matched_message:
raise AssertionError(
f"Did not see warning {expected_warning.__name__!r} "
f"Did not see warning {warning_name!r} "
f"matching '{match}'. The emitted warning messages are "
f"{unmatched_messages}"
)
Expand Down
4 changes: 2 additions & 2 deletions pandas/_testing/contexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def raises_chained_assignment_error(warn=True, extra_warnings=(), extra_match=()
elif PYPY and extra_warnings:
return assert_produces_warning(
extra_warnings,
match="|".join(extra_match),
match=extra_match,
)
else:
if using_copy_on_write():
Expand All @@ -190,5 +190,5 @@ def raises_chained_assignment_error(warn=True, extra_warnings=(), extra_match=()
warning = (warning, *extra_warnings) # type: ignore[assignment]
return assert_produces_warning(
warning,
match="|".join((match, *extra_match)),
match=(match, *extra_match),
)
4 changes: 3 additions & 1 deletion pandas/tests/indexing/test_chaining_and_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,9 @@ def test_detect_chained_assignment_changing_dtype(self):
with tm.raises_chained_assignment_error():
df.loc[2]["C"] = "foo"
tm.assert_frame_equal(df, df_original)
with tm.raises_chained_assignment_error(extra_warnings=(FutureWarning,)):
with tm.raises_chained_assignment_error(
extra_warnings=(FutureWarning,), extra_match=(None,)
):
df["C"][2] = "foo"
tm.assert_frame_equal(df, df_original)

Expand Down
2 changes: 0 additions & 2 deletions pandas/tests/io/parser/common/test_read_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@ def test_warn_bad_lines(all_parsers):
expected_warning = ParserWarning
if parser.engine == "pyarrow":
match_msg = "Expected 1 columns, but found 3: 1,2,3"
expected_warning = (ParserWarning, DeprecationWarning)

with tm.assert_produces_warning(
expected_warning, match=match_msg, check_stacklevel=False
Expand Down Expand Up @@ -315,7 +314,6 @@ def test_on_bad_lines_warn_correct_formatting(all_parsers):
expected_warning = ParserWarning
if parser.engine == "pyarrow":
match_msg = "Expected 2 columns, but found 3: a,b,c"
expected_warning = (ParserWarning, DeprecationWarning)

with tm.assert_produces_warning(
expected_warning, match=match_msg, check_stacklevel=False
Expand Down
14 changes: 7 additions & 7 deletions pandas/tests/io/parser/test_parse_dates.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def test_multiple_date_col(all_parsers, keep_date_col, request):
"names": ["X0", "X1", "X2", "X3", "X4", "X5", "X6", "X7", "X8"],
}
with tm.assert_produces_warning(
(DeprecationWarning, FutureWarning), match=depr_msg, check_stacklevel=False
FutureWarning, match=depr_msg, check_stacklevel=False
):
result = parser.read_csv(StringIO(data), **kwds)

Expand Down Expand Up @@ -724,7 +724,7 @@ def test_multiple_date_col_name_collision(all_parsers, data, parse_dates, msg):
)
with pytest.raises(ValueError, match=msg):
with tm.assert_produces_warning(
(FutureWarning, DeprecationWarning), match=depr_msg, check_stacklevel=False
FutureWarning, match=depr_msg, check_stacklevel=False
):
parser.read_csv(StringIO(data), parse_dates=parse_dates)

Expand Down Expand Up @@ -1248,14 +1248,14 @@ def test_multiple_date_col_named_index_compat(all_parsers):
"Support for nested sequences for 'parse_dates' in pd.read_csv is deprecated"
)
with tm.assert_produces_warning(
(FutureWarning, DeprecationWarning), match=depr_msg, check_stacklevel=False
FutureWarning, match=depr_msg, check_stacklevel=False
):
with_indices = parser.read_csv(
StringIO(data), parse_dates={"nominal": [1, 2]}, index_col="nominal"
)

with tm.assert_produces_warning(
(FutureWarning, DeprecationWarning), match=depr_msg, check_stacklevel=False
FutureWarning, match=depr_msg, check_stacklevel=False
):
with_names = parser.read_csv(
StringIO(data),
Expand All @@ -1280,13 +1280,13 @@ def test_multiple_date_col_multiple_index_compat(all_parsers):
"Support for nested sequences for 'parse_dates' in pd.read_csv is deprecated"
)
with tm.assert_produces_warning(
(FutureWarning, DeprecationWarning), match=depr_msg, check_stacklevel=False
FutureWarning, match=depr_msg, check_stacklevel=False
):
result = parser.read_csv(
StringIO(data), index_col=["nominal", "ID"], parse_dates={"nominal": [1, 2]}
)
with tm.assert_produces_warning(
(FutureWarning, DeprecationWarning), match=depr_msg, check_stacklevel=False
FutureWarning, match=depr_msg, check_stacklevel=False
):
expected = parser.read_csv(StringIO(data), parse_dates={"nominal": [1, 2]})

Expand Down Expand Up @@ -2267,7 +2267,7 @@ def test_parse_dates_dict_format_two_columns(all_parsers, key, parse_dates):
"Support for nested sequences for 'parse_dates' in pd.read_csv is deprecated"
)
with tm.assert_produces_warning(
(FutureWarning, DeprecationWarning), match=depr_msg, check_stacklevel=False
FutureWarning, match=depr_msg, check_stacklevel=False
):
result = parser.read_csv(
StringIO(data), date_format={key: "%d- %m-%Y"}, parse_dates=parse_dates
Expand Down
4 changes: 2 additions & 2 deletions pandas/tests/io/parser/usecols/test_parse_dates.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def test_usecols_with_parse_dates4(all_parsers):
"Support for nested sequences for 'parse_dates' in pd.read_csv is deprecated"
)
with tm.assert_produces_warning(
(FutureWarning, DeprecationWarning), match=depr_msg, check_stacklevel=False
FutureWarning, match=depr_msg, check_stacklevel=False
):
result = parser.read_csv(
StringIO(data),
Expand Down Expand Up @@ -187,7 +187,7 @@ def test_usecols_with_parse_dates_and_names(all_parsers, usecols, names, request
"Support for nested sequences for 'parse_dates' in pd.read_csv is deprecated"
)
with tm.assert_produces_warning(
(FutureWarning, DeprecationWarning), match=depr_msg, check_stacklevel=False
FutureWarning, match=depr_msg, check_stacklevel=False
):
result = parser.read_csv(
StringIO(s), names=names, parse_dates=parse_dates, usecols=usecols
Expand Down
39 changes: 38 additions & 1 deletion pandas/tests/util/test_assert_produces_warning.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def f():
warnings.warn("f2", RuntimeWarning)


@pytest.mark.filterwarnings("ignore:f1:FutureWarning")
def test_assert_produces_warning_honors_filter():
# Raise by default.
msg = r"Caused unexpected warning\(s\)"
Expand Down Expand Up @@ -180,6 +179,44 @@ def test_match_multiple_warnings():
warnings.warn("Match this too", UserWarning)


def test_must_match_multiple_warnings():
# https://github.com/pandas-dev/pandas/issues/56555
category = (FutureWarning, UserWarning)
msg = "Did not see expected warning of class 'UserWarning'"
with pytest.raises(AssertionError, match=msg):
with tm.assert_produces_warning(category, match=r"^Match this"):
warnings.warn("Match this", FutureWarning)


def test_must_match_multiple_warnings_messages():
# https://github.com/pandas-dev/pandas/issues/56555
category = (FutureWarning, UserWarning)
msg = r"The emitted warning messages are \[UserWarning\('Not this'\)\]"
with pytest.raises(AssertionError, match=msg):
with tm.assert_produces_warning(category, match=r"^Match this"):
warnings.warn("Match this", FutureWarning)
warnings.warn("Not this", UserWarning)


def test_allow_partial_match_for_multiple_warnings():
# https://github.com/pandas-dev/pandas/issues/56555
category = (FutureWarning, UserWarning)
with tm.assert_produces_warning(
category, match=r"^Match this", must_find_all_warnings=False
):
warnings.warn("Match this", FutureWarning)


def test_allow_partial_match_for_multiple_warnings_messages():
# https://github.com/pandas-dev/pandas/issues/56555
category = (FutureWarning, UserWarning)
with tm.assert_produces_warning(
category, match=r"^Match this", must_find_all_warnings=False
):
warnings.warn("Match this", FutureWarning)
warnings.warn("Not this", UserWarning)


def test_right_category_wrong_match_raises(pair_different_warnings):
target_category, other_category = pair_different_warnings
with pytest.raises(AssertionError, match="Did not see warning.*matching"):
Expand Down

0 comments on commit d8c7e85

Please sign in to comment.