diff --git a/pkg/rpc/context.go b/pkg/rpc/context.go index f2e5ac2d7fff..b6a4ecf73cbb 100644 --- a/pkg/rpc/context.go +++ b/pkg/rpc/context.go @@ -28,8 +28,8 @@ import ( "github.com/rubyist/circuitbreaker" "golang.org/x/net/context" "google.golang.org/grpc" - "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/keepalive" "github.com/cockroachdb/cockroach/pkg/base" "github.com/cockroachdb/cockroach/pkg/roachpb" @@ -247,7 +247,7 @@ func (ctx *Context) GRPCDial(target string, opts ...grpc.DialOption) (*grpc.Clie dialOpt = grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)) } - dialOpts := make([]grpc.DialOption, 0, 2+len(opts)) + var dialOpts []grpc.DialOption dialOpts = append(dialOpts, dialOpt) dialOpts = append(dialOpts, grpc.WithBackoffMaxDelay(maxBackoff)) dialOpts = append(dialOpts, grpc.WithDecompressor(snappyDecompressor{})) @@ -256,6 +256,17 @@ func (ctx *Context) GRPCDial(target string, opts ...grpc.DialOption) (*grpc.Clie if ctx.rpcCompression { dialOpts = append(dialOpts, grpc.WithCompressor(snappyCompressor{})) } + dialOpts = append(dialOpts, grpc.WithKeepaliveParams(keepalive.ClientParameters{ + // Send periodic pings on the connection. + Time: base.NetworkTimeout, + // If the pings don't get a response within the timeout, we might be + // experiencing a network partition. gRPC will close the transport-level + // connection and all the pending RPCs (which may not have timeouts) will + // fail eagerly. gRPC will then reconnect the transport transparently. + Timeout: base.NetworkTimeout, + // Do the pings even when there are no ongoing RPCs. + PermitWithoutStream: true, + })) dialOpts = append(dialOpts, opts...) if SourceAddr != nil { @@ -359,24 +370,6 @@ func (ctx *Context) runHeartbeat(meta *connMeta, remoteAddr string) error { meta.heartbeatErr = err ctx.conns.Unlock() - // If we got a timeout, we might be experiencing a network partition. We - // close the connection so that all other pending RPCs (which may not have - // timeouts) fail eagerly. Any other error is likely to be noticed by - // other RPCs, so it's OK to leave the connection open while grpc - // internally reconnects if necessary. - // - // NB: This check is skipped when the connection is initiated from a CLI - // client since those clients aren't sensitive to partitions, are likely - // to be invoked while the server is starting (particularly in tests), and - // are not equipped with the retry logic necessary to deal with this - // connection termination. - // - // TODO(tamird): That we rely on the zero maxOffset to indicate a CLI - // client is a hack; we should do something more explicit. - if maxOffset != 0 && grpc.Code(err) == codes.DeadlineExceeded { - return err - } - // HACK: work around https://github.com/grpc/grpc-go/issues/1026 // Getting a "connection refused" error from the "write" system call // has confused grpc's error handling and this connection is permanently diff --git a/pkg/rpc/context_test.go b/pkg/rpc/context_test.go index 6ea87f8bf055..b78826855118 100644 --- a/pkg/rpc/context_test.go +++ b/pkg/rpc/context_test.go @@ -17,6 +17,7 @@ package rpc import ( + "math" "net" "runtime" "sync" @@ -25,9 +26,11 @@ import ( "time" "github.com/pkg/errors" + "golang.org/x/net/context" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/keepalive" "github.com/cockroachdb/cockroach/pkg/testutils" "github.com/cockroachdb/cockroach/pkg/util" @@ -551,3 +554,127 @@ func TestRemoteOffsetUnhealthy(t *testing.T) { } } } + +// This is a smoketest for gRPC Keepalives: rpc.Context asks gRPC to perform +// periodic pings on the transport to check that it's still alive. If the ping +// doesn't get a pong within a timeout, the transport is supposed to be closed - +// that's what we're testing here. +func TestGRPCKeepaliveFailureFailsInflightRPCs(t *testing.T) { + defer leaktest.AfterTest(t)() + + stopper := stop.NewStopper() + defer stopper.Stop() + + clock := hlc.NewClock(time.Unix(0, 20).UnixNano, time.Nanosecond) + serverCtx := NewContext( + log.AmbientContext{}, + testutils.NewNodeTestBaseContext(), + clock, + stopper, + ) + s, ln := newTestServer(t, serverCtx, true) + remoteAddr := ln.Addr().String() + + RegisterHeartbeatServer(s, &HeartbeatService{ + clock: clock, + remoteClockMonitor: serverCtx.RemoteClocks, + }) + + clientCtx := NewContext( + log.AmbientContext{}, testutils.NewNodeTestBaseContext(), clock, stopper) + // Disable automatic heartbeats. We'll send them by hand. + clientCtx.heartbeatInterval = math.MaxInt64 + + var firstConn int32 = 1 + + // We're going to open RPC transport connections using a dialer that returns + // PartitionableConns. We'll partition the first opened connection. + dialerCh := make(chan *testutils.PartitionableConn, 1) + conn, err := clientCtx.GRPCDial(remoteAddr, + grpc.WithDialer( + func(addr string, timeout time.Duration) (net.Conn, error) { + if !atomic.CompareAndSwapInt32(&firstConn, 1, 0) { + // If we allow gRPC to open a 2nd transport connection, then our RPCs + // might succeed if they're sent on that one. In the spirit of a + // partition, we'll return errors for the attempt to open a new + // connection (albeit for a TCP connection the error would come after + // a socket connect timeout). + return nil, errors.Errorf("No more connections for you. We're partitioned.") + } + + conn, err := net.DialTimeout("tcp", addr, timeout) + if err != nil { + return nil, err + } + transportConn := testutils.NewPartitionableConn(conn) + dialerCh <- transportConn + return transportConn, nil + }), + // Override the keepalive settings that the rpc.Context uses to more + // aggressive ones, so that the test doesn't take long. + grpc.WithKeepaliveParams( + keepalive.ClientParameters{ + // The aggressively low timeout we set here makes the connection very + // flaky for any RPC use, particularly when running under stress with -p + // 100. This test can't expect any RPCs to succeed reliably. + Time: time.Millisecond, + Timeout: 5 * time.Millisecond, + PermitWithoutStream: false, + }), + ) + if err != nil { + t.Fatal(err) + } + defer func() { _ = conn.Close() }() + + // We'll expect any of the errors which tests revealed that the RPC call might + // return when an RPC's transport connection is closed because of the + // heartbeats timing out. + gRPCErrorsRegex := "transport is closing|" + + "rpc error: code = Unavailable desc = grpc: the connection is unavailable|" + + "rpc error: code = Internal desc = transport: io: read/write on closed pipe|" + + "rpc error: code = Internal desc = transport: tls: use of closed connection|" + + "rpc error: code = Internal desc = transport: EOF|" + + "use of closed network connection" + + // Perform an RPC so that a connection gets opened. In theory this RPC should + // succeed (and it does when running without too much stress), but we can't + // rely on that - see comment on the timeout above. + heartbeatClient := NewHeartbeatClient(conn) + request := PingRequest{} + if _, err := heartbeatClient.Ping(context.TODO(), &request); err != nil { + if !testutils.IsError(err, gRPCErrorsRegex) { + t.Fatal(err) + } + // In the rare eventuality that we got the expected error, this test + // succeeded: even though we didn't partition the connection, the low gRPC + // keepalive timeout caused our RPC to fail (happens occasionally under + // stress -p 100). We're going to let the rest of the test code run, to make + // sure it's exercised. + // If the heartbeats didn't timeout (the normal case), we're going to + // simulate a network partition and then the heartbeats must timeout. + log.Infof(context.TODO(), "test returning early; no partition done") + } + + // Now partition client->server and attempt to perform an RPC. We expect it to + // fail once the grpc keepalive fails to get a response from the server. + + transportConn := <-dialerCh + defer transportConn.Finish() + + transportConn.PartitionC2S() + + if _, err := heartbeatClient.Ping(context.TODO(), &request); !testutils.IsError( + err, gRPCErrorsRegex) { + t.Fatal(err) + } + + // If the DialOptions we passed to gRPC didn't prevent it from opening new + // connections, then next RPCs would succeed since gRPC reconnects the + // transport (and that would succeed here since we've only partitioned one + // connection). We could further test that the status reported by + // Context.ConnHealth() for the remote node moves to UNAVAILABLE because of + // the (application-level) heartbeats performed by rpc.Context, but the + // behaviour of our heartbeats in the face of transport failures is + // sufficiently tested in TestHeartbeatHealthTransport. +} diff --git a/pkg/testutils/net.go b/pkg/testutils/net.go new file mode 100644 index 000000000000..4fb8c8f70f99 --- /dev/null +++ b/pkg/testutils/net.go @@ -0,0 +1,427 @@ +// Copyright 2017 The Cockroach Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the License. +// +// Author: Andrei Matei (andreimatei1@gmail.com) + +package testutils + +import ( + "io" + "net" + "sync" + + "github.com/pkg/errors" + "golang.org/x/net/context" + + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/syncutil" +) + +// bufferSize is the size of the buffer used by PartitionableConn. Writes to a +// partitioned connection will block after the buffer gets filled. +const bufferSize = 16 << 10 // 16 KB + +// PartitionableConn is an implementation of net.Conn that allows the +// client->server and/or the server->client directions to be temporarily +// partitioned. +// +// A PartitionableConn wraps a provided net.Conn (the serverConn member) and +// forwards every read and write to it. It interposes an arbiter in front of it +// that's used to block reads/writes while the PartitionableConn is in the +// partitioned mode. +// +// While a direction is partitioned, data sent in that direction doesn't flow. A +// write while partitioned will block after an internal buffer gets filled. Data +// written to the conn after the partition has been established is not delivered +// to the remote party until the partition is lifted. At that time, all the +// buffered data is delivered. Since data is delivered async, data written +// before the partition is established may or may not be blocked by the +// partition; use application-level ACKs if that's important. +type PartitionableConn struct { + // We embed a net.Conn so that we inherit the interface. Note that we override + // Read() and Write(). + // + // This embedded Conn is half of a net.Pipe(). The other half is clientConn. + net.Conn + + clientConn net.Conn + serverConn net.Conn + + mu struct { + syncutil.Mutex + + // err, if set, is returned by any subsequent call to Read or Write. + err error + + // Are any of the two direction (client-to-server, server-to-client) + // currently partitioned? + c2sPartitioned bool + s2cPartitioned bool + + c2sBuffer buf + s2cBuffer buf + + // Conds to be signaled when the corresponding partition is lifted. + c2sWaiter *sync.Cond + s2cWaiter *sync.Cond + } +} + +type buf struct { + // A mutex used to synchronize access to all the fields. It will be set to the + // parent PartitionableConn's mutex. + *syncutil.Mutex + + data []byte + capacity int + closed bool + // The error that was passed to Close(err). See Close() for more info. + closedErr error + name string // A human-readable name, useful for debugging. + + // readerWait is signaled when the reader should wake up and check the + // buffer's state: when new data is put in the buffer, when the buffer is + // closed, and whenever the PartitionableConn wants to unblock all reads (i.e. + // on partition). + readerWait *sync.Cond + + // capacityWait is signaled when a blocked writer should wake up because data + // is taken out of the buffer and there's now some capacity. It's also + // signaled when the buffer is closed. + capacityWait *sync.Cond +} + +func makeBuf(name string, capacity int, mu *syncutil.Mutex) buf { + b := buf{ + Mutex: mu, + name: name, + capacity: capacity, + } + b.readerWait = sync.NewCond(b.Mutex) + b.capacityWait = sync.NewCond(b.Mutex) + return b +} + +// Write adds data to the buffer. If there's zero free capacity, it will block +// until there's some capacity available or the buffer is closed. If there's +// non-zero insufficient capacity, it will perform a partial write. +// +// The number of bytes written is returned. +func (b *buf) Write(data []byte) (int, error) { + b.Lock() + defer b.Unlock() + for b.capacity == len(b.data) && !b.closed { + // Block for capacity. + b.capacityWait.Wait() + } + if b.closed { + return 0, b.closedErr + } + available := b.capacity - len(b.data) + toCopy := available + if len(data) < available { + toCopy = len(data) + } + b.data = append(b.data, data[:toCopy]...) + b.wakeReaderLocked() + return toCopy, nil +} + +// errEAgain is returned by buf.readLocked() when the read was blocked at the +// time when buf.readerWait was signalled (in particular, after the +// PartitionableConn interrupted the read because of a partition). The caller is +// expected to try the read again after the partition is gone. +var errEAgain = errors.New("try read again") + +// readLocked returns data from buf, up to "size" bytes. If there's no data in +// the buffer, it blocks until either some data becomes available or the buffer +// is closed. +func (b *buf) readLocked(size int) ([]byte, error) { + if len(b.data) == 0 && !b.closed { + b.readerWait.Wait() + // We were unblocked either by data arrving, or by a partition, or by + // another uninteresting reason. Return to the caller, in case it's because + // of a partition. + return nil, errEAgain + } + if b.closed && len(b.data) == 0 { + return nil, b.closedErr + } + var ret []byte + if len(b.data) < size { + ret = b.data + b.data = nil + } else { + ret = b.data[:size] + b.data = b.data[size:] + } + b.capacityWait.Signal() + return ret, nil +} + +// Close closes the buffer. All reads and writes that are currently blocked will +// be woken and they'll all return err. +func (b *buf) Close(err error) { + b.Lock() + b.closed = true + b.closedErr = err + b.readerWait.Signal() + b.capacityWait.Signal() + b.Unlock() +} + +// wakeReaderLocked wakes the reader in case it's blocked. +// See comments on readerWait. +// +// This needs to be called while holding the buffer's mutex. +func (b *buf) wakeReaderLocked() { + b.readerWait.Signal() +} + +// NewPartitionableConn wraps serverConn in a PartitionableConn. +func NewPartitionableConn(serverConn net.Conn) *PartitionableConn { + clientEnd, clientConn := net.Pipe() + c := &PartitionableConn{ + Conn: clientEnd, + clientConn: clientConn, + serverConn: serverConn, + } + c.mu.c2sWaiter = sync.NewCond(&c.mu.Mutex) + c.mu.s2cWaiter = sync.NewCond(&c.mu.Mutex) + c.mu.c2sBuffer = makeBuf("c2sBuf", bufferSize, &c.mu.Mutex) + c.mu.s2cBuffer = makeBuf("s2cBuf", bufferSize, &c.mu.Mutex) + + // Start copying from client to server. + go func() { + err := c.copy( + c.clientConn, // src + c.serverConn, // dst + &c.mu.c2sBuffer, + func() { // waitForNoPartitionLocked + for c.mu.c2sPartitioned { + c.mu.c2sWaiter.Wait() + } + }) + c.mu.Lock() + c.mu.err = err + c.mu.Unlock() + if err := c.clientConn.Close(); err != nil { + log.Errorf(context.TODO(), "unexpected error closing internal pipe: %s", err) + } + if err := c.serverConn.Close(); err != nil { + log.Errorf(context.TODO(), "error closing server conn: %s", err) + } + }() + + // Start copying from server to client. + go func() { + err := c.copy( + c.serverConn, // src + c.clientConn, // dst + &c.mu.s2cBuffer, + func() { // waitForNoPartitionLocked + for c.mu.s2cPartitioned { + c.mu.s2cWaiter.Wait() + } + }) + c.mu.Lock() + c.mu.err = err + c.mu.Unlock() + if err := c.clientConn.Close(); err != nil { + log.Fatalf(context.TODO(), "unexpected error closing internal pipe: %s", err) + } + if err := c.serverConn.Close(); err != nil { + log.Errorf(context.TODO(), "error closing server conn: %s", err) + } + }() + + return c +} + +// Finish removes any partitions that may exist so that blocked goroutines can +// finish. +// Finish() must be called if a connection may have been left in a partitioned +// state. +func (c *PartitionableConn) Finish() { + c.mu.Lock() + c.mu.c2sPartitioned = false + c.mu.c2sWaiter.Signal() + c.mu.s2cPartitioned = false + c.mu.s2cWaiter.Signal() + c.mu.Unlock() +} + +// PartitionC2S partitions the client-to-server direction. +// If UnpartitionC2S() is not called, Finish() must be called. +func (c *PartitionableConn) PartitionC2S() { + c.mu.Lock() + if c.mu.c2sPartitioned { + panic("already partitioned") + } + c.mu.c2sPartitioned = true + c.mu.c2sBuffer.wakeReaderLocked() + c.mu.Unlock() +} + +// UnpartitionC2S lifts an existing client-to-server partition. +func (c *PartitionableConn) UnpartitionC2S() { + c.mu.Lock() + if !c.mu.c2sPartitioned { + panic("not partitioned") + } + c.mu.c2sPartitioned = false + c.mu.c2sWaiter.Signal() + c.mu.Unlock() +} + +// PartitionS2C partitions the server-to-client direction. +// If UnpartitionS2C() is not called, Finish() must be called. +func (c *PartitionableConn) PartitionS2C() { + c.mu.Lock() + if c.mu.s2cPartitioned { + panic("already partitioned") + } + c.mu.s2cPartitioned = true + c.mu.s2cBuffer.wakeReaderLocked() + c.mu.Unlock() +} + +// UnpartitionS2C lifts an existing server-to-client partition. +func (c *PartitionableConn) UnpartitionS2C() { + c.mu.Lock() + if !c.mu.s2cPartitioned { + panic("not partitioned") + } + c.mu.s2cPartitioned = false + c.mu.s2cWaiter.Signal() + c.mu.Unlock() +} + +// Read is part of the net.Conn interface. +func (c *PartitionableConn) Read(b []byte) (n int, err error) { + c.mu.Lock() + err = c.mu.err + c.mu.Unlock() + if err != nil { + return 0, err + } + + // Forward to the embedded connection. + return c.Conn.Read(b) +} + +// Write is part of the net.Conn interface. +func (c *PartitionableConn) Write(b []byte) (n int, err error) { + c.mu.Lock() + err = c.mu.err + c.mu.Unlock() + if err != nil { + return 0, err + } + + // Forward to the embedded connection. + return c.Conn.Write(b) +} + +// readFrom copies data from src into the buffer until src.Read() returns an +// error (e.g. io.EOF). That error is returned. +// +// readFrom is written in the spirit of interface io.ReaderFrom, except it +// returns the io.EOF error, and also doesn't guarantee that every byte that has +// been read from src is put into the buffer (as the buffer allows concurrent +// access and buf.Write can return an error). +func (b *buf) readFrom(src io.Reader) error { + data := make([]byte, 1024) + for { + nr, err := src.Read(data) + if err != nil { + return err + } + toSend := data[:nr] + for { + nw, ew := b.Write(toSend) + if ew != nil { + return ew + } + if nw == len(toSend) { + break + } + toSend = toSend[nw:] + } + } +} + +// copyFromBuffer copies data from src to dst until src.Read() returns EOF. +// The EOF is returned (i.e. the return value is always != nil). This is because +// the PartitionableConn wants to hold on to any error, including EOF. +// +// waitForNoPartitionLocked is a function to be called before consuming data +// from src, in order to make sure that we only consume data when we're not +// partitioned. It needs to be called under src.Mutex, as the check needs to be +// done atomically with consuming the buffer's data. +func (c *PartitionableConn) copyFromBuffer( + src *buf, dst net.Conn, waitForNoPartitionLocked func(), +) error { + for { + // Don't read from the buffer while we're partitioned. + src.Mutex.Lock() + waitForNoPartitionLocked() + data, err := src.readLocked(1024 * 1024) + src.Mutex.Unlock() + + if len(data) > 0 { + nw, ew := dst.Write(data) + if ew != nil { + err = ew + } + if len(data) != nw { + err = io.ErrShortWrite + } + } else if err == nil { + err = io.EOF + } else if err == errEAgain { + continue + } + if err != nil { + return err + } + } +} + +// copy copies data from src to dst while we're not partitioned and stops doing +// so while partitioned. +// +// It runs two goroutines internally: one copying from src to an internal buffer +// and one copying from the buffer to dst. The 2nd one deals with partitions. +func (c *PartitionableConn) copy( + src net.Conn, dst net.Conn, buf *buf, waitForNoPartitionLocked func(), +) error { + tasks := make(chan error) + go func() { + err := buf.readFrom(src) + buf.Close(err) + tasks <- err + }() + go func() { + err := c.copyFromBuffer(buf, dst, waitForNoPartitionLocked) + buf.Close(err) + tasks <- err + }() + err := <-tasks + err2 := <-tasks + if err == nil { + err = err2 + } + return err +} diff --git a/pkg/testutils/net_test.go b/pkg/testutils/net_test.go new file mode 100644 index 000000000000..feb5174b1887 --- /dev/null +++ b/pkg/testutils/net_test.go @@ -0,0 +1,420 @@ +// Copyright 2017 The Cockroach Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the License. +// +// Author: Andrei Matei (andreimatei1@gmail.com) + +package testutils + +import ( + "bufio" + "fmt" + "io" + "net" + "testing" + "time" + + "github.com/pkg/errors" + + "github.com/cockroachdb/cockroach/pkg/util" + "github.com/cockroachdb/cockroach/pkg/util/grpcutil" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/netutil" +) + +// RunEchoServer runs a network server that accepts one connection from ln and +// echos the data sent on it. +// +// If serverSideCh != nil, every slice of data received by the server is also +// sent on this channel before being echoed back on the connection it came on. +// Useful to observe what the server has received when this server is used with +// partitioned connections. +func RunEchoServer(ln net.Listener, serverSideCh chan<- []byte) error { + conn, err := ln.Accept() + if err != nil { + if grpcutil.IsClosedConnection(err) { + return nil + } + return err + } + if _, err := copyWithSideChan(conn, conn, serverSideCh); err != nil { + return err + } + return nil +} + +// copyWithSideChan is like io.Copy(), but also takes a channel on which data +// read from src is sent before being written to dst. +func copyWithSideChan(dst io.Writer, src io.Reader, ch chan<- []byte) (written int64, err error) { + buf := make([]byte, 32*1024) + for { + nr, er := src.Read(buf) + if nr > 0 { + if ch != nil { + ch <- buf[:nr] + } + + nw, ew := dst.Write(buf[0:nr]) + if nw > 0 { + written += int64(nw) + } + if ew != nil { + err = ew + break + } + if nr != nw { + err = io.ErrShortWrite + break + } + } + if er != nil { + if er != io.EOF { + err = er + } + break + } + } + return written, err +} + +func TestPartitionableConnBasic(t *testing.T) { + defer leaktest.AfterTest(t)() + addr := util.TestAddr + ln, err := net.Listen(addr.Network(), addr.String()) + if err != nil { + t.Fatal(err) + } + go func() { + if err := RunEchoServer(ln, nil); err != nil { + t.Error(err) + } + }() + defer func() { + netutil.FatalIfUnexpected(ln.Close()) + }() + + serverConn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + + pConn := NewPartitionableConn(serverConn) + defer pConn.Close() + + exp := "let's see if this value comes back\n" + fmt.Fprintf(pConn, exp) + got, err := bufio.NewReader(pConn).ReadString('\n') + if err != nil { + t.Fatal(err) + } + if got != exp { + t.Fatalf("expecting: %q , got %q", exp, got) + } +} + +func TestPartitionableConnPartitionC2S(t *testing.T) { + defer leaktest.AfterTest(t)() + + addr := util.TestAddr + ln, err := net.Listen(addr.Network(), addr.String()) + if err != nil { + t.Fatal(err) + } + serverSideCh := make(chan []byte) + go func() { + if err := RunEchoServer(ln, serverSideCh); err != nil { + t.Error(err) + } + }() + defer func() { + netutil.FatalIfUnexpected(ln.Close()) + }() + + serverConn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + + pConn := NewPartitionableConn(serverConn) + defer pConn.Close() + + // Partition the client->server connection. Afterwards, we're going to send + // something and assert that the server doesn't get it (within a timeout) by + // snooping on the server's side channel. Then we'll resolve the partition and + // expect that the server gets the message that was pending and echoes it + // back. + + pConn.PartitionC2S() + + // Client sends data. + exp := "let's see when this value comes back\n" + fmt.Fprintf(pConn, exp) + + // In the background, the client waits on a read. + clientDoneCh := make(chan error) + go func() { + clientDoneCh <- func() error { + got, err := bufio.NewReader(pConn).ReadString('\n') + if err != nil { + return err + } + if got != exp { + return errors.Errorf("expecting: %q , got %q", exp, got) + } + return nil + }() + }() + + timerDoneCh := make(chan error) + time.AfterFunc(3*time.Millisecond, func() { + var err error + select { + case err = <-clientDoneCh: + err = errors.Errorf("unexpected reply while partitioned: %v", err) + case buf := <-serverSideCh: + err = errors.Errorf("server was not supposed to have received data while partitioned: %q", buf) + default: + } + timerDoneCh <- err + }) + + if err := <-timerDoneCh; err != nil { + t.Fatal(err) + } + + // Now unpartition and expect the pending data to be sent and a reply to be + // received. + + pConn.UnpartitionC2S() + + // Expect the server to receive the data. + <-serverSideCh + + if err := <-clientDoneCh; err != nil { + t.Fatal(err) + } +} + +func TestPartitionableConnPartitionS2C(t *testing.T) { + defer leaktest.AfterTest(t)() + + addr := util.TestAddr + ln, err := net.Listen(addr.Network(), addr.String()) + if err != nil { + t.Fatal(err) + } + serverSideCh := make(chan []byte) + go func() { + if err := RunEchoServer(ln, serverSideCh); err != nil { + t.Error(err) + } + }() + defer func() { + netutil.FatalIfUnexpected(ln.Close()) + }() + + serverConn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + + pConn := NewPartitionableConn(serverConn) + defer pConn.Close() + + // We're going to partition the server->client connection. Then we'll send + // some data and assert that the server gets it (by snooping on the server's + // side-channel). Then we'll assert that the client doesn't get the reply + // (with a timeout). Then we resolve the partition and assert that the client + // gets the reply. + + pConn.PartitionS2C() + + // Client sends data. + exp := "let's see when this value comes back\n" + fmt.Fprintf(pConn, exp) + + if s := <-serverSideCh; string(s) != exp { + t.Fatalf("expected server to receive %q, got %q", exp, s) + } + + // In the background, the client waits on a read. + clientDoneCh := make(chan error) + go func() { + clientDoneCh <- func() error { + got, err := bufio.NewReader(pConn).ReadString('\n') + if err != nil { + return err + } + if got != exp { + return errors.Errorf("expecting: %q , got %q", exp, got) + } + return nil + }() + }() + + // Check that the client does not get the server's response. + time.AfterFunc(3*time.Millisecond, func() { + select { + case err := <-clientDoneCh: + t.Errorf("unexpected reply while partitioned: %v", err) + default: + } + }) + + // Now unpartition and expect the pending data to be sent and a reply to be + // received. + + pConn.UnpartitionS2C() + + if err := <-clientDoneCh; err != nil { + t.Fatal(err) + } +} + +// Test that, while partitioned, a sender doesn't block while the internal +// buffer is not full. +func TestPartitionableConnBuffering(t *testing.T) { + defer leaktest.AfterTest(t)() + + addr := util.TestAddr + ln, err := net.Listen(addr.Network(), addr.String()) + if err != nil { + t.Fatal(err) + } + + // In the background, the server reads everything. + exp := 5 * (bufferSize / 10) + serverDoneCh := make(chan error) + go func() { + serverDoneCh <- func() error { + conn, err := ln.Accept() + if err != nil { + return err + } + received := 0 + for { + data := make([]byte, 1024*1024) + nr, err := conn.Read(data) + if err != nil { + if err == io.EOF { + break + } + return err + } + received += nr + } + if received != exp { + return errors.Errorf("server expecting: %d , got %d", exp, received) + } + return nil + }() + }() + + serverConn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + + pConn := NewPartitionableConn(serverConn) + defer pConn.Close() + + pConn.PartitionC2S() + defer pConn.Finish() + + // Send chunks such that they don't add up to the buffer size exactly. + data := make([]byte, bufferSize/10) + for i := 0; i < 5; i++ { + nw, err := pConn.Write(data) + if err != nil { + t.Fatal(err) + } + if nw != len(data) { + t.Fatal("unexpected partial write; PartitionableConn always writes fully") + } + } + pConn.UnpartitionC2S() + pConn.Close() + + if err := <-serverDoneCh; err != nil { + t.Fatal(err) + } +} + +// Test that, while partitioned, a party can close the connection and the other +// party will not observe this until after the partition is lifted. +func TestPartitionableConnCloseDeliveredAfterPartition(t *testing.T) { + defer leaktest.AfterTest(t)() + + addr := util.TestAddr + ln, err := net.Listen(addr.Network(), addr.String()) + if err != nil { + t.Fatal(err) + } + + // In the background, the server reads everything. + serverDoneCh := make(chan error) + go func() { + serverDoneCh <- func() error { + conn, err := ln.Accept() + if err != nil { + return err + } + received := 0 + for { + data := make([]byte, 1<<20 /* 1 MiB */) + nr, err := conn.Read(data) + if err != nil { + if err == io.EOF { + return nil + } + return err + } + received += nr + } + }() + }() + + serverConn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + + pConn := NewPartitionableConn(serverConn) + defer pConn.Close() + + pConn.PartitionC2S() + defer pConn.Finish() + + pConn.Close() + + timerDoneCh := make(chan error) + time.AfterFunc(3*time.Millisecond, func() { + var err error + select { + case err = <-serverDoneCh: + err = errors.Wrapf(err, "server was not supposed to see the closing while partitioned") + default: + } + timerDoneCh <- err + }) + + if err := <-timerDoneCh; err != nil { + t.Fatal(err) + } + + pConn.UnpartitionC2S() + + if err := <-serverDoneCh; err != nil { + t.Fatal(err) + } +}