Skip to content

Commit

Permalink
Remove sizeof and standardize on memory_usage (#9544)
Browse files Browse the repository at this point in the history
This PR removes implementations of `__sizeof__` from cudf classes. Previously, `__sizeof__` was overridden to return the total GPU memory usage, but this is inconsistent with the standard Python semantics of this function and should be removed. The appropriate way to query for total GPU memory usage is via the `memory_usage` function, which is now standardized across various objects. The sizeof dispatch for dask is set to use `memory_usage` as well to avoid any breakage here.

Authors:
  - Vyas Ramasubramani (https://github.com/vyasr)

Approvers:
  - Peter Andreas Entschev (https://github.com/pentschev)
  - Charles Blackmon-Luca (https://github.com/charlesbluca)
  - Mads R. B. Kristensen (https://github.com/madsbk)

URL: #9544
  • Loading branch information
vyasr authored Nov 12, 2021
1 parent 10cbbd7 commit 79b4f54
Show file tree
Hide file tree
Showing 11 changed files with 93 additions and 108 deletions.
7 changes: 6 additions & 1 deletion python/cudf/cudf/core/_base_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -1365,7 +1365,12 @@ def memory_usage(self, deep=False):
-------
bytes used
"""
return self._values._memory_usage(deep=deep)
if deep:
warnings.warn(
"The deep parameter is ignored and is only included "
"for pandas compatibility."
)
return self._values.memory_usage()

@classmethod
def from_pandas(cls, index, nan_as_null=None):
Expand Down
11 changes: 2 additions & 9 deletions python/cudf/cudf/core/column/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -1323,15 +1323,8 @@ def copy(self, deep: bool = True) -> CategoricalColumn:
size=self.size,
)

def __sizeof__(self) -> int:
return self.categories.__sizeof__() + self.codes.__sizeof__()

def _memory_usage(self, **kwargs) -> int:
deep = kwargs.get("deep", False)
if deep:
return self.__sizeof__()
else:
return self.categories._memory_usage() + self.codes._memory_usage()
def memory_usage(self) -> int:
return self.categories.memory_usage() + self.codes.memory_usage()

def _mimic_inplace(
self, other_col: ColumnBase, inplace: bool = False
Expand Down
19 changes: 8 additions & 11 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,14 +199,6 @@ def any(self, skipna: bool = True) -> bool:

return result_col

def __sizeof__(self) -> int:
n = 0
if self.data is not None:
n += self.data.size
if self.nullable:
n += bitmask_allocation_size_bytes(self.size)
return n

def dropna(self, drop_nan: bool = False) -> ColumnBase:
if drop_nan:
col = self.nans_to_nulls()
Expand Down Expand Up @@ -313,13 +305,18 @@ def _get_mask_as_column(self) -> ColumnBase:
self.base_mask, self.offset, self.offset + len(self)
)

def _memory_usage(self, **kwargs) -> int:
return self.__sizeof__()
def memory_usage(self) -> int:
n = 0
if self.data is not None:
n += self.data.size
if self.nullable:
n += bitmask_allocation_size_bytes(self.size)
return n

def _default_na_value(self) -> Any:
raise NotImplementedError()

# TODO: This method is decpreated and can be removed when the associated
# TODO: This method is deprecated and can be removed when the associated
# Frame methods are removed.
def to_gpu_array(self, fillna=None) -> "cuda.devicearray.DeviceNDArray":
"""Get a dense numba device array for the data.
Expand Down
59 changes: 26 additions & 33 deletions python/cudf/cudf/core/column/lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,41 +42,34 @@ def __init__(
children=children,
)

def __sizeof__(self):
if self._cached_sizeof is None:
n = 0
if self.nullable:
n += cudf._lib.null_mask.bitmask_allocation_size_bytes(
self.size
)

child0_size = (self.size + 1) * self.base_children[
0
].dtype.itemsize
current_base_child = self.base_children[1]
current_offset = self.offset
def memory_usage(self):
n = 0
if self.nullable:
n += cudf._lib.null_mask.bitmask_allocation_size_bytes(self.size)

child0_size = (self.size + 1) * self.base_children[0].dtype.itemsize
current_base_child = self.base_children[1]
current_offset = self.offset
n += child0_size
while type(current_base_child) is ListColumn:
child0_size = (
current_base_child.size + 1 - current_offset
) * current_base_child.base_children[0].dtype.itemsize
current_offset = current_base_child.base_children[0][
current_offset
]
n += child0_size
while type(current_base_child) is ListColumn:
child0_size = (
current_base_child.size + 1 - current_offset
) * current_base_child.base_children[0].dtype.itemsize
current_offset = current_base_child.base_children[0][
current_offset
]
n += child0_size
current_base_child = current_base_child.base_children[1]

n += (
current_base_child.size - current_offset
) * current_base_child.dtype.itemsize

if current_base_child.nullable:
n += cudf._lib.null_mask.bitmask_allocation_size_bytes(
current_base_child.size
)
self._cached_sizeof = n
current_base_child = current_base_child.base_children[1]

n += (
current_base_child.size - current_offset
) * current_base_child.dtype.itemsize

return self._cached_sizeof
if current_base_child.nullable:
n += cudf._lib.null_mask.bitmask_allocation_size_bytes(
current_base_child.size
)
return n

def __setitem__(self, key, value):
if isinstance(value, list):
Expand Down
33 changes: 14 additions & 19 deletions python/cudf/cudf/core/column/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -5214,26 +5214,21 @@ def end_offset(self) -> int:

return self._end_offset

def __sizeof__(self) -> int:
if self._cached_sizeof is None:
n = 0
if len(self.base_children) == 2:
child0_size = (self.size + 1) * self.base_children[
0
].dtype.itemsize

child1_size = (
self.end_offset - self.start_offset
) * self.base_children[1].dtype.itemsize

n += child0_size + child1_size
if self.nullable:
n += cudf._lib.null_mask.bitmask_allocation_size_bytes(
self.size
)
self._cached_sizeof = n
def memory_usage(self) -> int:
n = 0
if len(self.base_children) == 2:
child0_size = (self.size + 1) * self.base_children[
0
].dtype.itemsize

child1_size = (
self.end_offset - self.start_offset
) * self.base_children[1].dtype.itemsize

return self._cached_sizeof
n += child0_size + child1_size
if self.nullable:
n += cudf._lib.null_mask.bitmask_allocation_size_bytes(self.size)
return n

@property
def base_size(self) -> int:
Expand Down
14 changes: 7 additions & 7 deletions python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1138,11 +1138,6 @@ def __delitem__(self, name):
"""
self._drop_column(name)

def __sizeof__(self):
columns = sum(col.__sizeof__() for col in self._data.columns)
index = self._index.__sizeof__()
return columns + index

def _slice(self: T, arg: slice) -> T:
"""
_slice : slice the frame as per the arg
Expand Down Expand Up @@ -1253,12 +1248,17 @@ def memory_usage(self, index=True, deep=False):
>>> df['object'].astype('category').memory_usage(deep=True)
5048
"""
if deep:
warnings.warn(
"The deep parameter is ignored and is only included "
"for pandas compatibility."
)
ind = list(self.columns)
sizes = [col._memory_usage(deep=deep) for col in self._data.columns]
sizes = [col.memory_usage() for col in self._data.columns]
if index:
ind.append("Index")
ind = cudf.Index(ind, dtype="str")
sizes.append(self.index.memory_usage(deep=deep))
sizes.append(self.index.memory_usage())
return Series(sizes, index=ind)

def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
Expand Down
10 changes: 6 additions & 4 deletions python/cudf/cudf/core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,12 @@ def get_slice_bound(self, label, side, kind=None):
pos = search_range(start, stop, label, step, side=side)
return pos

def memory_usage(self, **kwargs):
def memory_usage(self, deep=False):
if deep:
warnings.warn(
"The deep parameter is ignored and is only included "
"for pandas compatibility."
)
return 0

def unique(self):
Expand Down Expand Up @@ -1022,9 +1027,6 @@ def get_loc(self, key, method=None, tolerance=None):
mask[true_inds] = True
return mask

def __sizeof__(self):
return self._values.__sizeof__()

def __repr__(self):
max_seq_items = get_option("max_seq_items") or len(self)
mr = 0
Expand Down
12 changes: 7 additions & 5 deletions python/cudf/cudf/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,9 +979,6 @@ def set_mask(self, mask, null_count=None):
{self.name: self._column.set_mask(mask)}, self._index
)

def __sizeof__(self):
return self._column.__sizeof__() + self._index.__sizeof__()

def memory_usage(self, index=True, deep=False):
"""
Return the memory usage of the Series.
Expand Down Expand Up @@ -1020,9 +1017,14 @@ def memory_usage(self, index=True, deep=False):
>>> s.memory_usage(index=False)
24
"""
n = self._column._memory_usage(deep=deep)
if deep:
warnings.warn(
"The deep parameter is ignored and is only included "
"for pandas compatibility."
)
n = self._column.memory_usage()
if index:
n += self._index.memory_usage(deep=deep)
n += self._index.memory_usage()
return n

def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
Expand Down
21 changes: 4 additions & 17 deletions python/cudf/cudf/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3649,19 +3649,6 @@ def test_empty_dataframe_any(axis):
assert_eq(got, expected, check_index_type=False)


@pytest.mark.parametrize("indexed", [False, True])
def test_dataframe_sizeof(indexed):
rows = int(1e6)
index = list(i for i in range(rows)) if indexed else None

gdf = cudf.DataFrame({"A": [8] * rows, "B": [32] * rows}, index=index)

for c in gdf._data.columns:
assert gdf._index.__sizeof__() == gdf._index.__sizeof__()
cols_sizeof = sum(c.__sizeof__() for c in gdf._data.columns)
assert gdf.__sizeof__() == (gdf._index.__sizeof__() + cols_sizeof)


@pytest.mark.parametrize("a", [[], ["123"]])
@pytest.mark.parametrize("b", ["123", ["123"]])
@pytest.mark.parametrize(
Expand Down Expand Up @@ -5394,8 +5381,8 @@ def test_memory_usage_cat():
gdf = cudf.from_pandas(df)

expected = (
gdf.B._column.categories.__sizeof__()
+ gdf.B._column.codes.__sizeof__()
gdf.B._column.categories.memory_usage()
+ gdf.B._column.codes.memory_usage()
)

# Check cat column
Expand All @@ -5408,8 +5395,8 @@ def test_memory_usage_cat():
def test_memory_usage_list():
df = cudf.DataFrame({"A": [[0, 1, 2, 3], [4, 5, 6], [7, 8], [9]]})
expected = (
df.A._column.offsets._memory_usage()
+ df.A._column.elements._memory_usage()
df.A._column.offsets.memory_usage()
+ df.A._column.elements.memory_usage()
)
assert expected == df.A.memory_usage()

Expand Down
4 changes: 2 additions & 2 deletions python/cudf/cudf/tests/test_pickling.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@ def test_pickle_dataframe_categorical():
check_serialization(df)


def test_sizeof_dataframe():
def test_memory_usage_dataframe():
np.random.seed(0)
df = DataFrame()
nelem = 1000
df["keys"] = hkeys = np.arange(nelem, dtype=np.float64)
df["vals"] = hvals = np.random.random(nelem)

nbytes = hkeys.nbytes + hvals.nbytes
sizeof = sys.getsizeof(df)
sizeof = df.memory_usage().sum()
assert sizeof >= nbytes

serialized_nbytes = len(pickle.dumps(df, protocol=pickle.HIGHEST_PROTOCOL))
Expand Down
11 changes: 11 additions & 0 deletions python/dask_cudf/dask_cudf/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
is_scalar,
make_meta_obj,
)
from dask.sizeof import sizeof as sizeof_dispatch

import cudf
from cudf.api.types import is_string_dtype
Expand Down Expand Up @@ -345,3 +346,13 @@ def group_split_cudf(df, c, k, ignore_index=False):
),
)
)


@sizeof_dispatch.register(cudf.DataFrame)
def sizeof_cudf_dataframe(df):
return int(df.memory_usage().sum())


@sizeof_dispatch.register((cudf.Series, cudf.BaseIndex))
def sizeof_cudf_series_index(obj):
return obj.memory_usage()

0 comments on commit 79b4f54

Please sign in to comment.