Skip to content

Commit

Permalink
Allow merging index column with data column using keyword "on" (rapid…
Browse files Browse the repository at this point in the history
  • Loading branch information
skirui-source authored and shwina committed Apr 7, 2021
1 parent 9fa8679 commit 6338e9c
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 15 deletions.
45 changes: 30 additions & 15 deletions python/cudf/cudf/core/join/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,14 +196,14 @@ def perform_merge(self) -> Frame:

def _compute_join_keys(self):
# Computes self._keys
left_keys = []
right_keys = []
if (
self.left_index
or self.right_index
or self.left_on
or self.right_on
):
left_keys = []
right_keys = []
if self.left_index:
left_keys.extend(
[
Expand Down Expand Up @@ -234,14 +234,25 @@ def _compute_join_keys(self):
for on in _coerce_to_tuple(self.right_on)
]
)
elif self.on:
on_names = _coerce_to_tuple(self.on)
for on in on_names:
# If `on` is provided, Merge on columns if present,
# otherwise default to indexes.
if on in self.lhs._data:
left_keys.append(_Indexer(name=on, column=True))
else:
left_keys.append(_Indexer(name=on, index=True))
if on in self.rhs._data:
right_keys.append(_Indexer(name=on, column=True))
else:
right_keys.append(_Indexer(name=on, index=True))

else:
# Use `on` if provided. Otherwise,
# implicitly use identically named columns as the key columns:
on_names = (
_coerce_to_tuple(self.on)
if self.on is not None
else set(self.lhs._data) & set(self.rhs._data)
)
# if `on` is not provided and we're not merging
# index with column or on both indexes, then use
# the intersection of columns in both frames
on_names = set(self.lhs._data) & set(self.rhs._data)
left_keys = [_Indexer(name=on, column=True) for on in on_names]
right_keys = [_Indexer(name=on, column=True) for on in on_names]

Expand Down Expand Up @@ -384,12 +395,16 @@ def _validate_merge_params(
if how not in {"left", "inner", "outer", "leftanti", "leftsemi"}:
raise NotImplementedError(f"{how} merge not supported yet")

# Passing 'on' with 'left_on' or 'right_on' is ambiguous
if on and (left_on or right_on):
raise ValueError(
'Can only pass argument "on" OR "left_on" '
'and "right_on", not a combination of both.'
)
if on:
if left_on or right_on:
# Passing 'on' with 'left_on' or 'right_on' is ambiguous
raise ValueError(
'Can only pass argument "on" OR "left_on" '
'and "right_on", not a combination of both.'
)
else:
# the validity of 'on' being checked by _Indexer
return

# Can't merge on unnamed Series
if (isinstance(lhs, cudf.Series) and not lhs.name) or (
Expand Down
53 changes: 53 additions & 0 deletions python/cudf/cudf/tests/test_joining.py
Original file line number Diff line number Diff line change
Expand Up @@ -1869,3 +1869,56 @@ def test_join_renamed_index():
)
got = df.merge(df, left_index=True, right_index=True, how="inner")
assert_join_results_equal(expect, got, how="inner")


@pytest.mark.parametrize(
"lhs_col, lhs_idx, rhs_col, rhs_idx, on",
[
(["A", "B"], "L0", ["B", "C"], "L0", ["B"]),
(["A", "B"], "L0", ["B", "C"], "L0", ["L0"]),
(["A", "B"], "L0", ["B", "C"], "L0", ["B", "L0"]),
(["A", "B"], "L0", ["C", "L0"], "A", ["A"]),
(["A", "B"], "L0", ["C", "L0"], "A", ["L0"]),
(["A", "B"], "L0", ["C", "L0"], "A", ["A", "L0"]),
],
)
@pytest.mark.parametrize(
"how", ["left", "inner", "right", "outer", "leftanti", "leftsemi"]
)
def test_join_merge_with_on(lhs_col, lhs_idx, rhs_col, rhs_idx, on, how):
lhs_data = {col_name: [4, 5, 6] for col_name in lhs_col}
lhs_index = cudf.Index([0, 1, 2], name=lhs_idx)

rhs_data = {col_name: [4, 5, 6] for col_name in rhs_col}
rhs_index = cudf.Index([2, 3, 4], name=rhs_idx)

gd_left = cudf.DataFrame(lhs_data, lhs_index)
gd_right = cudf.DataFrame(rhs_data, rhs_index)
pd_left = gd_left.to_pandas()
pd_right = gd_right.to_pandas()

expect = pd_left.merge(pd_right, on=on).sort_index(axis=1, ascending=False)
got = gd_left.merge(gd_right, on=on).sort_index(axis=1, ascending=False)

assert_join_results_equal(expect, got, how=how)


@pytest.mark.parametrize(
"on", ["A", "L0"],
)
@pytest.mark.parametrize(
"how", ["left", "inner", "right", "outer", "leftanti", "leftsemi"]
)
def test_join_merge_invalid_keys(on, how):
gd_left = cudf.DataFrame(
{"A": [1, 2, 3], "B": [4, 5, 6]}, index=cudf.Index([0, 1, 2], name="C")
)
gd_right = cudf.DataFrame(
{"D": [2, 3, 4], "E": [7, 8, 0]}, index=cudf.Index([0, 2, 4], name="F")
)
pd_left = gd_left.to_pandas()
pd_right = gd_right.to_pandas()

with pytest.raises(KeyError):
pd_left.merge(pd_right, on=on)
gd_left.merge(gd_right, on=on)

0 comments on commit 6338e9c

Please sign in to comment.