Skip to content

Commit

Permalink
feat: add multi message support for greenfield crosschain app (#417)
Browse files Browse the repository at this point in the history
* feat: add multi message package

* fix: decode multi-message methods

* chores: modify decode tests

* chores: rename payloads to ackMessages

* fix: no ack package in multi-message if no need
  • Loading branch information
cosinlink authored Apr 7, 2024
1 parent e0a150a commit 308bb78
Show file tree
Hide file tree
Showing 4 changed files with 247 additions and 7 deletions.
201 changes: 194 additions & 7 deletions x/oracle/keeper/msg_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,56 @@ import (
"context"
"encoding/hex"
"fmt"
"math/big"
"runtime/debug"
"strings"

sdkerrors "cosmossdk.io/errors"
sdkmath "cosmossdk.io/math"
proto "github.com/cosmos/gogoproto/proto"

govtypes "github.com/cosmos/cosmos-sdk/x/gov/types"
proto "github.com/cosmos/gogoproto/proto"
"github.com/ethereum/go-ethereum/accounts/abi"

"github.com/cosmos/cosmos-sdk/bsc/rlp"
sdk "github.com/cosmos/cosmos-sdk/types"
crosschaintypes "github.com/cosmos/cosmos-sdk/x/crosschain/types"
"github.com/cosmos/cosmos-sdk/x/oracle/types"
)

const (
ChannelIdLength = 1
AckRelayFeeLength = 32
SoliditySelectorLength = 4
)

type msgServer struct {
Keeper
}

type MessagesType [][]byte

var (
Uint8, _ = abi.NewType("uint8", "", nil)
Bytes, _ = abi.NewType("bytes", "", nil)
Uint256, _ = abi.NewType("uint256", "", nil)
Address, _ = abi.NewType("address", "", nil)

MessageTypeArgs = abi.Arguments{
{Name: "ChannelId", Type: Uint8},
{Name: "MsgBytes", Type: Bytes},
{Name: "RelayFee", Type: Uint256},
{Name: "AckRelayFee", Type: Uint256},
{Name: "Sender", Type: Address},
}

MessagesAbiDefinition = `[{ "name" : "method", "type": "function", "outputs": [{"type": "bytes[]"}]}]`
MessagesAbi, _ = abi.JSON(strings.NewReader(MessagesAbiDefinition))

AckMessagesAbiDefinition = `[{ "name" : "method", "type": "function", "inputs": [{"type": "bytes[]"}]}]`
AckMessagesAbi, _ = abi.JSON(strings.NewReader(AckMessagesAbiDefinition))
)

// NewMsgServerImpl returns an implementation of the oracle MsgServer interface
// for the provided Keeper.
func NewMsgServerImpl(k Keeper) types.MsgServer {
Expand Down Expand Up @@ -166,6 +198,79 @@ func (k Keeper) distributeReward(ctx sdk.Context, relayer sdk.AccAddress, signed
return nil
}

func (k Keeper) handleMultiMessagePackage(
ctx sdk.Context,
pack *types.Package,
packageHeader *sdk.PackageHeader,
srcChainId uint32,
) (crash bool, result sdk.ExecuteResult) {
defer func() {
if r := recover(); r != nil {
log := fmt.Sprintf("recovered: %v\nstack:\n%v", r, string(debug.Stack()))
logger := ctx.Logger().With("module", "oracle")
logger.Error("execute handleMultiMessagePackage panic", "err_log", log)
crash = true
result = sdk.ExecuteResult{
Err: fmt.Errorf("execute handleMultiMessagePackage failed: %v", r),
}
}
}()

messages, err := DecodeMultiMessage(pack.Payload[sdk.SynPackageHeaderLength+sdk.PackageTypeLength:])
if err != nil {
return true, sdk.ExecuteResult{
Err: err,
}
}

crash = false
result = sdk.ExecuteResult{}
ackMessages := make([][]byte, 0)
for i, message := range messages {
channelId, msgBytes, ackRelayFee, err := DecodeMessage(message)
if err != nil {
return true, sdk.ExecuteResult{
Err: err,
}
}

crossChainApp := k.CrossChainKeeper.GetCrossChainApp(sdk.ChannelID(channelId))
if crossChainApp == nil {
return true, sdk.ExecuteResult{
Err: sdkerrors.Wrapf(types.ErrChannelNotRegistered, "message %d, channel %d not registered", i, channelId),
}
}

msgHeader := sdk.PackageHeader{
PackageType: packageHeader.PackageType,
Timestamp: packageHeader.Timestamp,
RelayerFee: big.NewInt(0),
AckRelayerFee: ackRelayFee,
}

payload := append(make([]byte, sdk.SynPackageHeaderLength), msgBytes...)
crashSingleMsg, resultSingleMsg := executeClaim(ctx, crossChainApp, srcChainId, 0, payload, &msgHeader)
if crashSingleMsg {
return true, resultSingleMsg
}

if len(resultSingleMsg.Payload) != 0 {
ackMessages = append(ackMessages, EncodeAckMessage(channelId, ackRelayFee, resultSingleMsg.Payload))
}
}

if len(ackMessages) > 0 {
result.Payload, err = EncodeMultiAckMessage(ackMessages)
if err != nil {
return true, sdk.ExecuteResult{
Err: sdkerrors.Wrapf(types.ErrInvalidMessagesResult, "messages result pack failed, payloads=%v, error=%s", ackMessages, err),
}
}
}

return crash, result
}

func (k Keeper) handlePackage(
ctx sdk.Context,
pack *types.Package,
Expand All @@ -175,11 +280,6 @@ func (k Keeper) handlePackage(
) (sdkmath.Int, *types.EventPackageClaim, error) {
logger := k.Logger(ctx)

crossChainApp := k.CrossChainKeeper.GetCrossChainApp(pack.ChannelId)
if crossChainApp == nil {
return sdkmath.ZeroInt(), nil, sdkerrors.Wrapf(types.ErrChannelNotRegistered, "channel %d not registered", pack.ChannelId)
}

sequence := k.CrossChainKeeper.GetReceiveSequence(ctx, sdk.ChainID(srcChainId), pack.ChannelId)
if sequence != pack.Sequence {
return sdkmath.ZeroInt(), nil, sdkerrors.Wrapf(types.ErrInvalidReceiveSequence,
Expand All @@ -201,8 +301,20 @@ func (k Keeper) handlePackage(
"package type %d is invalid", packageHeader.PackageType)
}

crash := false
var result sdk.ExecuteResult
cacheCtx, write := ctx.CacheContext()
crash, result := executeClaim(cacheCtx, crossChainApp, srcChainId, sequence, pack.Payload, &packageHeader)

if pack.ChannelId == types.MultiMessageChannelId {
crash, result = k.handleMultiMessagePackage(cacheCtx, pack, &packageHeader, srcChainId)
} else {
crossChainApp := k.CrossChainKeeper.GetCrossChainApp(pack.ChannelId)
if crossChainApp == nil {
return sdkmath.ZeroInt(), nil, sdkerrors.Wrapf(types.ErrChannelNotRegistered, "channel %d not registered", pack.ChannelId)
}
crash, result = executeClaim(cacheCtx, crossChainApp, srcChainId, sequence, pack.Payload, &packageHeader)
}

if result.IsOk() {
write()
}
Expand Down Expand Up @@ -295,3 +407,78 @@ func executeClaim(
}
return crash, result
}

func DecodeMultiMessage(multiMessagePayload []byte) (messages [][]byte, err error) {
out, err := MessagesAbi.Unpack("method", multiMessagePayload)
if err != nil {
return nil, sdkerrors.Wrapf(types.ErrInvalidMultiMessage, "messages unpack failed, payload=%s", hex.EncodeToString(multiMessagePayload))
}

unpacked := abi.ConvertType(out[0], MessagesType{})
messages, ok := unpacked.(MessagesType)
if !ok {
return nil, sdkerrors.Wrapf(types.ErrInvalidMultiMessage, "messages ConvertType failed, payload=%v", multiMessagePayload)
}

if len(messages) == 0 {
return nil, sdkerrors.Wrapf(types.ErrInvalidMultiMessage, "empty messages, payload=%v", multiMessagePayload)
}

return messages, nil
}

func DecodeMessage(message []byte) (channelId uint8, msgBytes []byte, ackRelayFee *big.Int, err error) {
unpacked, err := MessageTypeArgs.Unpack(message)
if err != nil || len(unpacked) != 5 {
return 0, nil, nil, sdkerrors.Wrapf(types.ErrInvalidMultiMessage, "decode message error, message=%v, error: %s", message, err)
}

channelIdType := abi.ConvertType(unpacked[0], uint8(0))
msgBytesType := abi.ConvertType(unpacked[1], []byte{})
ackRelayFeeType := abi.ConvertType(unpacked[3], big.NewInt(0))

channelId, ok := channelIdType.(uint8)
if !ok {
return 0, nil, nil, sdkerrors.Wrapf(types.ErrInvalidMultiMessage, "decode channelId error, message=%v, error: %v", message, err)
}

msgBytes, ok = msgBytesType.([]byte)
if !ok {
return 0, nil, nil, sdkerrors.Wrapf(types.ErrInvalidMultiMessage, "decode msgBytes error, message=%v, error: %v", message, err)
}

ackRelayFee, ok = ackRelayFeeType.(*big.Int)
if !ok {
return 0, nil, nil, sdkerrors.Wrapf(types.ErrInvalidMultiMessage, "decode ackRelayFee error, message=%v, error: %v", message, err)
}

if len(ackRelayFee.Bytes()) > 32 {
return 0, nil, nil, sdkerrors.Wrapf(types.ErrInvalidMultiMessage, "ackRelayFee too large, ackRelayFee=%v ", ackRelayFee.Bytes())
}

return channelId, msgBytes, ackRelayFee, nil
}

func EncodeAckMessage(channelId uint8, ackRelayFee *big.Int, result []byte) (ackMessage []byte) {
resultPayloadLength := len(result)
ackMessage = make([]byte, ChannelIdLength+AckRelayFeeLength+resultPayloadLength)
ackMessage[0] = channelId

ackRelayFeeBytes := ackRelayFee.Bytes()
copy(ackMessage[ChannelIdLength+AckRelayFeeLength-len(ackRelayFeeBytes):], ackRelayFeeBytes)

if resultPayloadLength > 0 {
copy(ackMessage[ChannelIdLength+AckRelayFeeLength:], result)
}

return ackMessage
}

func EncodeMultiAckMessage(ackMessages [][]byte) (encoded []byte, err error) {
encoded, err = AckMessagesAbi.Pack("method", ackMessages)
if err != nil {
return nil, sdkerrors.Wrapf(types.ErrInvalidMessagesResult, "ack messages pack failed, payloads=%v, error=%s", ackMessages, err)
}

return encoded[SoliditySelectorLength:], nil
}
50 changes: 50 additions & 0 deletions x/oracle/keeper/msg_server_test.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,32 @@
package keeper_test

import (
"encoding/hex"
"fmt"
"math/big"
"time"

"github.com/ethereum/go-ethereum/common/hexutil"

"github.com/golang/mock/gomock"
"github.com/willf/bitset"

"github.com/cosmos/cosmos-sdk/bsc/rlp"
sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/cosmos/cosmos-sdk/x/oracle/keeper"
"github.com/cosmos/cosmos-sdk/x/oracle/testutil"
"github.com/cosmos/cosmos-sdk/x/oracle/types"
stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types"
)

type DummyCrossChainApp struct{}

type packUnpackTest struct {
def string
unpacked interface{}
packed string
}

func (ta *DummyCrossChainApp) ExecuteSynPackage(ctx sdk.Context, header *sdk.CrossChainAppContext, payload []byte) sdk.ExecuteResult {
return sdk.ExecuteResult{}
}
Expand Down Expand Up @@ -179,3 +190,42 @@ func (s *TestSuite) TestInvalidClaim() {
s.Require().NotNil(err, "process claim should return error")
s.Require().Contains(err.Error(), "is not the same in payload header")
}

func (s *TestSuite) TestMultiMessageDecode() {
msg1, _ := hexutil.Decode("000000000000000000000000000000000000000000000000000000000000000600000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000e35fa931a00000000000000000000000000000000000000000000000000001626218b45860000000000000000000000000007fa9385be102ac3eac297483dd6233d62b3e149600000000000000000000000000000000000000000000000000000000000000e10200000000000000000000000000000000000000000000000000000000000000200000000000000000000000007fa9385be102ac3eac297483dd6233d62b3e1496000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000000a000000000000000000000000000000000000000000000000000000000000000057465737431000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
msg2, _ := hexutil.Decode("000000000000000000000000000000000000000000000000000000000000000600000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000e35fa931a00000000000000000000000000000000000000000000000000001626218b45860000000000000000000000000007fa9385be102ac3eac297483dd6233d62b3e149600000000000000000000000000000000000000000000000000000000000000e10200000000000000000000000000000000000000000000000000000000000000200000000000000000000000007fa9385be102ac3eac297483dd6233d62b3e1496000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000000a000000000000000000000000000000000000000000000000000000000000000057465737432000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")

tests := []packUnpackTest{
{
def: `[{"type": "bytes[]"}]`,
packed: "000000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000022000000000000000000000000000000000000000000000000000000000000001c0000000000000000000000000000000000000000000000000000000000000000600000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000e35fa931a00000000000000000000000000000000000000000000000000001626218b45860000000000000000000000000007fa9385be102ac3eac297483dd6233d62b3e149600000000000000000000000000000000000000000000000000000000000000e10200000000000000000000000000000000000000000000000000000000000000200000000000000000000000007fa9385be102ac3eac297483dd6233d62b3e1496000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000005746573743100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001c0000000000000000000000000000000000000000000000000000000000000000600000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000e35fa931a00000000000000000000000000000000000000000000000000001626218b45860000000000000000000000000007fa9385be102ac3eac297483dd6233d62b3e149600000000000000000000000000000000000000000000000000000000000000e10200000000000000000000000000000000000000000000000000000000000000200000000000000000000000007fa9385be102ac3eac297483dd6233d62b3e1496000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000000a000000000000000000000000000000000000000000000000000000000000000057465737432000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
unpacked: [][]byte{msg1, msg2},
},
}

for i, test := range tests {
encb, err := hex.DecodeString(test.packed)
s.Require().Nilf(err, "invalid hex %s: %v", test.packed, err)

messages, err := keeper.DecodeMultiMessage(encb)
s.Require().Nilf(err, "test %d (%v) failed: %v", i, test.def, err)

for _, message := range messages {
fmt.Println("message", hex.EncodeToString(message))

channelId, msgBytes, ackRelayFee, err := keeper.DecodeMessage(message)
s.Require().Nil(err, "unpack error")

fmt.Println(channelId, msgBytes, ackRelayFee)
}
}
}

func (s *TestSuite) TestMultiAckMessageEncode() {
msg1, _ := hexutil.Decode("000000000000000000000000000000000000000000000000000000000000000600000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000e35fa931a00000000000000000000000000000000000000000000000000001626218b45860000000000000000000000000007fa9385be102ac3eac297483dd6233d62b3e149600000000000000000000000000000000000000000000000000000000000000e10200000000000000000000000000000000000000000000000000000000000000200000000000000000000000007fa9385be102ac3eac297483dd6233d62b3e1496000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000000a000000000000000000000000000000000000000000000000000000000000000057465737431000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
msg2, _ := hexutil.Decode("000000000000000000000000000000000000000000000000000000000000000600000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000e35fa931a00000000000000000000000000000000000000000000000000001626218b45860000000000000000000000000007fa9385be102ac3eac297483dd6233d62b3e149600000000000000000000000000000000000000000000000000000000000000e10200000000000000000000000000000000000000000000000000000000000000200000000000000000000000007fa9385be102ac3eac297483dd6233d62b3e1496000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000000a000000000000000000000000000000000000000000000000000000000000000057465737432000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
data := [][]byte{msg1, msg2}

_, err := keeper.EncodeMultiAckMessage(data)
s.Require().Nil(err, "EncodeMultiAckMessage error")
}
2 changes: 2 additions & 0 deletions x/oracle/types/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,6 @@ var (
ErrInvalidDestChainId = errors.Register(ModuleName, 14, "dest chain id is invalid")
ErrInvalidSrcChainId = errors.Register(ModuleName, 15, "src chain id is invalid")
ErrInvalidAddress = errors.Register(ModuleName, 16, "address is invalid")
ErrInvalidMultiMessage = errors.Register(ModuleName, 17, "multi message is invalid")
ErrInvalidMessagesResult = errors.Register(ModuleName, 18, "multi message result is invalid")
)
1 change: 1 addition & 0 deletions x/oracle/types/keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ const (
// RelayPackagesChannelId is not a communication channel actually, we just use it to record sequence.
RelayPackagesChannelName = "relayPackages"
RelayPackagesChannelId sdk.ChannelID = 0x00
MultiMessageChannelId sdk.ChannelID = 0x08
)

0 comments on commit 308bb78

Please sign in to comment.