Skip to content

Commit

Permalink
Implement lexical scoping with let expressions
Browse files Browse the repository at this point in the history
See jmespath/jmespath.jep#18 for more details.
  • Loading branch information
jamesls committed Mar 23, 2023
1 parent bbe7300 commit 6f469c2
Show file tree
Hide file tree
Showing 8 changed files with 382 additions and 18 deletions.
12 changes: 12 additions & 0 deletions jmespath/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
# {"type": <node type>", children: [], "value": ""}


def assign(name, expr):
return {'type': 'assign', 'children': [expr], 'value': name}


def comparator(name, first, second):
return {'type': 'comparator', 'children': [first, second], 'value': name}

Expand Down Expand Up @@ -46,6 +50,10 @@ def key_val_pair(key_name, node):
return {"type": "key_val_pair", 'children': [node], "value": key_name}


def let_expression(bindings, expr):
return {'type': 'let_expression', 'children': [*bindings, expr]}


def literal(literal_value):
return {'type': 'literal', 'value': literal_value, 'children': []}

Expand Down Expand Up @@ -88,3 +96,7 @@ def slice(start, end, step):

def value_projection(left, right):
return {'type': 'value_projection', 'children': [left, right]}


def variable_ref(name):
return {"type": "variable_ref", "children": [], "value": name}
6 changes: 6 additions & 0 deletions jmespath/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,9 @@ def __init__(self):

class UnknownFunctionError(JMESPathError):
pass


class UndefinedVariable(JMESPathError):
def __init__(self, varname):
self.varname = varname
super().__init__(f"Reference to undefined variable: {self.varname}")
34 changes: 18 additions & 16 deletions jmespath/lexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,22 +87,9 @@ def tokenize(self, expression):
elif self._current == '!':
yield self._match_or_else('=', 'ne', 'not')
elif self._current == '=':
if self._next() == '=':
yield {'type': 'eq', 'value': '==',
'start': self._position - 1, 'end': self._position}
self._next()
else:
if self._current is None:
# If we're at the EOF, we never advanced
# the position so we don't need to rewind
# it back one location.
position = self._position
else:
position = self._position - 1
raise LexerError(
lexer_position=position,
lexer_value='=',
message="Unknown token '='")
yield self._match_or_else('=', 'eq', 'assign')
elif self._current == '$':
yield self._consume_variable()
else:
raise LexerError(lexer_position=self._position,
lexer_value=self._current,
Expand All @@ -117,6 +104,21 @@ def _consume_number(self):
buff += self._current
return buff

def _consume_variable(self):
start = self._position
buff = self._current
self._next()
if self._current not in self.START_IDENTIFIER:
raise LexerError(
lexer_position=start,
lexer_value=self._current,
message='Invalid variable starting character %s' % self._current)
buff += self._current
while self._next() in self.VALID_IDENTIFIER:
buff += self._current
return {'type': 'variable', 'value': buff,
'start': start, 'end': start + len(buff)}

def _initialize_for_expression(self, expression):
if not expression:
raise EmptyExpressionError()
Expand Down
35 changes: 34 additions & 1 deletion jmespath/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
class Parser(object):
BINDING_POWER = {
'eof': 0,
'variable': 0,
'assign': 0,
'unquoted_identifier': 0,
'quoted_identifier': 0,
'literal': 0,
Expand Down Expand Up @@ -137,8 +139,39 @@ def _expression(self, binding_power=0):
def _token_nud_literal(self, token):
return ast.literal(token['value'])

def _token_nud_variable(self, token):
return ast.variable_ref(token['value'][1:])

def _token_nud_unquoted_identifier(self, token):
return ast.field(token['value'])
if token['value'] == 'let' and \
self._current_token() == 'variable':
return self._parse_let_expression()
else:
return ast.field(token['value'])

def _parse_let_expression(self):
bindings = []
while True:
var_token = self._lookahead_token(0)
# Strip off the '$'.
varname = var_token['value'][1:]
self._advance()
self._match('assign')
assign_expr = self._expression()
bindings.append(ast.assign(varname, assign_expr))
if self._is_in_keyword(self._lookahead_token(0)):
self._advance()
break
else:
self._match('comma')
expr = self._expression()
return ast.let_expression(bindings, expr)

def _is_in_keyword(self, token):
return (
token['type'] == 'unquoted_identifier' and
token['value'] == 'in'
)

def _token_nud_quoted_identifier(self, token):
field = ast.field(token['value'])
Expand Down
40 changes: 40 additions & 0 deletions jmespath/scope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from collections import deque


class ScopedChainDict:
"""Dictionary that can delegate lookups to multiple dicts.
This provides a basic get/set dict interface that is
backed by multiple dicts. Each dict is searched from
the top most (most recently pushed) scope dict until
a match is found.
"""
def __init__(self, *scopes):
# The scopes are evaluated starting at the top of the stack (the most
# recently pushed scope via .push_scope()). If we use a normal list()
# and push/pop scopes by adding/removing to the end of the list, we'd
# have to always call reversed(self._scopes) whenever we resolve a key,
# because the end of the list is the top of the stack.
# To avoid this, we're using a deque so we can append to the front of
# the list via .appendleft() in constant time, and iterate over scopes
# without having to do so with a reversed() call each time.
self._scopes = deque(scopes)

def __getitem__(self, key):
for scope in self._scopes:
if key in scope:
return scope[key]
raise KeyError(key)

def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default

def push_scope(self, scope):
self._scopes.appendleft(scope)

def pop_scope(self):
self._scopes.popleft()
24 changes: 24 additions & 0 deletions jmespath/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from jmespath import functions
from jmespath.compat import string_type
from jmespath.scope import ScopedChainDict
from jmespath import exceptions
from numbers import Number


Expand Down Expand Up @@ -121,6 +123,7 @@ def __init__(self, options=None):
self._functions = self._options.custom_functions
else:
self._functions = functions.Functions()
self._scope = ScopedChainDict()

def default_visit(self, node, *args, **kwargs):
raise NotImplementedError(node['type'])
Expand Down Expand Up @@ -280,6 +283,27 @@ def visit_projection(self, node, value):
collected.append(current)
return collected

def visit_let_expression(self, node, value):
*bindings, expr = node['children']
scope = {}
for assign in bindings:
scope.update(self.visit(assign, value))
self._scope.push_scope(scope)
result = self.visit(expr, value)
self._scope.pop_scope()
return result

def visit_assign(self, node, value):
name = node['value']
value = self.visit(node['children'][0], value)
return {name: value}

def visit_variable_ref(self, node, value):
try:
return self._scope[node['value']]
except KeyError:
raise exceptions.UndefinedVariable(node['value'])

def visit_value_projection(self, node, value):
base = self.visit(node['children'][0], value)
try:
Expand Down
Loading

0 comments on commit 6f469c2

Please sign in to comment.