Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] Add dot product binary op #8909

Merged
merged 20 commits into from
Sep 8, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
aaf84d7
First pass at adding dot binop
charlesbluca Jul 30, 2021
ba17655
Merge remote-tracking branch 'upstream/branch-21.10' into add-dot-op
charlesbluca Aug 12, 2021
d88d24d
Add support for Sequences, move implementation to Frame
charlesbluca Aug 12, 2021
9fff2bc
Add reference to dot() in docstrings
charlesbluca Aug 12, 2021
cb9b3b1
Remove dot product from missing func test
charlesbluca Aug 12, 2021
c896391
Merge remote-tracking branch 'upstream/branch-21.10' into add-dot-op
charlesbluca Aug 13, 2021
a6fd9c3
Fix incorrect reflect computation
charlesbluca Aug 17, 2021
7987136
Use item() to get the scalar from a 0-d array
charlesbluca Aug 17, 2021
c5de1fe
Improve docstring description for 'reflect'
charlesbluca Aug 17, 2021
042e7d6
Add more testing for dot, xfail null cases
charlesbluca Aug 17, 2021
a59a35d
Remove missing dot func case for indexes
charlesbluca Aug 17, 2021
6246cf3
Apply suggestions from code review
charlesbluca Aug 19, 2021
c16b393
Add support for pandas operands, update tests
charlesbluca Aug 19, 2021
f9df302
Add tests for array function dot product
charlesbluca Aug 20, 2021
e7f38e1
Merge remote-tracking branch 'upstream/branch-21.10' into add-dot-op
charlesbluca Aug 31, 2021
b023413
Special case for np.dot of two dataframes
charlesbluca Aug 31, 2021
05c18c9
Merge remote-tracking branch 'upstream/branch-21.10' into add-dot-op
charlesbluca Aug 31, 2021
0bb9ca1
Merge remote-tracking branch 'upstream/branch-21.10' into add-dot-op
charlesbluca Sep 7, 2021
971ce33
Add Python link to reverse operators for reflect docstrings
charlesbluca Sep 7, 2021
f9a9200
Merge remote-tracking branch 'upstream/branch-22.10' into add-dot-op
charlesbluca Sep 8, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 42 additions & 28 deletions python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1502,8 +1502,9 @@ def add(self, other, axis="columns", level=None, fill_value=None):
fill_value for missing data in one of the inputs. With reverse
version, `radd`.

Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`) to
arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`.
Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`,
`dot`) to arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`,
`@`.

Parameters
----------
Expand Down Expand Up @@ -1656,8 +1657,9 @@ def radd(self, other, axis=1, level=None, fill_value=None):
fill_value for missing data in one of the inputs. With reverse
version, `add`.

Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`) to
arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`.
Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`,
`dot`) to arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`,
`@`.

Parameters
----------
Expand Down Expand Up @@ -1709,8 +1711,9 @@ def sub(self, other, axis="columns", level=None, fill_value=None):
fill_value for missing data in one of the inputs. With reverse
version, `rsub`.

Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`) to
arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`.
Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`,
`dot`) to arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`,
`@`.

Parameters
----------
Expand Down Expand Up @@ -1762,8 +1765,9 @@ def rsub(self, other, axis="columns", level=None, fill_value=None):
fill_value for missing data in one of the inputs. With reverse
version, `sub`.

Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`) to
arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`.
Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`,
`dot`) to arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`,
`@`.

Parameters
----------
Expand Down Expand Up @@ -1820,8 +1824,9 @@ def mul(self, other, axis="columns", level=None, fill_value=None):
fill_value for missing data in one of the inputs. With reverse
version, `rmul`.

Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`) to
arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`.
Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`,
`dot`) to arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`,
`@`.

Parameters
----------
Expand Down Expand Up @@ -1875,8 +1880,9 @@ def rmul(self, other, axis="columns", level=None, fill_value=None):
fill_value for missing data in one of the inputs. With reverse
version, `mul`.

Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`) to
arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`.
Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`,
`dot`) to arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`,
`@`.

Parameters
----------
Expand Down Expand Up @@ -1930,8 +1936,9 @@ def mod(self, other, axis="columns", level=None, fill_value=None):
fill_value for missing data in one of the inputs. With reverse
version, `rmod`.

Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`) to
arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`.
Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`,
`dot`) to arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`,
`@`.

Parameters
----------
Expand Down Expand Up @@ -1983,8 +1990,9 @@ def rmod(self, other, axis="columns", level=None, fill_value=None):
fill_value for missing data in one of the inputs. With reverse
version, `mod`.

Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`) to
arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`.
Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`,
`dot`) to arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`,
`@`.

Parameters
----------
Expand Down Expand Up @@ -2036,8 +2044,9 @@ def pow(self, other, axis="columns", level=None, fill_value=None):
fill_value for missing data in one of the inputs. With reverse
version, `rpow`.

Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`) to
arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`.
Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`,
`dot`) to arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`,
`@`.

Parameters
----------
Expand Down Expand Up @@ -2089,8 +2098,9 @@ def rpow(self, other, axis="columns", level=None, fill_value=None):
fill_value for missing data in one of the inputs. With reverse
version, `pow`.

Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`) to
arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`.
Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`,
`dot`) to arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`,
`@`.

Parameters
----------
Expand Down Expand Up @@ -2142,8 +2152,9 @@ def floordiv(self, other, axis="columns", level=None, fill_value=None):
fill_value for missing data in one of the inputs. With reverse
version, `rfloordiv`.

Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`) to
arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`.
Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`,
`dot`) to arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`,
`@`.

Parameters
----------
Expand Down Expand Up @@ -2195,8 +2206,9 @@ def rfloordiv(self, other, axis="columns", level=None, fill_value=None):
a fill_value for missing data in one of the inputs. With reverse
version, `floordiv`.

Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`) to
arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`.
Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`,
`dot`) to arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`,
`@`.

Parameters
----------
Expand Down Expand Up @@ -2258,8 +2270,9 @@ def truediv(self, other, axis="columns", level=None, fill_value=None):
fill_value for missing data in one of the inputs. With reverse
version, `rtruediv`.

Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`) to
arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`.
Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`,
`dot`) to arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`,
`@`.

Parameters
----------
Expand Down Expand Up @@ -2319,8 +2332,9 @@ def rtruediv(self, other, axis="columns", level=None, fill_value=None):
fill_value for missing data in one of the inputs. With reverse
version, `truediv`.

Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`) to
arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`.
Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`,
`dot`) to arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`,
`@`.

Parameters
----------
Expand Down
64 changes: 64 additions & 0 deletions python/cudf/cudf/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3594,6 +3594,64 @@ def _colwise_binop(

return output

def dot(self, other, reflect=False):
"""
Get dot product of frame and other, (binary operator `dot`).

Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`,
`dot`) to arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`,
`@`.

Parameters
----------
other : sequence, Series, or DataFrame
charlesbluca marked this conversation as resolved.
Show resolved Hide resolved
Any multiple element data structure, or list-like object.
reflect : bool, default False
If ``True`` the operation is reflected (i.e whether to swap the
skirui-source marked this conversation as resolved.
Show resolved Hide resolved
left and right operands).

Returns
-------
scalar, Series, or DataFrame
The result of the operation.

Examples
--------
>>> import cudf
>>> df = cudf.DataFrame([[1, 2, 3, 4],
... [5, 6, 7, 8]])
>>> df @ df.T
0 1
0 30 70
1 70 174
>>> s = cudf.Series([1, 1, 1, 1])
>>> df @ s
0 10
1 26
dtype: int64
>>> [1, 2, 3, 4] @ s
array(10)
skirui-source marked this conversation as resolved.
Show resolved Hide resolved
"""
lhs = self.values
if isinstance(other, Frame):
rhs = other.values
elif isinstance(other, cupy.ndarray):
rhs = other
elif isinstance(other, abc.Sequence):
charlesbluca marked this conversation as resolved.
Show resolved Hide resolved
rhs = cupy.asarray(other)
else:
return NotImplemented
if reflect:
lhs, rhs = rhs.T, lhs.T

result = lhs.dot(rhs)
if len(result.shape) == 1:
result = cudf.Series(result)
elif len(result.shape) == 2:
result = cudf.DataFrame(result)

skirui-source marked this conversation as resolved.
Show resolved Hide resolved
return result

# Binary arithmetic operations.
def __add__(self, other):
return self._binaryop(other, "add")
Expand All @@ -3607,6 +3665,12 @@ def __sub__(self, other):
def __rsub__(self, other):
return self._binaryop(other, "sub", reflect=True)

def __matmul__(self, other):
return self.dot(other)

def __rmatmul__(self, other):
return self.dot(other, reflect=True)

def __mul__(self, other):
return self._binaryop(other, "mul")

Expand Down
1 change: 0 additions & 1 deletion python/cudf/cudf/tests/test_array_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def test_array_func_cudf_dataframe(pd_df, func):
"func",
[
lambda x: np.cov(x, x),
lambda x: np.dot(x, x),
vyasr marked this conversation as resolved.
Show resolved Hide resolved
lambda x: np.linalg.norm(x),
lambda x: np.linalg.det(x),
],
Expand Down
21 changes: 21 additions & 0 deletions python/cudf/cudf/tests/test_binops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2913,3 +2913,24 @@ def test_empty_column(binop, data, scalar):
expected = binop(pdf, scalar)

utils.assert_eq(expected, got)


@pytest.mark.parametrize(
"lhs",
[pd.DataFrame([[0, 1, -2, -1], [1, 1, 1, 1]]), pd.Series([1, 1, 2, 1])],
skirui-source marked this conversation as resolved.
Show resolved Hide resolved
)
@pytest.mark.parametrize(
"rhs",
[
pd.DataFrame([[0, 1], [1, 2], [-1, -1], [2, 0]]),
pd.Series([1, 1, 2, 1]),
],
)
def test_binops_dot(lhs, rhs):
glhs = cudf.from_pandas(lhs)
grhs = cudf.from_pandas(rhs)

expected = lhs.dot(rhs)
got = glhs.dot(grhs)

utils.assert_eq(expected, got)