From 1df8a05f9f740b7085fb65fe7a4d8b1e034f2ea6 Mon Sep 17 00:00:00 2001 From: Guillaume Wenzek Date: Mon, 7 Mar 2022 15:34:57 +0100 Subject: [PATCH] add pathlib.Path support --- omegaconf/__init__.py | 1 + omegaconf/nodes.py | 35 +++++++++++++++++++ omegaconf/omegaconf.py | 4 +++ .../structured_conf/test_structured_config.py | 19 ++++++++++ tests/test_nodes.py | 5 +++ 5 files changed, 64 insertions(+) diff --git a/omegaconf/__init__.py b/omegaconf/__init__.py index 02d28a489..1fb870e2d 100644 --- a/omegaconf/__init__.py +++ b/omegaconf/__init__.py @@ -14,6 +14,7 @@ EnumNode, FloatNode, IntegerNode, + PathNode, StringNode, ValueNode, ) diff --git a/omegaconf/nodes.py b/omegaconf/nodes.py index 558e77750..40bd38fdc 100644 --- a/omegaconf/nodes.py +++ b/omegaconf/nodes.py @@ -3,6 +3,7 @@ import sys from abc import abstractmethod from enum import Enum +from pathlib import Path from typing import Any, Dict, Optional, Type, Union from omegaconf._utils import ( @@ -179,6 +180,40 @@ def __deepcopy__(self, memo: Dict[int, Any]) -> "StringNode": return res +class PathNode(ValueNode): + def __init__( + self, + value: Any = None, + key: Any = None, + parent: Optional[Container] = None, + is_optional: bool = True, + flags: Optional[Dict[str, bool]] = None, + ): + super().__init__( + parent=parent, + value=value, + metadata=Metadata( + key=key, + optional=is_optional, + ref_type=Path, + object_type=Path, + flags=flags, + ), + ) + + def _validate_and_convert_impl(self, value: Any) -> Path: + from omegaconf import OmegaConf + + if OmegaConf.is_config(value) or is_primitive_container(value): + raise ValidationError("Cannot convert '$VALUE_TYPE' to Path: '$VALUE'") + return Path(value) + + def __deepcopy__(self, memo: Dict[int, Any]) -> "PathNode": + res = PathNode() + self._deepcopy_impl(res, memo) + return res + + class IntegerNode(ValueNode): def __init__( self, diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index b1a32b4f6..8b8278897 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -9,6 +9,7 @@ from collections import defaultdict from contextlib import contextmanager from enum import Enum +from pathlib import Path from textwrap import dedent from typing import ( IO, @@ -66,6 +67,7 @@ EnumNode, FloatNode, IntegerNode, + PathNode, StringNode, ValueNode, ) @@ -1062,6 +1064,8 @@ def _node_wrap( node = BooleanNode(value=value, key=key, parent=parent, is_optional=is_optional) elif type_ == str: node = StringNode(value=value, key=key, parent=parent, is_optional=is_optional) + elif type_ == Path: + node = PathNode(value=value, key=key, parent=parent, is_optional=is_optional) else: if parent is not None and parent._get_flag("allow_objects") is True: node = AnyNode(value=value, key=key, parent=parent) diff --git a/tests/structured_conf/test_structured_config.py b/tests/structured_conf/test_structured_config.py index bd188bc50..509540575 100644 --- a/tests/structured_conf/test_structured_config.py +++ b/tests/structured_conf/test_structured_config.py @@ -1,4 +1,6 @@ import sys +import pathlib +import dataclasses from importlib import import_module from typing import Any, Dict, List, Optional @@ -814,6 +816,23 @@ def test_promote_to_object(self, module: Any) -> None: assert OmegaConf.get_type(conf) == module.BoolConfig assert conf.with_default is False + def test_promote_to_dataclass(self, module: Any) -> None: + @dataclasses.dataclass + class Foo: + foo: pathlib.Path + bar: str + qub: int = 5 + + x = DictConfig({"foo": "hello.txt", "bar": "hello.txt"}) + assert isinstance(x.foo, str) + assert isinstance(x.bar, str) + + x._promote(Foo) + assert isinstance(x.foo, pathlib.Path) + assert isinstance(x.bar, str) + assert x.qub == 5 + + def test_set_key_with_with_dataclass(self, module: Any) -> None: cfg = OmegaConf.create({"foo": [1, 2]}) cfg.foo = module.ListClass() diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 8d1fba29d..5aedd9de2 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -3,6 +3,7 @@ import re from enum import Enum from functools import partial +from pathlib import Path from typing import Any, Dict, Tuple, Type from pytest import mark, param, raises @@ -17,6 +18,7 @@ ListConfig, Node, OmegaConf, + PathNode, StringNode, ValueNode, ) @@ -79,6 +81,9 @@ (lambda v: EnumNode(enum_type=Color, value=v), "Color.RED", Color.RED), (lambda v: EnumNode(enum_type=Color, value=v), "RED", Color.RED), (lambda v: EnumNode(enum_type=Color, value=v), 1, Color.RED), + # Path node + (PathNode, "hello.txt", Path("hello.txt")), + (PathNode, Path("hello.txt"), Path("hello.txt")), ], ) def test_valid_inputs(type_: type, input_: Any, output_: Any) -> None: