Skip to content

Commit

Permalink
fix: already claimed tokens should not be in claim proof (#240)
Browse files Browse the repository at this point in the history
  • Loading branch information
shrimalmadhur authored Nov 20, 2024
1 parent 680b434 commit 5f6537f
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 31 deletions.
63 changes: 48 additions & 15 deletions pkg/rewards/claim.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ type elChainReader interface {
ctx context.Context,
) (rewardscoordinator.IRewardsCoordinatorDistributionRoot, error)
CurrRewardsCalculationEndTimestamp(ctx context.Context) (uint32, error)
GetCumulativeClaimed(ctx context.Context, earnerAddress, tokenAddress gethcommon.Address) (*big.Int, error)
}

func ClaimCmd(p utils.Prompter) *cli.Command {
Expand Down Expand Up @@ -123,15 +124,22 @@ func Claim(cCtx *cli.Context, p utils.Prompter) error {
return eigenSdkUtils.WrapError("failed to fetch claim amounts for date", err)
}

claimableTokens, present := proofData.Distribution.GetTokensForEarner(config.EarnerAddress)
claimableTokensOrderMap, present := proofData.Distribution.GetTokensForEarner(config.EarnerAddress)
if !present {
return errors.New("no tokens claimable by earner")
}

claimableTokensMap := getTokensToClaim(claimableTokensOrderMap, config.TokenAddresses)

claimableTokens, err := filterClaimableTokens(ctx, elReader, config.EarnerAddress, claimableTokensMap)
if err != nil {
return eigenSdkUtils.WrapError("failed to get claimable tokens", err)
}

cg := claimgen.NewClaimgen(proofData.Distribution)
accounts, claim, err := cg.GenerateClaimProofForEarner(
config.EarnerAddress,
getTokensToClaim(claimableTokens, config.TokenAddresses),
claimableTokens,
rootIndex,
)
if err != nil {
Expand Down Expand Up @@ -270,6 +278,30 @@ func Claim(cCtx *cli.Context, p utils.Prompter) error {
return nil
}

// filterClaimableTokens to filter out tokens that have been fully claimed
func filterClaimableTokens(
ctx context.Context,
elReader elChainReader,
earnerAddress gethcommon.Address,
claimableTokensMap map[gethcommon.Address]*big.Int,
) ([]gethcommon.Address, error) {
claimableTokens := make([]gethcommon.Address, 0)
for token, claimedAmount := range claimableTokensMap {
amount, err := getCummulativeClaimedRewards(ctx, elReader, earnerAddress, token)
if err != nil {
return nil, err
}
// If the token has been claimed fully, we don't need to include it in the claim
// This is because contracts reject claims for tokens that have been fully claimed
// https://github.com/Layr-Labs/eigenlayer-contracts/blob/ac57bc1b28c83d9d7143c0da19167c148c3596a3/src/contracts/core/RewardsCoordinator.sol#L575-L578
if claimedAmount.Cmp(amount) == 0 {
continue
}
claimableTokens = append(claimableTokens, token)
}
return claimableTokens, nil
}

func getClaimDistributionRoot(
ctx context.Context,
claimTimestamp string,
Expand Down Expand Up @@ -312,39 +344,40 @@ func getClaimDistributionRoot(
func getTokensToClaim(
claimableTokens *orderedmap.OrderedMap[gethcommon.Address, *distribution.BigInt],
tokenAddresses []gethcommon.Address,
) []gethcommon.Address {
) map[gethcommon.Address]*big.Int {
var tokenMap map[gethcommon.Address]*big.Int
if len(tokenAddresses) == 0 {
tokenAddresses = getAllClaimableTokenAddresses(claimableTokens)
tokenMap = getAllClaimableTokenAddresses(claimableTokens)
} else {
tokenAddresses = filterClaimableTokenAddresses(claimableTokens, tokenAddresses)
tokenMap = filterClaimableTokenAddresses(claimableTokens, tokenAddresses)
}

return tokenAddresses
return tokenMap
}

func getAllClaimableTokenAddresses(
addressesMap *orderedmap.OrderedMap[gethcommon.Address, *distribution.BigInt],
) []gethcommon.Address {
var addresses []gethcommon.Address
) map[gethcommon.Address]*big.Int {
tokens := make(map[gethcommon.Address]*big.Int)
for pair := addressesMap.Oldest(); pair != nil; pair = pair.Next() {
addresses = append(addresses, pair.Key)
tokens[pair.Key] = pair.Value.Int
}

return addresses
return tokens
}

func filterClaimableTokenAddresses(
addressesMap *orderedmap.OrderedMap[gethcommon.Address, *distribution.BigInt],
providedAddresses []gethcommon.Address,
) []gethcommon.Address {
var addresses []gethcommon.Address
) map[gethcommon.Address]*big.Int {
tokens := make(map[gethcommon.Address]*big.Int)
for _, address := range providedAddresses {
if _, ok := addressesMap.Get(address); ok {
addresses = append(addresses, address)
if val, ok := addressesMap.Get(address); ok {
tokens[address] = val.Int
}
}

return addresses
return tokens
}

func convertClaimTokenLeaves(
Expand Down
114 changes: 102 additions & 12 deletions pkg/rewards/claim_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,14 @@ import (
)

type fakeELReader struct {
roots []rewardscoordinator.IRewardsCoordinatorDistributionRoot
roots []rewardscoordinator.IRewardsCoordinatorDistributionRoot
earnerTokenClaimedMap map[common.Address]map[common.Address]*big.Int
}

func newFakeELReader(now time.Time) *fakeELReader {
func newFakeELReader(
now time.Time,
earnerTokenClaimedMap map[common.Address]map[common.Address]*big.Int,
) *fakeELReader {
roots := make([]rewardscoordinator.IRewardsCoordinatorDistributionRoot, 0)
rootOne := rewardscoordinator.IRewardsCoordinatorDistributionRoot{
Root: [32]byte{0x01},
Expand Down Expand Up @@ -60,7 +64,8 @@ func newFakeELReader(now time.Time) *fakeELReader {
return roots[i].ActivatedAt < roots[j].ActivatedAt
})
return &fakeELReader{
roots: roots,
roots: roots,
earnerTokenClaimedMap: earnerTokenClaimedMap,
}
}

Expand Down Expand Up @@ -91,6 +96,21 @@ func (f *fakeELReader) GetCurrentClaimableDistributionRoot(
return rewardscoordinator.IRewardsCoordinatorDistributionRoot{}, errors.New("no active distribution root found")
}

func (f *fakeELReader) GetCumulativeClaimed(
ctx context.Context,
earnerAddress,
tokenAddress common.Address,
) (*big.Int, error) {
if f.earnerTokenClaimedMap == nil {
return big.NewInt(0), nil
}
claimed, ok := f.earnerTokenClaimedMap[earnerAddress][tokenAddress]
if !ok {
return big.NewInt(0), nil
}
return claimed, nil
}

func (f *fakeELReader) CurrRewardsCalculationEndTimestamp(ctx context.Context) (uint32, error) {
rootLen, err := f.GetDistributionRootsLength(ctx)
if err != nil {
Expand Down Expand Up @@ -246,7 +266,7 @@ func TestGetClaimDistributionRoot(t *testing.T) {
},
}

reader := newFakeELReader(now)
reader := newFakeELReader(now, nil)
logger := logging.NewJsonSLogger(os.Stdout, &logging.SLoggerOptions{})

for _, tt := range tests {
Expand Down Expand Up @@ -280,13 +300,18 @@ func TestGetTokensToClaim(t *testing.T) {

// Case 1: No token addresses provided, should return all addresses in claimableTokens
result := getTokensToClaim(claimableTokens, []common.Address{})
expected := []common.Address{addr1, addr2}
assert.ElementsMatch(t, result, expected)
expected := map[common.Address]*big.Int{
addr1: big.NewInt(100),
addr2: big.NewInt(200),
}
assert.Equal(t, result, expected)

// Case 2: Provided token addresses, should return only those present in claimableTokens
result = getTokensToClaim(claimableTokens, []common.Address{addr2, addr3})
expected = []common.Address{addr2}
assert.ElementsMatch(t, result, expected)
expected = map[common.Address]*big.Int{
addr2: big.NewInt(200),
}
assert.Equal(t, result, expected)
}

func TestGetTokenAddresses(t *testing.T) {
Expand All @@ -300,8 +325,11 @@ func TestGetTokenAddresses(t *testing.T) {

// Test that the function returns all addresses in the map
result := getAllClaimableTokenAddresses(addressesMap)
expected := []common.Address{addr1, addr2}
assert.ElementsMatch(t, result, expected)
expected := map[common.Address]*big.Int{
addr1: big.NewInt(100),
addr2: big.NewInt(200),
}
assert.Equal(t, result, expected)
}

func TestFilterClaimableTokenAddresses(t *testing.T) {
Expand All @@ -321,8 +349,70 @@ func TestFilterClaimableTokenAddresses(t *testing.T) {
}

result := filterClaimableTokenAddresses(addressesMap, providedAddresses)
expected := []common.Address{addr1}
assert.ElementsMatch(t, result, expected)
expected := map[common.Address]*big.Int{
addr1: big.NewInt(100),
}
assert.Equal(t, result, expected)
}

func TestFilterClaimableTokens(t *testing.T) {
// Set up a mock claimableTokens map
earnerAddress := common.HexToAddress(testutils.GenerateRandomEthereumAddressString())
tokenAddress1 := common.HexToAddress(testutils.GenerateRandomEthereumAddressString())
tokenAddress2 := common.HexToAddress(testutils.GenerateRandomEthereumAddressString())
amountClaimed1 := big.NewInt(100)
amountClaimed2 := big.NewInt(200)
elReaderClaimedMap := map[common.Address]map[common.Address]*big.Int{
earnerAddress: {
tokenAddress1: amountClaimed1,
tokenAddress2: amountClaimed2,
},
}
now := time.Now()
reader := newFakeELReader(now, elReaderClaimedMap)
tests := []struct {
name string
earnerAddress common.Address
claimableTokensMap map[common.Address]*big.Int
expectedClaimableTokens []common.Address
}{
{
name: "all tokens are claimable and non zero",
earnerAddress: earnerAddress,
claimableTokensMap: map[common.Address]*big.Int{
tokenAddress1: big.NewInt(2345),
tokenAddress2: big.NewInt(3345),
},
expectedClaimableTokens: []common.Address{
tokenAddress1,
tokenAddress2,
},
},
{
name: "one token is already claimed",
earnerAddress: earnerAddress,
claimableTokensMap: map[common.Address]*big.Int{
tokenAddress1: amountClaimed1,
tokenAddress2: big.NewInt(1234),
},
expectedClaimableTokens: []common.Address{
tokenAddress2,
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := filterClaimableTokens(
context.Background(),
reader,
tt.earnerAddress,
tt.claimableTokensMap,
)
assert.NoError(t, err)
assert.ElementsMatch(t, tt.expectedClaimableTokens, result)
})
}
}

func newBigInt(value int64) *distribution.BigInt {
Expand Down
21 changes: 17 additions & 4 deletions pkg/rewards/show.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,18 +169,31 @@ func getClaimedRewards(
) (map[gethcommon.Address]*big.Int, error) {
claimedRewards := make(map[gethcommon.Address]*big.Int)
for address := range allRewards {
claimed, err := elReader.GetCumulativeClaimed(ctx, earnerAddress, address)
claimed, err := getCummulativeClaimedRewards(ctx, elReader, earnerAddress, address)
if err != nil {
return nil, err
}
if claimed == nil {
claimed = big.NewInt(0)
}
claimedRewards[address] = claimed
}
return claimedRewards, nil
}

func getCummulativeClaimedRewards(
ctx context.Context,
elReader ELReader,
earnerAddress gethcommon.Address,
tokenAddress gethcommon.Address,
) (*big.Int, error) {
claimed, err := elReader.GetCumulativeClaimed(ctx, earnerAddress, tokenAddress)
if err != nil {
return nil, err
}
if claimed == nil {
claimed = big.NewInt(0)
}
return claimed, nil
}

func calculateUnclaimedRewards(
allRewards,
claimedRewards map[gethcommon.Address]*big.Int,
Expand Down

0 comments on commit 5f6537f

Please sign in to comment.