From 5d1d35f8513ff7e40e964faecae5071644414f6f Mon Sep 17 00:00:00 2001 From: nisdas Date: Tue, 1 Nov 2022 20:27:31 +0800 Subject: [PATCH] Revert "fix: don't prefer local ports from other addresses when dialing (#1673)" This reverts commit bbd28365c5d7d86b259403a379fd34a8279e4b4b. --- p2p/net/reuseport/dial.go | 57 +++++++++++++- p2p/net/reuseport/dialer.go | 114 ---------------------------- p2p/net/reuseport/multidialer.go | 90 ++++++++++++++++++++++ p2p/net/reuseport/singledialer.go | 16 ++++ p2p/net/reuseport/transport.go | 2 +- p2p/net/reuseport/transport_test.go | 8 ++ 6 files changed, 169 insertions(+), 118 deletions(-) delete mode 100644 p2p/net/reuseport/dialer.go create mode 100644 p2p/net/reuseport/multidialer.go create mode 100644 p2p/net/reuseport/singledialer.go diff --git a/p2p/net/reuseport/dial.go b/p2p/net/reuseport/dial.go index 6a3d18ff21..b998be7d29 100644 --- a/p2p/net/reuseport/dial.go +++ b/p2p/net/reuseport/dial.go @@ -2,11 +2,18 @@ package reuseport import ( "context" + "net" + "github.com/libp2p/go-reuseport" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" ) +type dialer interface { + Dial(network, addr string) (net.Conn, error) + DialContext(ctx context.Context, network, addr string) (net.Conn, error) +} + // Dial dials the given multiaddr, reusing ports we're currently listening on if // possible. // @@ -24,7 +31,7 @@ func (t *Transport) DialContext(ctx context.Context, raddr ma.Multiaddr) (manet. if err != nil { return nil, err } - var d *dialer + var d dialer switch network { case "tcp4": d = t.v4.getDialer(network) @@ -45,7 +52,7 @@ func (t *Transport) DialContext(ctx context.Context, raddr ma.Multiaddr) (manet. return maconn, nil } -func (n *network) getDialer(network string) *dialer { +func (n *network) getDialer(network string) dialer { n.mu.RLock() d := n.dialer n.mu.RUnlock() @@ -54,9 +61,53 @@ func (n *network) getDialer(network string) *dialer { defer n.mu.Unlock() if n.dialer == nil { - n.dialer = newDialer(n.listeners) + n.dialer = n.makeDialer(network) } d = n.dialer } return d } + +func (n *network) makeDialer(network string) dialer { + if !reuseport.Available() { + log.Debug("reuseport not available") + return &net.Dialer{} + } + + var unspec net.IP + switch network { + case "tcp4": + unspec = net.IPv4zero + case "tcp6": + unspec = net.IPv6unspecified + default: + panic("invalid network: must be either tcp4 or tcp6") + } + + // How many ports are we listening on. + var port = 0 + for l := range n.listeners { + newPort := l.Addr().(*net.TCPAddr).Port + switch { + case newPort == 0: // Any port, ignore (really, we shouldn't get this case...). + case port == 0: // Haven't selected a port yet, choose this one. + port = newPort + case newPort == port: // Same as the selected port, continue... + default: // Multiple ports, use the multi dialer + return newMultiDialer(unspec, n.listeners) + } + } + + // None. + if port == 0 { + return &net.Dialer{} + } + + // One. Always dial from the single port we're listening on. + laddr := &net.TCPAddr{ + IP: unspec, + Port: port, + } + + return (*singleDialer)(laddr) +} diff --git a/p2p/net/reuseport/dialer.go b/p2p/net/reuseport/dialer.go deleted file mode 100644 index 2efc02d393..0000000000 --- a/p2p/net/reuseport/dialer.go +++ /dev/null @@ -1,114 +0,0 @@ -package reuseport - -import ( - "context" - "fmt" - "math/rand" - "net" - - "github.com/libp2p/go-netroute" -) - -type dialer struct { - // All address that are _not_ loopback or unspecified (0.0.0.0 or ::). - specific []*net.TCPAddr - // All loopback addresses (127.*.*.*, ::1). - loopback []*net.TCPAddr - // Unspecified addresses (0.0.0.0, ::) - unspecified []*net.TCPAddr -} - -func (d *dialer) Dial(network, addr string) (net.Conn, error) { - return d.DialContext(context.Background(), network, addr) -} - -func randAddr(addrs []*net.TCPAddr) *net.TCPAddr { - if len(addrs) > 0 { - return addrs[rand.Intn(len(addrs))] - } - return nil -} - -// DialContext dials a target addr. -// -// In-order: -// -// 1. If we're _explicitly_ listening on the prefered source address for the destination address -// (per the system's routes), we'll use that listener's port as the source port. -// 2. If we're listening on one or more _unspecified_ addresses (zero address), we'll pick a source -// port from one of these listener's. -// 3. Otherwise, we'll let the system pick the source port. -func (d *dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { - // We only check this case if the user is listening on a specific address (loopback or - // otherwise). Generally, users will listen on the "unspecified" address (0.0.0.0 or ::) and - // we can skip this section. - // - // This lets us avoid resolving the address twice, in most cases. - if len(d.specific) > 0 || len(d.loopback) > 0 { - tcpAddr, err := net.ResolveTCPAddr(network, addr) - if err != nil { - return nil, err - } - ip := tcpAddr.IP - if !ip.IsLoopback() && !ip.IsGlobalUnicast() { - return nil, fmt.Errorf("undialable IP: %s", ip) - } - - // If we're listening on some specific address and that specific address happens to - // be the preferred source address for the target destination address, we try to - // dial with that address/port. - // - // We skip this check if we _aren't_ listening on any specific addresses, because - // checking routing tables can be expensive and users rarely listen on specific IP - // addresses. - if len(d.specific) > 0 { - if router, err := netroute.New(); err == nil { - if _, _, preferredSrc, err := router.Route(ip); err == nil { - for _, optAddr := range d.specific { - if optAddr.IP.Equal(preferredSrc) { - return reuseDial(ctx, optAddr, network, addr) - } - } - } - } - } - - // Otherwise, if we are listening on a loopback address and the destination is also - // a loopback address, use the port from our loopback listener. - if len(d.loopback) > 0 && ip.IsLoopback() { - return reuseDial(ctx, randAddr(d.loopback), network, addr) - } - } - - // If we're listening on any uspecified addresses, use a randomly chosen port from one of - // these listeners. - if len(d.unspecified) > 0 { - return reuseDial(ctx, randAddr(d.unspecified), network, addr) - } - - // Finally, just pick a random port. - var dialer net.Dialer - return dialer.DialContext(ctx, network, addr) -} - -func newDialer(listeners map[*listener]struct{}) *dialer { - specific := make([]*net.TCPAddr, 0) - loopback := make([]*net.TCPAddr, 0) - unspecified := make([]*net.TCPAddr, 0) - - for l := range listeners { - addr := l.Addr().(*net.TCPAddr) - if addr.IP.IsLoopback() { - loopback = append(loopback, addr) - } else if addr.IP.IsUnspecified() { - unspecified = append(unspecified, addr) - } else { - specific = append(specific, addr) - } - } - return &dialer{ - specific: specific, - loopback: loopback, - unspecified: unspecified, - } -} diff --git a/p2p/net/reuseport/multidialer.go b/p2p/net/reuseport/multidialer.go new file mode 100644 index 0000000000..a3b5e2e99f --- /dev/null +++ b/p2p/net/reuseport/multidialer.go @@ -0,0 +1,90 @@ +package reuseport + +import ( + "context" + "fmt" + "math/rand" + "net" + + "github.com/libp2p/go-netroute" +) + +type multiDialer struct { + listeningAddresses []*net.TCPAddr + loopback []*net.TCPAddr + unspecified []*net.TCPAddr + fallback net.TCPAddr +} + +func (d *multiDialer) Dial(network, addr string) (net.Conn, error) { + return d.DialContext(context.Background(), network, addr) +} + +func randAddr(addrs []*net.TCPAddr) *net.TCPAddr { + if len(addrs) > 0 { + return addrs[rand.Intn(len(addrs))] + } + return nil +} + +// DialContext dials a target addr. +// Dialing preference is +// * If there is a listener on the local interface the OS expects to use to route towards addr, use that. +// * If there is a listener on a loopback address, addr is loopback, use that. +// * If there is a listener on an undefined address (0.0.0.0 or ::), use that. +// * Use the fallback IP specified during construction, with a port that's already being listened on, if one exists. +func (d *multiDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + tcpAddr, err := net.ResolveTCPAddr(network, addr) + if err != nil { + return nil, err + } + ip := tcpAddr.IP + if !ip.IsLoopback() && !ip.IsGlobalUnicast() { + return nil, fmt.Errorf("undialable IP: %s", ip) + } + + if router, err := netroute.New(); err == nil { + if _, _, preferredSrc, err := router.Route(ip); err == nil { + for _, optAddr := range d.listeningAddresses { + if optAddr.IP.Equal(preferredSrc) { + return reuseDial(ctx, optAddr, network, addr) + } + } + } + } + + if ip.IsLoopback() && len(d.loopback) > 0 { + return reuseDial(ctx, randAddr(d.loopback), network, addr) + } + if len(d.unspecified) == 0 { + return reuseDial(ctx, &d.fallback, network, addr) + } + + return reuseDial(ctx, randAddr(d.unspecified), network, addr) +} + +func newMultiDialer(unspec net.IP, listeners map[*listener]struct{}) (m dialer) { + addrs := make([]*net.TCPAddr, 0) + loopback := make([]*net.TCPAddr, 0) + unspecified := make([]*net.TCPAddr, 0) + existingPort := 0 + + for l := range listeners { + addr := l.Addr().(*net.TCPAddr) + addrs = append(addrs, addr) + if addr.IP.IsLoopback() { + loopback = append(loopback, addr) + } else if addr.IP.IsGlobalUnicast() && existingPort == 0 { + existingPort = addr.Port + } else if addr.IP.IsUnspecified() { + unspecified = append(unspecified, addr) + } + } + m = &multiDialer{ + listeningAddresses: addrs, + loopback: loopback, + unspecified: unspecified, + fallback: net.TCPAddr{IP: unspec, Port: existingPort}, + } + return +} diff --git a/p2p/net/reuseport/singledialer.go b/p2p/net/reuseport/singledialer.go new file mode 100644 index 0000000000..b15dae80b9 --- /dev/null +++ b/p2p/net/reuseport/singledialer.go @@ -0,0 +1,16 @@ +package reuseport + +import ( + "context" + "net" +) + +type singleDialer net.TCPAddr + +func (d *singleDialer) Dial(network, address string) (net.Conn, error) { + return d.DialContext(context.Background(), network, address) +} + +func (d *singleDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + return reuseDial(ctx, (*net.TCPAddr)(d), network, address) +} diff --git a/p2p/net/reuseport/transport.go b/p2p/net/reuseport/transport.go index ba7b9debf7..37fb446cb7 100644 --- a/p2p/net/reuseport/transport.go +++ b/p2p/net/reuseport/transport.go @@ -31,5 +31,5 @@ type Transport struct { type network struct { mu sync.RWMutex listeners map[*listener]struct{} - dialer *dialer + dialer dialer } diff --git a/p2p/net/reuseport/transport_test.go b/p2p/net/reuseport/transport_test.go index 88f9cdb98f..b99b583824 100644 --- a/p2p/net/reuseport/transport_test.go +++ b/p2p/net/reuseport/transport_test.go @@ -141,6 +141,7 @@ func TestGlobalPreferenceV4(t *testing.T) { testPrefer(t, loopbackV4, loopbackV4, globalV4) t.Logf("when listening on %v, should prefer %v over %v", loopbackV4, unspecV4, globalV4) testPrefer(t, loopbackV4, unspecV4, globalV4) + t.Logf("when listening on %v, should prefer %v over %v", globalV4, unspecV4, loopbackV4) testPrefer(t, globalV4, unspecV4, loopbackV4) } @@ -176,6 +177,8 @@ func testPrefer(t *testing.T, listen, prefer, avoid ma.Multiaddr) { } defer listenerB1.Close() + dialOne(t, &trB, listenerA, listenerB1.Addr().(*net.TCPAddr).Port) + listenerB2, err := trB.Listen(prefer) if err != nil { t.Fatal(err) @@ -183,6 +186,11 @@ func testPrefer(t *testing.T, listen, prefer, avoid ma.Multiaddr) { defer listenerB2.Close() dialOne(t, &trB, listenerA, listenerB2.Addr().(*net.TCPAddr).Port) + + // Closing the listener should reset the dialer. + listenerB2.Close() + + dialOne(t, &trB, listenerA, listenerB1.Addr().(*net.TCPAddr).Port) } func TestV6V4(t *testing.T) {