Skip to content

Commit

Permalink
Refactoring column logic Part 1 (#8081)
Browse files Browse the repository at this point in the history
This PR is a first pass at refactoring `ColumnBase` and its subclasses to reduce redundancy and improve performance by avoiding runtime type checking. Many functions are implemented in the top-level class but dispatch on dtype, which can instead be accomplished via ducktyping. Additionally, other parts of `cudf` require various methods to be implemented by a column, but `ColumnBase` does not currently clearly delineate an interface, making it difficult to know what to rely on in classes like `Frame` and pushing dynamic type dispatch upstream in the call stack where it is even less efficient and causes substantial code duplication. This PR moves specialized implementations of certain methods into the appropriate subclasses of `ColumnBase` and establishes a base API in the parent class where appropriate. Since this change will be large, I plan to split it into a few different PRs. This PR primarily modifies `to_pandas`, `to_arrow`, and `__cuda_array_interface__`, along with a few other minor improvements.

Authors:
  - Vyas Ramasubramani (https://github.com/vyasr)

Approvers:
  - Ashwin Srinath (https://github.com/shwina)
  - Keith Kraus (https://github.com/kkraus14)

URL: #8081
  • Loading branch information
vyasr authored Apr 28, 2021
1 parent 663457b commit 0ca0e69
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 130 deletions.
24 changes: 20 additions & 4 deletions python/cudf/cudf/core/column/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import numpy as np
import pandas as pd
import pyarrow as pa
from numba import cuda

import cudf
Expand Down Expand Up @@ -1075,10 +1076,7 @@ def __cuda_array_interface__(self) -> Mapping[str, Any]:
" if you need this functionality."
)

def to_pandas(
self, index: ColumnLike = None, nullable: bool = False, **kwargs
) -> pd.Series:

def to_pandas(self, index: pd.Index = None, **kwargs) -> pd.Series:
if self.categories.dtype.kind == "f":
new_mask = bools_to_mask(self.notnull())
col = column.build_categorical_column(
Expand All @@ -1099,6 +1097,24 @@ def to_pandas(
)
return pd.Series(data, index=index)

def to_arrow(self) -> pa.Array:
"""Convert to PyArrow Array."""
# arrow doesn't support unsigned codes
signed_type = (
min_signed_type(self.codes.max())
if self.codes.size > 0
else np.int8
)
codes = self.codes.astype(signed_type)
categories = self.categories

out_indices = codes.to_arrow()
out_dictionary = categories.to_arrow()

return pa.DictionaryArray.from_arrays(
out_indices, out_dictionary, ordered=self.ordered,
)

@property
def values_host(self) -> np.ndarray:
"""
Expand Down
123 changes: 40 additions & 83 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
Callable,
Dict,
List,
Mapping,
Optional,
Sequence,
Tuple,
Expand Down Expand Up @@ -44,9 +43,7 @@
from cudf.core.dtypes import CategoricalDtype, IntervalDtype
from cudf.utils import ioutils, utils
from cudf.utils.dtypes import (
NUMERIC_TYPES,
check_cast_unsupported_dtype,
cudf_dtypes_to_pandas_dtypes,
get_time_unit,
is_categorical_dtype,
is_decimal_dtype,
Expand All @@ -56,7 +53,6 @@
is_scalar,
is_string_dtype,
is_struct_dtype,
min_signed_type,
min_unsigned_type,
np_to_pa_dtype,
)
Expand Down Expand Up @@ -116,22 +112,16 @@ def __repr__(self):
f"dtype: {self.dtype}"
)

def to_pandas(
self, index: ColumnLike = None, nullable: bool = False, **kwargs
) -> "pd.Series":
if nullable and self.dtype in cudf_dtypes_to_pandas_dtypes:
pandas_nullable_dtype = cudf_dtypes_to_pandas_dtypes[self.dtype]
arrow_array = self.to_arrow()
pandas_array = pandas_nullable_dtype.__from_arrow__(arrow_array)
pd_series = pd.Series(pandas_array, copy=False)
elif str(self.dtype) in NUMERIC_TYPES and self.null_count == 0:
pd_series = pd.Series(cupy.asnumpy(self.values), copy=False)
elif is_interval_dtype(self.dtype):
pd_series = pd.Series(
pd.IntervalDtype().__from_arrow__(self.to_arrow())
)
else:
pd_series = self.to_arrow().to_pandas(**kwargs)
def to_pandas(self, index: pd.Index = None, **kwargs) -> "pd.Series":
"""Convert object to pandas type.
The default implementation falls back to PyArrow for the conversion.
"""
# This default implementation does not handle nulls in any meaningful
# way, but must consume the parameter to avoid passing it to PyArrow
# (which does not recognize it).
kwargs.pop("nullable", None)
pd_series = self.to_arrow().to_pandas(**kwargs)

if index is not None:
pd_series.index = index
Expand Down Expand Up @@ -333,46 +323,14 @@ def to_arrow(self) -> pa.Array:
4
]
"""
if isinstance(self, cudf.core.column.CategoricalColumn):
# arrow doesn't support unsigned codes
signed_type = (
min_signed_type(self.codes.max())
if self.codes.size > 0
else np.int8
)
codes = self.codes.astype(signed_type)
categories = self.categories

out_indices = codes.to_arrow()
out_dictionary = categories.to_arrow()

return pa.DictionaryArray.from_arrays(
out_indices, out_dictionary, ordered=self.ordered,
)

if isinstance(self, cudf.core.column.StringColumn) and (
self.null_count == len(self)
):
return pa.NullArray.from_buffers(
pa.null(), len(self), [pa.py_buffer((b""))]
)

result = libcudf.interop.to_arrow(
return libcudf.interop.to_arrow(
libcudf.table.Table(
cudf.core.column_accessor.ColumnAccessor({"None": self})
),
[["None"]],
keep_index=False,
)["None"].chunk(0)

if isinstance(self.dtype, cudf.Decimal64Dtype):
result = result.view(
pa.decimal128(
scale=result.type.scale, precision=self.dtype.precision
)
)
return result

@classmethod
def from_arrow(cls, array: pa.Array) -> ColumnBase:
"""
Expand Down Expand Up @@ -838,7 +796,7 @@ def find_last_value(self, value: ScalarLike, closest: bool = False) -> int:
return indices[-1]

def append(self, other: ColumnBase) -> ColumnBase:
return ColumnBase._concat([self, as_column(other)])
return self.__class__._concat([self, as_column(other)])

def quantile(
self,
Expand Down Expand Up @@ -890,9 +848,6 @@ def isin(self, values: Sequence) -> ColumnBase:
result: Column
Column of booleans indicating if each element is in values.
"""
lhs = self
rhs = None

try:
lhs, rhs = self._process_values_for_isin(values)
res = lhs._isin_earlystop(rhs)
Expand Down Expand Up @@ -1167,32 +1122,26 @@ def argsort(
)
return sorted_indices

@property
def __cuda_array_interface__(self) -> Mapping[builtins.str, Any]:
output = {
"shape": (len(self),),
"strides": (self.dtype.itemsize,),
"typestr": self.dtype.str,
"data": (self.data_ptr, False),
"version": 1,
}

if self.nullable and self.has_nulls:

# Create a simple Python object that exposes the
# `__cuda_array_interface__` attribute here since we need to modify
# some of the attributes from the numba device array
mask = SimpleNamespace(
__cuda_array_interface__={
"shape": (len(self),),
"typestr": "<t1",
"data": (self.mask_ptr, True),
"version": 1,
}
)
output["mask"] = mask
def __arrow_array__(self, type=None):
raise TypeError(
"Implicit conversion to a host PyArrow Array via __arrow_array__ "
"is not allowed, To explicitly construct a PyArrow Array, "
"consider using .to_arrow()"
)

return output
def __array__(self, dtype=None):
raise TypeError(
"Implicit conversion to a host NumPy array via __array__ is not "
"allowed. To explicitly construct a host array, consider using "
".to_array()"
)

@property
def __cuda_array_interface__(self):
raise NotImplementedError(
f"dtype {self.dtype} is not yet supported via "
"`__cuda_array_interface__`"
)

def __add__(self, other):
return self.binary_operator("add", other)
Expand Down Expand Up @@ -1291,10 +1240,18 @@ def deserialize(cls, header: dict, frames: list) -> ColumnBase:
data=data, dtype=dtype, mask=mask, size=header.get("size", None)
)

def unary_operator(self, unaryop: builtins.str):
raise TypeError(
f"Operation {unaryop} not supported for dtype {self.dtype}."
)

def binary_operator(
self, op: builtins.str, other: BinaryOperand, reflect: bool = False
) -> ColumnBase:
raise NotImplementedError
raise TypeError(
f"Operation {op} not supported between dtypes {self.dtype} and "
f"{other.dtype}."
)

def min(self, skipna: bool = None, dtype: Dtype = None):
result_col = self._process_for_reduction(skipna=skipna)
Expand Down
44 changes: 35 additions & 9 deletions python/cudf/cudf/core/column/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

from __future__ import annotations

import builtins
import datetime as dt
import re
from numbers import Number
from typing import Any, Sequence, Union, cast
from types import SimpleNamespace
from typing import Any, Mapping, Sequence, Union, cast

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -133,21 +135,18 @@ def weekday(self) -> ColumnBase:
return self.get_dt_field("weekday")

def to_pandas(
self, index: "cudf.Index" = None, nullable: bool = False, **kwargs
self, index: pd.Index = None, nullable: bool = False, **kwargs
) -> "cudf.Series":
# Workaround until following issue is fixed:
# https://issues.apache.org/jira/browse/ARROW-9772

# Pandas supports only `datetime64[ns]`, hence the cast.
pd_series = pd.Series(
self.astype("datetime64[ns]").to_array("NAT"), copy=False
return pd.Series(
self.astype("datetime64[ns]").to_array("NAT"),
copy=False,
index=index,
)

if index is not None:
pd_series.index = index

return pd_series

def get_dt_field(self, field: str) -> ColumnBase:
return libcudf.datetime.extract_datetime_component(self, field)

Expand Down Expand Up @@ -202,6 +201,33 @@ def as_numerical(self) -> "cudf.core.column.NumericalColumn":
),
)

@property
def __cuda_array_interface__(self) -> Mapping[builtins.str, Any]:
output = {
"shape": (len(self),),
"strides": (self.dtype.itemsize,),
"typestr": self.dtype.str,
"data": (self.data_ptr, False),
"version": 1,
}

if self.nullable and self.has_nulls:

# Create a simple Python object that exposes the
# `__cuda_array_interface__` attribute here since we need to modify
# some of the attributes from the numba device array
mask = SimpleNamespace(
__cuda_array_interface__={
"shape": (len(self),),
"typestr": "<t1",
"data": (self.mask_ptr, True),
"version": 1,
}
)
output["mask"] = mask

return output

def as_datetime_column(self, dtype: Dtype, **kwargs) -> DatetimeColumn:
dtype = np.dtype(dtype)
if dtype == self.dtype:
Expand Down
12 changes: 12 additions & 0 deletions python/cudf/cudf/core/column/interval.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (c) 2018-2021, NVIDIA CORPORATION.
import pandas as pd
import pyarrow as pa

import cudf
from cudf.core.column import StructColumn
from cudf.core.dtypes import IntervalDtype
Expand Down Expand Up @@ -110,3 +112,13 @@ def as_interval_column(self, dtype, **kwargs):
)
else:
raise ValueError("dtype must be IntervalDtype")

def to_pandas(self, index: pd.Index = None, **kwargs) -> "pd.Series":
# Note: This does not handle null values in the interval column.
# However, this exact sequence (calling __from_arrow__ on the output of
# self.to_arrow) is currently the best known way to convert interval
# types into pandas (trying to convert the underlying numerical columns
# directly is problematic), so we're stuck with this for now.
return pd.Series(
pd.IntervalDtype().__from_arrow__(self.to_arrow()), index=index
)
Loading

0 comments on commit 0ca0e69

Please sign in to comment.