diff --git a/python/deltalake/writer.py b/python/deltalake/writer.py index e435dea10a..b5b26cf8f5 100644 --- a/python/deltalake/writer.py +++ b/python/deltalake/writer.py @@ -15,6 +15,7 @@ List, Mapping, Optional, + Protocol, Tuple, Union, overload, @@ -34,7 +35,7 @@ import pyarrow as pa import pyarrow.dataset as ds import pyarrow.fs as pa_fs -from pyarrow.lib import RecordBatchReader +from pyarrow import RecordBatchReader from ._internal import DeltaDataChecker as _DeltaDataChecker from ._internal import batch_distinct @@ -70,6 +71,17 @@ DEFAULT_DATA_SKIPPING_NUM_INDEX_COLS = 32 +class ArrowStreamExportable(Protocol): + """Type hint for object exporting Arrow C Stream via Arrow PyCapsule Interface. + + https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html + """ + + def __arrow_c_stream__( + self, requested_schema: Optional[object] = None + ) -> object: ... + + @dataclass class AddAction: path: str @@ -90,6 +102,7 @@ def write_deltalake( pa.RecordBatch, Iterable[pa.RecordBatch], RecordBatchReader, + ArrowStreamExportable, ], *, schema: Optional[Union[pa.Schema, DeltaSchema]] = ..., @@ -123,6 +136,7 @@ def write_deltalake( pa.RecordBatch, Iterable[pa.RecordBatch], RecordBatchReader, + ArrowStreamExportable, ], *, schema: Optional[Union[pa.Schema, DeltaSchema]] = ..., @@ -150,6 +164,7 @@ def write_deltalake( pa.RecordBatch, Iterable[pa.RecordBatch], RecordBatchReader, + ArrowStreamExportable, ], *, schema: Optional[Union[pa.Schema, DeltaSchema]] = ..., @@ -177,6 +192,7 @@ def write_deltalake( pa.RecordBatch, Iterable[pa.RecordBatch], RecordBatchReader, + ArrowStreamExportable, ], *, schema: Optional[Union[pa.Schema, DeltaSchema]] = None, @@ -285,12 +301,26 @@ def write_deltalake( data = convert_pyarrow_table( pa.Table.from_pandas(data), large_dtypes=large_dtypes ) + elif hasattr(data, "__arrow_c_array__"): + data = convert_pyarrow_recordbatch( + pa.record_batch(data), # type:ignore[attr-defined] + large_dtypes, + ) + elif hasattr(data, "__arrow_c_stream__"): + if not hasattr(RecordBatchReader, "from_stream"): + raise ValueError( + "pyarrow 15 or later required to read stream via pycapsule interface" + ) + + data = convert_pyarrow_recordbatchreader( + RecordBatchReader.from_stream(data), large_dtypes + ) elif isinstance(data, Iterable): if schema is None: raise ValueError("You must provide schema if data is Iterable") else: raise TypeError( - f"{type(data).__name__} is not a valid input. Only PyArrow RecordBatchReader, RecordBatch, Iterable[RecordBatch], Table, Dataset or Pandas DataFrame are valid inputs for source." + f"{type(data).__name__} is not a valid input. Only PyArrow RecordBatchReader, RecordBatch, Iterable[RecordBatch], Table, Dataset or Pandas DataFrame or objects implementing the Arrow PyCapsule Interface are valid inputs for source." ) if schema is None: @@ -443,7 +473,7 @@ def visitor(written_file: Any) -> None: raise DeltaProtocolError( "This table's min_writer_version is " f"{table_protocol.min_writer_version}, " - f"""but this method only supports version 2 or 7 with at max these features {SUPPORTED_WRITER_FEATURES} enabled. + f"""but this method only supports version 2 or 7 with at max these features {SUPPORTED_WRITER_FEATURES} enabled. Try engine='rust' instead which supports more features and writer versions.""" ) if (