From 9e7763191654048867336ed7644809e30afce394 Mon Sep 17 00:00:00 2001 From: Stefaan Lippens Date: Tue, 27 Sep 2022 19:49:06 +0200 Subject: [PATCH] `DataCube.download()`: only add `save_result` when necessary --- CHANGELOG.md | 1 + openeo/rest/datacube.py | 22 +++++++++--- tests/rest/datacube/test_datacube100.py | 45 +++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0dcc11a20..d1bc07451 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - Improve default dimension metadata of a datacube created with `openeo.rest.datacube.DataCube.load_disk_collection` +- `DataCube.download()`: only automatically add `save_result` node when there is none yet. ### Removed diff --git a/openeo/rest/datacube.py b/openeo/rest/datacube.py index 337551aba..148988193 100644 --- a/openeo/rest/datacube.py +++ b/openeo/rest/datacube.py @@ -1614,10 +1614,24 @@ def download( :param options: Optional, file format options :return: None if the result is stored to disk, or a bytes object returned by the backend. """ - if not format: - format = guess_format(outputfile) if outputfile else "GTiff" - # TODO: only add `save_result` node when there is none yet? - cube = self.save_result(format=format, options=options) + if self.result_node().process_id == "save_result": + # There is already a `save_result` node: check if it is consistent with given format/options + args = self.result_node().arguments + if format is not None and format.lower() != args["format"].lower(): + raise ValueError( + f"Existing `save_result` node with different format {args['format']!r} != {format!r}" + ) + if options is not None and options != args["options"]: + raise ValueError( + f"Existing `save_result` node with different options {args['options']!r} != {options!r}" + ) + cube = self + else: + # No `save_result` node yet: automatically add it. + if not format: + format = guess_format(outputfile) if outputfile else "GTiff" + cube = self.save_result(format=format, options=options) + return self._connection.download(cube.flat_graph(), outputfile) def validate(self) -> List[dict]: diff --git a/tests/rest/datacube/test_datacube100.py b/tests/rest/datacube/test_datacube100.py index 660c2757c..43955dc2e 100644 --- a/tests/rest/datacube/test_datacube100.py +++ b/tests/rest/datacube/test_datacube100.py @@ -2044,6 +2044,51 @@ def test_apply_math_simple(con100, math, process, args): } +@pytest.mark.parametrize(["save_result_kwargs", "download_kwargs", "expected_fail"], [ + ({}, {}, None), + ({"format": "GTiff"}, {}, None), + ({}, {"format": "GTiff"}, None), + ({"format": "GTiff"}, {"format": "GTiff"}, None), + ({"format": "netCDF"}, {"format": "NETCDF"}, None), + ( + {"format": "netCDF"}, + {"format": "JSON"}, + "Existing `save_result` node with different format 'netCDF' != 'JSON'" + ), + ({"options": {}}, {}, None), + ({"options": {"quality": "low"}}, {"options": {"quality": "low"}}, None), + ( + {"options": {"colormap": "jet"}}, + {"options": {"quality": "low"}}, + "Existing `save_result` node with different options {'colormap': 'jet'} != {'quality': 'low'}" + ), +]) +def test_save_result_and_download( + con100, requests_mock, tmp_path, save_result_kwargs, download_kwargs, expected_fail +): + def post_result(request, context): + pg = request.json()["process"]["process_graph"] + process_histogram = collections.Counter(p["process_id"] for p in pg.values()) + assert process_histogram["save_result"] == 1 + return b"tiffdata" + + post_result_mock = requests_mock.post(API_URL + '/result', content=post_result) + + cube = con100.load_collection("S2") + if save_result_kwargs: + cube = cube.save_result(**save_result_kwargs) + + path = tmp_path / "tmp.tiff" + if expected_fail: + with pytest.raises(ValueError, match=expected_fail): + cube.download(str(path), **download_kwargs) + assert post_result_mock.call_count == 0 + else: + cube.download(str(path), **download_kwargs) + assert path.read_bytes() == b"tiffdata" + assert post_result_mock.call_count == 1 + + class TestBatchJob: _EXPECTED_SIMPLE_S2_JOB = {"process": {"process_graph": { "loadcollection1": {