diff --git a/modules/apps/29-fee/keeper/msg_server.go b/modules/apps/29-fee/keeper/msg_server.go index 4722ad3c29d..77477057ae0 100644 --- a/modules/apps/29-fee/keeper/msg_server.go +++ b/modules/apps/29-fee/keeper/msg_server.go @@ -20,6 +20,15 @@ var _ types.MsgServer = Keeper{} func (k Keeper) RegisterPayee(goCtx context.Context, msg *types.MsgRegisterPayee) (*types.MsgRegisterPayeeResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) + payee, err := sdk.AccAddressFromBech32(msg.Payee) + if err != nil { + return nil, err + } + + if k.bankKeeper.BlockedAddr(payee) { + return nil, sdkerrors.Wrapf(sdkerrors.ErrUnauthorized, "%s is not authorized to be a payee", payee) + } + // only register payee address if the channel exists and is fee enabled if _, found := k.channelKeeper.GetChannel(ctx, msg.PortId, msg.ChannelId); !found { return nil, channeltypes.ErrChannelNotFound @@ -78,6 +87,15 @@ func (k Keeper) PayPacketFee(goCtx context.Context, msg *types.MsgPayPacketFee) return nil, types.ErrFeeModuleLocked } + refundAcc, err := sdk.AccAddressFromBech32(msg.Signer) + if err != nil { + return nil, err + } + + if k.bankKeeper.BlockedAddr(refundAcc) { + return nil, sdkerrors.Wrapf(sdkerrors.ErrUnauthorized, "%s is not allowed to escrow fees", refundAcc) + } + // get the next sequence sequence, found := k.GetNextSequenceSend(ctx, msg.SourcePortId, msg.SourceChannelId) if !found { @@ -110,6 +128,15 @@ func (k Keeper) PayPacketFeeAsync(goCtx context.Context, msg *types.MsgPayPacket return nil, types.ErrFeeModuleLocked } + refundAcc, err := sdk.AccAddressFromBech32(msg.PacketFee.RefundAddress) + if err != nil { + return nil, err + } + + if k.bankKeeper.BlockedAddr(refundAcc) { + return nil, sdkerrors.Wrapf(sdkerrors.ErrUnauthorized, "%s is not allowed to escrow fees", refundAcc) + } + nextSeqSend, found := k.GetNextSequenceSend(ctx, msg.PacketId.PortId, msg.PacketId.ChannelId) if !found { return nil, sdkerrors.Wrapf(channeltypes.ErrSequenceSendNotFound, "channel does not exist, portID: %s, channelID: %s", msg.PacketId.PortId, msg.PacketId.ChannelId) diff --git a/modules/apps/29-fee/keeper/msg_server_test.go b/modules/apps/29-fee/keeper/msg_server_test.go index 5ecd17cc4b8..3f3c19fd871 100644 --- a/modules/apps/29-fee/keeper/msg_server_test.go +++ b/modules/apps/29-fee/keeper/msg_server_test.go @@ -4,6 +4,7 @@ import ( sdk "github.com/Finschia/finschia-sdk/types" "github.com/cosmos/ibc-go/v4/modules/apps/29-fee/types" + transfertypes "github.com/cosmos/ibc-go/v4/modules/apps/transfer/types" clienttypes "github.com/cosmos/ibc-go/v4/modules/core/02-client/types" channeltypes "github.com/cosmos/ibc-go/v4/modules/core/04-channel/types" ibctesting "github.com/cosmos/ibc-go/v4/testing" @@ -11,9 +12,7 @@ import ( ) func (suite *KeeperTestSuite) TestRegisterPayee() { - var ( - msg *types.MsgRegisterPayee - ) + var msg *types.MsgRegisterPayee testCases := []struct { name string @@ -39,6 +38,20 @@ func (suite *KeeperTestSuite) TestRegisterPayee() { suite.chainA.GetSimApp().IBCFeeKeeper.DeleteFeeEnabled(suite.chainA.GetContext(), suite.path.EndpointA.ChannelConfig.PortID, suite.path.EndpointA.ChannelID) }, }, + { + "given payee is not an sdk address", + false, + func() { + msg.Payee = "invalid-addr" + }, + }, + { + "payee is a blocked address", + false, + func() { + msg.Payee = suite.chainA.GetSimApp().AccountKeeper.GetModuleAddress(transfertypes.ModuleName).String() + }, + }, } for _, tc := range testCases { @@ -222,6 +235,14 @@ func (suite *KeeperTestSuite) TestPayPacketFee() { }, false, }, + { + "refund account is a blocked address", + func() { + blockedAddr := suite.chainA.GetSimApp().AccountKeeper.GetModuleAccount(suite.chainA.GetContext(), transfertypes.ModuleName).GetAddress() + msg.Signer = blockedAddr.String() + }, + false, + }, { "acknowledgement fee balance not found", func() { @@ -401,6 +422,14 @@ func (suite *KeeperTestSuite) TestPayPacketFeeAsync() { }, false, }, + { + "refund account is a blocked address", + func() { + blockedAddr := suite.chainA.GetSimApp().AccountKeeper.GetModuleAccount(suite.chainA.GetContext(), transfertypes.ModuleName).GetAddress() + msg.PacketFee.RefundAddress = blockedAddr.String() + }, + false, + }, { "acknowledgement fee balance not found", func() { diff --git a/testing/simapp/app.go b/testing/simapp/app.go index 0a2de262f8b..365ae935615 100644 --- a/testing/simapp/app.go +++ b/testing/simapp/app.go @@ -653,7 +653,7 @@ func (app *SimApp) LoadHeight(height int64) error { func (app *SimApp) ModuleAccountAddrs() map[string]bool { modAccAddrs := make(map[string]bool) for acc := range maccPerms { - // do not add mock module to blocked addresses + // do not add the following modules to blocked addresses // this is only used for testing if acc == ibcmock.ModuleName { continue