Skip to content

Commit

Permalink
[client] Remove loop after route calculation (#2856)
Browse files Browse the repository at this point in the history
- ICE do not trigger disconnect callbacks if the stated did not change
- Fix route calculation callback loop
- Move route state updates into protected scope by mutex
- Do not calculate routes in case of peer.Open() and peer.Close()
  • Loading branch information
pappz authored Nov 11, 2024
1 parent 08b6e9d commit b4d7605
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 84 deletions.
88 changes: 53 additions & 35 deletions client/internal/peer/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,6 @@ func (d *Status) UpdatePeerState(receivedState State) error {
peerState.IP = receivedState.IP
}

if receivedState.GetRoutes() != nil {
peerState.SetRoutes(receivedState.GetRoutes())
}

skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)

if receivedState.ConnStatus != peerState.ConnStatus {
Expand All @@ -261,12 +257,40 @@ func (d *Status) UpdatePeerState(receivedState State) error {
return nil
}

ch, found := d.changeNotify[receivedState.PubKey]
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
d.notifyPeerListChanged()
return nil
}

func (d *Status) AddPeerStateRoute(peer string, route string) error {
d.mux.Lock()
defer d.mux.Unlock()

peerState, ok := d.peers[peer]
if !ok {
return errors.New("peer doesn't exist")
}

peerState.AddRoute(route)
d.peers[peer] = peerState

// todo: consider to make sense of this notification or not
d.notifyPeerListChanged()
return nil
}

func (d *Status) RemovePeerStateRoute(peer string, route string) error {
d.mux.Lock()
defer d.mux.Unlock()

peerState, ok := d.peers[peer]
if !ok {
return errors.New("peer doesn't exist")
}

peerState.DeleteRoute(route)
d.peers[peer] = peerState

// todo: consider to make sense of this notification or not
d.notifyPeerListChanged()
return nil
}
Expand Down Expand Up @@ -301,12 +325,7 @@ func (d *Status) UpdatePeerICEState(receivedState State) error {
return nil
}

ch, found := d.changeNotify[receivedState.PubKey]
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
}

d.notifyPeerStateChangeListeners(receivedState.PubKey)
d.notifyPeerListChanged()
return nil
}
Expand Down Expand Up @@ -334,12 +353,7 @@ func (d *Status) UpdatePeerRelayedState(receivedState State) error {
return nil
}

ch, found := d.changeNotify[receivedState.PubKey]
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
}

d.notifyPeerStateChangeListeners(receivedState.PubKey)
d.notifyPeerListChanged()
return nil
}
Expand All @@ -366,12 +380,7 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error
return nil
}

ch, found := d.changeNotify[receivedState.PubKey]
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
}

d.notifyPeerStateChangeListeners(receivedState.PubKey)
d.notifyPeerListChanged()
return nil
}
Expand Down Expand Up @@ -401,12 +410,7 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
return nil
}

ch, found := d.changeNotify[receivedState.PubKey]
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
}

d.notifyPeerStateChangeListeners(receivedState.PubKey)
d.notifyPeerListChanged()
return nil
}
Expand Down Expand Up @@ -477,11 +481,14 @@ func (d *Status) FinishPeerListModifications() {
func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} {
d.mux.Lock()
defer d.mux.Unlock()

ch, found := d.changeNotify[peer]
if !found || ch == nil {
ch = make(chan struct{})
d.changeNotify[peer] = ch
if found {
return ch
}

ch = make(chan struct{})
d.changeNotify[peer] = ch
return ch
}

Expand Down Expand Up @@ -755,6 +762,17 @@ func (d *Status) onConnectionChanged() {
d.notifier.updateServerStates(d.managementState, d.signalState)
}

// notifyPeerStateChangeListeners notifies route manager about the change in peer state
func (d *Status) notifyPeerStateChangeListeners(peerID string) {
ch, found := d.changeNotify[peerID]
if !found {
return
}

close(ch)
delete(d.changeNotify, peerID)
}

func (d *Status) notifyPeerListChanged() {
d.notifier.peerListChanged(d.numOfPeers())
}
Expand Down
2 changes: 1 addition & 1 deletion client/internal/peer/status_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) {

peerState.IP = ip

err := status.UpdatePeerState(peerState)
err := status.UpdatePeerRelayedStateToDisconnected(peerState)
assert.NoError(t, err, "shouldn't return error")

select {
Expand Down
38 changes: 27 additions & 11 deletions client/internal/peer/worker_ice.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ type WorkerICE struct {

localUfrag string
localPwd string

// we record the last known state of the ICE agent to avoid duplicate on disconnected events
lastKnownState ice.ConnectionState
}

func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool, callBacks WorkerICECallbacks) (*WorkerICE, error) {
Expand Down Expand Up @@ -194,8 +197,7 @@ func (w *WorkerICE) Close() {
return
}

err := w.agent.Close()
if err != nil {
if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}
}
Expand All @@ -215,15 +217,18 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i

err = agent.OnConnectionStateChange(func(state ice.ConnectionState) {
w.log.Debugf("ICE ConnectionState has changed to %s", state.String())
if state == ice.ConnectionStateFailed || state == ice.ConnectionStateDisconnected {
w.conn.OnStatusChanged(StatusDisconnected)

w.muxAgent.Lock()
agentCancel()
_ = agent.Close()
w.agent = nil

w.muxAgent.Unlock()
switch state {
case ice.ConnectionStateConnected:
w.lastKnownState = ice.ConnectionStateConnected
return
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected:
if w.lastKnownState != ice.ConnectionStateDisconnected {
w.lastKnownState = ice.ConnectionStateDisconnected
w.conn.OnStatusChanged(StatusDisconnected)
}
w.closeAgent(agentCancel)
default:
return
}
})
if err != nil {
Expand All @@ -249,6 +254,17 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i
return agent, nil
}

func (w *WorkerICE) closeAgent(cancel context.CancelFunc) {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()

cancel()
if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}
w.agent = nil
}

func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
// wait local endpoint configuration
time.Sleep(time.Second)
Expand Down
64 changes: 27 additions & 37 deletions client/internal/routemanager/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,20 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
tempScore = float64(metricDiff) * 10
}

// in some temporal cases, latency can be 0, so we set it to 1s to not block but try to avoid this route
latency := time.Second
// in some temporal cases, latency can be 0, so we set it to 999ms to not block but try to avoid this route
latency := 999 * time.Millisecond
if peerStatus.latency != 0 {
latency = peerStatus.latency
} else {
log.Warnf("peer %s has 0 latency", r.Peer)
log.Tracef("peer %s has 0 latency, range %s", r.Peer, c.handler)
}

// avoid negative tempScore on the higher latency calculation
if latency > 1*time.Second {
latency = 999 * time.Millisecond
}

// higher latency is worse score
tempScore += 1 - latency.Seconds()

if !peerStatus.relayed {
Expand All @@ -150,6 +157,8 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
}
}

log.Debugf("chosen route: %s, chosen score: %f, current route: %s, current score: %f", chosen, chosenScore, currID, currScore)

switch {
case chosen == "":
var peers []string
Expand Down Expand Up @@ -195,15 +204,20 @@ func (c *clientNetwork) watchPeerStatusChanges(ctx context.Context, peerKey stri
func (c *clientNetwork) startPeersStatusChangeWatcher() {
for _, r := range c.routes {
_, found := c.routePeersNotifiers[r.Peer]
if !found {
c.routePeersNotifiers[r.Peer] = make(chan struct{})
go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, c.routePeersNotifiers[r.Peer])
if found {
continue
}

closerChan := make(chan struct{})
c.routePeersNotifiers[r.Peer] = closerChan
go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, closerChan)
}
}

func (c *clientNetwork) removeRouteFromWireguardPeer() error {
c.removeStateRoute()
func (c *clientNetwork) removeRouteFromWireGuardPeer() error {
if err := c.statusRecorder.RemovePeerStateRoute(c.currentChosen.Peer, c.handler.String()); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}

if err := c.handler.RemoveAllowedIPs(); err != nil {
return fmt.Errorf("remove allowed IPs: %w", err)
Expand All @@ -218,7 +232,7 @@ func (c *clientNetwork) removeRouteFromPeerAndSystem() error {

var merr *multierror.Error

if err := c.removeRouteFromWireguardPeer(); err != nil {
if err := c.removeRouteFromWireGuardPeer(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err))
}
if err := c.handler.RemoveRoute(); err != nil {
Expand Down Expand Up @@ -257,7 +271,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
}
} else {
// Otherwise, remove the allowed IPs from the previous peer first
if err := c.removeRouteFromWireguardPeer(); err != nil {
if err := c.removeRouteFromWireGuardPeer(); err != nil {
return fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
}
}
Expand All @@ -268,35 +282,11 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
return fmt.Errorf("add allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
}

c.addStateRoute()

return nil
}

func (c *clientNetwork) addStateRoute() {
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
if err != nil {
log.Errorf("Failed to get peer state: %v", err)
return
}

state.AddRoute(c.handler.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
}

func (c *clientNetwork) removeStateRoute() {
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
err := c.statusRecorder.AddPeerStateRoute(c.currentChosen.Peer, c.handler.String())
if err != nil {
log.Errorf("Failed to get peer state: %v", err)
return
}

state.DeleteRoute(c.handler.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err)
return fmt.Errorf("add peer state route: %w", err)
}
return nil
}

func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
Expand Down
5 changes: 5 additions & 0 deletions client/internal/routemanager/refcounter/refcounter.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,11 @@ func (rm *Counter[Key, I, O]) Clear() {

// MarshalJSON implements the json.Marshaler interface for Counter.
func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()

return json.Marshal(struct {
RefCountMap map[Key]Ref[O] `json:"refCountMap"`
IDMap map[string][]Key `json:"idMap"`
Expand Down

0 comments on commit b4d7605

Please sign in to comment.