Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add multi message support for greenfield crosschain app #417

Merged
merged 5 commits into from
Apr 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 194 additions & 7 deletions x/oracle/keeper/msg_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,56 @@ import (
"context"
"encoding/hex"
unclezoro marked this conversation as resolved.
Show resolved Hide resolved
"fmt"
"math/big"
"runtime/debug"
"strings"

sdkerrors "cosmossdk.io/errors"
sdkmath "cosmossdk.io/math"
proto "github.com/cosmos/gogoproto/proto"

govtypes "github.com/cosmos/cosmos-sdk/x/gov/types"
proto "github.com/cosmos/gogoproto/proto"
"github.com/ethereum/go-ethereum/accounts/abi"

"github.com/cosmos/cosmos-sdk/bsc/rlp"
sdk "github.com/cosmos/cosmos-sdk/types"
crosschaintypes "github.com/cosmos/cosmos-sdk/x/crosschain/types"
"github.com/cosmos/cosmos-sdk/x/oracle/types"
)

const (
ChannelIdLength = 1
AckRelayFeeLength = 32
SoliditySelectorLength = 4
)

type msgServer struct {
Keeper
}

type MessagesType [][]byte

var (
Uint8, _ = abi.NewType("uint8", "", nil)
Bytes, _ = abi.NewType("bytes", "", nil)
Uint256, _ = abi.NewType("uint256", "", nil)
Address, _ = abi.NewType("address", "", nil)

MessageTypeArgs = abi.Arguments{
{Name: "ChannelId", Type: Uint8},
Copy link
Contributor

@pythonberg1997 pythonberg1997 Mar 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could use {Name: "ChannelId", Type: "uint8"} directly. No need to define a new type Uint8. So do other primary types.

{Name: "MsgBytes", Type: Bytes},
{Name: "RelayFee", Type: Uint256},
{Name: "AckRelayFee", Type: Uint256},
{Name: "Sender", Type: Address},
}

MessagesAbiDefinition = `[{ "name" : "method", "type": "function", "outputs": [{"type": "bytes[]"}]}]`
MessagesAbi, _ = abi.JSON(strings.NewReader(MessagesAbiDefinition))

AckMessagesAbiDefinition = `[{ "name" : "method", "type": "function", "inputs": [{"type": "bytes[]"}]}]`
AckMessagesAbi, _ = abi.JSON(strings.NewReader(AckMessagesAbiDefinition))
)

// NewMsgServerImpl returns an implementation of the oracle MsgServer interface
// for the provided Keeper.
func NewMsgServerImpl(k Keeper) types.MsgServer {
Expand Down Expand Up @@ -166,6 +198,79 @@ func (k Keeper) distributeReward(ctx sdk.Context, relayer sdk.AccAddress, signed
return nil
}

func (k Keeper) handleMultiMessagePackage(
ctx sdk.Context,
pack *types.Package,
packageHeader *sdk.PackageHeader,
srcChainId uint32,
) (crash bool, result sdk.ExecuteResult) {
defer func() {
if r := recover(); r != nil {
log := fmt.Sprintf("recovered: %v\nstack:\n%v", r, string(debug.Stack()))
logger := ctx.Logger().With("module", "oracle")
logger.Error("execute handleMultiMessagePackage panic", "err_log", log)
crash = true
result = sdk.ExecuteResult{
Err: fmt.Errorf("execute handleMultiMessagePackage failed: %v", r),
}
}
}()

messages, err := DecodeMultiMessage(pack.Payload[sdk.SynPackageHeaderLength+sdk.PackageTypeLength:])
if err != nil {
return true, sdk.ExecuteResult{
Err: err,
}
}

crash = false
result = sdk.ExecuteResult{}
ackMessages := make([][]byte, 0)
for i, message := range messages {
channelId, msgBytes, ackRelayFee, err := DecodeMessage(message)
if err != nil {
return true, sdk.ExecuteResult{
Err: err,
}
}

crossChainApp := k.CrossChainKeeper.GetCrossChainApp(sdk.ChannelID(channelId))
if crossChainApp == nil {
return true, sdk.ExecuteResult{
Err: sdkerrors.Wrapf(types.ErrChannelNotRegistered, "message %d, channel %d not registered", i, channelId),
}
}

msgHeader := sdk.PackageHeader{
PackageType: packageHeader.PackageType,
Timestamp: packageHeader.Timestamp,
RelayerFee: big.NewInt(0),
AckRelayerFee: ackRelayFee,
}

payload := append(make([]byte, sdk.SynPackageHeaderLength), msgBytes...)
crashSingleMsg, resultSingleMsg := executeClaim(ctx, crossChainApp, srcChainId, 0, payload, &msgHeader)
if crashSingleMsg {
return true, resultSingleMsg
}

if len(resultSingleMsg.Payload) != 0 {
ackMessages = append(ackMessages, EncodeAckMessage(channelId, ackRelayFee, resultSingleMsg.Payload))
}
}

if len(ackMessages) > 0 {
result.Payload, err = EncodeMultiAckMessage(ackMessages)
if err != nil {
return true, sdk.ExecuteResult{
Err: sdkerrors.Wrapf(types.ErrInvalidMessagesResult, "messages result pack failed, payloads=%v, error=%s", ackMessages, err),
}
}
}

return crash, result
}

func (k Keeper) handlePackage(
ctx sdk.Context,
pack *types.Package,
Expand All @@ -175,11 +280,6 @@ func (k Keeper) handlePackage(
) (sdkmath.Int, *types.EventPackageClaim, error) {
logger := k.Logger(ctx)

crossChainApp := k.CrossChainKeeper.GetCrossChainApp(pack.ChannelId)
if crossChainApp == nil {
return sdkmath.ZeroInt(), nil, sdkerrors.Wrapf(types.ErrChannelNotRegistered, "channel %d not registered", pack.ChannelId)
}

sequence := k.CrossChainKeeper.GetReceiveSequence(ctx, sdk.ChainID(srcChainId), pack.ChannelId)
if sequence != pack.Sequence {
return sdkmath.ZeroInt(), nil, sdkerrors.Wrapf(types.ErrInvalidReceiveSequence,
Expand All @@ -201,8 +301,20 @@ func (k Keeper) handlePackage(
"package type %d is invalid", packageHeader.PackageType)
}

crash := false
var result sdk.ExecuteResult
cacheCtx, write := ctx.CacheContext()
crash, result := executeClaim(cacheCtx, crossChainApp, srcChainId, sequence, pack.Payload, &packageHeader)

if pack.ChannelId == types.MultiMessageChannelId {
unclezoro marked this conversation as resolved.
Show resolved Hide resolved
crash, result = k.handleMultiMessagePackage(cacheCtx, pack, &packageHeader, srcChainId)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a hardfork, isn't it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, will add upgrade for it

unclezoro marked this conversation as resolved.
Show resolved Hide resolved
} else {
crossChainApp := k.CrossChainKeeper.GetCrossChainApp(pack.ChannelId)
if crossChainApp == nil {
return sdkmath.ZeroInt(), nil, sdkerrors.Wrapf(types.ErrChannelNotRegistered, "channel %d not registered", pack.ChannelId)
}
crash, result = executeClaim(cacheCtx, crossChainApp, srcChainId, sequence, pack.Payload, &packageHeader)
}

if result.IsOk() {
write()
}
Expand Down Expand Up @@ -295,3 +407,78 @@ func executeClaim(
}
return crash, result
}

func DecodeMultiMessage(multiMessagePayload []byte) (messages [][]byte, err error) {
out, err := MessagesAbi.Unpack("method", multiMessagePayload)
if err != nil {
return nil, sdkerrors.Wrapf(types.ErrInvalidMultiMessage, "messages unpack failed, payload=%s", hex.EncodeToString(multiMessagePayload))
}

unpacked := abi.ConvertType(out[0], MessagesType{})
messages, ok := unpacked.(MessagesType)
if !ok {
return nil, sdkerrors.Wrapf(types.ErrInvalidMultiMessage, "messages ConvertType failed, payload=%v", multiMessagePayload)
}

if len(messages) == 0 {
return nil, sdkerrors.Wrapf(types.ErrInvalidMultiMessage, "empty messages, payload=%v", multiMessagePayload)
}

return messages, nil
}

func DecodeMessage(message []byte) (channelId uint8, msgBytes []byte, ackRelayFee *big.Int, err error) {
unpacked, err := MessageTypeArgs.Unpack(message)
if err != nil || len(unpacked) != 5 {
return 0, nil, nil, sdkerrors.Wrapf(types.ErrInvalidMultiMessage, "decode message error, message=%v, error: %s", message, err)
}

channelIdType := abi.ConvertType(unpacked[0], uint8(0))
msgBytesType := abi.ConvertType(unpacked[1], []byte{})
ackRelayFeeType := abi.ConvertType(unpacked[3], big.NewInt(0))

channelId, ok := channelIdType.(uint8)
if !ok {
return 0, nil, nil, sdkerrors.Wrapf(types.ErrInvalidMultiMessage, "decode channelId error, message=%v, error: %v", message, err)
}

msgBytes, ok = msgBytesType.([]byte)
if !ok {
return 0, nil, nil, sdkerrors.Wrapf(types.ErrInvalidMultiMessage, "decode msgBytes error, message=%v, error: %v", message, err)
}

ackRelayFee, ok = ackRelayFeeType.(*big.Int)
if !ok {
return 0, nil, nil, sdkerrors.Wrapf(types.ErrInvalidMultiMessage, "decode ackRelayFee error, message=%v, error: %v", message, err)
}

if len(ackRelayFee.Bytes()) > 32 {
return 0, nil, nil, sdkerrors.Wrapf(types.ErrInvalidMultiMessage, "ackRelayFee too large, ackRelayFee=%v ", ackRelayFee.Bytes())
}

return channelId, msgBytes, ackRelayFee, nil
}

func EncodeAckMessage(channelId uint8, ackRelayFee *big.Int, result []byte) (ackMessage []byte) {
resultPayloadLength := len(result)
ackMessage = make([]byte, ChannelIdLength+AckRelayFeeLength+resultPayloadLength)
ackMessage[0] = channelId

ackRelayFeeBytes := ackRelayFee.Bytes()
copy(ackMessage[ChannelIdLength+AckRelayFeeLength-len(ackRelayFeeBytes):], ackRelayFeeBytes)

if resultPayloadLength > 0 {
copy(ackMessage[ChannelIdLength+AckRelayFeeLength:], result)
}

return ackMessage
}

func EncodeMultiAckMessage(ackMessages [][]byte) (encoded []byte, err error) {
encoded, err = AckMessagesAbi.Pack("method", ackMessages)
if err != nil {
return nil, sdkerrors.Wrapf(types.ErrInvalidMessagesResult, "ack messages pack failed, payloads=%v, error=%s", ackMessages, err)
}

return encoded[SoliditySelectorLength:], nil
}
50 changes: 50 additions & 0 deletions x/oracle/keeper/msg_server_test.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,32 @@
package keeper_test

import (
"encoding/hex"
"fmt"
"math/big"
"time"

"github.com/ethereum/go-ethereum/common/hexutil"

"github.com/golang/mock/gomock"
"github.com/willf/bitset"

"github.com/cosmos/cosmos-sdk/bsc/rlp"
sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/cosmos/cosmos-sdk/x/oracle/keeper"
"github.com/cosmos/cosmos-sdk/x/oracle/testutil"
"github.com/cosmos/cosmos-sdk/x/oracle/types"
stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types"
)

type DummyCrossChainApp struct{}

type packUnpackTest struct {
def string
unpacked interface{}
packed string
}

func (ta *DummyCrossChainApp) ExecuteSynPackage(ctx sdk.Context, header *sdk.CrossChainAppContext, payload []byte) sdk.ExecuteResult {
return sdk.ExecuteResult{}
}
Expand Down Expand Up @@ -179,3 +190,42 @@ func (s *TestSuite) TestInvalidClaim() {
s.Require().NotNil(err, "process claim should return error")
s.Require().Contains(err.Error(), "is not the same in payload header")
}

func (s *TestSuite) TestMultiMessageDecode() {
msg1, _ := hexutil.Decode("000000000000000000000000000000000000000000000000000000000000000600000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000e35fa931a00000000000000000000000000000000000000000000000000001626218b45860000000000000000000000000007fa9385be102ac3eac297483dd6233d62b3e149600000000000000000000000000000000000000000000000000000000000000e10200000000000000000000000000000000000000000000000000000000000000200000000000000000000000007fa9385be102ac3eac297483dd6233d62b3e1496000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000000a000000000000000000000000000000000000000000000000000000000000000057465737431000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
msg2, _ := hexutil.Decode("000000000000000000000000000000000000000000000000000000000000000600000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000e35fa931a00000000000000000000000000000000000000000000000000001626218b45860000000000000000000000000007fa9385be102ac3eac297483dd6233d62b3e149600000000000000000000000000000000000000000000000000000000000000e10200000000000000000000000000000000000000000000000000000000000000200000000000000000000000007fa9385be102ac3eac297483dd6233d62b3e1496000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000000a000000000000000000000000000000000000000000000000000000000000000057465737432000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")

tests := []packUnpackTest{
{
def: `[{"type": "bytes[]"}]`,
packed: "000000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000022000000000000000000000000000000000000000000000000000000000000001c0000000000000000000000000000000000000000000000000000000000000000600000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000e35fa931a00000000000000000000000000000000000000000000000000001626218b45860000000000000000000000000007fa9385be102ac3eac297483dd6233d62b3e149600000000000000000000000000000000000000000000000000000000000000e10200000000000000000000000000000000000000000000000000000000000000200000000000000000000000007fa9385be102ac3eac297483dd6233d62b3e1496000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000005746573743100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001c0000000000000000000000000000000000000000000000000000000000000000600000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000e35fa931a00000000000000000000000000000000000000000000000000001626218b45860000000000000000000000000007fa9385be102ac3eac297483dd6233d62b3e149600000000000000000000000000000000000000000000000000000000000000e10200000000000000000000000000000000000000000000000000000000000000200000000000000000000000007fa9385be102ac3eac297483dd6233d62b3e1496000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000000a000000000000000000000000000000000000000000000000000000000000000057465737432000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
unpacked: [][]byte{msg1, msg2},
},
}

for i, test := range tests {
encb, err := hex.DecodeString(test.packed)
s.Require().Nilf(err, "invalid hex %s: %v", test.packed, err)

messages, err := keeper.DecodeMultiMessage(encb)
s.Require().Nilf(err, "test %d (%v) failed: %v", i, test.def, err)

for _, message := range messages {
fmt.Println("message", hex.EncodeToString(message))

channelId, msgBytes, ackRelayFee, err := keeper.DecodeMessage(message)
s.Require().Nil(err, "unpack error")

fmt.Println(channelId, msgBytes, ackRelayFee)
}
}
}

func (s *TestSuite) TestMultiAckMessageEncode() {
msg1, _ := hexutil.Decode("000000000000000000000000000000000000000000000000000000000000000600000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000e35fa931a00000000000000000000000000000000000000000000000000001626218b45860000000000000000000000000007fa9385be102ac3eac297483dd6233d62b3e149600000000000000000000000000000000000000000000000000000000000000e10200000000000000000000000000000000000000000000000000000000000000200000000000000000000000007fa9385be102ac3eac297483dd6233d62b3e1496000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000000a000000000000000000000000000000000000000000000000000000000000000057465737431000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
msg2, _ := hexutil.Decode("000000000000000000000000000000000000000000000000000000000000000600000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000e35fa931a00000000000000000000000000000000000000000000000000001626218b45860000000000000000000000000007fa9385be102ac3eac297483dd6233d62b3e149600000000000000000000000000000000000000000000000000000000000000e10200000000000000000000000000000000000000000000000000000000000000200000000000000000000000007fa9385be102ac3eac297483dd6233d62b3e1496000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000000a000000000000000000000000000000000000000000000000000000000000000057465737432000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
data := [][]byte{msg1, msg2}

_, err := keeper.EncodeMultiAckMessage(data)
s.Require().Nil(err, "EncodeMultiAckMessage error")
}
2 changes: 2 additions & 0 deletions x/oracle/types/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,6 @@ var (
ErrInvalidDestChainId = errors.Register(ModuleName, 14, "dest chain id is invalid")
ErrInvalidSrcChainId = errors.Register(ModuleName, 15, "src chain id is invalid")
ErrInvalidAddress = errors.Register(ModuleName, 16, "address is invalid")
ErrInvalidMultiMessage = errors.Register(ModuleName, 17, "multi message is invalid")
ErrInvalidMessagesResult = errors.Register(ModuleName, 18, "multi message result is invalid")
)
1 change: 1 addition & 0 deletions x/oracle/types/keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ const (
// RelayPackagesChannelId is not a communication channel actually, we just use it to record sequence.
RelayPackagesChannelName = "relayPackages"
RelayPackagesChannelId sdk.ChannelID = 0x00
MultiMessageChannelId sdk.ChannelID = 0x08
)
Loading