Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Respect IntervalDtype and CategoricalDtype objects passed by users #14961

Merged
merged 16 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 50 additions & 117 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
is_datetime64_dtype,
is_dtype_equal,
is_integer_dtype,
is_list_dtype,
is_scalar,
is_string_dtype,
)
Expand Down Expand Up @@ -2161,59 +2160,57 @@ def as_column(
return as_column(
np.asarray(view), dtype=dtype, nan_as_null=nan_as_null
)
# Start of arbitrary that's not handed above but dtype provided
elif isinstance(dtype, pd.DatetimeTZDtype):
raise NotImplementedError(
"Use `tz_localize()` to construct timezone aware data."
)
elif isinstance(dtype, cudf.core.dtypes.DecimalDtype):
# Arrow throws a type error if the input is of
# mixed-precision and cannot fit into the provided
# decimal type properly, see:
# https://github.com/apache/arrow/pull/9948
# Hence we should let the exception propagate to
# the user.
data = pa.array(
arbitrary,
type=pa.decimal128(precision=dtype.precision, scale=dtype.scale),
)
if isinstance(dtype, cudf.core.dtypes.Decimal128Dtype):
return cudf.core.column.Decimal128Column.from_arrow(data)
elif isinstance(dtype, cudf.core.dtypes.Decimal64Dtype):
return cudf.core.column.Decimal64Column.from_arrow(data)
elif isinstance(dtype, cudf.core.dtypes.Decimal32Dtype):
return cudf.core.column.Decimal32Column.from_arrow(data)
else:
raise NotImplementedError(f"{dtype} not implemented")
elif isinstance(
dtype,
(
pd.CategoricalDtype,
cudf.CategoricalDtype,
pd.IntervalDtype,
cudf.IntervalDtype,
),
) or dtype in {"category", "interval", "str", str, np.str_}:
if isinstance(dtype, (cudf.CategoricalDtype, cudf.IntervalDtype)):
dtype = dtype.to_pandas()
ser = pd.Series(arbitrary, dtype=dtype)
return as_column(ser, nan_as_null=nan_as_null)
elif isinstance(dtype, (cudf.StructDtype, cudf.ListDtype)):
try:
data = pa.array(arbitrary, type=dtype.to_arrow())
except (pa.ArrowInvalid, pa.ArrowTypeError):
if isinstance(dtype, cudf.ListDtype):
# e.g. test_cudf_list_struct_write
return cudf.core.column.ListColumn.from_sequences(arbitrary)
raise
return as_column(data, nan_as_null=nan_as_null)
else:
if dtype is not None:
# Arrow throws a type error if the input is of
# mixed-precision and cannot fit into the provided
# decimal type properly, see:
# https://github.com/apache/arrow/pull/9948
# Hence we should let the exception propagate to
# the user.
if isinstance(dtype, cudf.core.dtypes.Decimal128Dtype):
data = pa.array(
arbitrary,
type=pa.decimal128(
precision=dtype.precision, scale=dtype.scale
),
)
return cudf.core.column.Decimal128Column.from_arrow(data)
elif isinstance(dtype, cudf.core.dtypes.Decimal64Dtype):
data = pa.array(
arbitrary,
type=pa.decimal128(
precision=dtype.precision, scale=dtype.scale
),
)
return cudf.core.column.Decimal64Column.from_arrow(data)
elif isinstance(dtype, cudf.core.dtypes.Decimal32Dtype):
data = pa.array(
arbitrary,
type=pa.decimal128(
precision=dtype.precision, scale=dtype.scale
),
)
return cudf.core.column.Decimal32Column.from_arrow(data)

pa_type = None
np_type = None
try:
if dtype is not None:
if dtype in {"category", "interval"} or isinstance(
dtype,
(
cudf.CategoricalDtype,
cudf.IntervalDtype,
pd.IntervalDtype,
pd.CategoricalDtype,
),
):
raise TypeError
if isinstance(dtype, pd.DatetimeTZDtype):
raise NotImplementedError(
"Use `tz_localize()` to construct "
"timezone aware data."
)
elif is_datetime64_dtype(dtype):
if is_datetime64_dtype(dtype):
# Error checking only, actual construction happens
# below.
pa_array = pa.array(arbitrary)
Expand All @@ -2225,42 +2222,6 @@ def as_column(
"cuDF does not yet support timezone-aware "
"datetimes"
)
if is_list_dtype(dtype):
data = pa.array(arbitrary)
if type(data) not in (pa.ListArray, pa.NullArray):
raise ValueError(
"Cannot create list column from given data"
)
return as_column(data, nan_as_null=nan_as_null)
elif isinstance(dtype, cudf.StructDtype) and not isinstance(
dtype, cudf.IntervalDtype
):
data = pa.array(arbitrary, type=dtype.to_arrow())
return as_column(data, nan_as_null=nan_as_null)
elif isinstance(dtype, cudf.core.dtypes.Decimal128Dtype):
data = pa.array(
arbitrary,
type=pa.decimal128(
precision=dtype.precision, scale=dtype.scale
),
)
return cudf.core.column.Decimal128Column.from_arrow(data)
elif isinstance(dtype, cudf.core.dtypes.Decimal64Dtype):
data = pa.array(
arbitrary,
type=pa.decimal128(
precision=dtype.precision, scale=dtype.scale
),
)
return cudf.core.column.Decimal64Column.from_arrow(data)
elif isinstance(dtype, cudf.core.dtypes.Decimal32Dtype):
data = pa.array(
arbitrary,
type=pa.decimal128(
precision=dtype.precision, scale=dtype.scale
),
)
return cudf.core.column.Decimal32Column.from_arrow(data)
if is_bool_dtype(dtype):
# Need this special case handling for bool dtypes,
# since 'boolean' & 'pd.BooleanDtype' are not
Expand All @@ -2273,7 +2234,6 @@ def as_column(
raise NotImplementedError(
f"{dtype=} is not supported."
)
np_type = np_dtype.type
pa_type = np_to_pa_dtype(np_dtype)
else:
# By default cudf constructs a 64-bit column. Setting
Expand All @@ -2296,15 +2256,6 @@ def as_column(
_maybe_convert_to_default_type("float")
)

if (
cudf.get_option("mode.pandas_compatible")
and isinstance(
arbitrary, (pd.Index, pd.api.extensions.ExtensionArray)
)
and _is_pandas_nullable_extension_dtype(arbitrary.dtype)
):
raise NotImplementedError("not supported")

pyarrow_array = pa.array(
arbitrary,
type=pa_type,
Expand All @@ -2325,16 +2276,6 @@ def as_column(
dtype = cudf.dtype("str")
pyarrow_array = pyarrow_array.cast(np_to_pa_dtype(dtype))

if (
isinstance(arbitrary, pd.Index)
and arbitrary.dtype == cudf.dtype("object")
and (
cudf.dtype(pyarrow_array.type.to_pandas_dtype())
!= cudf.dtype(arbitrary.dtype)
)
):
raise MixedTypeError("Cannot create column with mixed types")

if (
cudf.get_option("mode.pandas_compatible")
and pa.types.is_integer(pyarrow_array.type)
Expand All @@ -2350,24 +2291,16 @@ def as_column(
except (pa.ArrowInvalid, pa.ArrowTypeError, TypeError) as e:
if isinstance(e, MixedTypeError):
raise TypeError(str(e))
if _is_categorical_dtype(dtype):
sr = pd.Series(arbitrary, dtype="category")
data = as_column(sr, nan_as_null=nan_as_null, dtype=dtype)
elif np_type == np.str_:
sr = pd.Series(arbitrary, dtype="str")
data = as_column(sr, nan_as_null=nan_as_null)
elif dtype == "interval" or isinstance(
dtype, (pd.IntervalDtype, cudf.IntervalDtype)
):
sr = pd.Series(arbitrary, dtype="interval")
data = as_column(sr, nan_as_null=nan_as_null, dtype=dtype)
elif (
isinstance(arbitrary, Sequence)
and len(arbitrary) > 0
and any(
cudf.utils.dtypes.is_column_like(arb) for arb in arbitrary
)
):
# TODO: I think can be removed; covered by
# elif isinstance(dtype, (cudf.StructDtype, cudf.ListDtype)):
# above
return cudf.core.column.ListColumn.from_sequences(arbitrary)
elif isinstance(arbitrary, abc.Iterable) or isinstance(
arbitrary, abc.Sequence
Expand Down
8 changes: 6 additions & 2 deletions python/cudf/cudf/core/column/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ def as_interval_column(self, dtype):
mask=self.mask,
offset=self.offset,
null_count=self.null_count,
children=self.children,
children=tuple(
child.astype(dtype.subtype) for child in self.children
),
)
else:
raise ValueError("dtype must be IntervalDtype")
Expand All @@ -124,8 +126,10 @@ def to_pandas(
raise NotImplementedError(f"{nullable=} is not implemented.")
elif arrow_type:
raise NotImplementedError(f"{arrow_type=} is not implemented.")

pd_type = self.dtype.to_pandas()
return pd.Series(
self.dtype.to_pandas().__from_arrow__(self.to_arrow()), index=index
pd_type.__from_arrow__(self.to_arrow()), index=index, dtype=pd_type
)

def element_indexing(self, index: int):
Expand Down
16 changes: 14 additions & 2 deletions python/cudf/cudf/tests/indexes/test_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,18 @@ def test_interval_range_freq_basic_dtype(start_t, end_t, freq_t):
gindex = cudf.interval_range(
start=start, end=end, freq=freq, closed="left"
)
if gindex.dtype.subtype.kind == "f":
gindex = gindex.astype(
cudf.IntervalDtype(subtype="float64", closed=gindex.dtype.closed)
)
elif gindex.dtype.subtype.kind == "i":
gindex = gindex.astype(
cudf.IntervalDtype(subtype="int64", closed=gindex.dtype.closed)
)

assert_eq(pindex, gindex)
# pandas upcasts to 64 bit https://github.com/pandas-dev/pandas/issues/57268
# using Series to use check_dtype
assert_eq(pd.Series(pindex), cudf.Series(gindex), check_dtype=False)


@pytest.mark.parametrize("closed", ["left", "right", "both", "neither"])
Expand Down Expand Up @@ -221,7 +231,9 @@ def test_interval_range_periods_freq_start_dtype(periods_t, freq_t, start_t):
start=start, freq=freq, periods=periods, closed="left"
)

assert_eq(pindex, gindex)
# pandas upcasts to 64 bit https://github.com/pandas-dev/pandas/issues/57268
# using Series to use check_dtype
assert_eq(pd.Series(pindex), cudf.Series(gindex), check_dtype=False)


@pytest.mark.parametrize("closed", ["right", "left", "both", "neither"])
Expand Down
16 changes: 16 additions & 0 deletions python/cudf/cudf/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2663,6 +2663,22 @@ def test_series_duplicate_index_reindex():
)


def test_list_category_like_maintains_dtype():
dtype = cudf.CategoricalDtype(categories=[1, 2, 3, 4], ordered=True)
data = [1, 2, 3]
result = cudf.Series(cudf.core.column.as_column(data, dtype=dtype))
expected = pd.Series(data, dtype=dtype.to_pandas())
assert_eq(result, expected)


def test_list_interval_like_maintains_dtype():
dtype = cudf.IntervalDtype(subtype=np.int8)
data = [pd.Interval(1, 2)]
result = cudf.Series(cudf.core.column.as_column(data, dtype=dtype))
expected = pd.Series(data, dtype=dtype.to_pandas())
assert_eq(result, expected)


@pytest.mark.parametrize(
"klass", [cudf.Series, cudf.Index, pd.Series, pd.Index]
)
Expand Down
Loading