Skip to content

Commit

Permalink
Remove unneeded methods in Column (#14730)
Browse files Browse the repository at this point in the history
* `valid_count` can be composed of `null_count` or where checked `has_nulls`
* `contains_na_entries` is redundant with `has_nulls`
* Better typing in `searchsorted`

Authors:
  - Matthew Roeschke (https://github.com/mroeschke)

Approvers:
  - Vyas Ramasubramani (https://github.com/vyasr)

URL: #14730
  • Loading branch information
mroeschke authored Jan 10, 2024
1 parent fa37e13 commit 3f19d04
Show file tree
Hide file tree
Showing 11 changed files with 52 additions and 43 deletions.
16 changes: 11 additions & 5 deletions python/cudf/cudf/core/_base_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

from __future__ import annotations

import builtins
import pickle
import warnings
from functools import cached_property
from typing import Any, Set, Tuple
from typing import Any, Literal, Set, Tuple

import pandas as pd
from typing_extensions import Self
Expand Down Expand Up @@ -1702,6 +1701,8 @@ def find_label_range(self, loc: slice) -> slice:
start = loc.start
stop = loc.stop
step = 1 if loc.step is None else loc.step
start_side: Literal["left", "right"]
stop_side: Literal["left", "right"]
if step < 0:
start_side, stop_side = "right", "left"
else:
Expand All @@ -1725,9 +1726,9 @@ def find_label_range(self, loc: slice) -> slice:
def searchsorted(
self,
value,
side: builtins.str = "left",
side: Literal["left", "right"] = "left",
ascending: bool = True,
na_position: builtins.str = "last",
na_position: Literal["first", "last"] = "last",
):
"""Find index where elements should be inserted to maintain order
Expand All @@ -1754,7 +1755,12 @@ def searchsorted(
"""
raise NotImplementedError

def get_slice_bound(self, label, side: builtins.str, kind=None) -> int:
def get_slice_bound(
self,
label,
side: Literal["left", "right"],
kind: Literal["ix", "loc", "getitem", None] = None,
) -> int:
"""
Calculate slice bound that corresponds to given label.
Returns leftmost (one-past-the-rightmost if ``side=='right'``) position
Expand Down
4 changes: 3 additions & 1 deletion python/cudf/cudf/core/column/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -1372,7 +1372,9 @@ def _concat(
# improved as the concatenation API is solidified.

# Find the first non-null column:
head = next((obj for obj in objs if obj.valid_count), objs[0])
head = next(
(obj for obj in objs if not obj.null_count != len(obj)), objs[0]
)

# Combine and de-dupe the categories
cats = column.concat_columns([o.categories for o in objs]).unique()
Expand Down
20 changes: 6 additions & 14 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Any,
Dict,
List,
Literal,
MutableSequence,
Optional,
Sequence,
Expand Down Expand Up @@ -428,11 +429,6 @@ def _fill(
def shift(self, offset: int, fill_value: ScalarLike) -> ColumnBase:
return libcudf.copying.shift(self, offset, fill_value)

@property
def valid_count(self) -> int:
"""Number of non-null values"""
return len(self) - self.null_count

@property
def nullmask(self) -> Buffer:
"""The gpu buffer for the null-mask"""
Expand Down Expand Up @@ -1159,9 +1155,9 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
def searchsorted(
self,
value,
side: str = "left",
side: Literal["left", "right"] = "left",
ascending: bool = True,
na_position: str = "last",
na_position: Literal["first", "last"] = "last",
) -> Self:
if not isinstance(value, ColumnBase) or value.dtype != self.dtype:
raise ValueError(
Expand Down Expand Up @@ -1304,10 +1300,6 @@ def _reduce(
return libcudf.reduce.reduce(op, preprocessed, **kwargs)
return preprocessed

@property
def contains_na_entries(self) -> bool:
return self.null_count != 0

def _process_for_reduction(
self, skipna: Optional[bool] = None, min_count: int = 0
) -> Union[ColumnBase, ScalarLike]:
Expand Down Expand Up @@ -2742,7 +2734,7 @@ def concat_columns(objs: "MutableSequence[ColumnBase]") -> ColumnBase:
# If all columns are `NumericalColumn` with different dtypes,
# we cast them to a common dtype.
# Notice, we can always cast pure null columns
not_null_col_dtypes = [o.dtype for o in objs if o.valid_count]
not_null_col_dtypes = [o.dtype for o in objs if o.null_count != len(o)]
if len(not_null_col_dtypes) and all(
_is_non_decimal_numeric_dtype(dtyp)
and np.issubdtype(dtyp, np.datetime64)
Expand All @@ -2754,13 +2746,13 @@ def concat_columns(objs: "MutableSequence[ColumnBase]") -> ColumnBase:
objs = [obj.astype(common_dtype) for obj in objs]

# Find the first non-null column:
head = next((obj for obj in objs if obj.valid_count), objs[0])
head = next((obj for obj in objs if obj.null_count != len(obj)), objs[0])

for i, obj in enumerate(objs):
# Check that all columns are the same type:
if not is_dtype_equal(obj.dtype, head.dtype):
# if all null, cast to appropriate dtype
if obj.valid_count == 0:
if obj.null_count == len(obj):
objs[i] = column_empty_like(
head, dtype=head.dtype, masked=True, newsize=len(obj)
)
Expand Down
6 changes: 1 addition & 5 deletions python/cudf/cudf/core/column/numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def indices_of(self, value: ScalarLike) -> NumericalColumn:
else:
return super().indices_of(value)

def has_nulls(self, include_nan=False):
def has_nulls(self, include_nan: bool = False) -> bool:
return bool(self.null_count != 0) or (
include_nan and bool(self.nan_count != 0)
)
Expand Down Expand Up @@ -425,10 +425,6 @@ def dropna(self, drop_nan: bool = False) -> NumericalColumn:
col = self.nans_to_nulls() if drop_nan else self
return drop_nulls([col])[0]

@property
def contains_na_entries(self) -> bool:
return (self.nan_count != 0) or (self.null_count != 0)

def _process_values_for_isin(
self, values: Sequence
) -> Tuple[ColumnBase, ColumnBase]:
Expand Down
11 changes: 8 additions & 3 deletions python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5997,9 +5997,14 @@ def count(self, axis=0, level=None, numeric_only=False, **kwargs):
axis = self._get_axis_from_axis_arg(axis)
if axis != 0:
raise NotImplementedError("Only axis=0 is currently supported.")

length = len(self)
return Series._from_data(
{None: [self._data[col].valid_count for col in self._data.names]},
{
None: [
length - self._data[col].null_count
for col in self._data.names
]
},
as_index(self._data.names),
)

Expand Down Expand Up @@ -8091,7 +8096,7 @@ def _get_non_null_cols_and_dtypes(col_idxs, list_of_columns):
# non-null Column with the same name is found.
if idx not in dtypes:
dtypes[idx] = cols[idx].dtype
if cols[idx].valid_count > 0:
if cols[idx].null_count != len(cols[idx]):
if idx not in non_null_columns:
non_null_columns[idx] = [cols[idx]]
else:
Expand Down
9 changes: 7 additions & 2 deletions python/cudf/cudf/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Callable,
Dict,
List,
Literal,
MutableMapping,
Optional,
Tuple,
Expand Down Expand Up @@ -882,7 +883,7 @@ def fillna(
replace_val = None
should_fill = (
col_name in value
and col.contains_na_entries
and col.has_nulls(include_nan=True)
and not libcudf.scalar._is_null_host_scalar(replace_val)
) or method is not None
if should_fill:
Expand Down Expand Up @@ -1354,7 +1355,11 @@ def notna(self):

@_cudf_nvtx_annotate
def searchsorted(
self, values, side="left", ascending=True, na_position="last"
self,
values,
side: Literal["left", "right"] = "left",
ascending: bool = True,
na_position: Literal["first", "last"] = "last",
):
"""Find indices where elements should be inserted to maintain order
Expand Down
11 changes: 6 additions & 5 deletions python/cudf/cudf/core/index.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2018-2023, NVIDIA CORPORATION.
# Copyright (c) 2018-2024, NVIDIA CORPORATION.

from __future__ import annotations

Expand All @@ -11,6 +11,7 @@
Any,
Dict,
List,
Literal,
MutableMapping,
Optional,
Sequence,
Expand Down Expand Up @@ -233,9 +234,9 @@ def _copy_type_metadata(
def searchsorted(
self,
value: int,
side: str = "left",
side: Literal["left", "right"] = "left",
ascending: bool = True,
na_position: str = "last",
na_position: Literal["first", "last"] = "last",
):
assert (len(self) <= 1) or (
ascending == (self._step > 0)
Expand Down Expand Up @@ -2205,9 +2206,9 @@ def copy(self, name=None, deep=False, dtype=None, names=None):
def searchsorted(
self,
value,
side: str = "left",
side: Literal["left", "right"] = "left",
ascending: bool = True,
na_position: str = "last",
na_position: Literal["first", "last"] = "last",
):
value = self.dtype.type(value)
return super().searchsorted(
Expand Down
2 changes: 1 addition & 1 deletion python/cudf/cudf/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1710,7 +1710,7 @@ def _concat(cls, objs, axis=0, index=True):
@_cudf_nvtx_annotate
def valid_count(self):
"""Number of non-null values"""
return self._column.valid_count
return len(self) - self._column.null_count

@property # type: ignore
@_cudf_nvtx_annotate
Expand Down
4 changes: 2 additions & 2 deletions python/cudf/cudf/tests/test_categorical.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2018-2023, NVIDIA CORPORATION.
# Copyright (c) 2018-2024, NVIDIA CORPORATION.

import operator
import string
Expand Down Expand Up @@ -217,7 +217,7 @@ def test_categorical_masking():
got_masked = sr[got_matches]

assert len(expect_masked) == len(got_masked)
assert len(expect_masked) == got_masked.valid_count
assert got_masked.null_count == 0
assert_eq(got_masked, expect_masked)


Expand Down
8 changes: 5 additions & 3 deletions python/cudf/cudf/tests/test_orc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019-2023, NVIDIA CORPORATION.
# Copyright (c) 2019-2024, NVIDIA CORPORATION.

import datetime
import decimal
Expand Down Expand Up @@ -812,7 +812,7 @@ def test_orc_write_bool_statistics(tmpdir, datadir, nrows):

if "number_of_values" in file_stats[0][col]:
stats_valid_count = file_stats[0][col]["number_of_values"]
actual_valid_count = gdf[col].valid_count
actual_valid_count = len(gdf[col]) - gdf[col].null_count
assert normalized_equals(actual_valid_count, stats_valid_count)

# compare stripe statistics with actual min/max
Expand All @@ -827,7 +827,9 @@ def test_orc_write_bool_statistics(tmpdir, datadir, nrows):
assert normalized_equals(actual_true_count, stats_true_count)

if "number_of_values" in stripes_stats[stripe_idx][col]:
actual_valid_count = stripe_df[col].valid_count
actual_valid_count = (
len(stripe_df[col]) - stripe_df[col].null_count
)
stats_valid_count = stripes_stats[stripe_idx][col][
"number_of_values"
]
Expand Down
4 changes: 2 additions & 2 deletions python/cudf/cudf/utils/dtypes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
# Copyright (c) 2020-2024, NVIDIA CORPORATION.

import datetime
from collections import namedtuple
Expand Down Expand Up @@ -401,7 +401,7 @@ def min_column_type(x, expected_type):

if not isinstance(x, cudf.core.column.NumericalColumn):
raise TypeError("Argument x must be of type column.NumericalColumn")
if x.valid_count == 0:
if x.null_count == len(x):
return x.dtype

if np.issubdtype(x.dtype, np.floating):
Expand Down

0 comments on commit 3f19d04

Please sign in to comment.