diff --git a/types/address.go b/types/address.go index 9868755b32..931d549e06 100644 --- a/types/address.go +++ b/types/address.go @@ -6,8 +6,6 @@ import ( "encoding/json" "errors" "fmt" - "math/big" - "strings" "sync" "github.com/hashicorp/golang-lru/simplelru" @@ -155,22 +153,22 @@ func MustAccAddressFromHex(address string) AccAddress { // Note, this function is considered unsafe as it may produce an AccAddress from // otherwise invalid input, such as a transaction hash. func AccAddressFromHexUnsafe(address string) (AccAddress, error) { - addr := strings.ToLower(address) - if len(addr) >= 2 && addr[:2] == "0x" { - addr = addr[2:] - } - if len(strings.TrimSpace(addr)) == 0 { + if len(address) == 0 { return AccAddress{}, ErrEmptyHexAddress } - if length := len(addr); length != 2*EthAddressLength { - return AccAddress{}, fmt.Errorf("invalid address hex length: %v != %v", length, 2*EthAddressLength) + + if len(address) >= 2 && address[0] == '0' && (address[1] == 'x' || address[1] == 'X') { + address = address[2:] + } + if len(address) != 2*EthAddressLength { + return AccAddress{}, fmt.Errorf("invalid address hex length: %v != %v", len(address), 2*EthAddressLength) } - bz, err := hex.DecodeString(addr) + bz, err := hex.DecodeString(address) if err != nil { return AccAddress{}, err } - return AccAddress(bz), nil + return bz, nil } // VerifyAddressFormat verifies that the provided bytes form a valid address @@ -219,10 +217,7 @@ func (aa AccAddress) Equals(aa2 Address) bool { // Returns boolean for whether an AccAddress is empty func (aa AccAddress) Empty() bool { - addrValue := big.NewInt(0) - addrValue.SetBytes(aa[:]) - - return addrValue.Cmp(big.NewInt(0)) == 0 + return len(aa) == 0 } // Marshal returns the raw address bytes. It is needed for protobuf diff --git a/types/address_test.go b/types/address_test.go index 49ab341c3b..5525b4d7db 100644 --- a/types/address_test.go +++ b/types/address_test.go @@ -484,3 +484,13 @@ func (s *addressTestSuite) TestGetFromBech32() { s.Require().Error(err) s.Require().Equal("invalid Bech32 prefix; expected x, got cosmos", err.Error()) } + +func (s *addressTestSuite) TestAddressCases() { + addrStr1 := "0x17d749d3e2ac204a07e19d8096d9a05c423ea3af" + addr1, _ := types.AccAddressFromHexUnsafe(addrStr1) + s.Require().False(addr1.Empty()) + + addrStr2 := "0x17D749D3E2ac204a07e19D8096d9a05c423ea3af" + addr2, _ := types.AccAddressFromHexUnsafe(addrStr2) + s.Require().True(addr1.Equals(addr2)) +} diff --git a/x/oracle/keeper/keeper.go b/x/oracle/keeper/keeper.go index c5bc8a2c13..52a39e3359 100644 --- a/x/oracle/keeper/keeper.go +++ b/x/oracle/keeper/keeper.go @@ -82,26 +82,21 @@ func (k Keeper) GetRelayerRewardShare(ctx sdk.Context) uint32 { } // IsRelayerValid returns true if the relayer is valid and allowed to send the claim message -func (k Keeper) IsRelayerValid(ctx sdk.Context, validators []stakingtypes.Validator, claim *types.MsgClaim) (bool, error) { - fromAddress, err := sdk.AccAddressFromHexUnsafe(claim.FromAddress) - if err != nil { - return false, sdkerrors.Wrapf(types.ErrInvalidAddress, fmt.Sprintf("from address (%s) is invalid", claim.FromAddress)) - } - +func (k Keeper) IsRelayerValid(ctx sdk.Context, relayer sdk.AccAddress, validators []stakingtypes.Validator, claimTimestamp uint64) (bool, error) { var validatorIndex int64 = -1 for index, validator := range validators { - if validator.RelayerAddress == fromAddress.String() { + if validator.RelayerAddress == relayer.String() { validatorIndex = int64(index) break } } if validatorIndex < 0 { - return false, sdkerrors.Wrapf(types.ErrNotRelayer, fmt.Sprintf("sender(%s) is not a relayer", fromAddress.String())) + return false, sdkerrors.Wrapf(types.ErrNotRelayer, fmt.Sprintf("sender(%s) is not a relayer", relayer.String())) } // check inturn validator index - inturnValidatorIndex := claim.Timestamp % uint64(len(validators)) + inturnValidatorIndex := claimTimestamp % uint64(len(validators)) curTime := ctx.BlockTime().Unix() relayerTimeout, relayerBackoffTime := k.GetRelayerParam(ctx) @@ -112,32 +107,37 @@ func (k Keeper) IsRelayerValid(ctx sdk.Context, validators []stakingtypes.Valida } // not inturn validators can not relay in the timeout duration - if uint64(curTime)-claim.Timestamp <= relayerTimeout { + if uint64(curTime)-claimTimestamp <= relayerTimeout { return false, nil } validatorDistance := (validatorIndex - int64(inturnValidatorIndex) + int64(len(validators))) % int64(len(validators)) - return curTime > int64(claim.Timestamp+relayerTimeout)+(validatorDistance-1)*int64(relayerBackoffTime), nil + return curTime > int64(claimTimestamp+relayerTimeout)+(validatorDistance-1)*int64(relayerBackoffTime), nil } // CheckClaim checks the bls signature -func (k Keeper) CheckClaim(ctx sdk.Context, claim *types.MsgClaim) ([]string, error) { +func (k Keeper) CheckClaim(ctx sdk.Context, claim *types.MsgClaim) (sdk.AccAddress, []string, error) { + relayer, err := sdk.AccAddressFromHexUnsafe(claim.FromAddress) + if err != nil { + return sdk.AccAddress{}, nil, sdkerrors.Wrapf(types.ErrInvalidAddress, fmt.Sprintf("from address (%s) is invalid", claim.FromAddress)) + } + historicalInfo, ok := k.StakingKeeper.GetHistoricalInfo(ctx, ctx.BlockHeight()) if !ok { - return nil, sdkerrors.Wrapf(types.ErrValidatorSet, "get historical validators failed") + return sdk.AccAddress{}, nil, sdkerrors.Wrapf(types.ErrValidatorSet, "get historical validators failed") } validators := historicalInfo.Valset - isValid, err := k.IsRelayerValid(ctx, validators, claim) + isValid, err := k.IsRelayerValid(ctx, relayer, validators, claim.Timestamp) if err != nil { - return nil, err + return sdk.AccAddress{}, nil, err } if !isValid { - return nil, sdkerrors.Wrapf(types.ErrRelayerNotInTurn, fmt.Sprintf("relayer(%s) is not in turn", claim.FromAddress)) + return sdk.AccAddress{}, nil, sdkerrors.Wrapf(types.ErrRelayerNotInTurn, fmt.Sprintf("relayer(%s) is not in turn", claim.FromAddress)) } validatorsBitSet := bitset.From(claim.VoteAddressSet) if validatorsBitSet.Count() > uint(len(validators)) { - return nil, sdkerrors.Wrapf(types.ErrValidatorSet, "number of validator set is larger than validators") + return sdk.AccAddress{}, nil, sdkerrors.Wrapf(types.ErrValidatorSet, "number of validator set is larger than validators") } signedRelayers := make([]string, 0, validatorsBitSet.Count()) @@ -151,27 +151,27 @@ func (k Keeper) CheckClaim(ctx sdk.Context, claim *types.MsgClaim) ([]string, er votePubKey, err := bls.PublicKeyFromBytes(val.RelayerBlsKey) if err != nil { - return nil, sdkerrors.Wrapf(types.ErrBlsPubKey, fmt.Sprintf("BLS public key converts failed: %v", err)) + return sdk.AccAddress{}, nil, sdkerrors.Wrapf(types.ErrBlsPubKey, fmt.Sprintf("BLS public key converts failed: %v", err)) } votedPubKeys = append(votedPubKeys, votePubKey) } // The valid voted validators should be no less than 2/3 validators. if len(votedPubKeys) <= len(validators)*2/3 { - return nil, sdkerrors.Wrapf(types.ErrBlsVotesNotEnough, fmt.Sprintf("not enough validators voted, need: %d, voted: %d", len(validators)*2/3, len(votedPubKeys))) + return sdk.AccAddress{}, nil, sdkerrors.Wrapf(types.ErrBlsVotesNotEnough, fmt.Sprintf("not enough validators voted, need: %d, voted: %d", len(validators)*2/3, len(votedPubKeys))) } // Verify the aggregated signature. aggSig, err := bls.SignatureFromBytes(claim.AggSignature) if err != nil { - return nil, sdkerrors.Wrapf(types.ErrInvalidBlsSignature, fmt.Sprintf("BLS signature converts failed: %v", err)) + return sdk.AccAddress{}, nil, sdkerrors.Wrapf(types.ErrInvalidBlsSignature, fmt.Sprintf("BLS signature converts failed: %v", err)) } if !aggSig.FastAggregateVerify(votedPubKeys, claim.GetBlsSignBytes()) { - return nil, sdkerrors.Wrapf(types.ErrInvalidBlsSignature, "signature verify failed") + return sdk.AccAddress{}, nil, sdkerrors.Wrapf(types.ErrInvalidBlsSignature, "signature verify failed") } - return signedRelayers, nil + return relayer, signedRelayers, nil } // GetParams returns the current params diff --git a/x/oracle/keeper/keeper_test.go b/x/oracle/keeper/keeper_test.go index 2ce61fbde5..28d488d7f5 100644 --- a/x/oracle/keeper/keeper_test.go +++ b/x/oracle/keeper/keeper_test.go @@ -117,7 +117,7 @@ func (s *TestSuite) TestProcessClaim() { msgClaim.AggSignature = blsSig s.ctx = s.ctx.WithBlockTime(time.Unix(int64(msgClaim.Timestamp), 0)) - _, err := s.app.OracleKeeper.CheckClaim(s.ctx, &msgClaim) + _, _, err := s.app.OracleKeeper.CheckClaim(s.ctx, &msgClaim) s.Require().Nil(err, "error should be nil") // wrong validator set @@ -127,7 +127,7 @@ func (s *TestSuite) TestProcessClaim() { } msgClaim.VoteAddressSet = wrongValBitSet.Bytes() s.ctx = s.ctx.WithBlockTime(time.Unix(int64(msgClaim.Timestamp), 0)) - _, err = s.app.OracleKeeper.CheckClaim(s.ctx, &msgClaim) + _, _, err = s.app.OracleKeeper.CheckClaim(s.ctx, &msgClaim) s.Require().NotNil(err, "error should not be nil") s.Require().Contains(err.Error(), "number of validator set is larger than validators") @@ -137,7 +137,7 @@ func (s *TestSuite) TestProcessClaim() { wrongValBitSet.Set(uint(validatorMap[newValidators[1].RelayerAddress])) msgClaim.VoteAddressSet = wrongValBitSet.Bytes() s.ctx = s.ctx.WithBlockTime(time.Unix(int64(msgClaim.Timestamp), 0)) - _, err = s.app.OracleKeeper.CheckClaim(s.ctx, &msgClaim) + _, _, err = s.app.OracleKeeper.CheckClaim(s.ctx, &msgClaim) s.Require().NotNil(err, "error should not be nil") s.Require().Contains(err.Error(), "not enough validators voted") @@ -146,7 +146,7 @@ func (s *TestSuite) TestProcessClaim() { msgClaim.AggSignature = bytes.Repeat([]byte{2}, 96) s.ctx = s.ctx.WithBlockTime(time.Unix(int64(msgClaim.Timestamp), 0)) - _, err = s.app.OracleKeeper.CheckClaim(s.ctx, &msgClaim) + _, _, err = s.app.OracleKeeper.CheckClaim(s.ctx, &msgClaim) s.Require().NotNil(err, "error should not be nil") s.Require().Contains(err.Error(), "BLS signature converts failed") } @@ -257,7 +257,8 @@ func (s *TestSuite) TestKeeper_IsRelayerValid() { for idx, test := range tests { s.ctx = s.ctx.WithBlockTime(time.Unix(test.blockTime, 0)) - isValid, err := s.app.OracleKeeper.IsRelayerValid(s.ctx, vals, &test.claimMsg) + relayer := sdk.MustAccAddressFromHex(test.claimMsg.FromAddress) + isValid, err := s.app.OracleKeeper.IsRelayerValid(s.ctx, relayer, vals, test.claimMsg.Timestamp) if test.expectedPass { s.Require().Nil(err) diff --git a/x/oracle/keeper/msg_server.go b/x/oracle/keeper/msg_server.go index 25373d5724..daa54719a2 100644 --- a/x/oracle/keeper/msg_server.go +++ b/x/oracle/keeper/msg_server.go @@ -50,7 +50,7 @@ func (k msgServer) Claim(goCtx context.Context, req *types.MsgClaim) (*types.Msg return nil, sdkerrors.Wrapf(types.ErrInvalidReceiveSequence, fmt.Sprintf("current sequence of channel %d is %d", types.RelayPackagesChannelId, sequence)) } - signedRelayers, err := k.oracleKeeper.CheckClaim(ctx, req) + relayer, signedRelayers, err := k.oracleKeeper.CheckClaim(ctx, req) if err != nil { return nil, err } @@ -66,7 +66,7 @@ func (k msgServer) Claim(goCtx context.Context, req *types.MsgClaim) (*types.Msg for idx := range packages { pack := packages[idx] - relayerFee, event, err := handlePackage(ctx, req, k.oracleKeeper, &pack) + relayerFee, event, err := handlePackage(ctx, k.oracleKeeper, &pack, req.SrcChainId, req.DestChainId, req.Timestamp) if err != nil { // only do log, but let rest package get chance to execute. logger.Error("process package failed", "channel", pack.ChannelId, "sequence", pack.Sequence, "error", err.Error()) @@ -82,7 +82,7 @@ func (k msgServer) Claim(goCtx context.Context, req *types.MsgClaim) (*types.Msg k.oracleKeeper.CrossChainKeeper.IncrReceiveSequence(ctx, pack.ChannelId) } - err = distributeReward(ctx, k.oracleKeeper, req.FromAddress, signedRelayers, totalRelayerFee) + err = distributeReward(ctx, k.oracleKeeper, relayer, signedRelayers, totalRelayerFee) if err != nil { return nil, err } @@ -95,25 +95,20 @@ func (k msgServer) Claim(goCtx context.Context, req *types.MsgClaim) (*types.Msg } // distributeReward will distribute reward to relayers -func distributeReward(ctx sdk.Context, oracleKeeper Keeper, relayer string, signedRelayers []string, relayerFee *big.Int) error { +func distributeReward(ctx sdk.Context, oracleKeeper Keeper, relayer sdk.AccAddress, signedRelayers []string, relayerFee *big.Int) error { if relayerFee.Cmp(big.NewInt(0)) <= 0 { oracleKeeper.Logger(ctx).Info("total relayer fee is zero") return nil } - relayerAddr, err := sdk.AccAddressFromHexUnsafe(relayer) - if err != nil { - return sdkerrors.Wrapf(types.ErrInvalidAddress, fmt.Sprintf("relayer address (%s) is invalid", relayer)) - } - otherRelayers := make([]sdk.AccAddress, 0, len(signedRelayers)) for idx := range signedRelayers { signedRelayerAddr, err := sdk.AccAddressFromHexUnsafe(signedRelayers[idx]) if err != nil { return sdkerrors.Wrapf(types.ErrInvalidAddress, fmt.Sprintf("relayer address (%s) is invalid", relayer)) } - if !signedRelayerAddr.Equals(relayerAddr) { - otherRelayers = append(otherRelayers, relayerAddr) + if !signedRelayerAddr.Equals(relayer) { + otherRelayers = append(otherRelayers, relayer) } } @@ -131,7 +126,7 @@ func distributeReward(ctx sdk.Context, oracleKeeper Keeper, relayer string, sign bondDenom := oracleKeeper.StakingKeeper.BondDenom(ctx) if otherRelayerReward.Cmp(big.NewInt(0)) > 0 { for idx := range otherRelayers { - err = oracleKeeper.BankKeeper.SendCoinsFromModuleToAccount(ctx, + err := oracleKeeper.BankKeeper.SendCoinsFromModuleToAccount(ctx, crosschaintypes.ModuleName, otherRelayers[idx], sdk.Coins{sdk.Coin{Denom: bondDenom, Amount: sdk.NewIntFromBigInt(otherRelayerReward)}}, @@ -145,9 +140,9 @@ func distributeReward(ctx sdk.Context, oracleKeeper Keeper, relayer string, sign remainingReward := relayerFee.Sub(relayerFee, totalDistributed) if remainingReward.Cmp(big.NewInt(0)) > 0 { - err = oracleKeeper.BankKeeper.SendCoinsFromModuleToAccount(ctx, + err := oracleKeeper.BankKeeper.SendCoinsFromModuleToAccount(ctx, crosschaintypes.ModuleName, - relayerAddr, + relayer, sdk.Coins{sdk.Coin{Denom: bondDenom, Amount: sdk.NewIntFromBigInt(remainingReward)}}, ) if err != nil { @@ -155,14 +150,16 @@ func distributeReward(ctx sdk.Context, oracleKeeper Keeper, relayer string, sign } } - return err + return nil } func handlePackage( ctx sdk.Context, - req *types.MsgClaim, oracleKeeper Keeper, pack *types.Package, + srcChainId uint32, + destChainId uint32, + timestamp uint64, ) (*big.Int, *types.EventPackageClaim, error) { logger := oracleKeeper.Logger(ctx) @@ -182,9 +179,9 @@ func handlePackage( return nil, nil, sdkerrors.Wrapf(types.ErrInvalidPayloadHeader, "payload header is invalid") } - if packageHeader.Timestamp != req.Timestamp { + if packageHeader.Timestamp != timestamp { return nil, nil, sdkerrors.Wrapf(types.ErrInvalidPayloadHeader, - "timestamp(%d) is not the same in payload header(%d)", req.Timestamp, packageHeader.Timestamp) + "timestamp(%d) is not the same in payload header(%d)", timestamp, packageHeader.Timestamp) } if !sdk.IsValidCrossChainPackageType(packageHeader.PackageType) { @@ -230,8 +227,8 @@ func handlePackage( } claimEvent := &types.EventPackageClaim{ - SrcChainId: req.SrcChainId, - DestChainId: req.DestChainId, + SrcChainId: srcChainId, + DestChainId: destChainId, ChannelId: uint32(pack.ChannelId), PackageType: uint32(packageHeader.PackageType), ReceiveSequence: pack.Sequence,