Skip to content

Commit

Permalink
Support DNS routes on iOS (#2254)
Browse files Browse the repository at this point in the history
  • Loading branch information
pascal-fischer authored Jul 15, 2024
1 parent 58fbc12 commit 47752e1
Show file tree
Hide file tree
Showing 9 changed files with 185 additions and 26 deletions.
20 changes: 11 additions & 9 deletions client/internal/routemanager/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
Expand Down Expand Up @@ -50,7 +51,7 @@ type DefaultManager struct {
statusRecorder *peer.Status
wgInterface *iface.WGIface
pubKey string
notifier *notifier
notifier *notifier.Notifier
routeRefCounter *refcounter.RouteRefCounter
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
dnsRouteInterval time.Duration
Expand All @@ -65,7 +66,8 @@ func NewManager(
initialRoutes []*route.Route,
) *DefaultManager {
mCTX, cancel := context.WithCancel(ctx)
sysOps := systemops.NewSysOps(wgInterface)
notifier := notifier.NewNotifier()
sysOps := systemops.NewSysOps(wgInterface, notifier)

dm := &DefaultManager{
ctx: mCTX,
Expand All @@ -77,7 +79,7 @@ func NewManager(
statusRecorder: statusRecorder,
wgInterface: wgInterface,
pubKey: pubKey,
notifier: newNotifier(),
notifier: notifier,
}

dm.routeRefCounter = refcounter.New(
Expand Down Expand Up @@ -107,7 +109,7 @@ func NewManager(

if runtime.GOOS == "android" {
cr := dm.clientRoutes(initialRoutes)
dm.notifier.setInitialClientRoutes(cr)
dm.notifier.SetInitialClientRoutes(cr)
}
return dm
}
Expand Down Expand Up @@ -186,7 +188,7 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro

filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap)
m.updateClientNetworks(updateSerial, filteredClientRoutes)
m.notifier.onNewRoutes(filteredClientRoutes)
m.notifier.OnNewRoutes(filteredClientRoutes)

if m.serverRouter != nil {
err := m.serverRouter.updateRoutes(newServerRoutesMap)
Expand All @@ -199,14 +201,14 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
}
}

// SetRouteChangeListener set RouteListener for route change notifier
// SetRouteChangeListener set RouteListener for route change Notifier
func (m *DefaultManager) SetRouteChangeListener(listener listener.NetworkChangeListener) {
m.notifier.setListener(listener)
m.notifier.SetListener(listener)
}

// InitialRouteRange return the list of initial routes. It used by mobile systems
func (m *DefaultManager) InitialRouteRange() []string {
return m.notifier.getInitialRouteRanges()
return m.notifier.GetInitialRouteRanges()
}

// GetRouteSelector returns the route selector
Expand All @@ -226,7 +228,7 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {

networks = m.routeSelector.FilterSelected(networks)

m.notifier.onNewRoutes(networks)
m.notifier.OnNewRoutes(networks)

m.stopObsoleteClients(networks)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package routemanager
package notifier

import (
"net/netip"
"runtime"
"sort"
"strings"
Expand All @@ -10,25 +11,25 @@ import (
"github.com/netbirdio/netbird/route"
)

type notifier struct {
type Notifier struct {
initialRouteRanges []string
routeRanges []string

listener listener.NetworkChangeListener
listenerMux sync.Mutex
}

func newNotifier() *notifier {
return &notifier{}
func NewNotifier() *Notifier {
return &Notifier{}
}

func (n *notifier) setListener(listener listener.NetworkChangeListener) {
func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
n.listenerMux.Lock()
defer n.listenerMux.Unlock()
n.listener = listener
}

func (n *notifier) setInitialClientRoutes(clientRoutes []*route.Route) {
func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) {
nets := make([]string, 0)
for _, r := range clientRoutes {
nets = append(nets, r.Network.String())
Expand All @@ -37,7 +38,10 @@ func (n *notifier) setInitialClientRoutes(clientRoutes []*route.Route) {
n.initialRouteRanges = nets
}

func (n *notifier) onNewRoutes(idMap route.HAMap) {
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
if runtime.GOOS != "android" {
return
}
newNets := make([]string, 0)
for _, routes := range idMap {
for _, r := range routes {
Expand All @@ -62,7 +66,30 @@ func (n *notifier) onNewRoutes(idMap route.HAMap) {
n.notify()
}

func (n *notifier) notify() {
func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
newNets := make([]string, 0)
for _, prefix := range prefixes {
newNets = append(newNets, prefix.String())
}

sort.Strings(newNets)
switch runtime.GOOS {
case "android":
if !n.hasDiff(n.initialRouteRanges, newNets) {
return
}
default:
if !n.hasDiff(n.routeRanges, newNets) {
return
}
}

n.routeRanges = newNets

n.notify()
}

func (n *Notifier) notify() {
n.listenerMux.Lock()
defer n.listenerMux.Unlock()
if n.listener == nil {
Expand All @@ -74,7 +101,7 @@ func (n *notifier) notify() {
}(n.listener)
}

func (n *notifier) hasDiff(a []string, b []string) bool {
func (n *Notifier) hasDiff(a []string, b []string) bool {
if len(a) != len(b) {
return true
}
Expand All @@ -86,7 +113,7 @@ func (n *notifier) hasDiff(a []string, b []string) bool {
return false
}

func (n *notifier) getInitialRouteRanges() []string {
func (n *Notifier) GetInitialRouteRanges() []string {
return addIPv6RangeIfNeeded(n.initialRouteRanges)
}

Expand Down
13 changes: 12 additions & 1 deletion client/internal/routemanager/systemops/systemops.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package systemops
import (
"net"
"net/netip"
"sync"

"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/iface"
)
Expand All @@ -18,10 +20,19 @@ type ExclusionCounter = refcounter.Counter[any, Nexthop]
type SysOps struct {
refCounter *ExclusionCounter
wgInterface *iface.WGIface
// prefixes is tracking all the current added prefixes im memory
// (this is used in iOS as all route updates require a full table update)
//nolint
prefixes map[netip.Prefix]struct{}
//nolint
mu sync.Mutex
// notifier is used to notify the system of route changes (also used on mobile)
notifier *notifier.Notifier
}

func NewSysOps(wgInterface *iface.WGIface) *SysOps {
func NewSysOps(wgInterface *iface.WGIface, notifier *notifier.Notifier) *SysOps {
return &SysOps{
wgInterface: wgInterface,
notifier: notifier,
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//go:build ios || android
//go:build android

package systemops

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func TestConcurrentRoutes(t *testing.T) {
baseIP := netip.MustParseAddr("192.0.2.0")
intf := &net.Interface{Name: "lo0"}

r := NewSysOps(nil)
r := NewSysOps(nil, nil)

var wg sync.WaitGroup
for i := 0; i < 1024; i++ {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func TestAddRemoveRoutes(t *testing.T) {
err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface")

r := NewSysOps(wgInterface)
r := NewSysOps(wgInterface, nil)

_, _, err = r.SetupRouting(nil)
require.NoError(t, err)
Expand Down Expand Up @@ -224,7 +224,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}

r := NewSysOps(wgInterface)
r := NewSysOps(wgInterface, nil)

// Prepare the environment
if testCase.preExistingPrefix.IsValid() {
Expand Down Expand Up @@ -379,7 +379,7 @@ func setupTestEnv(t *testing.T) {
assert.NoError(t, wgInterface.Close())
})

r := NewSysOps(wgInterface)
r := NewSysOps(wgInterface, nil)
_, _, err := r.SetupRouting(nil)
require.NoError(t, err, "setupRouting should not return err")
t.Cleanup(func() {
Expand Down
64 changes: 64 additions & 0 deletions client/internal/routemanager/systemops/systemops_ios.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
//go:build ios

package systemops

import (
"net"
"net/netip"
"runtime"

log "github.com/sirupsen/logrus"

nbnet "github.com/netbirdio/netbird/util/net"
)

func (r *SysOps) SetupRouting([]net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
r.mu.Lock()
defer r.mu.Unlock()
r.prefixes = make(map[netip.Prefix]struct{})
return nil, nil, nil
}

func (r *SysOps) CleanupRouting() error {
r.mu.Lock()
defer r.mu.Unlock()

r.prefixes = make(map[netip.Prefix]struct{})
r.notify()
return nil
}

func (r *SysOps) AddVPNRoute(prefix netip.Prefix, _ *net.Interface) error {
r.mu.Lock()
defer r.mu.Unlock()

r.prefixes[prefix] = struct{}{}
r.notify()
return nil
}

func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, _ *net.Interface) error {
r.mu.Lock()
defer r.mu.Unlock()

delete(r.prefixes, prefix)
r.notify()
return nil
}

func EnableIPForwarding() error {
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil
}

func IsAddrRouted(netip.Addr, []netip.Prefix) (bool, netip.Prefix) {
return false, netip.Prefix{}
}

func (r *SysOps) notify() {
prefixes := make([]netip.Prefix, 0, len(r.prefixes))
for prefix := range r.prefixes {
prefixes = append(prefixes, prefix)
}
r.notifier.OnNewPrefixes(prefixes)
}
28 changes: 27 additions & 1 deletion client/ios/NetBirdSDK/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route"
)

Expand Down Expand Up @@ -47,6 +48,7 @@ type CustomLogger interface {
type selectRoute struct {
NetID string
Network netip.Prefix
Domains domain.List
Selected bool
}

Expand Down Expand Up @@ -279,6 +281,7 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
route := &selectRoute{
NetID: string(id),
Network: rt[0].Network,
Domains: rt[0].Domains,
Selected: routeSelector.IsSelected(id),
}
routes = append(routes, route)
Expand All @@ -299,17 +302,40 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
return iPrefix < jPrefix
})

resolvedDomains := c.recorder.GetResolvedDomainsStates()

return prepareRouteSelectionDetails(routes, resolvedDomains), nil

}

func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[domain.Domain][]netip.Prefix) *RoutesSelectionDetails {
var routeSelection []RoutesSelectionInfo
for _, r := range routes {
domainList := make([]DomainInfo, 0)
for _, d := range r.Domains {
domainResp := DomainInfo{
Domain: d.SafeString(),
}
if prefixes, exists := resolvedDomains[d]; exists {
var ipStrings []string
for _, prefix := range prefixes {
ipStrings = append(ipStrings, prefix.Addr().String())
}
domainResp.ResolvedIPs = strings.Join(ipStrings, ", ")
}
domainList = append(domainList, domainResp)
}
domainDetails := DomainDetails{items: domainList}
routeSelection = append(routeSelection, RoutesSelectionInfo{
ID: r.NetID,
Network: r.Network.String(),
Domains: &domainDetails,
Selected: r.Selected,
})
}

routeSelectionDetails := RoutesSelectionDetails{items: routeSelection}
return &routeSelectionDetails, nil
return &routeSelectionDetails
}

func (c *Client) SelectRoute(id string) error {
Expand Down
Loading

0 comments on commit 47752e1

Please sign in to comment.