From b6f1fb08405b1e3fecf3a720ba3aa6a96c75dcb4 Mon Sep 17 00:00:00 2001 From: Adin Schmahmann Date: Wed, 25 Sep 2024 11:55:26 -0400 Subject: [PATCH] feat(tcpreuse): add options for sharing TCP listeners amongst TCP, WS, and WSS transports --- p2p/transport/tcp/tcp.go | 32 ++- p2p/transport/tcp/tcp_test.go | 13 +- p2p/transport/tcpreuse/demultiplex.go | 240 ++++++++++++++++++ p2p/transport/tcpreuse/demultiplex_test.go | 50 ++++ p2p/transport/tcpreuse/dialer.go | 16 ++ p2p/transport/tcpreuse/listener.go | 250 +++++++++++++++++++ p2p/transport/{tcp => tcpreuse}/reuseport.go | 10 +- p2p/transport/websocket/addrs_test.go | 2 +- p2p/transport/websocket/listener.go | 34 ++- p2p/transport/websocket/websocket.go | 12 +- 10 files changed, 635 insertions(+), 24 deletions(-) create mode 100644 p2p/transport/tcpreuse/demultiplex.go create mode 100644 p2p/transport/tcpreuse/demultiplex_test.go create mode 100644 p2p/transport/tcpreuse/dialer.go create mode 100644 p2p/transport/tcpreuse/listener.go rename p2p/transport/{tcp => tcpreuse}/reuseport.go (81%) diff --git a/p2p/transport/tcp/tcp.go b/p2p/transport/tcp/tcp.go index d52bb96019..66fe9b7631 100644 --- a/p2p/transport/tcp/tcp.go +++ b/p2p/transport/tcp/tcp.go @@ -13,6 +13,7 @@ import ( "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-libp2p/p2p/net/reuseport" + "github.com/libp2p/go-libp2p/p2p/transport/tcpreuse" logging "github.com/ipfs/go-log/v2" ma "github.com/multiformats/go-multiaddr" @@ -33,6 +34,9 @@ type canKeepAlive interface { var _ canKeepAlive = &net.TCPConn{} +// Deprecated: Use tcpreuse.ReuseportIsAvailable +var ReuseportIsAvailable = tcpreuse.ReuseportIsAvailable + func tryKeepAlive(conn net.Conn, keepAlive bool) { keepAliveConn, ok := conn.(canKeepAlive) if !ok { @@ -113,6 +117,13 @@ func WithMetrics() Option { } } +func WithSharedTCP(mgr *tcpreuse.ConnMgr) Option { + return func(tr *TcpTransport) error { + tr.sharedTcp = mgr + return nil + } +} + // TcpTransport is the TCP transport. type TcpTransport struct { // Connection upgrader for upgrading insecure stream connections to @@ -122,6 +133,9 @@ type TcpTransport struct { disableReuseport bool // Explicitly disable reuseport. enableMetrics bool + // share and demultiplex TCP listeners across multiple transports + sharedTcp *tcpreuse.ConnMgr + // TCP connect timeout connectTimeout time.Duration @@ -168,6 +182,10 @@ func (t *TcpTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (manet.Co defer cancel() } + if t.sharedTcp != nil { + return t.sharedTcp.DialContext(ctx, raddr) + } + if t.UseReuseport() { return t.reuse.DialContext(ctx, raddr) } @@ -233,10 +251,10 @@ func (t *TcpTransport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p // UseReuseport returns true if reuseport is enabled and available. func (t *TcpTransport) UseReuseport() bool { - return !t.disableReuseport && ReuseportIsAvailable() + return !t.disableReuseport && tcpreuse.ReuseportIsAvailable() } -func (t *TcpTransport) maListen(laddr ma.Multiaddr) (manet.Listener, error) { +func (t *TcpTransport) unsharedMAListen(laddr ma.Multiaddr) (manet.Listener, error) { if t.UseReuseport() { return t.reuse.Listen(laddr) } @@ -245,10 +263,18 @@ func (t *TcpTransport) maListen(laddr ma.Multiaddr) (manet.Listener, error) { // Listen listens on the given multiaddr. func (t *TcpTransport) Listen(laddr ma.Multiaddr) (transport.Listener, error) { - list, err := t.maListen(laddr) + var list manet.Listener + var err error + + if t.sharedTcp == nil { + list, err = t.unsharedMAListen(laddr) + } else { + list, err = t.sharedTcp.DemultiplexedListen(laddr, tcpreuse.MultistreamSelect) + } if err != nil { return nil, err } + if t.enableMetrics { list = newTracingListener(&tcpListener{list, 0}) } diff --git a/p2p/transport/tcp/tcp_test.go b/p2p/transport/tcp/tcp_test.go index a57a65e420..4c692fbf4c 100644 --- a/p2p/transport/tcp/tcp_test.go +++ b/p2p/transport/tcp/tcp_test.go @@ -14,6 +14,7 @@ import ( "github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-libp2p/p2p/muxer/yamux" tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader" + "github.com/libp2p/go-libp2p/p2p/transport/tcpreuse" ttransport "github.com/libp2p/go-libp2p/p2p/transport/testsuite" ma "github.com/multiformats/go-multiaddr" @@ -41,9 +42,9 @@ func TestTcpTransport(t *testing.T) { zero := "/ip4/127.0.0.1/tcp/0" ttransport.SubtestTransport(t, ta, tb, zero, peerA) - envReuseportVal = false + tcpreuse.EnvReuseportVal = false } - envReuseportVal = true + tcpreuse.EnvReuseportVal = true } func TestTcpTransportWithMetrics(t *testing.T) { @@ -126,9 +127,9 @@ func TestTcpTransportCantDialDNS(t *testing.T) { t.Fatal("shouldn't be able to dial dns") } - envReuseportVal = false + tcpreuse.EnvReuseportVal = false } - envReuseportVal = true + tcpreuse.EnvReuseportVal = true } func TestTcpTransportCantListenUtp(t *testing.T) { @@ -143,9 +144,9 @@ func TestTcpTransportCantListenUtp(t *testing.T) { _, err = tpt.Listen(utpa) require.Error(t, err, "shouldn't be able to listen on utp addr with tcp transport") - envReuseportVal = false + tcpreuse.EnvReuseportVal = false } - envReuseportVal = true + tcpreuse.EnvReuseportVal = true } func TestDialWithUpdates(t *testing.T) { diff --git a/p2p/transport/tcpreuse/demultiplex.go b/p2p/transport/tcpreuse/demultiplex.go new file mode 100644 index 0000000000..59e26a9aee --- /dev/null +++ b/p2p/transport/tcpreuse/demultiplex.go @@ -0,0 +1,240 @@ +package tcpreuse + +import ( + "bufio" + "errors" + "fmt" + "io" + "math" + "net" + "time" + + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" +) + +type peekAble interface { + // Peek returns the next n bytes without advancing the reader. The bytes stop + // being valid at the next read call. If Peek returns fewer than n bytes, it + // also returns an error explaining why the read is short. The error is + // [ErrBufferFull] if n is larger than b's buffer size. + Peek(n int) ([]byte, error) +} + +var _ peekAble = (*bufio.Reader)(nil) + +type DemultiplexedConnType int + +const ( + Unknown DemultiplexedConnType = iota + MultistreamSelect + HTTP + TLS +) + +func (t DemultiplexedConnType) String() string { + switch t { + case MultistreamSelect: + return "MultistreamSelect" + case HTTP: + return "HTTP" + case TLS: + return "TLS" + default: + return fmt.Sprintf("Unknown(%d)", int(t)) + } +} + +func (t DemultiplexedConnType) IsKnown() bool { + return t >= 1 || t <= 3 +} + +func ConnTypeFromConn(c net.Conn) (DemultiplexedConnType, manet.Conn, error) { + if err := c.SetReadDeadline(time.Now().Add(1 * time.Second)); err != nil { + closeErr := c.Close() + return 0, nil, errors.Join(err, closeErr) + } + + s, sc, err := ReadSampleFromConn(c) + if err != nil { + closeErr := c.Close() + return 0, nil, errors.Join(err, closeErr) + } + + if err := c.SetReadDeadline(time.Time{}); err != nil { + closeErr := c.Close() + return 0, nil, errors.Join(err, closeErr) + } + + if IsMultistreamSelect(s) { + return MultistreamSelect, sc, nil + } + if IsTLS(s) { + return TLS, sc, nil + } + if IsHTTP(s) { + return HTTP, sc, nil + } + return Unknown, sc, nil +} + +// ReadSampleFromConn read the sample and returns a reader which still include the sample, so it can be kept undamaged. +// If an error occurs it only return the error. +func ReadSampleFromConn(c net.Conn) (Sample, manet.Conn, error) { + if peekAble, ok := c.(peekAble); ok { + b, err := peekAble.Peek(len(Sample{})) + switch { + case err == nil: + mac, err := manet.WrapNetConn(c) + if err != nil { + return Sample{}, nil, err + } + + return Sample(b), mac, nil + case errors.Is(err, bufio.ErrBufferFull): + // fallback to sampledConn + default: + return Sample{}, nil, err + } + } + + tcpConnLike, ok := c.(tcpConnInterface) + if !ok { + return Sample{}, nil, fmt.Errorf("expected tcp-like connection") + } + + laddr, err := manet.FromNetAddr(c.LocalAddr()) + if err != nil { + return Sample{}, nil, fmt.Errorf("failed to convert nconn.LocalAddr: %s", err) + } + + raddr, err := manet.FromNetAddr(c.RemoteAddr()) + if err != nil { + return Sample{}, nil, fmt.Errorf("failed to convert nconn.RemoteAddr: %s", err) + } + + sc := &sampledConn{tcpConnInterface: tcpConnLike, maEndpoints: maEndpoints{laddr: laddr, raddr: raddr}} + _, err = io.ReadFull(c, sc.s[:]) + if err != nil { + return Sample{}, nil, err + } + + return sc.s, sc, nil +} + +// Try out best to mimic a TCPConn's functions +// Note: Skipping `SyscallConn() (syscall.RawConn, error)` since it can be misused given we've read a few bytes from the connection +// If this is an issue here we can revisit the options. +type tcpConnInterface interface { + net.Conn + + CloseRead() error + CloseWrite() error + + SetLinger(sec int) error + SetKeepAlive(keepalive bool) error + SetKeepAlivePeriod(d time.Duration) error + SetNoDelay(noDelay bool) error + MultipathTCP() (bool, error) + + io.ReaderFrom + io.WriterTo +} + +type maEndpoints struct { + laddr ma.Multiaddr + raddr ma.Multiaddr +} + +// LocalMultiaddr returns the local address associated with +// this connection +func (c *maEndpoints) LocalMultiaddr() ma.Multiaddr { + return c.laddr +} + +// RemoteMultiaddr returns the remote address associated with +// this connection +func (c *maEndpoints) RemoteMultiaddr() ma.Multiaddr { + return c.raddr +} + +type sampledConn struct { + tcpConnInterface + maEndpoints + + s Sample + readFromSample uint8 +} + +var _ = [math.MaxUint8]struct{}{}[len(Sample{})] // compiletime assert sampledConn.readFromSample wont overflow +var _ io.ReaderFrom = (*sampledConn)(nil) +var _ io.WriterTo = (*sampledConn)(nil) + +func (sc *sampledConn) Read(b []byte) (int, error) { + if int(sc.readFromSample) != len(sc.s) { + red := copy(b, sc.s[sc.readFromSample:]) + sc.readFromSample += uint8(red) + return red, nil + } + + return sc.tcpConnInterface.Read(b) +} + +// forward optimizations +func (sc *sampledConn) ReadFrom(r io.Reader) (int64, error) { + return io.Copy(sc.tcpConnInterface, r) +} + +// forward optimizations +func (sc *sampledConn) WriteTo(w io.Writer) (total int64, err error) { + if int(sc.readFromSample) != len(sc.s) { + b := sc.s[sc.readFromSample:] + written, err := w.Write(b) + if written < 0 || len(b) < written { + // buggy writer, harden against this + sc.readFromSample = uint8(len(sc.s)) + total = int64(len(sc.s)) + } else { + sc.readFromSample += uint8(written) + total += int64(written) + } + if err != nil { + return total, err + } + } + + written, err := io.Copy(w, sc.tcpConnInterface) + total += written + return total, err +} + +type Matcher interface { + Match(s Sample) bool +} + +// Sample might evolve over time. +type Sample [3]byte + +// Matchers are implemented here instead of in the transports so we can easily fuzz them together. + +func IsMultistreamSelect(s Sample) bool { + return string(s[:]) == "\x13/m" +} + +func IsHTTP(s Sample) bool { + switch string(s[:]) { + case "GET", "HEA", "POS", "PUT", "DEL", "CON", "OPT", "TRA", "PAT": + return true + default: + return false + } +} + +func IsTLS(s Sample) bool { + switch string(s[:]) { + case "\x16\x03\x01", "\x16\x03\x02", "\x16\x03\x03", "\x16\x03\x04": + return true + default: + return false + } +} diff --git a/p2p/transport/tcpreuse/demultiplex_test.go b/p2p/transport/tcpreuse/demultiplex_test.go new file mode 100644 index 0000000000..3d6e91f35a --- /dev/null +++ b/p2p/transport/tcpreuse/demultiplex_test.go @@ -0,0 +1,50 @@ +package tcpreuse + +import "testing" + +func FuzzClash(f *testing.F) { + // make untyped literals type correctly + add := func(a, b, c byte) { f.Add(a, b, c) } + + // multistream-select + add('\x13', '/', 'm') + // http + add('G', 'E', 'T') + add('H', 'E', 'A') + add('P', 'O', 'S') + add('P', 'U', 'T') + add('D', 'E', 'L') + add('C', 'O', 'N') + add('O', 'P', 'T') + add('T', 'R', 'A') + add('P', 'A', 'T') + // tls + add('\x16', '\x03', '\x01') + add('\x16', '\x03', '\x02') + add('\x16', '\x03', '\x03') + add('\x16', '\x03', '\x04') + + f.Fuzz(func(t *testing.T, a, b, c byte) { + s := Sample{a, b, c} + var total uint + + ms := IsMultistreamSelect(s) + if ms { + total++ + } + + http := IsHTTP(s) + if http { + total++ + } + + tls := IsTLS(s) + if tls { + total++ + } + + if total > 1 { + t.Errorf("clash on: %q; ms: %v; http: %v; tls: %v", s, ms, http, tls) + } + }) +} diff --git a/p2p/transport/tcpreuse/dialer.go b/p2p/transport/tcpreuse/dialer.go new file mode 100644 index 0000000000..ad634583ed --- /dev/null +++ b/p2p/transport/tcpreuse/dialer.go @@ -0,0 +1,16 @@ +package tcpreuse + +import ( + "context" + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" +) + +// DialContext is like Dial but takes a context. +func (t *ConnMgr) DialContext(ctx context.Context, raddr ma.Multiaddr) (manet.Conn, error) { + if t.useReuseport() { + return t.reuse.DialContext(ctx, raddr) + } + var d manet.Dialer + return d.DialContext(ctx, raddr) +} diff --git a/p2p/transport/tcpreuse/listener.go b/p2p/transport/tcpreuse/listener.go new file mode 100644 index 0000000000..59aeed1f93 --- /dev/null +++ b/p2p/transport/tcpreuse/listener.go @@ -0,0 +1,250 @@ +package tcpreuse + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + + logging "github.com/ipfs/go-log/v2" + "github.com/libp2p/go-libp2p/core/transport" + "github.com/libp2p/go-libp2p/p2p/net/reuseport" + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" +) + +var log = logging.Logger("tcp-demultiplex") + +type ConnMgr struct { + disableReuseport bool + reuse reuseport.Transport + listeners map[string]*multiplexedListener + mx sync.Mutex +} + +func NewConnMgr(disableReuseport bool) *ConnMgr { + return &ConnMgr{ + disableReuseport: disableReuseport, + reuse: reuseport.Transport{}, + listeners: make(map[string]*multiplexedListener), + } +} + +func (t *ConnMgr) maListen(laddr ma.Multiaddr) (manet.Listener, error) { + if t.useReuseport() { + return t.reuse.Listen(laddr) + } else { + return manet.Listen(laddr) + } +} + +func (t *ConnMgr) useReuseport() bool { + return !t.disableReuseport && ReuseportIsAvailable() +} + +func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType DemultiplexedConnType) (manet.Listener, error) { + if !connType.IsKnown() { + return nil, fmt.Errorf("unknown connection type: %s", connType) + } + + t.mx.Lock() + defer t.mx.Unlock() + ml, ok := t.listeners[laddr.String()] + if ok { + dl, err := ml.DemultiplexedListen(connType) + if err != nil { + return nil, err + } + return dl, nil + } + + l, err := t.maListen(laddr) + if err != nil { + return nil, err + } + + ctx, cancel := context.WithCancel(context.Background()) + cancelFunc := func() error { + cancel() + t.mx.Lock() + defer t.mx.Unlock() + delete(t.listeners, laddr.String()) + return l.Close() + } + ml = &multiplexedListener{ + Listener: l, + listeners: make(map[DemultiplexedConnType]*demultiplexedListener), + buffer: make(chan manet.Conn, 16), // TODO: how big should this buffer be? + ctx: ctx, + closeFn: cancelFunc, + } + + dl, err := ml.DemultiplexedListen(connType) + if err != nil { + cerr := ml.Close() + return nil, errors.Join(err, cerr) + } + + go func() { + err = ml.Run() + if err != nil { + log.Debugf("Error running multiplexed listener: %s", err.Error()) + } + }() + + t.listeners[laddr.String()] = ml + + return dl, nil +} + +var _ manet.Listener = &demultiplexedListener{} + +type multiplexedListener struct { + manet.Listener + listeners map[DemultiplexedConnType]*demultiplexedListener + mx sync.Mutex + listenerCounter int + buffer chan manet.Conn + + ctx context.Context + closeFn func() error +} + +func (m *multiplexedListener) DemultiplexedListen(connType DemultiplexedConnType) (manet.Listener, error) { + if !connType.IsKnown() { + return nil, fmt.Errorf("unknown connection type: %s", connType) + } + + m.mx.Lock() + defer m.mx.Unlock() + l, ok := m.listeners[connType] + if ok { + return l, nil + } + + ctx, cancel := context.WithCancel(m.ctx) + closeFn := func() error { + cancel() + m.mx.Lock() + defer m.mx.Unlock() + m.listenerCounter-- + if m.listenerCounter == 0 { + return m.Close() + } + return nil + } + + l = &demultiplexedListener{ + buffer: make(chan manet.Conn, 16), // TODO: how big should this buffer be? + inner: m.Listener, + ctx: ctx, + closeFn: closeFn, + } + + m.listeners[connType] = l + m.listenerCounter++ + + return l, nil +} + +func (m *multiplexedListener) Run() error { + const numWorkers = 16 + for i := 0; i < numWorkers; i++ { + go func() { + m.background() + }() + } + + for { + c, err := m.Listener.Accept() + if err != nil { + return err + } + + select { + case m.buffer <- c: + case <-m.ctx.Done(): + return transport.ErrListenerClosed + } + } +} + +func (m *multiplexedListener) background() { + // TODO: if/how do we want to handle stalled connections and stop them from clogging up the pipeline? + // Drop connection because the buffer is full + for { + select { + case c := <-m.buffer: + t, sampleC, err := ConnTypeFromConn(c) + if err != nil { + closeErr := c.Close() + err = errors.Join(err, closeErr) + log.Debugf("error demultiplexing connection: %s", err.Error()) + continue + } + + demux, ok := m.listeners[t] + if !ok { + closeErr := c.Close() + if closeErr != nil { + log.Debugf("no registered listener for demultiplex connection %s. Error closing the connection %s", t, closeErr.Error()) + } else { + log.Debugf("no registered listener for demultiplex connection %s", t) + } + continue + } + + select { + case demux.buffer <- sampleC: + case <-m.ctx.Done(): + return + default: + closeErr := c.Close() + if closeErr != nil { + log.Debugf("dropped connection due to full buffer of awaiting connections of type %s. Error closing the connection %s", t, closeErr.Error()) + } else { + log.Debugf("dropped connection due to full buffer of awaiting connections of type %s", t) + } + continue + } + case <-m.ctx.Done(): + return + } + } +} + +func (m *multiplexedListener) Close() error { + cerr := m.closeFn() + lerr := m.Listener.Close() + return errors.Join(lerr, cerr) +} + +type demultiplexedListener struct { + buffer chan manet.Conn + inner manet.Listener + ctx context.Context + closeFn func() error +} + +func (m *demultiplexedListener) Accept() (manet.Conn, error) { + select { + case c := <-m.buffer: + return c, nil + case <-m.ctx.Done(): + return nil, transport.ErrListenerClosed + } +} + +func (m *demultiplexedListener) Close() error { + return m.closeFn() +} + +func (m *demultiplexedListener) Multiaddr() ma.Multiaddr { + // TODO: do we need to add a suffix for the rest of the transport? + return m.inner.Multiaddr() +} + +func (m *demultiplexedListener) Addr() net.Addr { + return m.inner.Addr() +} diff --git a/p2p/transport/tcp/reuseport.go b/p2p/transport/tcpreuse/reuseport.go similarity index 81% rename from p2p/transport/tcp/reuseport.go rename to p2p/transport/tcpreuse/reuseport.go index ba09304622..a2529c0bda 100644 --- a/p2p/transport/tcp/reuseport.go +++ b/p2p/transport/tcpreuse/reuseport.go @@ -1,4 +1,4 @@ -package tcp +package tcpreuse import ( "os" @@ -11,13 +11,13 @@ import ( // It default to true. const envReuseport = "LIBP2P_TCP_REUSEPORT" -// envReuseportVal stores the value of envReuseport. defaults to true. -var envReuseportVal = true +// EnvReuseportVal stores the value of envReuseport. defaults to true. +var EnvReuseportVal = true func init() { v := strings.ToLower(os.Getenv(envReuseport)) if v == "false" || v == "f" || v == "0" { - envReuseportVal = false + EnvReuseportVal = false log.Infof("REUSEPORT disabled (LIBP2P_TCP_REUSEPORT=%s)", v) } } @@ -31,5 +31,5 @@ func init() { // If this becomes a sought after feature, we could add this to the config. // In the end, reuseport is a stop-gap. func ReuseportIsAvailable() bool { - return envReuseportVal && reuseport.Available() + return EnvReuseportVal && reuseport.Available() } diff --git a/p2p/transport/websocket/addrs_test.go b/p2p/transport/websocket/addrs_test.go index 3c5ba502a9..50a8b9e823 100644 --- a/p2p/transport/websocket/addrs_test.go +++ b/p2p/transport/websocket/addrs_test.go @@ -69,7 +69,7 @@ func TestConvertWebsocketMultiaddrToNetAddr(t *testing.T) { } func TestListeningOnDNSAddr(t *testing.T) { - ln, err := newListener(ma.StringCast("/dns/localhost/tcp/0/ws"), nil) + ln, err := newListener(ma.StringCast("/dns/localhost/tcp/0/ws"), nil, nil) require.NoError(t, err) addr := ln.Multiaddr() first, rest := ma.SplitFirst(addr) diff --git a/p2p/transport/websocket/listener.go b/p2p/transport/websocket/listener.go index 3ff72830d1..2253b6597e 100644 --- a/p2p/transport/websocket/listener.go +++ b/p2p/transport/websocket/listener.go @@ -9,6 +9,7 @@ import ( "sync" "github.com/libp2p/go-libp2p/core/transport" + "github.com/libp2p/go-libp2p/p2p/transport/tcpreuse" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" @@ -44,7 +45,7 @@ func (pwma *parsedWebsocketMultiaddr) toMultiaddr() ma.Multiaddr { // newListener creates a new listener from a raw net.Listener. // tlsConf may be nil (for unencrypted websockets). -func newListener(a ma.Multiaddr, tlsConf *tls.Config) (*listener, error) { +func newListener(a ma.Multiaddr, tlsConf *tls.Config, sharedTcp *tcpreuse.ConnMgr) (*listener, error) { parsed, err := parseWebsocketMultiaddr(a) if err != nil { return nil, err @@ -54,19 +55,36 @@ func newListener(a ma.Multiaddr, tlsConf *tls.Config) (*listener, error) { return nil, fmt.Errorf("cannot listen on wss address %s without a tls.Config", a) } - lnet, lnaddr, err := manet.DialArgs(parsed.restMultiaddr) - if err != nil { - return nil, err - } - nl, err := net.Listen(lnet, lnaddr) - if err != nil { - return nil, err + var nl net.Listener + + if sharedTcp == nil { + lnet, lnaddr, err := manet.DialArgs(parsed.restMultiaddr) + if err != nil { + return nil, err + } + nl, err = net.Listen(lnet, lnaddr) + if err != nil { + return nil, err + } + } else { + var connType tcpreuse.DemultiplexedConnType + if parsed.isWSS { + connType = tcpreuse.TLS + } else { + connType = tcpreuse.HTTP + } + mal, err := sharedTcp.DemultiplexedListen(parsed.restMultiaddr, connType) + if err != nil { + return nil, err + } + nl = manet.NetListener(mal) } laddr, err := manet.FromNetAddr(nl.Addr()) if err != nil { return nil, err } + first, _ := ma.SplitFirst(a) // Don't resolve dns addresses. // We want to be able to announce domain names, so the peer can validate the TLS certificate. diff --git a/p2p/transport/websocket/websocket.go b/p2p/transport/websocket/websocket.go index 36818decee..68ac5e77a4 100644 --- a/p2p/transport/websocket/websocket.go +++ b/p2p/transport/websocket/websocket.go @@ -11,6 +11,7 @@ import ( "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/transport" + "github.com/libp2p/go-libp2p/p2p/transport/tcpreuse" ma "github.com/multiformats/go-multiaddr" mafmt "github.com/multiformats/go-multiaddr-fmt" @@ -80,6 +81,13 @@ func WithTLSConfig(conf *tls.Config) Option { } } +func WithSharedTCP(mgr *tcpreuse.ConnMgr) Option { + return func(t *WebsocketTransport) error { + t.sharedTcp = mgr + return nil + } +} + // WebsocketTransport is the actual go-libp2p transport type WebsocketTransport struct { upgrader transport.Upgrader @@ -87,6 +95,8 @@ type WebsocketTransport struct { tlsClientConf *tls.Config tlsConf *tls.Config + + sharedTcp *tcpreuse.ConnMgr } var _ transport.Transport = (*WebsocketTransport)(nil) @@ -233,7 +243,7 @@ func (t *WebsocketTransport) maListen(a ma.Multiaddr) (manet.Listener, error) { if t.tlsConf != nil { tlsConf = t.tlsConf.Clone() } - l, err := newListener(a, tlsConf) + l, err := newListener(a, tlsConf, t.sharedTcp) if err != nil { return nil, err }