Skip to content

Commit

Permalink
Add support for converting to/from pyarrow Arrays (#23894)
Browse files Browse the repository at this point in the history
  • Loading branch information
TheNeuralBit authored Nov 1, 2022
1 parent c6f64bb commit 2bf0795
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 1 deletion.
34 changes: 34 additions & 0 deletions sdks/python/apache_beam/typehints/arrow_type_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,3 +348,37 @@ def _from_serialized_schema(serialized_schema):
def __reduce__(self):
return self._from_serialized_schema, (
self._beam_schema.SerializeToString(), )


class PyarrowArrayBatchConverter(BatchConverter):
def __init__(self, element_type: type):
super().__init__(pa.Array, element_type)
self._element_type = element_type
beam_fieldtype = typing_to_runner_api(element_type)
self._arrow_type = _arrow_type_from_beam_fieldtype(beam_fieldtype)

@staticmethod
@BatchConverter.register
def from_typehints(element_type,
batch_type) -> Optional['PyarrowArrayBatchConverter']:
if batch_type == pa.Array:
return PyarrowArrayBatchConverter(element_type)

return None

def produce_batch(self, elements):
return pa.array(list(elements), type=self._arrow_type)

def explode_batch(self, batch: pa.Array):
"""Convert an instance of B to Generator[E]."""
for val in batch:
yield val.as_py()

def combine_batches(self, batches: List[pa.Array]):
return pa.concat_arrays(batches)

def get_length(self, batch: pa.Array):
return batch.num_rows

def estimate_byte_size(self, batch: pa.Array):
return batch.nbytes
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,23 @@ def test_beam_schema_survives_roundtrip(self, beam_schema):
]),
}),
},
{
'batch_typehint': pa.Array,
'element_typehint': int,
'batch': pa.array(range(100), type=pa.int64()),
},
{
'batch_typehint': pa.Array,
'element_typehint': row_type.RowTypeConstraint.from_fields([
("bar", Optional[float]), # noqa: F821
("baz", Optional[str]), # noqa: F821
]),
'batch': pa.array([
{
'bar': i / 100, 'baz': str(i)
} if i % 7 else None for i in range(100)
]),
}
])
@pytest.mark.uses_pyarrow
class ArrowBatchConverterTest(unittest.TestCase):
Expand All @@ -93,7 +110,10 @@ def setUp(self):
self.element_typehint)

def equality_check(self, left, right):
self.assertEqual(left, right)
if isinstance(left, pa.Array):
self.assertTrue(left.equals(right))
else:
self.assertEqual(left, right)

def test_typehint_validates(self):
typehints.validate_composite_type_param(self.batch_typehint, '')
Expand Down

0 comments on commit 2bf0795

Please sign in to comment.