diff --git a/consensus/consortium/v2/consortium.go b/consensus/consortium/v2/consortium.go index b6f02199d9..772e5dd263 100644 --- a/consensus/consortium/v2/consortium.go +++ b/consensus/consortium/v2/consortium.go @@ -208,6 +208,15 @@ func (c *Consortium) GetRecents(chain consensus.ChainHeaderReader, number uint64 // VerifyVote check if the finality voter is in the validator set, it assumes the signature is // already verified func (c *Consortium) VerifyVote(chain consensus.ChainHeaderReader, vote *types.VoteEnvelope) error { + header := chain.GetHeaderByHash(vote.Data.TargetHash) + if header == nil { + return errors.New("header not found") + } + + if header.Number.Uint64() != vote.Data.TargetNumber { + return finality.ErrInvalidTargetNumber + } + // Look at the comment assembleFinalityVote in function for the // detailed explanation on the snapshot we need to get to verify the // finality vote. @@ -1176,9 +1185,6 @@ func (c *Consortium) assembleFinalityVote(header *types.Header, snap *Snapshot) log.Warn("Malformed public key from vote pool", "err", err) continue } - if vote.Data.TargetNumber != header.Number.Uint64()-1 { - continue - } authorized := false for valPosition, validator := range snap.ValidatorsWithBlsPub { if publicKey.Equals(validator.BlsPublicKey) { diff --git a/consensus/consortium/v2/consortium_test.go b/consensus/consortium/v2/consortium_test.go index 0d5ac7d293..653eca4854 100644 --- a/consensus/consortium/v2/consortium_test.go +++ b/consensus/consortium/v2/consortium_test.go @@ -12,8 +12,11 @@ import ( "github.com/ethereum/go-ethereum/consensus" consortiumCommon "github.com/ethereum/go-ethereum/consensus/consortium/common" "github.com/ethereum/go-ethereum/consensus/consortium/v2/finality" + "github.com/ethereum/go-ethereum/consensus/ethash" + "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/crypto/bls/blst" blsCommon "github.com/ethereum/go-ethereum/crypto/bls/common" "github.com/ethereum/go-ethereum/params" @@ -539,8 +542,8 @@ func TestVerifyFinalitySignature(t *testing.T) { valWithBlsPub := make([]finality.ValidatorWithBlsPub, numValidator) for i := 0; i < len(valWithBlsPub); i++ { valWithBlsPub[i] = finality.ValidatorWithBlsPub{ - common.BigToAddress(big.NewInt(int64(i))), - secretKey[i].PublicKey(), + Address: common.BigToAddress(big.NewInt(int64(i))), + BlsPublicKey: secretKey[i].PublicKey(), } } @@ -803,12 +806,6 @@ func TestAssembleFinalityVote(t *testing.T) { }, }) } - // Wrong target number vote - malformedVoteData := types.VoteData{ - TargetNumber: 6, - TargetHash: common.Hash{0x1}, - } - votes[4].Data = &malformedVoteData mock := mockVotePool{ vote: votes, @@ -846,9 +843,7 @@ func TestAssembleFinalityVote(t *testing.T) { bitSet := finality.FinalityVoteBitSet(0) for i := 0; i < 9; i++ { - if i != 4 { - bitSet.SetBit(i) - } + bitSet.SetBit(i) } if uint64(bitSet) != uint64(extraData.FinalityVotedValidators) { @@ -861,9 +856,7 @@ func TestAssembleFinalityVote(t *testing.T) { var includedSignatures []blsCommon.Signature for i := 0; i < 9; i++ { - if i != 4 { - includedSignatures = append(includedSignatures, signatures[i]) - } + includedSignatures = append(includedSignatures, signatures[i]) } aggregatedSignature := blst.AggregateSignatures(includedSignatures) @@ -872,3 +865,110 @@ func TestAssembleFinalityVote(t *testing.T) { t.Fatal("Mismatch signature") } } + +func TestVerifyVote(t *testing.T) { + const numValidator = 3 + var err error + + secretKey := make([]blsCommon.SecretKey, numValidator+1) + for i := 0; i < len(secretKey); i++ { + secretKey[i], err = blst.RandKey() + if err != nil { + t.Fatalf("Failed to generate secret key, err %s", err) + } + } + + valWithBlsPub := make([]finality.ValidatorWithBlsPub, numValidator) + for i := 0; i < len(valWithBlsPub); i++ { + valWithBlsPub[i] = finality.ValidatorWithBlsPub{ + Address: common.BigToAddress(big.NewInt(int64(i))), + BlsPublicKey: secretKey[i].PublicKey(), + } + } + + db := rawdb.NewMemoryDatabase() + genesis := (&core.Genesis{ + Config: params.TestChainConfig, + BaseFee: big.NewInt(params.InitialBaseFee), + }).MustCommit(db) + chain, _ := core.NewBlockChain(db, nil, params.TestChainConfig, ethash.NewFullFaker(), vm.Config{}, nil, nil) + + bs, _ := core.GenerateChain(params.TestChainConfig, genesis, ethash.NewFaker(), db, 1, nil, true) + if _, err := chain.InsertChain(bs[:]); err != nil { + panic(err) + } + + snap := newSnapshot(nil, nil, nil, 10, common.Hash{}, nil, valWithBlsPub, nil) + recents, _ := lru.NewARC(inmemorySnapshots) + c := Consortium{ + chainConfig: ¶ms.ChainConfig{ + ShillinBlock: big.NewInt(0), + }, + config: ¶ms.ConsortiumConfig{ + EpochV2: 300, + }, + recents: recents, + } + snap.Hash = bs[0].Hash() + c.recents.Add(snap.Hash, snap) + + // invalid vote number + voteData := types.VoteData{ + TargetNumber: 2, + TargetHash: bs[0].Hash(), + } + signature := secretKey[0].Sign(voteData.Hash().Bytes()) + + vote := types.VoteEnvelope{ + RawVoteEnvelope: types.RawVoteEnvelope{ + PublicKey: types.BLSPublicKey(secretKey[0].PublicKey().Marshal()), + Signature: types.BLSSignature(signature.Marshal()), + Data: &voteData, + }, + } + + err = c.VerifyVote(chain, &vote) + if !errors.Is(err, finality.ErrInvalidTargetNumber) { + t.Errorf("Expect error %v have %v", finality.ErrInvalidTargetNumber, err) + } + + // invalid public key + voteData = types.VoteData{ + TargetNumber: 1, + TargetHash: bs[0].Hash(), + } + signature = secretKey[numValidator].Sign(voteData.Hash().Bytes()) + + vote = types.VoteEnvelope{ + RawVoteEnvelope: types.RawVoteEnvelope{ + PublicKey: types.BLSPublicKey(secretKey[numValidator].PublicKey().Marshal()), + Signature: types.BLSSignature(signature.Marshal()), + Data: &voteData, + }, + } + + err = c.VerifyVote(chain, &vote) + if !errors.Is(err, finality.ErrUnauthorizedFinalityVoter) { + t.Errorf("Expect error %v have %v", finality.ErrUnauthorizedFinalityVoter, err) + } + + // sucessful case + voteData = types.VoteData{ + TargetNumber: 1, + TargetHash: bs[0].Hash(), + } + signature = secretKey[0].Sign(voteData.Hash().Bytes()) + + vote = types.VoteEnvelope{ + RawVoteEnvelope: types.RawVoteEnvelope{ + PublicKey: types.BLSPublicKey(secretKey[0].PublicKey().Marshal()), + Signature: types.BLSSignature(signature.Marshal()), + Data: &voteData, + }, + } + + err = c.VerifyVote(chain, &vote) + if err != nil { + t.Errorf("Expect sucessful verification have %s", err) + } +} diff --git a/consensus/consortium/v2/finality/consortium_header.go b/consensus/consortium/v2/finality/consortium_header.go index 801dfc7daa..e19e85ed15 100644 --- a/consensus/consortium/v2/finality/consortium_header.go +++ b/consensus/consortium/v2/finality/consortium_header.go @@ -62,6 +62,10 @@ var ( // ErrInvalidSpanValidators is returned if a block contains an // invalid list of validators (i.e. non divisible by 20 bytes). ErrInvalidSpanValidators = errors.New("invalid validator list on sprint end block") + + // ErrInvalidTargetNumber is returned if the vote contains invalid + // target number + ErrInvalidTargetNumber = errors.New("invalid target number in vote") ) type ValidatorWithBlsPub struct { diff --git a/core/vote/vote_pool_test.go b/core/vote/vote_pool_test.go index 8151695b65..2bd1b57fc1 100644 --- a/core/vote/vote_pool_test.go +++ b/core/vote/vote_pool_test.go @@ -363,7 +363,7 @@ func generateVote( secretKey blsCommon.SecretKey, ) *types.VoteEnvelope { voteData := types.VoteData{ - TargetNumber: 1, + TargetNumber: uint64(blockNumber), TargetHash: blockHash, } digest := voteData.Hash() @@ -465,3 +465,67 @@ func TestVotePoolDosProtection(t *testing.T) { t.Fatalf("Number of future vote per peer, expect %d have %d", 0, votePool.numFutureVotePerPeer["AAAA"]) } } + +type mockPOSAv2 struct { + consensus.FastFinalityPoSA +} + +func (p *mockPOSAv2) GetJustifiedNumberAndHash(chain consensus.ChainHeaderReader, header *types.Header) (uint64, common.Hash, error) { + parentHeader := chain.GetHeaderByHash(header.ParentHash) + if parentHeader == nil { + return 0, common.Hash{}, fmt.Errorf("unexpected error") + } + return parentHeader.Number.Uint64(), parentHeader.Hash(), nil +} + +func (m *mockPOSAv2) VerifyVote(chain consensus.ChainHeaderReader, vote *types.VoteEnvelope) error { + header := chain.GetHeaderByHash(vote.Data.TargetHash) + if header == nil { + return errors.New("header not found") + } + + if header.Number.Uint64() != vote.Data.TargetNumber { + return errors.New("wrong target number in vote") + } + + return nil +} + +func (m *mockPOSAv2) IsActiveValidatorAt(chain consensus.ChainHeaderReader, header *types.Header) bool { + return true +} + +func TestVotePoolWrongTargetNumber(t *testing.T) { + secretKey, err := bls.RandKey() + if err != nil { + t.Fatalf("Failed to create secret key, err %s", err) + } + + // Create a database pre-initialize with a genesis block + db := rawdb.NewMemoryDatabase() + genesis := (&core.Genesis{ + Config: params.TestChainConfig, + Alloc: core.GenesisAlloc{testAddr: {Balance: big.NewInt(1000000)}}, + BaseFee: big.NewInt(params.InitialBaseFee), + }).MustCommit(db) + chain, _ := core.NewBlockChain(db, nil, params.TestChainConfig, ethash.NewFullFaker(), vm.Config{}, nil, nil) + + bs, _ := core.GenerateChain(params.TestChainConfig, genesis, ethash.NewFaker(), db, 1, nil, true) + if _, err := chain.InsertChain(bs[:1]); err != nil { + panic(err) + } + mockEngine := &mockPOSAv2{} + + // Create vote pool + votePool := NewVotePool(chain, mockEngine, 22) + + // bs[0] is the block 1 so the target block number must be 1. + // Here we provide wrong target number 0 + vote := generateVote(0, bs[0].Hash(), secretKey) + votePool.PutVote("AAAA", vote) + time.Sleep(100 * time.Millisecond) + + if len(votePool.curVotes) != 0 { + t.Fatalf("Current vote length, expect %d have %d", 0, len(votePool.curVotes)) + } +}