diff --git a/p2p/transport/quic/transport.go b/p2p/transport/quic/transport.go index 1bcbdec81f..ddd358d761 100644 --- a/p2p/transport/quic/transport.go +++ b/p2p/transport/quic/transport.go @@ -306,7 +306,8 @@ func (t *transport) Listen(addr ma.Multiaddr) (tpt.Listener, error) { underlyingListener = &l acceptRunner = &acceptLoopRunner{ - muxer: make(map[quic.VersionNumber]chan acceptVal), + acceptSem: make(chan struct{}, 1), + muxer: make(map[quic.VersionNumber]chan acceptVal), } } @@ -316,7 +317,7 @@ func (t *transport) Listen(addr ma.Multiaddr) (tpt.Listener, error) { udpAddr: udpAddr.String(), t: t, acceptRunnner: acceptRunner, - acceptChan: acceptRunner.acceptForVersion(version), + acceptChan: acceptRunner.AcceptForVersion(version), } listeners = append(listeners, l) diff --git a/p2p/transport/quic/virtuallistener.go b/p2p/transport/quic/virtuallistener.go index 9bf508d765..1f730bc67a 100644 --- a/p2p/transport/quic/virtuallistener.go +++ b/p2p/transport/quic/virtuallistener.go @@ -29,7 +29,7 @@ func (l *virtualListener) Multiaddr() ma.Multiaddr { } func (l *virtualListener) Close() error { - l.acceptRunnner.rmAcceptForVersion(l.version) + l.acceptRunnner.RmAcceptForVersion(l.version) l.t.listenersMu.Lock() defer l.t.listenersMu.Unlock() @@ -40,37 +40,23 @@ func (l *virtualListener) Close() error { // This is the last virtual listener here, so we can close the underlying listener err = l.listener.Close() delete(l.t.listeners, l.udpAddr) - } else { - for i := 0; i < len(listeners); i++ { - // Swap remove - if l == listeners[i] { - listeners[i] = listeners[len(listeners)-1] - listeners = listeners[0 : len(listeners)-1] - l.t.listeners[l.udpAddr] = listeners - break - } - } + return err } - return err + for i := 0; i < len(listeners); i++ { + // Swap remove + if l == listeners[i] { + listeners[i] = listeners[len(listeners)-1] + listeners = listeners[0 : len(listeners)-1] + l.t.listeners[l.udpAddr] = listeners + break + } + } + return nil } func (l *virtualListener) Accept() (tpt.CapableConn, error) { - var v acceptVal - var ok bool - select { - // Check if we have a pending connection first - case v, ok = <-l.acceptChan: - default: - // No? Let's call Accept and wait for a connection - go l.acceptRunnner.Accept(l.listener, l.version) - v, ok = <-l.acceptChan - } - if !ok { - return nil, errors.New("listener closed") - } - - return v.conn, v.err + return l.acceptRunnner.Accept(l.listener, l.version, l.acceptChan) } type acceptVal struct { @@ -79,13 +65,13 @@ type acceptVal struct { } type acceptLoopRunner struct { - listenerMu sync.Mutex + acceptSem chan struct{} muxerMu sync.Mutex muxer map[quic.VersionNumber]chan acceptVal } -func (r *acceptLoopRunner) acceptForVersion(v quic.VersionNumber) chan acceptVal { +func (r *acceptLoopRunner) AcceptForVersion(v quic.VersionNumber) chan acceptVal { r.muxerMu.Lock() defer r.muxerMu.Unlock() @@ -99,7 +85,7 @@ func (r *acceptLoopRunner) acceptForVersion(v quic.VersionNumber) chan acceptVal return ch } -func (r *acceptLoopRunner) rmAcceptForVersion(v quic.VersionNumber) { +func (r *acceptLoopRunner) RmAcceptForVersion(v quic.VersionNumber) { r.muxerMu.Lock() defer r.muxerMu.Unlock() @@ -124,45 +110,70 @@ func (r *acceptLoopRunner) sendErrAndClose(err error) { } } -func (r *acceptLoopRunner) Accept(l *listener, expectedVersion quic.VersionNumber) error { - for { - r.listenerMu.Lock() - conn, err := l.Accept() - r.listenerMu.Unlock() - - if err != nil { - r.sendErrAndClose(err) - return err +// innerAccept is the inner logic of the Accept loop. Assume caller holds the +// acceptSemaphore. May return both a nil conn and nil error if it didn't find a +// conn with the expected version +func (r *acceptLoopRunner) innerAccept(l *listener, expectedVersion quic.VersionNumber, bufferedConnChan chan acceptVal) (tpt.CapableConn, error) { + select { + // Check if we have a buffered connection first from an earlier Accept call + case v, ok := <-bufferedConnChan: + if !ok { + return nil, errors.New("listener closed") } + return v.conn, v.err + default: + } - _, version, err := quicreuse.FromQuicMultiaddr(conn.RemoteMultiaddr()) - if err != nil { - r.sendErrAndClose(err) - return err - } + conn, err := l.Accept() - r.muxerMu.Lock() - ch, ok := r.muxer[version] - r.muxerMu.Unlock() + if err != nil { + r.sendErrAndClose(err) + return nil, err + } - if !ok { - // Nothing to handle this connection version. Close it - conn.Close() - continue - } + _, version, err := quicreuse.FromQuicMultiaddr(conn.RemoteMultiaddr()) + if err != nil { + r.sendErrAndClose(err) + return nil, err + } - // Non blocking - select { - case ch <- acceptVal{conn: conn}: - default: - // We dropped the connection, close it - conn.Close() - continue - } + if version == expectedVersion { + return conn, nil + } + + // This wasn't the version we were expecting, lets queue it up for a + // future Accept call with a different version + r.muxerMu.Lock() + ch, ok := r.muxer[version] + r.muxerMu.Unlock() + + if !ok { + // Nothing to handle this connection version. Close it + conn.Close() + return nil, nil + } + + // Non blocking + select { + case ch <- acceptVal{conn: conn}: + default: + // We dropped the connection, close it + conn.Close() + } - if version == expectedVersion { - // We got the version we were expecting, we can exit. - return nil + return nil, nil +} + +func (r *acceptLoopRunner) Accept(l *listener, expectedVersion quic.VersionNumber, bufferedConnChan chan acceptVal) (tpt.CapableConn, error) { + for { + r.acceptSem <- struct{}{} + conn, err := r.innerAccept(l, expectedVersion, bufferedConnChan) + <-r.acceptSem + + if conn == nil && err == nil { + // Didn't find a conn for the expected version and there was no error, lets try again + continue } + return conn, err } }