Skip to content

Commit

Permalink
Use explicit construction of column subclass instead of `build_column…
Browse files Browse the repository at this point in the history
…` when type is known (#16470)

When we need to construct a column with a specific type, we do not need to go through the indirection of `build_column`, which matches a column subclass to a passed type, and instead construct directly from the class instead

Authors:
  - Matthew Roeschke (https://github.com/mroeschke)

Approvers:
  - Thomas Li (https://github.com/lithomas1)

URL: #16470
  • Loading branch information
mroeschke authored Aug 2, 2024
1 parent a8a3670 commit cc19d8a
Show file tree
Hide file tree
Showing 9 changed files with 64 additions and 65 deletions.
2 changes: 1 addition & 1 deletion python/cudf/cudf/core/_internals/where.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def _make_categorical_like(result, column):
if isinstance(column, cudf.core.column.CategoricalColumn):
result = cudf.core.column.build_categorical_column(
categories=column.categories,
codes=cudf.core.column.build_column(
codes=cudf.core.column.NumericalColumn(
result.base_data, dtype=result.dtype
),
mask=result.base_mask,
Expand Down
46 changes: 28 additions & 18 deletions python/cudf/cudf/core/column/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,13 +572,10 @@ def children(self) -> tuple[NumericalColumn]:
codes_column = self.base_children[0]
start = self.offset * codes_column.dtype.itemsize
end = start + self.size * codes_column.dtype.itemsize
codes_column = cast(
cudf.core.column.NumericalColumn,
column.build_column(
data=codes_column.base_data[start:end],
dtype=codes_column.dtype,
size=self.size,
),
codes_column = cudf.core.column.NumericalColumn(
data=codes_column.base_data[start:end],
dtype=codes_column.dtype,
size=self.size,
)
self._children = (codes_column,)
return self._children
Expand Down Expand Up @@ -660,8 +657,9 @@ def slice(self, start: int, stop: int, stride: int | None = None) -> Self:
Self,
cudf.core.column.build_categorical_column(
categories=self.categories,
codes=cudf.core.column.build_column(
codes.base_data, dtype=codes.dtype
codes=cudf.core.column.NumericalColumn(
codes.base_data, # type: ignore[arg-type]
dtype=codes.dtype,
),
mask=codes.base_mask,
ordered=self.ordered,
Expand Down Expand Up @@ -734,7 +732,10 @@ def sort_values(
codes = self.codes.sort_values(ascending, na_position)
col = column.build_categorical_column(
categories=self.dtype.categories._values,
codes=column.build_column(codes.base_data, dtype=codes.dtype),
codes=cudf.core.column.NumericalColumn(
codes.base_data, # type: ignore[arg-type]
dtype=codes.dtype,
),
mask=codes.base_mask,
size=codes.size,
ordered=self.dtype.ordered,
Expand Down Expand Up @@ -842,7 +843,10 @@ def unique(self) -> CategoricalColumn:
codes = self.codes.unique()
return column.build_categorical_column(
categories=self.categories,
codes=column.build_column(codes.base_data, dtype=codes.dtype),
codes=cudf.core.column.NumericalColumn(
codes.base_data, # type: ignore[arg-type]
dtype=codes.dtype,
),
mask=codes.base_mask,
offset=codes.offset,
size=codes.size,
Expand Down Expand Up @@ -980,7 +984,9 @@ def find_and_replace(

result = column.build_categorical_column(
categories=new_cats["cats"],
codes=column.build_column(output.base_data, dtype=output.dtype),
codes=cudf.core.column.NumericalColumn(
output.base_data, dtype=output.dtype
),
mask=output.base_mask,
offset=output.offset,
size=output.size,
Expand Down Expand Up @@ -1176,8 +1182,9 @@ def _concat(

return column.build_categorical_column(
categories=column.as_column(cats),
codes=column.build_column(
codes_col.base_data, dtype=codes_col.dtype
codes=cudf.core.column.NumericalColumn(
codes_col.base_data, # type: ignore[arg-type]
dtype=codes_col.dtype,
),
mask=codes_col.base_mask,
size=codes_col.size,
Expand All @@ -1190,8 +1197,9 @@ def _with_type_metadata(
if isinstance(dtype, CategoricalDtype):
return column.build_categorical_column(
categories=dtype.categories._values,
codes=column.build_column(
self.codes.base_data, dtype=self.codes.dtype
codes=cudf.core.column.NumericalColumn(
self.codes.base_data, # type: ignore[arg-type]
dtype=self.codes.dtype,
),
mask=self.codes.base_mask,
ordered=dtype.ordered,
Expand Down Expand Up @@ -1339,7 +1347,7 @@ def _set_categories(
Self,
column.build_categorical_column(
categories=new_cats,
codes=column.build_column(
codes=cudf.core.column.NumericalColumn(
new_codes.base_data, dtype=new_codes.dtype
),
mask=new_codes.base_mask,
Expand Down Expand Up @@ -1472,7 +1480,9 @@ def pandas_categorical_as_column(

return column.build_categorical_column(
categories=categorical.categories,
codes=column.build_column(codes.base_data, codes.dtype),
codes=cudf.core.column.NumericalColumn(
codes.base_data, dtype=codes.dtype
),
size=codes.size,
mask=mask,
ordered=categorical.ordered,
Expand Down
2 changes: 1 addition & 1 deletion python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -1506,7 +1506,7 @@ def column_empty(
elif isinstance(dtype, CategoricalDtype):
data = None
children = (
build_column(
cudf.core.column.NumericalColumn(
data=as_buffer(
rmm.DeviceBuffer(
size=row_count
Expand Down
10 changes: 5 additions & 5 deletions python/cudf/cudf/core/column/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,15 +473,15 @@ def as_timedelta_column(self, dtype: Dtype) -> None: # type: ignore[override]

def as_numerical_column(
self, dtype: Dtype
) -> "cudf.core.column.NumericalColumn":
col = column.build_column(
data=self.base_data,
dtype=np.int64,
) -> cudf.core.column.NumericalColumn:
col = cudf.core.column.NumericalColumn(
data=self.base_data, # type: ignore[arg-type]
dtype=np.dtype(np.int64),
mask=self.base_mask,
offset=self.offset,
size=self.size,
)
return cast("cudf.core.column.NumericalColumn", col.astype(dtype))
return cast(cudf.core.column.NumericalColumn, col.astype(dtype))

def strftime(self, format: str) -> cudf.core.column.StringColumn:
if len(self) == 0:
Expand Down
43 changes: 17 additions & 26 deletions python/cudf/cudf/core/column/numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,7 @@
from cudf import _lib as libcudf
from cudf._lib import pylibcudf
from cudf.api.types import is_integer, is_scalar
from cudf.core.column import (
ColumnBase,
as_column,
build_column,
column,
string,
)
from cudf.core.column import ColumnBase, as_column, column, string
from cudf.core.dtypes import CategoricalDtype
from cudf.core.mixins import BinaryOperand
from cudf.errors import MixedTypeError
Expand Down Expand Up @@ -338,29 +332,23 @@ def as_string_column(self) -> cudf.core.column.StringColumn:
def as_datetime_column(
self, dtype: Dtype
) -> cudf.core.column.DatetimeColumn:
return cast(
"cudf.core.column.DatetimeColumn",
build_column(
data=self.astype("int64").base_data,
dtype=dtype,
mask=self.base_mask,
offset=self.offset,
size=self.size,
),
return cudf.core.column.DatetimeColumn(
data=self.astype("int64").base_data, # type: ignore[arg-type]
dtype=dtype,
mask=self.base_mask,
offset=self.offset,
size=self.size,
)

def as_timedelta_column(
self, dtype: Dtype
) -> cudf.core.column.TimeDeltaColumn:
return cast(
"cudf.core.column.TimeDeltaColumn",
build_column(
data=self.astype("int64").base_data,
dtype=dtype,
mask=self.base_mask,
offset=self.offset,
size=self.size,
),
return cudf.core.column.TimeDeltaColumn(
data=self.astype("int64").base_data, # type: ignore[arg-type]
dtype=dtype,
mask=self.base_mask,
offset=self.offset,
size=self.size,
)

def as_decimal_column(
Expand Down Expand Up @@ -637,7 +625,10 @@ def _with_type_metadata(self: ColumnBase, dtype: Dtype) -> ColumnBase:
if isinstance(dtype, CategoricalDtype):
return column.build_categorical_column(
categories=dtype.categories._values,
codes=build_column(self.base_data, dtype=self.dtype),
codes=cudf.core.column.NumericalColumn(
self.base_data, # type: ignore[arg-type]
dtype=self.dtype,
),
mask=self.base_mask,
ordered=dtype.ordered,
size=self.size,
Expand Down
6 changes: 3 additions & 3 deletions python/cudf/cudf/core/column/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -5934,9 +5934,9 @@ def view(self, dtype) -> "cudf.core.column.ColumnBase":

n_bytes_to_view = str_end_byte_offset - str_byte_offset

to_view = column.build_column(
self.base_data,
dtype=cudf.api.types.dtype("int8"),
to_view = cudf.core.column.NumericalColumn(
self.base_data, # type: ignore[arg-type]
dtype=np.dtype(np.int8),
offset=str_byte_offset,
size=n_bytes_to_view,
)
Expand Down
8 changes: 4 additions & 4 deletions python/cudf/cudf/core/column/timedelta.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,10 @@ def round(self, freq: str) -> ColumnBase:

def as_numerical_column(
self, dtype: Dtype
) -> "cudf.core.column.NumericalColumn":
col = column.build_column(
data=self.base_data,
dtype=np.int64,
) -> cudf.core.column.NumericalColumn:
col = cudf.core.column.NumericalColumn(
data=self.base_data, # type: ignore[arg-type]
dtype=np.dtype(np.int64),
mask=self.base_mask,
offset=self.offset,
size=self.size,
Expand Down
4 changes: 2 additions & 2 deletions python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@
from cudf.core.column import (
CategoricalColumn,
ColumnBase,
NumericalColumn,
StructColumn,
as_column,
build_categorical_column,
build_column,
column_empty,
concat_columns,
)
Expand Down Expand Up @@ -8543,7 +8543,7 @@ def _reassign_categories(categories, cols, col_idxs):
if idx in categories:
cols[name] = build_categorical_column(
categories=categories[idx],
codes=build_column(
codes=NumericalColumn(
cols[name].base_data, dtype=cols[name].dtype
),
mask=cols[name].base_mask,
Expand Down
8 changes: 3 additions & 5 deletions python/cudf/cudf/core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -2434,12 +2434,10 @@ def to_pandas(
return result

@_performance_tracking
def _get_dt_field(self, field):
def _get_dt_field(self, field: str) -> Index:
"""Return an Index of a numerical component of the DatetimeIndex."""
out_column = self._values.get_dt_field(field)
# column.column_empty_like always returns a Column object
# but we need a NumericalColumn for Index..
# how should this be handled?
out_column = column.build_column(
out_column = NumericalColumn(
data=out_column.base_data,
dtype=out_column.dtype,
mask=out_column.base_mask,
Expand Down

0 comments on commit cc19d8a

Please sign in to comment.