Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
yutianwu committed Dec 29, 2022
1 parent 7c4339a commit 99de73d
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 4 deletions.
12 changes: 11 additions & 1 deletion x/crosschain/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func (k Keeper) RegisterChannel(name string, id sdk.ChannelID, app sdk.CrossChai
return nil
}

// RegisterDestChain registers a chain with name
// RegisterDestChain registers a dest chain
func (k Keeper) RegisterDestChain(chainID sdk.ChainID) error {
for _, chain := range k.cfg.destChains {
if chainID == chain {
Expand All @@ -143,6 +143,16 @@ func (k Keeper) RegisterDestChain(chainID sdk.ChainID) error {
return nil
}

// IsDestChainSupported returns the support status of a dest chain
func (k Keeper) IsDestChainSupported(chainID sdk.ChainID) bool {
for _, chain := range k.cfg.destChains {
if chainID == chain {
return true
}
}
return false
}

// SetChannelSendPermission sets the channel send permission
func (k Keeper) SetChannelSendPermission(ctx sdk.Context, destChainID sdk.ChainID, channelID sdk.ChannelID, permission sdk.ChannelPermission) {
kvStore := ctx.KVStore(k.storeKey)
Expand Down
6 changes: 5 additions & 1 deletion x/oracle/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,11 @@ func (k Keeper) IsValidatorInturn(ctx sdk.Context, validators []stakingtypes.Val

// ProcessClaim checks the bls signature
func (k Keeper) ProcessClaim(ctx sdk.Context, claim *types.MsgClaim) error {
validators := k.StakingKeeper.GetLastValidators(ctx)
historicalInfo, ok := k.StakingKeeper.GetHistoricalInfo(ctx, ctx.BlockHeight())
if !ok {
return sdkerrors.Wrapf(types.ErrValidatorSet, fmt.Sprintf("get historical validators failed"))
}
validators := historicalInfo.Valset

inturn, err := k.IsValidatorInturn(ctx, validators, claim)
if err != nil {
Expand Down
6 changes: 6 additions & 0 deletions x/oracle/keeper/keeper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ func (s *TestSuite) TestProcessClaim() {
_, _, newValidators, blsKeys := createValidators(s.T(), s.ctx, s.app, []int64{9, 8, 7})

validators := s.app.StakingKeeper.GetLastValidators(s.ctx)

s.app.StakingKeeper.SetHistoricalInfo(s.ctx, s.ctx.BlockHeight(), &stakingtypes.HistoricalInfo{
Header: s.ctx.BlockHeader(),
Valset: validators,
})

validatorMap := make(map[string]int, 0)
for idx, validator := range validators {
validatorMap[validator.RelayerAddress] = idx
Expand Down
7 changes: 6 additions & 1 deletion x/oracle/keeper/msg_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,12 @@ func (k msgServer) Claim(goCtx context.Context, req *types.MsgClaim) (*types.Msg

// check dest chain id
if sdk.ChainID(req.DestChainId) != k.oracleKeeper.CrossChainKeeper.GetSrcChainID() {
return nil, sdkerrors.Wrapf(types.ErrInvalidDestChainId, fmt.Sprintf("dest chain id(%d) should be %d", req.SrcChainId, k.oracleKeeper.CrossChainKeeper.GetSrcChainID()))
return nil, sdkerrors.Wrapf(types.ErrInvalidDestChainId, fmt.Sprintf("dest chain id(%d) should be %d", req.DestChainId, k.oracleKeeper.CrossChainKeeper.GetSrcChainID()))
}

// check src chain id
if !k.oracleKeeper.CrossChainKeeper.IsDestChainSupported(sdk.ChainID(req.SrcChainId)) {
return nil, sdkerrors.Wrapf(types.ErrInvalidSrcChainId, fmt.Sprintf("src chain id(%d) is not supported", req.SrcChainId))
}

sequence := k.oracleKeeper.CrossChainKeeper.GetReceiveSequence(ctx, sdk.ChainID(req.SrcChainId), types.RelayPackagesChannelId)
Expand Down
23 changes: 23 additions & 0 deletions x/oracle/keeper/msg_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/cosmos/cosmos-sdk/x/oracle/testutil"
"github.com/cosmos/cosmos-sdk/x/oracle/types"
stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types"
)

type DummyCrossChainApp struct {
Expand All @@ -29,6 +30,7 @@ func (ta *DummyCrossChainApp) ExecuteFailAckPackage(ctx sdk.Context, payload []b

func (s *TestSuite) TestClaim() {
s.app.CrossChainKeeper.RegisterChannel("test", sdk.ChannelID(1), &DummyCrossChainApp{})
s.app.CrossChainKeeper.RegisterDestChain(sdk.ChainID(56))

s.app.OracleKeeper.SetParams(s.ctx, types.Params{
RelayerTimeout: 5,
Expand All @@ -38,6 +40,12 @@ func (s *TestSuite) TestClaim() {
_, _, newValidators, blsKeys := createValidators(s.T(), s.ctx, s.app, []int64{9, 8, 7})

validators := s.app.StakingKeeper.GetLastValidators(s.ctx)

s.app.StakingKeeper.SetHistoricalInfo(s.ctx, s.ctx.BlockHeight(), &stakingtypes.HistoricalInfo{
Header: s.ctx.BlockHeader(),
Valset: validators,
})

validatorMap := make(map[string]int, 0)
for idx, validator := range validators {
validatorMap[validator.RelayerAddress] = idx
Expand Down Expand Up @@ -92,6 +100,12 @@ func (s *TestSuite) TestInvalidClaim() {
_, _, newValidators, blsKeys := createValidators(s.T(), s.ctx, s.app, []int64{9, 8, 7})

validators := s.app.StakingKeeper.GetLastValidators(s.ctx)

s.app.StakingKeeper.SetHistoricalInfo(s.ctx, s.ctx.BlockHeight(), &stakingtypes.HistoricalInfo{
Header: s.ctx.BlockHeader(),
Valset: validators,
})

validatorMap := make(map[string]int, 0)
for idx, validator := range validators {
validatorMap[validator.RelayerAddress] = idx
Expand Down Expand Up @@ -119,9 +133,18 @@ func (s *TestSuite) TestInvalidClaim() {
msgClaim.VoteAddressSet = valBitSet.Bytes()
msgClaim.AggSignature = blsSig

// invalid src chain id
s.ctx = s.ctx.WithBlockTime(time.Unix(int64(msgClaim.Timestamp), 0))
_, err := s.msgServer.Claim(s.ctx, &msgClaim)
s.Require().NotNil(err, "process claim should return error")
s.Require().Contains(err.Error(), "src chain id is invalid")

s.app.CrossChainKeeper.RegisterDestChain(sdk.ChainID(56))

// invalid payload
s.ctx = s.ctx.WithBlockTime(time.Unix(int64(msgClaim.Timestamp), 0))
_, err = s.msgServer.Claim(s.ctx, &msgClaim)
s.Require().NotNil(err, "process claim should return error")
s.Require().Contains(err.Error(), "decode payload error")

// invalid timestamp
Expand Down
2 changes: 1 addition & 1 deletion x/oracle/types/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ var (
ErrInvalidReceiveSequence = sdkerrors.Register(ModuleName, 2, "receive sequence is invalid")
ErrInvalidPayloadHeader = sdkerrors.Register(ModuleName, 3, "payload header is invalid")
ErrInvalidPackageType = sdkerrors.Register(ModuleName, 4, "package type is invalid")
ErrFeeOverflow = sdkerrors.Register(ModuleName, 5, "fee is overflow")
ErrInvalidPackage = sdkerrors.Register(ModuleName, 6, "package is invalid")
ErrInvalidPayload = sdkerrors.Register(ModuleName, 7, "payload is invalid")
ErrValidatorSet = sdkerrors.Register(ModuleName, 8, "validator set is invalid")
Expand All @@ -17,4 +16,5 @@ var (
ErrNotValidator = sdkerrors.Register(ModuleName, 12, "sender is not validator")
ErrValidatorNotInTurn = sdkerrors.Register(ModuleName, 13, "validator is not in turn")
ErrInvalidDestChainId = sdkerrors.Register(ModuleName, 14, "dest chain id is invalid")
ErrInvalidSrcChainId = sdkerrors.Register(ModuleName, 15, "src chain id is invalid")
)
2 changes: 2 additions & 0 deletions x/oracle/types/expected_keepers.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

type StakingKeeper interface {
GetLastValidators(ctx sdk.Context) (validators []types.Validator)
GetHistoricalInfo(ctx sdk.Context, height int64) (types.HistoricalInfo, bool)
BondDenom(ctx sdk.Context) (res string)
}

Expand All @@ -15,6 +16,7 @@ type CrossChainKeeper interface {
packageType sdk.CrossChainPackageType, packageLoad []byte) (uint64, error)
GetCrossChainApp(channelID sdk.ChannelID) sdk.CrossChainApplication
GetSrcChainID() sdk.ChainID
IsDestChainSupported(chainID sdk.ChainID) bool
GetReceiveSequence(ctx sdk.Context, destChainID sdk.ChainID, channelID sdk.ChannelID) uint64
IncrReceiveSequence(ctx sdk.Context, destChainID sdk.ChainID, channelID sdk.ChannelID)
}
Expand Down

0 comments on commit 99de73d

Please sign in to comment.