Skip to content

Commit

Permalink
Speed up new backtracking parser
Browse files Browse the repository at this point in the history
  • Loading branch information
isidentical committed Dec 26, 2021
1 parent 092959f commit f5a202d
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 5 deletions.
66 changes: 61 additions & 5 deletions src/blib2to3/pgen2/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ def lam_sub(grammar: Grammar, node: RawNode) -> NL:
return Node(type=node[0], children=node[3], context=node[2])


# A placeholder node, used when parser is backtracking.
FAKE_NODE = (-1, None, None, None)


def stack_copy(
stack: List[Tuple[DFAS, int, RawNode]]
) -> List[Tuple[DFAS, int, RawNode]]:
"""Nodeless stack copy."""
return [(copy.deepcopy(dfa), label, FAKE_NODE) for dfa, label, _ in stack]


class Recorder:
def __init__(self, parser: "Parser", ilabels: List[int], context: Context) -> None:
self.parser = parser
Expand All @@ -54,21 +65,45 @@ def __init__(self, parser: "Parser", ilabels: List[int], context: Context) -> No

self._dead_ilabels: Set[int] = set()
self._start_point = self.parser.stack
self._points = {ilabel: copy.deepcopy(self._start_point) for ilabel in ilabels}
self._points = {ilabel: stack_copy(self._start_point) for ilabel in ilabels}

@property
def ilabels(self) -> Set[int]:
return self._dead_ilabels.symmetric_difference(self._ilabels)

@contextmanager
def switch_to(self, ilabel: int) -> Iterator[None]:
self.parser.stack = self._points[ilabel]
with self.patch():
self.parser.stack = self._points[ilabel]
try:
yield
except ParseError:
self._dead_ilabels.add(ilabel)
finally:
self.parser.stack = self._start_point

@contextmanager
def patch(self) -> Iterator[None]:
"""
Patch basic state operations (push/pop/shift) with node-level
immutable variants. These still will operate on the stack; but
they won't create any new nodes, or modify the contents of any
other existing nodes.
This saves us a ton of time when we are backtracking, since we
want to restore to the initial state as quick as possible, which
can only be done by having as little mutatations as possible.
"""
original_functions = {}
for name in self.parser.STATE_OPERATIONS:
original_functions[name] = getattr(self.parser, name)
safe_variant = getattr(self.parser, name + "_safe")
setattr(self.parser, name, safe_variant)
try:
yield
except ParseError:
self._dead_ilabels.add(ilabel)
finally:
self.parser.stack = self._start_point
for name, func in original_functions.items():
setattr(self.parser, name, func)

def add_token(self, tok_type: int, tok_val: Text, raw: bool = False) -> None:
func: Callable[..., Any]
Expand Down Expand Up @@ -317,6 +352,8 @@ def classify(self, type: int, value: Text, context: Context) -> List[int]:
raise ParseError("bad token", type, value, context)
return [ilabel]

STATE_OPERATIONS = ["shift", "push", "pop"]

def shift(self, type: int, value: Text, newstate: int, context: Context) -> None:
"""Shift a token. (Internal)"""
dfa, state, node = self.stack[-1]
Expand Down Expand Up @@ -344,3 +381,22 @@ def pop(self) -> None:
else:
self.rootnode = newnode
self.rootnode.used_names = self.used_names

def shift_safe(
self, type: int, value: Text, newstate: int, context: Context
) -> None:
"""Immutable (node-level) version of shift()"""
dfa, state, _ = self.stack[-1]
self.stack[-1] = (dfa, newstate, FAKE_NODE)

def push_safe(
self, type: int, newdfa: DFAS, newstate: int, context: Context
) -> None:
"""Immutable (node-level) version of push()"""
dfa, state, _ = self.stack[-1]
self.stack[-1] = (dfa, newstate, FAKE_NODE)
self.stack.append((newdfa, 0, FAKE_NODE))

def pop_safe(self) -> None:
"""Immutable (node-level) version of pop()"""
self.stack.pop()
107 changes: 107 additions & 0 deletions tests/data/pattern_matching_generic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
re.match()
match = a
with match() as match:
match = f"{match}"

re.match()
match = a
with match() as match:
match = f"{match}"


def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
if not target_versions:
# No target_version specified, so try all grammars.
return [
# Python 3.7+
pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords,
# Python 3.0-3.6
pygram.python_grammar_no_print_statement_no_exec_statement,
# Python 2.7 with future print_function import
pygram.python_grammar_no_print_statement,
# Python 2.7
pygram.python_grammar,
]

match match:
case case:
match match:
case case:
pass

if all(version.is_python2() for version in target_versions):
# Python 2-only code, so try Python 2 grammars.
return [
# Python 2.7 with future print_function import
pygram.python_grammar_no_print_statement,
# Python 2.7
pygram.python_grammar,
]

re.match()
match = a
with match() as match:
match = f"{match}"

def test_patma_139(self):
x = False
match x:
case bool(z):
y = 0
self.assertIs(x, False)
self.assertEqual(y, 0)
self.assertIs(z, x)

# Python 3-compatible code, so only try Python 3 grammar.
grammars = []
if supports_feature(target_versions, Feature.PATTERN_MATCHING):
# Python 3.10+
grammars.append(pygram.python_grammar_soft_keywords)
# If we have to parse both, try to parse async as a keyword first
if not supports_feature(
target_versions, Feature.ASYNC_IDENTIFIERS
) and not supports_feature(target_versions, Feature.PATTERN_MATCHING):
# Python 3.7-3.9
grammars.append(
pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords
)
if not supports_feature(target_versions, Feature.ASYNC_KEYWORDS):
# Python 3.0-3.6
grammars.append(pygram.python_grammar_no_print_statement_no_exec_statement)

def test_patma_155(self):
x = 0
y = None
match x:
case 1e1000:
y = 0
self.assertEqual(x, 0)
self.assertIs(y, None)

x = range(3)
match x:
case [y, case as x, z]:
w = 0

# At least one of the above branches must have been taken, because every Python
# version has exactly one of the two 'ASYNC_*' flags
return grammars


def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
"""Given a string with source, return the lib2to3 Node."""
if not src_txt.endswith("\n"):
src_txt += "\n"

grammars = get_grammars(set(target_versions))


re.match()
match = a
with match() as match:
match = f"{match}"

re.match()
match = a
with match() as match:
match = f"{match}"
1 change: 1 addition & 0 deletions tests/test_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
"pattern_matching_complex",
"pattern_matching_extras",
"pattern_matching_style",
"pattern_matching_generic",
"parenthesized_context_managers",
]

Expand Down

0 comments on commit f5a202d

Please sign in to comment.