diff --git a/news/531.bugfix b/news/531.bugfix new file mode 100644 index 000000000..6c6e6339e --- /dev/null +++ b/news/531.bugfix @@ -0,0 +1 @@ +Fix OmegaConf.to_yaml(cfg) when keys are of Enum type diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 0a46cb5c8..40779037a 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -227,6 +227,8 @@ def convert(val: Node) -> Any: ) assert node is not None + if enum_to_str and isinstance(key, Enum): + key = f"{key.name}" if isinstance(node, Container): retdict[key] = BaseContainer._to_content( node, diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index ee5b7c244..a79d7e76e 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -585,7 +585,7 @@ def to_container( Resursively converts an OmegaConf config to a primitive container (dict or list). :param cfg: the config to convert :param resolve: True to resolve all values - :param enum_to_str: True to convert Enum values to strings + :param enum_to_str: True to convert Enum keys and values to strings :param exclude_structured_configs: If True, do not convert Structured Configs (DictConfigs backed by a dataclass) :return: A dict or a list representing this config as a primitive container. diff --git a/tests/test_to_container.py b/tests/test_to_container.py index 48c92cb58..9854ab1fd 100644 --- a/tests/test_to_container.py +++ b/tests/test_to_container.py @@ -1,25 +1,25 @@ import re from enum import Enum -from typing import Any +from typing import Any, Dict, List -import pytest +from pytest import mark, param, raises from omegaconf import DictConfig, ListConfig, OmegaConf from tests import Color, User -@pytest.mark.parametrize( +@mark.parametrize( "input_", [ - pytest.param([1, 2, 3], id="list"), - pytest.param([1, 2, {"a": 3}], id="dict_in_list"), - pytest.param([1, 2, [10, 20]], id="list_in_list"), - pytest.param({"b": {"b": 10}}, id="dict_in_dict"), - pytest.param({"b": [False, 1, "2", 3.0, Color.RED]}, id="list_in_dict"), - pytest.param({"b": DictConfig(content=None)}, id="none_dictconfig"), - pytest.param({"b": ListConfig(content=None)}, id="none_listconfig"), - pytest.param({"b": DictConfig(content="???")}, id="missing_dictconfig"), - pytest.param({"b": ListConfig(content="???")}, id="missing_listconfig"), + param([1, 2, 3], id="list"), + param([1, 2, {"a": 3}], id="dict_in_list"), + param([1, 2, [10, 20]], id="list_in_list"), + param({"b": {"b": 10}}, id="dict_in_dict"), + param({"b": [False, 1, "2", 3.0, Color.RED]}, id="list_in_dict"), + param({"b": DictConfig(content=None)}, id="none_dictconfig"), + param({"b": ListConfig(content=None)}, id="none_listconfig"), + param({"b": DictConfig(content="???")}, id="missing_dictconfig"), + param({"b": ListConfig(content="???")}, id="missing_listconfig"), ], ) def test_to_container_returns_primitives(input_: Any) -> None: @@ -38,20 +38,20 @@ def assert_container_with_primitives(item: Any) -> None: assert_container_with_primitives(res) -@pytest.mark.parametrize( +@mark.parametrize( "cfg,ex_false,ex_true", [ - pytest.param( + param( {"user": User(age=7, name="Bond")}, {"user": {"name": "Bond", "age": 7}}, {"user": User(age=7, name="Bond")}, ), - pytest.param( + param( [1, User(age=7, name="Bond")], [1, {"name": "Bond", "age": 7}], [1, User(age=7, name="Bond")], ), - pytest.param( + param( {"users": [User(age=1, name="a"), User(age=2, name="b")]}, {"users": [{"age": 1, "name": "a"}, {"age": 2, "name": "b"}]}, {"users": [User(age=1, name="a"), User(age=2, name="b")]}, @@ -67,56 +67,56 @@ def test_exclude_structured_configs(cfg: Any, ex_false: Any, ex_true: Any) -> No assert ret1 == ex_true -@pytest.mark.parametrize( +@mark.parametrize( "src, expected, expected_with_resolve", [ - pytest.param([], None, None, id="empty_list"), - pytest.param([1, 2, 3], None, None, id="list"), - pytest.param([None], None, None, id="list_with_none"), - pytest.param([1, "${0}", 3], None, [1, 1, 3], id="list_with_inter"), - pytest.param({}, None, None, id="empty_dict"), - pytest.param({"foo": "bar"}, None, None, id="dict"), - pytest.param( + param([], None, None, id="empty_list"), + param([1, 2, 3], None, None, id="list"), + param([None], None, None, id="list_with_none"), + param([1, "${0}", 3], None, [1, 1, 3], id="list_with_inter"), + param({}, None, None, id="empty_dict"), + param({"foo": "bar"}, None, None, id="dict"), + param( {"foo": "${bar}", "bar": "zonk"}, None, {"foo": "zonk", "bar": "zonk"}, id="dict_with_inter", ), - pytest.param({"foo": None}, None, None, id="dict_with_none"), - pytest.param({"foo": "???"}, None, None, id="dict_missing_value"), - pytest.param({"foo": None}, None, None, id="dict_none_value"), + param({"foo": None}, None, None, id="dict_with_none"), + param({"foo": "???"}, None, None, id="dict_missing_value"), + param({"foo": None}, None, None, id="dict_none_value"), # containers - pytest.param( + param( {"foo": DictConfig(is_optional=True, content=None)}, {"foo": None}, None, id="dict_none_dictconfig", ), - pytest.param( + param( {"foo": DictConfig(content="???")}, {"foo": "???"}, None, id="dict_missing_dictconfig", ), - pytest.param( + param( {"foo": DictConfig(content="${bar}"), "bar": 10}, {"foo": "${bar}", "bar": 10}, {"foo": 10, "bar": 10}, id="dict_inter_dictconfig", ), - pytest.param( + param( {"foo": ListConfig(content="???")}, {"foo": "???"}, None, id="dict_missing_listconfig", ), - pytest.param( + param( {"foo": ListConfig(is_optional=True, content=None)}, {"foo": None}, None, id="dict_none_listconfig", ), - pytest.param( + param( {"foo": ListConfig(content="${bar}"), "bar": 10}, {"foo": "${bar}", "bar": 10}, {"foo": 10, "bar": 10}, @@ -137,7 +137,7 @@ def test_to_container(src: Any, expected: Any, expected_with_resolve: Any) -> No def test_to_container_invalid_input() -> None: - with pytest.raises( + with raises( ValueError, match=re.escape("Input cfg is not an OmegaConf config object (dict)"), ): @@ -153,11 +153,11 @@ def test_string_interpolation_with_readonly_parent() -> None: } -@pytest.mark.parametrize( +@mark.parametrize( "src,expected", [ - pytest.param(DictConfig(content="${bar}"), "${bar}", id="DictConfig"), - pytest.param( + param(DictConfig(content="${bar}"), "${bar}", id="DictConfig"), + param( OmegaConf.create({"foo": DictConfig(content="${bar}")}), {"foo": "${bar}"}, id="nested_DictConfig", @@ -167,3 +167,49 @@ def test_string_interpolation_with_readonly_parent() -> None: def test_to_container_missing_inter_no_resolve(src: Any, expected: Any) -> None: res = OmegaConf.to_container(src, resolve=False) assert res == expected + + +class TestEnumToStr: + """Test the `enum_to_str` argument to the `OmegaConf.to_container function`""" + + @mark.parametrize( + "src,enum_to_str,expected", + [ + param({Color.RED: "enum key"}, True, "RED", id="convert"), + param({Color.RED: "enum key"}, False, Color.RED, id="dont-convert"), + ], + ) + def test_enum_to_str_for_keys( + self, src: Any, enum_to_str: bool, expected: Any + ) -> None: + cfg = OmegaConf.create(src) + container: Dict[Any, Any] = OmegaConf.to_container(cfg, enum_to_str=enum_to_str) # type: ignore + assert container == {expected: "enum key"} + + @mark.parametrize( + "src,enum_to_str,expected", + [ + param({"enum val": Color.RED}, True, "RED", id="convert"), + param({"enum val": Color.RED}, False, Color.RED, id="dont-convert"), + ], + ) + def test_enum_to_str_for_values( + self, src: Any, enum_to_str: bool, expected: Any + ) -> None: + cfg = OmegaConf.create(src) + container: Dict[Any, Any] = OmegaConf.to_container(cfg, enum_to_str=enum_to_str) # type: ignore + assert container == {"enum val": expected} + + @mark.parametrize( + "src,enum_to_str,expected", + [ + param([Color.RED], True, "RED", id="convert"), + param([Color.RED], False, Color.RED, id="dont-convert"), + ], + ) + def test_enum_to_str_for_list( + self, src: Any, enum_to_str: bool, expected: Any + ) -> None: + cfg = OmegaConf.create(src) + container: List[Any] = OmegaConf.to_container(cfg, enum_to_str=enum_to_str) # type: ignore + assert container == [expected] diff --git a/tests/test_to_yaml.py b/tests/test_to_yaml.py index 8f88a891a..49ed1807c 100644 --- a/tests/test_to_yaml.py +++ b/tests/test_to_yaml.py @@ -126,6 +126,14 @@ def test_to_yaml_with_enum() -> None: ) +def test_to_yaml_with_enum_key() -> None: + cfg = OmegaConf.create({Enum1.FOO: "enum key"}) + expected = """FOO: enum key +""" + s = OmegaConf.to_yaml(cfg) + assert s == expected + + def test_pretty_deprecated() -> None: c = OmegaConf.create({"foo": "bar"}) with pytest.warns(