Skip to content

Commit

Permalink
fix and run lint workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
z80dev committed Dec 13, 2024
1 parent be7f7c8 commit ddd8521
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 42 deletions.
14 changes: 8 additions & 6 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: '3.10'
- run: pip install pre-commit
- run: pre-commit run --all-files
- uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v4
- name: Set up Python
run: uv python install
- name: Install dependencies
run: uv sync --all-extras --dev
- run: uv run pre-commit run --all-files
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,5 @@ dev = [
"flake8>=5.0.4,<5.1",
"pre-commit>=3.5.0,<4.0",
"pytest>=7.4.3,<7.5",
"ruff>=0.8.3",
]
61 changes: 43 additions & 18 deletions vyper_lsp/analyzer/AstAnalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,8 @@
from packaging.version import Version
from lsprotocol.types import (
CompletionItemLabelDetails,
Diagnostic,
DiagnosticSeverity,
ParameterInformation,
Position,
Range,
SignatureHelp,
SignatureInformation,
)
Expand Down Expand Up @@ -76,11 +73,10 @@ def signature_help(
module, fn = matches.groups()
logger.info(f"looking up function {fn} in module {module}")
if module in self.ast.imports:
logger.info(f"found module")
logger.info("found module")
if fn := self.ast.imports[module].functions[fn]:
logger.info(f"args: {fn.arguments}")


# this returns for all external functions
# TODO: Implement checking interfaces
if not expression.startswith("self."):
Expand Down Expand Up @@ -109,7 +105,6 @@ def signature_help(
line = doc.lines[search_start_line_no]
search_start_line_no += 1


fn_label = line.removeprefix("def ").removesuffix(":\n")

for arg in node.args.args:
Expand All @@ -131,10 +126,12 @@ def signature_help(
active_signature=0,
)

def _dot_completions_for_element(self, element: str, top_level_node = None, line: str="") -> List[CompletionItem]:
def _dot_completions_for_element(
self, element: str, top_level_node=None, line: str = ""
) -> List[CompletionItem]:
completions = []
logger.info(f"getting dot completions for element: {element}")
#logger.info(f"import keys: {self.ast.imports.keys()}")
# logger.info(f"import keys: {self.ast.imports.keys()}")
self.ast.imports.keys()
if element == "self":
for fn in self.ast.get_internal_functions():
Expand All @@ -148,7 +145,7 @@ def _dot_completions_for_element(self, element: str, top_level_node = None, line
if getattr(fn.ast_def, "doc_string", False):
doc_string = fn.ast_def.doc_string.value

#out = self._format_fn_signature(fn.decl_node)
# out = self._format_fn_signature(fn.decl_node)
out = format_fn(fn)

# NOTE: this just gets ignored by most editors
Expand All @@ -157,20 +154,38 @@ def _dot_completions_for_element(self, element: str, top_level_node = None, line

doc_string = f"{out}\n{doc_string}"

show_external: bool = isinstance(top_level_node, nodes.ExportsDecl) or line.startswith("exports:")
show_internal_and_deploy: bool = isinstance(top_level_node, nodes.FunctionDef)
show_external: bool = isinstance(
top_level_node, nodes.ExportsDecl
) or line.startswith("exports:")
show_internal_and_deploy: bool = isinstance(
top_level_node, nodes.FunctionDef
)

if show_internal_and_deploy and (fn.is_internal or fn.is_deploy):
completions.append(CompletionItem(label=name, documentation=doc_string, label_details=completion_item_label_details))
completions.append(
CompletionItem(
label=name,
documentation=doc_string,
label_details=completion_item_label_details,
)
)
elif show_external and fn.is_external:
completions.append(CompletionItem(label=name, documentation=doc_string, label_details=completion_item_label_details))
completions.append(
CompletionItem(
label=name,
documentation=doc_string,
label_details=completion_item_label_details,
)
)
elif element in self.ast.flags:
members = self.ast.flags[element]._flag_members
for member in members.keys():
completions.append(CompletionItem(label=member))

if isinstance(top_level_node, nodes.FunctionDef):
var_declarations = top_level_node.get_descendants(nodes.AnnAssign, filters={"target.id": element})
var_declarations = top_level_node.get_descendants(
nodes.AnnAssign, filters={"target.id": element}
)
assert len(var_declarations) <= 1
for vardecl in var_declarations:
type_name = vardecl.annotation.id
Expand Down Expand Up @@ -202,7 +217,9 @@ def get_completions_in_doc(
surrounding_node = self.ast.find_top_level_node_at_pos(pos)

# internal + imported fns, state vars, and flags
dot_completions = self._dot_completions_for_element(element, top_level_node=surrounding_node, line=current_line)
dot_completions = self._dot_completions_for_element(
element, top_level_node=surrounding_node, line=current_line
)
if len(dot_completions) > 0:
return CompletionList(is_incomplete=False, items=dot_completions)
else:
Expand All @@ -220,8 +237,17 @@ def get_completions_in_doc(

if params.context.trigger_character == ":":
# return empty_completions if the line starts with "flag", "struct", or "event"
object_declaration_keywords = ["flag", "struct", "event", "enum", "interface"]
if any(current_line.startswith(keyword) for keyword in object_declaration_keywords):
object_declaration_keywords = [
"flag",
"struct",
"event",
"enum",
"interface",
]
if any(
current_line.startswith(keyword)
for keyword in object_declaration_keywords
):
return no_completions

for typ in custom_types + BASE_TYPES:
Expand Down Expand Up @@ -288,7 +314,6 @@ def is_state_var(self, expression: str):
var_name = expression.split("self.")[-1]
return var_name in self.ast.variables


def hover_info(self, document: Document, pos: Position) -> Optional[str]:
if len(document.lines) < pos.line:
return None
Expand Down
4 changes: 2 additions & 2 deletions vyper_lsp/analyzer/BaseAnalyzer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from typing import List, Optional
from lsprotocol.types import CompletionList, CompletionParams, Diagnostic, Position
from typing import Optional
from lsprotocol.types import CompletionList, CompletionParams, Position
from pygls.server import LanguageServer
from pygls.workspace import Document

Expand Down
2 changes: 1 addition & 1 deletion vyper_lsp/analyzer/SourceAnalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def on_grammar_error(e: UnexpectedInput) -> bool:
return True
if (
last_error is not None
and type(last_error) == type(e)
and isinstance(last_error, type(e))
and last_error.line == e.line
):
return True
Expand Down
38 changes: 26 additions & 12 deletions vyper_lsp/ast.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import copy
from functools import cached_property
import logging
from pathlib import Path
from typing import Optional, List
from lsprotocol.types import Diagnostic, DiagnosticSeverity, Position
from lsprotocol.types import Diagnostic, Position
from pygls.workspace import Document
from vyper.ast import Module, VyperNode, nodes
from vyper.ast import VyperNode, nodes
from vyper.compiler import CompilerData
from vyper.compiler.input_bundle import FilesystemInputBundle
from vyper.compiler.phases import DEFAULT_CONTRACT_PATH, ModuleT
Expand All @@ -16,14 +14,20 @@
import warnings
import re

from vyper_lsp.utils import create_diagnostic_warning, diagnostic_from_exception, working_directory_for_document, document_to_fileinput
from vyper_lsp.utils import (
create_diagnostic_warning,
diagnostic_from_exception,
working_directory_for_document,
document_to_fileinput,
)

logger = logging.getLogger("vyper-lsp")


pattern_text = r"(.+) will be deprecated in a future release, use (.+) instead\."
deprecation_pattern = re.compile(pattern_text)


class AST:
ast_data = None
ast_data_annotated = None
Expand Down Expand Up @@ -68,10 +72,14 @@ def _load_module_data(self):
self.functions = ast._metadata["type"].functions
self.variables = ast._metadata["type"].variables

flagt_list = [FlagT.from_FlagDef(node) for node in ast._metadata["type"].flag_defs]
flagt_list = [
FlagT.from_FlagDef(node) for node in ast._metadata["type"].flag_defs
]
self.flags = {flagt.name: flagt for flagt in flagt_list}

structt_list = [StructT.from_StructDef(node) for node in ast._metadata["type"].struct_defs]
structt_list = [
StructT.from_StructDef(node) for node in ast._metadata["type"].struct_defs
]
self.structs = {structt.name: structt for structt in structt_list}

def update_ast(self, doc: Document) -> List[Diagnostic]:
Expand All @@ -84,7 +92,9 @@ def build_ast(self, doc: Document | str) -> List[Diagnostic]:
uri_parent_path = working_directory_for_document(doc)
search_paths = get_search_paths([str(uri_parent_path)])
fileinput = document_to_fileinput(doc)
compiler_data = CompilerData(fileinput, input_bundle=FilesystemInputBundle(search_paths))
compiler_data = CompilerData(
fileinput, input_bundle=FilesystemInputBundle(search_paths)
)
diagnostics = []
replacements = {}
warnings.simplefilter("always")
Expand All @@ -105,7 +115,9 @@ def build_ast(self, doc: Document | str) -> List[Diagnostic]:
diagnostics.append(diagnostic_from_exception(e))
if e.annotations:
for a in e.annotations:
diagnostics.append(diagnostic_from_exception(a, message=message))
diagnostics.append(
diagnostic_from_exception(a, message=message)
)

for warning in w:
m = deprecation_pattern.match(str(warning.message))
Expand Down Expand Up @@ -155,11 +167,11 @@ def get_top_level_nodes(self, *args, **kwargs):
return self.best_ast.get_children(*args, **kwargs)

def get_enums(self) -> List[str]:
#return [node.name for node in self.get_descendants(nodes.FlagDef)]
# return [node.name for node in self.get_descendants(nodes.FlagDef)]
return list(self.flags.keys())

def get_structs(self) -> List[str]:
#return [node.name for node in self.get_descendants(nodes.StructDef)]
# return [node.name for node in self.get_descendants(nodes.StructDef)]
return list(self.structs.keys())

def get_events(self) -> List[str]:
Expand All @@ -177,7 +189,9 @@ def get_constants(self):

return [
node.target.id
for node in self.ast_data.get_children(nodes.VariableDecl, {"is_constant": True})
for node in self.ast_data.get_children(
nodes.VariableDecl, {"is_constant": True}
)
]

def get_enum_variants(self, enum: str):
Expand Down
1 change: 0 additions & 1 deletion vyper_lsp/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import argparse
from pathlib import Path
from typing import Optional, List
import logging
from .logging import LanguageServerLogHandler
Expand Down
12 changes: 10 additions & 2 deletions vyper_lsp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from lsprotocol.types import Diagnostic, DiagnosticSeverity, Position, Range
from packaging.version import Version
from pygls.workspace import Document
from vyper.ast import FunctionDef, VyperNode
from vyper.ast import VyperNode
from vyper.exceptions import VyperException
from vyper.compiler import FileInput

Expand Down Expand Up @@ -48,6 +48,7 @@ def is_attribute_access(line):
# access to as much cursor information as possible (ex. line number),
# it could open up some possibilies when refactoring for performance


def get_word_at_cursor(sentence: str, cursor_index: int) -> str:
start = cursor_index
end = cursor_index
Expand Down Expand Up @@ -207,26 +208,33 @@ def create_diagnostic_warning(
severity=DiagnosticSeverity.Warning,
)


def diagnostic_from_exception(node: VyperException, message=None) -> Diagnostic:
return Diagnostic(
range=range_from_exception(node),
message=message or str(node),
severity=DiagnosticSeverity.Error,
)


def document_to_fileinput(doc: Document) -> FileInput:
path = Path(doc.uri.replace("file://", ""))
return FileInput(0, path, path, doc.source)


def working_directory_for_document(doc: Document) -> Path:
return Path(doc.uri.replace("file://", "")).parent


def escape_underscores(expression: str) -> str:
return expression.replace("_", "\\_")


def format_fn(func) -> str:
args = ", ".join([f"{arg.name}: _{arg.typ}_" for arg in func.arguments])
return_value = f" -> _{func.return_type}_" if func.return_type is not None else ""
mutability = func.mutability.value
out = f"def __{escape_underscores(func.name)}__({args}){return_value}: _{mutability}_"
out = (
f"def __{escape_underscores(func.name)}__({args}){return_value}: _{mutability}_"
)
return out

0 comments on commit ddd8521

Please sign in to comment.