diff --git a/go.mod b/go.mod index f47d8cf7928..e1620d96a84 100644 --- a/go.mod +++ b/go.mod @@ -58,6 +58,7 @@ require ( github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e + github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 github.com/patrickmn/go-cache v2.1.0+incompatible diff --git a/go.sum b/go.sum index 06df95a33a5..6e6196b8b43 100644 --- a/go.sum +++ b/go.sum @@ -477,6 +477,8 @@ github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513- github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y= github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g= github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= +github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080 h1:mXJkoWLdqJTlkQ7DgQ536kcXHXIdUPeagkN8i4eFDdg= +github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs= github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM= diff --git a/signal/peer/peer.go b/signal/peer/peer.go index 3149526b2e7..85de9158152 100644 --- a/signal/peer/peer.go +++ b/signal/peer/peer.go @@ -18,16 +18,20 @@ type Peer struct { StreamID int64 - //a gRpc connection stream to the Peer + // a gRpc connection stream to the Peer Stream proto.SignalExchange_ConnectStreamServer + + // registration time + RegisteredAt time.Time } // NewPeer creates a new instance of a connected Peer func NewPeer(id string, stream proto.SignalExchange_ConnectStreamServer) *Peer { return &Peer{ - Id: id, - Stream: stream, - StreamID: time.Now().UnixNano(), + Id: id, + Stream: stream, + StreamID: time.Now().UnixNano(), + RegisteredAt: time.Now(), } } diff --git a/signal/server/signal.go b/signal/server/signal.go index 219bdcc4143..69387cc6952 100644 --- a/signal/server/signal.go +++ b/signal/server/signal.go @@ -13,6 +13,8 @@ import ( "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" + "github.com/netbirdio/signal-dispatcher/dispatcher" + "github.com/netbirdio/netbird/signal/metrics" "github.com/netbirdio/netbird/signal/peer" "github.com/netbirdio/netbird/signal/proto" @@ -40,8 +42,8 @@ const ( type Server struct { registry *peer.Registry proto.UnimplementedSignalExchangeServer - - metrics *metrics.AppMetrics + dispatcher *dispatcher.Dispatcher + metrics *metrics.AppMetrics } // NewServer creates a new Signal server @@ -51,9 +53,15 @@ func NewServer(meter metric.Meter) (*Server, error) { return nil, fmt.Errorf("creating app metrics: %v", err) } + dispatcher, err := dispatcher.NewDispatcher() + if err != nil { + return nil, fmt.Errorf("creating dispatcher: %v", err) + } + s := &Server{ - registry: peer.NewRegistry(appMetrics), - metrics: appMetrics, + dispatcher: dispatcher, + registry: peer.NewRegistry(appMetrics), + metrics: appMetrics, } return s, nil @@ -61,57 +69,31 @@ func NewServer(meter metric.Meter) (*Server, error) { // Send forwards a message to the signal peer func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { - if !s.registry.IsPeerRegistered(msg.Key) { - s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeNotRegistered))) + log.Debugf("received a new message to send from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey) - return nil, fmt.Errorf("peer %s is not registered", msg.Key) + if msg.RemoteKey == "dummy" { + // Test message send during netbird status + return &proto.EncryptedMessage{}, nil } - getRegistrationStart := time.Now() - - if dstPeer, found := s.registry.Get(msg.RemoteKey); found { - s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeMessage), attribute.String(labelRegistrationStatus, labelRegistrationFound))) - start := time.Now() - //forward the message to the target peer - if err := dstPeer.Stream.Send(msg); err != nil { - log.Errorf("error while forwarding message from peer [%s] to peer [%s] %v", msg.Key, msg.RemoteKey, err) - //todo respond to the sender? - - s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeError))) - } else { - s.metrics.MessageForwardLatency.Record(ctx, float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeMessage))) - s.metrics.MessagesForwarded.Add(context.Background(), 1) - } - } else { - s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeMessage), attribute.String(labelRegistrationStatus, labelRegistrationNotFound))) - log.Debugf("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", msg.Key, msg.RemoteKey) - //todo respond to the sender? - - s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeNotConnected))) + if _, found := s.registry.Get(msg.RemoteKey); found { + s.forwardMessageToPeer(ctx, msg) + return &proto.EncryptedMessage{}, nil } - return &proto.EncryptedMessage{}, nil + + return s.dispatcher.SendMessage(context.Background(), msg) } // ConnectStream connects to the exchange stream func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer) error { - p, err := s.connectPeer(stream) + p, err := s.RegisterPeer(stream) if err != nil { return err } - startRegister := time.Now() - - s.metrics.ActivePeers.Add(stream.Context(), 1) - - defer func() { - log.Infof("peer disconnected [%s] [streamID %d] ", p.Id, p.StreamID) - s.registry.Deregister(p) + defer s.DeregisterPeer(p) - s.metrics.PeerConnectionDuration.Record(stream.Context(), int64(time.Since(startRegister).Seconds())) - s.metrics.ActivePeers.Add(context.Background(), -1) - }() - - //needed to confirm that the peer has been registered so that the client can proceed + // needed to confirm that the peer has been registered so that the client can proceed header := metadata.Pairs(proto.HeaderRegistered, "1") err = stream.SendHeader(header) if err != nil { @@ -119,11 +101,10 @@ func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer) return err } - log.Infof("peer connected [%s] [streamID %d] ", p.Id, p.StreamID) + log.Debugf("peer connected [%s] [streamID %d] ", p.Id, p.StreamID) for { - - //read incoming messages + // read incoming messages msg, err := stream.Recv() if err == io.EOF { break @@ -131,44 +112,28 @@ func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer) return err } - log.Debugf("received a new message from peer [%s] to peer [%s]", p.Id, msg.RemoteKey) - - getRegistrationStart := time.Now() - - // lookup the target peer where the message is going to - if dstPeer, found := s.registry.Get(msg.RemoteKey); found { - s.metrics.GetRegistrationDelay.Record(stream.Context(), float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationFound))) - start := time.Now() - //forward the message to the target peer - if err := dstPeer.Stream.Send(msg); err != nil { - log.Errorf("error while forwarding message from peer [%s] to peer [%s] %v", p.Id, msg.RemoteKey, err) - //todo respond to the sender? - s.metrics.MessageForwardFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelType, labelTypeError))) - } else { - // in milliseconds - s.metrics.MessageForwardLatency.Record(stream.Context(), float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream))) - s.metrics.MessagesForwarded.Add(stream.Context(), 1) - } - } else { - s.metrics.GetRegistrationDelay.Record(stream.Context(), float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationNotFound))) - s.metrics.MessageForwardFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelType, labelTypeNotConnected))) - log.Debugf("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", p.Id, msg.RemoteKey) - //todo respond to the sender? + log.Debugf("Received a response from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey) + + _, err = s.dispatcher.SendMessage(stream.Context(), msg) + if err != nil { + log.Debugf("error while sending message from peer [%s] to peer [%s] %v", msg.Key, msg.RemoteKey, err) } } + <-stream.Context().Done() return stream.Context().Err() } -// Handles initial Peer connection. -// Each connection must provide an Id header. -// At this moment the connecting Peer will be registered in the peer.Registry -func (s Server) connectPeer(stream proto.SignalExchange_ConnectStreamServer) (*peer.Peer, error) { +func (s *Server) RegisterPeer(stream proto.SignalExchange_ConnectStreamServer) (*peer.Peer, error) { + log.Debugf("registering new peer") if meta, hasMeta := metadata.FromIncomingContext(stream.Context()); hasMeta { if id, found := meta[proto.HeaderId]; found { p := peer.NewPeer(id[0], stream) s.registry.Register(p) + s.dispatcher.ListenForMessages(stream.Context(), p.Id, s.forwardMessageToPeer) + + s.metrics.ActivePeers.Add(stream.Context(), 1) return p, nil } else { @@ -180,3 +145,38 @@ func (s Server) connectPeer(stream proto.SignalExchange_ConnectStreamServer) (*p return nil, status.Errorf(codes.FailedPrecondition, "missing connection stream meta") } } + +func (s *Server) DeregisterPeer(p *peer.Peer) { + log.Debugf("peer disconnected [%s] [streamID %d] ", p.Id, p.StreamID) + s.registry.Deregister(p) + + s.metrics.PeerConnectionDuration.Record(p.Stream.Context(), int64(time.Since(p.RegisteredAt).Seconds())) + s.metrics.ActivePeers.Add(context.Background(), -1) +} + +func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedMessage) { + log.Debugf("forwarding a new message from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey) + + getRegistrationStart := time.Now() + + // lookup the target peer where the message is going to + if dstPeer, found := s.registry.Get(msg.RemoteKey); found { + s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationFound))) + start := time.Now() + // forward the message to the target peer + if err := dstPeer.Stream.Send(msg); err != nil { + log.Warnf("error while forwarding message from peer [%s] to peer [%s] %v", msg.Key, msg.RemoteKey, err) + // todo respond to the sender? + s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeError))) + } else { + // in milliseconds + s.metrics.MessageForwardLatency.Record(ctx, float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream))) + s.metrics.MessagesForwarded.Add(ctx, 1) + } + } else { + s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationNotFound))) + s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeNotConnected))) + log.Debugf("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", msg.Key, msg.RemoteKey) + // todo respond to the sender? + } +}