Skip to content

Commit

Permalink
only distribute whitelisted denoms and fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
facundomedica committed Oct 14, 2024
1 parent c92e812 commit e2f3fc8
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 14 deletions.
13 changes: 10 additions & 3 deletions x/protocolpool/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,16 +144,23 @@ func (k Keeper) SetToDistribute(ctx context.Context) error {
if moduleAccount == nil {
return errorsmod.Wrapf(sdkerrors.ErrUnknownAddress, "module account %s does not exist", types.ProtocolPoolDistrAccount)
}
params, err := k.Params.Get(ctx)
if err != nil {
return err
}

currentBalance := k.bankKeeper.GetAllBalances(ctx, moduleAccount.GetAddress())
// only take into account the balances of denoms whitelisted in EnabledDistributionDenoms
currentBalance := sdk.NewCoins()
for _, denom := range params.EnabledDistributionDenoms {
bal := k.bankKeeper.GetBalance(ctx, moduleAccount.GetAddress(), denom)
currentBalance = currentBalance.Add(bal)
}

// if the balance is zero, return early
if currentBalance.IsZero() {
return nil
}

// if the balance does not have any of the allowed denoms, return early // TODO

lastBalance, err := k.LastBalance.Get(ctx)
if err != nil {
if errors.Is(err, collections.ErrNotFound) {
Expand Down
27 changes: 16 additions & 11 deletions x/protocolpool/keeper/keeper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ func (s *KeeperTestSuite) SetupTest() {
s.poolKeeper = poolKeeper
s.environment = environment

err = s.poolKeeper.Params.Set(ctx, types.Params{
EnabledDistributionDenoms: []string{sdk.DefaultBondDenom},
})
s.Require().NoError(err)

types.RegisterInterfaces(encCfg.InterfaceRegistry)
queryHelper := baseapp.NewQueryServerTestHelper(ctx, encCfg.InterfaceRegistry)
types.RegisterQueryServer(queryHelper, poolkeeper.Querier{Keeper: poolKeeper})
Expand All @@ -92,17 +97,17 @@ func (s *KeeperTestSuite) mockSendCoinsFromModuleToAccount(accAddr sdk.AccAddres

func (s *KeeperTestSuite) mockWithdrawContinuousFund() {
s.authKeeper.EXPECT().GetModuleAccount(gomock.Any(), types.ModuleName).Return(poolAcc).AnyTimes()
distrBal := sdk.NewCoins(sdk.NewCoin(sdk.DefaultBondDenom, math.NewInt(100000)))
s.bankKeeper.EXPECT().GetAllBalances(gomock.Any(), gomock.Any()).Return(distrBal).AnyTimes()
distrBal := sdk.NewCoin(sdk.DefaultBondDenom, math.NewInt(100000))
s.bankKeeper.EXPECT().GetBalance(gomock.Any(), gomock.Any(), sdk.DefaultBondDenom).Return(distrBal).AnyTimes()
s.bankKeeper.EXPECT().SendCoinsFromModuleToAccount(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
}

func (s *KeeperTestSuite) mockStreamFunds(distributed math.Int) {
s.authKeeper.EXPECT().GetModuleAccount(s.ctx, types.ModuleName).Return(poolAcc).AnyTimes()
s.authKeeper.EXPECT().GetModuleAccount(s.ctx, types.ProtocolPoolDistrAccount).Return(poolDistrAcc).AnyTimes()
s.authKeeper.EXPECT().GetModuleAddress(types.StreamAccount).Return(streamAcc.GetAddress()).AnyTimes()
distrBal := sdk.NewCoins(sdk.NewCoin(sdk.DefaultBondDenom, distributed))
s.bankKeeper.EXPECT().GetAllBalances(s.ctx, poolDistrAcc.GetAddress()).Return(distrBal).AnyTimes()
distrBal := sdk.NewCoin(sdk.DefaultBondDenom, distributed)
s.bankKeeper.EXPECT().GetBalance(s.ctx, poolDistrAcc.GetAddress(), sdk.DefaultBondDenom).Return(distrBal).AnyTimes()
s.bankKeeper.EXPECT().SendCoinsFromModuleToModule(s.ctx, poolDistrAcc.GetName(), streamAcc.GetName(), gomock.Any()).AnyTimes()
s.bankKeeper.EXPECT().SendCoinsFromModuleToModule(s.ctx, poolDistrAcc.GetName(), poolAcc.GetName(), gomock.Any()).AnyTimes()
}
Expand All @@ -116,8 +121,8 @@ func (s *KeeperTestSuite) TestIterateAndUpdateFundsDistribution() {

s.SetupTest()
s.authKeeper.EXPECT().GetModuleAccount(s.ctx, types.ProtocolPoolDistrAccount).Return(poolAcc).AnyTimes()
distrBal := sdk.NewCoins(sdk.NewCoin(sdk.DefaultBondDenom, math.NewInt(1000000)))
s.bankKeeper.EXPECT().GetAllBalances(s.ctx, poolAcc.GetAddress()).Return(distrBal).AnyTimes()
distrBal := sdk.NewCoin(sdk.DefaultBondDenom, math.NewInt(1000000))
s.bankKeeper.EXPECT().GetBalance(s.ctx, poolAcc.GetAddress(), sdk.DefaultBondDenom).Return(distrBal).AnyTimes()
s.bankKeeper.EXPECT().SendCoinsFromModuleToModule(s.ctx, poolDistrAcc.GetName(), streamAcc.GetName(), sdk.NewCoins(sdk.NewCoin(sdk.DefaultBondDenom, math.NewInt(600000))))
s.bankKeeper.EXPECT().SendCoinsFromModuleToModule(s.ctx, poolDistrAcc.GetName(), poolAcc.GetName(), sdk.NewCoins(sdk.NewCoin(sdk.DefaultBondDenom, math.NewInt(400000))))

Expand Down Expand Up @@ -176,11 +181,11 @@ func (suite *KeeperTestSuite) TestSetToDistribute() {
suite.SetupTest()

suite.authKeeper.EXPECT().GetModuleAccount(suite.ctx, types.ProtocolPoolDistrAccount).Return(poolDistrAcc).AnyTimes()
distrBal := sdk.NewCoins(sdk.NewCoin(sdk.DefaultBondDenom, math.NewInt(1000000)))
suite.bankKeeper.EXPECT().GetAllBalances(suite.ctx, poolDistrAcc.GetAddress()).Return(distrBal).AnyTimes()
distrBal := sdk.NewCoin(sdk.DefaultBondDenom, math.NewInt(1000000))
suite.bankKeeper.EXPECT().GetBalance(suite.ctx, poolDistrAcc.GetAddress(), sdk.DefaultBondDenom).Return(distrBal).Times(2)

// because there are no continuous funds, all are going to the community pool
suite.bankKeeper.EXPECT().SendCoinsFromModuleToModule(suite.ctx, poolDistrAcc.GetName(), poolAcc.GetName(), distrBal)
suite.bankKeeper.EXPECT().SendCoinsFromModuleToModule(suite.ctx, poolDistrAcc.GetName(), poolAcc.GetName(), sdk.NewCoins(distrBal))

err := suite.poolKeeper.SetToDistribute(suite.ctx)
suite.Require().NoError(err)
Expand Down Expand Up @@ -220,8 +225,8 @@ func (suite *KeeperTestSuite) TestSetToDistribute() {
suite.Require().Equal(sdk.NewCoins(sdk.NewCoin("stake", math.NewInt(1000000))), distribution.Amount)

// Test case when balance is zero
zeroBal := sdk.NewCoins(sdk.NewCoin(sdk.DefaultBondDenom, math.ZeroInt()))
suite.bankKeeper.EXPECT().GetAllBalances(suite.ctx, poolDistrAcc.GetAddress()).Return(zeroBal).AnyTimes()
zeroBal := sdk.NewCoin(sdk.DefaultBondDenom, math.ZeroInt())
suite.bankKeeper.EXPECT().GetBalance(suite.ctx, poolDistrAcc.GetAddress(), sdk.DefaultBondDenom).Return(zeroBal)

err = suite.poolKeeper.SetToDistribute(suite.ctx)
suite.Require().NoError(err)
Expand Down
14 changes: 14 additions & 0 deletions x/protocolpool/testutil/expected_keepers_mocks.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions x/protocolpool/types/expected_keepers.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ type AccountKeeper interface {

// BankKeeper defines the expected interface needed to retrieve account balances.
type BankKeeper interface {
GetBalance(ctx context.Context, addr sdk.AccAddress, denom string) sdk.Coin
GetAllBalances(ctx context.Context, addr sdk.AccAddress) sdk.Coins
SpendableCoins(ctx context.Context, addr sdk.AccAddress) sdk.Coins
SendCoinsFromModuleToAccount(ctx context.Context, senderModule string, recipientAddr sdk.AccAddress, amt sdk.Coins) error
Expand Down

0 comments on commit e2f3fc8

Please sign in to comment.