Skip to content

Commit

Permalink
DataCube.download(): only add save_result when necessary
Browse files Browse the repository at this point in the history
  • Loading branch information
soxofaan committed Sep 28, 2022
1 parent ee659f1 commit 9e77631
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed
- Improve default dimension metadata of a datacube created with `openeo.rest.datacube.DataCube.load_disk_collection`
- `DataCube.download()`: only automatically add `save_result` node when there is none yet.

### Removed

Expand Down
22 changes: 18 additions & 4 deletions openeo/rest/datacube.py
Original file line number Diff line number Diff line change
Expand Up @@ -1614,10 +1614,24 @@ def download(
:param options: Optional, file format options
:return: None if the result is stored to disk, or a bytes object returned by the backend.
"""
if not format:
format = guess_format(outputfile) if outputfile else "GTiff"
# TODO: only add `save_result` node when there is none yet?
cube = self.save_result(format=format, options=options)
if self.result_node().process_id == "save_result":
# There is already a `save_result` node: check if it is consistent with given format/options
args = self.result_node().arguments
if format is not None and format.lower() != args["format"].lower():
raise ValueError(
f"Existing `save_result` node with different format {args['format']!r} != {format!r}"
)
if options is not None and options != args["options"]:
raise ValueError(
f"Existing `save_result` node with different options {args['options']!r} != {options!r}"
)
cube = self
else:
# No `save_result` node yet: automatically add it.
if not format:
format = guess_format(outputfile) if outputfile else "GTiff"
cube = self.save_result(format=format, options=options)

return self._connection.download(cube.flat_graph(), outputfile)

def validate(self) -> List[dict]:
Expand Down
45 changes: 45 additions & 0 deletions tests/rest/datacube/test_datacube100.py
Original file line number Diff line number Diff line change
Expand Up @@ -2044,6 +2044,51 @@ def test_apply_math_simple(con100, math, process, args):
}


@pytest.mark.parametrize(["save_result_kwargs", "download_kwargs", "expected_fail"], [
({}, {}, None),
({"format": "GTiff"}, {}, None),
({}, {"format": "GTiff"}, None),
({"format": "GTiff"}, {"format": "GTiff"}, None),
({"format": "netCDF"}, {"format": "NETCDF"}, None),
(
{"format": "netCDF"},
{"format": "JSON"},
"Existing `save_result` node with different format 'netCDF' != 'JSON'"
),
({"options": {}}, {}, None),
({"options": {"quality": "low"}}, {"options": {"quality": "low"}}, None),
(
{"options": {"colormap": "jet"}},
{"options": {"quality": "low"}},
"Existing `save_result` node with different options {'colormap': 'jet'} != {'quality': 'low'}"
),
])
def test_save_result_and_download(
con100, requests_mock, tmp_path, save_result_kwargs, download_kwargs, expected_fail
):
def post_result(request, context):
pg = request.json()["process"]["process_graph"]
process_histogram = collections.Counter(p["process_id"] for p in pg.values())
assert process_histogram["save_result"] == 1
return b"tiffdata"

post_result_mock = requests_mock.post(API_URL + '/result', content=post_result)

cube = con100.load_collection("S2")
if save_result_kwargs:
cube = cube.save_result(**save_result_kwargs)

path = tmp_path / "tmp.tiff"
if expected_fail:
with pytest.raises(ValueError, match=expected_fail):
cube.download(str(path), **download_kwargs)
assert post_result_mock.call_count == 0
else:
cube.download(str(path), **download_kwargs)
assert path.read_bytes() == b"tiffdata"
assert post_result_mock.call_count == 1


class TestBatchJob:
_EXPECTED_SIMPLE_S2_JOB = {"process": {"process_graph": {
"loadcollection1": {
Expand Down

0 comments on commit 9e77631

Please sign in to comment.