Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] Add dot product binary op #8909

Merged
merged 20 commits into from
Sep 8, 2021
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
aaf84d7
First pass at adding dot binop
charlesbluca Jul 30, 2021
ba17655
Merge remote-tracking branch 'upstream/branch-21.10' into add-dot-op
charlesbluca Aug 12, 2021
d88d24d
Add support for Sequences, move implementation to Frame
charlesbluca Aug 12, 2021
9fff2bc
Add reference to dot() in docstrings
charlesbluca Aug 12, 2021
cb9b3b1
Remove dot product from missing func test
charlesbluca Aug 12, 2021
c896391
Merge remote-tracking branch 'upstream/branch-21.10' into add-dot-op
charlesbluca Aug 13, 2021
a6fd9c3
Fix incorrect reflect computation
charlesbluca Aug 17, 2021
7987136
Use item() to get the scalar from a 0-d array
charlesbluca Aug 17, 2021
c5de1fe
Improve docstring description for 'reflect'
charlesbluca Aug 17, 2021
042e7d6
Add more testing for dot, xfail null cases
charlesbluca Aug 17, 2021
a59a35d
Remove missing dot func case for indexes
charlesbluca Aug 17, 2021
6246cf3
Apply suggestions from code review
charlesbluca Aug 19, 2021
c16b393
Add support for pandas operands, update tests
charlesbluca Aug 19, 2021
f9df302
Add tests for array function dot product
charlesbluca Aug 20, 2021
e7f38e1
Merge remote-tracking branch 'upstream/branch-21.10' into add-dot-op
charlesbluca Aug 31, 2021
b023413
Special case for np.dot of two dataframes
charlesbluca Aug 31, 2021
05c18c9
Merge remote-tracking branch 'upstream/branch-21.10' into add-dot-op
charlesbluca Aug 31, 2021
0bb9ca1
Merge remote-tracking branch 'upstream/branch-21.10' into add-dot-op
charlesbluca Sep 7, 2021
971ce33
Add Python link to reverse operators for reflect docstrings
charlesbluca Sep 7, 2021
f9a9200
Merge remote-tracking branch 'upstream/branch-22.10' into add-dot-op
charlesbluca Sep 8, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions python/cudf/cudf/core/column/lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,9 @@ def binary_operator(
for lists concatenation functions

reflect : boolean, default False
If ``reflect`` is ``True``, swap the order of
the operands.
If ``True`` the operation is reflected (i.e ``other`` is used as
the left operand instead of the right). This is enabled when using
a binary operation with a left operand that does not implement it.

Returns
-------
Expand Down
77 changes: 49 additions & 28 deletions python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,6 +974,13 @@ def __array_function__(self, func, types, args, kwargs):
# Handle case if cudf_func is same as numpy function
if cudf_func is func:
return NotImplemented
# numpy returns an array from the dot product of two dataframes
elif (
func is np.dot
and isinstance(args[0], (DataFrame, pd.DataFrame))
and isinstance(args[1], (DataFrame, pd.DataFrame))
):
return cudf_func(*args, **kwargs).values
else:
return cudf_func(*args, **kwargs)
else:
Expand Down Expand Up @@ -1657,8 +1664,9 @@ def add(self, other, axis="columns", level=None, fill_value=None):
fill_value for missing data in one of the inputs. With reverse
version, `radd`.

Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`) to
arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`.
Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`,
`dot`) to arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`,
`@`.

Parameters
----------
Expand Down Expand Up @@ -1803,8 +1811,9 @@ def radd(self, other, axis=1, level=None, fill_value=None):
fill_value for missing data in one of the inputs. With reverse
version, `add`.

Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`) to
arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`.
Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`,
`dot`) to arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`,
`@`.

Parameters
----------
Expand Down Expand Up @@ -1856,8 +1865,9 @@ def sub(self, other, axis="columns", level=None, fill_value=None):
fill_value for missing data in one of the inputs. With reverse
version, `rsub`.

Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`) to
arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`.
Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`,
`dot`) to arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`,
`@`.

Parameters
----------
Expand Down Expand Up @@ -1909,8 +1919,9 @@ def rsub(self, other, axis="columns", level=None, fill_value=None):
fill_value for missing data in one of the inputs. With reverse
version, `sub`.

Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`) to
arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`.
Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`,
`dot`) to arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`,
`@`.

Parameters
----------
Expand Down Expand Up @@ -1967,8 +1978,9 @@ def mul(self, other, axis="columns", level=None, fill_value=None):
fill_value for missing data in one of the inputs. With reverse
version, `rmul`.

Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`) to
arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`.
Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`,
`dot`) to arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`,
`@`.

Parameters
----------
Expand Down Expand Up @@ -2022,8 +2034,9 @@ def rmul(self, other, axis="columns", level=None, fill_value=None):
fill_value for missing data in one of the inputs. With reverse
version, `mul`.

Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`) to
arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`.
Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`,
`dot`) to arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`,
`@`.

Parameters
----------
Expand Down Expand Up @@ -2077,8 +2090,9 @@ def mod(self, other, axis="columns", level=None, fill_value=None):
fill_value for missing data in one of the inputs. With reverse
version, `rmod`.

Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`) to
arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`.
Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`,
`dot`) to arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`,
`@`.

Parameters
----------
Expand Down Expand Up @@ -2130,8 +2144,9 @@ def rmod(self, other, axis="columns", level=None, fill_value=None):
fill_value for missing data in one of the inputs. With reverse
version, `mod`.

Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`) to
arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`.
Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`,
`dot`) to arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`,
`@`.

Parameters
----------
Expand Down Expand Up @@ -2183,8 +2198,9 @@ def pow(self, other, axis="columns", level=None, fill_value=None):
fill_value for missing data in one of the inputs. With reverse
version, `rpow`.

Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`) to
arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`.
Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`,
`dot`) to arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`,
`@`.

Parameters
----------
Expand Down Expand Up @@ -2236,8 +2252,9 @@ def rpow(self, other, axis="columns", level=None, fill_value=None):
fill_value for missing data in one of the inputs. With reverse
version, `pow`.

Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`) to
arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`.
Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`,
`dot`) to arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`,
`@`.

Parameters
----------
Expand Down Expand Up @@ -2289,8 +2306,9 @@ def floordiv(self, other, axis="columns", level=None, fill_value=None):
fill_value for missing data in one of the inputs. With reverse
version, `rfloordiv`.

Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`) to
arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`.
Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`,
`dot`) to arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`,
`@`.

Parameters
----------
Expand Down Expand Up @@ -2342,8 +2360,9 @@ def rfloordiv(self, other, axis="columns", level=None, fill_value=None):
a fill_value for missing data in one of the inputs. With reverse
version, `floordiv`.

Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`) to
arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`.
Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`,
`dot`) to arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`,
`@`.

Parameters
----------
Expand Down Expand Up @@ -2405,8 +2424,9 @@ def truediv(self, other, axis="columns", level=None, fill_value=None):
fill_value for missing data in one of the inputs. With reverse
version, `rtruediv`.

Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`) to
arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`.
Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`,
`dot`) to arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`,
`@`.

Parameters
----------
Expand Down Expand Up @@ -2466,8 +2486,9 @@ def rtruediv(self, other, axis="columns", level=None, fill_value=None):
fill_value for missing data in one of the inputs. With reverse
version, `truediv`.

Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`) to
arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`.
Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`,
`dot`) to arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`,
`@`.

Parameters
----------
Expand Down
71 changes: 69 additions & 2 deletions python/cudf/cudf/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3406,6 +3406,66 @@ def _colwise_binop(

return output

def dot(self, other, reflect=False):
"""
Get dot product of frame and other, (binary operator `dot`).

Among flexible wrappers (`add`, `sub`, `mul`, `div`, `mod`, `pow`,
`dot`) to arithmetic operators: `+`, `-`, `*`, `/`, `//`, `%`, `**`,
`@`.

Parameters
----------
other : Sequence, Series, or DataFrame
Any multiple element data structure, or list-like object.
reflect : bool, default False
If ``True`` the operation is reflected (i.e ``other`` is used as
the left operand instead of the right). This is enabled when using
a binary operation with a left operand that does not implement it.

Returns
-------
scalar, Series, or DataFrame
The result of the operation.

Examples
--------
>>> import cudf
>>> df = cudf.DataFrame([[1, 2, 3, 4],
... [5, 6, 7, 8]])
>>> df @ df.T
0 1
0 30 70
1 70 174
>>> s = cudf.Series([1, 1, 1, 1])
>>> df @ s
0 10
1 26
dtype: int64
>>> [1, 2, 3, 4] @ s
10
"""
lhs = self.values
if isinstance(other, Frame):
rhs = other.values
elif isinstance(other, cupy.ndarray):
rhs = other
elif isinstance(
other, (abc.Sequence, np.ndarray, pd.DataFrame, pd.Series)
):
rhs = cupy.asarray(other)
else:
return NotImplemented
if reflect:
lhs, rhs = rhs, lhs

result = lhs.dot(rhs)
if len(result.shape) == 1:
return cudf.Series(result)
if len(result.shape) == 2:
return cudf.DataFrame(result)
return result.item()

# Binary arithmetic operations.
def __add__(self, other):
return self._binaryop(other, "add")
Expand All @@ -3419,6 +3479,12 @@ def __sub__(self, other):
def __rsub__(self, other):
return self._binaryop(other, "sub", reflect=True)

def __matmul__(self, other):
return self.dot(other)

def __rmatmul__(self, other):
return self.dot(other, reflect=True)

def __mul__(self, other):
return self._binaryop(other, "mul")

Expand Down Expand Up @@ -4923,8 +4989,9 @@ def _make_operands_for_binop(
The value to replace null values with. If ``None``, nulls are not
filled before the operation.
reflect : bool, default False
If ``True`` the operation is reflected (i.e whether to swap the
left and right operands).
If ``True`` the operation is reflected (i.e ``other`` is used as
the left operand instead of the right). This is enabled when using
a binary operation with a left operand that does not implement it.

Returns
-------
Expand Down
11 changes: 7 additions & 4 deletions python/cudf/cudf/tests/test_array_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,12 @@ def test_array_func_cudf_series(np_ar, func):
)
@pytest.mark.parametrize(
"func",
[lambda x: np.mean(x), lambda x: np.sum(x), lambda x: np.var(x, ddof=1)],
[
lambda x: np.mean(x),
lambda x: np.sum(x),
lambda x: np.var(x, ddof=1),
lambda x: np.dot(x, x.transpose()),
],
)
def test_array_func_cudf_dataframe(pd_df, func):
cudf_df = cudf.from_pandas(pd_df)
Expand All @@ -60,7 +65,6 @@ def test_array_func_cudf_dataframe(pd_df, func):
"func",
[
lambda x: np.cov(x, x),
lambda x: np.dot(x, x),
vyasr marked this conversation as resolved.
Show resolved Hide resolved
lambda x: np.linalg.norm(x),
lambda x: np.linalg.det(x),
],
Expand All @@ -74,7 +78,7 @@ def test_array_func_missing_cudf_dataframe(pd_df, func):
# we only implement sum among all numpy non-ufuncs
@pytest.mark.skipif(missing_arrfunc_cond, reason=missing_arrfunc_reason)
@pytest.mark.parametrize("np_ar", [np.random.random(100)])
@pytest.mark.parametrize("func", [lambda x: np.sum(x)])
@pytest.mark.parametrize("func", [lambda x: np.sum(x), lambda x: np.dot(x, x)])
def test_array_func_cudf_index(np_ar, func):
cudf_index = cudf.core.index.as_index(cudf.Series(np_ar))
expect = func(np_ar)
Expand All @@ -88,7 +92,6 @@ def test_array_func_cudf_index(np_ar, func):
"func",
[
lambda x: np.cov(x, x),
lambda x: np.dot(x, x),
lambda x: np.linalg.norm(x),
lambda x: np.linalg.det(x),
],
Expand Down
40 changes: 40 additions & 0 deletions python/cudf/cudf/tests/test_binops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2915,3 +2915,43 @@ def test_empty_column(binop, data, scalar):
expected = binop(pdf, scalar)

utils.assert_eq(expected, got)


@pytest.mark.parametrize(
"df",
[
cudf.DataFrame([[1, 2, 3, 4], [5, 6, 7, 8]]),
pytest.param(
cudf.DataFrame([[1, None, None, 4], [5, 6, 7, None]]),
marks=pytest.mark.xfail(
reason="Cannot access Frame.values if frame contains nulls"
),
),
cudf.DataFrame([[1.2, 2.3, 3.4, 4.5], [5.6, 6.7, 7.8, 8.9]]),
cudf.Series([14, 15, 16, 17]),
cudf.Series([14.15, 15.16, 16.17, 17.18]),
],
)
@pytest.mark.parametrize(
"other",
[
cudf.DataFrame([[9, 10], [11, 12], [13, 14], [15, 16]]),
cudf.DataFrame(
[[9.4, 10.5], [11.6, 12.7], [13.8, 14.9], [15.1, 16.2]]
),
cudf.Series([5, 6, 7, 8]),
cudf.Series([5.6, 6.7, 7.8, 8.9]),
pd.DataFrame([[9, 10], [11, 12], [13, 14], [15, 16]]),
pd.Series([5, 6, 7, 8]),
np.array([5, 6, 7, 8]),
[25.5, 26.6, 27.7, 28.8],
],
)
def test_binops_dot(df, other):
pdf = df.to_pandas()
host_other = other.to_pandas() if hasattr(other, "to_pandas") else other

expected = pdf @ host_other
got = df @ other

utils.assert_eq(expected, got)