From 753e48e5ced653011b519653815d0f2814626f17 Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Wed, 1 Feb 2023 09:59:30 -0800 Subject: [PATCH] rcmgr: *: Always close connscope (#2037) * rcmgr: Fix connection accounting * Always close conn scope in the case of errors * circuitv2: fix resource accounting when connection upgrading fails --------- Co-authored-by: Marten Seemann --- p2p/net/swarm/swarm_conn.go | 12 +++++++++-- p2p/protocol/circuitv2/client/transport.go | 11 ++++++++-- p2p/transport/quic/conn.go | 6 +++++- p2p/transport/quic/listener.go | 21 ++++++++++++------ p2p/transport/quic/transport.go | 17 +++++++++++---- p2p/transport/tcp/tcp.go | 13 ++++++++--- p2p/transport/websocket/websocket.go | 10 ++++++++- p2p/transport/webtransport/listener.go | 22 +++++++++++-------- p2p/transport/webtransport/transport.go | 25 ++++++++++++++-------- 9 files changed, 99 insertions(+), 38 deletions(-) diff --git a/p2p/net/swarm/swarm_conn.go b/p2p/net/swarm/swarm_conn.go index c24ddee310..146305beca 100644 --- a/p2p/net/swarm/swarm_conn.go +++ b/p2p/net/swarm/swarm_conn.go @@ -207,11 +207,20 @@ func (c *Conn) NewStream(ctx context.Context) (network.Stream, error) { if err != nil { return nil, err } - ts, err := c.conn.OpenStream(ctx) + + s, err := c.openAndAddStream(ctx, scope) if err != nil { scope.Done() return nil, err } + return s, nil +} + +func (c *Conn) openAndAddStream(ctx context.Context, scope network.StreamManagementScope) (network.Stream, error) { + ts, err := c.conn.OpenStream(ctx) + if err != nil { + return nil, err + } return c.addStream(ts, network.DirOutbound, scope) } @@ -220,7 +229,6 @@ func (c *Conn) addStream(ts network.MuxedStream, dir network.Direction, scope ne // Are we still online? if c.streams.m == nil { c.streams.Unlock() - scope.Done() ts.Reset() return nil, ErrConnClosed } diff --git a/p2p/protocol/circuitv2/client/transport.go b/p2p/protocol/circuitv2/client/transport.go index 5c94fa7ebd..e08d55707e 100644 --- a/p2p/protocol/circuitv2/client/transport.go +++ b/p2p/protocol/circuitv2/client/transport.go @@ -53,13 +53,20 @@ func (c *Client) Dial(ctx context.Context, a ma.Multiaddr, p peer.ID) (transport if err != nil { return nil, err } - if err := connScope.SetPeer(p); err != nil { + conn, err := c.dialAndUpgrade(ctx, a, p, connScope) + if err != nil { connScope.Done() return nil, err } + return conn, nil +} + +func (c *Client) dialAndUpgrade(ctx context.Context, a ma.Multiaddr, p peer.ID, connScope network.ConnManagementScope) (transport.CapableConn, error) { + if err := connScope.SetPeer(p); err != nil { + return nil, err + } conn, err := c.dial(ctx, a, p) if err != nil { - connScope.Done() return nil, err } conn.tagHop() diff --git a/p2p/transport/quic/conn.go b/p2p/transport/quic/conn.go index 999615ceb8..261e19a2d1 100644 --- a/p2p/transport/quic/conn.go +++ b/p2p/transport/quic/conn.go @@ -32,8 +32,12 @@ var _ tpt.CapableConn = &conn{} // It must be called even if the peer closed the connection in order for // garbage collection to properly work in this package. func (c *conn) Close() error { + return c.closeWithError(0, "") +} + +func (c *conn) closeWithError(errCode quic.ApplicationErrorCode, errString string) error { c.transport.removeConn(c.quicConn) - err := c.quicConn.CloseWithError(0, "") + err := c.quicConn.CloseWithError(errCode, errString) c.scope.Done() return err } diff --git a/p2p/transport/quic/listener.go b/p2p/transport/quic/listener.go index ea6b68bd6c..9c00026f78 100644 --- a/p2p/transport/quic/listener.go +++ b/p2p/transport/quic/listener.go @@ -56,15 +56,13 @@ func (l *listener) Accept() (tpt.CapableConn, error) { } c, err := l.setupConn(qconn) if err != nil { - qconn.CloseWithError(1, err.Error()) continue } + l.transport.addConn(qconn, c) if l.transport.gater != nil && !(l.transport.gater.InterceptAccept(c) && l.transport.gater.InterceptSecured(network.DirInbound, c.remotePeerID, c)) { - c.scope.Done() - qconn.CloseWithError(errorCodeConnectionGating, "connection gated") + c.closeWithError(errorCodeConnectionGating, "connection gated") continue } - l.transport.addConn(qconn, c) // return through active hole punching if any key := holePunchKey{addr: qconn.RemoteAddr().String(), peer: c.remotePeerID} @@ -95,23 +93,32 @@ func (l *listener) setupConn(qconn quic.Connection) (*conn, error) { log.Debugw("resource manager blocked incoming connection", "addr", qconn.RemoteAddr(), "error", err) return nil, err } + c, err := l.setupConnWithScope(qconn, connScope, remoteMultiaddr) + if err != nil { + connScope.Done() + qconn.CloseWithError(1, err.Error()) + return nil, err + } + + return c, nil +} + +func (l *listener) setupConnWithScope(qconn quic.Connection, connScope network.ConnManagementScope, remoteMultiaddr ma.Multiaddr) (*conn, error) { + // The tls.Config used to establish this connection already verified the certificate chain. // Since we don't have any way of knowing which tls.Config was used though, // we have to re-determine the peer's identity here. // Therefore, this is expected to never fail. remotePubKey, err := p2ptls.PubKeyFromCertChain(qconn.ConnectionState().TLS.PeerCertificates) if err != nil { - connScope.Done() return nil, err } remotePeerID, err := peer.IDFromPublicKey(remotePubKey) if err != nil { - connScope.Done() return nil, err } if err := connScope.SetPeer(remotePeerID); err != nil { log.Debugw("resource manager blocked incoming connection for peer", "peer", remotePeerID, "addr", qconn.RemoteAddr(), "error", err) - connScope.Done() return nil, err } diff --git a/p2p/transport/quic/transport.go b/p2p/transport/quic/transport.go index 325471973e..cc967e16c2 100644 --- a/p2p/transport/quic/transport.go +++ b/p2p/transport/quic/transport.go @@ -102,8 +102,7 @@ func NewTransport(key ic.PrivKey, connManager *quicreuse.ConnManager, psk pnet.P } // Dial dials a new QUIC connection -func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) { - tlsConf, keyCh := t.identity.ConfigForPeer(p) +func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (_c tpt.CapableConn, _err error) { if ok, isClient, _ := network.GetSimultaneousConnect(ctx); ok && !isClient { return t.holePunch(ctx, raddr, p) } @@ -113,11 +112,22 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp log.Debugw("resource manager blocked outgoing connection", "peer", p, "addr", raddr, "error", err) return nil, err } + + c, err := t.dialWithScope(ctx, raddr, p, scope) + if err != nil { + scope.Done() + return nil, err + } + return c, nil +} + +func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p peer.ID, scope network.ConnManagementScope) (tpt.CapableConn, error) { if err := scope.SetPeer(p); err != nil { log.Debugw("resource manager blocked outgoing connection for peer", "peer", p, "addr", raddr, "error", err) - scope.Done() return nil, err } + + tlsConf, keyCh := t.identity.ConfigForPeer(p) pconn, err := t.connManager.DialQUIC(ctx, raddr, tlsConf, t.allowWindowIncrease) if err != nil { return nil, err @@ -131,7 +141,6 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp } if remotePubKey == nil { pconn.CloseWithError(1, "") - scope.Done() return nil, errors.New("p2p/transport/quic BUG: expected remote pub key to be set") } diff --git a/p2p/transport/tcp/tcp.go b/p2p/transport/tcp/tcp.go index b41fe7bf1a..f277b3f8f3 100644 --- a/p2p/transport/tcp/tcp.go +++ b/p2p/transport/tcp/tcp.go @@ -181,14 +181,22 @@ func (t *TcpTransport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) log.Debugw("resource manager blocked outgoing connection", "peer", p, "addr", raddr, "error", err) return nil, err } + + c, err := t.dialWithScope(ctx, raddr, p, connScope) + if err != nil { + connScope.Done() + return nil, err + } + return c, nil +} + +func (t *TcpTransport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p peer.ID, connScope network.ConnManagementScope) (transport.CapableConn, error) { if err := connScope.SetPeer(p); err != nil { log.Debugw("resource manager blocked outgoing connection for peer", "peer", p, "addr", raddr, "error", err) - connScope.Done() return nil, err } conn, err := t.maDial(ctx, raddr) if err != nil { - connScope.Done() return nil, err } // Set linger to 0 so we never get stuck in the TIME-WAIT state. When @@ -201,7 +209,6 @@ func (t *TcpTransport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) var err error c, err = newTracingConn(conn, true) if err != nil { - connScope.Done() return nil, err } } diff --git a/p2p/transport/websocket/websocket.go b/p2p/transport/websocket/websocket.go index 03941013a1..e1965123d9 100644 --- a/p2p/transport/websocket/websocket.go +++ b/p2p/transport/websocket/websocket.go @@ -161,11 +161,19 @@ func (t *WebsocketTransport) Dial(ctx context.Context, raddr ma.Multiaddr, p pee if err != nil { return nil, err } - macon, err := t.maDial(ctx, raddr) + c, err := t.dialWithScope(ctx, raddr, p, connScope) if err != nil { connScope.Done() return nil, err } + return c, nil +} + +func (t *WebsocketTransport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p peer.ID, connScope network.ConnManagementScope) (transport.CapableConn, error) { + macon, err := t.maDial(ctx, raddr) + if err != nil { + return nil, err + } conn, err := t.upgrader.Upgrade(ctx, t, macon, network.DirOutbound, p, connScope) if err != nil { return nil, err diff --git a/p2p/transport/webtransport/listener.go b/p2p/transport/webtransport/listener.go index c604022837..7fb66dda9d 100644 --- a/p2p/transport/webtransport/listener.go +++ b/p2p/transport/webtransport/listener.go @@ -103,14 +103,19 @@ func (l *listener) httpHandler(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusServiceUnavailable) return } + err = l.httpHandlerWithConnScope(w, r, connScope) + if err != nil { + connScope.Done() + } +} +func (l *listener) httpHandlerWithConnScope(w http.ResponseWriter, r *http.Request, connScope network.ConnManagementScope) error { sess, err := l.server.Upgrade(w, r) if err != nil { log.Debugw("upgrade failed", "error", err) // TODO: think about the status code to use here w.WriteHeader(500) - connScope.Done() - return + return err } ctx, cancel := context.WithTimeout(l.ctx, handshakeTimeout) sconn, err := l.handshake(ctx, sess) @@ -118,23 +123,20 @@ func (l *listener) httpHandler(w http.ResponseWriter, r *http.Request) { cancel() log.Debugw("handshake failed", "error", err) sess.CloseWithError(1, "") - connScope.Done() - return + return err } cancel() if l.transport.gater != nil && !l.transport.gater.InterceptSecured(network.DirInbound, sconn.RemotePeer(), sconn) { // TODO: can we close with a specific error here? sess.CloseWithError(errorCodeConnectionGating, "") - connScope.Done() - return + return errors.New("gater blocked connection") } if err := connScope.SetPeer(sconn.RemotePeer()); err != nil { log.Debugw("resource manager blocked incoming connection for peer", "peer", sconn.RemotePeer(), "addr", r.RemoteAddr, "error", err) sess.CloseWithError(1, "") - connScope.Done() - return + return err } conn := newConn(l.transport, sess, sconn, connScope) @@ -144,8 +146,10 @@ func (l *listener) httpHandler(w http.ResponseWriter, r *http.Request) { default: log.Debugw("accept queue full, dropping incoming connection", "peer", sconn.RemotePeer(), "addr", r.RemoteAddr, "error", err) sess.CloseWithError(1, "") - connScope.Done() + return errors.New("accept queue full") } + + return nil } func (l *listener) Accept() (tpt.CapableConn, error) { diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index 65c6c1402b..eae3772d33 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -119,6 +119,22 @@ func New(key ic.PrivKey, psk pnet.PSK, connManager *quicreuse.ConnManager, gater } func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) { + scope, err := t.rcmgr.OpenConnection(network.DirOutbound, false, raddr) + if err != nil { + log.Debugw("resource manager blocked outgoing connection", "peer", p, "addr", raddr, "error", err) + return nil, err + } + + c, err := t.dialWithScope(ctx, raddr, p, scope) + if err != nil { + scope.Done() + return nil, err + } + + return c, nil +} + +func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p peer.ID, scope network.ConnManagementScope) (tpt.CapableConn, error) { _, addr, err := manet.DialArgs(raddr) if err != nil { return nil, err @@ -135,32 +151,23 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp sni, _ := extractSNI(raddr) - scope, err := t.rcmgr.OpenConnection(network.DirOutbound, false, raddr) - if err != nil { - log.Debugw("resource manager blocked outgoing connection", "peer", p, "addr", raddr, "error", err) - return nil, err - } if err := scope.SetPeer(p); err != nil { log.Debugw("resource manager blocked outgoing connection for peer", "peer", p, "addr", raddr, "error", err) - scope.Done() return nil, err } maddr, _ := ma.SplitFunc(raddr, func(c ma.Component) bool { return c.Protocol().Code == ma.P_WEBTRANSPORT }) sess, err := t.dial(ctx, maddr, url, sni, certHashes) if err != nil { - scope.Done() return nil, err } sconn, err := t.upgrade(ctx, sess, p, certHashes) if err != nil { sess.CloseWithError(1, "") - scope.Done() return nil, err } if t.gater != nil && !t.gater.InterceptSecured(network.DirOutbound, p, sconn) { sess.CloseWithError(errorCodeConnectionGating, "") - scope.Done() return nil, fmt.Errorf("secured connection gated") } conn := newConn(t, sess, sconn, scope)