diff --git a/core/network/network.go b/core/network/network.go index 0beaac0f71..47908b8e31 100644 --- a/core/network/network.go +++ b/core/network/network.go @@ -6,8 +6,10 @@ package network import ( + "bytes" "context" "io" + "sort" "time" "github.com/libp2p/go-libp2p/core/peer" @@ -184,3 +186,23 @@ type Dialer interface { Notify(Notifiee) StopNotify(Notifiee) } + +// DedupAddrs deduplicates addresses in place, leave only unique addresses. +// It doesn't allocate. +func DedupAddrs(addrs []ma.Multiaddr) []ma.Multiaddr { + if len(addrs) == 0 { + return addrs + } + sort.Slice(addrs, func(i, j int) bool { return bytes.Compare(addrs[i].Bytes(), addrs[j].Bytes()) < 0 }) + idx := 1 + for i := 1; i < len(addrs); i++ { + if !addrs[i-1].Equal(addrs[i]) { + addrs[idx] = addrs[i] + idx++ + } + } + for i := idx; i < len(addrs); i++ { + addrs[i] = nil + } + return addrs[:idx] +} diff --git a/core/network/network_test.go b/core/network/network_test.go new file mode 100644 index 0000000000..a78e6d7044 --- /dev/null +++ b/core/network/network_test.go @@ -0,0 +1,36 @@ +package network + +import ( + "fmt" + "testing" + + ma "github.com/multiformats/go-multiaddr" + + "github.com/stretchr/testify/require" +) + +func TestDedupAddrs(t *testing.T) { + tcpAddr := ma.StringCast("/ip4/127.0.0.1/tcp/1234") + quicAddr := ma.StringCast("/ip4/127.0.0.1/udp/1234/quic-v1") + wsAddr := ma.StringCast("/ip4/127.0.0.1/tcp/1234/ws") + + type testcase struct { + in, out []ma.Multiaddr + } + + for i, tc := range []testcase{ + {in: nil, out: nil}, + {in: []ma.Multiaddr{tcpAddr}, out: []ma.Multiaddr{tcpAddr}}, + {in: []ma.Multiaddr{tcpAddr, tcpAddr, tcpAddr}, out: []ma.Multiaddr{tcpAddr}}, + {in: []ma.Multiaddr{tcpAddr, quicAddr, tcpAddr}, out: []ma.Multiaddr{tcpAddr, quicAddr}}, + {in: []ma.Multiaddr{tcpAddr, quicAddr, wsAddr}, out: []ma.Multiaddr{tcpAddr, quicAddr, wsAddr}}, + } { + tc := tc + t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) { + deduped := DedupAddrs(tc.in) + for _, a := range tc.out { + require.Contains(t, deduped, a) + } + }) + } +} diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index b328b74b0f..70c40bf18b 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -1,13 +1,11 @@ package basichost import ( - "bytes" "context" "errors" "fmt" "io" "net" - "sort" "sync" "time" @@ -816,26 +814,6 @@ func (h *BasicHost) NormalizeMultiaddr(addr ma.Multiaddr) ma.Multiaddr { return addr } -// dedupAddrs deduplicates addresses in place, leave only unique addresses. -// It doesn't allocate. -func dedupAddrs(addrs []ma.Multiaddr) []ma.Multiaddr { - if len(addrs) == 0 { - return addrs - } - sort.Slice(addrs, func(i, j int) bool { return bytes.Compare(addrs[i].Bytes(), addrs[j].Bytes()) < 0 }) - idx := 1 - for i := 1; i < len(addrs); i++ { - if !addrs[i-1].Equal(addrs[i]) { - addrs[idx] = addrs[i] - idx++ - } - } - for i := idx; i < len(addrs); i++ { - addrs[i] = nil - } - return addrs[:idx] -} - // AllAddrs returns all the addresses of BasicHost at this moment in time. // It's ok to not include addresses if they're not available to be used now. func (h *BasicHost) AllAddrs() []ma.Multiaddr { @@ -860,7 +838,7 @@ func (h *BasicHost) AllAddrs() []ma.Multiaddr { finalAddrs = append(finalAddrs, resolved...) } - finalAddrs = dedupAddrs(finalAddrs) + finalAddrs = network.DedupAddrs(finalAddrs) var natMappings []inat.Mapping @@ -1010,7 +988,7 @@ func (h *BasicHost) AllAddrs() []ma.Multiaddr { } finalAddrs = append(finalAddrs, observedAddrs...) } - finalAddrs = dedupAddrs(finalAddrs) + finalAddrs = network.DedupAddrs(finalAddrs) finalAddrs = inferWebtransportAddrsFromQuic(finalAddrs) return finalAddrs diff --git a/p2p/host/basic/basic_host_test.go b/p2p/host/basic/basic_host_test.go index 19ea2bd02f..5c1babd9d5 100644 --- a/p2p/host/basic/basic_host_test.go +++ b/p2p/host/basic/basic_host_test.go @@ -825,32 +825,6 @@ func TestNormalizeMultiaddr(t *testing.T) { require.Equal(t, "/ip4/1.2.3.4/udp/9999/quic-v1/webtransport", h1.NormalizeMultiaddr(ma.StringCast("/ip4/1.2.3.4/udp/9999/quic-v1/webtransport/certhash/uEgNmb28")).String()) } -func TestDedupAddrs(t *testing.T) { - tcpAddr := ma.StringCast("/ip4/127.0.0.1/tcp/1234") - quicAddr := ma.StringCast("/ip4/127.0.0.1/udp/1234/quic-v1") - wsAddr := ma.StringCast("/ip4/127.0.0.1/tcp/1234/ws") - - type testcase struct { - in, out []ma.Multiaddr - } - - for i, tc := range []testcase{ - {in: nil, out: nil}, - {in: []ma.Multiaddr{tcpAddr}, out: []ma.Multiaddr{tcpAddr}}, - {in: []ma.Multiaddr{tcpAddr, tcpAddr, tcpAddr}, out: []ma.Multiaddr{tcpAddr}}, - {in: []ma.Multiaddr{tcpAddr, quicAddr, tcpAddr}, out: []ma.Multiaddr{tcpAddr, quicAddr}}, - {in: []ma.Multiaddr{tcpAddr, quicAddr, wsAddr}, out: []ma.Multiaddr{tcpAddr, quicAddr, wsAddr}}, - } { - tc := tc - t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) { - deduped := dedupAddrs(tc.in) - for _, a := range tc.out { - require.Contains(t, deduped, a) - } - }) - } -} - func TestInferWebtransportAddrsFromQuic(t *testing.T) { type testCase struct { name string diff --git a/p2p/net/swarm/dial_worker.go b/p2p/net/swarm/dial_worker.go index 4e37814ce3..02a7d63269 100644 --- a/p2p/net/swarm/dial_worker.go +++ b/p2p/net/swarm/dial_worker.go @@ -180,7 +180,12 @@ loop: case <-w.triggerDial: for _, addr := range w.nextDial { // spawn the dial - ad := w.pending[string(addr.Bytes())] + ad, ok := w.pending[string(addr.Bytes())] + if !ok { + log.Warn("unexpectedly missing pending addrDial for addr") + // Assume nothing to dial here + continue + } err := w.s.dialNextAddr(ad.ctx, w.peer, addr, w.resch) if err != nil { w.dispatchError(ad, err) @@ -195,7 +200,12 @@ loop: w.connected = true } - ad := w.pending[string(res.Addr.Bytes())] + ad, ok := w.pending[string(res.Addr.Bytes())] + if !ok { + log.Warn("unexpectedly missing pending addrDial res") + // Assume nothing to do here + continue + } if res.Conn != nil { // we got a connection, add it to the swarm diff --git a/p2p/net/swarm/swarm_dial.go b/p2p/net/swarm/swarm_dial.go index 5423a199b7..49c0fc7fd9 100644 --- a/p2p/net/swarm/swarm_dial.go +++ b/p2p/net/swarm/swarm_dial.go @@ -334,6 +334,7 @@ func (s *Swarm) addrsForDial(ctx context.Context, p peer.ID) ([]ma.Multiaddr, er if forceDirect, _ := network.GetForceDirectDial(ctx); forceDirect { goodAddrs = ma.FilterAddrs(goodAddrs, s.nonProxyAddr) } + goodAddrs = network.DedupAddrs(goodAddrs) if len(goodAddrs) == 0 { return nil, ErrNoGoodAddresses diff --git a/p2p/net/swarm/swarm_dial_test.go b/p2p/net/swarm/swarm_dial_test.go index 566a2307f4..215ee6df9f 100644 --- a/p2p/net/swarm/swarm_dial_test.go +++ b/p2p/net/swarm/swarm_dial_test.go @@ -65,6 +65,51 @@ func TestAddrsForDial(t *testing.T) { require.NotZero(t, len(mas)) } +func TestDedupAddrsForDial(t *testing.T) { + mockResolver := madns.MockResolver{IP: make(map[string][]net.IPAddr)} + ipaddr, err := net.ResolveIPAddr("ip4", "1.2.3.4") + if err != nil { + t.Fatal(err) + } + mockResolver.IP["example.com"] = []net.IPAddr{*ipaddr} + + resolver, err := madns.NewResolver(madns.WithDomainResolver("example.com", &mockResolver)) + if err != nil { + t.Fatal(err) + } + + priv, _, err := crypto.GenerateEd25519Key(rand.Reader) + require.NoError(t, err) + id, err := peer.IDFromPrivateKey(priv) + require.NoError(t, err) + + ps, err := pstoremem.NewPeerstore() + require.NoError(t, err) + ps.AddPubKey(id, priv.GetPublic()) + ps.AddPrivKey(id, priv) + t.Cleanup(func() { ps.Close() }) + + s, err := NewSwarm(id, ps, eventbus.NewBus(), WithMultiaddrResolver(resolver)) + require.NoError(t, err) + defer s.Close() + + tpt, err := tcp.NewTCPTransport(nil, &network.NullResourceManager{}) + require.NoError(t, err) + err = s.AddTransport(tpt) + require.NoError(t, err) + + otherPeer := test.RandPeerIDFatal(t) + + ps.AddAddr(otherPeer, ma.StringCast("/dns4/example.com/tcp/1234"), time.Hour) + ps.AddAddr(otherPeer, ma.StringCast("/ip4/1.2.3.4/tcp/1234"), time.Hour) + + ctx := context.Background() + mas, err := s.addrsForDial(ctx, otherPeer) + require.NoError(t, err) + + require.Equal(t, 1, len(mas)) +} + func newTestSwarmWithResolver(t *testing.T, resolver *madns.Resolver) *Swarm { priv, _, err := crypto.GenerateEd25519Key(rand.Reader) require.NoError(t, err) diff --git a/version.json b/version.json index c936d576a5..824d184e80 100644 --- a/version.json +++ b/version.json @@ -1,3 +1,3 @@ { - "version": "v0.27.4" + "version": "v0.27.5" }