Skip to content

Commit

Permalink
type lark.py
Browse files Browse the repository at this point in the history
  • Loading branch information
tybug committed Jan 7, 2025
1 parent d7bb898 commit 1ec6a4a
Showing 1 changed file with 33 additions and 21 deletions.
54 changes: 33 additions & 21 deletions hypothesis-python/src/hypothesis/extra/lark.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,13 @@
from typing import Optional

import lark
from lark.grammar import NonTerminal, Terminal
from lark.grammar import NonTerminal, Rule, Symbol, Terminal
from lark.lark import Lark
from lark.lexer import TerminalDef

from hypothesis import strategies as st
from hypothesis.errors import InvalidArgument
from hypothesis.internal.conjecture.data import ConjectureData
from hypothesis.internal.conjecture.utils import calc_label_from_name
from hypothesis.internal.validation import check_type
from hypothesis.strategies._internal.regex import IncompatibleWithAlphabet
Expand All @@ -40,7 +43,9 @@
__all__ = ["from_lark"]


def get_terminal_names(terminals, rules, ignore_names):
def get_terminal_names(
terminals: list[TerminalDef], rules: list[Rule], ignore_names: list[str]
) -> set[str]:
"""Get names of all terminals in the grammar.
The arguments are the results of calling ``Lark.grammar.compile()``,
Expand All @@ -60,13 +65,15 @@ class LarkStrategy(st.SearchStrategy):
See ``from_lark`` for details.
"""

def __init__(self, grammar, start, explicit, alphabet):
def __init__(
self,
grammar: Lark,
start: Optional[str],
explicit: dict[str, st.SearchStrategy[str]],
alphabet: st.SearchStrategy[str],
) -> None:
assert isinstance(grammar, lark.lark.Lark)
if start is None:
start = grammar.options.start
if not isinstance(start, list):
start = [start]
self.grammar = grammar
start: list[str] = grammar.options.start if start is None else [start]

# This is a total hack, but working around the changes is a nicer user
# experience than breaking for anyone who doesn't instantly update their
Expand All @@ -76,19 +83,18 @@ def __init__(self, grammar, start, explicit, alphabet):
terminals, rules, ignore_names = grammar.grammar.compile(start, ())
elif "start" in compile_args: # pragma: no cover
# Support lark <= 0.10.0, without the terminals_to_keep argument.
terminals, rules, ignore_names = grammar.grammar.compile(start)
terminals, rules, ignore_names = grammar.grammar.compile(start) # type: ignore
else: # pragma: no cover
# This branch is to support lark <= 0.7.1, without the start argument.
terminals, rules, ignore_names = grammar.grammar.compile()
terminals, rules, ignore_names = grammar.grammar.compile() # type: ignore

self.names_to_symbols = {}
self.names_to_symbols: dict[str, Symbol] = {}

for r in rules:
t = r.origin
self.names_to_symbols[t.name] = t
self.names_to_symbols[r.origin.name] = r.origin

disallowed = set()
self.terminal_strategies = {}
self.terminal_strategies: dict[str, st.SearchStrategy[str]] = {}
for t in terminals:
self.names_to_symbols[t.name] = Terminal(t.name)
s = st.from_regex(t.pattern.to_regexp(), fullmatch=True, alphabet=alphabet)
Expand Down Expand Up @@ -119,7 +125,8 @@ def __init__(self, grammar, start, explicit, alphabet):
)
self.terminal_strategies.update(explicit)

nonterminals = {}
# TODO: should this be NonTerminal only? but rule.expansion is Symbol...
nonterminals: dict[str, list[tuple[Symbol, ...]]] = {}

for rule in rules:
if disallowed.isdisjoint(r.name for r in rule.expansion):
Expand Down Expand Up @@ -149,23 +156,28 @@ def __init__(self, grammar, start, explicit, alphabet):
k: st.sampled_from(sorted(v, key=len)) for k, v in nonterminals.items()
}

self.__rule_labels = {}
self.__rule_labels: dict[str, int] = {}

def do_draw(self, data):
state = []
def do_draw(self, data: ConjectureData) -> str:
state: list[str] = []
start = data.draw(self.start)
self.draw_symbol(data, start, state)
return "".join(state)

def rule_label(self, name):
def rule_label(self, name: str) -> int:
try:
return self.__rule_labels[name]
except KeyError:
return self.__rule_labels.setdefault(
name, calc_label_from_name(f"LARK:{name}")
)

def draw_symbol(self, data, symbol, draw_state):
def draw_symbol(
self,
data: ConjectureData,
symbol: Symbol,
draw_state: list[str],
) -> None:
if isinstance(symbol, Terminal):
strategy = self.terminal_strategies[symbol.name]
draw_state.append(data.draw(strategy))
Expand All @@ -178,7 +190,7 @@ def draw_symbol(self, data, symbol, draw_state):
self.gen_ignore(data, draw_state)
data.stop_example()

def gen_ignore(self, data, draw_state):
def gen_ignore(self, data: ConjectureData, draw_state: list[str]) -> None:
if self.ignored_symbols and data.draw_boolean(1 / 4):
emit = data.draw(st.sampled_from(self.ignored_symbols))
self.draw_symbol(data, emit, draw_state)
Expand Down

0 comments on commit 1ec6a4a

Please sign in to comment.