Skip to content

Commit

Permalink
extract functions and methods for unit testing remaining coins, add c…
Browse files Browse the repository at this point in the history
…omments
  • Loading branch information
p0mvn committed Jun 8, 2022
1 parent 73f12d5 commit 5172ccd
Show file tree
Hide file tree
Showing 3 changed files with 320 additions and 27 deletions.
112 changes: 85 additions & 27 deletions x/gamm/pool-models/balancer/amm.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package balancer

import (
"errors"
"fmt"

sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
Expand All @@ -12,10 +13,12 @@ import (
)

const (
errMsgFormatSharesAmountNotPositive = "shares amount must be positive, was %d"
errMsgFormatTokenAmountNotPositive = "token amount must be positive, was %d"
errMsgFormatTokensLargerThanMax = "%d resulted tokens is larger than the max amount of %d"
errMsgFormatSharesLargerThanMax = "%d resulted shares is larger than the max amount of %d"
errMsgFormatSharesAmountNotPositive = "shares amount must be positive, was %d"
errMsgFormatTokenAmountNotPositive = "token amount must be positive, was %d"
errMsgFormatTokensLargerThanMax = "%d resulted tokens is larger than the max amount of %d"
errMsgFormatSharesLargerThanMax = "%d resulted shares is larger than the max amount of %d"
errMsgFormatFailedInterimLiquidityUpdate = "failed to update interim liquidity - pool asset %s does not exist"
errMsgFormatRepeatingPoolAssetsNotAllowed = "repeating pool assets not allowed, found %s"
)

// solveConstantFunctionInvariant solves the constant function of an AMM
Expand Down Expand Up @@ -259,10 +262,9 @@ func (p *Pool) JoinPool(_ctx sdk.Context, tokensIn sdk.Coins, swapFee sdk.Dec) (

// CalcJoinPoolShares
func (p *Pool) CalcJoinPoolShares(_ sdk.Context, tokensIn sdk.Coins, swapFee sdk.Dec) (numShares sdk.Int, newLiquidity sdk.Coins, err error) {
poolAssets := p.GetAllPoolAssets()
poolAssetsByDenom := make(map[string]PoolAsset)
for _, poolAsset := range poolAssets {
poolAssetsByDenom[poolAsset.Token.Denom] = poolAsset
poolAssetsByDenom, err := getPoolAssetsByDenom(p.GetAllPoolAssets())
if err != nil {
return sdk.ZeroInt(), sdk.NewCoins(), err
}

totalShares := p.GetTotalShares()
Expand All @@ -286,34 +288,90 @@ func (p *Pool) CalcJoinPoolShares(_ sdk.Context, tokensIn sdk.Coins, swapFee sdk
return sdk.ZeroInt(), sdk.NewCoins(), err
}

// update liquidity for accurate calcSingleAssetJoin calculation
newLiquidity = tokensIn.Sub(remCoins)
for _, coin := range newLiquidity {
poolAsset := poolAssetsByDenom[coin.Denom]
poolAsset.Token.Amount = poolAssetsByDenom[coin.Denom].Token.Amount.Add(coin.Amount)
poolAssetsByDenom[coin.Denom] = poolAsset
}

newTotalShares := totalShares.Add(numShares)

// If there are coins that couldn't be perfectly joined, do single asset joins
// for each of them.
if !remCoins.Empty() {
for _, coin := range remCoins {
newShares, err := p.calcSingleAssetJoin(coin, swapFee, poolAssetsByDenom[coin.Denom], newTotalShares)
if err != nil {
return sdk.ZeroInt(), sdk.NewCoins(), err
}

newLiquidity = newLiquidity.Add(coin)
newTotalShares = newTotalShares.Add(newShares)
numShares = numShares.Add(newShares)
// update liquidity for accurate calcSingleAssetJoin calculation
if err := updateIntermediaryLiquidity(newLiquidity, poolAssetsByDenom); err != nil {
return sdk.ZeroInt(), sdk.NewCoins(), err
}

// update total shares for accurate calcSingleAssetJoin calculation
newTotalShares := totalShares.Add(numShares)

newNumSharesFromRemaining, newLiquidityFromRemaining, err := p.calcJoinMultipleSingleAssetTokensIn(remCoins, newTotalShares, poolAssetsByDenom, swapFee)
if err != nil {
return sdk.ZeroInt(), sdk.NewCoins(), err
}
numShares.Add(newNumSharesFromRemaining)
newLiquidity.Add(newLiquidityFromRemaining...)
}

return numShares, newLiquidity, nil
}

// getPoolAssetsByDenom return a mapping from pool asset
// denom to the pool asset itself. There must be no duplicates.
// Returns error, if any found.
func getPoolAssetsByDenom(poolAssets []PoolAsset) (map[string]PoolAsset, error) {
poolAssetsByDenom := make(map[string]PoolAsset)
for _, poolAsset := range poolAssets {
_, ok := poolAssetsByDenom[poolAsset.Token.Denom]
if ok {
return nil, fmt.Errorf(errMsgFormatRepeatingPoolAssetsNotAllowed, poolAsset.Token.Denom)
}

poolAssetsByDenom[poolAsset.Token.Denom] = poolAsset
}
return poolAssetsByDenom, nil
}

// updateIntermediaryLiquidity updates poolAssetsByDenom with newLiquidity.
//
// all liqidity coins must exist in poolAssetsByDenom. Returns error, if not.
//
// This is a helper function that is useful for updating the pool asset amounts
// as an intermediary step in a multi-join methods such as CalcJoinPoolShares.
// In CalcJoinPoolShares with multi-asset joins, we first attempt to do
// a MaximalExactRatioJoin that might leave out some tokens in.
// Then, for every remaining tokens in, we attempt to do a single asset join.
// Since the first step (MaximalExactRatioJoin) affects the pool liqudity due to slippage,
// we would like to account for that in the subsequent steps of single asset join
func updateIntermediaryLiquidity(liquidity sdk.Coins, poolAssetsByDenom map[string]PoolAsset) error {
for _, coin := range liquidity {
poolAsset, ok := poolAssetsByDenom[coin.Denom]
if !ok {
return fmt.Errorf(errMsgFormatFailedInterimLiquidityUpdate, coin.Denom)
}

poolAsset.Token.Amount = poolAssetsByDenom[coin.Denom].Token.Amount.Add(coin.Amount)
poolAssetsByDenom[coin.Denom] = poolAsset
}
return nil
}

// calcJoinMultipleSingleAssetTokensIn attemps to calculate single
// asset join for all tokensIn given totalSharesSoFar,
// poolAssetsByDenom and swapFee.
//
// Returns totalNumShares and totalNewLiquidity from joining all tokensIn
// or error if fails to calculate join for any of the tokensIn.
func (p *Pool) calcJoinMultipleSingleAssetTokensIn(tokensIn sdk.Coins, totalSharesSoFar sdk.Int, poolAssetsByDenom map[string]PoolAsset, swapFee sdk.Dec) (sdk.Int, sdk.Coins, error) {
totalNumShares := sdk.ZeroInt()
totalNewLiquidity := sdk.NewCoins()
for _, coin := range tokensIn {
newShares, err := p.calcSingleAssetJoin(coin, swapFee, poolAssetsByDenom[coin.Denom], totalSharesSoFar)
if err != nil {
return sdk.ZeroInt(), sdk.Coins{}, err
}

totalNewLiquidity = totalNewLiquidity.Add(coin)
totalSharesSoFar = totalSharesSoFar.Add(newShares)
totalNumShares.Add(newShares)
}
return totalNumShares, totalNewLiquidity, nil
}

func (p *Pool) ExitPool(ctx sdk.Context, exitingShares sdk.Int, exitFee sdk.Dec) (exitingCoins sdk.Coins, err error) {
exitingCoins, err = p.CalcExitPoolShares(ctx, exitingShares, exitFee)
if err != nil {
Expand Down
224 changes: 224 additions & 0 deletions x/gamm/pool-models/balancer/amm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -436,3 +436,227 @@ func TestCalcJoinPoolShares(t *testing.T) {
})
}
}

func TestGetPoolAssetsByDenom(t *testing.T) {
testCases := []struct {
name string
poolAssets []balancer.PoolAsset
expectedPoolAssetsByDenom map[string]balancer.PoolAsset

err error
}{
{
name: "zero pool assets",
poolAssets: []balancer.PoolAsset{},
expectedPoolAssetsByDenom: make(map[string]balancer.PoolAsset),
},
{
name: "one pool asset",
poolAssets: []balancer.PoolAsset{
{
Token: sdk.NewInt64Coin("uosmo", 1_000_000_000_000),
Weight: sdk.NewInt(100),
},
},
expectedPoolAssetsByDenom: map[string]balancer.PoolAsset{
"uosmo": {
Token: sdk.NewInt64Coin("uosmo", 1_000_000_000_000),
Weight: sdk.NewInt(100),
},
},
},
{
name: "two pool assets",
poolAssets: []balancer.PoolAsset{
{
Token: sdk.NewInt64Coin("uosmo", 1_000_000_000_000),
Weight: sdk.NewInt(100),
},
{
Token: sdk.NewInt64Coin("atom", 123),
Weight: sdk.NewInt(400),
},
},
expectedPoolAssetsByDenom: map[string]balancer.PoolAsset{
"uosmo": {
Token: sdk.NewInt64Coin("uosmo", 1_000_000_000_000),
Weight: sdk.NewInt(100),
},
"atom": {
Token: sdk.NewInt64Coin("atom", 123),
Weight: sdk.NewInt(400),
},
},
},
{
name: "duplicate pool assets",
poolAssets: []balancer.PoolAsset{
{
Token: sdk.NewInt64Coin("uosmo", 1_000_000_000_000),
Weight: sdk.NewInt(100),
},
{
Token: sdk.NewInt64Coin("uosmo", 123),
Weight: sdk.NewInt(400),
},
},
err: fmt.Errorf(balancer.ErrMsgFormatRepeatingPoolAssetsNotAllowed, "uosmo"),
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
actualPoolAssetsByDenom, err := balancer.GetPoolAssetsByDenom(tc.poolAssets)

require.Equal(t, tc.err, err)

if tc.err != nil {
return
}

require.Equal(t, tc.expectedPoolAssetsByDenom, actualPoolAssetsByDenom)
})
}
}

func TestUpdateIntermediaryLiquidity(t *testing.T) {
testCases := []struct {
name string

// returns newLiquidity, originalPoolAssetsByDenom, expectedPoolAssetsByDenom
setup func() (sdk.Coins, map[string]balancer.PoolAsset, map[string]balancer.PoolAsset)

err error
}{
{
name: "regular case with multiple pool assets and a subset of newLqiduity to update",

setup: func() (sdk.Coins, map[string]balancer.PoolAsset, map[string]balancer.PoolAsset) {
const (
uosmoValueOriginal = 1_000_000_000_000
atomValueOriginal = 123
ionValueOriginal = 657

uosmoValueUpdate = 1_000
atomValueUpdate = 2_000
ionValueUpdate = 3_000

// Weight does not affect calculations so it is shared
weight = 100
)

newLiquidity := sdk.NewCoins(
sdk.NewInt64Coin("uosmo", uosmoValueUpdate),
sdk.NewInt64Coin("atom", atomValueUpdate),
sdk.NewInt64Coin("ion", ionValueUpdate))

originalPoolAssetsByDenom := map[string]balancer.PoolAsset{
"uosmo": {
Token: sdk.NewInt64Coin("uosmo", uosmoValueOriginal),
Weight: sdk.NewInt(weight),
},
"atom": {
Token: sdk.NewInt64Coin("atom", atomValueOriginal),
Weight: sdk.NewInt(weight),
},
"ion": {
Token: sdk.NewInt64Coin("ion", ionValueOriginal),
Weight: sdk.NewInt(weight),
},
}

expectedPoolAssetsByDenom := map[string]balancer.PoolAsset{
"uosmo": {
Token: sdk.NewInt64Coin("uosmo", uosmoValueOriginal+uosmoValueUpdate),
Weight: sdk.NewInt(weight),
},
"atom": {
Token: sdk.NewInt64Coin("atom", atomValueOriginal+atomValueUpdate),
Weight: sdk.NewInt(weight),
},
"ion": {
Token: sdk.NewInt64Coin("ion", ionValueOriginal+ionValueUpdate),
Weight: sdk.NewInt(weight),
},
}

return newLiquidity, originalPoolAssetsByDenom, expectedPoolAssetsByDenom
},
},
{
name: "new liquidity has no coins",

setup: func() (sdk.Coins, map[string]balancer.PoolAsset, map[string]balancer.PoolAsset) {
const (
uosmoValueOriginal = 1_000_000_000_000
atomValueOriginal = 123
ionValueOriginal = 657

// Weight does not affect calculations so it is shared
weight = 100
)

newLiquidity := sdk.NewCoins()

originalPoolAssetsByDenom := map[string]balancer.PoolAsset{
"uosmo": {
Token: sdk.NewInt64Coin("uosmo", uosmoValueOriginal),
Weight: sdk.NewInt(weight),
},
"atom": {
Token: sdk.NewInt64Coin("atom", atomValueOriginal),
Weight: sdk.NewInt(weight),
},
"ion": {
Token: sdk.NewInt64Coin("ion", ionValueOriginal),
Weight: sdk.NewInt(weight),
},
}

return newLiquidity, originalPoolAssetsByDenom, originalPoolAssetsByDenom
},
},
{
name: "newLiquidity has a coin that poolAssets don't",

setup: func() (sdk.Coins, map[string]balancer.PoolAsset, map[string]balancer.PoolAsset) {
const (
uosmoValueOriginal = 1_000_000_000_000

// Weight does not affect calculations so it is shared
weight = 100
)

newLiquidity := sdk.NewCoins(
sdk.NewInt64Coin("juno", 1_000))

originalPoolAssetsByDenom := map[string]balancer.PoolAsset{
"uosmo": {
Token: sdk.NewInt64Coin("uosmo", uosmoValueOriginal),
Weight: sdk.NewInt(weight),
},
}

return newLiquidity, originalPoolAssetsByDenom, originalPoolAssetsByDenom
},

err: fmt.Errorf(balancer.ErrMsgFormatFailedInterimLiquidityUpdate, "juno"),
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
newLiquidity, originalPoolAssetsByDenom, expectedPoolAssetsByDenom := tc.setup()

err := balancer.UpdateIntermediaryLiquidity(newLiquidity, originalPoolAssetsByDenom)

require.Equal(t, tc.err, err)

if tc.err != nil {
return
}

require.Equal(t, expectedPoolAssetsByDenom, originalPoolAssetsByDenom)
})
}
}
11 changes: 11 additions & 0 deletions x/gamm/pool-models/balancer/export_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
package balancer

import sdk "github.com/cosmos/cosmos-sdk/types"

var (
ErrMsgFormatFailedInterimLiquidityUpdate = errMsgFormatFailedInterimLiquidityUpdate
ErrMsgFormatRepeatingPoolAssetsNotAllowed = errMsgFormatRepeatingPoolAssetsNotAllowed

CalcPoolSharesOutGivenSingleAssetIn = calcPoolSharesOutGivenSingleAssetIn
CalcSingleAssetInGivenPoolSharesOut = calcSingleAssetInGivenPoolSharesOut
GetPoolAssetsByDenom = getPoolAssetsByDenom
UpdateIntermediaryLiquidity = updateIntermediaryLiquidity
)

func (p *Pool) CalcJoinMultipleSingleAssetTokensIn(tokensIn sdk.Coins, totalSharesSoFar sdk.Int, poolAssetsByDenom map[string]PoolAsset, swapFee sdk.Dec) (sdk.Int, sdk.Coins, error) {
return p.calcJoinMultipleSingleAssetTokensIn(tokensIn, totalSharesSoFar, poolAssetsByDenom, swapFee)
}

0 comments on commit 5172ccd

Please sign in to comment.