diff --git a/slither/core/slither_core.py b/slither/core/slither_core.py index 5c87b67dd..d22f5ae0f 100644 --- a/slither/core/slither_core.py +++ b/slither/core/slither_core.py @@ -22,7 +22,7 @@ from slither.slithir.variables import Constant from slither.utils.colors import red from slither.utils.sarif import read_triage_info -from slither.utils.source_mapping import get_definition, get_references, get_implementation +from slither.utils.source_mapping import get_definition, get_references, get_all_implementations logger = logging.getLogger("Slither") logging.basicConfig() @@ -202,36 +202,11 @@ def offset_to_objects(self, filename_str: str, offset: int) -> Set[SourceMapping return self._offset_to_objects[filename][offset] def _compute_offsets_from_thing(self, thing: SourceMapping): - definition = get_definition(thing, self.crytic_compile) - references = get_references(thing) - implementations = set() - - # Abstract contracts and interfaces are implemented by their children - if isinstance(thing, Contract): - is_interface = thing.is_interface - is_implicitly_abstract = not thing.is_fully_implemented - is_explicitly_abstract = thing.is_abstract - if is_interface or is_implicitly_abstract or is_explicitly_abstract: - - for contract in self.contracts: - if thing in contract.immediate_inheritance: - implementations.add(contract.source_mapping) - - # Parent's virtual functions may be overridden by children - elif isinstance(thing, FunctionContract): - for over in thing.overridden_by: - implementations.add(over.source_mapping) - # Only show implemented virtual functions - if not thing.is_virtual or thing.is_implemented: - implementations.add(get_implementation(thing)) - - else: - implementations.add(get_implementation(thing)) + implementations = get_all_implementations(thing, self.contracts) for offset in range(definition.start, definition.end + 1): - if ( isinstance(thing, TopLevel) or ( @@ -265,8 +240,9 @@ def _compute_offsets_from_thing(self, thing: SourceMapping): if is_declared_function: # Only show the nearest lexical definition for declared contract-level functions if ( - offset > thing.contract.source_mapping.start - and offset < thing.contract.source_mapping.end + thing.contract.source_mapping.start + < offset + < thing.contract.source_mapping.end ): self._offset_to_definitions[ref.filename][offset].add(definition) diff --git a/slither/utils/source_mapping.py b/slither/utils/source_mapping.py index fe3f908b4..9bf772894 100644 --- a/slither/utils/source_mapping.py +++ b/slither/utils/source_mapping.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Set from crytic_compile import CryticCompile from slither.core.declarations import ( Contract, @@ -9,6 +9,7 @@ Pragma, Structure, CustomError, + FunctionContract, ) from slither.core.solidity_types import Type, TypeAlias from slither.core.source_mapping.source_mapping import Source, SourceMapping @@ -65,5 +66,34 @@ def get_implementation(target: SourceMapping) -> Source: return target.source_mapping +def get_all_implementations(target: SourceMapping, contracts: List[Contract]) -> Set[Source]: + """ + Get all implementations of a contract or function, accounting for inheritance and overrides + """ + implementations = set() + # Abstract contracts and interfaces are implemented by their children + if isinstance(target, Contract): + is_interface = target.is_interface + is_implicitly_abstract = not target.is_fully_implemented + is_explicitly_abstract = target.is_abstract + if is_interface or is_implicitly_abstract or is_explicitly_abstract: + for contract in contracts: + if target in contract.immediate_inheritance: + implementations.add(contract.source_mapping) + + # Parent's virtual functions may be overridden by children + elif isinstance(target, FunctionContract): + for over in target.overridden_by: + implementations.add(over.source_mapping) + # Only show implemented virtual functions + if not target.is_virtual or target.is_implemented: + implementations.add(get_implementation(target)) + + else: + implementations.add(get_implementation(target)) + + return implementations + + def get_references(target: SourceMapping) -> List[Source]: return target.references