Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Column.can_cast_safely instead of some ad-hoc dtype functions in .where #16303

Merged
merged 3 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 61 additions & 17 deletions python/cudf/cudf/core/_internals/where.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,7 @@
import cudf
from cudf.api.types import _is_non_decimal_numeric_dtype, is_scalar
from cudf.core.dtypes import CategoricalDtype
from cudf.utils.dtypes import (
_can_cast,
_dtype_can_hold_element,
find_common_type,
is_mixed_with_object_dtype,
)
from cudf.utils.dtypes import find_common_type, is_mixed_with_object_dtype

if TYPE_CHECKING:
from cudf._typing import ScalarLike
Expand Down Expand Up @@ -44,6 +39,8 @@ def _check_and_cast_columns_with_other(
inplace: bool,
) -> tuple[ColumnBase, ScalarLike | ColumnBase]:
# Returns type-casted `source_col` & `other` based on `inplace`.
from cudf.core.column import as_column

source_dtype = source_col.dtype
if isinstance(source_dtype, CategoricalDtype):
return _normalize_categorical(source_col, other)
Expand Down Expand Up @@ -84,17 +81,9 @@ def _check_and_cast_columns_with_other(
)
return _normalize_categorical(source_col, other.astype(source_dtype))

if (
_is_non_decimal_numeric_dtype(source_dtype)
and not other_is_scalar # can-cast fails for Python scalars
and _can_cast(other, source_dtype)
):
common_dtype = source_dtype
elif (
isinstance(source_col, cudf.core.column.NumericalColumn)
and other_is_scalar
and _dtype_can_hold_element(source_dtype, other)
):
if _is_non_decimal_numeric_dtype(source_dtype) and as_column(
other
).can_cast_safely(source_dtype):
common_dtype = source_dtype
else:
common_dtype = find_common_type(
Expand Down Expand Up @@ -130,3 +119,58 @@ def _make_categorical_like(result, column):
ordered=column.ordered,
)
return result


def _can_cast(from_dtype, to_dtype):
"""
Utility function to determine if we can cast
from `from_dtype` to `to_dtype`. This function primarily calls
`np.can_cast` but with some special handling around
cudf specific dtypes.
"""
if cudf.utils.utils.is_na_like(from_dtype):
return True
if isinstance(from_dtype, type):
from_dtype = cudf.dtype(from_dtype)
if isinstance(to_dtype, type):
to_dtype = cudf.dtype(to_dtype)

# TODO : Add precision & scale checking for
# decimal types in future

if isinstance(from_dtype, cudf.core.dtypes.DecimalDtype):
if isinstance(to_dtype, cudf.core.dtypes.DecimalDtype):
return True
elif isinstance(to_dtype, np.dtype):
if to_dtype.kind in {"i", "f", "u", "U", "O"}:
return True
else:
return False
elif isinstance(from_dtype, np.dtype):
if isinstance(to_dtype, np.dtype):
return np.can_cast(from_dtype, to_dtype)
elif isinstance(to_dtype, cudf.core.dtypes.DecimalDtype):
if from_dtype.kind in {"i", "f", "u", "U", "O"}:
return True
else:
return False
elif isinstance(to_dtype, cudf.core.types.CategoricalDtype):
return True
else:
return False
elif isinstance(from_dtype, cudf.core.dtypes.ListDtype):
# TODO: Add level based checks too once casting of
# list columns is supported
if isinstance(to_dtype, cudf.core.dtypes.ListDtype):
return np.can_cast(from_dtype.leaf_type, to_dtype.leaf_type)
else:
return False
elif isinstance(from_dtype, cudf.core.dtypes.CategoricalDtype):
if isinstance(to_dtype, cudf.core.dtypes.CategoricalDtype):
return True
elif isinstance(to_dtype, np.dtype):
return np.can_cast(from_dtype._categories.dtype, to_dtype)
else:
return False
else:
return np.can_cast(from_dtype, to_dtype)
96 changes: 1 addition & 95 deletions python/cudf/cudf/utils/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
from pandas.core.dtypes.common import infer_dtype_from_object

import cudf
from cudf._typing import DtypeObj
from cudf.api.types import is_bool, is_float, is_integer

"""Map numpy dtype to pyarrow types.
Note that np.bool_ bitwidth (8) is different from pa.bool_ (1). Special
Expand Down Expand Up @@ -584,61 +582,6 @@ def _dtype_pandas_compatible(dtype):
return dtype


def _can_cast(from_dtype, to_dtype):
"""
Utility function to determine if we can cast
from `from_dtype` to `to_dtype`. This function primarily calls
`np.can_cast` but with some special handling around
cudf specific dtypes.
"""
if cudf.utils.utils.is_na_like(from_dtype):
return True
if isinstance(from_dtype, type):
from_dtype = cudf.dtype(from_dtype)
if isinstance(to_dtype, type):
to_dtype = cudf.dtype(to_dtype)

# TODO : Add precision & scale checking for
# decimal types in future

if isinstance(from_dtype, cudf.core.dtypes.DecimalDtype):
if isinstance(to_dtype, cudf.core.dtypes.DecimalDtype):
return True
elif isinstance(to_dtype, np.dtype):
if to_dtype.kind in {"i", "f", "u", "U", "O"}:
return True
else:
return False
elif isinstance(from_dtype, np.dtype):
if isinstance(to_dtype, np.dtype):
return np.can_cast(from_dtype, to_dtype)
elif isinstance(to_dtype, cudf.core.dtypes.DecimalDtype):
if from_dtype.kind in {"i", "f", "u", "U", "O"}:
return True
else:
return False
elif isinstance(to_dtype, cudf.core.types.CategoricalDtype):
return True
else:
return False
elif isinstance(from_dtype, cudf.core.dtypes.ListDtype):
# TODO: Add level based checks too once casting of
# list columns is supported
if isinstance(to_dtype, cudf.core.dtypes.ListDtype):
return np.can_cast(from_dtype.leaf_type, to_dtype.leaf_type)
else:
return False
elif isinstance(from_dtype, cudf.core.dtypes.CategoricalDtype):
if isinstance(to_dtype, cudf.core.dtypes.CategoricalDtype):
return True
elif isinstance(to_dtype, np.dtype):
return np.can_cast(from_dtype._categories.dtype, to_dtype)
else:
return False
else:
return np.can_cast(from_dtype, to_dtype)


def _maybe_convert_to_default_type(dtype):
"""Convert `dtype` to default if specified by user.

Expand All @@ -661,44 +604,7 @@ def _maybe_convert_to_default_type(dtype):
return dtype


def _dtype_can_hold_range(rng: range, dtype: np.dtype) -> bool:
if not len(rng):
return True
return np.can_cast(rng[0], dtype) and np.can_cast(rng[-1], dtype)


def _dtype_can_hold_element(dtype: np.dtype, element) -> bool:
if dtype.kind in {"i", "u"}:
if isinstance(element, range):
if _dtype_can_hold_range(element, dtype):
return True
return False

elif is_integer(element) or (
is_float(element) and element.is_integer()
):
info = np.iinfo(dtype)
if info.min <= element <= info.max:
return True
return False

elif dtype.kind == "f":
if is_integer(element) or is_float(element):
casted = dtype.type(element)
if np.isnan(casted) or casted == element:
return True
# otherwise e.g. overflow see TestCoercionFloat32
return False

elif dtype.kind == "b":
if is_bool(element):
return True
return False

raise NotImplementedError(f"Unsupported dtype: {dtype}")


def _get_base_dtype(dtype: DtypeObj) -> DtypeObj:
def _get_base_dtype(dtype: pd.DatetimeTZDtype) -> np.dtype:
# TODO: replace the use of this function with just `dtype.base`
# when Pandas 2.1.0 is the minimum version we support:
# https://github.com/pandas-dev/pandas/pull/52706
Expand Down
Loading