Skip to content

Commit

Permalink
add pathlib.Path support
Browse files Browse the repository at this point in the history
  • Loading branch information
gwenzek committed Mar 7, 2022
1 parent 5e73ee6 commit 1df8a05
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 0 deletions.
1 change: 1 addition & 0 deletions omegaconf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
EnumNode,
FloatNode,
IntegerNode,
PathNode,
StringNode,
ValueNode,
)
Expand Down
35 changes: 35 additions & 0 deletions omegaconf/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
from abc import abstractmethod
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Optional, Type, Union

from omegaconf._utils import (
Expand Down Expand Up @@ -179,6 +180,40 @@ def __deepcopy__(self, memo: Dict[int, Any]) -> "StringNode":
return res


class PathNode(ValueNode):
def __init__(
self,
value: Any = None,
key: Any = None,
parent: Optional[Container] = None,
is_optional: bool = True,
flags: Optional[Dict[str, bool]] = None,
):
super().__init__(
parent=parent,
value=value,
metadata=Metadata(
key=key,
optional=is_optional,
ref_type=Path,
object_type=Path,
flags=flags,
),
)

def _validate_and_convert_impl(self, value: Any) -> Path:
from omegaconf import OmegaConf

if OmegaConf.is_config(value) or is_primitive_container(value):
raise ValidationError("Cannot convert '$VALUE_TYPE' to Path: '$VALUE'")
return Path(value)

def __deepcopy__(self, memo: Dict[int, Any]) -> "PathNode":
res = PathNode()
self._deepcopy_impl(res, memo)
return res


class IntegerNode(ValueNode):
def __init__(
self,
Expand Down
4 changes: 4 additions & 0 deletions omegaconf/omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from collections import defaultdict
from contextlib import contextmanager
from enum import Enum
from pathlib import Path
from textwrap import dedent
from typing import (
IO,
Expand Down Expand Up @@ -66,6 +67,7 @@
EnumNode,
FloatNode,
IntegerNode,
PathNode,
StringNode,
ValueNode,
)
Expand Down Expand Up @@ -1062,6 +1064,8 @@ def _node_wrap(
node = BooleanNode(value=value, key=key, parent=parent, is_optional=is_optional)
elif type_ == str:
node = StringNode(value=value, key=key, parent=parent, is_optional=is_optional)
elif type_ == Path:
node = PathNode(value=value, key=key, parent=parent, is_optional=is_optional)
else:
if parent is not None and parent._get_flag("allow_objects") is True:
node = AnyNode(value=value, key=key, parent=parent)
Expand Down
19 changes: 19 additions & 0 deletions tests/structured_conf/test_structured_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import sys
import pathlib
import dataclasses
from importlib import import_module
from typing import Any, Dict, List, Optional

Expand Down Expand Up @@ -814,6 +816,23 @@ def test_promote_to_object(self, module: Any) -> None:
assert OmegaConf.get_type(conf) == module.BoolConfig
assert conf.with_default is False

def test_promote_to_dataclass(self, module: Any) -> None:
@dataclasses.dataclass
class Foo:
foo: pathlib.Path
bar: str
qub: int = 5

x = DictConfig({"foo": "hello.txt", "bar": "hello.txt"})
assert isinstance(x.foo, str)
assert isinstance(x.bar, str)

x._promote(Foo)
assert isinstance(x.foo, pathlib.Path)
assert isinstance(x.bar, str)
assert x.qub == 5


def test_set_key_with_with_dataclass(self, module: Any) -> None:
cfg = OmegaConf.create({"foo": [1, 2]})
cfg.foo = module.ListClass()
Expand Down
5 changes: 5 additions & 0 deletions tests/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import re
from enum import Enum
from functools import partial
from pathlib import Path
from typing import Any, Dict, Tuple, Type

from pytest import mark, param, raises
Expand All @@ -17,6 +18,7 @@
ListConfig,
Node,
OmegaConf,
PathNode,
StringNode,
ValueNode,
)
Expand Down Expand Up @@ -79,6 +81,9 @@
(lambda v: EnumNode(enum_type=Color, value=v), "Color.RED", Color.RED),
(lambda v: EnumNode(enum_type=Color, value=v), "RED", Color.RED),
(lambda v: EnumNode(enum_type=Color, value=v), 1, Color.RED),
# Path node
(PathNode, "hello.txt", Path("hello.txt")),
(PathNode, Path("hello.txt"), Path("hello.txt")),
],
)
def test_valid_inputs(type_: type, input_: Any, output_: Any) -> None:
Expand Down

0 comments on commit 1df8a05

Please sign in to comment.