From 022530c41555839e27aec3868cc480fb7b5e33d4 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Fri, 17 May 2024 15:53:22 -0700 Subject: [PATCH] http2: add a more full-featured test net.Conn Add a net.Conn implementation that plays nicely with testsyncGroup, implements read/write timeouts, and gives control over buffering to let us write tests that cause writes to a Conn to block at specific points in time. Change-Id: I9d870b211ac9d938a8c4a221277981cdb821a6e4 Reviewed-on: https://go-review.googlesource.com/c/net/+/586246 Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- http2/clientconn_test.go | 116 ++----------- http2/netconn_test.go | 350 +++++++++++++++++++++++++++++++++++++++ http2/sync_test.go | 7 + http2/transport_test.go | 4 +- 4 files changed, 376 insertions(+), 101 deletions(-) create mode 100644 http2/netconn_test.go diff --git a/http2/clientconn_test.go b/http2/clientconn_test.go index 855f44e6b..36f080b9b 100644 --- a/http2/clientconn_test.go +++ b/http2/clientconn_test.go @@ -10,11 +10,10 @@ package http2 import ( "bytes" "context" - "errors" "fmt" "io" - "net" "net/http" + "os" "reflect" "runtime" "slices" @@ -104,7 +103,7 @@ type testClientConn struct { roundtrips []*testRoundTrip - netconn testClientConnNetConn + netconn *synctestNetConn } func newTestClientConnFromClientConn(t *testing.T, cc *ClientConn) *testClientConn { @@ -114,22 +113,21 @@ func newTestClientConnFromClientConn(t *testing.T, cc *ClientConn) *testClientCo cc: cc, group: cc.t.transportTestHooks.group.(*synctestGroup), } + cli, srv := synctestNetPipe(tc.group) + srv.SetReadDeadline(tc.group.Now()) + tc.netconn = srv tc.enc = hpack.NewEncoder(&tc.encbuf) - tc.netconn.gate = newGate() // all writes and reads are finished. // // cli is the ClientConn's side, srv is the side controlled by the test. - cc.tconn = &tc.netconn - tc.fr = NewFramer( - (*testClientConnNetConnWriteToClient)(&tc.netconn), - (*testClientConnNetConnReadFromClient)(&tc.netconn), - ) + cc.tconn = cli + tc.fr = NewFramer(srv, srv) tc.fr.ReadMetaHeaders = hpack.NewDecoder(initialHeaderTableSize, nil) tc.fr.SetMaxReadFrameSize(10 << 20) t.Cleanup(func() { - tc.closeWrite(io.EOF) + tc.closeWrite() }) return tc } @@ -138,8 +136,7 @@ func (tc *testClientConn) readClientPreface() { tc.t.Helper() // Read the client's HTTP/2 preface, sent prior to any HTTP/2 frames. buf := make([]byte, len(clientPreface)) - r := (*testClientConnNetConnReadFromClient)(&tc.netconn) - if _, err := io.ReadFull(r, buf); err != nil { + if _, err := io.ReadFull(tc.netconn, buf); err != nil { tc.t.Fatalf("reading preface: %v", err) } if !bytes.Equal(buf, clientPreface) { @@ -174,26 +171,23 @@ func (tc *testClientConn) advance(d time.Duration) { // hasFrame reports whether a frame is available to be read. func (tc *testClientConn) hasFrame() bool { - tc.netconn.lock() - defer tc.netconn.unlock() - return tc.netconn.fromConn.Len() > 0 + return len(tc.netconn.Peek()) > 0 } func (tc *testClientConn) isClosed() bool { - tc.netconn.lock() - defer tc.netconn.unlock() - return tc.netconn.fromConnClosed + return tc.netconn.IsClosedByPeer() } // readFrame reads the next frame from the conn. func (tc *testClientConn) readFrame() Frame { + tc.t.Helper() tc.sync() fr, err := tc.fr.ReadFrame() - if err == io.EOF { + if err == io.EOF || err == os.ErrDeadlineExceeded { return nil } if err != nil { - return nil + tc.t.Fatalf("ReadFrame: %v", err) } return fr } @@ -597,10 +591,8 @@ func (tc *testClientConn) writeWindowUpdate(streamID, incr uint32) { // closeWrite causes the net.Conn used by the ClientConn to return a error // from Read calls. -func (tc *testClientConn) closeWrite(err error) { - tc.netconn.lock() - tc.netconn.toConnErr = err - tc.netconn.unlock() +func (tc *testClientConn) closeWrite() { + tc.netconn.Close() tc.sync() } @@ -746,80 +738,6 @@ func diffHeaders(got, want http.Header) string { return fmt.Sprintf("got: %v\nwant: %v", got, want) } -// testClientConnNetConn implements net.Conn, -// and is the Conn used by a ClientConn under test. -type testClientConnNetConn struct { - gate gate - toConn bytes.Buffer - toConnErr error - fromConn bytes.Buffer - fromConnClosed bool -} - -func (c *testClientConnNetConn) lock() { - c.gate.lock() -} - -func (c *testClientConnNetConn) unlock() { - c.gate.unlock(c.toConn.Len() > 0 || c.toConnErr != nil) -} - -func (c *testClientConnNetConn) Read(b []byte) (n int, err error) { - if err := c.gate.waitAndLock(context.Background()); err != nil { - return 0, err - } - defer c.unlock() - if c.toConn.Len() == 0 && c.toConnErr != nil { - return 0, c.toConnErr - } - return c.toConn.Read(b) -} - -func (c *testClientConnNetConn) Write(b []byte) (n int, err error) { - c.lock() - defer c.unlock() - return c.fromConn.Write(b) -} - -func (c *testClientConnNetConn) Close() error { - c.lock() - defer c.unlock() - c.fromConnClosed = true - c.toConn.Reset() - if c.toConnErr == nil { - c.toConnErr = errors.New("connection closed") - } - return nil -} - -func (*testClientConnNetConn) LocalAddr() (_ net.Addr) { return } -func (*testClientConnNetConn) RemoteAddr() (_ net.Addr) { return } -func (*testClientConnNetConn) SetDeadline(t time.Time) error { return nil } -func (*testClientConnNetConn) SetReadDeadline(t time.Time) error { return nil } -func (*testClientConnNetConn) SetWriteDeadline(t time.Time) error { return nil } - -// testClientConnNetConnWriteToClient is a view on a testClientConnNetConn -// that implements an io.Writer that sends to the client conn under test. -type testClientConnNetConnWriteToClient testClientConnNetConn - -func (w *testClientConnNetConnWriteToClient) Write(b []byte) (n int, err error) { - c := (*testClientConnNetConn)(w) - c.gate.lock() - defer c.unlock() - return c.toConn.Write(b) -} - -// testClientConnNetConnReadFromClient is a view on a testClientConnNetConn -// that implements an io.Reader that reads data sent by the client conn under test. -type testClientConnNetConnReadFromClient testClientConnNetConn - -func (w *testClientConnNetConnReadFromClient) Read(b []byte) (n int, err error) { - c := (*testClientConnNetConn)(w) - c.gate.lock() - defer c.unlock() - return c.fromConn.Read(b) -} - // A testTransport allows testing Transport.RoundTrip against fake servers. // Tests that aren't specifically exercising RoundTrip's retry loop or connection pooling // should use testClientConn instead. @@ -861,7 +779,7 @@ func newTestTransport(t *testing.T, opts ...func(*Transport)) *testTransport { buf := make([]byte, 16*1024) n := runtime.Stack(buf, true) t.Logf("stacks:\n%s", buf[:n]) - t.Fatalf("%v goroutines still running after test completed, expect 1", count-1) + t.Fatalf("%v goroutines still running after test completed, expect 1", count) } }) diff --git a/http2/netconn_test.go b/http2/netconn_test.go new file mode 100644 index 000000000..8a61fbef1 --- /dev/null +++ b/http2/netconn_test.go @@ -0,0 +1,350 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http2 + +import ( + "bytes" + "context" + "errors" + "io" + "math" + "net" + "net/netip" + "os" + "sync" + "time" +) + +// synctestNetPipe creates an in-memory, full duplex network connection. +// Read and write timeouts are managed by the synctest group. +// +// Unlike net.Pipe, the connection is not synchronous. +// Writes are made to a buffer, and return immediately. +// By default, the buffer size is unlimited. +func synctestNetPipe(group *synctestGroup) (r, w *synctestNetConn) { + s1addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:8000")) + s2addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:8001")) + s1 := newSynctestNetConnHalf(s1addr) + s2 := newSynctestNetConnHalf(s2addr) + return &synctestNetConn{group: group, loc: s1, rem: s2}, + &synctestNetConn{group: group, loc: s2, rem: s1} +} + +// A synctestNetConn is one endpoint of the connection created by synctestNetPipe. +type synctestNetConn struct { + group *synctestGroup + + // local and remote connection halves. + // Each half contains a buffer. + // Reads pull from the local buffer, and writes push to the remote buffer. + loc, rem *synctestNetConnHalf + + // When set, group.Wait is automatically called before reads and after writes. + autoWait bool +} + +// Read reads data from the connection. +func (c *synctestNetConn) Read(b []byte) (n int, err error) { + if c.autoWait { + c.group.Wait() + } + return c.loc.read(b) +} + +// Peek returns the available unread read buffer, +// without consuming its contents. +func (c *synctestNetConn) Peek() []byte { + if c.autoWait { + c.group.Wait() + } + return c.loc.peek() +} + +// Write writes data to the connection. +func (c *synctestNetConn) Write(b []byte) (n int, err error) { + if c.autoWait { + defer c.group.Wait() + } + return c.rem.write(b) +} + +// IsClosed reports whether the peer has closed its end of the connection. +func (c *synctestNetConn) IsClosedByPeer() bool { + if c.autoWait { + c.group.Wait() + } + return c.loc.isClosedByPeer() +} + +// Close closes the connection. +func (c *synctestNetConn) Close() error { + c.loc.setWriteError(errors.New("connection closed by peer")) + c.rem.setReadError(io.EOF) + if c.autoWait { + c.group.Wait() + } + return nil +} + +// LocalAddr returns the (fake) local network address. +func (c *synctestNetConn) LocalAddr() net.Addr { + return c.loc.addr +} + +// LocalAddr returns the (fake) remote network address. +func (c *synctestNetConn) RemoteAddr() net.Addr { + return c.rem.addr +} + +// SetDeadline sets the read and write deadlines for the connection. +func (c *synctestNetConn) SetDeadline(t time.Time) error { + c.SetReadDeadline(t) + c.SetWriteDeadline(t) + return nil +} + +// SetReadDeadline sets the read deadline for the connection. +func (c *synctestNetConn) SetReadDeadline(t time.Time) error { + c.loc.rctx.setDeadline(c.group, t) + return nil +} + +// SetWriteDeadline sets the write deadline for the connection. +func (c *synctestNetConn) SetWriteDeadline(t time.Time) error { + c.rem.wctx.setDeadline(c.group, t) + return nil +} + +// SetReadBufferSize sets the read buffer limit for the connection. +// Writes by the peer will block so long as the buffer is full. +func (c *synctestNetConn) SetReadBufferSize(size int) { + c.loc.setReadBufferSize(size) +} + +// synctestNetConnHalf is one data flow in the connection created by synctestNetPipe. +// Each half contains a buffer. Writes to the half push to the buffer, and reads pull from it. +type synctestNetConnHalf struct { + addr net.Addr + + // Read and write timeouts. + rctx, wctx deadlineContext + + // A half can be readable and/or writable. + // + // These four channels act as a lock, + // and allow waiting for readability/writability. + // When the half is unlocked, exactly one channel contains a value. + // When the half is locked, all channels are empty. + lockr chan struct{} // readable + lockw chan struct{} // writable + lockrw chan struct{} // readable and writable + lockc chan struct{} // neither readable nor writable + + bufMax int // maximum buffer size + buf bytes.Buffer + readErr error // error returned by reads + writeErr error // error returned by writes +} + +func newSynctestNetConnHalf(addr net.Addr) *synctestNetConnHalf { + h := &synctestNetConnHalf{ + addr: addr, + lockw: make(chan struct{}, 1), + lockr: make(chan struct{}, 1), + lockrw: make(chan struct{}, 1), + lockc: make(chan struct{}, 1), + bufMax: math.MaxInt, // unlimited + } + h.unlock() + return h +} + +func (h *synctestNetConnHalf) lock() { + select { + case <-h.lockw: + case <-h.lockr: + case <-h.lockrw: + case <-h.lockc: + } +} + +func (h *synctestNetConnHalf) unlock() { + canRead := h.readErr != nil || h.buf.Len() > 0 + canWrite := h.writeErr != nil || h.bufMax > h.buf.Len() + switch { + case canRead && canWrite: + h.lockrw <- struct{}{} + case canRead: + h.lockr <- struct{}{} + case canWrite: + h.lockw <- struct{}{} + default: + h.lockc <- struct{}{} + } +} + +func (h *synctestNetConnHalf) readWaitAndLock() error { + select { + case <-h.lockr: + return nil + case <-h.lockrw: + return nil + default: + } + ctx := h.rctx.context() + select { + case <-h.lockr: + return nil + case <-h.lockrw: + return nil + case <-ctx.Done(): + return context.Cause(ctx) + } +} + +func (h *synctestNetConnHalf) writeWaitAndLock() error { + select { + case <-h.lockw: + return nil + case <-h.lockrw: + return nil + default: + } + ctx := h.wctx.context() + select { + case <-h.lockw: + return nil + case <-h.lockrw: + return nil + case <-ctx.Done(): + return context.Cause(ctx) + } +} + +func (h *synctestNetConnHalf) peek() []byte { + h.lock() + defer h.unlock() + return h.buf.Bytes() +} + +func (h *synctestNetConnHalf) isClosedByPeer() bool { + h.lock() + defer h.unlock() + return h.readErr != nil +} + +func (h *synctestNetConnHalf) read(b []byte) (n int, err error) { + if err := h.readWaitAndLock(); err != nil { + return 0, err + } + defer h.unlock() + if h.buf.Len() == 0 && h.readErr != nil { + return 0, h.readErr + } + return h.buf.Read(b) +} + +func (h *synctestNetConnHalf) setReadBufferSize(size int) { + h.lock() + defer h.unlock() + h.bufMax = size +} + +func (h *synctestNetConnHalf) write(b []byte) (n int, err error) { + for n < len(b) { + nn, err := h.writePartial(b[n:]) + n += nn + if err != nil { + return n, err + } + } + return n, nil +} + +func (h *synctestNetConnHalf) writePartial(b []byte) (n int, err error) { + if err := h.writeWaitAndLock(); err != nil { + return 0, err + } + defer h.unlock() + if h.writeErr != nil { + return 0, h.writeErr + } + writeMax := h.bufMax - h.buf.Len() + if writeMax < len(b) { + b = b[:writeMax] + } + return h.buf.Write(b) +} + +func (h *synctestNetConnHalf) setReadError(err error) { + h.lock() + defer h.unlock() + if h.readErr == nil { + h.readErr = err + } +} + +func (h *synctestNetConnHalf) setWriteError(err error) { + h.lock() + defer h.unlock() + if h.writeErr == nil { + h.writeErr = err + } +} + +// deadlineContext converts a changable deadline (as in net.Conn.SetDeadline) into a Context. +type deadlineContext struct { + mu sync.Mutex + ctx context.Context + cancel context.CancelCauseFunc + timer timer +} + +// context returns a Context which expires when the deadline does. +func (t *deadlineContext) context() context.Context { + t.mu.Lock() + defer t.mu.Unlock() + if t.ctx == nil { + t.ctx, t.cancel = context.WithCancelCause(context.Background()) + } + return t.ctx +} + +// setDeadline sets the current deadline. +func (t *deadlineContext) setDeadline(group *synctestGroup, deadline time.Time) { + t.mu.Lock() + defer t.mu.Unlock() + // If t.ctx is non-nil and t.cancel is nil, then t.ctx was canceled + // and we should create a new one. + if t.ctx == nil || t.cancel == nil { + t.ctx, t.cancel = context.WithCancelCause(context.Background()) + } + // Stop any existing deadline from expiring. + if t.timer != nil { + t.timer.Stop() + } + if deadline.IsZero() { + // No deadline. + return + } + if !deadline.After(group.Now()) { + // Deadline has already expired. + t.cancel(os.ErrDeadlineExceeded) + t.cancel = nil + return + } + if t.timer != nil { + // Reuse existing deadline timer. + t.timer.Reset(deadline.Sub(group.Now())) + return + } + // Create a new timer to cancel the context at the deadline. + t.timer = group.AfterFunc(deadline.Sub(group.Now()), func() { + t.mu.Lock() + defer t.mu.Unlock() + t.cancel(os.ErrDeadlineExceeded) + t.cancel = nil + }) +} diff --git a/http2/sync_test.go b/http2/sync_test.go index 3f5cf31f1..bcbbe66ac 100644 --- a/http2/sync_test.go +++ b/http2/sync_test.go @@ -166,6 +166,13 @@ func (g *synctestGroup) AdvanceTime(d time.Duration) { } } +// Now returns the current synthetic time. +func (g *synctestGroup) Now() time.Time { + g.mu.Lock() + defer g.mu.Unlock() + return g.now +} + // TimeUntilEvent returns the amount of time until the next scheduled timer. func (g *synctestGroup) TimeUntilEvent() (d time.Duration, scheduled bool) { g.mu.Lock() diff --git a/http2/transport_test.go b/http2/transport_test.go index 2171359ca..d62407b47 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -2491,7 +2491,7 @@ func testTransportUsesGoAwayDebugError(t *testing.T, failMidBody bool) { // the interesting parts of both. tc.writeGoAway(5, ErrCodeNo, []byte(goAwayDebugData)) tc.writeGoAway(5, goAwayErrCode, nil) - tc.closeWrite(io.EOF) + tc.closeWrite() res, err := rt.result() whence := "RoundTrip" @@ -5151,7 +5151,7 @@ func testTransportClosesConnAfterGoAway(t *testing.T, lastStream uint32) { }) } - tc.closeWrite(io.EOF) + tc.closeWrite() err := rt.err() if gotErr, wantErr := err != nil, lastStream == 0; gotErr != wantErr { t.Errorf("RoundTrip got error %v (want error: %v)", err, wantErr)