Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow more types of dictionary keys in overrides grammar #1208

Merged
merged 16 commits into from
Dec 20, 2020
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions hydra/_internal/grammar/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,37 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import inspect
import re
from typing import Any, Union

from omegaconf._utils import is_dict_annotation, is_list_annotation

# All characters that must be escaped (must match the ESC grammar lexer token).
_ESC = "\\()[]{}:=, \t"

# Regular expression that matches any sequence of characters in `_ESC`.
_ESC_REGEX = re.compile(f"[{re.escape(_ESC)}]+")


def escape_special_characters(s: str) -> str:
"""Escape special characters in `s`"""
matches = _ESC_REGEX.findall(s)
if not matches:
return s
# Replace all special characters found in `s`. Performance should not be critical
# so we do one pass per special character.
all_special = set("".join(matches))
# '\' is even more special: it needs to be replaced first, otherwise we will
# mess up the other escaped characters.
try:
all_special.remove("\\")
except KeyError:
pass # no '\' in the string
else:
s = s.replace("\\", "\\\\")
for special_char in all_special:
s = s.replace(special_char, f"\\{special_char}")
return s


def is_type_matching(value: Any, type_: Any) -> bool:
# Union
Expand Down
28 changes: 18 additions & 10 deletions hydra/core/override_parser/overrides_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ def is_ws(self, c: Any) -> bool:

def visitPrimitive(
self, ctx: OverrideParser.PrimitiveContext
) -> Optional[Union[QuotedString, int, bool, float, str]]:
return self.visitPrimitiveOrDictKey(ctx)

def visitPrimitiveOrDictKey(
self, ctx: Union[OverrideParser.PrimitiveContext, OverrideParser.DictKeyContext]
omry marked this conversation as resolved.
Show resolved Hide resolved
) -> Optional[Union[QuotedString, int, bool, float, str]]:
ret: Optional[Union[int, bool, float, str]]
first_idx = 0
Expand Down Expand Up @@ -144,8 +149,8 @@ def visitPrimitive(
return node.getText() # type: ignore
return ret

def visitListValue(
self, ctx: OverrideParser.ListValueContext
def visitListContainer(
self, ctx: OverrideParser.ListContainerContext
) -> List[ParsedElementType]:
ret: List[ParsedElementType] = []

Expand All @@ -159,22 +164,25 @@ def visitListValue(
ret.append(self.visitElement(element))
return ret

def visitDictValue(
self, ctx: OverrideParser.DictValueContext
def visitDictContainer(
self, ctx: OverrideParser.DictContainerContext
) -> Dict[str, ParsedElementType]:
assert self.is_matching_terminal(ctx.getChild(0), OverrideLexer.BRACE_OPEN)
return dict(
self.visitDictKeyValuePair(ctx.getChild(i))
for i in range(1, ctx.getChildCount() - 1, 2)
)

def visitDictKey(self, ctx: OverrideParser.DictKeyContext) -> Any:
return self.visitPrimitiveOrDictKey(ctx)

def visitDictKeyValuePair(
self, ctx: OverrideParser.DictKeyValuePairContext
) -> Tuple[str, ParsedElementType]:
children = ctx.getChildren()
item = next(children)
assert self.is_matching_terminal(item, OverrideLexer.ID)
pkey = item.getText()
assert isinstance(item, OverrideParser.DictKeyContext)
pkey = self.visitDictKey(item)
assert self.is_matching_terminal(next(children), OverrideLexer.COLON)
value = next(children)
assert isinstance(value, OverrideParser.ElementContext)
Expand All @@ -186,10 +194,10 @@ def visitElement(self, ctx: OverrideParser.ElementContext) -> ParsedElementType:
return self.visitFunction(ctx.function()) # type: ignore
elif ctx.primitive():
return self.visitPrimitive(ctx.primitive())
elif ctx.listValue():
return self.visitListValue(ctx.listValue())
elif ctx.dictValue():
return self.visitDictValue(ctx.dictValue())
elif ctx.listContainer():
return self.visitListContainer(ctx.listContainer())
elif ctx.dictContainer():
return self.visitDictContainer(ctx.dictContainer())
else:
assert False

Expand Down
22 changes: 17 additions & 5 deletions hydra/core/override_parser/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from omegaconf import OmegaConf
from omegaconf._utils import is_structured_config

from hydra._internal.grammar.utils import escape_special_characters
from hydra.core.config_loader import ConfigLoader
from hydra.core.object_type import ObjectType
from hydra.errors import HydraException
Expand All @@ -20,7 +21,7 @@ class Quote(Enum):
double = 1


@dataclass
@dataclass(frozen=True)
class QuotedString:
text: str

Expand Down Expand Up @@ -258,7 +259,16 @@ def _convert_value(value: ParsedElementType) -> Optional[ElementType]:
if isinstance(value, list):
return [Override._convert_value(x) for x in value]
elif isinstance(value, dict):
return {k: Override._convert_value(v) for k, v in value.items()}

# Currently only strings are allowed as dictionary keys.
def check_str(k: Any) -> str:
assert isinstance(k, str)
return k

omry marked this conversation as resolved.
Show resolved Hide resolved
return {
omry marked this conversation as resolved.
Show resolved Hide resolved
check_str(Override._convert_value(k)): Override._convert_value(v)
for k, v in value.items()
}
elif isinstance(value, QuotedString):
return value.text
else:
Expand Down Expand Up @@ -413,15 +423,17 @@ def _get_value_element_as_str(
elif isinstance(value, dict):
s = comma.join(
[
f"{k}{colon}{Override._get_value_element_as_str(v, space_after_sep=space_after_sep)}"
f"{Override._get_value_element_as_str(k)}{colon}"
f"{Override._get_value_element_as_str(v, space_after_sep=space_after_sep)}"
for k, v in value.items()
]
omry marked this conversation as resolved.
Show resolved Hide resolved
)
return "{" + s + "}"
elif isinstance(value, (str, int, bool, float)):
elif isinstance(value, str):
return escape_special_characters(value)
elif isinstance(value, (int, bool, float)):
return str(value)
elif is_structured_config(value):
print(value)
return Override._get_value_element_as_str(
OmegaConf.to_container(OmegaConf.structured(value))
)
Expand Down
2 changes: 2 additions & 0 deletions hydra/grammar/OverrideLexer.g4
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ NULL: [Nn][Uu][Ll][Ll];

UNQUOTED_CHAR: [/\-\\+.$%*@]; // other characters allowed in unquoted strings
ID: (CHAR|'_') (CHAR|DIGIT|'_')*;
// Note: when adding more characters to the ESC rule below, also add them to
// the `_ESC` string in `_internal/grammar/utils.py`.
ESC: (ESC_BACKSLASH | '\\(' | '\\)' | '\\[' | '\\]' | '\\{' | '\\}' |
'\\:' | '\\=' | '\\,' | '\\ ' | '\\\t')+;
WS: [ \t]+;
Expand Down
23 changes: 18 additions & 5 deletions hydra/grammar/OverrideParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ value: element | simpleChoiceSweep;

element:
primitive
| listValue
| dictValue
| listContainer
| dictContainer
| function
;

Expand All @@ -47,12 +47,12 @@ function: ID POPEN (argName? element (COMMA argName? element )* )? PCLOSE;

// Data structures.

listValue: BRACKET_OPEN // [], [1,2,3], [a,b,[1,2]]
listContainer: BRACKET_OPEN // [], [1,2,3], [a,b,[1,2]]
(element(COMMA element)*)?
BRACKET_CLOSE;

dictValue: BRACE_OPEN (dictKeyValuePair (COMMA dictKeyValuePair)*)? BRACE_CLOSE; // {}, {a:10,b:20}
dictKeyValuePair: ID COLON element;
dictContainer: BRACE_OPEN (dictKeyValuePair (COMMA dictKeyValuePair)*)? BRACE_CLOSE; // {}, {a:10,b:20}
dictKeyValuePair: dictKey COLON element;

// Primitive types.

Expand All @@ -69,3 +69,16 @@ primitive:
| ESC // \\, \(, \), \[, \], \{, \}, \:, \=, \ , \\t, \,
| WS // whitespaces
)+;

// Same as `primitive` except that `COLON` and `INTERPOLATION` are not allowed.
dictKey:
QUOTED_VALUE // 'hello world', "hello world"
| ( ID // foo_10
| NULL // null, NULL
| INT // 0, 10, -20, 1_000_000
| FLOAT // 3.14, -20.0, 1e-1, -10e3
| BOOL // true, TrUe, false, False
| UNQUOTED_CHAR // /, -, \, +, ., $, %, *, @
| ESC // \\, \(, \), \[, \], \{, \}, \:, \=, \ , \\t, \,
| WS // whitespaces
)+;
2 changes: 0 additions & 2 deletions tests/test_config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,14 +1152,12 @@ def test_apply_overrides_to_config(
id="default_change",
),
pytest.param(
# need to unset optimizer config group first, otherwise they get merged
"config",
["optimizer={type:nesterov2,lr:1}"],
{"optimizer": {"type": "nesterov2", "lr": 1}},
id="dict_merge",
),
pytest.param(
# need to unset optimizer config group first, otherwise they get merged
"config",
["+optimizer={foo:10}"],
{"optimizer": {"type": "nesterov", "lr": 0.001, "foo": 10}},
Expand Down
27 changes: 27 additions & 0 deletions tests/test_internal_grammar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
omry marked this conversation as resolved.
Show resolved Hide resolved
from pytest import mark, param

from hydra._internal.grammar.utils import escape_special_characters


@mark.parametrize( # type: ignore
("s", "expected"),
[
param("abc", "abc", id="no_esc"),
param("\\", "\\\\", id="esc_backslash"),
param("\\\\\\", "\\\\\\\\\\\\", id="esc_backslash_x3"),
param("()", "\\(\\)", id="esc_parentheses"),
param("[]", "\\[\\]", id="esc_brackets"),
param("{}", "\\{\\}", id="esc_braces"),
param(":=,", "\\:\\=\\,", id="esc_symbols"),
param(" \t", "\\ \\ \\\t", id="esc_ws"),
param(
"ab\\(cd{ef}[gh]): ij,kl\t",
"ab\\\\\\(cd\\{ef\\}\\[gh\\]\\)\\:\\ ij\\,kl\\\t",
id="esc_mixed",
),
],
)
def test_escape_special_characters(s: str, expected: str) -> None:
escaped = escape_special_characters(s)
assert escaped == expected
78 changes: 74 additions & 4 deletions tests/test_overrides_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,8 @@ def test_value(value: str, expected: Any) -> None:
pytest.param("[1,[a]]", [1, ["a"]], id="list:simple_and_list_elements"),
],
)
def test_list_value(value: str, expected: Any) -> None:
ret = parse_rule(value, "listValue")
def test_list_container(value: str, expected: Any) -> None:
ret = parse_rule(value, "listContainer")
assert ret == expected


Expand Down Expand Up @@ -277,10 +277,40 @@ def test_shuffle_sequence(value: str, expected: Any) -> None:
pytest.param("{a:10,b:20}", {"a": 10, "b": 20}, id="dict"),
pytest.param("{a:10,b:{}}", {"a": 10, "b": {}}, id="dict"),
pytest.param("{a:10,b:{c:[1,2]}}", {"a": 10, "b": {"c": [1, 2]}}, id="dict"),
pytest.param(
"{'0a': 0, \"1b\": 1}",
{
QuotedString(text="0a", quote=Quote.single): 0,
QuotedString(text="1b", quote=Quote.double): 1,
},
id="dict_quoted_key",
),
pytest.param("{null: 1}", {None: 1}, id="dict_null_key"),
pytest.param("{123: 1, 0: 2, -1: 3}", {123: 1, 0: 2, -1: 3}, id="dict_int_key"),
pytest.param("{3.14: 0, 1e3: 1}", {3.14: 0, 1000.0: 1}, id="dict_float_key"),
pytest.param("{true: 1, fAlSe: 0}", {True: 1, False: 0}, id="dict_bool_key"),
pytest.param("{/-\\+.$%*@: 1}", {"/-\\+.$%*@": 1}, id="dict_unquoted_char_key"),
pytest.param(
"{\\\\\\(\\)\\[\\]\\{\\}\\:\\=\\ \\\t\\,: 1}",
{"\\()[]{}:= \t,": 1},
id="dict_esc_key",
),
pytest.param("{white spaces: 1}", {"white spaces": 1}, id="dict_ws_key"),
pytest.param(
"{'a:b': 1, ab 123.5 True: 2, null false: 3, 1: 4, null: 5}",
{
QuotedString(text="a:b", quote=Quote.single): 1,
"ab 123.5 True": 2,
"null false": 3,
1: 4,
None: 5,
},
id="dict_mixed_keys",
),
],
)
def test_dict_value(value: str, expected: Any) -> None:
ret = parse_rule(value, "dictValue")
def test_dict_container(value: str, expected: Any) -> None:
ret = parse_rule(value, "dictContainer")
assert ret == expected


Expand Down Expand Up @@ -929,6 +959,12 @@ def test_get_key_element(override: str, expected: str) -> None:
pytest.param("key='value'", "'value'", False, id="single_quoted"),
pytest.param('key="value"', '"value"', False, id="double_quoted"),
pytest.param("key='שלום'", "'שלום'", False, id="quoted_unicode"),
pytest.param(
"key=\\\\\\(\\)\\[\\]\\{\\}\\:\\=\\ \\\t\\,",
"\\\\\\(\\)\\[\\]\\{\\}\\:\\=\\ \\\t\\,",
False,
id="escaped_chars",
),
pytest.param("key=10", "10", False, id="int"),
pytest.param("key=3.1415", "3.1415", False, id="float"),
pytest.param("key=[]", "[]", False, id="list"),
Expand All @@ -942,6 +978,30 @@ def test_get_key_element(override: str, expected: str) -> None:
pytest.param("key={a:10,b:20}", "{a:10,b:20}", False, id="dict"),
pytest.param("key={a:10,b:20}", "{a: 10, b: 20}", True, id="dict"),
pytest.param("key={a:10,b:[1,2,3]}", "{a: 10, b: [1, 2, 3]}", True, id="dict"),
pytest.param(
"key={'null':1, \"a:b\": 0}",
"{'null': 1, \"a:b\": 0}",
True,
id="dict_quoted_key",
),
pytest.param(
"key={/-\\+.$%*@: 1}",
"{/-\\\\+.$%*@: 1}", # note that \ gets escaped
True,
id="dict_unquoted_key_special",
),
pytest.param(
"key={ white space\t: 2}",
"{white\\ \\ space: 2}",
True,
id="dict_ws_in_key",
),
pytest.param(
"key={\\\\\\(\\)\\[\\]\\{\\}\\:\\=\\ \\\t\\,: 2}",
"{\\\\\\(\\)\\[\\]\\{\\}\\:\\=\\ \\\t\\,: 2}",
True,
id="dict_esc_key",
),
],
)
def test_override_get_value_element_method(
Expand Down Expand Up @@ -970,6 +1030,16 @@ def test_override_get_value_element_method(
pytest.param("key={a:10,b:20}", {"a": 10, "b": 20}, id="dict"),
pytest.param("key={a:10,b:20}", {"a": 10, "b": 20}, id="dict"),
pytest.param("key={a:10,b:[1,2,3]}", {"a": 10, "b": [1, 2, 3]}, id="dict"),
pytest.param("key={123id: 0}", {"123id": 0}, id="dict_key_int_plus_id"),
pytest.param("key={' abc ': 0}", {" abc ": 0}, id="dict_key_quoted_single"),
pytest.param('key={" abc ": 0}', {" abc ": 0}, id="dict_key_quoted_double"),
pytest.param("key={a/-\\+.$%*@: 0}", {"a/-\\+.$%*@": 0}, id="dict_key_noquote"),
pytest.param("key={w s: 0}", {"w s": 0}, id="dict_key_ws"),
pytest.param(
"key={\\\\\\(\\)\\[\\]\\{\\}\\:\\=\\ \\\t\\,: 0}",
{"\\()[]{}:= \t,": 0},
id="dict_key_esc",
),
],
)
def test_override_value_method(override: str, expected: str) -> None:
Expand Down
Loading