From db5a41fd106d907a4b9fae430d1c3dfb647be5b3 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Fri, 2 Feb 2024 17:34:18 -0800 Subject: [PATCH 1/5] Move as_column dtype= logic out of try except --- python/cudf/cudf/core/column/column.py | 150 +++++++---------------- python/cudf/cudf/core/column/interval.py | 3 +- python/cudf/cudf/tests/test_series.py | 16 +++ 3 files changed, 59 insertions(+), 110 deletions(-) diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index 9143c7f5e9e..3afbd35263a 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -53,8 +53,6 @@ from cudf._typing import ColumnLike, Dtype, ScalarLike from cudf.api.types import ( _is_categorical_dtype, - _is_datetime64tz_dtype, - _is_interval_dtype, _is_non_decimal_numeric_dtype, _is_pandas_nullable_extension_dtype, infer_dtype, @@ -63,7 +61,6 @@ is_decimal_dtype, is_dtype_equal, is_integer_dtype, - is_list_dtype, is_scalar, is_string_dtype, ) @@ -2226,51 +2223,51 @@ def as_column( ) elif isinstance(arbitrary, cudf.Scalar): data = ColumnBase.from_scalar(arbitrary, length if length else 1) + # Start of arbitrary that's not handed above but dtype provided + elif isinstance(dtype, pd.DatetimeTZDtype): + raise NotImplementedError( + "Use `tz_localize()` to construct " "timezone aware data." + ) + elif isinstance(dtype, cudf.core.dtypes.DecimalDtype): + # Arrow throws a type error if the input is of + # mixed-precision and cannot fit into the provided + # decimal type properly, see: + # https://github.com/apache/arrow/pull/9948 + # Hence we should let the exception propagate to + # the user. + data = pa.array( + arbitrary, + type=pa.decimal128(precision=dtype.precision, scale=dtype.scale), + ) + if isinstance(dtype, cudf.core.dtypes.Decimal128Dtype): + return cudf.core.column.Decimal128Column.from_arrow(data) + elif isinstance(dtype, cudf.core.dtypes.Decimal64Dtype): + return cudf.core.column.Decimal64Column.from_arrow(data) + elif isinstance(dtype, cudf.core.dtypes.Decimal32Dtype): + return cudf.core.column.Decimal32Column.from_arrow(data) + else: + raise NotImplementedError(f"{dtype} not implemented") + elif isinstance( + dtype, + ( + pd.CategoricalDtype, + cudf.CategoricalDtype, + pd.IntervalDtype, + cudf.IntervalDtype, + ), + ) or dtype in {"category", "interval", "str", str, np.str_}: + if isinstance(dtype, (cudf.CategoricalDtype, cudf.IntervalDtype)): + dtype = dtype.to_pandas() + ser = pd.Series(arbitrary, dtype=dtype) + return as_column(ser, nan_as_null=nan_as_null) + elif isinstance(dtype, (cudf.StructDtype, cudf.ListDtype)): + data = pa.array(arbitrary, type=dtype.to_arrow()) + return as_column(data, nan_as_null=nan_as_null) else: - if dtype is not None: - # Arrow throws a type error if the input is of - # mixed-precision and cannot fit into the provided - # decimal type properly, see: - # https://github.com/apache/arrow/pull/9948 - # Hence we should let the exception propagate to - # the user. - if isinstance(dtype, cudf.core.dtypes.Decimal128Dtype): - data = pa.array( - arbitrary, - type=pa.decimal128( - precision=dtype.precision, scale=dtype.scale - ), - ) - return cudf.core.column.Decimal128Column.from_arrow(data) - elif isinstance(dtype, cudf.core.dtypes.Decimal64Dtype): - data = pa.array( - arbitrary, - type=pa.decimal128( - precision=dtype.precision, scale=dtype.scale - ), - ) - return cudf.core.column.Decimal64Column.from_arrow(data) - elif isinstance(dtype, cudf.core.dtypes.Decimal32Dtype): - data = pa.array( - arbitrary, - type=pa.decimal128( - precision=dtype.precision, scale=dtype.scale - ), - ) - return cudf.core.column.Decimal32Column.from_arrow(data) - pa_type = None - np_type = None try: if dtype is not None: - if _is_categorical_dtype(dtype) or _is_interval_dtype(dtype): - raise TypeError - if _is_datetime64tz_dtype(dtype): - raise NotImplementedError( - "Use `tz_localize()` to construct " - "timezone aware data." - ) - elif is_datetime64_dtype(dtype): + if is_datetime64_dtype(dtype): # Error checking only, actual construction happens # below. pa_array = pa.array(arbitrary) @@ -2282,42 +2279,6 @@ def as_column( "cuDF does not yet support timezone-aware " "datetimes" ) - if is_list_dtype(dtype): - data = pa.array(arbitrary) - if type(data) not in (pa.ListArray, pa.NullArray): - raise ValueError( - "Cannot create list column from given data" - ) - return as_column(data, nan_as_null=nan_as_null) - elif isinstance(dtype, cudf.StructDtype) and not isinstance( - dtype, cudf.IntervalDtype - ): - data = pa.array(arbitrary, type=dtype.to_arrow()) - return as_column(data, nan_as_null=nan_as_null) - elif isinstance(dtype, cudf.core.dtypes.Decimal128Dtype): - data = pa.array( - arbitrary, - type=pa.decimal128( - precision=dtype.precision, scale=dtype.scale - ), - ) - return cudf.core.column.Decimal128Column.from_arrow(data) - elif isinstance(dtype, cudf.core.dtypes.Decimal64Dtype): - data = pa.array( - arbitrary, - type=pa.decimal128( - precision=dtype.precision, scale=dtype.scale - ), - ) - return cudf.core.column.Decimal64Column.from_arrow(data) - elif isinstance(dtype, cudf.core.dtypes.Decimal32Dtype): - data = pa.array( - arbitrary, - type=pa.decimal128( - precision=dtype.precision, scale=dtype.scale - ), - ) - return cudf.core.column.Decimal32Column.from_arrow(data) if is_bool_dtype(dtype): # Need this special case handling for bool dtypes, # since 'boolean' & 'pd.BooleanDtype' are not @@ -2330,7 +2291,6 @@ def as_column( raise NotImplementedError( f"{dtype=} is not supported." ) - np_type = np_dtype.type pa_type = np_to_pa_dtype(np_dtype) else: # By default cudf constructs a 64-bit column. Setting @@ -2353,15 +2313,6 @@ def as_column( _maybe_convert_to_default_type("float") ) - if ( - cudf.get_option("mode.pandas_compatible") - and isinstance( - arbitrary, (pd.Index, pd.api.extensions.ExtensionArray) - ) - and _is_pandas_nullable_extension_dtype(arbitrary.dtype) - ): - raise NotImplementedError("not supported") - pyarrow_array = pa.array( arbitrary, type=pa_type, @@ -2382,16 +2333,6 @@ def as_column( dtype = cudf.dtype("str") pyarrow_array = pyarrow_array.cast(np_to_pa_dtype(dtype)) - if ( - isinstance(arbitrary, pd.Index) - and arbitrary.dtype == cudf.dtype("object") - and ( - cudf.dtype(pyarrow_array.type.to_pandas_dtype()) - != cudf.dtype(arbitrary.dtype) - ) - ): - raise MixedTypeError("Cannot create column with mixed types") - if ( cudf.get_option("mode.pandas_compatible") and pa.types.is_integer(pyarrow_array.type) @@ -2407,15 +2348,6 @@ def as_column( except (pa.ArrowInvalid, pa.ArrowTypeError, TypeError) as e: if isinstance(e, MixedTypeError): raise TypeError(str(e)) - if _is_categorical_dtype(dtype): - sr = pd.Series(arbitrary, dtype="category") - data = as_column(sr, nan_as_null=nan_as_null, dtype=dtype) - elif np_type == np.str_: - sr = pd.Series(arbitrary, dtype="str") - data = as_column(sr, nan_as_null=nan_as_null) - elif _is_interval_dtype(dtype): - sr = pd.Series(arbitrary, dtype="interval") - data = as_column(sr, nan_as_null=nan_as_null, dtype=dtype) elif ( isinstance(arbitrary, Sequence) and len(arbitrary) > 0 diff --git a/python/cudf/cudf/core/column/interval.py b/python/cudf/cudf/core/column/interval.py index f5d527ad201..253d0dbfa54 100644 --- a/python/cudf/cudf/core/column/interval.py +++ b/python/cudf/cudf/core/column/interval.py @@ -122,8 +122,9 @@ def to_pandas( # directly is problematic), so we're stuck with this for now. if nullable: raise NotImplementedError(f"{nullable=} is not implemented.") + pd_type = self.dtype.to_pandas() return pd.Series( - self.dtype.to_pandas().__from_arrow__(self.to_arrow()), index=index + pd_type.__from_arrow__(self.to_arrow()), index=index, dtype=pd_type ) def element_indexing(self, index: int): diff --git a/python/cudf/cudf/tests/test_series.py b/python/cudf/cudf/tests/test_series.py index 14006f90b45..ca0b754bd3a 100644 --- a/python/cudf/cudf/tests/test_series.py +++ b/python/cudf/cudf/tests/test_series.py @@ -2659,6 +2659,22 @@ def test_series_duplicate_index_reindex(): ) +def test_list_category_like_maintains_dtype(): + dtype = cudf.CategoricalDtype(categories=[1, 2, 3, 4], ordered=True) + data = [1, 2, 3] + result = cudf.Series(cudf.core.column.as_column(data, dtype=dtype)) + expected = pd.Series(data, dtype=dtype.to_pandas()) + assert_eq(result, expected) + + +def test_list_interval_like_maintains_dtype(): + dtype = cudf.IntervalDtype(subtype=np.int8) + data = [pd.Interval(1, 2)] + result = cudf.Series(cudf.core.column.as_column(data, dtype=dtype)) + expected = pd.Series(data, dtype=dtype.to_pandas()) + assert_eq(result, expected) + + @pytest.mark.parametrize( "klass", [cudf.Series, cudf.Index, pd.Series, pd.Index] ) From e2da6af064d4a0a5d31244986d4e255e51c03ce3 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Mon, 5 Feb 2024 15:10:06 -0800 Subject: [PATCH 2/5] Fix astypeing with Interval --- python/cudf/cudf/core/column/interval.py | 4 +++- python/cudf/cudf/tests/indexes/test_interval.py | 16 ++++++++++++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/python/cudf/cudf/core/column/interval.py b/python/cudf/cudf/core/column/interval.py index 5ee249828a9..e0caa1e2073 100644 --- a/python/cudf/cudf/core/column/interval.py +++ b/python/cudf/cudf/core/column/interval.py @@ -99,7 +99,9 @@ def as_interval_column(self, dtype): mask=self.mask, offset=self.offset, null_count=self.null_count, - children=self.children, + children=tuple( + child.astype(dtype.subtype) for child in self.children + ), ) else: raise ValueError("dtype must be IntervalDtype") diff --git a/python/cudf/cudf/tests/indexes/test_interval.py b/python/cudf/cudf/tests/indexes/test_interval.py index 5a6155ece29..87e7f4d47a2 100644 --- a/python/cudf/cudf/tests/indexes/test_interval.py +++ b/python/cudf/cudf/tests/indexes/test_interval.py @@ -91,8 +91,18 @@ def test_interval_range_freq_basic_dtype(start_t, end_t, freq_t): gindex = cudf.interval_range( start=start, end=end, freq=freq, closed="left" ) + if gindex.dtype.subtype.kind == "f": + gindex = gindex.astype( + cudf.IntervalDtype(subtype="float64", closed=gindex.dtype.closed) + ) + elif gindex.dtype.subtype.kind == "i": + gindex = gindex.astype( + cudf.IntervalDtype(subtype="int64", closed=gindex.dtype.closed) + ) - assert_eq(pindex, gindex) + # pandas upcasts to 64 bit https://github.com/pandas-dev/pandas/issues/57268 + # using Series to use check_dtype + assert_eq(pd.Series(pindex), cudf.Series(gindex), check_dtype=False) @pytest.mark.parametrize("closed", ["left", "right", "both", "neither"]) @@ -197,7 +207,9 @@ def test_interval_range_periods_freq_start_dtype(periods_t, freq_t, start_t): start=start, freq=freq, periods=periods, closed="left" ) - assert_eq(pindex, gindex) + # pandas upcasts to 64 bit https://github.com/pandas-dev/pandas/issues/57268 + # using Series to use check_dtype + assert_eq(pd.Series(pindex), cudf.Series(gindex), check_dtype=False) @pytest.mark.parametrize("closed", ["right", "left", "both", "neither"]) From b65f0ed8b8ec715413e47fff36ffa3fb343186b4 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Mon, 5 Feb 2024 15:11:20 -0800 Subject: [PATCH 3/5] Fix quotes --- python/cudf/cudf/core/column/column.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index 3afbd35263a..92259ee857f 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -2226,7 +2226,7 @@ def as_column( # Start of arbitrary that's not handed above but dtype provided elif isinstance(dtype, pd.DatetimeTZDtype): raise NotImplementedError( - "Use `tz_localize()` to construct " "timezone aware data." + "Use `tz_localize()` to construct timezone aware data." ) elif isinstance(dtype, cudf.core.dtypes.DecimalDtype): # Arrow throws a type error if the input is of From 0ac707c335f25a5476dc767be1bff373c4feba01 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Mon, 5 Feb 2024 19:04:35 -0800 Subject: [PATCH 4/5] Fix nested list --- python/cudf/cudf/core/column/column.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index 92259ee857f..bfce841f7e1 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -2261,7 +2261,13 @@ def as_column( ser = pd.Series(arbitrary, dtype=dtype) return as_column(ser, nan_as_null=nan_as_null) elif isinstance(dtype, (cudf.StructDtype, cudf.ListDtype)): - data = pa.array(arbitrary, type=dtype.to_arrow()) + try: + data = pa.array(arbitrary, type=dtype.to_arrow()) + except (pa.ArrowInvalid, pa.ArrowTypeError): + if isinstance(dtype, cudf.ListDtype): + # e.g. test_cudf_list_struct_write + return cudf.core.column.ListColumn.from_sequences(arbitrary) + raise return as_column(data, nan_as_null=nan_as_null) else: pa_type = None @@ -2355,6 +2361,9 @@ def as_column( cudf.utils.dtypes.is_column_like(arb) for arb in arbitrary ) ): + # TODO: I think can be removed; covered by + # elif isinstance(dtype, (cudf.StructDtype, cudf.ListDtype)): + # above return cudf.core.column.ListColumn.from_sequences(arbitrary) elif isinstance(arbitrary, abc.Iterable) or isinstance( arbitrary, abc.Sequence From 9507ff71a58fa4ca6e8b717ca082a899d6c6aab5 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Tue, 5 Mar 2024 15:00:08 -0800 Subject: [PATCH 5/5] Remove covered case --- python/cudf/cudf/core/column/column.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index 25e6b2384a0..072c7d4aa33 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -2160,8 +2160,6 @@ def as_column( return as_column( np.asarray(view), dtype=dtype, nan_as_null=nan_as_null ) - elif isinstance(arbitrary, cudf.Scalar): - data = ColumnBase.from_scalar(arbitrary, length if length else 1) # Start of arbitrary that's not handed above but dtype provided elif isinstance(dtype, pd.DatetimeTZDtype): raise NotImplementedError(