Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Fix find_common_dtype and values to handle complex dtypes #12537

Merged
merged 12 commits into from
Mar 23, 2023
17 changes: 14 additions & 3 deletions python/cudf/cudf/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,9 +485,20 @@ def get_column_values_na(col):
)

if dtype is None:
dtype = find_common_type(
[col.dtype for col in self._data.values()]
)
dtypes = [col.dtype for col in self._data.values()]
for dtype in dtypes:
if isinstance(
dtype,
(
cudf.ListDtype,
cudf.core.dtypes.DecimalDtype,
cudf.StructDtype,
),
):
raise NotImplementedError(
f"{dtype} cannot be exposed as a cupy array"
)
dtype = find_common_type(dtypes)

matrix = make_empty_matrix(
shape=(len(self), ncol), dtype=dtype, order="F"
Expand Down
13 changes: 13 additions & 0 deletions python/cudf/cudf/tests/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -1869,3 +1869,16 @@ def test_concat_invalid_axis(axis):
s = gd.Series([1, 2, 3])
with pytest.raises(ValueError):
gd.concat([s], axis=axis)


@pytest.mark.parametrize(
"s1,s2",
[
([1, 2], [[1, 2], [3, 4]]),
],
)
def test_concat_mixed_list_types_error(s1, s2):
s1, s2 = gd.Series(s1), gd.Series(s2)

with pytest.raises(NotImplementedError):
gd.concat([s1, s2], ignore_index=True)
14 changes: 14 additions & 0 deletions python/cudf/cudf/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -10031,6 +10031,20 @@ def test_dataframe_transpose_complex_types(data):
assert_eq(expected, actual)


@pytest.mark.parametrize(
"data",
[
{"col": [{"a": 1.1}, {"a": 2.1}, {"a": 10.0}, {"a": 11.2323}, None]},
{"a": [[{"b": 567}], None] * 10},
{"a": [decimal.Decimal(10), decimal.Decimal(20), None]},
],
)
def test_dataframe_values_complex_types(data):
gdf = cudf.DataFrame(data)
with pytest.raises(NotImplementedError):
gdf.values


def test_dataframe_from_arrow_slice():
table = pa.Table.from_pandas(
pd.DataFrame.from_dict(
Expand Down
21 changes: 21 additions & 0 deletions python/cudf/cudf/utils/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,27 @@ def find_common_type(dtypes):
)
else:
return cudf.dtype("O")
if any(cudf.api.types.is_list_dtype(dtype) for dtype in dtypes):
if len(dtypes) == 1:
return dtypes.get(0)
else:
# TODO: As list dtypes allow casting
# to identical types, improve this logic of returning a
# common dtype, for example:
# ListDtype(int64) & ListDtype(int32) common
# dtype could be ListDtype(int64).
raise NotImplementedError(
"Finding a common type for `ListDtype` is currently "
"not supported"
)
if any(cudf.api.types.is_struct_dtype(dtype) for dtype in dtypes):
if len(dtypes) == 1:
return dtypes.get(0)
else:
raise NotImplementedError(
"Finding a common type for `StructDtype` is currently "
"not supported"
)

# Corner case 1:
# Resort to np.result_type to handle "M" and "m" types separately
Expand Down