From 209a03c04215e9fc79359aa47fc41b8986676bc2 Mon Sep 17 00:00:00 2001 From: Dmitry Chigarev Date: Thu, 9 Feb 2023 16:20:09 +0100 Subject: [PATCH] PERF-#5596: Do not trigger index materialization for '.merge' result (#5619) Signed-off-by: Dmitry Chigarev --- .../storage_formats/pandas/query_compiler.py | 74 +++++++++++++------ modin/pandas/test/dataframe/test_join_sort.py | 73 ++++++++++++++++++ 2 files changed, 123 insertions(+), 24 deletions(-) diff --git a/modin/core/storage_formats/pandas/query_compiler.py b/modin/core/storage_formats/pandas/query_compiler.py index c90ddbde8b8..65f87c2a48a 100644 --- a/modin/core/storage_formats/pandas/query_compiler.py +++ b/modin/core/storage_formats/pandas/query_compiler.py @@ -459,6 +459,13 @@ def merge(self, right, **kwargs): kwargs["sort"] = False + # Want to ensure that these are python lists + if left_on is not None and right_on is not None: + left_on = list(left_on) if is_list_like(left_on) else [left_on] + right_on = list(right_on) if is_list_like(right_on) else [right_on] + elif on is not None: + on = list(on) if is_list_like(on) else [on] + def map_func(left, right=right, kwargs=kwargs): return pandas.merge(left, right, **kwargs) @@ -471,34 +478,53 @@ def map_func(left, right=right, kwargs=kwargs): keep_partitioning=False, ) ) - is_reset_index = True - if left_on and right_on: - left_on = left_on if is_list_like(left_on) else [left_on] - right_on = right_on if is_list_like(right_on) else [right_on] - is_reset_index = ( - False - if any(o in new_self.index.names for o in left_on) - and any(o in right.index.names for o in right_on) - else True - ) - if sort: + + # Here we want to understand whether we're joining on a column or on an index level. + # It's cool if indexes are already materialized so we can easily check that, if not + # it's fine too, we can also decide that by columns, which tend to be already + # materialized quite often compared to the indexes. + keep_index = False + if self._modin_frame._index_cache is not None: + if left_on is not None and right_on is not None: + keep_index = any( + o in self.index.names + and o in right_on + and o in right.index.names + for o in left_on + ) + elif on is not None: + keep_index = any( + o in self.index.names and o in right.index.names for o in on + ) + else: + # Have to trigger columns materialization. Hope they're already available at this point. + if left_on is not None and right_on is not None: + keep_index = any( + o not in right.columns + and o in left_on + and o not in self.columns + for o in right_on + ) + elif on is not None: + keep_index = any( + o not in right.columns and o not in self.columns for o in on + ) + + if sort: + if left_on is not None and right_on is not None: new_self = ( - new_self.sort_rows_by_column_values(left_on.append(right_on)) - if is_reset_index - else new_self.sort_index(axis=0, level=left_on.append(right_on)) + new_self.sort_index(axis=0, level=left_on + right_on) + if keep_index + else new_self.sort_rows_by_column_values(left_on + right_on) ) - if on: - on = on if is_list_like(on) else [on] - is_reset_index = not any( - o in new_self.index.names and o in right.index.names for o in on - ) - if sort: + elif on is not None: new_self = ( - new_self.sort_rows_by_column_values(on) - if is_reset_index - else new_self.sort_index(axis=0, level=on) + new_self.sort_index(axis=0, level=on) + if keep_index + else new_self.sort_rows_by_column_values(on) ) - return new_self.reset_index(drop=True) if is_reset_index else new_self + + return new_self if keep_index else new_self.reset_index(drop=True) else: return self.default_to_pandas(pandas.DataFrame.merge, right, **kwargs) diff --git a/modin/pandas/test/dataframe/test_join_sort.py b/modin/pandas/test/dataframe/test_join_sort.py index aadec34aca0..ab3241b8824 100644 --- a/modin/pandas/test/dataframe/test_join_sort.py +++ b/modin/pandas/test/dataframe/test_join_sort.py @@ -329,6 +329,79 @@ def test_merge(test_data, test_data2): modin_df.merge("Non-valid type") +@pytest.mark.parametrize("has_index_cache", [True, False]) +def test_merge_on_index(has_index_cache): + modin_df1, pandas_df1 = create_test_dfs( + { + "idx_key1": [1, 2, 3, 4], + "idx_key2": [2, 3, 4, 5], + "idx_key3": [3, 4, 5, 6], + "data_col1": [10, 2, 3, 4], + "col_key1": [3, 4, 5, 6], + "col_key2": [3, 4, 5, 6], + } + ) + + modin_df1 = modin_df1.set_index(["idx_key1", "idx_key2"]) + pandas_df1 = pandas_df1.set_index(["idx_key1", "idx_key2"]) + + modin_df2, pandas_df2 = create_test_dfs( + { + "idx_key1": [4, 3, 2, 1], + "idx_key2": [5, 4, 3, 2], + "idx_key3": [6, 5, 4, 3], + "data_col2": [10, 2, 3, 4], + "col_key1": [6, 5, 4, 3], + "col_key2": [6, 5, 4, 3], + } + ) + + modin_df2 = modin_df2.set_index(["idx_key2", "idx_key3"]) + pandas_df2 = pandas_df2.set_index(["idx_key2", "idx_key3"]) + + def setup_cache(): + if has_index_cache: + modin_df1.index # triggering index materialization + modin_df2.index + assert modin_df1._query_compiler._modin_frame._index_cache is not None + assert modin_df2._query_compiler._modin_frame._index_cache is not None + else: + # Propagate deferred indices to partitions + modin_df1._query_compiler._modin_frame._propagate_index_objs(axis=0) + modin_df1._query_compiler._modin_frame._index_cache = None + modin_df2._query_compiler._modin_frame._propagate_index_objs(axis=0) + modin_df2._query_compiler._modin_frame._index_cache = None + + for on in ( + ["col_key1", "idx_key1"], + ["col_key1", "idx_key2"], + ["col_key1", "idx_key3"], + ["idx_key1"], + ["idx_key2"], + ["idx_key3"], + ): + setup_cache() + eval_general( + (modin_df1, modin_df2), + (pandas_df1, pandas_df2), + lambda dfs: dfs[0].merge(dfs[1], on=on), + ) + + for left_on, right_on in ( + (["idx_key1"], ["col_key1"]), + (["col_key1"], ["idx_key3"]), + (["idx_key1"], ["idx_key3"]), + (["idx_key2"], ["idx_key2"]), + (["col_key1", "idx_key2"], ["col_key2", "idx_key2"]), + ): + setup_cache() + eval_general( + (modin_df1, modin_df2), + (pandas_df1, pandas_df2), + lambda dfs: dfs[0].merge(dfs[1], left_on=left_on, right_on=right_on), + ) + + @pytest.mark.parametrize("axis", [0, 1]) @pytest.mark.parametrize( "ascending", bool_arg_values, ids=arg_keys("ascending", bool_arg_keys)