diff --git a/jmespath/ast.py b/jmespath/ast.py index dd56c6e..9d8fb0d 100644 --- a/jmespath/ast.py +++ b/jmespath/ast.py @@ -2,6 +2,10 @@ # {"type": ", children: [], "value": ""} +def assign(name, expr): + return {'type': 'assign', 'children': [expr], 'value': name} + + def comparator(name, first, second): return {'type': 'comparator', 'children': [first, second], 'value': name} @@ -46,6 +50,10 @@ def key_val_pair(key_name, node): return {"type": "key_val_pair", 'children': [node], "value": key_name} +def let_expression(bindings, expr): + return {'type': 'let_expression', 'children': [*bindings, expr]} + + def literal(literal_value): return {'type': 'literal', 'value': literal_value, 'children': []} @@ -88,3 +96,7 @@ def slice(start, end, step): def value_projection(left, right): return {'type': 'value_projection', 'children': [left, right]} + + +def variable_ref(name): + return {"type": "variable_ref", "children": [], "value": name} diff --git a/jmespath/exceptions.py b/jmespath/exceptions.py index 0156015..42c7eb4 100644 --- a/jmespath/exceptions.py +++ b/jmespath/exceptions.py @@ -120,3 +120,9 @@ def __init__(self): class UnknownFunctionError(JMESPathError): pass + + +class UndefinedVariable(JMESPathError): + def __init__(self, varname): + self.varname = varname + super().__init__(f"Reference to undefined variable: {self.varname}") diff --git a/jmespath/lexer.py b/jmespath/lexer.py index 8db05e3..e98e53c 100644 --- a/jmespath/lexer.py +++ b/jmespath/lexer.py @@ -87,22 +87,9 @@ def tokenize(self, expression): elif self._current == '!': yield self._match_or_else('=', 'ne', 'not') elif self._current == '=': - if self._next() == '=': - yield {'type': 'eq', 'value': '==', - 'start': self._position - 1, 'end': self._position} - self._next() - else: - if self._current is None: - # If we're at the EOF, we never advanced - # the position so we don't need to rewind - # it back one location. - position = self._position - else: - position = self._position - 1 - raise LexerError( - lexer_position=position, - lexer_value='=', - message="Unknown token '='") + yield self._match_or_else('=', 'eq', 'assign') + elif self._current == '$': + yield self._consume_variable() else: raise LexerError(lexer_position=self._position, lexer_value=self._current, @@ -117,6 +104,21 @@ def _consume_number(self): buff += self._current return buff + def _consume_variable(self): + start = self._position + buff = self._current + self._next() + if self._current not in self.START_IDENTIFIER: + raise LexerError( + lexer_position=start, + lexer_value=self._current, + message='Invalid variable starting character %s' % self._current) + buff += self._current + while self._next() in self.VALID_IDENTIFIER: + buff += self._current + return {'type': 'variable', 'value': buff, + 'start': start, 'end': start + len(buff)} + def _initialize_for_expression(self, expression): if not expression: raise EmptyExpressionError() diff --git a/jmespath/parser.py b/jmespath/parser.py index 4706688..a08e5fc 100644 --- a/jmespath/parser.py +++ b/jmespath/parser.py @@ -37,6 +37,8 @@ class Parser(object): BINDING_POWER = { 'eof': 0, + 'variable': 0, + 'assign': 0, 'unquoted_identifier': 0, 'quoted_identifier': 0, 'literal': 0, @@ -137,8 +139,39 @@ def _expression(self, binding_power=0): def _token_nud_literal(self, token): return ast.literal(token['value']) + def _token_nud_variable(self, token): + return ast.variable_ref(token['value'][1:]) + def _token_nud_unquoted_identifier(self, token): - return ast.field(token['value']) + if token['value'] == 'let' and \ + self._current_token() == 'variable': + return self._parse_let_expression() + else: + return ast.field(token['value']) + + def _parse_let_expression(self): + bindings = [] + while True: + var_token = self._lookahead_token(0) + # Strip off the '$'. + varname = var_token['value'][1:] + self._advance() + self._match('assign') + assign_expr = self._expression() + bindings.append(ast.assign(varname, assign_expr)) + if self._is_in_keyword(self._lookahead_token(0)): + self._advance() + break + else: + self._match('comma') + expr = self._expression() + return ast.let_expression(bindings, expr) + + def _is_in_keyword(self, token): + return ( + token['type'] == 'unquoted_identifier' and + token['value'] == 'in' + ) def _token_nud_quoted_identifier(self, token): field = ast.field(token['value']) diff --git a/jmespath/scope.py b/jmespath/scope.py new file mode 100644 index 0000000..34f5d3d --- /dev/null +++ b/jmespath/scope.py @@ -0,0 +1,40 @@ +from collections import deque + + +class ScopedChainDict: + """Dictionary that can delegate lookups to multiple dicts. + + This provides a basic get/set dict interface that is + backed by multiple dicts. Each dict is searched from + the top most (most recently pushed) scope dict until + a match is found. + + """ + def __init__(self, *scopes): + # The scopes are evaluated starting at the top of the stack (the most + # recently pushed scope via .push_scope()). If we use a normal list() + # and push/pop scopes by adding/removing to the end of the list, we'd + # have to always call reversed(self._scopes) whenever we resolve a key, + # because the end of the list is the top of the stack. + # To avoid this, we're using a deque so we can append to the front of + # the list via .appendleft() in constant time, and iterate over scopes + # without having to do so with a reversed() call each time. + self._scopes = deque(scopes) + + def __getitem__(self, key): + for scope in self._scopes: + if key in scope: + return scope[key] + raise KeyError(key) + + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default + + def push_scope(self, scope): + self._scopes.appendleft(scope) + + def pop_scope(self): + self._scopes.popleft() diff --git a/jmespath/visitor.py b/jmespath/visitor.py index 15fb177..b2ea839 100644 --- a/jmespath/visitor.py +++ b/jmespath/visitor.py @@ -2,6 +2,8 @@ from jmespath import functions from jmespath.compat import string_type +from jmespath.scope import ScopedChainDict +from jmespath import exceptions from numbers import Number @@ -121,6 +123,7 @@ def __init__(self, options=None): self._functions = self._options.custom_functions else: self._functions = functions.Functions() + self._scope = ScopedChainDict() def default_visit(self, node, *args, **kwargs): raise NotImplementedError(node['type']) @@ -280,6 +283,27 @@ def visit_projection(self, node, value): collected.append(current) return collected + def visit_let_expression(self, node, value): + *bindings, expr = node['children'] + scope = {} + for assign in bindings: + scope.update(self.visit(assign, value)) + self._scope.push_scope(scope) + result = self.visit(expr, value) + self._scope.pop_scope() + return result + + def visit_assign(self, node, value): + name = node['value'] + value = self.visit(node['children'][0], value) + return {name: value} + + def visit_variable_ref(self, node, value): + try: + return self._scope[node['value']] + except KeyError: + raise exceptions.UndefinedVariable(node['value']) + def visit_value_projection(self, node, value): base = self.visit(node['children'][0], value) try: diff --git a/tests/compliance/letexpr.json b/tests/compliance/letexpr.json new file mode 100644 index 0000000..9347e82 --- /dev/null +++ b/tests/compliance/letexpr.json @@ -0,0 +1,247 @@ +[ + { + "given": { + "foo": { + "bar": "baz" + } + }, + "cases": [ + { + "expression": "let $foo = foo in $foo", + "result": { + "bar": "baz" + } + }, + { + "expression": "let $foo = foo.bar in $foo", + "result": "baz" + }, + { + "expression": "let $foo = foo.bar in [$foo, $foo]", + "result": [ + "baz", + "baz" + ] + }, + { + "command": "Multiple assignments", + "expression": "let $foo = 'foo', $bar = 'bar' in [$foo, $bar]", + "result": [ + "foo", + "bar" + ] + } + ] + }, + { + "given": { + "a": "topval", + "b": [ + { + "a": "inner1" + }, + { + "a": "inner2" + } + ] + }, + "cases": [ + { + "expression": "let $a = a in b[*].[a, $a, let $a = 'shadow' in $a]", + "result": [ + [ + "inner1", + "topval", + "shadow" + ], + [ + "inner2", + "topval", + "shadow" + ] + ] + }, + { + "comment": "Bindings only visible within expression clause", + "expression": "let $a = 'top-a' in let $a = 'in-a', $b = $a in $b", + "result": "top-a" + } + ] + }, + { + "given": { + "foo": [ + [ + 0, + 1 + ], + [ + 2, + 3 + ], + [ + 4, + 5 + ] + ] + }, + "cases": [ + { + "comment": "Projection is stopped when bound to variable", + "expression": "let $foo = foo[*] in $foo[0]", + "result": [ + 0, + 1 + ] + } + ] + }, + { + "given": [ + { + "home_state": "WA", + "states": [ + { + "name": "WA", + "cities": [ + "Seattle", + "Bellevue", + "Olympia" + ] + }, + { + "name": "CA", + "cities": [ + "Los Angeles", + "San Francisco" + ] + }, + { + "name": "NY", + "cities": [ + "New York City", + "Albany" + ] + } + ] + }, + { + "home_state": "NY", + "states": [ + { + "name": "WA", + "cities": [ + "Seattle", + "Bellevue", + "Olympia" + ] + }, + { + "name": "CA", + "cities": [ + "Los Angeles", + "San Francisco" + ] + }, + { + "name": "NY", + "cities": [ + "New York City", + "Albany" + ] + } + ] + } + ], + "cases": [ + { + "expression": "[*].[let $home_state = home_state in states[? name == $home_state].cities[]][]", + "result": [ + [ + "Seattle", + "Bellevue", + "Olympia" + ], + [ + "New York City", + "Albany" + ] + ] + } + ] + }, + { + "given": { + "imageDetails": [ + { + "repositoryName": "org/first-repo", + "imageTags": [ + "latest", + "v1.0", + "v1.2" + ], + "imageDigest": "sha256:abcd" + }, + { + "repositoryName": "org/second-repo", + "imageTags": [ + "v2.0", + "v2.2" + ], + "imageDigest": "sha256:efgh" + } + ] + }, + "cases": [ + { + "expression": "imageDetails[].[\n let $repo = repositoryName,\n $digest = imageDigest\n in\n imageTags[].[@, $digest, $repo]\n][][]\n", + "result": [ + [ + "latest", + "sha256:abcd", + "org/first-repo" + ], + [ + "v1.0", + "sha256:abcd", + "org/first-repo" + ], + [ + "v1.2", + "sha256:abcd", + "org/first-repo" + ], + [ + "v2.0", + "sha256:efgh", + "org/second-repo" + ], + [ + "v2.2", + "sha256:efgh", + "org/second-repo" + ] + ] + } + ] + }, + { + "given": {}, + "cases": [ + { + "expression": "$noexist", + "error": "undefined-variable" + }, + { + "comment": "Reference out of scope variable", + "expression": "[let $scope = 'foo' in [$scope], $scope]", + "error": "undefined-variable" + }, + { + "comment": "Can't use var ref in RHS of subexpression", + "expression": "foo.$bar", + "error": "syntax" + } + ] + } +] diff --git a/tests/test_compliance.py b/tests/test_compliance.py index cde6acb..a352e62 100644 --- a/tests/test_compliance.py +++ b/tests/test_compliance.py @@ -93,7 +93,7 @@ def test_expression(given, expression, expected, filename): ) def test_error_expression(given, expression, error, filename): import jmespath.parser - if error not in ('syntax', 'invalid-type', + if error not in ('syntax', 'invalid-type', 'undefined-variable', 'unknown-function', 'invalid-arity', 'invalid-value'): raise RuntimeError("Unknown error type '%s'" % error) try: