From 73f2c44511bebad745e8f696e6d0fe8641fac839 Mon Sep 17 00:00:00 2001 From: shaneding Date: Thu, 15 Jul 2021 12:51:20 -0400 Subject: [PATCH] Implement `__setitem__` for `StructColumn` (#8737) Closes #8558 Authors: - https://github.com/shaneding Approvers: - Ashwin Srinath (https://github.com/shwina) - https://github.com/brandon-b-miller URL: https://github.com/rapidsai/cudf/pull/8737 --- python/cudf/cudf/_lib/scalar.pyx | 2 +- python/cudf/cudf/core/column/struct.py | 9 ++++++ python/cudf/cudf/core/indexing.py | 7 ++-- python/cudf/cudf/core/scalar.py | 8 ++--- python/cudf/cudf/core/series.py | 1 + python/cudf/cudf/tests/test_struct.py | 45 ++++++++++++++++++++++++++ 6 files changed, 64 insertions(+), 8 deletions(-) diff --git a/python/cudf/cudf/_lib/scalar.pyx b/python/cudf/cudf/_lib/scalar.pyx index 9e50f42d625..5b447513c95 100644 --- a/python/cudf/cudf/_lib/scalar.pyx +++ b/python/cudf/cudf/_lib/scalar.pyx @@ -326,7 +326,7 @@ cdef _set_struct_from_pydict(unique_ptr[scalar]& s, else: pyarrow_table = pa.Table.from_arrays( [ - pa.array([], from_pandas=True, type=f.type) + pa.array([cudf.NA], from_pandas=True, type=f.type) for f in arrow_schema ], names=columns diff --git a/python/cudf/cudf/core/column/struct.py b/python/cudf/cudf/core/column/struct.py index 6bcc594ab22..df9a601d6ef 100644 --- a/python/cudf/cudf/core/column/struct.py +++ b/python/cudf/cudf/core/column/struct.py @@ -89,6 +89,15 @@ def __getitem__(self, args): } return result + def __setitem__(self, key, value): + if isinstance(value, dict): + # filling in fields not in dict + for field in self.dtype.fields: + value[field] = value.get(field, cudf.NA) + + value = cudf.Scalar(value, self.dtype) + super().__setitem__(key, value) + def copy(self, deep=True): result = super().copy(deep=deep) if deep: diff --git a/python/cudf/cudf/core/indexing.py b/python/cudf/cudf/core/indexing.py index 933fd768d7c..a4a69a4e084 100755 --- a/python/cudf/cudf/core/indexing.py +++ b/python/cudf/cudf/core/indexing.py @@ -111,10 +111,13 @@ def __setitem__(self, key, value): if is_scalar(value): value = to_cudf_compatible_scalar(value) elif not ( - isinstance(value, list) - and isinstance(self._sr._column.dtype, cudf.ListDtype) + isinstance(value, (list, dict)) + and isinstance( + self._sr._column.dtype, (cudf.ListDtype, cudf.StructDtype) + ) ): value = column.as_column(value) + if ( not isinstance( self._sr._column.dtype, diff --git a/python/cudf/cudf/core/scalar.py b/python/cudf/cudf/core/scalar.py index db9bc6d6c85..c6663a25684 100644 --- a/python/cudf/cudf/core/scalar.py +++ b/python/cudf/cudf/core/scalar.py @@ -133,15 +133,13 @@ def _preprocess_host_value(self, value, dtype): return NA, dtype if isinstance(value, dict): - if dtype is not None: - raise TypeError("dict may not be cast to a different dtype") - else: + if dtype is None: dtype = StructDtype.from_arrow( pa.infer_type([value], from_pandas=True) ) - return value, dtype + return value, dtype elif isinstance(dtype, StructDtype): - if value is not None: + if value not in {None, NA}: raise ValueError(f"Can not coerce {value} to StructDType") else: return NA, dtype diff --git a/python/cudf/cudf/core/series.py b/python/cudf/cudf/core/series.py index 68f2b42483b..cb18cd0de6d 100644 --- a/python/cudf/cudf/core/series.py +++ b/python/cudf/cudf/core/series.py @@ -1238,6 +1238,7 @@ def __repr__(self): preprocess._column, cudf.core.column.CategoricalColumn ) and not is_list_dtype(preprocess.dtype) + and not is_struct_dtype(preprocess.dtype) and not is_decimal_dtype(preprocess.dtype) ) or isinstance( preprocess._column, cudf.core.column.timedelta.TimeDeltaColumn diff --git a/python/cudf/cudf/tests/test_struct.py b/python/cudf/cudf/tests/test_struct.py index 4f3bb9bda92..b38fd3a5b6c 100644 --- a/python/cudf/cudf/tests/test_struct.py +++ b/python/cudf/cudf/tests/test_struct.py @@ -101,6 +101,51 @@ def test_struct_getitem(series, expected): assert sr[0] == expected +@pytest.mark.parametrize( + "data, item", + [ + ( + [ + {"a": "Hello world", "b": []}, + {"a": "CUDF", "b": [1, 2, 3], "c": cudf.NA}, + {"a": "abcde", "b": [4, 5, 6], "c": 9}, + ], + {"a": "Hello world", "b": [], "c": cudf.NA}, + ), + ( + [ + {"a": "Hello world", "b": []}, + {"a": "CUDF", "b": [1, 2, 3], "c": cudf.NA}, + {"a": "abcde", "b": [4, 5, 6], "c": 9}, + ], + {}, + ), + ( + [ + {"a": "Hello world", "b": []}, + {"a": "CUDF", "b": [1, 2, 3], "c": cudf.NA}, + {"a": "abcde", "b": [4, 5, 6], "c": 9}, + ], + cudf.NA, + ), + ( + [ + {"a": "Hello world", "b": []}, + {"a": "CUDF", "b": [1, 2, 3], "c": cudf.NA}, + {"a": "abcde", "b": [4, 5, 6], "c": 9}, + ], + {"a": "Second element", "b": [1, 2], "c": 1000}, + ), + ], +) +def test_struct_setitem(data, item): + sr = cudf.Series(data) + sr[1] = item + data[1] = item + expected = cudf.Series(data) + assert sr.to_arrow() == expected.to_arrow() + + @pytest.mark.parametrize( "data", [