From 1ec6a4a7b8deba4ab855f5ecc8be423d2cfe4c84 Mon Sep 17 00:00:00 2001 From: Liam DeVoe Date: Fri, 3 Jan 2025 15:11:40 -0500 Subject: [PATCH] type lark.py --- .../src/hypothesis/extra/lark.py | 54 +++++++++++-------- 1 file changed, 33 insertions(+), 21 deletions(-) diff --git a/hypothesis-python/src/hypothesis/extra/lark.py b/hypothesis-python/src/hypothesis/extra/lark.py index 3011559796..bf053ed283 100644 --- a/hypothesis-python/src/hypothesis/extra/lark.py +++ b/hypothesis-python/src/hypothesis/extra/lark.py @@ -28,10 +28,13 @@ from typing import Optional import lark -from lark.grammar import NonTerminal, Terminal +from lark.grammar import NonTerminal, Rule, Symbol, Terminal +from lark.lark import Lark +from lark.lexer import TerminalDef from hypothesis import strategies as st from hypothesis.errors import InvalidArgument +from hypothesis.internal.conjecture.data import ConjectureData from hypothesis.internal.conjecture.utils import calc_label_from_name from hypothesis.internal.validation import check_type from hypothesis.strategies._internal.regex import IncompatibleWithAlphabet @@ -40,7 +43,9 @@ __all__ = ["from_lark"] -def get_terminal_names(terminals, rules, ignore_names): +def get_terminal_names( + terminals: list[TerminalDef], rules: list[Rule], ignore_names: list[str] +) -> set[str]: """Get names of all terminals in the grammar. The arguments are the results of calling ``Lark.grammar.compile()``, @@ -60,13 +65,15 @@ class LarkStrategy(st.SearchStrategy): See ``from_lark`` for details. """ - def __init__(self, grammar, start, explicit, alphabet): + def __init__( + self, + grammar: Lark, + start: Optional[str], + explicit: dict[str, st.SearchStrategy[str]], + alphabet: st.SearchStrategy[str], + ) -> None: assert isinstance(grammar, lark.lark.Lark) - if start is None: - start = grammar.options.start - if not isinstance(start, list): - start = [start] - self.grammar = grammar + start: list[str] = grammar.options.start if start is None else [start] # This is a total hack, but working around the changes is a nicer user # experience than breaking for anyone who doesn't instantly update their @@ -76,19 +83,18 @@ def __init__(self, grammar, start, explicit, alphabet): terminals, rules, ignore_names = grammar.grammar.compile(start, ()) elif "start" in compile_args: # pragma: no cover # Support lark <= 0.10.0, without the terminals_to_keep argument. - terminals, rules, ignore_names = grammar.grammar.compile(start) + terminals, rules, ignore_names = grammar.grammar.compile(start) # type: ignore else: # pragma: no cover # This branch is to support lark <= 0.7.1, without the start argument. - terminals, rules, ignore_names = grammar.grammar.compile() + terminals, rules, ignore_names = grammar.grammar.compile() # type: ignore - self.names_to_symbols = {} + self.names_to_symbols: dict[str, Symbol] = {} for r in rules: - t = r.origin - self.names_to_symbols[t.name] = t + self.names_to_symbols[r.origin.name] = r.origin disallowed = set() - self.terminal_strategies = {} + self.terminal_strategies: dict[str, st.SearchStrategy[str]] = {} for t in terminals: self.names_to_symbols[t.name] = Terminal(t.name) s = st.from_regex(t.pattern.to_regexp(), fullmatch=True, alphabet=alphabet) @@ -119,7 +125,8 @@ def __init__(self, grammar, start, explicit, alphabet): ) self.terminal_strategies.update(explicit) - nonterminals = {} + # TODO: should this be NonTerminal only? but rule.expansion is Symbol... + nonterminals: dict[str, list[tuple[Symbol, ...]]] = {} for rule in rules: if disallowed.isdisjoint(r.name for r in rule.expansion): @@ -149,15 +156,15 @@ def __init__(self, grammar, start, explicit, alphabet): k: st.sampled_from(sorted(v, key=len)) for k, v in nonterminals.items() } - self.__rule_labels = {} + self.__rule_labels: dict[str, int] = {} - def do_draw(self, data): - state = [] + def do_draw(self, data: ConjectureData) -> str: + state: list[str] = [] start = data.draw(self.start) self.draw_symbol(data, start, state) return "".join(state) - def rule_label(self, name): + def rule_label(self, name: str) -> int: try: return self.__rule_labels[name] except KeyError: @@ -165,7 +172,12 @@ def rule_label(self, name): name, calc_label_from_name(f"LARK:{name}") ) - def draw_symbol(self, data, symbol, draw_state): + def draw_symbol( + self, + data: ConjectureData, + symbol: Symbol, + draw_state: list[str], + ) -> None: if isinstance(symbol, Terminal): strategy = self.terminal_strategies[symbol.name] draw_state.append(data.draw(strategy)) @@ -178,7 +190,7 @@ def draw_symbol(self, data, symbol, draw_state): self.gen_ignore(data, draw_state) data.stop_example() - def gen_ignore(self, data, draw_state): + def gen_ignore(self, data: ConjectureData, draw_state: list[str]) -> None: if self.ignored_symbols and data.draw_boolean(1 / 4): emit = data.draw(st.sampled_from(self.ignored_symbols)) self.draw_symbol(data, emit, draw_state)