Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow more types of dictionary keys in overrides grammar #1208

Merged
merged 16 commits into from
Dec 20, 2020
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions hydra/core/override_parser/overrides_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ def is_ws(self, c: Any) -> bool:

def visitPrimitive(
self, ctx: OverrideParser.PrimitiveContext
) -> Optional[Union[QuotedString, int, bool, float, str]]:
return self.visitPrimitiveOrDictKey(ctx)

def visitPrimitiveOrDictKey(
self, ctx: Union[OverrideParser.PrimitiveContext, OverrideParser.DictKeyContext]
omry marked this conversation as resolved.
Show resolved Hide resolved
) -> Optional[Union[QuotedString, int, bool, float, str]]:
ret: Optional[Union[int, bool, float, str]]
first_idx = 0
Expand Down Expand Up @@ -144,8 +149,8 @@ def visitPrimitive(
return node.getText() # type: ignore
return ret

def visitListValue(
self, ctx: OverrideParser.ListValueContext
def visitListContainer(
self, ctx: OverrideParser.ListContainerContext
) -> List[ParsedElementType]:
ret: List[ParsedElementType] = []

Expand All @@ -159,22 +164,25 @@ def visitListValue(
ret.append(self.visitElement(element))
return ret

def visitDictValue(
self, ctx: OverrideParser.DictValueContext
def visitDictContainer(
self, ctx: OverrideParser.DictContainerContext
) -> Dict[str, ParsedElementType]:
assert self.is_matching_terminal(ctx.getChild(0), OverrideLexer.BRACE_OPEN)
return dict(
self.visitDictKeyValuePair(ctx.getChild(i))
for i in range(1, ctx.getChildCount() - 1, 2)
)

def visitDictKey(self, ctx: OverrideParser.DictKeyContext) -> Any:
return self.visitPrimitiveOrDictKey(ctx)

def visitDictKeyValuePair(
self, ctx: OverrideParser.DictKeyValuePairContext
) -> Tuple[str, ParsedElementType]:
children = ctx.getChildren()
item = next(children)
assert self.is_matching_terminal(item, OverrideLexer.ID)
pkey = item.getText()
assert isinstance(item, OverrideParser.DictKeyContext)
pkey = self.visitDictKey(item)
assert self.is_matching_terminal(next(children), OverrideLexer.COLON)
value = next(children)
assert isinstance(value, OverrideParser.ElementContext)
Expand All @@ -186,10 +194,10 @@ def visitElement(self, ctx: OverrideParser.ElementContext) -> ParsedElementType:
return self.visitFunction(ctx.function()) # type: ignore
elif ctx.primitive():
return self.visitPrimitive(ctx.primitive())
elif ctx.listValue():
return self.visitListValue(ctx.listValue())
elif ctx.dictValue():
return self.visitDictValue(ctx.dictValue())
elif ctx.listContainer():
return self.visitListContainer(ctx.listContainer())
elif ctx.dictContainer():
return self.visitDictContainer(ctx.dictContainer())
else:
assert False

Expand Down
24 changes: 19 additions & 5 deletions hydra/core/override_parser/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@ class QuotedString:

quote: Quote

def __hash__(self) -> int:
return hash(self.text)

def __eq__(self, other: Any) -> Any:
# We do not care whether quotes match for equality.
if isinstance(other, QuotedString):
return self.text == other.text
else:
return NotImplemented

omry marked this conversation as resolved.
Show resolved Hide resolved
def with_quotes(self) -> str:
if self.quote == Quote.single:
q = "'"
Expand Down Expand Up @@ -142,9 +152,9 @@ def __eq__(self, other: Any) -> Any:
return NotImplemented


# Ideally we would use List[ElementType] and Dict[str, ElementType] but Python does not seem
# to support recursive type definitions.
ElementType = Union[str, int, float, bool, List[Any], Dict[str, Any]]
# Ideally we would use List[ElementType] and Dict[ElementType, ElementType] but Python
omry marked this conversation as resolved.
Show resolved Hide resolved
# does not seem to support recursive type definitions.
ElementType = Union[str, int, float, bool, List[Any], Dict[Any, Any]]
ParsedElementType = Optional[Union[ElementType, QuotedString]]
TransformerType = Callable[[ParsedElementType], Any]

Expand Down Expand Up @@ -258,7 +268,10 @@ def _convert_value(value: ParsedElementType) -> Optional[ElementType]:
if isinstance(value, list):
return [Override._convert_value(x) for x in value]
elif isinstance(value, dict):
return {k: Override._convert_value(v) for k, v in value.items()}
return {
omry marked this conversation as resolved.
Show resolved Hide resolved
Override._convert_value(k): Override._convert_value(v)
for k, v in value.items()
}
elif isinstance(value, QuotedString):
return value.text
else:
Expand Down Expand Up @@ -413,7 +426,8 @@ def _get_value_element_as_str(
elif isinstance(value, dict):
s = comma.join(
[
f"{k}{colon}{Override._get_value_element_as_str(v, space_after_sep=space_after_sep)}"
f"{Override._get_value_element_as_str(k)}{colon}"
f"{Override._get_value_element_as_str(v, space_after_sep=space_after_sep)}"
for k, v in value.items()
]
omry marked this conversation as resolved.
Show resolved Hide resolved
)
Expand Down
23 changes: 18 additions & 5 deletions hydra/grammar/OverrideParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ value: element | simpleChoiceSweep;

element:
primitive
| listValue
| dictValue
| listContainer
| dictContainer
| function
;

Expand All @@ -47,12 +47,12 @@ function: ID POPEN (argName? element (COMMA argName? element )* )? PCLOSE;

// Data structures.

listValue: BRACKET_OPEN // [], [1,2,3], [a,b,[1,2]]
listContainer: BRACKET_OPEN // [], [1,2,3], [a,b,[1,2]]
(element(COMMA element)*)?
BRACKET_CLOSE;

dictValue: BRACE_OPEN (dictKeyValuePair (COMMA dictKeyValuePair)*)? BRACE_CLOSE; // {}, {a:10,b:20}
dictKeyValuePair: ID COLON element;
dictContainer: BRACE_OPEN (dictKeyValuePair (COMMA dictKeyValuePair)*)? BRACE_CLOSE; // {}, {a:10,b:20}
dictKeyValuePair: dictKey COLON element;

// Primitive types.

Expand All @@ -69,3 +69,16 @@ primitive:
| ESC // \\, \(, \), \[, \], \{, \}, \:, \=, \ , \\t, \,
| WS // whitespaces
)+;

// Same as `primitive` except that `COLON` and `INTERPOLATION` are not allowed.
dictKey:
QUOTED_VALUE // 'hello world', "hello world"
| ( ID // foo_10
| NULL // null, NULL
| INT // 0, 10, -20, 1_000_000
| FLOAT // 3.14, -20.0, 1e-1, -10e3
| BOOL // true, TrUe, false, False
| UNQUOTED_CHAR // /, -, \, +, ., $, %, *, @
| ESC // \\, \(, \), \[, \], \{, \}, \:, \=, \ , \\t, \,
| WS // whitespaces
)+;
34 changes: 32 additions & 2 deletions tests/test_config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,38 @@ def test_sweep_config_cache(
monkeypatch.setenv("HOME", "/another/home/dir/")
assert sweep_cfg.home == os.getenv("HOME")

@pytest.mark.parametrize( # type: ignore
"key,expected",
[
pytest.param("id123", "id123", id="id"),
pytest.param("123id", "123id", id="int_plus_id"),
pytest.param("'quoted_single'", "quoted_single", id="quoted_single"),
pytest.param('"quoted_double"', "quoted_double", id="quoted_double"),
pytest.param("'quoted_$(){}[]'", "quoted_$(){}[]", id="quoted_misc_chars"),
pytest.param("a/-\\+.$%*@", "a/-\\+.$%*@", id="unquoted_misc_chars"),
pytest.param("white space", "white space", id="whitespace"),
pytest.param(
"\\\\\\(\\)\\[\\]\\{\\}\\:\\=\\ \\\t\\,",
"\\()[]{}:= \t,",
id="unquoted_esc",
),
],
)
def test_dict_key_formats(
odelalleau marked this conversation as resolved.
Show resolved Hide resolved
odelalleau marked this conversation as resolved.
Show resolved Hide resolved
self, hydra_restore_singletons: Any, path: str, key: str, expected: str
) -> None:
"""Test that we can assign dictionaries with keys that are not just IDs"""
config_loader = ConfigLoaderImpl(
config_search_path=create_config_search_path(path)
)
cfg = config_loader.load_configuration(
config_name="config.yaml",
overrides=[f"+dict={{{key}: 123}}"],
run_mode=RunMode.RUN,
)
assert "dict" in cfg
assert cfg.dict == {expected: 123}


@pytest.mark.parametrize( # type:ignore
"config_file, overrides",
Expand Down Expand Up @@ -1152,14 +1184,12 @@ def test_apply_overrides_to_config(
id="default_change",
),
pytest.param(
# need to unset optimizer config group first, otherwise they get merged
"config",
["optimizer={type:nesterov2,lr:1}"],
{"optimizer": {"type": "nesterov2", "lr": 1}},
id="dict_merge",
),
pytest.param(
# need to unset optimizer config group first, otherwise they get merged
"config",
["+optimizer={foo:10}"],
{"optimizer": {"type": "nesterov", "lr": 0.001, "foo": 10}},
Expand Down
23 changes: 23 additions & 0 deletions tests/test_hydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,6 +1047,29 @@ def test_run_pass_list(self, cmd_base: List[str], tmpdir: Any) -> None:
ret, _err = get_run_output(cmd)
assert OmegaConf.create(ret) == OmegaConf.create(expected)

def test_multirun_dict_keys(self, cmd_base: List[str], tmpdir: Any) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test, and test_overrides_dict_keys - are very high level.
you have some low level tests in test_overrides_parser and you are jumping straight into testing as processes.

If I understand the purpose of those tests correctly (and it's hard because they are so high level), I think should should be replaced by lower level config composition test in test_config_loader.py.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, to clarify the motivation: those two new tests are meant to test the changes in types.py (which are not tested by the low level grammar tests)

The second one (test_overrides_dict_keys) should probably indeed be a lower level test. I wasn't familiar with the various Hydra tests and couldn't figure out where the basic override functionalities were being tested. I just moved it to test_config_loader.py as you suggested: see 6bcaac2, let me know if that makes more sense this way. The main thing that needs to be tested is quoted strings as keys, but I thought it wouldn't hurt to add tests for all other kinds of string formatting that we may expect, just in case (also later on, we should also test other types of keys like int / float / bool if we support them).

I'm not sure I can move test_multirun_dict_keys to this same file though. This test specifically tests the change to _get_value_element_as_str() (l. 429 of types.py in the current diff of this PR), which is used in sweeps to provide the proper command line overrides. Again, the main thing that needs testing are quoted strings, but I added the other key formats as well to be safe.
Maybe to clarify the purpose of this test, I can show you how it fails if I revert the l.429 change:

LexerNoViableAltException: +foo={QuotedString(text='null', quote=<Quote.single: 0>):0}

(this is because the sweeper would not replace the QuotedString with its actual quoted string representation in the command line override)

Side note: I also pushed 0b53209 which removes some comments that seemed irrelevant to me (probably a copy/paste leftover from another comment you can see below them)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great.
There is a test in test_overrides_parser that is dedicated for that logic: test_override_get_value_element_method

I think testing there should be sufficient.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think testing there should be sufficient.

Perfect! Done in 8084317

At the same time I uncovered a bug with escaped characters (that affected regular unquoted strings as well, not just dictionary keys), which is now fixed in that same commit.

cmd = cmd_base + [
"+foo={'null': 0},{'NuLl': 1},{123abc: 0},{/-\\+.$%*@: 1},{white space: 3}",
"--multirun",
]
expected = """\
foo:
'null': 0

foo:
NuLl: 1

foo:
123abc: 0

foo:
/-\\+.$%*@: 1

foo:
white space: 3"""
ret, _err = get_run_output(cmd)
assert normalize_newlines(ret) == normalize_newlines(expected)


def test_app_with_error_exception_sanitized(tmpdir: Any, monkeypatch: Any) -> None:
monkeypatch.chdir("tests/test_apps/app_with_runtime_config_error")
Expand Down
38 changes: 34 additions & 4 deletions tests/test_overrides_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,8 @@ def test_value(value: str, expected: Any) -> None:
pytest.param("[1,[a]]", [1, ["a"]], id="list:simple_and_list_elements"),
],
)
def test_list_value(value: str, expected: Any) -> None:
ret = parse_rule(value, "listValue")
def test_list_container(value: str, expected: Any) -> None:
ret = parse_rule(value, "listContainer")
assert ret == expected


Expand Down Expand Up @@ -277,10 +277,40 @@ def test_shuffle_sequence(value: str, expected: Any) -> None:
pytest.param("{a:10,b:20}", {"a": 10, "b": 20}, id="dict"),
pytest.param("{a:10,b:{}}", {"a": 10, "b": {}}, id="dict"),
pytest.param("{a:10,b:{c:[1,2]}}", {"a": 10, "b": {"c": [1, 2]}}, id="dict"),
pytest.param(
"{'0a': 0, \"1b\": 1}",
{
QuotedString(text="0a", quote=Quote.double): 0,
QuotedString(text="1b", quote=Quote.single): 1,
},
id="dict_quoted_key",
),
pytest.param("{null: 1}", {None: 1}, id="dict_null_key"),
pytest.param("{123: 1, 0: 2, -1: 3}", {123: 1, 0: 2, -1: 3}, id="dict_int_key"),
pytest.param("{3.14: 0, 1e3: 1}", {3.14: 0, 1000.0: 1}, id="dict_float_key"),
pytest.param("{true: 1, fAlSe: 0}", {True: 1, False: 0}, id="dict_bool_key"),
pytest.param("{/-\\+.$%*@: 1}", {"/-\\+.$%*@": 1}, id="dict_unquoted_char_key"),
pytest.param(
"{\\\\\\(\\)\\[\\]\\{\\}\\:\\=\\ \\\t\\,: 1}",
{"\\()[]{}:= \t,": 1},
id="dict_esc_key",
),
pytest.param("{white spaces: 1}", {"white spaces": 1}, id="dict_ws_key"),
pytest.param(
"{'a:b': 1, ab 123.5 True: 2, null false: 3, 1: 4, null: 5}",
{
QuotedString(text="a:b", quote=Quote.single): 1,
"ab 123.5 True": 2,
"null false": 3,
1: 4,
None: 5,
},
id="dict_mixed_keys",
),
],
)
def test_dict_value(value: str, expected: Any) -> None:
ret = parse_rule(value, "dictValue")
def test_dict_container(value: str, expected: Any) -> None:
ret = parse_rule(value, "dictContainer")
assert ret == expected


Expand Down
23 changes: 18 additions & 5 deletions website/docs/advanced/override_grammar/basic.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ value: element | simpleChoiceSweep;

element:
primitive
| listValue
| dictValue
| listContainer
| dictContainer
| function
;

Expand All @@ -73,12 +73,12 @@ function: ID POPEN (argName? element (COMMA argName? element )* )? PCLOSE;

// Data structures.

listValue: BRACKET_OPEN // [], [1,2,3], [a,b,[1,2]]
listContainer: BRACKET_OPEN // [], [1,2,3], [a,b,[1,2]]
(element(COMMA element)*)?
BRACKET_CLOSE;

dictValue: BRACE_OPEN (dictKeyValuePair (COMMA dictKeyValuePair)*)? BRACE_CLOSE; // {}, {a:10,b:20}
dictKeyValuePair: ID COLON element;
dictContainer: BRACE_OPEN (dictKeyValuePair (COMMA dictKeyValuePair)*)? BRACE_CLOSE; // {}, {a:10,b:20}
dictKeyValuePair: dictKey COLON element;

// Primitive types.

Expand All @@ -95,6 +95,19 @@ primitive:
| ESC // \\, \(, \), \[, \], \{, \}, \:, \=, \ , \\t, \,
| WS // whitespaces
)+;

// Same as `primitive` except that `COLON` and `INTERPOLATION` are not allowed.
dictKey:
QUOTED_VALUE // 'hello world', "hello world"
| ( ID // foo_10
| NULL // null, NULL
| INT // 0, 10, -20, 1_000_000
| FLOAT // 3.14, -20.0, 1e-1, -10e3
| BOOL // true, TrUe, false, False
| UNQUOTED_CHAR // /, -, \, +, ., $, %, *, @
| ESC // \\, \(, \), \[, \], \{, \}, \:, \=, \ , \\t, \,
| WS // whitespaces
)+;
```

## Elements
Expand Down