Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use OZ checkpointing library #20

Merged
merged 9 commits into from
Dec 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
[submodule "lib/solmate"]
path = lib/solmate
url = https://github.com/rari-capital/solmate
[submodule "lib/openzeppelin-contracts"]
path = lib/openzeppelin-contracts
url = https://github.com/openzeppelin/openzeppelin-contracts
[submodule "lib/aave-v3-core"]
path = lib/aave-v3-core
url = https://github.com/aave/aave-v3-core
[submodule "lib/forge-std"]
path = lib/forge-std
url = https://github.com/foundry-rs/forge-std
[submodule "lib/openzeppelin-contracts"]
path = lib/openzeppelin-contracts
url = https://github.com/openzeppelin/openzeppelin-contracts
branch = v4.8.0
1 change: 1 addition & 0 deletions foundry.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# auto-detect solc versions.
optimizer = true
optimizer_runs = 10_000_000
remappings = ["openzeppelin-contracts/=lib/openzeppelin-contracts"]
verbosity = 3

[profile.ci]
Expand Down
2 changes: 1 addition & 1 deletion lib/openzeppelin-contracts
130 changes: 29 additions & 101 deletions src/ATokenNaive.sol
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {IPool} from "aave-v3-core/contracts/interfaces/IPool.sol";
import {WadRayMath} from "aave-v3-core/contracts/protocol/libraries/math/WadRayMath.sol";
import {SafeCast} from "openzeppelin-contracts/contracts/utils/math/SafeCast.sol";
import {Math} from "openzeppelin-contracts/contracts/utils/math/Math.sol";
import {Checkpoints} from "openzeppelin-contracts/contracts/utils/Checkpoints.sol";

interface IFractionalGovernor {
function token() external returns (address);
Expand All @@ -34,6 +35,7 @@ contract ATokenNaive is AToken {
using WadRayMath for uint256;
using SafeCast for uint256;
using GPv2SafeERC20 for IERC20;
using Checkpoints for Checkpoints.History;

/// @notice The voting options corresponding to those used in the Governor.
enum VoteType {
Expand All @@ -56,7 +58,7 @@ contract ATokenNaive is AToken {
uint32 public immutable CAST_VOTE_WINDOW;

/// @notice Map proposalId to an address to whether they have voted on this proposal.
mapping(uint256 => mapping(address => bool)) private _proposalVotersHasVoted;
mapping(uint256 => mapping(address => bool)) private proposalVotersHasVoted;

/// @notice Map proposalId to whether or not this contract has cast votes on it.
mapping(uint256 => bool) public hasCastVotesOnProposal;
Expand All @@ -69,6 +71,12 @@ contract ATokenNaive is AToken {
/// GovernorCountingFractional.
IFractionalGovernor public immutable governor;

/// @notice Mapping from address to deposit checkpoint history.
mapping(address => Checkpoints.History) private depositCheckpoints;

/// @notice History of total underlying asset balance.
Checkpoints.History private totalDepositCheckpoints;

/// @dev Constructor.
/// @param _pool The address of the Pool contract
/// @param _governor The address of the flex-voting-compatible governance contract.
Expand Down Expand Up @@ -101,8 +109,6 @@ contract ATokenNaive is AToken {
_lastVotingBlock = governor.proposalDeadline(proposalId) - CAST_VOTE_WINDOW;
}

/// TODO how to handle onBehalfOf?
/// TODO should this revert if the vote has been cast?
/// @notice Allow a depositor to express their voting preference for a given
/// proposal. Their preference is recorded internally but not moved to the
/// Governor until `castVote` is called. We deliberately do NOT revert if the
Expand All @@ -114,8 +120,8 @@ contract ATokenNaive is AToken {
uint256 weight = getPastDeposits(msg.sender, governor.proposalSnapshot(proposalId));
require(weight > 0, "no weight");

require(!_proposalVotersHasVoted[proposalId][msg.sender], "already voted");
_proposalVotersHasVoted[proposalId][msg.sender] = true;
require(!proposalVotersHasVoted[proposalId][msg.sender], "already voted");
proposalVotersHasVoted[proposalId][msg.sender] = true;

if (support == uint8(VoteType.Against)) {
proposalVotes[proposalId].againstVotes += SafeCast.toUint128(weight);
Expand Down Expand Up @@ -198,15 +204,22 @@ contract ATokenNaive is AToken {
uint256 amount,
uint256 index
) internal returns (bool) {
bool _returnVar = _mintScaled(caller, onBehalfOf, amount, index);

// We increment by `amount` instead of any computed/rebased amounts because
// `amount` is what actually gets transferred of the underlying asset. We
// need our checkpoints to still match up with the underlying asset balance.
_writeCheckpoint(_checkpoints[onBehalfOf], _additionFn, amount);
_writeCheckpoint(_totalDepositCheckpoints, _additionFn, amount);
// need our checkpoints to still match up with underlying asset transactions.
Checkpoints.History storage _depositHistory = depositCheckpoints[onBehalfOf];
_depositHistory.push(_depositHistory.latest() + amount);
totalDepositCheckpoints.push(totalDepositCheckpoints.latest() + amount);

return _mintScaled(caller, onBehalfOf, amount, index);
}

function getPastDeposits(address _voter, uint256 _blockNumber) public returns (uint256) {
return depositCheckpoints[_voter].getAtBlock(_blockNumber);
}

return _returnVar;
function getPastTotalDeposits(uint256 _blockNumber) public returns (uint256) {
return totalDepositCheckpoints.getAtBlock(_blockNumber);
}

// forgefmt: disable-start
Expand Down Expand Up @@ -252,11 +265,12 @@ contract ATokenNaive is AToken {
//
// We decrement by `amount` instead of any computed/rebased amounts because
// `amount` is what actually gets transferred of the underlying asset. We
// need our checkpoints to still match up with the underlying asset balance.
_writeCheckpoint(_checkpoints[from], _subtractionFn, amount);
_writeCheckpoint(_totalDepositCheckpoints, _subtractionFn, amount);
// End modifications.
// need our checkpoints to still match up with underlying asset transactions.
Checkpoints.History storage _depositHistory = depositCheckpoints[from];
_depositHistory.push(_depositHistory.latest() - amount);
totalDepositCheckpoints.push(totalDepositCheckpoints.latest() - amount);

// End modifications.
_burnScaled(from, receiverOfUnderlying, amount, index);
if (receiverOfUnderlying != address(this)) {
IERC20(_underlyingAsset).safeTransfer(receiverOfUnderlying, amount);
Expand All @@ -265,91 +279,5 @@ contract ATokenNaive is AToken {
//===========================================================================
// END: Aave overrides
//===========================================================================

//===========================================================================
// BEGIN: Checkpointing code.
//===========================================================================
// This was been copied from OZ's ERC20Votes checkpointing system with minor
// revisions:
// * Replace "Vote" with "Deposit", as deposits are what we need to track
// * Make some variable names longer for readability
// * Break lines at 80-characters
struct Checkpoint {
uint32 fromBlock;
uint224 deposits;
}
mapping(address => Checkpoint[]) private _checkpoints;
Checkpoint[] private _totalDepositCheckpoints;
function checkpoints(
address account,
uint32 pos
) public view virtual returns (Checkpoint memory) {
return _checkpoints[account][pos];
}
function getDeposits(address account) public view virtual returns (uint256) {
uint256 pos = _checkpoints[account].length;
return pos == 0 ? 0 : _checkpoints[account][pos - 1].deposits;
}
function getPastDeposits(
address account,
uint256 blockNumber
) public view virtual returns (uint256) {
require(blockNumber < block.number, "block not yet mined");
return _checkpointsLookup(_checkpoints[account], blockNumber);
}
function getPastTotalDeposits(
uint256 blockNumber
) public view virtual returns (uint256) {
require(blockNumber < block.number, "block not yet mined");
return _checkpointsLookup(_totalDepositCheckpoints, blockNumber);
}
function _checkpointsLookup(
Checkpoint[] storage ckpts,
uint256 blockNumber
) private view returns (uint256) {
// We run a binary search to look for the earliest checkpoint taken after
// `blockNumber`.
uint256 high = ckpts.length;
uint256 low = 0;
while (low < high) {
uint256 mid = Math.average(low, high);
if (ckpts[mid].fromBlock > blockNumber) {
high = mid;
} else {
low = mid + 1;
}
}
return high == 0 ? 0 : ckpts[high - 1].deposits;
}
function _writeCheckpoint(
Checkpoint[] storage ckpts,
function(uint256, uint256) view returns (uint256) operation,
uint256 delta
) private returns (uint256 oldWeight, uint256 newWeight) {
uint256 position = ckpts.length;
oldWeight = position == 0 ? 0 : ckpts[position - 1].deposits;
newWeight = operation(oldWeight, delta);

if (position > 0 && ckpts[position - 1].fromBlock == block.number) {
ckpts[position - 1].deposits = SafeCast.toUint224(newWeight);
} else {
ckpts.push(
Checkpoint({
fromBlock: SafeCast.toUint32(block.number),
deposits: SafeCast.toUint224(newWeight)
})
);
}
}
function _additionFn(uint256 a, uint256 b) private pure returns (uint256) {
return a + b;
}

function _subtractionFn(uint256 a, uint256 b) private pure returns (uint256) {
return a - b;
}
//===========================================================================
// END: Checkpointing code.
//===========================================================================
// forgefmt: disable-end
}
15 changes: 11 additions & 4 deletions test/FractionalPool.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,17 @@ contract Deposit is FractionalPoolTest {
_amountA = bound(_amountA, 1, type(uint128).max);
_amountB = bound(_amountB, 1, type(uint128).max);

// Deposit some gov.
_mintGovAndDepositIntoPool(_holder, _amountA);
assertEq(token.balanceOf(_holder), 0); // they've all been deposited
assert(token.balanceOf(address(pool)) > token.balanceOf(_holder));

vm.roll(block.number + 42); // advance so that we can look at checkpoints

// We can still retrieve the user's balance at the given time.
assertEq(pool.getPastDeposits(_holder, block.number - 1), _amountA);
assertEq(
pool.getPastDeposits(_holder, block.number - 1),
_amountA,
"user's first deposit was not properly checkpointed"
);

uint256 newBlockNum = block.number + _depositDelay;
vm.roll(newBlockNum);
Expand All @@ -146,7 +149,11 @@ contract Deposit is FractionalPoolTest {
_mintGovAndDepositIntoPool(_holder, _amountB);

vm.roll(block.number + 42); // advance so that we can look at checkpoints
assertEq(pool.getPastDeposits(_holder, block.number - 1), _amountA + _amountB);
assertEq(
pool.getPastDeposits(_holder, block.number - 1),
_amountA + _amountB,
"user's second deposit was not properly checkpointed"
);
}
}

Expand Down