From a99f3f638a929294a6039badc94c3098aa4f0e97 Mon Sep 17 00:00:00 2001 From: ChristofKaufmann Date: Mon, 8 Jul 2024 20:56:41 +0200 Subject: [PATCH] Fix reading from zstd decompression stream (#443) --- CHANGES.md | 1 + pyogrio/tests/conftest.py | 26 ++++++++++++++++++++++++++ pyogrio/tests/test_arrow.py | 6 ++++++ pyogrio/tests/test_core.py | 21 +++++++++++++++++++++ pyogrio/tests/test_raw_io.py | 6 ++++++ pyogrio/util.py | 2 +- 6 files changed, 61 insertions(+), 1 deletion(-) diff --git a/CHANGES.md b/CHANGES.md index 786762c9..b655b86e 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -6,6 +6,7 @@ - Silence warning from `write_dataframe` with `GeoSeries.notna()` (#435). - BUG: Enable mask & bbox filter when geometry column not read (#431). +- Prevent seek on read from compressed inputs (#443). ## 0.9.0 (2024-06-17) diff --git a/pyogrio/tests/conftest.py b/pyogrio/tests/conftest.py index 56fd3bfc..d562a4d1 100644 --- a/pyogrio/tests/conftest.py +++ b/pyogrio/tests/conftest.py @@ -1,3 +1,4 @@ +from io import BytesIO from pathlib import Path from zipfile import ZIP_DEFLATED, ZipFile @@ -178,6 +179,31 @@ def geojson_filelike(tmp_path): yield f +@pytest.fixture(scope="function") +def nonseekable_bytes(tmp_path): + # mock a non-seekable byte stream, such as a zstandard handle + class NonSeekableBytesIO(BytesIO): + def seekable(self): + return False + + def seek(self, *args, **kwargs): + raise OSError("cannot seek") + + # wrap GeoJSON into a non-seekable BytesIO + geojson = """{ + "type": "FeatureCollection", + "features": [ + { + "type": "Feature", + "properties": { }, + "geometry": { "type": "Point", "coordinates": [1, 1] } + } + ] + }""" + + return NonSeekableBytesIO(geojson.encode("UTF-8")) + + @pytest.fixture( scope="session", params=[ diff --git a/pyogrio/tests/test_arrow.py b/pyogrio/tests/test_arrow.py index 7edddcdb..0d3c7147 100644 --- a/pyogrio/tests/test_arrow.py +++ b/pyogrio/tests/test_arrow.py @@ -168,6 +168,12 @@ def test_read_arrow_bytes(geojson_bytes): assert len(table) == 3 +def test_read_arrow_nonseekable_bytes(nonseekable_bytes): + meta, table = read_arrow(nonseekable_bytes) + assert meta["fields"].shape == (0,) + assert len(table) == 1 + + def test_read_arrow_filelike(geojson_filelike): meta, table = read_arrow(geojson_filelike) diff --git a/pyogrio/tests/test_core.py b/pyogrio/tests/test_core.py index 765564eb..ed0a5ef9 100644 --- a/pyogrio/tests/test_core.py +++ b/pyogrio/tests/test_core.py @@ -184,6 +184,13 @@ def test_list_layers_bytes(geojson_bytes): assert layers[0, 0] == "test" +def test_list_layers_nonseekable_bytes(nonseekable_bytes): + layers = list_layers(nonseekable_bytes) + + assert layers.shape == (1, 2) + assert layers[0, 1] == "Point" + + def test_list_layers_filelike(geojson_filelike): layers = list_layers(geojson_filelike) @@ -218,6 +225,13 @@ def test_read_bounds_bytes(geojson_bytes): assert allclose(bounds[:, 0], [-180.0, -18.28799, 180.0, -16.02088]) +def test_read_bounds_nonseekable_bytes(nonseekable_bytes): + fids, bounds = read_bounds(nonseekable_bytes) + assert fids.shape == (1,) + assert bounds.shape == (4, 1) + assert allclose(bounds[:, 0], [1, 1, 1, 1]) + + def test_read_bounds_filelike(geojson_filelike): fids, bounds = read_bounds(geojson_filelike) assert fids.shape == (3,) @@ -449,6 +463,13 @@ def test_read_info_bytes(geojson_bytes): assert meta["features"] == 3 +def test_read_info_nonseekable_bytes(nonseekable_bytes): + meta = read_info(nonseekable_bytes) + + assert meta["fields"].shape == (0,) + assert meta["features"] == 1 + + def test_read_info_filelike(geojson_filelike): meta = read_info(geojson_filelike) diff --git a/pyogrio/tests/test_raw_io.py b/pyogrio/tests/test_raw_io.py index ebabe073..f45d1f85 100644 --- a/pyogrio/tests/test_raw_io.py +++ b/pyogrio/tests/test_raw_io.py @@ -819,6 +819,12 @@ def test_read_from_file_like(tmp_path, naturalearth_lowres, driver, ext): assert_equal_result((meta, index, geometry, field_data), result2) +def test_read_from_nonseekable_bytes(nonseekable_bytes): + meta, _, geometry, _ = read(nonseekable_bytes) + assert meta["fields"].shape == (0,) + assert len(geometry) == 1 + + @pytest.mark.parametrize("ext", ["gpkg", "fgb"]) def test_read_write_data_types_numeric(tmp_path, ext): # Point(0, 0) diff --git a/pyogrio/util.py b/pyogrio/util.py index 26277f58..6bf04693 100644 --- a/pyogrio/util.py +++ b/pyogrio/util.py @@ -38,7 +38,7 @@ def get_vsi_path_or_buffer(path_or_buffer): bytes_buffer = path_or_buffer.read() # rewind buffer if possible so that subsequent operations do not need to rewind - if hasattr(path_or_buffer, "seek"): + if hasattr(path_or_buffer, "seekable") and path_or_buffer.seekable(): path_or_buffer.seek(0) return bytes_buffer