From 75e7c544168b060a2f96fb3c9a5d5d91e1e91a5e Mon Sep 17 00:00:00 2001 From: Even Rouault Date: Wed, 6 Sep 2023 23:31:31 +0200 Subject: [PATCH] Arrow/Parquet GetNextArrowArray(): implement full spatial filtering (not just bbox intersection) (fixes #8347) --- autotest/ogr/ogr_parquet.py | 100 ++++++++++++++++++++++ ogr/ogrsf_frmts/generic/ogrlayerarrow.cpp | 22 ++--- 2 files changed, 111 insertions(+), 11 deletions(-) diff --git a/autotest/ogr/ogr_parquet.py b/autotest/ogr/ogr_parquet.py index 12bb12370e2f..37922e0806b1 100755 --- a/autotest/ogr/ogr_parquet.py +++ b/autotest/ogr/ogr_parquet.py @@ -1851,6 +1851,106 @@ def test_ogr_parquet_arrow_stream_numpy_fast_spatial_filter(): assert len(batches) == 0 +############################################################################### + + +def test_ogr_parquet_arrow_stream_numpy_detailed_spatial_filter(tmp_vsimem): + pytest.importorskip("osgeo.gdal_array") + pytest.importorskip("numpy") + + filename = str( + tmp_vsimem + / "test_ogr_parquet_arrow_stream_numpy_detailed_spatial_filter.parquet" + ) + ds = ogr.GetDriverByName("Parquet").CreateDataSource(filename) + lyr = ds.CreateLayer("test", options=["FID=fid"]) + for idx, wkt in enumerate( + [ + "POINT(1 2)", + "MULTIPOINT(0 0,1 2)", + "LINESTRING(3 4,5 6)", + "MULTILINESTRING((7 8,7.5 8.5),(3 4,5 6))", + "POLYGON((10 20,10 30,20 30,10 20),(11 21,11 29,19 29,11 21))", + "MULTIPOLYGON(((100 100,100 200,200 200,100 100)),((10 20,10 30,20 30,10 20),(11 21,11 29,19 29,11 21)))", + "LINESTRING EMPTY", + "MULTILINESTRING EMPTY", + "POLYGON EMPTY", + "MULTIPOLYGON EMPTY", + "GEOMETRYCOLLECTION EMPTY", + ] + ): + f = ogr.Feature(lyr.GetLayerDefn()) + f.SetFID(idx) + f.SetGeometryDirectly(ogr.CreateGeometryFromWkt(wkt)) + lyr.CreateFeature(f) + ds = None + + ds = ogr.Open(filename) + lyr = ds.GetLayer(0) + + eps = 1e-1 + + # Select nothing + with ogrtest.spatial_filter(lyr, 6, 0, 8, 1): + stream = lyr.GetArrowStreamAsNumPy(options=["USE_MASKED_ARRAYS=NO"]) + batches = [batch for batch in stream] + assert len(batches) == 0 + + # Select POINT and MULTIPOINT + with ogrtest.spatial_filter(lyr, 1 - eps, 2 - eps, 1 + eps, 2 + eps): + stream = lyr.GetArrowStreamAsNumPy(options=["USE_MASKED_ARRAYS=NO"]) + batches = [batch for batch in stream] + assert len(batches) == 1 + assert list(batches[0]["fid"]) == [0, 1] + assert [f.GetFID() for f in lyr] == [0, 1] + + # Select LINESTRING and MULTILINESTRING due to point falling in bbox + with ogrtest.spatial_filter(lyr, 3 - eps, 4 - eps, 3 + eps, 4 + eps): + stream = lyr.GetArrowStreamAsNumPy(options=["USE_MASKED_ARRAYS=NO"]) + batches = [batch for batch in stream] + assert len(batches) == 1 + assert list(batches[0]["fid"]) == [2, 3] + assert [f.GetFID() for f in lyr] == [2, 3] + + # Select LINESTRING and MULTILINESTRING due to point falling in bbox + with ogrtest.spatial_filter(lyr, 5 - eps, 6 - eps, 5 + eps, 6 + eps): + stream = lyr.GetArrowStreamAsNumPy(options=["USE_MASKED_ARRAYS=NO"]) + batches = [batch for batch in stream] + assert len(batches) == 1 + assert list(batches[0]["fid"]) == [2, 3] + assert [f.GetFID() for f in lyr] == [2, 3] + + # Select LINESTRING and MULTILINESTRING due to more generic intersection + with ogrtest.spatial_filter(lyr, 4 - eps, 5 - eps, 4 + eps, 5 + eps): + stream = lyr.GetArrowStreamAsNumPy(options=["USE_MASKED_ARRAYS=NO"]) + batches = [batch for batch in stream] + assert len(batches) == 1 + assert list(batches[0]["fid"]) == [2, 3] + assert [f.GetFID() for f in lyr] == [2, 3] + + # Select POLYGON and MULTIPOLYGON due to point falling in bbox + with ogrtest.spatial_filter(lyr, 10 - eps, 20 - eps, 10 + eps, 20 + eps): + stream = lyr.GetArrowStreamAsNumPy(options=["USE_MASKED_ARRAYS=NO"]) + batches = [batch for batch in stream] + assert len(batches) == 1 + assert list(batches[0]["fid"]) == [4, 5] + assert [f.GetFID() for f in lyr] == [4, 5] + + # bbox with polygon hole + with ogrtest.spatial_filter(lyr, 12 - eps, 20.5 - eps, 12 + eps, 20.5 + eps): + stream = lyr.GetArrowStreamAsNumPy(options=["USE_MASKED_ARRAYS=NO"]) + batches = [batch for batch in stream] + if ogrtest.have_geos(): + assert len(batches) == 0 + else: + assert len(batches) == 1 + assert list(batches[0]["fid"]) == [4, 5] + assert [f.GetFID() for f in lyr] == [4, 5] + + ds = None + gdal.Unlink(filename) + + ############################################################################### # Test SetAttributeFilter() and arrow stream interface diff --git a/ogr/ogrsf_frmts/generic/ogrlayerarrow.cpp b/ogr/ogrsf_frmts/generic/ogrlayerarrow.cpp index ea23d72aac2d..86c7e7a0452b 100644 --- a/ogr/ogrsf_frmts/generic/ogrlayerarrow.cpp +++ b/ogr/ogrsf_frmts/generic/ogrlayerarrow.cpp @@ -2379,8 +2379,7 @@ CompactFixedWidthArray(struct ArrowArray *array, int nWidth, template static size_t -FillValidityArrayFromWKBArray(struct ArrowArray *array, - const OGREnvelope &sFilterEnvelope, +FillValidityArrayFromWKBArray(struct ArrowArray *array, const OGRLayer *poLayer, std::vector &abyValidityFromFilters) { const size_t nLength = static_cast(array->length); @@ -2399,11 +2398,12 @@ FillValidityArrayFromWKBArray(struct ArrowArray *array, { if (!pabyValidity || TestBit(pabyValidity, i + nOffset)) { - if (OGRWKBGetBoundingBox( - pabyData + panOffsets[i], - static_cast(panOffsets[i + 1] - panOffsets[i]), - sEnvelope) && - sFilterEnvelope.Intersects(sEnvelope)) + const GByte *pabyWKB = pabyData + panOffsets[i]; + const size_t nWKBSize = + static_cast(panOffsets[i + 1] - panOffsets[i]); + if (poLayer->FilterWKBGeometry(pabyWKB, nWKBSize, + /* bEnvelopeAlreadySet=*/false, + sEnvelope)) { abyValidityFromFilters[i] = true; nCountIntersecting++; @@ -2880,11 +2880,11 @@ void OGRLayer::PostFilterArrowArray(const struct ArrowSchema *schema, const size_t nCountIntersectingGeom = m_poFilterGeom ? (strcmp(schema->children[iGeomField]->format, "z") == 0 ? FillValidityArrayFromWKBArray( - array->children[iGeomField], - m_sFilterEnvelope, abyValidityFromFilters) + array->children[iGeomField], this, + abyValidityFromFilters) : FillValidityArrayFromWKBArray( - array->children[iGeomField], - m_sFilterEnvelope, abyValidityFromFilters)) + array->children[iGeomField], this, + abyValidityFromFilters)) : nLength; if (!m_poFilterGeom) abyValidityFromFilters.resize(nLength, true);