From e2f3fc8e02d8122ad627cbc68a3cb8b7e9f7ffe7 Mon Sep 17 00:00:00 2001 From: Facundo Date: Mon, 14 Oct 2024 11:38:26 -0300 Subject: [PATCH] only distribute whitelisted denoms and fix tests --- x/protocolpool/keeper/keeper.go | 13 ++++++--- x/protocolpool/keeper/keeper_test.go | 27 +++++++++++-------- .../testutil/expected_keepers_mocks.go | 14 ++++++++++ x/protocolpool/types/expected_keepers.go | 1 + 4 files changed, 41 insertions(+), 14 deletions(-) diff --git a/x/protocolpool/keeper/keeper.go b/x/protocolpool/keeper/keeper.go index b2b08d1736f8..668246d95c69 100644 --- a/x/protocolpool/keeper/keeper.go +++ b/x/protocolpool/keeper/keeper.go @@ -144,16 +144,23 @@ func (k Keeper) SetToDistribute(ctx context.Context) error { if moduleAccount == nil { return errorsmod.Wrapf(sdkerrors.ErrUnknownAddress, "module account %s does not exist", types.ProtocolPoolDistrAccount) } + params, err := k.Params.Get(ctx) + if err != nil { + return err + } - currentBalance := k.bankKeeper.GetAllBalances(ctx, moduleAccount.GetAddress()) + // only take into account the balances of denoms whitelisted in EnabledDistributionDenoms + currentBalance := sdk.NewCoins() + for _, denom := range params.EnabledDistributionDenoms { + bal := k.bankKeeper.GetBalance(ctx, moduleAccount.GetAddress(), denom) + currentBalance = currentBalance.Add(bal) + } // if the balance is zero, return early if currentBalance.IsZero() { return nil } - // if the balance does not have any of the allowed denoms, return early // TODO - lastBalance, err := k.LastBalance.Get(ctx) if err != nil { if errors.Is(err, collections.ErrNotFound) { diff --git a/x/protocolpool/keeper/keeper_test.go b/x/protocolpool/keeper/keeper_test.go index 894cd0903da7..a52f7b3edf17 100644 --- a/x/protocolpool/keeper/keeper_test.go +++ b/x/protocolpool/keeper/keeper_test.go @@ -79,6 +79,11 @@ func (s *KeeperTestSuite) SetupTest() { s.poolKeeper = poolKeeper s.environment = environment + err = s.poolKeeper.Params.Set(ctx, types.Params{ + EnabledDistributionDenoms: []string{sdk.DefaultBondDenom}, + }) + s.Require().NoError(err) + types.RegisterInterfaces(encCfg.InterfaceRegistry) queryHelper := baseapp.NewQueryServerTestHelper(ctx, encCfg.InterfaceRegistry) types.RegisterQueryServer(queryHelper, poolkeeper.Querier{Keeper: poolKeeper}) @@ -92,8 +97,8 @@ func (s *KeeperTestSuite) mockSendCoinsFromModuleToAccount(accAddr sdk.AccAddres func (s *KeeperTestSuite) mockWithdrawContinuousFund() { s.authKeeper.EXPECT().GetModuleAccount(gomock.Any(), types.ModuleName).Return(poolAcc).AnyTimes() - distrBal := sdk.NewCoins(sdk.NewCoin(sdk.DefaultBondDenom, math.NewInt(100000))) - s.bankKeeper.EXPECT().GetAllBalances(gomock.Any(), gomock.Any()).Return(distrBal).AnyTimes() + distrBal := sdk.NewCoin(sdk.DefaultBondDenom, math.NewInt(100000)) + s.bankKeeper.EXPECT().GetBalance(gomock.Any(), gomock.Any(), sdk.DefaultBondDenom).Return(distrBal).AnyTimes() s.bankKeeper.EXPECT().SendCoinsFromModuleToAccount(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() } @@ -101,8 +106,8 @@ func (s *KeeperTestSuite) mockStreamFunds(distributed math.Int) { s.authKeeper.EXPECT().GetModuleAccount(s.ctx, types.ModuleName).Return(poolAcc).AnyTimes() s.authKeeper.EXPECT().GetModuleAccount(s.ctx, types.ProtocolPoolDistrAccount).Return(poolDistrAcc).AnyTimes() s.authKeeper.EXPECT().GetModuleAddress(types.StreamAccount).Return(streamAcc.GetAddress()).AnyTimes() - distrBal := sdk.NewCoins(sdk.NewCoin(sdk.DefaultBondDenom, distributed)) - s.bankKeeper.EXPECT().GetAllBalances(s.ctx, poolDistrAcc.GetAddress()).Return(distrBal).AnyTimes() + distrBal := sdk.NewCoin(sdk.DefaultBondDenom, distributed) + s.bankKeeper.EXPECT().GetBalance(s.ctx, poolDistrAcc.GetAddress(), sdk.DefaultBondDenom).Return(distrBal).AnyTimes() s.bankKeeper.EXPECT().SendCoinsFromModuleToModule(s.ctx, poolDistrAcc.GetName(), streamAcc.GetName(), gomock.Any()).AnyTimes() s.bankKeeper.EXPECT().SendCoinsFromModuleToModule(s.ctx, poolDistrAcc.GetName(), poolAcc.GetName(), gomock.Any()).AnyTimes() } @@ -116,8 +121,8 @@ func (s *KeeperTestSuite) TestIterateAndUpdateFundsDistribution() { s.SetupTest() s.authKeeper.EXPECT().GetModuleAccount(s.ctx, types.ProtocolPoolDistrAccount).Return(poolAcc).AnyTimes() - distrBal := sdk.NewCoins(sdk.NewCoin(sdk.DefaultBondDenom, math.NewInt(1000000))) - s.bankKeeper.EXPECT().GetAllBalances(s.ctx, poolAcc.GetAddress()).Return(distrBal).AnyTimes() + distrBal := sdk.NewCoin(sdk.DefaultBondDenom, math.NewInt(1000000)) + s.bankKeeper.EXPECT().GetBalance(s.ctx, poolAcc.GetAddress(), sdk.DefaultBondDenom).Return(distrBal).AnyTimes() s.bankKeeper.EXPECT().SendCoinsFromModuleToModule(s.ctx, poolDistrAcc.GetName(), streamAcc.GetName(), sdk.NewCoins(sdk.NewCoin(sdk.DefaultBondDenom, math.NewInt(600000)))) s.bankKeeper.EXPECT().SendCoinsFromModuleToModule(s.ctx, poolDistrAcc.GetName(), poolAcc.GetName(), sdk.NewCoins(sdk.NewCoin(sdk.DefaultBondDenom, math.NewInt(400000)))) @@ -176,11 +181,11 @@ func (suite *KeeperTestSuite) TestSetToDistribute() { suite.SetupTest() suite.authKeeper.EXPECT().GetModuleAccount(suite.ctx, types.ProtocolPoolDistrAccount).Return(poolDistrAcc).AnyTimes() - distrBal := sdk.NewCoins(sdk.NewCoin(sdk.DefaultBondDenom, math.NewInt(1000000))) - suite.bankKeeper.EXPECT().GetAllBalances(suite.ctx, poolDistrAcc.GetAddress()).Return(distrBal).AnyTimes() + distrBal := sdk.NewCoin(sdk.DefaultBondDenom, math.NewInt(1000000)) + suite.bankKeeper.EXPECT().GetBalance(suite.ctx, poolDistrAcc.GetAddress(), sdk.DefaultBondDenom).Return(distrBal).Times(2) // because there are no continuous funds, all are going to the community pool - suite.bankKeeper.EXPECT().SendCoinsFromModuleToModule(suite.ctx, poolDistrAcc.GetName(), poolAcc.GetName(), distrBal) + suite.bankKeeper.EXPECT().SendCoinsFromModuleToModule(suite.ctx, poolDistrAcc.GetName(), poolAcc.GetName(), sdk.NewCoins(distrBal)) err := suite.poolKeeper.SetToDistribute(suite.ctx) suite.Require().NoError(err) @@ -220,8 +225,8 @@ func (suite *KeeperTestSuite) TestSetToDistribute() { suite.Require().Equal(sdk.NewCoins(sdk.NewCoin("stake", math.NewInt(1000000))), distribution.Amount) // Test case when balance is zero - zeroBal := sdk.NewCoins(sdk.NewCoin(sdk.DefaultBondDenom, math.ZeroInt())) - suite.bankKeeper.EXPECT().GetAllBalances(suite.ctx, poolDistrAcc.GetAddress()).Return(zeroBal).AnyTimes() + zeroBal := sdk.NewCoin(sdk.DefaultBondDenom, math.ZeroInt()) + suite.bankKeeper.EXPECT().GetBalance(suite.ctx, poolDistrAcc.GetAddress(), sdk.DefaultBondDenom).Return(zeroBal) err = suite.poolKeeper.SetToDistribute(suite.ctx) suite.Require().NoError(err) diff --git a/x/protocolpool/testutil/expected_keepers_mocks.go b/x/protocolpool/testutil/expected_keepers_mocks.go index bd775d35bdc1..358fcad30bef 100644 --- a/x/protocolpool/testutil/expected_keepers_mocks.go +++ b/x/protocolpool/testutil/expected_keepers_mocks.go @@ -129,6 +129,20 @@ func (mr *MockBankKeeperMockRecorder) GetAllBalances(ctx, addr interface{}) *gom return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllBalances", reflect.TypeOf((*MockBankKeeper)(nil).GetAllBalances), ctx, addr) } +// GetBalance mocks base method. +func (m *MockBankKeeper) GetBalance(ctx context.Context, addr types.AccAddress, denom string) types.Coin { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetBalance", ctx, addr, denom) + ret0, _ := ret[0].(types.Coin) + return ret0 +} + +// GetBalance indicates an expected call of GetBalance. +func (mr *MockBankKeeperMockRecorder) GetBalance(ctx, addr, denom interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBalance", reflect.TypeOf((*MockBankKeeper)(nil).GetBalance), ctx, addr, denom) +} + // SendCoinsFromAccountToModule mocks base method. func (m *MockBankKeeper) SendCoinsFromAccountToModule(ctx context.Context, senderAddr types.AccAddress, recipientModule string, amt types.Coins) error { m.ctrl.T.Helper() diff --git a/x/protocolpool/types/expected_keepers.go b/x/protocolpool/types/expected_keepers.go index ae7adc3a3d80..c4191460b2fa 100644 --- a/x/protocolpool/types/expected_keepers.go +++ b/x/protocolpool/types/expected_keepers.go @@ -17,6 +17,7 @@ type AccountKeeper interface { // BankKeeper defines the expected interface needed to retrieve account balances. type BankKeeper interface { + GetBalance(ctx context.Context, addr sdk.AccAddress, denom string) sdk.Coin GetAllBalances(ctx context.Context, addr sdk.AccAddress) sdk.Coins SpendableCoins(ctx context.Context, addr sdk.AccAddress) sdk.Coins SendCoinsFromModuleToAccount(ctx context.Context, senderModule string, recipientAddr sdk.AccAddress, amt sdk.Coins) error