Skip to content

Commit

Permalink
PERF-modin-project#6583: Remove redundant index reassignment in query()
Browse files Browse the repository at this point in the history
Signed-off-by: mvashishtha <[email protected]>
  • Loading branch information
mvashishtha committed Sep 18, 2023
1 parent 40216fa commit 931fb46
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 1 deletion.
35 changes: 34 additions & 1 deletion modin/pandas/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1597,7 +1597,40 @@ def quantile(
# methods and fields we need to use pandas.DataFrame.query
_AXIS_ORDERS = ["index", "columns"]
_get_index_resolvers = pandas.DataFrame._get_index_resolvers
_get_axis_resolvers = pandas.DataFrame._get_axis_resolvers

def _get_axis_resolvers(self, axis: str) -> dict:
# forked from pandas because we only want to update the index if there's more
# than one level of the index.
# index or columns
axis_index = getattr(self, axis)
d = {}
prefix = axis[0]

for i, name in enumerate(axis_index.names):
if name is not None:
key = level = name
else:
# prefix with 'i' or 'c' depending on the input axis
# e.g., you must do ilevel_0 for the 0th level of an unnamed
# multiiindex
key = f"{prefix}level_{i}"
level = i

level_values = axis_index.get_level_values(level)
s = level_values.to_series()
if axis_index.nlevels > 1:
s.index = axis_index
d[key] = s

# put the index/columns itself in the dict
if axis_index.nlevels > 2:
dindex = axis_index
else:
dindex = axis_index.to_series()

d[axis] = dindex
return d

_get_cleaned_column_resolvers = pandas.DataFrame._get_cleaned_column_resolvers

def query(self, expr, inplace=False, **kwargs): # noqa: PR01, RT01, D200
Expand Down
35 changes: 35 additions & 0 deletions modin/pandas/test/dataframe/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,41 @@ def test_query(data, funcs, engine):
df_equals(modin_result.dtypes, pandas_result.dtypes)


def test_query_named_index():
eval_general(
*(df.set_index("col1") for df in create_test_dfs(test_data["int_data"])),
lambda df: df.query("col1 % 2 == 0 | col2 % 2 == 1"),
# work around https://github.com/modin-project/modin/issues/6016
raising_exceptions=Exception,
)


def test_query_named_multiindex():
eval_general(
*(
df.set_index(["col1", "col2"])
for df in create_test_dfs(test_data["int_data"])
),
lambda df: df.query("col1 % 2 == 1 | col2 % 2 == 1"),
# work around https://github.com/modin-project/modin/issues/6016
raising_exceptions=Exception,
)


def test_query_multiindex_without_names():
def make_df(without_index):
new_df = without_index.set_index(["col1", "col2"])
new_df.index.names = [None, None]
return new_df

eval_general(
*(make_df(df) for df in create_test_dfs(test_data["int_data"])),
lambda df: df.query("ilevel_0 % 2 == 0 | ilevel_1 % 2 == 1 | col3 % 2 == 1"),
# work around https://github.com/modin-project/modin/issues/6016
raising_exceptions=Exception,
)


def test_empty_query():
modin_df = pd.DataFrame([1, 2, 3, 4, 5])

Expand Down
3 changes: 3 additions & 0 deletions modin/pandas/test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,9 @@
"col3 > col4": "col3 > col4",
"col1 == col2": "col1 == col2",
"(col2 > col1) and (col1 < col3)": "(col2 > col1) and (col1 < col3)",
# this is how to query for values of an unnamed index per
# https://pandas.pydata.org/docs/user_guide/indexing.html#multiindex-query-syntax
"ilevel_0 % 2 == 1": "ilevel_0 % 2 == 1",
}
query_func_keys = list(query_func.keys())
query_func_values = list(query_func.values())
Expand Down

0 comments on commit 931fb46

Please sign in to comment.