diff --git a/docs/notebook/Tutorial.ipynb b/docs/notebook/Tutorial.ipynb index fbbb89f9e..65da517ce 100644 --- a/docs/notebook/Tutorial.ipynb +++ b/docs/notebook/Tutorial.ipynb @@ -698,7 +698,8 @@ "source": [ "## Environment variable interpolation\n", "\n", - "Environment variable interpolation is also supported." + "Environment variable interpolation is also supported.\n", + "An environment variable is always returned as a string." ] }, { @@ -733,8 +734,8 @@ "output_type": "stream", "text": [ "user:\n", - " name: ${env:USER}\n", - " home: /home/${env:USER}\n", + " name: ${oc.env:USER}\n", + " home: /home/${oc.env:USER}\n", "\n" ] } @@ -773,7 +774,10 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "You can specify a default value to use in case the environment variable is not defined. The following example sets `abc123` as the the default value when `DB_PASSWORD` is not defined." + "You can specify a default value to use in case the environment variable is not defined.\n", + "This default value can be a string or ``null`` (representing Python ``None``). Passing a default with a different type will result in an error.\n", + "\n", + "The following example sets default database passwords when ``DB_PASSWORD`` is not defined:" ] }, { @@ -782,27 +786,43 @@ "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "'abc123'" - ] - }, - "execution_count": 25, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "'abc123'\n", + "'12345'\n" + ] } ], "source": [ "os.environ.pop('DB_PASSWORD', None) # ensure env variable does not exist\n", - "cfg = OmegaConf.create({'database': {'password': '${env:DB_PASSWORD,abc123}'}})\n", - "cfg.database.password" + "cfg = OmegaConf.create(\n", + " {\n", + " \"database\": {\n", + " \"password1\": \"${oc.env:DB_PASSWORD,abc123}\", # the string 'abc123'\n", + " \"password2\": \"${oc.env:DB_PASSWORD,'12345'}\", # the string '12345'\n", + " },\n", + " }\n", + ")\n", + "print(repr(cfg.database.password1))\n", + "print(repr(cfg.database.password2))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Environment variables are parsed when they are recognized as valid quantities that may be evaluated (e.g., int, float, dict, list):" + "## Decoding strings with interpolations" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can automatically convert a string to its corresponding type (e.g., bool, int, float, dict, list) using `oc.decode` (which can even resolve interpolations).\n", + "This resolver also accepts ``None`` as input, in which case it returns ``None``.\n", + "\n", + "This can be useful for instance to parse environment variables:" ] }, { @@ -814,25 +834,34 @@ "name": "stdout", "output_type": "stream", "text": [ - "3308\n", - "['host1', 'host2', 'host3']\n", - "'a%#@~{}$*&^?/<'\n" + "port (int): 3308\n", + "nodes (list): ['host1', 'host2', 'host3']\n", + "timeout (missing variable): None\n", + "timeout (interpolation): 3308\n" ] } ], "source": [ - "cfg = OmegaConf.create({'database': {'password': '${env:DB_PASSWORD,abc123}',\n", - " 'user': 'someuser',\n", - " 'port': '${env:DB_PORT,3306}',\n", - " 'nodes': '${env:DB_NODES,[]}'}})\n", + "cfg = OmegaConf.create(\n", + " {\n", + " \"database\": {\n", + " \"port\": '${oc.decode:${oc.env:DB_PORT}}',\n", + " \"nodes\": '${oc.decode:${oc.env:DB_NODES,null}}',\n", + " \"timeout\": '${oc.decode:${oc.env:DB_TIMEOUT,null}}',\n", + " }\n", + " }\n", + ")\n", + "\n", + "os.environ[\"DB_PORT\"] = \"3308\" # integer\n", + "os.environ[\"DB_NODES\"] = \"[host1, host2, host3]\" # list\n", + "os.environ.pop(\"DB_TIMEOUT\", None) # unset variable\n", "\n", - "os.environ[\"DB_PORT\"] = '3308' # integer\n", - "os.environ[\"DB_NODES\"] = '[host1, host2, host3]' # list\n", - "os.environ[\"DB_PASSWORD\"] = 'a%#@~{}$*&^?/<' # string\n", + "print(\"port (int):\", repr(cfg.database.port))\n", + "print(\"nodes (list):\", repr(cfg.database.nodes))\n", + "print(\"timeout (missing variable):\", repr(cfg.database.timeout))\n", "\n", - "print(repr(cfg.database.port))\n", - "print(repr(cfg.database.nodes))\n", - "print(repr(cfg.database.password))" + "os.environ[\"DB_TIMEOUT\"] = \"${.port}\"\n", + "print(\"timeout (interpolation):\", repr(cfg.database.timeout))" ] }, { diff --git a/docs/source/env_interpolation.yaml b/docs/source/env_interpolation.yaml index 75f44169b..3f0b64b8e 100644 --- a/docs/source/env_interpolation.yaml +++ b/docs/source/env_interpolation.yaml @@ -1,3 +1,3 @@ user: - name: ${env:USER} - home: /home/${env:USER} + name: ${oc.env:USER} + home: /home/${oc.env:USER} diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 834a66a36..73ce66dd0 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -6,6 +6,8 @@ import tempfile import pickle os.environ['USER'] = 'omry' + # ensures that DB_TIMEOUT is not set in the doc. + os.environ.pop('DB_TIMEOUT', None) .. testsetup:: loaded @@ -334,20 +336,17 @@ Example: .. doctest:: >>> conf = OmegaConf.load('source/config_interpolation.yaml') + >>> def show(x): + ... print(f"type: {type(x).__name__}, value: {repr(x)}") >>> # Primitive interpolation types are inherited from the reference - >>> conf.client.server_port - 80 - >>> type(conf.client.server_port).__name__ - 'int' - >>> conf.client.description - 'Client of http://localhost:80/' - - >>> # Composite interpolation types are always string - >>> conf.client.url - 'http://localhost:80/' - >>> type(conf.client.url).__name__ - 'str' - + >>> show(conf.client.server_port) + type: int, value: 80 + >>> # String interpolations concatenate fragments into a string + >>> show(conf.client.url) + type: str, value: 'http://localhost:80/' + >>> # Relative interpolation example + >>> show(conf.client.description) + type: str, value: 'Client of http://localhost:80/' Interpolations may be nested, enabling more advanced behavior like dynamically selecting a sub-config: @@ -386,7 +385,7 @@ Interpolated nodes can be any node in the config, not just leaf nodes: Environment variable interpolation ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Environment variable interpolation is also supported. +Access to environment variables is supported using ``oc.env``: Input YAML file: @@ -402,36 +401,59 @@ Input YAML file: '/home/omry' You can specify a default value to use in case the environment variable is not defined. -The following example sets `abc123` as the the default value when `DB_PASSWORD` is not defined. +This default value can be a string or ``null`` (representing Python ``None``). Passing a default with a different type will result in an error. +The following example falls back to default passwords when ``DB_PASSWORD`` is not defined: .. doctest:: - >>> cfg = OmegaConf.create({ - ... 'database': {'password': '${env:DB_PASSWORD,abc123}'} - ... }) - >>> cfg.database.password + >>> cfg = OmegaConf.create( + ... { + ... "database": { + ... "password1": "${oc.env:DB_PASSWORD,abc123}", + ... "password2": "${oc.env:DB_PASSWORD,'12345'}", + ... }, + ... } + ... ) + >>> cfg.database.password1 # the string 'abc123' 'abc123' + >>> cfg.database.password2 # the string '12345' + '12345' + -Environment variables are parsed when they are recognized as valid quantities that -may be evaluated (e.g., int, float, dict, list): +Decoding strings with interpolations +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Strings may be converted using ``oc.decode``: + +- Primitive values (e.g., ``"true"``, ``"1"``, ``"1e-3"``) are automatically converted to their corresponding type (bool, int, float) +- Dictionaries and lists (e.g., ``"{a: b}"``, ``"[a, b, c]"``) are returned as transient config nodes (DictConfig and ListConfig) +- Interpolations (e.g., ``"${foo}"``) are automatically resolved +- ``None`` is the only valid non-string input to ``oc.decode`` (returning ``None`` in that case) + +This can be useful for instance to parse environment variables: .. doctest:: - >>> cfg = OmegaConf.create({ - ... 'database': {'password': '${env:DB_PASSWORD,abc123}', - ... 'user': 'someuser', - ... 'port': '${env:DB_PORT,3306}', - ... 'nodes': '${env:DB_NODES,[]}'} - ... }) - >>> os.environ["DB_PORT"] = '3308' - >>> cfg.database.port # converted to int - 3308 - >>> os.environ["DB_NODES"] = '[host1, host2, host3]' - >>> cfg.database.nodes # converted to list - ['host1', 'host2', 'host3'] - >>> os.environ["DB_PASSWORD"] = 'a%#@~{}$*&^?/<' - >>> cfg.database.password # kept as a string - 'a%#@~{}$*&^?/<' + >>> cfg = OmegaConf.create( + ... { + ... "database": { + ... "port": '${oc.decode:${oc.env:DB_PORT}}', + ... "nodes": '${oc.decode:${oc.env:DB_NODES}}', + ... "timeout": '${oc.decode:${oc.env:DB_TIMEOUT,null}}', + ... } + ... } + ... ) + >>> os.environ["DB_PORT"] = "3308" + >>> show(cfg.database.port) # converted to int + type: int, value: 3308 + >>> os.environ["DB_NODES"] = "[host1, host2, host3]" + >>> show(cfg.database.nodes) # converted to a ListConfig + type: ListConfig, value: ['host1', 'host2', 'host3'] + >>> show(cfg.database.timeout) # keeping `None` as is + type: NoneType, value: None + >>> os.environ["DB_TIMEOUT"] = "${.port}" + >>> show(cfg.database.timeout) # resolving interpolation + type: int, value: 3308 Custom interpolations @@ -762,12 +784,11 @@ If resolve is set to True, interpolations will be resolved during conversion. >>> conf = OmegaConf.create({"foo": "bar", "foo2": "${foo}"}) >>> assert type(conf) == DictConfig >>> primitive = OmegaConf.to_container(conf) - >>> assert type(primitive) == dict - >>> print(primitive) - {'foo': 'bar', 'foo2': '${foo}'} + >>> show(primitive) + type: dict, value: {'foo': 'bar', 'foo2': '${foo}'} >>> resolved = OmegaConf.to_container(conf, resolve=True) - >>> print(resolved) - {'foo': 'bar', 'foo2': 'bar'} + >>> show(resolved) + type: dict, value: {'foo': 'bar', 'foo2': 'bar'} You can customize the treatment of **OmegaConf.to_container()** for Structured Config nodes using the `structured_config_mode` option. @@ -780,11 +801,10 @@ as DictConfig, allowing attribute style access on the resulting node. >>> from omegaconf import SCMode >>> conf = OmegaConf.create({"structured_config": MyConfig}) >>> container = OmegaConf.to_container(conf, structured_config_mode=SCMode.DICT_CONFIG) - >>> print(container) - {'structured_config': {'port': 80, 'host': 'localhost'}} - >>> assert type(container) is dict - >>> assert type(container["structured_config"]) is DictConfig - >>> assert container["structured_config"].port == 80 + >>> show(container) + type: dict, value: {'structured_config': {'port': 80, 'host': 'localhost'}} + >>> show(container["structured_config"]) + type: DictConfig, value: {'port': 80, 'host': 'localhost'} OmegaConf.select ^^^^^^^^^^^^^^^^ diff --git a/news/230.bugfix b/news/230.bugfix deleted file mode 100644 index 61009f8be..000000000 --- a/news/230.bugfix +++ /dev/null @@ -1 +0,0 @@ -`${env:MYVAR,null}` now properly returns `None` if the environment variable MYVAR is undefined. diff --git a/news/445.feature.1 b/news/445.feature.1 index f43a7de6a..31bec183b 100644 --- a/news/445.feature.1 +++ b/news/445.feature.1 @@ -1 +1 @@ -Add ability to nest interpolations, e.g. ${foo.${bar}}}, ${env:{$var1},${var2}}, or ${${func}:x1,x2} +Add ability to nest interpolations, e.g. ${foo.${bar}}}, ${oc.env:{$var1},${var2}}, or ${${func}:x1,x2} diff --git a/news/573.api_change b/news/573.api_change new file mode 100644 index 000000000..c7f9ccdd7 --- /dev/null +++ b/news/573.api_change @@ -0,0 +1 @@ +The `env` resolver is deprecated in favor of `oc.env`, which keeps the string representation of environment variables, does not cache the resulting value, and handles "null" as default value. diff --git a/news/574.feature b/news/574.feature new file mode 100644 index 000000000..c62d93f6f --- /dev/null +++ b/news/574.feature @@ -0,0 +1 @@ +New resolver `oc.decode` that can be used to automatically convert a string to bool, int, float, dict, list, etc. diff --git a/omegaconf/_utils.py b/omegaconf/_utils.py index 08e49e9be..8a538471b 100644 --- a/omegaconf/_utils.py +++ b/omegaconf/_utils.py @@ -63,6 +63,18 @@ _CMP_TYPES = {t: i for i, t in enumerate([float, int, bool, str, type(None)])} +class Marker: + def __init__(self, desc: str): + self.desc = desc + + def __repr__(self) -> str: + return self.desc + + +# To be used as default value when `None` is not an option. +_DEFAULT_MARKER_: Any = Marker("_DEFAULT_MARKER_") + + class OmegaConfDumper(yaml.Dumper): # type: ignore str_representer_added = False @@ -404,6 +416,12 @@ def get_value_kind( return ValueKind.VALUE +# DEPRECATED: remove in 2.2 +def is_bool(st: str) -> bool: + st = str.lower(st) + return st == "true" or st == "false" + + def is_float(st: str) -> bool: try: float(st) @@ -420,6 +438,20 @@ def is_int(st: str) -> bool: return False +# DEPRECATED: remove in 2.2 +def decode_primitive(s: str) -> Any: + if is_bool(s): + return str.lower(s) == "true" + + if is_int(s): + return int(s) + + if is_float(s): + return float(s) + + return s + + def is_primitive_list(obj: Any) -> bool: from .base import Container diff --git a/omegaconf/base.py b/omegaconf/base.py index 4a0a213fd..ec78ed965 100644 --- a/omegaconf/base.py +++ b/omegaconf/base.py @@ -9,6 +9,7 @@ from antlr4 import ParserRuleContext from ._utils import ( + _DEFAULT_MARKER_, ValueKind, _get_value, _is_missing_value, @@ -32,8 +33,6 @@ DictKeyType = Union[str, int, Enum, float, bool] -_MARKER_ = object() - @dataclass class Metadata: @@ -155,8 +154,8 @@ def _get_flag(self, flag: str) -> Optional[bool]: if cache is None: cache = self.__dict__["_flags_cache"] = {} - ret = cache.get(flag, _MARKER_) - if ret is _MARKER_: + ret = cache.get(flag, _DEFAULT_MARKER_) + if ret is _DEFAULT_MARKER_: ret = self._get_flag_no_cache(flag) cache[flag] = ret assert ret is None or isinstance(ret, bool) diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 74ce491ef..b3bad23de 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -9,6 +9,7 @@ import yaml from ._utils import ( + _DEFAULT_MARKER_, _ensure_container, _get_value, _is_interpolation, @@ -37,8 +38,6 @@ if TYPE_CHECKING: from .dictconfig import DictConfig # pragma: no cover -DEFAULT_VALUE_MARKER: Any = str("__DEFAULT_VALUE_MARKER__") - class BaseContainer(Container, ABC): # static @@ -52,11 +51,11 @@ def _resolve_with_default( self, key: Union[DictKeyType, int], value: Node, - default_value: Any = DEFAULT_VALUE_MARKER, + default_value: Any = _DEFAULT_MARKER_, ) -> Any: """returns the value with the specified key, like obj.key and obj['key']""" if _is_missing_value(value): - if default_value is not DEFAULT_VALUE_MARKER: + if default_value is not _DEFAULT_MARKER_: return default_value raise MissingMandatoryValue("Missing mandatory value: $FULL_KEY") diff --git a/omegaconf/dictconfig.py b/omegaconf/dictconfig.py index a8532602c..e19025c22 100644 --- a/omegaconf/dictconfig.py +++ b/omegaconf/dictconfig.py @@ -16,6 +16,7 @@ ) from ._utils import ( + _DEFAULT_MARKER_, ValueKind, _get_value, _is_interpolation, @@ -35,7 +36,7 @@ valid_value_annotation_type, ) from .base import Container, ContainerMetadata, DictKeyType, Node -from .basecontainer import DEFAULT_VALUE_MARKER, BaseContainer +from .basecontainer import BaseContainer from .errors import ( ConfigAttributeError, ConfigKeyError, @@ -345,7 +346,7 @@ def __getattr__(self, key: str) -> Any: raise AttributeError() try: - return self._get_impl(key=key, default_value=DEFAULT_VALUE_MARKER) + return self._get_impl(key=key, default_value=_DEFAULT_MARKER_) except ConfigKeyError as e: self._format_and_raise( key=key, value=None, cause=e, type_override=ConfigAttributeError @@ -361,7 +362,7 @@ def __getitem__(self, key: DictKeyType) -> Any: """ try: - return self._get_impl(key=key, default_value=DEFAULT_VALUE_MARKER) + return self._get_impl(key=key, default_value=_DEFAULT_MARKER_) except AttributeError as e: self._format_and_raise( key=key, value=None, cause=e, type_override=ConfigKeyError @@ -414,7 +415,7 @@ def _get_impl(self, key: DictKeyType, default_value: Any) -> Any: try: node = self._get_node(key=key, throw_on_missing_key=True) except (ConfigAttributeError, ConfigKeyError): - if default_value is not DEFAULT_VALUE_MARKER: + if default_value is not _DEFAULT_MARKER_: return default_value else: raise @@ -449,7 +450,7 @@ def _get_node( raise MissingMandatoryValue("Missing mandatory value") return value - def pop(self, key: DictKeyType, default: Any = DEFAULT_VALUE_MARKER) -> Any: + def pop(self, key: DictKeyType, default: Any = _DEFAULT_MARKER_) -> Any: try: if self._get_flag("readonly"): raise ReadonlyConfigError("Cannot pop from read-only node") @@ -470,7 +471,7 @@ def pop(self, key: DictKeyType, default: Any = DEFAULT_VALUE_MARKER) -> Any: del self[key] return value else: - if default is not DEFAULT_VALUE_MARKER: + if default is not _DEFAULT_MARKER_: return default else: full = self._get_full_key(key=key) diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index 61e4c0ddb..d510d40bc 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -28,10 +28,12 @@ from . import DictConfig, DictKeyType, ListConfig from ._utils import ( + _DEFAULT_MARKER_, _ensure_container, _get_value, _is_none, _make_hashable, + decode_primitive, format_and_raise, get_dict_key_value_types, get_list_element_type, @@ -53,8 +55,6 @@ from .basecontainer import BaseContainer from .errors import ( ConfigKeyError, - GrammarParseError, - InterpolationKeyError, MissingMandatoryValue, OmegaConfBaseException, UnsupportedInterpolationType, @@ -73,12 +73,6 @@ MISSING: Any = "???" -# A marker used: -# - in OmegaConf.create() to differentiate between creating an empty {} DictConfig -# and creating a DictConfig with None content -# - in env() to detect between no default value vs a default value set to None -_EMPTY_MARKER_ = object() - Resolver = Callable[..., Any] @@ -101,43 +95,61 @@ def SI(interpolation: str) -> Any: def register_default_resolvers() -> None: - def env(key: str, default: Any = _EMPTY_MARKER_) -> Any: + # DEPRECATED: remove in 2.2 + def legacy_env(key: str, default: Optional[str] = None) -> Any: + warnings.warn( + "The `env` resolver is deprecated, see https://github.com/omry/omegaconf/issues/573", + ) + try: - val_str = os.environ[key] + return decode_primitive(os.environ[key]) except KeyError: - if default is not _EMPTY_MARKER_: - return default + if default is not None: + return decode_primitive(default) else: raise ValidationError(f"Environment variable '{key}' not found") - # We obtained a string from the environment variable: we try to parse it - # using the grammar (as if it was a resolver argument), so that expressions - # like numbers, booleans, lists and dictionaries can be properly evaluated. - try: - parse_tree = parse( - val_str, parser_rule="singleElement", lexer_mode="VALUE_MODE" + def env(key: str, default: Optional[str] = _DEFAULT_MARKER_) -> Optional[str]: + if ( + default is not _DEFAULT_MARKER_ + and default is not None + and not isinstance(default, str) + ): + raise TypeError( + f"The default value of the `oc.env` resolver must be a string or " + f"None, but `{default}` is of type {type(default).__name__}" ) - except GrammarParseError: - # Un-parsable as a resolver argument: keep the string unchanged. - return val_str - # Resolve the parse tree. We use an empty config for this, which means that - # interpolations referring to other nodes will fail. - empty_config = DictConfig({}) try: - val = empty_config.resolve_parse_tree(parse_tree) - except InterpolationKeyError as exc: - raise InterpolationKeyError( - f"When attempting to resolve env variable '{key}', a node interpolation " - f"caused the following exception: {exc}. Node interpolations are not " - f"supported in environment variables: either remove them, or escape " - f"them to keep them as a strings." - ).with_traceback(sys.exc_info()[2]) + return os.environ[key] + except KeyError: + if default is not _DEFAULT_MARKER_: + return default + else: + raise KeyError(f"Environment variable '{key}' not found") + + def decode(expr: Optional[str], _parent_: Container) -> Any: + """ + Parse and evaluate `expr` according to the `singleElement` rule of the grammar. + + If `expr` is `None`, then return `None`. + """ + if expr is None: + return None + + if not isinstance(expr, str): + raise TypeError( + f"`oc.decode` can only take strings or None as input, " + f"but `{expr}` is of type {type(expr).__name__}" + ) + parse_tree = parse(expr, parser_rule="singleElement", lexer_mode="VALUE_MODE") + val = _parent_.resolve_parse_tree(parse_tree) return _get_value(val) - # Note that the `env` resolver does *NOT* use the cache. - OmegaConf.register_new_resolver("env", env, use_cache=True) + OmegaConf.legacy_register_resolver("env", legacy_env) + OmegaConf.register_new_resolver("oc.env", env, use_cache=False) + OmegaConf.register_new_resolver("oc.decode", decode, use_cache=False) class OmegaConf: @@ -201,7 +213,7 @@ def create( @staticmethod def create( # noqa F811 - obj: Any = _EMPTY_MARKER_, + obj: Any = _DEFAULT_MARKER_, parent: Optional[BaseContainer] = None, flags: Optional[Dict[str, bool]] = None, ) -> Union[DictConfig, ListConfig]: @@ -666,7 +678,7 @@ def select( cfg: Container, key: str, *, - default: Any = _EMPTY_MARKER_, + default: Any = _DEFAULT_MARKER_, throw_on_resolution_failure: bool = True, throw_on_missing: bool = False, ) -> Any: @@ -678,13 +690,13 @@ def select( throw_on_resolution_failure=throw_on_resolution_failure, ) except ConfigKeyError: - if default is not _EMPTY_MARKER_: + if default is not _DEFAULT_MARKER_: return default else: raise if ( - default is not _EMPTY_MARKER_ + default is not _DEFAULT_MARKER_ and _root is not None and _last_key is not None and _last_key not in _root @@ -788,7 +800,7 @@ def to_yaml(cfg: Any, *, resolve: bool = False, sort_keys: bool = False) -> str: @staticmethod def _create_impl( # noqa F811 - obj: Any = _EMPTY_MARKER_, + obj: Any = _DEFAULT_MARKER_, parent: Optional[BaseContainer] = None, flags: Optional[Dict[str, bool]] = None, ) -> Union[DictConfig, ListConfig]: @@ -797,7 +809,7 @@ def _create_impl( # noqa F811 from .dictconfig import DictConfig from .listconfig import ListConfig - if obj is _EMPTY_MARKER_: + if obj is _DEFAULT_MARKER_: obj = {} if isinstance(obj, str): obj = yaml.load(obj, Loader=get_yaml_loader()) diff --git a/tests/test_interpolation.py b/tests/test_interpolation.py index e46bec84a..f61615b47 100644 --- a/tests/test_interpolation.py +++ b/tests/test_interpolation.py @@ -211,71 +211,210 @@ def test_type_inherit_type(cfg: Any) -> None: assert type(cfg.s) == str # check that string interpolations are always strings +@pytest.mark.parametrize("env_func", ["env", "oc.env"]) +class TestEnvInterpolation: + @pytest.mark.parametrize( + ("cfg", "env_name", "env_val", "key", "expected"), + [ + pytest.param( + {"path": "/test/${${env_func}:foo}"}, + "foo", + "1234", + "path", + "/test/1234", + id="simple", + ), + pytest.param( + {"path": "/test/${${env_func}:not_found,ZZZ}"}, + None, + None, + "path", + "/test/ZZZ", + id="not_found_with_default", + ), + pytest.param( + {"path": "/test/${${env_func}:not_found,a/b}"}, + None, + None, + "path", + "/test/a/b", + id="not_found_with_default", + ), + ], + ) + def test_env_interpolation( + self, + # DEPRECATED: remove in 2.2 with the legacy env resolver + recwarn: Any, + monkeypatch: Any, + env_func: str, + cfg: Any, + env_name: Optional[str], + env_val: str, + key: str, + expected: Any, + ) -> None: + if env_name is not None: + monkeypatch.setenv(env_name, env_val) + + cfg["env_func"] = env_func # allows choosing which env resolver to use + cfg = OmegaConf.create(cfg) + + assert OmegaConf.select(cfg, key) == expected + + @pytest.mark.parametrize( + ("cfg", "key", "expected"), + [ + pytest.param( + {"path": "/test/${${env_func}:not_found}"}, + "path", + pytest.raises( + InterpolationResolutionError, + match=re.escape("Environment variable 'not_found' not found"), + ), + id="not_found", + ), + ], + ) + def test_env_interpolation_error( + self, + # DEPRECATED: remove in 2.2 with the legacy env resolver + recwarn: Any, + env_func: str, + cfg: Any, + key: str, + expected: Any, + ) -> None: + cfg["env_func"] = env_func # allows choosing which env resolver to use + cfg = _ensure_container(cfg) + + with expected: + OmegaConf.select(cfg, key) + + +def test_legacy_env_is_cached(monkeypatch: Any) -> None: + monkeypatch.setenv("FOOBAR", "1234") + c = OmegaConf.create({"foobar": "${env:FOOBAR}"}) + with pytest.warns(UserWarning): + before = c.foobar + monkeypatch.setenv("FOOBAR", "3456") + assert c.foobar == before + + +def test_env_is_not_cached(monkeypatch: Any) -> None: + monkeypatch.setenv("FOOBAR", "1234") + c = OmegaConf.create({"foobar": "${oc.env:FOOBAR}"}) + before = c.foobar + monkeypatch.setenv("FOOBAR", "3456") + assert c.foobar != before + + @pytest.mark.parametrize( - "cfg,env_name,env_val,key,expected", + "value,expected", + [ + # We only test a few typical cases: more extensive grammar tests are + # found in `test_grammar.py`. + # bool + ("false", False), + ("true", True), + # int + ("10", 10), + ("-10", -10), + # float + ("10.0", 10.0), + ("-10.0", -10.0), + # null + ("null", None), + ("NulL", None), + # strings + ("hello", "hello"), + ("hello world", "hello world"), + (" 123 ", " 123 "), + ('"123"', "123"), + # lists and dicts + ("[1, 2, 3]", [1, 2, 3]), + ("{a: 0, b: 1}", {"a": 0, "b": 1}), + ("[\t1, 2, 3\t]", [1, 2, 3]), + ("{ a: b\t }", {"a": "b"}), + # interpolations + ("${parent.sibling}", 1), + ("${.sibling}", 1), + ("${..parent.sibling}", 1), + ("${uncle}", 2), + ("${..uncle}", 2), + ("${oc.env:MYKEY}", 456), + ], +) +def test_decode(monkeypatch: Any, value: Optional[str], expected: Any) -> None: + monkeypatch.setenv("MYKEY", "456") + c = OmegaConf.create( + { + # The node of interest is "node" (others are used to test interpolations). + "parent": { + "node": f"${{oc.decode:'{value}'}}", + "sibling": 1, + }, + "uncle": 2, + } + ) + assert c.parent.node == expected + + +def test_decode_none() -> None: + c = OmegaConf.create({"x": "${oc.decode:null}"}) + assert c.x is None + + +@pytest.mark.parametrize( + ("value", "exc"), [ pytest.param( - {"path": "/test/${env:foo}"}, - "foo", - "1234", - "path", - "/test/1234", - id="simple", - ), - pytest.param( - {"path": "/test/${env:not_found}"}, - None, - None, - "path", + 123, pytest.raises( InterpolationResolutionError, - match=re.escape("Environment variable 'not_found' not found"), + match=re.escape( + "TypeError raised while resolving interpolation: " + "`oc.decode` can only take strings or None as input, but `123` is of type int" + ), ), - id="not_found", + id="bad_type", ), pytest.param( - {"path": "/test/${env:not_found,ZZZ}"}, - None, - None, - "path", - "/test/ZZZ", - id="not_found_with_default", + "'[1, '", + pytest.raises( + InterpolationResolutionError, + match=re.escape( + "GrammarParseError raised while resolving interpolation: " + "missing BRACKET_CLOSE at ''" + ), + ), + id="parse_error", ), pytest.param( - {"path": "/test/${env:not_found,a/b}"}, - None, - None, - "path", - "/test/a/b", - id="not_found_with_default", + # Must be escaped to prevent resolution before feeding it to `oc.decode`. + "'\\${foo}'", + pytest.raises( + InterpolationResolutionError, + match=re.escape("Interpolation key 'foo' not found"), + ), + id="interpolation_not_found", ), ], ) -def test_env_interpolation( - monkeypatch: Any, - cfg: Any, - env_name: Optional[str], - env_val: str, - key: str, - expected: Any, -) -> None: - if env_name is not None: - monkeypatch.setenv(env_name, env_val) - - cfg = _ensure_container(cfg) - if isinstance(expected, RaisesContext): - with expected: - OmegaConf.select(cfg, key) - else: - assert OmegaConf.select(cfg, key) == expected +def test_decode_error(monkeypatch: Any, value: Any, exc: Any) -> None: + c = OmegaConf.create({"x": f"${{oc.decode:{value}}}"}) + with exc: + c.x -def test_env_is_cached(monkeypatch: Any) -> None: - monkeypatch.setenv("foobar", "1234") - c = OmegaConf.create({"foobar": "${env:foobar}"}) - before = c.foobar - monkeypatch.setenv("foobar", "3456") - assert c.foobar == before +@pytest.mark.parametrize( + "value", + ["false", "true", "10", "1.5", "null", "None", "${foo}"], +) +def test_env_preserves_string(monkeypatch: Any, value: str) -> None: + monkeypatch.setenv("MYKEY", value) + c = OmegaConf.create({"my_key": "${oc.env:MYKEY}"}) + assert c.my_key == value @pytest.mark.parametrize( @@ -304,47 +443,44 @@ def test_env_is_cached(monkeypatch: Any) -> None: # more advanced uses of the grammar ("ab \\{foo} cd", "ab \\{foo} cd"), ("ab \\\\{foo} cd", "ab \\\\{foo} cd"), - ("'\\${other_key}'", "${other_key}"), # escaped interpolation - ("'ab \\${other_key} cd'", "ab ${other_key} cd"), # escaped interpolation - ("[1, 2, 3]", [1, 2, 3]), - ("{a: 0, b: 1}", {"a": 0, "b": 1}), - (" 123 ", " 123 "), (" 1 2 3 ", " 1 2 3 "), ("\t[1, 2, 3]\t", "\t[1, 2, 3]\t"), - ("[\t1, 2, 3\t]", [1, 2, 3]), (" {a: b}\t ", " {a: b}\t "), - ("{ a: b\t }", {"a": "b"}), - ("'123'", "123"), - ("${env:my_key_2}", 456), # can call another resolver ], ) -def test_env_values_are_typed(monkeypatch: Any, value: Any, expected: Any) -> None: - monkeypatch.setenv("my_key", value) - monkeypatch.setenv("my_key_2", "456") - c = OmegaConf.create({"my_key": "${env:my_key}"}) - assert c.my_key == expected +def test_legacy_env_values_are_typed( + monkeypatch: Any, value: Any, expected: Any +) -> None: + monkeypatch.setenv("MYKEY", value) + c = OmegaConf.create({"my_key": "${env:MYKEY}"}) + with pytest.warns(UserWarning, match=re.escape("The `env` resolver is deprecated")): + assert c.my_key == expected -def test_env_node_interpolation(monkeypatch: Any) -> None: - # Test that node interpolations are not supported in env variables. - monkeypatch.setenv("MYKEY", "${other_key}") - c = OmegaConf.create({"my_key": "${env:MYKEY}", "other_key": 123}) +def test_env_default_none(monkeypatch: Any) -> None: + monkeypatch.delenv("MYKEY", raising=False) + c = OmegaConf.create({"my_key": "${oc.env:MYKEY, null}"}) + assert c.my_key is None + + +@pytest.mark.parametrize("has_var", [True, False]) +def test_env_non_str_default(monkeypatch: Any, has_var: bool) -> None: + if has_var: + monkeypatch.setenv("MYKEY", "456") + else: + monkeypatch.delenv("MYKEY", raising=False) + + c = OmegaConf.create({"my_key": "${oc.env:MYKEY, 123}"}) with pytest.raises( - InterpolationKeyError, + InterpolationResolutionError, match=re.escape( - "When attempting to resolve env variable 'MYKEY', a node interpolation caused " - "the following exception: Interpolation key 'other_key' not found." + "TypeError raised while resolving interpolation: The default value " + "of the `oc.env` resolver must be a string or None, but `123` is of type int" ), ): c.my_key -def test_env_default_none(monkeypatch: Any) -> None: - monkeypatch.delenv("my_key", raising=False) - c = OmegaConf.create({"my_key": "${env:my_key, null}"}) - assert c.my_key is None - - def test_register_resolver_twice_error(restore_resolvers: Any) -> None: def foo(_: Any) -> int: return 10 diff --git a/tests/test_utils.py b/tests/test_utils.py index fe80d283d..a308ad2c6 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -8,6 +8,7 @@ from omegaconf import DictConfig, ListConfig, Node, OmegaConf, _utils from omegaconf._utils import ( + Marker, _ensure_container, _get_value, _make_hashable, @@ -58,7 +59,7 @@ param(float, 1, FloatNode(1), id="float"), param(float, 1.0, FloatNode(1.0), id="float"), param(float, Color.RED, ValidationError, id="float"), - # # bool + # bool param(bool, "foo", ValidationError, id="bool"), param(bool, True, BooleanNode(True), id="bool"), param(bool, 1, BooleanNode(True), id="bool"), @@ -631,3 +632,8 @@ def test_ensure_container_raises_ValueError() -> None: ), ): _ensure_container("abc") + + +def test_marker_string_representation() -> None: + marker = Marker("marker") + assert repr(marker) == "marker"