Skip to content

Commit

Permalink
PERF-#7397: Avoid materializing index/columns in shape checks (#7398)
Browse files Browse the repository at this point in the history
Signed-off-by: Jonathan Shi <[email protected]>
Co-authored-by: Anatoly Myachev <[email protected]>
  • Loading branch information
noloerino and anmyachev authored Sep 21, 2024
1 parent 8b8806e commit cc717a0
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 37 deletions.
20 changes: 19 additions & 1 deletion modin/core/storage_formats/base/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import abc
import warnings
from functools import cached_property
from typing import TYPE_CHECKING, Hashable, List, Optional
from typing import TYPE_CHECKING, Hashable, List, Literal, Optional

import numpy as np
import pandas
Expand Down Expand Up @@ -4270,6 +4270,24 @@ def get_axis(self, axis):
"""
return self.index if axis == 0 else self.columns

def get_axis_len(self, axis: Literal[0, 1]) -> int:
"""
Return the length of the specified axis.
A query compiler may choose to override this method if it has a more efficient way
of computing the length of an axis without materializing it.
Parameters
----------
axis : {0, 1}
Axis to return labels on.
Returns
-------
int
"""
return len(self.get_axis(axis))

def take_2d_labels(
self,
index,
Expand Down
20 changes: 19 additions & 1 deletion modin/core/storage_formats/pandas/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import re
import warnings
from collections.abc import Iterable
from typing import TYPE_CHECKING, Hashable, List, Optional
from typing import TYPE_CHECKING, Hashable, List, Literal, Optional

import numpy as np
import pandas
Expand Down Expand Up @@ -395,6 +395,24 @@ def from_dataframe(cls, df, data_cls):
index: pandas.Index = property(_get_axis(0), _set_axis(0))
columns: pandas.Index = property(_get_axis(1), _set_axis(1))

def get_axis_len(self, axis: Literal[0, 1]) -> int:
"""
Return the length of the specified axis.
Parameters
----------
axis : {0, 1}
Axis to return labels on.
Returns
-------
int
"""
if axis == 0:
return len(self._modin_frame)
else:
return sum(self._modin_frame.column_widths)

@property
def dtypes(self) -> pandas.Series:
return self._modin_frame.dtypes
Expand Down
12 changes: 7 additions & 5 deletions modin/pandas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,9 @@ def _build_repr_df(
A pandas dataset with `num_rows` or fewer rows and `num_cols` or fewer columns.
"""
# Fast track for empty dataframe.
if len(self.index) == 0 or (self._is_dataframe and len(self.columns) == 0):
if len(self) == 0 or (
self._is_dataframe and self._query_compiler.get_axis_len(1) == 0
):
return pandas.DataFrame(
index=self.index,
columns=self.columns if self._is_dataframe else None,
Expand Down Expand Up @@ -1004,7 +1006,7 @@ def error_raiser(msg, exception):
return result._query_compiler
return result
elif isinstance(func, dict):
if len(self.columns) != len(set(self.columns)):
if self._query_compiler.get_axis_len(1) != len(set(self.columns)):
warnings.warn(
"duplicate column names not supported with apply().",
FutureWarning,
Expand Down Expand Up @@ -2860,7 +2862,7 @@ def sample(
axis_length = len(axis_labels)
else:
# Getting rows requires indices instead of labels. RangeIndex provides this.
axis_labels = pandas.RangeIndex(len(self.index))
axis_labels = pandas.RangeIndex(len(self))
axis_length = len(axis_labels)
if weights is not None:
# Index of the weights Series should correspond to the index of the
Expand Down Expand Up @@ -3217,7 +3219,7 @@ def tail(self, n=5) -> Self: # noqa: PR01, RT01, D200
"""
if n != 0:
return self.iloc[-n:]
return self.iloc[len(self.index) :]
return self.iloc[len(self) :]

def take(self, indices, axis=0, **kwargs) -> Self: # noqa: PR01, RT01, D200
"""
Expand Down Expand Up @@ -4149,7 +4151,7 @@ def __len__(self) -> int:
-------
int
"""
return len(self.index)
return self._query_compiler.get_axis_len(0)

@_doc_binary_op(
operation="less than comparison",
Expand Down
53 changes: 30 additions & 23 deletions modin/pandas/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,13 +268,15 @@ def __repr__(self) -> str:
-------
str
"""
num_rows = pandas.get_option("display.max_rows") or len(self.index)
num_cols = pandas.get_option("display.max_columns") or len(self.columns)
num_rows = pandas.get_option("display.max_rows") or len(self)
num_cols = pandas.get_option(
"display.max_columns"
) or self._query_compiler.get_axis_len(1)
result = repr(self._build_repr_df(num_rows, num_cols))
if len(self.index) > num_rows or len(self.columns) > num_cols:
if len(self) > num_rows or self._query_compiler.get_axis_len(1) > num_cols:
# The split here is so that we don't repr pandas row lengths.
return result.rsplit("\n\n", 1)[0] + "\n\n[{0} rows x {1} columns]".format(
len(self.index), len(self.columns)
*self.shape
)
else:
return result
Expand All @@ -293,13 +295,11 @@ def _repr_html_(self) -> str: # pragma: no cover
# We use pandas _repr_html_ to get a string of the HTML representation
# of the dataframe.
result = self._build_repr_df(num_rows, num_cols)._repr_html_()
if len(self.index) > num_rows or len(self.columns) > num_cols:
if len(self) > num_rows or self._query_compiler.get_axis_len(1) > num_cols:
# We split so that we insert our correct dataframe dimensions.
return result.split("<p>")[
0
] + "<p>{0} rows x {1} columns</p>\n</div>".format(
len(self.index), len(self.columns)
)
] + "<p>{0} rows x {1} columns</p>\n</div>".format(*self.shape)
else:
return result

Expand Down Expand Up @@ -365,7 +365,7 @@ def empty(self) -> bool: # noqa: RT01, D200
"""
Indicate whether ``DataFrame`` is empty.
"""
return len(self.columns) == 0 or len(self.index) == 0
return self._query_compiler.get_axis_len(1) == 0 or len(self) == 0

@property
def axes(self) -> list[pandas.Index]: # noqa: RT01, D200
Expand All @@ -379,7 +379,7 @@ def shape(self) -> tuple[int, int]: # noqa: RT01, D200
"""
Return a tuple representing the dimensionality of the ``DataFrame``.
"""
return len(self.index), len(self.columns)
return len(self), self._query_compiler.get_axis_len(1)

def add_prefix(self, prefix, axis=None) -> DataFrame: # noqa: PR01, RT01, D200
"""
Expand Down Expand Up @@ -781,7 +781,9 @@ def dot(self, other) -> Union[DataFrame, Series]: # noqa: PR01, RT01, D200
"""
if isinstance(other, BasePandasDataset):
common = self.columns.union(other.index)
if len(common) > len(self.columns) or len(common) > len(other.index):
if len(common) > self._query_compiler.get_axis_len(1) or len(common) > len(
other
):
raise ValueError("Matrices are not aligned")

qc = other.reindex(index=common)._query_compiler
Expand Down Expand Up @@ -1084,7 +1086,7 @@ def insert(
+ f"{len(value.columns)} columns instead."
)
value = value.squeeze(axis=1)
if not self._query_compiler.lazy_row_count and len(self.index) == 0:
if not self._query_compiler.lazy_row_count and len(self) == 0:
if not hasattr(value, "index"):
try:
value = pandas.Series(value)
Expand All @@ -1099,7 +1101,7 @@ def insert(
new_query_compiler = self.__constructor__(
value, index=new_index, columns=new_columns
)._query_compiler
elif len(self.columns) == 0 and loc == 0:
elif self._query_compiler.get_axis_len(1) == 0 and loc == 0:
new_index = self.index
new_query_compiler = self.__constructor__(
data=value,
Expand All @@ -1110,18 +1112,19 @@ def insert(
if (
is_list_like(value)
and not isinstance(value, (pandas.Series, Series))
and len(value) != len(self.index)
and len(value) != len(self)
):
raise ValueError(
"Length of values ({}) does not match length of index ({})".format(
len(value), len(self.index)
len(value), len(self)
)
)
if allow_duplicates is not True and column in self.columns:
raise ValueError(f"cannot insert {column}, already exists")
if not -len(self.columns) <= loc <= len(self.columns):
columns_len = self._query_compiler.get_axis_len(1)
if not -columns_len <= loc <= columns_len:
raise IndexError(
f"index {loc} is out of bounds for axis 0 with size {len(self.columns)}"
f"index {loc} is out of bounds for axis 0 with size {columns_len}"
)
elif loc < 0:
raise ValueError("unbounded slice")
Expand Down Expand Up @@ -2074,9 +2077,11 @@ def squeeze(
Squeeze 1 dimensional axis objects into scalars.
"""
axis = self._get_axis_number(axis) if axis is not None else None
if axis is None and (len(self.columns) == 1 or len(self) == 1):
if axis is None and (
self._query_compiler.get_axis_len(1) == 1 or len(self) == 1
):
return Series(query_compiler=self._query_compiler).squeeze()
if axis == 1 and len(self.columns) == 1:
if axis == 1 and self._query_compiler.get_axis_len(1) == 1:
self._query_compiler._shape_hint = "column"
return Series(query_compiler=self._query_compiler)
if axis == 0 and len(self) == 1:
Expand Down Expand Up @@ -2671,7 +2676,7 @@ def __setitem__(self, key, value) -> None:
return self._setitem_slice(key, value)

if hashable(key) and key not in self.columns:
if isinstance(value, Series) and len(self.columns) == 0:
if isinstance(value, Series) and self._query_compiler.get_axis_len(1) == 0:
# Note: column information is lost when assigning a query compiler
prev_index = self.columns
self._query_compiler = value._query_compiler.copy()
Expand All @@ -2680,7 +2685,9 @@ def __setitem__(self, key, value) -> None:
self.columns = prev_index.insert(0, key)
return
# Do new column assignment after error checks and possible value modifications
self.insert(loc=len(self.columns), column=key, value=value)
self.insert(
loc=self._query_compiler.get_axis_len(1), column=key, value=value
)
return

if not hashable(key):
Expand Down Expand Up @@ -2756,7 +2763,7 @@ def __setitem__(self, key, value) -> None:

new_qc = self._query_compiler.insert_item(
axis=1,
loc=len(self.columns),
loc=self._query_compiler.get_axis_len(1),
value=value._query_compiler,
how="left",
)
Expand All @@ -2783,7 +2790,7 @@ def setitem_unhashable_key(df, value):
if not isinstance(value, (Series, Categorical, np.ndarray, list, range)):
value = list(value)

if not self._query_compiler.lazy_row_count and len(self.index) == 0:
if not self._query_compiler.lazy_row_count and len(self) == 0:
new_self = self.__constructor__({key: value}, columns=self.columns)
self._update_inplace(new_self._query_compiler)
else:
Expand Down
14 changes: 7 additions & 7 deletions modin/pandas/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,8 +445,8 @@ def __repr__(self) -> str:
name_str = "Name: {}, ".format(str(self.name))
else:
name_str = ""
if len(self.index) > num_rows:
len_str = "Length: {}, ".format(len(self.index))
if len(self) > num_rows:
len_str = "Length: {}, ".format(len(self))
else:
len_str = ""
dtype_str = "dtype: {}".format(
Expand Down Expand Up @@ -966,7 +966,7 @@ def dot(self, other) -> Union[Series, np.ndarray]: # noqa: PR01, RT01, D200
"""
if isinstance(other, BasePandasDataset):
common = self.index.union(other.index)
if len(common) > len(self.index) or len(common) > len(other.index):
if len(common) > len(self) or len(common) > len(other):
raise ValueError("Matrices are not aligned")

qc = other.reindex(index=common)._query_compiler
Expand Down Expand Up @@ -1761,7 +1761,7 @@ def reset_index(
name = 0 if self.name is None else self.name

if drop and level is None:
new_idx = pandas.RangeIndex(len(self.index))
new_idx = pandas.RangeIndex(len(self))
if inplace:
self.index = new_idx
else:
Expand Down Expand Up @@ -1989,7 +1989,7 @@ def squeeze(self, axis=None) -> Union[Series, Scalar]: # noqa: PR01, RT01, D200
if axis is not None:
# Validate `axis`
pandas.Series._get_axis_number(axis)
if len(self.index) == 1:
if len(self) == 1:
return self._reduce_dimension(self._query_compiler)
else:
return self.copy()
Expand Down Expand Up @@ -2307,7 +2307,7 @@ def empty(self) -> bool: # noqa: RT01, D200
"""
Indicate whether Series is empty.
"""
return len(self.index) == 0
return len(self) == 0

@property
def hasnans(self) -> bool: # noqa: RT01, D200
Expand Down Expand Up @@ -2648,7 +2648,7 @@ def _getitem(self, key) -> Union[Series, Scalar]:
if is_bool_indexer(key):
return self.__constructor__(
query_compiler=self._query_compiler.getitem_row_array(
pandas.RangeIndex(len(self.index))[key]
pandas.RangeIndex(len(self))[key]
)
)
# TODO: More efficiently handle `tuple` case for `Series.__getitem__`
Expand Down

0 comments on commit cc717a0

Please sign in to comment.