From 68fb1db7c94d4a5f2bd4cd0d5bd8f6c72c4c6875 Mon Sep 17 00:00:00 2001 From: Olivier Delalleau <507137+odelalleau@users.noreply.github.com> Date: Mon, 14 Dec 2020 12:32:25 -0500 Subject: [PATCH 01/16] Allow more types of dictionary keys in overrides grammar --- .../core/override_parser/overrides_visitor.py | 26 +++++---- hydra/core/override_parser/types.py | 18 +++++- hydra/grammar/OverrideParser.g4 | 23 ++++++-- tests/test_hydra.py | 55 +++++++++++++++++++ tests/test_overrides_parser.py | 38 +++++++++++-- .../docs/advanced/override_grammar/basic.md | 23 ++++++-- .../advanced/override_grammar/basic.md | 23 ++++++-- 7 files changed, 174 insertions(+), 32 deletions(-) diff --git a/hydra/core/override_parser/overrides_visitor.py b/hydra/core/override_parser/overrides_visitor.py index 36b5a86d82b..08fb49b478e 100644 --- a/hydra/core/override_parser/overrides_visitor.py +++ b/hydra/core/override_parser/overrides_visitor.py @@ -76,7 +76,7 @@ def is_ws(self, c: Any) -> bool: return isinstance(c, TerminalNodeImpl) and c.symbol.type == OverrideLexer.WS def visitPrimitive( - self, ctx: OverrideParser.PrimitiveContext + self, ctx: Union[OverrideParser.PrimitiveContext, OverrideParser.DictKeyContext] ) -> Optional[Union[QuotedString, int, bool, float, str]]: ret: Optional[Union[int, bool, float, str]] first_idx = 0 @@ -144,8 +144,8 @@ def visitPrimitive( return node.getText() # type: ignore return ret - def visitListValue( - self, ctx: OverrideParser.ListValueContext + def visitListContainer( + self, ctx: OverrideParser.ListContainerContext ) -> List[ParsedElementType]: ret: List[ParsedElementType] = [] @@ -159,8 +159,8 @@ def visitListValue( ret.append(self.visitElement(element)) return ret - def visitDictValue( - self, ctx: OverrideParser.DictValueContext + def visitDictContainer( + self, ctx: OverrideParser.DictContainerContext ) -> Dict[str, ParsedElementType]: assert self.is_matching_terminal(ctx.getChild(0), OverrideLexer.BRACE_OPEN) return dict( @@ -168,13 +168,17 @@ def visitDictValue( for i in range(1, ctx.getChildCount() - 1, 2) ) + def visitDictKey(self, ctx: OverrideParser.DictKeyContext) -> Any: + # Dictionary keys are a subset of primitives, they may thus be parsed as such. + return self.visitPrimitive(ctx) + def visitDictKeyValuePair( self, ctx: OverrideParser.DictKeyValuePairContext ) -> Tuple[str, ParsedElementType]: children = ctx.getChildren() item = next(children) - assert self.is_matching_terminal(item, OverrideLexer.ID) - pkey = item.getText() + assert isinstance(item, OverrideParser.DictKeyContext) + pkey = self.visitDictKey(item) assert self.is_matching_terminal(next(children), OverrideLexer.COLON) value = next(children) assert isinstance(value, OverrideParser.ElementContext) @@ -186,10 +190,10 @@ def visitElement(self, ctx: OverrideParser.ElementContext) -> ParsedElementType: return self.visitFunction(ctx.function()) # type: ignore elif ctx.primitive(): return self.visitPrimitive(ctx.primitive()) - elif ctx.listValue(): - return self.visitListValue(ctx.listValue()) - elif ctx.dictValue(): - return self.visitDictValue(ctx.dictValue()) + elif ctx.listContainer(): + return self.visitListContainer(ctx.listContainer()) + elif ctx.dictContainer(): + return self.visitDictContainer(ctx.dictContainer()) else: assert False diff --git a/hydra/core/override_parser/types.py b/hydra/core/override_parser/types.py index dfef6798b14..42b2fac06d9 100644 --- a/hydra/core/override_parser/types.py +++ b/hydra/core/override_parser/types.py @@ -26,6 +26,16 @@ class QuotedString: quote: Quote + def __hash__(self) -> int: + return hash(self.text) + + def __eq__(self, other: Any) -> Any: + # We do not care whether quotes match for equality. + if isinstance(other, QuotedString): + return self.text == other.text + else: + return NotImplemented + def with_quotes(self) -> str: if self.quote == Quote.single: q = "'" @@ -258,7 +268,10 @@ def _convert_value(value: ParsedElementType) -> Optional[ElementType]: if isinstance(value, list): return [Override._convert_value(x) for x in value] elif isinstance(value, dict): - return {k: Override._convert_value(v) for k, v in value.items()} + return { + Override._convert_value(k): Override._convert_value(v) + for k, v in value.items() + } elif isinstance(value, QuotedString): return value.text else: @@ -413,7 +426,8 @@ def _get_value_element_as_str( elif isinstance(value, dict): s = comma.join( [ - f"{k}{colon}{Override._get_value_element_as_str(v, space_after_sep=space_after_sep)}" + f"{Override._get_value_element_as_str(k)}{colon}" + f"{Override._get_value_element_as_str(v, space_after_sep=space_after_sep)}" for k, v in value.items() ] ) diff --git a/hydra/grammar/OverrideParser.g4 b/hydra/grammar/OverrideParser.g4 index 8645a4147da..953c736cecc 100644 --- a/hydra/grammar/OverrideParser.g4 +++ b/hydra/grammar/OverrideParser.g4 @@ -31,8 +31,8 @@ value: element | simpleChoiceSweep; element: primitive - | listValue - | dictValue + | listContainer + | dictContainer | function ; @@ -47,12 +47,12 @@ function: ID POPEN (argName? element (COMMA argName? element )* )? PCLOSE; // Data structures. -listValue: BRACKET_OPEN // [], [1,2,3], [a,b,[1,2]] +listContainer: BRACKET_OPEN // [], [1,2,3], [a,b,[1,2]] (element(COMMA element)*)? BRACKET_CLOSE; -dictValue: BRACE_OPEN (dictKeyValuePair (COMMA dictKeyValuePair)*)? BRACE_CLOSE; // {}, {a:10,b:20} -dictKeyValuePair: ID COLON element; +dictContainer: BRACE_OPEN (dictKeyValuePair (COMMA dictKeyValuePair)*)? BRACE_CLOSE; // {}, {a:10,b:20} +dictKeyValuePair: dictKey COLON element; // Primitive types. @@ -69,3 +69,16 @@ primitive: | ESC // \\, \(, \), \[, \], \{, \}, \:, \=, \ , \\t, \, | WS // whitespaces )+; + +// Same as `primitive` except that `COLON` and `INTERPOLATION` are not allowed. +dictKey: + QUOTED_VALUE // 'hello world', "hello world" + | ( ID // foo_10 + | NULL // null, NULL + | INT // 0, 10, -20, 1_000_000 + | FLOAT // 3.14, -20.0, 1e-1, -10e3 + | BOOL // true, TrUe, false, False + | UNQUOTED_CHAR // /, -, \, +, ., $, %, *, @ + | ESC // \\, \(, \), \[, \], \{, \}, \:, \=, \ , \\t, \, + | WS // whitespaces + )+; diff --git a/tests/test_hydra.py b/tests/test_hydra.py index 2c5628ebf49..18609d49184 100644 --- a/tests/test_hydra.py +++ b/tests/test_hydra.py @@ -1047,6 +1047,29 @@ def test_run_pass_list(self, cmd_base: List[str], tmpdir: Any) -> None: ret, _err = get_run_output(cmd) assert OmegaConf.create(ret) == OmegaConf.create(expected) + def test_multirun_dict_keys(self, cmd_base: List[str], tmpdir: Any) -> None: + cmd = cmd_base + [ + "+foo={'null': 0},{'NuLl': 1},{123abc: 0},{/-\\+.$%*@: 1},{white space: 3}", + "--multirun", + ] + expected = """\ +foo: + 'null': 0 + +foo: + NuLl: 1 + +foo: + 123abc: 0 + +foo: + /-\\+.$%*@: 1 + +foo: + white space: 3""" + ret, _err = get_run_output(cmd) + assert normalize_newlines(ret) == normalize_newlines(expected) + def test_app_with_error_exception_sanitized(tmpdir: Any, monkeypatch: Any) -> None: monkeypatch.chdir("tests/test_apps/app_with_runtime_config_error") @@ -1185,3 +1208,35 @@ def test_structured_with_none_list(monkeypatch: Any, tmpdir: Path) -> None: ] result, _err = get_run_output(cmd) assert result == "{'list': None}" + + +def test_overrides_dict_keys(tmpdir: Path) -> None: + """Test that different types of dictionary keys can be overridden""" + # Not currently testing non-string keys since they are not supported + # by OmegaConf. + cfg = OmegaConf.create( + { + "foo": { + "quoted_$(){}[]": 0, + "id123": 0, + "123id": 0, + "a/-\\+.$%*@": 0, + "\\()[]{}:= \t,": 0, + "white space": 0, + } + } + ) + integration_test( + tmpdir=tmpdir, + task_config=cfg, + overrides=[ + "foo={'quoted_$(){}[]': 1, id123: 1, 123id: 1, a/-\\+.$%*@: 1, " + "\\\\\\(\\)\\[\\]\\{\\}\\:\\=\\ \\\t\\,: 1, white space: 1}" + ], + prints=( + "','.join(map(repr, [cfg.foo[x] for x in [" + "'quoted_$(){}[]', 'id123', '123id', 'a/-\\+.$%*@', '\\()[]{}:= \t,', 'white space'" + "]]))" + ), + expected_outputs="1,1,1,1,1,1", + ) diff --git a/tests/test_overrides_parser.py b/tests/test_overrides_parser.py index 5e7af0e2572..839f16027eb 100644 --- a/tests/test_overrides_parser.py +++ b/tests/test_overrides_parser.py @@ -206,8 +206,8 @@ def test_value(value: str, expected: Any) -> None: pytest.param("[1,[a]]", [1, ["a"]], id="list:simple_and_list_elements"), ], ) -def test_list_value(value: str, expected: Any) -> None: - ret = parse_rule(value, "listValue") +def test_list_container(value: str, expected: Any) -> None: + ret = parse_rule(value, "listContainer") assert ret == expected @@ -277,10 +277,40 @@ def test_shuffle_sequence(value: str, expected: Any) -> None: pytest.param("{a:10,b:20}", {"a": 10, "b": 20}, id="dict"), pytest.param("{a:10,b:{}}", {"a": 10, "b": {}}, id="dict"), pytest.param("{a:10,b:{c:[1,2]}}", {"a": 10, "b": {"c": [1, 2]}}, id="dict"), + pytest.param( + "{'0a': 0, \"1b\": 1}", + { + QuotedString(text="0a", quote=Quote.double): 0, + QuotedString(text="1b", quote=Quote.single): 1, + }, + id="dict_quoted_key", + ), + pytest.param("{null: 1}", {None: 1}, id="dict_null_key"), + pytest.param("{123: 1, 0: 2, -1: 3}", {123: 1, 0: 2, -1: 3}, id="dict_int_key"), + pytest.param("{3.14: 0, 1e3: 1}", {3.14: 0, 1000.0: 1}, id="dict_float_key"), + pytest.param("{true: 1, fAlSe: 0}", {True: 1, False: 0}, id="dict_bool_key"), + pytest.param("{/-\\+.$%*@: 1}", {"/-\\+.$%*@": 1}, id="dict_unquoted_char_key"), + pytest.param( + "{\\\\\\(\\)\\[\\]\\{\\}\\:\\=\\ \\\t\\,: 1}", + {"\\()[]{}:= \t,": 1}, + id="dict_esc_key", + ), + pytest.param("{white spaces: 1}", {"white spaces": 1}, id="dict_ws_key"), + pytest.param( + "{'a:b': 1, ab 123.5 True: 2, null false: 3, 1: 4, null: 5}", + { + QuotedString(text="a:b", quote=Quote.single): 1, + "ab 123.5 True": 2, + "null false": 3, + 1: 4, + None: 5, + }, + id="dict_mixed_keys", + ), ], ) -def test_dict_value(value: str, expected: Any) -> None: - ret = parse_rule(value, "dictValue") +def test_dict_container(value: str, expected: Any) -> None: + ret = parse_rule(value, "dictContainer") assert ret == expected diff --git a/website/docs/advanced/override_grammar/basic.md b/website/docs/advanced/override_grammar/basic.md index b6ef50a9b57..912b47f95ad 100644 --- a/website/docs/advanced/override_grammar/basic.md +++ b/website/docs/advanced/override_grammar/basic.md @@ -57,8 +57,8 @@ value: element | simpleChoiceSweep; element: primitive - | listValue - | dictValue + | listContainer + | dictContainer | function ; @@ -73,12 +73,12 @@ function: ID POPEN (argName? element (COMMA argName? element )* )? PCLOSE; // Data structures. -listValue: BRACKET_OPEN // [], [1,2,3], [a,b,[1,2]] +listContainer: BRACKET_OPEN // [], [1,2,3], [a,b,[1,2]] (element(COMMA element)*)? BRACKET_CLOSE; -dictValue: BRACE_OPEN (dictKeyValuePair (COMMA dictKeyValuePair)*)? BRACE_CLOSE; // {}, {a:10,b:20} -dictKeyValuePair: ID COLON element; +dictContainer: BRACE_OPEN (dictKeyValuePair (COMMA dictKeyValuePair)*)? BRACE_CLOSE; // {}, {a:10,b:20} +dictKeyValuePair: dictKey COLON element; // Primitive types. @@ -95,6 +95,19 @@ primitive: | ESC // \\, \(, \), \[, \], \{, \}, \:, \=, \ , \\t, \, | WS // whitespaces )+; + +// Same as `primitive` except that `COLON` and `INTERPOLATION` are not allowed. +dictKey: + QUOTED_VALUE // 'hello world', "hello world" + | ( ID // foo_10 + | NULL // null, NULL + | INT // 0, 10, -20, 1_000_000 + | FLOAT // 3.14, -20.0, 1e-1, -10e3 + | BOOL // true, TrUe, false, False + | UNQUOTED_CHAR // /, -, \, +, ., $, %, *, @ + | ESC // \\, \(, \), \[, \], \{, \}, \:, \=, \ , \\t, \, + | WS // whitespaces + )+; ``` ## Elements diff --git a/website/versioned_docs/version-1.0/advanced/override_grammar/basic.md b/website/versioned_docs/version-1.0/advanced/override_grammar/basic.md index b6ef50a9b57..912b47f95ad 100644 --- a/website/versioned_docs/version-1.0/advanced/override_grammar/basic.md +++ b/website/versioned_docs/version-1.0/advanced/override_grammar/basic.md @@ -57,8 +57,8 @@ value: element | simpleChoiceSweep; element: primitive - | listValue - | dictValue + | listContainer + | dictContainer | function ; @@ -73,12 +73,12 @@ function: ID POPEN (argName? element (COMMA argName? element )* )? PCLOSE; // Data structures. -listValue: BRACKET_OPEN // [], [1,2,3], [a,b,[1,2]] +listContainer: BRACKET_OPEN // [], [1,2,3], [a,b,[1,2]] (element(COMMA element)*)? BRACKET_CLOSE; -dictValue: BRACE_OPEN (dictKeyValuePair (COMMA dictKeyValuePair)*)? BRACE_CLOSE; // {}, {a:10,b:20} -dictKeyValuePair: ID COLON element; +dictContainer: BRACE_OPEN (dictKeyValuePair (COMMA dictKeyValuePair)*)? BRACE_CLOSE; // {}, {a:10,b:20} +dictKeyValuePair: dictKey COLON element; // Primitive types. @@ -95,6 +95,19 @@ primitive: | ESC // \\, \(, \), \[, \], \{, \}, \:, \=, \ , \\t, \, | WS // whitespaces )+; + +// Same as `primitive` except that `COLON` and `INTERPOLATION` are not allowed. +dictKey: + QUOTED_VALUE // 'hello world', "hello world" + | ( ID // foo_10 + | NULL // null, NULL + | INT // 0, 10, -20, 1_000_000 + | FLOAT // 3.14, -20.0, 1e-1, -10e3 + | BOOL // true, TrUe, false, False + | UNQUOTED_CHAR // /, -, \, +, ., $, %, *, @ + | ESC // \\, \(, \), \[, \], \{, \}, \:, \=, \ , \\t, \, + | WS // whitespaces + )+; ``` ## Elements From 1478301b471eb9c5ab4463f485b7c859254134ef Mon Sep 17 00:00:00 2001 From: Olivier Delalleau <507137+odelalleau@users.noreply.github.com> Date: Mon, 14 Dec 2020 16:29:57 -0500 Subject: [PATCH 02/16] Typechecking fix --- hydra/core/override_parser/types.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/hydra/core/override_parser/types.py b/hydra/core/override_parser/types.py index 42b2fac06d9..377b8bba9fd 100644 --- a/hydra/core/override_parser/types.py +++ b/hydra/core/override_parser/types.py @@ -152,9 +152,9 @@ def __eq__(self, other: Any) -> Any: return NotImplemented -# Ideally we would use List[ElementType] and Dict[str, ElementType] but Python does not seem -# to support recursive type definitions. -ElementType = Union[str, int, float, bool, List[Any], Dict[str, Any]] +# Ideally we would use List[ElementType] and Dict[ElementType, ElementType] but Python +# does not seem to support recursive type definitions. +ElementType = Union[str, int, float, bool, List[Any], Dict[Any, Any]] ParsedElementType = Optional[Union[ElementType, QuotedString]] TransformerType = Callable[[ParsedElementType], Any] From ee5eaa80750adabfc22529574c7601072dfc6070 Mon Sep 17 00:00:00 2001 From: Olivier Delalleau <507137+odelalleau@users.noreply.github.com> Date: Wed, 16 Dec 2020 16:33:14 -0500 Subject: [PATCH 03/16] Minor refactoring of override grammar visitor --- hydra/core/override_parser/overrides_visitor.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/hydra/core/override_parser/overrides_visitor.py b/hydra/core/override_parser/overrides_visitor.py index 08fb49b478e..89138616fcb 100644 --- a/hydra/core/override_parser/overrides_visitor.py +++ b/hydra/core/override_parser/overrides_visitor.py @@ -76,6 +76,11 @@ def is_ws(self, c: Any) -> bool: return isinstance(c, TerminalNodeImpl) and c.symbol.type == OverrideLexer.WS def visitPrimitive( + self, ctx: OverrideParser.PrimitiveContext + ) -> Optional[Union[QuotedString, int, bool, float, str]]: + return self.visitPrimitiveOrDictKey(ctx) + + def visitPrimitiveOrDictKey( self, ctx: Union[OverrideParser.PrimitiveContext, OverrideParser.DictKeyContext] ) -> Optional[Union[QuotedString, int, bool, float, str]]: ret: Optional[Union[int, bool, float, str]] @@ -169,8 +174,7 @@ def visitDictContainer( ) def visitDictKey(self, ctx: OverrideParser.DictKeyContext) -> Any: - # Dictionary keys are a subset of primitives, they may thus be parsed as such. - return self.visitPrimitive(ctx) + return self.visitPrimitiveOrDictKey(ctx) def visitDictKeyValuePair( self, ctx: OverrideParser.DictKeyValuePairContext From a99c690f67361829d3c1e62cc836589579914886 Mon Sep 17 00:00:00 2001 From: Olivier Delalleau <507137+odelalleau@users.noreply.github.com> Date: Wed, 16 Dec 2020 16:37:57 -0500 Subject: [PATCH 04/16] Revert changes to 1.0 versioned doc --- .../advanced/override_grammar/basic.md | 23 ++++--------------- 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/website/versioned_docs/version-1.0/advanced/override_grammar/basic.md b/website/versioned_docs/version-1.0/advanced/override_grammar/basic.md index 912b47f95ad..b6ef50a9b57 100644 --- a/website/versioned_docs/version-1.0/advanced/override_grammar/basic.md +++ b/website/versioned_docs/version-1.0/advanced/override_grammar/basic.md @@ -57,8 +57,8 @@ value: element | simpleChoiceSweep; element: primitive - | listContainer - | dictContainer + | listValue + | dictValue | function ; @@ -73,12 +73,12 @@ function: ID POPEN (argName? element (COMMA argName? element )* )? PCLOSE; // Data structures. -listContainer: BRACKET_OPEN // [], [1,2,3], [a,b,[1,2]] +listValue: BRACKET_OPEN // [], [1,2,3], [a,b,[1,2]] (element(COMMA element)*)? BRACKET_CLOSE; -dictContainer: BRACE_OPEN (dictKeyValuePair (COMMA dictKeyValuePair)*)? BRACE_CLOSE; // {}, {a:10,b:20} -dictKeyValuePair: dictKey COLON element; +dictValue: BRACE_OPEN (dictKeyValuePair (COMMA dictKeyValuePair)*)? BRACE_CLOSE; // {}, {a:10,b:20} +dictKeyValuePair: ID COLON element; // Primitive types. @@ -95,19 +95,6 @@ primitive: | ESC // \\, \(, \), \[, \], \{, \}, \:, \=, \ , \\t, \, | WS // whitespaces )+; - -// Same as `primitive` except that `COLON` and `INTERPOLATION` are not allowed. -dictKey: - QUOTED_VALUE // 'hello world', "hello world" - | ( ID // foo_10 - | NULL // null, NULL - | INT // 0, 10, -20, 1_000_000 - | FLOAT // 3.14, -20.0, 1e-1, -10e3 - | BOOL // true, TrUe, false, False - | UNQUOTED_CHAR // /, -, \, +, ., $, %, *, @ - | ESC // \\, \(, \), \[, \], \{, \}, \:, \=, \ , \\t, \, - | WS // whitespaces - )+; ``` ## Elements From 0b532092a3dfee7b6180d942d1c29e374fe3ae31 Mon Sep 17 00:00:00 2001 From: Olivier Delalleau <507137+odelalleau@users.noreply.github.com> Date: Wed, 16 Dec 2020 16:48:35 -0500 Subject: [PATCH 05/16] Remove irrelevant comments --- tests/test_config_loader.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_config_loader.py b/tests/test_config_loader.py index 9d6a91d756b..d26893d80b6 100644 --- a/tests/test_config_loader.py +++ b/tests/test_config_loader.py @@ -1152,14 +1152,12 @@ def test_apply_overrides_to_config( id="default_change", ), pytest.param( - # need to unset optimizer config group first, otherwise they get merged "config", ["optimizer={type:nesterov2,lr:1}"], {"optimizer": {"type": "nesterov2", "lr": 1}}, id="dict_merge", ), pytest.param( - # need to unset optimizer config group first, otherwise they get merged "config", ["+optimizer={foo:10}"], {"optimizer": {"type": "nesterov", "lr": 0.001, "foo": 10}}, From 6bcaac2077f2d1b8a08010da102ffdf6f0ba1f2a Mon Sep 17 00:00:00 2001 From: Olivier Delalleau <507137+odelalleau@users.noreply.github.com> Date: Wed, 16 Dec 2020 17:24:52 -0500 Subject: [PATCH 06/16] Move test from test_hydra to test_config_loader --- tests/test_config_loader.py | 32 ++++++++++++++++++++++++++++++++ tests/test_hydra.py | 32 -------------------------------- 2 files changed, 32 insertions(+), 32 deletions(-) diff --git a/tests/test_config_loader.py b/tests/test_config_loader.py index d26893d80b6..2fa62d7d0f5 100644 --- a/tests/test_config_loader.py +++ b/tests/test_config_loader.py @@ -593,6 +593,38 @@ def test_sweep_config_cache( monkeypatch.setenv("HOME", "/another/home/dir/") assert sweep_cfg.home == os.getenv("HOME") + @pytest.mark.parametrize( # type: ignore + "key,expected", + [ + pytest.param("id123", "id123", id="id"), + pytest.param("123id", "123id", id="int_plus_id"), + pytest.param("'quoted_single'", "quoted_single", id="quoted_single"), + pytest.param('"quoted_double"', "quoted_double", id="quoted_double"), + pytest.param("'quoted_$(){}[]'", "quoted_$(){}[]", id="quoted_misc_chars"), + pytest.param("a/-\\+.$%*@", "a/-\\+.$%*@", id="unquoted_misc_chars"), + pytest.param("white space", "white space", id="whitespace"), + pytest.param( + "\\\\\\(\\)\\[\\]\\{\\}\\:\\=\\ \\\t\\,", + "\\()[]{}:= \t,", + id="unquoted_esc", + ), + ], + ) + def test_dict_key_formats( + self, hydra_restore_singletons: Any, path: str, key: str, expected: str + ) -> None: + """Test that we can assign dictionaries with keys that are not just IDs""" + config_loader = ConfigLoaderImpl( + config_search_path=create_config_search_path(path) + ) + cfg = config_loader.load_configuration( + config_name="config.yaml", + overrides=[f"+dict={{{key}: 123}}"], + run_mode=RunMode.RUN, + ) + assert "dict" in cfg + assert cfg.dict == {expected: 123} + @pytest.mark.parametrize( # type:ignore "config_file, overrides", diff --git a/tests/test_hydra.py b/tests/test_hydra.py index 18609d49184..d0dcb20c12f 100644 --- a/tests/test_hydra.py +++ b/tests/test_hydra.py @@ -1208,35 +1208,3 @@ def test_structured_with_none_list(monkeypatch: Any, tmpdir: Path) -> None: ] result, _err = get_run_output(cmd) assert result == "{'list': None}" - - -def test_overrides_dict_keys(tmpdir: Path) -> None: - """Test that different types of dictionary keys can be overridden""" - # Not currently testing non-string keys since they are not supported - # by OmegaConf. - cfg = OmegaConf.create( - { - "foo": { - "quoted_$(){}[]": 0, - "id123": 0, - "123id": 0, - "a/-\\+.$%*@": 0, - "\\()[]{}:= \t,": 0, - "white space": 0, - } - } - ) - integration_test( - tmpdir=tmpdir, - task_config=cfg, - overrides=[ - "foo={'quoted_$(){}[]': 1, id123: 1, 123id: 1, a/-\\+.$%*@: 1, " - "\\\\\\(\\)\\[\\]\\{\\}\\:\\=\\ \\\t\\,: 1, white space: 1}" - ], - prints=( - "','.join(map(repr, [cfg.foo[x] for x in [" - "'quoted_$(){}[]', 'id123', '123id', 'a/-\\+.$%*@', '\\()[]{}:= \t,', 'white space'" - "]]))" - ), - expected_outputs="1,1,1,1,1,1", - ) From e249eb5ff5e421730c03c856601bbbbe0e20d914 Mon Sep 17 00:00:00 2001 From: Olivier Delalleau <507137+odelalleau@users.noreply.github.com> Date: Fri, 18 Dec 2020 09:42:04 -0500 Subject: [PATCH 07/16] Quote type matters again in quoted strings equality Also made the `QuotedString` dataclass immutable for hash safety. --- hydra/core/override_parser/types.py | 12 +----------- tests/test_overrides_parser.py | 4 ++-- 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/hydra/core/override_parser/types.py b/hydra/core/override_parser/types.py index 377b8bba9fd..719f1984fbb 100644 --- a/hydra/core/override_parser/types.py +++ b/hydra/core/override_parser/types.py @@ -20,22 +20,12 @@ class Quote(Enum): double = 1 -@dataclass +@dataclass(frozen=True) class QuotedString: text: str quote: Quote - def __hash__(self) -> int: - return hash(self.text) - - def __eq__(self, other: Any) -> Any: - # We do not care whether quotes match for equality. - if isinstance(other, QuotedString): - return self.text == other.text - else: - return NotImplemented - def with_quotes(self) -> str: if self.quote == Quote.single: q = "'" diff --git a/tests/test_overrides_parser.py b/tests/test_overrides_parser.py index 839f16027eb..b882c196467 100644 --- a/tests/test_overrides_parser.py +++ b/tests/test_overrides_parser.py @@ -280,8 +280,8 @@ def test_shuffle_sequence(value: str, expected: Any) -> None: pytest.param( "{'0a': 0, \"1b\": 1}", { - QuotedString(text="0a", quote=Quote.double): 0, - QuotedString(text="1b", quote=Quote.single): 1, + QuotedString(text="0a", quote=Quote.single): 0, + QuotedString(text="1b", quote=Quote.double): 1, }, id="dict_quoted_key", ), From fc874646ea0243a553e4bb804014e73b4e163620 Mon Sep 17 00:00:00 2001 From: Olivier Delalleau <507137+odelalleau@users.noreply.github.com> Date: Fri, 18 Dec 2020 10:47:52 -0500 Subject: [PATCH 08/16] Remove spurious print --- hydra/core/override_parser/types.py | 1 - 1 file changed, 1 deletion(-) diff --git a/hydra/core/override_parser/types.py b/hydra/core/override_parser/types.py index 719f1984fbb..a8a2e9e689d 100644 --- a/hydra/core/override_parser/types.py +++ b/hydra/core/override_parser/types.py @@ -425,7 +425,6 @@ def _get_value_element_as_str( elif isinstance(value, (str, int, bool, float)): return str(value) elif is_structured_config(value): - print(value) return Override._get_value_element_as_str( OmegaConf.to_container(OmegaConf.structured(value)) ) From 8084317f1c9675c49047f6059957197d3d4d8811 Mon Sep 17 00:00:00 2001 From: Olivier Delalleau <507137+odelalleau@users.noreply.github.com> Date: Fri, 18 Dec 2020 10:46:12 -0500 Subject: [PATCH 09/16] Fix escaped characters in sweeps Also improved the corresponding tests for dictionary keys. --- hydra/_internal/grammar/utils.py | 28 +++++++++++++++++++++++++++ hydra/core/override_parser/types.py | 5 ++++- hydra/grammar/OverrideLexer.g4 | 2 ++ tests/test_hydra.py | 23 ---------------------- tests/test_overrides_parser.py | 30 +++++++++++++++++++++++++++++ 5 files changed, 64 insertions(+), 24 deletions(-) diff --git a/hydra/_internal/grammar/utils.py b/hydra/_internal/grammar/utils.py index 765bf9365fe..60f7ae26af0 100644 --- a/hydra/_internal/grammar/utils.py +++ b/hydra/_internal/grammar/utils.py @@ -1,9 +1,37 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import inspect +import re from typing import Any, Union from omegaconf._utils import is_dict_annotation, is_list_annotation +# All characters that must be escaped (must match the ESC grammar lexer token). +_ESC = "\\()[]{}:=, \t" + +# Regular expression that matches any sequence of characters in `_ESC`. +_ESC_REGEX = re.compile(f"[{re.escape(_ESC)}]+") + + +def escape(s: str) -> str: + """Escape special characters in `s`""" + matches = _ESC_REGEX.findall(s) + if not matches: + return s + # Replace all special characters found in `s`. Performance should not be critical + # so we do one pass per special character. + all_special = set("".join(matches)) + # '\' is even more special: it needs to be replaced first, otherwise we will + # mess up the other escaped characters. + try: + all_special.remove("\\") + except KeyError: + pass # no '\' in the string + else: + s = s.replace("\\", "\\\\") + for special_char in all_special: + s = s.replace(special_char, f"\\{special_char}") + return s + def is_type_matching(value: Any, type_: Any) -> bool: # Union diff --git a/hydra/core/override_parser/types.py b/hydra/core/override_parser/types.py index a8a2e9e689d..5d3d1d73eaa 100644 --- a/hydra/core/override_parser/types.py +++ b/hydra/core/override_parser/types.py @@ -10,6 +10,7 @@ from omegaconf import OmegaConf from omegaconf._utils import is_structured_config +from hydra._internal.grammar.utils import escape from hydra.core.config_loader import ConfigLoader from hydra.core.object_type import ObjectType from hydra.errors import HydraException @@ -422,7 +423,9 @@ def _get_value_element_as_str( ] ) return "{" + s + "}" - elif isinstance(value, (str, int, bool, float)): + elif isinstance(value, str): + return escape(value) # ensure special characters are escaped + elif isinstance(value, (int, bool, float)): return str(value) elif is_structured_config(value): return Override._get_value_element_as_str( diff --git a/hydra/grammar/OverrideLexer.g4 b/hydra/grammar/OverrideLexer.g4 index 41fb9b606a2..8e25ff4d777 100644 --- a/hydra/grammar/OverrideLexer.g4 +++ b/hydra/grammar/OverrideLexer.g4 @@ -61,6 +61,8 @@ NULL: [Nn][Uu][Ll][Ll]; UNQUOTED_CHAR: [/\-\\+.$%*@]; // other characters allowed in unquoted strings ID: (CHAR|'_') (CHAR|DIGIT|'_')*; +// Note: when adding more characters to the ESC rule below, also add them to +// the `_ESC` string in `_internal/grammar/utils.py`. ESC: (ESC_BACKSLASH | '\\(' | '\\)' | '\\[' | '\\]' | '\\{' | '\\}' | '\\:' | '\\=' | '\\,' | '\\ ' | '\\\t')+; WS: [ \t]+; diff --git a/tests/test_hydra.py b/tests/test_hydra.py index d0dcb20c12f..2c5628ebf49 100644 --- a/tests/test_hydra.py +++ b/tests/test_hydra.py @@ -1047,29 +1047,6 @@ def test_run_pass_list(self, cmd_base: List[str], tmpdir: Any) -> None: ret, _err = get_run_output(cmd) assert OmegaConf.create(ret) == OmegaConf.create(expected) - def test_multirun_dict_keys(self, cmd_base: List[str], tmpdir: Any) -> None: - cmd = cmd_base + [ - "+foo={'null': 0},{'NuLl': 1},{123abc: 0},{/-\\+.$%*@: 1},{white space: 3}", - "--multirun", - ] - expected = """\ -foo: - 'null': 0 - -foo: - NuLl: 1 - -foo: - 123abc: 0 - -foo: - /-\\+.$%*@: 1 - -foo: - white space: 3""" - ret, _err = get_run_output(cmd) - assert normalize_newlines(ret) == normalize_newlines(expected) - def test_app_with_error_exception_sanitized(tmpdir: Any, monkeypatch: Any) -> None: monkeypatch.chdir("tests/test_apps/app_with_runtime_config_error") diff --git a/tests/test_overrides_parser.py b/tests/test_overrides_parser.py index b882c196467..d14e1c7e211 100644 --- a/tests/test_overrides_parser.py +++ b/tests/test_overrides_parser.py @@ -959,6 +959,12 @@ def test_get_key_element(override: str, expected: str) -> None: pytest.param("key='value'", "'value'", False, id="single_quoted"), pytest.param('key="value"', '"value"', False, id="double_quoted"), pytest.param("key='שלום'", "'שלום'", False, id="quoted_unicode"), + pytest.param( + "key=\\\\\\(\\)\\[\\]\\{\\}\\:\\=\\ \\\t\\,", + "\\\\\\(\\)\\[\\]\\{\\}\\:\\=\\ \\\t\\,", + False, + id="escaped_chars", + ), pytest.param("key=10", "10", False, id="int"), pytest.param("key=3.1415", "3.1415", False, id="float"), pytest.param("key=[]", "[]", False, id="list"), @@ -972,6 +978,30 @@ def test_get_key_element(override: str, expected: str) -> None: pytest.param("key={a:10,b:20}", "{a:10,b:20}", False, id="dict"), pytest.param("key={a:10,b:20}", "{a: 10, b: 20}", True, id="dict"), pytest.param("key={a:10,b:[1,2,3]}", "{a: 10, b: [1, 2, 3]}", True, id="dict"), + pytest.param( + "key={'null':1, \"a:b\": 0}", + "{'null': 1, \"a:b\": 0}", + True, + id="dict_quoted_key", + ), + pytest.param( + "key={/-\\+.$%*@: 1}", + "{/-\\\\+.$%*@: 1}", # note that \ gets escaped + True, + id="dict_unquoted_key_special", + ), + pytest.param( + "key={ white space\t: 2}", + "{white\\ \\ space: 2}", + True, + id="dict_ws_in_key", + ), + pytest.param( + "key={\\\\\\(\\)\\[\\]\\{\\}\\:\\=\\ \\\t\\,: 2}", + "{\\\\\\(\\)\\[\\]\\{\\}\\:\\=\\ \\\t\\,: 2}", + True, + id="dict_esc_key", + ), ], ) def test_override_get_value_element_method( From cfe3932fc7af86ca0107c71a9fa0111bbbd3ee24 Mon Sep 17 00:00:00 2001 From: Olivier Delalleau <507137+odelalleau@users.noreply.github.com> Date: Fri, 18 Dec 2020 14:14:51 -0500 Subject: [PATCH 10/16] Actually at this time only strings are allowed as dictionary keys --- hydra/core/override_parser/types.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/hydra/core/override_parser/types.py b/hydra/core/override_parser/types.py index 5d3d1d73eaa..f8166f29639 100644 --- a/hydra/core/override_parser/types.py +++ b/hydra/core/override_parser/types.py @@ -143,9 +143,9 @@ def __eq__(self, other: Any) -> Any: return NotImplemented -# Ideally we would use List[ElementType] and Dict[ElementType, ElementType] but Python -# does not seem to support recursive type definitions. -ElementType = Union[str, int, float, bool, List[Any], Dict[Any, Any]] +# Ideally we would use List[ElementType] and Dict[str, ElementType] but Python does not seem +# to support recursive type definitions. +ElementType = Union[str, int, float, bool, List[Any], Dict[str, Any]] ParsedElementType = Optional[Union[ElementType, QuotedString]] TransformerType = Callable[[ParsedElementType], Any] @@ -259,8 +259,14 @@ def _convert_value(value: ParsedElementType) -> Optional[ElementType]: if isinstance(value, list): return [Override._convert_value(x) for x in value] elif isinstance(value, dict): + + # Currently only strings are allowed as dictionary keys. + def check_str(k: Any) -> str: + assert isinstance(k, str) + return k + return { - Override._convert_value(k): Override._convert_value(v) + check_str(Override._convert_value(k)): Override._convert_value(v) for k, v in value.items() } elif isinstance(value, QuotedString): From 4af259e03588ead9962a45b81264a0cf5c2e1cb8 Mon Sep 17 00:00:00 2001 From: Olivier Delalleau <507137+odelalleau@users.noreply.github.com> Date: Fri, 18 Dec 2020 15:01:33 -0500 Subject: [PATCH 11/16] Add test for escape_special_characters() --- hydra/_internal/grammar/utils.py | 2 +- hydra/core/override_parser/types.py | 4 ++-- tests/test_internal_grammar.py | 27 +++++++++++++++++++++++++++ 3 files changed, 30 insertions(+), 3 deletions(-) create mode 100644 tests/test_internal_grammar.py diff --git a/hydra/_internal/grammar/utils.py b/hydra/_internal/grammar/utils.py index 60f7ae26af0..459e8413d91 100644 --- a/hydra/_internal/grammar/utils.py +++ b/hydra/_internal/grammar/utils.py @@ -12,7 +12,7 @@ _ESC_REGEX = re.compile(f"[{re.escape(_ESC)}]+") -def escape(s: str) -> str: +def escape_special_characters(s: str) -> str: """Escape special characters in `s`""" matches = _ESC_REGEX.findall(s) if not matches: diff --git a/hydra/core/override_parser/types.py b/hydra/core/override_parser/types.py index f8166f29639..be7611d24a1 100644 --- a/hydra/core/override_parser/types.py +++ b/hydra/core/override_parser/types.py @@ -10,7 +10,7 @@ from omegaconf import OmegaConf from omegaconf._utils import is_structured_config -from hydra._internal.grammar.utils import escape +from hydra._internal.grammar.utils import escape_special_characters from hydra.core.config_loader import ConfigLoader from hydra.core.object_type import ObjectType from hydra.errors import HydraException @@ -430,7 +430,7 @@ def _get_value_element_as_str( ) return "{" + s + "}" elif isinstance(value, str): - return escape(value) # ensure special characters are escaped + return escape_special_characters(value) elif isinstance(value, (int, bool, float)): return str(value) elif is_structured_config(value): diff --git a/tests/test_internal_grammar.py b/tests/test_internal_grammar.py new file mode 100644 index 00000000000..a01249961c6 --- /dev/null +++ b/tests/test_internal_grammar.py @@ -0,0 +1,27 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from pytest import mark, param + +from hydra._internal.grammar.utils import escape_special_characters + + +@mark.parametrize( # type: ignore + ("s", "expected"), + [ + param("abc", "abc", id="no_esc"), + param("\\", "\\\\", id="esc_backslash"), + param("\\\\\\", "\\\\\\\\\\\\", id="esc_backslash_x3"), + param("()", "\\(\\)", id="esc_parentheses"), + param("[]", "\\[\\]", id="esc_brackets"), + param("{}", "\\{\\}", id="esc_braces"), + param(":=,", "\\:\\=\\,", id="esc_symbols"), + param(" \t", "\\ \\ \\\t", id="esc_ws"), + param( + "ab\\(cd{ef}[gh]): ij,kl\t", + "ab\\\\\\(cd\\{ef\\}\\[gh\\]\\)\\:\\ ij\\,kl\\\t", + id="esc_mixed", + ), + ], +) +def test_escape_special_characters(s: str, expected: str) -> None: + escaped = escape_special_characters(s) + assert escaped == expected From 2e545f9ecf891c68acb04b65f74beab5547a63b2 Mon Sep 17 00:00:00 2001 From: Olivier Delalleau <507137+odelalleau@users.noreply.github.com> Date: Fri, 18 Dec 2020 15:47:42 -0500 Subject: [PATCH 12/16] Make new test lower level (and faster!) --- tests/test_config_loader.py | 32 -------------------------------- tests/test_overrides_parser.py | 10 ++++++++++ 2 files changed, 10 insertions(+), 32 deletions(-) diff --git a/tests/test_config_loader.py b/tests/test_config_loader.py index 2fa62d7d0f5..d26893d80b6 100644 --- a/tests/test_config_loader.py +++ b/tests/test_config_loader.py @@ -593,38 +593,6 @@ def test_sweep_config_cache( monkeypatch.setenv("HOME", "/another/home/dir/") assert sweep_cfg.home == os.getenv("HOME") - @pytest.mark.parametrize( # type: ignore - "key,expected", - [ - pytest.param("id123", "id123", id="id"), - pytest.param("123id", "123id", id="int_plus_id"), - pytest.param("'quoted_single'", "quoted_single", id="quoted_single"), - pytest.param('"quoted_double"', "quoted_double", id="quoted_double"), - pytest.param("'quoted_$(){}[]'", "quoted_$(){}[]", id="quoted_misc_chars"), - pytest.param("a/-\\+.$%*@", "a/-\\+.$%*@", id="unquoted_misc_chars"), - pytest.param("white space", "white space", id="whitespace"), - pytest.param( - "\\\\\\(\\)\\[\\]\\{\\}\\:\\=\\ \\\t\\,", - "\\()[]{}:= \t,", - id="unquoted_esc", - ), - ], - ) - def test_dict_key_formats( - self, hydra_restore_singletons: Any, path: str, key: str, expected: str - ) -> None: - """Test that we can assign dictionaries with keys that are not just IDs""" - config_loader = ConfigLoaderImpl( - config_search_path=create_config_search_path(path) - ) - cfg = config_loader.load_configuration( - config_name="config.yaml", - overrides=[f"+dict={{{key}: 123}}"], - run_mode=RunMode.RUN, - ) - assert "dict" in cfg - assert cfg.dict == {expected: 123} - @pytest.mark.parametrize( # type:ignore "config_file, overrides", diff --git a/tests/test_overrides_parser.py b/tests/test_overrides_parser.py index d14e1c7e211..19e2852dea7 100644 --- a/tests/test_overrides_parser.py +++ b/tests/test_overrides_parser.py @@ -1030,6 +1030,16 @@ def test_override_get_value_element_method( pytest.param("key={a:10,b:20}", {"a": 10, "b": 20}, id="dict"), pytest.param("key={a:10,b:20}", {"a": 10, "b": 20}, id="dict"), pytest.param("key={a:10,b:[1,2,3]}", {"a": 10, "b": [1, 2, 3]}, id="dict"), + pytest.param("key={123id: 0}", {"123id": 0}, id="dict_key_int_plus_id"), + pytest.param("key={' abc ': 0}", {" abc ": 0}, id="dict_key_quoted_single"), + pytest.param('key={" abc ": 0}', {" abc ": 0}, id="dict_key_quoted_double"), + pytest.param("key={a/-\\+.$%*@: 0}", {"a/-\\+.$%*@": 0}, id="dict_key_noquote"), + pytest.param("key={w s: 0}", {"w s": 0}, id="dict_key_ws"), + pytest.param( + "key={\\\\\\(\\)\\[\\]\\{\\}\\:\\=\\ \\\t\\,: 0}", + {"\\()[]{}:= \t,": 0}, + id="dict_key_esc", + ), ], ) def test_override_value_method(override: str, expected: str) -> None: From 0d745041ad22e811e97a467a61d9818bf5ef0dac Mon Sep 17 00:00:00 2001 From: Olivier Delalleau <507137+odelalleau@users.noreply.github.com> Date: Fri, 18 Dec 2020 17:28:11 -0500 Subject: [PATCH 13/16] Remove type check for more explicit errors --- hydra/core/override_parser/types.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/hydra/core/override_parser/types.py b/hydra/core/override_parser/types.py index be7611d24a1..6ff7b4118e0 100644 --- a/hydra/core/override_parser/types.py +++ b/hydra/core/override_parser/types.py @@ -260,13 +260,10 @@ def _convert_value(value: ParsedElementType) -> Optional[ElementType]: return [Override._convert_value(x) for x in value] elif isinstance(value, dict): - # Currently only strings are allowed as dictionary keys. - def check_str(k: Any) -> str: - assert isinstance(k, str) - return k - return { - check_str(Override._convert_value(k)): Override._convert_value(v) + # We ignore potential type mismatch here so as to let OmegaConf + # raise an explicit error in case of invalid type. + Override._convert_value(k): Override._convert_value(v) # type: ignore for k, v in value.items() } elif isinstance(value, QuotedString): From 0c894b3de674a24c759ad69b212163cf282e4604 Mon Sep 17 00:00:00 2001 From: Olivier Delalleau <507137+odelalleau@users.noreply.github.com> Date: Fri, 18 Dec 2020 18:13:20 -0500 Subject: [PATCH 14/16] Move test to `test_overrides_parser.py` --- tests/test_internal_grammar.py | 27 --------------------------- tests/test_overrides_parser.py | 24 ++++++++++++++++++++++++ 2 files changed, 24 insertions(+), 27 deletions(-) delete mode 100644 tests/test_internal_grammar.py diff --git a/tests/test_internal_grammar.py b/tests/test_internal_grammar.py deleted file mode 100644 index a01249961c6..00000000000 --- a/tests/test_internal_grammar.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -from pytest import mark, param - -from hydra._internal.grammar.utils import escape_special_characters - - -@mark.parametrize( # type: ignore - ("s", "expected"), - [ - param("abc", "abc", id="no_esc"), - param("\\", "\\\\", id="esc_backslash"), - param("\\\\\\", "\\\\\\\\\\\\", id="esc_backslash_x3"), - param("()", "\\(\\)", id="esc_parentheses"), - param("[]", "\\[\\]", id="esc_brackets"), - param("{}", "\\{\\}", id="esc_braces"), - param(":=,", "\\:\\=\\,", id="esc_symbols"), - param(" \t", "\\ \\ \\\t", id="esc_ws"), - param( - "ab\\(cd{ef}[gh]): ij,kl\t", - "ab\\\\\\(cd\\{ef\\}\\[gh\\]\\)\\:\\ ij\\,kl\\\t", - id="esc_mixed", - ), - ], -) -def test_escape_special_characters(s: str, expected: str) -> None: - escaped = escape_special_characters(s) - assert escaped == expected diff --git a/tests/test_overrides_parser.py b/tests/test_overrides_parser.py index 19e2852dea7..d6e5f2e9b46 100644 --- a/tests/test_overrides_parser.py +++ b/tests/test_overrides_parser.py @@ -9,6 +9,7 @@ from _pytest.python_api import RaisesContext from hydra._internal.grammar.functions import Functions +from hydra._internal.grammar.utils import escape_special_characters from hydra.core.override_parser.overrides_parser import ( OverridesParser, create_functions, @@ -2024,3 +2025,26 @@ def test_sweep_iterators( ] assert actual_sweep_string_list == expected_sweep_string_list assert actual_sweep_encoded_list == expected_sweep_encoded_list + + +@pytest.mark.parametrize( # type: ignore + ("s", "expected"), + [ + pytest.param("abc", "abc", id="no_esc"), + pytest.param("\\", "\\\\", id="esc_backslash"), + pytest.param("\\\\\\", "\\\\\\\\\\\\", id="esc_backslash_x3"), + pytest.param("()", "\\(\\)", id="esc_parentheses"), + pytest.param("[]", "\\[\\]", id="esc_brackets"), + pytest.param("{}", "\\{\\}", id="esc_braces"), + pytest.param(":=,", "\\:\\=\\,", id="esc_symbols"), + pytest.param(" \t", "\\ \\ \\\t", id="esc_ws"), + pytest.param( + "ab\\(cd{ef}[gh]): ij,kl\t", + "ab\\\\\\(cd\\{ef\\}\\[gh\\]\\)\\:\\ ij\\,kl\\\t", + id="esc_mixed", + ), + ], +) +def test_escape_special_characters(s: str, expected: str) -> None: + escaped = escape_special_characters(s) + assert escaped == expected From 8a16f54e485019ad5624d1c04bd7e6381dc5ad4b Mon Sep 17 00:00:00 2001 From: Olivier Delalleau <507137+odelalleau@users.noreply.github.com> Date: Fri, 18 Dec 2020 23:35:58 -0500 Subject: [PATCH 15/16] Refactor: rename visitPrimitiveOrDictKey() into _createPrimitive() --- .../core/override_parser/overrides_visitor.py | 144 +++++++++--------- 1 file changed, 72 insertions(+), 72 deletions(-) diff --git a/hydra/core/override_parser/overrides_visitor.py b/hydra/core/override_parser/overrides_visitor.py index 89138616fcb..eedf239e484 100644 --- a/hydra/core/override_parser/overrides_visitor.py +++ b/hydra/core/override_parser/overrides_visitor.py @@ -3,7 +3,7 @@ import warnings from typing import Any, Dict, List, Optional, Tuple, Union -from antlr4 import TerminalNode, Token +from antlr4 import ParserRuleContext, TerminalNode, Token from antlr4.error.ErrorListener import ErrorListener from antlr4.tree.Tree import TerminalNodeImpl @@ -78,76 +78,7 @@ def is_ws(self, c: Any) -> bool: def visitPrimitive( self, ctx: OverrideParser.PrimitiveContext ) -> Optional[Union[QuotedString, int, bool, float, str]]: - return self.visitPrimitiveOrDictKey(ctx) - - def visitPrimitiveOrDictKey( - self, ctx: Union[OverrideParser.PrimitiveContext, OverrideParser.DictKeyContext] - ) -> Optional[Union[QuotedString, int, bool, float, str]]: - ret: Optional[Union[int, bool, float, str]] - first_idx = 0 - last_idx = ctx.getChildCount() - # skip first if whitespace - if self.is_ws(ctx.getChild(0)): - if last_idx == 1: - # Only whitespaces => this is not allowed. - raise HydraException( - "Trying to parse a primitive that is all whitespaces" - ) - first_idx = 1 - if self.is_ws(ctx.getChild(-1)): - last_idx = last_idx - 1 - num = last_idx - first_idx - if num > 1: - # Concatenate, while un-escaping as needed. - tokens = [] - for i, n in enumerate(ctx.getChildren()): - if n.symbol.type == OverrideLexer.WS and ( - i < first_idx or i >= last_idx - ): - # Skip leading / trailing whitespaces. - continue - tokens.append( - n.symbol.text[1::2] # un-escape by skipping every other char - if n.symbol.type == OverrideLexer.ESC - else n.symbol.text - ) - ret = "".join(tokens) - else: - node = ctx.getChild(first_idx) - if node.symbol.type == OverrideLexer.QUOTED_VALUE: - text = node.getText() - qc = text[0] - text = text[1:-1] - if qc == "'": - quote = Quote.single - text = text.replace("\\'", "'") - elif qc == '"': - quote = Quote.double - text = text.replace('\\"', '"') - else: - assert False - return QuotedString(text=text, quote=quote) - elif node.symbol.type in (OverrideLexer.ID, OverrideLexer.INTERPOLATION): - ret = node.symbol.text - elif node.symbol.type == OverrideLexer.INT: - ret = int(node.symbol.text) - elif node.symbol.type == OverrideLexer.FLOAT: - ret = float(node.symbol.text) - elif node.symbol.type == OverrideLexer.NULL: - ret = None - elif node.symbol.type == OverrideLexer.BOOL: - text = node.getText().lower() - if text == "true": - ret = True - elif text == "false": - ret = False - else: - assert False - elif node.symbol.type == OverrideLexer.ESC: - ret = node.symbol.text[1::2] - else: - return node.getText() # type: ignore - return ret + return self._createPrimitive(ctx) def visitListContainer( self, ctx: OverrideParser.ListContainerContext @@ -174,7 +105,7 @@ def visitDictContainer( ) def visitDictKey(self, ctx: OverrideParser.DictKeyContext) -> Any: - return self.visitPrimitiveOrDictKey(ctx) + return self._createPrimitive(ctx) def visitDictKeyValuePair( self, ctx: OverrideParser.DictKeyValuePairContext @@ -313,6 +244,75 @@ def visitFunction(self, ctx: OverrideParser.FunctionContext) -> Any: f"{type(e).__name__} while evaluating '{ctx.getText()}': {e}" ) from e + def _createPrimitive( + self, ctx: ParserRuleContext + ) -> Optional[Union[QuotedString, int, bool, float, str]]: + ret: Optional[Union[int, bool, float, str]] + first_idx = 0 + last_idx = ctx.getChildCount() + # skip first if whitespace + if self.is_ws(ctx.getChild(0)): + if last_idx == 1: + # Only whitespaces => this is not allowed. + raise HydraException( + "Trying to parse a primitive that is all whitespaces" + ) + first_idx = 1 + if self.is_ws(ctx.getChild(-1)): + last_idx = last_idx - 1 + num = last_idx - first_idx + if num > 1: + # Concatenate, while un-escaping as needed. + tokens = [] + for i, n in enumerate(ctx.getChildren()): + if n.symbol.type == OverrideLexer.WS and ( + i < first_idx or i >= last_idx + ): + # Skip leading / trailing whitespaces. + continue + tokens.append( + n.symbol.text[1::2] # un-escape by skipping every other char + if n.symbol.type == OverrideLexer.ESC + else n.symbol.text + ) + ret = "".join(tokens) + else: + node = ctx.getChild(first_idx) + if node.symbol.type == OverrideLexer.QUOTED_VALUE: + text = node.getText() + qc = text[0] + text = text[1:-1] + if qc == "'": + quote = Quote.single + text = text.replace("\\'", "'") + elif qc == '"': + quote = Quote.double + text = text.replace('\\"', '"') + else: + assert False + return QuotedString(text=text, quote=quote) + elif node.symbol.type in (OverrideLexer.ID, OverrideLexer.INTERPOLATION): + ret = node.symbol.text + elif node.symbol.type == OverrideLexer.INT: + ret = int(node.symbol.text) + elif node.symbol.type == OverrideLexer.FLOAT: + ret = float(node.symbol.text) + elif node.symbol.type == OverrideLexer.NULL: + ret = None + elif node.symbol.type == OverrideLexer.BOOL: + text = node.getText().lower() + if text == "true": + ret = True + elif text == "false": + ret = False + else: + assert False + elif node.symbol.type == OverrideLexer.ESC: + ret = node.symbol.text[1::2] + else: + return node.getText() # type: ignore + return ret + class HydraErrorListener(ErrorListener): # type: ignore def syntaxError( From 4f2665eec46a18033c58771c1922d0f929084988 Mon Sep 17 00:00:00 2001 From: Olivier Delalleau <507137+odelalleau@users.noreply.github.com> Date: Fri, 18 Dec 2020 23:42:30 -0500 Subject: [PATCH 16/16] Replace list comprehension with for loop for readability --- hydra/core/override_parser/types.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/hydra/core/override_parser/types.py b/hydra/core/override_parser/types.py index 6ff7b4118e0..6d32db687d5 100644 --- a/hydra/core/override_parser/types.py +++ b/hydra/core/override_parser/types.py @@ -418,14 +418,14 @@ def _get_value_element_as_str( ) return "[" + s + "]" elif isinstance(value, dict): - s = comma.join( - [ - f"{Override._get_value_element_as_str(k)}{colon}" - f"{Override._get_value_element_as_str(v, space_after_sep=space_after_sep)}" - for k, v in value.items() - ] - ) - return "{" + s + "}" + str_items = [] + for k, v in value.items(): + str_key = Override._get_value_element_as_str(k) + str_value = Override._get_value_element_as_str( + v, space_after_sep=space_after_sep + ) + str_items.append(f"{str_key}{colon}{str_value}") + return "{" + comma.join(str_items) + "}" elif isinstance(value, str): return escape_special_characters(value) elif isinstance(value, (int, bool, float)):