Skip to content

Commit

Permalink
ARROW-17131: [Python] add StructType().field(): returns a field by na…
Browse files Browse the repository at this point in the history
…me or index
  • Loading branch information
anjakefala committed Jul 20, 2022
1 parent 6e3f26a commit 10bc33e
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 6 deletions.
10 changes: 10 additions & 0 deletions python/pyarrow/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,14 +577,24 @@ def test_struct_type():

assert ty['b'] == ty[2]

assert ty['b'] == ty.field('b')

assert ty[2] == ty.field(2)

# Not found
with pytest.raises(KeyError):
ty['c']

with pytest.raises(KeyError):
ty.field('c')

# Neither integer nor string
with pytest.raises(TypeError):
ty[None]

with pytest.raises(TypeError):
ty.field(None)

for a, b in zip(ty, fields):
a == b

Expand Down
53 changes: 47 additions & 6 deletions python/pyarrow/types.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -429,12 +429,23 @@ cdef class StructType(DataType):
Examples
--------
>>> import pyarrow as pa
Accessing fields using direct indexing:
>>> struct_type = pa.struct({'x': pa.int32(), 'y': pa.string()})
>>> struct_type[0]
pyarrow.Field<x: int32>
>>> struct_type['y']
pyarrow.Field<y: string>
Accessing fields using ``field()``:
>>> struct_type.field(1)
pyarrow.Field<y: string>
>>> struct_type.field('x')
pyarrow.Field<x: int32>
>>> pa.schema(list(struct_type))
x: int32
y: string
Expand Down Expand Up @@ -494,6 +505,41 @@ cdef class StructType(DataType):
"""
return self.struct_type.GetFieldIndex(tobytes(name))

def field(self, i):
"""
Select a field by its column name or numeric index.
Parameters
----------
i : int or str
Returns
-------
pyarrow.Field
Examples
--------
>>> import pyarrow as pa
>>> struct_type = pa.struct({'x': pa.int32(), 'y': pa.string()})
Select the second field:
>>> struct_type.field(1)
pyarrow.Field<y: string>
Select the field of the column named 'x':
>>> struct_type.field('x')
pyarrow.Field<x: int32>
"""
if isinstance(i, (bytes, str)):
return self.field_by_name(i)
elif isinstance(i, int):
return DataType.field(self, i)
else:
raise TypeError('Expected integer or string index')

def get_all_field_indices(self, name):
"""
Return sorted list of indices for the fields with the given name.
Expand Down Expand Up @@ -526,12 +572,7 @@ cdef class StructType(DataType):
"""
Return the struct field with the given index or name.
"""
if isinstance(i, (bytes, str)):
return self.field_by_name(i)
elif isinstance(i, int):
return self.field(i)
else:
raise TypeError('Expected integer or string index')
return self.field(i)

def __reduce__(self):
return struct, (list(self),)
Expand Down

0 comments on commit 10bc33e

Please sign in to comment.