Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Fix find_common_dtype and values to handle complex dtypes #12537

Merged
merged 12 commits into from
Mar 23, 2023
4 changes: 2 additions & 2 deletions python/cudf/cudf/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,8 +496,8 @@ def get_column_values_na(col):
),
):
raise NotImplementedError(
f"{dtype} is not yet supported via "
"`__cuda_array_interface__`"
f"{dtype} is not yet supported to be exported to"
"a cupy array"
galipremsagar marked this conversation as resolved.
Show resolved Hide resolved
)
dtype = find_common_type(dtypes)

Expand Down
11 changes: 5 additions & 6 deletions python/cudf/cudf/tests/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -1872,14 +1872,13 @@ def test_concat_invalid_axis(axis):


@pytest.mark.parametrize(
"s1,s2,expected",
"s1,s2",
[
([1, 2], [[1, 2], [3, 4]], ["1", "2", "[1, 2]", "[3, 4]"]),
([1, 2], [[1, 2], [3, 4]]),
],
)
def test_concat_mixed_list_types(s1, s2, expected):
def test_concat_mixed_list_types_error(s1, s2):
s1, s2 = gd.Series(s1), gd.Series(s2)
expected = pd.Series(expected)
actual = gd.concat([s1, s2], ignore_index=True)

assert_eq(expected, actual, check_dtype=False)
with pytest.raises(NotImplementedError):
gd.concat([s1, s2], ignore_index=True)
3 changes: 2 additions & 1 deletion python/cudf/cudf/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -10030,6 +10030,7 @@ def test_dataframe_transpose_complex_types(data):

assert_eq(expected, actual)


@pytest.mark.parametrize(
"data",
[
Expand All @@ -10043,6 +10044,7 @@ def test_dataframe_values_complex_types(data):
with pytest.raises(NotImplementedError):
gdf.values


def test_dataframe_from_arrow_slice():
table = pa.Table.from_pandas(
pd.DataFrame.from_dict(
Expand All @@ -10055,4 +10057,3 @@ def test_dataframe_from_arrow_slice():
actual = cudf.DataFrame.from_arrow(table_slice)

assert_eq(expected, actual)

5 changes: 4 additions & 1 deletion python/cudf/cudf/utils/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,10 @@ def find_common_type(dtypes):
# common dtype, for example:
# ListDtype(int64) & ListDtype(int32) common
# dtype could be ListDtype(int64).
return cudf.dtype("O")
raise NotImplementedError(
"Finding a common type for `ListDtype` is currently "
"not supported"
)
if any(cudf.api.types.is_struct_dtype(dtype) for dtype in dtypes):
if len(dtypes) == 1:
return dtypes.get(0)
Expand Down