Skip to content

Commit

Permalink
Python bindings: add a ogr.Layer.WriteArrow() method consuming __arro…
Browse files Browse the repository at this point in the history
…w_c_stream__ or __arrow_c_array__ interfaces

fixes #9132
  • Loading branch information
rouault committed Jan 24, 2024
1 parent f7a22b6 commit 6a00ccd
Show file tree
Hide file tree
Showing 3 changed files with 286 additions and 1 deletion.
60 changes: 60 additions & 0 deletions autotest/ogr/ogr_mem.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,66 @@ def test_ogr_mem_arrow_stream_pycapsule_interface():
del stream


###############################################################################
# Test consuming __arrow_c_stream__() interface.
# Cf https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html


@gdaltest.enable_exceptions()
def test_ogr_mem_consume_arrow_stream_pycapsule_interface():

ds = ogr.GetDriverByName("Memory").CreateDataSource("")
lyr = ds.CreateLayer("foo", geom_type=ogr.wkbNone)
lyr.CreateGeomField(ogr.GeomFieldDefn("my_geometry"))
lyr.CreateField(ogr.FieldDefn("foo"))
f = ogr.Feature(lyr.GetLayerDefn())
f["foo"] = "bar"
f.SetGeometry(ogr.CreateGeometryFromWkt("POINT (1 2)"))
lyr.CreateFeature(f)

lyr2 = ds.CreateLayer("foo2")
lyr2.WriteArrow(lyr)

f = lyr2.GetNextFeature()
assert f["foo"] == "bar"
assert f.GetGeometryRef().ExportToIsoWkt() == "POINT (1 2)"


###############################################################################
# Test consuming __arrow_c_array__() interface.
# Cf https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html


@gdaltest.enable_exceptions()
def test_ogr_mem_consume_arrow_arrow_pycapsule_interface():
pyarrow = pytest.importorskip("pyarrow")
if int(pyarrow.__version__.split(".")[0]) < 14:
pytest.skip("pyarrow >= 14 needed")

ds = ogr.GetDriverByName("Memory").CreateDataSource("")
lyr = ds.CreateLayer("foo")
lyr.CreateField(ogr.FieldDefn("foo"))
f = ogr.Feature(lyr.GetLayerDefn())
f["foo"] = "bar"
f.SetGeometry(ogr.CreateGeometryFromWkt("POINT (1 2)"))
lyr.CreateFeature(f)

table = pyarrow.table(lyr)

lyr2 = ds.CreateLayer("foo2")
batches = table.to_batches()
for batch in batches:
array = batch.to_struct_array()
if not hasattr(array, "__arrow_c_array__"):
pytest.skip("table does not declare __arrow_c_array__")

lyr2.WriteArrow(array)

f = lyr2.GetNextFeature()
assert f["foo"] == "bar"
assert f.GetGeometryRef().ExportToIsoWkt() == "POINT (1 2)"


###############################################################################


Expand Down
181 changes: 181 additions & 0 deletions swig/include/ogr.i
Original file line number Diff line number Diff line change
Expand Up @@ -1155,6 +1155,73 @@ static void ReleaseArrowArrayStreamPyCapsule(PyObject* capsule) {
}
CPLFree(stream);
}

static char** ParseArrowMetadata(const char *pabyMetadata)
{
char** ret = NULL;
int32_t nKVP;
memcpy(&nKVP, pabyMetadata, sizeof(int32_t));
pabyMetadata += sizeof(int32_t);
for (int i = 0; i < nKVP; ++i)
{
int32_t nSizeKey;
memcpy(&nSizeKey, pabyMetadata, sizeof(int32_t));
pabyMetadata += sizeof(int32_t);
std::string osKey;
osKey.assign(pabyMetadata, nSizeKey);
pabyMetadata += nSizeKey;

int32_t nSizeValue;
memcpy(&nSizeValue, pabyMetadata, sizeof(int32_t));
pabyMetadata += sizeof(int32_t);
std::string osValue;
osValue.assign(pabyMetadata, nSizeValue);
pabyMetadata += nSizeValue;

ret = CSLSetNameValue(ret, osKey.c_str(), osValue.c_str());
}

return ret;
}

// Create output fields using CreateFieldFromArrowSchema()
static bool CreateFieldsFromArrowSchema(OGRLayerH hDstLayer,
const struct ArrowSchema* schemaSrc,
char** options)
{
for (int i = 0; i < schemaSrc->n_children; ++i)
{
const char *metadata =
schemaSrc->children[i]->metadata;
if( metadata )
{
char** keyValues = ParseArrowMetadata(metadata);
const char *ARROW_EXTENSION_NAME_KEY = "ARROW:extension:name";
const char *EXTENSION_NAME_OGC_WKB = "ogc.wkb";
const char *EXTENSION_NAME_GEOARROW_WKB = "geoarrow.wkb";
const char* value = CSLFetchNameValue(keyValues, ARROW_EXTENSION_NAME_KEY);
const bool bSkip = ( value && (EQUAL(value, EXTENSION_NAME_OGC_WKB) || EQUAL(value, EXTENSION_NAME_GEOARROW_WKB)) );
CSLDestroy(keyValues);
if( bSkip )
continue;
}

const char *pszFieldName =
schemaSrc->children[i]->name;
if (!EQUAL(pszFieldName, "OGC_FID") &&
!EQUAL(pszFieldName, "wkb_geometry") &&
!OGR_L_CreateFieldFromArrowSchema(
hDstLayer, schemaSrc->children[i], options))
{
CPLError(CE_Failure, CPLE_AppDefined,
"Cannot create field %s",
pszFieldName);
return false;
}
}
return true;
}

%}

#endif
Expand Down Expand Up @@ -1580,6 +1647,120 @@ public:
{
return OGR_L_WriteArrowBatch(self, schema, array, options) ? OGRERR_NONE : OGRERR_FAILURE;
}

OGRErr WriteArrowStreamCapsule(PyObject* capsule, int createFieldsFromSchema, char** options = NULL)
{
ArrowArrayStream* stream = (ArrowArrayStream*)PyCapsule_GetPointer(capsule, "arrow_array_stream");
if( !stream )
{
CPLError(CE_Failure, CPLE_AppDefined, "PyCapsule_GetPointer(capsule, \"arrow_array_stream\") failed");
return OGRERR_FAILURE;
}
if( stream->release == NULL )
{
CPLError(CE_Failure, CPLE_AppDefined, "stream->release == NULL");
return OGRERR_FAILURE;
}

ArrowSchema schema;
if( stream->get_schema(stream, &schema) != 0 )
{
stream->release(stream);
return OGRERR_FAILURE;
}

if( createFieldsFromSchema == TRUE ||
(createFieldsFromSchema == -1 && OGR_FD_GetFieldCount(OGR_L_GetLayerDefn(self)) == 0) )
{
if( !CreateFieldsFromArrowSchema(self, &schema, options) )
{
schema.release(&schema);
stream->release(stream);
return OGRERR_FAILURE;
}
}

while( true )
{
ArrowArray array;
if( stream->get_next(stream, &array) == 0 )
{
if( array.release == NULL )
break;
if( !OGR_L_WriteArrowBatch(self, &schema, &array, options) )
{
if( array.release )
array.release(&array);
schema.release(&schema);
stream->release(stream);
return OGRERR_FAILURE;
}
if( array.release )
array.release(&array);
}
else
{
CPLError(CE_Failure, CPLE_AppDefined, "stream->get_next(stream, &array) failed");
schema.release(&schema);
stream->release(stream);
return OGRERR_FAILURE;
}
}
schema.release(&schema);
stream->release(stream);
return OGRERR_NONE;
}

OGRErr WriteArrowSchemaAndArrowArrayCapsule(PyObject* schemaCapsule, PyObject* arrayCapsule, int createFieldsFromSchema, char** options = NULL)
{
ArrowSchema* schema = (ArrowSchema*)PyCapsule_GetPointer(schemaCapsule, "arrow_schema");
if( !schema )
{
CPLError(CE_Failure, CPLE_AppDefined, "PyCapsule_GetPointer(schemaCapsule, \"arrow_schema\") failed");
return OGRERR_FAILURE;
}
if( schema->release == NULL )
{
CPLError(CE_Failure, CPLE_AppDefined, "schema->release == NULL");
return OGRERR_FAILURE;
}

if( createFieldsFromSchema == TRUE ||
(createFieldsFromSchema == -1 && OGR_FD_GetFieldCount(OGR_L_GetLayerDefn(self)) == 0) )
{
if( !CreateFieldsFromArrowSchema(self, schema, options) )
{
schema->release(schema);
return OGRERR_FAILURE;
}
}

ArrowArray* array = (ArrowArray*)PyCapsule_GetPointer(arrayCapsule, "arrow_array");
if( !array )
{
CPLError(CE_Failure, CPLE_AppDefined, "PyCapsule_GetPointer(arrayCapsule, \"arrow_array\") failed");
schema->release(schema);
return OGRERR_FAILURE;
}
if( array->release == NULL )
{
CPLError(CE_Failure, CPLE_AppDefined, "array->release == NULL");
schema->release(schema);
return OGRERR_FAILURE;
}

OGRErr eErr = OGRERR_NONE;
if( !OGR_L_WriteArrowBatch(self, schema, array, options) )
{
eErr = OGRERR_FAILURE;
}

if( schema->release )
schema->release(schema);
if( array->release )
array->release(array);
return eErr;
}
#endif

#ifdef SWIGPYTHON
Expand Down
46 changes: 45 additions & 1 deletion swig/include/python/ogr_python.i
Original file line number Diff line number Diff line change
Expand Up @@ -643,8 +643,52 @@ def ReleaseResultSet(self, sql_lyr):
return self.CreateFieldFromArrowSchema(schema, options)


def WriteArrow(self, obj, requested_schema=None, createFieldsFromSchema=None, options=[]):
"""Write the content of the passed object, which must implement the
__arrow_c_stream__ or __arrow_c_array__ interface, into the layer.

Parameters
----------
obj:
Object implementing the __arrow_c_stream__ or __arrow_c_array__ interface

requested_schema: PyCapsule, default None
The schema to which the stream should be casted, passed as a
PyCapsule containing a C ArrowSchema representation of the
requested schema.

createFieldsFromSchema: boolean or None. Default to None
Whether OGRLayer::CreateFieldFromArrowSchema() should be called. If None
specified, it is called if no fields have been created yet

options: list of strings
Options to pass to OGRLayer::CreateFieldFromArrowSchema() and OGRLayer::WriteArrowBatch()

"""

if createFieldsFromSchema is None:
createFieldsFromSchema = -1
elif createFieldsFromSchema is True:
createFieldsFromSchema = 1
else:
createFieldsFromSchema = 0

if hasattr(obj, "__arrow_c_stream__"):
stream_capsule = obj.__arrow_c_stream__(requested_schema=requested_schema)
return self.WriteArrowStreamCapsule(stream_capsule, createFieldsFromSchema, options)

if hasattr(obj, "__arrow_c_array__"):
schema_capsule, array_capsule = obj.__arrow_c_array__(requested_schema=requested_schema)
return self.WriteArrowSchemaAndArrowArrayCapsule(schema_capsule, array_capsule, createFieldsFromSchema, options)

raise Exception("Passed object does not implement the __arrow_c_stream__ or __arrow_c_array__ interface.")


def WritePyArrow(self, pa_batch, options=[]):
"""Write the content of the passed PyArrow batch (either a pyarrow.Table, a pyarrow.RecordBatch or a pyarrow.StructArray) into the layer."""
"""Write the content of the passed PyArrow batch (either a pyarrow.Table, a pyarrow.RecordBatch or a pyarrow.StructArray) into the layer.

See also the WriteArrow() method to be independent of PyArrow
"""

import pyarrow as pa

Expand Down

0 comments on commit 6a00ccd

Please sign in to comment.