diff --git a/CHANGES.md b/CHANGES.md index 048506684..b4c0361ad 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,6 +1,6 @@ ## Version 2.0.0.dev19 (in development) - +* Added operation `write_geo_data_frame()` [#758](https://github.com/CCI-Tools/cate/issues/758) * Numbers displayed with too many digits [#754](https://github.com/CCI-Tools/cate/issues/754) * Improved error handling in operation `pearson_correlation_scalar()``, addresses [#746](https://github.com/CCI-Tools/cate/issues/746) * Fixed error in `plot_xxx()` operations `"'NoneType' object is not iterable"` [#749](https://github.com/CCI-Tools/cate/issues/749) diff --git a/cate/ops/io.py b/cate/ops/io.py index ebcb7b81d..49c4b4b88 100644 --- a/cate/ops/io.py +++ b/cate/ops/io.py @@ -408,19 +408,24 @@ def write_csv(obj: DataFrameLike.TYPE, raise ValidationError('obj must be a pandas.DataFrame or a xarray.Dataset') +GEO_DATA_FRAME_FILE_FILTERS = [ + dict(name='ESRI Shapefile', extensions=['shp']), + dict(name='GeoJSON', extensions=['json', 'geojson']), + dict(name='GPX', extensions=['gpx']), + dict(name='GPKG', extensions=['gpkg']), + _ALL_FILE_FILTER +] + + # noinspection PyIncorrectDocstring,PyUnusedLocal @op(tags=['input'], res_pattern='gdf_{index}') -@op_input('file', - file_open_mode='r', - file_filters=[dict(name='ESRI Shapefiles', extensions=['shp']), - dict(name='GeoJSON', extensions=['json', 'geojson']), - _ALL_FILE_FILTER]) +@op_input('file', file_open_mode='r', file_filters=GEO_DATA_FRAME_FILE_FILTERS) @op_input('crs', nullable=True, deprecated="Not used at all.") @op_input('more_args', nullable=True, data_type=DictLike) def read_geo_data_frame(file: str, crs: str = None, more_args: DictLike.TYPE = None) -> gpd.GeoDataFrame: """ - Reads geo-data from files with formats such as ESRI Shapefile, GeoJSON, GML. + Read a geo data frame from a file with a format such as ESRI Shapefile or GeoJSON. :param file: Is either the absolute or relative path to the file to be opened. :param more_args: Other optional keyword arguments. @@ -432,6 +437,44 @@ def read_geo_data_frame(file: str, crs: str = None, return GeoDataFrame.from_features(features) +# noinspection PyIncorrectDocstring,PyUnusedLocal +@op(tags=['output'], no_cache=True) +@op_input('gdf') +@op_input('file', file_open_mode='w', file_filters=GEO_DATA_FRAME_FILE_FILTERS) +@op_input('more_args', nullable=True, data_type=DictLike) +def write_geo_data_frame(gdf: gpd.GeoDataFrame, + file: str, crs: str = None, + more_args: DictLike.TYPE = None): + """ + Write a geo data frame to files with formats such as ESRI Shapefile or GeoJSON. + + :param gdf: A geo data frame. + :param file: Is either the absolute or relative path to the file to be opened. + :param more_args: Other optional keyword arguments. + Please refer to Python documentation of ``fiona.open()`` function. + """ + kwargs = DictLike.convert(more_args) or {} + if "driver" in kwargs: + driver = kwargs.pop("driver") + else: + root, ext = os.path.splitext(file) + ext_low = ext.lower() + if ext_low == "": + driver = "ESRI Shapefile" + file += ".shp" + elif ext_low == ".shp": + driver = "ESRI Shapefile" + elif ext_low == ".json" or ext_low == ".geojson": + driver = "GeoJSON" + elif ext_low == ".gpx": + driver = "GPX" + elif ext_low == ".gpkg": + driver = "GPKG" + else: + raise ValidationError(f'Cannot detect supported format from file extension "{ext}"') + gdf.to_file(file, driver=driver, **kwargs) + + @op(tags=['input'], res_pattern='ds_{index}') @op_input('path', file_open_mode='r', file_filters=[dict(name='Zarr', extensions=['zarr'])]) @op_input('file_system', value_set=['Local', 'S3', 'OBS']) diff --git a/test/ops/test_io.py b/test/ops/test_io.py index 0ac7510f6..eb7418142 100644 --- a/test/ops/test_io.py +++ b/test/ops/test_io.py @@ -1,8 +1,11 @@ """ Test the IO operations """ - +import fiona.errors import os +import shutil + +import shapely.wkt import unittest from io import StringIO from unittest import TestCase @@ -10,7 +13,7 @@ import geopandas as gpd from cate.core.types import ValidationError -from cate.ops.io import open_dataset, save_dataset, read_csv, read_geo_data_frame, write_csv +from cate.ops.io import open_dataset, save_dataset, read_csv, read_geo_data_frame, write_csv, write_geo_data_frame class TestIO(TestCase): @@ -45,6 +48,7 @@ def test_save_dataset(self): # Test behavior when passing unexpected type with self.assertRaises(NotImplementedError): dataset = ('a', 1, 3, 5) + # noinspection PyTypeChecker save_dataset(dataset, 'remove_me.nc') self.assertFalse(os.path.isfile('remove_me.nc')) @@ -72,16 +76,60 @@ def test_read_geo_data_frame(self): file = os.path.join(os.path.dirname(__file__), '..', '..', 'cate', 'ds', 'data', 'countries', 'countries.geojson') - data_frame = read_geo_data_frame(file) + data_frame = read_geo_data_frame(file=file) self.assertIsInstance(data_frame, gpd.GeoDataFrame) self.assertEqual(len(data_frame), 179) data_frame.close() - # Now with crs - data_frame = read_geo_data_frame(file, crs="EPSG:4326") - self.assertIsInstance(data_frame, gpd.GeoDataFrame) - self.assertEqual(len(data_frame), 179) - data_frame.close() + def test_write_geo_data_frame(self): + gdf = gpd.GeoDataFrame({'coli': [1, 2, 3, 4, 5, 6], + 'cols': ['a', 'b', 'c', 'x', 'y', 'z'], + 'colf': [0.4, 0.5, 0.3, 0.3, 0.1, 0.4], + 'geometry': gpd.GeoSeries([ + shapely.wkt.loads('POINT(10 10)'), + shapely.wkt.loads('POINT(10 20)'), + shapely.wkt.loads('POINT(10 30)'), + shapely.wkt.loads('POINT(20 30)'), + shapely.wkt.loads('POINT(20 20)'), + shapely.wkt.loads('POINT(20 10)'), + ])}) + + out_dir = os.path.join(os.path.dirname(__file__), '..', '..', "_test_outputs") + shutil.rmtree(out_dir, ignore_errors=True) + os.makedirs(out_dir, exist_ok=True) + + file = os.path.join(out_dir, 'test1.geojson') + write_geo_data_frame(gdf=gdf, file=file) + self.assertTrue(os.path.isfile(file)) + + file = os.path.join(out_dir, 'test2.js') + write_geo_data_frame(gdf=gdf, file=file, more_args=dict(driver='GeoJSON')) + self.assertTrue(os.path.isfile(file)) + + file = os.path.join(out_dir, 'test3') + write_geo_data_frame(gdf=gdf, file=file) + self.assertTrue(os.path.isfile(file + ".shp")) + + file = os.path.join(out_dir, 'test4.shp') + write_geo_data_frame(gdf=gdf, file=file) + self.assertTrue(os.path.isfile(file)) + + file = os.path.join(out_dir, 'test5.gpkg') + write_geo_data_frame(gdf=gdf, file=file) + self.assertTrue(os.path.isfile(file)) + + file = os.path.join(out_dir, 'test6.bibo') + with self.assertRaises(ValidationError) as cm: + write_geo_data_frame(gdf=gdf, file=file) + self.assertEquals(f'{cm.exception}', 'Cannot detect supported format from file extension ".bibo"') + + file = os.path.join(out_dir, 'test7.gpx') + try: + write_geo_data_frame(gdf=gdf, file=file) + except fiona.errors.SchemaError: + pass + + shutil.rmtree(out_dir, ignore_errors=True) def test_write_csv_with_dataset(self): import io