Skip to content

Commit

Permalink
Add tests for completion
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanwn committed Apr 27, 2024
1 parent b794da8 commit de32935
Show file tree
Hide file tree
Showing 8 changed files with 213 additions and 117 deletions.
107 changes: 0 additions & 107 deletions scripts/debug.py

This file was deleted.

12 changes: 8 additions & 4 deletions src/souffle_analyzer/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from souffle_analyzer.parser import Parser
from souffle_analyzer.sourceutil import (
get_before_token,
get_bracket_scores,
get_pair_symbol_score,
get_words_in_consecutive_block_at_line,
)
from souffle_analyzer.visitor.code_action_visitor import CodeActionVisitor
Expand Down Expand Up @@ -154,7 +154,9 @@ def get_completion_items(
else:
return []
else:
bracket_scores = get_bracket_scores(code)
paren_scores = get_pair_symbol_score(code, ("(", ")"))
curly_bracket_scores = get_pair_symbol_score(code, ("{", "}"))
square_bracket_scores = get_pair_symbol_score(code, ("{", "}"))
before_token = get_before_token(
code,
position.line,
Expand All @@ -165,9 +167,11 @@ def get_completion_items(
elif (
# This is essentially the score on the position before this.
# See the property of the get_bracket_scores function.
bracket_scores[position.line][position.character] == 0
paren_scores[position.line][position.character] == 0
and curly_bracket_scores[position.line][position.character] == 0
and square_bracket_scores[position.line][position.character] == 0
and (
before_token in [".input", ".output"]
before_token in [".input", ".output", ".printsize"]
or before_token.endswith(",")
or before_token.endswith(":-")
or before_token.endswith(".")
Expand Down
30 changes: 29 additions & 1 deletion src/souffle_analyzer/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,19 @@ def covers(self, position: Position) -> bool:
return self.start <= position < self.end


@dataclass
class Location:
uri: str
range_: Range

@classmethod
def from_lsp_type(cls, loc: lsptypes.Location) -> Location:
return cls(uri=loc.uri, range_=Range.from_lsp_type(loc.range))

def to_lsp_type(self) -> lsptypes.Location:
return lsptypes.Location(uri=self.uri, range=self.range_.to_lsp_type())


@dataclass
class SyntaxIssue:
range_: Range
Expand Down Expand Up @@ -834,6 +847,7 @@ def accept(self, visitor: Visitor[T]) -> T:
@dataclass
class Argument(ValidNode):
ty: SouffleType
# parent: ValidNode | None


@dataclass
Expand All @@ -850,11 +864,25 @@ def accept(self, visitor: Visitor[T]) -> T:
return visitor.visit_variable(self)


@dataclass
class RecordInit(Argument):
arguments: list[Argument]
definition: TypeDeclaration | None = field(default=None)

@property
def children(self) -> list[Node]:
return [
*self.arguments,
]

def accept(self, visitor: Visitor[T]) -> T:
return visitor.visit_record_init(self)


@dataclass
class BranchInit(Argument):
name: ResultNode[BranchInitName]
arguments: list[Argument]
definition: ResultNode[AbstractDataTypeBranch] | None = field(default=None)

@property
def children(self) -> list[Node]:
Expand Down
18 changes: 18 additions & 0 deletions src/souffle_analyzer/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
PreprocInclude,
QualifiedName,
Range,
RecordInit,
RecordTypeExpression,
RelationDeclaration,
RelationReference,
Expand Down Expand Up @@ -687,6 +688,9 @@ def parse_argument(self, node: ts.Node) -> Argument | None:
variable_node = self.get_child_of_type(node, "variable")
if variable_node:
return self.parse_variable(variable_node)
record_init_node = self.get_child_of_type(node, "record_init")
if record_init_node:
return self.parse_record_init(record_init_node)
branch_init_node = self.get_child_of_type(node, "branch_init")
if branch_init_node:
return self.parse_branch_init(branch_init_node)
Expand All @@ -711,6 +715,16 @@ def parse_variable(self, node: ts.Node) -> Variable:
range_=self.get_range(node),
name=self.get_text(node),
ty=UnresolvedType(),
# parent=None,
)

def parse_record_init(self, node: ts.Node) -> RecordInit:
arg_nodes = self.get_children_of_type(node, "argument")
arguments = list(filter(None, (self.parse_argument(_) for _ in arg_nodes)))
return RecordInit(
range_=self.get_range(node),
arguments=arguments,
ty=UnresolvedType(),
)

def parse_branch_init(self, node: ts.Node) -> BranchInit:
Expand All @@ -735,6 +749,7 @@ def parse_branch_init(self, node: ts.Node) -> BranchInit:
name=name,
arguments=arguments,
ty=UnresolvedType(),
# parent=None,
)

def parse_binary_operation(self, node: ts.Node) -> BinaryOperation:
Expand Down Expand Up @@ -810,6 +825,7 @@ def parse_binary_operation(self, node: ts.Node) -> BinaryOperation:
op=op,
rhs=rhs,
ty=UnresolvedType(),
# parent=None,
)

def parse_binary_operator(self, node: ts.Node) -> BinaryOperator:
Expand All @@ -823,13 +839,15 @@ def parse_decimal(self, node: ts.Node) -> NumberConstant:
range_=self.get_range(node),
val=int(self.get_text(node)),
ty=UnresolvedType(),
# parent=None,
)

def parse_string_literal(self, node: ts.Node) -> StringConstant:
return StringConstant(
range_=self.get_range(node),
val=self.get_text(node),
ty=UnresolvedType(),
# parent=None,
)

QualifiedNameT = TypeVar("QualifiedNameT", bound=QualifiedName)
Expand Down
8 changes: 4 additions & 4 deletions src/souffle_analyzer/sourceutil.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import re
from typing import List, Set
from typing import List, Set, Tuple


def get_consecutive_block_at_line(
Expand Down Expand Up @@ -33,7 +33,7 @@ def get_words_in_consecutive_block_at_line(code: str, line_no: int) -> Set[str]:
return res


def get_bracket_scores(code: str) -> List[List[int]]:
def get_pair_symbol_score(code: str, symbol_pair: Tuple[str, str]) -> List[List[int]]:
# On each line, there is always at least a value marking the
# bracket score at the "beginning" of the line before
# any character appears.
Expand All @@ -44,9 +44,9 @@ def get_bracket_scores(code: str) -> List[List[int]]:
for line in range(len(code_lines)):
scores.append([cur])
for character in range(len(code_lines[line])):
if code_lines[line][character] == "(":
if code_lines[line][character] == symbol_pair[0]:
cur += 1
elif code_lines[line][character] == ")":
elif code_lines[line][character] == symbol_pair[1]:
cur -= 1
scores[line].append(cur)
return scores
Expand Down
4 changes: 4 additions & 0 deletions src/souffle_analyzer/visitor/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Node,
PreprocInclude,
QualifiedName,
RecordInit,
RecordTypeExpression,
RelationDeclaration,
RelationReference,
Expand Down Expand Up @@ -136,6 +137,9 @@ def visit_variable(self, variable: Variable) -> T:
def visit_branch_init(self, branch_init: BranchInit) -> T:
return self.generic_visit(branch_init)

def visit_record_init(self, record_init: RecordInit) -> T:
return self.generic_visit(record_init)

def visit_binary_operation(self, binary_operation: BinaryOperation) -> T:
return self.generic_visit(binary_operation)

Expand Down
Loading

0 comments on commit de32935

Please sign in to comment.