Skip to content

Commit

Permalink
🖼 Mock contract chaining behaviour (#816)
Browse files Browse the repository at this point in the history
  • Loading branch information
pawelpolak2 authored Jan 23, 2023
1 parent fb6863d commit 46b954e
Show file tree
Hide file tree
Showing 8 changed files with 389 additions and 45 deletions.
5 changes: 5 additions & 0 deletions .changeset/orange-deers-sit.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@ethereum-waffle/mock-contract": patch
---

Add mock contract chaining behaviour
49 changes: 49 additions & 0 deletions docs/source/mock-contract.rst
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,55 @@ Mock contract will be used to mock exactly this call with values that are releva
});
});
Mocking multiple calls
----------------------

Mock contract allows to queue multiple mock calls to the same function. This can only be done if the function is not pure or view. That's because the mock call queue is stored on the blockchain and we need to modify it.

.. code-block:: ts
await mockContract.mock.<nameOfMethod>.returns(<value1>).returns(<value2>);
await mockContract.<nameOfMethod>() // returns <value1>
await mockContract.<nameOfMethod>() // returns <value2>
Just like with regular mock calls, the queue can be set up to revert or return a specified value. It can also be set up to return different values for different arguments.

.. code-block:: ts
await mockContract.mock.<nameOfMethod>.returns(<value1>).returns(<value2>);
await mockContract.mock.<nameOfMethod>.withArgs(<arguments1>).returns(<value3>);
await mockContract.<nameOfMethod>() // returns <value1>
await mockContract.<nameOfMethod>() // returns <value2>
await mockContract.<nameOfMethod>(<arguments1>) // returns <value3>
Keep in mind that the mocked revert must be at the end of the queue, because it prevents the contract from updating the queue.

.. code-block:: ts
await mockContract.mock.<nameOfMethod>.returns(<value1>).returns(<value2>).reverts();
await mockContract.<nameOfMethod>() // returns <value1>
await mockContract.<nameOfMethod>() // returns <value2>
await mockContract.<nameOfMethod>() // reverts
When the queue is empty, the mock contract will return the last value from the queue and each time the you set up a new queue, the old one is overwritten.

.. code-block:: ts
await mockContract.mock.<nameOfMethod>.returns(<value1>).returns(<value2>);
await mockContract.<nameOfMethod>() // returns <value1>
await mockContract.<nameOfMethod>() // returns <value2>
await mockContract.<nameOfMethod>() // returns <value2>
await mockContract.mock.<nameOfMethod>.returns(<value1>).returns(<value2>);
await mockContract.mock.<nameOfMethod>.returns(<value3>).returns(<value4>);
await mockContract.<nameOfMethod>() // returns <value3>
await mockContract.<nameOfMethod>() // returns <value4>
Mocking receive function
------------------------

Expand Down
2 changes: 1 addition & 1 deletion waffle-mock-contract/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"module": "dist/esm/src/index.ts",
"types": "dist/esm/src/index.d.ts",
"scripts": {
"test": "export NODE_ENV=test && mocha",
"test": "ts-node ./test/helpers/buildTestContracts.ts && export NODE_ENV=test && mocha",
"lint": "eslint '{src,test}/**/*.ts'",
"lint:fix": "eslint --fix '{src,test}/**/*.ts'",
"build": "rimraf ./dist && yarn build:sol && yarn build:esm && yarn build:cjs && ts-node ./test/helpers/buildTestContracts.ts",
Expand Down
98 changes: 83 additions & 15 deletions waffle-mock-contract/src/Doppelganger.sol
Original file line number Diff line number Diff line change
Expand Up @@ -2,48 +2,101 @@
pragma solidity ^0.6.3;

contract Doppelganger {

// ============================== Linked list queues data structure explainer ==============================
// mockConfig contains multiple linked lists, one for each unique call
// mockConfig[<callData hash>] => root node of the linked list for this call
// mockConfig[<callData hash>].next => 'address' of the next node. It's always defined, even if it's the last node.
// When defining a new node .next is set to the hash of the 'address' of the last node
// mockConfig[mockConfig[<callData hash>].next] => next node (possibly undefined)
// tails[<callData hash>] => 'address' of the node 'one after' the last node (<last node>.next)
// in the linked list with root node at <callData hash>

struct MockCall {
bool initialized;
bytes32 next;
bool reverts;
string revertReason;
bytes returnValue;
}

mapping(bytes32 => MockCall) mockConfig;
mapping(bytes32 => bytes32) tails;
bool receiveReverts;
string receiveRevertReason;

fallback() external payable {
MockCall storage mockCall = __internal__getMockCall();
MockCall memory mockCall = __internal__getMockCall();
if (mockCall.reverts == true) {
__internal__mockRevert(mockCall.revertReason);
return;
}
__internal__mockReturn(mockCall.returnValue);
}

receive() payable external {
require(receiveReverts == false, receiveRevertReason);
}

function __clearQueue(bytes32 at) private {
tails[at] = at;
while(mockConfig[at].next != "") {
bytes32 next = mockConfig[at].next;
delete mockConfig[at];
at = next;
}
}

function __waffle__mockReverts(bytes memory data, string memory reason) public {
mockConfig[keccak256(data)] = MockCall({
initialized: true,
function __waffle__queueRevert(bytes memory data, string memory reason) public {
// get the root node of the linked list for this call
bytes32 root = keccak256(data);

// get the 'address' of the node 'one after' the last node
// this is where the new node will be inserted
bytes32 tail = tails[root];
if(tail == "") tail = keccak256(data);

// new tail is set to the hash of the current tail
tails[root] = keccak256(abi.encodePacked(tail));

// initialize the new node
mockConfig[tail] = MockCall({
next: tails[root],
reverts: true,
revertReason: reason,
returnValue: ""
});
}

function __waffle__mockReturns(bytes memory data, bytes memory value) public {
mockConfig[keccak256(data)] = MockCall({
initialized: true,
function __waffle__mockReverts(bytes memory data, string memory reason) public {
__clearQueue(keccak256(data));
__waffle__queueRevert(data, reason);
}

function __waffle__queueReturn(bytes memory data, bytes memory value) public {
// get the root node of the linked list for this call
bytes32 root = keccak256(data);

// get the 'address' of the node 'one after' the last node
// this is where the new node will be inserted
bytes32 tail = tails[root];
if(tail == "") tail = keccak256(data);

// new tail is set to the hash of the current tail
tails[root] = keccak256(abi.encodePacked(tail));

// initialize the new node
mockConfig[tail] = MockCall({
next: tails[root],
reverts: false,
revertReason: "",
returnValue: value
});
}

function __waffle__mockReturns(bytes memory data, bytes memory value) public {
__clearQueue(keccak256(data));
__waffle__queueReturn(data, value);
}

function __waffle__receiveReverts(string memory reason) public {
receiveReverts = true;
receiveRevertReason = reason;
Expand All @@ -61,15 +114,30 @@ contract Doppelganger {
return returnValue;
}

function __internal__getMockCall() view private returns (MockCall storage mockCall) {
mockCall = mockConfig[keccak256(msg.data)];
if (mockCall.initialized == true) {
function __internal__getMockCall() private returns (MockCall memory mockCall) {
// get the root node of the queue for this call
bytes32 root = keccak256(msg.data);
mockCall = mockConfig[root];
if (mockCall.next != "") {
// Mock method with specified arguments

// If there is a next mock call, set it as the current mock call
// We check if the next mock call is defined by checking if it has a 'next' variable defined
// (next value is always defined, even if it's the last mock call)
if(mockConfig[mockCall.next].next != ""){ // basically if it's not the last mock call
mockConfig[root] = mockConfig[mockCall.next];
delete mockConfig[mockCall.next];
}
return mockCall;
}
mockCall = mockConfig[keccak256(abi.encodePacked(msg.sig))];
if (mockCall.initialized == true) {
root = keccak256(abi.encodePacked(msg.sig));
mockCall = mockConfig[root];
if (mockCall.next != "") {
// Mock method with any arguments
if(mockConfig[mockCall.next].next != ""){ // same as above
mockConfig[root] = mockConfig[mockCall.next];
delete mockConfig[mockCall.next];
}
return mockCall;
}
revert("Mock on the method is not initialized");
Expand Down
142 changes: 115 additions & 27 deletions waffle-mock-contract/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,126 @@ import type {JsonRpcProvider} from '@ethersproject/providers';

type ABI = string | Array<utils.Fragment | JsonFragment | string>

export type Stub = ReturnType<typeof stub>;

type DeployOptions = {
address: string;
override?: boolean;
interface StubInterface {
returns(...args: any): StubInterface;
reverts(): StubInterface;
revertsWithReason(reason: string): StubInterface;
withArgs(...args: any[]): StubInterface;
}

export interface MockContract extends Contract {
mock: {
[key: string]: Stub;
[key: string]: StubInterface;
};
call (contract: Contract, functionName: string, ...params: any[]): Promise<any>;
staticcall (contract: Contract, functionName: string, ...params: any[]): Promise<any>;
}

class Stub implements StubInterface {
callData: string;
stubCalls: Array<() => Promise<any>> = [];
revertSet = false;
argsSet = false;

constructor(
private mockContract: Contract,
private encoder: utils.AbiCoder,
private func: utils.FunctionFragment
) {
this.callData = mockContract.interface.getSighash(func);
}

private err(reason: string): never {
this.stubCalls = [];
this.revertSet = false;
this.argsSet = false;
throw new Error(reason);
}

returns(...args: any) {
if (this.revertSet) this.err('Revert must be the last call');
if (!this.func.outputs) this.err('Cannot mock return values from a void function');
const encoded = this.encoder.encode(this.func.outputs, args);

// if there no calls then this is the first call and we need to use mockReturns to override the queue
if (this.stubCalls.length === 0) {
this.stubCalls.push(async () => {
await this.mockContract.__waffle__mockReturns(this.callData, encoded);
});
} else {
this.stubCalls.push(async () => {
await this.mockContract.__waffle__queueReturn(this.callData, encoded);
});
}
return this;
}

reverts() {
if (this.revertSet) this.err('Revert must be the last call');

// if there no calls then this is the first call and we need to use mockReturns to override the queue
if (this.stubCalls.length === 0) {
this.stubCalls.push(async () => {
await this.mockContract.__waffle__mockReverts(this.callData, 'Mock revert');
});
} else {
this.stubCalls.push(async () => {
await this.mockContract.__waffle__queueRevert(this.callData, 'Mock revert');
});
}
this.revertSet = true;
return this;
}

revertsWithReason(reason: string) {
if (this.revertSet) this.err('Revert must be the last call');

// if there no calls then this is the first call and we need to use mockReturns to override the queue
if (this.stubCalls.length === 0) {
this.stubCalls.push(async () => {
await this.mockContract.__waffle__mockReverts(this.callData, reason);
});
} else {
this.stubCalls.push(async () => {
await this.mockContract.__waffle__queueRevert(this.callData, reason);
});
}
this.revertSet = true;
return this;
}

withArgs(...params: any[]) {
if (this.argsSet) this.err('withArgs can be called only once');
this.callData = this.mockContract.interface.encodeFunctionData(this.func, params);
this.argsSet = true;
return this;
}

async then(resolve: () => void, reject: (e: any) => void) {
for (let i = 0; i < this.stubCalls.length; i++) {
try {
await this.stubCalls[i]();
} catch (e) {
this.stubCalls = [];
this.argsSet = false;
this.revertSet = false;
reject(e);
return;
}
}

this.stubCalls = [];
this.argsSet = false;
this.revertSet = false;
resolve();
}
}

type DeployOptions = {
address: string;
override?: boolean;
}

async function deploy(signer: Signer, options?: DeployOptions) {
if (options) {
const {address, override} = options;
Expand Down Expand Up @@ -50,29 +155,12 @@ async function deploy(signer: Signer, options?: DeployOptions) {
return factory.deploy();
}

function stub(mockContract: Contract, encoder: utils.AbiCoder, func: utils.FunctionFragment, params?: any[]) {
const callData = params
? mockContract.interface.encodeFunctionData(func, params)
: mockContract.interface.getSighash(func);

return {
returns: async (...args: any) => {
if (!func.outputs) return;
const encoded = encoder.encode(func.outputs, args);
await mockContract.__waffle__mockReturns(callData, encoded);
},
reverts: async () => mockContract.__waffle__mockReverts(callData, 'Mock revert'),
revertsWithReason: async (reason: string) => mockContract.__waffle__mockReverts(callData, reason),
withArgs: (...args: any[]) => stub(mockContract, encoder, func, args)
};
}

function createMock(abi: ABI, mockContractInstance: Contract) {
const {functions} = new utils.Interface(abi);
const encoder = new utils.AbiCoder();

const mockedAbi = Object.values(functions).reduce((acc, func) => {
const stubbed = stub(mockContractInstance, encoder, func);
const stubbed = new Stub(mockContractInstance as MockContract, encoder, func);
return {
...acc,
[func.name]: stubbed,
Expand All @@ -81,10 +169,10 @@ function createMock(abi: ABI, mockContractInstance: Contract) {
}, {} as MockContract['mock']);

mockedAbi.receive = {
returns: async () => { throw new Error('Receive function return is not implemented.'); },
returns: () => { throw new Error('Receive function return is not implemented.'); },
withArgs: () => { throw new Error('Receive function return is not implemented.'); },
reverts: async () => mockContractInstance.__waffle__receiveReverts('Mock Revert'),
revertsWithReason: async (reason: string) => mockContractInstance.__waffle__receiveReverts(reason)
reverts: () => mockContractInstance.__waffle__receiveReverts('Mock Revert'),
revertsWithReason: (reason: string) => mockContractInstance.__waffle__receiveReverts(reason)
};

return mockedAbi;
Expand Down
Loading

0 comments on commit 46b954e

Please sign in to comment.