From 5f638d68baf5753b40169d12ce47794320b4befd Mon Sep 17 00:00:00 2001 From: rkonfj Date: Wed, 4 Sep 2024 20:34:04 +0800 Subject: [PATCH] disco: add UPDATE_NAT_INFO control code --- disco/disco.go | 8 ++++++++ disco/tp/udp.go | 37 ++++++++++++++++++++++++++----------- disco/tp/ws.go | 15 +++++++++++++++ p2p/conn.go | 5 +++++ peermap/peermap.go | 36 +++++++++++++----------------------- 5 files changed, 67 insertions(+), 34 deletions(-) diff --git a/disco/disco.go b/disco/disco.go index 3298763..47c8ebe 100644 --- a/disco/disco.go +++ b/disco/disco.go @@ -26,6 +26,8 @@ func (code ControlCode) String() string { return "LEAD_DISCO" case CONTROL_UPDATE_NETWORK_SECRET: return "UPDATE_NETWORK_SECRET" + case CONTROL_UPDATE_NAT_INFO: + return "UPDATE_NAT_INFO" case CONTROL_CONN: return "CONTROL_CONN" default: @@ -43,6 +45,7 @@ const ( CONTROL_NEW_PEER_UDP_ADDR ControlCode = 2 CONTROL_LEAD_DISCO ControlCode = 3 CONTROL_UPDATE_NETWORK_SECRET ControlCode = 20 + CONTROL_UPDATE_NAT_INFO ControlCode = 21 CONTROL_CONN ControlCode = 30 ) @@ -98,6 +101,11 @@ const ( Internal NATType = "internal" ) +type NATInfo struct { + Type NATType + Addrs []*net.UDPAddr +} + type Disco struct { Magic func() []byte } diff --git a/disco/tp/udp.go b/disco/tp/udp.go index 2e7f674..67c33ea 100644 --- a/disco/tp/udp.go +++ b/disco/tp/udp.go @@ -51,11 +51,6 @@ func SetModifyDiscoConfig(modify func(cfg *DiscoConfig)) { defaultDiscoConfig.ChallengesBackoffRate = max(1, defaultDiscoConfig.ChallengesBackoffRate) } -type NATInfo struct { - Type disco.NATType - Addrs []*net.UDPAddr -} - var ( ErrUDPConnNotReady = errors.New("udpConn not ready yet") @@ -79,6 +74,7 @@ type UDPConn struct { disco *disco.Disco closedSig chan int datagrams chan *disco.Datagram + natEvents chan *disco.NATInfo udpAddrSends chan *disco.PeerUDPAddr peersIndex map[disco.PeerID]*peerkeeper @@ -86,7 +82,7 @@ type UDPConn struct { stunResponseMapMutex sync.RWMutex stunResponseMap map[string]chan stunResponse // key is stun txid - natInfo atomic.Pointer[NATInfo] + natInfo atomic.Pointer[disco.NATInfo] upnpDeleteMapping func() } @@ -102,6 +98,7 @@ func (c *UDPConn) Close() error { c.udpConnsMutex.RUnlock() close(c.closedSig) + close(c.natEvents) close(c.datagrams) close(c.udpAddrSends) return nil @@ -129,6 +126,10 @@ func (c *UDPConn) SetWriteBuffer(bytes int) error { return nil } +func (c *UDPConn) NATEvents() <-chan *disco.NATInfo { + return c.natEvents +} + func (c *UDPConn) Datagrams() <-chan *disco.Datagram { return c.datagrams } @@ -158,11 +159,16 @@ func (c *UDPConn) GenerateLocalAddrsSends(peerID disco.PeerID, stunServers []str continue } c.upnpDeleteMapping = func() { nat.DeletePortMapping("udp", mappedPort, udpPort) } + addr := &net.UDPAddr{IP: externalIP, Port: mappedPort} c.udpAddrSends <- &disco.PeerUDPAddr{ ID: peerID, - Addr: &net.UDPAddr{IP: externalIP, Port: mappedPort}, + Addr: addr, Type: disco.UPnP, } + select { + case c.natEvents <- &disco.NATInfo{Type: disco.UPnP, Addrs: []*net.UDPAddr{addr}}: + default: + } return } }() @@ -181,6 +187,10 @@ func (c *UDPConn) GenerateLocalAddrsSends(peerID disco.PeerID, stunServers []str } else { natType = disco.IP6 } + select { + case c.natEvents <- &disco.NATInfo{Type: natType, Addrs: []*net.UDPAddr{uaddr}}: + default: + } } c.udpAddrSends <- &disco.PeerUDPAddr{ ID: peerID, @@ -247,10 +257,14 @@ func (c *UDPConn) RoundTripSTUN(stunServer string) (*net.UDPAddr, error) { } } -func (c *UDPConn) DetectNAT(stunServers []string) (info NATInfo) { +func (c *UDPConn) DetectNAT(stunServers []string) (info disco.NATInfo) { defer func() { slog.Log(context.Background(), -1, "[NAT] DetectNAT", "type", info.Type) c.natInfo.Store(&info) + select { + case c.natEvents <- &info: + default: + } if info.Type == disco.Hard { if lastNATInfo := c.natInfo.Load(); lastNATInfo == nil || lastNATInfo.Type != disco.Hard { c.RestartListener() @@ -281,15 +295,15 @@ func (c *UDPConn) DetectNAT(stunServers []string) (info NATInfo) { wg.Wait() if len(udpAddrs) <= 1 { - return NATInfo{Type: disco.Unknown, Addrs: udpAddrs} + return disco.NATInfo{Type: disco.Unknown, Addrs: udpAddrs} } lastAddr := udpAddrs[0].String() for _, addr := range udpAddrs { if lastAddr != addr.String() { - return NATInfo{Type: disco.Hard, Addrs: udpAddrs} + return disco.NATInfo{Type: disco.Hard, Addrs: udpAddrs} } } - return NATInfo{Type: disco.Easy, Addrs: udpAddrs} + return disco.NATInfo{Type: disco.Easy, Addrs: udpAddrs} } func (c *UDPConn) RunDiscoMessageSendLoop(udpAddr disco.PeerUDPAddr) { @@ -698,6 +712,7 @@ func ListenUDP(cfg UDPConfig) (*UDPConn, error) { cfg: cfg, disco: &disco.Disco{Magic: cfg.DiscoMagic}, closedSig: make(chan int), + natEvents: make(chan *disco.NATInfo), datagrams: make(chan *disco.Datagram), udpAddrSends: make(chan *disco.PeerUDPAddr, 10), peersIndex: make(map[disco.PeerID]*peerkeeper), diff --git a/disco/tp/ws.go b/disco/tp/ws.go index 53f95d9..0be916d 100644 --- a/disco/tp/ws.go +++ b/disco/tp/ws.go @@ -139,6 +139,21 @@ func (c *WSConn) LeadDisco(peerID disco.PeerID) error { return c.WriteTo(nil, peerID, disco.CONTROL_LEAD_DISCO) } +func (c *WSConn) UpdateNATInfo(natInfo disco.NATInfo) error { + if natInfo.Type == disco.Hard { + return nil + } + if natInfo.Type == disco.Easy { + natInfo.Addrs = natInfo.Addrs[:1] + } + controlPacket := []byte{byte(disco.CONTROL_UPDATE_NAT_INFO), 0} + b, err := json.Marshal(natInfo) + if err != nil { + return fmt.Errorf("marshal nat info: %w", err) + } + return c.write(append(controlPacket, b...)) +} + func (c *WSConn) Datagrams() <-chan *disco.Datagram { return c.datagrams } diff --git a/p2p/conn.go b/p2p/conn.go index 29b0eb8..9ea500d 100644 --- a/p2p/conn.go +++ b/p2p/conn.go @@ -279,6 +279,11 @@ func (c *PeerPacketConn) runControlEventLoop() { if onPeer := c.cfg.OnPeer; onPeer != nil { go onPeer(peer.ID, peer.Metadata) } + case natEvent, ok := <-c.udpConn.NATEvents(): + if !ok { + return + } + go c.wsConn.UpdateNATInfo(*natEvent) case revcUDPAddr, ok := <-c.wsConn.PeersUDPAddrs(): if !ok { return diff --git a/peermap/peermap.go b/peermap/peermap.go index 7b60088..81f88e8 100644 --- a/peermap/peermap.go +++ b/peermap/peermap.go @@ -9,7 +9,6 @@ import ( "io" "log/slog" "math" - "net" "net/http" "net/url" "os" @@ -216,6 +215,10 @@ func (p *peerConn) readMessageLoop() { p.connData <- b[1:] continue } + if b[0] == disco.CONTROL_UPDATE_NAT_INFO.Byte() { + p.updateNATInfo(b) + continue + } tgtPeerID := disco.PeerID(b[2 : b[1]+2]) slog.Debug("PeerEvent", "op", disco.ControlCode(b[0]), "from", p.id, "to", tgtPeerID) tgtPeer, err := p.peerMap.getPeer(p.networkSecret.Network, tgtPeerID) @@ -223,13 +226,10 @@ func (p *peerConn) readMessageLoop() { slog.Debug("FindPeer failed", "detail", err) continue } - if disco.ControlCode(b[0]) == disco.CONTROL_LEAD_DISCO { + if b[0] == disco.CONTROL_LEAD_DISCO.Byte() { p.leadDisco(tgtPeer) continue } - if disco.ControlCode(b[0]) == disco.CONTROL_NEW_PEER_UDP_ADDR { - p.updatePeerUDPAddr(b) - } data := b[b[1]+2:] bb := make([]byte, 2+len(p.id)+len(data)) bb[0] = b[0] @@ -241,26 +241,16 @@ func (p *peerConn) readMessageLoop() { } } -func (p *peerConn) updatePeerUDPAddr(b []byte) { - if b[b[1]+2] != 'a' { - return - } - addrLen := b[b[1]+3] - s := b[1] + 4 - addr, err := net.ResolveUDPAddr("udp", string(b[s:s+addrLen])) - if err != nil { - slog.Error("Resolve udp addr error", "err", err) +func (p *peerConn) updateNATInfo(b []byte) { + var natInfo disco.NATInfo + if err := json.Unmarshal(b[2:], &natInfo); err != nil { + slog.Error("UpdateNATInfo", "peer", p.id, "err", err) return } - natType := disco.NATType(b[s+addrLen:]) - slog.Debug("ExchangeUDPAddr", "nat", natType, "addr", addr.String()) - if slices.Contains([]disco.NATType{disco.Easy, disco.IP6, disco.IP4}, natType) { - if natType.AccurateThan(disco.NATType(p.metadata.Get("nat"))) { - p.metadata.Set("nat", natType.String()) - } - if !slices.Contains(p.metadata["addr"], addr.String()) { - p.metadata.Add("addr", addr.String()) - } + p.metadata.Del("addr") + p.metadata.Set("nat", natInfo.Type.String()) + for _, addr := range natInfo.Addrs { + p.metadata.Add("addr", addr.String()) } }