diff --git a/python/cudf/cudf/tests/test_categorical.py b/python/cudf/cudf/tests/test_categorical.py index 03314bd7fdf..19a5cd4a49d 100644 --- a/python/cudf/cudf/tests/test_categorical.py +++ b/python/cudf/cudf/tests/test_categorical.py @@ -34,6 +34,15 @@ def _hide_deprecated_pandas_categorical_inplace_warnings(function_name): yield +@contextmanager +def _hide_cudf_safe_casting_warning(): + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", "Can't safely cast column", category=UserWarning, + ) + yield + + @pytest.fixture def pd_str_cat(): categories = list("abc") @@ -606,10 +615,7 @@ def test_categorical_set_categories_categoricals(data, new_categories): gd_data = cudf.from_pandas(pd_data) expected = pd_data.cat.set_categories(new_categories=new_categories) - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", "Can't safely cast column", category=UserWarning, - ) + with _hide_cudf_safe_casting_warning(): actual = gd_data.cat.set_categories(new_categories=new_categories) assert_eq(expected, actual) @@ -617,10 +623,7 @@ def test_categorical_set_categories_categoricals(data, new_categories): expected = pd_data.cat.set_categories( new_categories=pd.Series(new_categories, dtype="category") ) - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", "Can't safely cast column", category=UserWarning, - ) + with _hide_cudf_safe_casting_warning(): actual = gd_data.cat.set_categories( new_categories=cudf.Series(new_categories, dtype="category") ) @@ -733,11 +736,9 @@ def test_add_categories(data, add): gds = cudf.Series(data, dtype="category") expected = pds.cat.add_categories(add) - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", "Can't safely cast column", category=UserWarning, - ) + with _hide_cudf_safe_casting_warning(): actual = gds.cat.add_categories(add) + assert_eq( expected.cat.codes, actual.cat.codes.astype(expected.cat.codes.dtype) )