Skip to content

Commit

Permalink
Add struct accessor to dask-cudf (#8874)
Browse files Browse the repository at this point in the history
This PR implements 'Struct Accessor' requested feature in dask-cudf (Issue [#8658](#8658))

StructMethod class implemented to expose 'field(key)' method in dask-cudf

        Examples
        --------
        >>> s = cudf.Series([{'a': 1, 'b': 2}, {'a': 3, 'b': 4}])
        >>> ds = dask_cudf.from_cudf(s, 2)
        >>> ds.struct.field(0).compute()
        0    1
        1    3
        dtype: int64
        >>> ds.struct.field('a').compute()
        0    1
        1    3
        dtype: int64

Authors:
  - https://github.com/NV-jpt
  - https://github.com/shaneding

Approvers:
  - Richard (Rick) Zamora (https://github.com/rjzamora)
  - Ashwin Srinath (https://github.com/shwina)

URL: #8874
  • Loading branch information
NV-jpt authored Aug 24, 2021
1 parent c271ce2 commit abba33f
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 1 deletion.
37 changes: 37 additions & 0 deletions python/dask_cudf/dask_cudf/accessors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,43 @@
# Copyright (c) 2021, NVIDIA CORPORATION.


class StructMethods:
def __init__(self, d_series):
self.d_series = d_series

def field(self, key):
"""
Extract children of the specified struct column
in the Series
Parameters
----------
key: int or str
index/position or field name of the respective
struct column
Returns
-------
Series
Examples
--------
>>> s = cudf.Series([{'a': 1, 'b': 2}, {'a': 3, 'b': 4}])
>>> ds = dask_cudf.from_cudf(s, 2)
>>> ds.struct.field(0).compute()
0 1
1 3
dtype: int64
>>> ds.struct.field('a').compute()
0 1
1 3
dtype: int64
"""
typ = self.d_series._meta.struct.field(key).dtype

return self.d_series.map_partitions(
lambda s: s.struct.field(key),
meta=self.d_series._meta._constructor([], dtype=typ),
)


class ListMethods:
def __init__(self, d_series):
self.d_series = d_series
Expand Down
6 changes: 5 additions & 1 deletion python/dask_cudf/dask_cudf/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from cudf import _lib as libcudf

from dask_cudf import sorting
from dask_cudf.accessors import ListMethods
from dask_cudf.accessors import ListMethods, StructMethods

DASK_VERSION = LooseVersion(dask.__version__)

Expand Down Expand Up @@ -414,6 +414,10 @@ def groupby(self, *args, **kwargs):
def list(self):
return ListMethods(self)

@property
def struct(self):
return StructMethods(self)


class Index(Series, dd.core.Index):
_partition_type = cudf.Index
Expand Down
62 changes: 62 additions & 0 deletions python/dask_cudf/dask_cudf/tests/test_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,3 +438,65 @@ def test_sorting(data, ascending, na_position, ignore_index):
.reset_index(drop=True)
)
assert_eq(expect, got)


#############################################################################
# Struct Accessor #
#############################################################################
struct_accessor_data_params = [
[{"a": 5, "b": 10}, {"a": 3, "b": 7}, {"a": -3, "b": 11}],
[{"a": None, "b": 1}, {"a": None, "b": 0}, {"a": -3, "b": None}],
[{"a": 1, "b": 2}],
[{"a": 1, "b": 3, "c": 4}],
]


@pytest.mark.parametrize(
"data", struct_accessor_data_params,
)
def test_create_struct_series(data):
expect = pd.Series(data)
ds_got = dgd.from_cudf(Series(data), 2)
assert_eq(expect, ds_got.compute())


@pytest.mark.parametrize(
"data", struct_accessor_data_params,
)
def test_struct_field_str(data):
for test_key in ["a", "b"]:
expect = Series(data).struct.field(test_key)
ds_got = dgd.from_cudf(Series(data), 2).struct.field(test_key)
assert_eq(expect, ds_got.compute())


@pytest.mark.parametrize(
"data", struct_accessor_data_params,
)
def test_struct_field_integer(data):
for test_key in [0, 1]:
expect = Series(data).struct.field(test_key)
ds_got = dgd.from_cudf(Series(data), 2).struct.field(test_key)
assert_eq(expect, ds_got.compute())


@pytest.mark.parametrize(
"data", struct_accessor_data_params,
)
def test_dask_struct_field_Key_Error(data):
got = dgd.from_cudf(Series(data), 2)

# import pdb; pdb.set_trace()
with pytest.raises(KeyError):
got.struct.field("notakey").compute()


@pytest.mark.parametrize(
"data", struct_accessor_data_params,
)
def test_dask_struct_field_Int_Error(data):
# breakpoint()
got = dgd.from_cudf(Series(data), 2)

with pytest.raises(IndexError):
got.struct.field(1000).compute()

0 comments on commit abba33f

Please sign in to comment.