Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python bindings: add a ogr.Layer.WriteArrow() method consuming __arrow_c_stream__ or __arrow_c_array__ interfaces #9133

Merged
merged 2 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions autotest/ogr/ogr_mem.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,66 @@ def test_ogr_mem_arrow_stream_pycapsule_interface():
del stream


###############################################################################
# Test consuming __arrow_c_stream__() interface.
# Cf https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html


@gdaltest.enable_exceptions()
def test_ogr_mem_consume_arrow_stream_pycapsule_interface():

ds = ogr.GetDriverByName("Memory").CreateDataSource("")
lyr = ds.CreateLayer("foo", geom_type=ogr.wkbNone)
lyr.CreateGeomField(ogr.GeomFieldDefn("my_geometry"))
lyr.CreateField(ogr.FieldDefn("foo"))
f = ogr.Feature(lyr.GetLayerDefn())
f["foo"] = "bar"
f.SetGeometry(ogr.CreateGeometryFromWkt("POINT (1 2)"))
lyr.CreateFeature(f)

lyr2 = ds.CreateLayer("foo2")
lyr2.WriteArrow(lyr)

f = lyr2.GetNextFeature()
assert f["foo"] == "bar"
assert f.GetGeometryRef().ExportToIsoWkt() == "POINT (1 2)"


###############################################################################
# Test consuming __arrow_c_array__() interface.
# Cf https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html


@gdaltest.enable_exceptions()
def test_ogr_mem_consume_arrow_array_pycapsule_interface():
pyarrow = pytest.importorskip("pyarrow")
if int(pyarrow.__version__.split(".")[0]) < 14:
pytest.skip("pyarrow >= 14 needed")

ds = ogr.GetDriverByName("Memory").CreateDataSource("")
lyr = ds.CreateLayer("foo")
lyr.CreateField(ogr.FieldDefn("foo"))
f = ogr.Feature(lyr.GetLayerDefn())
f["foo"] = "bar"
f.SetGeometry(ogr.CreateGeometryFromWkt("POINT (1 2)"))
lyr.CreateFeature(f)

table = pyarrow.table(lyr)

lyr2 = ds.CreateLayer("foo2")
batches = table.to_batches()
for batch in batches:
array = batch.to_struct_array()
if not hasattr(array, "__arrow_c_array__"):
pytest.skip("table does not declare __arrow_c_array__")

lyr2.WriteArrow(array)

f = lyr2.GetNextFeature()
assert f["foo"] == "bar"
assert f.GetGeometryRef().ExportToIsoWkt() == "POINT (1 2)"


###############################################################################


Expand Down
181 changes: 181 additions & 0 deletions swig/include/ogr.i
Original file line number Diff line number Diff line change
Expand Up @@ -1155,6 +1155,73 @@ static void ReleaseArrowArrayStreamPyCapsule(PyObject* capsule) {
}
CPLFree(stream);
}

static char** ParseArrowMetadata(const char *pabyMetadata)
{
char** ret = NULL;
int32_t nKVP;
memcpy(&nKVP, pabyMetadata, sizeof(int32_t));
pabyMetadata += sizeof(int32_t);
for (int i = 0; i < nKVP; ++i)
{
int32_t nSizeKey;
memcpy(&nSizeKey, pabyMetadata, sizeof(int32_t));
pabyMetadata += sizeof(int32_t);
std::string osKey;
osKey.assign(pabyMetadata, nSizeKey);
pabyMetadata += nSizeKey;

int32_t nSizeValue;
memcpy(&nSizeValue, pabyMetadata, sizeof(int32_t));
pabyMetadata += sizeof(int32_t);
std::string osValue;
osValue.assign(pabyMetadata, nSizeValue);
pabyMetadata += nSizeValue;

ret = CSLSetNameValue(ret, osKey.c_str(), osValue.c_str());
}

return ret;
}

// Create output fields using CreateFieldFromArrowSchema()
static bool CreateFieldsFromArrowSchema(OGRLayerH hDstLayer,
const struct ArrowSchema* schemaSrc,
char** options)
{
for (int i = 0; i < schemaSrc->n_children; ++i)
{
const char *metadata =
schemaSrc->children[i]->metadata;
if( metadata )
{
char** keyValues = ParseArrowMetadata(metadata);
const char *ARROW_EXTENSION_NAME_KEY = "ARROW:extension:name";
const char *EXTENSION_NAME_OGC_WKB = "ogc.wkb";
const char *EXTENSION_NAME_GEOARROW_WKB = "geoarrow.wkb";
const char* value = CSLFetchNameValue(keyValues, ARROW_EXTENSION_NAME_KEY);
const bool bSkip = ( value && (EQUAL(value, EXTENSION_NAME_OGC_WKB) || EQUAL(value, EXTENSION_NAME_GEOARROW_WKB)) );
CSLDestroy(keyValues);
if( bSkip )
continue;
}

const char *pszFieldName =
schemaSrc->children[i]->name;
if (!EQUAL(pszFieldName, "OGC_FID") &&
!EQUAL(pszFieldName, "wkb_geometry") &&
!OGR_L_CreateFieldFromArrowSchema(
hDstLayer, schemaSrc->children[i], options))
{
CPLError(CE_Failure, CPLE_AppDefined,
"Cannot create field %s",
pszFieldName);
return false;
}
}
return true;
}

%}

#endif
Expand Down Expand Up @@ -1580,6 +1647,120 @@ public:
{
return OGR_L_WriteArrowBatch(self, schema, array, options) ? OGRERR_NONE : OGRERR_FAILURE;
}

OGRErr WriteArrowStreamCapsule(PyObject* capsule, int createFieldsFromSchema, char** options = NULL)
{
ArrowArrayStream* stream = (ArrowArrayStream*)PyCapsule_GetPointer(capsule, "arrow_array_stream");
if( !stream )
{
CPLError(CE_Failure, CPLE_AppDefined, "PyCapsule_GetPointer(capsule, \"arrow_array_stream\") failed");
return OGRERR_FAILURE;
}
if( stream->release == NULL )
{
CPLError(CE_Failure, CPLE_AppDefined, "stream->release == NULL");
return OGRERR_FAILURE;
}

ArrowSchema schema;
if( stream->get_schema(stream, &schema) != 0 )
{
stream->release(stream);
return OGRERR_FAILURE;
}

if( createFieldsFromSchema == TRUE ||
(createFieldsFromSchema == -1 && OGR_FD_GetFieldCount(OGR_L_GetLayerDefn(self)) == 0) )
{
if( !CreateFieldsFromArrowSchema(self, &schema, options) )
{
schema.release(&schema);
stream->release(stream);
return OGRERR_FAILURE;
}
}

while( true )
{
ArrowArray array;
if( stream->get_next(stream, &array) == 0 )
{
if( array.release == NULL )
break;
if( !OGR_L_WriteArrowBatch(self, &schema, &array, options) )
{
if( array.release )
array.release(&array);
schema.release(&schema);
stream->release(stream);
return OGRERR_FAILURE;
}
if( array.release )
array.release(&array);
}
else
{
CPLError(CE_Failure, CPLE_AppDefined, "stream->get_next(stream, &array) failed");
schema.release(&schema);
stream->release(stream);
return OGRERR_FAILURE;
}
}
schema.release(&schema);
stream->release(stream);
return OGRERR_NONE;
}

OGRErr WriteArrowSchemaAndArrowArrayCapsule(PyObject* schemaCapsule, PyObject* arrayCapsule, int createFieldsFromSchema, char** options = NULL)
{
ArrowSchema* schema = (ArrowSchema*)PyCapsule_GetPointer(schemaCapsule, "arrow_schema");
if( !schema )
{
CPLError(CE_Failure, CPLE_AppDefined, "PyCapsule_GetPointer(schemaCapsule, \"arrow_schema\") failed");
return OGRERR_FAILURE;
}
if( schema->release == NULL )
{
CPLError(CE_Failure, CPLE_AppDefined, "schema->release == NULL");
return OGRERR_FAILURE;
}

if( createFieldsFromSchema == TRUE ||
(createFieldsFromSchema == -1 && OGR_FD_GetFieldCount(OGR_L_GetLayerDefn(self)) == 0) )
{
if( !CreateFieldsFromArrowSchema(self, schema, options) )
{
schema->release(schema);
return OGRERR_FAILURE;
}
}

ArrowArray* array = (ArrowArray*)PyCapsule_GetPointer(arrayCapsule, "arrow_array");
if( !array )
{
CPLError(CE_Failure, CPLE_AppDefined, "PyCapsule_GetPointer(arrayCapsule, \"arrow_array\") failed");
schema->release(schema);
return OGRERR_FAILURE;
}
if( array->release == NULL )
{
CPLError(CE_Failure, CPLE_AppDefined, "array->release == NULL");
schema->release(schema);
return OGRERR_FAILURE;
}

OGRErr eErr = OGRERR_NONE;
if( !OGR_L_WriteArrowBatch(self, schema, array, options) )
{
eErr = OGRERR_FAILURE;
}

if( schema->release )
schema->release(schema);
if( array->release )
array->release(array);
return eErr;
}
#endif

#ifdef SWIGPYTHON
Expand Down
49 changes: 48 additions & 1 deletion swig/include/python/ogr_python.i
Original file line number Diff line number Diff line change
Expand Up @@ -643,8 +643,55 @@ def ReleaseResultSet(self, sql_lyr):
return self.CreateFieldFromArrowSchema(schema, options)


def WriteArrow(self, obj, requested_schema=None, createFieldsFromSchema=None, options=[]):
"""Write the content of the passed object, which must implement the
__arrow_c_stream__ or __arrow_c_array__ interface, into the layer.

Parameters
----------
obj:
Object implementing the __arrow_c_stream__ or __arrow_c_array__ interface

requested_schema: PyCapsule, object implementing __arrow_c_schema__ or None. Default None
The schema to which the stream should be casted, passed as a
PyCapsule containing a C ArrowSchema representation of the
requested schema, or an object implementing the __arrow_c_schema__ interface.

createFieldsFromSchema: boolean or None. Default to None
Whether OGRLayer::CreateFieldFromArrowSchema() should be called. If None
specified, it is called if no fields have been created yet

options: list of strings
Options to pass to OGRLayer::CreateFieldFromArrowSchema() and OGRLayer::WriteArrowBatch()

"""

if createFieldsFromSchema is None:
createFieldsFromSchema = -1
elif createFieldsFromSchema is True:
createFieldsFromSchema = 1
else:
createFieldsFromSchema = 0

if requested_schema is not None and hasattr(requested_schema, "__arrow_c_schema__"):
requested_schema = requested_schema.__arrow_c_schema__()

if hasattr(obj, "__arrow_c_stream__"):
stream_capsule = obj.__arrow_c_stream__(requested_schema=requested_schema)
return self.WriteArrowStreamCapsule(stream_capsule, createFieldsFromSchema, options)

if hasattr(obj, "__arrow_c_array__"):
schema_capsule, array_capsule = obj.__arrow_c_array__(requested_schema=requested_schema)
return self.WriteArrowSchemaAndArrowArrayCapsule(schema_capsule, array_capsule, createFieldsFromSchema, options)

raise Exception("Passed object does not implement the __arrow_c_stream__ or __arrow_c_array__ interface.")


def WritePyArrow(self, pa_batch, options=[]):
"""Write the content of the passed PyArrow batch (either a pyarrow.Table, a pyarrow.RecordBatch or a pyarrow.StructArray) into the layer."""
"""Write the content of the passed PyArrow batch (either a pyarrow.Table, a pyarrow.RecordBatch or a pyarrow.StructArray) into the layer.

See also the WriteArrow() method to be independent of PyArrow
"""

import pyarrow as pa

Expand Down
Loading