From 3b8fee0635faf945ec9c1a43f023e62718f93204 Mon Sep 17 00:00:00 2001 From: anjakefala Date: Tue, 19 Jul 2022 14:50:22 -0700 Subject: [PATCH 1/2] ARROW-17131: [Python] add StructType().field(): returns a field by name or index --- python/pyarrow/tests/test_types.py | 10 ++++++ python/pyarrow/types.pxi | 55 ++++++++++++++++++++++++++---- 2 files changed, 59 insertions(+), 6 deletions(-) diff --git a/python/pyarrow/tests/test_types.py b/python/pyarrow/tests/test_types.py index 8cb7cea684274..cabf69ed07af0 100644 --- a/python/pyarrow/tests/test_types.py +++ b/python/pyarrow/tests/test_types.py @@ -577,14 +577,24 @@ def test_struct_type(): assert ty['b'] == ty[2] + assert ty['b'] == ty.field('b') + + assert ty[2] == ty.field(2) + # Not found with pytest.raises(KeyError): ty['c'] + with pytest.raises(KeyError): + ty.field('c') + # Neither integer nor string with pytest.raises(TypeError): ty[None] + with pytest.raises(TypeError): + ty.field(None) + for a, b in zip(ty, fields): a == b diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index 8407f95c984c3..1dae52f2fef81 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -429,12 +429,23 @@ cdef class StructType(DataType): Examples -------- >>> import pyarrow as pa + + Accessing fields using direct indexing: + >>> struct_type = pa.struct({'x': pa.int32(), 'y': pa.string()}) >>> struct_type[0] pyarrow.Field >>> struct_type['y'] pyarrow.Field + Accessing fields using ``field()``: + + >>> struct_type.field(1) + pyarrow.Field + >>> struct_type.field('x') + pyarrow.Field + + # Creating a schema from the struct type's fields: >>> pa.schema(list(struct_type)) x: int32 y: string @@ -494,6 +505,41 @@ cdef class StructType(DataType): """ return self.struct_type.GetFieldIndex(tobytes(name)) + def field(self, i): + """ + Select a field by its column name or numeric index. + + Parameters + ---------- + i : int or str + + Returns + ------- + pyarrow.Field + + Examples + -------- + + >>> import pyarrow as pa + >>> struct_type = pa.struct({'x': pa.int32(), 'y': pa.string()}) + + Select the second field: + + >>> struct_type.field(1) + pyarrow.Field + + Select the field named 'x': + + >>> struct_type.field('x') + pyarrow.Field + """ + if isinstance(i, (bytes, str)): + return self.field_by_name(i) + elif isinstance(i, int): + return DataType.field(self, i) + else: + raise TypeError('Expected integer or string index') + def get_all_field_indices(self, name): """ Return sorted list of indices for the fields with the given name. @@ -525,13 +571,10 @@ cdef class StructType(DataType): def __getitem__(self, i): """ Return the struct field with the given index or name. + + Alias of ``field``. """ - if isinstance(i, (bytes, str)): - return self.field_by_name(i) - elif isinstance(i, int): - return self.field(i) - else: - raise TypeError('Expected integer or string index') + return self.field(i) def __reduce__(self): return struct, (list(self),) From add10bdc632f7a064120489fe37f3d950cd4683b Mon Sep 17 00:00:00 2001 From: anjakefala Date: Wed, 20 Jul 2022 12:04:50 -0700 Subject: [PATCH 2/2] ARROW-17131: [Python] add public-facing field() to UnionType --- python/pyarrow/tests/test_types.py | 1 + python/pyarrow/types.pxi | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/python/pyarrow/tests/test_types.py b/python/pyarrow/tests/test_types.py index cabf69ed07af0..0ef9f5a86ec6f 100644 --- a/python/pyarrow/tests/test_types.py +++ b/python/pyarrow/tests/test_types.py @@ -644,6 +644,7 @@ def test_union_type(): def check_fields(ty, fields): assert ty.num_fields == len(fields) assert [ty[i] for i in range(ty.num_fields)] == fields + assert [ty.field(i) for i in range(ty.num_fields)] == fields fields = [pa.field('x', pa.list_(pa.int32())), pa.field('y', pa.binary())] diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index 1dae52f2fef81..1babbc41549c7 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -622,9 +622,28 @@ cdef class UnionType(DataType): for i in range(len(self)): yield self[i] + def field(self, i): + """ + Return a child field by its numeric index. + + Parameters + ---------- + i : int + + Returns + ------- + pyarrow.Field + """ + if isinstance(i, int): + return DataType.field(self, i) + else: + raise TypeError('Expected integer') + def __getitem__(self, i): """ Return a child field by its index. + + Alias of ``field``. """ return self.field(i)