diff --git a/.circleci/config.yml b/.circleci/config.yml index 7676f8ec3..7ffe0fac4 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1,47 +1,41 @@ version: 2.1 -jobs: - py36_linux: - docker: - - image: circleci/python:3.6 - steps: - - checkout - - run: echo 'export NOX_PYTHON_VERSIONS=3.6' >> $BASH_ENV - - run: sudo pip install nox - - run: nox --add-timestamp - - py37_linux: - docker: - - image: circleci/python:3.7 +commands: + linux: + description: "Commands run on Linux" + parameters: + py_version: + type: string steps: - checkout - - run: echo 'export NOX_PYTHON_VERSIONS=3.7' >> $BASH_ENV - - run: sudo pip install nox - - run: nox --add-timestamp + - run: + name: "Preparing environment" + command: | + sudo apt-get update + sudo apt-get install -y openjdk-11-jre + sudo pip install nox + - run: + name: "Testing OmegaConf" + command: | + export NOX_PYTHON_VERSIONS=<< parameters.py_version >> + nox --add-timestamp - py38_linux: - docker: - - image: circleci/python:3.8 - steps: - - checkout - - run: echo 'export NOX_PYTHON_VERSIONS=3.8' >> $BASH_ENV - - run: sudo pip install nox - - run: nox --add-timestamp - - py39_linux: +jobs: + test_linux: + parameters: + py_version: + type: string docker: - - image: circleci/python:3.9 + - image: circleci/python:<< parameters.py_version >> steps: - - checkout - - run: echo 'export NOX_PYTHON_VERSIONS=3.9' >> $BASH_ENV - - run: sudo pip install nox - - run: nox --add-timestamp + - linux: + py_version: << parameters.py_version >> workflows: version: 2 build: jobs: - - py36_linux - - py37_linux - - py38_linux - - py39_linux + - test_linux: + matrix: + parameters: + py_version: ["3.6", "3.7", "3.8", "3.9"] diff --git a/.coveragerc b/.coveragerc index 4c9bab177..7121c9090 100644 --- a/.coveragerc +++ b/.coveragerc @@ -3,6 +3,7 @@ omit = .nox/* *tests* docs/* + omegaconf/grammar/gen/* omegaconf/version.py .stubs @@ -15,7 +16,7 @@ exclude_lines = assert False @abstractmethod \.\.\. - + if TYPE_CHECKING: [html] directory = docs/build/coverage diff --git a/.flake8 b/.flake8 index 20adc760b..e6d773080 100644 --- a/.flake8 +++ b/.flake8 @@ -1,5 +1,5 @@ [flake8] -exclude = .git,.nox,.tox +exclude = .git,.nox,.tox,omegaconf/grammar/gen max-line-length = 119 select = E,F,W,C ignore=W503,E203 diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 000000000..084917e80 --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +*.jar binary diff --git a/.gitignore b/.gitignore index fbba511c5..16ea1a711 100644 --- a/.gitignore +++ b/.gitignore @@ -16,7 +16,9 @@ TODO .coverage .eggs .mypy_cache +/omegaconf/grammar/gen /pip-wheel-metadata /.pyre .dmypy.json .python-version +.vscode diff --git a/.isort.cfg b/.isort.cfg index ba22f7506..a1bd71c51 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -7,3 +7,4 @@ line_length=88 ensure_newline_before_comments=True known_third_party=attr,pytest known_first_party=omegaconf +skip=.eggs,.nox,omegaconf/grammar/gen diff --git a/benchmark/benchmark.py b/benchmark/benchmark.py index 8638d8217..1c77ad281 100644 --- a/benchmark/benchmark.py +++ b/benchmark/benchmark.py @@ -4,6 +4,7 @@ from pytest import fixture, mark, param from omegaconf import OmegaConf +from omegaconf._utils import ValueKind, get_value_kind def build_dict( @@ -119,3 +120,23 @@ def iterate(seq: Any) -> None: pass benchmark(iterate, lst) + + +@mark.parametrize( + "strict_interpolation_validation", + [True, False], +) +@mark.parametrize( + ("value", "expected"), + [ + ("simple", ValueKind.VALUE), + ("${a}", ValueKind.INTERPOLATION), + ("${a:b,c,d}", ValueKind.INTERPOLATION), + ("${${b}}", ValueKind.INTERPOLATION), + ("${a:${b}}", ValueKind.INTERPOLATION), + ], +) +def test_get_value_kind( + strict_interpolation_validation: bool, value: Any, expected: Any, benchmark: Any +) -> None: + assert benchmark(get_value_kind, value, strict_interpolation_validation) == expected diff --git a/build_helpers/__init__.py b/build_helpers/__init__.py new file mode 100644 index 000000000..aa0c875c5 --- /dev/null +++ b/build_helpers/__init__.py @@ -0,0 +1,3 @@ +# Order of imports is important (see warning otherwise when running tests) +import setuptools # isort:skip # noqa +import distutils # isort:skip # noqa diff --git a/build_helpers/bin/antlr-4.8-complete.jar b/build_helpers/bin/antlr-4.8-complete.jar new file mode 100644 index 000000000..89a0640e2 Binary files /dev/null and b/build_helpers/bin/antlr-4.8-complete.jar differ diff --git a/build_helpers/build_helpers.py b/build_helpers/build_helpers.py new file mode 100644 index 000000000..98baacd0e --- /dev/null +++ b/build_helpers/build_helpers.py @@ -0,0 +1,195 @@ +import codecs +import distutils.log +import errno +import os +import re +import shutil +import subprocess +import sys +from pathlib import Path +from typing import List, Optional + +from setuptools import Command +from setuptools.command import build_py, develop, sdist # type: ignore + + +class ANTLRCommand(Command): # type: ignore # pragma: no cover + """Generate parsers using ANTLR.""" + + description = "Run ANTLR" + user_options: List[str] = [] + + def run(self) -> None: + """Run command.""" + build_dir = Path(__file__).parent.absolute() + project_root = build_dir.parent + for grammar in [ + "OmegaConfGrammarLexer.g4", + "OmegaConfGrammarParser.g4", + ]: + command = [ + "java", + "-jar", + str(build_dir / "bin" / "antlr-4.8-complete.jar"), + "-Dlanguage=Python3", + "-o", + str(project_root / "omegaconf" / "grammar" / "gen"), + "-Xexact-output-dir", + "-visitor", + str(project_root / "omegaconf" / "grammar" / grammar), + ] + + self.announce( + f"Generating parser for Python3: {command}", + level=distutils.log.INFO, + ) + + subprocess.check_call(command) + + def initialize_options(self) -> None: + pass + + def finalize_options(self) -> None: + pass + + +class BuildPyCommand(build_py.build_py): # type: ignore # pragma: no cover + def run(self) -> None: + if not self.dry_run: + self.run_command("clean") + run_antlr(self) + build_py.build_py.run(self) + + +class CleanCommand(Command): # type: ignore # pragma: no cover + """ + Our custom command to clean out junk files. + """ + + description = "Cleans out generated and junk files we don't want in the repo" + dry_run: bool + user_options: List[str] = [] + + def run(self) -> None: + root = Path(__file__).parent.parent.absolute() + files = find( + root=root, + include_files=["^omegaconf/grammar/gen/.*"], + include_dirs=[ + "^omegaconf\\.egg-info$", + "\\.eggs$", + "^\\.mypy_cache$", + "^\\.nox$", + "^\\.pytest_cache$", + ".*/__pycache__$", + "^__pycache__$", + "^build$", + "^dist$", + ], + scan_exclude=["^.git$", "^.nox/.*$"], + excludes=[".*\\.gitignore$", ".*/__init__.py"], + ) + + if self.dry_run: + print("Dry run! Would clean up the following files and dirs:") + print("\n".join(sorted(map(str, files)))) + else: + for f in files: + if f.exists(): + if f.is_dir(): + shutil.rmtree(f, ignore_errors=True) + else: + f.unlink() + + def initialize_options(self) -> None: + pass + + def finalize_options(self) -> None: + pass + + +class DevelopCommand(develop.develop): # type: ignore # pragma: no cover + def run(self) -> None: + if not self.dry_run: + run_antlr(self) + develop.develop.run(self) + + +class SDistCommand(sdist.sdist): # type: ignore # pragma: no cover + def run(self) -> None: + if not self.dry_run: + self.run_command("clean") + run_antlr(self) + sdist.sdist.run(self) + + +def find( + root: Path, + include_files: List[str], + include_dirs: List[str], + excludes: List[str], + rbase: Optional[Path] = None, + scan_exclude: Optional[List[str]] = None, +) -> List[Path]: + if rbase is None: + rbase = Path() + if scan_exclude is None: + scan_exclude = [] + files = [] + scan_root = root / rbase + for entry in scan_root.iterdir(): + path = rbase / entry.name + if matches(scan_exclude, path): + continue + + if entry.is_dir(): + if matches(include_dirs, path): + if not matches(excludes, path): + files.append(path) + else: + ret = find( + root=root, + include_files=include_files, + include_dirs=include_dirs, + excludes=excludes, + rbase=path, + scan_exclude=scan_exclude, + ) + files.extend(ret) + else: + if matches(include_files, path) and not matches(excludes, path): + files.append(path) + + return files + + +def find_version(*file_paths: str) -> str: + root = Path(__file__).parent.parent.absolute() + with codecs.open(root / Path(*file_paths), "r") as fp: # type: ignore + version_file = fp.read() + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M) + if version_match: + return version_match.group(1) + raise RuntimeError("Unable to find version string.") # pragma: no cover + + +def matches(patterns: List[str], path: Path) -> bool: + string = str(path).replace(os.sep, "/") # for Windows + for pattern in patterns: + if re.match(pattern, string): + return True + return False + + +def run_antlr(cmd: Command) -> None: # pragma: no cover + try: + cmd.announce("Generating parsers with antlr4", level=distutils.log.INFO) + cmd.run_command("antlr") + except OSError as e: + if e.errno == errno.ENOENT: + msg = f"| Unable to generate parsers: {e} |" + msg = "=" * len(msg) + "\n" + msg + "\n" + "=" * len(msg) + cmd.announce(f"{msg}", level=distutils.log.FATAL) + sys.exit(1) + else: + raise diff --git a/build_helpers/test_files/a/b/bad_dir/.gitkeep b/build_helpers/test_files/a/b/bad_dir/.gitkeep new file mode 100644 index 000000000..0b05dbc19 --- /dev/null +++ b/build_helpers/test_files/a/b/bad_dir/.gitkeep @@ -0,0 +1 @@ +Intentionally left empty for git to keep the directory \ No newline at end of file diff --git a/build_helpers/test_files/a/b/file1.txt b/build_helpers/test_files/a/b/file1.txt new file mode 100644 index 000000000..e69de29bb diff --git a/build_helpers/test_files/a/b/file2.txt b/build_helpers/test_files/a/b/file2.txt new file mode 100644 index 000000000..e69de29bb diff --git a/build_helpers/test_files/a/b/junk.txt b/build_helpers/test_files/a/b/junk.txt new file mode 100644 index 000000000..e69de29bb diff --git a/build_helpers/test_files/c/bad_dir/.gitkeep b/build_helpers/test_files/c/bad_dir/.gitkeep new file mode 100644 index 000000000..0b05dbc19 --- /dev/null +++ b/build_helpers/test_files/c/bad_dir/.gitkeep @@ -0,0 +1 @@ +Intentionally left empty for git to keep the directory \ No newline at end of file diff --git a/build_helpers/test_files/c/file1.txt b/build_helpers/test_files/c/file1.txt new file mode 100644 index 000000000..e69de29bb diff --git a/build_helpers/test_files/c/file2.txt b/build_helpers/test_files/c/file2.txt new file mode 100644 index 000000000..e69de29bb diff --git a/build_helpers/test_files/c/junk.txt b/build_helpers/test_files/c/junk.txt new file mode 100644 index 000000000..e69de29bb diff --git a/build_helpers/test_helpers.py b/build_helpers/test_helpers.py new file mode 100644 index 000000000..7f10a1d85 --- /dev/null +++ b/build_helpers/test_helpers.py @@ -0,0 +1,142 @@ +from pathlib import Path +from typing import List + +import pytest + +from build_helpers.build_helpers import find, find_version, matches + + +@pytest.mark.parametrize( + "path_rel,include_files,include_dirs,excludes,scan_exclude,expected", + [ + pytest.param("test_files", [], [], [], None, [], id="none"), + pytest.param( + "test_files", + [".*"], + [], + [], + [], + [ + "a/b/bad_dir/.gitkeep", + "a/b/file2.txt", + "a/b/file1.txt", + "a/b/junk.txt", + "c/bad_dir/.gitkeep", + "c/file2.txt", + "c/file1.txt", + "c/junk.txt", + ], + id="all", + ), + pytest.param( + "test_files", + [".*"], + [], + ["^a/.*"], + [], + ["c/bad_dir/.gitkeep", "c/file2.txt", "c/file1.txt", "c/junk.txt"], + id="filter_a", + ), + pytest.param( + "test_files", + [".*"], + [], + [], + ["^a/.*"], + ["c/bad_dir/.gitkeep", "c/file2.txt", "c/file1.txt", "c/junk.txt"], + id="do_not_scan_a", + ), + pytest.param( + "test_files", + ["^a/.*"], + [], + [], + [], + ["a/b/bad_dir/.gitkeep", "a/b/file2.txt", "a/b/file1.txt", "a/b/junk.txt"], + id="include_a", + ), + pytest.param( + "test_files", + ["^a/.*"], + [], + [".*/file1\\.txt"], + [], + ["a/b/bad_dir/.gitkeep", "a/b/file2.txt", "a/b/junk.txt"], + id="include_a,exclude_file1", + ), + pytest.param( + "test_files", + [".*"], + [], + ["^.*/junk.txt$"], + [], + [ + "a/b/bad_dir/.gitkeep", + "a/b/file2.txt", + "a/b/file1.txt", + "c/bad_dir/.gitkeep", + "c/file2.txt", + "c/file1.txt", + ], + id="no_junk", + ), + pytest.param( + "test_files", + ["^.*/junk.txt"], + [], + [], + [], + ["a/b/junk.txt", "c/junk.txt"], + id="junk_only", + ), + pytest.param("test_files", [], ["^a$"], [], [], ["a"], id="exact_a"), + pytest.param( + "test_files", + [], + [".*bad_dir$"], + [], + [], + ["a/b/bad_dir", "c/bad_dir"], + id="bad_dirs", + ), + ], +) +def test_find( + path_rel: str, + include_files: List[str], + include_dirs: List[str], + excludes: List[str], + scan_exclude: List[str], + expected: List[str], +) -> None: + basedir = Path(__file__).parent.absolute() + path = basedir / path_rel + ret = find( + root=path, + excludes=excludes, + include_files=include_files, + include_dirs=include_dirs, + scan_exclude=scan_exclude, + ) + + ret_set = set([str(x) for x in ret]) + expected_set = set([str(Path(x)) for x in expected]) + assert ret_set == expected_set + + +@pytest.mark.parametrize( + "patterns,query,expected", + [ + (["^a/.*"], Path("a") / "b.txt", True), + (["^/foo/bar/.*"], Path("/foo") / "bar" / "blag", True), + ], +) +def test_matches(patterns: List[str], query: Path, expected: bool) -> None: + ret = matches(patterns, query) + assert ret == expected + + +def test_find_version() -> None: + version = find_version("omegaconf", "version.py") + # Ensure `version` is a string starting with a digit. + assert isinstance(version, str) and version and "0" <= version[0] <= "9" diff --git a/docs/notebook/Tutorial.ipynb b/docs/notebook/Tutorial.ipynb index 9898a7307..b3547fcfc 100644 --- a/docs/notebook/Tutorial.ipynb +++ b/docs/notebook/Tutorial.ipynb @@ -520,7 +520,7 @@ "\n", "OmegaConf support variable interpolation, Interpolations are evaluated lazily on access.\n", "\n", - "### Config node interpolation" + "## Config node interpolation" ] }, { @@ -585,7 +585,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "to_yaml will resolve interpolation if `resolve=True` is passed" + "`to_yaml()` will resolve interpolations if `resolve=True` is passed" ] }, { @@ -617,20 +617,106 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Environment variable interpolation\n", + "Interpolated nodes can be any node in the config, not just leaf nodes:" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "cfg.player.height: 180\n", + "cfg.player.weight: 75\n", + "=== Switching player\n", + "cfg.player.height: 195\n", + "cfg.player.weight: 90\n" + ] + } + ], + "source": [ + "from textwrap import dedent\n", + "cfg = OmegaConf.create(\n", + " dedent(\n", + " \"\"\"\\\n", + " john:\n", + " height: 180\n", + " weight: 75\n", "\n", - "Environment variable interpolation is also supported.\n", + " fred:\n", + " height: 195\n", + " weight: 90\n", + " \n", + " player: ${john}\n", + " \"\"\"\n", + " )\n", + ")\n", + "print(f\"cfg.player.height: {cfg.player.height}\")\n", + "print(f\"cfg.player.weight: {cfg.player.weight}\")\n", + "print(\"=== Switching player\")\n", + "cfg.player = \"${fred}\"\n", + "print(f\"cfg.player.height: {cfg.player.height}\")\n", + "print(f\"cfg.player.weight: {cfg.player.weight}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Interpolations may be nested, enabling more advanced behavior like dynamically selecting a sub-config:" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Default: cfg.plan = plan A\n", + "After selecting plan B: cfg.plan = plan B\n" + ] + } + ], + "source": [ + "from textwrap import dedent\n", + "cfg = OmegaConf.create(\n", + " dedent(\n", + " \"\"\"\\\n", + " plans:\n", + " A: plan A\n", + " B: plan B\n", + " selected_plan: A\n", + " plan: ${plans.${selected_plan}}\n", + " \"\"\"\n", + " )\n", + ")\n", + "print(f\"Default: cfg.plan = {cfg.plan}\")\n", + "cfg.selected_plan = \"B\"\n", + "print(f\"After selecting plan B: cfg.plan = {cfg.plan}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Environment variable interpolation\n", "\n", - "Input yaml file:" + "Environment variable interpolation is also supported." ] }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 22, "metadata": {}, "outputs": [], "source": [ - "# Let's set up the environment first (Only needed for this demonstration)\n", + "# Let's set up the environment first (only needed for this demonstration)\n", "import os\n", "os.environ['USER'] = 'omry'" ] @@ -644,7 +730,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 23, "metadata": { "pycharm": { "name": "#%%\n" @@ -669,7 +755,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 24, "metadata": { "pycharm": { "name": "#%%\n" @@ -696,25 +782,63 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "You can also set a default value for environment variables:" + "You can specify a default value to use in case the environment variable is not defined. The following example sets `abc123` as the the default value when `DB_PASSWORD` is not defined." ] }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "42\n" + "'abc123'\n" ] } ], "source": [ - "conf = OmegaConf.create({\"user\" : {\"age\" : \"${env:AGE, 42}\"}})\n", - "print(conf.user.age)" + "os.environ.pop('DB_PASSWORD', None) # ensure env variable does not exist\n", + "cfg = OmegaConf.create({'database': {'password': '${env:DB_PASSWORD,abc123}'}})\n", + "print(repr(cfg.database.password))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Environment variables are parsed when they are recognized as valid quantities that may be evaluated (e.g., int, float, dict, list):" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3308\n", + "['host1', 'host2', 'host3']\n", + "'a%#@~{}$*&^?/<'\n" + ] + } + ], + "source": [ + "cfg = OmegaConf.create({'database': {'password': '${env:DB_PASSWORD,abc123}',\n", + " 'user': 'someuser',\n", + " 'port': '${env:DB_PORT,3306}',\n", + " 'nodes': '${env:DB_NODES,[]}'}})\n", + "\n", + "os.environ[\"DB_PORT\"] = '3308' # integer\n", + "os.environ[\"DB_NODES\"] = '[host1, host2, host3]' # list\n", + "os.environ[\"DB_PASSWORD\"] = 'a%#@~{}$*&^?/<' # string\n", + "\n", + "print(repr(cfg.database.port))\n", + "print(repr(cfg.database.nodes))\n", + "print(repr(cfg.database.password))" ] }, { @@ -728,30 +852,106 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "You can add additional interpolation types using custom resolvers. This example creates a resolver that adds 10 the the given value." + "You can add additional interpolation types using custom resolvers.\n", + "The example below creates a resolver that adds 10 to the given value." ] }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 27, "metadata": { "pycharm": { "name": "#%%\n" } }, + "outputs": [ + { + "data": { + "text/plain": [ + "1000" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "OmegaConf.register_new_resolver(\"plus_10\", lambda x: x + 10)\n", + "conf = OmegaConf.create({'key': '${plus_10:990}'})\n", + "conf.key" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can take advantage of nested interpolations to perform custom operations over variables:" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "3" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "OmegaConf.register_new_resolver(\"plus\", lambda x, y: x + y)\n", + "conf = OmegaConf.create({\"a\": 1, \"b\": 2, \"a_plus_b\": \"${plus:${a},${b}}\"})\n", + "conf.a_plus_b" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "By default a custom resolver’s output is cached, so that when it is called with the same inputs we always return the same value. This behavior may be disabled by setting `use_cache=False`:" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "1000\n" + "With cache: \n", + "0.9664535356921388\n", + "0.9664535356921388\n", + "Without cache: \n", + "0.4407325991753527\n", + "0.007491470058587191\n" ] } ], "source": [ - "OmegaConf.register_resolver(\"plus_10\", lambda x: int(x) + 10)\n", - "conf = OmegaConf.create({'key': '${plus_10:990}'})\n", - "print(conf.key)" + "import random\n", + "random.seed(1234)\n", + "\n", + "OmegaConf.register_new_resolver(\"cached\", random.random)\n", + "OmegaConf.register_new_resolver(\"uncached\", random.random, use_cache=False)\n", + "\n", + "cfg = OmegaConf.create({\"cached\": \"${cached:}\", \"uncached\": \"${uncached:}\"})\n", + "\n", + "print(\"With cache: \")\n", + "print(cfg.cached)\n", + "print(cfg.cached) # same as above\n", + "\n", + "print(\"Without cache: \")\n", + "print(cfg.uncached)\n", + "print(cfg.uncached) # *not* the same as above" ] }, { @@ -777,7 +977,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 30, "metadata": { "pycharm": { "name": "#%%\n" @@ -804,7 +1004,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 31, "metadata": { "pycharm": { "name": "#%%\n" @@ -828,7 +1028,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 32, "metadata": { "pycharm": { "name": "#%%\n" @@ -863,13 +1063,6 @@ "conf.merge_with_cli()\n", "print(OmegaConf.to_yaml(conf))" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { @@ -888,9 +1081,18 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.1" + "version": "3.8.3" + }, + "pycharm": { + "stem_cell": { + "cell_type": "raw", + "metadata": { + "collapsed": false + }, + "source": [] + } } }, "nbformat": 4, "nbformat_minor": 2 -} \ No newline at end of file +} diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 29171185a..243fd6962 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -313,17 +313,36 @@ Example: .. doctest:: >>> conf = OmegaConf.load('source/config_interpolation.yaml') - >>> # Primitive interpolation types are inherited from the referenced value - >>> print(conf.client.server_port) + >>> # Primitive interpolation types are inherited from the reference + >>> conf.client.server_port 80 - >>> print(type(conf.client.server_port).__name__) - int + >>> type(conf.client.server_port).__name__ + 'int' >>> # Composite interpolation types are always string - >>> print(conf.client.url) - http://localhost:80/ - >>> print(type(conf.client.url).__name__) - str + >>> conf.client.url + 'http://localhost:80/' + >>> type(conf.client.url).__name__ + 'str' + + +Interpolations may be nested, enabling more advanced behavior like dynamically selecting a sub-config: + +.. doctest:: + + + >>> cfg = OmegaConf.create( + ... { + ... "plans": {"A": "plan A", "B": "plan B"}, + ... "selected_plan": "A", + ... "plan": "${plans.${selected_plan}}", + ... } + ... ) + >>> cfg.plan # default plan + 'plan A' + >>> cfg.selected_plan = "B" + >>> cfg.plan # new plan + 'plan B' Interpolated nodes can be any node in the config, not just leaf nodes: @@ -352,34 +371,53 @@ Input YAML file: .. doctest:: >>> conf = OmegaConf.load('source/env_interpolation.yaml') - >>> print(conf.user.name) - omry - >>> print(conf.user.home) - /home/omry + >>> conf.user.name + 'omry' + >>> conf.user.home + '/home/omry' You can specify a default value to use in case the environment variable is not defined. -The following example sets `12345` as the the default value for the `DB_PASSWORD` environment variable. +The following example sets `abc123` as the the default value when `DB_PASSWORD` is not defined. .. doctest:: >>> cfg = OmegaConf.create({ - ... 'database': {'password': '${env:DB_PASSWORD,12345}'} + ... 'database': {'password': '${env:DB_PASSWORD,abc123}'} ... }) - >>> print(cfg.database.password) - 12345 - >>> OmegaConf.clear_cache(cfg) # clear resolver cache - >>> os.environ["DB_PASSWORD"] = 'secret' - >>> print(cfg.database.password) - secret + >>> cfg.database.password + 'abc123' + +Environment variables are parsed when they are recognized as valid quantities that +may be evaluated (e.g., int, float, dict, list): + +.. doctest:: + + >>> cfg = OmegaConf.create({ + ... 'database': {'password': '${env:DB_PASSWORD,abc123}', + ... 'user': 'someuser', + ... 'port': '${env:DB_PORT,3306}', + ... 'nodes': '${env:DB_NODES,[]}'} + ... }) + >>> os.environ["DB_PORT"] = '3308' + >>> cfg.database.port # converted to int + 3308 + >>> os.environ["DB_NODES"] = '[host1, host2, host3]' + >>> cfg.database.nodes # converted to list + ['host1', 'host2', 'host3'] + >>> os.environ["DB_PASSWORD"] = 'a%#@~{}$*&^?/<' + >>> cfg.database.password # kept as a string + 'a%#@~{}$*&^?/<' + Custom interpolations ^^^^^^^^^^^^^^^^^^^^^ + You can add additional interpolation types using custom resolvers. -This example creates a resolver that adds 10 the the given value. +The example below creates a resolver that adds 10 to the given value. .. doctest:: - >>> OmegaConf.register_resolver("plus_10", lambda x: int(x) + 10) + >>> OmegaConf.register_new_resolver("plus_10", lambda x: x + 10) >>> c = OmegaConf.create({'key': '${plus_10:990}'}) >>> c.key 1000 @@ -387,15 +425,17 @@ This example creates a resolver that adds 10 the the given value. Custom resolvers support variadic argument lists in the form of a comma separated list of zero or more values. Whitespaces are stripped from both ends of each value ("foo,bar" is the same as "foo, bar "). -You can use literal commas and spaces anywhere by escaping (:code:`\,` and :code:`\ `). +You can use literal commas and spaces anywhere by escaping (:code:`\,` and :code:`\ `), or +simply use quotes to bypass character limitations in strings. .. doctest:: - >>> OmegaConf.register_resolver("concat", lambda x,y: x+y) + >>> OmegaConf.register_new_resolver("concat", lambda x, y: x+y) >>> c = OmegaConf.create({ ... 'key1': '${concat:Hello,World}', ... 'key_trimmed': '${concat:Hello , World}', ... 'escape_whitespace': '${concat:Hello,\ World}', + ... 'quoted': '${concat:"Hello,", " World"}', ... }) >>> c.key1 'HelloWorld' @@ -403,7 +443,36 @@ You can use literal commas and spaces anywhere by escaping (:code:`\,` and :code 'HelloWorld' >>> c.escape_whitespace 'Hello World' + >>> c.quoted + 'Hello, World' + +You can take advantage of nested interpolations to perform custom operations over variables: + +.. doctest:: + + >>> OmegaConf.register_new_resolver("plus", lambda x, y: x + y) + >>> c = OmegaConf.create({"a": 1, + ... "b": 2, + ... "a_plus_b": "${plus:${a},${b}}"}) + >>> c.a_plus_b + 3 + +By default a custom resolver's output is cached, so that when it is called with the same +inputs we always return the same value. This behavior may be disabled by setting ``use_cache=False``: + +.. doctest:: + >>> import random + >>> random.seed(1234) + >>> OmegaConf.register_new_resolver("cached", random.randint) + >>> OmegaConf.register_new_resolver( + ... "uncached", random.randint, use_cache=False) + >>> c = OmegaConf.create({"cached": "${cached:0,10000}", + ... "uncached": "${uncached:0,10000}"}) + >>> # same value on repeated access thanks to the cache + >>> assert c.cached == c.cached == 7220 + >>> # not the same since the cache is disabled + >>> assert c.uncached != c.uncached Merging configurations diff --git a/news/230.bugfix b/news/230.bugfix new file mode 100644 index 000000000..61009f8be --- /dev/null +++ b/news/230.bugfix @@ -0,0 +1 @@ +`${env:MYVAR,null}` now properly returns `None` if the environment variable MYVAR is undefined. diff --git a/news/426.api_change b/news/426.api_change new file mode 100644 index 000000000..963543ce5 --- /dev/null +++ b/news/426.api_change @@ -0,0 +1 @@ +`register_resolver()` is deprecated in favor of `register_new_resolver()`, allowing resolvers to take non-string arguments like int, float, dict, etc. diff --git a/news/445.feature.1 b/news/445.feature.1 new file mode 100644 index 000000000..a11b098e2 --- /dev/null +++ b/news/445.feature.1 @@ -0,0 +1 @@ +Add ability to nest interpolations, e.g. ${env:{$var}} or ${foo.${bar}.${baz}} diff --git a/news/445.feature.2 b/news/445.feature.2 new file mode 100644 index 000000000..8326bd016 --- /dev/null +++ b/news/445.feature.2 @@ -0,0 +1 @@ +Custom resolvers may take non string arguments as input, and control whether or not to use the cache. diff --git a/news/445.feature.3 b/news/445.feature.3 new file mode 100644 index 000000000..0ec958b12 --- /dev/null +++ b/news/445.feature.3 @@ -0,0 +1 @@ +The `env` resolver parses environment variables. Supported types includes primitives (int, float, bool, ...) and containers like dict and list. The used grammar is a subset of the interpolation grammar. diff --git a/noxfile.py b/noxfile.py index 6b226a40a..484c1c466 100644 --- a/noxfile.py +++ b/noxfile.py @@ -10,26 +10,27 @@ ).split(",") -def deps(session): +def deps(session, editable_installl): session.install("--upgrade", "setuptools", "pip") - session.install("-r", "requirements/dev.txt", ".", silent=True) + extra_flags = ["-e"] if editable_installl else [] + session.install("-r", "requirements/dev.txt", *extra_flags, ".", silent=True) @nox.session(python=PYTHON_VERSIONS) def omegaconf(session): - deps(session) + deps(session, editable_installl=False) # ensure we test the regular install session.run("pytest") @nox.session(python=PYTHON_VERSIONS) def benchmark(session): - deps(session) + deps(session, editable_installl=True) session.run("pytest", "benchmark/benchmark.py") @nox.session def docs(session): - deps(session) + deps(session, editable_installl=True) session.chdir("docs") session.run("sphinx-build", "-W", "-b", "doctest", "source", "build") session.run("sphinx-build", "-W", "-b", "html", "source", "build") @@ -37,7 +38,12 @@ def docs(session): @nox.session(python=PYTHON_VERSIONS) def coverage(session): - deps(session) + # For coverage, we must use the editable installation because + # `coverage run -m pytest` prepends `sys.path` with "." (the current + # folder), so that the local code will be used in tests even if we set + # `editable_installl=False`. This would cause problems due to potentially + # missing the generated grammar files. + deps(session, editable_installl=True) session.run("coverage", "erase") session.run("coverage", "run", "--append", "-m", "pytest", silent=True) session.run("coverage", "report", "--fail-under=100") @@ -49,7 +55,7 @@ def coverage(session): @nox.session(python=PYTHON_VERSIONS) def lint(session): - deps(session) + deps(session, editable_installl=True) session.run("mypy", ".", "--strict", silent=True) session.run("isort", ".", "--check", silent=True) session.run("black", "--check", ".", silent=True) @@ -64,6 +70,6 @@ def test_jupyter_notebook(session): session.python, ",".join(DEFAULT_PYTHON_VERSIONS) ) ) - deps(session) + deps(session, editable_installl=False) session.install("jupyter", "nbval") session.run("pytest", "--nbval", "docs/notebook/Tutorial.ipynb", silent=True) diff --git a/omegaconf/_utils.py b/omegaconf/_utils.py index 1b9f461b6..766437dd6 100644 --- a/omegaconf/_utils.py +++ b/omegaconf/_utils.py @@ -4,8 +4,9 @@ import string import sys from enum import Enum +from functools import cmp_to_key from textwrap import dedent -from typing import Any, Dict, List, Match, Optional, Tuple, Type, Union, get_type_hints +from typing import Any, Dict, List, Optional, Tuple, Type, Union, get_type_hints import yaml @@ -15,6 +16,7 @@ ConfigValueError, OmegaConfBaseException, ) +from .grammar_parser import parse try: import dataclasses @@ -28,6 +30,19 @@ except ImportError: # pragma: no cover attr = None # type: ignore # pragma: no cover +# Build regex pattern to efficiently identify typical interpolations. +# See test `test_match_simple_interpolation_pattern` for examples. +_id = "[a-zA-Z_]\\w*" # foo, foo_bar, abc123 +_dot_path = f"{_id}(\\.{_id})*" # foo, foo.bar3, foo_.b4r.b0z +_inter_node = f"\\${{\\s*{_dot_path}\\s*}}" # node interpolation +_arg = "[a-zA-Z_0-9/\\-\\+.$%*@]+" # string representing a resolver argument +_args = f"{_arg}(\\s*,\\s*{_arg})*" # list of resolver arguments +_inter_res = f"\\${{\\s*{_id}\\s*:\\s*{_args}?\\s*}}" # resolver interpolation +_inter = f"({_inter_node}|{_inter_res})" # any kind of interpolation +_outer = "([^$]|\\$(?!{))+" # any character except $ (unless not followed by {) +SIMPLE_INTERPOLATION_PATTERN = re.compile( + f"({_outer})?({_inter}({_outer})?)+$", flags=re.ASCII +) # source: https://yaml.org/type/bool.html YAML_BOOL_TYPES = [ @@ -55,6 +70,10 @@ "OFF", ] +# Define an arbitrary (but fixed) ordering over the types of dictionary keys +# that may be encountered when calling `_make_hashable()` on a dict. +_CMP_TYPES = {t: i for i, t in enumerate([float, int, bool, str, type(None)])} + class OmegaConfDumper(yaml.Dumper): # type: ignore str_representer_added = False @@ -313,61 +332,45 @@ class ValueKind(Enum): VALUE = 0 MANDATORY_MISSING = 1 INTERPOLATION = 2 - STR_INTERPOLATION = 3 -def get_value_kind(value: Any, return_match_list: bool = False) -> Any: +def get_value_kind( + value: Any, strict_interpolation_validation: bool = False +) -> ValueKind: """ Determine the kind of a value Examples: - MANDATORY_MISSING : "??? - VALUE : "10", "20", True, - INTERPOLATION: "${foo}", "${foo.bar}" - STR_INTERPOLATION: "ftp://${host}/path" - - :param value: input string to classify - :param return_match_list: True to return the match list as well - :return: ValueKind + VALUE : "10", "20", True + MANDATORY_MISSING : "???" + INTERPOLATION: "${foo.bar}", "${foo.${bar}}", "${foo:bar}", "[${foo}, ${bar}]", + "ftp://${host}/path", "${foo:${bar}, [true], {'baz': ${baz}}}" + + :param value: Input to classify. + :param strict_interpolation_validation: If `True`, then when `value` is a string + containing "${", it is parsed to validate the interpolation syntax. If `False`, + this parsing step is skipped: this is more efficient, but will not detect errors. """ - key_prefix = r"\${(\w+:)?" - legal_characters = r"([\w\.%_ \\/:,-@]*?)}" - match_list: Optional[List[Match[str]]] = None - - def ret( - value_kind: ValueKind, - ) -> Union[ValueKind, Tuple[ValueKind, Optional[List[Match[str]]]]]: - if return_match_list: - return value_kind, match_list - else: - return value_kind - - from .base import Container - - if isinstance(value, Container): - if value._is_interpolation() or value._is_missing(): - value = value._value() - value = _get_value(value) - if value == "???": - return ret(ValueKind.MANDATORY_MISSING) - if not isinstance(value, str): - return ret(ValueKind.VALUE) - - match_list = list(re.finditer(key_prefix + legal_characters, value)) - if len(match_list) == 0: - return ret(ValueKind.VALUE) - - if len(match_list) == 1 and value == match_list[0].group(0): - return ret(ValueKind.INTERPOLATION) + if value == "???": + return ValueKind.MANDATORY_MISSING + + # We identify potential interpolations by the presence of "${" in the string. + # Note that escaped interpolations (ex: "esc: \${bar}") are identified as + # interpolations: this is intended, since they must be processed as interpolations + # for the string to be properly un-escaped. + # Keep in mind that invalid interpolations will only be detected when + # `strict_interpolation_validation` is True. + if isinstance(value, str) and "${" in value: + if strict_interpolation_validation: + # First try the cheap regex matching that detects common interpolations. + if SIMPLE_INTERPOLATION_PATTERN.match(value) is None: + # If no match, do the more expensive grammar parsing to detect errors. + parse(value) + return ValueKind.INTERPOLATION else: - return ret(ValueKind.STR_INTERPOLATION) - - -def is_bool(st: str) -> bool: - st = str.lower(st) - return st == "true" or st == "false" + return ValueKind.VALUE def is_float(st: str) -> bool: @@ -386,19 +389,6 @@ def is_int(st: str) -> bool: return False -def decode_primitive(s: str) -> Any: - if is_bool(s): - return str.lower(s) == "true" - - if is_int(s): - return int(s) - - if is_float(s): - return float(s) - - return s - - def is_primitive_list(obj: Any) -> bool: from .base import Container @@ -496,11 +486,11 @@ def is_primitive_type(type_: Any) -> bool: return issubclass(type_, Enum) or type_ in (int, float, bool, str, type(None)) -def _is_interpolation(v: Any) -> bool: +def _is_interpolation(v: Any, strict_interpolation_validation: bool = False) -> bool: if isinstance(v, str): - ret = get_value_kind(v) in ( - ValueKind.INTERPOLATION, - ValueKind.STR_INTERPOLATION, + ret = ( + get_value_kind(v, strict_interpolation_validation) + == ValueKind.INTERPOLATION ) assert isinstance(ret, bool) return ret @@ -511,8 +501,10 @@ def _get_value(value: Any) -> Any: from .base import Container from .nodes import ValueNode - if isinstance(value, Container) and value._is_none(): - return None + if isinstance(value, Container) and ( + value._is_none() or value._is_interpolation() or value._is_missing() + ): + return value._value() if isinstance(value, ValueNode): value = value._value() return value @@ -595,7 +587,7 @@ def format_and_raise( ref_type = get_ref_type(node) ref_type_str = type_str(ref_type) - msg = string.Template(msg).substitute( + msg = string.Template(msg).safe_substitute( REF_TYPE=ref_type_str, OBJECT_TYPE=object_type_str, KEY=key, @@ -731,3 +723,64 @@ def is_generic_dict(type_: Any) -> bool: def is_container_annotation(type_: Any) -> bool: return is_list_annotation(type_) or is_dict_annotation(type_) + + +def _make_hashable(x: Any) -> Any: + """ + Obtain a hashable version of `x`. + + This is achieved by turning into tuples the lists and dicts that may be + stored within `x`. + Note that dicts are sorted, so that two dicts ordered differently will + lead to the same resulting hashable key. + + :return: a hashable version of `x` (which may be `x` itself if already hashable). + """ + # Hopefully it is already hashable and we have nothing to do! + try: + hash(x) + return x + except TypeError: + pass + + if isinstance(x, (list, tuple)): + return tuple(_make_hashable(y) for y in x) + elif isinstance(x, dict): + # We sort the dictionary so that the order of keys does not matter. + # Note that since keys might be of different types, and comparisons + # between different types are not always allowed, we use a custom + # `_safe_items_sort_key()` function to order keys. + return _make_hashable(tuple(sorted(x.items(), key=_safe_items_sort_key))) + else: + raise NotImplementedError(f"type {type(x)} cannot be made hashable") + + +def _safe_cmp(x: Any, y: Any) -> int: + """ + Compare two elements `x` and `y` in a "safe" way. + + By default, this function uses regular comparison operators (== and <), but + if an exception is raised (due to not being able to compare x and y), we instead + use `_CMP_TYPES` to decide which order to use. + """ + try: + return 0 if x == y else -1 if x < y else 1 + except Exception: + type_x, type_y = type(x), type(y) + try: + idx_x = _CMP_TYPES[type_x] + idx_y = _CMP_TYPES[type_y] + except KeyError: + bad_type = type_x if type_y in _CMP_TYPES else type_y + raise TypeError(f"Invalid data type: `{bad_type}`") + if idx_x == idx_y: # cannot compare two elements of the same type?! + raise # pragma: no cover + return -1 if idx_x < idx_y else 1 + + +_safe_key = cmp_to_key(_safe_cmp) + + +def _safe_items_sort_key(kv: Tuple[Any, Any]) -> Any: + """Safe function to use as sort key when sorting items in a dictionary""" + return _safe_key(kv[0]) diff --git a/omegaconf/base.py b/omegaconf/base.py index 46e18505a..d0da37fc7 100644 --- a/omegaconf/base.py +++ b/omegaconf/base.py @@ -5,6 +5,8 @@ from enum import Enum from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union +from antlr4 import ParserRuleContext + from ._utils import ValueKind, _get_value, format_and_raise, get_value_kind from .errors import ( ConfigKeyError, @@ -13,6 +15,9 @@ OmegaConfBaseException, UnsupportedInterpolationType, ) +from .grammar.gen.OmegaConfGrammarParser import OmegaConfGrammarParser +from .grammar_parser import parse +from .grammar_visitor import GrammarVisitor DictKeyType = Union[str, int, Enum] @@ -178,47 +183,26 @@ def _get_full_key(self, key: Union[str, Enum, int, None]) -> str: ... def _dereference_node( - self, throw_on_missing: bool = False, throw_on_resolution_failure: bool = True + self, + throw_on_missing: bool = False, + throw_on_resolution_failure: bool = True, ) -> Optional["Node"]: - from .nodes import StringNode - if self._is_interpolation(): - value_kind, match_list = get_value_kind( - value=self._value(), return_match_list=True - ) - match = match_list[0] parent = self._get_parent() - key = self._key() - if value_kind == ValueKind.INTERPOLATION: - if parent is None: - raise OmegaConfBaseException( - "Cannot resolve interpolation for a node without a parent" - ) - v = parent._resolve_simple_interpolation( - key=key, - inter_type=match.group(1), - inter_key=match.group(2), - throw_on_missing=throw_on_missing, - throw_on_resolution_failure=throw_on_resolution_failure, - ) - return v - elif value_kind == ValueKind.STR_INTERPOLATION: - assert parent is not None - ret = parent._resolve_interpolation( - key=key, - value=self, - throw_on_missing=throw_on_missing, - throw_on_resolution_failure=throw_on_resolution_failure, - ) - if ret is None: - return ret - return StringNode( - value=ret, - key=key, - parent=parent, - is_optional=self._metadata.optional, + if parent is None: + raise OmegaConfBaseException( + "Cannot resolve interpolation for a node without a parent" ) - assert False + assert parent is not None + key = self._key() + return parent._resolve_interpolation_from_parse_tree( + parent=parent, + key=key, + value=self, + parse_tree=parse(_get_value(self)), + throw_on_missing=throw_on_missing, + throw_on_resolution_failure=throw_on_resolution_failure, + ) else: # not interpolation, compare directly if throw_on_missing: @@ -380,9 +364,11 @@ def _select_impl( ) if ret is not None and not isinstance(ret, Container): + parent_key = ".".join(split[0 : i + 1]) + child_key = split[i + 1] raise ConfigKeyError( - f"Error trying to access {key}: node `{'.'.join(split[0:i + 1])}` " - f"is not a container and thus cannot contain `{split[i + 1]}``" + f"Error trying to access {key}: node `{parent_key}` " + f"is not a container and thus cannot contain `{child_key}`" ) root = ret @@ -397,8 +383,9 @@ def _select_impl( throw_on_type_error=throw_on_resolution_failure, ) if value is None: - return root, last_key, value - value = root._resolve_interpolation( + return root, last_key, None + value = root._maybe_resolve_interpolation( + parent=root, key=last_key, value=value, throw_on_missing=throw_on_missing, @@ -406,111 +393,182 @@ def _select_impl( ) return root, last_key, value - def _resolve_simple_interpolation( + def _resolve_interpolation_from_parse_tree( self, + parent: Optional["Container"], + value: "Node", key: Any, - inter_type: str, + parse_tree: OmegaConfGrammarParser.ConfigValueContext, + throw_on_missing: bool, + throw_on_resolution_failure: bool, + ) -> Optional["Node"]: + from .nodes import StringNode + + resolved = self.resolve_parse_tree( + parse_tree=parse_tree, + key=key, + parent=parent, + throw_on_missing=throw_on_missing, + throw_on_resolution_failure=throw_on_resolution_failure, + ) + + if resolved is None: + return None + elif isinstance(resolved, str): + # Result is a string: create a new StringNode for it. + return StringNode( + value=resolved, + key=key, + parent=parent, + is_optional=value._metadata.optional, + ) + else: + assert isinstance(resolved, Node) + return resolved + + def _resolve_node_interpolation( + self, inter_key: str, throw_on_missing: bool, throw_on_resolution_failure: bool, + ) -> Optional["Node"]: + """A node interpolation is of the form `${foo.bar}`""" + root_node, inter_key = self._resolve_key_and_root(inter_key) + parent, last_key, value = root_node._select_impl( + inter_key, + throw_on_missing=throw_on_missing, + throw_on_resolution_failure=throw_on_resolution_failure, + ) + + if parent is None or value is None: + if throw_on_resolution_failure: + raise InterpolationResolutionError( + f"Interpolation key '{inter_key}' not found" + ) + else: + return None + assert isinstance(value, Node) + return value + + def _evaluate_custom_resolver( + self, + key: Any, + inter_type: str, + inter_args: Tuple[Any, ...], + throw_on_missing: bool, + throw_on_resolution_failure: bool, + inter_args_str: Tuple[str, ...], ) -> Optional["Node"]: from omegaconf import OmegaConf from .nodes import ValueNode - inter_type = ("str:" if inter_type is None else inter_type)[0:-1] - if inter_type == "str": - root_node, inter_key = self._resolve_key_and_root(inter_key) - parent, last_key, value = root_node._select_impl( - inter_key, - throw_on_missing=throw_on_missing, - throw_on_resolution_failure=throw_on_resolution_failure, - ) - - # if parent is None or (value is None and last_key not in parent): # type: ignore - if parent is None or value is None: + resolver = OmegaConf.get_resolver(inter_type) + if resolver is not None: + root_node = self._get_root() + try: + value = resolver(root_node, inter_args, inter_args_str) + return ValueNode( + value=value, + parent=self, + metadata=Metadata( + ref_type=Any, object_type=Any, key=key, optional=True + ), + ) + except Exception as e: if throw_on_resolution_failure: - raise InterpolationResolutionError( - f"{inter_type} interpolation key '{inter_key}' not found" - ) + self._format_and_raise(key=None, value=None, cause=e) + assert False else: return None - assert isinstance(value, Node) - return value else: - resolver = OmegaConf.get_resolver(inter_type) - if resolver is not None: - root_node = self._get_root() - try: - value = resolver(root_node, inter_key) - return ValueNode( - value=value, - parent=self, - metadata=Metadata( - ref_type=Any, object_type=Any, key=key, optional=True - ), - ) - except Exception as e: - if throw_on_resolution_failure: - self._format_and_raise(key=inter_key, value=None, cause=e) - assert False - else: - return None + if throw_on_resolution_failure: + raise UnsupportedInterpolationType( + f"Unsupported interpolation type {inter_type}" + ) else: - if throw_on_resolution_failure: - raise UnsupportedInterpolationType( - f"Unsupported interpolation type {inter_type}" - ) - else: - return None + return None - def _resolve_interpolation( + def _maybe_resolve_interpolation( self, + parent: Optional["Container"], key: Any, value: "Node", throw_on_missing: bool, throw_on_resolution_failure: bool, ) -> Any: + value_kind = get_value_kind(value) + if value_kind != ValueKind.INTERPOLATION: + return value + + parse_tree = parse(_get_value(value)) + return self._resolve_interpolation_from_parse_tree( + parent=parent, + value=value, + key=key, + parse_tree=parse_tree, + throw_on_missing=throw_on_missing, + throw_on_resolution_failure=throw_on_resolution_failure, + ) + + def resolve_parse_tree( + self, + parse_tree: ParserRuleContext, + key: Optional[Any] = None, + parent: Optional["Container"] = None, + throw_on_missing: bool = True, + throw_on_resolution_failure: bool = True, + ) -> Any: + """ + Resolve a given parse tree into its value. + + We make no assumption here on the type of the tree's root, so that the + return value may be of any type. + """ from .nodes import StringNode - value_kind, match_list = get_value_kind(value=value, return_match_list=True) - if value_kind not in (ValueKind.INTERPOLATION, ValueKind.STR_INTERPOLATION): - return value + # Common arguments to all callbacks. + callback_args: Dict[str, Any] = dict( + throw_on_missing=throw_on_missing, + throw_on_resolution_failure=throw_on_resolution_failure, + ) - if value_kind == ValueKind.INTERPOLATION: - # simple interpolation, inherit type - match = match_list[0] - return self._resolve_simple_interpolation( + def node_interpolation_callback(inter_key: str) -> Optional["Node"]: + return self._resolve_node_interpolation( + inter_key=inter_key, **callback_args + ) + + def resolver_interpolation_callback( + name: str, args: Tuple[Any, ...], args_str: Tuple[str, ...] + ) -> Optional["Node"]: + return self._evaluate_custom_resolver( key=key, - inter_type=match.group(1), - inter_key=match.group(2), - throw_on_missing=throw_on_missing, - throw_on_resolution_failure=throw_on_resolution_failure, + inter_type=name, + inter_args=args, + inter_args_str=args_str, + **callback_args, ) - elif value_kind == ValueKind.STR_INTERPOLATION: - value = _get_value(value) - assert isinstance(value, str) - orig = value - new = "" - last_index = 0 - for match in match_list: - new_val = self._resolve_simple_interpolation( + + def quoted_string_callback(quoted_str: str) -> str: + quoted_val = self._maybe_resolve_interpolation( + key=key, + parent=parent, + value=StringNode( + value=quoted_str, key=key, - inter_type=match.group(1), - inter_key=match.group(2), - throw_on_missing=throw_on_missing, - throw_on_resolution_failure=throw_on_resolution_failure, - ) - # if failed to resolve, return None for the whole thing. - if new_val is None: - return None - new += orig[last_index : match.start(0)] + str(new_val) - last_index = match.end(0) + parent=parent, + is_optional=False, + ), + **callback_args, + ) + return str(quoted_val) - new += orig[last_index:] - return StringNode(value=new, key=key) - else: - assert False + visitor = GrammarVisitor( + node_interpolation_callback=node_interpolation_callback, + resolver_interpolation_callback=resolver_interpolation_callback, + quoted_string_callback=quoted_string_callback, + ) + return visitor.visit(parse_tree) def _re_parent(self) -> None: from .dictconfig import DictConfig diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index dfeace1af..0a46cb5c8 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -51,14 +51,15 @@ def _resolve_with_default( """returns the value with the specified key, like obj.key and obj['key']""" def is_mandatory_missing(val: Any) -> bool: - return get_value_kind(val) == ValueKind.MANDATORY_MISSING # type: ignore + return bool(get_value_kind(val) == ValueKind.MANDATORY_MISSING) - value = _get_value(value) + val = _get_value(value) has_default = default_value is not DEFAULT_VALUE_MARKER - if has_default and (value is None or is_mandatory_missing(value)): + if has_default and (val is None or is_mandatory_missing(val)): return default_value - resolved = self._resolve_interpolation( + resolved = self._maybe_resolve_interpolation( + parent=self, key=key, value=value, throw_on_missing=not has_default, @@ -455,7 +456,9 @@ def _list_merge(dest: Any, src: Any) -> None: def merge_with( self, - *others: Union["BaseContainer", Dict[str, Any], List[Any], Tuple[Any], Any], + *others: Union[ + "BaseContainer", Dict[str, Any], List[Any], Tuple[Any, ...], Any + ], ) -> None: try: self._merge_with(*others) @@ -464,7 +467,9 @@ def merge_with( def _merge_with( self, - *others: Union["BaseContainer", Dict[str, Any], List[Any], Tuple[Any], Any], + *others: Union[ + "BaseContainer", Dict[str, Any], List[Any], Tuple[Any, ...], Any + ], ) -> None: from .dictconfig import DictConfig from .listconfig import ListConfig diff --git a/omegaconf/dictconfig.py b/omegaconf/dictconfig.py index 4484ef218..6141fdc5d 100644 --- a/omegaconf/dictconfig.py +++ b/omegaconf/dictconfig.py @@ -167,7 +167,7 @@ def _validate_set(self, key: Any, value: Any) -> None: from omegaconf import OmegaConf vk = get_value_kind(value) - if vk in (ValueKind.INTERPOLATION, ValueKind.STR_INTERPOLATION): + if vk == ValueKind.INTERPOLATION: return self._validate_non_optional(key, value) if value == "???" or value is None: @@ -613,7 +613,7 @@ def _set_value_impl( if OmegaConf.is_none(value): self.__dict__["_content"] = None self._metadata.object_type = None - elif _is_interpolation(value): + elif _is_interpolation(value, strict_interpolation_validation=True): self.__dict__["_content"] = value self._metadata.object_type = None elif value == "???": diff --git a/omegaconf/errors.py b/omegaconf/errors.py index 79c23a07e..0ff62e601 100644 --- a/omegaconf/errors.py +++ b/omegaconf/errors.py @@ -109,3 +109,9 @@ class ConfigValueError(OmegaConfBaseException, ValueError): """ Thrown from a config object when a regular access would have caused a ValueError. """ + + +class GrammarParseError(OmegaConfBaseException): + """ + Thrown when failing to parse an expression according to the ANTLR grammar. + """ diff --git a/omegaconf/grammar/OmegaConfGrammarLexer.g4 b/omegaconf/grammar/OmegaConfGrammarLexer.g4 new file mode 100644 index 000000000..4960ed02e --- /dev/null +++ b/omegaconf/grammar/OmegaConfGrammarLexer.g4 @@ -0,0 +1,80 @@ +// Regenerate lexer and parser by running 'python setup.py antlr' at project root. +// See `OmegaConfGrammarParser.g4` for some important information regarding how to +// properly maintain this grammar. + +lexer grammar OmegaConfGrammarLexer; + +// Re-usable fragments. +fragment CHAR: [a-zA-Z]; +fragment DIGIT: [0-9]; +fragment INT_UNSIGNED: '0' | [1-9] (('_')? DIGIT)*; +fragment ESC_BACKSLASH: '\\\\'; // escaped backslash + +///////////////////////////// +// DEFAULT_MODE (TOPLEVEL) // +///////////////////////////// + +TOP_INTER_OPEN: INTER_OPEN -> type(INTER_OPEN), pushMode(INTERPOLATION_MODE); + +ESC_INTER: '\\${'; +TOP_ESC: ESC_BACKSLASH+ -> type(ESC); + +// The backslash and dollar characters must not be grouped with others, so that +// we can properly detect the tokens above. +TOP_CHAR: [\\$]; +TOP_STR: ~[\\$]+; // anything else + +//////////////// +// VALUE_MODE // +//////////////// + +mode VALUE_MODE; + +INTER_OPEN: '${' -> pushMode(INTERPOLATION_MODE); +BRACE_OPEN: '{' WS? -> pushMode(VALUE_MODE); // must keep track of braces to detect end of interpolation +BRACE_CLOSE: WS? '}' -> popMode; + +COMMA: WS? ',' WS?; +BRACKET_OPEN: '[' WS?; +BRACKET_CLOSE: WS? ']'; +COLON: WS? ':' WS?; + +// Numbers. + +fragment POINT_FLOAT: INT_UNSIGNED '.' | INT_UNSIGNED? '.' DIGIT (('_')? DIGIT)*; +fragment EXPONENT_FLOAT: (INT_UNSIGNED | POINT_FLOAT) [eE] [+-]? DIGIT (('_')? DIGIT)*; +FLOAT: [+-]? (POINT_FLOAT | EXPONENT_FLOAT | [Ii][Nn][Ff] | [Nn][Aa][Nn]); +INT: [+-]? INT_UNSIGNED; + +// Other reserved keywords. + +BOOL: + [Tt][Rr][Uu][Ee] // TRUE + | [Ff][Aa][Ll][Ss][Ee]; // FALSE + +NULL: [Nn][Uu][Ll][Ll]; + +UNQUOTED_CHAR: [/\-\\+.$%*@]; // other characters allowed in unquoted strings +ID: (CHAR|'_') (CHAR|DIGIT|'_')*; +ESC: (ESC_BACKSLASH | '\\(' | '\\)' | '\\[' | '\\]' | '\\{' | '\\}' | + '\\:' | '\\=' | '\\,' | '\\ ' | '\\\t')+; +WS: [ \t]+; + +QUOTED_VALUE: + '\'' ('\\\''|.)*? '\'' // Single quotes, can contain escaped single quote : /' + | '"' ('\\"'|.)*? '"' ; // Double quotes, can contain escaped double quote : /" + +//////////////////////// +// INTERPOLATION_MODE // +//////////////////////// + +mode INTERPOLATION_MODE; + +NESTED_INTER_OPEN: INTER_OPEN -> type(INTER_OPEN), pushMode(INTERPOLATION_MODE); +INTER_COLON: ':' WS? -> type(COLON), mode(VALUE_MODE); +INTER_CLOSE: '}' -> popMode; + +DOT: '.'; +INTER_ID: ID -> type(ID); +INTER_KEY: ~[\\${}()[\]:. \t'"]+; // interpolation key, may contain any non special character +INTER_WS: WS -> skip; diff --git a/omegaconf/grammar/OmegaConfGrammarParser.g4 b/omegaconf/grammar/OmegaConfGrammarParser.g4 new file mode 100644 index 000000000..d92495b49 --- /dev/null +++ b/omegaconf/grammar/OmegaConfGrammarParser.g4 @@ -0,0 +1,81 @@ +// Regenerate parser by running 'python setup.py antlr' at project root. + +// Maintenance guidelines when modifying this grammar: +// +// - For initial testing of the parsing abilities of the modified grammer before +// writing all the support visitor code, change the test +// `tests/test_interpolation.py::test_all_interpolations` +// by setting `dbg_test_access_only = True`, and run it. You will also probably +// need to comment / hijack the code accesssing the visitor. Tests that expect +// errors raised from the visitor will obviously fail. +// +// - Update Hydra's grammar accordingly, and if you added more cases to the test +// mentioned above, copy the latest version of `TEST_CONFIG_DATA` to Hydra (see +// Hydra's test: `tests/test_overrides_parser.py::test_omegaconf_interpolations`). + +// - Keep up-to-date the comments in the visitor (in `grammar_visitor.py`) +// that contain grammar excerpts (within each `visit...()` method). +// +// - Remember to update the documentation (including the tutorial notebook) + +parser grammar OmegaConfGrammarParser; +options {tokenVocab = OmegaConfGrammarLexer;} + +// Main rules used to parse OmegaConf strings. + +configValue: (toplevelStr | (toplevelStr? (interpolation toplevelStr?)+)) EOF; +singleElement: element EOF; + +// Top-level string (that does not need to be parsed). +toplevelStr: (ESC | ESC_INTER | TOP_CHAR | TOP_STR)+; + +// Elements. + +element: + primitive + | listContainer + | dictContainer +; + +// Data structures. + +listContainer: BRACKET_OPEN sequence? BRACKET_CLOSE; // [], [1,2,3], [a,b,[1,2]] +dictContainer: BRACE_OPEN (dictKeyValuePair (COMMA dictKeyValuePair)*)? BRACE_CLOSE; // {}, {a:10,b:20} +dictKeyValuePair: dictKey COLON element; +sequence: element (COMMA element)*; + +// Interpolations. + +interpolation: interpolationNode | interpolationResolver; +interpolationNode: INTER_OPEN DOT* configKey (DOT configKey)* INTER_CLOSE; +interpolationResolver: INTER_OPEN (interpolation | ID) COLON sequence? BRACE_CLOSE; +configKey: interpolation | ID | INTER_KEY; + +// Primitive types. + +primitive: + QUOTED_VALUE // 'hello world', "hello world" + | ( ID // foo_10 + | NULL // null, NULL + | INT // 0, 10, -20, 1_000_000 + | FLOAT // 3.14, -20.0, 1e-1, -10e3 + | BOOL // true, TrUe, false, False + | UNQUOTED_CHAR // /, -, \, +, ., $, %, *, @ + | COLON // : + | ESC // \\, \(, \), \[, \], \{, \}, \:, \=, \ , \\t, \, + | WS // whitespaces + | interpolation + )+; + +// Same as `primitive` except that `COLON` and interpolations are not allowed. +dictKey: + QUOTED_VALUE // 'hello world', "hello world" + | ( ID // foo_10 + | NULL // null, NULL + | INT // 0, 10, -20, 1_000_000 + | FLOAT // 3.14, -20.0, 1e-1, -10e3 + | BOOL // true, TrUe, false, False + | UNQUOTED_CHAR // /, -, \, +, ., $, %, *, @ + | ESC // \\, \(, \), \[, \], \{, \}, \:, \=, \ , \\t, \, + | WS // whitespaces + )+; \ No newline at end of file diff --git a/omegaconf/grammar/__init__.py b/omegaconf/grammar/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/omegaconf/grammar/gen/__init__.py b/omegaconf/grammar/gen/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/omegaconf/grammar_parser.py b/omegaconf/grammar_parser.py new file mode 100644 index 000000000..f379e9b2a --- /dev/null +++ b/omegaconf/grammar_parser.py @@ -0,0 +1,117 @@ +from typing import Any + +from antlr4 import CommonTokenStream, InputStream, ParserRuleContext +from antlr4.error.ErrorListener import ErrorListener + +from .errors import GrammarParseError + +# Import from visitor in order to check the presence of generated grammar files +# files in a single place. +from .grammar_visitor import ( # type: ignore + OmegaConfGrammarLexer, + OmegaConfGrammarParser, +) + +# Used to cache grammar objects to avoid re-creating them on each call to `parse()`. +_grammar_cache = None + + +class OmegaConfErrorListener(ErrorListener): # type: ignore + def syntaxError( + self, + recognizer: Any, + offending_symbol: Any, + line: Any, + column: Any, + msg: Any, + e: Any, + ) -> None: + raise GrammarParseError(str(e) if msg is None else msg) from e + + def reportAmbiguity( + self, + recognizer: Any, + dfa: Any, + startIndex: Any, + stopIndex: Any, + exact: Any, + ambigAlts: Any, + configs: Any, + ) -> None: + raise GrammarParseError("ANTLR error: Ambiguity") # pragma: no cover + + def reportAttemptingFullContext( + self, + recognizer: Any, + dfa: Any, + startIndex: Any, + stopIndex: Any, + conflictingAlts: Any, + configs: Any, + ) -> None: + # Note: for now we raise an error to be safe. However this is mostly a + # performance warning, so in the future this may be relaxed if we need + # to change the grammar in such a way that this warning cannot be + # avoided (another option would be to switch to SLL parsing mode). + raise GrammarParseError( + "ANTLR error: Attempting Full Context" + ) # pragma: no cover + + def reportContextSensitivity( + self, + recognizer: Any, + dfa: Any, + startIndex: Any, + stopIndex: Any, + prediction: Any, + configs: Any, + ) -> None: + raise GrammarParseError("ANTLR error: ContextSensitivity") # pragma: no cover + + +def parse( + value: str, parser_rule: str = "configValue", lexer_mode: str = "DEFAULT_MODE" +) -> ParserRuleContext: + """ + Parse interpolated string `value` (and return the parse tree). + """ + global _grammar_cache + + l_mode = getattr(OmegaConfGrammarLexer, lexer_mode) + istream = InputStream(value) + + if _grammar_cache is None: + error_listener = OmegaConfErrorListener() + lexer = OmegaConfGrammarLexer(istream) + lexer.removeErrorListeners() + lexer.addErrorListener(error_listener) + lexer.mode(l_mode) + tokens = CommonTokenStream(lexer) + parser = OmegaConfGrammarParser(tokens) + parser.removeErrorListeners() + parser.addErrorListener(error_listener) + + # The two lines below could be enabled in the future if we decide to switch + # to SLL prediction mode. Warning though, it has not been fully tested yet! + # from antlr4 import PredictionMode + # parser._interp.predictionMode = PredictionMode.SLL + + _grammar_cache = lexer, tokens, parser + + else: + lexer, tokens, parser = _grammar_cache + lexer.inputStream = istream + lexer.mode(l_mode) + tokens.setTokenSource(lexer) + parser.reset() + + try: + return getattr(parser, parser_rule)() + except Exception as exc: + if type(exc) is Exception and str(exc) == "Empty Stack": + # This exception is raised by antlr when trying to pop a mode while + # no mode has been pushed. We convert it into an `GrammarParseError` + # to facilitate exception handling from the caller. + raise GrammarParseError("Empty Stack") + else: + raise diff --git a/omegaconf/grammar_visitor.py b/omegaconf/grammar_visitor.py new file mode 100644 index 000000000..321f4de00 --- /dev/null +++ b/omegaconf/grammar_visitor.py @@ -0,0 +1,356 @@ +import sys +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Generator, + Iterable, + List, + Optional, + Tuple, + Union, +) + +from antlr4 import TerminalNode + +from .errors import GrammarParseError + +if TYPE_CHECKING: + from .base import Node # noqa F401 + +try: + from omegaconf.grammar.gen.OmegaConfGrammarLexer import OmegaConfGrammarLexer + from omegaconf.grammar.gen.OmegaConfGrammarParser import OmegaConfGrammarParser + from omegaconf.grammar.gen.OmegaConfGrammarParserVisitor import ( + OmegaConfGrammarParserVisitor, + ) + +except ModuleNotFoundError: # pragma: no cover + print( + "Error importing OmegaConf's generated parsers, run `python setup.py antlr` to regenerate.", + file=sys.stderr, + ) + sys.exit(1) + + +class GrammarVisitor(OmegaConfGrammarParserVisitor): + def __init__( + self, + node_interpolation_callback: Callable[[str], Optional["Node"]], + resolver_interpolation_callback: Callable[..., Optional["Node"]], + quoted_string_callback: Callable[[str], str], + **kw: Dict[Any, Any], + ): + """ + Constructor. + + :param node_interpolation_callback: Callback function that is called when + needing to resolve a node interpolation. This function should take a single + string input which is the key's dot path (ex: `"foo.bar"`). + + :param resolver_interpolation_callback: Callback function that is called when + needing to resolve a resolver interpolation. This function should accept + three keyword arguments: `name` (str, the name of the resolver), + `args` (tuple, the inputs to the resolver), and `args_str` (tuple, + the string representation of the inputs to the resolver). + + :param quoted_string_callback: Callback function that is called when needing to + resolve a quoted string (that may or may not contain interpolations). This + function should take a single string input which is the content of the quoted + string (without its enclosing quotes). + + :param kw: Additional keyword arguments to be forwarded to parent class. + """ + super().__init__(**kw) + self.node_interpolation_callback = node_interpolation_callback + self.resolver_interpolation_callback = resolver_interpolation_callback + self.quoted_string_callback = quoted_string_callback + + def aggregateResult(self, aggregate: List[Any], nextResult: Any) -> List[Any]: + raise NotImplementedError + + def defaultResult(self) -> List[Any]: + # Raising an exception because not currently used (like `aggregateResult()`). + raise NotImplementedError + + def visitConfigKey(self, ctx: OmegaConfGrammarParser.ConfigKeyContext) -> str: + from ._utils import _get_value + + # interpolation | ID | INTER_KEY + assert ctx.getChildCount() == 1 + child = ctx.getChild(0) + if isinstance(child, OmegaConfGrammarParser.InterpolationContext): + res = _get_value(self.visitInterpolation(child)) + if not isinstance(res, str): + raise GrammarParseError( + f"The following interpolation is used to denote a config key and " + f"thus should return a string, but instead returned `{res}` of " + f"type `{type(res)}`: {ctx.getChild(0).getText()}" + ) + return res + else: + assert isinstance(child, TerminalNode) and isinstance( + child.symbol.text, str + ) + return child.symbol.text + + def visitConfigValue( + self, ctx: OmegaConfGrammarParser.ConfigValueContext + ) -> Union[str, Optional["Node"]]: + # (toplevelStr | (toplevelStr? (interpolation toplevelStr?)+)) EOF + # Visit all children (except last one which is EOF) + vals = [self.visit(c) for c in list(ctx.getChildren())[:-1]] + assert vals + if len(vals) == 1 and isinstance( + ctx.getChild(0), OmegaConfGrammarParser.InterpolationContext + ): + from .base import Node # noqa F811 + + # Single interpolation: return the resulting node "as is". + ret = vals[0] + assert ret is None or isinstance(ret, Node), ret + return ret + # Concatenation of multiple components. + return "".join(map(str, vals)) + + def visitDictKey(self, ctx: OmegaConfGrammarParser.DictKeyContext) -> Any: + return self._createPrimitive(ctx) + + def visitDictContainer( + self, ctx: OmegaConfGrammarParser.DictContainerContext + ) -> Dict[Any, Any]: + # BRACE_OPEN (dictKeyValuePair (COMMA dictKeyValuePair)*)? BRACE_CLOSE + assert ctx.getChildCount() >= 2 + return dict( + self.visitDictKeyValuePair(ctx.getChild(i)) + for i in range(1, ctx.getChildCount() - 1, 2) + ) + + def visitElement(self, ctx: OmegaConfGrammarParser.ElementContext) -> Any: + # primitive | listContainer | dictContainer + assert ctx.getChildCount() == 1 + return self.visit(ctx.getChild(0)) + + def visitInterpolation( + self, ctx: OmegaConfGrammarParser.InterpolationContext + ) -> Optional["Node"]: + from .base import Node # noqa F811 + + assert ctx.getChildCount() == 1 # interpolationNode | interpolationResolver + ret = self.visit(ctx.getChild(0)) + assert ret is None or isinstance(ret, Node) + return ret + + def visitInterpolationNode( + self, ctx: OmegaConfGrammarParser.InterpolationNodeContext + ) -> Optional["Node"]: + # INTER_OPEN DOT* configKey (DOT configKey)* INTER_CLOSE + assert ctx.getChildCount() >= 3 + + inter_key_tokens = [] # parsed elements of the dot path + for child in ctx.getChildren(): + if isinstance(child, TerminalNode): + if child.symbol.type == OmegaConfGrammarLexer.DOT: + inter_key_tokens.append(".") # preserve dots + else: + assert child.symbol.type in ( + OmegaConfGrammarLexer.INTER_OPEN, + OmegaConfGrammarLexer.INTER_CLOSE, + ) + else: + assert isinstance(child, OmegaConfGrammarParser.ConfigKeyContext) + inter_key_tokens.append(self.visitConfigKey(child)) + + inter_key = "".join(inter_key_tokens) + return self.node_interpolation_callback(inter_key) + + def visitInterpolationResolver( + self, ctx: OmegaConfGrammarParser.InterpolationResolverContext + ) -> Optional["Node"]: + from ._utils import _get_value + + # INTER_OPEN (interpolation | ID) COLON sequence? BRACE_CLOSE; + resolver_name = None + args = [] + args_str = [] + for child in ctx.getChildren(): + if ( + isinstance(child, TerminalNode) + and child.symbol.type == OmegaConfGrammarLexer.ID + ): + assert resolver_name is None + resolver_name = child.symbol.text + elif isinstance(child, OmegaConfGrammarParser.InterpolationContext): + assert resolver_name is None + resolver_name = _get_value(self.visitInterpolation(child)) + if not isinstance(resolver_name, str): + raise GrammarParseError( + f"The name of a resolver must be a string, but the interpolation " + f"{child.getText()} resolved to `{resolver_name}` which is of type " + f"{type(resolver_name)}" + ) + elif isinstance(child, OmegaConfGrammarParser.SequenceContext): + assert resolver_name is not None + for val, txt in self.visitSequence(child): + args.append(val) + args_str.append(txt) + else: + assert isinstance(child, TerminalNode) + + assert resolver_name is not None + return self.resolver_interpolation_callback( + name=resolver_name, + args=tuple(args), + args_str=tuple(args_str), + ) + + def visitDictKeyValuePair( + self, ctx: OmegaConfGrammarParser.DictKeyValuePairContext + ) -> Tuple[Any, Any]: + from ._utils import _get_value + + assert ctx.getChildCount() == 3 # dictKey COLON element + key = self.visit(ctx.getChild(0)) + colon = ctx.getChild(1) + assert ( + isinstance(colon, TerminalNode) + and colon.symbol.type == OmegaConfGrammarLexer.COLON + ) + value = _get_value(self.visitElement(ctx.getChild(2))) + return key, value + + def visitListContainer( + self, ctx: OmegaConfGrammarParser.ListContainerContext + ) -> List[Any]: + # BRACKET_OPEN sequence? BRACKET_CLOSE; + assert ctx.getChildCount() in (2, 3) + if ctx.getChildCount() == 2: + return [] + sequence = ctx.getChild(1) + assert isinstance(sequence, OmegaConfGrammarParser.SequenceContext) + return list(val for val, _ in self.visitSequence(sequence)) # ignore raw text + + def visitPrimitive(self, ctx: OmegaConfGrammarParser.PrimitiveContext) -> Any: + return self._createPrimitive(ctx) + + def visitSequence( + self, ctx: OmegaConfGrammarParser.SequenceContext + ) -> Generator[Any, None, None]: + from ._utils import _get_value + + assert ctx.getChildCount() >= 1 # element (COMMA element)* + for i, child in enumerate(ctx.getChildren()): + if i % 2 == 0: + assert isinstance(child, OmegaConfGrammarParser.ElementContext) + # Also preserve the original text representation of `child` so + # as to allow backward compatibility with old resolvers (registered + # with `legacy_register_resolver()`). Note that we cannot just cast + # the value to string later as for instance `null` would become "None". + yield _get_value(self.visitElement(child)), child.getText() + else: + assert ( + isinstance(child, TerminalNode) + and child.symbol.type == OmegaConfGrammarLexer.COMMA + ) + + def visitSingleElement( + self, ctx: OmegaConfGrammarParser.SingleElementContext + ) -> Any: + # element EOF + assert ctx.getChildCount() == 2 + return self.visit(ctx.getChild(0)) + + def visitToplevelStr(self, ctx: OmegaConfGrammarParser.ToplevelStrContext) -> str: + # (ESC | ESC_INTER | TOP_CHAR | TOP_STR)+ + return self._unescape(ctx.getChildren()) + + def _createPrimitive( + self, + ctx: Union[ + OmegaConfGrammarParser.PrimitiveContext, + OmegaConfGrammarParser.DictKeyContext, + ], + ) -> Any: + # QUOTED_VALUE | + # (ID | NULL | INT | FLOAT | BOOL | UNQUOTED_CHAR | COLON | ESC | WS | interpolation)+ + if ctx.getChildCount() == 1: + child = ctx.getChild(0) + if isinstance(child, OmegaConfGrammarParser.InterpolationContext): + return self.visitInterpolation(child) + assert isinstance(child, TerminalNode) + symbol = child.symbol + # Parse primitive types. + if symbol.type == OmegaConfGrammarLexer.QUOTED_VALUE: + return self._resolve_quoted_string(symbol.text) + elif symbol.type in ( + OmegaConfGrammarLexer.ID, + OmegaConfGrammarLexer.UNQUOTED_CHAR, + OmegaConfGrammarLexer.COLON, + ): + return symbol.text + elif symbol.type == OmegaConfGrammarLexer.NULL: + return None + elif symbol.type == OmegaConfGrammarLexer.INT: + return int(symbol.text) + elif symbol.type == OmegaConfGrammarLexer.FLOAT: + return float(symbol.text) + elif symbol.type == OmegaConfGrammarLexer.BOOL: + return symbol.text.lower() == "true" + elif symbol.type == OmegaConfGrammarLexer.ESC: + return self._unescape([child]) + elif symbol.type == OmegaConfGrammarLexer.WS: # pragma: no cover + # A single WS should have been "consumed" by another token. + raise AssertionError("WS should never be reached") + assert False, symbol.type + # Concatenation of multiple items ==> un-escape the concatenation. + return self._unescape(ctx.getChildren()) + + def _resolve_quoted_string(self, quoted: str) -> str: + """ + Parse a quoted string. + """ + # Identify quote type. + assert len(quoted) >= 2 and quoted[0] == quoted[-1] + quote_type = quoted[0] + assert quote_type in ["'", '"'] + + # Un-escape quotes and backslashes within the string (the two kinds of + # escapable characters in quoted strings). We do it in two passes: + # 1. Replace `\"` with `"` (and similarly for single quotes) + # 2. Replace `\\` with `\` + # The order is important so that `\\"` is replaced with an escaped quote `\"`. + # We also remove the start and end quotes. + esc_quote = f"\\{quote_type}" + quoted_content = ( + quoted[1:-1].replace(esc_quote, quote_type).replace("\\\\", "\\") + ) + + # Parse the string. + return self.quoted_string_callback(quoted_content) + + def _unescape( + self, + seq: Iterable[Union[TerminalNode, OmegaConfGrammarParser.InterpolationContext]], + ) -> str: + """ + Concatenate all symbols / interpolations in `seq`, unescaping symbols as needed. + + Interpolations are resolved and cast to string *WITHOUT* escaping their result + (it is assumed that whatever escaping is required was already handled during the + resolving of the interpolation). + """ + chrs = [] + for node in seq: + if isinstance(node, TerminalNode): + s = node.symbol + if s.type == OmegaConfGrammarLexer.ESC: + chrs.append(s.text[1::2]) + elif s.type == OmegaConfGrammarLexer.ESC_INTER: + chrs.append(s.text[1:]) + else: + chrs.append(s.text) + else: + assert isinstance(node, OmegaConfGrammarParser.InterpolationContext) + chrs.append(str(self.visitInterpolation(node))) + return "".join(chrs) diff --git a/omegaconf/listconfig.py b/omegaconf/listconfig.py index 1f60b7f2f..32157750b 100644 --- a/omegaconf/listconfig.py +++ b/omegaconf/listconfig.py @@ -570,7 +570,7 @@ def _set_value_impl( if flags is None: flags = {} - vk = get_value_kind(value) + vk = get_value_kind(value, strict_interpolation_validation=True) if OmegaConf.is_none(value): if not self._is_optional(): raise ValidationError( @@ -579,10 +579,7 @@ def _set_value_impl( self.__dict__["_content"] = None elif vk is ValueKind.MANDATORY_MISSING: self.__dict__["_content"] = "???" - elif vk in ( - ValueKind.INTERPOLATION, - ValueKind.STR_INTERPOLATION, - ): + elif vk == ValueKind.INTERPOLATION: self.__dict__["_content"] = value else: if not (is_primitive_list(value) or isinstance(value, ListConfig)): diff --git a/omegaconf/nodes.py b/omegaconf/nodes.py index c8e38425b..b1e318b9f 100644 --- a/omegaconf/nodes.py +++ b/omegaconf/nodes.py @@ -4,7 +4,13 @@ from enum import Enum from typing import Any, Dict, Optional, Type, Union -from omegaconf._utils import _is_interpolation, get_type_of, is_primitive_container +from omegaconf._utils import ( + ValueKind, + _is_interpolation, + get_type_of, + get_value_kind, + is_primitive_container, +) from omegaconf.base import Container, Metadata, Node from omegaconf.errors import ( ConfigKeyError, @@ -28,14 +34,13 @@ def _value(self) -> Any: return self._val def _set_value(self, value: Any, flags: Optional[Dict[str, bool]] = None) -> None: - from ._utils import ValueKind, get_value_kind - if self._get_flag("readonly"): raise ReadonlyConfigError("Cannot set value of read-only config node") - if isinstance(value, str) and get_value_kind(value) in ( + if isinstance(value, str) and get_value_kind( + value, strict_interpolation_validation=True + ) in ( ValueKind.INTERPOLATION, - ValueKind.STR_INTERPOLATION, ValueKind.MANDATORY_MISSING, ): self._val = value diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index e93166b01..ee5b7c244 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -3,7 +3,6 @@ import io import os import pathlib -import re import sys import warnings from collections import defaultdict @@ -17,7 +16,6 @@ Dict, Generator, List, - Match, Optional, Tuple, Type, @@ -26,13 +24,12 @@ ) import yaml -from typing_extensions import Protocol from . import DictConfig, DictKeyType, ListConfig from ._utils import ( _ensure_container, _get_value, - decode_primitive, + _make_hashable, format_and_raise, get_dict_key_value_types, get_list_element_type, @@ -54,12 +51,14 @@ from .basecontainer import BaseContainer from .errors import ( ConfigKeyError, + GrammarParseError, InterpolationResolutionError, MissingMandatoryValue, OmegaConfBaseException, UnsupportedInterpolationType, ValidationError, ) +from .grammar_parser import parse from .nodes import ( AnyNode, BooleanNode, @@ -72,10 +71,14 @@ MISSING: Any = "???" -# A marker used in OmegaConf.create() to differentiate between creating an empty {} DictConfig -# and creating a DictConfig with None content. +# A marker used: +# - in OmegaConf.create() to differentiate between creating an empty {} DictConfig +# and creating a DictConfig with None content +# - in env() to detect between no default value vs a default value set to None _EMPTY_MARKER_ = object() +Resolver = Callable[..., Any] + def II(interpolation: str) -> Any: """ @@ -95,40 +98,43 @@ def SI(interpolation: str) -> Any: return interpolation -class Resolver0(Protocol): - def __call__(self) -> Any: - ... - - -class Resolver1(Protocol): - def __call__(self, __x1: str) -> Any: - ... - - -class Resolver2(Protocol): - def __call__(self, __x1: str, __x2: str) -> Any: - ... - - -class Resolver3(Protocol): - def __call__(self, __x1: str, __x2: str, __x3: str) -> Any: - ... - - -Resolver = Union[Resolver0, Resolver1, Resolver2, Resolver3] - - def register_default_resolvers() -> None: - def env(key: str, default: Optional[str] = None) -> Any: + def env(key: str, default: Any = _EMPTY_MARKER_) -> Any: try: - return decode_primitive(os.environ[key]) + val_str = os.environ[key] except KeyError: - if default is not None: - return decode_primitive(default) + if default is not _EMPTY_MARKER_: + return default else: raise ValidationError(f"Environment variable '{key}' not found") - OmegaConf.register_resolver("env", env) + # We obtained a string from the environment variable: we try to parse it + # using the grammar (as if it was a resolver argument), so that expressions + # like numbers, booleans, lists and dictionaries can be properly evaluated. + try: + parse_tree = parse( + val_str, parser_rule="singleElement", lexer_mode="VALUE_MODE" + ) + except GrammarParseError: + # Un-parsable as a resolver argument: keep the string unchanged. + return val_str + + # Resolve the parse tree. We use an empty config for this, which means that + # interpolations referring to other nodes will fail. + empty_config = DictConfig({}) + try: + val = empty_config.resolve_parse_tree(parse_tree) + except InterpolationResolutionError as exc: + raise InterpolationResolutionError( + f"When attempting to resolve env variable '{key}', a node interpolation " + f"caused the following exception: {exc}. Node interpolations are not " + f"supported in environment variables: either remove them, or escape " + f"them to keep them as a strings." + ) + return _get_value(val) + + # Note that the `env` resolver does *NOT* use the cache. + OmegaConf.register_new_resolver("env", env, use_cache=True) class OmegaConf: @@ -398,41 +404,110 @@ def unsafe_merge( return target - @staticmethod - def _tokenize_args(string: Optional[str]) -> List[str]: - if string is None or string == "": - return [] - - def _unescape_word_boundary(match: Match[str]) -> str: - if match.start() == 0 or match.end() == len(match.string): - return "" - return match.group(0) - - escaped = re.split(r"(? None: + # TODO re-enable warning message before 2.1 release (see also test_resolver_deprecated_behavior) + # warnings.warn( + # dedent( + # """\ + # register_resolver() is deprecated. + # See https://github.com/omry/omegaconf/issues/426 for migration instructions. + # """ + # ), + # stacklevel=2, + # ) + return OmegaConf.legacy_register_resolver(name, resolver) + + # This function will eventually be deprecated and removed. + @staticmethod + def legacy_register_resolver(name: str, resolver: Resolver) -> None: assert callable(resolver), "resolver must be callable" # noinspection PyProtectedMember assert ( name not in BaseContainer._resolvers - ), "resolved {} is already registered".format(name) + ), f"resolver {name} is already registered" - def caching(config: BaseContainer, key: str) -> Any: + def resolver_wrapper( + config: BaseContainer, + args: Tuple[Any, ...], + args_str: Tuple[str, ...], + ) -> Any: cache = OmegaConf.get_cache(config)[name] - val = ( - cache[key] if key in cache else resolver(*OmegaConf._tokenize_args(key)) - ) + # "Un-escape " spaces and commas. + args_unesc = [x.replace(r"\ ", " ").replace(r"\,", ",") for x in args_str] + + # Nested interpolations behave in a potentially surprising way with + # legacy resolvers (they remain as strings, e.g., "${foo}"). If any + # input looks like an interpolation we thus raise an exception. + try: + bad_arg = next(i for i in args_unesc if "${" in i) + except StopIteration: + pass + else: + raise ValueError( + f"Resolver '{name}' was called with argument '{bad_arg}' that appears " + f"to be an interpolation. Nested interpolations are not supported for " + f"resolvers registered with `[legacy_]register_resolver()`, please use " + f"`register_new_resolver()` instead (see " + f"https://github.com/omry/omegaconf/issues/426 for migration instructions)." + ) + key = args + val = cache[key] if key in cache else resolver(*args_unesc) cache[key] = val return val # noinspection PyProtectedMember - BaseContainer._resolvers[name] = caching + BaseContainer._resolvers[name] = resolver_wrapper + + @staticmethod + def register_new_resolver( + name: str, + resolver: Resolver, + use_cache: Optional[bool] = True, + ) -> None: + """ + Register a resolver. + + :param name: Name of the resolver. + :param resolver: Callable whose arguments are provided in the interpolation, + e.g., with ${foo:x,0,${y.z}} these arguments are respectively "x" (str), + 0 (int) and the value of `y.z`. + :param use_cache: Whether the resolver's outputs should be cached. The cache is + based only on the list of arguments given in the interpolation, i.e., for a + given list of arguments, the same value will always be returned. + """ + assert callable(resolver), "resolver must be callable" + # noinspection PyProtectedMember + assert ( + name not in BaseContainer._resolvers + ), "resolver {} is already registered".format(name) + + def resolver_wrapper( + config: BaseContainer, + args: Tuple[Any, ...], + args_str: Tuple[str, ...], + ) -> Any: + if use_cache: + cache = OmegaConf.get_cache(config)[name] + hashable_key = _make_hashable(args) + try: + return cache[hashable_key] + except KeyError: + pass + + # Call resolver. + ret = resolver(*args) + if use_cache: + cache[hashable_key] = ret + return ret + + # noinspection PyProtectedMember + BaseContainer._resolvers[name] = resolver_wrapper @staticmethod - def get_resolver(name: str) -> Optional[Callable[[Container, Any], Any]]: + def get_resolver( + name: str, + ) -> Optional[Callable[[Container, Tuple[Any, ...], Tuple[str, ...]], Any]]: # noinspection PyProtectedMember return ( BaseContainer._resolvers[name] if name in BaseContainer._resolvers else None diff --git a/pyproject.toml b/pyproject.toml index d4a040c43..f38b84fc1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,19 @@ +[tool.black] +exclude = ''' +( + /( + \.eggs # exclude a few common directories in the + | \.git # root of the project + | \.mypy_cache + | \omegaconf/grammar/gen + | \.nox + ) +) +''' + +[tool.pytest.ini_options] +addopts = "--import-mode=append" + [tool.towncrier] package = "omegaconf" package_dir = "" diff --git a/requirements/base.txt b/requirements/base.txt index 8746b97b2..ec35d86fa 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -1,4 +1,4 @@ +antlr4-python3-runtime==4.8 PyYAML>=5.1.* # Use dataclasses backport for Python 3.6. dataclasses;python_version=='3.6' -typing-extensions diff --git a/setup.cfg b/setup.cfg index 661e6c1b1..265a481c0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,4 +3,10 @@ test=pytest [mypy] python_version = 3.6 -mypy_path=.stubs \ No newline at end of file +mypy_path=.stubs + +[mypy-antlr4.*] +ignore_missing_imports = True + +[mypy-omegaconf.grammar.gen.*] +ignore_errors = True diff --git a/setup.py b/setup.py index 95b4d5502..3145f83c6 100644 --- a/setup.py +++ b/setup.py @@ -8,14 +8,20 @@ # Upload: twine upload dist/* """ -import codecs -import os import pathlib -import re import pkg_resources import setuptools +from build_helpers.build_helpers import ( + ANTLRCommand, + BuildPyCommand, + CleanCommand, + DevelopCommand, + SDistCommand, + find_version, +) + with pathlib.Path("requirements/base.txt").open() as requirements_txt: install_requires = [ str(requirement) @@ -23,23 +29,16 @@ ] -def find_version(*file_paths): - here = os.path.abspath(os.path.dirname(__file__)) - - def read(*parts): - with codecs.open(os.path.join(here, *parts), "r") as fp: - return fp.read() - - version_file = read(*file_paths) - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M) - if version_match: - return version_match.group(1) - raise RuntimeError("Unable to find version string.") - - with open("README.md", "r") as fh: LONG_DESC = fh.read() setuptools.setup( + cmdclass={ + "antlr": ANTLRCommand, + "clean": CleanCommand, + "sdist": SDistCommand, + "build_py": BuildPyCommand, + "develop": DevelopCommand, + }, name="omegaconf", version=find_version("omegaconf", "version.py"), author="Omry Yadan", @@ -51,7 +50,7 @@ def read(*parts): tests_require=["pytest"], url="https://github.com/omry/omegaconf", keywords="yaml configuration config", - packages=["omegaconf"], + packages=["omegaconf", "omegaconf.grammar", "omegaconf.grammar.gen"], python_requires=">=3.6", classifiers=[ "Programming Language :: Python :: 3.6", diff --git a/tests/test_base_config.py b/tests/test_base_config.py index d7c2c18dc..e10eed961 100644 --- a/tests/test_base_config.py +++ b/tests/test_base_config.py @@ -1,7 +1,7 @@ import copy import re from enum import Enum -from typing import Any, Dict, List, Union +from typing import Any, Dict, Union import pytest from pytest import raises @@ -520,26 +520,6 @@ def test_read_write_override(src: Any, func: Any, expectation: Any) -> None: func(c) -@pytest.mark.parametrize( - "string, tokenized", - [ - ("dog,cat", ["dog", "cat"]), - ("dog\\,cat\\ ", ["dog,cat "]), - ("dog,\\ cat", ["dog", " cat"]), - ("\\ ,cat", [" ", "cat"]), - ("dog, cat", ["dog", "cat"]), - ("dog, ca t", ["dog", "ca t"]), - ("dog, cat", ["dog", "cat"]), - ("whitespace\\ , before comma", ["whitespace ", "before comma"]), - (None, []), - ("", []), - ("no , escape", ["no", "escape"]), - ], -) -def test_tokenize_with_escapes(string: str, tokenized: List[str]) -> None: - assert OmegaConf._tokenize_args(string) == tokenized - - @pytest.mark.parametrize( "src, func, expectation", [({}, lambda c: c.__setattr__("foo", 1), raises(AttributeError))], @@ -690,7 +670,8 @@ def test_omegaconf_init_not_implemented() -> None: def test_resolve_str_interpolation(query: str, result: Any) -> None: cfg = OmegaConf.create({"foo": 10, "bar": "${foo}"}) assert ( - cfg._resolve_interpolation( + cfg._maybe_resolve_interpolation( + parent=None, key=None, value=StringNode(value=query), throw_on_missing=False, diff --git a/tests/test_errors.py b/tests/test_errors.py index c59ad1188..2d0c4d9bb 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -22,6 +22,7 @@ ConfigKeyError, ConfigTypeError, ConfigValueError, + GrammarParseError, InterpolationResolutionError, KeyValidationError, MissingMandatoryValue, @@ -198,7 +199,7 @@ def finalize(self, cfg: Any) -> None: create=lambda: OmegaConf.create({"foo": "${missing}"}), op=lambda cfg: getattr(cfg, "foo"), exception_type=InterpolationResolutionError, - msg="str interpolation key 'missing' not found", + msg="Interpolation key 'missing' not found", key="foo", child_node=lambda cfg: cfg._get_node("foo"), ), @@ -209,7 +210,7 @@ def finalize(self, cfg: Any) -> None: create=lambda: OmegaConf.create({"foo": "foo_${missing}"}), op=lambda cfg: getattr(cfg, "foo"), exception_type=InterpolationResolutionError, - msg="str interpolation key 'missing' not found", + msg="Interpolation key 'missing' not found", key="foo", child_node=lambda cfg: cfg._get_node("foo"), ), @@ -220,7 +221,7 @@ def finalize(self, cfg: Any) -> None: create=lambda: OmegaConf.create({"foo": {"bar": "${.missing}"}}), op=lambda cfg: getattr(cfg.foo, "bar"), exception_type=InterpolationResolutionError, - msg="str interpolation key 'missing' not found", + msg="Interpolation key 'missing' not found", key="bar", full_key="foo.bar", child_node=lambda cfg: cfg.foo._get_node("bar"), @@ -1149,14 +1150,32 @@ def test_errors(expected: Expected, monkeypatch: Any) -> None: assert e.__cause__ is None -def test_assertion_error(restore_resolvers: Any) -> None: +@pytest.mark.parametrize( + "register_func", [OmegaConf.register_resolver, OmegaConf.register_new_resolver] +) +def test_assertion_error(restore_resolvers: Any, register_func: Any) -> None: def assert_false() -> None: assert False # The purpose of this test is to cover the case where an `AssertionError` # is processed in `format_and_raise()`. Using a resolver to trigger the assertion # error is just one way of achieving this goal. - OmegaConf.register_resolver("assert_false", assert_false) + register_func("assert_false", assert_false) c = OmegaConf.create({"trigger": "${assert_false:}"}) with pytest.raises(AssertionError): c.trigger + + +@pytest.mark.parametrize( + ["create_func", "arg"], + [ + (OmegaConf.create, {"a": "${b"}), + (DictConfig, "${b"), + (ListConfig, "${b"), + ], +) +def test_parse_error_on_creation(create_func: Any, arg: Any) -> None: + with pytest.raises( + GrammarParseError, match=re.escape("no viable alternative at input '${b'") + ): + create_func(arg) diff --git a/tests/test_grammar.py b/tests/test_grammar.py new file mode 100644 index 000000000..8fca34863 --- /dev/null +++ b/tests/test_grammar.py @@ -0,0 +1,479 @@ +import math +from typing import Any, Callable, List, Optional, Tuple + +import antlr4 +from pytest import mark, param, raises + +from omegaconf import ( + DictConfig, + ListConfig, + OmegaConf, + _utils, + grammar_parser, + grammar_visitor, +) +from omegaconf.errors import ( + GrammarParseError, + InterpolationResolutionError, + UnsupportedInterpolationType, +) + +# A fixed config that may be used (but not modified!) by tests. +BASE_TEST_CFG = OmegaConf.create( + { + # Standard data types. + "str": "hi", + "int": 123, + "float": 1.2, + "dict": {"a": 0}, + "list": [x - 1 for x in range(11)], + "null": None, + # Special cases. + "x@y": 123, # to test keys with @ in name + "0": 0, # to test keys with int names + "1": {"2": 12}, # to test dot-path with int keys + "FalsE": {"TruE": True}, # to test keys with bool names + "None": {"null": 1}, # to test keys with null-like names + # Used in nested interpolations. + "str_test": "test", + "ref_str": "str", + "options": {"a": "A", "b": "B"}, + "choice": "a", + "rel_opt": ".options", + } +) + + +# Parameters for tests of the "singleElement" rule when there is no interpolation. +# Each item is a tuple with three elements: +# - The id of the test. +# - The expression to be evaluated. +# - The expected result, that may be an exception. If it is a `GrammarParseError` then +# it is assumed that the parsing will fail. If it is another kind of exception then +# it is assumed that the parsing will succeed, but this exception will be raised when +# visiting (= evaluating) the parse tree. If the expected behavior is for the parsing +# to succeed, but a `GrammarParseError` to be raised when visiting it, then set the +# expected result to the pair `(None, GrammarParseError)`. +PARAMS_SINGLE_ELEMENT_NO_INTERPOLATION: List[Tuple[str, str, Any]] = [ + # Special keywords. + ("null", "null", None), + ("true", "TrUe", True), + ("false", "falsE", False), + ("true_false", "true_false", "true_false"), + # Integers. + ("int", "123", 123), + ("int_pos", "+123", 123), + ("int_neg", "-123", -123), + ("int_underscore", "1_000", 1000), + ("int_bad_underscore_1", "1_000_", "1_000_"), + ("int_bad_underscore_2", "1__000", "1__000"), + ("int_bad_underscore_3", "_1000", "_1000"), + ("int_bad_zero_start", "007", "007"), + # Floats. + ("float", "1.1", 1.1), + ("float_no_int", ".1", 0.1), + ("float_no_decimal", "1.", 1.0), + ("float_minus", "-.2", -0.2), + ("float_underscore", "1.1_1", 1.11), + ("float_bad_1", "1.+2", "1.+2"), + ("float_bad_2", r"1\.2", r"1\.2"), + ("float_bad_3", "1.2_", "1.2_"), + ("float_exp_1", "-1e2", -100.0), + ("float_exp_2", "+1E-2", 0.01), + ("float_exp_3", "1_0e1_0", 10e10), + ("float_exp_4", "1.07e+2", 107.0), + ("float_exp_5", "1e+03", 1000.0), + ("float_exp_bad_1", "e-2", "e-2"), + ("float_exp_bad_2", "01e2", "01e2"), + ("float_inf", "inf", math.inf), + ("float_plus_inf", "+inf", math.inf), + ("float_minus_inf", "-inf", -math.inf), + ("float_nan", "nan", math.nan), + ("float_plus_nan", "+nan", math.nan), + ("float_minus_nan", "-nan", math.nan), + # Unquoted strings. + ("str_legal", "a/-\\+.$*@\\\\", "a/-\\+.$*@\\"), + ("str_illegal_1", "a,=b", GrammarParseError), + ("str_illegal_2", f"{chr(200)}", GrammarParseError), + ("str_illegal_3", f"{chr(129299)}", GrammarParseError), + ("str_dot", ".", "."), + ("str_dollar", "$", "$"), + ("str_colon", ":", ":"), + ("str_ws_1", "hello world", "hello world"), + ("str_ws_2", "a b\tc \t\t d", "a b\tc \t\t d"), + ("str_esc_ws_1", r"\ hello\ world\ ", " hello world "), + ("str_esc_ws_2", "\\ \\\t\\\t", " \t\t"), + ("str_esc_comma", r"hello\, world", "hello, world"), + ("str_esc_colon", r"a\:b", "a:b"), + ("str_esc_equal", r"a\=b", "a=b"), + ("str_esc_parentheses", r"\(foo\)", "(foo)"), + ("str_esc_brackets", r"\[foo\]", "[foo]"), + ("str_esc_braces", r"\{foo\}", "{foo}"), + ("str_esc_backslash", r"\\", "\\"), + ("str_backslash_noesc", r"ab\cd", r"ab\cd"), + ("str_esc_illegal_1", r"\#", GrammarParseError), + ("str_esc_illegal_2", "\\'\\\"", GrammarParseError), + # Quoted strings. + ("str_quoted_single", "'!@#$%^&*()[]:.,\"'", '!@#$%^&*()[]:.,"'), + ("str_quoted_double", '"!@#$%^&*()[]:.,\'"', "!@#$%^&*()[]:.,'"), + ("str_quoted_outer_ws_single", "' a \t'", " a \t"), + ("str_quoted_outer_ws_double", '" a \t"', " a \t"), + ("str_quoted_int", "'123'", "123"), + ("str_quoted_null", "'null'", "null"), + ("str_quoted_bool", "['truE', \"FalSe\"]", ["truE", "FalSe"]), + ("str_quoted_list", "'[a,b, c]'", "[a,b, c]"), + ("str_quoted_dict", '"{a:b, c: d}"', "{a:b, c: d}"), + ("str_quoted_backslash_noesc_single", r"'a\b'", r"a\b"), + ("str_quoted_backslash_noesc_double", r'"a\b"', r"a\b"), + ("str_quoted_concat_bad_2", "'Hi''there'", GrammarParseError), + ("str_quoted_too_many_1", "''a'", GrammarParseError), + ("str_quoted_too_many_2", "'a''", GrammarParseError), + ("str_quoted_too_many_3", "''a''", GrammarParseError), + # Lists and dictionaries. + ("list", "[0, 1]", [0, 1]), + ( + "dict", + "{x: 1, a: b, y: 1e2, null2: 0.1, true3: false, inf4: true}", + {"x": 1, "a": "b", "y": 100.0, "null2": 0.1, "true3": False, "inf4": True}, + ), + ( + "dict_unquoted_key", + "{a0-null-1-3.14-NaN- \t-true-False-/\\+.$%*@\\(\\)\\[\\]\\{\\}\\:\\=\\ \\\t\\,:0}", + {"a0-null-1-3.14-NaN- \t-true-False-/\\+.$%*@()[]{}:= \t,": 0}, + ), + ( + "dict_quoted", + "{0: 1, 'a': 'b', 1.1: 1e2, null: 0.1, true: false, -inf: true}", + {0: 1, "a": "b", 1.1: 100.0, None: 0.1, True: False, -math.inf: True}, + ), + ( + "structured_mixed", + "[10,str,3.14,true,false,inf,[1,2,3], 'quoted', \"quoted\", 'a,b,c']", + [ + 10, + "str", + 3.14, + True, + False, + math.inf, + [1, 2, 3], + "quoted", + "quoted", + "a,b,c", + ], + ), + ("dict_int_key", "{0: 0}", {0: 0}), + ("dict_float_key", "{1.1: 0}", {1.1: 0}), + ("dict_null_key", "{null: 0}", {None: 0}), + ("dict_nan_like_key", "{'nan': 0}", {"nan": 0}), + ("dict_list_as_key", "{[0]: 1}", GrammarParseError), + ( + "dict_bool_key", + "{true: true, false: 'false'}", + {True: True, False: "false"}, + ), + ("empty_dict", "{}", {}), + ("empty_list", "[]", []), + ( + "structured_deep", + "{null0: [0, 3.14, false], true1: {a: [0, 1, 2], b: {}}}", + {"null0": [0, 3.14, False], "true1": {"a": [0, 1, 2], "b": {}}}, + ), +] + +# Parameters for tests of the "singleElement" rule when there are interpolations. +PARAMS_SINGLE_ELEMENT_WITH_INTERPOLATION = [ + # Node interpolations. + ("dict_access", "${dict.a}", 0), + ("list_access", "${list.0}", -1), + ("list_access_underscore", "${list.1_0}", 9), + ("list_access_bad_negative", "${list.-1}", InterpolationResolutionError), + ("dict_access_list_like_1", "${0}", 0), + ("dict_access_list_like_2", "${1.2}", 12), + ("bool_like_keys", "${FalsE.TruE}", True), + ("null_like_key_ok", "${None.null}", 1), + ("null_like_key_bad_case", "${NoNe.null}", InterpolationResolutionError), + ("null_like_key_quoted_1", "${'None'.'null'}", GrammarParseError), + ("null_like_key_quoted_2", "${'None.null'}", GrammarParseError), + ("dotpath_bad_type", "${dict.${float}}", (None, GrammarParseError)), + ("at_in_key", "${x@y}", 123), + # Interpolations in dictionaries. + ("dict_interpolation_value", "{hi: ${str}, int: ${int}}", {"hi": "hi", "int": 123}), + ("dict_interpolation_key", "{${str}: 0, ${null}: 1", GrammarParseError), + # Interpolations in lists. + ("list_interpolation", "[${str}, ${int}]", ["hi", 123]), + # Interpolations in unquoted strings. + ("str_dollar_and_inter", "$$${str}", "$$hi"), + ("str_inter", "hi_${str}", "hi_hi"), + ("str_esc_illegal_3", r"\${foo\}", GrammarParseError), + # Interpolations in quoted strings. + ("str_quoted_inter", "'${null}'", "None"), + ("str_quoted_esc_single_1", r"'ab\'cd\'\'${str}'", "ab'cd''hi"), + ("str_quoted_esc_single_2", "'\"\\\\\\\\\\${foo}\\ '", r'"\${foo}\ '), + ("str_quoted_esc_double_1", r'"ab\"cd\"\"${str}"', 'ab"cd""hi'), + ("str_quoted_esc_double_2", '"\'\\\\\\\\\\${foo}\\ "', r"'\${foo}\ "), + ("str_quoted_concat_bad_1", '"Hi "${str}', GrammarParseError), + # Whitespaces. + ("ws_inter_node_outer", "${ \tdict.a \t}", 0), + ("ws_inter_node_around_dot", "${dict .\ta}", 0), + ("ws_inter_node_inside_id", "${d i c t.a}", GrammarParseError), + ("ws_inter_res_outer", "${\t test:foo\t }", "foo"), + ("ws_inter_res_around_colon", "${test\t : \tfoo}", "foo"), + ("ws_inter_res_inside_id", "${te st:foo}", GrammarParseError), + ("ws_inter_res_inside_args", "${test:f o o}", "f o o"), + ("ws_list", "${test:[\t a, b, ''\t ]}", ["a", "b", ""]), + ("ws_dict", "${test:{\t a : 1\t , b: \t''}}", {"a": 1, "b": ""}), + ("ws_quoted_single", "${test: \t'foo'\t }", "foo"), + ("ws_quoted_double", '${test: \t"foo"\t }', "foo"), + # Nested interpolations. + ("nested_simple", "${${ref_str}}", "hi"), + ("nested_select", "${options.${choice}}", "A"), + ("nested_relative", "${${rel_opt}.b}", "B"), + ("str_quoted_nested", r"'AB${test:\'CD${test:\\'EF\\'}GH\'}'", "ABCDEFGH"), + # Resolver interpolations. + ("no_args", "${test:}", []), + ("space_in_args", "${test:a, b c}", ["a", "b c"]), + ("list_as_input", "${test:[a, b], 0, [1.1]}", [["a", "b"], 0, [1.1]]), + ("dict_as_input", "${test:{a: 1.1, b: b}}", {"a": 1.1, "b": "b"}), + ("dict_as_input_quotes", "${test:{'a': 1.1, b: b}}", {"a": 1.1, "b": "b"}), + ("dict_typo_colons", "${test:{a: 1.1, b:: b}}", {"a": 1.1, "b": ": b"}), + ("missing_resolver", "${MiSsInG_ReSoLvEr:0}", UnsupportedInterpolationType), + ("at_in_resolver", "${y@z:}", GrammarParseError), + # Nested resolvers. + ("nested_resolver", "${${str_test}:a, b, c}", ["a", "b", "c"]), + ("nested_deep", "${test:${${test:${ref_str}}}}", "hi"), + ( + "nested_resolver_combined_illegal", + "${some_${resolver}:a, b, c}", + GrammarParseError, + ), + ("nested_args", "${test:${str}, ${null}, ${int}}", ["hi", None, 123]), + # Invalid resolver names. + ("int_resolver_quoted", "${'0':1,2,3}", GrammarParseError), + ("int_resolver_noquote", "${0:1,2,3}", GrammarParseError), + ("float_resolver_quoted", "${'1.1':1,2,3}", GrammarParseError), + ("float_resolver_noquote", "${1.1:1,2,3}", GrammarParseError), + ("float_resolver_exp", "${1e1:1,2,3}", GrammarParseError), + ("inter_float_resolver", "${${float}:1,2,3}", (None, GrammarParseError)), + # NaN as dictionary key (a resolver is used here to output only the key). + ("dict_nan_key_1", "${first:{nan: 0}}", math.nan), + ("dict_nan_key_2", "${first:{${test:nan}: 0}}", GrammarParseError), +] + +# Parameters for tests of the "configValue" rule (may contain node +# interpolations, but no resolvers). +PARAMS_CONFIG_VALUE = [ + # String interpolations (top-level). + ("str_top_basic", "bonjour ${str}", "bonjour hi"), + ("str_top_quotes_single_1", "'bonjour ${str}'", "'bonjour hi'"), + ( + "str_top_quotes_single_2", + "'Bonjour ${str}', I said.", + "'Bonjour hi', I said.", + ), + ("str_top_quotes_double_1", '"bonjour ${str}"', '"bonjour hi"'), + ( + "str_top_quotes_double_2", + '"Bonjour ${str}", I said.', + '"Bonjour hi", I said.', + ), + ("str_top_missing_end_quote_single", "'${str}", "'hi"), + ("str_top_missing_end_quote_double", '"${str}', '"hi'), + ("str_top_missing_start_quote_double", '${str}"', 'hi"'), + ("str_top_missing_start_quote_single", "${str}'", "hi'"), + ("str_top_middle_quote_single", "I'd like ${str}", "I'd like hi"), + ("str_top_middle_quote_double", 'I"d like ${str}', 'I"d like hi'), + ("str_top_middle_quotes_single", "I like '${str}'", "I like 'hi'"), + ("str_top_middle_quotes_double", 'I like "${str}"', 'I like "hi"'), + ("str_top_any_char", "${str} !@\\#$%^&*})][({,/?;", "hi !@\\#$%^&*})][({,/?;"), + ("str_top_esc_inter", r"Esc: \${str}", "Esc: ${str}"), + ("str_top_esc_inter_wrong_1", r"Wrong: $\{str\}", r"Wrong: $\{str\}"), + ("str_top_esc_inter_wrong_2", r"Wrong: \${str\}", r"Wrong: ${str\}"), + ("str_top_esc_backslash", r"Esc: \\${str}", r"Esc: \hi"), + ("str_top_quoted_braces_wrong", r"Wrong: \{${str}\}", r"Wrong: \{hi\}"), + ("str_top_leading_dollars", r"$$${str}", "$$hi"), + ("str_top_trailing_dollars", r"${str}$$$$", "hi$$$$"), + ("str_top_leading_escapes", r"\\\\\${str}", r"\\${str}"), + ("str_top_middle_escapes", r"abc\\\\\${str}", r"abc\\${str}"), + ("str_top_trailing_escapes", "${str}" + "\\" * 5, "hi" + "\\" * 3), + ("str_top_concat_interpolations", "${null}${float}", "None1.2"), + # Whitespaces. + ("ws_toplevel", " \tab ${str} cd ${int}\t", " \tab hi cd 123\t"), + # Unmatched braces. + ("missing_brace_1", "${test:${str}", GrammarParseError), + ("missing_brace_2", "${${test:str}", GrammarParseError), + ("extra_brace", "${str}}", "hi}"), +] + + +def parametrize_from( + data: List[Tuple[str, str, Any]] +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """Utility function to create PyTest parameters from the lists above""" + return mark.parametrize( + ["definition", "expected"], + [param(definition, expected, id=key) for key, definition, expected in data], + ) + + +class TestOmegaConfGrammar: + """ + Test most grammar constructs. + + Each method in this class tests the validity of expressions in a specific + setting. For instance, `test_single_element_no_interpolation()` tests the + "singleElement" parsing rule on expressions that do not contain interpolations + (which allows for faster tests without using any config object). + + Tests that actually need a config object all re-use the same `BASE_TEST_CFG` + config, to avoid creating a new config for each test. + """ + + @parametrize_from(PARAMS_SINGLE_ELEMENT_NO_INTERPOLATION) + def test_single_element_no_interpolation( + self, definition: str, expected: Any + ) -> None: + parse_tree, expected_visit = self._parse("singleElement", definition, expected) + if parse_tree is None: + return + + # Since there are no interpolations here, we do not need to provide + # callbacks to resolve them, and the quoted string callback can simply + # be the identity. + visitor = grammar_visitor.GrammarVisitor( + node_interpolation_callback=None, # type: ignore + resolver_interpolation_callback=None, # type: ignore + quoted_string_callback=lambda s: s, + ) + self._visit(lambda: visitor.visit(parse_tree), expected_visit) + + @parametrize_from(PARAMS_SINGLE_ELEMENT_WITH_INTERPOLATION) + def test_single_element_with_resolver( + self, restore_resolvers: Any, definition: str, expected: Any + ) -> None: + parse_tree, expected_visit = self._parse("singleElement", definition, expected) + + OmegaConf.register_new_resolver("test", self._resolver_test) + OmegaConf.register_new_resolver("first", self._resolver_first) + + self._visit_with_config(parse_tree, expected_visit) + + @parametrize_from(PARAMS_CONFIG_VALUE) + def test_config_value( + self, restore_resolvers: Any, definition: str, expected: Any + ) -> None: + parse_tree, expected_visit = self._parse("configValue", definition, expected) + self._visit_with_config(parse_tree, expected_visit) + + def _check_is_same_type(self, value: Any, expected: Any) -> None: + """ + Helper function to validate that types of `value` and `expected are the same. + + This function assumes that `value == expected` holds, and performs a "deep" + comparison of types (= it goes into data structures like dictionaries, lists + and tuples). + + Note that dictionaries being compared must have keys ordered the same way! + """ + assert type(value) is type(expected) + if isinstance(value, (str, int, float)): + pass + elif isinstance(value, (list, tuple, ListConfig)): + for vx, ex in zip(value, expected): + self._check_is_same_type(vx, ex) + elif isinstance(value, (dict, DictConfig)): + for (vk, vv), (ek, ev) in zip(value.items(), expected.items()): + assert vk == ek, "dictionaries are not ordered the same" + self._check_is_same_type(vk, ek) + self._check_is_same_type(vv, ev) + elif value is None: + assert expected is None + else: + raise NotImplementedError(type(value)) + + def _get_expected(self, expected: Any) -> Tuple[Any, Any]: + """Obtain the expected result of the parse & visit steps""" + if isinstance(expected, tuple): + # Outcomes of both the parse and visit steps are provided. + assert len(expected) == 2 + return expected[0], expected[1] + elif expected is GrammarParseError: + # If only a `GrammarParseError` is expected, assume it happens in parse step. + return expected, None + else: + # If anything else is provided, assume it is the outcome of the visit step. + return None, expected + + def _get_lexer_mode(self, rule: str) -> str: + return {"configValue": "DEFAULT_MODE", "singleElement": "VALUE_MODE"}[rule] + + def _parse( + self, rule: str, definition: str, expected: Any + ) -> Tuple[Optional[antlr4.ParserRuleContext], Any]: + """ + Parse the expression given by `definition`. + + Return both the parse tree and the expected result when visiting this tree. + """ + + def get_tree() -> antlr4.ParserRuleContext: + return grammar_parser.parse( + value=definition, + parser_rule=rule, + lexer_mode=self._get_lexer_mode(rule), + ) + + expected_parse, expected_visit = self._get_expected(expected) + if expected_parse is None: + return get_tree(), expected_visit + else: # expected failure on the parse step + with raises(expected_parse): + get_tree() + return None, None + + def _resolver_first(self, item: Any, *_: Any) -> Any: + """Resolver that returns the first element of its first input""" + try: + return next(iter(item)) + except StopIteration: + assert False # not supposed to happen in current tests + + def _resolver_test(self, *args: Any) -> Any: + """Resolver that returns the list of its inputs""" + return args[0] if len(args) == 1 else list(args) + + def _visit(self, visit: Callable[[], Any], expected: Any) -> None: + """Run the `visit()` function to visit the parse tree and validate the result""" + if isinstance(expected, type) and issubclass(expected, Exception): + with raises(expected): + visit() + else: + result = visit() + if expected is math.nan: + # Special case since nan != nan. + assert math.isnan(result) + else: + assert result == expected + # We also check types in particular because instances of `Node` are very + # good at mimicking their underlying type's behavior, and it is easy to + # fail to notice that the result contains nodes when it should not. + self._check_is_same_type(result, expected) + + def _visit_with_config( + self, parse_tree: antlr4.ParserRuleContext, expected: Any + ) -> None: + """Visit the tree using the default config `BASE_TEST_CFG`""" + if parse_tree is None: + return + cfg = BASE_TEST_CFG + + def visit() -> Any: + return _utils._get_value( + cfg.resolve_parse_tree( + parse_tree, + key=None, + parent=cfg, + ) + ) + + self._visit(visit, expected) diff --git a/tests/test_interpolation.py b/tests/test_interpolation.py index c561c3786..40bb7f8c7 100644 --- a/tests/test_interpolation.py +++ b/tests/test_interpolation.py @@ -1,5 +1,4 @@ import copy -import os import random import re from typing import Any, Optional, Tuple @@ -16,16 +15,27 @@ OmegaConf, Resolver, ValidationError, + grammar_parser, ) from omegaconf._utils import _ensure_container from omegaconf.errors import ( ConfigAttributeError, + GrammarParseError, InterpolationResolutionError, OmegaConfBaseException, + UnsupportedInterpolationType, ) from . import StructuredWithMissing +# file deepcode ignore CopyPasteError: +# The above comment is a statement to stop DeepCode from raising a warning on +# lines that do equality checks of the form +# c.k == c.k + +# Characters that are not allowed by the grammar in config key names. +INVALID_CHARS_IN_KEY_NAMES = "\\${}()[].: '\"" + @pytest.mark.parametrize( "cfg,key,expected", @@ -251,6 +261,14 @@ def test_env_interpolation( assert OmegaConf.select(cfg, key) == expected +def test_env_is_cached(monkeypatch: Any) -> None: + monkeypatch.setenv("foobar", "1234") + c = OmegaConf.create({"foobar": "${env:foobar}"}) + before = c.foobar + monkeypatch.setenv("foobar", "3456") + assert c.foobar == before + + @pytest.mark.parametrize( "value,expected", [ @@ -274,36 +292,91 @@ def test_env_interpolation( # yaml strings are not getting parsed by the env resolver ("foo: bar", "foo: bar"), ("foo: \n - bar\n - baz", "foo: \n - bar\n - baz"), + # more advanced uses of the grammar + ("ab \\{foo} cd", "ab \\{foo} cd"), + ("ab \\\\{foo} cd", "ab \\\\{foo} cd"), + ("'\\${other_key}'", "${other_key}"), # escaped interpolation + ("'ab \\${other_key} cd'", "ab ${other_key} cd"), # escaped interpolation + ("[1, 2, 3]", [1, 2, 3]), + ("{a: 0, b: 1}", {"a": 0, "b": 1}), + (" 123 ", " 123 "), + (" 1 2 3 ", " 1 2 3 "), + ("\t[1, 2, 3]\t", "\t[1, 2, 3]\t"), + ("[\t1, 2, 3\t]", [1, 2, 3]), + (" {a: b}\t ", " {a: b}\t "), + ("{ a: b\t }", {"a": "b"}), + ("'123'", "123"), + ("${env:my_key_2}", 456), # can call another resolver ], ) -def test_env_values_are_typed(value: Any, expected: Any) -> None: - try: - os.environ["my_key"] = value - c = OmegaConf.create(dict(my_key="${env:my_key}")) - assert c.my_key == expected - finally: - del os.environ["my_key"] +def test_env_values_are_typed(monkeypatch: Any, value: Any, expected: Any) -> None: + monkeypatch.setenv("my_key", value) + monkeypatch.setenv("my_key_2", "456") + c = OmegaConf.create(dict(my_key="${env:my_key}")) + assert c.my_key == expected + + +def test_env_node_interpolation(monkeypatch: Any) -> None: + # Test that node interpolations are not supported in env variables. + monkeypatch.setenv("my_key", "${other_key}") + c = OmegaConf.create(dict(my_key="${env:my_key}", other_key=123)) + with pytest.raises(InterpolationResolutionError): + c.my_key + + +def test_env_default_none(monkeypatch: Any) -> None: + monkeypatch.delenv("my_key", raising=False) + c = OmegaConf.create({"my_key": "${env:my_key, null}"}) + assert c.my_key is None def test_register_resolver_twice_error(restore_resolvers: Any) -> None: + def foo(_: Any) -> int: + return 10 + + OmegaConf.register_new_resolver("foo", foo) + with pytest.raises(AssertionError): + OmegaConf.register_new_resolver("foo", lambda _: 10) + + +def test_register_resolver_twice_error_legacy(restore_resolvers: Any) -> None: def foo() -> int: return 10 - OmegaConf.register_resolver("foo", foo) + OmegaConf.legacy_register_resolver("foo", foo) with pytest.raises(AssertionError): - OmegaConf.register_resolver("foo", lambda: 10) + OmegaConf.register_new_resolver("foo", lambda: 10) def test_clear_resolvers(restore_resolvers: Any) -> None: assert OmegaConf.get_resolver("foo") is None - OmegaConf.register_resolver("foo", lambda x: int(x) + 10) + OmegaConf.register_new_resolver("foo", lambda x: x + 10) + assert OmegaConf.get_resolver("foo") is not None + OmegaConf.clear_resolvers() + assert OmegaConf.get_resolver("foo") is None + + +def test_clear_resolvers_legacy(restore_resolvers: Any) -> None: + assert OmegaConf.get_resolver("foo") is None + OmegaConf.legacy_register_resolver("foo", lambda x: int(x) + 10) assert OmegaConf.get_resolver("foo") is not None OmegaConf.clear_resolvers() assert OmegaConf.get_resolver("foo") is None def test_register_resolver_1(restore_resolvers: Any) -> None: - OmegaConf.register_resolver("plus_10", lambda x: int(x) + 10) + OmegaConf.register_new_resolver("plus_10", lambda x: x + 10) + c = OmegaConf.create( + {"k": "${plus_10:990}", "node": {"bar": 10, "foo": "${plus_10:${.bar}}"}} + ) + + assert type(c.k) == int + assert c.k == 1000 + assert c.node.foo == 20 # this also tests relative interpolations with resolvers + + +def test_register_resolver_1_legacy(restore_resolvers: Any) -> None: + OmegaConf.legacy_register_resolver("plus_10", lambda x: int(x) + 10) c = OmegaConf.create({"k": "${plus_10:990}"}) assert type(c.k) == int @@ -315,7 +388,13 @@ def test_resolver_cache_1(restore_resolvers: Any) -> None: # subsequent calls to the same function with the same argument will always return the same value. # this is important to allow embedding of functions like time() without having the value change during # the program execution. - OmegaConf.register_resolver("random", lambda _: random.randint(0, 10000000)) + OmegaConf.register_new_resolver("random", lambda _: random.randint(0, 10000000)) + c = OmegaConf.create({"k": "${random:__}"}) + assert c.k == c.k + + +def test_resolver_cache_1_legacy(restore_resolvers: Any) -> None: + OmegaConf.legacy_register_resolver("random", lambda _: random.randint(0, 10000000)) c = OmegaConf.create({"k": "${random:_}"}) assert c.k == c.k @@ -324,7 +403,17 @@ def test_resolver_cache_2(restore_resolvers: Any) -> None: """ Tests that resolver cache is not shared between different OmegaConf objects """ - OmegaConf.register_resolver("random", lambda _: random.randint(0, 10000000)) + OmegaConf.register_new_resolver("random", lambda _: random.randint(0, 10000000)) + c1 = OmegaConf.create({"k": "${random:__}"}) + c2 = OmegaConf.create({"k": "${random:__}"}) + + assert c1.k != c2.k + assert c1.k == c1.k + assert c2.k == c2.k + + +def test_resolver_cache_2_legacy(restore_resolvers: Any) -> None: + OmegaConf.legacy_register_resolver("random", lambda _: random.randint(0, 10000000)) c1 = OmegaConf.create({"k": "${random:_}"}) c2 = OmegaConf.create({"k": "${random:_}"}) @@ -333,11 +422,54 @@ def test_resolver_cache_2(restore_resolvers: Any) -> None: assert c2.k == c2.k +def test_resolver_cache_3_dict_list(restore_resolvers: Any) -> None: + """ + Tests that the resolver cache works as expected with lists and dicts. + """ + OmegaConf.register_new_resolver("random", lambda _: random.uniform(0, 1)) + c = OmegaConf.create( + dict( + lst1="${random:[0, 1]}", + lst2="${random:[0, 1]}", + lst3="${random:[]}", + dct1="${random:{a: 1, b: 2}}", + dct2="${random:{b: 2, a: 1}}", + mixed1="${random:{x: [1.1], y: {a: true, b: false, c: null, d: []}}}", + mixed2="${random:{x: [1.1], y: {b: false, c: null, a: true, d: []}}}", + ) + ) + assert c.lst1 == c.lst1 + assert c.lst1 == c.lst2 + assert c.lst1 != c.lst3 + assert c.dct1 == c.dct1 + assert c.dct1 == c.dct2 + assert c.mixed1 == c.mixed1 + assert c.mixed2 == c.mixed2 + assert c.mixed1 == c.mixed2 + + +def test_resolver_no_cache(restore_resolvers: Any) -> None: + OmegaConf.register_new_resolver( + "random", lambda _: random.uniform(0, 1), use_cache=False + ) + c = OmegaConf.create(dict(k="${random:__}")) + assert c.k != c.k + + def test_resolver_dot_start(restore_resolvers: Any) -> None: """ Regression test for #373 """ - OmegaConf.register_resolver("identity", lambda x: x) + OmegaConf.register_new_resolver("identity", lambda x: x) + c = OmegaConf.create( + {"foo_nodot": "${identity:bar}", "foo_dot": "${identity:.bar}"} + ) + assert c.foo_nodot == "bar" + assert c.foo_dot == ".bar" + + +def test_resolver_dot_start_legacy(restore_resolvers: Any) -> None: + OmegaConf.legacy_register_resolver("identity", lambda x: x) c = OmegaConf.create( {"foo_nodot": "${identity:bar}", "foo_dot": "${identity:.bar}"} ) @@ -358,23 +490,80 @@ def test_resolver_dot_start(restore_resolvers: Any) -> None: ( lambda *args: args, "escape_whitespace", + "${my_resolver:cat,\\ do g}", + ("cat", " do g"), + ), + (lambda: "zero", "zero_arg", "${my_resolver:}", "zero"), + ], +) +def test_resolver_that_allows_a_list_of_arguments( + restore_resolvers: Any, resolver: Resolver, name: str, key: str, result: Any +) -> None: + OmegaConf.register_new_resolver("my_resolver", resolver) + c = OmegaConf.create({name: key}) + assert c[name] == result + + +@pytest.mark.parametrize( + "resolver,name,key,result", + [ + (lambda *args: args, "arg_list", "${my_resolver:cat, dog}", ("cat", "dog")), + ( + lambda *args: args, + "escape_comma", "${my_resolver:cat\\, do g}", ("cat, do g",), ), + ( + lambda *args: args, + "escape_whitespace", + "${my_resolver:cat,\\ do g}", + ("cat", " do g"), + ), (lambda: "zero", "zero_arg", "${my_resolver:}", "zero"), ], ) -def test_resolver_that_allows_a_list_of_arguments( +def test_resolver_that_allows_a_list_of_arguments_legacy( restore_resolvers: Any, resolver: Resolver, name: str, key: str, result: Any ) -> None: - OmegaConf.register_resolver("my_resolver", resolver) + OmegaConf.legacy_register_resolver("my_resolver", resolver) c = OmegaConf.create({name: key}) assert c[name] == result +def test_resolver_deprecated_behavior(restore_resolvers: Any) -> None: + # Ensure that resolvers registered with the old "register_resolver()" function + # behave as expected. + + # The registration should trigger a deprecation warning. + # with pytest.warns(UserWarning): # TODO re-enable this check with the warning + OmegaConf.register_resolver("my_resolver", lambda *args: args) + + c = OmegaConf.create( + { + "int": "${my_resolver:1}", + "null": "${my_resolver:null}", + "bool": "${my_resolver:TruE,falSE}", + "str": "${my_resolver:a,b,c}", + "inter": "${my_resolver:${int}}", + } + ) + + # All resolver arguments should be provided as strings (with no modification). + assert c.int == ("1",) + assert c.null == ("null",) + assert c.bool == ("TruE", "falSE") + assert c.str == ("a", "b", "c") + + # Trying to nest interpolations should trigger an error (users should switch to + # `register_new_resolver()` in order to use nested interpolations). + with pytest.raises(ValueError): + c.inter + + def test_copy_cache(restore_resolvers: Any) -> None: - OmegaConf.register_resolver("random", lambda _: random.randint(0, 10000000)) - d = {"k": "${random:_}"} + OmegaConf.register_new_resolver("random", lambda _: random.randint(0, 10000000)) + d = {"k": "${random:__}"} c1 = OmegaConf.create(d) assert c1.k == c1.k @@ -391,21 +580,56 @@ def test_copy_cache(restore_resolvers: Any) -> None: def test_clear_cache(restore_resolvers: Any) -> None: - OmegaConf.register_resolver("random", lambda _: random.randint(0, 10000000)) - c = OmegaConf.create(dict(k="${random:_}")) + OmegaConf.register_new_resolver("random", lambda _: random.randint(0, 10000000)) + c = OmegaConf.create(dict(k="${random:__}")) old = c.k OmegaConf.clear_cache(c) assert old != c.k def test_supported_chars() -> None: - supported_chars = "%_-abc123." + supported_chars = "abc123_/:-\\+.$%*@" c = OmegaConf.create(dict(dir1="${copy:" + supported_chars + "}")) - OmegaConf.register_resolver("copy", lambda x: x) + OmegaConf.register_new_resolver("copy", lambda x: x) assert c.dir1 == supported_chars +def test_valid_chars_in_key_names() -> None: + valid_chars = "".join( + chr(i) for i in range(33, 128) if chr(i) not in INVALID_CHARS_IN_KEY_NAMES + ) + cfg_dict = {valid_chars: 123, "inter": f"${{{valid_chars}}}"} + cfg = OmegaConf.create(cfg_dict) + # Test that we can access the node made of all valid characters, both + # directly and through interpolations. + assert cfg[valid_chars] == 123 + assert cfg.inter == 123 + + +@pytest.mark.parametrize("c", list(INVALID_CHARS_IN_KEY_NAMES)) +def test_invalid_chars_in_key_names(c: str) -> None: + def create() -> DictConfig: + return OmegaConf.create({"invalid": f"${{ab{c}de}}"}) + + # Test that all invalid characters trigger errors in interpolations. + if c in [".", "}"]: + # With '.', we try to access `${ab.de}`. + # With '}', we try to access `${ab}`. + cfg = create() + with pytest.raises(InterpolationResolutionError): + cfg.invalid + elif c == ":": + # With ':', we try to run a resolver `${ab:de}` + cfg = create() + with pytest.raises(UnsupportedInterpolationType): + cfg.invalid + else: + # Other invalid characters should be detected at creation time. + with pytest.raises(GrammarParseError): + create() + + def test_interpolation_in_list_key_error() -> None: # Test that a KeyError is thrown if an str_interpolation key is not available c = OmegaConf.create(["${10}"]) @@ -477,3 +701,11 @@ def test_optional_after_interpolation() -> None: # Ensure that we can set an optional field to `None` even when it currently # points to a non-optional field. cfg.opt_num = None + + +def test_empty_stack() -> None: + """ + Check that an empty stack during ANTLR parsing raises a `GrammarParseError`. + """ + with pytest.raises(GrammarParseError): + grammar_parser.parse("ab}", lexer_mode="VALUE_MODE") diff --git a/tests/test_matrix.py b/tests/test_matrix.py index 492aa3454..74ce9b835 100644 --- a/tests/test_matrix.py +++ b/tests/test_matrix.py @@ -170,11 +170,14 @@ def test_none_construction(self, node_type: Any, values: Any) -> None: with pytest.raises(ValidationError): node_type(value=None, is_optional=False) + @pytest.mark.parametrize( + "register_func", [OmegaConf.register_resolver, OmegaConf.register_new_resolver] + ) def test_interpolation( - self, node_type: Any, values: Any, restore_resolvers: Any + self, node_type: Any, values: Any, restore_resolvers: Any, register_func: Any ) -> None: resolver_output = 9999 - OmegaConf.register_resolver("func", lambda: resolver_output) + register_func("func", lambda: resolver_output) values = copy.deepcopy(values) for value in values: node = { diff --git a/tests/test_merge.py b/tests/test_merge.py index 85e2603bb..943fb21e3 100644 --- a/tests/test_merge.py +++ b/tests/test_merge.py @@ -668,11 +668,16 @@ def test_merge_with_error_not_changing_target(c1: Any, c2: Any) -> Any: assert c1 == backup -def test_into_custom_resolver_that_throws(restore_resolvers: Any) -> None: +@pytest.mark.parametrize( + "register_func", [OmegaConf.register_resolver, OmegaConf.register_new_resolver] +) +def test_into_custom_resolver_that_throws( + restore_resolvers: Any, register_func: Any +) -> None: def fail() -> None: raise ValueError() - OmegaConf.register_resolver("fail", fail) + register_func("fail", fail) configs = ( {"d": 20, "i": "${fail:}"}, diff --git a/tests/test_nodes.py b/tests/test_nodes.py index a27774ccd..c6ad1ab9f 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -361,7 +361,7 @@ def test_illegal_assignment(node: ValueNode, value: Any) -> None: def test_legal_assignment_enum( node_type: Type[EnumNode], enum_type: Type[Enum], - values: Tuple[Any], + values: Tuple[Any, ...], success_map: Dict[Any, Any], ) -> None: assert isinstance(values, (list, tuple)) diff --git a/tests/test_select.py b/tests/test_select.py index 68780476d..6fd2bd72d 100644 --- a/tests/test_select.py +++ b/tests/test_select.py @@ -16,6 +16,9 @@ def test_select_key_from_empty(struct: Optional[bool]) -> None: assert OmegaConf.select(c, "not_there") is None +@pytest.mark.parametrize( + "register_func", [OmegaConf.register_resolver, OmegaConf.register_new_resolver] +) @pytest.mark.parametrize( "cfg, key, expected", [ @@ -41,8 +44,10 @@ def test_select_key_from_empty(struct: Optional[bool]) -> None: pytest.param({"a": {"b": "one=${func:1}"}}, "a.b", "one=_1_", id="resolver"), ], ) -def test_select(restore_resolvers: Any, cfg: Any, key: Any, expected: Any) -> None: - OmegaConf.register_resolver("func", lambda x: f"_{x}_") +def test_select( + restore_resolvers: Any, cfg: Any, key: Any, expected: Any, register_func: Any +) -> None: + register_func("func", lambda x: f"_{x}_") cfg = _ensure_container(cfg) if isinstance(expected, RaisesContext): with expected: diff --git a/tests/test_utils.py b/tests/test_utils.py index 207e54230..0d3243ba2 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -6,7 +6,13 @@ from pytest import mark, param, raises from omegaconf import DictConfig, ListConfig, Node, OmegaConf, _utils -from omegaconf._utils import is_dict_annotation, is_list_annotation +from omegaconf._utils import ( + SIMPLE_INTERPOLATION_PATTERN, + _get_value, + _make_hashable, + is_dict_annotation, + is_list_annotation, +) from omegaconf.errors import UnsupportedValueType, ValidationError from omegaconf.nodes import ( AnyNode, @@ -251,7 +257,7 @@ class Dataclass: (Dataclass, _utils.ValueKind.VALUE), ("???", _utils.ValueKind.MANDATORY_MISSING), ("${foo.bar}", _utils.ValueKind.INTERPOLATION), - ("ftp://${host}/path", _utils.ValueKind.STR_INTERPOLATION), + ("ftp://${host}/path", _utils.ValueKind.INTERPOLATION), ("${func:foo}", _utils.ValueKind.INTERPOLATION), ("${func:a/b}", _utils.ValueKind.INTERPOLATION), ("${func:c:\\a\\b}", _utils.ValueKind.INTERPOLATION), @@ -550,3 +556,90 @@ def test_get_node_ref_type(obj: Any, key: str, expected: Any) -> None: def test_get_ref_type_error() -> None: with raises(ValueError): _utils.get_ref_type(AnyNode(), "foo") + + +@mark.parametrize( + "value", + [ + 1, + None, + {"a": 0}, + [1, 2, 3], + ], +) +def test_get_value_basic(value: Any) -> None: + val_node = _node_wrap( + value=value, type_=Any, parent=None, is_optional=True, key=None + ) + assert _get_value(val_node) == value + + +@mark.parametrize( + "content", + [{"a": 0, "b": 1}, "???", None, "${bar}"], +) +def test_get_value_container(content: Any) -> None: + cfg = DictConfig({}) + cfg._set_value(content) + assert _get_value(cfg) == content + + +@mark.parametrize( + "input_1,input_2", + [ + (0, 0), + ([0, 1], (0, 1)), + ([0, (1, 2)], (0, [1, 2])), + ({0: 1, 1: 2}, {1: 2, 0: 1}), + ({"": 1, 0: 2}, {0: 2, "": 1}), + ( + {1: 0, 1.1: 2.0, "1": "0", True: False, None: None}, + {None: None, 1.1: 2.0, True: False, "1": "0", 1: 0}, + ), + ], +) +def test_make_hashable(input_1: Any, input_2: Any) -> None: + out_1, out_2 = _make_hashable(input_1), _make_hashable(input_2) + assert out_1 == out_2 + hash_1, hash_2 = hash(out_1), hash(out_2) + assert hash_1 == hash_2 + + +def test_make_hashable_type_error() -> None: + with raises(TypeError): + _make_hashable({...: 0, None: 0}) + + +@mark.parametrize( + "expression", + [ + "${foo}", + "${foo.bar}", + "${a_b.c123}", + "${ foo \t}", + "x ${ab.cd.ef.gh} y", + "$ ${foo} ${bar} ${boz} $", + "${foo:bar}", + "${foo : bar, baz, boz}", + "${foo:bar,0,a-b+c*d/$.%@}", + "\\${foo}", + ], +) +def test_match_simple_interpolation_pattern(expression: str) -> None: + assert SIMPLE_INTERPOLATION_PATTERN.match(expression) is not None + + +@mark.parametrize( + "expression", + [ + "${foo", + "${0foo}", + "${0foo:bar}", + "${foo.${bar}}", + "${foo:${bar}}", + "${foo:'hello'}", + "\\${foo", + ], +) +def test_do_not_match_simple_interpolation_pattern(expression: str) -> None: + assert SIMPLE_INTERPOLATION_PATTERN.match(expression) is None