From 6338e9c399e7c3bf6edb34c8fda77c20cc1aab31 Mon Sep 17 00:00:00 2001 From: skirui-source <71867292+skirui-source@users.noreply.github.com> Date: Thu, 1 Apr 2021 22:19:23 -0700 Subject: [PATCH] Allow merging index column with data column using keyword "on" (#7736) fixes #5014 replaces PR #7569 Authors: - https://github.com/skirui-source - Ashwin Srinath (https://github.com/shwina) - Vyas Ramasubramani (https://github.com/vyasr) Approvers: - Michael Wang (https://github.com/isVoid) - Keith Kraus (https://github.com/kkraus14) - Vyas Ramasubramani (https://github.com/vyasr) URL: https://github.com/rapidsai/cudf/pull/7736 --- python/cudf/cudf/core/join/join.py | 45 ++++++++++++++-------- python/cudf/cudf/tests/test_joining.py | 53 ++++++++++++++++++++++++++ 2 files changed, 83 insertions(+), 15 deletions(-) diff --git a/python/cudf/cudf/core/join/join.py b/python/cudf/cudf/core/join/join.py index 1a4826d0570..3f5776b4ea4 100644 --- a/python/cudf/cudf/core/join/join.py +++ b/python/cudf/cudf/core/join/join.py @@ -196,14 +196,14 @@ def perform_merge(self) -> Frame: def _compute_join_keys(self): # Computes self._keys + left_keys = [] + right_keys = [] if ( self.left_index or self.right_index or self.left_on or self.right_on ): - left_keys = [] - right_keys = [] if self.left_index: left_keys.extend( [ @@ -234,14 +234,25 @@ def _compute_join_keys(self): for on in _coerce_to_tuple(self.right_on) ] ) + elif self.on: + on_names = _coerce_to_tuple(self.on) + for on in on_names: + # If `on` is provided, Merge on columns if present, + # otherwise default to indexes. + if on in self.lhs._data: + left_keys.append(_Indexer(name=on, column=True)) + else: + left_keys.append(_Indexer(name=on, index=True)) + if on in self.rhs._data: + right_keys.append(_Indexer(name=on, column=True)) + else: + right_keys.append(_Indexer(name=on, index=True)) + else: - # Use `on` if provided. Otherwise, - # implicitly use identically named columns as the key columns: - on_names = ( - _coerce_to_tuple(self.on) - if self.on is not None - else set(self.lhs._data) & set(self.rhs._data) - ) + # if `on` is not provided and we're not merging + # index with column or on both indexes, then use + # the intersection of columns in both frames + on_names = set(self.lhs._data) & set(self.rhs._data) left_keys = [_Indexer(name=on, column=True) for on in on_names] right_keys = [_Indexer(name=on, column=True) for on in on_names] @@ -384,12 +395,16 @@ def _validate_merge_params( if how not in {"left", "inner", "outer", "leftanti", "leftsemi"}: raise NotImplementedError(f"{how} merge not supported yet") - # Passing 'on' with 'left_on' or 'right_on' is ambiguous - if on and (left_on or right_on): - raise ValueError( - 'Can only pass argument "on" OR "left_on" ' - 'and "right_on", not a combination of both.' - ) + if on: + if left_on or right_on: + # Passing 'on' with 'left_on' or 'right_on' is ambiguous + raise ValueError( + 'Can only pass argument "on" OR "left_on" ' + 'and "right_on", not a combination of both.' + ) + else: + # the validity of 'on' being checked by _Indexer + return # Can't merge on unnamed Series if (isinstance(lhs, cudf.Series) and not lhs.name) or ( diff --git a/python/cudf/cudf/tests/test_joining.py b/python/cudf/cudf/tests/test_joining.py index 2dae2bf1e97..183385bacc1 100644 --- a/python/cudf/cudf/tests/test_joining.py +++ b/python/cudf/cudf/tests/test_joining.py @@ -1869,3 +1869,56 @@ def test_join_renamed_index(): ) got = df.merge(df, left_index=True, right_index=True, how="inner") assert_join_results_equal(expect, got, how="inner") + + +@pytest.mark.parametrize( + "lhs_col, lhs_idx, rhs_col, rhs_idx, on", + [ + (["A", "B"], "L0", ["B", "C"], "L0", ["B"]), + (["A", "B"], "L0", ["B", "C"], "L0", ["L0"]), + (["A", "B"], "L0", ["B", "C"], "L0", ["B", "L0"]), + (["A", "B"], "L0", ["C", "L0"], "A", ["A"]), + (["A", "B"], "L0", ["C", "L0"], "A", ["L0"]), + (["A", "B"], "L0", ["C", "L0"], "A", ["A", "L0"]), + ], +) +@pytest.mark.parametrize( + "how", ["left", "inner", "right", "outer", "leftanti", "leftsemi"] +) +def test_join_merge_with_on(lhs_col, lhs_idx, rhs_col, rhs_idx, on, how): + lhs_data = {col_name: [4, 5, 6] for col_name in lhs_col} + lhs_index = cudf.Index([0, 1, 2], name=lhs_idx) + + rhs_data = {col_name: [4, 5, 6] for col_name in rhs_col} + rhs_index = cudf.Index([2, 3, 4], name=rhs_idx) + + gd_left = cudf.DataFrame(lhs_data, lhs_index) + gd_right = cudf.DataFrame(rhs_data, rhs_index) + pd_left = gd_left.to_pandas() + pd_right = gd_right.to_pandas() + + expect = pd_left.merge(pd_right, on=on).sort_index(axis=1, ascending=False) + got = gd_left.merge(gd_right, on=on).sort_index(axis=1, ascending=False) + + assert_join_results_equal(expect, got, how=how) + + +@pytest.mark.parametrize( + "on", ["A", "L0"], +) +@pytest.mark.parametrize( + "how", ["left", "inner", "right", "outer", "leftanti", "leftsemi"] +) +def test_join_merge_invalid_keys(on, how): + gd_left = cudf.DataFrame( + {"A": [1, 2, 3], "B": [4, 5, 6]}, index=cudf.Index([0, 1, 2], name="C") + ) + gd_right = cudf.DataFrame( + {"D": [2, 3, 4], "E": [7, 8, 0]}, index=cudf.Index([0, 2, 4], name="F") + ) + pd_left = gd_left.to_pandas() + pd_right = gd_right.to_pandas() + + with pytest.raises(KeyError): + pd_left.merge(pd_right, on=on) + gd_left.merge(gd_right, on=on)