From 955fa2df39fd18dcd47255eba605178bec34a9f4 Mon Sep 17 00:00:00 2001 From: yutianwu Date: Fri, 25 Aug 2023 15:32:43 +0800 Subject: [PATCH] feat: restrict token transfers to payment accounts (#277) * feat: restrict token transfers to payment accounts * fix comments --- simapp/app.go | 5 +-- tests/integration/bank/keeper/keeper_test.go | 4 +-- x/bank/keeper/keeper_test.go | 2 +- x/bank/keeper/msg_server.go | 19 ++++++++-- x/bank/module.go | 8 +++-- x/bank/testutil/expected_keepers_mocks.go | 37 ++++++++++++++++++++ x/bank/types/expected_keepers.go | 7 ++++ 7 files changed, 72 insertions(+), 10 deletions(-) diff --git a/simapp/app.go b/simapp/app.go index d6a7d3805f..a17df531cb 100644 --- a/simapp/app.go +++ b/simapp/app.go @@ -13,9 +13,10 @@ import ( dbm "github.com/cometbft/cometbft-db" abci "github.com/cometbft/cometbft/abci/types" "github.com/cometbft/cometbft/libs/log" + "github.com/spf13/cast" + "github.com/cosmos/cosmos-sdk/x/crosschain" "github.com/cosmos/cosmos-sdk/x/oracle" - "github.com/spf13/cast" crosschainkeeper "github.com/cosmos/cosmos-sdk/x/crosschain/keeper" crosschaintypes "github.com/cosmos/cosmos-sdk/x/crosschain/types" @@ -415,7 +416,7 @@ func NewSimApp( ), auth.NewAppModule(appCodec, app.AccountKeeper, authsims.RandomGenesisAccounts, app.GetSubspace(authtypes.ModuleName)), vesting.NewAppModule(app.AccountKeeper, app.BankKeeper), - bank.NewAppModule(appCodec, app.BankKeeper, app.AccountKeeper, app.GetSubspace(banktypes.ModuleName)), + bank.NewAppModule(appCodec, app.BankKeeper, app.AccountKeeper, nil, app.GetSubspace(banktypes.ModuleName)), capability.NewAppModule(appCodec, *app.CapabilityKeeper, false), crisis.NewAppModule(app.CrisisKeeper, skipGenesisInvariants, app.GetSubspace(crisistypes.ModuleName)), feegrantmodule.NewAppModule(appCodec, app.AccountKeeper, app.BankKeeper, app.FeeGrantKeeper, app.interfaceRegistry), diff --git a/tests/integration/bank/keeper/keeper_test.go b/tests/integration/bank/keeper/keeper_test.go index f38d4a2298..8278881186 100644 --- a/tests/integration/bank/keeper/keeper_test.go +++ b/tests/integration/bank/keeper/keeper_test.go @@ -75,7 +75,7 @@ func newBarCoin(amt int64) sdk.Coin { return sdk.NewInt64Coin(barDenom, amt) } -//nolint: interfacer +// nolint: interfacer func getCoinsByName(ctx sdk.Context, bk keeper.Keeper, ak types.AccountKeeper, moduleName string) sdk.Coins { moduleAddress := ak.GetModuleAddress(moduleName) macc := ak.GetAccount(ctx, moduleAddress) @@ -150,7 +150,7 @@ func (suite *IntegrationTestSuite) SetupTest() { types.RegisterInterfaces(interfaceRegistry) suite.queryClient = queryClient - suite.msgServer = keeper.NewMsgServerImpl(suite.bankKeeper) + suite.msgServer = keeper.NewMsgServerImpl(suite.bankKeeper, nil) } func (suite *IntegrationTestSuite) TestSupply() { diff --git a/x/bank/keeper/keeper_test.go b/x/bank/keeper/keeper_test.go index 1b9d1a184c..530c590a4d 100644 --- a/x/bank/keeper/keeper_test.go +++ b/x/bank/keeper/keeper_test.go @@ -111,7 +111,7 @@ func (suite *KeeperTestSuite) SetupTest() { queryClient := banktypes.NewQueryClient(queryHelper) suite.queryClient = queryClient - suite.msgServer = keeper.NewMsgServerImpl(suite.bankKeeper) + suite.msgServer = keeper.NewMsgServerImpl(suite.bankKeeper, nil) suite.encCfg = encCfg } diff --git a/x/bank/keeper/msg_server.go b/x/bank/keeper/msg_server.go index b74285b83c..e8914069ed 100644 --- a/x/bank/keeper/msg_server.go +++ b/x/bank/keeper/msg_server.go @@ -10,18 +10,20 @@ import ( sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" "github.com/cosmos/cosmos-sdk/x/bank/types" govtypes "github.com/cosmos/cosmos-sdk/x/gov/types" + upgradetypes "github.com/cosmos/cosmos-sdk/x/upgrade/types" ) type msgServer struct { Keeper + PaymentKeeper types.PaymentKeeper } var _ types.MsgServer = msgServer{} // NewMsgServerImpl returns an implementation of the bank MsgServer interface // for the provided Keeper. -func NewMsgServerImpl(keeper Keeper) types.MsgServer { - return &msgServer{Keeper: keeper} +func NewMsgServerImpl(keeper Keeper, paymentKeeper types.PaymentKeeper) types.MsgServer { + return &msgServer{Keeper: keeper, PaymentKeeper: paymentKeeper} } func (k msgServer) Send(goCtx context.Context, msg *types.MsgSend) (*types.MsgSendResponse, error) { @@ -44,6 +46,12 @@ func (k msgServer) Send(goCtx context.Context, msg *types.MsgSend) (*types.MsgSe return nil, sdkerrors.Wrapf(sdkerrors.ErrUnauthorized, "%s is not allowed to receive funds", msg.ToAddress) } + if ctx.IsUpgraded(upgradetypes.Nagqu) { + if k.PaymentKeeper != nil && k.PaymentKeeper.IsPaymentAccount(ctx, to) { + return nil, sdkerrors.Wrapf(sdkerrors.ErrUnauthorized, "payment account %s is not allowed to receive funds", msg.ToAddress) + } + } + err = k.SendCoins(ctx, from, to, msg.Amount) if err != nil { return nil, err @@ -74,12 +82,19 @@ func (k msgServer) MultiSend(goCtx context.Context, msg *types.MsgMultiSend) (*t } } + isUpgradeNagqu := ctx.IsUpgraded(upgradetypes.Nagqu) for _, out := range msg.Outputs { accAddr := sdk.MustAccAddressFromHex(out.Address) if k.BlockedAddr(accAddr) { return nil, sdkerrors.Wrapf(sdkerrors.ErrUnauthorized, "%s is not allowed to receive funds", out.Address) } + + if isUpgradeNagqu { + if k.PaymentKeeper != nil && k.PaymentKeeper.IsPaymentAccount(ctx, accAddr) { + return nil, sdkerrors.Wrapf(sdkerrors.ErrUnauthorized, "payment account %s is not allowed to receive funds", accAddr) + } + } } err := k.InputOutputCoins(ctx, msg.Inputs, msg.Outputs) diff --git a/x/bank/module.go b/x/bank/module.go index 6bd264bf62..65bb879491 100644 --- a/x/bank/module.go +++ b/x/bank/module.go @@ -100,6 +100,7 @@ type AppModule struct { keeper keeper.Keeper accountKeeper types.AccountKeeper + paymentKeeper types.PaymentKeeper // legacySubspace is used solely for migration of x/params managed parameters legacySubspace exported.Subspace @@ -115,7 +116,7 @@ func (am AppModule) IsAppModule() {} // RegisterServices registers module services. func (am AppModule) RegisterServices(cfg module.Configurator) { - types.RegisterMsgServer(cfg.MsgServer(), keeper.NewMsgServerImpl(am.keeper)) + types.RegisterMsgServer(cfg.MsgServer(), keeper.NewMsgServerImpl(am.keeper, am.paymentKeeper)) types.RegisterQueryServer(cfg.QueryServer(), am.keeper) m := keeper.NewMigrator(am.keeper.(keeper.BaseKeeper), am.legacySubspace) @@ -133,11 +134,12 @@ func (am AppModule) RegisterServices(cfg module.Configurator) { } // NewAppModule creates a new AppModule object -func NewAppModule(cdc codec.Codec, keeper keeper.Keeper, accountKeeper types.AccountKeeper, ss exported.Subspace) AppModule { +func NewAppModule(cdc codec.Codec, keeper keeper.Keeper, accountKeeper types.AccountKeeper, paymentKeeper types.PaymentKeeper, ss exported.Subspace) AppModule { return AppModule{ AppModuleBasic: AppModuleBasic{cdc: cdc}, keeper: keeper, accountKeeper: accountKeeper, + paymentKeeper: paymentKeeper, legacySubspace: ss, } } @@ -254,7 +256,7 @@ func ProvideModule(in BankInputs) BankOutputs { blockedAddresses, authority.String(), ) - m := NewAppModule(in.Cdc, bankKeeper, in.AccountKeeper, in.LegacySubspace) + m := NewAppModule(in.Cdc, bankKeeper, in.AccountKeeper, nil, in.LegacySubspace) return BankOutputs{BankKeeper: bankKeeper, Module: m} } diff --git a/x/bank/testutil/expected_keepers_mocks.go b/x/bank/testutil/expected_keepers_mocks.go index c27addd028..e7697e32c1 100644 --- a/x/bank/testutil/expected_keepers_mocks.go +++ b/x/bank/testutil/expected_keepers_mocks.go @@ -226,3 +226,40 @@ func (mr *MockAccountKeeperMockRecorder) ValidatePermissions(macc interface{}) * mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidatePermissions", reflect.TypeOf((*MockAccountKeeper)(nil).ValidatePermissions), macc) } + +// MockPaymentKeeper is a mock of PaymentKeeper interface. +type MockPaymentKeeper struct { + ctrl *gomock.Controller + recorder *MockPaymentKeeperMockRecorder +} + +// MockPaymentKeeperMockRecorder is the mock recorder for MockPaymentKeeper. +type MockPaymentKeeperMockRecorder struct { + mock *MockPaymentKeeper +} + +// NewMockPaymentKeeper creates a new mock instance. +func NewMockPaymentKeeper(ctrl *gomock.Controller) *MockPaymentKeeper { + mock := &MockPaymentKeeper{ctrl: ctrl} + mock.recorder = &MockPaymentKeeperMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockPaymentKeeper) EXPECT() *MockPaymentKeeperMockRecorder { + return m.recorder +} + +// IsPaymentAccount mocks base method. +func (m *MockPaymentKeeper) IsPaymentAccount(ctx types.Context, addr types.AccAddress) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsPaymentAccount", ctx, addr) + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsPaymentAccount indicates an expected call of IsPaymentAccount. +func (mr *MockPaymentKeeperMockRecorder) IsPaymentAccount(ctx, addr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsPaymentAccount", reflect.TypeOf((*MockPaymentKeeper)(nil).IsPaymentAccount), ctx, addr) +} diff --git a/x/bank/types/expected_keepers.go b/x/bank/types/expected_keepers.go index 191d6b5a83..5b54b4186c 100644 --- a/x/bank/types/expected_keepers.go +++ b/x/bank/types/expected_keepers.go @@ -27,3 +27,10 @@ type AccountKeeper interface { SetModuleAccount(ctx sdk.Context, macc types.ModuleAccountI) GetModulePermissions() map[string]types.PermissionsForAddress } + +type PaymentKeeper interface { + IsPaymentAccount( + ctx sdk.Context, + addr sdk.AccAddress, + ) bool +}