Skip to content

Commit

Permalink
Simplify ColumnAccessor methods; avoid unnecessary validations (#14758)
Browse files Browse the repository at this point in the history
For methods that essentially do

```python
def select_by_foo(self, ...):
    ...
    return self.__class__(data={subset of self._data})
```

The `return` would perform validation on the returned subset of column, but I think that's unnecessary since that was done during initialization

Additionally
* Removed `_create_unsafe` in favor of a `verify=True|False` keyword in the constructor
* `_column_length` == `nrows` so removed `_column_length`
* Renamed `_compare_keys` to `_keys_equal`
* Remove seldom used/unnecessary methods

Authors:
  - Matthew Roeschke (https://github.com/mroeschke)

Approvers:
  - Vyas Ramasubramani (https://github.com/vyasr)

URL: #14758
  • Loading branch information
mroeschke authored Jan 23, 2024
1 parent c949abe commit 67a36a9
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 94 deletions.
141 changes: 52 additions & 89 deletions python/cudf/cudf/core/column_accessor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021-2023, NVIDIA CORPORATION.
# Copyright (c) 2021-2024, NVIDIA CORPORATION.

from __future__ import annotations

Expand All @@ -21,7 +21,6 @@
import pandas as pd
from packaging.version import Version
from pandas.api.types import is_bool
from typing_extensions import Self

import cudf
from cudf.core import column
Expand Down Expand Up @@ -66,7 +65,7 @@ def __getitem__(self, key):
return super().__getitem__(key)


def _to_flat_dict_inner(d, parents=()):
def _to_flat_dict_inner(d: dict, parents: tuple = ()):
for k, v in d.items():
if not isinstance(v, d.__class__):
if parents:
Expand All @@ -76,14 +75,6 @@ def _to_flat_dict_inner(d, parents=()):
yield from _to_flat_dict_inner(d=v, parents=parents + (k,))


def _to_flat_dict(d):
"""
Convert the given nested dictionary to a flat dictionary
with tuple keys.
"""
return {k: v for k, v in _to_flat_dict_inner(d)}


class ColumnAccessor(abc.MutableMapping):
"""
Parameters
Expand All @@ -103,6 +94,9 @@ class ColumnAccessor(abc.MutableMapping):
label_dtype : Dtype, optional
What dtype should be returned in `to_pandas_index`
(default=None).
verify : bool, optional
For non ColumnAccessor inputs, whether to verify
column length and type
"""

_data: "Dict[Any, ColumnBase]"
Expand All @@ -116,6 +110,7 @@ def __init__(
level_names=None,
rangeindex: bool = False,
label_dtype: Dtype | None = None,
verify: bool = True,
):
self.rangeindex = rangeindex
self.label_dtype = label_dtype
Expand All @@ -133,9 +128,9 @@ def __init__(
else:
# This code path is performance-critical for copies and should be
# modified with care.
self._data = {}
if data:
data = dict(data)
data = dict(data)
if data and verify:
result = {}
# Faster than next(iter(data.values()))
column_length = len(data[next(iter(data))])
for k, v in data.items():
Expand All @@ -146,30 +141,14 @@ def __init__(
v = column.as_column(v)
if len(v) != column_length:
raise ValueError("All columns must be of equal length")
self._data[k] = v
result[k] = v
self._data = result
else:
self._data = data

self.multiindex = multiindex
self._level_names = level_names

@classmethod
def _create_unsafe(
cls,
data: Dict[Any, ColumnBase],
multiindex: bool = False,
level_names=None,
rangeindex: bool = False,
label_dtype: Dtype | None = None,
) -> ColumnAccessor:
# create a ColumnAccessor without verifying column
# type or size
obj = cls()
obj._data = data
obj.multiindex = multiindex
obj._level_names = level_names
obj.rangeindex = rangeindex
obj.label_dtype = label_dtype
return obj

def __iter__(self):
return iter(self._data)

Expand Down Expand Up @@ -217,7 +196,7 @@ def nlevels(self) -> int:
def name(self) -> Any:
return self.level_names[-1]

@property
@cached_property
def nrows(self) -> int:
if len(self._data) == 0:
return 0
Expand All @@ -243,13 +222,6 @@ def _grouped_data(self) -> abc.MutableMapping:
else:
return self._data

@cached_property
def _column_length(self):
try:
return len(self._data[next(iter(self._data))])
except StopIteration:
return 0

def _clear_cache(self):
cached_properties = ("columns", "names", "_grouped_data")
for attr in cached_properties:
Expand All @@ -258,9 +230,9 @@ def _clear_cache(self):
except AttributeError:
pass

# Column length should only be cleared if no data is present.
if len(self._data) == 0 and hasattr(self, "_column_length"):
del self._column_length
# nrows should only be cleared if no data is present.
if len(self._data) == 0 and hasattr(self, "nrows"):
del self.nrows

def to_pandas_index(self) -> pd.Index:
"""Convert the keys of the ColumnAccessor to a Pandas Index object."""
Expand Down Expand Up @@ -345,11 +317,8 @@ def insert(
if loc == len(self._data):
if validate:
value = column.as_column(value)
if len(self._data) > 0:
if len(value) != self._column_length:
raise ValueError("All columns must be of equal length")
else:
self._column_length = len(value)
if len(self._data) > 0 and len(value) != self.nrows:
raise ValueError("All columns must be of equal length")
self._data[name] = value
else:
new_keys = self.names[:loc] + (name,) + self.names[loc:]
Expand All @@ -362,15 +331,16 @@ def copy(self, deep=False) -> ColumnAccessor:
Make a copy of this ColumnAccessor.
"""
if deep or cudf.get_option("copy_on_write"):
return self.__class__(
{k: v.copy(deep=deep) for k, v in self._data.items()},
multiindex=self.multiindex,
level_names=self.level_names,
)
data = {k: v.copy(deep=deep) for k, v in self._data.items()}
else:
data = self._data.copy()
return self.__class__(
self._data.copy(),
data=data,
multiindex=self.multiindex,
level_names=self.level_names,
rangeindex=self.rangeindex,
label_dtype=self.label_dtype,
verify=False,
)

def select_by_label(self, key: Any) -> ColumnAccessor:
Expand Down Expand Up @@ -508,22 +478,12 @@ def set_by_label(self, key: Any, value: Any, validate: bool = True):
key = self._pad_key(key)
if validate:
value = column.as_column(value)
if len(self._data) > 0:
if len(value) != self._column_length:
raise ValueError("All columns must be of equal length")
else:
self._column_length = len(value)
if len(self._data) > 0 and len(value) != self.nrows:
raise ValueError("All columns must be of equal length")

self._data[key] = value
self._clear_cache()

def _select_by_names(self, names: abc.Sequence) -> Self:
return self.__class__(
{key: self[key] for key in names},
multiindex=self.multiindex,
level_names=self.level_names,
)

def _select_by_label_list_like(self, key: Any) -> ColumnAccessor:
# Might be a generator
key = tuple(key)
Expand All @@ -541,7 +501,7 @@ def _select_by_label_list_like(self, key: Any) -> ColumnAccessor:
else:
data = {k: self._grouped_data[k] for k in key}
if self.multiindex:
data = _to_flat_dict(data)
data = dict(_to_flat_dict_inner(data))
return self.__class__(
data,
multiindex=self.multiindex,
Expand All @@ -550,11 +510,16 @@ def _select_by_label_list_like(self, key: Any) -> ColumnAccessor:

def _select_by_label_grouped(self, key: Any) -> ColumnAccessor:
result = self._grouped_data[key]
if isinstance(result, cudf.core.column.ColumnBase):
return self.__class__({key: result}, multiindex=self.multiindex)
if isinstance(result, column.ColumnBase):
# self._grouped_data[key] = self._data[key] so skip validation
return self.__class__(
data={key: result},
multiindex=self.multiindex,
verify=False,
)
else:
if self.multiindex:
result = _to_flat_dict(result)
result = dict(_to_flat_dict_inner(result))
if not isinstance(key, tuple):
key = (key,)
return self.__class__(
Expand All @@ -575,26 +540,28 @@ def _select_by_label_slice(self, key: slice) -> ColumnAccessor:
start = self._pad_key(start, slice(None))
stop = self._pad_key(stop, slice(None))
for idx, name in enumerate(self.names):
if _compare_keys(name, start):
if _keys_equal(name, start):
start_idx = idx
break
for idx, name in enumerate(reversed(self.names)):
if _compare_keys(name, stop):
if _keys_equal(name, stop):
stop_idx = len(self.names) - idx
break
keys = self.names[start_idx:stop_idx]
return self.__class__(
{k: self._data[k] for k in keys},
multiindex=self.multiindex,
level_names=self.level_names,
verify=False,
)

def _select_by_label_with_wildcard(self, key: Any) -> ColumnAccessor:
key = self._pad_key(key, slice(None))
return self.__class__(
{k: self._data[k] for k in self._data if _compare_keys(k, key)},
{k: self._data[k] for k in self._data if _keys_equal(k, key)},
multiindex=self.multiindex,
level_names=self.level_names,
verify=False,
)

def _pad_key(self, key: Any, pad_value="") -> Any:
Expand Down Expand Up @@ -639,6 +606,7 @@ def rename_levels(
to the given mapper and level.
"""
new_col_names: abc.Iterable
if self.multiindex:

def rename_column(x):
Expand All @@ -655,12 +623,7 @@ def rename_column(x):
"Renaming columns with a MultiIndex and level=None is"
"not supported"
)
new_names = map(rename_column, self.keys())
ca = ColumnAccessor(
dict(zip(new_names, self.values())),
level_names=self.level_names,
multiindex=self.multiindex,
)
new_col_names = (rename_column(k) for k in self.keys())

else:
if level is None:
Expand All @@ -680,13 +643,13 @@ def rename_column(x):
if len(new_col_names) != len(set(new_col_names)):
raise ValueError("Duplicate column names are not allowed")

ca = ColumnAccessor(
dict(zip(new_col_names, self.values())),
level_names=self.level_names,
multiindex=self.multiindex,
)

return self.__class__(ca)
data = dict(zip(new_col_names, self.values()))
return self.__class__(
data=data,
level_names=self.level_names,
multiindex=self.multiindex,
verify=False,
)

def droplevel(self, level):
# drop the nth level
Expand All @@ -708,7 +671,7 @@ def droplevel(self, level):
self._clear_cache()


def _compare_keys(target: Any, key: Any) -> bool:
def _keys_equal(target: Any, key: Any) -> bool:
"""
Compare `key` to `target`.
Expand Down
14 changes: 12 additions & 2 deletions python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,12 +481,22 @@ def __getitem__(self, arg):
index = self._frame.index
if col_is_scalar:
s = Series._from_data(
ca._select_by_names(column_names), index=index
data=ColumnAccessor(
{key: ca._data[key] for key in column_names},
multiindex=ca.multiindex,
level_names=ca.level_names,
),
index=index,
)
return s._getitem_preprocessed(row_spec)
if column_names != list(self._frame._column_names):
frame = self._frame._from_data(
ca._select_by_names(column_names), index=index
data=ColumnAccessor(
{key: ca._data[key] for key in column_names},
multiindex=ca.multiindex,
level_names=ca.level_names,
),
index=index,
)
else:
frame = self._frame
Expand Down
6 changes: 4 additions & 2 deletions python/cudf/cudf/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,12 +278,13 @@ def astype(self, dtype, copy: bool = False):
for col_name, col in self._data.items()
}

return ColumnAccessor._create_unsafe(
return ColumnAccessor(
data=result_data,
multiindex=self._data.multiindex,
level_names=self._data.level_names,
rangeindex=self._data.rangeindex,
label_dtype=self._data.label_dtype,
verify=False,
)

@_cudf_nvtx_annotate
Expand Down Expand Up @@ -881,12 +882,13 @@ def fillna(

return self._mimic_inplace(
self._from_data(
data=ColumnAccessor._create_unsafe(
data=ColumnAccessor(
data=filled_data,
multiindex=self._data.multiindex,
level_names=self._data.level_names,
rangeindex=self._data.rangeindex,
label_dtype=self._data.label_dtype,
verify=False,
)
),
inplace=inplace,
Expand Down
3 changes: 2 additions & 1 deletion python/cudf/cudf/core/multiindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,10 @@ def names(self, value):
# to unexpected behavior in some cases. This is
# definitely buggy, but we can't disallow non-unique
# names either...
self._data = self._data.__class__._create_unsafe(
self._data = self._data.__class__(
dict(zip(value, self._data.values())),
level_names=self._data.level_names,
verify=False,
)
self._names = pd.core.indexes.frozen.FrozenList(value)

Expand Down

0 comments on commit 67a36a9

Please sign in to comment.