diff --git a/openeo/internal/compat.py b/openeo/internal/compat.py new file mode 100644 index 000000000..f128d4f24 --- /dev/null +++ b/openeo/internal/compat.py @@ -0,0 +1,13 @@ +""" +Compatibility layer and small backports. +""" + +import contextlib + +try: + from contextlib import nullcontext +except ImportError: + # nullcontext for pre-3.7 python + @contextlib.contextmanager + def nullcontext(enter_result=None): + yield enter_result diff --git a/openeo/rest/_datacube.py b/openeo/rest/_datacube.py index fe7bb4f95..43fb53f86 100644 --- a/openeo/rest/_datacube.py +++ b/openeo/rest/_datacube.py @@ -1,8 +1,11 @@ import json import logging +import sys import typing +from pathlib import Path from typing import Optional, Union, Tuple +from openeo.internal.compat import nullcontext from openeo.internal.graph_building import PGNode, _FromNodeMixin from openeo.util import legacy_alias @@ -71,12 +74,23 @@ def print_json(self, *, file=None, indent: Union[int, None] = 2, separators: Opt Also see ``json.dumps`` docs for more information on the JSON formatting options. :param file: file-like object (stream) to print to (current ``sys.stdout`` by default). + Or a path (string or pathlib.Path) to a file to write to. :param indent: JSON indentation level. :param separators: (optional) tuple of item/key separators. .. versionadded:: 0.12.0 """ - print(self.to_json(indent=indent, separators=separators), file=file) + pg = {"process_graph": self.flat_graph()} + if isinstance(file, (str, Path)): + # Create (new) file and automatically close it + file_ctx = Path(file).open("w", encoding="utf8") + else: + # Just use file as-is, but don't close it automatically. + file_ctx = nullcontext(enter_result=file or sys.stdout) + with file_ctx as f: + json.dump(pg, f, indent=indent, separators=separators) + if indent is not None: + f.write("\n") @property def _api_version(self): diff --git a/tests/rest/datacube/test_datacube100.py b/tests/rest/datacube/test_datacube100.py index 5ec112d6b..660c2757c 100644 --- a/tests/rest/datacube/test_datacube100.py +++ b/tests/rest/datacube/test_datacube100.py @@ -1470,6 +1470,16 @@ def test_print_json_file(con100): assert f.getvalue() == EXPECTED_JSON_EXPORT_S2_NDVI + "\n" +@pytest.mark.parametrize("path_factory", [str, pathlib.Path]) +def test_print_json_file_path(con100, tmp_path, path_factory): + ndvi = con100.load_collection("S2").ndvi() + path = tmp_path / "dump.json" + assert not path.exists() + ndvi.print_json(file=path_factory(path)) + assert path.exists() + assert path.read_text() == EXPECTED_JSON_EXPORT_S2_NDVI + "\n" + + def test_sar_backscatter_defaults(con100): cube = con100.load_collection("S2").sar_backscatter() assert _get_leaf_node(cube) == {