From f6ce668822b0ccf2a8c1481cd8afedacb62c0e04 Mon Sep 17 00:00:00 2001 From: TravisHester <34654270+TravisHester@users.noreply.github.com> Date: Wed, 11 Aug 2021 12:06:45 -0500 Subject: [PATCH 1/7] Added Series.dt.is_month_end (#8989) closes #6396 Adds support for Series.dt.is_month_end the same way it is used in pandas. Authors: - https://github.com/TravisHester - GALI PREM SAGAR (https://github.com/galipremsagar) Approvers: - Ashwin Srinath (https://github.com/shwina) - GALI PREM SAGAR (https://github.com/galipremsagar) URL: https://github.com/rapidsai/cudf/pull/8989 --- python/cudf/cudf/_lib/cpp/datetime.pxd | 3 ++ python/cudf/cudf/_lib/datetime.pyx | 10 ++++++ python/cudf/cudf/core/series.py | 46 +++++++++++++++++++++++++ python/cudf/cudf/tests/test_datetime.py | 30 ++++++++++++++++ 4 files changed, 89 insertions(+) diff --git a/python/cudf/cudf/_lib/cpp/datetime.pxd b/python/cudf/cudf/_lib/cpp/datetime.pxd index 26d25e3017e..ef97be3cf9e 100644 --- a/python/cudf/cudf/_lib/cpp/datetime.pxd +++ b/python/cudf/cudf/_lib/cpp/datetime.pxd @@ -18,5 +18,8 @@ cdef extern from "cudf/datetime.hpp" namespace "cudf::datetime" nogil: ) except + cdef unique_ptr[column] day_of_year(const column_view& column) except + cdef unique_ptr[column] is_leap_year(const column_view& column) except + + cdef unique_ptr[column] last_day_of_month( + const column_view& column + ) except + cdef unique_ptr[column] extract_quarter(const column_view& column) except + cdef unique_ptr[column] days_in_month(const column_view& column) except + diff --git a/python/cudf/cudf/_lib/datetime.pyx b/python/cudf/cudf/_lib/datetime.pyx index 51ceb7c0d8a..1b152f1a3b7 100644 --- a/python/cudf/cudf/_lib/datetime.pyx +++ b/python/cudf/cudf/_lib/datetime.pyx @@ -95,3 +95,13 @@ def days_in_month(Column col): c_result = move(libcudf_datetime.days_in_month(col_view)) return Column.from_unique_ptr(move(c_result)) + + +def last_day_of_month(Column col): + cdef unique_ptr[column] c_result + cdef column_view col_view = col.view() + + with nogil: + c_result = move(libcudf_datetime.last_day_of_month(col_view)) + + return Column.from_unique_ptr(move(c_result)) diff --git a/python/cudf/cudf/core/series.py b/python/cudf/cudf/core/series.py index dd83b69b459..d33d624b266 100644 --- a/python/cudf/cudf/core/series.py +++ b/python/cudf/cudf/core/series.py @@ -6033,6 +6033,52 @@ def days_in_month(self): name=self.series.name, ) + @property + def is_month_end(self): + """ + Boolean indicator if the date is the last day of the month. + + Returns + ------- + Series + Booleans indicating if dates are the last day of the month. + + Example + ------- + >>> import pandas as pd, cudf + >>> s = cudf.Series( + ... pd.date_range(start='2000-08-026', end='2000-09-03', freq='1D')) + >>> s + 0 2000-08-26 + 1 2000-08-27 + 2 2000-08-28 + 3 2000-08-29 + 4 2000-08-30 + 5 2000-08-31 + 6 2000-09-01 + 7 2000-09-02 + 8 2000-09-03 + dtype: datetime64[ns] + >>> s.dt.is_month_end + 0 False + 1 False + 2 False + 3 False + 4 False + 5 True + 6 False + 7 False + 8 False + dtype: bool + """ # noqa: E501 + last_day = libcudf.datetime.last_day_of_month(self.series._column) + last_day = Series._from_data( + ColumnAccessor({None: last_day}), + index=self.series._index, + name=self.series.name, + ) + return (self.day == last_day.dt.day).fillna(False) + def _get_dt_field(self, field): out_column = self.series._column.get_dt_field(field) return Series( diff --git a/python/cudf/cudf/tests/test_datetime.py b/python/cudf/cudf/tests/test_datetime.py index 904595ad5a5..41ffa9e57a4 100644 --- a/python/cudf/cudf/tests/test_datetime.py +++ b/python/cudf/cudf/tests/test_datetime.py @@ -1379,3 +1379,33 @@ def test_is_month_start(data, dtype): got = gs.dt.is_month_start assert_eq(expect, got) + + +@pytest.mark.parametrize( + "data", + [ + [ + "2020-05-31", + "2020-02-29", + None, + "1999-12-01", + "2000-12-21", + None, + "1900-02-28", + "1800-03-14", + "2100-03-10", + "1970-01-01", + "1969-12-11", + ] + ], +) +@pytest.mark.parametrize("dtype", ["datetime64[ns]"]) +def test_is_month_end(data, dtype): + # Series + ps = pd.Series(data, dtype=dtype) + gs = cudf.from_pandas(ps) + + expect = ps.dt.is_month_end + got = gs.dt.is_month_end + + assert_eq(expect, got) From c007b1c17295af4fe009c6616f9fe3bf0d6d4c44 Mon Sep 17 00:00:00 2001 From: Vyas Ramasubramani Date: Wed, 11 Aug 2021 13:56:57 -0700 Subject: [PATCH 2/7] Remove _copy_construct factory (#8999) This PR removes the `_copy_construct` factory method for the `Frame` types that were still exposing it, replacing its usage with `_from_data`. The current implementation of `_from_data` is slightly faster, and more importantly this change leaves us with a single fast path for building `Frame` objects (bypassing the slow constructors) so that we only have to maintain and optimize one going forward. Authors: - Vyas Ramasubramani (https://github.com/vyasr) Approvers: - Marlene (https://github.com/marlenezw) URL: https://github.com/rapidsai/cudf/pull/8999 --- python/cudf/cudf/core/_internals/where.py | 2 +- python/cudf/cudf/core/dataframe.py | 4 +- python/cudf/cudf/core/frame.py | 23 ++++------ python/cudf/cudf/core/index.py | 56 +++++++---------------- python/cudf/cudf/core/indexing.py | 5 +- python/cudf/cudf/core/multiindex.py | 3 +- python/cudf/cudf/core/series.py | 51 +++++++++++---------- python/cudf/cudf/core/window/rolling.py | 2 +- 8 files changed, 60 insertions(+), 86 deletions(-) diff --git a/python/cudf/cudf/core/_internals/where.py b/python/cudf/cudf/core/_internals/where.py index 87dc1d8e01f..176d91ad478 100644 --- a/python/cudf/cudf/core/_internals/where.py +++ b/python/cudf/cudf/core/_internals/where.py @@ -378,6 +378,6 @@ def where( if isinstance(frame, Index): result = Index(result, name=frame.name) else: - result = frame._copy_construct(data=result) + result = frame._from_data({frame.name: result}, frame._index) return frame._mimic_inplace(result, inplace=inplace) diff --git a/python/cudf/cudf/core/dataframe.py b/python/cudf/cudf/core/dataframe.py index bb6b54a490a..0aafae0a85b 100644 --- a/python/cudf/cudf/core/dataframe.py +++ b/python/cudf/cudf/core/dataframe.py @@ -10,7 +10,7 @@ import warnings from collections import defaultdict from collections.abc import Iterable, Sequence -from typing import Any, Mapping, Optional, TypeVar +from typing import Any, MutableMapping, Optional, TypeVar import cupy import numpy as np @@ -459,7 +459,7 @@ def _init_from_dict_like(self, data, index=None, columns=None): @classmethod def _from_data( cls, - data: Mapping, + data: MutableMapping, index: Optional[BaseIndex] = None, columns: Any = None, ) -> DataFrame: diff --git a/python/cudf/cudf/core/frame.py b/python/cudf/cudf/core/frame.py index f06bd9f9024..3c6bc057af1 100644 --- a/python/cudf/cudf/core/frame.py +++ b/python/cudf/cudf/core/frame.py @@ -6,7 +6,7 @@ import functools import warnings from collections import abc -from typing import Any, Dict, Mapping, Optional, Tuple, TypeVar, Union +from typing import Any, Dict, MutableMapping, Optional, Tuple, TypeVar, Union import cupy import numpy as np @@ -65,7 +65,9 @@ def __init_subclass__(cls): @classmethod def _from_data( - cls, data: Mapping, index: Optional[cudf.core.index.BaseIndex] = None, + cls, + data: MutableMapping, + index: Optional[cudf.core.index.BaseIndex] = None, ): obj = cls.__new__(cls) libcudf.table.Table.__init__(obj, data, index) @@ -4229,7 +4231,7 @@ def _reduce( @classmethod def _from_data( cls, - data: Mapping, + data: MutableMapping, index: Optional[cudf.core.index.BaseIndex] = None, name: Any = None, ): @@ -4519,16 +4521,6 @@ def factorize(self, na_sentinel=-1): """ return cudf.core.algorithms.factorize(self, na_sentinel=na_sentinel) - @property - def _copy_construct_defaults(self): - """A default dictionary of kwargs to be used for copy construction.""" - raise NotImplementedError - - def _copy_construct(self, **kwargs): - """Shallow copy this object by replacing certain ctor args. - """ - return self.__class__(**{**self._copy_construct_defaults, **kwargs}) - def _binaryop( self, other: T, @@ -4587,8 +4579,9 @@ def _binaryop( result_name: (self._column, other, reflect, fill_value) } - return self._copy_construct( - data=type(self)._colwise_binop(operands, fn)[result_name], + return self._from_data( + data=type(self)._colwise_binop(operands, fn), + index=self._index, name=result_name, ) diff --git a/python/cudf/cudf/core/index.py b/python/cudf/cudf/core/index.py index 9ed756547bb..d53fc4dd3c3 100644 --- a/python/cudf/cudf/core/index.py +++ b/python/cudf/cudf/core/index.py @@ -4,7 +4,16 @@ import pickle from numbers import Number -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import ( + Any, + Dict, + List, + MutableMapping, + Optional, + Tuple, + Type, + Union, +) import cupy import numpy as np @@ -661,39 +670,6 @@ def difference(self, other, sort=None): return difference - def _copy_construct(self, **kwargs): - # Need to override the parent behavior because pandas allows operations - # on unsigned types to return signed values, forcing us to choose the - # right index type here. - data = kwargs.get("data") - cls = self.__class__ - - if data is not None: - if self.dtype != data.dtype: - # TODO: This logic is largely copied from `as_index`. The two - # should be unified via a centralized type dispatching scheme. - if isinstance(data, NumericalColumn): - try: - cls = _dtype_to_index[data.dtype.type] - except KeyError: - cls = GenericIndex - elif isinstance(data, StringColumn): - cls = StringIndex - elif isinstance(data, DatetimeColumn): - cls = DatetimeIndex - elif isinstance(data, TimeDeltaColumn): - cls = TimedeltaIndex - elif isinstance(data, CategoricalColumn): - cls = CategoricalIndex - elif cls is RangeIndex: - # RangeIndex must convert to other numerical types for ops - try: - cls = _dtype_to_index[data.dtype.type] - except KeyError: - cls = GenericIndex - - return cls(**{**self._copy_construct_defaults, **kwargs}) - def sort_values(self, return_indexer=False, ascending=True, key=None): """ Return a sorted copy of the index, and optionally return the indices @@ -1299,12 +1275,14 @@ def from_pandas(cls, index, nan_as_null=None): ind.name = index.name return ind - @property - def _copy_construct_defaults(self): - return {"data": self._column, "name": self.name} - @classmethod - def _from_data(cls, data, index=None): + def _from_data( + cls, + data: MutableMapping, + index: Optional[BaseIndex] = None, + name: Any = None, + ) -> BaseIndex: + assert index is None if not isinstance(data, cudf.core.column_accessor.ColumnAccessor): data = cudf.core.column_accessor.ColumnAccessor(data) if len(data) == 0: diff --git a/python/cudf/cudf/core/indexing.py b/python/cudf/cudf/core/indexing.py index a4a69a4e084..09cfc6e144a 100755 --- a/python/cudf/cudf/core/indexing.py +++ b/python/cudf/cudf/core/indexing.py @@ -98,8 +98,9 @@ def __getitem__(self, arg): or _is_null_host_scalar(data) ): return data - index = self._sr.index.take(arg) - return self._sr._copy_construct(data=data, index=index) + return self._sr._from_data( + {self._sr.name: data}, index=cudf.Index(self._sr.index.take(arg)) + ) def __setitem__(self, key, value): from cudf.core.column import column diff --git a/python/cudf/cudf/core/multiindex.py b/python/cudf/cudf/core/multiindex.py index cdc80b6ef32..af6ac5f3dae 100644 --- a/python/cudf/cudf/core/multiindex.py +++ b/python/cudf/cudf/core/multiindex.py @@ -830,8 +830,7 @@ def _compute_levels_and_codes(self): for name in self._source_data.columns: code, cats = self._source_data[name].factorize() codes[name] = code.astype(np.int64) - cats.name = None - cats = cudf.Series(cats)._copy_construct(name=None) + cats = cudf.Series(cats, name=None) levels.append(cats) self._levels = levels diff --git a/python/cudf/cudf/core/series.py b/python/cudf/cudf/core/series.py index d33d624b266..177208fa921 100644 --- a/python/cudf/cudf/core/series.py +++ b/python/cudf/cudf/core/series.py @@ -7,7 +7,7 @@ from collections import abc as abc from numbers import Number from shutil import get_terminal_size -from typing import Any, Mapping, Optional +from typing import Any, MutableMapping, Optional from uuid import uuid4 import cupy @@ -270,7 +270,7 @@ def __init__( @classmethod def _from_data( cls, - data: Mapping, + data: MutableMapping, index: Optional[BaseIndex] = None, name: Any = None, ) -> Series: @@ -383,10 +383,6 @@ def deserialize(cls, header, frames): return Series(column, index=index, name=name) - @property - def _copy_construct_defaults(self): - return {"data": self._column, "index": self._index, "name": self.name} - def _get_columns_by_label(self, labels, downcast=False): """Return the column specified by `labels` @@ -699,7 +695,7 @@ def reset_index(self, drop=False, inplace=False): if inplace is True: self._index = RangeIndex(len(self)) else: - return self._copy_construct(index=RangeIndex(len(self))) + return self._from_data(self._data, index=RangeIndex(len(self))) def set_index(self, index): """Returns a new Series with a different index. @@ -734,7 +730,7 @@ def set_index(self, index): dtype: int64 """ index = index if isinstance(index, BaseIndex) else as_index(index) - return self._copy_construct(index=index) + return self._from_data(self._data, index, self.name) def as_index(self): """Returns a new Series with a RangeIndex. @@ -851,8 +847,9 @@ def set_mask(self, mask, null_count=None): "in the future.", DeprecationWarning, ) - col = self._column.set_mask(mask) - return self._copy_construct(data=col) + return self._from_data( + {self.name: self._column.set_mask(mask)}, self._index + ) def __sizeof__(self): return self._column.__sizeof__() + self._index.__sizeof__() @@ -1093,8 +1090,9 @@ def take(self, indices, keep_index=True): return self.iloc[indices] else: col_inds = as_column(indices) - data = self._column.take(col_inds, keep_index=False) - return self._copy_construct(data=data, index=None) + return self._from_data( + {self.name: self._column.take(col_inds, keep_index=False)} + ) def head(self, n=5): """ @@ -2723,8 +2721,9 @@ def nans_to_nulls(self): 4 10.0 dtype: float64 """ - result_col = self._column.nans_to_nulls() - return self._copy_construct(data=result_col) + return self._from_data( + {self.name: self._column.nans_to_nulls()}, self._index + ) def all(self, axis=0, bool_only=None, skipna=True, level=None, **kwargs): if bool_only not in (None, True): @@ -3011,8 +3010,9 @@ def astype(self, dtype, copy=False, errors="raise"): try: data = self._column.astype(dtype) - return self._copy_construct( - data=data.copy(deep=True) if copy else data, index=self.index + return self._from_data( + {self.name: (data.copy(deep=True) if copy else data)}, + index=self._index, ) except Exception as e: @@ -3326,8 +3326,8 @@ def _sort(self, ascending=True, na_position="last"): col_keys, col_inds = self._column.sort_by_values( ascending=ascending, na_position=na_position ) - sr_keys = self._copy_construct(data=col_keys) - sr_inds = self._copy_construct(data=col_inds) + sr_keys = self._from_data({self.name: col_keys}, self._index) + sr_inds = self._from_data({self.name: col_inds}, self._index) return sr_keys, sr_inds def replace( @@ -3630,9 +3630,9 @@ def reverse(self): dtype: int64 """ rinds = column.arange((self._column.size - 1), -1, -1, dtype=np.int32) - col = self._column[rinds] - index = self.index._values[rinds] - return self._copy_construct(data=col, index=index) + return self._from_data( + {self.name: self._column[rinds]}, self.index._values[rinds] + ) def one_hot_encoding(self, cats, dtype="float64"): """Perform one-hot-encoding @@ -3786,7 +3786,9 @@ def _return_sentinel_series(): codes = codes.merge(value, on="value", how="left") codes = codes.sort_values("order")["code"].fillna(na_sentinel) - return codes._copy_construct(name=None, index=self.index) + codes.name = None + codes.index = self._index + return codes # UDF related @@ -3900,7 +3902,7 @@ def applymap(self, udf, out_dtype=None): """ if not callable(udf): raise ValueError("Input UDF must be a callable object.") - return self._copy_construct(data=self._unaryop(udf)) + return self._from_data({self.name: self._unaryop(udf)}, self._index) # # Stats @@ -4721,7 +4723,8 @@ def scale(self): vmin = self.min() vmax = self.max() scaled = (self - vmin) / (vmax - vmin) - return self._copy_construct(data=scaled) + scaled._index = self._index.copy(deep=False) + return scaled # Absolute def abs(self): diff --git a/python/cudf/cudf/core/window/rolling.py b/python/cudf/cudf/core/window/rolling.py index d2f120a7bb9..e3ed15ba2a6 100644 --- a/python/cudf/cudf/core/window/rolling.py +++ b/python/cudf/cudf/core/window/rolling.py @@ -215,7 +215,7 @@ def _apply_agg_series(self, sr, agg_name): self.center, agg_name, ) - return sr._copy_construct(data=result_col) + return sr._from_data({sr.name: result_col}, sr._index) def _apply_agg_dataframe(self, df, agg_name): result_df = cudf.DataFrame({}) From 4968a9687a80c6b32cbc6e588635f407751416ff Mon Sep 17 00:00:00 2001 From: David Wendt <45795991+davidwendt@users.noreply.github.com> Date: Wed, 11 Aug 2021 16:59:45 -0400 Subject: [PATCH 3/7] Replace allocate with device_uvector for subword_tokenize internal tables (#8952) The `nvtext::subword_tokenize` function uses 2 internal static code-point tables for processing. These were being allocated in device memory using `rmm::mr::get_current_device_resource()->allocate()` which is to be deprecated. This PR changes this code logic to use `rmm::device_uvector` instead. Authors: - David Wendt (https://github.com/davidwendt) Approvers: - Christopher Harris (https://github.com/cwharris) - MithunR (https://github.com/mythrocks) URL: https://github.com/rapidsai/cudf/pull/8952 --- cpp/src/text/subword/load_hash_file.cu | 34 ++++++++++++++------------ 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/cpp/src/text/subword/load_hash_file.cu b/cpp/src/text/subword/load_hash_file.cu index 3800339a6a2..b2230f95842 100644 --- a/cpp/src/text/subword/load_hash_file.cu +++ b/cpp/src/text/subword/load_hash_file.cu @@ -26,6 +26,7 @@ #include #include +#include #include #include @@ -40,11 +41,11 @@ namespace { struct get_codepoint_metadata_init { rmm::cuda_stream_view stream; - codepoint_metadata_type* operator()() const + rmm::device_uvector* operator()() const { - codepoint_metadata_type* table = - static_cast(rmm::mr::get_current_device_resource()->allocate( - codepoint_metadata_size * sizeof(codepoint_metadata_type), stream)); + auto table_vector = + new rmm::device_uvector(codepoint_metadata_size, stream); + auto table = table_vector->data(); thrust::fill(rmm::exec_policy(stream), table + cp_section1_end, table + codepoint_metadata_size, @@ -60,18 +61,18 @@ struct get_codepoint_metadata_init { (cp_section2_end - cp_section2_begin + 1) * sizeof(codepoint_metadata[0]), // 2nd section cudaMemcpyHostToDevice, stream.value())); - return table; + return table_vector; }; }; struct get_aux_codepoint_data_init { rmm::cuda_stream_view stream; - aux_codepoint_data_type* operator()() const + rmm::device_uvector* operator()() const { - aux_codepoint_data_type* table = - static_cast(rmm::mr::get_current_device_resource()->allocate( - aux_codepoint_data_size * sizeof(aux_codepoint_data_type), stream)); + auto table_vector = + new rmm::device_uvector(aux_codepoint_data_size, stream); + auto table = table_vector->data(); thrust::fill(rmm::exec_policy(stream), table + aux_section1_end, table + aux_codepoint_data_size, @@ -99,7 +100,7 @@ struct get_aux_codepoint_data_init { (aux_section4_end - aux_section4_begin + 1) * sizeof(aux_codepoint_data[0]), // 4th section cudaMemcpyHostToDevice, stream.value())); - return table; + return table_vector; } }; } // namespace @@ -112,11 +113,11 @@ struct get_aux_codepoint_data_init { */ const codepoint_metadata_type* get_codepoint_metadata(rmm::cuda_stream_view stream) { - static cudf::strings::detail::thread_safe_per_context_cache + static cudf::strings::detail::thread_safe_per_context_cache< + rmm::device_uvector> g_codepoint_metadata; - get_codepoint_metadata_init function = {stream}; - return g_codepoint_metadata.find_or_initialize(function); + return g_codepoint_metadata.find_or_initialize(get_codepoint_metadata_init{stream})->data(); } /** @@ -127,10 +128,11 @@ const codepoint_metadata_type* get_codepoint_metadata(rmm::cuda_stream_view stre */ const aux_codepoint_data_type* get_aux_codepoint_data(rmm::cuda_stream_view stream) { - static cudf::strings::detail::thread_safe_per_context_cache + static cudf::strings::detail::thread_safe_per_context_cache< + rmm::device_uvector> g_aux_codepoint_data; - get_aux_codepoint_data_init function = {stream}; - return g_aux_codepoint_data.find_or_initialize(function); + + return g_aux_codepoint_data.find_or_initialize(get_aux_codepoint_data_init{stream})->data(); } namespace { From 7461b20f99bbb1f931a9ff91318e722210c9311d Mon Sep 17 00:00:00 2001 From: Sarah Yurick <53962159+sarahyurick@users.noreply.github.com> Date: Wed, 11 Aug 2021 19:07:58 -0400 Subject: [PATCH 4/7] Allow `where()` to work with a Series and `other=cudf.NA` (#9019) Fixes #8969. Duplicate of #8977 - some of the checks are erroring and I'm seeing strange messages about the git commits, so I'm re-opening the PR here to see if that fixes it. Authors: - Sarah Yurick (https://github.com/sarahyurick) Approvers: - Ashwin Srinath (https://github.com/shwina) URL: https://github.com/rapidsai/cudf/pull/9019 --- python/cudf/cudf/core/_internals/where.py | 16 ++++++++++++---- python/cudf/cudf/tests/test_dataframe.py | 20 ++++++++++++++++++++ python/cudf/cudf/utils/dtypes.py | 2 ++ 3 files changed, 34 insertions(+), 4 deletions(-) diff --git a/python/cudf/cudf/core/_internals/where.py b/python/cudf/cudf/core/_internals/where.py index 176d91ad478..0688283bc43 100644 --- a/python/cudf/cudf/core/_internals/where.py +++ b/python/cudf/cudf/core/_internals/where.py @@ -27,7 +27,9 @@ def _normalize_scalars(col: ColumnBase, other: ScalarLike) -> ScalarLike: f"{type(other).__name__} to {col.dtype.name}" ) - return cudf.Scalar(other, dtype=col.dtype if other is None else None) + return cudf.Scalar( + other, dtype=col.dtype if other in {None, cudf.NA} else None + ) def _check_and_cast_columns_with_other( @@ -234,9 +236,15 @@ def where( if isinstance(frame, DataFrame): if hasattr(cond, "__cuda_array_interface__"): - cond = DataFrame( - cond, columns=frame._column_names, index=frame.index - ) + if isinstance(cond, Series): + cond = DataFrame( + {name: cond for name in frame._column_names}, + index=frame.index, + ) + else: + cond = DataFrame( + cond, columns=frame._column_names, index=frame.index + ) elif ( hasattr(cond, "__array_interface__") and cond.__array_interface__["shape"] != frame.shape diff --git a/python/cudf/cudf/tests/test_dataframe.py b/python/cudf/cudf/tests/test_dataframe.py index 8744238a062..14176fd932d 100644 --- a/python/cudf/cudf/tests/test_dataframe.py +++ b/python/cudf/cudf/tests/test_dataframe.py @@ -8731,3 +8731,23 @@ def test_frame_series_where(): expected = gdf.where(gdf.notna(), gdf.mean()) actual = pdf.where(pdf.notna(), pdf.mean(), axis=1) assert_eq(expected, actual) + + +@pytest.mark.parametrize( + "data", [{"a": [1, 2, 3], "b": [1, 1, 0]}], +) +def test_frame_series_where_other(data): + gdf = cudf.DataFrame(data) + pdf = gdf.to_pandas() + + expected = gdf.where(gdf["b"] == 1, cudf.NA) + actual = pdf.where(pdf["b"] == 1, pd.NA) + assert_eq( + actual.fillna(-1).values, + expected.fillna(-1).values, + check_dtype=False, + ) + + expected = gdf.where(gdf["b"] == 1, 0) + actual = pdf.where(pdf["b"] == 1, 0) + assert_eq(expected, actual) diff --git a/python/cudf/cudf/utils/dtypes.py b/python/cudf/cudf/utils/dtypes.py index 46bd1b449c4..829a1545365 100644 --- a/python/cudf/cudf/utils/dtypes.py +++ b/python/cudf/cudf/utils/dtypes.py @@ -581,6 +581,8 @@ def _can_cast(from_dtype, to_dtype): `np.can_cast` but with some special handling around cudf specific dtypes. """ + if from_dtype in {None, cudf.NA}: + return True if isinstance(from_dtype, type): from_dtype = np.dtype(from_dtype) if isinstance(to_dtype, type): From d22fdcf5102f78d09ca6f0a8d79086ec6ed3d49b Mon Sep 17 00:00:00 2001 From: Karthikeyan <6488848+karthikeyann@users.noreply.github.com> Date: Thu, 12 Aug 2021 17:22:52 +0530 Subject: [PATCH 5/7] Fix libcudf memory errors (#8884) Fixes https://github.com/rapidsai/cudf/issues/8883 All memory errors caught in libcudf unit tests are fixed in this PR. These unit tests are checked with `cuda-memcheck` before and after the fix. The following tests FAILED: - [x] 2 - SCALAR_TEST (Failed) (Fixed) - [x] 32 - COPYING_TEST (Failed) (Fixed) - [x] 39 - MERGE_TEST (Failed) (Fixed) - [x] 46 - FACTORIES_TEST (Failed) (Fixed by scalar test fix) Authors: - Karthikeyan (https://github.com/karthikeyann) Approvers: - Christopher Harris (https://github.com/cwharris) - David Wendt (https://github.com/davidwendt) - MithunR (https://github.com/mythrocks) URL: https://github.com/rapidsai/cudf/pull/8884 --- cpp/src/dictionary/detail/merge.cu | 8 ++++++-- cpp/src/scalar/scalar.cpp | 9 +++++---- cpp/src/structs/utilities.cpp | 3 +-- cpp/tests/copying/concatenate_tests.cu | 14 ++++++-------- cpp/tests/utilities/column_utilities.cu | 22 ++++++++-------------- 5 files changed, 26 insertions(+), 30 deletions(-) diff --git a/cpp/src/dictionary/detail/merge.cu b/cpp/src/dictionary/detail/merge.cu index 2ff0a3e0a2a..e972403cad3 100644 --- a/cpp/src/dictionary/detail/merge.cu +++ b/cpp/src/dictionary/detail/merge.cu @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -62,8 +63,11 @@ std::unique_ptr merge(dictionary_column_view const& lcol, return make_dictionary_column( std::make_unique(lcol.keys(), stream, mr), std::move(indices_column), - rmm::device_buffer{ - lcol.has_nulls() || rcol.has_nulls() ? static_cast(merged_size) : 0, stream, mr}, + cudf::detail::create_null_mask( + lcol.has_nulls() || rcol.has_nulls() ? static_cast(merged_size) : 0, + mask_state::UNINITIALIZED, + stream, + mr), lcol.null_count() + rcol.null_count()); } diff --git a/cpp/src/scalar/scalar.cpp b/cpp/src/scalar/scalar.cpp index a8d05e98034..f982e7b99f2 100644 --- a/cpp/src/scalar/scalar.cpp +++ b/cpp/src/scalar/scalar.cpp @@ -17,7 +17,7 @@ #include #include -#include +#include #include #include #include @@ -574,12 +574,13 @@ void struct_scalar::superimpose_nulls(rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { // push validity mask down - std::vector host_validity({0}); - auto validity = cudf::detail::make_device_uvector_sync(host_validity, stream, mr); + std::vector host_validity( + cudf::bitmask_allocation_size_bytes(1) / sizeof(bitmask_type), 0); + auto validity = cudf::detail::create_null_mask(1, mask_state::ALL_NULL, stream); auto iter = thrust::make_counting_iterator(0); std::for_each(iter, iter + _data.num_columns(), [&](size_type i) { cudf::structs::detail::superimpose_parent_nulls( - validity.data(), 1, _data.get_column(i), stream, mr); + static_cast(validity.data()), 1, _data.get_column(i), stream, mr); }); } diff --git a/cpp/src/structs/utilities.cpp b/cpp/src/structs/utilities.cpp index 80bea2ab55e..aa32c555324 100644 --- a/cpp/src/structs/utilities.cpp +++ b/cpp/src/structs/utilities.cpp @@ -187,8 +187,7 @@ void superimpose_parent_nulls(bitmask_type const* parent_null_mask, { if (!child.nullable()) { // Child currently has no null mask. Copy parent's null mask. - child.set_null_mask(rmm::device_buffer{ - parent_null_mask, cudf::bitmask_allocation_size_bytes(child.size()), stream, mr}); + child.set_null_mask(cudf::detail::copy_bitmask(parent_null_mask, 0, child.size(), stream, mr)); child.set_null_count(parent_null_count); } else { // Child should have a null mask. diff --git a/cpp/tests/copying/concatenate_tests.cu b/cpp/tests/copying/concatenate_tests.cu index 5237c75e4d4..c48f7ad4dbc 100644 --- a/cpp/tests/copying/concatenate_tests.cu +++ b/cpp/tests/copying/concatenate_tests.cu @@ -48,8 +48,6 @@ using Table = cudf::table; template struct TypedColumnTest : public cudf::test::BaseFixture { - static std::size_t data_size() { return 1000; } - static std::size_t mask_size() { return 100; } cudf::data_type type() { return cudf::data_type{cudf::type_to_id()}; } TypedColumnTest(rmm::cuda_stream_view stream = rmm::cuda_stream_default) @@ -58,14 +56,14 @@ struct TypedColumnTest : public cudf::test::BaseFixture { { auto typed_data = static_cast(data.data()); auto typed_mask = static_cast(mask.data()); - std::vector h_data(data_size()); + std::vector h_data(data.size()); std::iota(h_data.begin(), h_data.end(), char{0}); - std::vector h_mask(mask_size()); + std::vector h_mask(mask.size()); std::iota(h_mask.begin(), h_mask.end(), char{0}); CUDA_TRY(cudaMemcpyAsync( - typed_data, h_data.data(), data_size(), cudaMemcpyHostToDevice, stream.value())); + typed_data, h_data.data(), data.size(), cudaMemcpyHostToDevice, stream.value())); CUDA_TRY(cudaMemcpyAsync( - typed_mask, h_mask.data(), mask_size(), cudaMemcpyHostToDevice, stream.value())); + typed_mask, h_mask.data(), mask.size(), cudaMemcpyHostToDevice, stream.value())); stream.synchronize(); } @@ -484,7 +482,7 @@ TEST_F(OverflowTest, Presliced) auto offset_gen = cudf::detail::make_counting_transform_iterator( 0, [string_size](size_type index) { return index * string_size; }); cudf::test::fixed_width_column_wrapper offsets(offset_gen, offset_gen + num_rows + 1); - auto many_chars = cudf::make_fixed_width_column(data_type{type_id::INT8}, num_rows); + auto many_chars = cudf::make_fixed_width_column(data_type{type_id::INT8}, total_chars_size); auto col = cudf::make_strings_column( num_rows, offsets.release(), std::move(many_chars), 0, rmm::device_buffer{}); @@ -515,7 +513,7 @@ TEST_F(OverflowTest, Presliced) offsets->view().begin(), offsets->view().end(), offsets->mutable_view().begin()); - auto many_chars = cudf::make_fixed_width_column(data_type{type_id::INT8}, num_rows); + auto many_chars = cudf::make_fixed_width_column(data_type{type_id::INT8}, total_chars_size); auto col = cudf::make_strings_column( num_rows, std::move(offsets), std::move(many_chars), 0, rmm::device_buffer{}); diff --git a/cpp/tests/utilities/column_utilities.cu b/cpp/tests/utilities/column_utilities.cu index 88e9e3d1384..f3002bc4b1a 100644 --- a/cpp/tests/utilities/column_utilities.cu +++ b/cpp/tests/utilities/column_utilities.cu @@ -114,14 +114,6 @@ std::unique_ptr generate_child_row_indices(lists_column_view const& c, // // result = [6, 1, 11, 1, 1] // - auto validity_iter = cudf::detail::make_counting_transform_iterator( - 0, - [row_indices = row_indices.begin(), - validity = c.null_mask(), - offset = c.offset()] __device__(int index) { - auto const true_index = row_indices[index] + offset; - return !validity || cudf::bit_is_set(validity, true_index) ? 1 : 0; - }); auto output_row_iter = cudf::detail::make_counting_transform_iterator( 0, [row_indices = row_indices.begin(), @@ -136,8 +128,9 @@ std::unique_ptr generate_child_row_indices(lists_column_view const& c, output_row_iter, output_row_iter + row_indices.size(), output_row_start->view().begin(), - validity_iter, - result->mutable_view().begin()); + row_size_iter, + result->mutable_view().begin(), + [] __device__(auto row_size) { return row_size != 0; }); // generate keys for each output row // @@ -150,11 +143,12 @@ std::unique_ptr generate_child_row_indices(lists_column_view const& c, keys->mutable_view().end(), [] __device__() { return 0; }); thrust::scatter_if(rmm::exec_policy(), - validity_iter, - validity_iter + row_indices.size(), + row_size_iter, + row_size_iter + row_indices.size(), output_row_start->view().begin(), - validity_iter, - keys->mutable_view().begin()); + row_size_iter, + keys->mutable_view().begin(), + [] __device__(auto row_size) { return row_size != 0; }); thrust::inclusive_scan(rmm::exec_policy(), keys->view().begin(), keys->view().end(), From 59b84f3e8d7cf0d9a15e92bdc8bcafbc01c9bfec Mon Sep 17 00:00:00 2001 From: Marlene <57748216+marlenezw@users.noreply.github.com> Date: Thu, 12 Aug 2021 15:17:07 +0200 Subject: [PATCH 6/7] Series datetime is_year_end and is_year_start (#8954) This PR aims to allow users to be able to determine whether a datetime is the beginning or end of a year. This PR closes #8680 Authors: - Marlene (https://github.com/marlenezw) - GALI PREM SAGAR (https://github.com/galipremsagar) Approvers: - GALI PREM SAGAR (https://github.com/galipremsagar) - https://github.com/brandon-b-miller URL: https://github.com/rapidsai/cudf/pull/8954 --- python/cudf/cudf/core/series.py | 70 +++++++++++++++++++++++++ python/cudf/cudf/tests/test_datetime.py | 66 +++++++++++++++++++++++ 2 files changed, 136 insertions(+) diff --git a/python/cudf/cudf/core/series.py b/python/cudf/cudf/core/series.py index 177208fa921..bc6242646a0 100644 --- a/python/cudf/cudf/core/series.py +++ b/python/cudf/cudf/core/series.py @@ -6082,6 +6082,76 @@ def is_month_end(self): ) return (self.day == last_day.dt.day).fillna(False) + @property + def is_year_start(self): + """ + Boolean indicator if the date is the first day of the year. + + Returns + ------- + Series + Booleans indicating if dates are the first day of the year. + + Example + ------- + >>> import pandas as pd, cudf + >>> s = cudf.Series(pd.date_range("2017-12-30", periods=3)) + >>> dates + 0 2017-12-30 + 1 2017-12-31 + 2 2018-01-01 + dtype: datetime64[ns] + >>> dates.dt.is_year_start + 0 False + 1 False + 2 True + dtype: bool + """ + outcol = self.series._column.get_dt_field( + "day_of_year" + ) == cudf.Scalar(1) + return Series._from_data( + {None: outcol.fillna(False)}, + index=self.series._index, + name=self.series.name, + ) + + @property + def is_year_end(self): + """ + Boolean indicator if the date is the last day of the year. + + Returns + ------- + Series + Booleans indicating if dates are the last day of the year. + + Example + ------- + >>> import pandas as pd, cudf + >>> dates = cudf.Series(pd.date_range("2017-12-30", periods=3)) + >>> dates + 0 2017-12-30 + 1 2017-12-31 + 2 2018-01-01 + dtype: datetime64[ns] + >>> dates.dt.is_year_end + 0 False + 1 True + 2 False + dtype: bool + """ + day_of_year = self.series._column.get_dt_field("day_of_year") + leap_dates = libcudf.datetime.is_leap_year(self.series._column) + + leap = day_of_year == cudf.Scalar(366) + non_leap = day_of_year == cudf.Scalar(365) + result = cudf._lib.copying.copy_if_else(leap, non_leap, leap_dates) + result = result.fillna(False) + return Series._from_data( + {None: result}, index=self.series._index, name=self.series.name, + ) + def _get_dt_field(self, field): out_column = self.series._column.get_dt_field(field) return Series( diff --git a/python/cudf/cudf/tests/test_datetime.py b/python/cudf/cudf/tests/test_datetime.py index 41ffa9e57a4..6f30f04d6e1 100644 --- a/python/cudf/cudf/tests/test_datetime.py +++ b/python/cudf/cudf/tests/test_datetime.py @@ -1409,3 +1409,69 @@ def test_is_month_end(data, dtype): got = gs.dt.is_month_end assert_eq(expect, got) + + +@pytest.mark.parametrize( + "data", + [ + [ + "2020-05-31", + None, + "1999-12-01", + "2000-12-21", + None, + "1900-01-01", + "1800-03-14", + "2100-03-10", + "1970-01-01", + "1969-12-11", + "2017-12-30", + "2017-12-31", + "2018-01-01", + ] + ], +) +@pytest.mark.parametrize("dtype", ["datetime64[ns]"]) +def test_is_year_start(data, dtype): + ps = pd.Series(data, dtype=dtype) + gs = cudf.from_pandas(ps) + + expect = ps.dt.is_year_start + got = gs.dt.is_year_start + + assert_eq(expect, got) + + +@pytest.mark.parametrize( + "data", + [ + [ + "2020-05-31", + None, + "1999-12-01", + "2000-12-21", + None, + "1900-12-31", + "1800-03-14", + "2017-12-30", + "2017-12-31", + "2020-12-31 08:00:00", + None, + "1999-12-31 18:40:00", + "2000-12-31 04:00:00", + None, + "1800-12-14 07:30:00", + "2100-12-14 07:30:00", + "2020-05-31", + ] + ], +) +@pytest.mark.parametrize("dtype", ["datetime64[ns]"]) +def test_is_year_end(data, dtype): + ps = pd.Series(data, dtype=dtype) + gs = cudf.from_pandas(ps) + + expect = ps.dt.is_year_end + got = gs.dt.is_year_end + + assert_eq(expect, got) From 2c5a2ad842f1d412c2b8afc86eb49be8c1e0c681 Mon Sep 17 00:00:00 2001 From: GALI PREM SAGAR Date: Fri, 13 Aug 2021 03:08:35 -0500 Subject: [PATCH 7/7] Upgrade `arrow` & `pyarrow` to `5.0.0` (#8908) This PR upgrades arrow to `5.0.0`. - [x] Upgrade & test arrow 5.0.0. - [x] Fix pytest failures related to decimal arrays. Authors: - GALI PREM SAGAR (https://github.com/galipremsagar) Approvers: - Robert Maynard (https://github.com/robertmaynard) - Ashwin Srinath (https://github.com/shwina) - Mark Sadang (https://github.com/msadang) - Dillon Cullinan (https://github.com/dillon-cullinan) URL: https://github.com/rapidsai/cudf/pull/8908 --- conda/environments/cudf_dev_cuda11.0.yml | 4 +-- conda/environments/cudf_dev_cuda11.2.yml | 4 +-- conda/recipes/cudf/meta.yaml | 2 +- conda/recipes/libcudf/meta.yaml | 2 +- cpp/cmake/thirdparty/CUDF_GetArrow.cmake | 2 +- python/cudf/cudf/core/column/column.py | 44 +++++++++++++----------- python/cudf/cudf/tests/test_binops.py | 40 ++++++++++----------- python/cudf/cudf/tests/test_decimal.py | 10 +++++- 8 files changed, 60 insertions(+), 48 deletions(-) diff --git a/conda/environments/cudf_dev_cuda11.0.yml b/conda/environments/cudf_dev_cuda11.0.yml index 692ebe71794..2c0984569db 100644 --- a/conda/environments/cudf_dev_cuda11.0.yml +++ b/conda/environments/cudf_dev_cuda11.0.yml @@ -17,7 +17,7 @@ dependencies: - numba>=0.53.1 - numpy - pandas>=1.0,<1.3.0dev0 - - pyarrow=4.0.1=*cuda + - pyarrow=5.0.0=*cuda - fastavro>=0.22.9 - notebook>=0.5.0 - cython>=0.29,<0.30 @@ -42,7 +42,7 @@ dependencies: - dask>=2021.6.0 - distributed>=2021.6.0 - streamz - - arrow-cpp=4.0.1 + - arrow-cpp=5.0.0 - dlpack>=0.5,<0.6.0a0 - arrow-cpp-proc * cuda - double-conversion diff --git a/conda/environments/cudf_dev_cuda11.2.yml b/conda/environments/cudf_dev_cuda11.2.yml index ce82b870e16..766d85e957b 100644 --- a/conda/environments/cudf_dev_cuda11.2.yml +++ b/conda/environments/cudf_dev_cuda11.2.yml @@ -17,7 +17,7 @@ dependencies: - numba>=0.53.1 - numpy - pandas>=1.0,<1.3.0dev0 - - pyarrow=4.0.1=*cuda + - pyarrow=5.0.0=*cuda - fastavro>=0.22.9 - notebook>=0.5.0 - cython>=0.29,<0.30 @@ -42,7 +42,7 @@ dependencies: - dask>=2021.6.0 - distributed>=2021.6.0 - streamz - - arrow-cpp=4.0.1 + - arrow-cpp=5.0.0 - dlpack>=0.5,<0.6.0a0 - arrow-cpp-proc * cuda - double-conversion diff --git a/conda/recipes/cudf/meta.yaml b/conda/recipes/cudf/meta.yaml index 9023e89c2f5..ca36acccfbb 100644 --- a/conda/recipes/cudf/meta.yaml +++ b/conda/recipes/cudf/meta.yaml @@ -30,7 +30,7 @@ requirements: - setuptools - numba >=0.53.1 - dlpack>=0.5,<0.6.0a0 - - pyarrow 4.0.1 *cuda + - pyarrow 5.0.0 *cuda - libcudf {{ version }} - rmm {{ minor_version }} - cudatoolkit {{ cuda_version }} diff --git a/conda/recipes/libcudf/meta.yaml b/conda/recipes/libcudf/meta.yaml index 88065ef49e0..c1ba2b495eb 100644 --- a/conda/recipes/libcudf/meta.yaml +++ b/conda/recipes/libcudf/meta.yaml @@ -37,7 +37,7 @@ requirements: host: - librmm {{ minor_version }}.* - cudatoolkit {{ cuda_version }}.* - - arrow-cpp 4.0.1 *cuda + - arrow-cpp 5.0.0 *cuda - arrow-cpp-proc * cuda - dlpack>=0.5,<0.6.0a0 run: diff --git a/cpp/cmake/thirdparty/CUDF_GetArrow.cmake b/cpp/cmake/thirdparty/CUDF_GetArrow.cmake index 5f6ff9651a2..38a5d8da44a 100644 --- a/cpp/cmake/thirdparty/CUDF_GetArrow.cmake +++ b/cpp/cmake/thirdparty/CUDF_GetArrow.cmake @@ -177,7 +177,7 @@ function(find_and_configure_arrow VERSION BUILD_STATIC ENABLE_S3 ENABLE_ORC ENAB endfunction() -set(CUDF_VERSION_Arrow 4.0.1) +set(CUDF_VERSION_Arrow 5.0.0) find_and_configure_arrow( ${CUDF_VERSION_Arrow} diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index 8aeaf08273f..b95a4495a69 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -2016,6 +2016,29 @@ def as_column( memoryview(arbitrary), dtype=dtype, nan_as_null=nan_as_null ) except TypeError: + if dtype is not None: + # Arrow throws a type error if the input is of + # mixed-precision and cannot fit into the provided + # decimal type properly, see: + # https://github.com/apache/arrow/pull/9948 + # Hence we should let the exception propagate to + # the user. + if isinstance(dtype, cudf.core.dtypes.Decimal64Dtype): + data = pa.array( + arbitrary, + type=pa.decimal128( + precision=dtype.precision, scale=dtype.scale + ), + ) + return cudf.core.column.Decimal64Column.from_arrow(data) + if isinstance(dtype, cudf.core.dtypes.Decimal32Dtype): + data = pa.array( + arbitrary, + type=pa.decimal128( + precision=dtype.precision, scale=dtype.scale + ), + ) + return cudf.core.column.Decimal32Column.from_arrow(data) pa_type = None np_type = None try: @@ -2034,26 +2057,7 @@ def as_column( ) and not isinstance(dtype, cudf.IntervalDtype): data = pa.array(arbitrary, type=dtype.to_arrow()) return as_column(data, nan_as_null=nan_as_null) - if isinstance(dtype, cudf.core.dtypes.Decimal64Dtype): - data = pa.array( - arbitrary, - type=pa.decimal128( - precision=dtype.precision, scale=dtype.scale - ), - ) - return cudf.core.column.Decimal64Column.from_arrow( - data - ) - if isinstance(dtype, cudf.core.dtypes.Decimal32Dtype): - data = pa.array( - arbitrary, - type=pa.decimal128( - precision=dtype.precision, scale=dtype.scale - ), - ) - return cudf.core.column.Decimal32Column.from_arrow( - data - ) + dtype = pd.api.types.pandas_dtype(dtype) np_type = np.dtype(dtype).type if np_type == np.bool_: diff --git a/python/cudf/cudf/tests/test_binops.py b/python/cudf/cudf/tests/test_binops.py index 8277b8e7b32..f8fd2502a7d 100644 --- a/python/cudf/cudf/tests/test_binops.py +++ b/python/cudf/cudf/tests/test_binops.py @@ -1758,16 +1758,16 @@ def test_binops_with_NA_consistent(dtype, op): ( operator.add, ["1.5", "2.0"], - cudf.Decimal64Dtype(scale=2, precision=2), + cudf.Decimal64Dtype(scale=2, precision=3), ["1.5", "2.0"], - cudf.Decimal64Dtype(scale=2, precision=2), - ["3.0", "4.0"], cudf.Decimal64Dtype(scale=2, precision=3), + ["3.0", "4.0"], + cudf.Decimal64Dtype(scale=2, precision=4), ), ( operator.add, ["1.5", "2.0"], - cudf.Decimal64Dtype(scale=2, precision=2), + cudf.Decimal64Dtype(scale=2, precision=3), ["2.25", "1.005"], cudf.Decimal64Dtype(scale=3, precision=4), ["3.75", "3.005"], @@ -1785,7 +1785,7 @@ def test_binops_with_NA_consistent(dtype, op): ( operator.sub, ["1.5", "2.0"], - cudf.Decimal64Dtype(scale=2, precision=2), + cudf.Decimal64Dtype(scale=1, precision=2), ["2.25", "1.005"], cudf.Decimal64Dtype(scale=3, precision=4), ["-0.75", "0.995"], @@ -1794,7 +1794,7 @@ def test_binops_with_NA_consistent(dtype, op): ( operator.sub, ["1.5", "2.0"], - cudf.Decimal64Dtype(scale=2, precision=2), + cudf.Decimal64Dtype(scale=1, precision=2), ["2.25", "1.005"], cudf.Decimal64Dtype(scale=3, precision=4), ["-0.75", "0.995"], @@ -1812,11 +1812,11 @@ def test_binops_with_NA_consistent(dtype, op): ( operator.mul, ["1.5", "2.0"], - cudf.Decimal64Dtype(scale=2, precision=2), + cudf.Decimal64Dtype(scale=2, precision=3), ["1.5", "3.0"], cudf.Decimal64Dtype(scale=3, precision=4), ["2.25", "6.0"], - cudf.Decimal64Dtype(scale=5, precision=7), + cudf.Decimal64Dtype(scale=5, precision=8), ), ( operator.mul, @@ -1866,16 +1866,16 @@ def test_binops_with_NA_consistent(dtype, op): ( operator.add, ["1.5", None, "2.0"], - cudf.Decimal64Dtype(scale=2, precision=2), + cudf.Decimal64Dtype(scale=1, precision=2), ["1.5", None, "2.0"], - cudf.Decimal64Dtype(scale=2, precision=2), + cudf.Decimal64Dtype(scale=1, precision=2), ["3.0", None, "4.0"], - cudf.Decimal64Dtype(scale=2, precision=3), + cudf.Decimal64Dtype(scale=1, precision=3), ), ( operator.add, ["1.5", None], - cudf.Decimal64Dtype(scale=2, precision=2), + cudf.Decimal64Dtype(scale=2, precision=3), ["2.25", "1.005"], cudf.Decimal64Dtype(scale=3, precision=4), ["3.75", None], @@ -1884,7 +1884,7 @@ def test_binops_with_NA_consistent(dtype, op): ( operator.sub, ["1.5", None], - cudf.Decimal64Dtype(scale=2, precision=2), + cudf.Decimal64Dtype(scale=2, precision=3), ["2.25", None], cudf.Decimal64Dtype(scale=3, precision=4), ["-0.75", None], @@ -1893,7 +1893,7 @@ def test_binops_with_NA_consistent(dtype, op): ( operator.sub, ["1.5", "2.0"], - cudf.Decimal64Dtype(scale=2, precision=2), + cudf.Decimal64Dtype(scale=2, precision=3), ["2.25", None], cudf.Decimal64Dtype(scale=3, precision=4), ["-0.75", None], @@ -1902,11 +1902,11 @@ def test_binops_with_NA_consistent(dtype, op): ( operator.mul, ["1.5", None], - cudf.Decimal64Dtype(scale=2, precision=2), + cudf.Decimal64Dtype(scale=2, precision=3), ["1.5", None], cudf.Decimal64Dtype(scale=3, precision=4), ["2.25", None], - cudf.Decimal64Dtype(scale=5, precision=7), + cudf.Decimal64Dtype(scale=5, precision=8), ), ( operator.mul, @@ -2432,10 +2432,10 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected): ( operator.truediv, ["100", "200"], - cudf.Decimal64Dtype(scale=2, precision=4), + cudf.Decimal64Dtype(scale=2, precision=5), decimal.Decimal(2), ["50", "100"], - cudf.Decimal64Dtype(scale=2, precision=6), + cudf.Decimal64Dtype(scale=2, precision=7), False, ), ( @@ -2459,10 +2459,10 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected): ( operator.truediv, ["100", "200"], - cudf.Decimal64Dtype(scale=2, precision=3), + cudf.Decimal64Dtype(scale=2, precision=5), 1, ["0", "0"], - cudf.Decimal64Dtype(scale=-2, precision=5), + cudf.Decimal64Dtype(scale=-2, precision=7), True, ), ( diff --git a/python/cudf/cudf/tests/test_decimal.py b/python/cudf/cudf/tests/test_decimal.py index d2de44b0c8f..9d93898dcd9 100644 --- a/python/cudf/cudf/tests/test_decimal.py +++ b/python/cudf/cudf/tests/test_decimal.py @@ -24,7 +24,7 @@ [1], [-1], [1, 2, 3, 4], - [42, 1729, 4104], + [42, 17, 41], [1, 2, None, 4], [None, None, None], [], @@ -347,3 +347,11 @@ def test_serialize_decimal_columns(data): df = cudf.DataFrame(data) recreated = df.__class__.deserialize(*df.serialize()) assert_eq(recreated, df) + + +def test_decimal_invalid_precision(): + with pytest.raises(pa.ArrowInvalid): + _ = cudf.Series([10, 20, 30], dtype=cudf.Decimal64Dtype(2, 2)) + + with pytest.raises(pa.ArrowInvalid): + _ = cudf.Series([Decimal("300")], dtype=cudf.Decimal64Dtype(2, 1))