From 602627952c729c4f3b555387869f0197fc8a0ce7 Mon Sep 17 00:00:00 2001 From: Joshua Kim <20001595+joshua-kim@users.noreply.github.com> Date: Wed, 2 Aug 2023 15:01:21 -0400 Subject: [PATCH] Add p2p sdk (#1799) Co-authored-by: Stephen Buttolph --- network/p2p/client.go | 176 +++++++++++++ network/p2p/handler.go | 98 +++++++ network/p2p/mocks/mock_handler.go | 84 ++++++ network/p2p/router.go | 232 ++++++++++++++++ network/p2p/router_test.go | 425 ++++++++++++++++++++++++++++++ network/peer/ip.go | 5 +- scripts/mocks.mockgen.txt | 1 + staking/verification.go | 35 +++ utils/sampler/uniform_resample.go | 12 +- utils/set/sampleable_set.go | 231 ++++++++++++++++ utils/set/sampleable_set_test.go | 133 ++++++++++ vms/proposervm/block/block.go | 7 +- 12 files changed, 1430 insertions(+), 9 deletions(-) create mode 100644 network/p2p/client.go create mode 100644 network/p2p/handler.go create mode 100644 network/p2p/mocks/mock_handler.go create mode 100644 network/p2p/router.go create mode 100644 network/p2p/router_test.go create mode 100644 staking/verification.go create mode 100644 utils/set/sampleable_set.go create mode 100644 utils/set/sampleable_set_test.go diff --git a/network/p2p/client.go b/network/p2p/client.go new file mode 100644 index 000000000000..383002e58222 --- /dev/null +++ b/network/p2p/client.go @@ -0,0 +1,176 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package p2p + +import ( + "context" + "errors" + "fmt" + + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/snow/engine/common" + "github.com/ava-labs/avalanchego/utils/set" +) + +var ( + ErrAppRequestFailed = errors.New("app request failed") + ErrRequestPending = errors.New("request pending") + ErrNoPeers = errors.New("no peers") +) + +// AppResponseCallback is called upon receiving an AppResponse for an AppRequest +// issued by Client. +// Callers should check [err] to see whether the AppRequest failed or not. +type AppResponseCallback func( + nodeID ids.NodeID, + responseBytes []byte, + err error, +) + +// CrossChainAppResponseCallback is called upon receiving an +// CrossChainAppResponse for a CrossChainAppRequest issued by Client. +// Callers should check [err] to see whether the AppRequest failed or not. +type CrossChainAppResponseCallback func( + chainID ids.ID, + responseBytes []byte, + err error, +) + +type Client struct { + handlerPrefix []byte + router *Router + sender common.AppSender +} + +// AppRequestAny issues an AppRequest to an arbitrary node decided by Client. +// If a specific node needs to be requested, use AppRequest instead. +// See AppRequest for more docs. +func (c *Client) AppRequestAny( + ctx context.Context, + appRequestBytes []byte, + onResponse AppResponseCallback, +) error { + c.router.lock.RLock() + peers := c.router.peers.Sample(1) + c.router.lock.RUnlock() + + if len(peers) != 1 { + return ErrNoPeers + } + + nodeIDs := set.Set[ids.NodeID]{ + peers[0]: struct{}{}, + } + return c.AppRequest(ctx, nodeIDs, appRequestBytes, onResponse) +} + +// AppRequest issues an arbitrary request to a node. +// [onResponse] is invoked upon an error or a response. +func (c *Client) AppRequest( + ctx context.Context, + nodeIDs set.Set[ids.NodeID], + appRequestBytes []byte, + onResponse AppResponseCallback, +) error { + c.router.lock.Lock() + defer c.router.lock.Unlock() + + appRequestBytes = c.prefixMessage(appRequestBytes) + for nodeID := range nodeIDs { + requestID := c.router.requestID + if _, ok := c.router.pendingAppRequests[requestID]; ok { + return fmt.Errorf( + "failed to issue request with request id %d: %w", + requestID, + ErrRequestPending, + ) + } + + if err := c.sender.SendAppRequest( + ctx, + set.Set[ids.NodeID]{nodeID: struct{}{}}, + requestID, + appRequestBytes, + ); err != nil { + return err + } + + c.router.pendingAppRequests[requestID] = onResponse + c.router.requestID++ + } + + return nil +} + +// AppGossip sends a gossip message to a random set of peers. +func (c *Client) AppGossip( + ctx context.Context, + appGossipBytes []byte, +) error { + return c.sender.SendAppGossip( + ctx, + c.prefixMessage(appGossipBytes), + ) +} + +// AppGossipSpecific sends a gossip message to a predetermined set of peers. +func (c *Client) AppGossipSpecific( + ctx context.Context, + nodeIDs set.Set[ids.NodeID], + appGossipBytes []byte, +) error { + return c.sender.SendAppGossipSpecific( + ctx, + nodeIDs, + c.prefixMessage(appGossipBytes), + ) +} + +// CrossChainAppRequest sends a cross chain app request to another vm. +// [onResponse] is invoked upon an error or a response. +func (c *Client) CrossChainAppRequest( + ctx context.Context, + chainID ids.ID, + appRequestBytes []byte, + onResponse CrossChainAppResponseCallback, +) error { + c.router.lock.Lock() + defer c.router.lock.Unlock() + + requestID := c.router.requestID + if _, ok := c.router.pendingCrossChainAppRequests[requestID]; ok { + return fmt.Errorf( + "failed to issue request with request id %d: %w", + requestID, + ErrRequestPending, + ) + } + + if err := c.sender.SendCrossChainAppRequest( + ctx, + chainID, + c.router.requestID, + c.prefixMessage(appRequestBytes), + ); err != nil { + return err + } + + c.router.pendingCrossChainAppRequests[requestID] = onResponse + c.router.requestID++ + + return nil +} + +// prefixMessage prefixes the original message with the handler identifier +// corresponding to this client. +// +// Only gossip and request messages need to be prefixed. +// Response messages don't need to be prefixed because request ids are tracked +// which map to the expected response handler. +func (c *Client) prefixMessage(src []byte) []byte { + messageBytes := make([]byte, len(c.handlerPrefix)+len(src)) + copy(messageBytes, c.handlerPrefix) + copy(messageBytes[len(c.handlerPrefix):], src) + return messageBytes +} diff --git a/network/p2p/handler.go b/network/p2p/handler.go new file mode 100644 index 000000000000..d7ae86faa029 --- /dev/null +++ b/network/p2p/handler.go @@ -0,0 +1,98 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package p2p + +import ( + "context" + "time" + + "go.uber.org/zap" + + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/message" + "github.com/ava-labs/avalanchego/snow/engine/common" + "github.com/ava-labs/avalanchego/utils/logging" +) + +// Handler is the server-side logic for virtual machine application protocols. +type Handler interface { + // AppGossip is called when handling an AppGossip message. + AppGossip( + ctx context.Context, + nodeID ids.NodeID, + gossipBytes []byte, + ) error + // AppRequest is called when handling an AppRequest message. + // Returns the bytes for the response corresponding to [requestBytes] + AppRequest( + ctx context.Context, + nodeID ids.NodeID, + deadline time.Time, + requestBytes []byte, + ) ([]byte, error) + // CrossChainAppRequest is called when handling a CrossChainAppRequest + // message. + // Returns the bytes for the response corresponding to [requestBytes] + CrossChainAppRequest( + ctx context.Context, + chainID ids.ID, + deadline time.Time, + requestBytes []byte, + ) ([]byte, error) +} + +// responder automatically sends the response for a given request +type responder struct { + handlerID uint64 + handler Handler + log logging.Logger + sender common.AppSender +} + +func (r *responder) AppRequest(ctx context.Context, nodeID ids.NodeID, requestID uint32, deadline time.Time, request []byte) error { + appResponse, err := r.handler.AppRequest(ctx, nodeID, deadline, request) + if err != nil { + r.log.Debug("failed to handle message", + zap.Stringer("messageOp", message.AppRequestOp), + zap.Stringer("nodeID", nodeID), + zap.Uint32("requestID", requestID), + zap.Time("deadline", deadline), + zap.Uint64("handlerID", r.handlerID), + zap.Binary("message", request), + ) + return nil + } + + return r.sender.SendAppResponse(ctx, nodeID, requestID, appResponse) +} + +func (r *responder) AppGossip(ctx context.Context, nodeID ids.NodeID, msg []byte) error { + err := r.handler.AppGossip(ctx, nodeID, msg) + if err != nil { + r.log.Debug("failed to handle message", + zap.Stringer("messageOp", message.AppGossipOp), + zap.Stringer("nodeID", nodeID), + zap.Uint64("handlerID", r.handlerID), + zap.Binary("message", msg), + ) + } + return nil +} + +func (r *responder) CrossChainAppRequest(ctx context.Context, chainID ids.ID, requestID uint32, deadline time.Time, request []byte) error { + appResponse, err := r.handler.CrossChainAppRequest(ctx, chainID, deadline, request) + if err != nil { + r.log.Debug("failed to handle message", + zap.Stringer("messageOp", message.CrossChainAppRequestOp), + zap.Stringer("chainID", chainID), + zap.Uint32("requestID", requestID), + zap.Time("deadline", deadline), + zap.Uint64("handlerID", r.handlerID), + zap.Binary("message", request), + ) + return nil + } + + return r.sender.SendCrossChainAppResponse(ctx, chainID, requestID, appResponse) +} diff --git a/network/p2p/mocks/mock_handler.go b/network/p2p/mocks/mock_handler.go new file mode 100644 index 000000000000..87b3ee58c197 --- /dev/null +++ b/network/p2p/mocks/mock_handler.go @@ -0,0 +1,84 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/ava-labs/avalanchego/x/p2p (interfaces: Handler) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + time "time" + + ids "github.com/ava-labs/avalanchego/ids" + gomock "github.com/golang/mock/gomock" +) + +// MockHandler is a mock of Handler interface. +type MockHandler struct { + ctrl *gomock.Controller + recorder *MockHandlerMockRecorder +} + +// MockHandlerMockRecorder is the mock recorder for MockHandler. +type MockHandlerMockRecorder struct { + mock *MockHandler +} + +// NewMockHandler creates a new mock instance. +func NewMockHandler(ctrl *gomock.Controller) *MockHandler { + mock := &MockHandler{ctrl: ctrl} + mock.recorder = &MockHandlerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockHandler) EXPECT() *MockHandlerMockRecorder { + return m.recorder +} + +// AppGossip mocks base method. +func (m *MockHandler) AppGossip(arg0 context.Context, arg1 ids.NodeID, arg2 []byte) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AppGossip", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// AppGossip indicates an expected call of AppGossip. +func (mr *MockHandlerMockRecorder) AppGossip(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppGossip", reflect.TypeOf((*MockHandler)(nil).AppGossip), arg0, arg1, arg2) +} + +// AppRequest mocks base method. +func (m *MockHandler) AppRequest(arg0 context.Context, arg1 ids.NodeID, arg2 time.Time, arg3 []byte) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AppRequest", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AppRequest indicates an expected call of AppRequest. +func (mr *MockHandlerMockRecorder) AppRequest(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppRequest", reflect.TypeOf((*MockHandler)(nil).AppRequest), arg0, arg1, arg2, arg3) +} + +// CrossChainAppRequest mocks base method. +func (m *MockHandler) CrossChainAppRequest(arg0 context.Context, arg1 ids.ID, arg2 time.Time, arg3 []byte) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CrossChainAppRequest", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CrossChainAppRequest indicates an expected call of CrossChainAppRequest. +func (mr *MockHandlerMockRecorder) CrossChainAppRequest(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CrossChainAppRequest", reflect.TypeOf((*MockHandler)(nil).CrossChainAppRequest), arg0, arg1, arg2, arg3) +} diff --git a/network/p2p/router.go b/network/p2p/router.go new file mode 100644 index 000000000000..a685abc7e502 --- /dev/null +++ b/network/p2p/router.go @@ -0,0 +1,232 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package p2p + +import ( + "context" + "encoding/binary" + "errors" + "fmt" + "sync" + "time" + + "go.uber.org/zap" + + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/message" + "github.com/ava-labs/avalanchego/snow/engine/common" + "github.com/ava-labs/avalanchego/snow/validators" + "github.com/ava-labs/avalanchego/utils/logging" + "github.com/ava-labs/avalanchego/utils/set" + "github.com/ava-labs/avalanchego/version" +) + +var ( + ErrExistingAppProtocol = errors.New("existing app protocol") + ErrUnrequestedResponse = errors.New("unrequested response") + + _ common.AppHandler = (*Router)(nil) + _ validators.Connector = (*Router)(nil) +) + +// Router routes incoming application messages to the corresponding registered +// app handler. App messages must be made using the registered handler's +// corresponding Client. +type Router struct { + log logging.Logger + sender common.AppSender + + lock sync.RWMutex + handlers map[uint64]*responder + pendingAppRequests map[uint32]AppResponseCallback + pendingCrossChainAppRequests map[uint32]CrossChainAppResponseCallback + requestID uint32 + peers set.SampleableSet[ids.NodeID] +} + +// NewRouter returns a new instance of Router +func NewRouter(log logging.Logger, sender common.AppSender) *Router { + return &Router{ + log: log, + sender: sender, + handlers: make(map[uint64]*responder), + pendingAppRequests: make(map[uint32]AppResponseCallback), + pendingCrossChainAppRequests: make(map[uint32]CrossChainAppResponseCallback), + } +} + +func (r *Router) Connected(_ context.Context, nodeID ids.NodeID, _ *version.Application) error { + r.lock.Lock() + defer r.lock.Unlock() + + r.peers.Add(nodeID) + return nil +} + +func (r *Router) Disconnected(_ context.Context, nodeID ids.NodeID) error { + r.lock.Lock() + defer r.lock.Unlock() + + r.peers.Remove(nodeID) + return nil +} + +// RegisterAppProtocol reserves an identifier for an application protocol and +// returns a Client that can be used to send messages for the corresponding +// protocol. +func (r *Router) RegisterAppProtocol(handlerID uint64, handler Handler) (*Client, error) { + r.lock.Lock() + defer r.lock.Unlock() + + if _, ok := r.handlers[handlerID]; ok { + return nil, fmt.Errorf("failed to register handler id %d: %w", handlerID, ErrExistingAppProtocol) + } + + r.handlers[handlerID] = &responder{ + handlerID: handlerID, + handler: handler, + log: r.log, + sender: r.sender, + } + + return &Client{ + handlerPrefix: binary.AppendUvarint(nil, handlerID), + sender: r.sender, + router: r, + }, nil +} + +func (r *Router) AppRequest(ctx context.Context, nodeID ids.NodeID, requestID uint32, deadline time.Time, request []byte) error { + parsedMsg, handler, ok := r.parse(request) + if !ok { + r.log.Debug("failed to process message", + zap.Stringer("messageOp", message.AppRequestOp), + zap.Stringer("nodeID", nodeID), + zap.Uint32("requestID", requestID), + zap.Time("deadline", deadline), + zap.Binary("message", request), + ) + return nil + } + + return handler.AppRequest(ctx, nodeID, requestID, deadline, parsedMsg) +} + +func (r *Router) AppRequestFailed(_ context.Context, nodeID ids.NodeID, requestID uint32) error { + callback, ok := r.clearAppRequest(requestID) + if !ok { + return ErrUnrequestedResponse + } + + callback(nodeID, nil, ErrAppRequestFailed) + return nil +} + +func (r *Router) AppResponse(_ context.Context, nodeID ids.NodeID, requestID uint32, response []byte) error { + callback, ok := r.clearAppRequest(requestID) + if !ok { + return ErrUnrequestedResponse + } + + callback(nodeID, response, nil) + return nil +} + +func (r *Router) AppGossip(ctx context.Context, nodeID ids.NodeID, gossip []byte) error { + parsedMsg, handler, ok := r.parse(gossip) + if !ok { + r.log.Debug("failed to process message", + zap.Stringer("messageOp", message.AppGossipOp), + zap.Stringer("nodeID", nodeID), + zap.Binary("message", gossip), + ) + return nil + } + + return handler.AppGossip(ctx, nodeID, parsedMsg) +} + +func (r *Router) CrossChainAppRequest( + ctx context.Context, + chainID ids.ID, + requestID uint32, + deadline time.Time, + msg []byte, +) error { + parsedMsg, handler, ok := r.parse(msg) + if !ok { + r.log.Debug("failed to process message", + zap.Stringer("messageOp", message.CrossChainAppRequestOp), + zap.Stringer("chainID", chainID), + zap.Uint32("requestID", requestID), + zap.Time("deadline", deadline), + zap.Binary("message", msg), + ) + return nil + } + + return handler.CrossChainAppRequest(ctx, chainID, requestID, deadline, parsedMsg) +} + +func (r *Router) CrossChainAppRequestFailed(_ context.Context, chainID ids.ID, requestID uint32) error { + callback, ok := r.clearCrossChainAppRequest(requestID) + if !ok { + return ErrUnrequestedResponse + } + + callback(chainID, nil, ErrAppRequestFailed) + return nil +} + +func (r *Router) CrossChainAppResponse(_ context.Context, chainID ids.ID, requestID uint32, response []byte) error { + callback, ok := r.clearCrossChainAppRequest(requestID) + if !ok { + return ErrUnrequestedResponse + } + + callback(chainID, response, nil) + return nil +} + +// Parse parses a gossip or request message and maps it to a corresponding +// handler if present. +// +// Returns: +// - The unprefixed protocol message. +// - The protocol responder. +// - A boolean indicating that parsing succeeded. +// +// Invariant: Assumes [r.lock] isn't held. +func (r *Router) parse(msg []byte) ([]byte, *responder, bool) { + handlerID, bytesRead := binary.Uvarint(msg) + if bytesRead <= 0 { + return nil, nil, false + } + + r.lock.RLock() + defer r.lock.RUnlock() + + handler, ok := r.handlers[handlerID] + return msg[bytesRead:], handler, ok +} + +// Invariant: Assumes [r.lock] isn't held. +func (r *Router) clearAppRequest(requestID uint32) (AppResponseCallback, bool) { + r.lock.Lock() + defer r.lock.Unlock() + + callback, ok := r.pendingAppRequests[requestID] + delete(r.pendingAppRequests, requestID) + return callback, ok +} + +// Invariant: Assumes [r.lock] isn't held. +func (r *Router) clearCrossChainAppRequest(requestID uint32) (CrossChainAppResponseCallback, bool) { + r.lock.Lock() + defer r.lock.Unlock() + + callback, ok := r.pendingCrossChainAppRequests[requestID] + delete(r.pendingCrossChainAppRequests, requestID) + return callback, ok +} diff --git a/network/p2p/router_test.go b/network/p2p/router_test.go new file mode 100644 index 000000000000..20bf0dcf2460 --- /dev/null +++ b/network/p2p/router_test.go @@ -0,0 +1,425 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package p2p + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/golang/mock/gomock" + + "github.com/stretchr/testify/require" + + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/network/p2p/mocks" + "github.com/ava-labs/avalanchego/snow/engine/common" + "github.com/ava-labs/avalanchego/utils/logging" + "github.com/ava-labs/avalanchego/utils/set" +) + +func TestAppRequestResponse(t *testing.T) { + handlerID := uint64(0x0) + request := []byte("request") + response := []byte("response") + nodeID := ids.GenerateTestNodeID() + chainID := ids.GenerateTestID() + + tests := []struct { + name string + requestFunc func(t *testing.T, router *Router, client *Client, sender *common.MockSender, handler *mocks.MockHandler, wg *sync.WaitGroup) + }{ + { + name: "app request", + requestFunc: func(t *testing.T, router *Router, client *Client, sender *common.MockSender, handler *mocks.MockHandler, wg *sync.WaitGroup) { + sender.EXPECT().SendAppRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Do(func(ctx context.Context, nodeIDs set.Set[ids.NodeID], requestID uint32, request []byte) { + for range nodeIDs { + go func() { + require.NoError(t, router.AppRequest(ctx, nodeID, requestID, time.Time{}, request)) + }() + } + }).AnyTimes() + sender.EXPECT().SendAppResponse(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Do(func(ctx context.Context, _ ids.NodeID, requestID uint32, response []byte) { + go func() { + require.NoError(t, router.AppResponse(ctx, nodeID, requestID, response)) + }() + }).AnyTimes() + handler.EXPECT(). + AppRequest(context.Background(), nodeID, gomock.Any(), request). + DoAndReturn(func(context.Context, ids.NodeID, time.Time, []byte) ([]byte, error) { + return response, nil + }) + + callback := func(actualNodeID ids.NodeID, actualResponse []byte, err error) { + defer wg.Done() + + require.NoError(t, err) + require.Equal(t, nodeID, actualNodeID) + require.Equal(t, response, actualResponse) + } + + require.NoError(t, client.AppRequestAny(context.Background(), request, callback)) + }, + }, + { + name: "app request failed", + requestFunc: func(t *testing.T, router *Router, client *Client, sender *common.MockSender, handler *mocks.MockHandler, wg *sync.WaitGroup) { + sender.EXPECT().SendAppRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Do(func(ctx context.Context, nodeIDs set.Set[ids.NodeID], requestID uint32, request []byte) { + for range nodeIDs { + go func() { + require.NoError(t, router.AppRequestFailed(ctx, nodeID, requestID)) + }() + } + }) + + callback := func(actualNodeID ids.NodeID, actualResponse []byte, err error) { + defer wg.Done() + + require.ErrorIs(t, err, ErrAppRequestFailed) + require.Equal(t, nodeID, actualNodeID) + require.Nil(t, actualResponse) + } + + require.NoError(t, client.AppRequest(context.Background(), set.Set[ids.NodeID]{nodeID: struct{}{}}, request, callback)) + }, + }, + { + name: "cross-chain app request", + requestFunc: func(t *testing.T, router *Router, client *Client, sender *common.MockSender, handler *mocks.MockHandler, wg *sync.WaitGroup) { + chainID := ids.GenerateTestID() + sender.EXPECT().SendCrossChainAppRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Do(func(ctx context.Context, chainID ids.ID, requestID uint32, request []byte) { + go func() { + require.NoError(t, router.CrossChainAppRequest(ctx, chainID, requestID, time.Time{}, request)) + }() + }).AnyTimes() + sender.EXPECT().SendCrossChainAppResponse(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Do(func(ctx context.Context, chainID ids.ID, requestID uint32, response []byte) { + go func() { + require.NoError(t, router.CrossChainAppResponse(ctx, chainID, requestID, response)) + }() + }).AnyTimes() + handler.EXPECT(). + CrossChainAppRequest(context.Background(), chainID, gomock.Any(), request). + DoAndReturn(func(context.Context, ids.ID, time.Time, []byte) ([]byte, error) { + return response, nil + }) + + callback := func(actualChainID ids.ID, actualResponse []byte, err error) { + defer wg.Done() + require.NoError(t, err) + require.Equal(t, chainID, actualChainID) + require.Equal(t, response, actualResponse) + } + + require.NoError(t, client.CrossChainAppRequest(context.Background(), chainID, request, callback)) + }, + }, + { + name: "cross-chain app request failed", + requestFunc: func(t *testing.T, router *Router, client *Client, sender *common.MockSender, handler *mocks.MockHandler, wg *sync.WaitGroup) { + sender.EXPECT().SendCrossChainAppRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Do(func(ctx context.Context, chainID ids.ID, requestID uint32, request []byte) { + go func() { + require.NoError(t, router.CrossChainAppRequestFailed(ctx, chainID, requestID)) + }() + }) + + callback := func(actualChainID ids.ID, actualResponse []byte, err error) { + defer wg.Done() + + require.ErrorIs(t, err, ErrAppRequestFailed) + require.Equal(t, chainID, actualChainID) + require.Nil(t, actualResponse) + } + + require.NoError(t, client.CrossChainAppRequest(context.Background(), chainID, request, callback)) + }, + }, + { + name: "app gossip", + requestFunc: func(t *testing.T, router *Router, client *Client, sender *common.MockSender, handler *mocks.MockHandler, wg *sync.WaitGroup) { + sender.EXPECT().SendAppGossip(gomock.Any(), gomock.Any()). + Do(func(ctx context.Context, gossip []byte) { + go func() { + require.NoError(t, router.AppGossip(ctx, nodeID, gossip)) + }() + }).AnyTimes() + handler.EXPECT(). + AppGossip(context.Background(), nodeID, request). + DoAndReturn(func(context.Context, ids.NodeID, []byte) error { + defer wg.Done() + return nil + }) + + require.NoError(t, client.AppGossip(context.Background(), request)) + }, + }, + { + name: "app gossip specific", + requestFunc: func(t *testing.T, router *Router, client *Client, sender *common.MockSender, handler *mocks.MockHandler, wg *sync.WaitGroup) { + sender.EXPECT().SendAppGossipSpecific(gomock.Any(), gomock.Any(), gomock.Any()). + Do(func(ctx context.Context, nodeIDs set.Set[ids.NodeID], gossip []byte) { + for n := range nodeIDs { + nodeID := n + go func() { + require.NoError(t, router.AppGossip(ctx, nodeID, gossip)) + }() + } + }).AnyTimes() + handler.EXPECT(). + AppGossip(context.Background(), nodeID, request). + DoAndReturn(func(context.Context, ids.NodeID, []byte) error { + defer wg.Done() + return nil + }) + + require.NoError(t, client.AppGossipSpecific(context.Background(), set.Set[ids.NodeID]{nodeID: struct{}{}}, request)) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + ctrl := gomock.NewController(t) + + sender := common.NewMockSender(ctrl) + handler := mocks.NewMockHandler(ctrl) + router := NewRouter(logging.NoLog{}, sender) + require.NoError(router.Connected(context.Background(), nodeID, nil)) + + client, err := router.RegisterAppProtocol(handlerID, handler) + require.NoError(err) + + wg := &sync.WaitGroup{} + wg.Add(1) + tt.requestFunc(t, router, client, sender, handler, wg) + wg.Wait() + }) + } +} + +func TestRouterDropMessage(t *testing.T) { + unregistered := byte(0x0) + + tests := []struct { + name string + requestFunc func(router *Router) error + err error + }{ + { + name: "drop unregistered app request message", + requestFunc: func(router *Router) error { + return router.AppRequest(context.Background(), ids.GenerateTestNodeID(), 0, time.Time{}, []byte{unregistered}) + }, + err: nil, + }, + { + name: "drop empty app request message", + requestFunc: func(router *Router) error { + return router.AppRequest(context.Background(), ids.GenerateTestNodeID(), 0, time.Time{}, []byte{}) + }, + err: nil, + }, + { + name: "drop unregistered cross-chain app request message", + requestFunc: func(router *Router) error { + return router.CrossChainAppRequest(context.Background(), ids.GenerateTestID(), 0, time.Time{}, []byte{unregistered}) + }, + err: nil, + }, + { + name: "drop empty cross-chain app request message", + requestFunc: func(router *Router) error { + return router.CrossChainAppRequest(context.Background(), ids.GenerateTestID(), 0, time.Time{}, []byte{}) + }, + err: nil, + }, + { + name: "drop unregistered gossip message", + requestFunc: func(router *Router) error { + return router.AppGossip(context.Background(), ids.GenerateTestNodeID(), []byte{unregistered}) + }, + err: nil, + }, + { + name: "drop empty gossip message", + requestFunc: func(router *Router) error { + return router.AppGossip(context.Background(), ids.GenerateTestNodeID(), []byte{}) + }, + err: nil, + }, + { + name: "drop unrequested app request failed", + requestFunc: func(router *Router) error { + return router.AppRequestFailed(context.Background(), ids.GenerateTestNodeID(), 0) + }, + err: ErrUnrequestedResponse, + }, + { + name: "drop unrequested app response", + requestFunc: func(router *Router) error { + return router.AppResponse(context.Background(), ids.GenerateTestNodeID(), 0, nil) + }, + err: ErrUnrequestedResponse, + }, + { + name: "drop unrequested cross-chain request failed", + requestFunc: func(router *Router) error { + return router.CrossChainAppRequestFailed(context.Background(), ids.GenerateTestID(), 0) + }, + err: ErrUnrequestedResponse, + }, + { + name: "drop unrequested cross-chain response", + requestFunc: func(router *Router) error { + return router.CrossChainAppResponse(context.Background(), ids.GenerateTestID(), 0, nil) + }, + err: ErrUnrequestedResponse, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + + router := NewRouter(logging.NoLog{}, nil) + + err := tt.requestFunc(router) + require.ErrorIs(err, tt.err) + }) + } +} + +// It's possible for the request id to overflow and wrap around. +// If there are still pending requests with the same request id, we should +// not attempt to issue another request until the previous one has cleared. +func TestAppRequestDuplicateRequestIDs(t *testing.T) { + require := require.New(t) + ctrl := gomock.NewController(t) + + handler := mocks.NewMockHandler(ctrl) + sender := common.NewMockSender(ctrl) + router := NewRouter(logging.NoLog{}, sender) + nodeID := ids.GenerateTestNodeID() + + requestSent := &sync.WaitGroup{} + sender.EXPECT().SendAppRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Do(func(ctx context.Context, nodeIDs set.Set[ids.NodeID], requestID uint32, request []byte) { + for range nodeIDs { + requestSent.Add(1) + go func() { + require.NoError(router.AppRequest(ctx, nodeID, requestID, time.Time{}, request)) + requestSent.Done() + }() + } + }).AnyTimes() + + timeout := &sync.WaitGroup{} + response := []byte("response") + handler.EXPECT().AppRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, nodeID ids.NodeID, deadline time.Time, request []byte) ([]byte, error) { + timeout.Wait() + return response, nil + }).AnyTimes() + sender.EXPECT().SendAppResponse(gomock.Any(), gomock.Any(), gomock.Any(), response) + + client, err := router.RegisterAppProtocol(0x1, handler) + require.NoError(err) + + require.NoError(client.AppRequest(context.Background(), set.Set[ids.NodeID]{nodeID: struct{}{}}, []byte{}, nil)) + requestSent.Wait() + + // force the router to use the same requestID + router.requestID = 0 + timeout.Add(1) + err = client.AppRequest(context.Background(), set.Set[ids.NodeID]{nodeID: struct{}{}}, []byte{}, nil) + requestSent.Wait() + require.ErrorIs(err, ErrRequestPending) + + timeout.Done() +} + +func TestRouterConnected(t *testing.T) { + tests := []struct { + name string + connect []ids.NodeID + disconnect []ids.NodeID + }{ + { + name: "empty", + }, + { + name: "connect and disconnect", + connect: []ids.NodeID{ + {0x0}, + }, + disconnect: []ids.NodeID{ + {0x0}, + }, + }, + { + name: "two nodes connect", + connect: []ids.NodeID{ + {0x0, 0x1}, + }, + }, + { + name: "two nodes connect, last one disconnects", + connect: []ids.NodeID{ + {0x0, 0x1}, + }, + disconnect: []ids.NodeID{ + {0x1}, + }, + }, + { + name: "two nodes connect, first one disconnects", + connect: []ids.NodeID{ + {0x0, 0x1}, + }, + disconnect: []ids.NodeID{ + {0x0}, + }, + }, + { + name: "two nodes connect and disconnect", + connect: []ids.NodeID{ + {0x0, 0x1}, + }, + disconnect: []ids.NodeID{ + {0x0, 0x1}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + router := NewRouter(logging.NoLog{}, nil) + + expected := set.Set[ids.NodeID]{} + + for _, connect := range tt.connect { + expected.Add(connect) + require.NoError(router.Connected(context.Background(), connect, nil)) + } + + for _, disconnect := range tt.disconnect { + expected.Remove(disconnect) + require.NoError(router.Disconnected(context.Background(), disconnect)) + } + + require.Len(expected, router.peers.Len()) + for _, peer := range router.peers.List() { + require.Contains(expected, peer) + } + }) + } +} diff --git a/network/peer/ip.go b/network/peer/ip.go index 720a1cd89814..711e28a104da 100644 --- a/network/peer/ip.go +++ b/network/peer/ip.go @@ -8,6 +8,7 @@ import ( "crypto/rand" "crypto/x509" + "github.com/ava-labs/avalanchego/staking" "github.com/ava-labs/avalanchego/utils/hashing" "github.com/ava-labs/avalanchego/utils/ips" "github.com/ava-labs/avalanchego/utils/wrappers" @@ -50,8 +51,8 @@ type SignedIP struct { } func (ip *SignedIP) Verify(cert *x509.Certificate) error { - return cert.CheckSignature( - cert.SignatureAlgorithm, + return staking.CheckSignature( + cert, ip.UnsignedIP.bytes(), ip.Signature, ) diff --git a/scripts/mocks.mockgen.txt b/scripts/mocks.mockgen.txt index 7b40e0224c73..0ebe66b46103 100644 --- a/scripts/mocks.mockgen.txt +++ b/scripts/mocks.mockgen.txt @@ -48,4 +48,5 @@ github.com/ava-labs/avalanchego/vms/registry=VMRegisterer=vms/registry/mock_vm_r github.com/ava-labs/avalanchego/vms/registry=VMRegistry=vms/registry/mock_vm_registry.go github.com/ava-labs/avalanchego/vms=Factory,Manager=vms/mock_manager.go github.com/ava-labs/avalanchego/x/merkledb=MerkleDB=x/merkledb/mock_db.go +github.com/ava-labs/avalanchego/x/p2p=Handler=x/p2p/mocks/mock_handler.go github.com/ava-labs/avalanchego/x/sync=Client=x/sync/mock_client.go diff --git a/staking/verification.go b/staking/verification.go new file mode 100644 index 000000000000..ef87e15ae863 --- /dev/null +++ b/staking/verification.go @@ -0,0 +1,35 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package staking + +import ( + "crypto/rsa" + "crypto/x509" + "errors" + "fmt" +) + +// MaxRSAKeyBitLen is the maximum RSA key size in bits that we are willing to +// parse. +const MaxRSAKeyBitLen = 8192 + +var ErrInvalidPublicKey = errors.New("invalid public key") + +func CheckSignature(cert *x509.Certificate, message []byte, signature []byte) error { + if cert.PublicKeyAlgorithm == x509.RSA { + pk, ok := cert.PublicKey.(*rsa.PublicKey) + if !ok { + return fmt.Errorf("%w: %T", ErrInvalidPublicKey, cert.PublicKey) + } + if bitLen := pk.N.BitLen(); bitLen > MaxRSAKeyBitLen { + return fmt.Errorf("%w: bitLen=%d > maxBitLen=%d", ErrInvalidPublicKey, bitLen, MaxRSAKeyBitLen) + } + } + + return cert.CheckSignature( + cert.SignatureAlgorithm, + message, + signature, + ) +} diff --git a/utils/sampler/uniform_resample.go b/utils/sampler/uniform_resample.go index e0557bbd88d1..8f09e95f777c 100644 --- a/utils/sampler/uniform_resample.go +++ b/utils/sampler/uniform_resample.go @@ -3,7 +3,7 @@ package sampler -import "github.com/ava-labs/avalanchego/utils/set" +import "golang.org/x/exp/maps" // uniformResample allows for sampling over a uniform distribution without // replacement. @@ -18,14 +18,14 @@ type uniformResample struct { rng *rng seededRNG *rng length uint64 - drawn set.Set[uint64] + drawn map[uint64]struct{} } func (s *uniformResample) Initialize(length uint64) { s.rng = globalRNG s.seededRNG = newRNG() s.length = length - s.drawn.Clear() + s.drawn = make(map[uint64]struct{}) } func (s *uniformResample) Sample(count int) ([]uint64, error) { @@ -52,7 +52,7 @@ func (s *uniformResample) ClearSeed() { } func (s *uniformResample) Reset() { - s.drawn.Clear() + maps.Clear(s.drawn) } func (s *uniformResample) Next() (uint64, error) { @@ -63,10 +63,10 @@ func (s *uniformResample) Next() (uint64, error) { for { draw := s.rng.Uint64Inclusive(s.length - 1) - if s.drawn.Contains(draw) { + if _, ok := s.drawn[draw]; ok { continue } - s.drawn.Add(draw) + s.drawn[draw] = struct{}{} return draw, nil } } diff --git a/utils/set/sampleable_set.go b/utils/set/sampleable_set.go new file mode 100644 index 000000000000..fa96151348ad --- /dev/null +++ b/utils/set/sampleable_set.go @@ -0,0 +1,231 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package set + +import ( + "bytes" + + stdjson "encoding/json" + + "golang.org/x/exp/maps" + "golang.org/x/exp/slices" + + "github.com/ava-labs/avalanchego/utils" + "github.com/ava-labs/avalanchego/utils/json" + "github.com/ava-labs/avalanchego/utils/math" + "github.com/ava-labs/avalanchego/utils/sampler" + "github.com/ava-labs/avalanchego/utils/wrappers" +) + +var _ stdjson.Marshaler = (*Set[int])(nil) + +// SampleableSet is a set of elements that supports sampling. +type SampleableSet[T comparable] struct { + // indices maps the element in the set to the index that it appears in + // elements. + indices map[T]int + elements []T +} + +// Return a new sampleable set with initial capacity [size]. +// More or less than [size] elements can be added to this set. +// Using NewSampleableSet() rather than SampleableSet[T]{} is just an +// optimization that can be used if you know how many elements will be put in +// this set. +func NewSampleableSet[T comparable](size int) SampleableSet[T] { + if size < 0 { + return SampleableSet[T]{} + } + return SampleableSet[T]{ + indices: make(map[T]int, size), + elements: make([]T, 0, size), + } +} + +// Add all the elements to this set. +// If the element is already in the set, nothing happens. +func (s *SampleableSet[T]) Add(elements ...T) { + s.resize(2 * len(elements)) + for _, e := range elements { + s.add(e) + } +} + +// Union adds all the elements from the provided set to this set. +func (s *SampleableSet[T]) Union(set SampleableSet[T]) { + s.resize(2 * set.Len()) + for _, e := range set.elements { + s.add(e) + } +} + +// Difference removes all the elements in [set] from [s]. +func (s *SampleableSet[T]) Difference(set SampleableSet[T]) { + for _, e := range set.elements { + s.remove(e) + } +} + +// Contains returns true iff the set contains this element. +func (s SampleableSet[T]) Contains(e T) bool { + _, contains := s.indices[e] + return contains +} + +// Overlaps returns true if the intersection of the set is non-empty +func (s SampleableSet[T]) Overlaps(big SampleableSet[T]) bool { + small := s + if small.Len() > big.Len() { + small, big = big, small + } + + for _, e := range small.elements { + if _, ok := big.indices[e]; ok { + return true + } + } + return false +} + +// Len returns the number of elements in this set. +func (s SampleableSet[_]) Len() int { + return len(s.elements) +} + +// Remove all the given elements from this set. +// If an element isn't in the set, it's ignored. +func (s *SampleableSet[T]) Remove(elements ...T) { + for _, e := range elements { + s.remove(e) + } +} + +// Clear empties this set +func (s *SampleableSet[T]) Clear() { + maps.Clear(s.indices) + for i := range s.elements { + s.elements[i] = utils.Zero[T]() + } + s.elements = s.elements[:0] +} + +// List converts this set into a list +func (s SampleableSet[T]) List() []T { + return slices.Clone(s.elements) +} + +// Equals returns true if the sets contain the same elements +func (s SampleableSet[T]) Equals(other SampleableSet[T]) bool { + if len(s.indices) != len(other.indices) { + return false + } + for k := range s.indices { + if _, ok := other.indices[k]; !ok { + return false + } + } + return true +} + +func (s SampleableSet[T]) Sample(numToSample int) []T { + if numToSample <= 0 { + return nil + } + + uniform := sampler.NewUniform() + uniform.Initialize(uint64(len(s.elements))) + indices, _ := uniform.Sample(math.Min(len(s.elements), numToSample)) + elements := make([]T, len(indices)) + for i, index := range indices { + elements[i] = s.elements[index] + } + return elements +} + +func (s *SampleableSet[T]) UnmarshalJSON(b []byte) error { + str := string(b) + if str == json.Null { + return nil + } + var elements []T + if err := stdjson.Unmarshal(b, &elements); err != nil { + return err + } + s.Clear() + s.Add(elements...) + return nil +} + +func (s *SampleableSet[_]) MarshalJSON() ([]byte, error) { + var ( + elementBytes = make([][]byte, len(s.elements)) + err error + ) + for i, e := range s.elements { + elementBytes[i], err = stdjson.Marshal(e) + if err != nil { + return nil, err + } + } + // Sort for determinism + utils.SortBytes(elementBytes) + + // Build the JSON + var ( + jsonBuf = bytes.Buffer{} + errs = wrappers.Errs{} + ) + _, err = jsonBuf.WriteString("[") + errs.Add(err) + for i, elt := range elementBytes { + _, err := jsonBuf.Write(elt) + errs.Add(err) + if i != len(elementBytes)-1 { + _, err := jsonBuf.WriteString(",") + errs.Add(err) + } + } + _, err = jsonBuf.WriteString("]") + errs.Add(err) + + return jsonBuf.Bytes(), errs.Err +} + +func (s *SampleableSet[T]) resize(size int) { + if s.elements == nil { + if minSetSize > size { + size = minSetSize + } + s.indices = make(map[T]int, size) + } +} + +func (s *SampleableSet[T]) add(e T) { + _, ok := s.indices[e] + if ok { + return + } + + s.indices[e] = len(s.elements) + s.elements = append(s.elements, e) +} + +func (s *SampleableSet[T]) remove(e T) { + indexToRemove, ok := s.indices[e] + if !ok { + return + } + + lastIndex := len(s.elements) - 1 + if indexToRemove != lastIndex { + lastElement := s.elements[lastIndex] + + s.indices[lastElement] = indexToRemove + s.elements[indexToRemove] = lastElement + } + + delete(s.indices, e) + s.elements[lastIndex] = utils.Zero[T]() + s.elements = s.elements[:lastIndex] +} diff --git a/utils/set/sampleable_set_test.go b/utils/set/sampleable_set_test.go new file mode 100644 index 000000000000..0cda8c23c79d --- /dev/null +++ b/utils/set/sampleable_set_test.go @@ -0,0 +1,133 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package set + +import ( + "encoding/json" + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSampleableSet(t *testing.T) { + require := require.New(t) + id1 := 1 + + s := SampleableSet[int]{} + + s.Add(id1) + require.True(s.Contains(id1)) + + s.Remove(id1) + require.False(s.Contains(id1)) + + s.Add(id1) + require.True(s.Contains(id1)) + require.Len(s.List(), 1) + require.Equal(id1, s.List()[0]) + + s.Clear() + require.False(s.Contains(id1)) + + s.Add(id1) + + s2 := SampleableSet[int]{} + + require.False(s.Overlaps(s2)) + + s2.Union(s) + require.True(s2.Contains(id1)) + require.True(s.Overlaps(s2)) + + s2.Difference(s) + require.False(s2.Contains(id1)) + require.False(s.Overlaps(s2)) +} + +func TestSampleableSetClear(t *testing.T) { + require := require.New(t) + + set := SampleableSet[int]{} + for i := 0; i < 25; i++ { + set.Add(i) + } + set.Clear() + require.Zero(set.Len()) + set.Add(1337) + require.Equal(1, set.Len()) +} + +func TestSampleableSetMarshalJSON(t *testing.T) { + require := require.New(t) + set := SampleableSet[int]{} + { + asJSON, err := set.MarshalJSON() + require.NoError(err) + require.Equal("[]", string(asJSON)) + } + id1, id2 := 1, 2 + id1JSON, err := json.Marshal(id1) + require.NoError(err) + id2JSON, err := json.Marshal(id2) + require.NoError(err) + set.Add(id1) + { + asJSON, err := set.MarshalJSON() + require.NoError(err) + require.Equal(fmt.Sprintf("[%s]", string(id1JSON)), string(asJSON)) + } + set.Add(id2) + { + asJSON, err := set.MarshalJSON() + require.NoError(err) + require.Equal(fmt.Sprintf("[%s,%s]", string(id1JSON), string(id2JSON)), string(asJSON)) + } +} + +func TestSampleableSetUnmarshalJSON(t *testing.T) { + require := require.New(t) + set := SampleableSet[int]{} + { + require.NoError(set.UnmarshalJSON([]byte("[]"))) + require.Zero(set.Len()) + } + id1, id2 := 1, 2 + id1JSON, err := json.Marshal(id1) + require.NoError(err) + id2JSON, err := json.Marshal(id2) + require.NoError(err) + { + require.NoError(set.UnmarshalJSON([]byte(fmt.Sprintf("[%s]", string(id1JSON))))) + require.Equal(1, set.Len()) + require.True(set.Contains(id1)) + } + { + require.NoError(set.UnmarshalJSON([]byte(fmt.Sprintf("[%s,%s]", string(id1JSON), string(id2JSON))))) + require.Equal(2, set.Len()) + require.True(set.Contains(id1)) + require.True(set.Contains(id2)) + } + { + require.NoError(set.UnmarshalJSON([]byte(fmt.Sprintf("[%d,%d,%d]", 3, 4, 5)))) + require.Equal(3, set.Len()) + require.True(set.Contains(3)) + require.True(set.Contains(4)) + require.True(set.Contains(5)) + } + { + require.NoError(set.UnmarshalJSON([]byte(fmt.Sprintf("[%d,%d,%d, %d]", 3, 4, 5, 3)))) + require.Equal(3, set.Len()) + require.True(set.Contains(3)) + require.True(set.Contains(4)) + require.True(set.Contains(5)) + } + { + set1 := SampleableSet[int]{} + set2 := SampleableSet[int]{} + require.NoError(set1.UnmarshalJSON([]byte(fmt.Sprintf("[%s,%s]", string(id1JSON), string(id2JSON))))) + require.NoError(set2.UnmarshalJSON([]byte(fmt.Sprintf("[%s,%s]", string(id2JSON), string(id1JSON))))) + require.True(set1.Equals(set2)) + } +} diff --git a/vms/proposervm/block/block.go b/vms/proposervm/block/block.go index a64ec2e87979..d3e63b87eae6 100644 --- a/vms/proposervm/block/block.go +++ b/vms/proposervm/block/block.go @@ -10,6 +10,7 @@ import ( "time" "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/staking" "github.com/ava-labs/avalanchego/utils/hashing" "github.com/ava-labs/avalanchego/utils/wrappers" ) @@ -129,5 +130,9 @@ func (b *statelessBlock) Verify(shouldHaveProposer bool, chainID ids.ID) error { } headerBytes := header.Bytes() - return b.cert.CheckSignature(b.cert.SignatureAlgorithm, headerBytes, b.Signature) + return staking.CheckSignature( + b.cert, + headerBytes, + b.Signature, + ) }