Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow more types of dictionary keys in overrides grammar #1208

Merged
merged 16 commits into from
Dec 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions hydra/_internal/grammar/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,37 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import inspect
import re
from typing import Any, Union

from omegaconf._utils import is_dict_annotation, is_list_annotation

# All characters that must be escaped (must match the ESC grammar lexer token).
_ESC = "\\()[]{}:=, \t"

# Regular expression that matches any sequence of characters in `_ESC`.
_ESC_REGEX = re.compile(f"[{re.escape(_ESC)}]+")


def escape_special_characters(s: str) -> str:
"""Escape special characters in `s`"""
matches = _ESC_REGEX.findall(s)
if not matches:
return s
# Replace all special characters found in `s`. Performance should not be critical
# so we do one pass per special character.
all_special = set("".join(matches))
# '\' is even more special: it needs to be replaced first, otherwise we will
# mess up the other escaped characters.
try:
all_special.remove("\\")
except KeyError:
pass # no '\' in the string
else:
s = s.replace("\\", "\\\\")
for special_char in all_special:
s = s.replace(special_char, f"\\{special_char}")
return s


def is_type_matching(value: Any, type_: Any) -> bool:
# Union
Expand Down
160 changes: 84 additions & 76 deletions hydra/core/override_parser/overrides_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import warnings
from typing import Any, Dict, List, Optional, Tuple, Union

from antlr4 import TerminalNode, Token
from antlr4 import ParserRuleContext, TerminalNode, Token
from antlr4.error.ErrorListener import ErrorListener
from antlr4.tree.Tree import TerminalNodeImpl

Expand Down Expand Up @@ -78,74 +78,10 @@ def is_ws(self, c: Any) -> bool:
def visitPrimitive(
self, ctx: OverrideParser.PrimitiveContext
) -> Optional[Union[QuotedString, int, bool, float, str]]:
ret: Optional[Union[int, bool, float, str]]
first_idx = 0
last_idx = ctx.getChildCount()
# skip first if whitespace
if self.is_ws(ctx.getChild(0)):
if last_idx == 1:
# Only whitespaces => this is not allowed.
raise HydraException(
"Trying to parse a primitive that is all whitespaces"
)
first_idx = 1
if self.is_ws(ctx.getChild(-1)):
last_idx = last_idx - 1
num = last_idx - first_idx
if num > 1:
# Concatenate, while un-escaping as needed.
tokens = []
for i, n in enumerate(ctx.getChildren()):
if n.symbol.type == OverrideLexer.WS and (
i < first_idx or i >= last_idx
):
# Skip leading / trailing whitespaces.
continue
tokens.append(
n.symbol.text[1::2] # un-escape by skipping every other char
if n.symbol.type == OverrideLexer.ESC
else n.symbol.text
)
ret = "".join(tokens)
else:
node = ctx.getChild(first_idx)
if node.symbol.type == OverrideLexer.QUOTED_VALUE:
text = node.getText()
qc = text[0]
text = text[1:-1]
if qc == "'":
quote = Quote.single
text = text.replace("\\'", "'")
elif qc == '"':
quote = Quote.double
text = text.replace('\\"', '"')
else:
assert False
return QuotedString(text=text, quote=quote)
elif node.symbol.type in (OverrideLexer.ID, OverrideLexer.INTERPOLATION):
ret = node.symbol.text
elif node.symbol.type == OverrideLexer.INT:
ret = int(node.symbol.text)
elif node.symbol.type == OverrideLexer.FLOAT:
ret = float(node.symbol.text)
elif node.symbol.type == OverrideLexer.NULL:
ret = None
elif node.symbol.type == OverrideLexer.BOOL:
text = node.getText().lower()
if text == "true":
ret = True
elif text == "false":
ret = False
else:
assert False
elif node.symbol.type == OverrideLexer.ESC:
ret = node.symbol.text[1::2]
else:
return node.getText() # type: ignore
return ret
return self._createPrimitive(ctx)

def visitListValue(
self, ctx: OverrideParser.ListValueContext
def visitListContainer(
self, ctx: OverrideParser.ListContainerContext
) -> List[ParsedElementType]:
ret: List[ParsedElementType] = []

Expand All @@ -159,22 +95,25 @@ def visitListValue(
ret.append(self.visitElement(element))
return ret

def visitDictValue(
self, ctx: OverrideParser.DictValueContext
def visitDictContainer(
self, ctx: OverrideParser.DictContainerContext
) -> Dict[str, ParsedElementType]:
assert self.is_matching_terminal(ctx.getChild(0), OverrideLexer.BRACE_OPEN)
return dict(
self.visitDictKeyValuePair(ctx.getChild(i))
for i in range(1, ctx.getChildCount() - 1, 2)
)

def visitDictKey(self, ctx: OverrideParser.DictKeyContext) -> Any:
return self._createPrimitive(ctx)

def visitDictKeyValuePair(
self, ctx: OverrideParser.DictKeyValuePairContext
) -> Tuple[str, ParsedElementType]:
children = ctx.getChildren()
item = next(children)
assert self.is_matching_terminal(item, OverrideLexer.ID)
pkey = item.getText()
assert isinstance(item, OverrideParser.DictKeyContext)
pkey = self.visitDictKey(item)
assert self.is_matching_terminal(next(children), OverrideLexer.COLON)
value = next(children)
assert isinstance(value, OverrideParser.ElementContext)
Expand All @@ -186,10 +125,10 @@ def visitElement(self, ctx: OverrideParser.ElementContext) -> ParsedElementType:
return self.visitFunction(ctx.function()) # type: ignore
elif ctx.primitive():
return self.visitPrimitive(ctx.primitive())
elif ctx.listValue():
return self.visitListValue(ctx.listValue())
elif ctx.dictValue():
return self.visitDictValue(ctx.dictValue())
elif ctx.listContainer():
return self.visitListContainer(ctx.listContainer())
elif ctx.dictContainer():
return self.visitDictContainer(ctx.dictContainer())
else:
assert False

Expand Down Expand Up @@ -305,6 +244,75 @@ def visitFunction(self, ctx: OverrideParser.FunctionContext) -> Any:
f"{type(e).__name__} while evaluating '{ctx.getText()}': {e}"
) from e

def _createPrimitive(
self, ctx: ParserRuleContext
) -> Optional[Union[QuotedString, int, bool, float, str]]:
ret: Optional[Union[int, bool, float, str]]
first_idx = 0
last_idx = ctx.getChildCount()
# skip first if whitespace
if self.is_ws(ctx.getChild(0)):
if last_idx == 1:
# Only whitespaces => this is not allowed.
raise HydraException(
"Trying to parse a primitive that is all whitespaces"
)
first_idx = 1
if self.is_ws(ctx.getChild(-1)):
last_idx = last_idx - 1
num = last_idx - first_idx
if num > 1:
# Concatenate, while un-escaping as needed.
tokens = []
for i, n in enumerate(ctx.getChildren()):
if n.symbol.type == OverrideLexer.WS and (
i < first_idx or i >= last_idx
):
# Skip leading / trailing whitespaces.
continue
tokens.append(
n.symbol.text[1::2] # un-escape by skipping every other char
if n.symbol.type == OverrideLexer.ESC
else n.symbol.text
)
ret = "".join(tokens)
else:
node = ctx.getChild(first_idx)
if node.symbol.type == OverrideLexer.QUOTED_VALUE:
text = node.getText()
qc = text[0]
text = text[1:-1]
if qc == "'":
quote = Quote.single
text = text.replace("\\'", "'")
elif qc == '"':
quote = Quote.double
text = text.replace('\\"', '"')
else:
assert False
return QuotedString(text=text, quote=quote)
elif node.symbol.type in (OverrideLexer.ID, OverrideLexer.INTERPOLATION):
ret = node.symbol.text
elif node.symbol.type == OverrideLexer.INT:
ret = int(node.symbol.text)
elif node.symbol.type == OverrideLexer.FLOAT:
ret = float(node.symbol.text)
elif node.symbol.type == OverrideLexer.NULL:
ret = None
elif node.symbol.type == OverrideLexer.BOOL:
text = node.getText().lower()
if text == "true":
ret = True
elif text == "false":
ret = False
else:
assert False
elif node.symbol.type == OverrideLexer.ESC:
ret = node.symbol.text[1::2]
else:
return node.getText() # type: ignore
return ret


class HydraErrorListener(ErrorListener): # type: ignore
def syntaxError(
Expand Down
31 changes: 20 additions & 11 deletions hydra/core/override_parser/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from omegaconf import OmegaConf
from omegaconf._utils import is_structured_config

from hydra._internal.grammar.utils import escape_special_characters
from hydra.core.config_loader import ConfigLoader
from hydra.core.object_type import ObjectType
from hydra.errors import HydraException
Expand All @@ -20,7 +21,7 @@ class Quote(Enum):
double = 1


@dataclass
@dataclass(frozen=True)
class QuotedString:
text: str

Expand Down Expand Up @@ -258,7 +259,13 @@ def _convert_value(value: ParsedElementType) -> Optional[ElementType]:
if isinstance(value, list):
return [Override._convert_value(x) for x in value]
elif isinstance(value, dict):
return {k: Override._convert_value(v) for k, v in value.items()}

return {
omry marked this conversation as resolved.
Show resolved Hide resolved
# We ignore potential type mismatch here so as to let OmegaConf
# raise an explicit error in case of invalid type.
Override._convert_value(k): Override._convert_value(v) # type: ignore
for k, v in value.items()
}
elif isinstance(value, QuotedString):
return value.text
else:
Expand Down Expand Up @@ -411,17 +418,19 @@ def _get_value_element_as_str(
)
return "[" + s + "]"
elif isinstance(value, dict):
s = comma.join(
[
f"{k}{colon}{Override._get_value_element_as_str(v, space_after_sep=space_after_sep)}"
for k, v in value.items()
]
)
return "{" + s + "}"
elif isinstance(value, (str, int, bool, float)):
str_items = []
for k, v in value.items():
str_key = Override._get_value_element_as_str(k)
str_value = Override._get_value_element_as_str(
v, space_after_sep=space_after_sep
)
str_items.append(f"{str_key}{colon}{str_value}")
return "{" + comma.join(str_items) + "}"
elif isinstance(value, str):
return escape_special_characters(value)
elif isinstance(value, (int, bool, float)):
return str(value)
elif is_structured_config(value):
print(value)
return Override._get_value_element_as_str(
OmegaConf.to_container(OmegaConf.structured(value))
)
Expand Down
2 changes: 2 additions & 0 deletions hydra/grammar/OverrideLexer.g4
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ NULL: [Nn][Uu][Ll][Ll];

UNQUOTED_CHAR: [/\-\\+.$%*@]; // other characters allowed in unquoted strings
ID: (CHAR|'_') (CHAR|DIGIT|'_')*;
// Note: when adding more characters to the ESC rule below, also add them to
// the `_ESC` string in `_internal/grammar/utils.py`.
ESC: (ESC_BACKSLASH | '\\(' | '\\)' | '\\[' | '\\]' | '\\{' | '\\}' |
'\\:' | '\\=' | '\\,' | '\\ ' | '\\\t')+;
WS: [ \t]+;
Expand Down
23 changes: 18 additions & 5 deletions hydra/grammar/OverrideParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ value: element | simpleChoiceSweep;

element:
primitive
| listValue
| dictValue
| listContainer
| dictContainer
| function
;

Expand All @@ -47,12 +47,12 @@ function: ID POPEN (argName? element (COMMA argName? element )* )? PCLOSE;

// Data structures.

listValue: BRACKET_OPEN // [], [1,2,3], [a,b,[1,2]]
listContainer: BRACKET_OPEN // [], [1,2,3], [a,b,[1,2]]
(element(COMMA element)*)?
BRACKET_CLOSE;

dictValue: BRACE_OPEN (dictKeyValuePair (COMMA dictKeyValuePair)*)? BRACE_CLOSE; // {}, {a:10,b:20}
dictKeyValuePair: ID COLON element;
dictContainer: BRACE_OPEN (dictKeyValuePair (COMMA dictKeyValuePair)*)? BRACE_CLOSE; // {}, {a:10,b:20}
dictKeyValuePair: dictKey COLON element;

// Primitive types.

Expand All @@ -69,3 +69,16 @@ primitive:
| ESC // \\, \(, \), \[, \], \{, \}, \:, \=, \ , \\t, \,
| WS // whitespaces
)+;

// Same as `primitive` except that `COLON` and `INTERPOLATION` are not allowed.
dictKey:
QUOTED_VALUE // 'hello world', "hello world"
| ( ID // foo_10
| NULL // null, NULL
| INT // 0, 10, -20, 1_000_000
| FLOAT // 3.14, -20.0, 1e-1, -10e3
| BOOL // true, TrUe, false, False
| UNQUOTED_CHAR // /, -, \, +, ., $, %, *, @
| ESC // \\, \(, \), \[, \], \{, \}, \:, \=, \ , \\t, \,
| WS // whitespaces
)+;
2 changes: 0 additions & 2 deletions tests/test_config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,14 +1152,12 @@ def test_apply_overrides_to_config(
id="default_change",
),
pytest.param(
# need to unset optimizer config group first, otherwise they get merged
"config",
["optimizer={type:nesterov2,lr:1}"],
{"optimizer": {"type": "nesterov2", "lr": 1}},
id="dict_merge",
),
pytest.param(
# need to unset optimizer config group first, otherwise they get merged
"config",
["+optimizer={foo:10}"],
{"optimizer": {"type": "nesterov", "lr": 0.001, "foo": 10}},
Expand Down
Loading