Skip to content

Commit

Permalink
Fix Dataframe indexer setitem when array is passed (#9006)
Browse files Browse the repository at this point in the history
Fixes: #8672 

This PR handles `ndarray` inputs in the `_DataFrameLocIndexer.__setitem__`

Authors:
  - GALI PREM SAGAR (https://github.com/galipremsagar)

Approvers:
  - Marlene  (https://github.com/marlenezw)

URL: #9006
  • Loading branch information
galipremsagar authored Aug 17, 2021
1 parent b3c1caf commit c76ec13
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 3 deletions.
25 changes: 22 additions & 3 deletions python/cudf/cudf/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ def _setitem_tuple_arg(self, key, value):
)

try:
columns = self._get_column_selection(key[1])
columns_df = self._get_column_selection(key[1])
except KeyError:
if not self._df.empty and isinstance(key[0], slice):
pos_range = get_label_range_or_mask(
Expand All @@ -457,8 +457,27 @@ def _setitem_tuple_arg(self, key, value):
)
self._df._data.insert(key[1], new_col)
else:
for col in columns:
self._df[col].loc[key[0]] = value
if isinstance(value, (cp.ndarray, np.ndarray)):
value_df = cudf.DataFrame(value)
if value_df.shape[1] != columns_df.shape[1]:
if value_df.shape[1] == 1:
value_cols = (
value_df._data.columns * columns_df.shape[1]
)
else:
raise ValueError(
f"shape mismatch: value array of shape "
f"{value_df.shape} could not be "
f"broadcast to indexing result of shape "
f"{columns_df.shape}"
)
else:
value_cols = value_df._data.columns
for i, col in enumerate(columns_df._column_names):
self._df[col].loc[key[0]] = value_cols[i]
else:
for col in columns_df._column_names:
self._df[col].loc[key[0]] = value

def _get_column_selection(self, arg):
return self._df._get_columns_by_label(arg)
Expand Down
37 changes: 37 additions & 0 deletions python/cudf/cudf/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8759,6 +8759,43 @@ def test_frame_series_where():
assert_eq(expected, actual)


@pytest.mark.parametrize(
"array,is_error",
[
(cupy.arange(20, 40).reshape(-1, 2), False),
(cupy.arange(20, 50).reshape(-1, 3), True),
(np.arange(20, 40).reshape(-1, 2), False),
(np.arange(20, 30).reshape(-1, 1), False),
(cupy.arange(20, 30).reshape(-1, 1), False),
],
)
def test_dataframe_indexing_setitem_np_cp_array(array, is_error):
gdf = cudf.DataFrame({"a": range(10), "b": range(10)})
pdf = gdf.to_pandas()
if not is_error:
gdf.loc[:, ["a", "b"]] = array
pdf.loc[:, ["a", "b"]] = cupy.asnumpy(array)

assert_eq(gdf, pdf)
else:
assert_exceptions_equal(
lfunc=pdf.loc.__setitem__,
rfunc=gdf.loc.__setitem__,
lfunc_args_and_kwargs=(
[(slice(None, None, None), ["a", "b"]), cupy.asnumpy(array)],
{},
),
rfunc_args_and_kwargs=(
[(slice(None, None, None), ["a", "b"]), array],
{},
),
compare_error_message=False,
expected_error_message="shape mismatch: value array of shape "
"(10, 3) could not be broadcast to indexing "
"result of shape (10, 2)",
)


@pytest.mark.parametrize(
"data", [{"a": [1, 2, 3], "b": [1, 1, 0]}],
)
Expand Down

0 comments on commit c76ec13

Please sign in to comment.