Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG FIX] In-place updates with loc or iloc don't work correctly when the LHS has more than one column #9918

Merged
merged 100 commits into from
May 4, 2022
Merged
Show file tree
Hide file tree
Changes from 99 commits
Commits
Show all changes
100 commits
Select commit Hold shift + click to select a range
8fb15f2
create new PR
skirui-source Dec 16, 2021
8b284e6
Merge branch 'branch-22.02' of https://github.com/rapidsai/cudf into …
skirui-source Jan 6, 2022
5a08fa9
Merge branch 'branch-22.02' of https://github.com/rapidsai/cudf into …
skirui-source Jan 7, 2022
07aa29f
Merge branch 'branch-22.02' of https://github.com/rapidsai/cudf into …
skirui-source Jan 12, 2022
bed1802
Merge branch 'branch-22.04' of https://github.com/rapidsai/cudf into …
skirui-source Jan 19, 2022
7c534a3
Merge branch 'branch-22.04' of https://github.com/rapidsai/cudf into …
skirui-source Jan 19, 2022
5510500
Merge branch 'branch-22.04' of https://github.com/rapidsai/cudf into …
skirui-source Feb 22, 2022
c08852c
Merge branch 'branch-22.04' of https://github.com/rapidsai/cudf into …
skirui-source Feb 23, 2022
b4b263b
Merge branch 'branch-22.04' of https://github.com/rapidsai/cudf into …
skirui-source Mar 2, 2022
47cd7a7
Merge branch 'branch-22.04' of https://github.com/rapidsai/cudf into …
skirui-source Mar 2, 2022
2531ceb
Merge branch 'branch-22.04' of https://github.com/rapidsai/cudf into …
skirui-source Mar 2, 2022
24ea9f9
paired with Mike on these edits
skirui-source Mar 3, 2022
66a46d5
Merge branch 'branch-22.04' of https://github.com/rapidsai/cudf into …
skirui-source Mar 8, 2022
b359402
Merge branch 'branch-22.04' of https://github.com/rapidsai/cudf into …
skirui-source Mar 9, 2022
f898d57
Merge branch 'branch-22.04' of https://github.com/rapidsai/cudf into …
skirui-source Mar 11, 2022
bd7b9a6
implemented for DataFrameIlocIndexer as well
skirui-source Mar 12, 2022
7e47433
Merge branch 'branch-22.04' of https://github.com/rapidsai/cudf into …
skirui-source Mar 12, 2022
a73ee17
preliminary adding tests
skirui-source Mar 12, 2022
1f83014
Merge branch 'branch-22.04' of https://github.com/rapidsai/cudf into …
skirui-source Mar 16, 2022
73a9776
added standard 3 tests
skirui-source Mar 16, 2022
f7d0ce8
added more tests. ready for initial review
skirui-source Mar 16, 2022
0618827
Merge branch 'branch-22.04' of https://github.com/rapidsai/cudf into …
skirui-source Mar 16, 2022
42d7a05
Merge branch 'branch-22.04' of https://github.com/rapidsai/cudf into …
skirui-source Mar 17, 2022
bc304d0
Update python/cudf/cudf/core/dataframe.py
skirui-source Mar 17, 2022
da0520e
Update python/cudf/cudf/core/dataframe.py
skirui-source Mar 17, 2022
a4ef228
Merge branch 'i-loc.bug' of github.com:skirui-source/cudf into i-loc.bug
skirui-source Mar 17, 2022
94ba1d0
Merge branch 'branch-22.04' of https://github.com/rapidsai/cudf into …
skirui-source Mar 18, 2022
de9a8c2
use string Template to formate Except errors
skirui-source Mar 18, 2022
5f29051
parameterize tests using dict and list of indices
skirui-source Mar 18, 2022
7f1cc51
fixed merge conflict in dataframe.py
skirui-source Mar 22, 2022
3a15a2f
Merge branch 'branch-22.04' of https://github.com/rapidsai/cudf into …
skirui-source Mar 23, 2022
20d9231
Merge branch 'branch-22.06' of https://github.com/rapidsai/cudf into …
skirui-source Mar 23, 2022
a5351ad
Merge branch 'branch-22.06' of https://github.com/rapidsai/cudf into …
skirui-source Mar 25, 2022
8d01231
Merge branch 'branch-22.06' of https://github.com/rapidsai/cudf into …
skirui-source Mar 30, 2022
48325a6
Merge branch 'branch-22.06' of https://github.com/rapidsai/cudf into …
skirui-source Apr 4, 2022
0134dc8
Merge branch 'branch-22.06' of https://github.com/rapidsai/cudf into …
skirui-source Apr 4, 2022
be230fc
Merge branch 'branch-22.06' of https://github.com/rapidsai/cudf into …
skirui-source Apr 5, 2022
83ae853
added support for non-numeric index labels
skirui-source Apr 5, 2022
a70ee3d
Merge branch 'branch-22.06' of https://github.com/rapidsai/cudf into …
skirui-source Apr 5, 2022
08ea6c8
replace with select_by_index in setitem for ilocindexer
skirui-source Apr 5, 2022
7df403d
Update python/cudf/cudf/core/dataframe.py
skirui-source Apr 6, 2022
942db3e
addressed michael's reviews
skirui-source Apr 6, 2022
5c9518a
Merge branch 'branch-22.06' of https://github.com/rapidsai/cudf into …
skirui-source Apr 6, 2022
5c50091
reverted changes, after discussion with mike
skirui-source Apr 7, 2022
b0ac194
Merge branch 'branch-22.06' of https://github.com/rapidsai/cudf into …
skirui-source Apr 8, 2022
90d2e5b
Merge branch 'branch-22.06' of https://github.com/rapidsai/cudf into …
skirui-source Apr 8, 2022
2a7b195
Merge branch 'branch-22.06' of https://github.com/rapidsai/cudf into …
skirui-source Apr 8, 2022
46b9290
Merge branch 'branch-22.06' of https://github.com/rapidsai/cudf into …
skirui-source Apr 11, 2022
e3f5993
Merge branch 'branch-22.06' of https://github.com/rapidsai/cudf into …
skirui-source Apr 12, 2022
c63a1e4
added check for scalar value
skirui-source Apr 13, 2022
827a741
Merge branch 'branch-22.06' of https://github.com/rapidsai/cudf into …
skirui-source Apr 13, 2022
5e0fe69
Merge branch 'branch-22.06' of https://github.com/rapidsai/cudf into …
skirui-source Apr 14, 2022
0d6000f
removed indexing key by 0 for iloc
skirui-source Apr 14, 2022
84c54fe
resolved conflict in test_dataframe.py
skirui-source Apr 18, 2022
b7b6699
WIP: pairing with michael to fix edge cases - slice keys and range va…
skirui-source Apr 18, 2022
7ec03b3
Merge branch 'branch-22.06' of https://github.com/rapidsai/cudf into …
skirui-source Apr 18, 2022
54f6b8e
Merge branch 'branch-22.06' of https://github.com/rapidsai/cudf into …
skirui-source Apr 19, 2022
92642bc
Move indexing tests to `test_indexing.py` and add a couple more tests…
isVoid Apr 20, 2022
47a5527
apply is_scalar check to column axis
isVoid Apr 20, 2022
fdee979
scalar axis indexer test case on `iloc`
isVoid Apr 20, 2022
bec434c
Minor cleanups
isVoid Apr 20, 2022
42ef238
Merge branch 'i-loc.bug' of https://github.com/skirui-source/cudf int…
isVoid Apr 20, 2022
0e94717
Merge branch 'branch-22.06' of https://github.com/rapidsai/cudf into …
skirui-source Apr 20, 2022
5db598d
Merge branch 'i-loc.bug' of github.com:skirui-source/cudf into i-loc.bug
skirui-source Apr 20, 2022
d063d18
remove breakpoint in indexed_frame.py
skirui-source Apr 20, 2022
cfe210b
Merge branch 'branch-22.06' of https://github.com/rapidsai/cudf into …
skirui-source Apr 20, 2022
93915d1
Update python/cudf/cudf/core/dataframe.py
skirui-source Apr 20, 2022
d41400a
Merge branch 'i-loc.bug' of github.com:skirui-source/cudf into i-loc.bug
skirui-source Apr 20, 2022
4a6a3bd
addressed michael's review comments
skirui-source Apr 20, 2022
862ca2e
Update python/cudf/cudf/core/dataframe.py
skirui-source Apr 21, 2022
ca05ae8
Update python/cudf/cudf/core/dataframe.py
skirui-source Apr 21, 2022
b43745d
added check for scalar key in loc and iloc indexer
skirui-source Apr 22, 2022
7318627
Merge branch 'branch-22.06' of https://github.com/rapidsai/cudf into …
skirui-source Apr 25, 2022
2922ce2
Merge branch 'branch-22.06' of https://github.com/rapidsai/cudf into …
skirui-source Apr 26, 2022
7d43640
Use `cupy.asarray` instead of `np.array` to convert as device data an…
isVoid Apr 26, 2022
973befa
Removes branch that handles ndarray and add column-broadcasting logic…
isVoid Apr 26, 2022
d1976ed
Update python/cudf/cudf/core/dataframe.py
skirui-source Apr 26, 2022
7196b0c
Update python/cudf/cudf/core/dataframe.py
skirui-source Apr 26, 2022
cb10ad7
Merge branch 'branch-22.06' of https://github.com/rapidsai/cudf into …
skirui-source Apr 27, 2022
924017b
Merge branch 'i-loc.bug' of github.com:skirui-source/cudf into i-loc.bug
skirui-source Apr 27, 2022
89e6373
fixed style check failures
skirui-source Apr 27, 2022
7a40a5e
Merge branch 'branch-22.06' of https://github.com/rapidsai/cudf into …
skirui-source Apr 27, 2022
c8ccb9c
moved test from test_datframe.py > test_indexing.py
skirui-source Apr 27, 2022
f918ed9
Merge branch 'branch-22.06' of https://github.com/rapidsai/cudf into …
skirui-source Apr 28, 2022
aa7b1ec
Merge branch 'branch-22.06' of https://github.com/rapidsai/cudf into …
skirui-source Apr 28, 2022
dbf60e4
Merge branch 'branch-22.06' of https://github.com/rapidsai/cudf into …
skirui-source Apr 29, 2022
94bd3d6
renamed to shape_mismatch_error_msg
skirui-source Apr 29, 2022
a4a22d2
added helper function shape_mismatch_error
skirui-source Apr 29, 2022
8d87676
Merge branch 'branch-22.06' of https://github.com/rapidsai/cudf into …
skirui-source Apr 29, 2022
9aa42a8
added few more shape mismatch tests
skirui-source Apr 30, 2022
80e5c11
Merge branch 'branch-22.06' of https://github.com/rapidsai/cudf into …
skirui-source Apr 30, 2022
c379a5b
separated the shape mismatch tests
skirui-source May 1, 2022
cf9dfdb
Merge branch 'branch-22.06' of https://github.com/rapidsai/cudf into …
skirui-source May 2, 2022
9440827
parameterize iloc/loc tests for shape mismatch
skirui-source May 3, 2022
bedf6ff
Merge branch 'branch-22.06' of https://github.com/rapidsai/cudf into …
skirui-source May 3, 2022
bea3f52
separate shape mismatch iloc/loc test for RHS df
skirui-source May 3, 2022
ecee045
.
skirui-source May 3, 2022
03c2617
Merge branch 'branch-22.06' of https://github.com/rapidsai/cudf into …
skirui-source May 3, 2022
e7f36cc
remove raise ValueError call to the shape_mismatch_error function
skirui-source May 3, 2022
194920f
internal method
galipremsagar May 4, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 100 additions & 23 deletions python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,14 @@
}


def shape_mismatch_error(x, y):
galipremsagar marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
f"shape mismatch: value array of shape {x} "
f"could not be broadcast to indexing result of "
f"shape {y}"
)


class _DataFrameIndexer(_FrameIndexer):
def __getitem__(self, arg):
if (
Expand Down Expand Up @@ -342,28 +350,58 @@ def _setitem_tuple_arg(self, key, value):
)
self._frame._data.insert(key[1], new_col)
else:
if isinstance(value, (cupy.ndarray, np.ndarray)):
value_df = DataFrame(value)
if value_df.shape[1] != columns_df.shape[1]:
if value_df.shape[1] == 1:
value_cols = (
value_df._data.columns * columns_df.shape[1]
)
else:
raise ValueError(
f"shape mismatch: value array of shape "
f"{value_df.shape} could not be "
f"broadcast to indexing result of shape "
f"{columns_df.shape}"
)
else:
value_cols = value_df._data.columns
for i, col in enumerate(columns_df._column_names):
self._frame[col].loc[key[0]] = value_cols[i]
else:
if is_scalar(value):
for col in columns_df._column_names:
self._frame[col].loc[key[0]] = value

elif isinstance(value, cudf.DataFrame):
if value.shape != self._frame.loc[key[0]].shape:
shape_mismatch_error(
galipremsagar marked this conversation as resolved.
Show resolved Hide resolved
value.shape,
self._frame.loc[key[0]].shape,
)
value_column_names = set(value._column_names)
scatter_map = _indices_from_labels(self._frame, key[0])
for col in columns_df._column_names:
columns_df[col][scatter_map] = (
value._data[col]
if col in value_column_names
else cudf.NA
)

else:
value = cupy.asarray(value)
if cupy.ndim(value) == 2:
# If the inner dimension is 1, it's broadcastable to
# all columns of the dataframe.
indexed_shape = columns_df.loc[key[0]].shape
if value.shape[1] == 1:
if value.shape[0] != indexed_shape[0]:
shape_mismatch_error(value.shape, indexed_shape)
galipremsagar marked this conversation as resolved.
Show resolved Hide resolved
for i, col in enumerate(columns_df._column_names):
self._frame[col].loc[key[0]] = value[:, 0]
else:
if value.shape != indexed_shape:
shape_mismatch_error(value.shape, indexed_shape)
galipremsagar marked this conversation as resolved.
Show resolved Hide resolved
for i, col in enumerate(columns_df._column_names):
self._frame[col].loc[key[0]] = value[:, i]
else:
# handle cases where value is 1d object:
# If the key on column axis is a scalar, we indexed
# a single column; The 1d value should assign along
# the columns.
if is_scalar(key[1]):
for col in columns_df._column_names:
self._frame[col].loc[key[0]] = value
# Otherwise, there are two situations. The key on row axis
# can be a scalar or 1d. In either of the situation, the
# ith element in value corresponds to the ith row in
# the indexed object.
# If the key is 1d, a broadcast will happen.
else:
for i, col in enumerate(columns_df._column_names):
self._frame[col].loc[key[0]] = value[i]


class _DataFrameIlocIndexer(_DataFrameIndexer):
"""
Expand Down Expand Up @@ -424,10 +462,49 @@ def _getitem_tuple_arg(self, arg):

@_cudf_nvtx_annotate
def _setitem_tuple_arg(self, key, value):
# TODO: Determine if this usage is prevalent enough to expose this
# selection logic at a higher level than ColumnAccessor.
for col in self._frame._data.get_labels_by_index(key[1]):
self._frame[col].iloc[key[0]] = value
columns_df = self._frame._from_data(
self._frame._data.select_by_index(key[1]), self._frame._index
)

if is_scalar(value):
for col in columns_df._column_names:
self._frame[col].iloc[key[0]] = value

elif isinstance(value, cudf.DataFrame):
if value.shape != self._frame.iloc[key[0]].shape:
shape_mismatch_error(
galipremsagar marked this conversation as resolved.
Show resolved Hide resolved
value.shape,
self._frame.loc[key[0]].shape,
)
value_column_names = set(value._column_names)
for col in columns_df._column_names:
columns_df[col][key[0]] = (
value._data[col] if col in value_column_names else cudf.NA
)

else:
# TODO: consolidate code path with identical counterpart
# in `_DataFrameLocIndexer._setitem_tuple_arg`
value = cupy.asarray(value)
if cupy.ndim(value) == 2:
indexed_shape = columns_df.iloc[key[0]].shape
if value.shape[1] == 1:
if value.shape[0] != indexed_shape[0]:
shape_mismatch_error(value.shape, indexed_shape)
galipremsagar marked this conversation as resolved.
Show resolved Hide resolved
for i, col in enumerate(columns_df._column_names):
self._frame[col].iloc[key[0]] = value[:, 0]
else:
if value.shape != indexed_shape:
shape_mismatch_error(value.shape, indexed_shape)
for i, col in enumerate(columns_df._column_names):
self._frame._data[col][key[0]] = value[:, i]
else:
if is_scalar(key[1]):
for col in columns_df._column_names:
self._frame[col].iloc[key[0]] = value
else:
for i, col in enumerate(columns_df._column_names):
self._frame[col].iloc[key[0]] = value[i]

def _getitem_scalar(self, arg):
col = self._frame.columns[arg[1]]
Expand Down
1 change: 0 additions & 1 deletion python/cudf/cudf/core/indexed_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ def _drop_columns(f: Frame, columns: abc.Iterable, errors: str):


def _indices_from_labels(obj, labels):

if not isinstance(labels, cudf.MultiIndex):
labels = cudf.core.column.as_column(labels)

Expand Down
37 changes: 0 additions & 37 deletions python/cudf/cudf/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8697,43 +8697,6 @@ def test_frame_series_where():
assert_eq(expected, actual)


@pytest.mark.parametrize(
"array,is_error",
[
(cupy.arange(20, 40).reshape(-1, 2), False),
(cupy.arange(20, 50).reshape(-1, 3), True),
(np.arange(20, 40).reshape(-1, 2), False),
(np.arange(20, 30).reshape(-1, 1), False),
(cupy.arange(20, 30).reshape(-1, 1), False),
],
)
def test_dataframe_indexing_setitem_np_cp_array(array, is_error):
gdf = cudf.DataFrame({"a": range(10), "b": range(10)})
pdf = gdf.to_pandas()
if not is_error:
gdf.loc[:, ["a", "b"]] = array
pdf.loc[:, ["a", "b"]] = cupy.asnumpy(array)

assert_eq(gdf, pdf)
else:
assert_exceptions_equal(
lfunc=pdf.loc.__setitem__,
rfunc=gdf.loc.__setitem__,
lfunc_args_and_kwargs=(
[(slice(None, None, None), ["a", "b"]), cupy.asnumpy(array)],
{},
),
rfunc_args_and_kwargs=(
[(slice(None, None, None), ["a", "b"]), array],
{},
),
compare_error_message=False,
expected_error_message="shape mismatch: value array of shape "
"(10, 3) could not be broadcast to indexing "
"result of shape (10, 2)",
)


@pytest.mark.parametrize(
"data",
[{"a": [1, 2, 3], "b": [1, 1, 0]}],
Expand Down
186 changes: 186 additions & 0 deletions python/cudf/cudf/tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1486,3 +1486,189 @@ def test_iloc_decimal():
["4.00", "3.00", "2.00", "1.00"],
).astype(cudf.Decimal64Dtype(scale=2, precision=3))
assert_eq(expect.reset_index(drop=True), got.reset_index(drop=True))


@pytest.mark.parametrize(
("key, value"),
[
(
([0], ["x", "y"]),
[10, 20],
),
(
([0, 2], ["x", "y"]),
[[10, 30], [20, 40]],
),
(
(0, ["x", "y"]),
[10, 20],
),
(
([0, 2], "x"),
[10, 20],
),
],
)
def test_dataframe_loc_inplace_update(key, value):
gdf = cudf.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})
pdf = gdf.to_pandas()

actual = gdf.loc[key] = value
expected = pdf.loc[key] = value

assert_eq(expected, actual)


def test_dataframe_loc_inplace_update_string_index():
gdf = cudf.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]}, index=list("abc"))
pdf = gdf.to_pandas()

actual = gdf.loc[["a"], ["x", "y"]] = [10, 20]
expected = pdf.loc[["a"], ["x", "y"]] = [10, 20]

assert_eq(expected, actual)


@pytest.mark.parametrize(
("key, value"),
[
([0], [10, 20]),
([0, 2], [[10, 30], [20, 40]]),
(([0, 2], [0, 1]), [[10, 30], [20, 40]]),
(([0, 2], 0), [10, 30]),
((0, [0, 1]), [20, 40]),
],
)
def test_dataframe_iloc_inplace_update(key, value):
gdf = cudf.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})
pdf = gdf.to_pandas()

actual = gdf.iloc[key] = value
expected = pdf.iloc[key] = value

assert_eq(expected, actual)


@pytest.mark.parametrize(
"loc_key",
[([0, 2], ["x", "y"])],
)
@pytest.mark.parametrize(
"iloc_key",
[[0, 2]],
)
@pytest.mark.parametrize(
("data, index"),
[
(
{"x": [10, 20], "y": [30, 40]},
[0, 2],
)
],
)
def test_dataframe_loc_iloc_inplace_update_with_RHS_dataframe(
loc_key, iloc_key, data, index
):
gdf = cudf.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})
pdf = gdf.to_pandas()

actual = gdf.loc[loc_key] = cudf.DataFrame(data, index=cudf.Index(index))
expected = pdf.loc[loc_key] = pd.DataFrame(data, index=pd.Index(index))
assert_eq(expected, actual)

actual = gdf.iloc[iloc_key] = cudf.DataFrame(data, index=cudf.Index(index))
expected = pdf.iloc[iloc_key] = pd.DataFrame(data, index=pd.Index(index))
assert_eq(expected, actual)


def test_dataframe_loc_inplace_update_with_invalid_RHS_df_columns():
gdf = cudf.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})
pdf = gdf.to_pandas()

actual = gdf.loc[[0, 2], ["x", "y"]] = cudf.DataFrame(
{"b": [10, 20], "y": [30, 40]}, index=cudf.Index([0, 2])
)
expected = pdf.loc[[0, 2], ["x", "y"]] = pd.DataFrame(
{"b": [10, 20], "y": [30, 40]}, index=pd.Index([0, 2])
)

assert_eq(expected, actual)


@pytest.mark.parametrize(
("key, value"),
[
(([0, 2], ["x", "y"]), [[10, 30, 50], [20, 40, 60]]),
(([0], ["x", "y"]), [[10], [20]]),
],
)
def test_dataframe_loc_inplace_update_shape_mismatch(key, value):
gdf = cudf.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})
with pytest.raises(ValueError, match="shape mismatch:"):
gdf.loc[key] = value


@pytest.mark.parametrize(
("key, value"),
[
([0, 2], [[10, 30, 50], [20, 40, 60]]),
([0], [[10], [20]]),
],
)
def test_dataframe_iloc_inplace_update_shape_mismatch(key, value):
gdf = cudf.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})
with pytest.raises(ValueError, match="shape mismatch:"):
gdf.iloc[key] = value


def test_dataframe_loc_inplace_update_shape_mismatch_RHS_df():
gdf = cudf.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})
with pytest.raises(ValueError, match="shape mismatch:"):
gdf.loc[([0, 2], ["x", "y"])] = cudf.DataFrame(
{"x": [10, 20]}, index=cudf.Index([0, 2])
)


def test_dataframe_iloc_inplace_update_shape_mismatch_RHS_df():
gdf = cudf.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})
with pytest.raises(ValueError, match="shape mismatch:"):
gdf.iloc[[0, 2]] = cudf.DataFrame(
{"x": [10, 20]}, index=cudf.Index([0, 2])
)


@pytest.mark.parametrize(
"array,is_error",
[
(cupy.arange(20, 40).reshape(-1, 2), False),
(cupy.arange(20, 50).reshape(-1, 3), True),
(np.arange(20, 40).reshape(-1, 2), False),
(np.arange(20, 30).reshape(-1, 1), False),
(cupy.arange(20, 30).reshape(-1, 1), False),
],
)
def test_dataframe_indexing_setitem_np_cp_array(array, is_error):
gdf = cudf.DataFrame({"a": range(10), "b": range(10)})
pdf = gdf.to_pandas()
if not is_error:
gdf.loc[:, ["a", "b"]] = array
pdf.loc[:, ["a", "b"]] = cupy.asnumpy(array)

assert_eq(gdf, pdf)
else:
assert_exceptions_equal(
lfunc=pdf.loc.__setitem__,
rfunc=gdf.loc.__setitem__,
lfunc_args_and_kwargs=(
[(slice(None, None, None), ["a", "b"]), cupy.asnumpy(array)],
{},
),
rfunc_args_and_kwargs=(
[(slice(None, None, None), ["a", "b"]), array],
{},
),
compare_error_message=False,
expected_error_message="shape mismatch: value array of shape "
"(10, 3) could not be broadcast to indexing "
"result of shape (10, 2)",
)