From c76ec131dd5eb2517de90a23aaa328f5e9c7b064 Mon Sep 17 00:00:00 2001 From: GALI PREM SAGAR Date: Tue, 17 Aug 2021 18:14:34 -0500 Subject: [PATCH] Fix `Dataframe` indexer setitem when array is passed (#9006) Fixes: #8672 This PR handles `ndarray` inputs in the `_DataFrameLocIndexer.__setitem__` Authors: - GALI PREM SAGAR (https://github.com/galipremsagar) Approvers: - Marlene (https://github.com/marlenezw) URL: https://github.com/rapidsai/cudf/pull/9006 --- python/cudf/cudf/core/indexing.py | 25 ++++++++++++++-- python/cudf/cudf/tests/test_dataframe.py | 37 ++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 3 deletions(-) diff --git a/python/cudf/cudf/core/indexing.py b/python/cudf/cudf/core/indexing.py index 09cfc6e144a..da999f13fa8 100755 --- a/python/cudf/cudf/core/indexing.py +++ b/python/cudf/cudf/core/indexing.py @@ -432,7 +432,7 @@ def _setitem_tuple_arg(self, key, value): ) try: - columns = self._get_column_selection(key[1]) + columns_df = self._get_column_selection(key[1]) except KeyError: if not self._df.empty and isinstance(key[0], slice): pos_range = get_label_range_or_mask( @@ -457,8 +457,27 @@ def _setitem_tuple_arg(self, key, value): ) self._df._data.insert(key[1], new_col) else: - for col in columns: - self._df[col].loc[key[0]] = value + if isinstance(value, (cp.ndarray, np.ndarray)): + value_df = cudf.DataFrame(value) + if value_df.shape[1] != columns_df.shape[1]: + if value_df.shape[1] == 1: + value_cols = ( + value_df._data.columns * columns_df.shape[1] + ) + else: + raise ValueError( + f"shape mismatch: value array of shape " + f"{value_df.shape} could not be " + f"broadcast to indexing result of shape " + f"{columns_df.shape}" + ) + else: + value_cols = value_df._data.columns + for i, col in enumerate(columns_df._column_names): + self._df[col].loc[key[0]] = value_cols[i] + else: + for col in columns_df._column_names: + self._df[col].loc[key[0]] = value def _get_column_selection(self, arg): return self._df._get_columns_by_label(arg) diff --git a/python/cudf/cudf/tests/test_dataframe.py b/python/cudf/cudf/tests/test_dataframe.py index a37a80236c1..a337660b5b0 100644 --- a/python/cudf/cudf/tests/test_dataframe.py +++ b/python/cudf/cudf/tests/test_dataframe.py @@ -8759,6 +8759,43 @@ def test_frame_series_where(): assert_eq(expected, actual) +@pytest.mark.parametrize( + "array,is_error", + [ + (cupy.arange(20, 40).reshape(-1, 2), False), + (cupy.arange(20, 50).reshape(-1, 3), True), + (np.arange(20, 40).reshape(-1, 2), False), + (np.arange(20, 30).reshape(-1, 1), False), + (cupy.arange(20, 30).reshape(-1, 1), False), + ], +) +def test_dataframe_indexing_setitem_np_cp_array(array, is_error): + gdf = cudf.DataFrame({"a": range(10), "b": range(10)}) + pdf = gdf.to_pandas() + if not is_error: + gdf.loc[:, ["a", "b"]] = array + pdf.loc[:, ["a", "b"]] = cupy.asnumpy(array) + + assert_eq(gdf, pdf) + else: + assert_exceptions_equal( + lfunc=pdf.loc.__setitem__, + rfunc=gdf.loc.__setitem__, + lfunc_args_and_kwargs=( + [(slice(None, None, None), ["a", "b"]), cupy.asnumpy(array)], + {}, + ), + rfunc_args_and_kwargs=( + [(slice(None, None, None), ["a", "b"]), array], + {}, + ), + compare_error_message=False, + expected_error_message="shape mismatch: value array of shape " + "(10, 3) could not be broadcast to indexing " + "result of shape (10, 2)", + ) + + @pytest.mark.parametrize( "data", [{"a": [1, 2, 3], "b": [1, 1, 0]}], )