Skip to content

Commit

Permalink
Add support for null and non-numeric types in Series.diff and DataFra…
Browse files Browse the repository at this point in the history
…me.diff (#10625)

This PR supports non-numeric data types (timestamp and ranges) in `Series.diff` and `DataFrame.diff`. In `DataFrame.diff`, datetime ranges are already supported because `DataFrame.shift` works. But `Series.diff` doesn't use the `Series.shift` implementation, so there wasn't support for datetime ranges. 
```python
import datetime
dti = pd.to_datetime(
    ["1/1/2018", np.datetime64("2018-01-01"), datetime.datetime(2018, 1, 1), datetime.datetime(2020, 1, 1)]
)
df = DataFrame({"dates": dti})
df.diff(periods=periods, axis=axis)
```
closes #10212.

Authors:
  - Matthew Murray (https://github.com/Matt711)
  - Bradley Dice (https://github.com/bdice)

Approvers:
  - Ashwin Srinath (https://github.com/shwina)
  - Ram (Ramakrishna Prabhu) (https://github.com/rgsl888prabhu)

URL: #10625
  • Loading branch information
Matt711 authored Apr 15, 2022
1 parent d5a982b commit 94a5d41
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 68 deletions.
5 changes: 0 additions & 5 deletions python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2660,11 +2660,6 @@ def diff(self, periods=1, axis=0):
if axis != 0:
raise NotImplementedError("Only axis=0 is supported.")

if not all(is_numeric_dtype(i) for i in self.dtypes):
raise NotImplementedError(
"DataFrame.diff only supports numeric dtypes"
)

if abs(periods) > len(self):
df = cudf.DataFrame._from_data(
{
Expand Down
52 changes: 17 additions & 35 deletions python/cudf/cudf/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
import numpy as np
import pandas as pd
from pandas._config import get_option
from pandas.core.dtypes.common import is_float

import cudf
from cudf import _lib as libcudf
from cudf._lib.scalar import _is_null_host_scalar
from cudf._lib.transform import bools_to_mask
from cudf._typing import ColumnLike, DataFrameOrSeries, ScalarLike
from cudf.api.types import (
_is_non_decimal_numeric_dtype,
Expand All @@ -42,7 +42,6 @@
arange,
as_column,
column,
column_empty_like,
full,
)
from cudf.core.column.categorical import (
Expand All @@ -64,7 +63,7 @@
)
from cudf.core.single_column_frame import SingleColumnFrame
from cudf.core.udf.scalar_function import _get_scalar_kernel
from cudf.utils import cudautils, docutils
from cudf.utils import docutils
from cudf.utils.docutils import copy_docstring
from cudf.utils.dtypes import (
can_convert_to_column,
Expand Down Expand Up @@ -2969,19 +2968,22 @@ def digitize(self, bins, right=False):

@_cudf_nvtx_annotate
def diff(self, periods=1):
"""Calculate the difference between values at positions i and i - N in
an array and store the output in a new array.
"""First discrete difference of element.
Calculates the difference of a Series element compared with another
element in the Series (default is element in previous row).
Parameters
----------
periods : int, default 1
Periods to shift for calculating difference,
accepts negative values.
Returns
-------
Series
First differences of the Series.
Notes
-----
Diff currently only supports float and integer dtype columns with
no null values.
Examples
--------
>>> import cudf
Expand Down Expand Up @@ -3028,32 +3030,12 @@ def diff(self, periods=1):
5 <NA>
dtype: int64
"""
if self.has_nulls:
raise AssertionError(
"Diff currently requires columns with no null values"
)

if not np.issubdtype(self.dtype, np.number):
raise NotImplementedError(
"Diff currently only supports numeric dtypes"
)

# TODO: move this libcudf
input_col = self._column
output_col = column_empty_like(input_col)
output_mask = column_empty_like(input_col, dtype="bool")
if output_col.size > 0:
cudautils.gpu_diff.forall(output_col.size)(
input_col, output_col, output_mask, periods
)

output_col = column.build_column(
data=output_col.data,
dtype=output_col.dtype,
mask=bools_to_mask(output_mask),
)
if not is_integer(periods):
if not (is_float(periods) and periods.is_integer()):
raise ValueError("periods must be an integer")
periods = int(periods)

return Series(output_col, name=self.name, index=self.index)
return self - self.shift(periods=periods)

@copy_docstring(SeriesGroupBy)
@_cudf_nvtx_annotate
Expand Down
28 changes: 21 additions & 7 deletions python/cudf/cudf/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9098,7 +9098,7 @@ def test_groupby_cov_for_pandas_bug_case():
],
)
@pytest.mark.parametrize("periods", (-5, -1, 0, 1, 5))
def test_diff_dataframe_numeric_dtypes(data, periods):
def test_diff_numeric_dtypes(data, periods):
gdf = cudf.DataFrame(data)
pdf = gdf.to_pandas()

Expand Down Expand Up @@ -9137,7 +9137,7 @@ def test_diff_decimal_dtypes(precision, scale, dtype):
)


def test_diff_dataframe_invalid_axis():
def test_diff_invalid_axis():
gdf = cudf.DataFrame(np.array([1.123, 2.343, 5.890, 0.0]))
with pytest.raises(NotImplementedError, match="Only axis=0 is supported."):
gdf.diff(periods=1, axis=1)
Expand All @@ -9152,16 +9152,30 @@ def test_diff_dataframe_invalid_axis():
"string_col": ["a", "b", "c", "d", "e"],
},
["a", "b", "c", "d", "e"],
[np.nan, None, np.nan, None],
],
)
def test_diff_dataframe_non_numeric_dypes(data):
def test_diff_unsupported_dtypes(data):
gdf = cudf.DataFrame(data)
with pytest.raises(
NotImplementedError,
match="DataFrame.diff only supports numeric dtypes",
TypeError,
match=r"unsupported operand type\(s\)",
):
gdf.diff(periods=2, axis=0)
gdf.diff()


def test_diff_many_dtypes():
pdf = pd.DataFrame(
{
"dates": pd.date_range("2020-01-01", "2020-01-06", freq="D"),
"bools": [True, True, True, False, True, True],
"floats": [1.0, 2.0, 3.5, np.nan, 5.0, -1.7],
"ints": [1, 2, 3, 3, 4, 5],
"nans_nulls": [np.nan, None, None, np.nan, np.nan, None],
}
)
gdf = cudf.from_pandas(pdf)
assert_eq(pdf.diff(), gdf.diff())
assert_eq(pdf.diff(periods=2), gdf.diff(periods=2))


def test_dataframe_assign_cp_np_array():
Expand Down
58 changes: 58 additions & 0 deletions python/cudf/cudf/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
TIMEDELTA_TYPES,
assert_eq,
assert_exceptions_equal,
gen_rand,
)


Expand Down Expand Up @@ -1724,3 +1725,60 @@ def test_isin_categorical(data, values):
got = gsr.isin(values)
expected = psr.isin(values)
assert_eq(got, expected)


@pytest.mark.parametrize("dtype", NUMERIC_TYPES)
@pytest.mark.parametrize("period", [-1, -5, -10, -20, 0, 1, 5, 10, 20])
@pytest.mark.parametrize("data_empty", [False, True])
def test_diff(dtype, period, data_empty):
if data_empty:
data = None
else:
if dtype == np.int8:
# to keep data in range
data = gen_rand(dtype, 100000, low=-2, high=2)
else:
data = gen_rand(dtype, 100000)

gs = cudf.Series(data, dtype=dtype)
ps = pd.Series(data, dtype=dtype)

expected_outcome = ps.diff(period)
diffed_outcome = gs.diff(period).astype(expected_outcome.dtype)

if data_empty:
assert_eq(diffed_outcome, expected_outcome, check_index_type=False)
else:
assert_eq(diffed_outcome, expected_outcome)


@pytest.mark.parametrize(
"data",
[
["a", "b", "c", "d", "e"],
],
)
def test_diff_unsupported_dtypes(data):
gs = cudf.Series(data)
with pytest.raises(
TypeError,
match=r"unsupported operand type\(s\)",
):
gs.diff()


@pytest.mark.parametrize(
"data",
[
pd.date_range("2020-01-01", "2020-01-06", freq="D"),
[True, True, True, False, True, True],
[1.0, 2.0, 3.5, 4.0, 5.0, -1.7],
[1, 2, 3, 3, 4, 5],
[np.nan, None, None, np.nan, np.nan, None],
],
)
def test_diff_many_dtypes(data):
ps = pd.Series(data)
gs = cudf.from_pandas(ps)
assert_eq(ps.diff(), gs.diff())
assert_eq(ps.diff(periods=2), gs.diff(periods=2))
21 changes: 0 additions & 21 deletions python/cudf/cudf/utils/cudautils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,6 @@
#


@cuda.jit
def gpu_diff(in_col, out_col, out_mask, N):
"""Calculate the difference between values at positions i and i - N in an
array and store the output in a new array.
"""
i = cuda.grid(1)

if N > 0:
if i < in_col.size:
out_col[i] = in_col[i] - in_col[i - N]
out_mask[i] = True
if i < N:
out_mask[i] = False
else:
if i <= (in_col.size + N):
out_col[i] = in_col[i] - in_col[i - N]
out_mask[i] = True
if i >= (in_col.size + N) and i < in_col.size:
out_mask[i] = False


# Find segments


Expand Down

0 comments on commit 94a5d41

Please sign in to comment.