Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify ColumnAccessor methods; avoid unnecessary validations #14758

Merged
merged 18 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 50 additions & 91 deletions python/cudf/cudf/core/column_accessor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021-2023, NVIDIA CORPORATION.
# Copyright (c) 2021-2024, NVIDIA CORPORATION.

from __future__ import annotations

Expand All @@ -21,7 +21,6 @@
import pandas as pd
from packaging.version import Version
from pandas.api.types import is_bool
from typing_extensions import Self

import cudf
from cudf.core import column
Expand Down Expand Up @@ -66,7 +65,7 @@ def __getitem__(self, key):
return super().__getitem__(key)


def _to_flat_dict_inner(d, parents=()):
def _to_flat_dict_inner(d: dict, parents: tuple = ()):
for k, v in d.items():
if not isinstance(v, d.__class__):
if parents:
Expand All @@ -76,14 +75,6 @@ def _to_flat_dict_inner(d, parents=()):
yield from _to_flat_dict_inner(d=v, parents=parents + (k,))


def _to_flat_dict(d):
"""
Convert the given nested dictionary to a flat dictionary
with tuple keys.
"""
return {k: v for k, v in _to_flat_dict_inner(d)}


class ColumnAccessor(abc.MutableMapping):
"""
Parameters
Expand All @@ -103,6 +94,9 @@ class ColumnAccessor(abc.MutableMapping):
label_dtype : Dtype, optional
What dtype should be returned in `to_pandas_index`
(default=None).
verify : bool, optional
For non ColumnAccessor inputs, whether to verify
column length and type
"""

_data: "Dict[Any, ColumnBase]"
Expand All @@ -116,6 +110,7 @@ def __init__(
level_names=None,
rangeindex: bool = False,
label_dtype: Dtype | None = None,
verify: bool = True,
):
self.rangeindex = rangeindex
self.label_dtype = label_dtype
Expand All @@ -133,9 +128,9 @@ def __init__(
else:
# This code path is performance-critical for copies and should be
# modified with care.
self._data = {}
if data:
data = dict(data)
data = dict(data)
if data and verify:
result = {}
# Faster than next(iter(data.values()))
column_length = len(data[next(iter(data))])
for k, v in data.items():
Expand All @@ -146,30 +141,14 @@ def __init__(
v = column.as_column(v)
if len(v) != column_length:
raise ValueError("All columns must be of equal length")
self._data[k] = v
result[k] = v
self._data = result
else:
self._data = data

self.multiindex = multiindex
self._level_names = level_names

@classmethod
def _create_unsafe(
cls,
data: Dict[Any, ColumnBase],
multiindex: bool = False,
level_names=None,
rangeindex: bool = False,
label_dtype: Dtype | None = None,
) -> ColumnAccessor:
# create a ColumnAccessor without verifying column
# type or size
obj = cls()
obj._data = data
obj.multiindex = multiindex
obj._level_names = level_names
obj.rangeindex = rangeindex
obj.label_dtype = label_dtype
return obj

def __iter__(self):
return iter(self._data)

Expand Down Expand Up @@ -217,7 +196,7 @@ def nlevels(self) -> int:
def name(self) -> Any:
return self.level_names[-1]

@property
@cached_property
def nrows(self) -> int:
if len(self._data) == 0:
return 0
Expand All @@ -243,25 +222,14 @@ def _grouped_data(self) -> abc.MutableMapping:
else:
return self._data

@cached_property
vyasr marked this conversation as resolved.
Show resolved Hide resolved
def _column_length(self):
try:
return len(self._data[next(iter(self._data))])
except StopIteration:
return 0

def _clear_cache(self):
cached_properties = ("columns", "names", "_grouped_data")
cached_properties = ("columns", "names", "_grouped_data", "nrows")
for attr in cached_properties:
try:
self.__delattr__(attr)
except AttributeError:
pass

# Column length should only be cleared if no data is present.
vyasr marked this conversation as resolved.
Show resolved Hide resolved
if len(self._data) == 0 and hasattr(self, "_column_length"):
del self._column_length

def to_pandas_index(self) -> pd.Index:
"""Convert the keys of the ColumnAccessor to a Pandas Index object."""
if self.multiindex and len(self.level_names) > 0:
Expand Down Expand Up @@ -345,11 +313,8 @@ def insert(
if loc == len(self._data):
if validate:
value = column.as_column(value)
if len(self._data) > 0:
if len(value) != self._column_length:
raise ValueError("All columns must be of equal length")
else:
self._column_length = len(value)
if len(self._data) > 0 and len(value) != self.nrows:
raise ValueError("All columns must be of equal length")
self._data[name] = value
else:
new_keys = self.names[:loc] + (name,) + self.names[loc:]
Expand All @@ -362,15 +327,16 @@ def copy(self, deep=False) -> ColumnAccessor:
Make a copy of this ColumnAccessor.
"""
if deep or cudf.get_option("copy_on_write"):
return self.__class__(
{k: v.copy(deep=deep) for k, v in self._data.items()},
multiindex=self.multiindex,
level_names=self.level_names,
)
data = {k: v.copy(deep=deep) for k, v in self._data.items()}
else:
data = self._data.copy()
return self.__class__(
self._data.copy(),
data=data,
multiindex=self.multiindex,
level_names=self.level_names,
rangeindex=self.rangeindex,
label_dtype=self.label_dtype,
verify=False,
)

def select_by_label(self, key: Any) -> ColumnAccessor:
Expand Down Expand Up @@ -508,22 +474,12 @@ def set_by_label(self, key: Any, value: Any, validate: bool = True):
key = self._pad_key(key)
if validate:
value = column.as_column(value)
if len(self._data) > 0:
if len(value) != self._column_length:
raise ValueError("All columns must be of equal length")
else:
self._column_length = len(value)
if len(self._data) > 0 and len(value) != self.nrows:
raise ValueError("All columns must be of equal length")

self._data[key] = value
self._clear_cache()

def _select_by_names(self, names: abc.Sequence) -> Self:
return self.__class__(
{key: self[key] for key in names},
multiindex=self.multiindex,
level_names=self.level_names,
)

def _select_by_label_list_like(self, key: Any) -> ColumnAccessor:
# Might be a generator
key = tuple(key)
Expand All @@ -541,7 +497,7 @@ def _select_by_label_list_like(self, key: Any) -> ColumnAccessor:
else:
data = {k: self._grouped_data[k] for k in key}
if self.multiindex:
data = _to_flat_dict(data)
data = dict(_to_flat_dict_inner(data))
return self.__class__(
data,
multiindex=self.multiindex,
Expand All @@ -550,11 +506,16 @@ def _select_by_label_list_like(self, key: Any) -> ColumnAccessor:

def _select_by_label_grouped(self, key: Any) -> ColumnAccessor:
result = self._grouped_data[key]
if isinstance(result, cudf.core.column.ColumnBase):
return self.__class__({key: result}, multiindex=self.multiindex)
if isinstance(result, column.ColumnBase):
# self._grouped_data[key] = self._data[key] so skip validation
return self.__class__(
data={key: result},
multiindex=self.multiindex,
verify=False,
)
else:
if self.multiindex:
result = _to_flat_dict(result)
result = dict(_to_flat_dict_inner(result))
if not isinstance(key, tuple):
key = (key,)
return self.__class__(
Expand All @@ -575,26 +536,28 @@ def _select_by_label_slice(self, key: slice) -> ColumnAccessor:
start = self._pad_key(start, slice(None))
stop = self._pad_key(stop, slice(None))
for idx, name in enumerate(self.names):
if _compare_keys(name, start):
if _keys_equal(name, start):
start_idx = idx
break
for idx, name in enumerate(reversed(self.names)):
if _compare_keys(name, stop):
if _keys_equal(name, stop):
stop_idx = len(self.names) - idx
break
keys = self.names[start_idx:stop_idx]
return self.__class__(
{k: self._data[k] for k in keys},
multiindex=self.multiindex,
level_names=self.level_names,
verify=False,
)

def _select_by_label_with_wildcard(self, key: Any) -> ColumnAccessor:
key = self._pad_key(key, slice(None))
return self.__class__(
{k: self._data[k] for k in self._data if _compare_keys(k, key)},
{k: self._data[k] for k in self._data if _keys_equal(k, key)},
multiindex=self.multiindex,
level_names=self.level_names,
verify=False,
)

def _pad_key(self, key: Any, pad_value="") -> Any:
Expand Down Expand Up @@ -639,6 +602,7 @@ def rename_levels(
to the given mapper and level.

"""
new_col_names: abc.Iterable
if self.multiindex:

def rename_column(x):
Expand All @@ -655,12 +619,7 @@ def rename_column(x):
"Renaming columns with a MultiIndex and level=None is"
"not supported"
)
new_names = map(rename_column, self.keys())
ca = ColumnAccessor(
dict(zip(new_names, self.values())),
level_names=self.level_names,
multiindex=self.multiindex,
)
new_col_names = (rename_column(k) for k in self.keys())

else:
if level is None:
Expand All @@ -680,13 +639,13 @@ def rename_column(x):
if len(new_col_names) != len(set(new_col_names)):
raise ValueError("Duplicate column names are not allowed")

ca = ColumnAccessor(
dict(zip(new_col_names, self.values())),
level_names=self.level_names,
multiindex=self.multiindex,
)

return self.__class__(ca)
data = dict(zip(new_col_names, self.values()))
return self.__class__(
data=data,
level_names=self.level_names,
multiindex=self.multiindex,
verify=False,
)

def droplevel(self, level):
# drop the nth level
Expand All @@ -708,7 +667,7 @@ def droplevel(self, level):
self._clear_cache()


def _compare_keys(target: Any, key: Any) -> bool:
def _keys_equal(target: Any, key: Any) -> bool:
"""
Compare `key` to `target`.

Expand Down
14 changes: 12 additions & 2 deletions python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,12 +480,22 @@ def __getitem__(self, arg):
index = self._frame.index
if col_is_scalar:
s = Series._from_data(
ca._select_by_names(column_names), index=index
data=ColumnAccessor(
{key: ca._data[key] for key in column_names},
multiindex=ca.multiindex,
level_names=ca.level_names,
),
index=index,
)
return s._getitem_preprocessed(row_spec)
if column_names != list(self._frame._column_names):
frame = self._frame._from_data(
ca._select_by_names(column_names), index=index
data=ColumnAccessor(
{key: ca._data[key] for key in column_names},
multiindex=ca.multiindex,
level_names=ca.level_names,
),
index=index,
)
else:
frame = self._frame
Expand Down
6 changes: 4 additions & 2 deletions python/cudf/cudf/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,12 +285,13 @@ def astype(self, dtype, copy=False, **kwargs):
else:
result_data[col_name] = col.copy() if copy else col

return ColumnAccessor._create_unsafe(
return ColumnAccessor(
data=result_data,
multiindex=self._data.multiindex,
level_names=self._data.level_names,
rangeindex=self._data.rangeindex,
label_dtype=self._data.label_dtype,
verify=False,
)

@_cudf_nvtx_annotate
Expand Down Expand Up @@ -883,12 +884,13 @@ def fillna(

return self._mimic_inplace(
self._from_data(
data=ColumnAccessor._create_unsafe(
data=ColumnAccessor(
data=filled_data,
multiindex=self._data.multiindex,
level_names=self._data.level_names,
rangeindex=self._data.rangeindex,
label_dtype=self._data.label_dtype,
verify=False,
)
),
inplace=inplace,
Expand Down
3 changes: 2 additions & 1 deletion python/cudf/cudf/core/multiindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,10 @@ def names(self, value):
# to unexpected behavior in some cases. This is
# definitely buggy, but we can't disallow non-unique
# names either...
self._data = self._data.__class__._create_unsafe(
self._data = self._data.__class__(
dict(zip(value, self._data.values())),
level_names=self._data.level_names,
verify=False,
)
self._names = pd.core.indexes.frozen.FrozenList(value)

Expand Down