diff --git a/src/artifacts/ast/mod.rs b/src/artifacts/ast/mod.rs index ad863323..d54a2e28 100644 --- a/src/artifacts/ast/mod.rs +++ b/src/artifacts/ast/mod.rs @@ -560,15 +560,12 @@ impl VariableDeclaration { } } -/// Structured documentation (NatSpec). -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -#[serde(untagged)] -pub enum StructuredDocumentation { - /// The documentation is provided in the form of an AST node. - Parsed { text: String }, - /// The documentation is provided in the form of a string literal. - Text(String), -} +ast_node!( + /// Structured documentation (NatSpec). + struct StructuredDocumentation { + text: String, + } +); ast_node!( /// An override specifier. @@ -1030,7 +1027,7 @@ ast_node!( /// A using for directive. struct UsingForDirective { #[serde(default, deserialize_with = "serde_helpers::default_for_null")] - function_list: Vec, + function_list: Vec, #[serde(default)] global: bool, library_name: Option, @@ -1038,12 +1035,25 @@ ast_node!( } ); +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(untagged)] +pub enum UsingForFunctionItem { + Function(FunctionIdentifierPath), + OverloadedOperator(OverloadedOperator), +} + /// A wrapper around [IdentifierPath] for the [UsingForDirective]. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct FunctionIdentifierPath { pub function: IdentifierPath, } +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct OverloadedOperator { + pub definition: IdentifierPath, + pub operator: String, +} + ast_node!( /// An import directive. struct ImportDirective { diff --git a/src/artifacts/ast/visitor.rs b/src/artifacts/ast/visitor.rs index d2405190..336f74f3 100644 --- a/src/artifacts/ast/visitor.rs +++ b/src/artifacts/ast/visitor.rs @@ -50,6 +50,8 @@ pub trait Visitor { fn visit_return(&mut self, _return: &Return) {} fn visit_inheritance_specifier(&mut self, _specifier: &InheritanceSpecifier) {} fn visit_modifier_invocation(&mut self, _invocation: &ModifierInvocation) {} + fn visit_inline_assembly(&mut self, _assembly: &InlineAssembly) {} + fn visit_external_assembly_reference(&mut self, _ref: &ExternalInlineAssemblyReference) {} } pub trait Walk { @@ -244,10 +246,10 @@ impl_walk!(Statement, visit_statement, |statement, visitor| { Statement::Return(statement) => { statement.walk(visitor); } - Statement::Break(_) - | Statement::Continue(_) - | Statement::InlineAssembly(_) - | Statement::PlaceholderStatement(_) => {} + Statement::InlineAssembly(assembly) => { + assembly.walk(visitor); + } + Statement::Break(_) | Statement::Continue(_) | Statement::PlaceholderStatement(_) => {} } }); @@ -327,7 +329,7 @@ impl_walk!(UsingForDirective, visit_using_for, |directive, visitor| { library_name.walk(visitor); } for function in &directive.function_list { - function.function.walk(visitor); + function.walk(visitor); } }); @@ -526,6 +528,14 @@ impl_walk!(ModifierInvocation, visit_modifier_invocation, |invocation, visitor| invocation.modifier_name.walk(visitor); }); +impl_walk!(InlineAssembly, visit_inline_assembly, |assembly, visitor| { + assembly.external_references.iter().for_each(|reference| { + reference.walk(visitor); + }); +}); + +impl_walk!(ExternalInlineAssemblyReference, visit_external_assembly_reference); + impl_walk!(ElementaryTypeName, visit_elementary_type_name); impl_walk!(Literal, visit_literal); impl_walk!(ImportDirective, visit_import_directive); @@ -594,3 +604,18 @@ impl_walk!(ElementaryOrRawTypeName, |type_name, visitor| { ElementaryOrRawTypeName::Raw(_) => {} } }); + +impl_walk!(UsingForFunctionItem, |item, visitor| { + match item { + UsingForFunctionItem::Function(func) => { + func.function.walk(visitor); + } + UsingForFunctionItem::OverloadedOperator(operator) => { + operator.walk(visitor); + } + } +}); + +impl_walk!(OverloadedOperator, |operator, visitor| { + operator.definition.walk(visitor); +}); diff --git a/src/flatten.rs b/src/flatten.rs index d0eb2dd5..bc3a0be9 100644 --- a/src/flatten.rs +++ b/src/flatten.rs @@ -8,8 +8,8 @@ use crate::{ artifacts::{ ast::SourceLocation, visitor::{Visitor, Walk}, - ContractDefinitionPart, Identifier, IdentifierPath, MemberAccess, Source, SourceUnit, - SourceUnitPart, Sources, UserDefinedTypeName, + ContractDefinitionPart, ExternalInlineAssemblyReference, Identifier, IdentifierPath, + MemberAccess, Source, SourceUnit, SourceUnitPart, Sources, }, error::SolcError, utils, Graph, Project, ProjectCompileOutput, ProjectPathsConfig, Result, @@ -30,9 +30,21 @@ impl ItemLocation { Some(ItemLocation { path, start, end }) } + + fn length(&self) -> usize { + self.end - self.start + } } -/// Visitor exploring AST and collecting all references to any declarations +/// Visitor exploring AST and collecting all references to declarations via `Identifier` and +/// `IdentifierPath` nodes. +/// +/// It also collects `MemberAccess` parts. So, if we have `X.Y` expression, loc and AST ID will be +/// saved for Y only. +/// +/// That way, even if we have a long `MemberAccess` expression (a.b.c.d) then the first member (a) +/// will be collected as either `Identifier` or `IdentifierPath`, and all subsequent parts (b, c, d) +/// will be collected as `MemberAccess` parts. struct ReferencesCollector { path: PathBuf, references: HashMap>, @@ -53,34 +65,30 @@ impl Visitor for ReferencesCollector { } } - fn visit_user_defined_type_name(&mut self, type_name: &UserDefinedTypeName) { - self.process_referenced_declaration(type_name.referenced_declaration, &type_name.src); - } - - fn visit_member_access(&mut self, member_access: &MemberAccess) { - if let Some(id) = member_access.referenced_declaration { - self.process_referenced_declaration(id, &member_access.src); - } - } - fn visit_identifier_path(&mut self, path: &IdentifierPath) { self.process_referenced_declaration(path.referenced_declaration, &path.src); } -} - -/// Visitor exploring AST and collecting all references to any declarations found in -/// `UserDefinedTypeName` nodes -struct UserDefinedTypeNamesCollector { - path: PathBuf, - references: HashMap>, -} -impl Visitor for UserDefinedTypeNamesCollector { - fn visit_user_defined_type_name(&mut self, type_name: &UserDefinedTypeName) { - if let Some(loc) = ItemLocation::try_from_source_loc(&type_name.src, self.path.clone()) { - self.references.entry(type_name.referenced_declaration).or_default().insert(loc); + fn visit_member_access(&mut self, access: &MemberAccess) { + if let Some(referenced_declaration) = access.referenced_declaration { + if let (Some(src_start), Some(src_length)) = (access.src.start, access.src.length) { + let name_length = access.member_name.len(); + // Accessed member name is in the last name.len() symbols of the expression. + let start = src_start + src_length - name_length; + let end = start + name_length; + + self.references.entry(referenced_declaration).or_default().insert(ItemLocation { + start, + end, + path: self.path.to_path_buf(), + }); + } } } + + fn visit_external_assembly_reference(&mut self, reference: &ExternalInlineAssemblyReference) { + self.process_referenced_declaration(reference.declaration as isize, &reference.src); + } } /// Updates to be applied to the sources. @@ -206,6 +214,9 @@ impl Flattener { let top_level_names = self.rename_top_level_definitions(&mut updates); self.rename_contract_level_types_references(&top_level_names, &mut updates); + self.remove_qualified_imports(&mut updates); + self.update_inheritdocs(&top_level_names, &mut updates); + self.remove_imports(&mut updates); let target_pragmas = self.process_pragmas(&mut updates); let target_license = self.process_licenses(&mut updates); @@ -246,8 +257,11 @@ impl Flattener { // `loc.path` is expected to be different for each id because there can't be 2 // top-level eclarations with the same name in the same file. // - // Sorting by loc.path to make the renaming process deterministic - ids.sort_by(|(_, loc_0), (_, loc_1)| loc_0.path.cmp(&loc_1.path)); + // Sorting by index loc.path in sorted files to make the renaming process + // deterministic. + ids.sort_by_key(|(_, loc)| { + self.ordered_sources.iter().position(|p| p == &loc.path).unwrap() + }); } for (i, (id, loc)) in ids.iter().enumerate() { if needs_rename { @@ -274,12 +288,52 @@ impl Flattener { top_level_names } - /// This is a workaround to be able to correctly process definitions which types - /// are present in the form of `ParentName.ChildName` where `ParentName` is a - /// contract name and `ChildName` is a struct/enum name. + /// This is not very clean, but in most cases effective enough method to remove qualified + /// imports from sources. + /// + /// Every qualified import part is an `Identifier` with `referencedDeclaration` field matching + /// ID of one of the import directives. + /// + /// This approach works by firstly collecting all IDs of import directives, and then looks for + /// any references of them. Once the reference is found, it's full length is getting removed + /// from source + 1 charater ('.') /// - /// Such types are represented as `UserDefinedTypeName` in AST and don't include any - /// information about parent in which the definition of child is present. + /// This should work correctly for vast majority of cases, however there are situations for + /// which such approach won't work, most of which are related to code being formatted in an + /// uncommon way. + fn remove_qualified_imports(&self, updates: &mut Updates) { + let imports_ids = self + .asts + .iter() + .flat_map(|(_, ast)| { + ast.nodes.iter().filter_map(|node| match node { + SourceUnitPart::ImportDirective(directive) => Some(directive.id), + _ => None, + }) + }) + .collect::>(); + + let references = self.collect_references(); + + for (id, locs) in references { + if !imports_ids.contains(&(id as usize)) { + continue; + } + + for loc in locs { + updates.entry(loc.path).or_default().insert(( + loc.start, + loc.end + 1, + "".to_string(), + )); + } + } + } + + /// Here we are going through all references to items defined in scope of contracts and updating + /// them to be using correct parent contract name. + /// + /// This will only operate on references from `IdentifierPath` nodes. fn rename_contract_level_types_references( &self, top_level_names: &HashMap, @@ -289,49 +343,129 @@ impl Flattener { for (path, ast) in &self.asts { for node in &ast.nodes { - let current_contract_scope = match node { - SourceUnitPart::ContractDefinition(contract) => Some(contract.id), - _ => None, - }; - let mut collector = UserDefinedTypeNamesCollector { - path: self.target.clone(), - references: HashMap::new(), - }; + let mut collector = + ReferencesCollector { path: self.target.clone(), references: HashMap::new() }; node.walk(&mut collector); - // Now this contains all definitions found in all UserDefinedTypeName nodes in the - // given source unit let references = collector.references; for (id, locs) in references { if let Some((name, contract_id)) = contract_level_definitions.get(&(id as usize)) { - if let Some(current_scope) = current_contract_scope { - // If this is a contract-level definition reference inside of the same - // contract it declared in, we replace it with its name - if current_scope == *contract_id { - updates.entry(path.clone()).or_default().extend( - locs.iter().map(|loc| (loc.start, loc.end, name.to_string())), - ); + for loc in &locs { + // If child item is referenced directly by it's name it's either defined + // in the same contract or in one of it's base contracts, so we don't + // have to change anything. + // Comparing lengths is enough because such items cannot be aliased. + if loc.length() == name.len() { continue; } + // If it was referenced somehow else, we rename it to `Parent.Child` + // format. + let parent_name = top_level_names.get(contract_id).unwrap(); + updates.entry(path.clone()).or_default().insert(( + loc.start, + loc.end, + format!("{}.{}", parent_name, name), + )); } - // If we are in some other contract or in global scope (file-level), then we - // should replace type name with `ParentName.ChildName`` - let parent_name = top_level_names.get(contract_id).unwrap(); - updates.entry(path.clone()).or_default().extend( - locs.iter().map(|loc| { - (loc.start, loc.end, format!("{}.{}", parent_name, name)) - }), - ); } } } } } + /// Finds all @inheritdoc tags in natspec comments and tries replacing them. + /// + /// We will either replace contract name or remove @inheritdoc tag completely to avoid + /// generating invalid source code. + fn update_inheritdocs(&self, top_level_names: &HashMap, updates: &mut Updates) { + for (path, ast) in &self.asts { + // Collect all exported symbols for this source unit + // @inheritdoc value is either one of those or qualified import path which we don't + // support + let exported_symbols = ast + .exported_symbols + .iter() + .filter_map( + |(name, ids)| { + if !ids.is_empty() { + Some((name.as_str(), ids[0])) + } else { + None + } + }, + ) + .collect::>(); + + // Collect all docs in all contracts + let docs = ast + .nodes + .iter() + .filter_map(|node| match node { + SourceUnitPart::ContractDefinition(d) => Some(d), + _ => None, + }) + .flat_map(|contract| { + contract.nodes.iter().filter_map(|node| match node { + ContractDefinitionPart::EventDefinition(event) => { + event.documentation.as_ref() + } + ContractDefinitionPart::ErrorDefinition(error) => { + error.documentation.as_ref() + } + ContractDefinitionPart::FunctionDefinition(func) => { + func.documentation.as_ref() + } + ContractDefinitionPart::VariableDeclaration(var) => { + var.documentation.as_ref() + } + _ => None, + }) + }); + + docs.for_each(|doc| { + let src_start = doc.src.start.unwrap(); + let src_end = src_start + doc.src.length.unwrap(); + + // Documentation node has `text` field, however, it does not contain + // slashes and we can't use if to find positions. + let content: &str = &self.sources.get(path).unwrap().content[src_start..src_end]; + let tag_len = "@inheritdoc".len(); + if let Some(tag_start) = content.find("@inheritdoc") { + if let Some(name_start) = content[tag_start + tag_len..] + .find(|c| c != ' ') + .map(|p| p + tag_start + tag_len) + { + let name_end = content[name_start..] + .find([' ', '\n', '*', '/']) + .map(|p| p + name_start) + .unwrap_or(content.len()); + + let name = &content[name_start..name_end]; + + if let Some(ast_id) = exported_symbols.get(name) { + let new_name = top_level_names.get(ast_id).unwrap(); + updates.entry(path.to_path_buf()).or_default().insert(( + src_start + name_start, + src_start + name_end, + new_name.to_string(), + )); + } else { + updates.entry(path.to_path_buf()).or_default().insert(( + src_start + tag_start, + src_start + name_end, + "".to_string(), + )); + } + } + } + }); + } + } + /// Processes all ASTs and collects all top-level definitions in the form of /// a mapping from name to (definition id, source location) fn collect_top_level_definitions(&self) -> HashMap<&String, HashSet<(usize, ItemLocation)>> { @@ -341,39 +475,49 @@ impl Flattener { ast.nodes .iter() .filter_map(|node| match node { - SourceUnitPart::ContractDefinition(contract) => { - Some((&contract.name, contract.id, &contract.src)) - } + SourceUnitPart::ContractDefinition(contract) => Some(( + &contract.name, + contract.id, + &contract.src, + &contract.name_location, + )), SourceUnitPart::EnumDefinition(enum_) => { - Some((&enum_.name, enum_.id, &enum_.src)) + Some((&enum_.name, enum_.id, &enum_.src, &enum_.name_location)) } SourceUnitPart::StructDefinition(struct_) => { - Some((&struct_.name, struct_.id, &struct_.src)) + Some((&struct_.name, struct_.id, &struct_.src, &struct_.name_location)) } - SourceUnitPart::FunctionDefinition(function) => { - Some((&function.name, function.id, &function.src)) + SourceUnitPart::FunctionDefinition(func) => { + Some((&func.name, func.id, &func.src, &func.name_location)) } - SourceUnitPart::VariableDeclaration(variable) => { - Some((&variable.name, variable.id, &variable.src)) + SourceUnitPart::VariableDeclaration(var) => { + Some((&var.name, var.id, &var.src, &var.name_location)) } - SourceUnitPart::UserDefinedValueTypeDefinition(value_type) => { - Some((&value_type.name, value_type.id, &value_type.src)) + SourceUnitPart::UserDefinedValueTypeDefinition(type_) => { + Some((&type_.name, type_.id, &type_.src, &type_.name_location)) } _ => None, }) - .map(|(name, id, src)| { - // Find location of name in source - let content: &str = &self.sources.get(path).unwrap().content; - let start = src.start.unwrap(); - let end = start + src.length.unwrap(); - - let name_start = content[start..end].find(name).unwrap(); - let name_end = name_start + name.len(); - - let loc = ItemLocation { - path: path.clone(), - start: start + name_start, - end: start + name_end, + .map(|(name, id, src, maybe_name_src)| { + let loc = match maybe_name_src { + Some(src) => { + ItemLocation::try_from_source_loc(src, path.clone()).unwrap() + } + None => { + // Find location of name in source + let content: &str = &self.sources.get(path).unwrap().content; + let start = src.start.unwrap(); + let end = start + src.length.unwrap(); + + let name_start = content[start..end].find(name).unwrap(); + let name_end = name_start + name.len(); + + ItemLocation { + path: path.clone(), + start: start + name_start, + end: start + name_end, + } + } }; (name, (id, loc)) diff --git a/tests/project.rs b/tests/project.rs index 6fe7dcb5..53b99af9 100644 --- a/tests/project.rs +++ b/tests/project.rs @@ -1028,9 +1028,9 @@ contract Foo { } } -contract Bar_1 is Foo {} - contract Bar_0 is Foo {} + +contract Bar_1 is Foo {} " ); } @@ -1334,6 +1334,330 @@ contract B is A {} }); } +#[test] +fn can_flatten_rename_inheritdocs() { + let project = TempProject::dapptools().unwrap(); + + project + .add_source( + "DuplicateA", + r#"pragma solidity ^0.8.10; +contract A {} +"#, + ) + .unwrap(); + + project + .add_source( + "A", + r#"pragma solidity ^0.8.10; +import {A as OtherName} from "./DuplicateA.sol"; + +contract A { + /// Documentation + function foo() public virtual {} +} +"#, + ) + .unwrap(); + + let target = project + .add_source( + "B", + r#"pragma solidity ^0.8.10; +import {A} from "./A.sol"; + +contract B is A { + /// @inheritdoc A + function foo() public override {} +}"#, + ) + .unwrap(); + + let result = + Flattener::new(project.project(), &project.compile().unwrap(), &target).unwrap().flatten(); + assert_eq!( + result, + r"pragma solidity ^0.8.10; + +contract A_0 {} + +contract A_1 { + /// Documentation + function foo() public virtual {} +} + +contract B is A_1 { + /// @inheritdoc A_1 + function foo() public override {} +} +" + ); +} + +#[test] +fn can_flatten_rename_inheritdocs_alias() { + let project = TempProject::dapptools().unwrap(); + + project + .add_source( + "A", + r#"pragma solidity ^0.8.10; + +contract A { + /// Documentation + function foo() public virtual {} +} +"#, + ) + .unwrap(); + + let target = project + .add_source( + "B", + r#"pragma solidity ^0.8.10; +import {A as Alias} from "./A.sol"; + +contract B is Alias { + /// @inheritdoc Alias + function foo() public override {} +}"#, + ) + .unwrap(); + + let result = + Flattener::new(project.project(), &project.compile().unwrap(), &target).unwrap().flatten(); + assert_eq!( + result, + r"pragma solidity ^0.8.10; + +contract A { + /// Documentation + function foo() public virtual {} +} + +contract B is A { + /// @inheritdoc A + function foo() public override {} +} +" + ); +} + +#[test] +fn can_flatten_rename_user_defined_functions() { + let project = TempProject::dapptools().unwrap(); + + project + .add_source( + "CustomUint", + r" +pragma solidity ^0.8.10; + +type CustomUint is uint256; + +function mul(CustomUint a, CustomUint b) pure returns(CustomUint) { + return CustomUint.wrap(CustomUint.unwrap(a) * CustomUint.unwrap(b)); +} + +using {mul} for CustomUint global;", + ) + .unwrap(); + + project + .add_source( + "CustomInt", + r"pragma solidity ^0.8.10; + +type CustomInt is int256; + +function mul(CustomInt a, CustomInt b) pure returns(CustomInt) { + return CustomInt.wrap(CustomInt.unwrap(a) * CustomInt.unwrap(b)); +} + +using {mul} for CustomInt global;", + ) + .unwrap(); + + let target = project + .add_source( + "Target", + r"pragma solidity ^0.8.10; + +import {CustomInt} from './CustomInt.sol'; +import {CustomUint} from './CustomUint.sol'; + +contract Foo { + function mul(CustomUint a, CustomUint b) public returns(CustomUint) { + return a.mul(b); + } + + function mul(CustomInt a, CustomInt b) public returns(CustomInt) { + return a.mul(b); + } +}", + ) + .unwrap(); + + let result = + Flattener::new(project.project(), &project.compile().unwrap(), &target).unwrap().flatten(); + assert_eq!( + result, + r"pragma solidity ^0.8.10; + +type CustomInt is int256; + +function mul_0(CustomInt a, CustomInt b) pure returns(CustomInt) { + return CustomInt.wrap(CustomInt.unwrap(a) * CustomInt.unwrap(b)); +} + +using {mul_0} for CustomInt global; + +type CustomUint is uint256; + +function mul_1(CustomUint a, CustomUint b) pure returns(CustomUint) { + return CustomUint.wrap(CustomUint.unwrap(a) * CustomUint.unwrap(b)); +} + +using {mul_1} for CustomUint global; + +contract Foo { + function mul(CustomUint a, CustomUint b) public returns(CustomUint) { + return a.mul_1(b); + } + + function mul(CustomInt a, CustomInt b) public returns(CustomInt) { + return a.mul_0(b); + } +} +" + ); +} + +#[test] +fn can_flatten_rename_global_functions() { + let project = TempProject::dapptools().unwrap(); + + project + .add_source( + "func1", + r"pragma solidity ^0.8.10; + +function func() view {}", + ) + .unwrap(); + + project + .add_source( + "func2", + r"pragma solidity ^0.8.10; + +function func(uint256 x) view returns(uint256) { + return x + 1; +}", + ) + .unwrap(); + + let target = project + .add_source( + "Target", + r"pragma solidity ^0.8.10; + +import {func as func1} from './func1.sol'; +import {func as func2} from './func2.sol'; + +contract Foo { + constructor(uint256 x) { + func1(); + func2(x); + } +}", + ) + .unwrap(); + + let result = + Flattener::new(project.project(), &project.compile().unwrap(), &target).unwrap().flatten(); + assert_eq!( + result, + r"pragma solidity ^0.8.10; + +function func_0() view {} + +function func_1(uint256 x) view returns(uint256) { + return x + 1; +} + +contract Foo { + constructor(uint256 x) { + func_0(); + func_1(x); + } +} +" + ); +} + +#[test] +fn can_flatten_rename_in_assembly() { + let project = TempProject::dapptools().unwrap(); + + project + .add_source( + "A", + r"pragma solidity ^0.8.10; + +uint256 constant a = 1;", + ) + .unwrap(); + + project + .add_source( + "B", + r"pragma solidity ^0.8.10; + +uint256 constant a = 2;", + ) + .unwrap(); + + let target = project + .add_source( + "Target", + r"pragma solidity ^0.8.10; + +import {a as a1} from './A.sol'; +import {a as a2} from './B.sol'; + +contract Foo { + function test() public returns(uint256 x) { + assembly { + x := mul(a1, a2) + } + } +}", + ) + .unwrap(); + + let result = + Flattener::new(project.project(), &project.compile().unwrap(), &target).unwrap().flatten(); + assert_eq!( + result, + r"pragma solidity ^0.8.10; + +uint256 constant a_0 = 1; + +uint256 constant a_1 = 2; + +contract Foo { + function test() public returns(uint256 x) { + assembly { + x := mul(a_0, a_1) + } + } +} +" + ); +} + #[test] fn can_compile_single_files() { let tmp = TempProject::dapptools().unwrap();