Skip to content

Commit

Permalink
Raise when multi-dimensional arrays are passed to column constructor (#…
Browse files Browse the repository at this point in the history
…10)

Fixes: #14151

This PR introduces shape validation in the column constructor function for __cuda_array_interface__, cudf only supports 1-dimensional columns.
  • Loading branch information
galipremsagar authored Sep 30, 2023
1 parent 952f2bc commit 477d030
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 1 deletion.
3 changes: 3 additions & 0 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -1938,6 +1938,9 @@ def as_column(

elif hasattr(arbitrary, "__cuda_array_interface__"):
desc = arbitrary.__cuda_array_interface__
shape = desc["shape"]
if len(shape) > 1:
raise ValueError("Data must be 1-dimensional")
current_dtype = np.dtype(desc["typestr"])

arb_dtype = (
Expand Down
4 changes: 3 additions & 1 deletion python/cudf/cudf/core/column/numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@ def indices_of(self, value: ScalarLike) -> NumericalColumn:
and np.isnan(value)
):
return column.as_column(
cp.argwhere(cp.isnan(self.data_array_view(mode="read"))),
cp.argwhere(
cp.isnan(self.data_array_view(mode="read"))
).flatten(),
dtype=size_type_dtype,
)
else:
Expand Down
6 changes: 6 additions & 0 deletions python/cudf/cudf/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2328,6 +2328,12 @@ def test_series_count_invalid_param():
s.count(skipna=True)


def test_multi_dim_series_error():
arr = cp.array([(1, 2), (3, 4)])
with pytest.raises(ValueError):
cudf.Series(arr)


def test_bool_series_mixed_dtype_error():
ps = pd.Series([True, False, None])
# ps now has `object` dtype, which
Expand Down

0 comments on commit 477d030

Please sign in to comment.