diff --git a/docs/get_started.qmd b/docs/get_started.qmd index 2f658d7..ef7f79d 100644 --- a/docs/get_started.qmd +++ b/docs/get_started.qmd @@ -75,6 +75,7 @@ Above, we saved the data as a CSV, but you can choose another option depending o - `type = "arrow"` uses `to_feather()` from pandas to create an Arrow/Feather file. - `type = "joblib"` uses `joblib.dump()` to create a binary Python data file, such as for storing a trained model. See the [joblib docs](https://joblib.readthedocs.io/en/latest/) for more information. - `type = "json"` uses `json.dump()` to create a JSON file. Pretty much every programming language can read JSON files, but they only work well for nested lists. +- `type = "geoparquet"` uses `to_parquet()` from [geopandas](https://github.com/geopandas/geopandas) to create a [GeoParquet](https://github.com/opengeospatial/geoparquet) file, which is a specialized Parquet format for geospatial data. Note that when the data lives elsewhere, pins takes care of downloading and caching so that it's only re-downloaded when needed. That said, most boards transmit pins over HTTP, and this is going to be slow and possibly unreliable for very large pins. diff --git a/pins/boards.py b/pins/boards.py index f305488..2c32c95 100644 --- a/pins/boards.py +++ b/pins/boards.py @@ -319,7 +319,7 @@ def pin_write( Pin name. type: File type used to save `x` to disk. May be "csv", "arrow", "parquet", - "joblib", or "json". + "joblib", "json", or "geoparquet". title: A title for the pin; most important for shared boards so that others can understand what the pin contains. If omitted, a brief description diff --git a/pins/drivers.py b/pins/drivers.py index 5aa3e18..2cc5e79 100644 --- a/pins/drivers.py +++ b/pins/drivers.py @@ -22,6 +22,16 @@ def _assert_is_pandas_df(x, file_type: str) -> None: ) +def _assert_is_geopandas_df(x): + # Assume we have already protected against uninstalled geopandas + import geopandas as gpd + + if not isinstance(x, gpd.GeoDataFrame): + raise NotImplementedError( + "Currently only geopandas.GeoDataFrame can be saved to a GeoParquet." + ) + + def load_path(meta, path_to_version): # Check that only a single file name was given fnames = [meta.file] if isinstance(meta.file, str) else meta.file @@ -104,6 +114,17 @@ def load_data( return pd.read_csv(f) + elif meta.type == "geoparquet": + try: + import geopandas as gpd + except ModuleNotFoundError: + raise ModuleNotFoundError( + 'The "geopandas" package is required to read "geoparquet" type ' + "files." + ) from None + + return gpd.read_parquet(f) + elif meta.type == "joblib": import joblib @@ -144,6 +165,8 @@ def save_data(obj, fname, type=None, apply_suffix: bool = True) -> "str | Sequen if apply_suffix: if type == "file": suffix = "".join(Path(obj).suffixes) + elif type == "geoparquet": + suffix = ".parquet" else: suffix = f".{type}" else: @@ -175,6 +198,11 @@ def save_data(obj, fname, type=None, apply_suffix: bool = True) -> "str | Sequen obj.to_parquet(final_name) + elif type == "geoparquet": + _assert_is_geopandas_df(obj) + + obj.to_parquet(final_name) + elif type == "joblib": import joblib @@ -203,10 +231,20 @@ def default_title(obj, name): import pandas as pd if isinstance(obj, pd.DataFrame): + try: + import geopandas as gpd + except ModuleNotFoundError: + obj_name = "DataFrame" + else: + if isinstance(obj, gpd.GeoDataFrame): + obj_name = "GeoDataFrame" + else: + obj_name = "DataFrame" + # TODO(compat): title says CSV rather than data.frame # see https://github.com/machow/pins-python/issues/5 shape_str = " x ".join(map(str, obj.shape)) - return f"{name}: a pinned {shape_str} DataFrame" + return f"{name}: a pinned {shape_str} {obj_name}" else: obj_name = type(obj).__qualname__ return f"{name}: a pinned {obj_name} object" diff --git a/pins/tests/test_drivers.py b/pins/tests/test_drivers.py index 230f0e8..4588ab5 100644 --- a/pins/tests/test_drivers.py +++ b/pins/tests/test_drivers.py @@ -1,6 +1,7 @@ from pathlib import Path import fsspec +import geopandas as gpd import pandas as pd import pytest @@ -34,6 +35,10 @@ class D: [ (pd.DataFrame({"x": [1, 2]}), "somename: a pinned 2 x 1 DataFrame"), (pd.DataFrame({"x": [1], "y": [2]}), "somename: a pinned 1 x 2 DataFrame"), + ( + gpd.GeoDataFrame({"x": [1], "geometry": [None]}), + "somename: a pinned 1 x 2 GeoDataFrame", + ), (ExC(), "somename: a pinned ExC object"), (ExC().D(), "somename: a pinned ExC.D object"), ([1, 2, 3], "somename: a pinned list object"), @@ -76,6 +81,27 @@ def test_driver_roundtrip(tmp_path: Path, type_): assert df.equals(obj) +def test_driver_geoparquet_roundtrip(tmp_path): + import geopandas as gpd + + gdf = gpd.GeoDataFrame( + {"x": [1, 2, 3], "geometry": gpd.points_from_xy([1, 2, 3], [1, 2, 3])} + ) + + fname = "some_gdf" + full_file = f"{fname}.parquet" + + p_obj = tmp_path / fname + res_fname = save_data(gdf, p_obj, "geoparquet") + + assert Path(res_fname).name == full_file + + meta = MetaRaw(full_file, "geoparquet", "my_pin") + obj = load_data(meta, fsspec.filesystem("file"), tmp_path, allow_pickle_read=True) + + assert gdf.equals(obj) + + @pytest.mark.parametrize( "type_", [ diff --git a/pyproject.toml b/pyproject.toml index 3497c51..294ec11 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,7 @@ test = [ "pytest-dotenv", "pytest-parallel", "s3fs", + "geopandas>=0.8.0", # At 0.8.0, the GeoParquet format was introduced. ] [build-system] diff --git a/requirements/dev.txt b/requirements/dev.txt index 2b2a43a..5331a03 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -61,6 +61,8 @@ cachetools==5.4.0 # via google-auth certifi==2024.7.4 # via + # pyogrio + # pyproj # requests # sphobjinv cffi==1.16.0 @@ -118,6 +120,8 @@ fsspec==2024.6.1 # s3fs gcsfs==2024.6.1 # via pins (setup.cfg) +geopandas==1.0.1 + # via pins (setup.cfg) google-api-core==2.19.1 # via # google-cloud-core @@ -235,20 +239,26 @@ nodeenv==1.9.1 numpy==2.0.0 # via # fastparquet + # geopandas # pandas # pyarrow + # pyogrio + # shapely oauthlib==3.2.2 # via requests-oauthlib packaging==24.1 # via # build # fastparquet + # geopandas # ipykernel + # pyogrio # pytest # pytest-cases pandas==2.2.2 # via # fastparquet + # geopandas # pins (setup.cfg) parso==0.8.4 # via jedi @@ -309,6 +319,10 @@ pyjwt==2.8.0 # via # msal # pyjwt +pyogrio==0.9.0 + # via geopandas +pyproj==3.6.1 + # via geopandas pyproject-hooks==1.1.0 # via # build @@ -373,6 +387,8 @@ rsa==4.9 # via google-auth s3fs==2024.6.1 # via pins (setup.cfg) +shapely==2.0.5 + # via geopandas six==1.16.0 # via # asttokens