Skip to content

Commit

Permalink
Create contextmanager for hiding cudf safe casting warning.
Browse files Browse the repository at this point in the history
  • Loading branch information
bdice committed Feb 24, 2022
1 parent be29895 commit efa7b61
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions python/cudf/cudf/tests/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ def _hide_deprecated_pandas_categorical_inplace_warnings(function_name):
yield


@contextmanager
def _hide_cudf_safe_casting_warning():
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", "Can't safely cast column", category=UserWarning,
)
yield


@pytest.fixture
def pd_str_cat():
categories = list("abc")
Expand Down Expand Up @@ -606,21 +615,15 @@ def test_categorical_set_categories_categoricals(data, new_categories):
gd_data = cudf.from_pandas(pd_data)

expected = pd_data.cat.set_categories(new_categories=new_categories)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", "Can't safely cast column", category=UserWarning,
)
with _hide_cudf_safe_casting_warning():
actual = gd_data.cat.set_categories(new_categories=new_categories)

assert_eq(expected, actual)

expected = pd_data.cat.set_categories(
new_categories=pd.Series(new_categories, dtype="category")
)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", "Can't safely cast column", category=UserWarning,
)
with _hide_cudf_safe_casting_warning():
actual = gd_data.cat.set_categories(
new_categories=cudf.Series(new_categories, dtype="category")
)
Expand Down Expand Up @@ -733,11 +736,9 @@ def test_add_categories(data, add):
gds = cudf.Series(data, dtype="category")

expected = pds.cat.add_categories(add)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", "Can't safely cast column", category=UserWarning,
)
with _hide_cudf_safe_casting_warning():
actual = gds.cat.add_categories(add)

assert_eq(
expected.cat.codes, actual.cat.codes.astype(expected.cat.codes.dtype)
)
Expand Down

0 comments on commit efa7b61

Please sign in to comment.