From dc40e5fba1c9ace6da3de14158bb6195bed6fc58 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Mon, 8 Jan 2024 16:49:14 +0100 Subject: [PATCH] GH-39217: [Python] RecordBatchReader.from_stream constructor for objects implementing the Arrow PyCapsule protocol (#39218) ### Rationale for this change In contrast to Array, RecordBatch and Schema, for the C Stream (mapping to RecordBatchReader) we haven't an equivalent factory function that can accept any Arrow-compatible object and turn it into a pyarrow object through the PyCapsule Protocol. For that reason, this proposes an explicit constructor class method for this: `RecordBatchReader.from_stream` (this is a quite generic name, so other name suggestions are certainly welcome). ### Are these changes tested? TODO * Closes: #39217 Authored-by: Joris Van den Bossche Signed-off-by: Joris Van den Bossche --- python/pyarrow/ipc.pxi | 43 +++++++++++++++++++++++++++++ python/pyarrow/tests/test_array.py | 4 +-- python/pyarrow/tests/test_ipc.py | 44 ++++++++++++++++++++++++++++++ python/pyarrow/tests/test_table.py | 12 ++++---- 4 files changed, 95 insertions(+), 8 deletions(-) diff --git a/python/pyarrow/ipc.pxi b/python/pyarrow/ipc.pxi index ae52f5cf34e8b..da9636dfc86e1 100644 --- a/python/pyarrow/ipc.pxi +++ b/python/pyarrow/ipc.pxi @@ -883,6 +883,49 @@ cdef class RecordBatchReader(_Weakrefable): self.reader = c_reader return self + @staticmethod + def from_stream(data, schema=None): + """ + Create RecordBatchReader from a Arrow-compatible stream object. + + This accepts objects implementing the Arrow PyCapsule Protocol for + streams, i.e. objects that have a ``__arrow_c_stream__`` method. + + Parameters + ---------- + data : Arrow-compatible stream object + Any object that implements the Arrow PyCapsule Protocol for + streams. + schema : Schema, default None + The schema to which the stream should be casted, if supported + by the stream object. + + Returns + ------- + RecordBatchReader + """ + + if not hasattr(data, "__arrow_c_stream__"): + raise TypeError( + "Expected an object implementing the Arrow PyCapsule Protocol for " + "streams (i.e. having a `__arrow_c_stream__` method), " + f"got {type(data)!r}." + ) + + if schema is not None: + if not hasattr(schema, "__arrow_c_schema__"): + raise TypeError( + "Expected an object implementing the Arrow PyCapsule Protocol for " + "schema (i.e. having a `__arrow_c_schema__` method), " + f"got {type(schema)!r}." + ) + requested = schema.__arrow_c_schema__() + else: + requested = None + + capsule = data.__arrow_c_stream__(requested) + return RecordBatchReader._import_from_c_capsule(capsule) + @staticmethod def from_batches(Schema schema not None, batches): """ diff --git a/python/pyarrow/tests/test_array.py b/python/pyarrow/tests/test_array.py index d598630dc2103..3dcbf399f3aca 100644 --- a/python/pyarrow/tests/test_array.py +++ b/python/pyarrow/tests/test_array.py @@ -3351,8 +3351,8 @@ class ArrayWrapper: def __init__(self, data): self.data = data - def __arrow_c_array__(self, requested_type=None): - return self.data.__arrow_c_array__(requested_type) + def __arrow_c_array__(self, requested_schema=None): + return self.data.__arrow_c_array__(requested_schema) # Can roundtrip through the C array protocol arr = ArrayWrapper(pa.array([1, 2, 3], type=pa.int64())) diff --git a/python/pyarrow/tests/test_ipc.py b/python/pyarrow/tests/test_ipc.py index 450d26e3b771c..f75ec8158a9da 100644 --- a/python/pyarrow/tests/test_ipc.py +++ b/python/pyarrow/tests/test_ipc.py @@ -1194,3 +1194,47 @@ def make_batches(): with pytest.raises(TypeError): reader = pa.RecordBatchReader.from_batches(None, batches) pass + + +def test_record_batch_reader_from_arrow_stream(): + + class StreamWrapper: + def __init__(self, batches): + self.batches = batches + + def __arrow_c_stream__(self, requested_schema=None): + reader = pa.RecordBatchReader.from_batches( + self.batches[0].schema, self.batches) + return reader.__arrow_c_stream__(requested_schema) + + data = [ + pa.record_batch([pa.array([1, 2, 3], type=pa.int64())], names=['a']), + pa.record_batch([pa.array([4, 5, 6], type=pa.int64())], names=['a']) + ] + wrapper = StreamWrapper(data) + + # Can roundtrip a pyarrow stream-like object + expected = pa.Table.from_batches(data) + reader = pa.RecordBatchReader.from_stream(expected) + assert reader.read_all() == expected + + # Can roundtrip through the wrapper. + reader = pa.RecordBatchReader.from_stream(wrapper) + assert reader.read_all() == expected + + # Passing schema works if already that schema + reader = pa.RecordBatchReader.from_stream(wrapper, schema=data[0].schema) + assert reader.read_all() == expected + + # If schema doesn't match, raises NotImplementedError + with pytest.raises(NotImplementedError): + pa.RecordBatchReader.from_stream( + wrapper, schema=pa.schema([pa.field('a', pa.int32())]) + ) + + # Proper type errors for wrong input + with pytest.raises(TypeError): + pa.RecordBatchReader.from_stream(data[0]['a']) + + with pytest.raises(TypeError): + pa.RecordBatchReader.from_stream(expected, schema=data[0]) diff --git a/python/pyarrow/tests/test_table.py b/python/pyarrow/tests/test_table.py index a678f521e38d5..ee036f136c77b 100644 --- a/python/pyarrow/tests/test_table.py +++ b/python/pyarrow/tests/test_table.py @@ -558,8 +558,8 @@ class BatchWrapper: def __init__(self, batch): self.batch = batch - def __arrow_c_array__(self, requested_type=None): - return self.batch.__arrow_c_array__(requested_type) + def __arrow_c_array__(self, requested_schema=None): + return self.batch.__arrow_c_array__(requested_schema) data = pa.record_batch([ pa.array([1, 2, 3], type=pa.int64()) @@ -586,8 +586,8 @@ class BatchWrapper: def __init__(self, batch): self.batch = batch - def __arrow_c_array__(self, requested_type=None): - return self.batch.__arrow_c_array__(requested_type) + def __arrow_c_array__(self, requested_schema=None): + return self.batch.__arrow_c_array__(requested_schema) data = pa.record_batch([ pa.array([1, 2, 3], type=pa.int64()) @@ -615,10 +615,10 @@ class StreamWrapper: def __init__(self, batches): self.batches = batches - def __arrow_c_stream__(self, requested_type=None): + def __arrow_c_stream__(self, requested_schema=None): reader = pa.RecordBatchReader.from_batches( self.batches[0].schema, self.batches) - return reader.__arrow_c_stream__(requested_type) + return reader.__arrow_c_stream__(requested_schema) data = [ pa.record_batch([pa.array([1, 2, 3], type=pa.int64())], names=['a']),