diff --git a/libp2p_test.go b/libp2p_test.go index 0b81e33bb5..7d947afb82 100644 --- a/libp2p_test.go +++ b/libp2p_test.go @@ -329,3 +329,30 @@ func TestTransportCustomAddressWebTransport(t *testing.T) { require.Equal(t, secondToLastComp.Protocol().Code, ma.P_CERTHASH) require.True(t, restOfAddr.Equal(customAddr)) } + +// TestTransportCustomAddressWebTransportDoesNotStall tests that if the user +// manually returns a webtransport address from AddrsFactory, but we aren't +// listening on a webtranport address, we don't stall. +func TestTransportCustomAddressWebTransportDoesNotStall(t *testing.T) { + customAddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic-v1/webtransport") + if err != nil { + t.Fatal(err) + } + h, err := New( + Transport(webtransport.New), + // Purposely not listening on the custom address so that we make sure the node doesn't stall if it fails to add a certhash to the multiaddr + // ListenAddrs(customAddr), + DisableRelay(), + AddrsFactory(func(multiaddrs []ma.Multiaddr) []ma.Multiaddr { + return []ma.Multiaddr{customAddr} + }), + ) + require.NoError(t, err) + defer h.Close() + addrs := h.Addrs() + require.Len(t, addrs, 1) + _, lastComp := ma.SplitLast(addrs[0]) + require.NotEqual(t, lastComp.Protocol().Code, ma.P_CERTHASH) + // We did not add the certhash to the multiaddr + require.Equal(t, addrs[0], customAddr) +} diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index ce82ad331d..b328b74b0f 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -771,7 +771,7 @@ func (h *BasicHost) Addrs() []ma.Multiaddr { } type addCertHasher interface { - AddCertHashes(m ma.Multiaddr) ma.Multiaddr + AddCertHashes(m ma.Multiaddr) (ma.Multiaddr, bool) } addrs := h.AddrsFactory(h.AllAddrs()) @@ -793,7 +793,11 @@ func (h *BasicHost) Addrs() []ma.Multiaddr { if !ok { continue } - addrs[i] = tpt.AddCertHashes(addr) + addrWithCerthash, added := tpt.AddCertHashes(addr) + addrs[i] = addrWithCerthash + if !added { + log.Debug("Couldn't add certhashes to webtransport multiaddr because we aren't listening on webtransport") + } } } return addrs diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index e44ac33252..f9c68ddf30 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "sync" + "sync/atomic" "time" "github.com/libp2p/go-libp2p/core/connmgr" @@ -68,12 +69,12 @@ type transport struct { rcmgr network.ResourceManager gater connmgr.ConnectionGater - listenOnce sync.Once - listenOnceErr error - certManager *certManager - certManagerReady chan struct{} // Closed when the certManager has been instantiated. - staticTLSConf *tls.Config - tlsClientConf *tls.Config + listenOnce sync.Once + listenOnceErr error + certManager *certManager + hasCertManager atomic.Bool // set to true once the certManager is initialized + staticTLSConf *tls.Config + tlsClientConf *tls.Config noise *noise.Transport @@ -98,14 +99,13 @@ func New(key ic.PrivKey, psk pnet.PSK, connManager *quicreuse.ConnManager, gater return nil, err } t := &transport{ - pid: id, - privKey: key, - rcmgr: rcmgr, - gater: gater, - clock: clock.New(), - connManager: connManager, - conns: map[uint64]*conn{}, - certManagerReady: make(chan struct{}), + pid: id, + privKey: key, + rcmgr: rcmgr, + gater: gater, + clock: clock.New(), + connManager: connManager, + conns: map[uint64]*conn{}, } for _, opt := range opts { if err := opt(t); err != nil { @@ -300,13 +300,12 @@ func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { if t.staticTLSConf == nil { t.listenOnce.Do(func() { t.certManager, t.listenOnceErr = newCertManager(t.privKey, t.clock) - close(t.certManagerReady) + t.hasCertManager.Store(true) }) if t.listenOnceErr != nil { return nil, t.listenOnceErr } } else { - close(t.certManagerReady) return nil, errors.New("static TLS config not supported on WebTransport") } tlsConf := t.staticTLSConf.Clone() @@ -405,10 +404,11 @@ func (t *transport) Resolve(_ context.Context, maddr ma.Multiaddr) ([]ma.Multiad return []ma.Multiaddr{beforeQuicMA.Encapsulate(quicComponent).Encapsulate(sniComponent).Encapsulate(afterQuicMA)}, nil } -func (t *transport) AddCertHashes(m ma.Multiaddr) ma.Multiaddr { - <-t.certManagerReady - if t.certManager == nil { - return m +// AddCertHashes adds the current certificate hashes to a multiaddress. +// If called before Listen, it's a no-op. +func (t *transport) AddCertHashes(m ma.Multiaddr) (ma.Multiaddr, bool) { + if !t.hasCertManager.Load() { + return m, false } - return m.Encapsulate(t.certManager.AddrComponent()) + return m.Encapsulate(t.certManager.AddrComponent()), true }