diff --git a/modules/core/04-channel/keeper/export_test.go b/modules/core/04-channel/keeper/export_test.go index d97fd5e5a67..161205bf81a 100644 --- a/modules/core/04-channel/keeper/export_test.go +++ b/modules/core/04-channel/keeper/export_test.go @@ -30,8 +30,3 @@ func (k Keeper) StartFlushUpgradeHandshake( func (k Keeper) ValidateUpgradeFields(ctx sdk.Context, proposedUpgrade types.UpgradeFields, currentChannel types.Channel) error { return k.validateUpgradeFields(ctx, proposedUpgrade, currentChannel) } - -// WriteUpgradeOpenChannel is a wrapper around writeUpgradeOpenChannel to allow the function to be directly called in tests. -func (k Keeper) WriteUpgradeOpenChannel(ctx sdk.Context, portID, channelID string) { - k.writeUpgradeOpenChannel(ctx, portID, channelID) -} diff --git a/modules/core/04-channel/keeper/upgrade.go b/modules/core/04-channel/keeper/upgrade.go index 581989fbe62..d14579d763e 100644 --- a/modules/core/04-channel/keeper/upgrade.go +++ b/modules/core/04-channel/keeper/upgrade.go @@ -408,7 +408,7 @@ func (k Keeper) ChanUpgradeOpen( // WriteUpgradeOpenChannel writes the agreed upon upgrade fields to the channel, sets the channel flush status to NOTINFLUSH and sets the channel state back to OPEN. This can be called in one of two cases: // - In the UpgradeAck step of the handshake if both sides have already flushed all in-flight packets. // - In the UpgradeOpen step of the handshake. -func (k Keeper) writeUpgradeOpenChannel(ctx sdk.Context, portID, channelID string) { +func (k Keeper) WriteUpgradeOpenChannel(ctx sdk.Context, portID, channelID string) { channel, found := k.GetChannel(ctx, portID, channelID) if !found { panic(fmt.Sprintf("could not find existing channel when updating channel state, channelID: %s, portID: %s", channelID, portID)) diff --git a/modules/core/keeper/msg_server.go b/modules/core/keeper/msg_server.go index 4e202658d53..b7c68c0045b 100644 --- a/modules/core/keeper/msg_server.go +++ b/modules/core/keeper/msg_server.go @@ -861,7 +861,32 @@ func (k Keeper) ChannelUpgradeAck(goCtx context.Context, msg *channeltypes.MsgCh // ChannelUpgradeOpen defines a rpc handler method for MsgChannelUpgradeOpen. func (k Keeper) ChannelUpgradeOpen(goCtx context.Context, msg *channeltypes.MsgChannelUpgradeOpen) (*channeltypes.MsgChannelUpgradeOpenResponse, error) { - return nil, nil + ctx := sdk.UnwrapSDKContext(goCtx) + + module, _, err := k.ChannelKeeper.LookupModuleByChannel(ctx, msg.PortId, msg.ChannelId) + if err != nil { + ctx.Logger().Error("channel upgrade open failed", "port-id", msg.PortId, "error", errorsmod.Wrap(err, "could not retrieve module from port-id")) + return nil, errorsmod.Wrap(err, "could not retrieve module from port-id") + } + + cbs, ok := k.Router.GetRoute(module) + if !ok { + ctx.Logger().Error("channel upgrade open failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module)) + return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module) + } + + if err = k.ChannelKeeper.ChanUpgradeOpen(ctx, msg.PortId, msg.ChannelId, msg.CounterpartyChannelState, msg.ProofChannel, msg.ProofHeight); err != nil { + ctx.Logger().Error("channel upgrade open failed", "error", errorsmod.Wrap(err, "channel upgrade open failed")) + return nil, errorsmod.Wrap(err, "channel upgrade open failed") + } + + cbs.OnChanUpgradeOpen(ctx, msg.PortId, msg.ChannelId) + + k.ChannelKeeper.WriteUpgradeOpenChannel(ctx, msg.PortId, msg.ChannelId) + + ctx.Logger().Info("channel upgrade open succeeded", "port-id", msg.PortId, "channel-id", msg.ChannelId) + + return &channeltypes.MsgChannelUpgradeOpenResponse{}, nil } // ChannelUpgradeTimeout defines a rpc handler method for MsgChannelUpgradeTimeout. diff --git a/modules/core/keeper/msg_server_test.go b/modules/core/keeper/msg_server_test.go index 6285d8bfae2..ba62d025cfe 100644 --- a/modules/core/keeper/msg_server_test.go +++ b/modules/core/keeper/msg_server_test.go @@ -1074,6 +1074,103 @@ func (suite *KeeperTestSuite) TestChannelUpgradeAck() { } } +func (suite *KeeperTestSuite) TestChannelUpgradeOpen() { + var ( + path *ibctesting.Path + msg *channeltypes.MsgChannelUpgradeOpen + ) + + cases := []struct { + name string + malleate func() + expResult func(res *channeltypes.MsgChannelUpgradeOpenResponse, err error) + }{ + { + "success", + func() {}, + func(res *channeltypes.MsgChannelUpgradeOpenResponse, err error) { + suite.Require().NoError(err) + suite.Require().NotNil(res) + + channel := path.EndpointA.GetChannel() + suite.Require().Equal(channeltypes.OPEN, channel.State) + suite.Require().Equal(channeltypes.NOTINFLUSH, channel.FlushStatus) + }, + }, + { + "module capability not found", + func() { + msg.PortId = ibctesting.InvalidID + msg.ChannelId = ibctesting.InvalidID + }, + func(res *channeltypes.MsgChannelUpgradeOpenResponse, err error) { + suite.Require().Error(err) + suite.Require().Nil(res) + + suite.Require().ErrorIs(err, capabilitytypes.ErrCapabilityNotFound) + }, + }, + { + "core handler fails", + func() { + portID := path.EndpointA.ChannelConfig.PortID + channelID := path.EndpointA.ChannelID + // Set a dummy packet commitment to simulate in-flight packets + suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.SetPacketCommitment(suite.chainA.GetContext(), portID, channelID, 1, []byte("hash")) + }, + func(res *channeltypes.MsgChannelUpgradeOpenResponse, err error) { + suite.Require().Error(err) + suite.Require().Nil(res) + + suite.Require().ErrorIs(err, channeltypes.ErrPendingInflightPackets) + }, + }, + } + + for _, tc := range cases { + tc := tc + suite.Run(tc.name, func() { + suite.SetupTest() + + path = ibctesting.NewPath(suite.chainA, suite.chainB) + suite.coordinator.Setup(path) + + // configure the channel upgrade version on testing endpoints + path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = ibcmock.UpgradeVersion + path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = ibcmock.UpgradeVersion + + err := path.EndpointB.ChanUpgradeInit() + suite.Require().NoError(err) + + err = path.EndpointA.ChanUpgradeTry() + suite.Require().NoError(err) + + err = path.EndpointB.ChanUpgradeAck() + suite.Require().NoError(err) + + err = path.EndpointA.UpdateClient() + suite.Require().NoError(err) + + counterpartyChannel := path.EndpointB.GetChannel() + proofChannel, _, proofHeight := path.EndpointA.QueryChannelUpgradeProof() + + msg = &channeltypes.MsgChannelUpgradeOpen{ + PortId: path.EndpointA.ChannelConfig.PortID, + ChannelId: path.EndpointA.ChannelID, + CounterpartyChannelState: counterpartyChannel.State, + ProofChannel: proofChannel, + ProofHeight: proofHeight, + Signer: suite.chainA.SenderAccount.GetAddress().String(), + } + + tc.malleate() + res, err := suite.chainA.GetSimApp().GetIBCKeeper().ChannelUpgradeOpen(suite.chainA.GetContext(), msg) + + tc.expResult(res, err) + }) + } +} + func (suite *KeeperTestSuite) TestChannelUpgradeCancel() { var ( path *ibctesting.Path