Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize votes lookups for recent checkpoints #3673

Merged
merged 13 commits into from
Sep 4, 2022
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
* `Address`: optimize `functionCall` functions by checking contract size only if there is no returned data. ([#3469](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3469))
* `GovernorCompatibilityBravo`: remove unused `using` statements. ([#3506](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3506))
* `ERC20`: optimize `_transfer`, `_mint` and `_burn` by using `unchecked` arithmetic when possible. ([#3513](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3513))
* `ERC20Votes`, `ERC721Votes`: optimize `getPastVotes` for looking up recent checkpoints. ([#3673](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3673))
* `ERC20FlashMint`: add an internal `_flashFee` function for overriding. ([#3551](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3551))
* `ERC4626`: use the same `decimals()` as the underlying asset by default (if available). ([#3639](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3639))
* `ERC4626`: add internal `_initialConvertToShares` and `_initialConvertToAssets` functions to customize empty vaults behavior. ([#3639](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3639))
Expand Down
4 changes: 2 additions & 2 deletions contracts/governance/utils/Votes.sol
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ abstract contract Votes is IVotes, Context, EIP712 {
* - `blockNumber` must have been already mined
*/
function getPastVotes(address account, uint256 blockNumber) public view virtual override returns (uint256) {
return _delegateCheckpoints[account].getAtBlock(blockNumber);
return _delegateCheckpoints[account].getAtProbablyRecentBlock(blockNumber);
}

/**
Expand All @@ -72,7 +72,7 @@ abstract contract Votes is IVotes, Context, EIP712 {
*/
function getPastTotalSupply(uint256 blockNumber) public view virtual override returns (uint256) {
require(blockNumber < block.number, "Votes: block not yet mined");
return _totalCheckpoints.getAtBlock(blockNumber);
return _totalCheckpoints.getAtProbablyRecentBlock(blockNumber);
}

/**
Expand Down
12 changes: 2 additions & 10 deletions contracts/mocks/CheckpointsMock.sol
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ contract CheckpointsMock {
return _totalCheckpoints.getAtBlock(blockNumber);
}

function getAtRecentBlock(uint256 blockNumber) public view returns (uint256) {
return _totalCheckpoints.getAtRecentBlock(blockNumber);
function getAtProbablyRecentBlock(uint256 blockNumber) public view returns (uint256) {
return _totalCheckpoints.getAtProbablyRecentBlock(blockNumber);
}

function length() public view returns (uint256) {
Expand Down Expand Up @@ -52,10 +52,6 @@ contract Checkpoints224Mock {
return _totalCheckpoints.upperLookup(key);
}

function upperLookupRecent(uint32 key) public view returns (uint224) {
return _totalCheckpoints.upperLookupRecent(key);
}

function length() public view returns (uint256) {
return _totalCheckpoints._checkpoints.length;
}
Expand All @@ -82,10 +78,6 @@ contract Checkpoints160Mock {
return _totalCheckpoints.upperLookup(key);
}

function upperLookupRecent(uint96 key) public view returns (uint224) {
return _totalCheckpoints.upperLookupRecent(key);
}

function length() public view returns (uint256) {
return _totalCheckpoints._checkpoints.length;
}
Expand Down
35 changes: 29 additions & 6 deletions contracts/token/ERC20/extensions/ERC20Votes.sol
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ abstract contract ERC20Votes is IVotes, ERC20Permit {
function _checkpointsLookup(Checkpoint[] storage ckpts, uint256 blockNumber) private view returns (uint256) {
// We run a binary search to look for the earliest checkpoint taken after `blockNumber`.
//
// Initially we check if the block is recent to narrow the search range.
// During the loop, the index of the wanted checkpoint remains in the range [low-1, high).
// With each iteration, either `low` or `high` is moved towards the middle of the range to maintain the invariant.
// - If the middle checkpoint is after `blockNumber`, we look in [low, mid)
Expand All @@ -106,18 +107,30 @@ abstract contract ERC20Votes is IVotes, ERC20Permit {
// Note that if the latest checkpoint available is exactly for `blockNumber`, we end up with an index that is
// past the end of the array, so we technically don't find a checkpoint after `blockNumber`, but it works out
// the same.
uint256 high = ckpts.length;
uint256 length = ckpts.length;

uint256 low = 0;
uint256 high = length;

if (length > 5) {
uint256 mid = length - Math.sqrt(length);
if (_unsafeAccess(ckpts, mid).fromBlock > blockNumber) {
high = mid;
} else {
low = mid + 1;
}
}

while (low < high) {
uint256 mid = Math.average(low, high);
if (ckpts[mid].fromBlock > blockNumber) {
if (_unsafeAccess(ckpts, mid).fromBlock > blockNumber) {
high = mid;
} else {
low = mid + 1;
}
}

return high == 0 ? 0 : ckpts[high - 1].votes;
return high == 0 ? 0 : _unsafeAccess(ckpts, high - 1).votes;
}

/**
Expand Down Expand Up @@ -229,11 +242,14 @@ abstract contract ERC20Votes is IVotes, ERC20Permit {
uint256 delta
) private returns (uint256 oldWeight, uint256 newWeight) {
uint256 pos = ckpts.length;
oldWeight = pos == 0 ? 0 : ckpts[pos - 1].votes;

Checkpoint memory oldCkpt = pos == 0 ? Checkpoint(0, 0) : _unsafeAccess(ckpts, pos - 1);

oldWeight = oldCkpt.votes;
newWeight = op(oldWeight, delta);

if (pos > 0 && ckpts[pos - 1].fromBlock == block.number) {
ckpts[pos - 1].votes = SafeCast.toUint224(newWeight);
if (pos > 0 && oldCkpt.fromBlock == block.number) {
_unsafeAccess(ckpts, pos - 1).votes = SafeCast.toUint224(newWeight);
} else {
ckpts.push(Checkpoint({fromBlock: SafeCast.toUint32(block.number), votes: SafeCast.toUint224(newWeight)}));
}
Expand All @@ -246,4 +262,11 @@ abstract contract ERC20Votes is IVotes, ERC20Permit {
function _subtract(uint256 a, uint256 b) private pure returns (uint256) {
return a - b;
}

function _unsafeAccess(Checkpoint[] storage ckpts, uint256 pos) private view returns (Checkpoint storage result) {
assembly {
mstore(0, ckpts.slot)
result.slot := add(keccak256(0, 0x20), pos)
}
}
}
60 changes: 14 additions & 46 deletions contracts/utils/Checkpoints.sol
Original file line number Diff line number Diff line change
Expand Up @@ -49,22 +49,28 @@ library Checkpoints {

/**
* @dev Returns the value at a given block number. If a checkpoint is not available at that block, the closest one
* before it is returned, or zero otherwise. Similarly to {upperLookup} but optimized for the case when the search
* key is known to be recent.
* before it is returned, or zero otherwise. Similar to {upperLookup} but optimized for the case when the searched
* checkpoint is probably "recent", defined as being among the last sqrt(N) checkpoints where N is the number of
* checkpoints.
*/
function getAtRecentBlock(History storage self, uint256 blockNumber) internal view returns (uint256) {
function getAtProbablyRecentBlock(History storage self, uint256 blockNumber) internal view returns (uint256) {
require(blockNumber < block.number, "Checkpoints: block not yet mined");
uint32 key = SafeCast.toUint32(blockNumber);

uint256 length = self._checkpoints.length;
uint256 offset = 1;

while (offset <= length && _unsafeAccess(self._checkpoints, length - offset)._blockNumber > key) {
offset <<= 1;
uint256 low = 0;
uint256 high = length;

if (length > 5) {
uint256 mid = length - Math.sqrt(length);
if (key < _unsafeAccess(self._checkpoints, mid)._blockNumber) {
high = mid;
} else {
low = mid + 1;
}
}

uint256 low = offset < length ? length - offset : 0;
uint256 high = length - (offset >> 1);
uint256 pos = _upperBinaryLookup(self._checkpoints, key, low, high);

return pos == 0 ? 0 : _unsafeAccess(self._checkpoints, pos - 1)._value;
Expand Down Expand Up @@ -225,25 +231,6 @@ library Checkpoints {
return pos == 0 ? 0 : _unsafeAccess(self._checkpoints, pos - 1)._value;
}

/**
* @dev Returns the value in the most recent checkpoint with key lower or equal than the search key (similarly to
* {upperLookup}), optimized for the case when the search key is known to be recent.
*/
function upperLookupRecent(Trace224 storage self, uint32 key) internal view returns (uint224) {
uint256 length = self._checkpoints.length;
uint256 offset = 1;

while (offset <= length && _unsafeAccess(self._checkpoints, length - offset)._key > key) {
offset <<= 1;
}

uint256 low = 0 < offset && offset < length ? length - offset : 0;
uint256 high = length - (offset >> 1);
uint256 pos = _upperBinaryLookup(self._checkpoints, key, low, high);

return pos == 0 ? 0 : _unsafeAccess(self._checkpoints, pos - 1)._value;
}

/**
* @dev Pushes a (`key`, `value`) pair into an ordered list of checkpoints, either by inserting a new checkpoint,
* or by updating the last one.
Expand Down Expand Up @@ -380,25 +367,6 @@ library Checkpoints {
return pos == 0 ? 0 : _unsafeAccess(self._checkpoints, pos - 1)._value;
}

/**
* @dev Returns the value in the most recent checkpoint with key lower or equal than the search key (similarly to
* {upperLookup}), optimized for the case when the search key is known to be recent.
*/
function upperLookupRecent(Trace160 storage self, uint96 key) internal view returns (uint160) {
uint256 length = self._checkpoints.length;
uint256 offset = 1;

while (offset <= length && _unsafeAccess(self._checkpoints, length - offset)._key > key) {
offset <<= 1;
}

uint256 low = 0 < offset && offset < length ? length - offset : 0;
uint256 high = length - (offset >> 1);
uint256 pos = _upperBinaryLookup(self._checkpoints, key, low, high);

return pos == 0 ? 0 : _unsafeAccess(self._checkpoints, pos - 1)._value;
}

/**
* @dev Pushes a (`key`, `value`) pair into an ordered list of checkpoints, either by inserting a new checkpoint,
* or by updating the last one.
Expand Down
41 changes: 14 additions & 27 deletions scripts/generate/templates/Checkpoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -70,25 +70,6 @@ function upperLookup(${opts.historyTypeName} storage self, ${opts.keyTypeName} k
uint256 pos = _upperBinaryLookup(self.${opts.checkpointFieldName}, key, 0, length);
return pos == 0 ? 0 : _unsafeAccess(self.${opts.checkpointFieldName}, pos - 1).${opts.valueFieldName};
}

/**
* @dev Returns the value in the most recent checkpoint with key lower or equal than the search key (similarly to
* {upperLookup}), optimized for the case when the search key is known to be recent.
*/
function upperLookupRecent(${opts.historyTypeName} storage self, ${opts.keyTypeName} key) internal view returns (${opts.valueTypeName}) {
uint256 length = self.${opts.checkpointFieldName}.length;
uint256 offset = 1;

while (offset <= length && _unsafeAccess(self.${opts.checkpointFieldName}, length - offset).${opts.keyFieldName} > key) {
offset <<= 1;
}

uint256 low = 0 < offset && offset < length ? length - offset : 0;
uint256 high = length - (offset >> 1);
uint256 pos = _upperBinaryLookup(self.${opts.checkpointFieldName}, key, low, high);

return pos == 0 ? 0 : _unsafeAccess(self.${opts.checkpointFieldName}, pos - 1).${opts.valueFieldName};
}
`;

const legacyOperations = opts => `\
Expand All @@ -115,22 +96,28 @@ function getAtBlock(${opts.historyTypeName} storage self, uint256 blockNumber) i

/**
* @dev Returns the value at a given block number. If a checkpoint is not available at that block, the closest one
* before it is returned, or zero otherwise. Similarly to {upperLookup} but optimized for the case when the search
* key is known to be recent.
* before it is returned, or zero otherwise. Similar to {upperLookup} but optimized for the case when the searched
* checkpoint is probably "recent", defined as being among the last sqrt(N) checkpoints where N is the number of
* checkpoints.
*/
function getAtRecentBlock(${opts.historyTypeName} storage self, uint256 blockNumber) internal view returns (uint256) {
function getAtProbablyRecentBlock(${opts.historyTypeName} storage self, uint256 blockNumber) internal view returns (uint256) {
require(blockNumber < block.number, "Checkpoints: block not yet mined");
uint32 key = SafeCast.toUint32(blockNumber);

uint256 length = self.${opts.checkpointFieldName}.length;
uint256 offset = 1;

while (offset <= length && _unsafeAccess(self.${opts.checkpointFieldName}, length - offset).${opts.keyFieldName} > key) {
offset <<= 1;
uint256 low = 0;
uint256 high = length;

if (length > 5) {
uint256 mid = length - Math.sqrt(length);
if (key < _unsafeAccess(self.${opts.checkpointFieldName}, mid)._blockNumber) {
high = mid;
} else {
low = mid + 1;
}
}

uint256 low = offset < length ? length - offset : 0;
uint256 high = length - (offset >> 1);
uint256 pos = _upperBinaryLookup(self.${opts.checkpointFieldName}, key, low, high);

return pos == 0 ? 0 : _unsafeAccess(self.${opts.checkpointFieldName}, pos - 1).${opts.valueFieldName};
Expand Down
8 changes: 2 additions & 6 deletions scripts/generate/templates/CheckpointsMock.js
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ contract CheckpointsMock {
return _totalCheckpoints.getAtBlock(blockNumber);
}

function getAtRecentBlock(uint256 blockNumber) public view returns (uint256) {
return _totalCheckpoints.getAtRecentBlock(blockNumber);
function getAtProbablyRecentBlock(uint256 blockNumber) public view returns (uint256) {
return _totalCheckpoints.getAtProbablyRecentBlock(blockNumber);
}

function length() public view returns (uint256) {
Expand Down Expand Up @@ -58,10 +58,6 @@ contract Checkpoints${length}Mock {
return _totalCheckpoints.upperLookup(key);
}

function upperLookupRecent(uint${256 - length} key) public view returns (uint224) {
return _totalCheckpoints.upperLookupRecent(key);
}

function length() public view returns (uint256) {
return _totalCheckpoints._checkpoints.length;
}
Expand Down
10 changes: 5 additions & 5 deletions test/utils/Checkpoints.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ contract('Checkpoints', function (accounts) {

it('returns zero as past value', async function () {
await time.advanceBlock();
expect(await this.checkpoint.getAtBlock(await web3.eth.getBlockNumber() - 1)).to.be.bignumber.equal('0');
expect(await this.checkpoint.getAtRecentBlock(await web3.eth.getBlockNumber() - 1)).to.be.bignumber.equal('0');
expect(await this.checkpoint.getAtBlock(await web3.eth.getBlockNumber() - 1))
.to.be.bignumber.equal('0');
expect(await this.checkpoint.getAtProbablyRecentBlock(await web3.eth.getBlockNumber() - 1))
.to.be.bignumber.equal('0');
});
});

Expand All @@ -41,7 +43,7 @@ contract('Checkpoints', function (accounts) {
expect(await this.checkpoint.latest()).to.be.bignumber.equal('3');
});

for (const fn of [ 'getAtBlock(uint256)', 'getAtRecentBlock(uint256)' ]) {
for (const fn of [ 'getAtBlock(uint256)', 'getAtProbablyRecentBlock(uint256)' ]) {
describe(`lookup: ${fn}`, function () {
it('returns past values', async function () {
expect(await this.checkpoint.methods[fn](this.tx1.receipt.blockNumber - 1)).to.be.bignumber.equal('0');
Expand Down Expand Up @@ -95,7 +97,6 @@ contract('Checkpoints', function (accounts) {
it('lookup returns 0', async function () {
expect(await this.contract.lowerLookup(0)).to.be.bignumber.equal('0');
expect(await this.contract.upperLookup(0)).to.be.bignumber.equal('0');
expect(await this.contract.upperLookupRecent(0)).to.be.bignumber.equal('0');
});
});

Expand Down Expand Up @@ -149,7 +150,6 @@ contract('Checkpoints', function (accounts) {
const value = last(this.checkpoints.filter(x => i >= x.key))?.value || '0';

expect(await this.contract.upperLookup(i)).to.be.bignumber.equal(value);
expect(await this.contract.upperLookupRecent(i)).to.be.bignumber.equal(value);
}
});
});
Expand Down