Skip to content

Commit

Permalink
Use isinstance(..., cudf.CategoricalDtype) instead of is_categorical_…
Browse files Browse the repository at this point in the history
…dtype (rapidsai#14423)

Helps the code base be more-strict about checking `cudf.CategoricalDtype` and not accidentally allowing `pd.CategoricalDtype` or `"category"`

Authors:
  - Matthew Roeschke (https://github.com/mroeschke)

Approvers:
  - GALI PREM SAGAR (https://github.com/galipremsagar)

URL: rapidsai#14423
  • Loading branch information
mroeschke authored and karthikeyann committed Dec 12, 2023
1 parent a607b03 commit eaba945
Show file tree
Hide file tree
Showing 11 changed files with 42 additions and 36 deletions.
4 changes: 2 additions & 2 deletions python/cudf/cudf/core/_internals/where.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from cudf.api.types import (
_is_non_decimal_numeric_dtype,
is_bool_dtype,
is_categorical_dtype,
is_scalar,
)
from cudf.core.column import ColumnBase
from cudf.core.dtypes import CategoricalDtype
from cudf.utils.dtypes import (
_can_cast,
_dtype_can_hold_element,
Expand Down Expand Up @@ -46,7 +46,7 @@ def _check_and_cast_columns_with_other(
) -> Tuple[ColumnBase, Union[ScalarLike, ColumnBase]]:
# Returns type-casted `source_col` & `other` based on `inplace`.
source_dtype = source_col.dtype
if is_categorical_dtype(source_dtype):
if isinstance(source_dtype, CategoricalDtype):
return _normalize_categorical(source_col, other)

other_is_scalar = is_scalar(other)
Expand Down
4 changes: 2 additions & 2 deletions python/cudf/cudf/core/column/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from cudf import _lib as libcudf
from cudf._lib.transform import bools_to_mask
from cudf._typing import ColumnBinaryOperand, ColumnLike, Dtype, ScalarLike
from cudf.api.types import is_categorical_dtype, is_interval_dtype
from cudf.api.types import is_interval_dtype
from cudf.core.buffer import Buffer
from cudf.core.column import column
from cudf.core.column.methods import ColumnMethods
Expand Down Expand Up @@ -99,7 +99,7 @@ class CategoricalAccessor(ColumnMethods):
_column: CategoricalColumn

def __init__(self, parent: SeriesOrSingleColumnIndex):
if not is_categorical_dtype(parent.dtype):
if not isinstance(parent.dtype, CategoricalDtype):
raise AttributeError(
"Can only use .cat accessor with a 'category' dtype"
)
Expand Down
2 changes: 1 addition & 1 deletion python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -2798,7 +2798,7 @@ def concat_columns(objs: "MutableSequence[ColumnBase]") -> ColumnBase:
# ColumnBase._concat so that all subclasses can override necessary
# behavior. However, at the moment it's not clear what that API should look
# like, so CategoricalColumn simply implements a minimal working API.
if all(is_categorical_dtype(o.dtype) for o in objs):
if all(isinstance(o.dtype, CategoricalDtype) for o in objs):
return cudf.core.column.categorical.CategoricalColumn._concat(
cast(
MutableSequence[
Expand Down
6 changes: 3 additions & 3 deletions python/cudf/cudf/core/column/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import pyarrow as pa

import cudf
from cudf.api.types import is_categorical_dtype, is_interval_dtype
from cudf.api.types import is_interval_dtype
from cudf.core.column import StructColumn
from cudf.core.dtypes import IntervalDtype
from cudf.core.dtypes import CategoricalDtype, IntervalDtype


class IntervalColumn(StructColumn):
Expand Down Expand Up @@ -102,7 +102,7 @@ def copy(self, deep=True):

def as_interval_column(self, dtype, **kwargs):
if is_interval_dtype(dtype):
if is_categorical_dtype(self):
if isinstance(self.dtype, CategoricalDtype):
new_struct = self._get_decategorized_column()
return IntervalColumn.from_struct_column(new_struct)
if is_interval_dtype(dtype):
Expand Down
30 changes: 18 additions & 12 deletions python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
from cudf.api.types import (
_is_scalar_or_zero_d_array,
is_bool_dtype,
is_categorical_dtype,
is_datetime_dtype,
is_dict_like,
is_dtype_equal,
Expand Down Expand Up @@ -319,7 +318,9 @@ def _getitem_tuple_arg(self, arg):
as_column(
tmp_arg[0],
dtype=self._frame.index.dtype
if is_categorical_dtype(self._frame.index.dtype)
if isinstance(
self._frame.index.dtype, cudf.CategoricalDtype
)
else None,
),
tmp_arg[1],
Expand Down Expand Up @@ -1503,7 +1504,7 @@ def _get_numeric_data(self):
columns = [
c
for c, dt in self.dtypes.items()
if dt != object and not is_categorical_dtype(dt)
if dt != object and not isinstance(dt, cudf.CategoricalDtype)
]
return self[columns]

Expand Down Expand Up @@ -1746,8 +1747,8 @@ def _concat(
out._index._data,
indices[:first_data_column_position],
)
if not isinstance(out._index, MultiIndex) and is_categorical_dtype(
out._index._values.dtype
if not isinstance(out._index, MultiIndex) and isinstance(
out._index._values.dtype, cudf.CategoricalDtype
):
out = out.set_index(
cudf.core.index.as_index(out.index._values)
Expand Down Expand Up @@ -3910,8 +3911,11 @@ def transpose(self):
# No column from index is transposed with libcudf.
source_columns = [*self._columns]
source_dtype = source_columns[0].dtype
if is_categorical_dtype(source_dtype):
if any(not is_categorical_dtype(c.dtype) for c in source_columns):
if isinstance(source_dtype, cudf.CategoricalDtype):
if any(
not isinstance(c.dtype, cudf.CategoricalDtype)
for c in source_columns
):
raise ValueError("Columns must all have the same dtype")
cats = list(c.categories for c in source_columns)
cats = cudf.core.column.concat_columns(cats).unique()
Expand All @@ -3925,7 +3929,7 @@ def transpose(self):

result_columns = libcudf.transpose.transpose(source_columns)

if is_categorical_dtype(source_dtype):
if isinstance(source_dtype, cudf.CategoricalDtype):
result_columns = [
codes._with_type_metadata(
cudf.core.dtypes.CategoricalDtype(categories=cats)
Expand Down Expand Up @@ -4627,8 +4631,8 @@ def apply_rows(
"""
for col in incols:
current_col_dtype = self._data[col].dtype
if is_string_dtype(current_col_dtype) or is_categorical_dtype(
current_col_dtype
if is_string_dtype(current_col_dtype) or isinstance(
current_col_dtype, cudf.CategoricalDtype
):
raise TypeError(
"User defined functions are currently not "
Expand Down Expand Up @@ -6446,7 +6450,8 @@ def select_dtypes(self, include=None, exclude=None):
for dtype in self.dtypes:
for i_dtype in include:
# category handling
if is_categorical_dtype(i_dtype):
if i_dtype == cudf.CategoricalDtype:
# Matches cudf & pandas dtype objects
include_subtypes.add(i_dtype)
elif inspect.isclass(dtype.type):
if issubclass(dtype.type, i_dtype):
Expand All @@ -6457,7 +6462,8 @@ def select_dtypes(self, include=None, exclude=None):
for dtype in self.dtypes:
for e_dtype in exclude:
# category handling
if is_categorical_dtype(e_dtype):
if e_dtype == cudf.CategoricalDtype:
# Matches cudf & pandas dtype objects
exclude_subtypes.add(e_dtype)
elif inspect.isclass(dtype.type):
if issubclass(dtype.type, e_dtype):
Expand Down
3 changes: 1 addition & 2 deletions python/cudf/cudf/core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from cudf.api.extensions import no_default
from cudf.api.types import (
_is_non_decimal_numeric_dtype,
is_categorical_dtype,
is_dtype_equal,
is_integer,
is_interval_dtype,
Expand Down Expand Up @@ -2975,7 +2974,7 @@ def __init__(
if isinstance(data, CategoricalColumn):
data = data
elif isinstance(data, pd.Series) and (
is_categorical_dtype(data.dtype)
isinstance(data.dtype, pd.CategoricalDtype)
):
codes_data = column.as_column(data.cat.codes.values)
data = column.build_categorical_column(
Expand Down
15 changes: 7 additions & 8 deletions python/cudf/cudf/core/indexed_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
from cudf.api.types import (
_is_non_decimal_numeric_dtype,
is_bool_dtype,
is_categorical_dtype,
is_decimal_dtype,
is_dict_like,
is_list_dtype,
Expand Down Expand Up @@ -171,7 +170,7 @@ def _indices_from_labels(obj, labels):
if not isinstance(labels, cudf.MultiIndex):
labels = cudf.core.column.as_column(labels)

if is_categorical_dtype(obj.index):
if isinstance(obj.index.dtype, cudf.CategoricalDtype):
labels = labels.astype("category")
codes = labels.codes.astype(obj.index._values.codes.dtype)
labels = cudf.core.column.build_categorical_column(
Expand Down Expand Up @@ -5458,21 +5457,21 @@ def _is_same_dtype(lhs_dtype, rhs_dtype):
if lhs_dtype == rhs_dtype:
return True
elif (
is_categorical_dtype(lhs_dtype)
and is_categorical_dtype(rhs_dtype)
isinstance(lhs_dtype, cudf.CategoricalDtype)
and isinstance(rhs_dtype, cudf.CategoricalDtype)
and lhs_dtype.categories.dtype == rhs_dtype.categories.dtype
):
# OK if categories are not all the same
return True
elif (
is_categorical_dtype(lhs_dtype)
and not is_categorical_dtype(rhs_dtype)
isinstance(lhs_dtype, cudf.CategoricalDtype)
and not isinstance(rhs_dtype, cudf.CategoricalDtype)
and lhs_dtype.categories.dtype == rhs_dtype
):
return True
elif (
is_categorical_dtype(rhs_dtype)
and not is_categorical_dtype(lhs_dtype)
isinstance(rhs_dtype, cudf.CategoricalDtype)
and not isinstance(lhs_dtype, cudf.CategoricalDtype)
and rhs_dtype.categories.dtype == lhs_dtype
):
return True
Expand Down
2 changes: 1 addition & 1 deletion python/cudf/cudf/core/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ def melt(

# Error for unimplemented support for datatype
dtypes = [frame[col].dtype for col in id_vars + value_vars]
if any(cudf.api.types.is_categorical_dtype(t) for t in dtypes):
if any(isinstance(typ, cudf.CategoricalDtype) for typ in dtypes):
raise NotImplementedError(
"Categorical columns are not yet supported for function"
)
Expand Down
6 changes: 3 additions & 3 deletions python/cudf/cudf/core/tools/numeric.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2018-2022, NVIDIA CORPORATION.
# Copyright (c) 2018-2023, NVIDIA CORPORATION.

import warnings

Expand All @@ -10,14 +10,14 @@
from cudf._lib import strings as libstrings
from cudf.api.types import (
_is_non_decimal_numeric_dtype,
is_categorical_dtype,
is_datetime_dtype,
is_list_dtype,
is_string_dtype,
is_struct_dtype,
is_timedelta_dtype,
)
from cudf.core.column import as_column
from cudf.core.dtypes import CategoricalDtype
from cudf.utils.dtypes import can_convert_to_column


Expand Down Expand Up @@ -110,7 +110,7 @@ def to_numeric(arg, errors="raise", downcast=None):

if is_datetime_dtype(dtype) or is_timedelta_dtype(dtype):
col = col.as_numerical_column(cudf.dtype("int64"))
elif is_categorical_dtype(dtype):
elif isinstance(dtype, CategoricalDtype):
cat_dtype = col.dtype.type
if _is_non_decimal_numeric_dtype(cat_dtype):
col = col.as_numerical_column(cat_dtype)
Expand Down
4 changes: 3 additions & 1 deletion python/cudf/cudf/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6094,7 +6094,9 @@ def test_df_sr_mask_where(data, condition, other, error, inplace):
expect_mask = ps_mask
got_mask = gs_mask

if pd.api.types.is_categorical_dtype(expect_where):
if isinstance(expect_where, pd.Series) and isinstance(
expect_where.dtype, pd.CategoricalDtype
):
np.testing.assert_array_equal(
expect_where.cat.codes,
got_where.cat.codes.astype(expect_where.cat.codes.dtype)
Expand Down
2 changes: 1 addition & 1 deletion python/cudf/cudf/tests/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,7 @@ def test_index_where(data, condition, other, error):
gs_other = other

if error is None:
if pd.api.types.is_categorical_dtype(ps):
if isinstance(ps.dtype, pd.CategoricalDtype):
expect = ps.where(ps_condition, other=ps_other)
got = gs.where(gs_condition, other=gs_other)
np.testing.assert_array_equal(
Expand Down

0 comments on commit eaba945

Please sign in to comment.