Skip to content

Commit

Permalink
Use as_column instead of arange for range like inputs (#14689)
Browse files Browse the repository at this point in the history
1. Allows range-like inputs in `as_column` to short circuit and not materialize when creating columns
2. Avoids diverging column construction logic between `column.arange` and `column.as_column`

Authors:
  - Matthew Roeschke (https://github.com/mroeschke)

Approvers:
  - Vyas Ramasubramani (https://github.com/vyasr)

URL: #14689
  • Loading branch information
mroeschke authored Jan 12, 2024
1 parent c0a3cd1 commit 7a42b8b
Show file tree
Hide file tree
Showing 12 changed files with 79 additions and 126 deletions.
3 changes: 1 addition & 2 deletions python/cudf/cudf/core/column/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
# Copyright (c) 2020-2024, NVIDIA CORPORATION.

"""
isort: skip_file
Expand All @@ -8,7 +8,6 @@
from cudf.core.column.categorical import CategoricalColumn
from cudf.core.column.column import (
ColumnBase,
arange,
as_column,
build_categorical_column,
build_column,
Expand Down
12 changes: 8 additions & 4 deletions python/cudf/cudf/core/column/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -1159,7 +1159,7 @@ def find_and_replace(
new_cats_col = new_cats_col.apply_boolean_mask(bmask)
new_cats = cudf.DataFrame._from_data(
{
"index": cudf.core.column.arange(len(new_cats_col)),
"index": column.as_column(range(len(new_cats_col))),
"cats": new_cats_col,
}
)
Expand Down Expand Up @@ -1531,9 +1531,13 @@ def _set_categories(
)
out_code_dtype = min_unsigned_type(max_cat_size)

cur_order = column.arange(len(cur_codes))
old_codes = column.arange(len(cur_cats), dtype=out_code_dtype)
new_codes = column.arange(len(new_cats), dtype=out_code_dtype)
cur_order = column.as_column(range(len(cur_codes)))
old_codes = column.as_column(
range(len(cur_cats)), dtype=out_code_dtype
)
new_codes = column.as_column(
range(len(new_cats)), dtype=out_code_dtype
)

new_df = cudf.DataFrame._from_data(
data={"new_codes": new_codes, "cats": new_cats}
Expand Down
99 changes: 23 additions & 76 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,10 +554,8 @@ def slice(
]._with_type_metadata(self.dtype)
else:
# Need to create a gather map for given slice with stride
gather_map = arange(
start=start,
stop=stop,
step=stride,
gather_map = as_column(
range(start, stop, stride),
dtype=cudf.dtype(np.int32),
)
return self.take(gather_map)
Expand Down Expand Up @@ -626,10 +624,8 @@ def _scatter_by_slice(
)

# step != 1, create a scatter map with arange
scatter_map = arange(
start=start,
stop=stop,
step=step,
scatter_map = as_column(
range(start, stop, step),
dtype=cudf.dtype(np.int32),
)

Expand Down Expand Up @@ -745,7 +741,7 @@ def indices_of(
assert len(value) == 1
mask = libcudf.search.contains(value, self)
return apply_boolean_mask(
[arange(0, len(self), dtype=size_type_dtype)], mask
[as_column(range(0, len(self)), dtype=size_type_dtype)], mask
)[0]

def _find_first_and_last(self, value: ScalarLike) -> Tuple[int, int]:
Expand Down Expand Up @@ -1379,7 +1375,9 @@ def _return_sentinel_column():
[self], [cats], how="left"
)
codes = libcudf.copying.gather(
[arange(len(cats), dtype=dtype)], right_gather_map, nullify=True
[as_column(range(len(cats)), dtype=dtype)],
right_gather_map,
nullify=True,
)
del right_gather_map
# reorder `codes` so that its values correspond to the
Expand Down Expand Up @@ -1905,13 +1903,26 @@ def as_column(
* Objects exposing ``__array_interface__``(e.g., numpy arrays)
* pyarrow array
* pandas.Categorical objects
* range objects
"""
if isinstance(arbitrary, ColumnBase):
if isinstance(arbitrary, (range, pd.RangeIndex, cudf.RangeIndex)):
column = libcudf.filling.sequence(
len(arbitrary),
as_device_scalar(arbitrary.start, dtype=cudf.dtype("int64")),
as_device_scalar(arbitrary.step, dtype=cudf.dtype("int64")),
)
if cudf.get_option("default_integer_bitwidth") and dtype is None:
dtype = cudf.dtype(
f'i{cudf.get_option("default_integer_bitwidth")//8}'
)
if dtype is not None:
column = column.astype(dtype)
return column
elif isinstance(arbitrary, ColumnBase):
if dtype is not None:
return arbitrary.astype(dtype)
else:
return arbitrary

elif isinstance(arbitrary, cudf.Series):
data = arbitrary._column
if dtype is not None:
Expand Down Expand Up @@ -2614,70 +2625,6 @@ def deserialize_columns(headers: List[dict], frames: List) -> List[ColumnBase]:
return columns


def arange(
start: Union[int, float],
stop: Optional[Union[int, float]] = None,
step: Union[int, float] = 1,
dtype=None,
) -> cudf.core.column.NumericalColumn:
"""
Returns a column with evenly spaced values within a given interval.
Values are generated within the half-open interval [start, stop).
The first three arguments are mapped like the range built-in function,
i.e. start and step are optional.
Parameters
----------
start : int/float
Start of the interval.
stop : int/float, default is None
Stop of the interval.
step : int/float, default 1
Step width between each pair of consecutive values.
dtype : default None
Data type specifier. It is inferred from other arguments by default.
Returns
-------
cudf.core.column.NumericalColumn
Examples
--------
>>> import cudf
>>> col = cudf.core.column.arange(2, 7, 1, dtype='int16')
>>> col
<cudf.core.column.numerical.NumericalColumn object at 0x7ff7998f8b90>
>>> cudf.Series(col)
0 2
1 3
2 4
3 5
4 6
dtype: int16
"""
if stop is None:
stop = start
start = 0

if step is None:
step = 1

size = len(range(int(start), int(stop), int(step)))
if size == 0:
if dtype is None:
dtype = cudf.dtype("int64")
return cast(
cudf.core.column.NumericalColumn, column_empty(0, dtype=dtype)
)

return libcudf.filling.sequence(
size,
as_device_scalar(start, dtype=dtype),
as_device_scalar(step, dtype=dtype),
)


def full(
size: int, fill_value: ScalarLike, dtype: Optional[Dtype] = None
) -> ColumnBase:
Expand Down
10 changes: 8 additions & 2 deletions python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,10 +342,16 @@ def _getitem_tuple_arg(self, arg):
tmp_col_name = (tmp_col_name, *extra)
cantor_name = (cantor_name, *extra)
other_df = DataFrame(
{tmp_col_name: column.arange(len(tmp_arg[0]))},
{
tmp_col_name: column.as_column(
range(len(tmp_arg[0]))
)
},
index=as_index(tmp_arg[0]),
)
columns_df[cantor_name] = column.arange(len(columns_df))
columns_df[cantor_name] = column.as_column(
range(len(columns_df))
)
df = other_df.join(columns_df, how="inner")
# as join is not assigning any names to index,
# update it over here
Expand Down
10 changes: 6 additions & 4 deletions python/cudf/cudf/core/groupby/groupby.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
# Copyright (c) 2020-2024, NVIDIA CORPORATION.

import copy
import itertools
Expand All @@ -23,7 +23,7 @@
from cudf._typing import AggType, DataFrameOrSeries, MultiColumnAggType
from cudf.api.types import is_bool_dtype, is_float_dtype, is_list_like
from cudf.core.abc import Serializable
from cudf.core.column.column import ColumnBase, arange, as_column
from cudf.core.column.column import ColumnBase, as_column
from cudf.core.column_accessor import ColumnAccessor
from cudf.core.join._join_helpers import _match_join_keys
from cudf.core.mixins import Reducible, Scannable
Expand Down Expand Up @@ -761,7 +761,7 @@ def _head_tail(self, n, *, take_head: bool, preserve_order: bool):
# subsample the gather map from the full input ordering,
# rather than permuting the gather map of the output.
_, (ordering,), _ = self._groupby.groups(
[arange(0, len(self.obj))]
[as_column(range(0, len(self.obj)))]
)
# Invert permutation from original order to groups on the
# subset of entries we want.
Expand Down Expand Up @@ -2543,7 +2543,9 @@ def _mimic_pandas_order(
# result coming back from libcudf has null_count few rows than
# the input, so we must produce an ordering from the full
# input range.
_, (ordering,), _ = self._groupby.groups([arange(0, len(self.obj))])
_, (ordering,), _ = self._groupby.groups(
[as_column(range(0, len(self.obj)))]
)
if self._dropna and any(
c.has_nulls(include_nan=True) > 0
for c in self.grouping._key_columns
Expand Down
4 changes: 1 addition & 3 deletions python/cudf/cudf/core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,9 +286,7 @@ def _num_rows(self):
@_cudf_nvtx_annotate
def _values(self):
if len(self) > 0:
return column.arange(
self._start, self._stop, self._step, dtype=self.dtype
)
return column.as_column(self._range, dtype=self.dtype)
else:
return column.column_empty(0, masked=False, dtype=self.dtype)

Expand Down
18 changes: 6 additions & 12 deletions python/cudf/cudf/core/indexed_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,8 @@ def _indices_from_labels(obj, labels):
# join is not guaranteed to maintain the index ordering
# so we will sort it with its initial ordering which is stored
# in column "__"
lhs = cudf.DataFrame(
{"__": cudf.core.column.arange(len(labels))}, index=labels
)
rhs = cudf.DataFrame(
{"_": cudf.core.column.arange(len(obj))}, index=obj.index
)
lhs = cudf.DataFrame({"__": as_column(range(len(labels)))}, index=labels)
rhs = cudf.DataFrame({"_": as_column(range(len(obj)))}, index=obj.index)
return lhs.join(rhs).sort_values(by=["__", "_"])["_"]


Expand Down Expand Up @@ -1897,10 +1893,8 @@ def _slice(self, arg: slice, keep_index: bool = True) -> Self:
if stride != 1:
return self._gather(
GatherMap.from_column_unchecked(
cudf.core.column.arange(
start,
stop=stop,
step=stride,
as_column(
range(start, stop, stride),
dtype=libcudf.types.size_type_dtype,
),
len(self),
Expand Down Expand Up @@ -2541,9 +2535,9 @@ def _align_to_index(
# to recover ordering after index alignment.
sort_col_id = str(uuid4())
if how == "left":
lhs[sort_col_id] = cudf.core.column.arange(len(lhs))
lhs[sort_col_id] = as_column(range(len(lhs)))
elif how == "right":
rhs[sort_col_id] = cudf.core.column.arange(len(rhs))
rhs[sort_col_id] = as_column(range(len(rhs)))

result = lhs.join(rhs, how=how, sort=sort)
if how in ("left", "right"):
Expand Down
8 changes: 6 additions & 2 deletions python/cudf/cudf/core/join/join.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
from __future__ import annotations

import itertools
Expand Down Expand Up @@ -232,7 +232,11 @@ def _gather_maps(self, left_cols, right_cols):
key_order = list(
itertools.chain.from_iterable(
libcudf.copying.gather(
[cudf.core.column.arange(n, dtype=size_type_dtype)],
[
cudf.core.column.as_column(
range(n), dtype=size_type_dtype
)
],
map_,
nullify=null,
)
Expand Down
16 changes: 9 additions & 7 deletions python/cudf/cudf/core/multiindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,9 +501,9 @@ def __repr__(self):
# TODO: Update the following two arange calls to
# a single arange call once arange has support for
# a vector start/end points.
indices = column.arange(start=0, stop=n, step=1)
indices = column.as_column(range(n))
indices = indices.append(
column.arange(start=len(self) - n, stop=len(self), step=1)
column.as_column(range(len(self) - n, len(self), 1))
)
preprocess = self.take(indices)
else:
Expand Down Expand Up @@ -795,7 +795,7 @@ def _compute_validity_mask(self, index, row_tuple, max_length):
[
frame,
cudf.DataFrame(
{"idx": cudf.Series(column.arange(len(frame)))}
{"idx": cudf.Series(column.as_column(range(len(frame))))}
),
],
axis=1,
Expand All @@ -807,7 +807,7 @@ def _compute_validity_mask(self, index, row_tuple, max_length):
# obtain deterministic ordering.
if cudf.get_option("mode.pandas_compatible"):
lookup_order = "_" + "_".join(map(str, lookup._data.names))
lookup[lookup_order] = column.arange(len(lookup))
lookup[lookup_order] = column.as_column(range(len(lookup)))
postprocess = operator.methodcaller(
"sort_values", by=[lookup_order, "idx"]
)
Expand Down Expand Up @@ -840,14 +840,16 @@ def _get_valid_indices_by_tuple(self, index, row_tuple, max_length):
):
stop = row_tuple.stop or max_length
start, stop, step = row_tuple.indices(stop)
return column.arange(start, stop, step)
return column.as_column(range(start, stop, step))
start_values = self._compute_validity_mask(
index, row_tuple.start, max_length
)
stop_values = self._compute_validity_mask(
index, row_tuple.stop, max_length
)
return column.arange(start_values.min(), stop_values.max() + 1)
return column.as_column(
range(start_values.min(), stop_values.max() + 1)
)
elif isinstance(row_tuple, numbers.Number):
return row_tuple
return self._compute_validity_mask(index, row_tuple, max_length)
Expand Down Expand Up @@ -1024,7 +1026,7 @@ def __getitem__(self, index):
index = np.array(index)
elif isinstance(index, slice):
start, stop, step = index.indices(len(self))
index = column.arange(start, stop, step)
index = column.as_column(range(start, stop, step))
result = MultiIndex.from_frame(
self.to_frame(index=False, name=range(0, self.nlevels)).take(
index
Expand Down
9 changes: 6 additions & 3 deletions python/cudf/cudf/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
DatetimeColumn,
IntervalColumn,
TimeDeltaColumn,
arange,
as_column,
full,
)
Expand Down Expand Up @@ -1366,7 +1365,9 @@ def map(self, arg, na_action=None) -> "Series":
raise NotImplementedError(
"default values in dicts are currently not supported."
)
lhs = cudf.DataFrame({"x": self, "orig_order": arange(len(self))})
lhs = cudf.DataFrame(
{"x": self, "orig_order": as_column(range(len(self)))}
)
rhs = cudf.DataFrame(
{
"x": arg.keys(),
Expand All @@ -1386,7 +1387,9 @@ def map(self, arg, na_action=None) -> "Series":
"Reindexing only valid with"
" uniquely valued Index objects"
)
lhs = cudf.DataFrame({"x": self, "orig_order": arange(len(self))})
lhs = cudf.DataFrame(
{"x": self, "orig_order": as_column(range(len(self)))}
)
rhs = cudf.DataFrame(
{
"x": arg.keys(),
Expand Down
Loading

0 comments on commit 7a42b8b

Please sign in to comment.