From 364abe13060aa2dede1d24a92518e5f789545b26 Mon Sep 17 00:00:00 2001 From: galipremsagar Date: Mon, 9 Aug 2021 16:32:30 -0700 Subject: [PATCH] fix dataframe setitem --- python/cudf/cudf/core/indexing.py | 25 ++++++++++++++-- python/cudf/cudf/tests/test_dataframe.py | 37 ++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 3 deletions(-) diff --git a/python/cudf/cudf/core/indexing.py b/python/cudf/cudf/core/indexing.py index a4a69a4e084..7cd48fafd45 100755 --- a/python/cudf/cudf/core/indexing.py +++ b/python/cudf/cudf/core/indexing.py @@ -431,7 +431,7 @@ def _setitem_tuple_arg(self, key, value): ) try: - columns = self._get_column_selection(key[1]) + columns_df = self._get_column_selection(key[1]) except KeyError: if not self._df.empty and isinstance(key[0], slice): pos_range = get_label_range_or_mask( @@ -456,8 +456,27 @@ def _setitem_tuple_arg(self, key, value): ) self._df._data.insert(key[1], new_col) else: - for col in columns: - self._df[col].loc[key[0]] = value + if isinstance(value, (cp.ndarray, np.ndarray)): + value_df = cudf.DataFrame(value) + if value_df.shape[1] != columns_df.shape[1]: + if value_df.shape[1] == 1: + value_cols = ( + value_df._data.columns * columns_df.shape[1] + ) + else: + raise ValueError( + f"shape mismatch: value array of shape " + f"{value_df.shape} could not be " + f"broadcast to indexing result of shape " + f"{columns_df.shape}" + ) + else: + value_cols = value_df._data.columns + for i, col in enumerate(columns_df._column_names): + self._df[col].loc[key[0]] = value_cols[i] + else: + for col in columns_df._column_names: + self._df[col].loc[key[0]] = value def _get_column_selection(self, arg): return self._df._get_columns_by_label(arg) diff --git a/python/cudf/cudf/tests/test_dataframe.py b/python/cudf/cudf/tests/test_dataframe.py index 76d24dcd5d2..de416546db1 100644 --- a/python/cudf/cudf/tests/test_dataframe.py +++ b/python/cudf/cudf/tests/test_dataframe.py @@ -8733,3 +8733,40 @@ def test_frame_series_where(): expected = gdf.where(gdf.notna(), gdf.mean()) actual = pdf.where(pdf.notna(), pdf.mean(), axis=1) assert_eq(expected, actual) + + +@pytest.mark.parametrize( + "array,is_error", + [ + (cupy.arange(20, 40).reshape(-1, 2), False), + (cupy.arange(20, 50).reshape(-1, 3), True), + (np.arange(20, 40).reshape(-1, 2), False), + (np.arange(20, 30).reshape(-1, 1), False), + (cupy.arange(20, 30).reshape(-1, 1), False), + ], +) +def test_dataframe_indexing_setitem_np_cp_array(array, is_error): + gdf = cudf.DataFrame({"a": range(10), "b": range(10)}) + pdf = gdf.to_pandas() + if not is_error: + gdf.loc[:, ["a", "b"]] = array + pdf.loc[:, ["a", "b"]] = cupy.asnumpy(array) + + assert_eq(gdf, pdf) + else: + assert_exceptions_equal( + lfunc=pdf.loc.__setitem__, + rfunc=gdf.loc.__setitem__, + lfunc_args_and_kwargs=( + [(slice(None, None, None), ["a", "b"]), cupy.asnumpy(array)], + {}, + ), + rfunc_args_and_kwargs=( + [(slice(None, None, None), ["a", "b"]), array], + {}, + ), + compare_error_message=False, + expected_error_message="shape mismatch: value array of shape " + "(10, 3) could not be broadcast to indexing " + "result of shape (10, 2)", + )