From 296eddc84adeb1236fa10d5c628f44453986c384 Mon Sep 17 00:00:00 2001 From: brandon-b-miller Date: Fri, 6 Aug 2021 05:39:30 -0700 Subject: [PATCH] address review --- python/cudf/cudf/core/frame.py | 17 ++++++++--------- python/cudf/cudf/tests/test_interpolate.py | 5 ++++- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/python/cudf/cudf/core/frame.py b/python/cudf/cudf/core/frame.py index 8d4b66bcbc7..cbd92920b33 100644 --- a/python/cudf/cudf/core/frame.py +++ b/python/cudf/cudf/core/frame.py @@ -1490,28 +1490,27 @@ def interpolate( ) data = self - columns = {} if not isinstance(data._index, cudf.RangeIndex): perm_sort = data._index.argsort() data = data._gather(perm_sort) interpolator = cudf.core.algorithms.get_column_interpolator(method) + columns = {} for colname, col in data._data.items(): if col.nullable: col = col.astype("float64").fillna(np.nan) # Interpolation methods may or may not need the index - result = interpolator(col, index=data._index) - columns[colname] = result - - result = self.__class__(ColumnAccessor(columns), index=data._index) + columns[colname] = interpolator(col, index=data._index) - if not isinstance(data._index, cudf.RangeIndex): - # that which was once sorted, now is not - result = result._gather(perm_sort.argsort()) + result = self._from_data(columns, index=data._index) - return result + return ( + result + if isinstance(data._index, cudf.RangeIndex) + else result._gather(perm_sort.argsort()) + ) def _quantiles( self, diff --git a/python/cudf/cudf/tests/test_interpolate.py b/python/cudf/cudf/tests/test_interpolate.py index e9b9e03891e..66556c48828 100644 --- a/python/cudf/cudf/tests/test_interpolate.py +++ b/python/cudf/cudf/tests/test_interpolate.py @@ -16,7 +16,10 @@ @pytest.mark.parametrize("method", ["linear"]) @pytest.mark.parametrize("axis", [0]) def test_interpolate_dataframe(data, method, axis): - # doesn't seem to work with NAs just yet + # Pandas interpolate methods do not seem to work + # with nullable dtypes yet, so this method treats + # NAs as NaNs + # https://github.com/pandas-dev/pandas/issues/40252 gdf = cudf.DataFrame(data) pdf = gdf.to_pandas()