Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enum_to_str: convert enum keys, not just enum values #549

Merged
merged 12 commits into from
Feb 22, 2021
1 change: 1 addition & 0 deletions news/531.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix OmegaConf.to_yaml(cfg) when keys are of Enum type
2 changes: 2 additions & 0 deletions omegaconf/basecontainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ def convert(val: Node) -> Any:
)

assert node is not None
if enum_to_str and isinstance(key, Enum):
key = f"{key.name}"
if isinstance(node, Container):
retdict[key] = BaseContainer._to_content(
node,
Expand Down
2 changes: 1 addition & 1 deletion omegaconf/omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ def to_container(
Resursively converts an OmegaConf config to a primitive container (dict or list).
:param cfg: the config to convert
:param resolve: True to resolve all values
:param enum_to_str: True to convert Enum values to strings
:param enum_to_str: True to convert Enum keys and values to strings
:param exclude_structured_configs: If True, do not convert Structured Configs
(DictConfigs backed by a dataclass)
:return: A dict or a list representing this config as a primitive container.
Expand Down
120 changes: 83 additions & 37 deletions tests/test_to_container.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
import re
from enum import Enum
from typing import Any
from typing import Any, Dict, List

import pytest
from pytest import mark, param, raises

from omegaconf import DictConfig, ListConfig, OmegaConf
from tests import Color, User


@pytest.mark.parametrize(
@mark.parametrize(
"input_",
[
pytest.param([1, 2, 3], id="list"),
pytest.param([1, 2, {"a": 3}], id="dict_in_list"),
pytest.param([1, 2, [10, 20]], id="list_in_list"),
pytest.param({"b": {"b": 10}}, id="dict_in_dict"),
pytest.param({"b": [False, 1, "2", 3.0, Color.RED]}, id="list_in_dict"),
pytest.param({"b": DictConfig(content=None)}, id="none_dictconfig"),
pytest.param({"b": ListConfig(content=None)}, id="none_listconfig"),
pytest.param({"b": DictConfig(content="???")}, id="missing_dictconfig"),
pytest.param({"b": ListConfig(content="???")}, id="missing_listconfig"),
param([1, 2, 3], id="list"),
param([1, 2, {"a": 3}], id="dict_in_list"),
param([1, 2, [10, 20]], id="list_in_list"),
param({"b": {"b": 10}}, id="dict_in_dict"),
param({"b": [False, 1, "2", 3.0, Color.RED]}, id="list_in_dict"),
param({"b": DictConfig(content=None)}, id="none_dictconfig"),
param({"b": ListConfig(content=None)}, id="none_listconfig"),
param({"b": DictConfig(content="???")}, id="missing_dictconfig"),
param({"b": ListConfig(content="???")}, id="missing_listconfig"),
],
)
def test_to_container_returns_primitives(input_: Any) -> None:
Expand All @@ -38,20 +38,20 @@ def assert_container_with_primitives(item: Any) -> None:
assert_container_with_primitives(res)


@pytest.mark.parametrize(
@mark.parametrize(
"cfg,ex_false,ex_true",
[
pytest.param(
param(
{"user": User(age=7, name="Bond")},
{"user": {"name": "Bond", "age": 7}},
{"user": User(age=7, name="Bond")},
),
pytest.param(
param(
[1, User(age=7, name="Bond")],
[1, {"name": "Bond", "age": 7}],
[1, User(age=7, name="Bond")],
),
pytest.param(
param(
{"users": [User(age=1, name="a"), User(age=2, name="b")]},
{"users": [{"age": 1, "name": "a"}, {"age": 2, "name": "b"}]},
{"users": [User(age=1, name="a"), User(age=2, name="b")]},
Expand All @@ -67,56 +67,56 @@ def test_exclude_structured_configs(cfg: Any, ex_false: Any, ex_true: Any) -> No
assert ret1 == ex_true


@pytest.mark.parametrize(
@mark.parametrize(
"src, expected, expected_with_resolve",
[
pytest.param([], None, None, id="empty_list"),
pytest.param([1, 2, 3], None, None, id="list"),
pytest.param([None], None, None, id="list_with_none"),
pytest.param([1, "${0}", 3], None, [1, 1, 3], id="list_with_inter"),
pytest.param({}, None, None, id="empty_dict"),
pytest.param({"foo": "bar"}, None, None, id="dict"),
pytest.param(
param([], None, None, id="empty_list"),
param([1, 2, 3], None, None, id="list"),
param([None], None, None, id="list_with_none"),
param([1, "${0}", 3], None, [1, 1, 3], id="list_with_inter"),
param({}, None, None, id="empty_dict"),
param({"foo": "bar"}, None, None, id="dict"),
param(
{"foo": "${bar}", "bar": "zonk"},
None,
{"foo": "zonk", "bar": "zonk"},
id="dict_with_inter",
),
pytest.param({"foo": None}, None, None, id="dict_with_none"),
pytest.param({"foo": "???"}, None, None, id="dict_missing_value"),
pytest.param({"foo": None}, None, None, id="dict_none_value"),
param({"foo": None}, None, None, id="dict_with_none"),
param({"foo": "???"}, None, None, id="dict_missing_value"),
param({"foo": None}, None, None, id="dict_none_value"),
# containers
pytest.param(
param(
{"foo": DictConfig(is_optional=True, content=None)},
{"foo": None},
None,
id="dict_none_dictconfig",
),
pytest.param(
param(
{"foo": DictConfig(content="???")},
{"foo": "???"},
None,
id="dict_missing_dictconfig",
),
pytest.param(
param(
{"foo": DictConfig(content="${bar}"), "bar": 10},
{"foo": "${bar}", "bar": 10},
{"foo": 10, "bar": 10},
id="dict_inter_dictconfig",
),
pytest.param(
param(
{"foo": ListConfig(content="???")},
{"foo": "???"},
None,
id="dict_missing_listconfig",
),
pytest.param(
param(
{"foo": ListConfig(is_optional=True, content=None)},
{"foo": None},
None,
id="dict_none_listconfig",
),
pytest.param(
param(
{"foo": ListConfig(content="${bar}"), "bar": 10},
{"foo": "${bar}", "bar": 10},
{"foo": 10, "bar": 10},
Expand All @@ -137,7 +137,7 @@ def test_to_container(src: Any, expected: Any, expected_with_resolve: Any) -> No


def test_to_container_invalid_input() -> None:
with pytest.raises(
with raises(
ValueError,
match=re.escape("Input cfg is not an OmegaConf config object (dict)"),
):
Expand All @@ -153,11 +153,11 @@ def test_string_interpolation_with_readonly_parent() -> None:
}


@pytest.mark.parametrize(
@mark.parametrize(
"src,expected",
[
pytest.param(DictConfig(content="${bar}"), "${bar}", id="DictConfig"),
pytest.param(
param(DictConfig(content="${bar}"), "${bar}", id="DictConfig"),
param(
OmegaConf.create({"foo": DictConfig(content="${bar}")}),
{"foo": "${bar}"},
id="nested_DictConfig",
Expand All @@ -167,3 +167,49 @@ def test_string_interpolation_with_readonly_parent() -> None:
def test_to_container_missing_inter_no_resolve(src: Any, expected: Any) -> None:
res = OmegaConf.to_container(src, resolve=False)
assert res == expected

omry marked this conversation as resolved.
Show resolved Hide resolved

class TestEnumToStr:
"""Test the `enum_to_str` argument to the `OmegaConf.to_container function`"""

@mark.parametrize(
"src,enum_to_str,expected",
[
param({Color.RED: "enum key"}, True, "RED", id="convert"),
param({Color.RED: "enum key"}, False, Color.RED, id="dont-convert"),
],
)
def test_enum_to_str_for_keys(
self, src: Any, enum_to_str: bool, expected: Any
) -> None:
cfg = OmegaConf.create(src)
container: Dict[Any, Any] = OmegaConf.to_container(cfg, enum_to_str=enum_to_str) # type: ignore
assert container == {expected: "enum key"}

@mark.parametrize(
"src,enum_to_str,expected",
[
param({"enum val": Color.RED}, True, "RED", id="convert"),
param({"enum val": Color.RED}, False, Color.RED, id="dont-convert"),
],
)
def test_enum_to_str_for_values(
self, src: Any, enum_to_str: bool, expected: Any
) -> None:
cfg = OmegaConf.create(src)
container: Dict[Any, Any] = OmegaConf.to_container(cfg, enum_to_str=enum_to_str) # type: ignore
assert container == {"enum val": expected}

@mark.parametrize(
"src,enum_to_str,expected",
[
param([Color.RED], True, "RED", id="convert"),
param([Color.RED], False, Color.RED, id="dont-convert"),
],
)
def test_enum_to_str_for_list(
self, src: Any, enum_to_str: bool, expected: Any
) -> None:
cfg = OmegaConf.create(src)
container: List[Any] = OmegaConf.to_container(cfg, enum_to_str=enum_to_str) # type: ignore
assert container == [expected]
8 changes: 8 additions & 0 deletions tests/test_to_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,14 @@ def test_to_yaml_with_enum() -> None:
)


def test_to_yaml_with_enum_key() -> None:
cfg = OmegaConf.create({Enum1.FOO: "enum key"})
expected = """FOO: enum key
"""
s = OmegaConf.to_yaml(cfg)
assert s == expected


def test_pretty_deprecated() -> None:
c = OmegaConf.create({"foo": "bar"})
with pytest.warns(
Expand Down