Skip to content

Commit

Permalink
feat: Add minimum first deposit, refactor convert/preview functions (…
Browse files Browse the repository at this point in the history
…SC-446) (#2)

* feat: first test working

* feat: use larger numbers:

* feat: test with initial burn amount passing

* feat: update tests to work with updated burn logic, move conversion functions around and use previews

* feat: remove todos

* fix: update to remove console and update comment

* feat: remove all share burn logic, get all non inflation attack tests to pass

* fix: cleanup diff

* fix: update to use initial deposit instead of burn

* feat: add readme section explaining attack

* fix: minimize diff

* feat: update to address comments outside sharesToBurn

* feat: update inflation attack test and readme

* fix: update readme

* feat: update test to constrain deposit/withdraw

* feat: update to add both cases
  • Loading branch information
lucas-manuel authored Jun 4, 2024
1 parent 0b3d256 commit 6c44569
Show file tree
Hide file tree
Showing 5 changed files with 279 additions and 64 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ PSM contracts to either:
- Convert between a tokenization of an asset (ex. USDC) and a yield-bearing version of the asset (ex. sDAI).
- Convert one to one between directly correlated assets (ex. USDC-DAI).

## [CRITICAL]: First Depositor Attack Prevention on Deployment

On the deployment of the PSM, the deployer **MUST make an initial deposit to get AT LEAST 1e18 shares in order to protect the first depositor from getting attacked with a share inflation attack**. This is outlined further [here](https://github.com/marsfoundation/spark-automations/assets/44272939/9472a6d2-0361-48b0-b534-96a0614330d3). Technical details related to this can be found in `test/InflationAttack.t.sol`. The deployment script [TODO] in this repo contains logic for the deployer to perform this initial deposit, so it is **HIGHLY RECOMMENDED** to use this deployment script when deploying the PSM. Reasoning for the technical implementation approach taken is outlined in more detail [here](https://github.com/marsfoundation/spark-psm/pull/2).

## Usage

```bash
Expand Down
132 changes: 81 additions & 51 deletions src/PSM.sol
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@ interface IRateProviderLike {
// TODO: Refactor into inheritance structure
// TODO: Add interface with natspec and inherit
// TODO: Prove that we're always rounding against user
// TODO: Frontrunning attack, donation attack, virtual balances?
// TODO: Figure out how to optimize require checks for assets in view functions
// TODO: Discuss if we should add ERC20 functionality
// TODO: Add receiver to deposit/withdraw
contract PSM {

using SafeERC20 for IERC20;
Expand Down Expand Up @@ -80,11 +78,10 @@ contract PSM {
/*** Liquidity provision functions ***/
/**********************************************************************************************/

function deposit(address asset, uint256 assetsToDeposit) external {
require(asset == address(asset0) || asset == address(asset1), "PSM/invalid-asset");

// Convert amount to 1e18 precision denominated in value of asset0 then convert to shares.
uint256 newShares = convertToShares(_getAssetValue(asset, assetsToDeposit));
function deposit(address asset, uint256 assetsToDeposit)
external returns (uint256 newShares)
{
newShares = previewDeposit(asset, assetsToDeposit);

shares[msg.sender] += newShares;
totalShares += newShares;
Expand All @@ -94,6 +91,32 @@ contract PSM {

function withdraw(address asset, uint256 maxAssetsToWithdraw)
external returns (uint256 assetsWithdrawn)
{
uint256 sharesToBurn;

( sharesToBurn, assetsWithdrawn ) = previewWithdraw(asset, maxAssetsToWithdraw);

unchecked {
shares[msg.sender] -= sharesToBurn;
totalShares -= sharesToBurn;
}

IERC20(asset).safeTransfer(msg.sender, assetsWithdrawn);
}

/**********************************************************************************************/
/*** Deposit/withdraw preview functions ***/
/**********************************************************************************************/

function previewDeposit(address asset, uint256 assetsToDeposit) public view returns (uint256) {
require(asset == address(asset0) || asset == address(asset1), "PSM/invalid-asset");

// Convert amount to 1e18 precision denominated in value of asset0 then convert to shares.
return convertToShares(_getAssetValue(asset, assetsToDeposit));
}

function previewWithdraw(address asset, uint256 maxAssetsToWithdraw)
public view returns (uint256 sharesToBurn, uint256 assetsWithdrawn)
{
require(asset == address(asset0) || asset == address(asset1), "PSM/invalid-asset");

Expand All @@ -103,36 +126,43 @@ contract PSM {
? assetBalance
: maxAssetsToWithdraw;

uint256 sharesToBurn = convertToShares(asset, assetsWithdrawn);
sharesToBurn = _convertToSharesRoundUp(_getAssetValue(asset, assetsWithdrawn));

if (sharesToBurn > shares[msg.sender]) {
assetsWithdrawn = convertToAssets(asset, shares[msg.sender]);
sharesToBurn = convertToShares(asset, assetsWithdrawn);
}
uint256 userShares = shares[msg.sender];

unchecked {
shares[msg.sender] -= sharesToBurn;
totalShares -= sharesToBurn;
if (sharesToBurn > userShares) {
assetsWithdrawn = convertToAssets(asset, userShares);
sharesToBurn = userShares;
}

IERC20(asset).safeTransfer(msg.sender, assetsWithdrawn);
}

/**********************************************************************************************/
/*** Conversion functions ***/
/*** Swap preview functions ***/
/**********************************************************************************************/

function convertToShares(uint256 assetValue) public view returns (uint256) {
uint256 totalValue = getPsmTotalValue();
if (totalValue != 0) {
return assetValue * totalShares / totalValue;
}
return assetValue;
function previewSwapAssetZeroToOne(uint256 amountIn) public view returns (uint256) {
return amountIn
* 1e27
* asset1Precision
/ IRateProviderLike(rateProvider).getConversionRate()
/ asset0Precision;
}

function convertToShares(address asset, uint256 assets) public view returns (uint256) {
function previewSwapAssetOneToZero(uint256 amountIn) public view returns (uint256) {
return amountIn
* IRateProviderLike(rateProvider).getConversionRate()
* asset0Precision
/ 1e27
/ asset1Precision;
}

/**********************************************************************************************/
/*** Conversion functions ***/
/**********************************************************************************************/

function convertToAssets(address asset, uint256 numShares) public view returns (uint256) {
require(asset == address(asset0) || asset == address(asset1), "PSM/invalid-asset");
return convertToShares(_getAssetValue(asset, assets));
return _getAssetsByValue(asset, convertToAssetValue(numShares));
}

function convertToAssetValue(uint256 numShares) public view returns (uint256) {
Expand All @@ -144,9 +174,17 @@ contract PSM {
return numShares;
}

function convertToAssets(address asset, uint256 numShares) public view returns (uint256) {
function convertToShares(uint256 assetValue) public view returns (uint256) {
uint256 totalValue = getPsmTotalValue();
if (totalValue != 0) {
return assetValue * totalShares / totalValue;
}
return assetValue;
}

function convertToShares(address asset, uint256 assets) public view returns (uint256) {
require(asset == address(asset0) || asset == address(asset1), "PSM/invalid-asset");
return _getAssetsByValue(asset, convertToAssetValue(numShares));
return convertToShares(_getAssetValue(asset, assets));
}

/**********************************************************************************************/
Expand All @@ -159,35 +197,27 @@ contract PSM {
}

/**********************************************************************************************/
/*** Swap preview functions ***/
/*** Internal helper functions ***/
/**********************************************************************************************/

function previewSwapAssetZeroToOne(uint256 amountIn) public view returns (uint256) {
return amountIn
* 1e27
* asset1Precision
/ IRateProviderLike(rateProvider).getConversionRate()
/ asset0Precision;
function _convertToSharesRoundUp(uint256 assetValue) internal view returns (uint256) {
uint256 totalValue = getPsmTotalValue();
if (totalValue != 0) {
return _divUp(assetValue * totalShares, totalValue);
}
return assetValue;
}

function previewSwapAssetOneToZero(uint256 amountIn) public view returns (uint256) {
return amountIn
* IRateProviderLike(rateProvider).getConversionRate()
* asset0Precision
/ 1e27
/ asset1Precision;
function _divUp(uint256 x, uint256 y) internal pure returns (uint256 z) {
unchecked {
z = x != 0 ? ((x - 1) / y) + 1 : 0;
}
}

/**********************************************************************************************/
/*** Internal helper functions ***/
/**********************************************************************************************/

function _getAssetValue(address asset, uint256 amount) internal view returns (uint256) {
if (asset == address(asset0)) {
return _getAsset0Value(amount);
}

return _getAsset1Value(amount);
return asset == address(asset0)
? _getAsset0Value(amount)
: _getAsset1Value(amount);
}

function _getAssetsByValue(address asset, uint256 assetValue) internal view returns (uint256) {
Expand Down
111 changes: 111 additions & 0 deletions test/InflationAttack.t.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
pragma solidity ^0.8.13;

import "forge-std/Test.sol";

import { PSM } from "src/PSM.sol";

import { PSMTestBase } from "test/PSMTestBase.sol";

contract InflationAttackTests is PSMTestBase {

// TODO: Add DOS attack test outlined here: https://github.com/marsfoundation/spark-psm/pull/2#pullrequestreview-2085880206

function test_inflationAttack_noInitialDeposit() public {
psm = new PSM(address(usdc), address(sDai), address(rateProvider));

address firstDepositor = makeAddr("firstDepositor");
address frontRunner = makeAddr("frontRunner");

// Step 1: Front runner deposits 1 sDAI to get 1 share

// Have to use sDai because 1 USDC mints 1e12 shares
_deposit(frontRunner, address(sDai), 1);

assertEq(psm.shares(frontRunner), 1);

// Step 2: Front runner transfers 10m USDC to inflate the exchange rate to 1:(10m + 1)

deal(address(usdc), frontRunner, 10_000_000e6);

assertEq(psm.convertToAssetValue(1), 1);

vm.prank(frontRunner);
usdc.transfer(address(psm), 10_000_000e6);

// Highly inflated exchange rate
assertEq(psm.convertToAssetValue(1), 10_000_000e18 + 1);

// Step 3: First depositor deposits 20 million USDC, only gets one share because rounding
// error gives them 1 instead of 2 shares, worth 15m USDC

_deposit(firstDepositor, address(usdc), 20_000_000e6);

assertEq(psm.shares(firstDepositor), 1);

// 1 share = 3 million USDC / 2 shares = 1.5 million USDC
assertEq(psm.convertToAssetValue(1), 15_000_000e18);

// Step 4: Both users withdraw the max amount of funds they can

_withdraw(firstDepositor, address(usdc), type(uint256).max);
_withdraw(frontRunner, address(usdc), type(uint256).max);

assertEq(usdc.balanceOf(address(psm)), 0);

// Front runner profits 5m USDC, first depositor loses 5m USDC
assertEq(usdc.balanceOf(firstDepositor), 15_000_000e6);
assertEq(usdc.balanceOf(frontRunner), 15_000_000e6);
}

function test_inflationAttack_useInitialDeposit() public {
psm = new PSM(address(usdc), address(sDai), address(rateProvider));

address firstDepositor = makeAddr("firstDepositor");
address frontRunner = makeAddr("frontRunner");
address deployer = address(this); // TODO: Update to use non-deployer receiver

_deposit(address(this), address(sDai), 0.8e18); /// 1e18 shares

// Step 1: Front runner deposits sDAI to get 1 share

// User tries to do the same attack, depositing one sDAI for 1 share
_deposit(frontRunner, address(sDai), 1);

assertEq(psm.shares(frontRunner), 1);

// Step 2: Front runner transfers 10m USDC to inflate the exchange rate to 1:(10m + 1)

assertEq(psm.convertToAssetValue(1), 1);

deal(address(usdc), frontRunner, 10_000_000e6);

vm.prank(frontRunner);
usdc.transfer(address(psm), 10_000_000e6);

// Still inflated, but all value is transferred to existing holder, deployer
assertEq(psm.convertToAssetValue(1), 0.00000000001e18);

// Step 3: First depositor deposits 20 million USDC, this time rounding is not an issue
// so value reflected is much more accurate

_deposit(firstDepositor, address(usdc), 20_000_000e6);

assertEq(psm.shares(firstDepositor), 1.999999800000020001e18);

// Higher amount of initial shares means lower rounding error
assertEq(psm.convertToAssetValue(1.999999800000020001e18), 19_999_999.999999999996673334e18);

// Step 4: Both users withdraw the max amount of funds they can

_withdraw(firstDepositor, address(usdc), type(uint256).max);
_withdraw(frontRunner, address(usdc), type(uint256).max);
_withdraw(deployer, address(usdc), type(uint256).max);

// Front runner loses full 10m USDC to the deployer that had all shares at the beginning, first depositor loses nothing (1e-6 USDC)
assertEq(usdc.balanceOf(firstDepositor), 19_999_999.999999e6);
assertEq(usdc.balanceOf(frontRunner), 0);
assertEq(usdc.balanceOf(deployer), 10_000_000.000001e6);
}

}
2 changes: 1 addition & 1 deletion test/PSMTestBase.sol
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ pragma solidity ^0.8.13;

import "forge-std/Test.sol";

import { PSM } from "../src/PSM.sol";
import { PSM } from "src/PSM.sol";

import { MockERC20 } from "erc20-helpers/MockERC20.sol";

Expand Down
Loading

0 comments on commit 6c44569

Please sign in to comment.