Skip to content

Commit

Permalink
Use isinstance over is_foo_dtype internally (#14638)
Browse files Browse the repository at this point in the history
Internally IMO we should prefer `isinstance` over `is_*_dtype` when we know the input is a dtype instance in order to be stricter/faster that the `is_*_dtype` checks

Authors:
  - Matthew Roeschke (https://github.com/mroeschke)

Approvers:
  - Ashwin Srinath (https://github.com/shwina)

URL: #14638
  • Loading branch information
mroeschke authored Dec 15, 2023
1 parent 0762fbe commit cbcfa67
Show file tree
Hide file tree
Showing 12 changed files with 62 additions and 85 deletions.
43 changes: 12 additions & 31 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,13 @@
is_categorical_dtype,
is_datetime64_dtype,
is_datetime64tz_dtype,
is_decimal32_dtype,
is_decimal64_dtype,
is_decimal128_dtype,
is_decimal_dtype,
is_dtype_equal,
is_integer_dtype,
is_interval_dtype,
is_list_dtype,
is_scalar,
is_string_dtype,
is_struct_dtype,
)
from cudf.core._compat import PANDAS_GE_150
from cudf.core.abc import Serializable
Expand Down Expand Up @@ -1023,21 +1019,15 @@ def astype(self, dtype: Dtype, **kwargs) -> ColumnBase:
"`.astype('str')` instead."
)
return self.as_string_column(dtype, **kwargs)
elif is_list_dtype(dtype):
elif isinstance(dtype, (ListDtype, StructDtype)):
if not self.dtype == dtype:
raise NotImplementedError(
"Casting list columns not currently supported"
f"Casting {self.dtype} columns not currently supported"
)
return self
elif is_struct_dtype(dtype):
if not self.dtype == dtype:
raise NotImplementedError(
"Casting struct columns not currently supported"
)
return self
elif is_interval_dtype(self.dtype):
elif isinstance(dtype, IntervalDtype):
return self.as_interval_column(dtype, **kwargs)
elif is_decimal_dtype(dtype):
elif isinstance(dtype, cudf.core.dtypes.DecimalDtype):
return self.as_decimal_column(dtype, **kwargs)
elif np.issubdtype(cast(Any, dtype), np.datetime64):
return self.as_datetime_column(dtype, **kwargs)
Expand Down Expand Up @@ -1578,7 +1568,7 @@ def build_column(
)
return col

if is_categorical_dtype(dtype):
if isinstance(dtype, CategoricalDtype):
if not len(children) == 1:
raise ValueError(
"Must specify exactly one child column for CategoricalColumn"
Expand All @@ -1604,7 +1594,7 @@ def build_column(
offset=offset,
null_count=null_count,
)
elif is_datetime64tz_dtype(dtype):
elif isinstance(dtype, pd.DatetimeTZDtype):
if data is None:
raise TypeError("Must specify data buffer")
return cudf.core.column.datetime.DatetimeTZColumn(
Expand Down Expand Up @@ -1634,7 +1624,7 @@ def build_column(
children=children,
null_count=null_count,
)
elif is_list_dtype(dtype):
elif isinstance(dtype, ListDtype):
return cudf.core.column.ListColumn(
size=size,
dtype=dtype,
Expand All @@ -1643,7 +1633,7 @@ def build_column(
null_count=null_count,
children=children,
)
elif is_interval_dtype(dtype):
elif isinstance(dtype, IntervalDtype):
return cudf.core.column.IntervalColumn(
dtype=dtype,
mask=mask,
Expand All @@ -1652,7 +1642,7 @@ def build_column(
children=children,
null_count=null_count,
)
elif is_struct_dtype(dtype):
elif isinstance(dtype, StructDtype):
if size is None:
raise TypeError("Must specify size")
return cudf.core.column.StructColumn(
Expand All @@ -1664,7 +1654,7 @@ def build_column(
null_count=null_count,
children=children,
)
elif is_decimal64_dtype(dtype):
elif isinstance(dtype, cudf.Decimal64Dtype):
if size is None:
raise TypeError("Must specify size")
return cudf.core.column.Decimal64Column(
Expand All @@ -1676,7 +1666,7 @@ def build_column(
null_count=null_count,
children=children,
)
elif is_decimal32_dtype(dtype):
elif isinstance(dtype, cudf.Decimal32Dtype):
if size is None:
raise TypeError("Must specify size")
return cudf.core.column.Decimal32Column(
Expand All @@ -1688,7 +1678,7 @@ def build_column(
null_count=null_count,
children=children,
)
elif is_decimal128_dtype(dtype):
elif isinstance(dtype, cudf.Decimal128Dtype):
if size is None:
raise TypeError("Must specify size")
return cudf.core.column.Decimal128Column(
Expand All @@ -1700,15 +1690,6 @@ def build_column(
null_count=null_count,
children=children,
)
elif is_interval_dtype(dtype):
return cudf.core.column.IntervalColumn(
dtype=dtype,
mask=mask,
size=size,
offset=offset,
null_count=null_count,
children=children,
)
else:
raise TypeError(f"Unrecognized dtype: {dtype}")

Expand Down
5 changes: 2 additions & 3 deletions python/cudf/cudf/core/column/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import pyarrow as pa

import cudf
from cudf.api.types import is_interval_dtype
from cudf.core.column import StructColumn
from cudf.core.dtypes import CategoricalDtype, IntervalDtype

Expand Down Expand Up @@ -101,11 +100,11 @@ def copy(self, deep=True):
)

def as_interval_column(self, dtype, **kwargs):
if is_interval_dtype(dtype):
if isinstance(dtype, IntervalDtype):
if isinstance(self.dtype, CategoricalDtype):
new_struct = self._get_decategorized_column()
return IntervalColumn.from_struct_column(new_struct)
if is_interval_dtype(dtype):
else:
# a user can directly input the string `interval` as the dtype
# when creating an interval series or interval dataframe
if dtype == "interval":
Expand Down
12 changes: 4 additions & 8 deletions python/cudf/cudf/core/column/lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,7 @@
from cudf._lib.strings.convert.convert_lists import format_list_column
from cudf._lib.types import size_type_dtype
from cudf._typing import ColumnBinaryOperand, ColumnLike, Dtype, ScalarLike
from cudf.api.types import (
_is_non_decimal_numeric_dtype,
is_list_dtype,
is_scalar,
)
from cudf.api.types import _is_non_decimal_numeric_dtype, is_scalar
from cudf.core.column import ColumnBase, as_column, column
from cudf.core.column.methods import ColumnMethods, ParentType
from cudf.core.dtypes import ListDtype
Expand Down Expand Up @@ -298,7 +294,7 @@ class ListMethods(ColumnMethods):
_column: ListColumn

def __init__(self, parent: ParentType):
if not is_list_dtype(parent.dtype):
if not isinstance(parent.dtype, ListDtype):
raise AttributeError(
"Can only use .list accessor with a 'list' dtype"
)
Expand Down Expand Up @@ -589,7 +585,7 @@ def unique(self) -> ParentType:
dtype: list
"""

if is_list_dtype(self._column.children[1].dtype):
if isinstance(self._column.children[1].dtype, ListDtype):
raise NotImplementedError("Nested lists unique is not supported.")

return self._return_or_inplace(
Expand Down Expand Up @@ -642,7 +638,7 @@ def sort_values(
raise NotImplementedError("`kind` not currently implemented.")
if na_position not in {"first", "last"}:
raise ValueError(f"Unknown `na_position` value {na_position}")
if is_list_dtype(self._column.children[1].dtype):
if isinstance(self._column.children[1].dtype, ListDtype):
raise NotImplementedError("Nested lists sort is not supported.")

return self._return_or_inplace(
Expand Down
9 changes: 2 additions & 7 deletions python/cudf/cudf/core/column/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,7 @@
from cudf._lib import string_casting as str_cast, strings as libstrings
from cudf._lib.column import Column
from cudf._lib.types import size_type_dtype
from cudf.api.types import (
is_integer,
is_list_dtype,
is_scalar,
is_string_dtype,
)
from cudf.api.types import is_integer, is_scalar, is_string_dtype
from cudf.core.buffer import Buffer
from cudf.core.column import column, datetime
from cudf.core.column.column import ColumnBase
Expand Down Expand Up @@ -126,7 +121,7 @@ class StringMethods(ColumnMethods):
def __init__(self, parent):
value_type = (
parent.dtype.leaf_type
if is_list_dtype(parent.dtype)
if isinstance(parent.dtype, cudf.ListDtype)
else parent.dtype
)
if not is_string_dtype(value_type):
Expand Down
3 changes: 1 addition & 2 deletions python/cudf/cudf/core/column/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import cudf
from cudf._typing import Dtype
from cudf.api.types import is_struct_dtype
from cudf.core.column import ColumnBase, build_struct_column
from cudf.core.column.methods import ColumnMethods
from cudf.core.dtypes import StructDtype
Expand Down Expand Up @@ -158,7 +157,7 @@ class StructMethods(ColumnMethods):
_column: StructColumn

def __init__(self, parent=None):
if not is_struct_dtype(parent.dtype):
if not isinstance(parent.dtype, StructDtype):
raise AttributeError(
"Can only use .struct accessor with a 'struct' dtype"
)
Expand Down
6 changes: 3 additions & 3 deletions python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,11 @@
is_datetime_dtype,
is_dict_like,
is_dtype_equal,
is_list_dtype,
is_list_like,
is_numeric_dtype,
is_object_dtype,
is_scalar,
is_string_dtype,
is_struct_dtype,
)
from cudf.core import column, df_protocol, indexing_utils, reshape
from cudf.core.abc import Serializable
Expand Down Expand Up @@ -1825,7 +1823,9 @@ def _clean_nulls_from_dataframe(self, df):
filling with `<NA>` values.
"""
for col in df._data:
if is_list_dtype(df._data[col]) or is_struct_dtype(df._data[col]):
if isinstance(
df._data[col].dtype, (cudf.StructDtype, cudf.ListDtype)
):
# TODO we need to handle this
pass
elif df._data[col].has_nulls():
Expand Down
10 changes: 8 additions & 2 deletions python/cudf/cudf/core/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,7 +1050,10 @@ def is_list_dtype(obj):
or type(obj) is cudf.core.column.ListColumn
or obj is cudf.core.column.ListColumn
or (isinstance(obj, str) and obj == cudf.core.dtypes.ListDtype.name)
or (hasattr(obj, "dtype") and is_list_dtype(obj.dtype))
or (
hasattr(obj, "dtype")
and isinstance(obj.dtype, cudf.core.dtypes.ListDtype)
)
)


Expand All @@ -1076,7 +1079,10 @@ def is_struct_dtype(obj):
isinstance(obj, cudf.core.dtypes.StructDtype)
or obj is cudf.core.dtypes.StructDtype
or (isinstance(obj, str) and obj == cudf.core.dtypes.StructDtype.name)
or (hasattr(obj, "dtype") and is_struct_dtype(obj.dtype))
or (
hasattr(obj, "dtype")
and isinstance(obj.dtype, cudf.core.dtypes.StructDtype)
)
)


Expand Down
3 changes: 1 addition & 2 deletions python/cudf/cudf/core/indexed_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
is_bool_dtype,
is_decimal_dtype,
is_dict_like,
is_list_dtype,
is_list_like,
is_scalar,
)
Expand Down Expand Up @@ -4101,7 +4100,7 @@ def _explode(self, explode_column: Any, ignore_index: bool):
# specified nested column. Other columns' corresponding rows are
# duplicated. If ignore_index is set, the original index is not
# exploded and will be replaced with a `RangeIndex`.
if not is_list_dtype(self._data[explode_column].dtype):
if not isinstance(self._data[explode_column].dtype, ListDtype):
data = self._data.copy(deep=True)
idx = None if ignore_index else self._index.copy(deep=True)
return self.__class__._from_data(data, index=idx)
Expand Down
15 changes: 7 additions & 8 deletions python/cudf/cudf/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,12 @@
_is_non_decimal_numeric_dtype,
_is_scalar_or_zero_d_array,
is_bool_dtype,
is_decimal_dtype,
is_dict_like,
is_float_dtype,
is_integer,
is_integer_dtype,
is_list_dtype,
is_scalar,
is_string_dtype,
is_struct_dtype,
)
from cudf.core import indexing_utils
from cudf.core.abc import Serializable
Expand Down Expand Up @@ -1502,12 +1499,14 @@ def __repr__(self):
if (
preprocess.nullable
and not isinstance(
preprocess._column, cudf.core.column.CategoricalColumn
preprocess.dtype,
(
cudf.CategoricalDtype,
cudf.ListDtype,
cudf.StructDtype,
cudf.core.dtypes.DecimalDtype,
),
)
and not is_list_dtype(preprocess.dtype)
and not is_struct_dtype(preprocess.dtype)
and not is_decimal_dtype(preprocess.dtype)
and not is_struct_dtype(preprocess.dtype)
) or isinstance(
preprocess._column,
cudf.core.column.timedelta.TimeDeltaColumn,
Expand Down
4 changes: 1 addition & 3 deletions python/cudf/cudf/core/tools/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
from cudf.api.types import (
_is_non_decimal_numeric_dtype,
is_datetime_dtype,
is_list_dtype,
is_string_dtype,
is_struct_dtype,
is_timedelta_dtype,
)
from cudf.core.column import as_column
Expand Down Expand Up @@ -132,7 +130,7 @@ def to_numeric(arg, errors="raise", downcast=None):
return arg
else:
raise e
elif is_list_dtype(dtype) or is_struct_dtype(dtype):
elif isinstance(dtype, (cudf.ListDtype, cudf.StructDtype)):
raise ValueError("Input does not support nested datatypes")
elif _is_non_decimal_numeric_dtype(dtype):
pass
Expand Down
16 changes: 9 additions & 7 deletions python/cudf/cudf/testing/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,8 @@
from cudf._lib.unary import is_nan
from cudf.api.types import (
is_categorical_dtype,
is_decimal_dtype,
is_list_dtype,
is_numeric_dtype,
is_string_dtype,
is_struct_dtype,
)
from cudf.core.missing import NA, NaT

Expand All @@ -26,10 +23,15 @@ def dtype_can_compare_equal_to_other(dtype):
# as equal to equal values of a different dtype
return not (
is_string_dtype(dtype)
or is_list_dtype(dtype)
or is_struct_dtype(dtype)
or is_decimal_dtype(dtype)
or isinstance(dtype, cudf.IntervalDtype)
or isinstance(
dtype,
(
cudf.IntervalDtype,
cudf.ListDtype,
cudf.StructDtype,
cudf.core.dtypes.DecimalDtype,
),
)
)


Expand Down
Loading

0 comments on commit cbcfa67

Please sign in to comment.