Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: check that MsgChannelUpgradeInit is signed by authority #4773

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .github/workflows/golangci-feature.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ jobs:
steps:
- uses: actions/setup-go@v4
with:
go-version: '1.20'
- uses: actions/checkout@v3
go-version: '1.21'
- uses: actions/checkout@v4
- name: golangci-lint
uses: golangci/golangci-lint-action@v3.6.0
uses: golangci/golangci-lint-action@v3.7.0
with:
version: v1.53.1
args: --timeout 5m --exclude unused
version: v1.54.2
args: --timeout 10m
10 changes: 5 additions & 5 deletions modules/apps/callbacks/ibc_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -364,27 +364,27 @@ func (im IBCMiddleware) OnChanCloseConfirm(ctx sdk.Context, portID, channelID st
}

// OnChanUpgradeInit implements the IBCModule interface
func (im IBCMiddleware) OnChanUpgradeInit(ctx sdk.Context, portID, channelID string, order channeltypes.Order, connectionHops []string, version string) (string, error) {
func (IBCMiddleware) OnChanUpgradeInit(ctx sdk.Context, portID, channelID string, order channeltypes.Order, connectionHops []string, version string) (string, error) {
panic("implement me")
}

// OnChanUpgradeTry implements the IBCModule interface
func (im IBCMiddleware) OnChanUpgradeTry(ctx sdk.Context, portID, channelID string, order channeltypes.Order, connectionHops []string, counterpartyVersion string) (string, error) {
func (IBCMiddleware) OnChanUpgradeTry(ctx sdk.Context, portID, channelID string, order channeltypes.Order, connectionHops []string, counterpartyVersion string) (string, error) {
panic("implement me")
}

// OnChanUpgradeAck implements the IBCModule interface
func (im IBCMiddleware) OnChanUpgradeAck(ctx sdk.Context, portID, channelID, counterpartyVersion string) error {
func (IBCMiddleware) OnChanUpgradeAck(ctx sdk.Context, portID, channelID, counterpartyVersion string) error {
panic("implement me")
}

// OnChanUpgradeOpen implements the IBCModule interface
func (im IBCMiddleware) OnChanUpgradeOpen(ctx sdk.Context, portID, channelID string, order channeltypes.Order, connectionHops []string, version string) {
func (IBCMiddleware) OnChanUpgradeOpen(ctx sdk.Context, portID, channelID string, order channeltypes.Order, connectionHops []string, version string) {
panic("implement me")
}

// OnChanUpgradeRestore implements the IBCModule interface
func (im IBCMiddleware) OnChanUpgradeRestore(ctx sdk.Context, portID, channelID string) {
func (IBCMiddleware) OnChanUpgradeRestore(ctx sdk.Context, portID, channelID string) {
panic("implement me")
}

Expand Down
1 change: 0 additions & 1 deletion modules/apps/transfer/ibc_module_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,6 @@ func (suite *TransferTestSuite) TestOnChanUpgradeInit() {
expPass := tc.expError == nil
if expPass {
suite.Require().NoError(err)

upgrade := path.EndpointA.GetChannelUpgrade()
suite.Require().Equal(upgradePath.EndpointA.ConnectionID, upgrade.Fields.ConnectionHops[0])
} else {
Expand Down
4 changes: 4 additions & 0 deletions modules/core/04-channel/types/msgs.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ var (
_ sdk.HasValidateBasic = (*MsgAcknowledgement)(nil)
_ sdk.HasValidateBasic = (*MsgTimeout)(nil)
_ sdk.HasValidateBasic = (*MsgTimeoutOnClose)(nil)
_ sdk.HasValidateBasic = (*MsgChannelUpgradeInit)(nil)
_ sdk.HasValidateBasic = (*MsgChannelUpgradeTry)(nil)
_ sdk.HasValidateBasic = (*MsgChannelUpgradeAck)(nil)
_ sdk.HasValidateBasic = (*MsgChannelUpgradeConfirm)(nil)
)

// NewMsgChannelOpenInit creates a new MsgChannelOpenInit. It sets the counterparty channel
Expand Down
4 changes: 4 additions & 0 deletions modules/core/keeper/msg_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,10 @@ func (k Keeper) Acknowledgement(goCtx context.Context, msg *channeltypes.MsgAckn
func (k Keeper) ChannelUpgradeInit(goCtx context.Context, msg *channeltypes.MsgChannelUpgradeInit) (*channeltypes.MsgChannelUpgradeInitResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)

if k.GetAuthority() != msg.Signer {
return nil, errorsmod.Wrapf(ibcerrors.ErrUnauthorized, "expected %s, got %s", k.GetAuthority(), msg.Signer)
}

module, _, err := k.ChannelKeeper.LookupModuleByChannel(ctx, msg.PortId, msg.ChannelId)
if err != nil {
ctx.Logger().Error("channel upgrade init failed", "port-id", msg.PortId, "error", errorsmod.Wrap(err, "could not retrieve module from port-id"))
Expand Down
69 changes: 68 additions & 1 deletion modules/core/keeper/msg_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,72 @@ func (suite *KeeperTestSuite) TestUpgradeClient() {
}
}

func (suite *KeeperTestSuite) TestChannelUpgradeInit() {
var (
path *ibctesting.Path
msg *channeltypes.MsgChannelUpgradeInit
)

cases := []struct {
name string
malleate func()
expResult func(res *channeltypes.MsgChannelUpgradeInitResponse, err error)
}{
{
"success",
func() {
msg = channeltypes.NewMsgChannelUpgradeInit(
path.EndpointA.ChannelConfig.PortID,
path.EndpointA.ChannelID,
path.EndpointA.GetProposedUpgrade().Fields,
path.EndpointA.Chain.GetSimApp().IBCKeeper.GetAuthority(),
)
},
func(res *channeltypes.MsgChannelUpgradeInitResponse, err error) {
suite.Require().NoError(err)
suite.Require().NotNil(res)
suite.Require().Equal(uint64(1), res.UpgradeSequence)
},
},
{
"authority is not signer of the upgrade init msg",
func() {
msg = channeltypes.NewMsgChannelUpgradeInit(
path.EndpointA.ChannelConfig.PortID,
path.EndpointA.ChannelID,
path.EndpointA.GetProposedUpgrade().Fields,
path.EndpointA.Chain.SenderAccount.String(),
)
},
func(res *channeltypes.MsgChannelUpgradeInitResponse, err error) {
suite.Require().Error(err)
suite.Require().ErrorContains(err, ibcerrors.ErrUnauthorized.Error())
suite.Require().Nil(res)
},
},
}

for _, tc := range cases {
tc := tc
suite.Run(tc.name, func() {
suite.SetupTest()

path = ibctesting.NewPath(suite.chainA, suite.chainB)
suite.coordinator.Setup(path)

// configure the channel upgrade version on testing endpoints
path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = ibcmock.UpgradeVersion
path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = ibcmock.UpgradeVersion

tc.malleate()

res, err := keeper.Keeper.ChannelUpgradeInit(*suite.chainA.App.GetIBCKeeper(), suite.chainA.GetContext(), msg)

tc.expResult(res, err)
})
}
}

func (suite *KeeperTestSuite) TestChannelUpgradeTry() {
var (
path *ibctesting.Path
Expand Down Expand Up @@ -1511,7 +1577,8 @@ func (suite *KeeperTestSuite) TestChannelUpgradeCancel() {
path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = ibcmock.UpgradeVersion
path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = ibcmock.UpgradeVersion

suite.Require().NoError(path.EndpointA.ChanUpgradeInit())
err := path.EndpointA.ChanUpgradeInit()
suite.Require().NoError(err)

// cause the upgrade to fail on chain b so an error receipt is written.
// if the counterparty (chain A) upgrade sequence is less than the current sequence, (chain B)
Expand Down
35 changes: 33 additions & 2 deletions testing/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@ package ibctesting

import (
"fmt"
"strconv"
"strings"

"github.com/stretchr/testify/require"

"github.com/cosmos/cosmos-sdk/baseapp"
sdk "github.com/cosmos/cosmos-sdk/types"
govtypesv1 "github.com/cosmos/cosmos-sdk/x/gov/types/v1"

abci "github.com/cometbft/cometbft/abci/types"

Expand Down Expand Up @@ -587,18 +590,46 @@ func (endpoint *Endpoint) QueryChannelUpgradeProof() ([]byte, []byte, clienttype

// ChanUpgradeInit sends a MsgChannelUpgradeInit on the associated endpoint.
// A default upgrade proposal is used with overrides from the ProposedUpgrade
// in the channel config.
// in the channel config, and submitted via governance proposal
func (endpoint *Endpoint) ChanUpgradeInit() error {
upgrade := endpoint.GetProposedUpgrade()

// create upgrade init message via gov proposal and submit the proposal
msg := channeltypes.NewMsgChannelUpgradeInit(
endpoint.ChannelConfig.PortID,
endpoint.ChannelID,
upgrade.Fields,
endpoint.Chain.GetSimApp().IBCKeeper.GetAuthority(),
)

proposal, err := govtypesv1.NewMsgSubmitProposal(
[]sdk.Msg{msg},
sdk.NewCoins(sdk.NewCoin(sdk.DefaultBondDenom, govtypesv1.DefaultMinDepositTokens)),
endpoint.Chain.SenderAccount.GetAddress().String(),
endpoint.ChannelID,
"upgrade-init",
fmt.Sprintf("gov proposal for initialising channel upgrade: %s", endpoint.ChannelID),
false,
)
require.NoError(endpoint.Chain.TB, err)

return endpoint.Chain.sendMsgs(msg)
var proposalID uint64
res, err := endpoint.Chain.SendMsgs(proposal)
if err != nil {
return err
}

events := res.Events
for _, event := range events {
for _, attribute := range event.Attributes {
if attribute.Key == "proposal_id" {
proposalID, err = strconv.ParseUint(attribute.Value, 10, 64)
require.NoError(endpoint.Chain.TB, err)
}
}
}
Comment on lines +622 to +630
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could sweep this under a ibctesting.ParseProposalIDFromEvents() func but can be a future improvement 👍

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

created #4836

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there's like 3 issues now for this 😄 #4803, #4809

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LOL, I guess we should do it then :D


return VoteAndCheckProposalStatus(endpoint, proposalID)
}

// ChanUpgradeTry sends a MsgChannelUpgradeTry on the associated endpoint.
Expand Down
26 changes: 26 additions & 0 deletions testing/utils.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package ibctesting

import (
"fmt"
"testing"

"github.com/stretchr/testify/require"

govtypesv1 "github.com/cosmos/cosmos-sdk/x/gov/types/v1"

abci "github.com/cometbft/cometbft/abci/types"
tmtypes "github.com/cometbft/cometbft/types"
)
Expand All @@ -23,3 +26,26 @@ func ApplyValSetChanges(tb testing.TB, valSet *tmtypes.ValidatorSet, valUpdates

return newVals
}

// VoteAndCheckProposalStatus votes on a gov proposal, checks if the proposal has passed, and returns an error if it has not with the failure reason.
func VoteAndCheckProposalStatus(endpoint *Endpoint, proposalID uint64) error {
// vote on proposal
ctx := endpoint.Chain.GetContext()
require.NoError(endpoint.Chain.TB, endpoint.Chain.GetSimApp().GovKeeper.AddVote(ctx, proposalID, endpoint.Chain.SenderAccount.GetAddress(), govtypesv1.NewNonSplitVoteOption(govtypesv1.OptionYes), ""))

// fast forward the chain context to end the voting period
params, err := endpoint.Chain.GetSimApp().GovKeeper.Params.Get(ctx)
require.NoError(endpoint.Chain.TB, err)

endpoint.Chain.Coordinator.IncrementTimeBy(*params.VotingPeriod + *params.MaxDepositPeriod)
endpoint.Chain.NextBlock()

// check if proposal passed or failed on msg execution
// we need to grab the context again since the previous context is no longer valid as the chain header time has been incremented
p, err := endpoint.Chain.GetSimApp().GovKeeper.Proposals.Get(endpoint.Chain.GetContext(), proposalID)
require.NoError(endpoint.Chain.TB, err)
if p.Status != govtypesv1.StatusPassed {
return fmt.Errorf("proposal failed: %s", p.FailedReason)
}
return nil
}