Skip to content

Commit

Permalink
Fix reading from zstd decompression stream (#443)
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristofKaufmann authored Jul 8, 2024
1 parent 6a2b0d8 commit a99f3f6
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

- Silence warning from `write_dataframe` with `GeoSeries.notna()` (#435).
- BUG: Enable mask & bbox filter when geometry column not read (#431).
- Prevent seek on read from compressed inputs (#443).

## 0.9.0 (2024-06-17)

Expand Down
26 changes: 26 additions & 0 deletions pyogrio/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from io import BytesIO
from pathlib import Path
from zipfile import ZIP_DEFLATED, ZipFile

Expand Down Expand Up @@ -178,6 +179,31 @@ def geojson_filelike(tmp_path):
yield f


@pytest.fixture(scope="function")
def nonseekable_bytes(tmp_path):
# mock a non-seekable byte stream, such as a zstandard handle
class NonSeekableBytesIO(BytesIO):
def seekable(self):
return False

def seek(self, *args, **kwargs):
raise OSError("cannot seek")

# wrap GeoJSON into a non-seekable BytesIO
geojson = """{
"type": "FeatureCollection",
"features": [
{
"type": "Feature",
"properties": { },
"geometry": { "type": "Point", "coordinates": [1, 1] }
}
]
}"""

return NonSeekableBytesIO(geojson.encode("UTF-8"))


@pytest.fixture(
scope="session",
params=[
Expand Down
6 changes: 6 additions & 0 deletions pyogrio/tests/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,12 @@ def test_read_arrow_bytes(geojson_bytes):
assert len(table) == 3


def test_read_arrow_nonseekable_bytes(nonseekable_bytes):
meta, table = read_arrow(nonseekable_bytes)
assert meta["fields"].shape == (0,)
assert len(table) == 1


def test_read_arrow_filelike(geojson_filelike):
meta, table = read_arrow(geojson_filelike)

Expand Down
21 changes: 21 additions & 0 deletions pyogrio/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,13 @@ def test_list_layers_bytes(geojson_bytes):
assert layers[0, 0] == "test"


def test_list_layers_nonseekable_bytes(nonseekable_bytes):
layers = list_layers(nonseekable_bytes)

assert layers.shape == (1, 2)
assert layers[0, 1] == "Point"


def test_list_layers_filelike(geojson_filelike):
layers = list_layers(geojson_filelike)

Expand Down Expand Up @@ -218,6 +225,13 @@ def test_read_bounds_bytes(geojson_bytes):
assert allclose(bounds[:, 0], [-180.0, -18.28799, 180.0, -16.02088])


def test_read_bounds_nonseekable_bytes(nonseekable_bytes):
fids, bounds = read_bounds(nonseekable_bytes)
assert fids.shape == (1,)
assert bounds.shape == (4, 1)
assert allclose(bounds[:, 0], [1, 1, 1, 1])


def test_read_bounds_filelike(geojson_filelike):
fids, bounds = read_bounds(geojson_filelike)
assert fids.shape == (3,)
Expand Down Expand Up @@ -449,6 +463,13 @@ def test_read_info_bytes(geojson_bytes):
assert meta["features"] == 3


def test_read_info_nonseekable_bytes(nonseekable_bytes):
meta = read_info(nonseekable_bytes)

assert meta["fields"].shape == (0,)
assert meta["features"] == 1


def test_read_info_filelike(geojson_filelike):
meta = read_info(geojson_filelike)

Expand Down
6 changes: 6 additions & 0 deletions pyogrio/tests/test_raw_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,12 @@ def test_read_from_file_like(tmp_path, naturalearth_lowres, driver, ext):
assert_equal_result((meta, index, geometry, field_data), result2)


def test_read_from_nonseekable_bytes(nonseekable_bytes):
meta, _, geometry, _ = read(nonseekable_bytes)
assert meta["fields"].shape == (0,)
assert len(geometry) == 1


@pytest.mark.parametrize("ext", ["gpkg", "fgb"])
def test_read_write_data_types_numeric(tmp_path, ext):
# Point(0, 0)
Expand Down
2 changes: 1 addition & 1 deletion pyogrio/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def get_vsi_path_or_buffer(path_or_buffer):
bytes_buffer = path_or_buffer.read()

# rewind buffer if possible so that subsequent operations do not need to rewind
if hasattr(path_or_buffer, "seek"):
if hasattr(path_or_buffer, "seekable") and path_or_buffer.seekable():
path_or_buffer.seek(0)

return bytes_buffer
Expand Down

0 comments on commit a99f3f6

Please sign in to comment.