Skip to content

Commit

Permalink
Use OZ checkpointing library (#20)
Browse files Browse the repository at this point in the history
* Bump OZ to get newer checkpointing code

* Take pass at using OZ checkpointing

* Fix OZ checkpointing usage

* Fix comments

* scopelint fmt

* Fix OZ dep loading + pin to latest official release

* Remove underscores from private storage vars

* Use Checkpoints for Checkpoints.History

* Fix unrelated failing test
  • Loading branch information
davidlaprade authored Dec 2, 2022
1 parent 61a839c commit 1cd6967
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 109 deletions.
7 changes: 4 additions & 3 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
[submodule "lib/solmate"]
path = lib/solmate
url = https://github.com/rari-capital/solmate
[submodule "lib/openzeppelin-contracts"]
path = lib/openzeppelin-contracts
url = https://github.com/openzeppelin/openzeppelin-contracts
[submodule "lib/aave-v3-core"]
path = lib/aave-v3-core
url = https://github.com/aave/aave-v3-core
[submodule "lib/forge-std"]
path = lib/forge-std
url = https://github.com/foundry-rs/forge-std
[submodule "lib/openzeppelin-contracts"]
path = lib/openzeppelin-contracts
url = https://github.com/openzeppelin/openzeppelin-contracts
branch = v4.8.0
1 change: 1 addition & 0 deletions foundry.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# auto-detect solc versions.
optimizer = true
optimizer_runs = 10_000_000
remappings = ["openzeppelin-contracts/=lib/openzeppelin-contracts"]
verbosity = 3

[profile.ci]
Expand Down
2 changes: 1 addition & 1 deletion lib/openzeppelin-contracts
130 changes: 29 additions & 101 deletions src/ATokenNaive.sol
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {IPool} from "aave-v3-core/contracts/interfaces/IPool.sol";
import {WadRayMath} from "aave-v3-core/contracts/protocol/libraries/math/WadRayMath.sol";
import {SafeCast} from "openzeppelin-contracts/contracts/utils/math/SafeCast.sol";
import {Math} from "openzeppelin-contracts/contracts/utils/math/Math.sol";
import {Checkpoints} from "openzeppelin-contracts/contracts/utils/Checkpoints.sol";

interface IFractionalGovernor {
function token() external returns (address);
Expand All @@ -34,6 +35,7 @@ contract ATokenNaive is AToken {
using WadRayMath for uint256;
using SafeCast for uint256;
using GPv2SafeERC20 for IERC20;
using Checkpoints for Checkpoints.History;

/// @notice The voting options corresponding to those used in the Governor.
enum VoteType {
Expand All @@ -56,7 +58,7 @@ contract ATokenNaive is AToken {
uint32 public immutable CAST_VOTE_WINDOW;

/// @notice Map proposalId to an address to whether they have voted on this proposal.
mapping(uint256 => mapping(address => bool)) private _proposalVotersHasVoted;
mapping(uint256 => mapping(address => bool)) private proposalVotersHasVoted;

/// @notice Map proposalId to whether or not this contract has cast votes on it.
mapping(uint256 => bool) public hasCastVotesOnProposal;
Expand All @@ -69,6 +71,12 @@ contract ATokenNaive is AToken {
/// GovernorCountingFractional.
IFractionalGovernor public immutable governor;

/// @notice Mapping from address to deposit checkpoint history.
mapping(address => Checkpoints.History) private depositCheckpoints;

/// @notice History of total underlying asset balance.
Checkpoints.History private totalDepositCheckpoints;

/// @dev Constructor.
/// @param _pool The address of the Pool contract
/// @param _governor The address of the flex-voting-compatible governance contract.
Expand Down Expand Up @@ -101,8 +109,6 @@ contract ATokenNaive is AToken {
_lastVotingBlock = governor.proposalDeadline(proposalId) - CAST_VOTE_WINDOW;
}

/// TODO how to handle onBehalfOf?
/// TODO should this revert if the vote has been cast?
/// @notice Allow a depositor to express their voting preference for a given
/// proposal. Their preference is recorded internally but not moved to the
/// Governor until `castVote` is called. We deliberately do NOT revert if the
Expand All @@ -114,8 +120,8 @@ contract ATokenNaive is AToken {
uint256 weight = getPastDeposits(msg.sender, governor.proposalSnapshot(proposalId));
require(weight > 0, "no weight");

require(!_proposalVotersHasVoted[proposalId][msg.sender], "already voted");
_proposalVotersHasVoted[proposalId][msg.sender] = true;
require(!proposalVotersHasVoted[proposalId][msg.sender], "already voted");
proposalVotersHasVoted[proposalId][msg.sender] = true;

if (support == uint8(VoteType.Against)) {
proposalVotes[proposalId].againstVotes += SafeCast.toUint128(weight);
Expand Down Expand Up @@ -198,15 +204,22 @@ contract ATokenNaive is AToken {
uint256 amount,
uint256 index
) internal returns (bool) {
bool _returnVar = _mintScaled(caller, onBehalfOf, amount, index);

// We increment by `amount` instead of any computed/rebased amounts because
// `amount` is what actually gets transferred of the underlying asset. We
// need our checkpoints to still match up with the underlying asset balance.
_writeCheckpoint(_checkpoints[onBehalfOf], _additionFn, amount);
_writeCheckpoint(_totalDepositCheckpoints, _additionFn, amount);
// need our checkpoints to still match up with underlying asset transactions.
Checkpoints.History storage _depositHistory = depositCheckpoints[onBehalfOf];
_depositHistory.push(_depositHistory.latest() + amount);
totalDepositCheckpoints.push(totalDepositCheckpoints.latest() + amount);

return _mintScaled(caller, onBehalfOf, amount, index);
}

function getPastDeposits(address _voter, uint256 _blockNumber) public returns (uint256) {
return depositCheckpoints[_voter].getAtBlock(_blockNumber);
}

return _returnVar;
function getPastTotalDeposits(uint256 _blockNumber) public returns (uint256) {
return totalDepositCheckpoints.getAtBlock(_blockNumber);
}

// forgefmt: disable-start
Expand Down Expand Up @@ -252,11 +265,12 @@ contract ATokenNaive is AToken {
//
// We decrement by `amount` instead of any computed/rebased amounts because
// `amount` is what actually gets transferred of the underlying asset. We
// need our checkpoints to still match up with the underlying asset balance.
_writeCheckpoint(_checkpoints[from], _subtractionFn, amount);
_writeCheckpoint(_totalDepositCheckpoints, _subtractionFn, amount);
// End modifications.
// need our checkpoints to still match up with underlying asset transactions.
Checkpoints.History storage _depositHistory = depositCheckpoints[from];
_depositHistory.push(_depositHistory.latest() - amount);
totalDepositCheckpoints.push(totalDepositCheckpoints.latest() - amount);

// End modifications.
_burnScaled(from, receiverOfUnderlying, amount, index);
if (receiverOfUnderlying != address(this)) {
IERC20(_underlyingAsset).safeTransfer(receiverOfUnderlying, amount);
Expand All @@ -265,91 +279,5 @@ contract ATokenNaive is AToken {
//===========================================================================
// END: Aave overrides
//===========================================================================

//===========================================================================
// BEGIN: Checkpointing code.
//===========================================================================
// This was been copied from OZ's ERC20Votes checkpointing system with minor
// revisions:
// * Replace "Vote" with "Deposit", as deposits are what we need to track
// * Make some variable names longer for readability
// * Break lines at 80-characters
struct Checkpoint {
uint32 fromBlock;
uint224 deposits;
}
mapping(address => Checkpoint[]) private _checkpoints;
Checkpoint[] private _totalDepositCheckpoints;
function checkpoints(
address account,
uint32 pos
) public view virtual returns (Checkpoint memory) {
return _checkpoints[account][pos];
}
function getDeposits(address account) public view virtual returns (uint256) {
uint256 pos = _checkpoints[account].length;
return pos == 0 ? 0 : _checkpoints[account][pos - 1].deposits;
}
function getPastDeposits(
address account,
uint256 blockNumber
) public view virtual returns (uint256) {
require(blockNumber < block.number, "block not yet mined");
return _checkpointsLookup(_checkpoints[account], blockNumber);
}
function getPastTotalDeposits(
uint256 blockNumber
) public view virtual returns (uint256) {
require(blockNumber < block.number, "block not yet mined");
return _checkpointsLookup(_totalDepositCheckpoints, blockNumber);
}
function _checkpointsLookup(
Checkpoint[] storage ckpts,
uint256 blockNumber
) private view returns (uint256) {
// We run a binary search to look for the earliest checkpoint taken after
// `blockNumber`.
uint256 high = ckpts.length;
uint256 low = 0;
while (low < high) {
uint256 mid = Math.average(low, high);
if (ckpts[mid].fromBlock > blockNumber) {
high = mid;
} else {
low = mid + 1;
}
}
return high == 0 ? 0 : ckpts[high - 1].deposits;
}
function _writeCheckpoint(
Checkpoint[] storage ckpts,
function(uint256, uint256) view returns (uint256) operation,
uint256 delta
) private returns (uint256 oldWeight, uint256 newWeight) {
uint256 position = ckpts.length;
oldWeight = position == 0 ? 0 : ckpts[position - 1].deposits;
newWeight = operation(oldWeight, delta);

if (position > 0 && ckpts[position - 1].fromBlock == block.number) {
ckpts[position - 1].deposits = SafeCast.toUint224(newWeight);
} else {
ckpts.push(
Checkpoint({
fromBlock: SafeCast.toUint32(block.number),
deposits: SafeCast.toUint224(newWeight)
})
);
}
}
function _additionFn(uint256 a, uint256 b) private pure returns (uint256) {
return a + b;
}

function _subtractionFn(uint256 a, uint256 b) private pure returns (uint256) {
return a - b;
}
//===========================================================================
// END: Checkpointing code.
//===========================================================================
// forgefmt: disable-end
}
15 changes: 11 additions & 4 deletions test/FractionalPool.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,17 @@ contract Deposit is FractionalPoolTest {
_amountA = bound(_amountA, 1, type(uint128).max);
_amountB = bound(_amountB, 1, type(uint128).max);

// Deposit some gov.
_mintGovAndDepositIntoPool(_holder, _amountA);
assertEq(token.balanceOf(_holder), 0); // they've all been deposited
assert(token.balanceOf(address(pool)) > token.balanceOf(_holder));

vm.roll(block.number + 42); // advance so that we can look at checkpoints

// We can still retrieve the user's balance at the given time.
assertEq(pool.getPastDeposits(_holder, block.number - 1), _amountA);
assertEq(
pool.getPastDeposits(_holder, block.number - 1),
_amountA,
"user's first deposit was not properly checkpointed"
);

uint256 newBlockNum = block.number + _depositDelay;
vm.roll(newBlockNum);
Expand All @@ -146,7 +149,11 @@ contract Deposit is FractionalPoolTest {
_mintGovAndDepositIntoPool(_holder, _amountB);

vm.roll(block.number + 42); // advance so that we can look at checkpoints
assertEq(pool.getPastDeposits(_holder, block.number - 1), _amountA + _amountB);
assertEq(
pool.getPastDeposits(_holder, block.number - 1),
_amountA + _amountB,
"user's second deposit was not properly checkpointed"
);
}
}

Expand Down

0 comments on commit 1cd6967

Please sign in to comment.