From 09abc44b52d31fb9f69c6790f1fc29a813ba5e8a Mon Sep 17 00:00:00 2001 From: galipremsagar Date: Fri, 14 Jan 2022 09:27:28 -0800 Subject: [PATCH] fix dataframe setitem --- python/cudf/cudf/core/dataframe.py | 18 ++++++++++++++++-- python/cudf/cudf/tests/test_dataframe.py | 11 +++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/python/cudf/cudf/core/dataframe.py b/python/cudf/cudf/core/dataframe.py index 123f86cc200..747b6460ffc 100644 --- a/python/cudf/cudf/core/dataframe.py +++ b/python/cudf/cudf/core/dataframe.py @@ -1123,7 +1123,15 @@ def __setitem__(self, arg, value): for col_name in self._data: self._data[col_name][mask] = value else: - if isinstance(value, DataFrame): + if isinstance(value, (cupy.ndarray, np.ndarray)): + _setitem_with_dataframe( + input_df=self, + replace_df=cudf.DataFrame(value), + input_cols=arg, + mask=None, + ignore_index=True, + ) + elif isinstance(value, DataFrame): _setitem_with_dataframe( input_df=self, replace_df=value, @@ -6393,6 +6401,7 @@ def _setitem_with_dataframe( replace_df: DataFrame, input_cols: Any = None, mask: Optional[cudf.core.column.ColumnBase] = None, + ignore_index: bool = False, ): """ This function sets item dataframes relevant columns with replacement df @@ -6400,6 +6409,7 @@ def _setitem_with_dataframe( :param replace_df: Replacement DataFrame to replace values with :param input_cols: columns to replace in the input dataframe :param mask: boolean mask in case of masked replacing + :param ignore_index: Whether to conduct index equality and reindex """ if input_cols is None: @@ -6410,7 +6420,11 @@ def _setitem_with_dataframe( "Number of Input Columns must be same replacement Dataframe" ) - if len(input_df) != 0 and not input_df.index.equals(replace_df.index): + if ( + not ignore_index + and len(input_df) != 0 + and not input_df.index.equals(replace_df.index) + ): replace_df = replace_df.reindex(input_df.index) for col_1, col_2 in zip(input_cols, replace_df.columns): diff --git a/python/cudf/cudf/tests/test_dataframe.py b/python/cudf/cudf/tests/test_dataframe.py index e5b298a8448..372587ba677 100644 --- a/python/cudf/cudf/tests/test_dataframe.py +++ b/python/cudf/cudf/tests/test_dataframe.py @@ -9030,3 +9030,14 @@ def test_dataframe_add_suffix(): expected = pdf.add_suffix("_item") assert_eq(got, expected) + + +def test_dataframe_assign_cp_np_array(): + m, n = 5, 3 + cp_ndarray = cupy.random.randn(m, n) + pdf = pd.DataFrame({f"f_{i}": range(m) for i in range(n)}) + gdf = cudf.DataFrame({f"f_{i}": range(m) for i in range(n)}) + pdf[[f"f_{i}" for i in range(n)]] = cupy.asnumpy(cp_ndarray) + gdf[[f"f_{i}" for i in range(n)]] = cp_ndarray + + assert_eq(pdf, gdf)