diff --git a/p2p/transport/webrtc/listener.go b/p2p/transport/webrtc/listener.go index 0b29bf655d..aa5c69818a 100644 --- a/p2p/transport/webrtc/listener.go +++ b/p2p/transport/webrtc/listener.go @@ -61,7 +61,7 @@ type listener struct { var _ tpt.Listener = &listener{} -func newListener(transport *WebRTCTransport, laddr ma.Multiaddr, socket net.PacketConn, config webrtc.Configuration) (*listener, error) { +func newListener(transport *WebRTCTransport, laddr ma.Multiaddr, socket net.PacketConn, config webrtc.Configuration, mux *udpmux.UDPMux) (*listener, error) { localFingerprints, err := config.Certificates[0].GetFingerprints() if err != nil { return nil, err @@ -91,9 +91,7 @@ func newListener(transport *WebRTCTransport, laddr ma.Multiaddr, socket net.Pack } l.ctx, l.cancel = context.WithCancel(context.Background()) - mux := udpmux.NewUDPMux(socket) l.mux = mux - mux.Start() go l.listen() @@ -284,7 +282,7 @@ func (l *listener) setupConnection( localMultiaddrWithoutCerthash, "", // remotePeer nil, // remoteKey - remoteMultiaddr, + remoteMultiaddr.Encapsulate(webrtcComponent), ) if err != nil { return nil, err @@ -321,11 +319,9 @@ func (l *listener) Accept() (tpt.CapableConn, error) { } func (l *listener) Close() error { - select { - case <-l.ctx.Done(): - default: - l.cancel() - } + l.cancel() + l.mux.Close() + l.transport.RemoveMux(l.mux) return nil } diff --git a/p2p/transport/webrtc/reuseudpmux.go b/p2p/transport/webrtc/reuseudpmux.go new file mode 100644 index 0000000000..ad58f1f8e9 --- /dev/null +++ b/p2p/transport/webrtc/reuseudpmux.go @@ -0,0 +1,97 @@ +package libp2pwebrtc + +import ( + "fmt" + "net" + "sync" + + "github.com/libp2p/go-libp2p/p2p/transport/webrtc/udpmux" + "github.com/libp2p/go-netroute" +) + +// reuseUDPMux provides ability to reuse listening udpMux for dialing. This helps with address +// discovery for nodes that don't have access to their public ip address +type reuseUDPMux struct { + mu sync.RWMutex + loopback map[int]*udpmux.UDPMux + specific map[string]map[int]*udpmux.UDPMux // IP.String() => Port => Mux + unspecified map[int]*udpmux.UDPMux +} + +// Put stores mux for reuse later in Get calls. +func (r *reuseUDPMux) Put(mux *udpmux.UDPMux) error { + r.mu.Lock() + defer r.mu.Unlock() + for _, a := range mux.GetListenAddresses() { + udpAddr, err := net.ResolveUDPAddr(a.Network(), a.String()) + if err != nil { + return fmt.Errorf("udpmux ResolveUDPAddr failed for %s: %w", a, err) + } + if udpAddr.IP.IsLoopback() { + r.loopback[udpAddr.Port] = mux + continue + } + if udpAddr.IP.IsUnspecified() { + r.unspecified[udpAddr.Port] = mux + continue + } + if r.specific[udpAddr.IP.String()] == nil { + r.specific[udpAddr.IP.String()] = make(map[int]*udpmux.UDPMux) + } + r.specific[udpAddr.IP.String()][udpAddr.Port] = mux + } + return nil +} + +// Get retrieves a mux capable of dialing addr. Returns nil if no capable mux is present. If +// multiple muxes capable of dialing addr are available, it returns one arbitrarily +func (r *reuseUDPMux) Get(addr *net.UDPAddr) *udpmux.UDPMux { + r.mu.RLock() + defer r.mu.RUnlock() + if addr.IP.IsLoopback() { + for _, m := range r.loopback { + return m + } + } + if len(r.specific) > 0 { + if router, err := netroute.New(); err == nil { + if _, _, preferredSrc, err := router.Route(addr.IP); err == nil { + if len(r.specific[preferredSrc.String()]) != 0 { + for _, m := range r.specific[preferredSrc.String()] { + return m + } + } + } + } + } + for _, m := range r.unspecified { + return m + } + return nil +} + +// Delete removes a mux from the reuse pool. +func (r *reuseUDPMux) Delete(mux *udpmux.UDPMux) { + r.mu.Lock() + defer r.mu.Unlock() + for p, m := range r.loopback { + if mux == m { + delete(r.loopback, p) + } + } + for p, m := range r.unspecified { + if mux == m { + delete(r.unspecified, p) + } + } + for ip, mp := range r.specific { + for p, m := range mp { + if m == mux { + delete(mp, p) + } + } + if len(mp) == 0 { + delete(r.specific, ip) + } + } +} diff --git a/p2p/transport/webrtc/reuseudpmux_test.go b/p2p/transport/webrtc/reuseudpmux_test.go new file mode 100644 index 0000000000..484d42a2aa --- /dev/null +++ b/p2p/transport/webrtc/reuseudpmux_test.go @@ -0,0 +1,80 @@ +package libp2pwebrtc + +import ( + "net" + "testing" + + "github.com/libp2p/go-libp2p/p2p/transport/webrtc/udpmux" + "github.com/stretchr/testify/require" +) + +func newReuseUDPMux(t *testing.T) reuseUDPMux { + return reuseUDPMux{ + loopback: make(map[int]*udpmux.UDPMux), + specific: make(map[string]map[int]*udpmux.UDPMux), + unspecified: make(map[int]*udpmux.UDPMux), + } +} + +func udpAddr(t *testing.T, s string) *net.UDPAddr { + a, err := net.ResolveUDPAddr("udp", s) + require.NoError(t, err) + return a +} + +func TestReuseUDPMuxLoopback(t *testing.T) { + socket, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + require.NoError(t, err) + defer socket.Close() + r := newReuseUDPMux(t) + + mux := r.Get(udpAddr(t, "127.0.0.1:1")) + require.Nil(t, mux) + + originalMux := udpmux.NewUDPMux(socket) + err = r.Put(originalMux) + require.NoError(t, err) + + mux = r.Get(udpAddr(t, "127.0.0.1:1")) + require.Equal(t, originalMux, mux) + + mux = r.Get(udpAddr(t, "1.2.3.4:1")) + require.Nil(t, mux) + + r.Delete(originalMux) + mux = r.Get(udpAddr(t, "127.0.0.1:1")) + require.Nil(t, mux) +} + +func TestReuseUDPMuxUnspecified(t *testing.T) { + s1, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + require.NoError(t, err) + defer s1.Close() + + s2, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + require.NoError(t, err) + defer s2.Close() + + r := newReuseUDPMux(t) + + loMux := udpmux.NewUDPMux(s1) + err = r.Put(loMux) + require.NoError(t, err) + + mux := r.Get(udpAddr(t, "1.2.3.4:1")) + require.Nil(t, mux) + + unMux := udpmux.NewUDPMux(s2) + err = r.Put(unMux) + require.NoError(t, err) + + mux = r.Get(udpAddr(t, "127.0.0.1:1")) + require.Equal(t, loMux, mux) + + mux = r.Get(udpAddr(t, "1.2.3.4:1")) + require.Equal(t, unMux, mux) + + r.Delete(loMux) + mux = r.Get(udpAddr(t, "127.0.0.1:1")) + require.Equal(t, unMux, mux) +} diff --git a/p2p/transport/webrtc/transport.go b/p2p/transport/webrtc/transport.go index dd4028d1f2..677e60dbcd 100644 --- a/p2p/transport/webrtc/transport.go +++ b/p2p/transport/webrtc/transport.go @@ -35,6 +35,7 @@ import ( "github.com/libp2p/go-libp2p/core/sec" tpt "github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-libp2p/p2p/security/noise" + "github.com/libp2p/go-libp2p/p2p/transport/webrtc/udpmux" logging "github.com/ipfs/go-log/v2" ma "github.com/multiformats/go-multiaddr" @@ -94,6 +95,9 @@ type WebRTCTransport struct { // in-flight connections maxInFlightConnections uint32 + + v4Reuse reuseUDPMux + v6Reuse reuseUDPMux } var _ tpt.Transport = &WebRTCTransport{} @@ -156,6 +160,16 @@ func New(privKey ic.PrivKey, psk pnet.PSK, gater connmgr.ConnectionGater, rcmgr }, maxInFlightConnections: DefaultMaxInFlightConnections, + v4Reuse: reuseUDPMux{ + loopback: make(map[int]*udpmux.UDPMux), + specific: make(map[string]map[int]*udpmux.UDPMux), + unspecified: make(map[int]*udpmux.UDPMux), + }, + v6Reuse: reuseUDPMux{ + loopback: make(map[int]*udpmux.UDPMux), + specific: make(map[string]map[int]*udpmux.UDPMux), + unspecified: make(map[int]*udpmux.UDPMux), + }, } for _, opt := range opts { if err := opt(transport); err != nil { @@ -197,7 +211,7 @@ func (t *WebRTCTransport) Listen(addr ma.Multiaddr) (tpt.Listener, error) { return nil, fmt.Errorf("listen on udp: %w", err) } - listener, err := t.listenSocket(socket) + listener, err := t.listenSocket(socket, nw) if err != nil { socket.Close() return nil, err @@ -205,7 +219,7 @@ func (t *WebRTCTransport) Listen(addr ma.Multiaddr) (tpt.Listener, error) { return listener, nil } -func (t *WebRTCTransport) listenSocket(socket *net.UDPConn) (tpt.Listener, error) { +func (t *WebRTCTransport) listenSocket(socket *net.UDPConn, network string) (tpt.Listener, error) { listenerMultiaddr, err := manet.FromNetAddr(socket.LocalAddr()) if err != nil { return nil, err @@ -225,6 +239,12 @@ func (t *WebRTCTransport) listenSocket(socket *net.UDPConn) (tpt.Listener, error if err != nil { return nil, err } + + mux, err := t.newMux(socket, network) + if err != nil { + return nil, err + } + listenerMultiaddr = listenerMultiaddr.Encapsulate(webrtcComponent).Encapsulate(certComp) return newListener( @@ -232,6 +252,7 @@ func (t *WebRTCTransport) listenSocket(socket *net.UDPConn) (tpt.Listener, error listenerMultiaddr, socket, t.webrtcConfig, + mux, ) } @@ -306,6 +327,17 @@ func (t *WebRTCTransport) dial(ctx context.Context, scope network.ConnManagement t.peerConnectionTimeouts.Failed, t.peerConnectionTimeouts.Keepalive, ) + + if rnw == "udp4" { + if mux := t.v4Reuse.Get(raddr); mux != nil { + settingEngine.SetICEUDPMux(mux) + } + } else { + if mux := t.v6Reuse.Get(raddr); mux != nil { + settingEngine.SetICEUDPMux(mux) + } + } + // By default, webrtc will not collect candidates on the loopback address. // This is disallowed in the ICE specification. However, implementations // do not strictly follow this, for eg. Chrome gathers TCP loopback candidates. @@ -387,6 +419,7 @@ func (t *WebRTCTransport) dial(ctx context.Context, scope network.ConnManagement if err != nil { return nil, err } + localAddr = localAddr.Encapsulate(webrtcComponent) remoteMultiaddrWithoutCerthash, _ := ma.SplitFunc(remoteMultiaddr, func(c ma.Component) bool { return c.Protocol().Code == ma.P_CERTHASH }) @@ -521,6 +554,28 @@ func (t *WebRTCTransport) noiseHandshake(ctx context.Context, pc *webrtc.PeerCon return secureConn.RemotePublicKey(), nil } +func (t *WebRTCTransport) newMux(socket *net.UDPConn, network string) (*udpmux.UDPMux, error) { + mux := udpmux.NewUDPMux(socket) + if network == "udp4" { + if err := t.v4Reuse.Put(mux); err != nil { + t.v4Reuse.Delete(mux) + return nil, err + } + } else { + if err := t.v6Reuse.Put(mux); err != nil { + t.v6Reuse.Delete(mux) + return nil, err + } + } + mux.Start() + return mux, nil +} + +func (t *WebRTCTransport) RemoveMux(mux *udpmux.UDPMux) { + t.v4Reuse.Delete(mux) + t.v6Reuse.Delete(mux) +} + type fakeStreamConn struct{ *stream } func (fakeStreamConn) LocalAddr() net.Addr { return nil } diff --git a/p2p/transport/webrtc/transport_test.go b/p2p/transport/webrtc/transport_test.go index 7f4df94fc1..98643ba29d 100644 --- a/p2p/transport/webrtc/transport_test.go +++ b/p2p/transport/webrtc/transport_test.go @@ -711,3 +711,50 @@ func TestMaxInFlightRequests(t *testing.T) { require.Equal(t, count, int(success.Load()), "expected exactly 3 dial successes") require.Equal(t, 1, int(fails.Load()), "expected exactly 1 dial failure") } + +func TestTransportWebRTC_ReuseUDPMux(t *testing.T) { + tr, listeningPeer := getTransport(t) + listenMultiaddr := ma.StringCast("/ip4/127.0.0.1/udp/0/webrtc-direct") + + listener, err := tr.Listen(listenMultiaddr) + require.NoError(t, err) + + tr1, connectingPeer := getTransport(t) + listener1, err := tr1.Listen(listenMultiaddr) + // For dialing localhost this address should be preferred + addr1 := listener1.Multiaddr() + require.NoError(t, err) + defer listener1.Close() + + // For dialing localhost this address should be ignored + listener2, err := tr1.Listen(ma.StringCast("/ip4/0.0.0.0/udp/0/webrtc-direct")) + require.NoError(t, err) + defer listener2.Close() + + done := make(chan struct{}) + go func() { + conn, err := tr1.Dial(context.Background(), listener.Multiaddr(), listeningPeer) + assert.NoError(t, err) + _, err = conn.LocalMultiaddr().ValueForProtocol(ma.P_WEBRTC_DIRECT) + assert.NoError(t, err) + close(done) + }() + + conn, err := listener.Accept() + require.NoError(t, err) + require.NotNil(t, conn) + + require.Equal(t, connectingPeer, conn.RemotePeer()) + + // remote address on connection will not have /certhash + expectedAddr, _ := ma.SplitFunc(addr1, func(c ma.Component) bool { + return c.Protocol().Code == ma.P_CERTHASH + }) + require.Equal(t, conn.RemoteMultiaddr(), expectedAddr, "%s\n%s", conn.RemoteMultiaddr(), expectedAddr) + + select { + case <-done: + case <-time.After(10 * time.Second): + t.FailNow() + } +}