From 2c7723cce69b80aff3d09c78ab60b698e7a35a75 Mon Sep 17 00:00:00 2001 From: skirui-source <71867292+skirui-source@users.noreply.github.com> Date: Tue, 6 Apr 2021 11:38:56 -0700 Subject: [PATCH] Add `StructMethods.field()` API to access field of struct column (#7757) fixes: #7608 Authors: - https://github.com/skirui-source Approvers: - Keith Kraus (https://github.com/kkraus14) URL: https://github.com/rapidsai/cudf/pull/7757 --- python/cudf/cudf/core/column/struct.py | 60 ++++++++++++++++++++++++++ python/cudf/cudf/core/series.py | 6 +++ python/cudf/cudf/tests/test_struct.py | 10 +++++ python/cudf/cudf/utils/dtypes.py | 3 ++ 4 files changed, 79 insertions(+) diff --git a/python/cudf/cudf/core/column/struct.py b/python/cudf/cudf/core/column/struct.py index adaf62ffc25..266e448cdf3 100644 --- a/python/cudf/cudf/core/column/struct.py +++ b/python/cudf/cudf/core/column/struct.py @@ -5,9 +5,19 @@ import cudf from cudf.core.column import ColumnBase +from cudf.core.column.methods import ColumnMethodsMixin +from cudf.utils.dtypes import is_struct_dtype class StructColumn(ColumnBase): + """ + Column that stores fields of values. + + Every column has n children, where n is + the number of fields in the Struct Dtype. + + """ + dtype: cudf.core.dtypes.StructDtype @property @@ -74,6 +84,9 @@ def copy(self, deep=True): result = result._rename_fields(self.dtype.fields.keys()) return result + def struct(self, parent=None): + return StructMethods(self, parent=parent) + def _rename_fields(self, names): """ Return a StructColumn with the same field values as this StructColumn, @@ -91,3 +104,50 @@ def _rename_fields(self, names): null_count=self.null_count, children=self.base_children, ) + + +class StructMethods(ColumnMethodsMixin): + """ + Struct methods for Series + """ + + def __init__(self, column, parent=None): + if not is_struct_dtype(column.dtype): + raise AttributeError( + "Can only use .struct accessor with a 'struct' dtype" + ) + super().__init__(column=column, parent=parent) + + def field(self, key): + """ + Extract children of the specified struct column + in the Series + + Parameters + ---------- + key: int or str + index/position or field name of the respective + struct column + + Returns + ------- + Series + + Examples + -------- + >>> s = cudf.Series([{'a': 1, 'b': 2}, {'a': 3, 'b': 4}]) + >>> s.struct.field(0) + 0 1 + 1 3 + dtype: int64 + >>> s.struct.field('a') + 0 1 + 1 3 + dtype: int64 + """ + fields = list(self._column.dtype.fields.keys()) + if key in fields: + pos = fields.index(key) + return self._return_or_inplace(self._column.children[pos]) + else: + return self._return_or_inplace(self._column.children[key]) diff --git a/python/cudf/cudf/core/series.py b/python/cudf/cudf/core/series.py index 955519d0b57..55fd510f03a 100644 --- a/python/cudf/cudf/core/series.py +++ b/python/cudf/cudf/core/series.py @@ -36,6 +36,7 @@ ) from cudf.core.column.lists import ListMethods from cudf.core.column.string import StringMethods +from cudf.core.column.struct import StructMethods from cudf.core.column_accessor import ColumnAccessor from cudf.core.frame import Frame, _drop_rows_by_labels from cudf.core.groupby.groupby import SeriesGroupBy @@ -2675,6 +2676,11 @@ def str(self): def list(self): return ListMethods(column=self._column, parent=self) + @copy_docstring(StructMethods.__init__) # type: ignore + @property + def struct(self): + return StructMethods(column=self._column, parent=self) + @property def dtype(self): """dtype of the Series""" diff --git a/python/cudf/cudf/tests/test_struct.py b/python/cudf/cudf/tests/test_struct.py index c7efb55c089..3c211951dff 100644 --- a/python/cudf/cudf/tests/test_struct.py +++ b/python/cudf/cudf/tests/test_struct.py @@ -34,3 +34,13 @@ def test_struct_of_struct_loc(): df = cudf.DataFrame({"col": [{"a": {"b": 1}}]}) expect = cudf.Series([{"a": {"b": 1}}], name="col") assert_eq(expect, df["col"]) + + +@pytest.mark.parametrize( + "key, expect", [(0, [1, 3]), (1, [2, 4]), ("a", [1, 3]), ("b", [2, 4])] +) +def test_struct_for_field(key, expect): + sr = cudf.Series([{"a": 1, "b": 2}, {"a": 3, "b": 4}]) + expect = cudf.Series(expect) + got = sr.struct.field(key) + assert_eq(expect, got) diff --git a/python/cudf/cudf/utils/dtypes.py b/python/cudf/cudf/utils/dtypes.py index 5cb0391d76f..a8ff2177154 100644 --- a/python/cudf/cudf/utils/dtypes.py +++ b/python/cudf/cudf/utils/dtypes.py @@ -306,6 +306,9 @@ def cudf_dtype_to_pa_type(dtype): def cudf_dtype_from_pa_type(typ): + """ Given a cuDF pyarrow dtype, converts it into the equivalent + cudf pandas dtype. + """ if pa.types.is_list(typ): return cudf.core.dtypes.ListDtype.from_arrow(typ) elif pa.types.is_struct(typ):