Skip to content

Commit

Permalink
[inter] Mechanism to use the new interpolation parser when resolving …
Browse files Browse the repository at this point in the history
…interpolations

Fixes omry#100 and omry#318
  • Loading branch information
odelalleau committed Sep 16, 2020
1 parent b8208ba commit 5cfd674
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 103 deletions.
56 changes: 26 additions & 30 deletions omegaconf/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,32 +309,37 @@ class ValueKind(Enum):
VALUE = 0
MANDATORY_MISSING = 1
INTERPOLATION = 2
STR_INTERPOLATION = 3


def get_value_kind(value: Any, return_match_list: bool = False) -> Any:
def get_value_kind(
value: Any, return_parse_tree: bool = False
) -> Union[
ValueKind,
Tuple[ValueKind, Optional[OmegaConfGrammarParser.ConfigValueContext]],
]:
"""
Determine the kind of a value
Examples:
MANDATORY_MISSING : "???
VALUE : "10", "20", True,
INTERPOLATION: "${foo}", "${foo.bar}"
STR_INTERPOLATION: "ftp://${host}/path"
:param value: input string to classify
:param return_match_list: True to return the match list as well
:return: ValueKind
VALUE : "10", "20", True
MANDATORY_MISSING : "???"
INTERPOLATION: "${foo.bar}", "${foo.${bar}}", "${foo:bar}", "[${foo}, ${bar}]",
"ftp://${host}/path", "${foo:${bar}, [true], {'baz': ${baz}}}"
:param value: Input to classify.
:param return_parse_tree: Whether to also return the interpolation parse tree.
:return: ValueKind (and optionally the associated interpolation parse tree).
"""

key_prefix = r"\${(\w+:)?"
legal_characters = r"([\w\.%_ \\/:,-]*?)}"
match_list: Optional[List[Match[str]]] = None
parse_tree: Optional[OmegaConfGrammarParser.ConfigValueContext] = None

def ret(
value_kind: ValueKind,
) -> Union[ValueKind, Tuple[ValueKind, Optional[List[Match[str]]]]]:
if return_match_list:
return value_kind, match_list
) -> Union[
ValueKind,
Tuple[ValueKind, Optional[OmegaConfGrammarParser.ConfigValueContext]],
]:
if return_parse_tree:
assert value_kind != ValueKind.INTERPOLATION or parse_tree is not None
return value_kind, parse_tree
else:
return value_kind

Expand All @@ -348,22 +353,13 @@ def ret(
if value == "???":
return ret(ValueKind.MANDATORY_MISSING)

if not isinstance(value, str):
return ret(ValueKind.VALUE)

match_list = list(re.finditer(key_prefix + legal_characters, value))
if len(match_list) == 0:
if not isinstance(value, str) or "${" not in value:
return ret(ValueKind.VALUE)

if len(match_list) == 1 and value == match_list[0].group(0):
return ret(ValueKind.INTERPOLATION)
else:
return ret(ValueKind.STR_INTERPOLATION)

if return_parse_tree:
parse_tree = grammar_parser.parse(value)

def is_bool(st: str) -> bool:
st = str.lower(st)
return st == "true" or st == "false"
return ret(ValueKind.INTERPOLATION)


def is_float(st: str) -> bool:
Expand Down
139 changes: 68 additions & 71 deletions omegaconf/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,42 +111,19 @@ def _get_full_key(self, key: Union[str, Enum, int, None]) -> str:
def _dereference_node(
self, throw_on_missing: bool = False, throw_on_resolution_failure: bool = True
) -> Optional["Node"]:
from .nodes import StringNode

if self._is_interpolation():
value_kind, match_list = get_value_kind(
value=self._value(), return_match_list=True
)
match = match_list[0]
parent = self._get_parent()
assert parent is not None
key = self._key()
if value_kind == ValueKind.INTERPOLATION:
assert parent is not None
v = parent._resolve_simple_interpolation(
key=key,
inter_type=match.group(1),
inter_key=match.group(2),
throw_on_missing=throw_on_missing,
throw_on_resolution_failure=throw_on_resolution_failure,
)
return v
elif value_kind == ValueKind.STR_INTERPOLATION:
assert parent is not None
ret = parent._resolve_interpolation(
key=key,
value=self,
throw_on_missing=throw_on_missing,
throw_on_resolution_failure=throw_on_resolution_failure,
)
if ret is None:
return ret
return StringNode(
value=ret,
key=key,
parent=parent,
is_optional=self._metadata.optional,
)
assert False
rval = parent.resolve_interpolation(
parent=parent,
key=key,
value=self,
throw_on_missing=throw_on_missing,
throw_on_resolution_failure=throw_on_resolution_failure,
)
assert rval is None or isinstance(rval, Node)
return rval
else:
# not interpolation, compare directly
if throw_on_missing:
Expand Down Expand Up @@ -324,6 +301,51 @@ def _select_impl(
)
return root, last_key, value

def _resolve_complex_interpolation(
self,
parent: Optional["Container"],
value: "Node",
key: Any,
parse_tree: OmegaConfGrammarParser.ConfigValueContext,
throw_on_missing: bool,
throw_on_resolution_failure: bool,
) -> Optional["Node"]:
"""
A "complex" interpolation is any interpolation that cannot be handled by
`resolve_simple_interpolation()`, i.e. that either contains nested
interpolations or is not a single "${..}" block.
"""

from .nodes import StringNode

value_str = value._value()
assert isinstance(value_str, str)

visitor = GrammarVisitor(
container=self,
resolve_args=dict(
key=key,
parent=parent,
throw_on_missing=throw_on_missing,
throw_on_resolution_failure=throw_on_resolution_failure,
),
)

resolved = visitor.visit(parse_tree)
if resolved is None:
return None
elif isinstance(resolved, str):
# Result is a string: create a new node to store it.
return StringNode(
value=resolved,
key=key,
parent=parent,
is_optional=value._metadata.optional,
)
else:
assert isinstance(resolved, Node)
return resolved

def resolve_simple_interpolation(
self,
key: Any,
Expand Down Expand Up @@ -379,53 +401,28 @@ def resolve_simple_interpolation(
else:
return None

def _resolve_interpolation(
def resolve_interpolation(
self,
parent: Optional["Container"],
key: Any,
value: "Node",
throw_on_missing: bool,
throw_on_resolution_failure: bool,
) -> Any:
from .nodes import StringNode
value_kind, parse_tree = get_value_kind(value=value, return_parse_tree=True) # type: ignore

value_kind, match_list = get_value_kind(value=value, return_match_list=True)
if value_kind not in (ValueKind.INTERPOLATION, ValueKind.STR_INTERPOLATION):
if value_kind != ValueKind.INTERPOLATION:
return value

if value_kind == ValueKind.INTERPOLATION:
# simple interpolation, inherit type
match = match_list[0]
return self._resolve_simple_interpolation(
key=key,
inter_type=match.group(1),
inter_key=match.group(2),
throw_on_missing=throw_on_missing,
throw_on_resolution_failure=throw_on_resolution_failure,
)
elif value_kind == ValueKind.STR_INTERPOLATION:
value = _get_value(value)
assert isinstance(value, str)
orig = value
new = ""
last_index = 0
for match in match_list:
new_val = self._resolve_simple_interpolation(
key=key,
inter_type=match.group(1),
inter_key=match.group(2),
throw_on_missing=throw_on_missing,
throw_on_resolution_failure=throw_on_resolution_failure,
)
# if failed to resolve, return None for the whole thing.
if new_val is None:
return None
new += orig[last_index : match.start(0)] + str(new_val)
last_index = match.end(0)

new += orig[last_index:]
return StringNode(value=new, key=key)
else:
assert False
assert parse_tree is not None
return self._resolve_complex_interpolation(
parent=parent,
value=value,
key=key,
parse_tree=parse_tree,
throw_on_missing=throw_on_missing,
throw_on_resolution_failure=throw_on_resolution_failure,
)

def _re_parent(self) -> None:
from .dictconfig import DictConfig
Expand Down
4 changes: 2 additions & 2 deletions omegaconf/basecontainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def _resolve_with_default(
def is_mandatory_missing(val: Any) -> bool:
return get_value_kind(val) == ValueKind.MANDATORY_MISSING # type: ignore

value = _get_value(value)
val = _get_value(value)
has_default = default_value is not DEFAULT_VALUE_MARKER
if has_default and (value is None or is_mandatory_missing(value)):
if has_default and (val is None or is_mandatory_missing(val)):
return default_value

resolved = self.resolve_interpolation(
Expand Down

0 comments on commit 5cfd674

Please sign in to comment.