Skip to content

Commit

Permalink
fix issue with scalar broadcasting
Browse files Browse the repository at this point in the history
  • Loading branch information
galipremsagar committed Mar 1, 2022
1 parent 78b316c commit c786258
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
5 changes: 4 additions & 1 deletion python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,6 +1107,7 @@ def __getitem__(self, arg):
@annotate("DATAFRAME_SETITEM", color="blue", domain="cudf_python")
def __setitem__(self, arg, value):
"""Add/set column by *arg or DataFrame*"""
# import pdb;pdb.set_trace()
if isinstance(arg, DataFrame):
# not handling set_item where arg = df & value = df
if isinstance(value, DataFrame):
Expand Down Expand Up @@ -1161,7 +1162,9 @@ def __setitem__(self, arg, value):
allow_non_unique=True,
)
if is_scalar(value):
self._data[arg][:] = value
self._data[arg] = utils.scalar_broadcast_to(
value, len(self)
)
else:
value = as_column(value)
self._data[arg] = value
Expand Down
9 changes: 7 additions & 2 deletions python/cudf/cudf/tests/test_setitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,14 @@ def test_dataframe_setitem_scaler_bool():
assert_eq(df, gdf)


@pytest.mark.parametrize("df", [pd.DataFrame({"a": [1, 2, 3]})])
@pytest.mark.parametrize(
"df",
[pd.DataFrame({"a": [1, 2, 3]}), pd.DataFrame({"a": ["x", "y", "z"]})],
)
@pytest.mark.parametrize("arg", [["a"], "a", "b"])
@pytest.mark.parametrize("value", [-10, pd.DataFrame({"a": [-1, -2, -3]})])
@pytest.mark.parametrize(
"value", [-10, pd.DataFrame({"a": [-1, -2, -3]}), "abc"]
)
def test_dataframe_setitem_columns(df, arg, value):
gdf = cudf.from_pandas(df)
cudf_replace_value = value
Expand Down

0 comments on commit c786258

Please sign in to comment.