Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python/adbc_driver_manager): export handles and ingest data through python Arrow PyCapsule interface #1346

Merged
merged 10 commits into from
Dec 13, 2023
9 changes: 7 additions & 2 deletions python/adbc_driver_manager/adbc_driver_manager/_lib.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,15 @@ from libc.stdint cimport int32_t, int64_t, uint8_t, uint32_t

cdef extern from "adbc.h" nogil:
# C ABI

ctypedef void (*CArrowSchemaRelease)(void*)
ctypedef void (*CArrowArrayRelease)(void*)

cdef struct CArrowSchema"ArrowSchema":
pass
CArrowSchemaRelease release

cdef struct CArrowArray"ArrowArray":
pass
CArrowArrayRelease release

ctypedef int (*CArrowArrayStreamGetLastError)(void*)
ctypedef int (*CArrowArrayStreamGetNext)(void*, CArrowArray*)
Expand Down
110 changes: 97 additions & 13 deletions python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,15 @@ import threading
import typing
from typing import List, Tuple

cimport cpython
import cython
from cpython.bytes cimport PyBytes_FromStringAndSize
from cpython.pycapsule cimport (
PyCapsule_GetPointer, PyCapsule_New, PyCapsule_CheckExact
)
from libc.stdint cimport int32_t, int64_t, uint8_t, uint32_t, uintptr_t
from libc.string cimport memset
from libc.stdlib cimport malloc, free
from libc.string cimport memcpy, memset
from libcpp.vector cimport vector as c_vector

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -304,9 +309,29 @@ cdef class _AdbcHandle:
f"with open {self._child_type}")


cdef void pycapsule_schema_deleter(object capsule) noexcept:
cdef CArrowSchema* allocated = <CArrowSchema*>PyCapsule_GetPointer(
capsule, "arrow_schema"
)
if allocated.release != NULL:
allocated.release(allocated)
free(allocated)


cdef void pycapsule_stream_deleter(object capsule) noexcept:
cdef CArrowArrayStream* allocated = <CArrowArrayStream*> PyCapsule_GetPointer(
capsule, "arrow_array_stream"
)
if allocated.release != NULL:
allocated.release(allocated)
free(allocated)


cdef class ArrowSchemaHandle:
"""
A wrapper for an allocated ArrowSchema.

This object implements the Arrow PyCapsule interface.
"""
cdef:
CArrowSchema schema
Expand All @@ -316,23 +341,42 @@ cdef class ArrowSchemaHandle:
"""The address of the ArrowSchema."""
return <uintptr_t> &self.schema

def __arrow_c_schema__(self) -> object:
"""Consume this object to get a PyCapsule."""
# Reference:
# https://arrow.apache.org/docs/dev/format/CDataInterface/PyCapsuleInterface.html#create-a-pycapsule
cdef CArrowSchema* allocated = <CArrowSchema*> malloc(sizeof(CArrowSchema))
allocated.release = NULL
capsule = PyCapsule_New(
<void*>allocated, "arrow_schema", &pycapsule_schema_deleter,
)
memcpy(allocated, &self.schema, sizeof(CArrowSchema))
self.schema.release = NULL
return capsule
Comment on lines +353 to +355
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are "moving" the schema here, while in nanoarrow I opted for a hard copy for the schema (using nanoarrow's ArrowSchemaDeepCopy).

But I think the only advantage of a hard copy is that this means you can consume it multiple times? (or in the case of nanoarrow-python, that the nanoarrow Schema object is still valid and inspectable after it has been converted to eg a pyarrow.Schema)
For ADBC, I think the use case will be much more "receive handle and convert it directly once", given that the Handle object itself isn't useful at all (in contrast to nanoarrow.Schema), so moving here is probably fine?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think moving makes sense here.



cdef class ArrowArrayHandle:
"""
A wrapper for an allocated ArrowArray.

This object implements the Arrow PyCapsule interface.
"""
cdef:
CArrowArray array

@property
def address(self) -> int:
"""The address of the ArrowArray."""
"""
The address of the ArrowArray.
"""
return <uintptr_t> &self.array


cdef class ArrowArrayStreamHandle:
"""
A wrapper for an allocated ArrowArrayStream.

This object implements the Arrow PyCapsule interface.
"""
cdef:
CArrowArrayStream stream
Expand All @@ -342,6 +386,21 @@ cdef class ArrowArrayStreamHandle:
"""The address of the ArrowArrayStream."""
return <uintptr_t> &self.stream

def __arrow_c_stream__(self, requested_schema=None) -> object:
"""Consume this object to get a PyCapsule."""
if requested_schema is not None:
raise NotImplementedError("requested_schema")

cdef CArrowArrayStream* allocated = \
<CArrowArrayStream*> malloc(sizeof(CArrowArrayStream))
allocated.release = NULL
capsule = PyCapsule_New(
<void*>allocated, "arrow_array_stream", &pycapsule_stream_deleter,
)
memcpy(allocated, &self.stream, sizeof(CArrowArrayStream))
self.stream.release = NULL
return capsule


class GetObjectsDepth(enum.IntEnum):
ALL = ADBC_OBJECT_DEPTH_ALL
Expand Down Expand Up @@ -1000,32 +1059,47 @@ cdef class AdbcStatement(_AdbcHandle):

connection._open_child()

def bind(self, data, schema) -> None:
def bind(self, data, schema=None) -> None:
"""
Bind an ArrowArray to this statement.

Parameters
----------
data : int or ArrowArrayHandle
schema : int or ArrowSchemaHandle
data : PyCapsule or int or ArrowArrayHandle
schema : PyCapsule or int or ArrowSchemaHandle
"""
cdef CAdbcError c_error = empty_error()
cdef CArrowArray* c_array
cdef CArrowSchema* c_schema

if isinstance(data, ArrowArrayHandle):
if hasattr(data, "__arrow_c_array__") and not isinstance(data, ArrowArrayHandle):
if schema is not None:
raise ValueError(
"Can not provide a schema when passing Arrow-compatible "
"data that implements the Arrow PyCapsule Protocol"
)
schema, data = data.__arrow_c_array__()

if PyCapsule_CheckExact(data):
lidavidm marked this conversation as resolved.
Show resolved Hide resolved
c_array = <CArrowArray*> PyCapsule_GetPointer(data, "arrow_array")
elif isinstance(data, ArrowArrayHandle):
c_array = &(<ArrowArrayHandle> data).array
elif isinstance(data, int):
c_array = <CArrowArray*> data
else:
raise TypeError(f"data must be int or ArrowArrayHandle, not {type(data)}")

if isinstance(schema, ArrowSchemaHandle):
raise TypeError(
"data must be Arrow-compatible data (implementing the Arrow PyCapsule "
f"Protocol), a PyCapsule, int or ArrowArrayHandle, not {type(data)}"
)

if PyCapsule_CheckExact(schema):
c_schema = <CArrowSchema*> PyCapsule_GetPointer(schema, "arrow_schema")
elif isinstance(schema, ArrowSchemaHandle):
c_schema = &(<ArrowSchemaHandle> schema).schema
elif isinstance(schema, int):
c_schema = <CArrowSchema*> schema
else:
raise TypeError(f"schema must be int or ArrowSchemaHandle, "
raise TypeError("schema must be a PyCapsule, int or ArrowSchemaHandle, "
f"not {type(schema)}")

with nogil:
Expand All @@ -1042,17 +1116,27 @@ cdef class AdbcStatement(_AdbcHandle):

Parameters
----------
stream : int or ArrowArrayStreamHandle
stream : PyCapsule or int or ArrowArrayStreamHandle
"""
cdef CAdbcError c_error = empty_error()
cdef CArrowArrayStream* c_stream

if isinstance(stream, ArrowArrayStreamHandle):
if (
hasattr(stream, "__arrow_c_stream__")
and not isinstance(stream, ArrowArrayStreamHandle)
):
stream = stream.__arrow_c_stream__()

if PyCapsule_CheckExact(stream):
c_stream = <CArrowArrayStream*> PyCapsule_GetPointer(
stream, "arrow_array_stream"
)
elif isinstance(stream, ArrowArrayStreamHandle):
c_stream = &(<ArrowArrayStreamHandle> stream).stream
elif isinstance(stream, int):
c_stream = <CArrowArrayStream*> stream
else:
raise TypeError(f"data must be int or ArrowArrayStreamHandle, "
raise TypeError(f"data must be a PyCapsule, int or ArrowArrayStreamHandle, "
f"not {type(stream)}")

with nogil:
Expand Down
51 changes: 33 additions & 18 deletions python/adbc_driver_manager/adbc_driver_manager/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,17 +612,21 @@ def close(self):
self._closed = True

def _bind(self, parameters) -> None:
if isinstance(parameters, pyarrow.RecordBatch):
if hasattr(parameters, "__arrow_c_array__"):
self._stmt.bind(parameters)
elif hasattr(parameters, "__arrow_c_stream__"):
self._stmt.bind_stream(parameters)
elif isinstance(parameters, pyarrow.RecordBatch):
arr_handle = _lib.ArrowArrayHandle()
sch_handle = _lib.ArrowSchemaHandle()
parameters._export_to_c(arr_handle.address, sch_handle.address)
self._stmt.bind(arr_handle, sch_handle)
return
if isinstance(parameters, pyarrow.Table):
parameters = parameters.to_reader()
stream_handle = _lib.ArrowArrayStreamHandle()
parameters._export_to_c(stream_handle.address)
self._stmt.bind_stream(stream_handle)
else:
if isinstance(parameters, pyarrow.Table):
parameters = parameters.to_reader()
stream_handle = _lib.ArrowArrayStreamHandle()
parameters._export_to_c(stream_handle.address)
self._stmt.bind_stream(stream_handle)

def _prepare_execute(self, operation, parameters=None) -> None:
self._results = None
Expand All @@ -639,9 +643,7 @@ def _prepare_execute(self, operation, parameters=None) -> None:
# Not all drivers support it
pass

if isinstance(
parameters, (pyarrow.RecordBatch, pyarrow.Table, pyarrow.RecordBatchReader)
):
if _is_arrow_data(parameters):
self._bind(parameters)
elif parameters:
rb = pyarrow.record_batch(
Expand All @@ -668,7 +670,6 @@ def execute(self, operation: Union[bytes, str], parameters=None) -> None:
self._prepare_execute(operation, parameters)
handle, self._rowcount = self._stmt.execute_query()
self._results = _RowIterator(
# pyarrow.RecordBatchReader._import_from_c(handle.address)
_reader.AdbcRecordBatchReader._import_from_c(handle.address)
)

Expand All @@ -683,7 +684,7 @@ def executemany(self, operation: Union[bytes, str], seq_of_parameters) -> None:
operation : bytes or str
The query to execute. Pass SQL queries as strings,
(serialized) Substrait plans as bytes.
parameters
seq_of_parameters
Parameters to bind. Can be a list of Python sequences, or
an Arrow record batch, table, or record batch reader. If
None, then the query will be executed once, else it will
Expand All @@ -695,10 +696,7 @@ def executemany(self, operation: Union[bytes, str], seq_of_parameters) -> None:
self._stmt.set_sql_query(operation)
self._stmt.prepare()

if isinstance(
seq_of_parameters,
(pyarrow.RecordBatch, pyarrow.Table, pyarrow.RecordBatchReader),
):
if _is_arrow_data(seq_of_parameters):
arrow_parameters = seq_of_parameters
elif seq_of_parameters:
arrow_parameters = pyarrow.RecordBatch.from_pydict(
Expand Down Expand Up @@ -806,7 +804,10 @@ def adbc_ingest(
table_name
The table to insert into.
data
The Arrow data to insert.
The Arrow data to insert. This can be a pyarrow RecordBatch, Table
or RecordBatchReader, or any Arrow-compatible data that implements
the Arrow PyCapsule Protocol (i.e. has an ``__arrow_c_array__``
or ``__arrow_c_stream__ ``method).
mode
How to deal with existing data:

Expand Down Expand Up @@ -878,7 +879,11 @@ def adbc_ingest(
except NotSupportedError:
pass

if isinstance(data, pyarrow.RecordBatch):
if hasattr(data, "__arrow_c_array__"):
self._stmt.bind(data)
elif hasattr(data, "__arrow_c_stream__"):
self._stmt.bind_stream(data)
elif isinstance(data, pyarrow.RecordBatch):
array = _lib.ArrowArrayHandle()
schema = _lib.ArrowSchemaHandle()
data._export_to_c(array.address, schema.address)
Expand Down Expand Up @@ -1151,3 +1156,13 @@ def _warn_unclosed(name):
category=ResourceWarning,
stacklevel=2,
)


def _is_arrow_data(data):
return (
hasattr(data, "__arrow_c_array__")
or hasattr(data, "__arrow_c_stream__")
or isinstance(
data, (pyarrow.RecordBatch, pyarrow.Table, pyarrow.RecordBatchReader)
)
)
28 changes: 28 additions & 0 deletions python/adbc_driver_manager/tests/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,22 @@ def test_get_table_types(sqlite):
assert sqlite.adbc_get_table_types() == ["table", "view"]


class ArrayWrapper:
def __init__(self, array):
self.array = array

def __arrow_c_array__(self, requested_schema=None):
return self.array.__arrow_c_array__(requested_schema=requested_schema)


class StreamWrapper:
def __init__(self, stream):
self.stream = stream

def __arrow_c_stream__(self, requested_schema=None):
return self.stream.__arrow_c_stream__(requested_schema=requested_schema)


@pytest.mark.parametrize(
"data",
[
Expand All @@ -142,6 +158,12 @@ def test_get_table_types(sqlite):
lambda: pyarrow.table(
[[1, 2], ["foo", ""]], names=["ints", "strs"]
).to_reader(),
lambda: ArrayWrapper(
pyarrow.record_batch([[1, 2], ["foo", ""]], names=["ints", "strs"])
),
lambda: StreamWrapper(
pyarrow.table([[1, 2], ["foo", ""]], names=["ints", "strs"])
),
],
)
@pytest.mark.sqlite
Expand Down Expand Up @@ -237,6 +259,8 @@ def test_query_fetch_df(sqlite):
(1.0, 2),
pyarrow.record_batch([[1.0], [2]], names=["float", "int"]),
pyarrow.table([[1.0], [2]], names=["float", "int"]),
ArrayWrapper(pyarrow.record_batch([[1.0], [2]], names=["float", "int"])),
StreamWrapper(pyarrow.table([[1.0], [2]], names=["float", "int"])),
],
)
def test_execute_parameters(sqlite, parameters):
Expand All @@ -253,6 +277,10 @@ def test_execute_parameters(sqlite, parameters):
pyarrow.record_batch([[1, 3], ["a", None]], names=["float", "str"]),
pyarrow.table([[1, 3], ["a", None]], names=["float", "str"]),
pyarrow.table([[1, 3], ["a", None]], names=["float", "str"]).to_batches()[0],
ArrayWrapper(
pyarrow.record_batch([[1, 3], ["a", None]], names=["float", "str"])
),
StreamWrapper(pyarrow.table([[1, 3], ["a", None]], names=["float", "str"])),
((x, y) for x, y in ((1, "a"), (3, None))),
],
)
Expand Down
Loading
Loading