Skip to content

Commit

Permalink
Fix ColumnAccessor caching of nrows if empty previously (#15710)
Browse files Browse the repository at this point in the history
#14758 may have propagated a caching invalidation bug of the number of rows in a `ColumnAccessor`

Previously the number of rows was cached and cleared only if an operation caused the `ColumnAccessor` to have no more columns.

However, if the `ColumnAccessor` was empty and operation added new columns, the cached number of rows should have also been cleared.

Authors:
  - Matthew Roeschke (https://github.com/mroeschke)

Approvers:
  - Lawrence Mitchell (https://github.com/wence-)

URL: #15710
  • Loading branch information
mroeschke authored May 13, 2024
1 parent bff3015 commit b4bdea2
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 14 deletions.
47 changes: 33 additions & 14 deletions python/cudf/cudf/core/column_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,10 @@ def __setitem__(self, key: Any, value: Any):
self.set_by_label(key, value)

def __delitem__(self, key: Any):
old_ncols = len(self._data)
del self._data[key]
self._clear_cache()
new_ncols = len(self._data)
self._clear_cache(old_ncols, new_ncols)

def __len__(self) -> int:
return len(self._data)
Expand Down Expand Up @@ -253,17 +255,30 @@ def _grouped_data(self) -> abc.MutableMapping:
else:
return self._data

def _clear_cache(self):
def _clear_cache(self, old_ncols: int, new_ncols: int):
"""
Clear cached attributes.
Parameters
----------
old_ncols: int
len(self._data) before self._data was modified
new_ncols: int
len(self._data) after self._data was modified
"""
cached_properties = ("columns", "names", "_grouped_data")
for attr in cached_properties:
try:
self.__delattr__(attr)
except AttributeError:
pass

# nrows should only be cleared if no data is present.
if len(self._data) == 0 and hasattr(self, "nrows"):
del self.nrows
# nrows should only be cleared if empty before/after the op.
if (old_ncols == 0) ^ (new_ncols == 0):
try:
del self.nrows
except AttributeError:
pass

def to_pandas_index(self) -> pd.Index:
"""Convert the keys of the ColumnAccessor to a Pandas Index object."""
Expand Down Expand Up @@ -321,27 +336,27 @@ def insert(
"""
name = self._pad_key(name)

ncols = len(self._data)
old_ncols = len(self._data)
if loc == -1:
loc = ncols
if not (0 <= loc <= ncols):
loc = old_ncols
if not (0 <= loc <= old_ncols):
raise ValueError(
"insert: loc out of bounds: must be 0 <= loc <= ncols"
f"insert: loc out of bounds: must be 0 <= loc <= {old_ncols}"
)
# TODO: we should move all insert logic here
if name in self._data:
raise ValueError(f"Cannot insert '{name}', already exists")
if loc == len(self._data):
if loc == old_ncols:
if validate:
value = column.as_column(value)
if len(self._data) > 0 and len(value) != self.nrows:
if old_ncols > 0 and len(value) != self.nrows:
raise ValueError("All columns must be of equal length")
self._data[name] = value
else:
new_keys = self.names[:loc] + (name,) + self.names[loc:]
new_values = self.columns[:loc] + (value,) + self.columns[loc:]
self._data = self._data.__class__(zip(new_keys, new_values))
self._clear_cache()
self._clear_cache(old_ncols, old_ncols + 1)

def copy(self, deep=False) -> ColumnAccessor:
"""
Expand Down Expand Up @@ -498,8 +513,10 @@ def set_by_label(self, key: Any, value: Any, validate: bool = True):
if len(self._data) > 0 and len(value) != self.nrows:
raise ValueError("All columns must be of equal length")

old_ncols = len(self._data)
self._data[key] = value
self._clear_cache()
new_ncols = len(self._data)
self._clear_cache(old_ncols, new_ncols)

def _select_by_label_list_like(self, key: Any) -> ColumnAccessor:
# Might be a generator
Expand Down Expand Up @@ -673,10 +690,12 @@ def droplevel(self, level):
if level < 0:
level += self.nlevels

old_ncols = len(self._data)
self._data = {
_remove_key_level(key, level): value
for key, value in self._data.items()
}
new_ncols = len(self._data)
self._level_names = (
self._level_names[:level] + self._level_names[level + 1 :]
)
Expand All @@ -685,7 +704,7 @@ def droplevel(self, level):
len(self._level_names) == 1
): # can't use nlevels, as it depends on multiindex
self.multiindex = False
self._clear_cache()
self._clear_cache(old_ncols, new_ncols)


def _keys_equal(target: Any, key: Any) -> bool:
Expand Down
14 changes: 14 additions & 0 deletions python/cudf/cudf/tests/test_column_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,3 +293,17 @@ def test_replace_level_values_MultiColumn():

got = ca.rename_levels(mapper={"a": "f"}, level=0)
check_ca_equal(expect, got)


def test_clear_nrows_empty_before():
ca = ColumnAccessor({})
assert ca.nrows == 0
ca.insert("new", [1])
assert ca.nrows == 1


def test_clear_nrows_empty_after():
ca = ColumnAccessor({"new": [1]})
assert ca.nrows == 1
del ca["new"]
assert ca.nrows == 0

0 comments on commit b4bdea2

Please sign in to comment.