Skip to content

Commit

Permalink
fix some issue and add msg server tests
Browse files Browse the repository at this point in the history
  • Loading branch information
yutianwu committed Dec 29, 2022
1 parent 86de683 commit b306473
Show file tree
Hide file tree
Showing 10 changed files with 285 additions and 80 deletions.
2 changes: 1 addition & 1 deletion proto/cosmos/oracle/v1/event.proto
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@ message EventPackageClaim {
int64 send_sequence = 5;
bool crash = 6;
string error_msg = 7;
int64 fee = 8;
string relay_fee = 8;
}
1 change: 1 addition & 0 deletions simapp/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ var (
stakingtypes.NotBondedPoolName: {authtypes.Burner, authtypes.Staking},
govtypes.ModuleName: {authtypes.Burner},
nft.ModuleName: nil,
crosschaintypes.ModuleName: nil,
}
)

Expand Down
6 changes: 3 additions & 3 deletions types/cross_chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func ParseChainID(input string) (ChainID, error) {
}

type CrossChainApplication interface {
ExecuteSynPackage(ctx Context, payload []byte, relayerFee int64) ExecuteResult
ExecuteSynPackage(ctx Context, payload []byte, relayerFee *big.Int) ExecuteResult
ExecuteAckPackage(ctx Context, payload []byte) ExecuteResult
// When the ack application crash, payload is the payload of the origin package.
ExecuteFailAckPackage(ctx Context, payload []byte) ExecuteResult
Expand Down Expand Up @@ -89,7 +89,7 @@ func EncodePackageHeader(packageType CrossChainPackageType, timestamp uint64, re

timestampBytes := make([]byte, TimestampLength)
binary.BigEndian.PutUint64(timestampBytes, timestamp)
copy(packageHeader[CrossChainFeeLength:CrossChainFeeLength+TimestampLength], timestampBytes)
copy(packageHeader[PackageTypeLength:PackageTypeLength+TimestampLength], timestampBytes)

length := len(relayerFee.Bytes())
copy(packageHeader[PackageHeaderLength-length:PackageHeaderLength], relayerFee.Bytes())
Expand All @@ -103,7 +103,7 @@ func DecodePackageHeader(packageHeader []byte) (packageType CrossChainPackageTyp
}
packageType = CrossChainPackageType(packageHeader[0])

timestamp = binary.BigEndian.Uint64(packageHeader[PackageTypeLength : CrossChainFeeLength+TimestampLength])
timestamp = binary.BigEndian.Uint64(packageHeader[PackageTypeLength : PackageTypeLength+TimestampLength])

relayFee.SetBytes(packageHeader[PackageTypeLength+TimestampLength : PackageHeaderLength])
return
Expand Down
6 changes: 4 additions & 2 deletions x/crosschain/testutil/mockapp.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 17 additions & 0 deletions x/oracle/keeper/keeper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ import (
tmtime "github.com/tendermint/tendermint/types/time"
"github.com/willf/bitset"

crosschaintypes "github.com/cosmos/cosmos-sdk/x/crosschain/types"
minttypes "github.com/cosmos/cosmos-sdk/x/mint/types"

"github.com/cosmos/cosmos-sdk/x/oracle/keeper"

"github.com/cosmos/cosmos-sdk/crypto/keys/ed25519"
cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types"
"github.com/cosmos/cosmos-sdk/simapp"
Expand All @@ -28,6 +33,8 @@ type TestSuite struct {

app *simapp.SimApp
ctx sdk.Context

msgServer types.MsgServer
}

func (s *TestSuite) SetupTest() {
Expand All @@ -39,6 +46,16 @@ func (s *TestSuite) SetupTest() {

s.app = app
s.ctx = ctx

s.app.CrossChainKeeper.SetSrcChainID(sdk.ChainID(1))

coins := sdk.NewCoins(sdk.NewCoin("stake", sdk.NewInt(100000)))
err := s.app.BankKeeper.MintCoins(ctx, minttypes.ModuleName, coins)
s.NoError(err)
err = app.BankKeeper.SendCoinsFromModuleToModule(ctx, minttypes.ModuleName, crosschaintypes.ModuleName, coins)
s.NoError(err)

s.msgServer = keeper.NewMsgServerImpl(s.app.OracleKeeper)
}

func TestTestSuite(t *testing.T) {
Expand Down
25 changes: 12 additions & 13 deletions x/oracle/keeper/msg_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/hex"
"fmt"
"math/big"
"runtime/debug"

sdkerrors "cosmossdk.io/errors"
Expand Down Expand Up @@ -31,6 +32,8 @@ var _ types.MsgServer = msgServer{}
func (k msgServer) Claim(goCtx context.Context, req *types.MsgClaim) (*types.MsgClaimResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)

logger := k.oracleKeeper.Logger(ctx)

// check dest chain id
if sdk.ChainID(req.DestChainId) != k.oracleKeeper.CrossChainKeeper.GetSrcChainID() {
return nil, sdkerrors.Wrapf(types.ErrInvalidDestChainId, fmt.Sprintf("dest chain id(%d) should be %d", req.SrcChainId, k.oracleKeeper.CrossChainKeeper.GetSrcChainID()))
Expand All @@ -57,10 +60,10 @@ func (k msgServer) Claim(goCtx context.Context, req *types.MsgClaim) (*types.Msg
event, err := handlePackage(ctx, req, k.oracleKeeper, sdk.ChainID(req.SrcChainId), &pack)
if err != nil {
// only do log, but let reset package get chance to execute.
ctx.Logger().With("module", "oracle").Error(fmt.Sprintf("process package failed, channel=%d, sequence=%d, error=%v", pack.ChannelId, pack.Sequence, err))
logger.Error(fmt.Sprintf("process package failed, channel=%d, sequence=%d, error=%v", pack.ChannelId, pack.Sequence, err))
return nil, err
}
ctx.Logger().With("module", "oracle").Info(fmt.Sprintf("process package success, channel=%d, sequence=%d", pack.ChannelId, pack.Sequence))
logger.Info(fmt.Sprintf("process package success, channel=%d, sequence=%d", pack.ChannelId, pack.Sequence))

events = append(events, event)

Expand All @@ -74,7 +77,7 @@ func (k msgServer) Claim(goCtx context.Context, req *types.MsgClaim) (*types.Msg
}

func handlePackage(ctx sdk.Context, req *types.MsgClaim, oracleKeeper Keeper, chainId sdk.ChainID, pack *types.Package) (*types.EventPackageClaim, error) {
logger := ctx.Logger().With("module", "x/oracle")
logger := oracleKeeper.Logger(ctx)

crossChainApp := oracleKeeper.CrossChainKeeper.GetCrossChainApp(pack.ChannelId)
if crossChainApp == nil {
Expand All @@ -92,26 +95,22 @@ func handlePackage(ctx sdk.Context, req *types.MsgClaim, oracleKeeper Keeper, ch
}

if timestamp != req.Timestamp {
return nil, sdkerrors.Wrapf(types.ErrInvalidPayloadHeader, "timestamp is not the same in payload header")
return nil, sdkerrors.Wrapf(types.ErrInvalidPayloadHeader, "timestamp(%d) is not the same in payload header(%d)", req.Timestamp, timestamp)
}

if !sdk.IsValidCrossChainPackageType(packageType) {
return nil, sdkerrors.Wrapf(types.ErrInvalidPackageType, fmt.Sprintf("package type %d is invalid", packageType))
}

feeAmount := relayFee.Int64()
if feeAmount < 0 {
return nil, sdkerrors.Wrapf(types.ErrFeeOverflow, fmt.Sprintf("fee(%s) is overflow", relayFee.String()))
}

fee := sdk.Coins{sdk.Coin{Denom: sdk.NativeTokenSymbol, Amount: sdk.NewInt(feeAmount)}}
bondDenom := oracleKeeper.StakingKeeper.BondDenom(ctx)
fee := sdk.Coins{sdk.Coin{Denom: bondDenom, Amount: sdk.NewIntFromBigInt(&relayFee)}}
err = oracleKeeper.SendCoinsToFeeCollector(ctx, fee)
if err != nil {
return nil, err
}

cacheCtx, write := ctx.CacheContext()
crash, result := executeClaim(cacheCtx, crossChainApp, pack.Payload, packageType, feeAmount)
crash, result := executeClaim(cacheCtx, crossChainApp, pack.Payload, packageType, &relayFee)
if result.IsOk() {
write()
} else {
Expand Down Expand Up @@ -156,15 +155,15 @@ func handlePackage(ctx sdk.Context, req *types.MsgClaim, oracleKeeper Keeper, ch
PackageType: uint32(packageType),
ReceiveSequence: pack.Sequence,
SendSequence: sendSequence,
Fee: feeAmount,
RelayFee: relayFee.String(),
Crash: crash,
ErrorMsg: result.ErrMsg(),
}

return claimEvent, nil
}

func executeClaim(ctx sdk.Context, app sdk.CrossChainApplication, payload []byte, packageType sdk.CrossChainPackageType, relayerFee int64) (crash bool, result sdk.ExecuteResult) {
func executeClaim(ctx sdk.Context, app sdk.CrossChainApplication, payload []byte, packageType sdk.CrossChainPackageType, relayerFee *big.Int) (crash bool, result sdk.ExecuteResult) {
defer func() {
if r := recover(); r != nil {
log := fmt.Sprintf("recovered: %v\nstack:\n%v", r, string(debug.Stack()))
Expand Down
147 changes: 147 additions & 0 deletions x/oracle/keeper/msg_server_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package keeper_test

import (
"math/big"
"time"

"github.com/willf/bitset"

"github.com/cosmos/cosmos-sdk/bsc/rlp"
sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/cosmos/cosmos-sdk/x/oracle/testutil"
"github.com/cosmos/cosmos-sdk/x/oracle/types"
)

type DummyCrossChainApp struct {
}

func (ta *DummyCrossChainApp) ExecuteSynPackage(ctx sdk.Context, payload []byte, relayerFee *big.Int) sdk.ExecuteResult {
return sdk.ExecuteResult{}
}

func (ta *DummyCrossChainApp) ExecuteAckPackage(ctx sdk.Context, payload []byte) sdk.ExecuteResult {
return sdk.ExecuteResult{}
}

func (ta *DummyCrossChainApp) ExecuteFailAckPackage(ctx sdk.Context, payload []byte) sdk.ExecuteResult {
return sdk.ExecuteResult{}
}

func (s *TestSuite) TestClaim() {
s.app.CrossChainKeeper.RegisterChannel("test", sdk.ChannelID(1), &DummyCrossChainApp{})

s.app.OracleKeeper.SetParams(s.ctx, types.Params{
RelayerTimeout: 5,
RelayerBackoffTime: 3,
})

_, _, newValidators, blsKeys := createValidators(s.T(), s.ctx, s.app, []int64{9, 8, 7})

validators := s.app.StakingKeeper.GetLastValidators(s.ctx)
validatorMap := make(map[string]int, 0)
for idx, validator := range validators {
validatorMap[validator.RelayerAddress] = idx
}

payloadHeader := sdk.EncodePackageHeader(sdk.SynCrossChainPackageType, 1992, *big.NewInt(1))

testPackage := types.Package{
ChannelId: 1,
Sequence: 0,
Payload: append(payloadHeader, []byte("test payload")...),
}

packageBytes, err := rlp.EncodeToBytes([]types.Package{testPackage})
s.Require().Nil(err, "encode package error")

msgClaim := types.MsgClaim{
FromAddress: validators[0].RelayerAddress,
SrcChainId: 56,
DestChainId: 1,
Sequence: 0,
Timestamp: 1992,
Payload: packageBytes,
VoteAddressSet: []uint64{0, 1},
AggSignature: []byte("test sig"),
}

blsSignBytes := msgClaim.GetBlsSignBytes()

valBitSet := bitset.New(256)
for _, newValidator := range newValidators {
valBitSet.Set(uint(validatorMap[newValidator.RelayerAddress]))
}

blsSig := testutil.GenerateBlsSig(blsKeys, blsSignBytes[:])
msgClaim.VoteAddressSet = valBitSet.Bytes()
msgClaim.AggSignature = blsSig

s.ctx = s.ctx.WithBlockTime(time.Unix(int64(msgClaim.Timestamp), 0))
_, err = s.msgServer.Claim(s.ctx, &msgClaim)
s.Require().Nil(err, "process claim msg error")
}

func (s *TestSuite) TestInvalidClaim() {
s.app.CrossChainKeeper.RegisterChannel("test", sdk.ChannelID(1), &DummyCrossChainApp{})

s.app.OracleKeeper.SetParams(s.ctx, types.Params{
RelayerTimeout: 5,
RelayerBackoffTime: 3,
})

_, _, newValidators, blsKeys := createValidators(s.T(), s.ctx, s.app, []int64{9, 8, 7})

validators := s.app.StakingKeeper.GetLastValidators(s.ctx)
validatorMap := make(map[string]int, 0)
for idx, validator := range validators {
validatorMap[validator.RelayerAddress] = idx
}

msgClaim := types.MsgClaim{
FromAddress: validators[0].RelayerAddress,
SrcChainId: 56,
DestChainId: 1,
Sequence: 0,
Timestamp: 1992,
Payload: []byte("invalid payload"),
VoteAddressSet: []uint64{0, 1},
AggSignature: []byte("test sig"),
}

blsSignBytes := msgClaim.GetBlsSignBytes()

valBitSet := bitset.New(256)
for _, newValidator := range newValidators {
valBitSet.Set(uint(validatorMap[newValidator.RelayerAddress]))
}

blsSig := testutil.GenerateBlsSig(blsKeys, blsSignBytes[:])
msgClaim.VoteAddressSet = valBitSet.Bytes()
msgClaim.AggSignature = blsSig

s.ctx = s.ctx.WithBlockTime(time.Unix(int64(msgClaim.Timestamp), 0))
_, err := s.msgServer.Claim(s.ctx, &msgClaim)
s.Require().NotNil(err, "process claim should return error")
s.Require().Contains(err.Error(), "decode payload error")

// invalid timestamp
payloadHeader := sdk.EncodePackageHeader(sdk.SynCrossChainPackageType, 1993, *big.NewInt(1))
testPackage := types.Package{
ChannelId: 1,
Sequence: 0,
Payload: append(payloadHeader, []byte("test payload")...),
}

packageBytes, err := rlp.EncodeToBytes([]types.Package{testPackage})
s.Require().Nil(err, "encode package error")

msgClaim.Payload = packageBytes
blsSignBytes = msgClaim.GetBlsSignBytes()
blsSig = testutil.GenerateBlsSig(blsKeys, blsSignBytes[:])
msgClaim.AggSignature = blsSig

s.ctx = s.ctx.WithBlockTime(time.Unix(int64(msgClaim.Timestamp), 0))
_, err = s.msgServer.Claim(s.ctx, &msgClaim)
s.Require().NotNil(err, "process claim should return error")
s.Require().Contains(err.Error(), "is not the same in payload header")
}
Loading

0 comments on commit b306473

Please sign in to comment.