diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pxd b/python/adbc_driver_manager/adbc_driver_manager/_lib.pxd index 358a09aa19..e9ea833c36 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pxd +++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pxd @@ -22,10 +22,15 @@ from libc.stdint cimport int32_t, int64_t, uint8_t, uint32_t cdef extern from "adbc.h" nogil: # C ABI + + ctypedef void (*CArrowSchemaRelease)(void*) + ctypedef void (*CArrowArrayRelease)(void*) + cdef struct CArrowSchema"ArrowSchema": - pass + CArrowSchemaRelease release + cdef struct CArrowArray"ArrowArray": - pass + CArrowArrayRelease release ctypedef int (*CArrowArrayStreamGetLastError)(void*) ctypedef int (*CArrowArrayStreamGetNext)(void*, CArrowArray*) diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx index ced8870ec9..91139100bb 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx +++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx @@ -24,10 +24,15 @@ import threading import typing from typing import List, Tuple +cimport cpython import cython from cpython.bytes cimport PyBytes_FromStringAndSize +from cpython.pycapsule cimport ( + PyCapsule_GetPointer, PyCapsule_New, PyCapsule_CheckExact +) from libc.stdint cimport int32_t, int64_t, uint8_t, uint32_t, uintptr_t -from libc.string cimport memset +from libc.stdlib cimport malloc, free +from libc.string cimport memcpy, memset from libcpp.vector cimport vector as c_vector if typing.TYPE_CHECKING: @@ -304,9 +309,29 @@ cdef class _AdbcHandle: f"with open {self._child_type}") +cdef void pycapsule_schema_deleter(object capsule) noexcept: + cdef CArrowSchema* allocated = PyCapsule_GetPointer( + capsule, "arrow_schema" + ) + if allocated.release != NULL: + allocated.release(allocated) + free(allocated) + + +cdef void pycapsule_stream_deleter(object capsule) noexcept: + cdef CArrowArrayStream* allocated = PyCapsule_GetPointer( + capsule, "arrow_array_stream" + ) + if allocated.release != NULL: + allocated.release(allocated) + free(allocated) + + cdef class ArrowSchemaHandle: """ A wrapper for an allocated ArrowSchema. + + This object implements the Arrow PyCapsule interface. """ cdef: CArrowSchema schema @@ -316,23 +341,42 @@ cdef class ArrowSchemaHandle: """The address of the ArrowSchema.""" return &self.schema + def __arrow_c_schema__(self) -> object: + """Consume this object to get a PyCapsule.""" + # Reference: + # https://arrow.apache.org/docs/dev/format/CDataInterface/PyCapsuleInterface.html#create-a-pycapsule + cdef CArrowSchema* allocated = malloc(sizeof(CArrowSchema)) + allocated.release = NULL + capsule = PyCapsule_New( + allocated, "arrow_schema", &pycapsule_schema_deleter, + ) + memcpy(allocated, &self.schema, sizeof(CArrowSchema)) + self.schema.release = NULL + return capsule + cdef class ArrowArrayHandle: """ A wrapper for an allocated ArrowArray. + + This object implements the Arrow PyCapsule interface. """ cdef: CArrowArray array @property def address(self) -> int: - """The address of the ArrowArray.""" + """ + The address of the ArrowArray. + """ return &self.array cdef class ArrowArrayStreamHandle: """ A wrapper for an allocated ArrowArrayStream. + + This object implements the Arrow PyCapsule interface. """ cdef: CArrowArrayStream stream @@ -342,6 +386,21 @@ cdef class ArrowArrayStreamHandle: """The address of the ArrowArrayStream.""" return &self.stream + def __arrow_c_stream__(self, requested_schema=None) -> object: + """Consume this object to get a PyCapsule.""" + if requested_schema is not None: + raise NotImplementedError("requested_schema") + + cdef CArrowArrayStream* allocated = \ + malloc(sizeof(CArrowArrayStream)) + allocated.release = NULL + capsule = PyCapsule_New( + allocated, "arrow_array_stream", &pycapsule_stream_deleter, + ) + memcpy(allocated, &self.stream, sizeof(CArrowArrayStream)) + self.stream.release = NULL + return capsule + class GetObjectsDepth(enum.IntEnum): ALL = ADBC_OBJECT_DEPTH_ALL @@ -1000,32 +1059,47 @@ cdef class AdbcStatement(_AdbcHandle): connection._open_child() - def bind(self, data, schema) -> None: + def bind(self, data, schema=None) -> None: """ Bind an ArrowArray to this statement. Parameters ---------- - data : int or ArrowArrayHandle - schema : int or ArrowSchemaHandle + data : PyCapsule or int or ArrowArrayHandle + schema : PyCapsule or int or ArrowSchemaHandle """ cdef CAdbcError c_error = empty_error() cdef CArrowArray* c_array cdef CArrowSchema* c_schema - if isinstance(data, ArrowArrayHandle): + if hasattr(data, "__arrow_c_array__") and not isinstance(data, ArrowArrayHandle): + if schema is not None: + raise ValueError( + "Can not provide a schema when passing Arrow-compatible " + "data that implements the Arrow PyCapsule Protocol" + ) + schema, data = data.__arrow_c_array__() + + if PyCapsule_CheckExact(data): + c_array = PyCapsule_GetPointer(data, "arrow_array") + elif isinstance(data, ArrowArrayHandle): c_array = &( data).array elif isinstance(data, int): c_array = data else: - raise TypeError(f"data must be int or ArrowArrayHandle, not {type(data)}") - - if isinstance(schema, ArrowSchemaHandle): + raise TypeError( + "data must be Arrow-compatible data (implementing the Arrow PyCapsule " + f"Protocol), a PyCapsule, int or ArrowArrayHandle, not {type(data)}" + ) + + if PyCapsule_CheckExact(schema): + c_schema = PyCapsule_GetPointer(schema, "arrow_schema") + elif isinstance(schema, ArrowSchemaHandle): c_schema = &( schema).schema elif isinstance(schema, int): c_schema = schema else: - raise TypeError(f"schema must be int or ArrowSchemaHandle, " + raise TypeError("schema must be a PyCapsule, int or ArrowSchemaHandle, " f"not {type(schema)}") with nogil: @@ -1042,17 +1116,27 @@ cdef class AdbcStatement(_AdbcHandle): Parameters ---------- - stream : int or ArrowArrayStreamHandle + stream : PyCapsule or int or ArrowArrayStreamHandle """ cdef CAdbcError c_error = empty_error() cdef CArrowArrayStream* c_stream - if isinstance(stream, ArrowArrayStreamHandle): + if ( + hasattr(stream, "__arrow_c_stream__") + and not isinstance(stream, ArrowArrayStreamHandle) + ): + stream = stream.__arrow_c_stream__() + + if PyCapsule_CheckExact(stream): + c_stream = PyCapsule_GetPointer( + stream, "arrow_array_stream" + ) + elif isinstance(stream, ArrowArrayStreamHandle): c_stream = &( stream).stream elif isinstance(stream, int): c_stream = stream else: - raise TypeError(f"data must be int or ArrowArrayStreamHandle, " + raise TypeError(f"data must be a PyCapsule, int or ArrowArrayStreamHandle, " f"not {type(stream)}") with nogil: diff --git a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py index 8edcdf4f58..4c36ad5cbd 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py +++ b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py @@ -612,17 +612,21 @@ def close(self): self._closed = True def _bind(self, parameters) -> None: - if isinstance(parameters, pyarrow.RecordBatch): + if hasattr(parameters, "__arrow_c_array__"): + self._stmt.bind(parameters) + elif hasattr(parameters, "__arrow_c_stream__"): + self._stmt.bind_stream(parameters) + elif isinstance(parameters, pyarrow.RecordBatch): arr_handle = _lib.ArrowArrayHandle() sch_handle = _lib.ArrowSchemaHandle() parameters._export_to_c(arr_handle.address, sch_handle.address) self._stmt.bind(arr_handle, sch_handle) - return - if isinstance(parameters, pyarrow.Table): - parameters = parameters.to_reader() - stream_handle = _lib.ArrowArrayStreamHandle() - parameters._export_to_c(stream_handle.address) - self._stmt.bind_stream(stream_handle) + else: + if isinstance(parameters, pyarrow.Table): + parameters = parameters.to_reader() + stream_handle = _lib.ArrowArrayStreamHandle() + parameters._export_to_c(stream_handle.address) + self._stmt.bind_stream(stream_handle) def _prepare_execute(self, operation, parameters=None) -> None: self._results = None @@ -639,9 +643,7 @@ def _prepare_execute(self, operation, parameters=None) -> None: # Not all drivers support it pass - if isinstance( - parameters, (pyarrow.RecordBatch, pyarrow.Table, pyarrow.RecordBatchReader) - ): + if _is_arrow_data(parameters): self._bind(parameters) elif parameters: rb = pyarrow.record_batch( @@ -668,7 +670,6 @@ def execute(self, operation: Union[bytes, str], parameters=None) -> None: self._prepare_execute(operation, parameters) handle, self._rowcount = self._stmt.execute_query() self._results = _RowIterator( - # pyarrow.RecordBatchReader._import_from_c(handle.address) _reader.AdbcRecordBatchReader._import_from_c(handle.address) ) @@ -683,7 +684,7 @@ def executemany(self, operation: Union[bytes, str], seq_of_parameters) -> None: operation : bytes or str The query to execute. Pass SQL queries as strings, (serialized) Substrait plans as bytes. - parameters + seq_of_parameters Parameters to bind. Can be a list of Python sequences, or an Arrow record batch, table, or record batch reader. If None, then the query will be executed once, else it will @@ -695,10 +696,7 @@ def executemany(self, operation: Union[bytes, str], seq_of_parameters) -> None: self._stmt.set_sql_query(operation) self._stmt.prepare() - if isinstance( - seq_of_parameters, - (pyarrow.RecordBatch, pyarrow.Table, pyarrow.RecordBatchReader), - ): + if _is_arrow_data(seq_of_parameters): arrow_parameters = seq_of_parameters elif seq_of_parameters: arrow_parameters = pyarrow.RecordBatch.from_pydict( @@ -806,7 +804,10 @@ def adbc_ingest( table_name The table to insert into. data - The Arrow data to insert. + The Arrow data to insert. This can be a pyarrow RecordBatch, Table + or RecordBatchReader, or any Arrow-compatible data that implements + the Arrow PyCapsule Protocol (i.e. has an ``__arrow_c_array__`` + or ``__arrow_c_stream__ ``method). mode How to deal with existing data: @@ -878,7 +879,11 @@ def adbc_ingest( except NotSupportedError: pass - if isinstance(data, pyarrow.RecordBatch): + if hasattr(data, "__arrow_c_array__"): + self._stmt.bind(data) + elif hasattr(data, "__arrow_c_stream__"): + self._stmt.bind_stream(data) + elif isinstance(data, pyarrow.RecordBatch): array = _lib.ArrowArrayHandle() schema = _lib.ArrowSchemaHandle() data._export_to_c(array.address, schema.address) @@ -1151,3 +1156,13 @@ def _warn_unclosed(name): category=ResourceWarning, stacklevel=2, ) + + +def _is_arrow_data(data): + return ( + hasattr(data, "__arrow_c_array__") + or hasattr(data, "__arrow_c_stream__") + or isinstance( + data, (pyarrow.RecordBatch, pyarrow.Table, pyarrow.RecordBatchReader) + ) + ) diff --git a/python/adbc_driver_manager/tests/test_dbapi.py b/python/adbc_driver_manager/tests/test_dbapi.py index 52b8e1316c..20990eff43 100644 --- a/python/adbc_driver_manager/tests/test_dbapi.py +++ b/python/adbc_driver_manager/tests/test_dbapi.py @@ -134,6 +134,22 @@ def test_get_table_types(sqlite): assert sqlite.adbc_get_table_types() == ["table", "view"] +class ArrayWrapper: + def __init__(self, array): + self.array = array + + def __arrow_c_array__(self, requested_schema=None): + return self.array.__arrow_c_array__(requested_schema=requested_schema) + + +class StreamWrapper: + def __init__(self, stream): + self.stream = stream + + def __arrow_c_stream__(self, requested_schema=None): + return self.stream.__arrow_c_stream__(requested_schema=requested_schema) + + @pytest.mark.parametrize( "data", [ @@ -142,6 +158,12 @@ def test_get_table_types(sqlite): lambda: pyarrow.table( [[1, 2], ["foo", ""]], names=["ints", "strs"] ).to_reader(), + lambda: ArrayWrapper( + pyarrow.record_batch([[1, 2], ["foo", ""]], names=["ints", "strs"]) + ), + lambda: StreamWrapper( + pyarrow.table([[1, 2], ["foo", ""]], names=["ints", "strs"]) + ), ], ) @pytest.mark.sqlite @@ -237,6 +259,8 @@ def test_query_fetch_df(sqlite): (1.0, 2), pyarrow.record_batch([[1.0], [2]], names=["float", "int"]), pyarrow.table([[1.0], [2]], names=["float", "int"]), + ArrayWrapper(pyarrow.record_batch([[1.0], [2]], names=["float", "int"])), + StreamWrapper(pyarrow.table([[1.0], [2]], names=["float", "int"])), ], ) def test_execute_parameters(sqlite, parameters): @@ -253,6 +277,10 @@ def test_execute_parameters(sqlite, parameters): pyarrow.record_batch([[1, 3], ["a", None]], names=["float", "str"]), pyarrow.table([[1, 3], ["a", None]], names=["float", "str"]), pyarrow.table([[1, 3], ["a", None]], names=["float", "str"]).to_batches()[0], + ArrayWrapper( + pyarrow.record_batch([[1, 3], ["a", None]], names=["float", "str"]) + ), + StreamWrapper(pyarrow.table([[1, 3], ["a", None]], names=["float", "str"])), ((x, y) for x, y in ((1, "a"), (3, None))), ], ) diff --git a/python/adbc_driver_manager/tests/test_lowlevel.py b/python/adbc_driver_manager/tests/test_lowlevel.py index 15d98e5389..98c8721ca0 100644 --- a/python/adbc_driver_manager/tests/test_lowlevel.py +++ b/python/adbc_driver_manager/tests/test_lowlevel.py @@ -390,3 +390,75 @@ def test_child_tracking(sqlite): RuntimeError, match="Cannot close AdbcDatabase with open AdbcConnection" ): db.close() + + +@pytest.mark.sqlite +def test_pycapsule(sqlite): + _, conn = sqlite + handle = conn.get_table_types() + with pyarrow.RecordBatchReader._import_from_c_capsule( + handle.__arrow_c_stream__() + ) as reader: + reader.read_all() + + # set up some data + data = pyarrow.record_batch( + [ + [1, 2, 3, 4], + ["a", "b", "c", "d"], + ], + names=["ints", "strs"], + ) + table = pyarrow.Table.from_batches([data]) + + with adbc_driver_manager.AdbcStatement(conn) as stmt: + stmt.set_options(**{adbc_driver_manager.INGEST_OPTION_TARGET_TABLE: "foo"}) + schema_capsule, array_capsule = data.__arrow_c_array__() + stmt.bind(array_capsule, schema_capsule) + stmt.execute_update() + + with adbc_driver_manager.AdbcStatement(conn) as stmt: + stmt.set_options(**{adbc_driver_manager.INGEST_OPTION_TARGET_TABLE: "bar"}) + stream_capsule = data.__arrow_c_stream__() + stmt.bind_stream(stream_capsule) + stmt.execute_update() + + # importing a schema + + handle = conn.get_table_schema(catalog=None, db_schema=None, table_name="foo") + assert data.schema == pyarrow.schema(handle) + # ensure consumed schema was marked as such + with pytest.raises(ValueError, match="Cannot import released ArrowSchema"): + pyarrow.schema(handle) + + # smoke test for the capsule calling release + capsule = conn.get_table_schema( + catalog=None, db_schema=None, table_name="foo" + ).__arrow_c_schema__() + del capsule + + # importing a stream + + with adbc_driver_manager.AdbcStatement(conn) as stmt: + stmt.set_sql_query("SELECT * FROM foo") + handle, _ = stmt.execute_query() + + result = pyarrow.table(handle) + assert result == table + + with adbc_driver_manager.AdbcStatement(conn) as stmt: + stmt.set_sql_query("SELECT * FROM bar") + handle, _ = stmt.execute_query() + + result = pyarrow.table(handle) + assert result == table + + # ensure consumed schema was marked as such + with pytest.raises(ValueError, match="Cannot import released ArrowArrayStream"): + pyarrow.table(handle) + + # smoke test for the capsule calling release + with adbc_driver_manager.AdbcStatement(conn) as stmt: + stmt.set_sql_query("SELECT * FROM foo") + capsule = stmt.execute_query()[0].__arrow_c_stream__() + del capsule