Skip to content

Commit

Permalink
Misc Column cleanups (#15682)
Browse files Browse the repository at this point in the history
* Some typing
* Moved a single use helper function inline
* Some dtype checking simplification

Authors:
  - Matthew Roeschke (https://github.com/mroeschke)
  - GALI PREM SAGAR (https://github.com/galipremsagar)

Approvers:
  - GALI PREM SAGAR (https://github.com/galipremsagar)

URL: #15682
  • Loading branch information
mroeschke authored May 8, 2024
1 parent f965f3c commit 57e534a
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 27 deletions.
24 changes: 11 additions & 13 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
infer_dtype,
is_bool_dtype,
is_dtype_equal,
is_integer_dtype,
is_scalar,
is_string_dtype,
)
Expand Down Expand Up @@ -606,7 +605,8 @@ def _scatter_by_slice(
start, stop, step = key.indices(len(self))
if start >= stop:
return None
num_keys = len(range(start, stop, step))
rng = range(start, stop, step)
num_keys = len(rng)

self._check_scatter_key_length(num_keys, value)

Expand All @@ -625,7 +625,7 @@ def _scatter_by_slice(

# step != 1, create a scatter map with arange
scatter_map = as_column(
range(start, stop, step),
rng,
dtype=cudf.dtype(np.int32),
)

Expand Down Expand Up @@ -672,18 +672,16 @@ def _scatter_by_column(

def _check_scatter_key_length(
self, num_keys: int, value: Union[cudf.core.scalar.Scalar, ColumnBase]
):
) -> None:
"""`num_keys` is the number of keys to scatter. Should equal to the
number of rows in ``value`` if ``value`` is a column.
"""
if isinstance(value, ColumnBase):
if len(value) != num_keys:
msg = (
f"Size mismatch: cannot set value "
f"of size {len(value)} to indexing result of size "
f"{num_keys}"
)
raise ValueError(msg)
if isinstance(value, ColumnBase) and len(value) != num_keys:
raise ValueError(
f"Size mismatch: cannot set value "
f"of size {len(value)} to indexing result of size "
f"{num_keys}"
)

def fillna(
self,
Expand Down Expand Up @@ -820,7 +818,7 @@ def take(

# TODO: For performance, the check and conversion of gather map should
# be done by the caller. This check will be removed in future release.
if not is_integer_dtype(indices.dtype):
if indices.dtype.kind not in {"u", "i"}:
indices = indices.astype(libcudf.types.size_type_dtype)
if not libcudf.copying._gather_map_is_valid(
indices, len(self), check_bounds, nullify
Expand Down
22 changes: 8 additions & 14 deletions python/cudf/cudf/core/column/numerical_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2018-2023, NVIDIA CORPORATION.
# Copyright (c) 2018-2024, NVIDIA CORPORATION.
"""Define an interface for columns that can perform numerical operations."""

from __future__ import annotations
Expand Down Expand Up @@ -112,7 +112,13 @@ def quantile(
),
)
else:
result = self._numeric_quantile(q, interpolation, exact)
# get sorted indices and exclude nulls
indices = libcudf.sort.order_by(
[self], [True], "first", stable=True
).slice(self.null_count, len(self))
result = libcudf.quantiles.quantile(
self, q, interpolation, indices, exact
)
if return_scalar:
scalar_result = result.element_indexing(0)
if interpolation in {"lower", "higher", "nearest"}:
Expand Down Expand Up @@ -178,18 +184,6 @@ def median(self, skipna: Optional[bool] = None) -> NumericalBaseColumn:
return_scalar=True,
)

def _numeric_quantile(
self, q: np.ndarray, interpolation: str, exact: bool
) -> NumericalBaseColumn:
# get sorted indices and exclude nulls
indices = libcudf.sort.order_by(
[self], [True], "first", stable=True
).slice(self.null_count, len(self))

return libcudf.quantiles.quantile(
self, q, interpolation, indices, exact
)

def cov(self, other: NumericalBaseColumn) -> float:
if (
len(self) == 0
Expand Down

0 comments on commit 57e534a

Please sign in to comment.