Skip to content

Commit

Permalink
Refactor state tests to always use initialized state (#3310)
Browse files Browse the repository at this point in the history
  • Loading branch information
StephenButtolph authored Aug 19, 2024
1 parent df91c2f commit acfcfe4
Show file tree
Hide file tree
Showing 5 changed files with 220 additions and 223 deletions.
107 changes: 107 additions & 0 deletions vms/platformvm/genesis/genesistest/genesis.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
// See the file LICENSE for licensing terms.

package genesistest

import (
"testing"
"time"

"github.com/stretchr/testify/require"

"github.com/ava-labs/avalanchego/ids"
"github.com/ava-labs/avalanchego/utils/constants"
"github.com/ava-labs/avalanchego/utils/units"
"github.com/ava-labs/avalanchego/vms/components/avax"
"github.com/ava-labs/avalanchego/vms/platformvm/genesis"
"github.com/ava-labs/avalanchego/vms/platformvm/reward"
"github.com/ava-labs/avalanchego/vms/platformvm/txs"
"github.com/ava-labs/avalanchego/vms/secp256k1fx"
)

var (
AVAXAssetID = ids.GenerateTestID()
AVAXAsset = avax.Asset{ID: AVAXAssetID}

ValidatorNodeID = ids.GenerateTestNodeID()
Time = time.Now().Round(time.Second)
TimeUnix = uint64(Time.Unix())
ValidatorDuration = 28 * 24 * time.Hour
ValidatorEndTime = Time.Add(ValidatorDuration)
ValidatorEndTimeUnix = uint64(ValidatorEndTime.Unix())
ValidatorWeight = units.Avax
ValidatorRewardsOwner = &secp256k1fx.OutputOwners{}
ValidatorDelegationShares uint32 = reward.PercentDenominator

XChainName = "x"

InitialBalance = units.Schmeckle
InitialSupply = ValidatorWeight + InitialBalance
)

func New(t testing.TB) *genesis.Genesis {
require := require.New(t)

genesisValidator := &txs.AddValidatorTx{
Validator: txs.Validator{
NodeID: ValidatorNodeID,
Start: TimeUnix,
End: ValidatorEndTimeUnix,
Wght: ValidatorWeight,
},
StakeOuts: []*avax.TransferableOutput{
{
Asset: AVAXAsset,
Out: &secp256k1fx.TransferOutput{
Amt: ValidatorWeight,
},
},
},
RewardsOwner: ValidatorRewardsOwner,
DelegationShares: ValidatorDelegationShares,
}
genesisValidatorTx := &txs.Tx{Unsigned: genesisValidator}
require.NoError(genesisValidatorTx.Initialize(txs.Codec))

genesisChain := &txs.CreateChainTx{
SubnetID: constants.PrimaryNetworkID,
ChainName: XChainName,
VMID: constants.AVMID,
SubnetAuth: &secp256k1fx.Input{},
}
genesisChainTx := &txs.Tx{Unsigned: genesisChain}
require.NoError(genesisChainTx.Initialize(txs.Codec))

return &genesis.Genesis{
UTXOs: []*genesis.UTXO{
{
UTXO: avax.UTXO{
UTXOID: avax.UTXOID{
TxID: AVAXAssetID,
OutputIndex: 0,
},
Asset: AVAXAsset,
Out: &secp256k1fx.TransferOutput{
Amt: InitialBalance,
},
},
Message: nil,
},
},
Validators: []*txs.Tx{
genesisValidatorTx,
},
Chains: []*txs.Tx{
genesisChainTx,
},
Timestamp: TimeUnix,
InitialSupply: InitialSupply,
}
}

func NewBytes(t testing.TB) []byte {
g := New(t)
genesisBytes, err := genesis.Codec.Marshal(genesis.CodecVersion, g)
require.NoError(t, err)
return genesisBytes
}
19 changes: 10 additions & 9 deletions vms/platformvm/state/diff_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"go.uber.org/mock/gomock"

"github.com/ava-labs/avalanchego/database"
"github.com/ava-labs/avalanchego/database/memdb"
"github.com/ava-labs/avalanchego/ids"
"github.com/ava-labs/avalanchego/utils"
"github.com/ava-labs/avalanchego/utils/constants"
Expand All @@ -36,7 +37,7 @@ func TestDiffMissingState(t *testing.T) {
func TestNewDiffOn(t *testing.T) {
require := require.New(t)

state := newInitializedState(require)
state := newTestState(t, memdb.New())

d, err := NewDiffOn(state)
require.NoError(err)
Expand All @@ -47,7 +48,7 @@ func TestNewDiffOn(t *testing.T) {
func TestDiffFeeState(t *testing.T) {
require := require.New(t)

state := newInitializedState(require)
state := newTestState(t, memdb.New())

d, err := NewDiffOn(state)
require.NoError(err)
Expand All @@ -68,7 +69,7 @@ func TestDiffFeeState(t *testing.T) {
func TestDiffCurrentSupply(t *testing.T) {
require := require.New(t)

state := newInitializedState(require)
state := newTestState(t, memdb.New())

d, err := NewDiffOn(state)
require.NoError(err)
Expand Down Expand Up @@ -256,7 +257,7 @@ func TestDiffSubnet(t *testing.T) {
require := require.New(t)
ctrl := gomock.NewController(t)

state := newInitializedState(require)
state := newTestState(t, memdb.New())

// Initialize parent with one subnet
parentStateCreateSubnetTx := &txs.Tx{
Expand Down Expand Up @@ -305,7 +306,7 @@ func TestDiffSubnet(t *testing.T) {
func TestDiffChain(t *testing.T) {
require := require.New(t)

state := newInitializedState(require)
state := newTestState(t, memdb.New())
subnetID := ids.GenerateTestID()

// Initialize parent with one chain
Expand Down Expand Up @@ -402,7 +403,7 @@ func TestDiffTx(t *testing.T) {
func TestDiffRewardUTXO(t *testing.T) {
require := require.New(t)

state := newInitializedState(require)
state := newTestState(t, memdb.New())

// Initialize parent with one reward UTXO
var (
Expand Down Expand Up @@ -531,7 +532,7 @@ func TestDiffSubnetOwner(t *testing.T) {
require := require.New(t)
ctrl := gomock.NewController(t)

state := newInitializedState(require)
state := newTestState(t, memdb.New())

var (
owner1 = fx.NewMockOwner(ctrl)
Expand Down Expand Up @@ -589,7 +590,7 @@ func TestDiffSubnetManager(t *testing.T) {
require := require.New(t)
ctrl := gomock.NewController(t)

state := newInitializedState(require)
state := newTestState(t, memdb.New())

states := NewMockVersions(ctrl)
lastAcceptedID := ids.GenerateTestID()
Expand Down Expand Up @@ -638,7 +639,7 @@ func TestDiffStacking(t *testing.T) {
require := require.New(t)
ctrl := gomock.NewController(t)

state := newInitializedState(require)
state := newTestState(t, memdb.New())

var (
owner1 = fx.NewMockOwner(ctrl)
Expand Down
3 changes: 2 additions & 1 deletion vms/platformvm/state/stakers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/ava-labs/avalanchego/database"
"github.com/ava-labs/avalanchego/ids"
"github.com/ava-labs/avalanchego/vms/platformvm/genesis/genesistest"
"github.com/ava-labs/avalanchego/vms/platformvm/txs"
)

Expand Down Expand Up @@ -219,7 +220,7 @@ func TestDiffStakersDelegator(t *testing.T) {

func newTestStaker() *Staker {
startTime := time.Now().Round(time.Second)
endTime := startTime.Add(28 * 24 * time.Hour)
endTime := startTime.Add(genesistest.ValidatorDuration)
return &Staker{
TxID: ids.GenerateTestID(),
NodeID: ids.GenerateTestNodeID(),
Expand Down
72 changes: 24 additions & 48 deletions vms/platformvm/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -458,38 +458,6 @@ func New(
metrics metrics.Metrics,
rewards reward.Calculator,
) (State, error) {
s, err := newState(
db,
metrics,
cfg,
execCfg,
ctx,
metricsReg,
rewards,
)
if err != nil {
return nil, err
}

if err := s.sync(genesisBytes); err != nil {
// Drop any errors on close to return the first error
_ = s.Close()

return nil, err
}

return s, nil
}

func newState(
db database.Database,
metrics metrics.Metrics,
cfg *config.Config,
execCfg *config.ExecutionConfig,
ctx *snow.Context,
metricsReg prometheus.Registerer,
rewards reward.Calculator,
) (*state, error) {
blockIDCache, err := metercacher.New[uint64, ids.ID](
"block_id_cache",
metricsReg,
Expand Down Expand Up @@ -614,7 +582,7 @@ func newState(
return nil, err
}

return &state{
s := &state{
validatorState: newValidatorState(),

validators: cfg.Validators,
Expand Down Expand Up @@ -694,7 +662,16 @@ func newState(
chainDBCache: chainDBCache,

singletonDB: prefixdb.New(SingletonPrefix, baseDB),
}, nil
}

if err := s.sync(genesisBytes); err != nil {
return nil, errors.Join(
err,
s.Close(),
)
}

return s, nil
}

func (s *state) GetCurrentValidator(subnetID ids.ID, nodeID ids.NodeID) (*Staker, error) {
Expand Down Expand Up @@ -753,15 +730,6 @@ func (s *state) GetPendingStakerIterator() (StakerIterator, error) {
return s.pendingStakers.GetStakerIterator(), nil
}

func (s *state) shouldInit() (bool, error) {
has, err := s.singletonDB.Has(InitializedKey)
return !has, err
}

func (s *state) doneInit() error {
return s.singletonDB.Put(InitializedKey, nil)
}

func (s *state) GetSubnetIDs() ([]ids.ID, error) {
if s.cachedSubnetIDs != nil {
return s.cachedSubnetIDs, nil
Expand Down Expand Up @@ -1751,17 +1719,17 @@ func (s *state) Close() error {
}

func (s *state) sync(genesis []byte) error {
shouldInit, err := s.shouldInit()
wasInitialized, err := isInitialized(s.singletonDB)
if err != nil {
return fmt.Errorf(
"failed to check if the database is initialized: %w",
err,
)
}

// If the database is empty, create the platform chain anew using the
// provided genesis state
if shouldInit {
// If the database wasn't previously initialized, create the platform chain
// anew using the provided genesis state.
if !wasInitialized {
if err := s.init(genesis); err != nil {
return fmt.Errorf(
"failed to initialize the database: %w",
Expand Down Expand Up @@ -1797,7 +1765,7 @@ func (s *state) init(genesisBytes []byte) error {
return err
}

if err := s.doneInit(); err != nil {
if err := markInitialized(s.singletonDB); err != nil {
return err
}

Expand Down Expand Up @@ -2548,6 +2516,14 @@ func (s *state) ReindexBlocks(lock sync.Locker, log logging.Logger) error {
return s.Commit()
}

func markInitialized(db database.KeyValueWriter) error {
return db.Put(InitializedKey, nil)
}

func isInitialized(db database.KeyValueReader) (bool, error) {
return db.Has(InitializedKey)
}

func putFeeState(db database.KeyValueWriter, feeState fee.State) error {
feeStateBytes, err := block.GenesisCodec.Marshal(block.CodecVersion, feeState)
if err != nil {
Expand Down
Loading

0 comments on commit acfcfe4

Please sign in to comment.