From 1cd6967167928a48deb796c2c8f1a2d820a70d31 Mon Sep 17 00:00:00 2001 From: David Laprade Date: Fri, 2 Dec 2022 11:19:45 -0500 Subject: [PATCH] Use OZ checkpointing library (#20) * Bump OZ to get newer checkpointing code * Take pass at using OZ checkpointing * Fix OZ checkpointing usage * Fix comments * scopelint fmt * Fix OZ dep loading + pin to latest official release * Remove underscores from private storage vars * Use Checkpoints for Checkpoints.History * Fix unrelated failing test --- .gitmodules | 7 +- foundry.toml | 1 + lib/openzeppelin-contracts | 2 +- src/ATokenNaive.sol | 130 +++++++++---------------------------- test/FractionalPool.t.sol | 15 +++-- 5 files changed, 46 insertions(+), 109 deletions(-) diff --git a/.gitmodules b/.gitmodules index 211cb53..608aa52 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,12 +1,13 @@ [submodule "lib/solmate"] path = lib/solmate url = https://github.com/rari-capital/solmate -[submodule "lib/openzeppelin-contracts"] - path = lib/openzeppelin-contracts - url = https://github.com/openzeppelin/openzeppelin-contracts [submodule "lib/aave-v3-core"] path = lib/aave-v3-core url = https://github.com/aave/aave-v3-core [submodule "lib/forge-std"] path = lib/forge-std url = https://github.com/foundry-rs/forge-std +[submodule "lib/openzeppelin-contracts"] + path = lib/openzeppelin-contracts + url = https://github.com/openzeppelin/openzeppelin-contracts + branch = v4.8.0 diff --git a/foundry.toml b/foundry.toml index 5cf310f..da53074 100644 --- a/foundry.toml +++ b/foundry.toml @@ -4,6 +4,7 @@ # auto-detect solc versions. optimizer = true optimizer_runs = 10_000_000 + remappings = ["openzeppelin-contracts/=lib/openzeppelin-contracts"] verbosity = 3 [profile.ci] diff --git a/lib/openzeppelin-contracts b/lib/openzeppelin-contracts index 65b4572..49c0e43 160000 --- a/lib/openzeppelin-contracts +++ b/lib/openzeppelin-contracts @@ -1 +1 @@ -Subproject commit 65b45726b34dafe8fc3ef78c3d4b7b3f404f61ad +Subproject commit 49c0e4370d0cc50ea6090709e3835a3091e33ee2 diff --git a/src/ATokenNaive.sol b/src/ATokenNaive.sol index a9dc754..06af3bb 100644 --- a/src/ATokenNaive.sol +++ b/src/ATokenNaive.sol @@ -10,6 +10,7 @@ import {IPool} from "aave-v3-core/contracts/interfaces/IPool.sol"; import {WadRayMath} from "aave-v3-core/contracts/protocol/libraries/math/WadRayMath.sol"; import {SafeCast} from "openzeppelin-contracts/contracts/utils/math/SafeCast.sol"; import {Math} from "openzeppelin-contracts/contracts/utils/math/Math.sol"; +import {Checkpoints} from "openzeppelin-contracts/contracts/utils/Checkpoints.sol"; interface IFractionalGovernor { function token() external returns (address); @@ -34,6 +35,7 @@ contract ATokenNaive is AToken { using WadRayMath for uint256; using SafeCast for uint256; using GPv2SafeERC20 for IERC20; + using Checkpoints for Checkpoints.History; /// @notice The voting options corresponding to those used in the Governor. enum VoteType { @@ -56,7 +58,7 @@ contract ATokenNaive is AToken { uint32 public immutable CAST_VOTE_WINDOW; /// @notice Map proposalId to an address to whether they have voted on this proposal. - mapping(uint256 => mapping(address => bool)) private _proposalVotersHasVoted; + mapping(uint256 => mapping(address => bool)) private proposalVotersHasVoted; /// @notice Map proposalId to whether or not this contract has cast votes on it. mapping(uint256 => bool) public hasCastVotesOnProposal; @@ -69,6 +71,12 @@ contract ATokenNaive is AToken { /// GovernorCountingFractional. IFractionalGovernor public immutable governor; + /// @notice Mapping from address to deposit checkpoint history. + mapping(address => Checkpoints.History) private depositCheckpoints; + + /// @notice History of total underlying asset balance. + Checkpoints.History private totalDepositCheckpoints; + /// @dev Constructor. /// @param _pool The address of the Pool contract /// @param _governor The address of the flex-voting-compatible governance contract. @@ -101,8 +109,6 @@ contract ATokenNaive is AToken { _lastVotingBlock = governor.proposalDeadline(proposalId) - CAST_VOTE_WINDOW; } - /// TODO how to handle onBehalfOf? - /// TODO should this revert if the vote has been cast? /// @notice Allow a depositor to express their voting preference for a given /// proposal. Their preference is recorded internally but not moved to the /// Governor until `castVote` is called. We deliberately do NOT revert if the @@ -114,8 +120,8 @@ contract ATokenNaive is AToken { uint256 weight = getPastDeposits(msg.sender, governor.proposalSnapshot(proposalId)); require(weight > 0, "no weight"); - require(!_proposalVotersHasVoted[proposalId][msg.sender], "already voted"); - _proposalVotersHasVoted[proposalId][msg.sender] = true; + require(!proposalVotersHasVoted[proposalId][msg.sender], "already voted"); + proposalVotersHasVoted[proposalId][msg.sender] = true; if (support == uint8(VoteType.Against)) { proposalVotes[proposalId].againstVotes += SafeCast.toUint128(weight); @@ -198,15 +204,22 @@ contract ATokenNaive is AToken { uint256 amount, uint256 index ) internal returns (bool) { - bool _returnVar = _mintScaled(caller, onBehalfOf, amount, index); - // We increment by `amount` instead of any computed/rebased amounts because // `amount` is what actually gets transferred of the underlying asset. We - // need our checkpoints to still match up with the underlying asset balance. - _writeCheckpoint(_checkpoints[onBehalfOf], _additionFn, amount); - _writeCheckpoint(_totalDepositCheckpoints, _additionFn, amount); + // need our checkpoints to still match up with underlying asset transactions. + Checkpoints.History storage _depositHistory = depositCheckpoints[onBehalfOf]; + _depositHistory.push(_depositHistory.latest() + amount); + totalDepositCheckpoints.push(totalDepositCheckpoints.latest() + amount); + + return _mintScaled(caller, onBehalfOf, amount, index); + } + + function getPastDeposits(address _voter, uint256 _blockNumber) public returns (uint256) { + return depositCheckpoints[_voter].getAtBlock(_blockNumber); + } - return _returnVar; + function getPastTotalDeposits(uint256 _blockNumber) public returns (uint256) { + return totalDepositCheckpoints.getAtBlock(_blockNumber); } // forgefmt: disable-start @@ -252,11 +265,12 @@ contract ATokenNaive is AToken { // // We decrement by `amount` instead of any computed/rebased amounts because // `amount` is what actually gets transferred of the underlying asset. We - // need our checkpoints to still match up with the underlying asset balance. - _writeCheckpoint(_checkpoints[from], _subtractionFn, amount); - _writeCheckpoint(_totalDepositCheckpoints, _subtractionFn, amount); - // End modifications. + // need our checkpoints to still match up with underlying asset transactions. + Checkpoints.History storage _depositHistory = depositCheckpoints[from]; + _depositHistory.push(_depositHistory.latest() - amount); + totalDepositCheckpoints.push(totalDepositCheckpoints.latest() - amount); + // End modifications. _burnScaled(from, receiverOfUnderlying, amount, index); if (receiverOfUnderlying != address(this)) { IERC20(_underlyingAsset).safeTransfer(receiverOfUnderlying, amount); @@ -265,91 +279,5 @@ contract ATokenNaive is AToken { //=========================================================================== // END: Aave overrides //=========================================================================== - - //=========================================================================== - // BEGIN: Checkpointing code. - //=========================================================================== - // This was been copied from OZ's ERC20Votes checkpointing system with minor - // revisions: - // * Replace "Vote" with "Deposit", as deposits are what we need to track - // * Make some variable names longer for readability - // * Break lines at 80-characters - struct Checkpoint { - uint32 fromBlock; - uint224 deposits; - } - mapping(address => Checkpoint[]) private _checkpoints; - Checkpoint[] private _totalDepositCheckpoints; - function checkpoints( - address account, - uint32 pos - ) public view virtual returns (Checkpoint memory) { - return _checkpoints[account][pos]; - } - function getDeposits(address account) public view virtual returns (uint256) { - uint256 pos = _checkpoints[account].length; - return pos == 0 ? 0 : _checkpoints[account][pos - 1].deposits; - } - function getPastDeposits( - address account, - uint256 blockNumber - ) public view virtual returns (uint256) { - require(blockNumber < block.number, "block not yet mined"); - return _checkpointsLookup(_checkpoints[account], blockNumber); - } - function getPastTotalDeposits( - uint256 blockNumber - ) public view virtual returns (uint256) { - require(blockNumber < block.number, "block not yet mined"); - return _checkpointsLookup(_totalDepositCheckpoints, blockNumber); - } - function _checkpointsLookup( - Checkpoint[] storage ckpts, - uint256 blockNumber - ) private view returns (uint256) { - // We run a binary search to look for the earliest checkpoint taken after - // `blockNumber`. - uint256 high = ckpts.length; - uint256 low = 0; - while (low < high) { - uint256 mid = Math.average(low, high); - if (ckpts[mid].fromBlock > blockNumber) { - high = mid; - } else { - low = mid + 1; - } - } - return high == 0 ? 0 : ckpts[high - 1].deposits; - } - function _writeCheckpoint( - Checkpoint[] storage ckpts, - function(uint256, uint256) view returns (uint256) operation, - uint256 delta - ) private returns (uint256 oldWeight, uint256 newWeight) { - uint256 position = ckpts.length; - oldWeight = position == 0 ? 0 : ckpts[position - 1].deposits; - newWeight = operation(oldWeight, delta); - - if (position > 0 && ckpts[position - 1].fromBlock == block.number) { - ckpts[position - 1].deposits = SafeCast.toUint224(newWeight); - } else { - ckpts.push( - Checkpoint({ - fromBlock: SafeCast.toUint32(block.number), - deposits: SafeCast.toUint224(newWeight) - }) - ); - } - } - function _additionFn(uint256 a, uint256 b) private pure returns (uint256) { - return a + b; - } - - function _subtractionFn(uint256 a, uint256 b) private pure returns (uint256) { - return a - b; - } - //=========================================================================== - // END: Checkpointing code. - //=========================================================================== // forgefmt: disable-end } diff --git a/test/FractionalPool.t.sol b/test/FractionalPool.t.sol index 031f314..4c6b238 100644 --- a/test/FractionalPool.t.sol +++ b/test/FractionalPool.t.sol @@ -130,14 +130,17 @@ contract Deposit is FractionalPoolTest { _amountA = bound(_amountA, 1, type(uint128).max); _amountB = bound(_amountB, 1, type(uint128).max); + // Deposit some gov. _mintGovAndDepositIntoPool(_holder, _amountA); - assertEq(token.balanceOf(_holder), 0); // they've all been deposited - assert(token.balanceOf(address(pool)) > token.balanceOf(_holder)); vm.roll(block.number + 42); // advance so that we can look at checkpoints // We can still retrieve the user's balance at the given time. - assertEq(pool.getPastDeposits(_holder, block.number - 1), _amountA); + assertEq( + pool.getPastDeposits(_holder, block.number - 1), + _amountA, + "user's first deposit was not properly checkpointed" + ); uint256 newBlockNum = block.number + _depositDelay; vm.roll(newBlockNum); @@ -146,7 +149,11 @@ contract Deposit is FractionalPoolTest { _mintGovAndDepositIntoPool(_holder, _amountB); vm.roll(block.number + 42); // advance so that we can look at checkpoints - assertEq(pool.getPastDeposits(_holder, block.number - 1), _amountA + _amountB); + assertEq( + pool.getPastDeposits(_holder, block.number - 1), + _amountA + _amountB, + "user's second deposit was not properly checkpointed" + ); } }