From 431d3171eb19514c995bddd5000162d86cad2d39 Mon Sep 17 00:00:00 2001 From: galipremsagar Date: Sun, 7 Feb 2021 22:51:31 -0600 Subject: [PATCH 1/2] handle cupy array in setitem of dataframe --- python/cudf/cudf/core/dataframe.py | 11 ++++++++++- python/cudf/cudf/tests/test_dataframe.py | 12 ++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/python/cudf/cudf/core/dataframe.py b/python/cudf/cudf/core/dataframe.py index 8e4e54cafdb..a78e43aace8 100644 --- a/python/cudf/cudf/core/dataframe.py +++ b/python/cudf/cudf/core/dataframe.py @@ -777,7 +777,16 @@ def __setitem__(self, arg, value): self.insert(len(self._data), arg, value) elif isinstance( - arg, (list, np.ndarray, pd.Series, Series, Index, pd.Index) + arg, + ( + list, + np.ndarray, + cupy.ndarray, + pd.Series, + Series, + Index, + pd.Index, + ), ): mask = arg if isinstance(mask, list): diff --git a/python/cudf/cudf/tests/test_dataframe.py b/python/cudf/cudf/tests/test_dataframe.py index 39d2a980e87..d8005911fcd 100644 --- a/python/cudf/cudf/tests/test_dataframe.py +++ b/python/cudf/cudf/tests/test_dataframe.py @@ -8315,3 +8315,15 @@ def test_dataframe_roundtrip_arrow_struct_dtype(gdf): expected = gd.DataFrame.from_arrow(table) assert_eq(gdf, expected) + + +def test_dataframe_setitem_cupy_array(): + np.random.seed(0) + pdf = pd.DataFrame(np.random.randn(10, 2)) + gdf = gd.from_pandas(pdf) + + gpu_array = cupy.array([True, False] * 5) + pdf[gpu_array.get()] = 1.5 + gdf[gpu_array] = 1.5 + + assert_eq(pdf, gdf) From 71ba7e5755da08c5159ebb68b3b6d2d9d27e6a92 Mon Sep 17 00:00:00 2001 From: galipremsagar Date: Mon, 8 Feb 2021 09:13:23 -0800 Subject: [PATCH 2/2] use utility functions --- python/cudf/cudf/core/dataframe.py | 31 +++++------------------------- 1 file changed, 5 insertions(+), 26 deletions(-) diff --git a/python/cudf/cudf/core/dataframe.py b/python/cudf/cudf/core/dataframe.py index a78e43aace8..59e7a0a7a8a 100644 --- a/python/cudf/cudf/core/dataframe.py +++ b/python/cudf/cudf/core/dataframe.py @@ -39,6 +39,7 @@ from cudf.utils import applyutils, docutils, ioutils, queryutils, utils from cudf.utils.docutils import copy_docstring from cudf.utils.dtypes import ( + can_convert_to_column, cudf_dtype_from_pydata_dtype, find_common_type, is_categorical_dtype, @@ -683,20 +684,9 @@ def __getitem__(self, arg): elif isinstance(arg, slice): return self._slice(arg) - elif isinstance( - arg, - ( - list, - cupy.ndarray, - np.ndarray, - pd.Series, - Series, - Index, - pd.Index, - ), - ): + elif can_convert_to_column(arg): mask = arg - if isinstance(mask, list): + if is_list_like(mask): mask = pd.Series(mask) if mask.dtype == "bool": return self._apply_boolean_mask(mask) @@ -776,20 +766,9 @@ def __setitem__(self, arg, value): # pandas raises key error here self.insert(len(self._data), arg, value) - elif isinstance( - arg, - ( - list, - np.ndarray, - cupy.ndarray, - pd.Series, - Series, - Index, - pd.Index, - ), - ): + elif can_convert_to_column(arg): mask = arg - if isinstance(mask, list): + if is_list_like(mask): mask = np.array(mask) if mask.dtype == "bool":