From a6cfc5d33307b1b2ba3aa08e77fb8d071ce30dfb Mon Sep 17 00:00:00 2001 From: Owen Date: Wed, 22 Mar 2023 20:06:29 +0800 Subject: [PATCH] chore: refine oracle module (#145) --- proto/cosmos/oracle/v1/oracle.proto | 8 +- x/oracle/client/cli/query.go | 32 +++++++ x/oracle/keeper/keeper.go | 50 ++++++----- x/oracle/keeper/msg_server.go | 128 +++++++++++++--------------- 4 files changed, 127 insertions(+), 91 deletions(-) diff --git a/proto/cosmos/oracle/v1/oracle.proto b/proto/cosmos/oracle/v1/oracle.proto index 673272fd80..87f7dedc43 100644 --- a/proto/cosmos/oracle/v1/oracle.proto +++ b/proto/cosmos/oracle/v1/oracle.proto @@ -5,10 +5,10 @@ option go_package = "github.com/cosmos/cosmos-sdk/x/oracle/types"; // Params holds parameters for the oracle module. message Params { - // Timeout for the in turn relayer - uint64 relayer_timeout = 1; // in s - // RelayInterval is for in-turn relayer - uint64 relayer_interval = 2; // in s + // Timeout for the in turn relayer in seconds + uint64 relayer_timeout = 1; + // RelayInterval is for in-turn relayer in seconds + uint64 relayer_interval = 2; // Reward share for the relayer sends the claim message, // the other relayers signed the bls message will share the reward evenly. uint32 relayer_reward_share = 3; // in percentage diff --git a/x/oracle/client/cli/query.go b/x/oracle/client/cli/query.go index f9f9ad0aa5..e21562cc98 100644 --- a/x/oracle/client/cli/query.go +++ b/x/oracle/client/cli/query.go @@ -22,6 +22,7 @@ func GetQueryCmd() *cobra.Command { cmd.AddCommand( QueryParamsCmd(), + QueryInturnRelayerCmd(), ) return cmd @@ -57,3 +58,34 @@ $ query oracle params return cmd } + +// QueryParamsCmd returns the command handler for evidence parameter querying. +func QueryInturnRelayerCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "inturn-relayer", + Short: "Query the inturn relayer", + Args: cobra.NoArgs, + Long: strings.TrimSpace(`Query the inturn relayer: + +$ query oracle inturn-relayer +`), + RunE: func(cmd *cobra.Command, args []string) error { + clientCtx, err := client.GetClientQueryContext(cmd) + if err != nil { + return err + } + + queryClient := types.NewQueryClient(clientCtx) + res, err := queryClient.InturnRelayer(cmd.Context(), &types.QueryInturnRelayerRequest{}) + if err != nil { + return err + } + + return clientCtx.PrintProto(res) + }, + } + + flags.AddQueryFlagsToCmd(cmd) + + return cmd +} diff --git a/x/oracle/keeper/keeper.go b/x/oracle/keeper/keeper.go index 8d53373099..cb78ad5b45 100644 --- a/x/oracle/keeper/keeper.go +++ b/x/oracle/keeper/keeper.go @@ -1,8 +1,8 @@ package keeper import ( + "bytes" "encoding/hex" - "fmt" sdkerrors "cosmossdk.io/errors" @@ -96,32 +96,36 @@ func (k Keeper) IsRelayerValid(ctx sdk.Context, relayer sdk.AccAddress, validato } if validatorIndex < 0 { - return false, sdkerrors.Wrapf(types.ErrNotRelayer, fmt.Sprintf("sender(%s) is not a relayer", relayer.String())) + return false, sdkerrors.Wrapf(types.ErrNotRelayer, "sender(%s) is not a relayer", relayer.String()) } inturnRelayerTimeout, relayerInterval := k.GetRelayerParams(ctx) // check whether submitter of msgClaim is an in-turn relayer - inturnRelayer, err := k.GetInturnRelayer(ctx, relayerInterval) + inturnRelayerBlsKey, _, err := k.getInturnRelayer(ctx, relayerInterval) if err != nil { return false, err } - if inturnRelayer.BlsPubKey == hex.EncodeToString(vldr.BlsKey) { + if bytes.Equal(inturnRelayerBlsKey, vldr.BlsKey) { return true, nil } // It is possible that claim comes from out-turn relayers when exceeding the inturnRelayerTimeout, all other // relayers can relay within the in-turn relayer's current interval curTime := ctx.BlockTime().Unix() + if uint64(curTime) < claimTimestamp { + return false, nil + } + return uint64(curTime)-claimTimestamp >= inturnRelayerTimeout, nil } // CheckClaim checks the bls signature -func (k Keeper) CheckClaim(ctx sdk.Context, claim *types.MsgClaim) (sdk.AccAddress, []string, error) { +func (k Keeper) CheckClaim(ctx sdk.Context, claim *types.MsgClaim) (sdk.AccAddress, []sdk.AccAddress, error) { relayer, err := sdk.AccAddressFromHexUnsafe(claim.FromAddress) if err != nil { - return sdk.AccAddress{}, nil, sdkerrors.Wrapf(types.ErrInvalidAddress, fmt.Sprintf("from address (%s) is invalid", claim.FromAddress)) + return sdk.AccAddress{}, nil, sdkerrors.Wrapf(types.ErrInvalidAddress, "from address (%s) is invalid", claim.FromAddress) } historicalInfo, ok := k.StakingKeeper.GetHistoricalInfo(ctx, ctx.BlockHeight()) @@ -136,7 +140,7 @@ func (k Keeper) CheckClaim(ctx sdk.Context, claim *types.MsgClaim) (sdk.AccAddre } if !isValid { - return sdk.AccAddress{}, nil, sdkerrors.Wrapf(types.ErrRelayerNotInTurn, fmt.Sprintf("relayer(%s) is not in turn", claim.FromAddress)) + return sdk.AccAddress{}, nil, sdkerrors.Wrapf(types.ErrRelayerNotInTurn, "relayer(%s) is not in turn", claim.FromAddress) } validatorsBitSet := bitset.From(claim.VoteAddressSet) @@ -144,31 +148,31 @@ func (k Keeper) CheckClaim(ctx sdk.Context, claim *types.MsgClaim) (sdk.AccAddre return sdk.AccAddress{}, nil, sdkerrors.Wrapf(types.ErrValidatorSet, "number of validator set is larger than validators") } - signedRelayers := make([]string, 0, validatorsBitSet.Count()) + signedRelayers := make([]sdk.AccAddress, 0, validatorsBitSet.Count()) votedPubKeys := make([]bls.PublicKey, 0, validatorsBitSet.Count()) for index, val := range validators { if !validatorsBitSet.Test(uint(index)) { continue } - signedRelayers = append(signedRelayers, val.RelayerAddress) + signedRelayers = append(signedRelayers, sdk.MustAccAddressFromHex(val.RelayerAddress)) votePubKey, err := bls.PublicKeyFromBytes(val.BlsKey) if err != nil { - return sdk.AccAddress{}, nil, sdkerrors.Wrapf(types.ErrBlsPubKey, fmt.Sprintf("BLS public key converts failed: %v", err)) + return sdk.AccAddress{}, nil, sdkerrors.Wrapf(types.ErrBlsPubKey, "BLS public key converts failed: %v", err) } votedPubKeys = append(votedPubKeys, votePubKey) } // The valid voted validators should be no less than 2/3 validators. if len(votedPubKeys) <= len(validators)*2/3 { - return sdk.AccAddress{}, nil, sdkerrors.Wrapf(types.ErrBlsVotesNotEnough, fmt.Sprintf("not enough validators voted, need: %d, voted: %d", len(validators)*2/3, len(votedPubKeys))) + return sdk.AccAddress{}, nil, sdkerrors.Wrapf(types.ErrBlsVotesNotEnough, "not enough validators voted, need: %d, voted: %d", len(validators)*2/3, len(votedPubKeys)) } // Verify the aggregated signature. aggSig, err := bls.SignatureFromBytes(claim.AggSignature) if err != nil { - return sdk.AccAddress{}, nil, sdkerrors.Wrapf(types.ErrInvalidBlsSignature, fmt.Sprintf("BLS signature converts failed: %v", err)) + return sdk.AccAddress{}, nil, sdkerrors.Wrapf(types.ErrInvalidBlsSignature, "BLS signature converts failed: %v", err) } if !aggSig.FastAggregateVerify(votedPubKeys, claim.GetBlsSignBytes()) { @@ -184,10 +188,10 @@ func (k Keeper) GetParams(ctx sdk.Context) (params types.Params) { return params } -func (k Keeper) GetInturnRelayer(ctx sdk.Context, relayerInterval uint64) (*types.QueryInturnRelayerResponse, error) { +func (k Keeper) getInturnRelayer(ctx sdk.Context, relayerInterval uint64) ([]byte, *types.RelayInterval, error) { historicalInfo, ok := k.StakingKeeper.GetHistoricalInfo(ctx, ctx.BlockHeight()) if !ok { - return nil, sdkerrors.Wrapf(types.ErrValidatorSet, "get historical validators failed") + return nil, nil, sdkerrors.Wrapf(types.ErrValidatorSet, "get historical validators failed") } validators := historicalInfo.Valset @@ -207,12 +211,20 @@ func (k Keeper) GetInturnRelayer(ctx sdk.Context, relayerInterval uint64) (*type inturnRelayer := validators[inTurnRelayerIndex] + return inturnRelayer.BlsKey, &types.RelayInterval{ + Start: start, + End: end, + }, nil +} + +func (k Keeper) GetInturnRelayer(ctx sdk.Context, relayerInterval uint64) (*types.QueryInturnRelayerResponse, error) { + blsKey, interval, err := k.getInturnRelayer(ctx, relayerInterval) + if err != nil { + return nil, err + } res := &types.QueryInturnRelayerResponse{ - BlsPubKey: hex.EncodeToString(inturnRelayer.BlsKey), - RelayInterval: &types.RelayInterval{ - Start: start, - End: end, - }, + BlsPubKey: hex.EncodeToString(blsKey), + RelayInterval: interval, } return res, nil } diff --git a/x/oracle/keeper/msg_server.go b/x/oracle/keeper/msg_server.go index daa54719a2..551a3ff49d 100644 --- a/x/oracle/keeper/msg_server.go +++ b/x/oracle/keeper/msg_server.go @@ -4,10 +4,10 @@ import ( "context" "encoding/hex" "fmt" - "math/big" "runtime/debug" sdkerrors "cosmossdk.io/errors" + sdkmath "cosmossdk.io/math" "github.com/gogo/protobuf/proto" "github.com/cosmos/cosmos-sdk/bsc/rlp" @@ -17,14 +17,14 @@ import ( ) type msgServer struct { - oracleKeeper Keeper + Keeper } // NewMsgServerImpl returns an implementation of the oracle MsgServer interface // for the provided Keeper. func NewMsgServerImpl(k Keeper) types.MsgServer { return &msgServer{ - oracleKeeper: k, + k, } } @@ -33,24 +33,24 @@ var _ types.MsgServer = msgServer{} func (k msgServer) Claim(goCtx context.Context, req *types.MsgClaim) (*types.MsgClaimResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) - logger := k.oracleKeeper.Logger(ctx) + logger := k.Logger(ctx) // check dest chain id - if sdk.ChainID(req.DestChainId) != k.oracleKeeper.CrossChainKeeper.GetSrcChainID() { - return nil, sdkerrors.Wrapf(types.ErrInvalidDestChainId, fmt.Sprintf("dest chain id(%d) should be %d", req.DestChainId, k.oracleKeeper.CrossChainKeeper.GetSrcChainID())) + if sdk.ChainID(req.DestChainId) != k.CrossChainKeeper.GetSrcChainID() { + return nil, sdkerrors.Wrapf(types.ErrInvalidDestChainId, "dest chain id(%d) should be %d", req.DestChainId, k.CrossChainKeeper.GetSrcChainID()) } // check src chain id - if !k.oracleKeeper.CrossChainKeeper.IsDestChainSupported(sdk.ChainID(req.SrcChainId)) { - return nil, sdkerrors.Wrapf(types.ErrInvalidSrcChainId, fmt.Sprintf("src chain id(%d) is not supported", req.SrcChainId)) + if !k.CrossChainKeeper.IsDestChainSupported(sdk.ChainID(req.SrcChainId)) { + return nil, sdkerrors.Wrapf(types.ErrInvalidSrcChainId, "src chain id(%d) is not supported", req.SrcChainId) } - sequence := k.oracleKeeper.CrossChainKeeper.GetReceiveSequence(ctx, types.RelayPackagesChannelId) + sequence := k.CrossChainKeeper.GetReceiveSequence(ctx, types.RelayPackagesChannelId) if sequence != req.Sequence { - return nil, sdkerrors.Wrapf(types.ErrInvalidReceiveSequence, fmt.Sprintf("current sequence of channel %d is %d", types.RelayPackagesChannelId, sequence)) + return nil, sdkerrors.Wrapf(types.ErrInvalidReceiveSequence, "current sequence of channel %d is %d", types.RelayPackagesChannelId, sequence) } - relayer, signedRelayers, err := k.oracleKeeper.CheckClaim(ctx, req) + relayer, signedRelayers, err := k.CheckClaim(ctx, req) if err != nil { return nil, err } @@ -62,13 +62,12 @@ func (k msgServer) Claim(goCtx context.Context, req *types.MsgClaim) (*types.Msg } events := make([]proto.Message, 0, len(packages)) - totalRelayerFee := big.NewInt(0) + totalRelayerFee := sdkmath.ZeroInt() for idx := range packages { pack := packages[idx] - relayerFee, event, err := handlePackage(ctx, k.oracleKeeper, &pack, req.SrcChainId, req.DestChainId, req.Timestamp) + relayerFee, event, err := k.handlePackage(ctx, &pack, req.SrcChainId, req.DestChainId, req.Timestamp) if err != nil { - // only do log, but let rest package get chance to execute. logger.Error("process package failed", "channel", pack.ChannelId, "sequence", pack.Sequence, "error", err.Error()) return nil, err } @@ -76,74 +75,71 @@ func (k msgServer) Claim(goCtx context.Context, req *types.MsgClaim) (*types.Msg events = append(events, event) - totalRelayerFee = totalRelayerFee.Add(totalRelayerFee, relayerFee) + totalRelayerFee = totalRelayerFee.Add(relayerFee) // increase channel sequence - k.oracleKeeper.CrossChainKeeper.IncrReceiveSequence(ctx, pack.ChannelId) + k.CrossChainKeeper.IncrReceiveSequence(ctx, pack.ChannelId) } - err = distributeReward(ctx, k.oracleKeeper, relayer, signedRelayers, totalRelayerFee) + err = k.distributeReward(ctx, relayer, signedRelayers, totalRelayerFee) if err != nil { return nil, err } - k.oracleKeeper.CrossChainKeeper.IncrReceiveSequence(ctx, types.RelayPackagesChannelId) + k.CrossChainKeeper.IncrReceiveSequence(ctx, types.RelayPackagesChannelId) - ctx.EventManager().EmitTypedEvents(events...) + err = ctx.EventManager().EmitTypedEvents(events...) + if err != nil { + return nil, err + } return &types.MsgClaimResponse{}, nil } // distributeReward will distribute reward to relayers -func distributeReward(ctx sdk.Context, oracleKeeper Keeper, relayer sdk.AccAddress, signedRelayers []string, relayerFee *big.Int) error { - if relayerFee.Cmp(big.NewInt(0)) <= 0 { - oracleKeeper.Logger(ctx).Info("total relayer fee is zero") +func (k Keeper) distributeReward(ctx sdk.Context, relayer sdk.AccAddress, signedRelayers []sdk.AccAddress, relayerFee sdkmath.Int) error { + if !relayerFee.IsPositive() { + k.Logger(ctx).Info("total relayer fee is zero") return nil } otherRelayers := make([]sdk.AccAddress, 0, len(signedRelayers)) - for idx := range signedRelayers { - signedRelayerAddr, err := sdk.AccAddressFromHexUnsafe(signedRelayers[idx]) - if err != nil { - return sdkerrors.Wrapf(types.ErrInvalidAddress, fmt.Sprintf("relayer address (%s) is invalid", relayer)) - } - if !signedRelayerAddr.Equals(relayer) { + for _, signedRelayer := range signedRelayers { + if !signedRelayer.Equals(relayer) { otherRelayers = append(otherRelayers, relayer) } } - totalDistributed, otherRelayerReward := big.NewInt(0), big.NewInt(0) + totalDistributed, otherRelayerReward := sdkmath.ZeroInt(), sdkmath.ZeroInt() - relayerRewardShare := oracleKeeper.GetRelayerRewardShare(ctx) + relayerRewardShare := k.GetRelayerRewardShare(ctx) // calculate the reward to distribute to each other relayer if len(otherRelayers) > 0 { - otherRelayerReward = otherRelayerReward.Mul(big.NewInt(100-int64(relayerRewardShare)), relayerFee) - otherRelayerReward = otherRelayerReward.Div(otherRelayerReward, big.NewInt(100)) - otherRelayerReward = otherRelayerReward.Div(otherRelayerReward, big.NewInt(int64(len(otherRelayers)))) + otherRelayerReward = relayerFee.Mul(sdkmath.NewInt(100 - int64(relayerRewardShare))).Mul(relayerFee).Quo(sdkmath.NewInt(100)).Quo(sdkmath.NewInt(int64(len(otherRelayers)))) } - bondDenom := oracleKeeper.StakingKeeper.BondDenom(ctx) - if otherRelayerReward.Cmp(big.NewInt(0)) > 0 { - for idx := range otherRelayers { - err := oracleKeeper.BankKeeper.SendCoinsFromModuleToAccount(ctx, + bondDenom := k.StakingKeeper.BondDenom(ctx) + if otherRelayerReward.IsPositive() { + for _, signedRelayer := range otherRelayers { + err := k.BankKeeper.SendCoinsFromModuleToAccount(ctx, crosschaintypes.ModuleName, - otherRelayers[idx], - sdk.Coins{sdk.Coin{Denom: bondDenom, Amount: sdk.NewIntFromBigInt(otherRelayerReward)}}, + signedRelayer, + sdk.Coins{sdk.Coin{Denom: bondDenom, Amount: otherRelayerReward}}, ) if err != nil { return err } - totalDistributed = totalDistributed.Add(totalDistributed, otherRelayerReward) + totalDistributed = totalDistributed.Add(otherRelayerReward) } } - remainingReward := relayerFee.Sub(relayerFee, totalDistributed) - if remainingReward.Cmp(big.NewInt(0)) > 0 { - err := oracleKeeper.BankKeeper.SendCoinsFromModuleToAccount(ctx, + remainingReward := relayerFee.Sub(totalDistributed) + if remainingReward.IsPositive() { + err := k.BankKeeper.SendCoinsFromModuleToAccount(ctx, crosschaintypes.ModuleName, relayer, - sdk.Coins{sdk.Coin{Denom: bondDenom, Amount: sdk.NewIntFromBigInt(remainingReward)}}, + sdk.Coins{sdk.Coin{Denom: bondDenom, Amount: remainingReward}}, ) if err != nil { return err @@ -153,40 +149,39 @@ func distributeReward(ctx sdk.Context, oracleKeeper Keeper, relayer sdk.AccAddre return nil } -func handlePackage( +func (k Keeper) handlePackage( ctx sdk.Context, - oracleKeeper Keeper, pack *types.Package, srcChainId uint32, destChainId uint32, timestamp uint64, -) (*big.Int, *types.EventPackageClaim, error) { - logger := oracleKeeper.Logger(ctx) +) (sdkmath.Int, *types.EventPackageClaim, error) { + logger := k.Logger(ctx) - crossChainApp := oracleKeeper.CrossChainKeeper.GetCrossChainApp(pack.ChannelId) + crossChainApp := k.CrossChainKeeper.GetCrossChainApp(pack.ChannelId) if crossChainApp == nil { - return nil, nil, sdkerrors.Wrapf(types.ErrChannelNotRegistered, "channel %d not registered", pack.ChannelId) + return sdkmath.ZeroInt(), nil, sdkerrors.Wrapf(types.ErrChannelNotRegistered, "channel %d not registered", pack.ChannelId) } - sequence := oracleKeeper.CrossChainKeeper.GetReceiveSequence(ctx, pack.ChannelId) + sequence := k.CrossChainKeeper.GetReceiveSequence(ctx, pack.ChannelId) if sequence != pack.Sequence { - return nil, nil, sdkerrors.Wrapf(types.ErrInvalidReceiveSequence, - fmt.Sprintf("current sequence of channel %d is %d", pack.ChannelId, sequence)) + return sdkmath.ZeroInt(), nil, sdkerrors.Wrapf(types.ErrInvalidReceiveSequence, + "current sequence of channel %d is %d", pack.ChannelId, sequence) } packageHeader, err := sdk.DecodePackageHeader(pack.Payload) if err != nil { - return nil, nil, sdkerrors.Wrapf(types.ErrInvalidPayloadHeader, "payload header is invalid") + return sdkmath.ZeroInt(), nil, sdkerrors.Wrapf(types.ErrInvalidPayloadHeader, "payload header is invalid") } if packageHeader.Timestamp != timestamp { - return nil, nil, sdkerrors.Wrapf(types.ErrInvalidPayloadHeader, + return sdkmath.ZeroInt(), nil, sdkerrors.Wrapf(types.ErrInvalidPayloadHeader, "timestamp(%d) is not the same in payload header(%d)", timestamp, packageHeader.Timestamp) } if !sdk.IsValidCrossChainPackageType(packageHeader.PackageType) { - return nil, nil, sdkerrors.Wrapf(types.ErrInvalidPackageType, - fmt.Sprintf("package type %d is invalid", packageHeader.PackageType)) + return sdkmath.ZeroInt(), nil, sdkerrors.Wrapf(types.ErrInvalidPackageType, + "package type %d is invalid", packageHeader.PackageType) } cacheCtx, write := ctx.CacheContext() @@ -199,28 +194,25 @@ func handlePackage( var sendSequence int64 = -1 if packageHeader.PackageType == sdk.SynCrossChainPackageType { if crash { - var ibcErr error - var sendSeq uint64 - if len(pack.Payload) >= sdk.SynPackageHeaderLength { - sendSeq, ibcErr = oracleKeeper.CrossChainKeeper.CreateRawIBCPackageWithFee(ctx, pack.ChannelId, - sdk.FailAckCrossChainPackageType, pack.Payload[sdk.SynPackageHeaderLength:], packageHeader.AckRelayerFee, sdk.NilAckRelayerFee) - } else { + if len(pack.Payload) < sdk.SynPackageHeaderLength { logger.Error("found payload without header", "channelID", pack.ChannelId, "sequence", pack.Sequence, "payload", hex.EncodeToString(pack.Payload)) - return nil, nil, sdkerrors.Wrapf(types.ErrInvalidPackage, "payload without header") + return sdkmath.ZeroInt(), nil, sdkerrors.Wrapf(types.ErrInvalidPackage, "payload without header") } + sendSeq, ibcErr := k.CrossChainKeeper.CreateRawIBCPackageWithFee(ctx, pack.ChannelId, + sdk.FailAckCrossChainPackageType, pack.Payload[sdk.SynPackageHeaderLength:], packageHeader.AckRelayerFee, sdk.NilAckRelayerFee) if ibcErr != nil { logger.Error("failed to write FailAckCrossChainPackage", "err", err) - return nil, nil, ibcErr + return sdkmath.ZeroInt(), nil, ibcErr } sendSequence = int64(sendSeq) } else if len(result.Payload) != 0 { - sendSeq, err := oracleKeeper.CrossChainKeeper.CreateRawIBCPackageWithFee(ctx, pack.ChannelId, + sendSeq, err := k.CrossChainKeeper.CreateRawIBCPackageWithFee(ctx, pack.ChannelId, sdk.AckCrossChainPackageType, result.Payload, packageHeader.AckRelayerFee, sdk.NilAckRelayerFee) if err != nil { logger.Error("failed to write AckCrossChainPackage", "err", err) - return nil, nil, err + return sdkmath.ZeroInt(), nil, err } sendSequence = int64(sendSeq) } @@ -239,7 +231,7 @@ func handlePackage( ErrorMsg: result.ErrMsg(), } - return packageHeader.RelayerFee, claimEvent, nil + return sdkmath.NewIntFromBigInt(packageHeader.RelayerFee), claimEvent, nil } func executeClaim(