diff --git a/python/cudf/cudf/core/column/categorical.py b/python/cudf/cudf/core/column/categorical.py index c199947d261..756f175c238 100644 --- a/python/cudf/cudf/core/column/categorical.py +++ b/python/cudf/cudf/core/column/categorical.py @@ -1506,6 +1506,28 @@ def _concat(objs: MutableSequence[CategoricalColumn]) -> CategoricalColumn: offset=codes_col.offset, ) + def _copy_type_metadata( + self: CategoricalColumn, other: ColumnBase + ) -> ColumnBase: + """Copies type metadata from self onto other, returning a new column. + + In addition to the default behavior, if `other` is not a + CategoricalColumn, we assume other is a column of codes, and return a + CategoricalColumn composed of `other` and the categories of `self`. + """ + if not isinstance(other, cudf.core.column.CategoricalColumn): + other = column.build_categorical_column( + categories=self.categories, + codes=column.as_column(other.base_data, dtype=other.dtype), + mask=other.base_mask, + ordered=self.ordered, + size=other.size, + offset=other.offset, + null_count=other.null_count, + ) + # Have to ignore typing here because it misdiagnoses super(). + return super()._copy_type_metadata(other) # type: ignore + def _create_empty_categorical_column( categorical_column: CategoricalColumn, dtype: "CategoricalDtype" diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index 4bf4b2b87f2..a58b2eda822 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -8,7 +8,6 @@ from types import SimpleNamespace from typing import ( Any, - Callable, Dict, List, MutableSequence, @@ -310,16 +309,6 @@ def _memory_usage(self, **kwargs) -> int: def default_na_value(self) -> Any: raise NotImplementedError() - def applymap( - self, udf: Callable[[ScalarLike], ScalarLike], out_dtype: Dtype = None - ) -> ColumnBase: - """Apply an element-wise function to the values in the Column.""" - # Subclasses that support applymap must override this behavior. - raise TypeError( - "User-defined functions are currently not supported on data " - f"with dtype {self.dtype}." - ) - def to_gpu_array(self, fillna=None) -> "cuda.devicearray.DeviceNDArray": """Get a dense numba device array for the data. @@ -1139,6 +1128,11 @@ def binary_operator( f"{other.dtype}." ) + def normalize_binop_value( + self, other: ScalarLike + ) -> Union[ColumnBase, ScalarLike]: + raise NotImplementedError + def min(self, skipna: bool = None, dtype: Dtype = None): result_col = self._process_for_reduction(skipna=skipna) if isinstance(result_col, ColumnBase): @@ -1273,46 +1267,18 @@ def scatter_to_table( } ) - def _copy_type_metadata(self: T, other: ColumnBase) -> ColumnBase: + def _copy_type_metadata(self: ColumnBase, other: ColumnBase) -> ColumnBase: """ Copies type metadata from self onto other, returning a new column. - * when `self` is a CategoricalColumn and `other` is not, we assume - other is a column of codes, and return a CategoricalColumn composed - of `other` and the categories of `self`. - * when both `self` and `other` are StructColumns, rename the fields - of `other` to the field names of `self`. - * when both `self` and `other` are DecimalColumns, copy the precision - from self.dtype to other.dtype * when `self` and `other` are nested columns of the same type, recursively apply this function on the children of `self` to the and the children of `other`. * if none of the above, return `other` without any changes """ - if isinstance(self, cudf.core.column.CategoricalColumn) and not ( - isinstance(other, cudf.core.column.CategoricalColumn) - ): - other = build_categorical_column( - categories=self.categories, - codes=as_column(other.base_data, dtype=other.dtype), - mask=other.base_mask, - ordered=self.ordered, - size=other.size, - offset=other.offset, - null_count=other.null_count, - ) - - if isinstance(other, cudf.core.column.StructColumn) and isinstance( - self, cudf.core.column.StructColumn - ): - other = other._rename_fields(self.dtype.fields.keys()) - - if isinstance(other, cudf.core.column.DecimalColumn) and isinstance( - self, cudf.core.column.DecimalColumn - ): - other.dtype.precision = self.dtype.precision - - if type(self) is type(other): + # TODO: This logic should probably be moved to a common nested column + # class. + if isinstance(other, type(self)): if self.base_children and other.base_children: base_children = tuple( self.base_children[i]._copy_type_metadata( diff --git a/python/cudf/cudf/core/column/decimal.py b/python/cudf/cudf/core/column/decimal.py index e3d88424b8a..907ec23c468 100644 --- a/python/cudf/cudf/core/column/decimal.py +++ b/python/cudf/cudf/core/column/decimal.py @@ -1,7 +1,6 @@ # Copyright (c) 2021, NVIDIA CORPORATION. from decimal import Decimal -from numbers import Number from typing import Any, Sequence, Tuple, Union, cast import cupy as cp @@ -22,11 +21,16 @@ from cudf.utils.dtypes import is_scalar from cudf.utils.utils import pa_mask_buffer_to_mask +from .numerical_base import NumericalBaseColumn -class DecimalColumn(ColumnBase): + +class DecimalColumn(NumericalBaseColumn): dtype: Decimal64Dtype def __truediv__(self, other): + # TODO: This override is not sufficient. While it will change the + # behavior of x / y for two decimal columns, it will not affect + # col1.binary_operator(col2), which is how Series/Index will call this. return self.binary_operator("div", other) def __setitem__(self, key, value): @@ -123,39 +127,6 @@ def normalize_binop_value(self, other): else: raise TypeError(f"cannot normalize {type(other)}") - def _apply_scan_op(self, op: str) -> ColumnBase: - result = libcudf.reduce.scan(op, self, True) - return self._copy_type_metadata(result) - - def quantile( - self, q: Union[float, Sequence[float]], interpolation: str, exact: bool - ) -> ColumnBase: - if isinstance(q, Number) or cudf.utils.dtypes.is_list_like(q): - np_array_q = np.asarray(q) - if np.logical_or(np_array_q < 0, np_array_q > 1).any(): - raise ValueError( - "percentiles should all be in the interval [0, 1]" - ) - # Beyond this point, q either being scalar or list-like - # will only have values in range [0, 1] - result = self._decimal_quantile(q, interpolation, exact) - if isinstance(q, Number): - return ( - cudf.utils.dtypes._get_nan_for_dtype(self.dtype) - if result[0] is cudf.NA - else result[0] - ) - return result - - def median(self, skipna: bool = None) -> ColumnBase: - skipna = True if skipna is None else skipna - - if not skipna and self.has_nulls: - return cudf.utils.dtypes._get_nan_for_dtype(self.dtype) - - # enforce linear in case the default ever changes - return self.quantile(0.5, interpolation="linear", exact=True) - def _decimal_quantile( self, q: Union[float, Sequence[float]], interpolation: str, exact: bool ) -> ColumnBase: @@ -194,37 +165,6 @@ def as_string_column( "cudf.core.column.StringColumn", as_column([], dtype="object") ) - def reduce(self, op: str, skipna: bool = None, **kwargs) -> Decimal: - min_count = kwargs.pop("min_count", 0) - preprocessed = self._process_for_reduction( - skipna=skipna, min_count=min_count - ) - if isinstance(preprocessed, ColumnBase): - return libcudf.reduce.reduce(op, preprocessed, **kwargs) - else: - return preprocessed - - def sum( - self, skipna: bool = None, dtype: Dtype = None, min_count: int = 0 - ) -> Decimal: - return self.reduce( - "sum", skipna=skipna, dtype=dtype, min_count=min_count - ) - - def product( - self, skipna: bool = None, dtype: Dtype = None, min_count: int = 0 - ) -> Decimal: - return self.reduce( - "product", skipna=skipna, dtype=dtype, min_count=min_count - ) - - def sum_of_squares( - self, skipna: bool = None, dtype: Dtype = None, min_count: int = 0 - ) -> Decimal: - return self.reduce( - "sum_of_squares", skipna=skipna, dtype=dtype, min_count=min_count - ) - def fillna( self, value: Any = None, method: str = None, dtype: Dtype = None ): @@ -269,6 +209,17 @@ def __cuda_array_interface__(self): "Decimals are not yet supported via `__cuda_array_interface__`" ) + def _copy_type_metadata(self: ColumnBase, other: ColumnBase) -> ColumnBase: + """Copies type metadata from self onto other, returning a new column. + + In addition to the default behavior, if `other` is also a decimal + column the precision is copied over. + """ + if isinstance(other, DecimalColumn): + other.dtype.precision = self.dtype.precision # type: ignore + # Have to ignore typing here because it misdiagnoses super(). + return super()._copy_type_metadata(other) # type: ignore + def _binop_scale(l_dtype, r_dtype, op): # This should at some point be hooked up to libcudf's diff --git a/python/cudf/cudf/core/column/numerical.py b/python/cudf/cudf/core/column/numerical.py index 39bbf10c235..e35cc744434 100644 --- a/python/cudf/cudf/core/column/numerical.py +++ b/python/cudf/cudf/core/column/numerical.py @@ -2,20 +2,16 @@ from __future__ import annotations -import builtins -from numbers import Number from types import SimpleNamespace -from typing import Any, Callable, Mapping, Sequence, Tuple, Union, cast +from typing import Any, Mapping, Sequence, Tuple, Union, cast import cupy import numpy as np import pandas as pd -from numba import cuda, njit from pandas.api.types import is_integer_dtype import cudf from cudf import _lib as libcudf -from cudf._lib.quantiles import quantile as cpp_quantile from cudf._typing import BinaryOperand, ColumnLike, Dtype, DtypeObj, ScalarLike from cudf.core.buffer import Buffer from cudf.core.column import ( @@ -23,7 +19,6 @@ as_column, build_column, column, - column_empty, string, ) from cudf.core.dtypes import Decimal64Dtype @@ -37,8 +32,10 @@ to_cudf_compatible_scalar, ) +from .numerical_base import NumericalBaseColumn -class NumericalColumn(ColumnBase): + +class NumericalColumn(NumericalBaseColumn): def __init__( self, data: Buffer, @@ -91,7 +88,7 @@ def __contains__(self, item: ScalarLike) -> bool: ).any() @property - def __cuda_array_interface__(self) -> Mapping[builtins.str, Any]: + def __cuda_array_interface__(self) -> Mapping[str, Any]: output = { "shape": (len(self),), "strides": (self.dtype.itemsize,), @@ -168,9 +165,6 @@ def binary_operator( lhs, rhs = (self, rhs) if not reflect else (rhs, self) return libcudf.binaryop.binaryop(lhs, rhs, binop, out_dtype) - def _apply_scan_op(self, op: str) -> ColumnBase: - return libcudf.reduce.scan(op, self, True) - def normalize_binop_value( self, other: ScalarLike ) -> Union[ColumnBase, ScalarLike]: @@ -264,43 +258,6 @@ def as_numerical_column(self, dtype: Dtype) -> NumericalColumn: return self return libcudf.unary.cast(self, dtype) - def reduce(self, op: str, skipna: bool = None, **kwargs) -> float: - min_count = kwargs.pop("min_count", 0) - preprocessed = self._process_for_reduction( - skipna=skipna, min_count=min_count - ) - if isinstance(preprocessed, ColumnBase): - return libcudf.reduce.reduce(op, preprocessed, **kwargs) - else: - return cast(float, preprocessed) - - def sum( - self, skipna: bool = None, dtype: Dtype = None, min_count: int = 0 - ) -> float: - return self.reduce( - "sum", skipna=skipna, dtype=dtype, min_count=min_count - ) - - def product( - self, skipna: bool = None, dtype: Dtype = None, min_count: int = 0 - ) -> float: - return self.reduce( - "product", skipna=skipna, dtype=dtype, min_count=min_count - ) - - def mean(self, skipna: bool = None, dtype: Dtype = np.float64) -> float: - return self.reduce("mean", skipna=skipna, dtype=dtype) - - def var( - self, skipna: bool = None, ddof: int = 1, dtype: Dtype = np.float64 - ) -> float: - return self.reduce("var", skipna=skipna, dtype=dtype, ddof=ddof) - - def std( - self, skipna: bool = None, ddof: int = 1, dtype: Dtype = np.float64 - ) -> float: - return self.reduce("std", skipna=skipna, dtype=dtype, ddof=ddof) - def _process_values_for_isin( self, values: Sequence ) -> Tuple[ColumnBase, ColumnBase]: @@ -317,163 +274,6 @@ def _process_values_for_isin( return lhs, rhs - def sum_of_squares(self, dtype: Dtype = None) -> float: - return libcudf.reduce.reduce("sum_of_squares", self, dtype=dtype) - - def kurtosis(self, skipna: bool = None) -> float: - skipna = True if skipna is None else skipna - - if len(self) == 0 or (not skipna and self.has_nulls): - return cudf.utils.dtypes._get_nan_for_dtype(self.dtype) - - self = self.nans_to_nulls().dropna() # type: ignore - - if len(self) < 4: - return cudf.utils.dtypes._get_nan_for_dtype(self.dtype) - - n = len(self) - miu = self.mean() - m4_numerator = ((self - miu) ** self.normalize_binop_value(4)).sum() - V = self.var() - - if V == 0: - return 0 - - term_one_section_one = (n * (n + 1)) / ((n - 1) * (n - 2) * (n - 3)) - term_one_section_two = m4_numerator / (V ** 2) - term_two = ((n - 1) ** 2) / ((n - 2) * (n - 3)) - kurt = term_one_section_one * term_one_section_two - 3 * term_two - return kurt - - def skew(self, skipna: bool = None) -> ScalarLike: - skipna = True if skipna is None else skipna - - if len(self) == 0 or (not skipna and self.has_nulls): - return cudf.utils.dtypes._get_nan_for_dtype(self.dtype) - - self = self.nans_to_nulls().dropna() # type: ignore - - if len(self) < 3: - return cudf.utils.dtypes._get_nan_for_dtype(self.dtype) - - n = len(self) - miu = self.mean() - m3 = (((self - miu) ** self.normalize_binop_value(3)).sum()) / n - m2 = self.var(ddof=0) - - if m2 == 0: - return 0 - - unbiased_coef = ((n * (n - 1)) ** 0.5) / (n - 2) - skew = unbiased_coef * m3 / (m2 ** (3 / 2)) - return skew - - def quantile( - self, q: Union[float, Sequence[float]], interpolation: str, exact: bool - ) -> NumericalColumn: - if isinstance(q, Number) or cudf.utils.dtypes.is_list_like(q): - np_array_q = np.asarray(q) - if np.logical_or(np_array_q < 0, np_array_q > 1).any(): - raise ValueError( - "percentiles should all be in the interval [0, 1]" - ) - # Beyond this point, q either being scalar or list-like - # will only have values in range [0, 1] - result = self._numeric_quantile(q, interpolation, exact) - if isinstance(q, Number): - return ( - cudf.utils.dtypes._get_nan_for_dtype(self.dtype) - if result[0] is cudf.NA - else result[0] - ) - return result - - def median(self, skipna: bool = None) -> NumericalColumn: - skipna = True if skipna is None else skipna - - if not skipna and self.has_nulls: - return cudf.utils.dtypes._get_nan_for_dtype(self.dtype) - - # enforce linear in case the default ever changes - return self.quantile(0.5, interpolation="linear", exact=True) - - def _numeric_quantile( - self, q: Union[float, Sequence[float]], interpolation: str, exact: bool - ) -> NumericalColumn: - quant = [float(q)] if not isinstance(q, (Sequence, np.ndarray)) else q - # get sorted indices and exclude nulls - sorted_indices = self.as_frame()._get_sorted_inds( - ascending=True, na_position="first" - ) - sorted_indices = sorted_indices[self.null_count :] - - return cpp_quantile(self, quant, interpolation, sorted_indices, exact) - - def cov(self, other: ColumnBase) -> float: - if ( - len(self) == 0 - or len(other) == 0 - or (len(self) == 1 and len(other) == 1) - ): - return cudf.utils.dtypes._get_nan_for_dtype(self.dtype) - - result = (self - self.mean()) * (other - other.mean()) - cov_sample = result.sum() / (len(self) - 1) - return cov_sample - - def corr(self, other: ColumnBase) -> float: - if len(self) == 0 or len(other) == 0: - return cudf.utils.dtypes._get_nan_for_dtype(self.dtype) - - cov = self.cov(other) - lhs_std, rhs_std = self.std(), other.std() - - if not cov or lhs_std == 0 or rhs_std == 0: - return cudf.utils.dtypes._get_nan_for_dtype(self.dtype) - return cov / lhs_std / rhs_std - - def round(self, decimals: int = 0) -> NumericalColumn: - """Round the values in the Column to the given number of decimals. - """ - return libcudf.round.round(self, decimal_places=decimals) - - def applymap( - self, udf: Callable[[ScalarLike], ScalarLike], out_dtype: Dtype = None - ) -> ColumnBase: - """Apply an elementwise function to transform the values in the Column. - - Parameters - ---------- - udf : function - Wrapped by numba jit for call on the GPU as a device function. - out_dtype : numpy.dtype; optional - The dtype for use in the output. - By default, use the same dtype as *self.dtype*. - - Returns - ------- - result : Column - The mask is preserved. - """ - if out_dtype is None: - out_dtype = self.dtype - - core = njit(udf) - - # For non-masked columns - @cuda.jit - def kernel_applymap(values, results): - i = cuda.grid(1) - # in range? - if i < values.size: - # call udf - results[i] = core(values[i]) - - results = column_empty(self.size, dtype=out_dtype) - values = self.data_array_view - kernel_applymap.forall(self.size)(values, results) - return as_column(results) - def default_na_value(self) -> ScalarLike: """Returns the default NA value for this column """ diff --git a/python/cudf/cudf/core/column/numerical_base.py b/python/cudf/cudf/core/column/numerical_base.py new file mode 100644 index 00000000000..fd62b58db9b --- /dev/null +++ b/python/cudf/cudf/core/column/numerical_base.py @@ -0,0 +1,200 @@ +# Copyright (c) 2018-2021, NVIDIA CORPORATION. +"""Define an interface for columns that can perform numerical operations.""" + +from __future__ import annotations + +from numbers import Number +from typing import Sequence, Union + +import numpy as np + +import cudf +from cudf import _lib as libcudf +from cudf._typing import Dtype, ScalarLike +from cudf.core.column import ColumnBase + + +class NumericalBaseColumn(ColumnBase): + """A column composed of numerical data. + + This class encodes a standard interface for different types of columns + containing numerical types of data. In particular, mathematical operations + that make sense whether a column is integral or real, fixed or floating + point, should be encoded here. + """ + + def reduce( + self, op: str, skipna: bool = None, min_count: int = 0, **kwargs + ) -> ScalarLike: + """Perform a reduction operation. + + op : str + The operation to perform. + skipna : bool + Whether or not na values must be + """ + preprocessed = self._process_for_reduction( + skipna=skipna, min_count=min_count + ) + if isinstance(preprocessed, ColumnBase): + return libcudf.reduce.reduce(op, preprocessed, **kwargs) + else: + return preprocessed + + def sum( + self, skipna: bool = None, dtype: Dtype = None, min_count: int = 0 + ) -> ScalarLike: + return self.reduce( + "sum", skipna=skipna, dtype=dtype, min_count=min_count + ) + + def product( + self, skipna: bool = None, dtype: Dtype = None, min_count: int = 0 + ) -> ScalarLike: + return self.reduce( + "product", skipna=skipna, dtype=dtype, min_count=min_count + ) + + def mean( + self, skipna: bool = None, dtype: Dtype = np.float64 + ) -> ScalarLike: + return self.reduce("mean", skipna=skipna, dtype=dtype) + + def var( + self, skipna: bool = None, ddof: int = 1, dtype: Dtype = np.float64 + ) -> ScalarLike: + return self.reduce("var", skipna=skipna, dtype=dtype, ddof=ddof) + + def std( + self, skipna: bool = None, ddof: int = 1, dtype: Dtype = np.float64 + ) -> ScalarLike: + return self.reduce("std", skipna=skipna, dtype=dtype, ddof=ddof) + + def sum_of_squares( + self, skipna: bool = None, dtype: Dtype = None, min_count: int = 0 + ) -> ScalarLike: + return self.reduce( + "sum_of_squares", skipna=skipna, dtype=dtype, min_count=min_count + ) + + def kurtosis(self, skipna: bool = None) -> float: + skipna = True if skipna is None else skipna + + if len(self) == 0 or (not skipna and self.has_nulls): + return cudf.utils.dtypes._get_nan_for_dtype(self.dtype) + + self = self.nans_to_nulls().dropna() # type: ignore + + if len(self) < 4: + return cudf.utils.dtypes._get_nan_for_dtype(self.dtype) + + n = len(self) + miu = self.mean() + m4_numerator = ((self - miu) ** self.normalize_binop_value(4)).sum() + V = self.var() + + if V == 0: + return 0 + + term_one_section_one = (n * (n + 1)) / ((n - 1) * (n - 2) * (n - 3)) + term_one_section_two = m4_numerator / (V ** 2) + term_two = ((n - 1) ** 2) / ((n - 2) * (n - 3)) + kurt = term_one_section_one * term_one_section_two - 3 * term_two + return kurt + + def skew(self, skipna: bool = None) -> ScalarLike: + skipna = True if skipna is None else skipna + + if len(self) == 0 or (not skipna and self.has_nulls): + return cudf.utils.dtypes._get_nan_for_dtype(self.dtype) + + self = self.nans_to_nulls().dropna() # type: ignore + + if len(self) < 3: + return cudf.utils.dtypes._get_nan_for_dtype(self.dtype) + + n = len(self) + miu = self.mean() + m3 = (((self - miu) ** self.normalize_binop_value(3)).sum()) / n + m2 = self.var(ddof=0) + + if m2 == 0: + return 0 + + unbiased_coef = ((n * (n - 1)) ** 0.5) / (n - 2) + skew = unbiased_coef * m3 / (m2 ** (3 / 2)) + return skew + + def quantile( + self, q: Union[float, Sequence[float]], interpolation: str, exact: bool + ) -> NumericalBaseColumn: + if isinstance(q, Number) or cudf.utils.dtypes.is_list_like(q): + np_array_q = np.asarray(q) + if np.logical_or(np_array_q < 0, np_array_q > 1).any(): + raise ValueError( + "percentiles should all be in the interval [0, 1]" + ) + # Beyond this point, q either being scalar or list-like + # will only have values in range [0, 1] + result = self._numeric_quantile(q, interpolation, exact) + if isinstance(q, Number): + return ( + cudf.utils.dtypes._get_nan_for_dtype(self.dtype) + if result[0] is cudf.NA + else result[0] + ) + return result + + def median(self, skipna: bool = None) -> NumericalBaseColumn: + skipna = True if skipna is None else skipna + + if not skipna and self.has_nulls: + return cudf.utils.dtypes._get_nan_for_dtype(self.dtype) + + # enforce linear in case the default ever changes + return self.quantile(0.5, interpolation="linear", exact=True) + + def _numeric_quantile( + self, q: Union[float, Sequence[float]], interpolation: str, exact: bool + ) -> NumericalBaseColumn: + quant = [float(q)] if not isinstance(q, (Sequence, np.ndarray)) else q + # get sorted indices and exclude nulls + sorted_indices = self.as_frame()._get_sorted_inds( + ascending=True, na_position="first" + ) + sorted_indices = sorted_indices[self.null_count :] + + return libcudf.quantiles.quantile( + self, quant, interpolation, sorted_indices, exact + ) + + def cov(self, other: ColumnBase) -> float: + if ( + len(self) == 0 + or len(other) == 0 + or (len(self) == 1 and len(other) == 1) + ): + return cudf.utils.dtypes._get_nan_for_dtype(self.dtype) + + result = (self - self.mean()) * (other - other.mean()) + cov_sample = result.sum() / (len(self) - 1) + return cov_sample + + def corr(self, other: ColumnBase) -> float: + if len(self) == 0 or len(other) == 0: + return cudf.utils.dtypes._get_nan_for_dtype(self.dtype) + + cov = self.cov(other) + lhs_std, rhs_std = self.std(), other.std() + + if not cov or lhs_std == 0 or rhs_std == 0: + return cudf.utils.dtypes._get_nan_for_dtype(self.dtype) + return cov / lhs_std / rhs_std + + def round(self, decimals: int = 0) -> NumericalBaseColumn: + """Round the values in the Column to the given number of decimals. + """ + return libcudf.round.round(self, decimal_places=decimals) + + def _apply_scan_op(self, op: str) -> ColumnBase: + return self._copy_type_metadata(libcudf.reduce.scan(op, self, True)) diff --git a/python/cudf/cudf/core/column/struct.py b/python/cudf/cudf/core/column/struct.py index c2b820d0b43..3c47f30dd15 100644 --- a/python/cudf/cudf/core/column/struct.py +++ b/python/cudf/cudf/core/column/struct.py @@ -111,6 +111,19 @@ def __cuda_array_interface__(self): "Structs are not yet supported via `__cuda_array_interface__`" ) + def _copy_type_metadata(self: ColumnBase, other: ColumnBase) -> ColumnBase: + """Copies type metadata from self onto other, returning a new column. + + In addition to the default behavior, if `other` is a StructColumns we + rename the fields of `other` to the field names of `self`. + """ + if isinstance(other, cudf.core.column.StructColumn): + other = other._rename_fields( + self.dtype.fields.keys() # type: ignore + ) + # Have to ignore typing here because it misdiagnoses super(). + return super()._copy_type_metadata(other) # type: ignore + class StructMethods(ColumnMethodsMixin): """ diff --git a/python/cudf/cudf/core/series.py b/python/cudf/cudf/core/series.py index a894baf8235..c5a7b07d778 100644 --- a/python/cudf/cudf/core/series.py +++ b/python/cudf/cudf/core/series.py @@ -46,13 +46,13 @@ from cudf.utils.docutils import copy_docstring from cudf.utils.dtypes import ( can_convert_to_column, + find_common_type, is_decimal_dtype, is_list_dtype, is_list_like, is_mixed_with_object_dtype, is_scalar, min_scalar_type, - find_common_type, ) from cudf.utils.utils import ( get_appropriate_dispatched_func, @@ -3967,11 +3967,9 @@ def applymap(self, udf, out_dtype=None): 4 105 dtype: int64 """ - if callable(udf): - res_col = self._unaryop(udf) - else: - res_col = self._column.applymap(udf, out_dtype=out_dtype) - return self._copy_construct(data=res_col) + if not callable(udf): + raise ValueError("Input UDF must be a callable object.") + return self._copy_construct(data=self._unaryop(udf)) # # Stats diff --git a/python/cudf/cudf/tests/test_transform.py b/python/cudf/cudf/tests/test_transform.py index 6ec5f88be48..ed409de196e 100644 --- a/python/cudf/cudf/tests/test_transform.py +++ b/python/cudf/cudf/tests/test_transform.py @@ -11,38 +11,25 @@ supported_types = NUMERIC_TYPES -@pytest.mark.parametrize("dtype", supported_types) -def test_applymap(dtype): - - size = 500 - - lhs_arr = np.random.random(size).astype(dtype) - lhs_col = Series(lhs_arr)._column - - def generic_function(a): - return a ** 3 - - out_col = lhs_col.applymap(generic_function) - - result = lhs_arr ** 3 - - np.testing.assert_almost_equal(result, out_col.to_array()) +def _generic_function(a): + return a ** 3 @pytest.mark.parametrize("dtype", supported_types) -def test_applymap_python_lambda(dtype): +@pytest.mark.parametrize( + "udf,testfunc", + [ + (_generic_function, lambda ser: ser ** 3), + (lambda x: x in [1, 2, 3, 4], lambda ser: np.isin(ser, [1, 2, 3, 4])), + ], +) +def test_applymap_python_lambda(dtype, udf, testfunc): size = 500 lhs_arr = np.random.random(size).astype(dtype) lhs_ser = Series(lhs_arr) - # Note that the lambda has to be written this way. - # In other words, the following code does NOT compile with numba: - # test_list = [1, 2, 3, 4] - # out_ser = lhs_ser.applymap(lambda x: x in test_list) - out_ser = lhs_ser.applymap(lambda x: x in [1, 2, 3, 4]) - - result = np.isin(lhs_arr, [1, 2, 3, 4]) - + out_ser = lhs_ser.applymap(udf) + result = testfunc(lhs_arr) np.testing.assert_almost_equal(result, out_ser.to_array())