From badebaa7673ee966a2c029e2ff5639c1d06b10a4 Mon Sep 17 00:00:00 2001 From: "Sasha (Alejandro Vicente Grabovetsky)" Date: Fri, 28 Feb 2020 11:19:26 +0200 Subject: [PATCH 1/4] Add support for pathlib.Path --- omegaconf/omegaconf.py | 5 +++-- tests/test_serialization.py | 21 +++++++++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index 46579a10e..c03c59e55 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -2,6 +2,7 @@ import copy import io import os +import pathlib import re import sys from contextlib import contextmanager @@ -183,7 +184,7 @@ def create( # noqa F811 def load(file_: Union[str, IO[bytes]]) -> Union[DictConfig, ListConfig]: from ._utils import get_yaml_loader - if isinstance(file_, str): + if isinstance(file_, (str, pathlib.Path)): with io.open(os.path.abspath(file_), "r", encoding="utf-8") as f: obj = yaml.load(f, Loader=get_yaml_loader()) assert isinstance(obj, (list, dict, str)) @@ -204,7 +205,7 @@ def save(config: Container, f: Union[str, IO[str]], resolve: bool = False) -> No :param resolve: True to save a resolved config (defaults to False) """ data = config.pretty(resolve=resolve) - if isinstance(f, str): + if isinstance(f, (str, pathlib.Path)): with io.open(os.path.abspath(f), "w", encoding="utf-8") as file: file.write(data) elif hasattr(f, "write"): diff --git a/tests/test_serialization.py b/tests/test_serialization.py index b32084e2a..3c01f2045 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import io import os +import pathlib import tempfile from typing import Any, Dict @@ -37,6 +38,20 @@ def save_load_from_filename(conf: Container, resolve: bool, expected: Any) -> No os.unlink(fp.name) +def save_load_from_pathlib_path(conf: Container, resolve: bool, expected: Any) -> None: + if expected is None: + expected = conf + # note that delete=False here is a work around windows incompetence. + try: + with tempfile.NamedTemporaryFile(delete=False) as fp: + filepath = pathlib.Path(fp.name) + OmegaConf.save(conf, filepath, resolve=resolve) + c2 = OmegaConf.load(filepath) + assert c2 == expected + finally: + os.unlink(fp.name) + + def test_load_from_invalid() -> None: with pytest.raises(TypeError): OmegaConf.load(3.1415) # type: ignore @@ -64,6 +79,12 @@ def test_save_load__from_filename( cfg = OmegaConf.create(input_) save_load_from_filename(cfg, resolve, expected) + def test_save_load__from_pathlib_path( + self, input_: Dict[str, Any], resolve: bool, expected: Any + ) -> None: + cfg = OmegaConf.create(input_) + save_load_from_pathlib_path(cfg, resolve, expected) + def test_save_illegal_type() -> None: with pytest.raises(TypeError): From 5e1d23ca2b237aaf8eb0f6ce5c3e0324a94678ad Mon Sep 17 00:00:00 2001 From: "Sasha (Alejandro Vicente Grabovetsky)" Date: Fri, 28 Feb 2020 16:07:02 +0200 Subject: [PATCH 2/4] Add news file --- news/159.bugfix | 1 + 1 file changed, 1 insertion(+) create mode 100644 news/159.bugfix diff --git a/news/159.bugfix b/news/159.bugfix new file mode 100644 index 000000000..7f7979662 --- /dev/null +++ b/news/159.bugfix @@ -0,0 +1 @@ +Add support for loading/saving config files by using native python pathlib.Path From 6ceff949fd30ae3758dd46ed730ae346d20cbeae Mon Sep 17 00:00:00 2001 From: "Sasha (Alejandro Vicente Grabovetsky)" Date: Fri, 28 Feb 2020 16:09:47 +0200 Subject: [PATCH 3/4] Update typing --- omegaconf/omegaconf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index c03c59e55..b3246f2ac 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -181,7 +181,7 @@ def create( # noqa F811 ) @staticmethod - def load(file_: Union[str, IO[bytes]]) -> Union[DictConfig, ListConfig]: + def load(file_: Union[str, pathlib.Path, IO[bytes]]) -> Union[DictConfig, ListConfig]: from ._utils import get_yaml_loader if isinstance(file_, (str, pathlib.Path)): @@ -197,7 +197,7 @@ def load(file_: Union[str, IO[bytes]]) -> Union[DictConfig, ListConfig]: raise TypeError("Unexpected file type") @staticmethod - def save(config: Container, f: Union[str, IO[str]], resolve: bool = False) -> None: + def save(config: Container, f: Union[str, pathlib.Path, IO[str]], resolve: bool = False) -> None: """ Save as configuration object to a file :param config: omegaconf.Config object (DictConfig or ListConfig). From ae6de0f6a2dc1cd966e929938ba349179d6d31dd Mon Sep 17 00:00:00 2001 From: "Sasha (Alejandro Vicente Grabovetsky)" Date: Fri, 28 Feb 2020 16:19:42 +0200 Subject: [PATCH 4/4] Parametrise tests --- tests/test_serialization.py | 42 +++++++++++-------------------------- 1 file changed, 12 insertions(+), 30 deletions(-) diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 3c01f2045..fd8a46beb 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -3,7 +3,7 @@ import os import pathlib import tempfile -from typing import Any, Dict +from typing import Any, Dict, Type import pytest @@ -25,26 +25,13 @@ def save_load_from_file(conf: Container, resolve: bool, expected: Any) -> None: os.unlink(fp.name) -def save_load_from_filename(conf: Container, resolve: bool, expected: Any) -> None: +def save_load_from_filename(conf: Container, resolve: bool, expected: Any, file_class: Type) -> None: if expected is None: expected = conf # note that delete=False here is a work around windows incompetence. try: with tempfile.NamedTemporaryFile(delete=False) as fp: - OmegaConf.save(conf, fp.name, resolve=resolve) - c2 = OmegaConf.load(fp.name) - assert c2 == expected - finally: - os.unlink(fp.name) - - -def save_load_from_pathlib_path(conf: Container, resolve: bool, expected: Any) -> None: - if expected is None: - expected = conf - # note that delete=False here is a work around windows incompetence. - try: - with tempfile.NamedTemporaryFile(delete=False) as fp: - filepath = pathlib.Path(fp.name) + filepath = file_class(fp.name) OmegaConf.save(conf, filepath, resolve=resolve) c2 = OmegaConf.load(filepath) assert c2 == expected @@ -58,32 +45,27 @@ def test_load_from_invalid() -> None: @pytest.mark.parametrize( - "input_,resolve,expected", + "input_,resolve,expected,file_class", [ - (dict(a=10), False, None), - ({"foo": 10, "bar": "${foo}"}, False, None), - ({"foo": 10, "bar": "${foo}"}, False, {"foo": 10, "bar": 10}), - ([u"שלום"], False, None), + (dict(a=10), False, None, str), + ({"foo": 10, "bar": "${foo}"}, False, None, str), + ({"foo": 10, "bar": "${foo}"}, False, None, pathlib.Path), + ({"foo": 10, "bar": "${foo}"}, False, {"foo": 10, "bar": 10}, str), + ([u"שלום"], False, None, str), ], ) class TestSaveLoad: def test_save_load__from_file( - self, input_: Dict[str, Any], resolve: bool, expected: Any + self, input_: Dict[str, Any], resolve: bool, expected: Any, file_class: Type ) -> None: cfg = OmegaConf.create(input_) save_load_from_file(cfg, resolve, expected) def test_save_load__from_filename( - self, input_: Dict[str, Any], resolve: bool, expected: Any - ) -> None: - cfg = OmegaConf.create(input_) - save_load_from_filename(cfg, resolve, expected) - - def test_save_load__from_pathlib_path( - self, input_: Dict[str, Any], resolve: bool, expected: Any + self, input_: Dict[str, Any], resolve: bool, expected: Any, file_class: Type ) -> None: cfg = OmegaConf.create(input_) - save_load_from_pathlib_path(cfg, resolve, expected) + save_load_from_filename(cfg, resolve, expected, file_class) def test_save_illegal_type() -> None: