From beaaa3c53a72840b463f9ae24eb7a2685053160d Mon Sep 17 00:00:00 2001 From: Carlos Hernandez Date: Thu, 11 Apr 2024 14:12:23 -0600 Subject: [PATCH] Add safe read/write to route map (#1760) --- client/internal/engine.go | 1 + client/internal/peer/conn.go | 5 +++- client/internal/peer/status.go | 40 +++++++++++++++++++++++--- client/internal/peer/status_test.go | 5 ++++ client/internal/routemanager/client.go | 7 ++--- client/server/server.go | 2 +- 6 files changed, 49 insertions(+), 11 deletions(-) diff --git a/client/internal/engine.go b/client/internal/engine.go index d6238c4b3ca..ba7074672c3 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -794,6 +794,7 @@ func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) { FQDN: offlinePeer.GetFqdn(), ConnStatus: peer.StatusDisconnected, ConnStatusUpdate: time.Now(), + Mux: new(sync.RWMutex), } } e.statusRecorder.ReplaceOfflinePeers(replacement) diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index f3d07dcad1f..9e7ee695932 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -229,7 +229,6 @@ func (conn *Conn) reCreateAgent() error { } conn.agent, err = ice.NewAgent(agentConfig) - if err != nil { return err } @@ -285,6 +284,7 @@ func (conn *Conn) Open() error { IP: strings.Split(conn.config.WgConfig.AllowedIps, "/")[0], ConnStatusUpdate: time.Now(), ConnStatus: conn.status, + Mux: new(sync.RWMutex), } err := conn.statusRecorder.UpdatePeerState(peerState) if err != nil { @@ -344,6 +344,7 @@ func (conn *Conn) Open() error { PubKey: conn.config.Key, ConnStatus: conn.status, ConnStatusUpdate: time.Now(), + Mux: new(sync.RWMutex), } err = conn.statusRecorder.UpdatePeerState(peerState) if err != nil { @@ -468,6 +469,7 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Local.Port()), Direct: !isRelayCandidate(pair.Local), RosenpassEnabled: rosenpassEnabled, + Mux: new(sync.RWMutex), } if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay { peerState.Relayed = true @@ -558,6 +560,7 @@ func (conn *Conn) cleanup() error { PubKey: conn.config.Key, ConnStatus: conn.status, ConnStatusUpdate: time.Now(), + Mux: new(sync.RWMutex), } err := conn.statusRecorder.UpdatePeerState(peerState) if err != nil { diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index ca97c3ea497..ddea7d04e16 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -14,6 +14,7 @@ import ( // State contains the latest state of a peer type State struct { + Mux *sync.RWMutex IP string PubKey string FQDN string @@ -30,7 +31,38 @@ type State struct { BytesRx int64 Latency time.Duration RosenpassEnabled bool - Routes map[string]struct{} + routes map[string]struct{} +} + +// AddRoute add a single route to routes map +func (s *State) AddRoute(network string) { + s.Mux.Lock() + if s.routes == nil { + s.routes = make(map[string]struct{}) + } + s.routes[network] = struct{}{} + s.Mux.Unlock() +} + +// SetRoutes set state routes +func (s *State) SetRoutes(routes map[string]struct{}) { + s.Mux.Lock() + s.routes = routes + s.Mux.Unlock() +} + +// DeleteRoute removes a route from the network amp +func (s *State) DeleteRoute(network string) { + s.Mux.Lock() + delete(s.routes, network) + s.Mux.Unlock() +} + +// GetRoutes return routes map +func (s *State) GetRoutes() map[string]struct{} { + s.Mux.RLock() + defer s.Mux.RUnlock() + return s.routes } // LocalPeerState contains the latest state of the local peer @@ -143,6 +175,7 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string) error { PubKey: peerPubKey, ConnStatus: StatusDisconnected, FQDN: fqdn, + Mux: new(sync.RWMutex), } d.peerListChangedForNotification = true return nil @@ -189,8 +222,8 @@ func (d *Status) UpdatePeerState(receivedState State) error { peerState.IP = receivedState.IP } - if receivedState.Routes != nil { - peerState.Routes = receivedState.Routes + if receivedState.GetRoutes() != nil { + peerState.SetRoutes(receivedState.GetRoutes()) } skipNotification := shouldSkipNotify(receivedState, peerState) @@ -440,7 +473,6 @@ func (d *Status) IsLoginRequired() bool { s, ok := gstatus.FromError(d.managementError) if ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) { return true - } return false } diff --git a/client/internal/peer/status_test.go b/client/internal/peer/status_test.go index 9038371bd1c..a4a6e608132 100644 --- a/client/internal/peer/status_test.go +++ b/client/internal/peer/status_test.go @@ -3,6 +3,7 @@ package peer import ( "errors" "testing" + "sync" "github.com/stretchr/testify/assert" ) @@ -42,6 +43,7 @@ func TestUpdatePeerState(t *testing.T) { status := NewRecorder("https://mgm") peerState := State{ PubKey: key, + Mux: new(sync.RWMutex), } status.peers[key] = peerState @@ -62,6 +64,7 @@ func TestStatus_UpdatePeerFQDN(t *testing.T) { status := NewRecorder("https://mgm") peerState := State{ PubKey: key, + Mux: new(sync.RWMutex), } status.peers[key] = peerState @@ -80,6 +83,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) { status := NewRecorder("https://mgm") peerState := State{ PubKey: key, + Mux: new(sync.RWMutex), } status.peers[key] = peerState @@ -104,6 +108,7 @@ func TestRemovePeer(t *testing.T) { status := NewRecorder("https://mgm") peerState := State{ PubKey: key, + Mux: new(sync.RWMutex), } status.peers[key] = peerState diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index 370ad5cf44b..d41ed422b81 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -196,7 +196,7 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error { return fmt.Errorf("get peer state: %v", err) } - delete(state.Routes, c.network.String()) + state.DeleteRoute(c.network.String()) if err := c.statusRecorder.UpdatePeerState(state); err != nil { log.Warnf("Failed to update peer state: %v", err) } @@ -268,10 +268,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { if err != nil { log.Errorf("Failed to get peer state: %v", err) } else { - if state.Routes == nil { - state.Routes = map[string]struct{}{} - } - state.Routes[c.network.String()] = struct{}{} + state.AddRoute(c.network.String()) if err := c.statusRecorder.UpdatePeerState(state); err != nil { log.Warnf("Failed to update peer state: %v", err) } diff --git a/client/server/server.go b/client/server/server.go index d1d9dbda451..d33bb515582 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -718,7 +718,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus { BytesRx: peerState.BytesRx, BytesTx: peerState.BytesTx, RosenpassEnabled: peerState.RosenpassEnabled, - Routes: maps.Keys(peerState.Routes), + Routes: maps.Keys(peerState.GetRoutes()), Latency: durationpb.New(peerState.Latency), } pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)