diff --git a/python/cudf/cudf/core/column/categorical.py b/python/cudf/cudf/core/column/categorical.py index e3e73035046..dc51cd4f28f 100644 --- a/python/cudf/cudf/core/column/categorical.py +++ b/python/cudf/cudf/core/column/categorical.py @@ -1045,9 +1045,6 @@ def fillna( """ Fill null values with *fill_value* """ - if not self.nullable: - return self - if fill_value is not None: fill_is_scalar = np.isscalar(fill_value) @@ -1079,6 +1076,11 @@ def fillna( self.codes.dtype ) + # Validation of `fill_value` will have to be performed + # before returning self. + if not self.nullable: + return self + return super().fillna(fill_value, method=method) def indices_of( diff --git a/python/cudf/cudf/core/frame.py b/python/cudf/cudf/core/frame.py index 017190ab5b4..58932db2bda 100644 --- a/python/cudf/cudf/core/frame.py +++ b/python/cudf/cudf/core/frame.py @@ -762,10 +762,17 @@ def fillna( else: replace_val = None should_fill = ( - col_name in value - and col.has_nulls(include_nan=True) - and not libcudf.scalar._is_null_host_scalar(replace_val) - ) or method is not None + ( + col_name in value + and col.has_nulls(include_nan=True) + and not libcudf.scalar._is_null_host_scalar(replace_val) + ) + or method is not None + or ( + isinstance(col, cudf.core.column.CategoricalColumn) + and not libcudf.scalar._is_null_host_scalar(replace_val) + ) + ) if should_fill: filled_data[col_name] = col.fillna(replace_val, method) else: diff --git a/python/cudf/cudf/tests/test_categorical.py b/python/cudf/cudf/tests/test_categorical.py index 7aba2e45532..07ce81e3c39 100644 --- a/python/cudf/cudf/tests/test_categorical.py +++ b/python/cudf/cudf/tests/test_categorical.py @@ -859,3 +859,19 @@ def test_cat_from_scalar(scalar): gs = cudf.Series(scalar, dtype="category") assert_eq(ps, gs) + + +def test_cat_groupby_fillna(): + ps = pd.Series(["a", "b", "c"], dtype="category") + gs = cudf.from_pandas(ps) + + with pytest.warns(FutureWarning): + pg = ps.groupby(ps) + gg = gs.groupby(gs) + + assert_exceptions_equal( + lfunc=pg.fillna, + rfunc=gg.fillna, + lfunc_args_and_kwargs=(("d",), {}), + rfunc_args_and_kwargs=(("d",), {}), + )