Skip to content

Commit

Permalink
ingest data supporting the Arrow PyCapsule protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
jorisvandenbossche committed Dec 7, 2023
1 parent 07da02c commit e0fdba2
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 32 deletions.
31 changes: 21 additions & 10 deletions python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ from typing import List, Tuple
cimport cpython
import cython
from cpython.bytes cimport PyBytes_FromStringAndSize
from cpython.pycapsule cimport PyCapsule_GetPointer, PyCapsule_New
from cpython.pycapsule cimport (
PyCapsule_GetPointer, PyCapsule_New, PyCapsule_CheckExact
)
from libc.stdint cimport int32_t, int64_t, uint8_t, uint32_t, uintptr_t
from libc.stdlib cimport malloc, free
from libc.string cimport memcpy, memset
Expand Down Expand Up @@ -1086,26 +1088,31 @@ cdef class AdbcStatement(_AdbcHandle):

Parameters
----------
data : int or ArrowArrayHandle
schema : int or ArrowSchemaHandle
data : PyCapsule or int or ArrowArrayHandle
schema : PyCapsule or int or ArrowSchemaHandle
"""
cdef CAdbcError c_error = empty_error()
cdef CArrowArray* c_array
cdef CArrowSchema* c_schema

if isinstance(data, ArrowArrayHandle):
if PyCapsule_CheckExact(data):
c_array = <CArrowArray*> PyCapsule_GetPointer(data, "arrow_array")
elif isinstance(data, ArrowArrayHandle):
c_array = &(<ArrowArrayHandle> data).array
elif isinstance(data, int):
c_array = <CArrowArray*> data
else:
raise TypeError(f"data must be int or ArrowArrayHandle, not {type(data)}")
raise TypeError(
f"data must be a PyCapsule, int or ArrowArrayHandle, not {type(data)}")

if isinstance(schema, ArrowSchemaHandle):
if PyCapsule_CheckExact(schema):
c_schema = <CArrowSchema*> PyCapsule_GetPointer(schema, "arrow_schema")
elif isinstance(schema, ArrowSchemaHandle):
c_schema = &(<ArrowSchemaHandle> schema).schema
elif isinstance(schema, int):
c_schema = <CArrowSchema*> schema
else:
raise TypeError(f"schema must be int or ArrowSchemaHandle, "
raise TypeError("schema must be a PyCapsule, int or ArrowSchemaHandle, "
f"not {type(schema)}")

with nogil:
Expand All @@ -1122,17 +1129,21 @@ cdef class AdbcStatement(_AdbcHandle):

Parameters
----------
stream : int or ArrowArrayStreamHandle
stream : PyCapsule or int or ArrowArrayStreamHandle
"""
cdef CAdbcError c_error = empty_error()
cdef CArrowArrayStream* c_stream

if isinstance(stream, ArrowArrayStreamHandle):
if PyCapsule_CheckExact(stream):
c_stream = <CArrowArrayStream*> PyCapsule_GetPointer(
stream, "arrow_array_stream"
)
elif isinstance(stream, ArrowArrayStreamHandle):
c_stream = &(<ArrowArrayStreamHandle> stream).stream
elif isinstance(stream, int):
c_stream = <CArrowArrayStream*> stream
else:
raise TypeError(f"data must be int or ArrowArrayStreamHandle, "
raise TypeError(f"data must be a PyCapsule, int or ArrowArrayStreamHandle, "
f"not {type(stream)}")

with nogil:
Expand Down
52 changes: 35 additions & 17 deletions python/adbc_driver_manager/adbc_driver_manager/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,17 +612,22 @@ def close(self):
self._closed = True

def _bind(self, parameters) -> None:
if isinstance(parameters, pyarrow.RecordBatch):
if hasattr(parameters, "__arrow_c_array__"):
schema_capsule, array_capsule = parameters.__arrow_c_array__()
self._stmt.bind(array_capsule, schema_capsule)
elif hasattr(parameters, "__arrow_c_stream__"):
self._stmt.bind_stream(parameters.__arrow_c_stream__())
elif isinstance(parameters, pyarrow.RecordBatch):
arr_handle = _lib.ArrowArrayHandle()
sch_handle = _lib.ArrowSchemaHandle()
parameters._export_to_c(arr_handle.address, sch_handle.address)
self._stmt.bind(arr_handle, sch_handle)
return
if isinstance(parameters, pyarrow.Table):
parameters = parameters.to_reader()
stream_handle = _lib.ArrowArrayStreamHandle()
parameters._export_to_c(stream_handle.address)
self._stmt.bind_stream(stream_handle)
else:
if isinstance(parameters, pyarrow.Table):
parameters = parameters.to_reader()
stream_handle = _lib.ArrowArrayStreamHandle()
parameters._export_to_c(stream_handle.address)
self._stmt.bind_stream(stream_handle)

def _prepare_execute(self, operation, parameters=None) -> None:
self._results = None
Expand All @@ -639,9 +644,7 @@ def _prepare_execute(self, operation, parameters=None) -> None:
# Not all drivers support it
pass

if isinstance(
parameters, (pyarrow.RecordBatch, pyarrow.Table, pyarrow.RecordBatchReader)
):
if _is_arrow_data(parameters):
self._bind(parameters)
elif parameters:
rb = pyarrow.record_batch(
Expand Down Expand Up @@ -682,7 +685,7 @@ def executemany(self, operation: Union[bytes, str], seq_of_parameters) -> None:
operation : bytes or str
The query to execute. Pass SQL queries as strings,
(serialized) Substrait plans as bytes.
parameters
seq_of_parameters
Parameters to bind. Can be a list of Python sequences, or
an Arrow record batch, table, or record batch reader. If
None, then the query will be executed once, else it will
Expand All @@ -694,10 +697,7 @@ def executemany(self, operation: Union[bytes, str], seq_of_parameters) -> None:
self._stmt.set_sql_query(operation)
self._stmt.prepare()

if isinstance(
seq_of_parameters,
(pyarrow.RecordBatch, pyarrow.Table, pyarrow.RecordBatchReader),
):
if _is_arrow_data(seq_of_parameters):
arrow_parameters = seq_of_parameters
elif seq_of_parameters:
arrow_parameters = pyarrow.RecordBatch.from_pydict(
Expand Down Expand Up @@ -805,7 +805,10 @@ def adbc_ingest(
table_name
The table to insert into.
data
The Arrow data to insert.
The Arrow data to insert. This can be a pyarrow RecordBatch, Table
or RecordBatchReader, or any Arrow-compatible data that implements
the Arrow PyCapsule Protocol (i.e. has an ``__arrow_c_array__``
or ``__arrow_c_stream__ ``method).
mode
How to deal with existing data:
Expand Down Expand Up @@ -877,7 +880,12 @@ def adbc_ingest(
except NotSupportedError:
pass

if isinstance(data, pyarrow.RecordBatch):
if hasattr(data, "__arrow_c_array__"):
schema_capsule, array_capsule = data.__arrow_c_array__()
self._stmt.bind(array_capsule, schema_capsule)
elif hasattr(data, "__arrow_c_stream__"):
self._stmt.bind_stream(data.__arrow_c_stream__())
elif isinstance(data, pyarrow.RecordBatch):
array = _lib.ArrowArrayHandle()
schema = _lib.ArrowSchemaHandle()
data._export_to_c(array.address, schema.address)
Expand Down Expand Up @@ -1150,3 +1158,13 @@ def _warn_unclosed(name):
category=ResourceWarning,
stacklevel=2,
)


def _is_arrow_data(data):
return (
hasattr(data, "__arrow_c_array__")
or hasattr(data, "__arrow_c_stream__")
or isinstance(
data, (pyarrow.RecordBatch, pyarrow.Table, pyarrow.RecordBatchReader),
)
)
28 changes: 28 additions & 0 deletions python/adbc_driver_manager/tests/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,22 @@ def test_get_table_types(sqlite):
assert sqlite.adbc_get_table_types() == ["table", "view"]


class ArrayWrapper:
def __init__(self, array):
self.array = array

def __arrow_c_array__(self, requested_schema=None):
return self.array.__arrow_c_array__(requested_schema=requested_schema)


class StreamWrapper:
def __init__(self, stream):
self.stream = stream

def __arrow_c_stream__(self, requested_schema=None):
return self.stream.__arrow_c_stream__(requested_schema=requested_schema)


@pytest.mark.parametrize(
"data",
[
Expand All @@ -142,6 +158,12 @@ def test_get_table_types(sqlite):
lambda: pyarrow.table(
[[1, 2], ["foo", ""]], names=["ints", "strs"]
).to_reader(),
lambda: ArrayWrapper(
pyarrow.record_batch([[1, 2], ["foo", ""]], names=["ints", "strs"])
),
lambda: StreamWrapper(
pyarrow.table([[1, 2], ["foo", ""]], names=["ints", "strs"])
),
],
)
@pytest.mark.sqlite
Expand Down Expand Up @@ -237,6 +259,8 @@ def test_query_fetch_df(sqlite):
(1.0, 2),
pyarrow.record_batch([[1.0], [2]], names=["float", "int"]),
pyarrow.table([[1.0], [2]], names=["float", "int"]),
ArrayWrapper(pyarrow.record_batch([[1.0], [2]], names=["float", "int"])),
StreamWrapper(pyarrow.table([[1.0], [2]], names=["float", "int"])),
],
)
def test_execute_parameters(sqlite, parameters):
Expand All @@ -253,6 +277,10 @@ def test_execute_parameters(sqlite, parameters):
pyarrow.record_batch([[1, 3], ["a", None]], names=["float", "str"]),
pyarrow.table([[1, 3], ["a", None]], names=["float", "str"]),
pyarrow.table([[1, 3], ["a", None]], names=["float", "str"]).to_batches()[0],
ArrayWrapper(
pyarrow.record_batch([[1, 3], ["a", None]], names=["float", "str"])
),
StreamWrapper(pyarrow.table([[1, 3], ["a", None]], names=["float", "str"])),
((x, y) for x, y in ((1, "a"), (3, None))),
],
)
Expand Down
26 changes: 21 additions & 5 deletions python/adbc_driver_manager/tests/test_lowlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,9 +407,18 @@ def test_pycapsule(sqlite):
],
names=["ints", "strs"],
)
table = pyarrow.Table.from_batches([data])

with adbc_driver_manager.AdbcStatement(conn) as stmt:
stmt.set_options(**{adbc_driver_manager.INGEST_OPTION_TARGET_TABLE: "foo"})
_bind(stmt, data)
schema_capsule, array_capsule = data.__arrow_c_array__()
stmt.bind(array_capsule, schema_capsule)
stmt.execute_update()

with adbc_driver_manager.AdbcStatement(conn) as stmt:
stmt.set_options(**{adbc_driver_manager.INGEST_OPTION_TARGET_TABLE: "bar"})
stream_capsule = data.__arrow_c_stream__()
stmt.bind_stream(stream_capsule)
stmt.execute_update()

# importing a schema
Expand All @@ -421,7 +430,9 @@ def test_pycapsule(sqlite):
pyarrow.schema(handle)

# smoke test for the capsule calling release
capsule = conn.get_table_schema(catalog=None, db_schema=None, table_name="foo").__arrow_c_schema__()
capsule = conn.get_table_schema(
catalog=None, db_schema=None, table_name="foo"
).__arrow_c_schema__()
del capsule

# importing a stream
Expand All @@ -431,7 +442,14 @@ def test_pycapsule(sqlite):
handle, _ = stmt.execute_query()

result = pyarrow.table(handle)
assert result.to_batches()[0] == data
assert result == table

with adbc_driver_manager.AdbcStatement(conn) as stmt:
stmt.set_sql_query("SELECT * FROM bar")
handle, _ = stmt.execute_query()

result = pyarrow.table(handle)
assert result == table

# ensure consumed schema was marked as such
with pytest.raises(ValueError, match="Cannot import released ArrowArrayStream"):
Expand All @@ -442,5 +460,3 @@ def test_pycapsule(sqlite):
stmt.set_sql_query("SELECT * FROM foo")
capsule = stmt.execute_query()[0].__arrow_c_stream__()
del capsule

# TODO: also need to import from things supporting protocol

0 comments on commit e0fdba2

Please sign in to comment.