Skip to content
This repository has been archived by the owner on Jan 24, 2022. It is now read-only.

Handle initialize function lookup on dependencies contracts #234

Merged
merged 1 commit into from
Oct 11, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
pragma solidity ^0.4.24;

contract GreeterBase {
event Greeting(string greeting);

uint256 public value;

function initialize(uint256 _value) public {
value = _value;
}

function clashingInitialize(uint256 _value) public {
value = _value;
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
pragma solidity ^0.4.24;

contract GreeterImpl {
event Greeting(string greeting);
import "./GreeterBase.sol";

// This contract and its parent are used in CLI scripts/create.test.js to check initialization
// of a contract loaded from a dependency. Do not import this file or its parent from any mock
// contract in CLI, since one of the goals of the test is to check processing a contract that
// has not been locally compiled. Also, make sure to change the absolute path in the build artifacts
// so they point to a path that does not exist in your machine, since that's the typical scenario
// for contracts loaded from libs.

contract GreeterImpl is GreeterBase {
function clashingInitialize(string _value) public {
value = bytes(_value).length;
}

function greet(string who) public {
emit Greeting(greeting(who));
Expand Down
18 changes: 17 additions & 1 deletion packages/cli/test/scripts/create.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ contract('create script', function([_, owner]) {
const version = '1.1.0';
const txParams = { from: owner };

const assertProxy = async function(networkFile, alias, { version, say, implementation, packageName }) {
const assertProxy = async function(networkFile, alias, { version, say, implementation, packageName, value }) {
const proxyInfo = networkFile.getProxies({ contract: alias })[0]
proxyInfo.contract.should.eq(alias)
proxyInfo.address.should.be.nonzeroAddress;
Expand All @@ -41,6 +41,12 @@ contract('create script', function([_, owner]) {
said.should.eq(say);
}

if (value) {
const proxy = await ImplV1.at(proxyInfo.address);
const actualValue = await proxy.value();
actualValue.toNumber().should.eq(value);
}

if (implementation) {
proxyInfo.implementation.should.eq(implementation);
}
Expand Down Expand Up @@ -167,6 +173,16 @@ contract('create script', function([_, owner]) {
await createProxy({ packageName: 'mock-stdlib-undeployed', contractAlias: 'Greeter', network, txParams, networkFile: this.networkFile });
await assertProxy(this.networkFile, 'Greeter', { version, packageName: 'mock-stdlib-undeployed' });
});

it('should initialize a proxy from a dependency', async function () {
await createProxy({ packageName: 'mock-stdlib-undeployed', contractAlias: 'Greeter', network, txParams, networkFile: this.networkFile, initMethod: 'initialize', initArgs: ["42"] });
await assertProxy(this.networkFile, 'Greeter', { version, packageName: 'mock-stdlib-undeployed', value: 42 });
});

it('should initialize a proxy from a dependency using explicit function', async function () {
await createProxy({ packageName: 'mock-stdlib-undeployed', contractAlias: 'Greeter', network, txParams, networkFile: this.networkFile, initMethod: 'clashingInitialize(uint256)', initArgs: ["42"] });
await assertProxy(this.networkFile, 'Greeter', { version, packageName: 'mock-stdlib-undeployed', value: 42 });
});
});

describe('with unlinked dependency', function () {
Expand Down
3 changes: 2 additions & 1 deletion packages/lib/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
"description": "ZeppelinOS library",
"scripts": {
"test": "scripts/test.sh",
"prepare": "rm -rf build/contracts && truffle compile && rm -rf lib && babel src --out-dir lib"
"prepare": "rm -rf build/contracts && truffle compile && rm -rf lib && babel src --out-dir lib",
"watch": "babel src -w -d lib"
},
"repository": {
"type": "git",
Expand Down
72 changes: 51 additions & 21 deletions packages/lib/src/utils/ABIs.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,41 +2,61 @@ import encodeCall from '../helpers/encodeCall'
import ContractAST from './ContractAST';

export function buildCallData(contractClass, methodName, args) {
const method = getFunctionFromMostDerivedContract(contractClass, methodName, args)
const method = getABIFunction(contractClass, methodName, args)
const argTypes = method.inputs.map(input => input.type)
const callData = encodeCall(methodName, argTypes, args)
const callData = encodeCall(method.name, argTypes, args)
return { method, callData }
}

export function getFunctionFromMostDerivedContract(contractClass, methodName, args) {
const methodNode = getFunctionNodeFromMostDerivedContract(contractClass, methodName, args);
const inputs = methodNode.parameters.parameters.map(parameter => {
const typeString = parameter.typeDescriptions.typeString
const type = typeString.includes('contract') ? 'address' : typeString
return { name: parameter.name , type }
})

const targetMethod = { name: methodName, inputs }
const matchArgsTypes = fn => fn.inputs.every((input, index) => targetMethod.inputs[index] && targetMethod.inputs[index].type === input.type);
export function getABIFunction(contractClass, methodName, args) {
const targetMethod = tryGetTargetFunction(contractClass, methodName, args);
if (targetMethod) methodName = targetMethod.name;

const matchArgsTypes = fn => targetMethod && fn.inputs.every((input, index) => targetMethod.inputs[index] && targetMethod.inputs[index].type === input.type);
const matchNameAndArgsLength = fn => fn.name === methodName && fn.inputs.length === args.length;

let abiMethods = contractClass.abi.filter(fn => matchNameAndArgsLength(fn) && matchArgsTypes(fn));
if (abiMethods.length === 0) abiMethods = contractClass.abi.filter(fn => matchNameAndArgsLength(fn));

const abiMethod =
contractClass.abi.find(fn => matchNameAndArgsLength(fn) && matchArgsTypes(fn)) ||
contractClass.abi.find(fn => matchNameAndArgsLength(fn));

if (!abiMethod) throw Error(`Could not find method ${methodName} with ${args.length} arguments in contract ${contractClass.contractName}`)
return abiMethod;
switch (abiMethods.length) {
case 0: throw Error(`Could not find method ${methodName} with ${args.length} arguments in contract ${contractClass.contractName}`);
case 1: return abiMethods[0];
default: throw Error(`Found more than one match for function ${methodName} with ${args.length} arguments in contract ${contractClass.contractName}`);
}
}

function tryGetTargetFunction(contractClass, methodName, args) {
// Match foo(uint256,string) as method name, and look for that in the ABI
const match = methodName.match(/^\s*(.+)\((.*)\)\s*$/)
if (match) {
const name = match[1];
const inputs = match[2].split(',').map(arg => ({ type: arg }));
return { name, inputs };
}

// Otherwise, look for the most derived contract
const methodNode = tryGetFunctionNodeFromMostDerivedContract(contractClass, methodName, args);
if (methodNode) {
const inputs = methodNode.parameters.parameters.map(parameter => {
const typeString = parameter.typeDescriptions.typeString
const type = typeString.includes('contract') ? 'address' : typeString
return { name: parameter.name, type }
})
return { name: methodNode.name, inputs }
}
}

function getFunctionNodeFromMostDerivedContract(contractClass, methodName, args) {
const ast = new ContractAST(contractClass, null, { nodesFilter: ['ContractDefinition', 'FunctionDefinition'] });
function tryGetFunctionNodeFromMostDerivedContract(contractClass, methodName, args) {
const linearizedBaseContracts = tryGetLinearizedBaseContracts(contractClass);
if (!linearizedBaseContracts) return null;

const nodeMatches = (node) => (
node.nodeType === 'FunctionDefinition' &&
node.name === methodName &&
node.parameters.parameters.length === args.length
);

for (const contract of ast.getLinearizedBaseContracts(true)) {
for (const contract of linearizedBaseContracts) {
const funs = contract.nodes.filter(nodeMatches);
switch (funs.length) {
case 0: continue;
Expand All @@ -47,6 +67,16 @@ function getFunctionNodeFromMostDerivedContract(contractClass, methodName, args)
throw Error(`Could not find method ${methodName} with ${args.length} arguments in contract ${contractClass.contractName}`)
}

function tryGetLinearizedBaseContracts(contractClass) {
try {
const ast = new ContractAST(contractClass, null, { nodesFilter: ['ContractDefinition', 'FunctionDefinition'] });
return ast.getLinearizedBaseContracts(true);
} catch (err) {
// This lookup may fail on contracts loaded from libraries, so we just silently fail and fall back to other methods
return null;
}
}

export function callDescription(method, args) {
const argsDescriptions = method.inputs.map((input, index) => ` - ${input.name} (${input.type}): ${JSON.stringify(args[index])}`)
return `${method.name} with: \n${argsDescriptions.join('\n')}`
Expand Down
1 change: 0 additions & 1 deletion packages/lib/src/utils/ContractAST.js
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ export default class ContractAST {
.forEach(importPath => {
if (this.imports.has(importPath)) return;
this.imports.add(importPath);
console.log(`Adding ${importPath}`)
this.artifacts.getArtifactsFromSourcePath(importPath).forEach(importedArtifact => {
this._collectNodes(importedArtifact.ast)
this._collectImports(importedArtifact.ast)
Expand Down
12 changes: 8 additions & 4 deletions packages/lib/test/src/utils/ABIs.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

require('../../setup')

import { getFunctionFromMostDerivedContract as getFunction } from '../../../src/utils/ABIs'
import { getABIFunction as getFunction } from '../../../src/utils/ABIs'
import Contracts from '../../../src/utils/Contracts'

const should = require('chai').should()

describe('ABIs', function() {
describe('getFunctionFromMostDerivedContract', function () {
describe('getABIFunction', function () {
it('matches number of arguments', async function () {
testGetFunction('GetFunctionBase', [1,2], ['uint256', 'uint256']);
});
Expand All @@ -25,6 +25,10 @@ describe('ABIs', function() {
testGetFunction('GetFunctionOtherGrandchild', ['1'], ['bytes']);
});

it('chooses function based on explicit types', async function () {
testGetFunction('GetFunctionGrandchild', ['1'], ['uint256'], 'initialize(uint256)');
});

it('throws if not found', async function () {
expect(() => testGetFunction('GetFunctionBase', [1,2,3])).to.throw("Could not find method initialize with 3 arguments in contract GetFunctionBase")
});
Expand All @@ -35,9 +39,9 @@ describe('ABIs', function() {
});
})

function testGetFunction(contractName, args, expectedTypes) {
function testGetFunction(contractName, args, expectedTypes, funName = 'initialize') {
const contractClass = Contracts.getFromLocal(contractName);
const method = getFunction(contractClass, 'initialize', args);
const method = getFunction(contractClass, funName, args);
should.exist(method)
method.inputs.map(m => m.type).should.be.deep.eq(expectedTypes);
}