diff --git a/cache/cachetest/cacher.go b/cache/cachetest/cacher.go index 92fd72635c02..50bdc9ad9fd6 100644 --- a/cache/cachetest/cacher.go +++ b/cache/cachetest/cacher.go @@ -12,14 +12,14 @@ import ( "github.com/ava-labs/avalanchego/ids" ) -const TestIntSize = ids.IDLen + 8 +const IntSize = ids.IDLen + 8 -func TestIntSizeFunc(ids.ID, int64) int { - return TestIntSize +func IntSizeFunc(ids.ID, int64) int { + return IntSize } -// CacherTests is a list of all Cacher tests -var CacherTests = []struct { +// Tests is a list of all Cacher tests +var Tests = []struct { Size int Func func(t *testing.T, c cache.Cacher[ids.ID, int64]) }{ diff --git a/cache/lru_sized_cache_test.go b/cache/lru_sized_cache_test.go index 8993408c1507..565f1b3343b5 100644 --- a/cache/lru_sized_cache_test.go +++ b/cache/lru_sized_cache_test.go @@ -15,13 +15,13 @@ import ( ) func TestSizedLRU(t *testing.T) { - cache := NewSizedLRU[ids.ID, int64](cachetest.TestIntSize, cachetest.TestIntSizeFunc) + cache := NewSizedLRU[ids.ID, int64](cachetest.IntSize, cachetest.IntSizeFunc) cachetest.TestBasic(t, cache) } func TestSizedLRUEviction(t *testing.T) { - cache := NewSizedLRU[ids.ID, int64](2*cachetest.TestIntSize, cachetest.TestIntSizeFunc) + cache := NewSizedLRU[ids.ID, int64](2*cachetest.IntSize, cachetest.IntSizeFunc) cachetest.TestEviction(t, cache) } diff --git a/cache/metercacher/cache_test.go b/cache/metercacher/cache_test.go index 609210c68a79..281f0b9ccaca 100644 --- a/cache/metercacher/cache_test.go +++ b/cache/metercacher/cache_test.go @@ -30,13 +30,13 @@ func TestInterface(t *testing.T) { { description: "sized cache LRU", setup: func(size int) cache.Cacher[ids.ID, int64] { - return cache.NewSizedLRU[ids.ID, int64](size*cachetest.TestIntSize, cachetest.TestIntSizeFunc) + return cache.NewSizedLRU[ids.ID, int64](size*cachetest.IntSize, cachetest.IntSizeFunc) }, }, } for _, scenario := range scenarios { - for _, test := range cachetest.CacherTests { + for _, test := range cachetest.Tests { baseCache := scenario.setup(test.Size) c, err := New("", prometheus.NewRegistry(), baseCache) require.NoError(t, err) diff --git a/codec/codectest/codectest.go b/codec/codectest/codectest.go index 172d61f4e453..528782fe9139 100644 --- a/codec/codectest/codectest.go +++ b/codec/codectest/codectest.go @@ -1,6 +1,7 @@ // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. // See the file LICENSE for licensing terms. +// Package codectest provides a test suite for testing codec implementations. package codectest import ( @@ -12,45 +13,72 @@ import ( codecpkg "github.com/ava-labs/avalanchego/codec" ) +// A NamedTest couples a test in the suite with a human-readable name. +type NamedTest struct { + Name string + Test func(testing.TB, codecpkg.GeneralCodec) +} + +// Run runs the test on the GeneralCodec. +func (tt *NamedTest) Run(t *testing.T, c codecpkg.GeneralCodec) { + t.Run(tt.Name, func(t *testing.T) { + tt.Test(t, c) + }) +} + +// RunAll runs all [Tests], constructing a new GeneralCodec for each. +func RunAll(t *testing.T, ctor func() codecpkg.GeneralCodec) { + for _, tt := range Tests { + tt.Run(t, ctor()) + } +} + +// RunAll runs all [MultipleTagsTests], constructing a new GeneralCodec for each. +func RunAllMultipleTags(t *testing.T, ctor func() codecpkg.GeneralCodec) { + for _, tt := range MultipleTagsTests { + tt.Run(t, ctor()) + } +} + var ( - Tests = []func(c codecpkg.GeneralCodec, t testing.TB){ - TestStruct, - TestRegisterStructTwice, - TestUInt32, - TestUIntPtr, - TestSlice, - TestMaxSizeSlice, - TestBool, - TestArray, - TestBigArray, - TestPointerToStruct, - TestSliceOfStruct, - TestInterface, - TestSliceOfInterface, - TestArrayOfInterface, - TestPointerToInterface, - TestString, - TestNilSlice, - TestSerializeUnexportedField, - TestSerializeOfNoSerializeField, - TestNilSliceSerialization, - TestEmptySliceSerialization, - TestSliceWithEmptySerialization, - TestSliceWithEmptySerializationError, - TestMapWithEmptySerialization, - TestMapWithEmptySerializationError, - TestSliceTooLarge, - TestNegativeNumbers, - TestTooLargeUnmarshal, - TestUnmarshalInvalidInterface, - TestExtraSpace, - TestSliceLengthOverflow, - TestMap, - TestCanMarshalLargeSlices, + Tests = []NamedTest{ + {"Struct", TestStruct}, + {"Register Struct Twice", TestRegisterStructTwice}, + {"UInt32", TestUInt32}, + {"UIntPtr", TestUIntPtr}, + {"Slice", TestSlice}, + {"Max-Size Slice", TestMaxSizeSlice}, + {"Bool", TestBool}, + {"Array", TestArray}, + {"Big Array", TestBigArray}, + {"Pointer To Struct", TestPointerToStruct}, + {"Slice Of Struct", TestSliceOfStruct}, + {"Interface", TestInterface}, + {"Slice Of Interface", TestSliceOfInterface}, + {"Array Of Interface", TestArrayOfInterface}, + {"Pointer To Interface", TestPointerToInterface}, + {"String", TestString}, + {"Nil Slice", TestNilSlice}, + {"Serialize Unexported Field", TestSerializeUnexportedField}, + {"Serialize Of NoSerialize Field", TestSerializeOfNoSerializeField}, + {"Nil Slice Serialization", TestNilSliceSerialization}, + {"Empty Slice Serialization", TestEmptySliceSerialization}, + {"Slice With Empty Serialization", TestSliceWithEmptySerialization}, + {"Slice With Empty Serialization Error", TestSliceWithEmptySerializationError}, + {"Map With Empty Serialization", TestMapWithEmptySerialization}, + {"Map With Empty Serialization Error", TestMapWithEmptySerializationError}, + {"Slice Too Large", TestSliceTooLarge}, + {"Negative Numbers", TestNegativeNumbers}, + {"Too Large Unmarshal", TestTooLargeUnmarshal}, + {"Unmarshal Invalid Interface", TestUnmarshalInvalidInterface}, + {"Extra Space", TestExtraSpace}, + {"Slice Length Overflow", TestSliceLengthOverflow}, + {"Map", TestMap}, + {"Can Marshal Large Slices", TestCanMarshalLargeSlices}, } - MultipleTagsTests = []func(c codecpkg.GeneralCodec, t testing.TB){ - TestMultipleTags, + MultipleTagsTests = []NamedTest{ + {"Multiple Tags", TestMultipleTags}, } ) @@ -127,7 +155,7 @@ type myStruct struct { } // Test marshaling/unmarshaling a complicated struct -func TestStruct(codec codecpkg.GeneralCodec, t testing.TB) { +func TestStruct(t testing.TB, codec codecpkg.GeneralCodec) { require := require.New(t) temp := Foo(&MyInnerStruct{}) @@ -243,7 +271,7 @@ func TestStruct(codec codecpkg.GeneralCodec, t testing.TB) { require.Equal(myStructInstance, *myStructUnmarshaled) } -func TestRegisterStructTwice(codec codecpkg.GeneralCodec, t testing.TB) { +func TestRegisterStructTwice(t testing.TB, codec codecpkg.GeneralCodec) { require := require.New(t) require.NoError(codec.RegisterType(&MyInnerStruct{})) @@ -251,7 +279,7 @@ func TestRegisterStructTwice(codec codecpkg.GeneralCodec, t testing.TB) { require.ErrorIs(err, codecpkg.ErrDuplicateType) } -func TestUInt32(codec codecpkg.GeneralCodec, t testing.TB) { +func TestUInt32(t testing.TB, codec codecpkg.GeneralCodec) { require := require.New(t) number := uint32(500) @@ -273,7 +301,7 @@ func TestUInt32(codec codecpkg.GeneralCodec, t testing.TB) { require.Equal(number, numberUnmarshaled) } -func TestUIntPtr(codec codecpkg.GeneralCodec, t testing.TB) { +func TestUIntPtr(t testing.TB, codec codecpkg.GeneralCodec) { require := require.New(t) manager := codecpkg.NewDefaultManager() @@ -285,7 +313,7 @@ func TestUIntPtr(codec codecpkg.GeneralCodec, t testing.TB) { require.ErrorIs(err, codecpkg.ErrUnsupportedType) } -func TestSlice(codec codecpkg.GeneralCodec, t testing.TB) { +func TestSlice(t testing.TB, codec codecpkg.GeneralCodec) { require := require.New(t) mySlice := []bool{true, false, true, true} @@ -307,7 +335,7 @@ func TestSlice(codec codecpkg.GeneralCodec, t testing.TB) { } // Test marshalling/unmarshalling largest possible slice -func TestMaxSizeSlice(codec codecpkg.GeneralCodec, t testing.TB) { +func TestMaxSizeSlice(t testing.TB, codec codecpkg.GeneralCodec) { require := require.New(t) mySlice := make([]string, math.MaxUint16) @@ -331,7 +359,7 @@ func TestMaxSizeSlice(codec codecpkg.GeneralCodec, t testing.TB) { } // Test marshalling a bool -func TestBool(codec codecpkg.GeneralCodec, t testing.TB) { +func TestBool(t testing.TB, codec codecpkg.GeneralCodec) { require := require.New(t) myBool := true @@ -353,7 +381,7 @@ func TestBool(codec codecpkg.GeneralCodec, t testing.TB) { } // Test marshalling an array -func TestArray(codec codecpkg.GeneralCodec, t testing.TB) { +func TestArray(t testing.TB, codec codecpkg.GeneralCodec) { require := require.New(t) myArr := [5]uint64{5, 6, 7, 8, 9} @@ -375,7 +403,7 @@ func TestArray(codec codecpkg.GeneralCodec, t testing.TB) { } // Test marshalling a really big array -func TestBigArray(codec codecpkg.GeneralCodec, t testing.TB) { +func TestBigArray(t testing.TB, codec codecpkg.GeneralCodec) { require := require.New(t) myArr := [30000]uint64{5, 6, 7, 8, 9} @@ -397,7 +425,7 @@ func TestBigArray(codec codecpkg.GeneralCodec, t testing.TB) { } // Test marshalling a pointer to a struct -func TestPointerToStruct(codec codecpkg.GeneralCodec, t testing.TB) { +func TestPointerToStruct(t testing.TB, codec codecpkg.GeneralCodec) { require := require.New(t) myPtr := &MyInnerStruct{Str: "Hello!"} @@ -419,7 +447,7 @@ func TestPointerToStruct(codec codecpkg.GeneralCodec, t testing.TB) { } // Test marshalling a slice of structs -func TestSliceOfStruct(codec codecpkg.GeneralCodec, t testing.TB) { +func TestSliceOfStruct(t testing.TB, codec codecpkg.GeneralCodec) { require := require.New(t) mySlice := []MyInnerStruct3{ { @@ -453,7 +481,7 @@ func TestSliceOfStruct(codec codecpkg.GeneralCodec, t testing.TB) { } // Test marshalling an interface -func TestInterface(codec codecpkg.GeneralCodec, t testing.TB) { +func TestInterface(t testing.TB, codec codecpkg.GeneralCodec) { require := require.New(t) require.NoError(codec.RegisterType(&MyInnerStruct2{})) @@ -477,7 +505,7 @@ func TestInterface(codec codecpkg.GeneralCodec, t testing.TB) { } // Test marshalling a slice of interfaces -func TestSliceOfInterface(codec codecpkg.GeneralCodec, t testing.TB) { +func TestSliceOfInterface(t testing.TB, codec codecpkg.GeneralCodec) { require := require.New(t) mySlice := []Foo{ @@ -508,7 +536,7 @@ func TestSliceOfInterface(codec codecpkg.GeneralCodec, t testing.TB) { } // Test marshalling an array of interfaces -func TestArrayOfInterface(codec codecpkg.GeneralCodec, t testing.TB) { +func TestArrayOfInterface(t testing.TB, codec codecpkg.GeneralCodec) { require := require.New(t) myArray := [2]Foo{ @@ -539,7 +567,7 @@ func TestArrayOfInterface(codec codecpkg.GeneralCodec, t testing.TB) { } // Test marshalling a pointer to an interface -func TestPointerToInterface(codec codecpkg.GeneralCodec, t testing.TB) { +func TestPointerToInterface(t testing.TB, codec codecpkg.GeneralCodec) { require := require.New(t) var myinnerStruct Foo = &MyInnerStruct{Str: "Hello!"} @@ -565,7 +593,7 @@ func TestPointerToInterface(codec codecpkg.GeneralCodec, t testing.TB) { } // Test marshalling a string -func TestString(codec codecpkg.GeneralCodec, t testing.TB) { +func TestString(t testing.TB, codec codecpkg.GeneralCodec) { require := require.New(t) myString := "Ayy" @@ -587,7 +615,7 @@ func TestString(codec codecpkg.GeneralCodec, t testing.TB) { } // Ensure a nil slice is unmarshaled to slice with length 0 -func TestNilSlice(codec codecpkg.GeneralCodec, t testing.TB) { +func TestNilSlice(t testing.TB, codec codecpkg.GeneralCodec) { require := require.New(t) type structWithSlice struct { @@ -614,7 +642,7 @@ func TestNilSlice(codec codecpkg.GeneralCodec, t testing.TB) { // Ensure that trying to serialize a struct with an unexported member // that has `serialize:"true"` returns error -func TestSerializeUnexportedField(codec codecpkg.GeneralCodec, t testing.TB) { +func TestSerializeUnexportedField(t testing.TB, codec codecpkg.GeneralCodec) { require := require.New(t) type s struct { @@ -637,7 +665,7 @@ func TestSerializeUnexportedField(codec codecpkg.GeneralCodec, t testing.TB) { require.ErrorIs(err, codecpkg.ErrUnexportedField) } -func TestSerializeOfNoSerializeField(codec codecpkg.GeneralCodec, t testing.TB) { +func TestSerializeOfNoSerializeField(t testing.TB, codec codecpkg.GeneralCodec) { require := require.New(t) type s struct { @@ -670,7 +698,7 @@ func TestSerializeOfNoSerializeField(codec codecpkg.GeneralCodec, t testing.TB) } // Test marshalling of nil slice -func TestNilSliceSerialization(codec codecpkg.GeneralCodec, t testing.TB) { +func TestNilSliceSerialization(t testing.TB, codec codecpkg.GeneralCodec) { require := require.New(t) type simpleSliceStruct struct { @@ -698,7 +726,7 @@ func TestNilSliceSerialization(codec codecpkg.GeneralCodec, t testing.TB) { } // Test marshaling a slice that has 0 elements (but isn't nil) -func TestEmptySliceSerialization(codec codecpkg.GeneralCodec, t testing.TB) { +func TestEmptySliceSerialization(t testing.TB, codec codecpkg.GeneralCodec) { require := require.New(t) type simpleSliceStruct struct { @@ -726,7 +754,7 @@ func TestEmptySliceSerialization(codec codecpkg.GeneralCodec, t testing.TB) { } // Test marshaling empty slice of zero length structs -func TestSliceWithEmptySerialization(codec codecpkg.GeneralCodec, t testing.TB) { +func TestSliceWithEmptySerialization(t testing.TB, codec codecpkg.GeneralCodec) { require := require.New(t) type emptyStruct struct{} @@ -757,7 +785,7 @@ func TestSliceWithEmptySerialization(codec codecpkg.GeneralCodec, t testing.TB) require.Empty(unmarshaled.Arr) } -func TestSliceWithEmptySerializationError(codec codecpkg.GeneralCodec, t testing.TB) { +func TestSliceWithEmptySerializationError(t testing.TB, codec codecpkg.GeneralCodec) { require := require.New(t) type emptyStruct struct{} @@ -786,7 +814,7 @@ func TestSliceWithEmptySerializationError(codec codecpkg.GeneralCodec, t testing } // Test marshaling empty map of zero length structs -func TestMapWithEmptySerialization(codec codecpkg.GeneralCodec, t testing.TB) { +func TestMapWithEmptySerialization(t testing.TB, codec codecpkg.GeneralCodec) { require := require.New(t) type emptyStruct struct{} @@ -811,7 +839,7 @@ func TestMapWithEmptySerialization(codec codecpkg.GeneralCodec, t testing.TB) { require.Empty(unmarshaled) } -func TestMapWithEmptySerializationError(codec codecpkg.GeneralCodec, t testing.TB) { +func TestMapWithEmptySerializationError(t testing.TB, codec codecpkg.GeneralCodec) { require := require.New(t) type emptyStruct struct{} @@ -835,7 +863,7 @@ func TestMapWithEmptySerializationError(codec codecpkg.GeneralCodec, t testing.T require.ErrorIs(err, codecpkg.ErrUnmarshalZeroLength) } -func TestSliceTooLarge(codec codecpkg.GeneralCodec, t testing.TB) { +func TestSliceTooLarge(t testing.TB, codec codecpkg.GeneralCodec) { require := require.New(t) manager := codecpkg.NewDefaultManager() @@ -848,7 +876,7 @@ func TestSliceTooLarge(codec codecpkg.GeneralCodec, t testing.TB) { } // Ensure serializing structs with negative number members works -func TestNegativeNumbers(codec codecpkg.GeneralCodec, t testing.TB) { +func TestNegativeNumbers(t testing.TB, codec codecpkg.GeneralCodec) { require := require.New(t) type s struct { @@ -877,7 +905,7 @@ func TestNegativeNumbers(codec codecpkg.GeneralCodec, t testing.TB) { } // Ensure deserializing structs with too many bytes errors correctly -func TestTooLargeUnmarshal(codec codecpkg.GeneralCodec, t testing.TB) { +func TestTooLargeUnmarshal(t testing.TB, codec codecpkg.GeneralCodec) { require := require.New(t) type inner struct { @@ -910,7 +938,7 @@ func (*innerInterface) ToInt() int { type innerNoInterface struct{} // Ensure deserializing structs into the wrong interface errors gracefully -func TestUnmarshalInvalidInterface(codec codecpkg.GeneralCodec, t testing.TB) { +func TestUnmarshalInvalidInterface(t testing.TB, codec codecpkg.GeneralCodec) { require := require.New(t) manager := codecpkg.NewDefaultManager() @@ -934,7 +962,7 @@ func TestUnmarshalInvalidInterface(codec codecpkg.GeneralCodec, t testing.TB) { } // Test unmarshaling something with extra data -func TestExtraSpace(codec codecpkg.GeneralCodec, t testing.TB) { +func TestExtraSpace(t testing.TB, codec codecpkg.GeneralCodec) { require := require.New(t) manager := codecpkg.NewDefaultManager() @@ -948,7 +976,7 @@ func TestExtraSpace(codec codecpkg.GeneralCodec, t testing.TB) { } // Ensure deserializing slices whose lengths exceed MaxInt32 error correctly -func TestSliceLengthOverflow(codec codecpkg.GeneralCodec, t testing.TB) { +func TestSliceLengthOverflow(t testing.TB, codec codecpkg.GeneralCodec) { require := require.New(t) type inner struct { @@ -978,7 +1006,7 @@ type MultipleVersionsStruct struct { NoTags string `tag1:"false" tag2:"false"` } -func TestMultipleTags(codec codecpkg.GeneralCodec, t testing.TB) { +func TestMultipleTags(t testing.TB, codec codecpkg.GeneralCodec) { // received codec is expected to have both v1 and v2 registered as tags inputs := MultipleVersionsStruct{ BothTags: "both Tags", @@ -1011,7 +1039,7 @@ func TestMultipleTags(codec codecpkg.GeneralCodec, t testing.TB) { } } -func TestMap(codec codecpkg.GeneralCodec, t testing.TB) { +func TestMap(t testing.TB, codec codecpkg.GeneralCodec) { require := require.New(t) data1 := map[string]MyInnerStruct2{ @@ -1080,7 +1108,7 @@ func TestMap(codec codecpkg.GeneralCodec, t testing.TB) { require.Len(outerArrayBytes, outerArraySize) } -func TestCanMarshalLargeSlices(codec codecpkg.GeneralCodec, t testing.TB) { +func TestCanMarshalLargeSlices(t testing.TB, codec codecpkg.GeneralCodec) { require := require.New(t) data := make([]uint16, 1_000_000) diff --git a/codec/hierarchycodec/codec_test.go b/codec/hierarchycodec/codec_test.go index fc7c9e4d7696..6908f16d09f0 100644 --- a/codec/hierarchycodec/codec_test.go +++ b/codec/hierarchycodec/codec_test.go @@ -6,21 +6,20 @@ package hierarchycodec import ( "testing" + "github.com/ava-labs/avalanchego/codec" "github.com/ava-labs/avalanchego/codec/codectest" ) func TestVectors(t *testing.T) { - for _, test := range codectest.Tests { - c := NewDefault() - test(c, t) - } + codectest.RunAll(t, func() codec.GeneralCodec { + return NewDefault() + }) } func TestMultipleTags(t *testing.T) { - for _, test := range codectest.MultipleTagsTests { - c := New([]string{"tag1", "tag2"}) - test(c, t) - } + codectest.RunAllMultipleTags(t, func() codec.GeneralCodec { + return New([]string{"tag1", "tag2"}) + }) } func FuzzStructUnmarshalHierarchyCodec(f *testing.F) { diff --git a/codec/linearcodec/codec_test.go b/codec/linearcodec/codec_test.go index bc9d5fb42dd5..8e9152c25fd1 100644 --- a/codec/linearcodec/codec_test.go +++ b/codec/linearcodec/codec_test.go @@ -6,21 +6,20 @@ package linearcodec import ( "testing" + "github.com/ava-labs/avalanchego/codec" "github.com/ava-labs/avalanchego/codec/codectest" ) func TestVectors(t *testing.T) { - for _, test := range codectest.Tests { - c := NewDefault() - test(c, t) - } + codectest.RunAll(t, func() codec.GeneralCodec { + return NewDefault() + }) } func TestMultipleTags(t *testing.T) { - for _, test := range codectest.MultipleTagsTests { - c := New([]string{"tag1", "tag2"}) - test(c, t) - } + codectest.RunAllMultipleTags(t, func() codec.GeneralCodec { + return New([]string{"tag1", "tag2"}) + }) } func FuzzStructUnmarshalLinearCodec(f *testing.F) { diff --git a/ids/aliases_test.go b/ids/aliases_test.go index aeed141c2327..8cb7b765f0d7 100644 --- a/ids/aliases_test.go +++ b/ids/aliases_test.go @@ -14,11 +14,10 @@ import ( ) func TestAliaser(t *testing.T) { - require := require.New(t) - for _, test := range idstest.AliasTests { - aliaser := NewAliaser() - test(require, aliaser, aliaser) - } + idstest.RunAllAlias(t, func() (AliaserReader, AliaserWriter) { + a := NewAliaser() + return a, a + }) } func TestPrimaryAliasOrDefaultTest(t *testing.T) { diff --git a/ids/galiasreader/alias_reader_test.go b/ids/galiasreader/alias_reader_test.go index 59c6c4228984..1835bb1792bc 100644 --- a/ids/galiasreader/alias_reader_test.go +++ b/ids/galiasreader/alias_reader_test.go @@ -16,28 +16,29 @@ import ( ) func TestInterface(t *testing.T) { - require := require.New(t) - for _, test := range idstest.AliasTests { - listener, err := grpcutils.NewListener() - require.NoError(err) - serverCloser := grpcutils.ServerCloser{} - w := ids.NewAliaser() + t.Run(test.Name, func(t *testing.T) { + require := require.New(t) - server := grpcutils.NewServer() - aliasreaderpb.RegisterAliasReaderServer(server, NewServer(w)) - serverCloser.Add(server) + listener, err := grpcutils.NewListener() + require.NoError(err) + defer listener.Close() + serverCloser := grpcutils.ServerCloser{} + defer serverCloser.Stop() + w := ids.NewAliaser() - go grpcutils.Serve(listener, server) + server := grpcutils.NewServer() + aliasreaderpb.RegisterAliasReaderServer(server, NewServer(w)) + serverCloser.Add(server) - conn, err := grpcutils.Dial(listener.Addr().String()) - require.NoError(err) + go grpcutils.Serve(listener, server) - r := NewClient(aliasreaderpb.NewAliasReaderClient(conn)) - test(require, r, w) + conn, err := grpcutils.Dial(listener.Addr().String()) + require.NoError(err) + defer conn.Close() - serverCloser.Stop() - _ = conn.Close() - _ = listener.Close() + r := NewClient(aliasreaderpb.NewAliasReaderClient(conn)) + test.Test(t, r, w) + }) } } diff --git a/ids/idstest/aliases.go b/ids/idstest/aliases.go index 2dd44e564a2d..5b5c78e407cd 100644 --- a/ids/idstest/aliases.go +++ b/ids/idstest/aliases.go @@ -1,31 +1,58 @@ // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. // See the file LICENSE for licensing terms. +// Package codectest provides a test suite for testing functionality related to +// IDs. package idstest import ( + "testing" + "github.com/stretchr/testify/require" "github.com/ava-labs/avalanchego/ids" ) -var AliasTests = []func(require *require.Assertions, r ids.AliaserReader, w ids.AliaserWriter){ - AliaserLookupErrorTest, - AliaserLookupTest, - AliaserAliasesEmptyTest, - AliaserAliasesTest, - AliaserPrimaryAliasTest, - AliaserAliasClashTest, - AliaserRemoveAliasTest, +// An AliasTest couples a test in the Aliaser suite with a human-readable name. +type AliasTest struct { + Name string + Test func(testing.TB, ids.AliaserReader, ids.AliaserWriter) +} + +// Run runs the test on the Aliaser{Reader+Writer} pair. +func (tt *AliasTest) Run(t *testing.T, r ids.AliaserReader, w ids.AliaserWriter) { + t.Run(tt.Name, func(t *testing.T) { + tt.Test(t, r, w) + }) +} + +// RunAll runs all [AliasTests], constructing a new GeneralCodec for each. +func RunAllAlias(t *testing.T, ctor func() (ids.AliaserReader, ids.AliaserWriter)) { + for _, tt := range AliasTests { + r, w := ctor() + tt.Run(t, r, w) + } +} + +var AliasTests = []AliasTest{ + {"Lookup Error}", TestAliaserLookupError}, + {"Lookup}", TestAliaserLookup}, + {"Aliases Empty}", TestAliaserAliasesEmpty}, + {"Aliases}", TestAliaserAliases}, + {"Primary Alias}", TestAliaserPrimaryAlias}, + {"Alias Clash}", TestAliaserAliasClash}, + {"Remove Alias}", TestAliaserRemoveAlias}, } -func AliaserLookupErrorTest(require *require.Assertions, r ids.AliaserReader, _ ids.AliaserWriter) { +func TestAliaserLookupError(tb testing.TB, r ids.AliaserReader, _ ids.AliaserWriter) { + require := require.New(tb) _, err := r.Lookup("Batman") // TODO: require error to be errNoIDWithAlias require.Error(err) //nolint:forbidigo // currently returns grpc errors too } -func AliaserLookupTest(require *require.Assertions, r ids.AliaserReader, w ids.AliaserWriter) { +func TestAliaserLookup(tb testing.TB, r ids.AliaserReader, w ids.AliaserWriter) { + require := require.New(tb) id := ids.ID{'K', 'a', 't', 'e', ' ', 'K', 'a', 'n', 'e'} require.NoError(w.Alias(id, "Batwoman")) @@ -34,7 +61,8 @@ func AliaserLookupTest(require *require.Assertions, r ids.AliaserReader, w ids.A require.Equal(id, res) } -func AliaserAliasesEmptyTest(require *require.Assertions, r ids.AliaserReader, _ ids.AliaserWriter) { +func TestAliaserAliasesEmpty(tb testing.TB, r ids.AliaserReader, _ ids.AliaserWriter) { + require := require.New(tb) id := ids.ID{'J', 'a', 'm', 'e', 's', ' ', 'G', 'o', 'r', 'd', 'o', 'n'} aliases, err := r.Aliases(id) @@ -42,7 +70,8 @@ func AliaserAliasesEmptyTest(require *require.Assertions, r ids.AliaserReader, _ require.Empty(aliases) } -func AliaserAliasesTest(require *require.Assertions, r ids.AliaserReader, w ids.AliaserWriter) { +func TestAliaserAliases(tb testing.TB, r ids.AliaserReader, w ids.AliaserWriter) { + require := require.New(tb) id := ids.ID{'B', 'r', 'u', 'c', 'e', ' ', 'W', 'a', 'y', 'n', 'e'} require.NoError(w.Alias(id, "Batman")) @@ -55,7 +84,8 @@ func AliaserAliasesTest(require *require.Assertions, r ids.AliaserReader, w ids. require.Equal(expected, aliases) } -func AliaserPrimaryAliasTest(require *require.Assertions, r ids.AliaserReader, w ids.AliaserWriter) { +func TestAliaserPrimaryAlias(tb testing.TB, r ids.AliaserReader, w ids.AliaserWriter) { + require := require.New(tb) id1 := ids.ID{'J', 'a', 'm', 'e', 's', ' ', 'G', 'o', 'r', 'd', 'o', 'n'} id2 := ids.ID{'B', 'r', 'u', 'c', 'e', ' ', 'W', 'a', 'y', 'n', 'e'} @@ -72,7 +102,8 @@ func AliaserPrimaryAliasTest(require *require.Assertions, r ids.AliaserReader, w require.Equal(expected, res) } -func AliaserAliasClashTest(require *require.Assertions, _ ids.AliaserReader, w ids.AliaserWriter) { +func TestAliaserAliasClash(tb testing.TB, _ ids.AliaserReader, w ids.AliaserWriter) { + require := require.New(tb) id1 := ids.ID{'B', 'r', 'u', 'c', 'e', ' ', 'W', 'a', 'y', 'n', 'e'} id2 := ids.ID{'D', 'i', 'c', 'k', ' ', 'G', 'r', 'a', 'y', 's', 'o', 'n'} @@ -83,7 +114,8 @@ func AliaserAliasClashTest(require *require.Assertions, _ ids.AliaserReader, w i require.Error(err) //nolint:forbidigo // currently returns grpc errors too } -func AliaserRemoveAliasTest(require *require.Assertions, r ids.AliaserReader, w ids.AliaserWriter) { +func TestAliaserRemoveAlias(tb testing.TB, r ids.AliaserReader, w ids.AliaserWriter) { + require := require.New(tb) id1 := ids.ID{'B', 'r', 'u', 'c', 'e', ' ', 'W', 'a', 'y', 'n', 'e'} id2 := ids.ID{'J', 'a', 'm', 'e', 's', ' ', 'G', 'o', 'r', 'd', 'o', 'n'} diff --git a/network/p2p/gossip/gossip_test.go b/network/p2p/gossip/gossip_test.go index 5615fc2d35f3..be577dc6f18c 100644 --- a/network/p2p/gossip/gossip_test.go +++ b/network/p2p/gossip/gossip_test.go @@ -105,7 +105,7 @@ func TestGossiperGossip(t *testing.T) { require := require.New(t) ctx := context.Background() - responseSender := &enginetest.FakeSender{ + responseSender := &enginetest.SenderStub{ SentAppResponse: make(chan []byte, 1), } responseNetwork, err := p2p.NewNetwork(logging.NoLog{}, responseSender, prometheus.NewRegistry(), "") @@ -134,7 +134,7 @@ func TestGossiperGossip(t *testing.T) { require.NoError(err) require.NoError(responseNetwork.AddHandler(0x0, handler)) - requestSender := &enginetest.FakeSender{ + requestSender := &enginetest.SenderStub{ SentAppRequest: make(chan []byte, 1), } @@ -510,7 +510,7 @@ func TestPushGossiper(t *testing.T) { require := require.New(t) ctx := context.Background() - sender := &enginetest.FakeSender{ + sender := &enginetest.SenderStub{ SentAppGossip: make(chan []byte, 2), } network, err := p2p.NewNetwork( @@ -525,7 +525,7 @@ func TestPushGossiper(t *testing.T) { &p2p.Peers{}, logging.NoLog{}, constants.PrimaryNetworkID, - &validatorstest.TestState{ + &validatorstest.State{ GetCurrentHeightF: func(context.Context) (uint64, error) { return 1, nil }, diff --git a/network/p2p/network_test.go b/network/p2p/network_test.go index 3afc9bbcd97a..5d4e746eaf1d 100644 --- a/network/p2p/network_test.go +++ b/network/p2p/network_test.go @@ -59,7 +59,7 @@ func TestMessageRouting(t *testing.T) { }, } - sender := &enginetest.FakeSender{ + sender := &enginetest.SenderStub{ SentAppGossip: make(chan []byte, 1), SentAppRequest: make(chan []byte, 1), SentCrossChainAppRequest: make(chan []byte, 1), @@ -94,7 +94,7 @@ func TestClientPrefixesMessages(t *testing.T) { require := require.New(t) ctx := context.Background() - sender := enginetest.FakeSender{ + sender := enginetest.SenderStub{ SentAppRequest: make(chan []byte, 1), SentAppGossip: make(chan []byte, 1), SentCrossChainAppRequest: make(chan []byte, 1), @@ -153,7 +153,7 @@ func TestAppRequestResponse(t *testing.T) { require := require.New(t) ctx := context.Background() - sender := enginetest.FakeSender{ + sender := enginetest.SenderStub{ SentAppRequest: make(chan []byte, 1), } network, err := NewNetwork(logging.NoLog{}, sender, prometheus.NewRegistry(), "") @@ -188,7 +188,7 @@ func TestAppRequestCancelledContext(t *testing.T) { ctx := context.Background() sentMessages := make(chan []byte, 1) - sender := &enginetest.SenderTest{ + sender := &enginetest.Sender{ SendAppRequestF: func(ctx context.Context, _ set.Set[ids.NodeID], _ uint32, msgBytes []byte) error { require.NoError(ctx.Err()) sentMessages <- msgBytes @@ -229,7 +229,7 @@ func TestAppRequestFailed(t *testing.T) { require := require.New(t) ctx := context.Background() - sender := enginetest.FakeSender{ + sender := enginetest.SenderStub{ SentAppRequest: make(chan []byte, 1), } network, err := NewNetwork(logging.NoLog{}, sender, prometheus.NewRegistry(), "") @@ -259,7 +259,7 @@ func TestCrossChainAppRequestResponse(t *testing.T) { require := require.New(t) ctx := context.Background() - sender := enginetest.FakeSender{ + sender := enginetest.SenderStub{ SentCrossChainAppRequest: make(chan []byte, 1), } network, err := NewNetwork(logging.NoLog{}, sender, prometheus.NewRegistry(), "") @@ -291,7 +291,7 @@ func TestCrossChainAppRequestCancelledContext(t *testing.T) { ctx := context.Background() sentMessages := make(chan []byte, 1) - sender := &enginetest.SenderTest{ + sender := &enginetest.Sender{ SendCrossChainAppRequestF: func(ctx context.Context, _ ids.ID, _ uint32, msgBytes []byte) { require.NoError(ctx.Err()) sentMessages <- msgBytes @@ -328,7 +328,7 @@ func TestCrossChainAppRequestFailed(t *testing.T) { require := require.New(t) ctx := context.Background() - sender := enginetest.FakeSender{ + sender := enginetest.SenderStub{ SentCrossChainAppRequest: make(chan []byte, 1), } network, err := NewNetwork(logging.NoLog{}, sender, prometheus.NewRegistry(), "") @@ -426,7 +426,7 @@ func TestAppRequestMessageForUnregisteredHandler(t *testing.T) { wantRequestID := uint32(111) done := make(chan struct{}) - sender := &enginetest.SenderTest{} + sender := &enginetest.Sender{} sender.SendAppErrorF = func(_ context.Context, nodeID ids.NodeID, requestID uint32, errorCode int32, errorMessage string) error { defer close(done) @@ -465,7 +465,7 @@ func TestAppError(t *testing.T) { wantRequestID := uint32(111) done := make(chan struct{}) - sender := &enginetest.SenderTest{} + sender := &enginetest.Sender{} sender.SendAppErrorF = func(_ context.Context, nodeID ids.NodeID, requestID uint32, errorCode int32, errorMessage string) error { defer close(done) @@ -546,7 +546,7 @@ func TestAppRequestDuplicateRequestIDs(t *testing.T) { require := require.New(t) ctx := context.Background() - sender := &enginetest.FakeSender{ + sender := &enginetest.SenderStub{ SentAppRequest: make(chan []byte, 1), } @@ -632,7 +632,7 @@ func TestPeersSample(t *testing.T) { t.Run(tt.name, func(t *testing.T) { require := require.New(t) - network, err := NewNetwork(logging.NoLog{}, &enginetest.FakeSender{}, prometheus.NewRegistry(), "") + network, err := NewNetwork(logging.NoLog{}, &enginetest.SenderStub{}, prometheus.NewRegistry(), "") require.NoError(err) for connected := range tt.connected { @@ -675,7 +675,7 @@ func TestAppRequestAnyNodeSelection(t *testing.T) { require := require.New(t) sent := set.Set[ids.NodeID]{} - sender := &enginetest.SenderTest{ + sender := &enginetest.Sender{ SendAppRequestF: func(_ context.Context, nodeIDs set.Set[ids.NodeID], _ uint32, _ []byte) error { sent = nodeIDs return nil @@ -721,7 +721,7 @@ func TestNodeSamplerClientOption(t *testing.T) { name: "validator connected", peers: []ids.NodeID{nodeID0, nodeID1}, option: func(_ *testing.T, n *Network) ClientOption { - state := &validatorstest.TestState{ + state := &validatorstest.State{ GetCurrentHeightF: func(context.Context) (uint64, error) { return 0, nil }, @@ -744,7 +744,7 @@ func TestNodeSamplerClientOption(t *testing.T) { name: "validator disconnected", peers: []ids.NodeID{nodeID0}, option: func(_ *testing.T, n *Network) ClientOption { - state := &validatorstest.TestState{ + state := &validatorstest.State{ GetCurrentHeightF: func(context.Context) (uint64, error) { return 0, nil }, @@ -770,7 +770,7 @@ func TestNodeSamplerClientOption(t *testing.T) { require := require.New(t) done := make(chan struct{}) - sender := &enginetest.SenderTest{ + sender := &enginetest.Sender{ SendAppRequestF: func(_ context.Context, nodeIDs set.Set[ids.NodeID], _ uint32, _ []byte) error { require.Subset(tt.expected, nodeIDs.List()) close(done) @@ -800,7 +800,7 @@ func TestNodeSamplerClientOption(t *testing.T) { func TestMultipleClients(t *testing.T) { require := require.New(t) - n, err := NewNetwork(logging.NoLog{}, &enginetest.SenderTest{}, prometheus.NewRegistry(), "") + n, err := NewNetwork(logging.NoLog{}, &enginetest.Sender{}, prometheus.NewRegistry(), "") require.NoError(err) _ = n.NewClient(0) _ = n.NewClient(0) diff --git a/network/p2p/validators_test.go b/network/p2p/validators_test.go index 2c7787904f8c..16cdd92e18a3 100644 --- a/network/p2p/validators_test.go +++ b/network/p2p/validators_test.go @@ -195,7 +195,7 @@ func TestValidatorsSample(t *testing.T) { } gomock.InOrder(calls...) - network, err := NewNetwork(logging.NoLog{}, &enginetest.FakeSender{}, prometheus.NewRegistry(), "") + network, err := NewNetwork(logging.NoLog{}, &enginetest.SenderStub{}, prometheus.NewRegistry(), "") require.NoError(err) ctx := context.Background() @@ -315,7 +315,7 @@ func TestValidatorsTop(t *testing.T) { mockValidators.EXPECT().GetCurrentHeight(gomock.Any()).Return(uint64(1), nil) mockValidators.EXPECT().GetValidatorSet(gomock.Any(), uint64(1), subnetID).Return(validatorSet, nil) - network, err := NewNetwork(logging.NoLog{}, &enginetest.FakeSender{}, prometheus.NewRegistry(), "") + network, err := NewNetwork(logging.NoLog{}, &enginetest.SenderStub{}, prometheus.NewRegistry(), "") require.NoError(err) ctx := context.Background() diff --git a/snow/engine/avalanche/bootstrap/bootstrapper_test.go b/snow/engine/avalanche/bootstrap/bootstrapper_test.go index abd6304f2651..6a185afa8e39 100644 --- a/snow/engine/avalanche/bootstrap/bootstrapper_test.go +++ b/snow/engine/avalanche/bootstrap/bootstrapper_test.go @@ -56,7 +56,7 @@ func (t *testTx) Accept(ctx context.Context) error { return nil } -func newConfig(t *testing.T) (Config, ids.NodeID, *enginetest.SenderTest, *vertextest.TestManager, *vertextest.TestVM) { +func newConfig(t *testing.T) (Config, ids.NodeID, *enginetest.Sender, *vertextest.Manager, *vertextest.VM) { require := require.New(t) snowCtx := snowtest.Context(t, snowtest.CChainID) @@ -64,9 +64,9 @@ func newConfig(t *testing.T) (Config, ids.NodeID, *enginetest.SenderTest, *verte vdrs := validators.NewManager() db := memdb.New() - sender := &enginetest.SenderTest{T: t} - manager := vertextest.NewTestManager(t) - vm := &vertextest.TestVM{} + sender := &enginetest.Sender{T: t} + manager := vertextest.NewManager(t) + vm := &vertextest.VM{} vm.T = t sender.Default(true) diff --git a/snow/engine/avalanche/state/unique_vertex_test.go b/snow/engine/avalanche/state/unique_vertex_test.go index f37896474013..76326be21609 100644 --- a/snow/engine/avalanche/state/unique_vertex_test.go +++ b/snow/engine/avalanche/state/unique_vertex_test.go @@ -24,7 +24,7 @@ import ( var errUnknownTx = errors.New("unknown tx") func newTestSerializer(t *testing.T, parse func(context.Context, []byte) (snowstorm.Tx, error)) *Serializer { - vm := vertextest.TestVM{} + vm := vertextest.VM{} vm.T = t vm.Default(true) vm.ParseTxF = parse diff --git a/snow/engine/avalanche/vertex/vertextest/builder.go b/snow/engine/avalanche/vertex/vertextest/builder.go index 828fbaa00199..07d27382d63c 100644 --- a/snow/engine/avalanche/vertex/vertextest/builder.go +++ b/snow/engine/avalanche/vertex/vertextest/builder.go @@ -18,20 +18,20 @@ import ( var ( errBuild = errors.New("unexpectedly called Build") - _ vertex.Builder = (*TestBuilder)(nil) + _ vertex.Builder = (*Builder)(nil) ) -type TestBuilder struct { +type Builder struct { T *testing.T CantBuildVtx bool BuildStopVtxF func(ctx context.Context, parentIDs []ids.ID) (avalanche.Vertex, error) } -func (b *TestBuilder) Default(cant bool) { +func (b *Builder) Default(cant bool) { b.CantBuildVtx = cant } -func (b *TestBuilder) BuildStopVtx(ctx context.Context, parentIDs []ids.ID) (avalanche.Vertex, error) { +func (b *Builder) BuildStopVtx(ctx context.Context, parentIDs []ids.ID) (avalanche.Vertex, error) { if b.BuildStopVtxF != nil { return b.BuildStopVtxF(ctx, parentIDs) } diff --git a/snow/engine/avalanche/vertex/vertextest/manager.go b/snow/engine/avalanche/vertex/vertextest/manager.go index 0ccee48a8e19..eadd6a60e9e2 100644 --- a/snow/engine/avalanche/vertex/vertextest/manager.go +++ b/snow/engine/avalanche/vertex/vertextest/manager.go @@ -9,24 +9,24 @@ import ( "github.com/ava-labs/avalanchego/snow/engine/avalanche/vertex" ) -var _ vertex.Manager = (*TestManager)(nil) +var _ vertex.Manager = (*Manager)(nil) -type TestManager struct { - TestBuilder - TestParser - TestStorage +type Manager struct { + Builder + Parser + Storage } -func NewTestManager(t *testing.T) *TestManager { - return &TestManager{ - TestBuilder: TestBuilder{T: t}, - TestParser: TestParser{T: t}, - TestStorage: TestStorage{T: t}, +func NewManager(t *testing.T) *Manager { + return &Manager{ + Builder: Builder{T: t}, + Parser: Parser{T: t}, + Storage: Storage{T: t}, } } -func (m *TestManager) Default(cant bool) { - m.TestBuilder.Default(cant) - m.TestParser.Default(cant) - m.TestStorage.Default(cant) +func (m *Manager) Default(cant bool) { + m.Builder.Default(cant) + m.Parser.Default(cant) + m.Storage.Default(cant) } diff --git a/snow/engine/avalanche/vertex/vertextest/parser.go b/snow/engine/avalanche/vertex/vertextest/parser.go index 21a308690bd6..c2cb5891e255 100644 --- a/snow/engine/avalanche/vertex/vertextest/parser.go +++ b/snow/engine/avalanche/vertex/vertextest/parser.go @@ -17,20 +17,20 @@ import ( var ( errParse = errors.New("unexpectedly called Parse") - _ vertex.Parser = (*TestParser)(nil) + _ vertex.Parser = (*Parser)(nil) ) -type TestParser struct { +type Parser struct { T *testing.T CantParseVtx bool ParseVtxF func(context.Context, []byte) (avalanche.Vertex, error) } -func (p *TestParser) Default(cant bool) { +func (p *Parser) Default(cant bool) { p.CantParseVtx = cant } -func (p *TestParser) ParseVtx(ctx context.Context, b []byte) (avalanche.Vertex, error) { +func (p *Parser) ParseVtx(ctx context.Context, b []byte) (avalanche.Vertex, error) { if p.ParseVtxF != nil { return p.ParseVtxF(ctx, b) } diff --git a/snow/engine/avalanche/vertex/vertextest/storage.go b/snow/engine/avalanche/vertex/vertextest/storage.go index 5dc1314d9a91..b58aa566569a 100644 --- a/snow/engine/avalanche/vertex/vertextest/storage.go +++ b/snow/engine/avalanche/vertex/vertextest/storage.go @@ -20,10 +20,10 @@ var ( errEdge = errors.New("unexpectedly called Edge") errStopVertexAccepted = errors.New("unexpectedly called StopVertexAccepted") - _ vertex.Storage = (*TestStorage)(nil) + _ vertex.Storage = (*Storage)(nil) ) -type TestStorage struct { +type Storage struct { T *testing.T CantGetVtx, CantEdge, CantStopVertexAccepted bool GetVtxF func(context.Context, ids.ID) (avalanche.Vertex, error) @@ -31,12 +31,12 @@ type TestStorage struct { StopVertexAcceptedF func(context.Context) (bool, error) } -func (s *TestStorage) Default(cant bool) { +func (s *Storage) Default(cant bool) { s.CantGetVtx = cant s.CantEdge = cant } -func (s *TestStorage) GetVtx(ctx context.Context, vtxID ids.ID) (avalanche.Vertex, error) { +func (s *Storage) GetVtx(ctx context.Context, vtxID ids.ID) (avalanche.Vertex, error) { if s.GetVtxF != nil { return s.GetVtxF(ctx, vtxID) } @@ -46,7 +46,7 @@ func (s *TestStorage) GetVtx(ctx context.Context, vtxID ids.ID) (avalanche.Verte return nil, errGet } -func (s *TestStorage) Edge(ctx context.Context) []ids.ID { +func (s *Storage) Edge(ctx context.Context) []ids.ID { if s.EdgeF != nil { return s.EdgeF(ctx) } @@ -56,7 +56,7 @@ func (s *TestStorage) Edge(ctx context.Context) []ids.ID { return nil } -func (s *TestStorage) StopVertexAccepted(ctx context.Context) (bool, error) { +func (s *Storage) StopVertexAccepted(ctx context.Context) (bool, error) { if s.StopVertexAcceptedF != nil { return s.StopVertexAcceptedF(ctx) } diff --git a/snow/engine/avalanche/vertex/vertextest/vm.go b/snow/engine/avalanche/vertex/vertextest/vm.go index 416472cb51d8..08138ffb918a 100644 --- a/snow/engine/avalanche/vertex/vertextest/vm.go +++ b/snow/engine/avalanche/vertex/vertextest/vm.go @@ -18,11 +18,11 @@ import ( var ( errLinearize = errors.New("unexpectedly called Linearize") - _ vertex.LinearizableVM = (*TestVM)(nil) + _ vertex.LinearizableVM = (*VM)(nil) ) -type TestVM struct { - blocktest.TestVM +type VM struct { + blocktest.VM CantLinearize, CantParse bool @@ -30,13 +30,13 @@ type TestVM struct { ParseTxF func(context.Context, []byte) (snowstorm.Tx, error) } -func (vm *TestVM) Default(cant bool) { - vm.TestVM.Default(cant) +func (vm *VM) Default(cant bool) { + vm.VM.Default(cant) vm.CantParse = cant } -func (vm *TestVM) Linearize(ctx context.Context, stopVertexID ids.ID) error { +func (vm *VM) Linearize(ctx context.Context, stopVertexID ids.ID) error { if vm.LinearizeF != nil { return vm.LinearizeF(ctx, stopVertexID) } @@ -46,7 +46,7 @@ func (vm *TestVM) Linearize(ctx context.Context, stopVertexID ids.ID) error { return errLinearize } -func (vm *TestVM) ParseTx(ctx context.Context, b []byte) (snowstorm.Tx, error) { +func (vm *VM) ParseTx(ctx context.Context, b []byte) (snowstorm.Tx, error) { if vm.ParseTxF != nil { return vm.ParseTxF(ctx, b) } diff --git a/snow/engine/enginetest/bootstrap_tracker.go b/snow/engine/enginetest/bootstrap_tracker.go index 577b06300403..481e28d10366 100644 --- a/snow/engine/enginetest/bootstrap_tracker.go +++ b/snow/engine/enginetest/bootstrap_tracker.go @@ -11,8 +11,8 @@ import ( "github.com/ava-labs/avalanchego/ids" ) -// BootstrapTrackerTest is a test subnet -type BootstrapTrackerTest struct { +// BootstrapTracker is a test subnet +type BootstrapTracker struct { T *testing.T CantIsBootstrapped, CantBootstrapped, CantOnBootstrapCompleted bool @@ -24,7 +24,7 @@ type BootstrapTrackerTest struct { } // Default set the default callable value to [cant] -func (s *BootstrapTrackerTest) Default(cant bool) { +func (s *BootstrapTracker) Default(cant bool) { s.CantIsBootstrapped = cant s.CantBootstrapped = cant s.CantOnBootstrapCompleted = cant @@ -33,7 +33,7 @@ func (s *BootstrapTrackerTest) Default(cant bool) { // IsBootstrapped calls IsBootstrappedF if it was initialized. If it wasn't // initialized and this function shouldn't be called and testing was // initialized, then testing will fail. Defaults to returning false. -func (s *BootstrapTrackerTest) IsBootstrapped() bool { +func (s *BootstrapTracker) IsBootstrapped() bool { if s.IsBootstrappedF != nil { return s.IsBootstrappedF() } @@ -46,7 +46,7 @@ func (s *BootstrapTrackerTest) IsBootstrapped() bool { // Bootstrapped calls BootstrappedF if it was initialized. If it wasn't // initialized and this function shouldn't be called and testing was // initialized, then testing will fail. -func (s *BootstrapTrackerTest) Bootstrapped(chainID ids.ID) { +func (s *BootstrapTracker) Bootstrapped(chainID ids.ID) { if s.BootstrappedF != nil { s.BootstrappedF(chainID) } else if s.CantBootstrapped && s.T != nil { @@ -54,7 +54,7 @@ func (s *BootstrapTrackerTest) Bootstrapped(chainID ids.ID) { } } -func (s *BootstrapTrackerTest) OnBootstrapCompleted() chan struct{} { +func (s *BootstrapTracker) OnBootstrapCompleted() chan struct{} { if s.OnBootstrapCompletedF != nil { return s.OnBootstrapCompletedF() } else if s.CantOnBootstrapCompleted && s.T != nil { diff --git a/snow/engine/enginetest/bootstrapper.go b/snow/engine/enginetest/bootstrapper.go index da9ae3641bf8..2292d68bcacb 100644 --- a/snow/engine/enginetest/bootstrapper.go +++ b/snow/engine/enginetest/bootstrapper.go @@ -13,26 +13,26 @@ import ( ) var ( - _ common.BootstrapableEngine = (*BootstrapperTest)(nil) + _ common.BootstrapableEngine = (*Bootstrapper)(nil) errClear = errors.New("unexpectedly called Clear") ) -type BootstrapperTest struct { - EngineTest +type Bootstrapper struct { + Engine CantClear bool ClearF func(ctx context.Context) error } -func (b *BootstrapperTest) Default(cant bool) { - b.EngineTest.Default(cant) +func (b *Bootstrapper) Default(cant bool) { + b.Engine.Default(cant) b.CantClear = cant } -func (b *BootstrapperTest) Clear(ctx context.Context) error { +func (b *Bootstrapper) Clear(ctx context.Context) error { if b.ClearF != nil { return b.ClearF(ctx) } diff --git a/snow/engine/enginetest/engine.go b/snow/engine/enginetest/engine.go index 3cb9817d6880..53f791f19b25 100644 --- a/snow/engine/enginetest/engine.go +++ b/snow/engine/enginetest/engine.go @@ -46,11 +46,11 @@ var ( errChits = errors.New("unexpectedly called Chits") errStart = errors.New("unexpectedly called Start") - _ common.Engine = (*EngineTest)(nil) + _ common.Engine = (*Engine)(nil) ) -// EngineTest is a test engine -type EngineTest struct { +// Engine is a test engine +type Engine struct { T *testing.T CantStart, @@ -142,7 +142,7 @@ type EngineTest struct { CrossChainAppRequestFailedF func(ctx context.Context, chainID ids.ID, requestID uint32, appErr *common.AppError) error } -func (e *EngineTest) Default(cant bool) { +func (e *Engine) Default(cant bool) { e.CantStart = cant e.CantIsBootstrapped = cant e.CantTimeout = cant @@ -186,7 +186,7 @@ func (e *EngineTest) Default(cant bool) { e.CantCrossChainAppResponse = cant } -func (e *EngineTest) Start(ctx context.Context, startReqID uint32) error { +func (e *Engine) Start(ctx context.Context, startReqID uint32) error { if e.StartF != nil { return e.StartF(ctx, startReqID) } @@ -199,7 +199,7 @@ func (e *EngineTest) Start(ctx context.Context, startReqID uint32) error { return errStart } -func (e *EngineTest) Context() *snow.ConsensusContext { +func (e *Engine) Context() *snow.ConsensusContext { if e.ContextF != nil { return e.ContextF() } @@ -212,7 +212,7 @@ func (e *EngineTest) Context() *snow.ConsensusContext { return nil } -func (e *EngineTest) Timeout(ctx context.Context) error { +func (e *Engine) Timeout(ctx context.Context) error { if e.TimeoutF != nil { return e.TimeoutF(ctx) } @@ -225,7 +225,7 @@ func (e *EngineTest) Timeout(ctx context.Context) error { return errTimeout } -func (e *EngineTest) Gossip(ctx context.Context) error { +func (e *Engine) Gossip(ctx context.Context) error { if e.GossipF != nil { return e.GossipF(ctx) } @@ -238,7 +238,7 @@ func (e *EngineTest) Gossip(ctx context.Context) error { return errGossip } -func (e *EngineTest) Halt(ctx context.Context) { +func (e *Engine) Halt(ctx context.Context) { if e.HaltF != nil { e.HaltF(ctx) return @@ -251,7 +251,7 @@ func (e *EngineTest) Halt(ctx context.Context) { } } -func (e *EngineTest) Shutdown(ctx context.Context) error { +func (e *Engine) Shutdown(ctx context.Context) error { if e.ShutdownF != nil { return e.ShutdownF(ctx) } @@ -264,7 +264,7 @@ func (e *EngineTest) Shutdown(ctx context.Context) error { return errShutdown } -func (e *EngineTest) Notify(ctx context.Context, msg common.Message) error { +func (e *Engine) Notify(ctx context.Context, msg common.Message) error { if e.NotifyF != nil { return e.NotifyF(ctx, msg) } @@ -277,7 +277,7 @@ func (e *EngineTest) Notify(ctx context.Context, msg common.Message) error { return errNotify } -func (e *EngineTest) GetStateSummaryFrontier(ctx context.Context, validatorID ids.NodeID, requestID uint32) error { +func (e *Engine) GetStateSummaryFrontier(ctx context.Context, validatorID ids.NodeID, requestID uint32) error { if e.GetStateSummaryFrontierF != nil { return e.GetStateSummaryFrontierF(ctx, validatorID, requestID) } @@ -290,7 +290,7 @@ func (e *EngineTest) GetStateSummaryFrontier(ctx context.Context, validatorID id return errGetStateSummaryFrontier } -func (e *EngineTest) StateSummaryFrontier(ctx context.Context, validatorID ids.NodeID, requestID uint32, summary []byte) error { +func (e *Engine) StateSummaryFrontier(ctx context.Context, validatorID ids.NodeID, requestID uint32, summary []byte) error { if e.StateSummaryFrontierF != nil { return e.StateSummaryFrontierF(ctx, validatorID, requestID, summary) } @@ -303,7 +303,7 @@ func (e *EngineTest) StateSummaryFrontier(ctx context.Context, validatorID ids.N return errStateSummaryFrontier } -func (e *EngineTest) GetStateSummaryFrontierFailed(ctx context.Context, validatorID ids.NodeID, requestID uint32) error { +func (e *Engine) GetStateSummaryFrontierFailed(ctx context.Context, validatorID ids.NodeID, requestID uint32) error { if e.GetStateSummaryFrontierFailedF != nil { return e.GetStateSummaryFrontierFailedF(ctx, validatorID, requestID) } @@ -316,7 +316,7 @@ func (e *EngineTest) GetStateSummaryFrontierFailed(ctx context.Context, validato return errGetStateSummaryFrontierFailed } -func (e *EngineTest) GetAcceptedStateSummary(ctx context.Context, validatorID ids.NodeID, requestID uint32, keys set.Set[uint64]) error { +func (e *Engine) GetAcceptedStateSummary(ctx context.Context, validatorID ids.NodeID, requestID uint32, keys set.Set[uint64]) error { if e.GetAcceptedStateSummaryF != nil { return e.GetAcceptedStateSummaryF(ctx, validatorID, requestID, keys) } @@ -329,7 +329,7 @@ func (e *EngineTest) GetAcceptedStateSummary(ctx context.Context, validatorID id return errGetAcceptedStateSummary } -func (e *EngineTest) AcceptedStateSummary(ctx context.Context, validatorID ids.NodeID, requestID uint32, summaryIDs set.Set[ids.ID]) error { +func (e *Engine) AcceptedStateSummary(ctx context.Context, validatorID ids.NodeID, requestID uint32, summaryIDs set.Set[ids.ID]) error { if e.AcceptedStateSummaryF != nil { return e.AcceptedStateSummaryF(ctx, validatorID, requestID, summaryIDs) } @@ -342,7 +342,7 @@ func (e *EngineTest) AcceptedStateSummary(ctx context.Context, validatorID ids.N return errAcceptedStateSummary } -func (e *EngineTest) GetAcceptedStateSummaryFailed(ctx context.Context, validatorID ids.NodeID, requestID uint32) error { +func (e *Engine) GetAcceptedStateSummaryFailed(ctx context.Context, validatorID ids.NodeID, requestID uint32) error { if e.GetAcceptedStateSummaryFailedF != nil { return e.GetAcceptedStateSummaryFailedF(ctx, validatorID, requestID) } @@ -355,7 +355,7 @@ func (e *EngineTest) GetAcceptedStateSummaryFailed(ctx context.Context, validato return errGetAcceptedStateSummaryFailed } -func (e *EngineTest) GetAcceptedFrontier(ctx context.Context, nodeID ids.NodeID, requestID uint32) error { +func (e *Engine) GetAcceptedFrontier(ctx context.Context, nodeID ids.NodeID, requestID uint32) error { if e.GetAcceptedFrontierF != nil { return e.GetAcceptedFrontierF(ctx, nodeID, requestID) } @@ -368,7 +368,7 @@ func (e *EngineTest) GetAcceptedFrontier(ctx context.Context, nodeID ids.NodeID, return errGetAcceptedFrontier } -func (e *EngineTest) GetAcceptedFrontierFailed(ctx context.Context, nodeID ids.NodeID, requestID uint32) error { +func (e *Engine) GetAcceptedFrontierFailed(ctx context.Context, nodeID ids.NodeID, requestID uint32) error { if e.GetAcceptedFrontierFailedF != nil { return e.GetAcceptedFrontierFailedF(ctx, nodeID, requestID) } @@ -381,7 +381,7 @@ func (e *EngineTest) GetAcceptedFrontierFailed(ctx context.Context, nodeID ids.N return errGetAcceptedFrontierFailed } -func (e *EngineTest) AcceptedFrontier(ctx context.Context, nodeID ids.NodeID, requestID uint32, containerID ids.ID) error { +func (e *Engine) AcceptedFrontier(ctx context.Context, nodeID ids.NodeID, requestID uint32, containerID ids.ID) error { if e.AcceptedFrontierF != nil { return e.AcceptedFrontierF(ctx, nodeID, requestID, containerID) } @@ -394,7 +394,7 @@ func (e *EngineTest) AcceptedFrontier(ctx context.Context, nodeID ids.NodeID, re return errAcceptedFrontier } -func (e *EngineTest) GetAccepted(ctx context.Context, nodeID ids.NodeID, requestID uint32, containerIDs set.Set[ids.ID]) error { +func (e *Engine) GetAccepted(ctx context.Context, nodeID ids.NodeID, requestID uint32, containerIDs set.Set[ids.ID]) error { if e.GetAcceptedF != nil { return e.GetAcceptedF(ctx, nodeID, requestID, containerIDs) } @@ -407,7 +407,7 @@ func (e *EngineTest) GetAccepted(ctx context.Context, nodeID ids.NodeID, request return errGetAccepted } -func (e *EngineTest) GetAcceptedFailed(ctx context.Context, nodeID ids.NodeID, requestID uint32) error { +func (e *Engine) GetAcceptedFailed(ctx context.Context, nodeID ids.NodeID, requestID uint32) error { if e.GetAcceptedFailedF != nil { return e.GetAcceptedFailedF(ctx, nodeID, requestID) } @@ -420,7 +420,7 @@ func (e *EngineTest) GetAcceptedFailed(ctx context.Context, nodeID ids.NodeID, r return errGetAcceptedFailed } -func (e *EngineTest) Accepted(ctx context.Context, nodeID ids.NodeID, requestID uint32, containerIDs set.Set[ids.ID]) error { +func (e *Engine) Accepted(ctx context.Context, nodeID ids.NodeID, requestID uint32, containerIDs set.Set[ids.ID]) error { if e.AcceptedF != nil { return e.AcceptedF(ctx, nodeID, requestID, containerIDs) } @@ -433,7 +433,7 @@ func (e *EngineTest) Accepted(ctx context.Context, nodeID ids.NodeID, requestID return errAccepted } -func (e *EngineTest) Get(ctx context.Context, nodeID ids.NodeID, requestID uint32, containerID ids.ID) error { +func (e *Engine) Get(ctx context.Context, nodeID ids.NodeID, requestID uint32, containerID ids.ID) error { if e.GetF != nil { return e.GetF(ctx, nodeID, requestID, containerID) } @@ -446,7 +446,7 @@ func (e *EngineTest) Get(ctx context.Context, nodeID ids.NodeID, requestID uint3 return errGet } -func (e *EngineTest) GetAncestors(ctx context.Context, nodeID ids.NodeID, requestID uint32, containerID ids.ID) error { +func (e *Engine) GetAncestors(ctx context.Context, nodeID ids.NodeID, requestID uint32, containerID ids.ID) error { if e.GetAncestorsF != nil { return e.GetAncestorsF(ctx, nodeID, requestID, containerID) } @@ -459,7 +459,7 @@ func (e *EngineTest) GetAncestors(ctx context.Context, nodeID ids.NodeID, reques return errGetAncestors } -func (e *EngineTest) GetFailed(ctx context.Context, nodeID ids.NodeID, requestID uint32) error { +func (e *Engine) GetFailed(ctx context.Context, nodeID ids.NodeID, requestID uint32) error { if e.GetFailedF != nil { return e.GetFailedF(ctx, nodeID, requestID) } @@ -472,7 +472,7 @@ func (e *EngineTest) GetFailed(ctx context.Context, nodeID ids.NodeID, requestID return errGetFailed } -func (e *EngineTest) GetAncestorsFailed(ctx context.Context, nodeID ids.NodeID, requestID uint32) error { +func (e *Engine) GetAncestorsFailed(ctx context.Context, nodeID ids.NodeID, requestID uint32) error { if e.GetAncestorsFailedF != nil { return e.GetAncestorsFailedF(ctx, nodeID, requestID) } @@ -485,7 +485,7 @@ func (e *EngineTest) GetAncestorsFailed(ctx context.Context, nodeID ids.NodeID, return errGetAncestorsFailed } -func (e *EngineTest) Put(ctx context.Context, nodeID ids.NodeID, requestID uint32, container []byte) error { +func (e *Engine) Put(ctx context.Context, nodeID ids.NodeID, requestID uint32, container []byte) error { if e.PutF != nil { return e.PutF(ctx, nodeID, requestID, container) } @@ -498,7 +498,7 @@ func (e *EngineTest) Put(ctx context.Context, nodeID ids.NodeID, requestID uint3 return errPut } -func (e *EngineTest) Ancestors(ctx context.Context, nodeID ids.NodeID, requestID uint32, containers [][]byte) error { +func (e *Engine) Ancestors(ctx context.Context, nodeID ids.NodeID, requestID uint32, containers [][]byte) error { if e.AncestorsF != nil { return e.AncestorsF(ctx, nodeID, requestID, containers) } @@ -511,7 +511,7 @@ func (e *EngineTest) Ancestors(ctx context.Context, nodeID ids.NodeID, requestID return errAncestors } -func (e *EngineTest) PushQuery(ctx context.Context, nodeID ids.NodeID, requestID uint32, container []byte, requestedHeight uint64) error { +func (e *Engine) PushQuery(ctx context.Context, nodeID ids.NodeID, requestID uint32, container []byte, requestedHeight uint64) error { if e.PushQueryF != nil { return e.PushQueryF(ctx, nodeID, requestID, container, requestedHeight) } @@ -524,7 +524,7 @@ func (e *EngineTest) PushQuery(ctx context.Context, nodeID ids.NodeID, requestID return errPushQuery } -func (e *EngineTest) PullQuery(ctx context.Context, nodeID ids.NodeID, requestID uint32, containerID ids.ID, requestedHeight uint64) error { +func (e *Engine) PullQuery(ctx context.Context, nodeID ids.NodeID, requestID uint32, containerID ids.ID, requestedHeight uint64) error { if e.PullQueryF != nil { return e.PullQueryF(ctx, nodeID, requestID, containerID, requestedHeight) } @@ -537,7 +537,7 @@ func (e *EngineTest) PullQuery(ctx context.Context, nodeID ids.NodeID, requestID return errPullQuery } -func (e *EngineTest) QueryFailed(ctx context.Context, nodeID ids.NodeID, requestID uint32) error { +func (e *Engine) QueryFailed(ctx context.Context, nodeID ids.NodeID, requestID uint32) error { if e.QueryFailedF != nil { return e.QueryFailedF(ctx, nodeID, requestID) } @@ -550,7 +550,7 @@ func (e *EngineTest) QueryFailed(ctx context.Context, nodeID ids.NodeID, request return errQueryFailed } -func (e *EngineTest) CrossChainAppRequest(ctx context.Context, chainID ids.ID, requestID uint32, deadline time.Time, request []byte) error { +func (e *Engine) CrossChainAppRequest(ctx context.Context, chainID ids.ID, requestID uint32, deadline time.Time, request []byte) error { if e.CrossChainAppRequestF != nil { return e.CrossChainAppRequestF(ctx, chainID, requestID, deadline, request) } @@ -563,7 +563,7 @@ func (e *EngineTest) CrossChainAppRequest(ctx context.Context, chainID ids.ID, r return errCrossChainAppRequest } -func (e *EngineTest) CrossChainAppRequestFailed(ctx context.Context, chainID ids.ID, requestID uint32, appErr *common.AppError) error { +func (e *Engine) CrossChainAppRequestFailed(ctx context.Context, chainID ids.ID, requestID uint32, appErr *common.AppError) error { if e.CrossChainAppRequestFailedF != nil { return e.CrossChainAppRequestFailedF(ctx, chainID, requestID, appErr) } @@ -576,7 +576,7 @@ func (e *EngineTest) CrossChainAppRequestFailed(ctx context.Context, chainID ids return errCrossChainAppRequestFailed } -func (e *EngineTest) CrossChainAppResponse(ctx context.Context, chainID ids.ID, requestID uint32, response []byte) error { +func (e *Engine) CrossChainAppResponse(ctx context.Context, chainID ids.ID, requestID uint32, response []byte) error { if e.CrossChainAppResponseF != nil { return e.CrossChainAppResponseF(ctx, chainID, requestID, response) } @@ -589,7 +589,7 @@ func (e *EngineTest) CrossChainAppResponse(ctx context.Context, chainID ids.ID, return errCrossChainAppResponse } -func (e *EngineTest) AppRequest(ctx context.Context, nodeID ids.NodeID, requestID uint32, deadline time.Time, request []byte) error { +func (e *Engine) AppRequest(ctx context.Context, nodeID ids.NodeID, requestID uint32, deadline time.Time, request []byte) error { if e.AppRequestF != nil { return e.AppRequestF(ctx, nodeID, requestID, deadline, request) } @@ -602,7 +602,7 @@ func (e *EngineTest) AppRequest(ctx context.Context, nodeID ids.NodeID, requestI return errAppRequest } -func (e *EngineTest) AppResponse(ctx context.Context, nodeID ids.NodeID, requestID uint32, response []byte) error { +func (e *Engine) AppResponse(ctx context.Context, nodeID ids.NodeID, requestID uint32, response []byte) error { if e.AppResponseF != nil { return e.AppResponseF(ctx, nodeID, requestID, response) } @@ -615,7 +615,7 @@ func (e *EngineTest) AppResponse(ctx context.Context, nodeID ids.NodeID, request return errAppResponse } -func (e *EngineTest) AppRequestFailed(ctx context.Context, nodeID ids.NodeID, requestID uint32, appErr *common.AppError) error { +func (e *Engine) AppRequestFailed(ctx context.Context, nodeID ids.NodeID, requestID uint32, appErr *common.AppError) error { if e.AppRequestFailedF != nil { return e.AppRequestFailedF(ctx, nodeID, requestID, appErr) } @@ -628,7 +628,7 @@ func (e *EngineTest) AppRequestFailed(ctx context.Context, nodeID ids.NodeID, re return errAppRequestFailed } -func (e *EngineTest) AppGossip(ctx context.Context, nodeID ids.NodeID, msg []byte) error { +func (e *Engine) AppGossip(ctx context.Context, nodeID ids.NodeID, msg []byte) error { if e.AppGossipF != nil { return e.AppGossipF(ctx, nodeID, msg) } @@ -641,7 +641,7 @@ func (e *EngineTest) AppGossip(ctx context.Context, nodeID ids.NodeID, msg []byt return errAppGossip } -func (e *EngineTest) Chits(ctx context.Context, nodeID ids.NodeID, requestID uint32, preferredID ids.ID, preferredIDAtHeight ids.ID, acceptedID ids.ID) error { +func (e *Engine) Chits(ctx context.Context, nodeID ids.NodeID, requestID uint32, preferredID ids.ID, preferredIDAtHeight ids.ID, acceptedID ids.ID) error { if e.ChitsF != nil { return e.ChitsF(ctx, nodeID, requestID, preferredID, preferredIDAtHeight, acceptedID) } @@ -654,7 +654,7 @@ func (e *EngineTest) Chits(ctx context.Context, nodeID ids.NodeID, requestID uin return errChits } -func (e *EngineTest) Connected(ctx context.Context, nodeID ids.NodeID, nodeVersion *version.Application) error { +func (e *Engine) Connected(ctx context.Context, nodeID ids.NodeID, nodeVersion *version.Application) error { if e.ConnectedF != nil { return e.ConnectedF(ctx, nodeID, nodeVersion) } @@ -667,7 +667,7 @@ func (e *EngineTest) Connected(ctx context.Context, nodeID ids.NodeID, nodeVersi return errConnected } -func (e *EngineTest) Disconnected(ctx context.Context, nodeID ids.NodeID) error { +func (e *Engine) Disconnected(ctx context.Context, nodeID ids.NodeID) error { if e.DisconnectedF != nil { return e.DisconnectedF(ctx, nodeID) } @@ -680,7 +680,7 @@ func (e *EngineTest) Disconnected(ctx context.Context, nodeID ids.NodeID) error return errDisconnected } -func (e *EngineTest) HealthCheck(ctx context.Context) (interface{}, error) { +func (e *Engine) HealthCheck(ctx context.Context) (interface{}, error) { if e.HealthF != nil { return e.HealthF(ctx) } diff --git a/snow/engine/enginetest/sender.go b/snow/engine/enginetest/sender.go index 25c227969315..02a57caf4b6f 100644 --- a/snow/engine/enginetest/sender.go +++ b/snow/engine/enginetest/sender.go @@ -6,6 +6,7 @@ package enginetest import ( "context" "errors" + "testing" "github.com/stretchr/testify/require" @@ -15,8 +16,8 @@ import ( ) var ( - _ common.Sender = (*SenderTest)(nil) - _ common.AppSender = (*FakeSender)(nil) + _ common.Sender = (*Sender)(nil) + _ common.AppSender = (*SenderStub)(nil) errSendAppRequest = errors.New("unexpectedly called SendAppRequest") errSendAppResponse = errors.New("unexpectedly called SendAppResponse") @@ -24,9 +25,9 @@ var ( errSendAppGossip = errors.New("unexpectedly called SendAppGossip") ) -// SenderTest is a test sender -type SenderTest struct { - T require.TestingT +// Sender is a test sender +type Sender struct { + T *testing.T CantSendGetStateSummaryFrontier, CantSendStateSummaryFrontier, CantSendGetAcceptedStateSummary, CantSendAcceptedStateSummary, @@ -63,7 +64,7 @@ type SenderTest struct { } // Default set the default callable value to [cant] -func (s *SenderTest) Default(cant bool) { +func (s *Sender) Default(cant bool) { s.CantSendGetStateSummaryFrontier = cant s.CantSendStateSummaryFrontier = cant s.CantSendGetAcceptedStateSummary = cant @@ -89,7 +90,7 @@ func (s *SenderTest) Default(cant bool) { // SendGetStateSummaryFrontier calls SendGetStateSummaryFrontierF if it was // initialized. If it wasn't initialized and this function shouldn't be called // and testing was initialized, then testing will fail. -func (s *SenderTest) SendGetStateSummaryFrontier(ctx context.Context, validatorIDs set.Set[ids.NodeID], requestID uint32) { +func (s *Sender) SendGetStateSummaryFrontier(ctx context.Context, validatorIDs set.Set[ids.NodeID], requestID uint32) { if s.SendGetStateSummaryFrontierF != nil { s.SendGetStateSummaryFrontierF(ctx, validatorIDs, requestID) } else if s.CantSendGetStateSummaryFrontier && s.T != nil { @@ -100,7 +101,7 @@ func (s *SenderTest) SendGetStateSummaryFrontier(ctx context.Context, validatorI // SendStateSummaryFrontier calls SendStateSummaryFrontierF if it was // initialized. If it wasn't initialized and this function shouldn't be called // and testing was initialized, then testing will fail. -func (s *SenderTest) SendStateSummaryFrontier(ctx context.Context, validatorID ids.NodeID, requestID uint32, summary []byte) { +func (s *Sender) SendStateSummaryFrontier(ctx context.Context, validatorID ids.NodeID, requestID uint32, summary []byte) { if s.SendStateSummaryFrontierF != nil { s.SendStateSummaryFrontierF(ctx, validatorID, requestID, summary) } else if s.CantSendStateSummaryFrontier && s.T != nil { @@ -111,7 +112,7 @@ func (s *SenderTest) SendStateSummaryFrontier(ctx context.Context, validatorID i // SendGetAcceptedStateSummary calls SendGetAcceptedStateSummaryF if it was // initialized. If it wasn't initialized and this function shouldn't be called // and testing was initialized, then testing will fail. -func (s *SenderTest) SendGetAcceptedStateSummary(ctx context.Context, nodeIDs set.Set[ids.NodeID], requestID uint32, heights []uint64) { +func (s *Sender) SendGetAcceptedStateSummary(ctx context.Context, nodeIDs set.Set[ids.NodeID], requestID uint32, heights []uint64) { if s.SendGetAcceptedStateSummaryF != nil { s.SendGetAcceptedStateSummaryF(ctx, nodeIDs, requestID, heights) } else if s.CantSendGetAcceptedStateSummary && s.T != nil { @@ -122,7 +123,7 @@ func (s *SenderTest) SendGetAcceptedStateSummary(ctx context.Context, nodeIDs se // SendAcceptedStateSummary calls SendAcceptedStateSummaryF if it was // initialized. If it wasn't initialized and this function shouldn't be called // and testing was initialized, then testing will fail. -func (s *SenderTest) SendAcceptedStateSummary(ctx context.Context, validatorID ids.NodeID, requestID uint32, summaryIDs []ids.ID) { +func (s *Sender) SendAcceptedStateSummary(ctx context.Context, validatorID ids.NodeID, requestID uint32, summaryIDs []ids.ID) { if s.SendAcceptedStateSummaryF != nil { s.SendAcceptedStateSummaryF(ctx, validatorID, requestID, summaryIDs) } else if s.CantSendAcceptedStateSummary && s.T != nil { @@ -133,7 +134,7 @@ func (s *SenderTest) SendAcceptedStateSummary(ctx context.Context, validatorID i // SendGetAcceptedFrontier calls SendGetAcceptedFrontierF if it was initialized. // If it wasn't initialized and this function shouldn't be called and testing // was initialized, then testing will fail. -func (s *SenderTest) SendGetAcceptedFrontier(ctx context.Context, validatorIDs set.Set[ids.NodeID], requestID uint32) { +func (s *Sender) SendGetAcceptedFrontier(ctx context.Context, validatorIDs set.Set[ids.NodeID], requestID uint32) { if s.SendGetAcceptedFrontierF != nil { s.SendGetAcceptedFrontierF(ctx, validatorIDs, requestID) } else if s.CantSendGetAcceptedFrontier && s.T != nil { @@ -144,7 +145,7 @@ func (s *SenderTest) SendGetAcceptedFrontier(ctx context.Context, validatorIDs s // SendAcceptedFrontier calls SendAcceptedFrontierF if it was initialized. If it // wasn't initialized and this function shouldn't be called and testing was // initialized, then testing will fail. -func (s *SenderTest) SendAcceptedFrontier(ctx context.Context, validatorID ids.NodeID, requestID uint32, containerID ids.ID) { +func (s *Sender) SendAcceptedFrontier(ctx context.Context, validatorID ids.NodeID, requestID uint32, containerID ids.ID) { if s.SendAcceptedFrontierF != nil { s.SendAcceptedFrontierF(ctx, validatorID, requestID, containerID) } else if s.CantSendAcceptedFrontier && s.T != nil { @@ -155,7 +156,7 @@ func (s *SenderTest) SendAcceptedFrontier(ctx context.Context, validatorID ids.N // SendGetAccepted calls SendGetAcceptedF if it was initialized. If it wasn't // initialized and this function shouldn't be called and testing was // initialized, then testing will fail. -func (s *SenderTest) SendGetAccepted(ctx context.Context, nodeIDs set.Set[ids.NodeID], requestID uint32, containerIDs []ids.ID) { +func (s *Sender) SendGetAccepted(ctx context.Context, nodeIDs set.Set[ids.NodeID], requestID uint32, containerIDs []ids.ID) { if s.SendGetAcceptedF != nil { s.SendGetAcceptedF(ctx, nodeIDs, requestID, containerIDs) } else if s.CantSendGetAccepted && s.T != nil { @@ -166,7 +167,7 @@ func (s *SenderTest) SendGetAccepted(ctx context.Context, nodeIDs set.Set[ids.No // SendAccepted calls SendAcceptedF if it was initialized. If it wasn't // initialized and this function shouldn't be called and testing was // initialized, then testing will fail. -func (s *SenderTest) SendAccepted(ctx context.Context, validatorID ids.NodeID, requestID uint32, containerIDs []ids.ID) { +func (s *Sender) SendAccepted(ctx context.Context, validatorID ids.NodeID, requestID uint32, containerIDs []ids.ID) { if s.SendAcceptedF != nil { s.SendAcceptedF(ctx, validatorID, requestID, containerIDs) } else if s.CantSendAccepted && s.T != nil { @@ -177,7 +178,7 @@ func (s *SenderTest) SendAccepted(ctx context.Context, validatorID ids.NodeID, r // SendGet calls SendGetF if it was initialized. If it wasn't initialized and // this function shouldn't be called and testing was initialized, then testing // will fail. -func (s *SenderTest) SendGet(ctx context.Context, vdr ids.NodeID, requestID uint32, containerID ids.ID) { +func (s *Sender) SendGet(ctx context.Context, vdr ids.NodeID, requestID uint32, containerID ids.ID) { if s.SendGetF != nil { s.SendGetF(ctx, vdr, requestID, containerID) } else if s.CantSendGet && s.T != nil { @@ -188,7 +189,7 @@ func (s *SenderTest) SendGet(ctx context.Context, vdr ids.NodeID, requestID uint // SendGetAncestors calls SendGetAncestorsF if it was initialized. If it wasn't // initialized and this function shouldn't be called and testing was // initialized, then testing will fail. -func (s *SenderTest) SendGetAncestors(ctx context.Context, validatorID ids.NodeID, requestID uint32, containerID ids.ID) { +func (s *Sender) SendGetAncestors(ctx context.Context, validatorID ids.NodeID, requestID uint32, containerID ids.ID) { if s.SendGetAncestorsF != nil { s.SendGetAncestorsF(ctx, validatorID, requestID, containerID) } else if s.CantSendGetAncestors && s.T != nil { @@ -199,7 +200,7 @@ func (s *SenderTest) SendGetAncestors(ctx context.Context, validatorID ids.NodeI // SendPut calls SendPutF if it was initialized. If it wasn't initialized and // this function shouldn't be called and testing was initialized, then testing // will fail. -func (s *SenderTest) SendPut(ctx context.Context, vdr ids.NodeID, requestID uint32, container []byte) { +func (s *Sender) SendPut(ctx context.Context, vdr ids.NodeID, requestID uint32, container []byte) { if s.SendPutF != nil { s.SendPutF(ctx, vdr, requestID, container) } else if s.CantSendPut && s.T != nil { @@ -210,7 +211,7 @@ func (s *SenderTest) SendPut(ctx context.Context, vdr ids.NodeID, requestID uint // SendAncestors calls SendAncestorsF if it was initialized. If it wasn't // initialized and this function shouldn't be called and testing was // initialized, then testing will fail. -func (s *SenderTest) SendAncestors(ctx context.Context, vdr ids.NodeID, requestID uint32, containers [][]byte) { +func (s *Sender) SendAncestors(ctx context.Context, vdr ids.NodeID, requestID uint32, containers [][]byte) { if s.SendAncestorsF != nil { s.SendAncestorsF(ctx, vdr, requestID, containers) } else if s.CantSendAncestors && s.T != nil { @@ -221,7 +222,7 @@ func (s *SenderTest) SendAncestors(ctx context.Context, vdr ids.NodeID, requestI // SendPushQuery calls SendPushQueryF if it was initialized. If it wasn't // initialized and this function shouldn't be called and testing was // initialized, then testing will fail. -func (s *SenderTest) SendPushQuery(ctx context.Context, vdrs set.Set[ids.NodeID], requestID uint32, container []byte, requestedHeight uint64) { +func (s *Sender) SendPushQuery(ctx context.Context, vdrs set.Set[ids.NodeID], requestID uint32, container []byte, requestedHeight uint64) { if s.SendPushQueryF != nil { s.SendPushQueryF(ctx, vdrs, requestID, container, requestedHeight) } else if s.CantSendPushQuery && s.T != nil { @@ -232,7 +233,7 @@ func (s *SenderTest) SendPushQuery(ctx context.Context, vdrs set.Set[ids.NodeID] // SendPullQuery calls SendPullQueryF if it was initialized. If it wasn't // initialized and this function shouldn't be called and testing was // initialized, then testing will fail. -func (s *SenderTest) SendPullQuery(ctx context.Context, vdrs set.Set[ids.NodeID], requestID uint32, containerID ids.ID, requestedHeight uint64) { +func (s *Sender) SendPullQuery(ctx context.Context, vdrs set.Set[ids.NodeID], requestID uint32, containerID ids.ID, requestedHeight uint64) { if s.SendPullQueryF != nil { s.SendPullQueryF(ctx, vdrs, requestID, containerID, requestedHeight) } else if s.CantSendPullQuery && s.T != nil { @@ -243,7 +244,7 @@ func (s *SenderTest) SendPullQuery(ctx context.Context, vdrs set.Set[ids.NodeID] // SendChits calls SendChitsF if it was initialized. If it wasn't initialized // and this function shouldn't be called and testing was initialized, then // testing will fail. -func (s *SenderTest) SendChits(ctx context.Context, vdr ids.NodeID, requestID uint32, preferredID ids.ID, preferredIDAtHeight ids.ID, acceptedID ids.ID) { +func (s *Sender) SendChits(ctx context.Context, vdr ids.NodeID, requestID uint32, preferredID ids.ID, preferredIDAtHeight ids.ID, acceptedID ids.ID) { if s.SendChitsF != nil { s.SendChitsF(ctx, vdr, requestID, preferredID, preferredIDAtHeight, acceptedID) } else if s.CantSendChits && s.T != nil { @@ -254,7 +255,7 @@ func (s *SenderTest) SendChits(ctx context.Context, vdr ids.NodeID, requestID ui // SendCrossChainAppRequest calls SendCrossChainAppRequestF if it was // initialized. If it wasn't initialized and this function shouldn't be called // and testing was initialized, then testing will fail. -func (s *SenderTest) SendCrossChainAppRequest(ctx context.Context, chainID ids.ID, requestID uint32, appRequestBytes []byte) error { +func (s *Sender) SendCrossChainAppRequest(ctx context.Context, chainID ids.ID, requestID uint32, appRequestBytes []byte) error { if s.SendCrossChainAppRequestF != nil { s.SendCrossChainAppRequestF(ctx, chainID, requestID, appRequestBytes) } else if s.CantSendCrossChainAppRequest && s.T != nil { @@ -266,7 +267,7 @@ func (s *SenderTest) SendCrossChainAppRequest(ctx context.Context, chainID ids.I // SendCrossChainAppResponse calls SendCrossChainAppResponseF if it was // initialized. If it wasn't initialized and this function shouldn't be called // and testing was initialized, then testing will fail. -func (s *SenderTest) SendCrossChainAppResponse(ctx context.Context, chainID ids.ID, requestID uint32, appResponseBytes []byte) error { +func (s *Sender) SendCrossChainAppResponse(ctx context.Context, chainID ids.ID, requestID uint32, appResponseBytes []byte) error { if s.SendCrossChainAppResponseF != nil { s.SendCrossChainAppResponseF(ctx, chainID, requestID, appResponseBytes) } else if s.CantSendCrossChainAppResponse && s.T != nil { @@ -278,7 +279,7 @@ func (s *SenderTest) SendCrossChainAppResponse(ctx context.Context, chainID ids. // SendCrossChainAppError calls SendCrossChainAppErrorF if it was // initialized. If it wasn't initialized and this function shouldn't be called // and testing was initialized, then testing will fail. -func (s *SenderTest) SendCrossChainAppError(ctx context.Context, chainID ids.ID, requestID uint32, errorCode int32, errorMessage string) error { +func (s *Sender) SendCrossChainAppError(ctx context.Context, chainID ids.ID, requestID uint32, errorCode int32, errorMessage string) error { if s.SendCrossChainAppErrorF != nil { s.SendCrossChainAppErrorF(ctx, chainID, requestID, errorCode, errorMessage) } else if s.CantSendCrossChainAppError && s.T != nil { @@ -290,7 +291,7 @@ func (s *SenderTest) SendCrossChainAppError(ctx context.Context, chainID ids.ID, // SendAppRequest calls SendAppRequestF if it was initialized. If it wasn't // initialized and this function shouldn't be called and testing was // initialized, then testing will fail. -func (s *SenderTest) SendAppRequest(ctx context.Context, nodeIDs set.Set[ids.NodeID], requestID uint32, appRequestBytes []byte) error { +func (s *Sender) SendAppRequest(ctx context.Context, nodeIDs set.Set[ids.NodeID], requestID uint32, appRequestBytes []byte) error { switch { case s.SendAppRequestF != nil: return s.SendAppRequestF(ctx, nodeIDs, requestID, appRequestBytes) @@ -303,7 +304,7 @@ func (s *SenderTest) SendAppRequest(ctx context.Context, nodeIDs set.Set[ids.Nod // SendAppResponse calls SendAppResponseF if it was initialized. If it wasn't // initialized and this function shouldn't be called and testing was // initialized, then testing will fail. -func (s *SenderTest) SendAppResponse(ctx context.Context, nodeID ids.NodeID, requestID uint32, appResponseBytes []byte) error { +func (s *Sender) SendAppResponse(ctx context.Context, nodeID ids.NodeID, requestID uint32, appResponseBytes []byte) error { switch { case s.SendAppResponseF != nil: return s.SendAppResponseF(ctx, nodeID, requestID, appResponseBytes) @@ -316,7 +317,7 @@ func (s *SenderTest) SendAppResponse(ctx context.Context, nodeID ids.NodeID, req // SendAppError calls SendAppErrorF if it was initialized. If it wasn't // initialized and this function shouldn't be called and testing was // initialized, then testing will fail. -func (s *SenderTest) SendAppError(ctx context.Context, nodeID ids.NodeID, requestID uint32, code int32, message string) error { +func (s *Sender) SendAppError(ctx context.Context, nodeID ids.NodeID, requestID uint32, code int32, message string) error { switch { case s.SendAppErrorF != nil: return s.SendAppErrorF(ctx, nodeID, requestID, code, message) @@ -329,7 +330,7 @@ func (s *SenderTest) SendAppError(ctx context.Context, nodeID ids.NodeID, reques // SendAppGossip calls SendAppGossipF if it was initialized. If it wasn't // initialized and this function shouldn't be called and testing was // initialized, then testing will fail. -func (s *SenderTest) SendAppGossip( +func (s *Sender) SendAppGossip( ctx context.Context, config common.SendConfig, appGossipBytes []byte, @@ -343,8 +344,8 @@ func (s *SenderTest) SendAppGossip( return errSendAppGossip } -// FakeSender is used for testing -type FakeSender struct { +// SenderStub is a stub sender that returns values received on method-specific channels. +type SenderStub struct { SentAppRequest, SentAppResponse, SentAppGossip, SentCrossChainAppRequest, SentCrossChainAppResponse chan []byte @@ -352,7 +353,7 @@ type FakeSender struct { SentAppError, SentCrossChainAppError chan *common.AppError } -func (f FakeSender) SendAppRequest(_ context.Context, _ set.Set[ids.NodeID], _ uint32, bytes []byte) error { +func (f SenderStub) SendAppRequest(_ context.Context, _ set.Set[ids.NodeID], _ uint32, bytes []byte) error { if f.SentAppRequest == nil { return nil } @@ -361,7 +362,7 @@ func (f FakeSender) SendAppRequest(_ context.Context, _ set.Set[ids.NodeID], _ u return nil } -func (f FakeSender) SendAppResponse(_ context.Context, _ ids.NodeID, _ uint32, bytes []byte) error { +func (f SenderStub) SendAppResponse(_ context.Context, _ ids.NodeID, _ uint32, bytes []byte) error { if f.SentAppResponse == nil { return nil } @@ -370,7 +371,7 @@ func (f FakeSender) SendAppResponse(_ context.Context, _ ids.NodeID, _ uint32, b return nil } -func (f FakeSender) SendAppError(_ context.Context, _ ids.NodeID, _ uint32, errorCode int32, errorMessage string) error { +func (f SenderStub) SendAppError(_ context.Context, _ ids.NodeID, _ uint32, errorCode int32, errorMessage string) error { if f.SentAppError == nil { return nil } @@ -382,7 +383,7 @@ func (f FakeSender) SendAppError(_ context.Context, _ ids.NodeID, _ uint32, erro return nil } -func (f FakeSender) SendAppGossip(_ context.Context, _ common.SendConfig, bytes []byte) error { +func (f SenderStub) SendAppGossip(_ context.Context, _ common.SendConfig, bytes []byte) error { if f.SentAppGossip == nil { return nil } @@ -391,7 +392,7 @@ func (f FakeSender) SendAppGossip(_ context.Context, _ common.SendConfig, bytes return nil } -func (f FakeSender) SendCrossChainAppRequest(_ context.Context, _ ids.ID, _ uint32, bytes []byte) error { +func (f SenderStub) SendCrossChainAppRequest(_ context.Context, _ ids.ID, _ uint32, bytes []byte) error { if f.SentCrossChainAppRequest == nil { return nil } @@ -400,7 +401,7 @@ func (f FakeSender) SendCrossChainAppRequest(_ context.Context, _ ids.ID, _ uint return nil } -func (f FakeSender) SendCrossChainAppResponse(_ context.Context, _ ids.ID, _ uint32, bytes []byte) error { +func (f SenderStub) SendCrossChainAppResponse(_ context.Context, _ ids.ID, _ uint32, bytes []byte) error { if f.SentCrossChainAppResponse == nil { return nil } @@ -409,7 +410,7 @@ func (f FakeSender) SendCrossChainAppResponse(_ context.Context, _ ids.ID, _ uin return nil } -func (f FakeSender) SendCrossChainAppError(_ context.Context, _ ids.ID, _ uint32, errorCode int32, errorMessage string) error { +func (f SenderStub) SendCrossChainAppError(_ context.Context, _ ids.ID, _ uint32, errorCode int32, errorMessage string) error { if f.SentCrossChainAppError == nil { return nil } diff --git a/snow/engine/enginetest/timer.go b/snow/engine/enginetest/timer.go index f0bd1700dd66..f2161b5c8c17 100644 --- a/snow/engine/enginetest/timer.go +++ b/snow/engine/enginetest/timer.go @@ -12,10 +12,10 @@ import ( "github.com/ava-labs/avalanchego/snow/engine/common" ) -var _ common.Timer = (*TimerTest)(nil) +var _ common.Timer = (*Timer)(nil) -// TimerTest is a test timer -type TimerTest struct { +// Timer is a test timer +type Timer struct { T *testing.T CantRegisterTimout bool @@ -24,11 +24,11 @@ type TimerTest struct { } // Default set the default callable value to [cant] -func (t *TimerTest) Default(cant bool) { +func (t *Timer) Default(cant bool) { t.CantRegisterTimout = cant } -func (t *TimerTest) RegisterTimeout(delay time.Duration) { +func (t *Timer) RegisterTimeout(delay time.Duration) { if t.RegisterTimeoutF != nil { t.RegisterTimeoutF(delay) } else if t.CantRegisterTimout && t.T != nil { diff --git a/snow/engine/enginetest/vm.go b/snow/engine/enginetest/vm.go index 2dadf512fe92..d6c5f4d0feb3 100644 --- a/snow/engine/enginetest/vm.go +++ b/snow/engine/enginetest/vm.go @@ -36,11 +36,11 @@ var ( errCrossChainAppResponse = errors.New("unexpectedly called CrossChainAppResponse") errCrossChainAppRequestFailed = errors.New("unexpectedly called CrossChainAppRequestFailed") - _ common.VM = (*TestVM)(nil) + _ common.VM = (*VM)(nil) ) -// TestVM is a test vm -type TestVM struct { +// VM is a test vm +type VM struct { T *testing.T CantInitialize, CantSetState, @@ -66,7 +66,7 @@ type TestVM struct { CrossChainAppRequestFailedF func(ctx context.Context, chainID ids.ID, requestID uint32, appErr *common.AppError) error } -func (vm *TestVM) Default(cant bool) { +func (vm *VM) Default(cant bool) { vm.CantInitialize = cant vm.CantSetState = cant vm.CantShutdown = cant @@ -84,7 +84,7 @@ func (vm *TestVM) Default(cant bool) { vm.CantCrossChainAppResponse = cant } -func (vm *TestVM) Initialize( +func (vm *VM) Initialize( ctx context.Context, chainCtx *snow.Context, db database.Database, @@ -114,7 +114,7 @@ func (vm *TestVM) Initialize( return errInitialize } -func (vm *TestVM) SetState(ctx context.Context, state snow.State) error { +func (vm *VM) SetState(ctx context.Context, state snow.State) error { if vm.SetStateF != nil { return vm.SetStateF(ctx, state) } @@ -127,7 +127,7 @@ func (vm *TestVM) SetState(ctx context.Context, state snow.State) error { return nil } -func (vm *TestVM) Shutdown(ctx context.Context) error { +func (vm *VM) Shutdown(ctx context.Context) error { if vm.ShutdownF != nil { return vm.ShutdownF(ctx) } @@ -140,7 +140,7 @@ func (vm *TestVM) Shutdown(ctx context.Context) error { return nil } -func (vm *TestVM) CreateHandlers(ctx context.Context) (map[string]http.Handler, error) { +func (vm *VM) CreateHandlers(ctx context.Context) (map[string]http.Handler, error) { if vm.CreateHandlersF != nil { return vm.CreateHandlersF(ctx) } @@ -150,7 +150,7 @@ func (vm *TestVM) CreateHandlers(ctx context.Context) (map[string]http.Handler, return nil, nil } -func (vm *TestVM) HealthCheck(ctx context.Context) (interface{}, error) { +func (vm *VM) HealthCheck(ctx context.Context) (interface{}, error) { if vm.HealthCheckF != nil { return vm.HealthCheckF(ctx) } @@ -160,7 +160,7 @@ func (vm *TestVM) HealthCheck(ctx context.Context) (interface{}, error) { return nil, errHealthCheck } -func (vm *TestVM) AppRequest(ctx context.Context, nodeID ids.NodeID, requestID uint32, deadline time.Time, request []byte) error { +func (vm *VM) AppRequest(ctx context.Context, nodeID ids.NodeID, requestID uint32, deadline time.Time, request []byte) error { if vm.AppRequestF != nil { return vm.AppRequestF(ctx, nodeID, requestID, deadline, request) } @@ -173,7 +173,7 @@ func (vm *TestVM) AppRequest(ctx context.Context, nodeID ids.NodeID, requestID u return errAppRequest } -func (vm *TestVM) AppRequestFailed(ctx context.Context, nodeID ids.NodeID, requestID uint32, appErr *common.AppError) error { +func (vm *VM) AppRequestFailed(ctx context.Context, nodeID ids.NodeID, requestID uint32, appErr *common.AppError) error { if vm.AppRequestFailedF != nil { return vm.AppRequestFailedF(ctx, nodeID, requestID, appErr) } @@ -186,7 +186,7 @@ func (vm *TestVM) AppRequestFailed(ctx context.Context, nodeID ids.NodeID, reque return errAppRequestFailed } -func (vm *TestVM) AppResponse(ctx context.Context, nodeID ids.NodeID, requestID uint32, response []byte) error { +func (vm *VM) AppResponse(ctx context.Context, nodeID ids.NodeID, requestID uint32, response []byte) error { if vm.AppResponseF != nil { return vm.AppResponseF(ctx, nodeID, requestID, response) } @@ -199,7 +199,7 @@ func (vm *TestVM) AppResponse(ctx context.Context, nodeID ids.NodeID, requestID return errAppResponse } -func (vm *TestVM) AppGossip(ctx context.Context, nodeID ids.NodeID, msg []byte) error { +func (vm *VM) AppGossip(ctx context.Context, nodeID ids.NodeID, msg []byte) error { if vm.AppGossipF != nil { return vm.AppGossipF(ctx, nodeID, msg) } @@ -212,7 +212,7 @@ func (vm *TestVM) AppGossip(ctx context.Context, nodeID ids.NodeID, msg []byte) return errAppGossip } -func (vm *TestVM) CrossChainAppRequest(ctx context.Context, chainID ids.ID, requestID uint32, deadline time.Time, request []byte) error { +func (vm *VM) CrossChainAppRequest(ctx context.Context, chainID ids.ID, requestID uint32, deadline time.Time, request []byte) error { if vm.CrossChainAppRequestF != nil { return vm.CrossChainAppRequestF(ctx, chainID, requestID, deadline, request) } @@ -225,7 +225,7 @@ func (vm *TestVM) CrossChainAppRequest(ctx context.Context, chainID ids.ID, requ return errCrossChainAppRequest } -func (vm *TestVM) CrossChainAppRequestFailed(ctx context.Context, chainID ids.ID, requestID uint32, appErr *common.AppError) error { +func (vm *VM) CrossChainAppRequestFailed(ctx context.Context, chainID ids.ID, requestID uint32, appErr *common.AppError) error { if vm.CrossChainAppRequestFailedF != nil { return vm.CrossChainAppRequestFailedF(ctx, chainID, requestID, appErr) } @@ -238,7 +238,7 @@ func (vm *TestVM) CrossChainAppRequestFailed(ctx context.Context, chainID ids.ID return errCrossChainAppRequestFailed } -func (vm *TestVM) CrossChainAppResponse(ctx context.Context, chainID ids.ID, requestID uint32, response []byte) error { +func (vm *VM) CrossChainAppResponse(ctx context.Context, chainID ids.ID, requestID uint32, response []byte) error { if vm.CrossChainAppResponseF != nil { return vm.CrossChainAppResponseF(ctx, chainID, requestID, response) } @@ -251,7 +251,7 @@ func (vm *TestVM) CrossChainAppResponse(ctx context.Context, chainID ids.ID, req return errCrossChainAppResponse } -func (vm *TestVM) Connected(ctx context.Context, id ids.NodeID, nodeVersion *version.Application) error { +func (vm *VM) Connected(ctx context.Context, id ids.NodeID, nodeVersion *version.Application) error { if vm.ConnectedF != nil { return vm.ConnectedF(ctx, id, nodeVersion) } @@ -261,7 +261,7 @@ func (vm *TestVM) Connected(ctx context.Context, id ids.NodeID, nodeVersion *ver return nil } -func (vm *TestVM) Disconnected(ctx context.Context, id ids.NodeID) error { +func (vm *VM) Disconnected(ctx context.Context, id ids.NodeID) error { if vm.DisconnectedF != nil { return vm.DisconnectedF(ctx, id) } @@ -271,7 +271,7 @@ func (vm *TestVM) Disconnected(ctx context.Context, id ids.NodeID) error { return nil } -func (vm *TestVM) Version(ctx context.Context) (string, error) { +func (vm *VM) Version(ctx context.Context) (string, error) { if vm.VersionF != nil { return vm.VersionF(ctx) } diff --git a/snow/engine/snowman/block/batched_vm_test.go b/snow/engine/snowman/block/batched_vm_test.go index 86129d294084..a57ac13d5613 100644 --- a/snow/engine/snowman/block/batched_vm_test.go +++ b/snow/engine/snowman/block/batched_vm_test.go @@ -25,7 +25,7 @@ var errTest = errors.New("non-nil error") func TestGetAncestorsDatabaseNotFound(t *testing.T) { require := require.New(t) - vm := &blocktest.TestVM{} + vm := &blocktest.VM{} someID := ids.GenerateTestID() vm.GetBlockF = func(_ context.Context, id ids.ID) (snowman.Block, error) { require.Equal(someID, id) @@ -41,7 +41,7 @@ func TestGetAncestorsDatabaseNotFound(t *testing.T) { func TestGetAncestorsPropagatesErrors(t *testing.T) { require := require.New(t) - vm := &blocktest.TestVM{} + vm := &blocktest.VM{} someID := ids.GenerateTestID() vm.GetBlockF = func(_ context.Context, id ids.ID) (snowman.Block, error) { require.Equal(someID, id) diff --git a/snow/engine/snowman/block/blocktest/batched_vm.go b/snow/engine/snowman/block/blocktest/batched_vm.go index 1513df481bb2..16d70648ad2f 100644 --- a/snow/engine/snowman/block/blocktest/batched_vm.go +++ b/snow/engine/snowman/block/blocktest/batched_vm.go @@ -20,11 +20,11 @@ var ( errGetAncestor = errors.New("unexpectedly called GetAncestor") errBatchedParseBlock = errors.New("unexpectedly called BatchedParseBlock") - _ block.BatchedChainVM = (*TestBatchedVM)(nil) + _ block.BatchedChainVM = (*BatchedVM)(nil) ) -// TestBatchedVM is a BatchedVM that is useful for testing. -type TestBatchedVM struct { +// BatchedVM is a BatchedVM that is useful for testing. +type BatchedVM struct { T *testing.T CantGetAncestors bool @@ -44,12 +44,12 @@ type TestBatchedVM struct { ) ([]snowman.Block, error) } -func (vm *TestBatchedVM) Default(cant bool) { +func (vm *BatchedVM) Default(cant bool) { vm.CantGetAncestors = cant vm.CantBatchParseBlock = cant } -func (vm *TestBatchedVM) GetAncestors( +func (vm *BatchedVM) GetAncestors( ctx context.Context, blkID ids.ID, maxBlocksNum int, @@ -71,7 +71,7 @@ func (vm *TestBatchedVM) GetAncestors( return nil, errGetAncestor } -func (vm *TestBatchedVM) BatchedParseBlock( +func (vm *BatchedVM) BatchedParseBlock( ctx context.Context, blks [][]byte, ) ([]snowman.Block, error) { diff --git a/snow/engine/snowman/block/blocktest/state_summary.go b/snow/engine/snowman/block/blocktest/state_summary.go index d410c19a38b7..59256390c075 100644 --- a/snow/engine/snowman/block/blocktest/state_summary.go +++ b/snow/engine/snowman/block/blocktest/state_summary.go @@ -15,12 +15,12 @@ import ( ) var ( - _ block.StateSummary = (*TestStateSummary)(nil) + _ block.StateSummary = (*StateSummary)(nil) errAccept = errors.New("unexpectedly called Accept") ) -type TestStateSummary struct { +type StateSummary struct { IDV ids.ID HeightV uint64 BytesV []byte @@ -30,19 +30,19 @@ type TestStateSummary struct { AcceptF func(context.Context) (block.StateSyncMode, error) } -func (s *TestStateSummary) ID() ids.ID { +func (s *StateSummary) ID() ids.ID { return s.IDV } -func (s *TestStateSummary) Height() uint64 { +func (s *StateSummary) Height() uint64 { return s.HeightV } -func (s *TestStateSummary) Bytes() []byte { +func (s *StateSummary) Bytes() []byte { return s.BytesV } -func (s *TestStateSummary) Accept(ctx context.Context) (block.StateSyncMode, error) { +func (s *StateSummary) Accept(ctx context.Context) (block.StateSyncMode, error) { if s.AcceptF != nil { return s.AcceptF(ctx) } diff --git a/snow/engine/snowman/block/blocktest/state_syncable_vm.go b/snow/engine/snowman/block/blocktest/state_syncable_vm.go index 66b28c374be5..79b9e0a0bb26 100644 --- a/snow/engine/snowman/block/blocktest/state_syncable_vm.go +++ b/snow/engine/snowman/block/blocktest/state_syncable_vm.go @@ -14,7 +14,7 @@ import ( ) var ( - _ block.StateSyncableVM = (*TestStateSyncableVM)(nil) + _ block.StateSyncableVM = (*StateSyncableVM)(nil) errStateSyncEnabled = errors.New("unexpectedly called StateSyncEnabled") errStateSyncGetOngoingSummary = errors.New("unexpectedly called StateSyncGetOngoingSummary") @@ -23,7 +23,7 @@ var ( errGetStateSummary = errors.New("unexpectedly called GetStateSummary") ) -type TestStateSyncableVM struct { +type StateSyncableVM struct { T *testing.T CantStateSyncEnabled, @@ -39,7 +39,7 @@ type TestStateSyncableVM struct { GetStateSummaryF func(ctx context.Context, summaryHeight uint64) (block.StateSummary, error) } -func (vm *TestStateSyncableVM) StateSyncEnabled(ctx context.Context) (bool, error) { +func (vm *StateSyncableVM) StateSyncEnabled(ctx context.Context) (bool, error) { if vm.StateSyncEnabledF != nil { return vm.StateSyncEnabledF(ctx) } @@ -49,7 +49,7 @@ func (vm *TestStateSyncableVM) StateSyncEnabled(ctx context.Context) (bool, erro return false, errStateSyncEnabled } -func (vm *TestStateSyncableVM) GetOngoingSyncStateSummary(ctx context.Context) (block.StateSummary, error) { +func (vm *StateSyncableVM) GetOngoingSyncStateSummary(ctx context.Context) (block.StateSummary, error) { if vm.GetOngoingSyncStateSummaryF != nil { return vm.GetOngoingSyncStateSummaryF(ctx) } @@ -59,7 +59,7 @@ func (vm *TestStateSyncableVM) GetOngoingSyncStateSummary(ctx context.Context) ( return nil, errStateSyncGetOngoingSummary } -func (vm *TestStateSyncableVM) GetLastStateSummary(ctx context.Context) (block.StateSummary, error) { +func (vm *StateSyncableVM) GetLastStateSummary(ctx context.Context) (block.StateSummary, error) { if vm.GetLastStateSummaryF != nil { return vm.GetLastStateSummaryF(ctx) } @@ -69,7 +69,7 @@ func (vm *TestStateSyncableVM) GetLastStateSummary(ctx context.Context) (block.S return nil, errGetLastStateSummary } -func (vm *TestStateSyncableVM) ParseStateSummary(ctx context.Context, summaryBytes []byte) (block.StateSummary, error) { +func (vm *StateSyncableVM) ParseStateSummary(ctx context.Context, summaryBytes []byte) (block.StateSummary, error) { if vm.ParseStateSummaryF != nil { return vm.ParseStateSummaryF(ctx, summaryBytes) } @@ -79,7 +79,7 @@ func (vm *TestStateSyncableVM) ParseStateSummary(ctx context.Context, summaryByt return nil, errParseStateSummary } -func (vm *TestStateSyncableVM) GetStateSummary(ctx context.Context, summaryHeight uint64) (block.StateSummary, error) { +func (vm *StateSyncableVM) GetStateSummary(ctx context.Context, summaryHeight uint64) (block.StateSummary, error) { if vm.GetStateSummaryF != nil { return vm.GetStateSummaryF(ctx, summaryHeight) } diff --git a/snow/engine/snowman/block/blocktest/vm.go b/snow/engine/snowman/block/blocktest/vm.go index c12c617b0bcf..a05e0c3e1429 100644 --- a/snow/engine/snowman/block/blocktest/vm.go +++ b/snow/engine/snowman/block/blocktest/vm.go @@ -22,12 +22,12 @@ var ( errLastAccepted = errors.New("unexpectedly called LastAccepted") errGetBlockIDAtHeight = errors.New("unexpectedly called GetBlockIDAtHeight") - _ block.ChainVM = (*TestVM)(nil) + _ block.ChainVM = (*VM)(nil) ) -// TestVM is a ChainVM that is useful for testing. -type TestVM struct { - enginetest.TestVM +// VM is a ChainVM that is useful for testing. +type VM struct { + enginetest.VM CantBuildBlock, CantParseBlock, @@ -44,8 +44,8 @@ type TestVM struct { GetBlockIDAtHeightF func(ctx context.Context, height uint64) (ids.ID, error) } -func (vm *TestVM) Default(cant bool) { - vm.TestVM.Default(cant) +func (vm *VM) Default(cant bool) { + vm.VM.Default(cant) vm.CantBuildBlock = cant vm.CantParseBlock = cant @@ -54,7 +54,7 @@ func (vm *TestVM) Default(cant bool) { vm.CantLastAccepted = cant } -func (vm *TestVM) BuildBlock(ctx context.Context) (snowman.Block, error) { +func (vm *VM) BuildBlock(ctx context.Context) (snowman.Block, error) { if vm.BuildBlockF != nil { return vm.BuildBlockF(ctx) } @@ -64,7 +64,7 @@ func (vm *TestVM) BuildBlock(ctx context.Context) (snowman.Block, error) { return nil, errBuildBlock } -func (vm *TestVM) ParseBlock(ctx context.Context, b []byte) (snowman.Block, error) { +func (vm *VM) ParseBlock(ctx context.Context, b []byte) (snowman.Block, error) { if vm.ParseBlockF != nil { return vm.ParseBlockF(ctx, b) } @@ -74,7 +74,7 @@ func (vm *TestVM) ParseBlock(ctx context.Context, b []byte) (snowman.Block, erro return nil, errParseBlock } -func (vm *TestVM) GetBlock(ctx context.Context, id ids.ID) (snowman.Block, error) { +func (vm *VM) GetBlock(ctx context.Context, id ids.ID) (snowman.Block, error) { if vm.GetBlockF != nil { return vm.GetBlockF(ctx, id) } @@ -84,7 +84,7 @@ func (vm *TestVM) GetBlock(ctx context.Context, id ids.ID) (snowman.Block, error return nil, errGetBlock } -func (vm *TestVM) SetPreference(ctx context.Context, id ids.ID) error { +func (vm *VM) SetPreference(ctx context.Context, id ids.ID) error { if vm.SetPreferenceF != nil { return vm.SetPreferenceF(ctx, id) } @@ -94,7 +94,7 @@ func (vm *TestVM) SetPreference(ctx context.Context, id ids.ID) error { return nil } -func (vm *TestVM) LastAccepted(ctx context.Context) (ids.ID, error) { +func (vm *VM) LastAccepted(ctx context.Context) (ids.ID, error) { if vm.LastAcceptedF != nil { return vm.LastAcceptedF(ctx) } @@ -104,7 +104,7 @@ func (vm *TestVM) LastAccepted(ctx context.Context) (ids.ID, error) { return ids.Empty, errLastAccepted } -func (vm *TestVM) GetBlockIDAtHeight(ctx context.Context, height uint64) (ids.ID, error) { +func (vm *VM) GetBlockIDAtHeight(ctx context.Context, height uint64) (ids.ID, error) { if vm.GetBlockIDAtHeightF != nil { return vm.GetBlockIDAtHeightF(ctx, height) } diff --git a/snow/engine/snowman/bootstrap/bootstrapper_test.go b/snow/engine/snowman/bootstrap/bootstrapper_test.go index 25713771c8ed..e35eb1d9990a 100644 --- a/snow/engine/snowman/bootstrap/bootstrapper_test.go +++ b/snow/engine/snowman/bootstrap/bootstrapper_test.go @@ -35,7 +35,7 @@ import ( var errUnknownBlock = errors.New("unknown block") -func newConfig(t *testing.T) (Config, ids.NodeID, *enginetest.SenderTest, *blocktest.TestVM) { +func newConfig(t *testing.T) (Config, ids.NodeID, *enginetest.Sender, *blocktest.VM) { require := require.New(t) snowCtx := snowtest.Context(t, snowtest.CChainID) @@ -43,8 +43,8 @@ func newConfig(t *testing.T) (Config, ids.NodeID, *enginetest.SenderTest, *block vdrs := validators.NewManager() - sender := &enginetest.SenderTest{} - vm := &blocktest.TestVM{} + sender := &enginetest.Sender{} + vm := &blocktest.VM{} sender.T = t vm.T = t @@ -53,7 +53,7 @@ func newConfig(t *testing.T) (Config, ids.NodeID, *enginetest.SenderTest, *block vm.Default(true) isBootstrapped := false - bootstrapTracker := &enginetest.BootstrapTrackerTest{ + bootstrapTracker := &enginetest.BootstrapTracker{ T: t, IsBootstrappedF: func() bool { return isBootstrapped @@ -99,7 +99,7 @@ func newConfig(t *testing.T) (Config, ids.NodeID, *enginetest.SenderTest, *block PeerTracker: peerTracker, Sender: sender, BootstrapTracker: bootstrapTracker, - Timer: &enginetest.TimerTest{}, + Timer: &enginetest.Timer{}, AncestorsMaxContainersReceived: 2000, DB: memdb.New(), VM: vm, @@ -109,9 +109,9 @@ func newConfig(t *testing.T) (Config, ids.NodeID, *enginetest.SenderTest, *block func TestBootstrapperStartsOnlyIfEnoughStakeIsConnected(t *testing.T) { require := require.New(t) - sender := &enginetest.SenderTest{T: t} - vm := &blocktest.TestVM{ - TestVM: enginetest.TestVM{T: t}, + sender := &enginetest.Sender{T: t} + vm := &blocktest.VM{ + VM: enginetest.VM{T: t}, } sender.Default(true) @@ -147,8 +147,8 @@ func TestBootstrapperStartsOnlyIfEnoughStakeIsConnected(t *testing.T) { StartupTracker: startupTracker, PeerTracker: peerTracker, Sender: sender, - BootstrapTracker: &enginetest.BootstrapTrackerTest{}, - Timer: &enginetest.TimerTest{}, + BootstrapTracker: &enginetest.BootstrapTracker{}, + Timer: &enginetest.Timer{}, AncestorsMaxContainersReceived: 2000, DB: memdb.New(), VM: vm, @@ -611,8 +611,8 @@ func TestBootstrapNoParseOnNew(t *testing.T) { ctx := snowtest.ConsensusContext(snowCtx) peers := validators.NewManager() - sender := &enginetest.SenderTest{} - vm := &blocktest.TestVM{} + sender := &enginetest.Sender{} + vm := &blocktest.VM{} sender.T = t vm.T = t @@ -621,7 +621,7 @@ func TestBootstrapNoParseOnNew(t *testing.T) { vm.Default(true) isBootstrapped := false - bootstrapTracker := &enginetest.BootstrapTrackerTest{ + bootstrapTracker := &enginetest.BootstrapTracker{ T: t, IsBootstrappedF: func() bool { return isBootstrapped @@ -680,7 +680,7 @@ func TestBootstrapNoParseOnNew(t *testing.T) { PeerTracker: peerTracker, Sender: sender, BootstrapTracker: bootstrapTracker, - Timer: &enginetest.TimerTest{}, + Timer: &enginetest.Timer{}, AncestorsMaxContainersReceived: 2000, DB: intervalDB, VM: vm, @@ -773,7 +773,7 @@ func TestBootstrapperRollbackOnSetState(t *testing.T) { require.Equal(blks[0].HeightV, bs.startingHeight) } -func initializeVMWithBlockchain(vm *blocktest.TestVM, blocks []*snowmantest.Block) { +func initializeVMWithBlockchain(vm *blocktest.VM, blocks []*snowmantest.Block) { vm.CantSetState = false vm.LastAcceptedF = snowmantest.MakeLastAcceptedBlockF( blocks, diff --git a/snow/engine/snowman/config_test.go b/snow/engine/snowman/config_test.go index 7ee38616ca26..fe13d4f474c1 100644 --- a/snow/engine/snowman/config_test.go +++ b/snow/engine/snowman/config_test.go @@ -20,8 +20,8 @@ func DefaultConfig(t testing.TB) Config { return Config{ Ctx: snowtest.ConsensusContext(ctx), - VM: &blocktest.TestVM{}, - Sender: &enginetest.SenderTest{}, + VM: &blocktest.VM{}, + Sender: &enginetest.Sender{}, Validators: validators.NewManager(), ConnectedValidators: tracker.NewPeers(), Params: snowball.Parameters{ diff --git a/snow/engine/snowman/engine_test.go b/snow/engine/snowman/engine_test.go index 0d68577c5032..2619dcc727b0 100644 --- a/snow/engine/snowman/engine_test.go +++ b/snow/engine/snowman/engine_test.go @@ -65,7 +65,7 @@ func MakeParseBlockF(blks ...[]*snowmantest.Block) func(context.Context, []byte) } } -func setup(t *testing.T, config Config) (ids.NodeID, validators.Manager, *enginetest.SenderTest, *blocktest.TestVM, *Engine) { +func setup(t *testing.T, config Config) (ids.NodeID, validators.Manager, *enginetest.Sender, *blocktest.VM, *Engine) { require := require.New(t) vdr := ids.GenerateTestNodeID() @@ -73,11 +73,11 @@ func setup(t *testing.T, config Config) (ids.NodeID, validators.Manager, *engine require.NoError(config.ConnectedValidators.Connected(context.Background(), vdr, version.CurrentApp)) config.Validators.RegisterSetCallbackListener(config.Ctx.SubnetID, config.ConnectedValidators) - sender := &enginetest.SenderTest{T: t} + sender := &enginetest.Sender{T: t} config.Sender = sender sender.Default(true) - vm := &blocktest.TestVM{} + vm := &blocktest.VM{} vm.T = t config.VM = vm @@ -327,11 +327,11 @@ func TestEngineMultipleQuery(t *testing.T) { require.NoError(vals.AddStaker(engCfg.Ctx.SubnetID, vdr1, nil, ids.Empty, 1)) require.NoError(vals.AddStaker(engCfg.Ctx.SubnetID, vdr2, nil, ids.Empty, 1)) - sender := &enginetest.SenderTest{T: t} + sender := &enginetest.Sender{T: t} engCfg.Sender = sender sender.Default(true) - vm := &blocktest.TestVM{} + vm := &blocktest.VM{} vm.T = t engCfg.VM = vm @@ -651,11 +651,11 @@ func TestVoteCanceling(t *testing.T) { require.NoError(vals.AddStaker(engCfg.Ctx.SubnetID, vdr1, nil, ids.Empty, 1)) require.NoError(vals.AddStaker(engCfg.Ctx.SubnetID, vdr2, nil, ids.Empty, 1)) - sender := &enginetest.SenderTest{T: t} + sender := &enginetest.Sender{T: t} engCfg.Sender = sender sender.Default(true) - vm := &blocktest.TestVM{} + vm := &blocktest.VM{} vm.T = t engCfg.VM = vm @@ -720,11 +720,11 @@ func TestEngineNoQuery(t *testing.T) { engCfg := DefaultConfig(t) - sender := &enginetest.SenderTest{T: t} + sender := &enginetest.Sender{T: t} engCfg.Sender = sender sender.Default(true) - vm := &blocktest.TestVM{} + vm := &blocktest.VM{} vm.T = t vm.LastAcceptedF = snowmantest.MakeLastAcceptedBlockF( []*snowmantest.Block{snowmantest.Genesis}, @@ -760,11 +760,11 @@ func TestEngineNoRepollQuery(t *testing.T) { engCfg := DefaultConfig(t) - sender := &enginetest.SenderTest{T: t} + sender := &enginetest.Sender{T: t} engCfg.Sender = sender sender.Default(true) - vm := &blocktest.TestVM{} + vm := &blocktest.VM{} vm.T = t vm.LastAcceptedF = snowmantest.MakeLastAcceptedBlockF( []*snowmantest.Block{snowmantest.Genesis}, @@ -1401,11 +1401,11 @@ func TestEngineAggressivePolling(t *testing.T) { vdr := ids.GenerateTestNodeID() require.NoError(vals.AddStaker(engCfg.Ctx.SubnetID, vdr, nil, ids.Empty, 1)) - sender := &enginetest.SenderTest{T: t} + sender := &enginetest.Sender{T: t} engCfg.Sender = sender sender.Default(true) - vm := &blocktest.TestVM{} + vm := &blocktest.VM{} vm.T = t engCfg.VM = vm @@ -1488,12 +1488,12 @@ func TestEngineDoubleChit(t *testing.T) { require.NoError(vals.AddStaker(engCfg.Ctx.SubnetID, vdr0, nil, ids.Empty, 1)) require.NoError(vals.AddStaker(engCfg.Ctx.SubnetID, vdr1, nil, ids.Empty, 1)) - sender := &enginetest.SenderTest{T: t} + sender := &enginetest.Sender{T: t} engCfg.Sender = sender sender.Default(true) - vm := &blocktest.TestVM{} + vm := &blocktest.VM{} vm.T = t engCfg.VM = vm @@ -1575,11 +1575,11 @@ func TestEngineBuildBlockLimit(t *testing.T) { vdr := ids.GenerateTestNodeID() require.NoError(vals.AddStaker(engCfg.Ctx.SubnetID, vdr, nil, ids.Empty, 1)) - sender := &enginetest.SenderTest{T: t} + sender := &enginetest.Sender{T: t} engCfg.Sender = sender sender.Default(true) - vm := &blocktest.TestVM{} + vm := &blocktest.VM{} vm.T = t engCfg.VM = vm @@ -2159,12 +2159,12 @@ func TestEngineApplyAcceptedFrontierInQueryFailed(t *testing.T) { vdr := ids.GenerateTestNodeID() require.NoError(vals.AddStaker(engCfg.Ctx.SubnetID, vdr, nil, ids.Empty, 1)) - sender := &enginetest.SenderTest{T: t} + sender := &enginetest.Sender{T: t} engCfg.Sender = sender sender.Default(true) - vm := &blocktest.TestVM{} + vm := &blocktest.VM{} vm.T = t engCfg.VM = vm @@ -2253,12 +2253,12 @@ func TestEngineRepollsMisconfiguredSubnet(t *testing.T) { vals := validators.NewManager() engCfg.Validators = vals - sender := &enginetest.SenderTest{T: t} + sender := &enginetest.Sender{T: t} engCfg.Sender = sender sender.Default(true) - vm := &blocktest.TestVM{} + vm := &blocktest.VM{} vm.T = t engCfg.VM = vm @@ -2392,7 +2392,7 @@ func TestEngineVoteStallRegression(t *testing.T) { require.NoError(config.Validators.AddStaker(config.Ctx.SubnetID, nodeID1, nil, ids.Empty, 1)) require.NoError(config.Validators.AddStaker(config.Ctx.SubnetID, nodeID2, nil, ids.Empty, 1)) - sender := &enginetest.SenderTest{ + sender := &enginetest.Sender{ T: t, SendChitsF: func(context.Context, ids.NodeID, uint32, ids.ID, ids.ID, ids.ID) {}, } @@ -2402,8 +2402,8 @@ func TestEngineVoteStallRegression(t *testing.T) { acceptedChain := snowmantest.BuildDescendants(snowmantest.Genesis, 3) rejectedChain := snowmantest.BuildDescendants(snowmantest.Genesis, 2) - vm := &blocktest.TestVM{ - TestVM: enginetest.TestVM{ + vm := &blocktest.VM{ + VM: enginetest.VM{ T: t, InitializeF: func( context.Context, @@ -2612,7 +2612,7 @@ func TestEngineEarlyTerminateVoterRegression(t *testing.T) { nodeID := ids.GenerateTestNodeID() require.NoError(config.Validators.AddStaker(config.Ctx.SubnetID, nodeID, nil, ids.Empty, 1)) - sender := &enginetest.SenderTest{ + sender := &enginetest.Sender{ T: t, SendChitsF: func(context.Context, ids.NodeID, uint32, ids.ID, ids.ID, ids.ID) {}, } @@ -2620,8 +2620,8 @@ func TestEngineEarlyTerminateVoterRegression(t *testing.T) { config.Sender = sender chain := snowmantest.BuildDescendants(snowmantest.Genesis, 3) - vm := &blocktest.TestVM{ - TestVM: enginetest.TestVM{ + vm := &blocktest.VM{ + VM: enginetest.VM{ T: t, InitializeF: func( context.Context, @@ -2757,7 +2757,7 @@ func TestEngineRegistersInvalidVoterDependencyRegression(t *testing.T) { nodeID := ids.GenerateTestNodeID() require.NoError(config.Validators.AddStaker(config.Ctx.SubnetID, nodeID, nil, ids.Empty, 1)) - sender := &enginetest.SenderTest{ + sender := &enginetest.Sender{ T: t, SendChitsF: func(context.Context, ids.NodeID, uint32, ids.ID, ids.ID, ids.ID) {}, } @@ -2770,8 +2770,8 @@ func TestEngineRegistersInvalidVoterDependencyRegression(t *testing.T) { ) rejectedChain[1].VerifyV = errInvalid - vm := &blocktest.TestVM{ - TestVM: enginetest.TestVM{ + vm := &blocktest.VM{ + VM: enginetest.VM{ T: t, InitializeF: func( context.Context, diff --git a/snow/engine/snowman/getter/getter_test.go b/snow/engine/snowman/getter/getter_test.go index dbdd2a2f17f7..4454117f345e 100644 --- a/snow/engine/snowman/getter/getter_test.go +++ b/snow/engine/snowman/getter/getter_test.go @@ -27,19 +27,19 @@ import ( var errUnknownBlock = errors.New("unknown block") type StateSyncEnabledMock struct { - *blocktest.TestVM + *blocktest.VM *block.MockStateSyncableVM } -func newTest(t *testing.T) (common.AllGetsServer, StateSyncEnabledMock, *enginetest.SenderTest) { +func newTest(t *testing.T) (common.AllGetsServer, StateSyncEnabledMock, *enginetest.Sender) { ctrl := gomock.NewController(t) vm := StateSyncEnabledMock{ - TestVM: &blocktest.TestVM{}, + VM: &blocktest.VM{}, MockStateSyncableVM: block.NewMockStateSyncableVM(ctrl), } - sender := &enginetest.SenderTest{ + sender := &enginetest.Sender{ T: t, } sender.Default(true) diff --git a/snow/engine/snowman/syncer/state_syncer_test.go b/snow/engine/snowman/syncer/state_syncer_test.go index 54a4608dce79..fd062cb2d8b1 100644 --- a/snow/engine/snowman/syncer/state_syncer_test.go +++ b/snow/engine/snowman/syncer/state_syncer_test.go @@ -40,11 +40,11 @@ func TestStateSyncerIsEnabledIfVMSupportsStateSyncing(t *testing.T) { // Build state syncer snowCtx := snowtest.Context(t, snowtest.CChainID) ctx := snowtest.ConsensusContext(snowCtx) - sender := &enginetest.SenderTest{T: t} + sender := &enginetest.Sender{T: t} // Non state syncableVM case - nonStateSyncableVM := &blocktest.TestVM{ - TestVM: enginetest.TestVM{T: t}, + nonStateSyncableVM := &blocktest.VM{ + VM: enginetest.VM{T: t}, } dummyGetter, err := getter.New( nonStateSyncableVM, @@ -68,10 +68,10 @@ func TestStateSyncerIsEnabledIfVMSupportsStateSyncing(t *testing.T) { // State syncableVM case fullVM := &fullVM{ - TestVM: &blocktest.TestVM{ - TestVM: enginetest.TestVM{T: t}, + VM: &blocktest.VM{ + VM: enginetest.VM{T: t}, }, - TestStateSyncableVM: &blocktest.TestStateSyncableVM{ + StateSyncableVM: &blocktest.StateSyncableVM{ T: t, }, } @@ -166,7 +166,7 @@ func TestStateSyncLocalSummaryIsIncludedAmongFrontiersIfAvailable(t *testing.T) syncer, fullVM, _ := buildTestsObjects(t, ctx, startup, beacons, (totalWeight+1)/2) // mock VM to simulate a valid summary is returned - localSummary := &blocktest.TestStateSummary{ + localSummary := &blocktest.StateSummary{ HeightV: 2000, IDV: summaryID, BytesV: summaryBytes, @@ -294,7 +294,7 @@ func TestUnRequestedStateSummaryFrontiersAreDropped(t *testing.T) { // mock VM to simulate a valid summary is returned fullVM.CantParseStateSummary = true fullVM.ParseStateSummaryF = func(_ context.Context, summaryBytes []byte) (block.StateSummary, error) { - return &blocktest.TestStateSummary{ + return &blocktest.StateSummary{ HeightV: key, IDV: summaryID, BytesV: summaryBytes, @@ -480,7 +480,7 @@ func TestLateResponsesFromUnresponsiveFrontiersAreNotRecorded(t *testing.T) { // mock VM to simulate a valid but late summary is returned fullVM.CantParseStateSummary = true fullVM.ParseStateSummaryF = func(_ context.Context, summaryBytes []byte) (block.StateSummary, error) { - return &blocktest.TestStateSummary{ + return &blocktest.StateSummary{ HeightV: key, IDV: summaryID, BytesV: summaryBytes, @@ -529,7 +529,7 @@ func TestStateSyncIsRestartedIfTooManyFrontierSeedersTimeout(t *testing.T) { fullVM.ParseStateSummaryF = func(_ context.Context, b []byte) (block.StateSummary, error) { switch { case bytes.Equal(b, summaryBytes): - return &blocktest.TestStateSummary{ + return &blocktest.StateSummary{ HeightV: key, IDV: summaryID, BytesV: summaryBytes, @@ -617,7 +617,7 @@ func TestVoteRequestsAreSentAsAllFrontierBeaconsResponded(t *testing.T) { fullVM.CantParseStateSummary = true fullVM.ParseStateSummaryF = func(_ context.Context, b []byte) (block.StateSummary, error) { require.Equal(summaryBytes, b) - return &blocktest.TestStateSummary{ + return &blocktest.StateSummary{ HeightV: key, IDV: summaryID, BytesV: summaryBytes, @@ -687,7 +687,7 @@ func TestUnRequestedVotesAreDropped(t *testing.T) { // mock VM to simulate a valid summary is returned fullVM.CantParseStateSummary = true fullVM.ParseStateSummaryF = func(_ context.Context, summaryBytes []byte) (block.StateSummary, error) { - return &blocktest.TestStateSummary{ + return &blocktest.StateSummary{ HeightV: key, IDV: summaryID, BytesV: summaryBytes, @@ -804,7 +804,7 @@ func TestVotesForUnknownSummariesAreDropped(t *testing.T) { // mock VM to simulate a valid summary is returned fullVM.CantParseStateSummary = true fullVM.ParseStateSummaryF = func(_ context.Context, summaryBytes []byte) (block.StateSummary, error) { - return &blocktest.TestStateSummary{ + return &blocktest.StateSummary{ HeightV: key, IDV: summaryID, BytesV: summaryBytes, @@ -906,13 +906,13 @@ func TestStateSummaryIsPassedToVMAsMajorityOfVotesIsCastedForIt(t *testing.T) { } // mock VM to simulate a valid summary is returned - summary := &blocktest.TestStateSummary{ + summary := &blocktest.StateSummary{ HeightV: key, IDV: summaryID, BytesV: summaryBytes, T: t, } - minoritySummary := &blocktest.TestStateSummary{ + minoritySummary := &blocktest.StateSummary{ HeightV: minorityKey, IDV: minoritySummaryID, BytesV: minoritySummaryBytes, @@ -1051,7 +1051,7 @@ func TestVotingIsRestartedIfMajorityIsNotReachedDueToTimeouts(t *testing.T) { } // mock VM to simulate a valid summary is returned - minoritySummary := &blocktest.TestStateSummary{ + minoritySummary := &blocktest.StateSummary{ HeightV: minorityKey, IDV: minoritySummaryID, BytesV: minoritySummaryBytes, @@ -1157,13 +1157,13 @@ func TestStateSyncIsStoppedIfEnoughVotesAreCastedWithNoClearMajority(t *testing. } // mock VM to simulate a valid minoritySummary1 is returned - minoritySummary1 := &blocktest.TestStateSummary{ + minoritySummary1 := &blocktest.StateSummary{ HeightV: key, IDV: summaryID, BytesV: summaryBytes, T: t, } - minoritySummary2 := &blocktest.TestStateSummary{ + minoritySummary2 := &blocktest.StateSummary{ HeightV: minorityKey, IDV: minoritySummaryID, BytesV: minoritySummaryBytes, diff --git a/snow/engine/snowman/syncer/utils_test.go b/snow/engine/snowman/syncer/utils_test.go index 7a2ea5c28116..4cd6e58d840e 100644 --- a/snow/engine/snowman/syncer/utils_test.go +++ b/snow/engine/snowman/syncer/utils_test.go @@ -54,8 +54,8 @@ func init() { } type fullVM struct { - *blocktest.TestVM - *blocktest.TestStateSyncableVM + *blocktest.VM + *blocktest.StateSyncableVM } func buildTestPeers(t *testing.T, subnetID ids.ID) validators.Manager { @@ -78,19 +78,19 @@ func buildTestsObjects( ) ( *stateSyncer, *fullVM, - *enginetest.SenderTest, + *enginetest.Sender, ) { require := require.New(t) fullVM := &fullVM{ - TestVM: &blocktest.TestVM{ - TestVM: enginetest.TestVM{T: t}, + VM: &blocktest.VM{ + VM: enginetest.VM{T: t}, }, - TestStateSyncableVM: &blocktest.TestStateSyncableVM{ + StateSyncableVM: &blocktest.StateSyncableVM{ T: t, }, } - sender := &enginetest.SenderTest{T: t} + sender := &enginetest.Sender{T: t} dummyGetter, err := getter.New( fullVM, sender, diff --git a/snow/networking/handler/handler_test.go b/snow/networking/handler/handler_test.go index 3e145f39778a..faefa616ab91 100644 --- a/snow/networking/handler/handler_test.go +++ b/snow/networking/handler/handler_test.go @@ -83,8 +83,8 @@ func TestHandlerDropsTimedOutMessages(t *testing.T) { require.NoError(err) handler := handlerIntf.(*handler) - bootstrapper := &enginetest.BootstrapperTest{ - EngineTest: enginetest.EngineTest{ + bootstrapper := &enginetest.Bootstrapper{ + Engine: enginetest.Engine{ T: t, }, } @@ -195,8 +195,8 @@ func TestHandlerClosesOnError(t *testing.T) { closed <- struct{}{} }) - bootstrapper := &enginetest.BootstrapperTest{ - EngineTest: enginetest.EngineTest{ + bootstrapper := &enginetest.Bootstrapper{ + Engine: enginetest.Engine{ T: t, }, } @@ -208,7 +208,7 @@ func TestHandlerClosesOnError(t *testing.T) { return errFatal } - engine := &enginetest.EngineTest{T: t} + engine := &enginetest.Engine{T: t} engine.Default(false) engine.ContextF = func() *snow.ConsensusContext { return ctx @@ -295,8 +295,8 @@ func TestHandlerDropsGossipDuringBootstrapping(t *testing.T) { handler.clock.Set(time.Now()) - bootstrapper := &enginetest.BootstrapperTest{ - EngineTest: enginetest.EngineTest{ + bootstrapper := &enginetest.Bootstrapper{ + Engine: enginetest.Engine{ T: t, }, } @@ -383,14 +383,14 @@ func TestHandlerDispatchInternal(t *testing.T) { ) require.NoError(err) - bootstrapper := &enginetest.BootstrapperTest{ - EngineTest: enginetest.EngineTest{ + bootstrapper := &enginetest.Bootstrapper{ + Engine: enginetest.Engine{ T: t, }, } bootstrapper.Default(false) - engine := &enginetest.EngineTest{T: t} + engine := &enginetest.Engine{T: t} engine.Default(false) engine.ContextF = func() *snow.ConsensusContext { return ctx @@ -469,14 +469,14 @@ func TestHandlerSubnetConnector(t *testing.T) { ) require.NoError(err) - bootstrapper := &enginetest.BootstrapperTest{ - EngineTest: enginetest.EngineTest{ + bootstrapper := &enginetest.Bootstrapper{ + Engine: enginetest.Engine{ T: t, }, } bootstrapper.Default(false) - engine := &enginetest.EngineTest{T: t} + engine := &enginetest.Engine{T: t} engine.Default(false) engine.ContextF = func() *snow.ConsensusContext { return ctx @@ -651,14 +651,14 @@ func TestDynamicEngineTypeDispatch(t *testing.T) { ) require.NoError(err) - bootstrapper := &enginetest.BootstrapperTest{ - EngineTest: enginetest.EngineTest{ + bootstrapper := &enginetest.Bootstrapper{ + Engine: enginetest.Engine{ T: t, }, } bootstrapper.Default(false) - engine := &enginetest.EngineTest{T: t} + engine := &enginetest.Engine{T: t} engine.Default(false) engine.ContextF = func() *snow.ConsensusContext { return ctx diff --git a/snow/networking/handler/health_test.go b/snow/networking/handler/health_test.go index 57c1b48b2f89..f37b5f551a6d 100644 --- a/snow/networking/handler/health_test.go +++ b/snow/networking/handler/health_test.go @@ -97,14 +97,14 @@ func TestHealthCheckSubnet(t *testing.T) { ) require.NoError(err) - bootstrapper := &enginetest.BootstrapperTest{ - EngineTest: enginetest.EngineTest{ + bootstrapper := &enginetest.Bootstrapper{ + Engine: enginetest.Engine{ T: t, }, } bootstrapper.Default(false) - engine := &enginetest.EngineTest{T: t} + engine := &enginetest.Engine{T: t} engine.Default(false) engine.ContextF = func() *snow.ConsensusContext { return ctx diff --git a/snow/networking/router/chain_router_test.go b/snow/networking/router/chain_router_test.go index f27931192200..dadbec2d93d7 100644 --- a/snow/networking/router/chain_router_test.go +++ b/snow/networking/router/chain_router_test.go @@ -118,8 +118,8 @@ func TestShutdown(t *testing.T) { ) require.NoError(err) - bootstrapper := &enginetest.BootstrapperTest{ - EngineTest: enginetest.EngineTest{ + bootstrapper := &enginetest.Bootstrapper{ + Engine: enginetest.Engine{ T: t, }, } @@ -137,7 +137,7 @@ func TestShutdown(t *testing.T) { } bootstrapper.HaltF = func(context.Context) {} - engine := &enginetest.EngineTest{T: t} + engine := &enginetest.Engine{T: t} engine.Default(true) engine.CantGossip = false engine.ContextF = func() *snow.ConsensusContext { @@ -244,7 +244,7 @@ func TestConnectedAfterShutdownErrorLogRegression(t *testing.T) { ) require.NoError(err) - engine := enginetest.EngineTest{ + engine := enginetest.Engine{ T: t, StartF: func(context.Context, uint32) error { return nil @@ -263,9 +263,9 @@ func TestConnectedAfterShutdownErrorLogRegression(t *testing.T) { engine.Default(true) engine.CantGossip = false - bootstrapper := &enginetest.BootstrapperTest{ - EngineTest: engine, - CantClear: true, + bootstrapper := &enginetest.Bootstrapper{ + Engine: engine, + CantClear: true, } h.SetEngineManager(&handler.EngineManager{ @@ -378,8 +378,8 @@ func TestShutdownTimesOut(t *testing.T) { require.NoError(err) bootstrapFinished := make(chan struct{}, 1) - bootstrapper := &enginetest.BootstrapperTest{ - EngineTest: enginetest.EngineTest{ + bootstrapper := &enginetest.Bootstrapper{ + Engine: enginetest.Engine{ T: t, }, } @@ -399,7 +399,7 @@ func TestShutdownTimesOut(t *testing.T) { return nil } - engine := &enginetest.EngineTest{T: t} + engine := &enginetest.Engine{T: t} engine.Default(false) engine.ContextF = func() *snow.ConsensusContext { return ctx @@ -547,8 +547,8 @@ func TestRouterTimeout(t *testing.T) { ) require.NoError(err) - bootstrapper := &enginetest.BootstrapperTest{ - EngineTest: enginetest.EngineTest{ + bootstrapper := &enginetest.Bootstrapper{ + Engine: enginetest.Engine{ T: t, }, } @@ -1130,8 +1130,8 @@ func TestValidatorOnlyMessageDrops(t *testing.T) { ) require.NoError(err) - bootstrapper := &enginetest.BootstrapperTest{ - EngineTest: enginetest.EngineTest{ + bootstrapper := &enginetest.Bootstrapper{ + Engine: enginetest.Engine{ T: t, }, } @@ -1149,7 +1149,7 @@ func TestValidatorOnlyMessageDrops(t *testing.T) { State: snow.Bootstrapping, // assumed bootstrapping is ongoing }) - engine := &enginetest.EngineTest{T: t} + engine := &enginetest.Engine{T: t} engine.ContextF = func() *snow.ConsensusContext { return ctx } @@ -1411,8 +1411,8 @@ func TestValidatorOnlyAllowedNodeMessageDrops(t *testing.T) { ) require.NoError(err) - bootstrapper := &enginetest.BootstrapperTest{ - EngineTest: enginetest.EngineTest{ + bootstrapper := &enginetest.Bootstrapper{ + Engine: enginetest.Engine{ T: t, }, } @@ -1429,7 +1429,7 @@ func TestValidatorOnlyAllowedNodeMessageDrops(t *testing.T) { Type: engineType, State: snow.Bootstrapping, // assumed bootstrapping is ongoing }) - engine := &enginetest.EngineTest{T: t} + engine := &enginetest.Engine{T: t} engine.ContextF = func() *snow.ConsensusContext { return ctx } @@ -1680,7 +1680,7 @@ func TestCrossChainAppRequest(t *testing.T) { } } -func newChainRouterTest(t *testing.T) (*ChainRouter, *enginetest.EngineTest) { +func newChainRouterTest(t *testing.T) (*ChainRouter, *enginetest.Engine) { // Create a timeout manager tm, err := timeout.NewManager( &timer.AdaptiveTimeoutConfig{ @@ -1751,8 +1751,8 @@ func newChainRouterTest(t *testing.T) (*ChainRouter, *enginetest.EngineTest) { ) require.NoError(t, err) - bootstrapper := &enginetest.BootstrapperTest{ - EngineTest: enginetest.EngineTest{ + bootstrapper := &enginetest.Bootstrapper{ + Engine: enginetest.Engine{ T: t, }, } @@ -1761,7 +1761,7 @@ func newChainRouterTest(t *testing.T) (*ChainRouter, *enginetest.EngineTest) { return ctx } - engine := &enginetest.EngineTest{T: t} + engine := &enginetest.Engine{T: t} engine.Default(false) engine.ContextF = func() *snow.ConsensusContext { return ctx diff --git a/snow/networking/sender/sender_test.go b/snow/networking/sender/sender_test.go index c7b1f5530054..215de3c56fb5 100644 --- a/snow/networking/sender/sender_test.go +++ b/snow/networking/sender/sender_test.go @@ -92,7 +92,7 @@ func TestTimeout(t *testing.T) { prometheus.NewRegistry(), )) - externalSender := &sendertest.ExternalSenderTest{TB: t} + externalSender := &sendertest.External{TB: t} externalSender.Default(false) sender, err := New( @@ -140,8 +140,8 @@ func TestTimeout(t *testing.T) { ) require.NoError(err) - bootstrapper := &enginetest.BootstrapperTest{ - EngineTest: enginetest.EngineTest{ + bootstrapper := &enginetest.Bootstrapper{ + Engine: enginetest.Engine{ T: t, }, } @@ -369,7 +369,7 @@ func TestReliableMessages(t *testing.T) { prometheus.NewRegistry(), )) - externalSender := &sendertest.ExternalSenderTest{TB: t} + externalSender := &sendertest.External{TB: t} externalSender.Default(false) sender, err := New( @@ -417,8 +417,8 @@ func TestReliableMessages(t *testing.T) { ) require.NoError(err) - bootstrapper := &enginetest.BootstrapperTest{ - EngineTest: enginetest.EngineTest{ + bootstrapper := &enginetest.Bootstrapper{ + Engine: enginetest.Engine{ T: t, }, } @@ -526,7 +526,7 @@ func TestReliableMessagesToMyself(t *testing.T) { prometheus.NewRegistry(), )) - externalSender := &sendertest.ExternalSenderTest{TB: t} + externalSender := &sendertest.External{TB: t} externalSender.Default(false) sender, err := New( @@ -574,8 +574,8 @@ func TestReliableMessagesToMyself(t *testing.T) { ) require.NoError(err) - bootstrapper := &enginetest.BootstrapperTest{ - EngineTest: enginetest.EngineTest{ + bootstrapper := &enginetest.Bootstrapper{ + Engine: enginetest.Engine{ T: t, }, } diff --git a/snow/networking/sender/sendertest/external.go b/snow/networking/sender/sendertest/external.go index 56e5d37f31aa..bf4ea670eab6 100644 --- a/snow/networking/sender/sendertest/external.go +++ b/snow/networking/sender/sendertest/external.go @@ -16,13 +16,13 @@ import ( ) var ( - _ sender.ExternalSender = (*ExternalSenderTest)(nil) + _ sender.ExternalSender = (*External)(nil) errSend = errors.New("unexpectedly called Send") ) -// ExternalSenderTest is a test sender -type ExternalSenderTest struct { +// External is a test sender +type External struct { TB testing.TB CantSend bool @@ -31,11 +31,11 @@ type ExternalSenderTest struct { } // Default set the default callable value to [cant] -func (s *ExternalSenderTest) Default(cant bool) { +func (s *External) Default(cant bool) { s.CantSend = cant } -func (s *ExternalSenderTest) Send( +func (s *External) Send( msg message.OutboundMessage, config common.SendConfig, subnetID ids.ID, diff --git a/snow/snowtest/context.go b/snow/snowtest/context.go index 65e1eb87cc74..779636520376 100644 --- a/snow/snowtest/context.go +++ b/snow/snowtest/context.go @@ -63,7 +63,7 @@ func Context(tb testing.TB, chainID ids.ID) *snow.Context { require.NoError(aliaser.Alias(CChainID, "C")) require.NoError(aliaser.Alias(CChainID, CChainID.String())) - validatorState := &validatorstest.TestState{ + validatorState := &validatorstest.State{ GetSubnetIDF: func(_ context.Context, chainID ids.ID) (ids.ID, error) { subnetID, ok := map[ids.ID]ids.ID{ constants.PlatformChainID: constants.PrimaryNetworkID, diff --git a/snow/validators/validatorstest/state.go b/snow/validators/validatorstest/state.go index bc6cef2c667a..64abbc515e5b 100644 --- a/snow/validators/validatorstest/state.go +++ b/snow/validators/validatorstest/state.go @@ -21,9 +21,15 @@ var ( errGetValidatorSet = errors.New("unexpectedly called GetValidatorSet") ) -var _ validators.State = (*TestState)(nil) +var _ validators.State = (*State)(nil) -type TestState struct { +// TestState is an alias for State because ava-labs/coreth uses the original +// identifier and this change would otherwise break the build. +// +// Deprecated: use [State]. +type TestState = State + +type State struct { T testing.TB CantGetMinimumHeight, @@ -37,7 +43,7 @@ type TestState struct { GetValidatorSetF func(ctx context.Context, height uint64, subnetID ids.ID) (map[ids.NodeID]*validators.GetValidatorOutput, error) } -func (vm *TestState) GetMinimumHeight(ctx context.Context) (uint64, error) { +func (vm *State) GetMinimumHeight(ctx context.Context) (uint64, error) { if vm.GetMinimumHeightF != nil { return vm.GetMinimumHeightF(ctx) } @@ -47,7 +53,7 @@ func (vm *TestState) GetMinimumHeight(ctx context.Context) (uint64, error) { return 0, errMinimumHeight } -func (vm *TestState) GetCurrentHeight(ctx context.Context) (uint64, error) { +func (vm *State) GetCurrentHeight(ctx context.Context) (uint64, error) { if vm.GetCurrentHeightF != nil { return vm.GetCurrentHeightF(ctx) } @@ -57,7 +63,7 @@ func (vm *TestState) GetCurrentHeight(ctx context.Context) (uint64, error) { return 0, errCurrentHeight } -func (vm *TestState) GetSubnetID(ctx context.Context, chainID ids.ID) (ids.ID, error) { +func (vm *State) GetSubnetID(ctx context.Context, chainID ids.ID) (ids.ID, error) { if vm.GetSubnetIDF != nil { return vm.GetSubnetIDF(ctx, chainID) } @@ -67,7 +73,7 @@ func (vm *TestState) GetSubnetID(ctx context.Context, chainID ids.ID) (ids.ID, e return ids.Empty, errSubnetID } -func (vm *TestState) GetValidatorSet( +func (vm *State) GetValidatorSet( ctx context.Context, height uint64, subnetID ids.ID, diff --git a/vms/avm/environment_test.go b/vms/avm/environment_test.go index 32db5ae00fd1..75cae23f347a 100644 --- a/vms/avm/environment_test.go +++ b/vms/avm/environment_test.go @@ -193,7 +193,7 @@ func setup(tb testing.TB, c *envConfig) *environment { }, c.additionalFxs..., ), - &enginetest.SenderTest{}, + &enginetest.Sender{}, )) stopVertexID := ids.GenerateTestID() diff --git a/vms/avm/network/network_test.go b/vms/avm/network/network_test.go index 9313209113eb..4d9b68f71f22 100644 --- a/vms/avm/network/network_test.go +++ b/vms/avm/network/network_test.go @@ -179,7 +179,7 @@ func TestNetworkIssueTxFromRPC(t *testing.T) { logging.NoLog{}, ids.EmptyNodeID, ids.Empty, - &validatorstest.TestState{ + &validatorstest.State{ GetCurrentHeightF: func(context.Context) (uint64, error) { return 0, nil }, @@ -273,7 +273,7 @@ func TestNetworkIssueTxFromRPCWithoutVerification(t *testing.T) { logging.NoLog{}, ids.EmptyNodeID, ids.Empty, - &validatorstest.TestState{ + &validatorstest.State{ GetCurrentHeightF: func(context.Context) (uint64, error) { return 0, nil }, diff --git a/vms/platformvm/block/builder/helpers_test.go b/vms/platformvm/block/builder/helpers_test.go index 7b3575e237e2..28220e60a1e7 100644 --- a/vms/platformvm/block/builder/helpers_test.go +++ b/vms/platformvm/block/builder/helpers_test.go @@ -108,7 +108,7 @@ type environment struct { blkManager blockexecutor.Manager mempool mempool.Mempool network *network.Network - sender *enginetest.SenderTest + sender *enginetest.Sender isBootstrapped *utils.Atomic[bool] config *config.Config @@ -169,7 +169,7 @@ func newEnvironment(t *testing.T, f fork) *environment { //nolint:unparam } registerer := prometheus.NewRegistry() - res.sender = &enginetest.SenderTest{T: t} + res.sender = &enginetest.Sender{T: t} res.sender.SendAppGossipF = func(context.Context, common.SendConfig, []byte) error { return nil } diff --git a/vms/platformvm/block/executor/helpers_test.go b/vms/platformvm/block/executor/helpers_test.go index d8c3b6bf510b..79f2e78eb19e 100644 --- a/vms/platformvm/block/executor/helpers_test.go +++ b/vms/platformvm/block/executor/helpers_test.go @@ -120,7 +120,7 @@ type test struct { type environment struct { blkManager Manager mempool mempool.Mempool - sender *enginetest.SenderTest + sender *enginetest.Sender isBootstrapped *utils.Atomic[bool] config *config.Config @@ -192,7 +192,7 @@ func newEnvironment(t *testing.T, ctrl *gomock.Controller, f fork) *environment } registerer := prometheus.NewRegistry() - res.sender = &enginetest.SenderTest{T: t} + res.sender = &enginetest.Sender{T: t} metrics := metrics.Noop diff --git a/vms/platformvm/validator_set_property_test.go b/vms/platformvm/validator_set_property_test.go index 9149d4023b7d..3411fb7e1d72 100644 --- a/vms/platformvm/validator_set_property_test.go +++ b/vms/platformvm/validator_set_property_test.go @@ -685,7 +685,7 @@ func buildVM(t *testing.T) (*VM, ids.ID, error) { ctx.Lock.Lock() defer ctx.Lock.Unlock() - appSender := &enginetest.SenderTest{} + appSender := &enginetest.Sender{} appSender.CantSendAppGossip = true appSender.SendAppGossipF = func(context.Context, common.SendConfig, []byte) error { return nil diff --git a/vms/platformvm/vm_test.go b/vms/platformvm/vm_test.go index 11b0ebd38214..08d51ce4f135 100644 --- a/vms/platformvm/vm_test.go +++ b/vms/platformvm/vm_test.go @@ -287,7 +287,7 @@ func defaultVM(t *testing.T, f fork) (*VM, *txstest.WalletFactory, database.Data ctx.Lock.Lock() defer ctx.Lock.Unlock() _, genesisBytes := defaultGenesis(t, ctx.AVAXAssetID) - appSender := &enginetest.SenderTest{} + appSender := &enginetest.Sender{} appSender.CantSendAppGossip = true appSender.SendAppGossipF = func(context.Context, common.SendConfig, []byte) error { return nil @@ -1461,7 +1461,7 @@ func TestBootstrapPartiallyAccepted(t *testing.T) { prometheus.NewRegistry(), )) - externalSender := &sendertest.ExternalSenderTest{TB: t} + externalSender := &sendertest.External{TB: t} externalSender.Default(true) // Passes messages from the consensus engine to the network @@ -1478,7 +1478,7 @@ func TestBootstrapPartiallyAccepted(t *testing.T) { require.NoError(err) isBootstrapped := false - bootstrapTracker := &enginetest.BootstrapTrackerTest{ + bootstrapTracker := &enginetest.BootstrapTracker{ T: t, IsBootstrappedF: func() bool { return isBootstrapped @@ -2030,7 +2030,7 @@ func TestUptimeDisallowedAfterNeverConnecting(t *testing.T) { ctx.SharedMemory = m.NewSharedMemory(ctx.ChainID) msgChan := make(chan common.Message, 1) - appSender := &enginetest.SenderTest{T: t} + appSender := &enginetest.Sender{T: t} require.NoError(vm.Initialize( context.Background(), ctx, diff --git a/vms/platformvm/warp/validator_test.go b/vms/platformvm/warp/validator_test.go index be918f339c74..4e9d64cc3f67 100644 --- a/vms/platformvm/warp/validator_test.go +++ b/vms/platformvm/warp/validator_test.go @@ -330,7 +330,7 @@ func BenchmarkGetCanonicalValidatorSet(b *testing.B) { validator := getValidatorOutputs[i] getValidatorsOutput[validator.NodeID] = validator } - validatorState := &validatorstest.TestState{ + validatorState := &validatorstest.State{ GetValidatorSetF: func(context.Context, uint64, ids.ID) (map[ids.NodeID]*validators.GetValidatorOutput, error) { return getValidatorsOutput, nil }, diff --git a/vms/proposervm/batched_vm_test.go b/vms/proposervm/batched_vm_test.go index c61c6678842c..d8b50dc94e67 100644 --- a/vms/proposervm/batched_vm_test.go +++ b/vms/proposervm/batched_vm_test.go @@ -600,7 +600,7 @@ func TestBatchedParseBlockParallel(t *testing.T) { vm := VM{ ctx: &snow.Context{ChainID: chainID}, - ChainVM: &blocktest.TestVM{ + ChainVM: &blocktest.VM{ ParseBlockF: func(_ context.Context, rawBlock []byte) (snowman.Block, error) { return &snowmantest.Block{BytesV: rawBlock}, nil }, @@ -909,8 +909,8 @@ func TestBatchedParseBlockAtSnomanPlusPlusFork(t *testing.T) { } type TestRemoteProposerVM struct { - *blocktest.TestBatchedVM - *blocktest.TestVM + *blocktest.BatchedVM + *blocktest.VM } func initTestRemoteProposerVM( @@ -925,11 +925,11 @@ func initTestRemoteProposerVM( initialState := []byte("genesis state") coreVM := TestRemoteProposerVM{ - TestVM: &blocktest.TestVM{}, - TestBatchedVM: &blocktest.TestBatchedVM{}, + VM: &blocktest.VM{}, + BatchedVM: &blocktest.BatchedVM{}, } - coreVM.TestVM.T = t - coreVM.TestBatchedVM.T = t + coreVM.VM.T = t + coreVM.BatchedVM.T = t coreVM.InitializeF = func( context.Context, @@ -980,7 +980,7 @@ func initTestRemoteProposerVM( }, ) - valState := &validatorstest.TestState{ + valState := &validatorstest.State{ T: t, } valState.GetMinimumHeightF = func(context.Context) (uint64, error) { diff --git a/vms/proposervm/proposer/windower_test.go b/vms/proposervm/proposer/windower_test.go index ea426b916475..2d90efaa5e3a 100644 --- a/vms/proposervm/proposer/windower_test.go +++ b/vms/proposervm/proposer/windower_test.go @@ -58,7 +58,7 @@ func TestWindowerRepeatedValidator(t *testing.T) { nonValidatorID = ids.GenerateTestNodeID() ) - vdrState := &validatorstest.TestState{ + vdrState := &validatorstest.State{ T: t, GetValidatorSetF: func(context.Context, uint64, ids.ID) (map[ids.NodeID]*validators.GetValidatorOutput, error) { return map[ids.NodeID]*validators.GetValidatorOutput{ @@ -445,13 +445,13 @@ func TestProposerDistribution(t *testing.T) { require.Less(maxSTDDeviation, 3.) } -func makeValidators(t testing.TB, count int) ([]ids.NodeID, *validatorstest.TestState) { +func makeValidators(t testing.TB, count int) ([]ids.NodeID, *validatorstest.State) { validatorIDs := make([]ids.NodeID, count) for i := range validatorIDs { validatorIDs[i] = ids.BuildTestNodeID([]byte{byte(i) + 1}) } - vdrState := &validatorstest.TestState{ + vdrState := &validatorstest.State{ T: t, GetValidatorSetF: func(context.Context, uint64, ids.ID) (map[ids.NodeID]*validators.GetValidatorOutput, error) { vdrs := make(map[ids.NodeID]*validators.GetValidatorOutput, MaxVerifyWindows) diff --git a/vms/proposervm/state_syncable_vm_test.go b/vms/proposervm/state_syncable_vm_test.go index 9ef4dc3066f5..1017c35d8a40 100644 --- a/vms/proposervm/state_syncable_vm_test.go +++ b/vms/proposervm/state_syncable_vm_test.go @@ -33,12 +33,12 @@ func helperBuildStateSyncTestObjects(t *testing.T) (*fullVM, *VM) { require := require.New(t) innerVM := &fullVM{ - TestVM: &blocktest.TestVM{ - TestVM: enginetest.TestVM{ + VM: &blocktest.VM{ + VM: enginetest.VM{ T: t, }, }, - TestStateSyncableVM: &blocktest.TestStateSyncableVM{ + StateSyncableVM: &blocktest.StateSyncableVM{ T: t, }, } @@ -125,7 +125,7 @@ func TestStateSyncGetOngoingSyncStateSummary(t *testing.T) { require.NoError(vm.Shutdown(context.Background())) }() - innerSummary := &blocktest.TestStateSummary{ + innerSummary := &blocktest.StateSummary{ IDV: ids.ID{'s', 'u', 'm', 'm', 'a', 'r', 'y', 'I', 'D'}, HeightV: uint64(2022), BytesV: []byte{'i', 'n', 'n', 'e', 'r'}, @@ -208,7 +208,7 @@ func TestStateSyncGetLastStateSummary(t *testing.T) { require.NoError(vm.Shutdown(context.Background())) }() - innerSummary := &blocktest.TestStateSummary{ + innerSummary := &blocktest.StateSummary{ IDV: ids.ID{'s', 'u', 'm', 'm', 'a', 'r', 'y', 'I', 'D'}, HeightV: uint64(2022), BytesV: []byte{'i', 'n', 'n', 'e', 'r'}, @@ -292,7 +292,7 @@ func TestStateSyncGetStateSummary(t *testing.T) { }() reqHeight := uint64(1969) - innerSummary := &blocktest.TestStateSummary{ + innerSummary := &blocktest.StateSummary{ IDV: ids.ID{'s', 'u', 'm', 'm', 'a', 'r', 'y', 'I', 'D'}, HeightV: reqHeight, BytesV: []byte{'i', 'n', 'n', 'e', 'r'}, @@ -377,7 +377,7 @@ func TestParseStateSummary(t *testing.T) { }() reqHeight := uint64(1969) - innerSummary := &blocktest.TestStateSummary{ + innerSummary := &blocktest.StateSummary{ IDV: ids.ID{'s', 'u', 'm', 'm', 'a', 'r', 'y', 'I', 'D'}, HeightV: reqHeight, BytesV: []byte{'i', 'n', 'n', 'e', 'r'}, @@ -454,7 +454,7 @@ func TestStateSummaryAccept(t *testing.T) { }() reqHeight := uint64(1969) - innerSummary := &blocktest.TestStateSummary{ + innerSummary := &blocktest.StateSummary{ IDV: ids.ID{'s', 'u', 'm', 'm', 'a', 'r', 'y', 'I', 'D'}, HeightV: reqHeight, BytesV: []byte{'i', 'n', 'n', 'e', 'r'}, @@ -521,7 +521,7 @@ func TestStateSummaryAcceptOlderBlock(t *testing.T) { }() reqHeight := uint64(1969) - innerSummary := &blocktest.TestStateSummary{ + innerSummary := &blocktest.StateSummary{ IDV: ids.ID{'s', 'u', 'm', 'm', 'a', 'r', 'y', 'I', 'D'}, HeightV: reqHeight, BytesV: []byte{'i', 'n', 'n', 'e', 'r'}, diff --git a/vms/proposervm/vm_test.go b/vms/proposervm/vm_test.go index 1b676d67e903..e5d5f142281c 100644 --- a/vms/proposervm/vm_test.go +++ b/vms/proposervm/vm_test.go @@ -46,8 +46,8 @@ var ( ) type fullVM struct { - *blocktest.TestVM - *blocktest.TestStateSyncableVM + *blocktest.VM + *blocktest.StateSyncableVM } var ( @@ -82,7 +82,7 @@ func initTestProposerVM( minPChainHeight uint64, ) ( *fullVM, - *validatorstest.TestState, + *validatorstest.State, *VM, database.Database, ) { @@ -90,12 +90,12 @@ func initTestProposerVM( initialState := []byte("genesis state") coreVM := &fullVM{ - TestVM: &blocktest.TestVM{ - TestVM: enginetest.TestVM{ + VM: &blocktest.VM{ + VM: enginetest.VM{ T: t, }, }, - TestStateSyncableVM: &blocktest.TestStateSyncableVM{ + StateSyncableVM: &blocktest.StateSyncableVM{ T: t, }, } @@ -142,7 +142,7 @@ func initTestProposerVM( }, ) - valState := &validatorstest.TestState{ + valState := &validatorstest.State{ T: t, } valState.GetMinimumHeightF = func(context.Context) (uint64, error) { @@ -782,7 +782,7 @@ func TestPreFork_SetPreference(t *testing.T) { func TestExpiredBuildBlock(t *testing.T) { require := require.New(t) - coreVM := &blocktest.TestVM{} + coreVM := &blocktest.VM{} coreVM.T = t coreVM.LastAcceptedF = snowmantest.MakeLastAcceptedBlockF( @@ -821,7 +821,7 @@ func TestExpiredBuildBlock(t *testing.T) { }, ) - valState := &validatorstest.TestState{ + valState := &validatorstest.State{ T: t, } valState.GetMinimumHeightF = func(context.Context) (uint64, error) { @@ -1065,7 +1065,7 @@ func TestInnerBlockDeduplication(t *testing.T) { func TestInnerVMRollback(t *testing.T) { require := require.New(t) - valState := &validatorstest.TestState{ + valState := &validatorstest.State{ T: t, GetCurrentHeightF: func(context.Context) (uint64, error) { return defaultPChainHeight, nil @@ -1081,8 +1081,8 @@ func TestInnerVMRollback(t *testing.T) { }, } - coreVM := &blocktest.TestVM{ - TestVM: enginetest.TestVM{ + coreVM := &blocktest.VM{ + VM: enginetest.VM{ T: t, InitializeF: func( context.Context, @@ -1559,8 +1559,8 @@ func TestRejectedHeightNotIndexed(t *testing.T) { coreHeights := []ids.ID{snowmantest.GenesisID} initialState := []byte("genesis state") - coreVM := &blocktest.TestVM{ - TestVM: enginetest.TestVM{ + coreVM := &blocktest.VM{ + VM: enginetest.VM{ T: t, }, GetBlockIDAtHeightF: func(_ context.Context, height uint64) (ids.ID, error) { @@ -1613,7 +1613,7 @@ func TestRejectedHeightNotIndexed(t *testing.T) { }, ) - valState := &validatorstest.TestState{ + valState := &validatorstest.State{ T: t, } valState.GetMinimumHeightF = func(context.Context) (uint64, error) { @@ -1732,8 +1732,8 @@ func TestRejectedOptionHeightNotIndexed(t *testing.T) { coreHeights := []ids.ID{snowmantest.GenesisID} initialState := []byte("genesis state") - coreVM := &blocktest.TestVM{ - TestVM: enginetest.TestVM{ + coreVM := &blocktest.VM{ + VM: enginetest.VM{ T: t, }, GetBlockIDAtHeightF: func(_ context.Context, height uint64) (ids.ID, error) { @@ -1786,7 +1786,7 @@ func TestRejectedOptionHeightNotIndexed(t *testing.T) { }, ) - valState := &validatorstest.TestState{ + valState := &validatorstest.State{ T: t, } valState.GetMinimumHeightF = func(context.Context) (uint64, error) { @@ -2171,8 +2171,8 @@ func TestHistoricalBlockDeletion(t *testing.T) { currentHeight := uint64(0) initialState := []byte("genesis state") - coreVM := &blocktest.TestVM{ - TestVM: enginetest.TestVM{ + coreVM := &blocktest.VM{ + VM: enginetest.VM{ T: t, InitializeF: func(context.Context, *snow.Context, database.Database, []byte, []byte, []byte, chan<- common.Message, []*common.Fx, common.AppSender) error { return nil @@ -2207,7 +2207,7 @@ func TestHistoricalBlockDeletion(t *testing.T) { ctx := snowtest.Context(t, snowtest.CChainID) ctx.NodeID = ids.NodeIDFromCert(pTestCert) - ctx.ValidatorState = &validatorstest.TestState{ + ctx.ValidatorState = &validatorstest.State{ T: t, GetMinimumHeightF: func(context.Context) (uint64, error) { return snowmantest.GenesisHeight, nil @@ -2482,7 +2482,7 @@ func TestGetPostDurangoSlotTimeWithNoValidators(t *testing.T) { } func TestLocalParse(t *testing.T) { - innerVM := &blocktest.TestVM{ + innerVM := &blocktest.VM{ ParseBlockF: func(_ context.Context, rawBlock []byte) (snowman.Block, error) { return &snowmantest.Block{BytesV: rawBlock}, nil }, diff --git a/vms/rpcchainvm/state_syncable_vm_test.go b/vms/rpcchainvm/state_syncable_vm_test.go index 347fb78cc4e5..f987424cdc36 100644 --- a/vms/rpcchainvm/state_syncable_vm_test.go +++ b/vms/rpcchainvm/state_syncable_vm_test.go @@ -35,7 +35,7 @@ var ( SummaryHeight = uint64(2022) // a summary to be returned in some UTs - mockedSummary = &blocktest.TestStateSummary{ + mockedSummary = &blocktest.StateSummary{ IDV: ids.ID{'s', 'u', 'm', 'm', 'a', 'r', 'y', 'I', 'D'}, HeightV: SummaryHeight, BytesV: []byte("summary"), diff --git a/wallet/chain/p/builder_test.go b/wallet/chain/p/builder_test.go index 4fa86905a00c..67d51eaa7fcd 100644 --- a/wallet/chain/p/builder_test.go +++ b/wallet/chain/p/builder_test.go @@ -60,7 +60,7 @@ func TestBaseTx(t *testing.T) { // backend utxosKey = testKeys[1] utxos = makeTestUTXOs(utxosKey) - chainUTXOs = utxotest.NewDeterministicChainUTXOs(require, map[ids.ID][]*avax.UTXO{ + chainUTXOs = utxotest.NewDeterministicChainUTXOs(t, map[ids.ID][]*avax.UTXO{ constants.PlatformChainID: utxos, }) backend = NewBackend(testContext, chainUTXOs, nil) @@ -107,7 +107,7 @@ func TestAddSubnetValidatorTx(t *testing.T) { // backend utxosKey = testKeys[1] utxos = makeTestUTXOs(utxosKey) - chainUTXOs = utxotest.NewDeterministicChainUTXOs(require, map[ids.ID][]*avax.UTXO{ + chainUTXOs = utxotest.NewDeterministicChainUTXOs(t, map[ids.ID][]*avax.UTXO{ constants.PlatformChainID: utxos, }) @@ -165,7 +165,7 @@ func TestRemoveSubnetValidatorTx(t *testing.T) { // backend utxosKey = testKeys[1] utxos = makeTestUTXOs(utxosKey) - chainUTXOs = utxotest.NewDeterministicChainUTXOs(require, map[ids.ID][]*avax.UTXO{ + chainUTXOs = utxotest.NewDeterministicChainUTXOs(t, map[ids.ID][]*avax.UTXO{ constants.PlatformChainID: utxos, }) @@ -217,7 +217,7 @@ func TestCreateChainTx(t *testing.T) { // backend utxosKey = testKeys[1] utxos = makeTestUTXOs(utxosKey) - chainUTXOs = utxotest.NewDeterministicChainUTXOs(require, map[ids.ID][]*avax.UTXO{ + chainUTXOs = utxotest.NewDeterministicChainUTXOs(t, map[ids.ID][]*avax.UTXO{ constants.PlatformChainID: utxos, }) @@ -277,7 +277,7 @@ func TestCreateSubnetTx(t *testing.T) { // backend utxosKey = testKeys[1] utxos = makeTestUTXOs(utxosKey) - chainUTXOs = utxotest.NewDeterministicChainUTXOs(require, map[ids.ID][]*avax.UTXO{ + chainUTXOs = utxotest.NewDeterministicChainUTXOs(t, map[ids.ID][]*avax.UTXO{ constants.PlatformChainID: utxos, }) @@ -326,7 +326,7 @@ func TestTransferSubnetOwnershipTx(t *testing.T) { // backend utxosKey = testKeys[1] utxos = makeTestUTXOs(utxosKey) - chainUTXOs = utxotest.NewDeterministicChainUTXOs(require, map[ids.ID][]*avax.UTXO{ + chainUTXOs = utxotest.NewDeterministicChainUTXOs(t, map[ids.ID][]*avax.UTXO{ constants.PlatformChainID: utxos, }) @@ -380,7 +380,7 @@ func TestImportTx(t *testing.T) { utxos = makeTestUTXOs(utxosKey) sourceChainID = ids.GenerateTestID() importedUTXOs = utxos[:1] - chainUTXOs = utxotest.NewDeterministicChainUTXOs(require, map[ids.ID][]*avax.UTXO{ + chainUTXOs = utxotest.NewDeterministicChainUTXOs(t, map[ids.ID][]*avax.UTXO{ constants.PlatformChainID: utxos, sourceChainID: importedUTXOs, }) @@ -429,7 +429,7 @@ func TestExportTx(t *testing.T) { // backend utxosKey = testKeys[1] utxos = makeTestUTXOs(utxosKey) - chainUTXOs = utxotest.NewDeterministicChainUTXOs(require, map[ids.ID][]*avax.UTXO{ + chainUTXOs = utxotest.NewDeterministicChainUTXOs(t, map[ids.ID][]*avax.UTXO{ constants.PlatformChainID: utxos, }) backend = NewBackend(testContext, chainUTXOs, nil) @@ -480,7 +480,7 @@ func TestTransformSubnetTx(t *testing.T) { // backend utxosKey = testKeys[1] utxos = makeTestUTXOs(utxosKey) - chainUTXOs = utxotest.NewDeterministicChainUTXOs(require, map[ids.ID][]*avax.UTXO{ + chainUTXOs = utxotest.NewDeterministicChainUTXOs(t, map[ids.ID][]*avax.UTXO{ constants.PlatformChainID: utxos, }) @@ -576,7 +576,7 @@ func TestAddPermissionlessValidatorTx(t *testing.T) { makeUTXO(1 * units.NanoAvax), // small UTXO makeUTXO(9 * units.Avax), // large UTXO } - chainUTXOs = utxotest.NewDeterministicChainUTXOs(require, map[ids.ID][]*avax.UTXO{ + chainUTXOs = utxotest.NewDeterministicChainUTXOs(t, map[ids.ID][]*avax.UTXO{ constants.PlatformChainID: utxos, }) backend = NewBackend(testContext, chainUTXOs, nil) @@ -655,7 +655,7 @@ func TestAddPermissionlessDelegatorTx(t *testing.T) { // backend utxosKey = testKeys[1] utxos = makeTestUTXOs(utxosKey) - chainUTXOs = utxotest.NewDeterministicChainUTXOs(require, map[ids.ID][]*avax.UTXO{ + chainUTXOs = utxotest.NewDeterministicChainUTXOs(t, map[ids.ID][]*avax.UTXO{ constants.PlatformChainID: utxos, }) backend = NewBackend(testContext, chainUTXOs, nil) diff --git a/wallet/chain/x/builder_test.go b/wallet/chain/x/builder_test.go index 0fe9cd1e6caa..7ed1b425be12 100644 --- a/wallet/chain/x/builder_test.go +++ b/wallet/chain/x/builder_test.go @@ -52,7 +52,7 @@ func TestBaseTx(t *testing.T) { utxosKey = testKeys[1] utxos = makeTestUTXOs(utxosKey) genericBackend = utxotest.NewDeterministicChainUTXOs( - require, + t, map[ids.ID][]*avax.UTXO{ xChainID: utxos, }, @@ -101,7 +101,7 @@ func TestCreateAssetTx(t *testing.T) { utxosKey = testKeys[1] utxos = makeTestUTXOs(utxosKey) genericBackend = utxotest.NewDeterministicChainUTXOs( - require, + t, map[ids.ID][]*avax.UTXO{ xChainID: utxos, }, @@ -190,7 +190,7 @@ func TestMintNFTOperation(t *testing.T) { utxosKey = testKeys[1] utxos = makeTestUTXOs(utxosKey) genericBackend = utxotest.NewDeterministicChainUTXOs( - require, + t, map[ids.ID][]*avax.UTXO{ xChainID: utxos, }, @@ -235,7 +235,7 @@ func TestMintFTOperation(t *testing.T) { utxosKey = testKeys[1] utxos = makeTestUTXOs(utxosKey) genericBackend = utxotest.NewDeterministicChainUTXOs( - require, + t, map[ids.ID][]*avax.UTXO{ xChainID: utxos, }, @@ -282,7 +282,7 @@ func TestMintPropertyOperation(t *testing.T) { utxosKey = testKeys[1] utxos = makeTestUTXOs(utxosKey) genericBackend = utxotest.NewDeterministicChainUTXOs( - require, + t, map[ids.ID][]*avax.UTXO{ xChainID: utxos, }, @@ -325,7 +325,7 @@ func TestBurnPropertyOperation(t *testing.T) { utxosKey = testKeys[1] utxos = makeTestUTXOs(utxosKey) genericBackend = utxotest.NewDeterministicChainUTXOs( - require, + t, map[ids.ID][]*avax.UTXO{ xChainID: utxos, }, @@ -363,7 +363,7 @@ func TestImportTx(t *testing.T) { sourceChainID = ids.GenerateTestID() importedUTXOs = utxos[:1] genericBackend = utxotest.NewDeterministicChainUTXOs( - require, + t, map[ids.ID][]*avax.UTXO{ xChainID: utxos, sourceChainID: importedUTXOs, @@ -413,7 +413,7 @@ func TestExportTx(t *testing.T) { utxosKey = testKeys[1] utxos = makeTestUTXOs(utxosKey) genericBackend = utxotest.NewDeterministicChainUTXOs( - require, + t, map[ids.ID][]*avax.UTXO{ xChainID: utxos, }, diff --git a/wallet/subnet/primary/common/utxotest/utxotest.go b/wallet/subnet/primary/common/utxotest/utxotest.go index 29436ac55eaf..172916fbd24c 100644 --- a/wallet/subnet/primary/common/utxotest/utxotest.go +++ b/wallet/subnet/primary/common/utxotest/utxotest.go @@ -6,6 +6,7 @@ package utxotest import ( "context" "slices" + "testing" "github.com/stretchr/testify/require" @@ -15,12 +16,12 @@ import ( "github.com/ava-labs/avalanchego/wallet/subnet/primary/common" ) -func NewDeterministicChainUTXOs(require *require.Assertions, utxoSets map[ids.ID][]*avax.UTXO) *DeterministicChainUTXOs { +func NewDeterministicChainUTXOs(t *testing.T, utxoSets map[ids.ID][]*avax.UTXO) *DeterministicChainUTXOs { globalUTXOs := common.NewUTXOs() for subnetID, utxos := range utxoSets { for _, utxo := range utxos { require.NoError( - globalUTXOs.AddUTXO(context.Background(), subnetID, constants.PlatformChainID, utxo), + t, globalUTXOs.AddUTXO(context.Background(), subnetID, constants.PlatformChainID, utxo), ) } }