diff --git a/pyproject.toml b/pyproject.toml index 45cef17f0..2754e9e29 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "woke" -version = "3.6.0" +version = "3.6.1" description = "Woke is a Python-based development and testing framework for Solidity." license = "ISC" authors = ["Ackee Blockchain"] diff --git a/woke/ast/ir/expression/identifier.py b/woke/ast/ir/expression/identifier.py index d5b07eea4..d5038372f 100644 --- a/woke/ast/ir/expression/identifier.py +++ b/woke/ast/ir/expression/identifier.py @@ -13,6 +13,7 @@ from woke.ast.nodes import AstNodeId, SolcIdentifier if TYPE_CHECKING: + from woke.ast.ir.declaration.function_definition import FunctionDefinition from woke.ast.ir.meta.import_directive import ImportDirective from woke.ast.ir.meta.source_unit import SourceUnit @@ -27,7 +28,7 @@ class Identifier(ExpressionAbc): _name: str _overloaded_declarations: List[AstNodeId] - _referenced_declaration_id: Optional[AstNodeId] + _referenced_declaration_ids: Set[AstNodeId] def __init__( self, init: IrInitTuple, identifier: SolcIdentifier, parent: SolidityAbc @@ -37,9 +38,11 @@ def __init__( super().__init__(init, identifier, parent) self._name = identifier.name self._overloaded_declarations = list(identifier.overloaded_declarations) - self._referenced_declaration_id = identifier.referenced_declaration - if self._referenced_declaration_id is None: + if identifier.referenced_declaration is None: assert isinstance(self._parent, ImportDirective) + self._referenced_declaration_ids = set() + else: + self._referenced_declaration_ids = {identifier.referenced_declaration} init.reference_resolver.register_post_process_callback( self._post_process, priority=-1 ) @@ -47,40 +50,46 @@ def __init__( def _post_process(self, callback_params: CallbackParams): from ..meta.import_directive import ImportDirective - assert self._referenced_declaration_id is not None - if self._referenced_declaration_id < 0: - global_symbol = GlobalSymbolsEnum(self._referenced_declaration_id) - self._reference_resolver.register_global_symbol_reference( - global_symbol, self - ) - self._reference_resolver.register_destroy_callback( - self.file, partial(self._destroy, global_symbol) - ) - else: - node = self._reference_resolver.resolve_node( - self._referenced_declaration_id, self._cu_hash - ) + new_referenced_declaration_ids = set() - if isinstance(node, DeclarationAbc): - node.register_reference(self) + for referenced_declaration_id in self._referenced_declaration_ids: + if referenced_declaration_id < 0: + global_symbol = GlobalSymbolsEnum(referenced_declaration_id) + self._reference_resolver.register_global_symbol_reference( + global_symbol, self + ) self._reference_resolver.register_destroy_callback( - self.file, partial(self._destroy, node) + self.file, partial(self._destroy, global_symbol) ) - elif isinstance(node, ImportDirective): - # make this node to reference the source unit directly - assert node.unit_alias is not None - source_unit = callback_params.source_units[node.imported_file] - node_path_order = self._reference_resolver.get_node_path_order( - AstNodeId(source_unit.ast_node_id), - source_unit.cu_hash, + new_referenced_declaration_ids.add(referenced_declaration_id) + else: + node = self._reference_resolver.resolve_node( + referenced_declaration_id, self._cu_hash ) - self._referenced_declaration_id = ( - self._reference_resolver.get_ast_id_from_cu_node_path_order( - node_path_order, self.cu_hash + + if isinstance(node, DeclarationAbc): + node.register_reference(self) + self._reference_resolver.register_destroy_callback( + self.file, partial(self._destroy, node) ) - ) - else: - raise TypeError(f"Unexpected type: {type(node)}") + new_referenced_declaration_ids.add(referenced_declaration_id) + elif isinstance(node, ImportDirective): + # make this node to reference the source unit directly + assert node.unit_alias is not None + source_unit = callback_params.source_units[node.imported_file] + node_path_order = self._reference_resolver.get_node_path_order( + AstNodeId(source_unit.ast_node_id), + source_unit.cu_hash, + ) + new_referenced_declaration_ids.add( + self._reference_resolver.get_ast_id_from_cu_node_path_order( + node_path_order, self.cu_hash + ) + ) + else: + raise TypeError(f"Unexpected type: {type(node)}") + + self._referenced_declaration_ids = new_referenced_declaration_ids def _destroy( self, referenced_declaration: Union[GlobalSymbolsEnum, DeclarationAbc] @@ -119,20 +128,30 @@ def overloaded_declarations(self) -> Tuple[DeclarationAbc, ...]: @property def referenced_declaration( self, - ) -> Union[DeclarationAbc, GlobalSymbolsEnum, SourceUnit]: - from ..meta.source_unit import SourceUnit + ) -> Union[DeclarationAbc, GlobalSymbolsEnum, SourceUnit, Set[FunctionDefinition]]: + def resolve(referenced_declaration_id: AstNodeId): + if referenced_declaration_id < 0: + return GlobalSymbolsEnum(referenced_declaration_id) + + node = self._reference_resolver.resolve_node( + referenced_declaration_id, self._cu_hash + ) + assert isinstance( + node, (DeclarationAbc, SourceUnit) + ), f"Unexpected type: {type(node)}\n{node.source}\n{self.source}\n{self.file}" + return node - assert self._referenced_declaration_id is not None - if self._referenced_declaration_id < 0: - return GlobalSymbolsEnum(self._referenced_declaration_id) + from ..declaration.function_definition import FunctionDefinition + from ..meta.source_unit import SourceUnit - node = self._reference_resolver.resolve_node( - self._referenced_declaration_id, self._cu_hash - ) - assert isinstance( - node, (DeclarationAbc, SourceUnit) - ), f"Unexpected type: {type(node)}\n{node.source}\n{self.source}\n{self.file}" - return node + assert len(self._referenced_declaration_ids) != 0 + if len(self._referenced_declaration_ids) == 1: + return resolve(next(iter(self._referenced_declaration_ids))) + else: + # Identifier in ImportDirective symbol alias referencing multiple overloaded functions + ret = set(map(resolve, self._referenced_declaration_ids)) + assert all(isinstance(x, FunctionDefinition) for x in ret) + return ret # pyright: ignore reportGeneralTypeIssues @property @lru_cache(maxsize=2048) diff --git a/woke/ast/ir/meta/import_directive.py b/woke/ast/ir/meta/import_directive.py index 78a01e236..b1d436e6f 100644 --- a/woke/ast/ir/meta/import_directive.py +++ b/woke/ast/ir/meta/import_directive.py @@ -124,32 +124,37 @@ def __iter__(self) -> Iterator[IrAbc]: yield from symbol_alias.foreign def _post_process(self, callback_params: CallbackParams): - # referenced declaration ID is missing (for whatever reason) in import directive symbol aliases + from ..declaration.function_definition import FunctionDefinition + + # referenced declaration ID is missing in import directive symbol aliases + # the reason is that the Identifier may refer to multiple overloaded functions # for example `import { SafeType } from "SafeLib.sol";` # fix: find these reference IDs manually - # seems to be fixed in solc >= 0.8.12 for symbol_alias in self._symbol_aliases: - if symbol_alias.foreign._referenced_declaration_id is not None: + if len(symbol_alias.foreign._referenced_declaration_ids) != 0: continue source_units_queue: Deque[SourceUnit] = deque( [callback_params.source_units[self._imported_file]] ) processed_source_units: Set[Path] = {self._imported_file} - referenced_declaration = None + referenced_declarations = set() + search = True searched_name = symbol_alias.foreign.name - while source_units_queue and referenced_declaration is None: + while source_units_queue and search: imported_source_unit = source_units_queue.pop() for declaration in imported_source_unit.declarations_iter(): - if declaration.canonical_name == searched_name: - referenced_declaration = declaration - break + if declaration.name == searched_name: + referenced_declarations.add(declaration) + if not isinstance(declaration, FunctionDefinition): + search = False + break for import_ in imported_source_unit.imports: if import_.unit_alias == searched_name: - referenced_declaration = import_ + referenced_declarations.add(import_) break # handle the case when an imported symbol is an alias of another symbol @@ -162,17 +167,22 @@ def _post_process(self, callback_params: CallbackParams): ) processed_source_units.add(import_.imported_file) - assert referenced_declaration is not None - node_path_order = self._reference_resolver.get_node_path_order( - AstNodeId(referenced_declaration.ast_node_id), - referenced_declaration.cu_hash, - ) - referenced_declaration_id = ( - self._reference_resolver.get_ast_id_from_cu_node_path_order( - node_path_order, self.cu_hash + assert len(referenced_declarations) > 0 + + referenced_declaration_ids = set() + for referenced_declaration in referenced_declarations: + node_path_order = self._reference_resolver.get_node_path_order( + AstNodeId(referenced_declaration.ast_node_id), + referenced_declaration.cu_hash, + ) + referenced_declaration_ids.add( + self._reference_resolver.get_ast_id_from_cu_node_path_order( + node_path_order, self.cu_hash + ) ) + symbol_alias.foreign._referenced_declaration_ids = ( + referenced_declaration_ids ) - symbol_alias.foreign._referenced_declaration_id = referenced_declaration_id @property def parent(self) -> SourceUnit: diff --git a/woke/ast/ir/statement/expression_statement.py b/woke/ast/ir/statement/expression_statement.py index bd8b3b201..eb2922dce 100644 --- a/woke/ast/ir/statement/expression_statement.py +++ b/woke/ast/ir/statement/expression_statement.py @@ -20,6 +20,7 @@ from ..expression.index_range_access import IndexRangeAccess from ..expression.literal import Literal from ..expression.member_access import MemberAccess +from ..expression.new_expression import NewExpression from ..expression.tuple_expression import TupleExpression from ..expression.unary_operation import UnaryOperation @@ -46,9 +47,9 @@ class ExpressionStatement(StatementAbc): - a [FunctionCall][woke.ast.ir.expression.function_call.FunctionCall]: - `:::solidity require(arr.length > 1)` in line 3, - a [FunctionCallOptions][woke.ast.ir.expression.function_call_options.FunctionCallOptions]: - - `:::solidity payable(msg.sender).call{value: 1}` in line 16, + - `:::solidity payable(msg.sender).call{value: 1}` in line 17, - an [Identifier][woke.ast.ir.expression.identifier.Identifier]: - - `:::solidity this` in line 15, + - `:::solidity this` in line 16, - an [IndexAccess][woke.ast.ir.expression.index_access.IndexAccess]: - `:::solidity arr[0]` in line 9, - an [IndexRangeAccess][woke.ast.ir.expression.index_range_access.IndexRangeAccess]: @@ -57,8 +58,10 @@ class ExpressionStatement(StatementAbc): - `:::solidity 10` in line 12, - a [MemberAccess][woke.ast.ir.expression.member_access.MemberAccess]: - `:::solidity arr.length` in line 13, + - a [NewExpression][woke.ast.ir.expression.new_expression.NewExpression]: + - `:::solidity new uint[]` in line 14, - a [TupleExpression][woke.ast.ir.expression.tuple_expression.TupleExpression]: - - `:::solidity (arr)` in line 14, + - `:::solidity (arr)` in line 15, - an [UnaryOperation][woke.ast.ir.expression.unary_operation.UnaryOperation]: - `:::solidity i++` in line 6. @@ -76,6 +79,7 @@ class ExpressionStatement(StatementAbc): arr[0] + arr[1]; 10; arr.length; + new uint[]; (arr); this; // silence state mutability warning without generating bytecode payable(msg.sender).call{value: 1}; @@ -105,6 +109,7 @@ class ExpressionStatement(StatementAbc): IndexRangeAccess, Literal, MemberAccess, + NewExpression, TupleExpression, UnaryOperation, ] @@ -130,6 +135,7 @@ def __init__( IndexRangeAccess, Literal, MemberAccess, + NewExpression, TupleExpression, UnaryOperation, ), @@ -171,6 +177,7 @@ def expression( IndexRangeAccess, Literal, MemberAccess, + NewExpression, TupleExpression, UnaryOperation, ]: diff --git a/woke/lsp/features/definition.py b/woke/lsp/features/definition.py index 1a3637730..531cb360d 100644 --- a/woke/lsp/features/definition.py +++ b/woke/lsp/features/definition.py @@ -1,6 +1,6 @@ import logging from pathlib import Path -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Set, Tuple, Union import woke.ast.ir.yul as yul from woke.ast.enums import GlobalSymbolsEnum @@ -65,6 +65,40 @@ def _get_results_from_node( byte_offset: int, node_name_location: Optional[Tuple[int, int]], ) -> Optional[List[Tuple[Path, Tuple[int, int]]]]: + def resolve(node) -> Set[Tuple[Path, Tuple[int, int]]]: + ret = set() + if isinstance(node, (FunctionDefinition, VariableDeclaration)): + if isinstance(node, VariableDeclaration) or node.implemented: + ret.add((node.file, node.name_location)) + + for base_function in node.base_functions: + if base_function.implemented: + ret.add((base_function.file, base_function.name_location)) + if isinstance(node, FunctionDefinition): + for child_function in node.child_functions: + if isinstance(child_function, VariableDeclaration): + ret.add((child_function.file, child_function.name_location)) + elif ( + isinstance(child_function, FunctionDefinition) + and child_function.implemented + ): + ret.add((child_function.file, child_function.name_location)) + elif isinstance(node, ModifierDefinition): + if node.implemented: + ret.add((node.file, node.name_location)) + for base_modifier in node.base_modifiers: + if base_modifier.implemented: + ret.add((base_modifier.file, base_modifier.name_location)) + for child_modifier in node.child_modifiers: + if child_modifier.implemented: + ret.add((child_modifier.file, child_modifier.name_location)) + elif isinstance(node, SourceUnit): + ret.add((node.file, node.byte_location)) + else: + ret.add((node.file, node.name_location)) + + return ret + if isinstance(original_node, DeclarationAbc): assert node_name_location is not None name_location_range = context.compiler.get_range_from_byte_offsets( @@ -96,46 +130,17 @@ def _get_results_from_node( else: node = original_node - if not isinstance(node, (DeclarationAbc, SourceUnit)): + if not isinstance(node, (DeclarationAbc, SourceUnit, set)): return None - definitions = [] - - if isinstance(node, (FunctionDefinition, VariableDeclaration)): - if isinstance(node, VariableDeclaration) or node.implemented: - definitions.append((node.file, node.name_location)) - - for base_function in node.base_functions: - if base_function.implemented: - definitions.append((base_function.file, base_function.name_location)) - if isinstance(node, FunctionDefinition): - for child_function in node.child_functions: - if isinstance(child_function, VariableDeclaration): - definitions.append( - (child_function.file, child_function.name_location) - ) - elif ( - isinstance(child_function, FunctionDefinition) - and child_function.implemented - ): - definitions.append( - (child_function.file, child_function.name_location) - ) - elif isinstance(node, ModifierDefinition): - if node.implemented: - definitions.append((node.file, node.name_location)) - for base_modifier in node.base_modifiers: - if base_modifier.implemented: - definitions.append((base_modifier.file, base_modifier.name_location)) - for child_modifier in node.child_modifiers: - if child_modifier.implemented: - definitions.append((child_modifier.file, child_modifier.name_location)) - elif isinstance(node, SourceUnit): - definitions.append((node.file, node.byte_location)) + if isinstance(node, set): + definitions = set() + for n in node: + definitions |= resolve(n) else: - definitions.append((node.file, node.name_location)) + definitions = resolve(node) - return definitions + return list(definitions) async def _get_definition_from_cache( diff --git a/woke/lsp/features/hover.py b/woke/lsp/features/hover.py index ab26ebe61..d72d55541 100644 --- a/woke/lsp/features/hover.py +++ b/woke/lsp/features/hover.py @@ -285,6 +285,11 @@ def _get_results_from_node( node, context.openzeppelin_contracts_version ) + return value, original_node_location + elif isinstance(node, set): + value = "\n".join( + "```solidity\n" + node.declaration_string + "\n```" for node in node + ) return value, original_node_location elif isinstance(node, yul.Identifier): if node.name in yul_definitions: diff --git a/woke/lsp/features/references.py b/woke/lsp/features/references.py index 1c459b467..532cf0b48 100644 --- a/woke/lsp/features/references.py +++ b/woke/lsp/features/references.py @@ -142,14 +142,21 @@ async def references( ): node = node.function - if not isinstance(node, DeclarationAbc): + if not isinstance(node, (DeclarationAbc, set)): return None refs = [] - for reference in node.get_all_references( - context.config.lsp.find_references.include_declarations - ): - refs.append(_generate_reference_location(reference, context)) + if isinstance(node, set): + for n in node: + for reference in n.get_all_references( + context.config.lsp.find_references.include_declarations + ): + refs.append(_generate_reference_location(reference, context)) + else: + for reference in node.get_all_references( + context.config.lsp.find_references.include_declarations + ): + refs.append(_generate_reference_location(reference, context)) if len(refs) == 0: return None diff --git a/woke/lsp/features/rename.py b/woke/lsp/features/rename.py index 82c80c653..d2e3d4f10 100644 --- a/woke/lsp/features/rename.py +++ b/woke/lsp/features/rename.py @@ -2,11 +2,12 @@ import re from collections import defaultdict from pathlib import Path -from typing import DefaultDict, List, Optional, Union +from typing import DefaultDict, List, Optional, Set, Union import woke.ast.ir.yul as yul from woke.ast.ir.abc import IrAbc from woke.ast.ir.declaration.abc import DeclarationAbc +from woke.ast.ir.declaration.function_definition import FunctionDefinition from woke.ast.ir.expression.binary_operation import BinaryOperation from woke.ast.ir.expression.identifier import Identifier from woke.ast.ir.expression.member_access import MemberAccess @@ -81,11 +82,20 @@ def _generate_reference_location( def _generate_workspace_edit( - declaration: DeclarationAbc, new_name: str, context: LspContext + declaration: Union[DeclarationAbc, Set[FunctionDefinition]], + new_name: str, + context: LspContext, ) -> WorkspaceEdit: changes_by_file: DefaultDict[Path, List[TextEdit]] = defaultdict(list) - for reference in declaration.get_all_references(True): + all_references = set() + if isinstance(declaration, set): + for func in declaration: + all_references.update(func.get_all_references(True)) + else: + all_references.update(declaration.get_all_references(True)) + + for reference in all_references: if not isinstance(reference, (UnaryOperation, BinaryOperation)): changes_by_file[reference.file].append( TextEdit( @@ -183,7 +193,7 @@ async def rename( raise LspError(ErrorCodes.RequestFailed, "Cannot rename this symbol") node = external_reference.referenced_declaration - if not isinstance(node, DeclarationAbc): + if not isinstance(node, (DeclarationAbc, set)): raise LspError(ErrorCodes.RequestFailed, "Cannot rename this symbol") return _generate_workspace_edit(node, params.new_name, context) @@ -249,6 +259,6 @@ async def prepare_rename( return None location = node.name_location - if not isinstance(node, DeclarationAbc) or location is None: + if not isinstance(node, (DeclarationAbc, set)) or location is None: return None return context.compiler.get_range_from_byte_offsets(path, location) diff --git a/woke/lsp/features/type_hierarchy.py b/woke/lsp/features/type_hierarchy.py index 1ae2a28c5..d8fbb67fa 100644 --- a/woke/lsp/features/type_hierarchy.py +++ b/woke/lsp/features/type_hierarchy.py @@ -129,6 +129,32 @@ def _get_node_symbol_kind( assert False, f"Unknown node type {type(node)}" +def prepare_type_hierarchy_item( + context: LspContext, + node: Union[ + ContractDefinition, FunctionDefinition, ModifierDefinition, VariableDeclaration + ], +) -> TypeHierarchyItem: + return TypeHierarchyItem( + name=node.canonical_name, + kind=_get_node_symbol_kind(node), + tags=None, + detail=None, + uri=DocumentUri(path_to_uri(node.file)), + range=context.compiler.get_range_from_byte_offsets( + node.file, node.byte_location + ), + selection_range=context.compiler.get_range_from_byte_offsets( + node.file, node.name_location + ), + data=TypeHierarchyItemData( + ast_node_id=node.ast_node_id, + cu_hash=node.cu_hash.hex(), + uri=DocumentUri(path_to_uri(node.file)), + ), + ) + + async def prepare_type_hierarchy( context: LspContext, params: TypeHierarchyPrepareParams ) -> Union[List[TypeHierarchyItem], None]: @@ -181,26 +207,9 @@ async def prepare_type_hierarchy( if isinstance( node, (ContractDefinition, FunctionDefinition, ModifierDefinition) ) or (isinstance(node, VariableDeclaration) and node.overrides is not None): - return [ - TypeHierarchyItem( - name=node.canonical_name, - kind=_get_node_symbol_kind(node), - tags=None, - detail=None, - uri=DocumentUri(path_to_uri(node.file)), - range=context.compiler.get_range_from_byte_offsets( - node.file, node.byte_location - ), - selection_range=context.compiler.get_range_from_byte_offsets( - node.file, node.name_location - ), - data=TypeHierarchyItemData( - ast_node_id=node.ast_node_id, - cu_hash=node.cu_hash.hex(), - uri=DocumentUri(path_to_uri(node.file)), - ), - ) - ] + return [prepare_type_hierarchy_item(context, node)] + elif isinstance(node, set): + return [prepare_type_hierarchy_item(context, n) for n in node] return None