Skip to content

Commit

Permalink
server,rpc: validate node IDs in RPC heartbeats
Browse files Browse the repository at this point in the history
Prior to this patch, it was possible for a RPC client to dial a node
ID and get a connection to another node instead. This is because the
mapping of node ID -> address may be stale, and a different node could
take the address of the intended node from "under" the dialer.

(See the previous commit for a scenario.)

This happened to be "safe" in many cases where it matters because:

- RPC requests for distSQL are OK with being served on a different
  node than intended (with potential performance drop);
- RPC requests to the KV layer are OK with being served on a different
  node than intended (they would route underneath);
- RPC requests to the storage layer are rejected by the
  remote node because the store ID in the request would not match.

However this safety is largely accidental, and we should not work with
the assumption that any RPC request is safe to be mis-routed. (In
fact, we have not audited all the RPC endpoints and cannot establish
this safety exists throughout.)

This patch works to prevent these mis-routings by adding a check of
the intended node ID during RPC heartbeats (including the initial
heartbeat), when the intended node ID is known. A new API
`GRPCDialNode()` is introduced to establish such connections.

Release note (bug fix): CockroachDB now performs fewer attempts to
communicate with the wrong node, when a node is restarted with another
node's address.
  • Loading branch information
knz committed Jan 23, 2019
1 parent 5bce267 commit 33afd7f
Show file tree
Hide file tree
Showing 17 changed files with 281 additions and 73 deletions.
1 change: 1 addition & 0 deletions pkg/gossip/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ func startGossipAtAddr(
registry *metric.Registry,
) *Gossip {
rpcContext := newInsecureRPCContext(stopper)
rpcContext.NodeID.Set(context.TODO(), nodeID)
server := rpc.NewServer(rpcContext)
g := NewTest(nodeID, rpcContext, server, stopper, registry)
ln, err := netutil.ListenAndServeGRPC(stopper, server, addr)
Expand Down
3 changes: 1 addition & 2 deletions pkg/gossip/gossip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ func TestGossipInfoStore(t *testing.T) {
}

// TestGossipMoveNode verifies that if a node is moved to a new address, it
// gets properly updated in gossip (including that any other node that was
// previously at that address gets removed from the cluster).
// gets properly updated in gossip.
func TestGossipMoveNode(t *testing.T) {
defer leaktest.AfterTest(t)()
stopper := stop.NewStopper()
Expand Down
81 changes: 64 additions & 17 deletions pkg/rpc/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ func NewServerWithInterceptor(
clock: ctx.LocalClock,
remoteClockMonitor: ctx.RemoteClocks,
clusterID: &ctx.ClusterID,
nodeID: &ctx.NodeID,
version: ctx.version,
})
return s
Expand All @@ -272,14 +273,20 @@ type Connection struct {
initialHeartbeatDone chan struct{} // closed after first heartbeat
stopper *stop.Stopper

// remoteNodeID implies checking the remote node ID. 0 when unknown,
// non-zero to check with remote node. This is constant throughout
// the lifetime of a Connection object.
remoteNodeID roachpb.NodeID

initOnce sync.Once
validatedOnce sync.Once
}

func newConnection(stopper *stop.Stopper) *Connection {
func newConnectionToNodeID(stopper *stop.Stopper, remoteNodeID roachpb.NodeID) *Connection {
c := &Connection{
initialHeartbeatDone: make(chan struct{}),
stopper: stopper,
remoteNodeID: remoteNodeID,
}
c.heartbeatResult.Store(heartbeatResult{err: ErrNotHeartbeated})
return c
Expand Down Expand Up @@ -346,13 +353,22 @@ type Context struct {
stats StatsHandler

ClusterID base.ClusterIDContainer
NodeID base.NodeIDContainer
version *cluster.ExposedClusterVersion

// For unittesting.
BreakerFactory func() *circuit.Breaker
testingDialOpts []grpc.DialOption
}

// connKey is used as key in the Context.conns map. Different remote
// node IDs get different *Connection objects, to ensure that we don't
// mis-route RPC requests.
type connKey struct {
targetAddr string
nodeID roachpb.NodeID
}

// NewContext creates an rpc Context with the supplied values.
func NewContext(
ambient log.AmbientContext,
Expand Down Expand Up @@ -396,7 +412,7 @@ func NewContext(
conn.dialErr = &roachpb.NodeUnavailableError{}
}
})
ctx.removeConn(k.(string), conn)
ctx.removeConn(k.(connKey), conn)
return true
})
})
Expand Down Expand Up @@ -518,15 +534,15 @@ func (ctx *Context) SetLocalInternalServer(internalServer roachpb.InternalServer
ctx.localInternalClient = internalClientAdapter{internalServer}
}

func (ctx *Context) removeConn(key string, conn *Connection) {
func (ctx *Context) removeConn(key connKey, conn *Connection) {
ctx.conns.Delete(key)
if log.V(1) {
log.Infof(ctx.masterCtx, "closing %s", key)
log.Infof(ctx.masterCtx, "closing %+v", key)
}
if grpcConn := conn.grpcConn; grpcConn != nil {
if err := grpcConn.Close(); err != nil && !grpcutil.IsClosedConnection(err) {
if log.V(1) {
log.Errorf(ctx.masterCtx, "failed to close client connection: %s", err)
log.Errorf(ctx.masterCtx, "failed to close client connection: %v", err)
}
}
}
Expand Down Expand Up @@ -650,10 +666,35 @@ func (ctx *Context) GRPCDialRaw(target string) (*grpc.ClientConn, <-chan struct{
}

// GRPCDial calls grpc.Dial with options appropriate for the context.
//
// It does not require validation of the node ID between client and server:
// if a connection existed already with some node ID requirement, that
// requirement will remain; if no connection existed yet,
// a new one is created without a node ID requirement.
func (ctx *Context) GRPCDial(target string) *Connection {
value, ok := ctx.conns.Load(target)
return ctx.GRPCDialNode(target, 0)
}

// GRPCDialNode calls grpc.Dial with options appropriate for the context.
//
// The remoteNodeID, if non-zero, becomes a constraint on the expected
// node ID of the remote node; this is checked during heartbeats.
func (ctx *Context) GRPCDialNode(target string, remoteNodeID roachpb.NodeID) *Connection {
thisConnKey := connKey{target, remoteNodeID}
value, ok := ctx.conns.Load(thisConnKey)
if !ok {
value, _ = ctx.conns.LoadOrStore(target, newConnection(ctx.Stopper))
value, _ = ctx.conns.LoadOrStore(thisConnKey, newConnectionToNodeID(ctx.Stopper, remoteNodeID))
if remoteNodeID != 0 {
// If the first connection established at a target address is
// for a specific node ID, then we want to reuse that connection
// also for other dials (eg for gossip) which don't require a
// specific node ID. (We do this as an optimization to reduce
// the number of TCP connections alive between nodes. This is
// not strictly required for correctness.) This LoadOrStore will
// ensure we're registering the connection we just created for
// future use by these other dials.
_, _ = ctx.conns.LoadOrStore(connKey{target, 0}, value)
}
}

conn := value.(*Connection)
Expand All @@ -668,11 +709,11 @@ func (ctx *Context) GRPCDial(target string) *Connection {
if err != nil && !grpcutil.IsClosedConnection(err) {
log.Errorf(masterCtx, "removing connection to %s due to error: %s", target, err)
}
ctx.removeConn(target, conn)
ctx.removeConn(thisConnKey, conn)
})
}); err != nil {
conn.dialErr = err
ctx.removeConn(target, conn)
ctx.removeConn(thisConnKey, conn)
}
}
})
Expand Down Expand Up @@ -703,6 +744,9 @@ var ErrNotHeartbeated = errors.New("not yet heartbeated")
// error will be returned. This method should therefore be used to
// prioritize among a list of candidate nodes, but not to filter out
// "unhealthy" nodes.
//
// This is used in tests only; in clusters use (*Dialer).ConnHealth()
// instead which validates the node ID.
func (ctx *Context) ConnHealth(target string) error {
if ctx.GetLocalInternalClientForAddr(target) != nil {
// The local server is always considered healthy.
Expand All @@ -716,14 +760,8 @@ func (ctx *Context) runHeartbeat(
conn *Connection, target string, redialChan <-chan struct{},
) error {
maxOffset := ctx.LocalClock.MaxOffset()
clusterID := ctx.ClusterID.Get()
maxOffsetNanos := maxOffset.Nanoseconds()

request := PingRequest{
Addr: ctx.Addr,
MaxOffsetNanos: maxOffset.Nanoseconds(),
ClusterID: &clusterID,
ServerVersion: ctx.version.ServerVersion,
}
heartbeatClient := NewHeartbeatClient(conn.grpcConn)

var heartbeatTimer timeutil.Timer
Expand All @@ -748,9 +786,18 @@ func (ctx *Context) runHeartbeat(
goCtx, cancel = context.WithTimeout(goCtx, hbTimeout)
}
sendTime := ctx.LocalClock.PhysicalTime()
// We re-mint the PingRequest to pick up any asynchronous update to clusterID.
clusterID := ctx.ClusterID.Get()
request := &PingRequest{
Addr: ctx.Addr,
MaxOffsetNanos: maxOffsetNanos,
ClusterID: &clusterID,
NodeID: conn.remoteNodeID,
ServerVersion: ctx.version.ServerVersion,
}
// NB: We want the request to fail-fast (the default), otherwise we won't
// be notified of transport failures.
response, err := heartbeatClient.Ping(goCtx, &request)
response, err := heartbeatClient.Ping(goCtx, request)
if cancel != nil {
cancel()
}
Expand Down
51 changes: 51 additions & 0 deletions pkg/rpc/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ func TestHeartbeatCB(t *testing.T) {
clock: clock,
remoteClockMonitor: serverCtx.RemoteClocks,
clusterID: &serverCtx.ClusterID,
nodeID: &serverCtx.NodeID,
version: serverCtx.version,
})

Expand Down Expand Up @@ -341,6 +342,7 @@ func TestHeartbeatHealthTransport(t *testing.T) {
clock: clock,
remoteClockMonitor: serverCtx.RemoteClocks,
clusterID: &serverCtx.ClusterID,
nodeID: &serverCtx.NodeID,
version: serverCtx.version,
})

Expand Down Expand Up @@ -515,6 +517,7 @@ func TestOffsetMeasurement(t *testing.T) {
clock: serverClock,
remoteClockMonitor: serverCtx.RemoteClocks,
clusterID: &serverCtx.ClusterID,
nodeID: &serverCtx.NodeID,
version: serverCtx.version,
})

Expand Down Expand Up @@ -682,6 +685,7 @@ func TestRemoteOffsetUnhealthy(t *testing.T) {
clock: clock,
remoteClockMonitor: nodeCtxs[i].ctx.RemoteClocks,
clusterID: &nodeCtxs[i].ctx.ClusterID,
nodeID: &nodeCtxs[i].ctx.NodeID,
version: nodeCtxs[i].ctx.version,
})
ln, err := netutil.ListenAndServeGRPC(nodeCtxs[i].ctx.Stopper, s, util.TestAddr)
Expand Down Expand Up @@ -829,6 +833,7 @@ func TestGRPCKeepaliveFailureFailsInflightRPCs(t *testing.T) {
clock: clock,
remoteClockMonitor: serverCtx.RemoteClocks,
clusterID: &serverCtx.ClusterID,
nodeID: &serverCtx.NodeID,
version: serverCtx.version,
},
interval: msgInterval,
Expand Down Expand Up @@ -1011,6 +1016,50 @@ func TestClusterIDMismatch(t *testing.T) {
wg.Wait()
}

func TestNodeIDMismatch(t *testing.T) {
defer leaktest.AfterTest(t)()

stopper := stop.NewStopper()
defer stopper.Stop(context.TODO())

clock := hlc.NewClock(timeutil.Unix(0, 20).UnixNano, time.Nanosecond)
serverCtx := newTestContext(clock, stopper)
uuid1 := uuid.MakeV4()
serverCtx.ClusterID.Set(context.TODO(), uuid1)
serverCtx.NodeID.Set(context.TODO(), 1)
s := newTestServer(t, serverCtx)
RegisterHeartbeatServer(s, &HeartbeatService{
clock: clock,
remoteClockMonitor: serverCtx.RemoteClocks,
clusterID: &serverCtx.ClusterID,
nodeID: &serverCtx.NodeID,
version: serverCtx.version,
})

ln, err := netutil.ListenAndServeGRPC(serverCtx.Stopper, s, util.TestAddr)
if err != nil {
t.Fatal(err)
}
remoteAddr := ln.Addr().String()

clientCtx := newTestContext(clock, stopper)
clientCtx.ClusterID.Set(context.TODO(), uuid1)

var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
_, err := clientCtx.GRPCDialNode(remoteAddr, 2).Connect(context.Background())
expected := "initial connection heartbeat failed.*doesn't match server node ID"
if !testutils.IsError(err, expected) {
t.Errorf("expected %s error, got %v", expected, err)
}
wg.Done()
}()
}
wg.Wait()
}

func setVersion(c *Context, v roachpb.Version) error {
settings := cluster.MakeClusterSettings(v, v)
cv := cluster.ClusterVersion{Version: v}
Expand Down Expand Up @@ -1048,6 +1097,7 @@ func TestVersionCheckBidirectional(t *testing.T) {
clock := hlc.NewClock(timeutil.Unix(0, 20).UnixNano, time.Nanosecond)
serverCtx := newTestContext(clock, stopper)
serverCtx.ClusterID.Set(context.TODO(), uuid.MakeV4())
serverCtx.NodeID.Set(context.TODO(), 1)
if err := setVersion(serverCtx, td.serverVersion); err != nil {
t.Fatal(err)
}
Expand All @@ -1056,6 +1106,7 @@ func TestVersionCheckBidirectional(t *testing.T) {
clock: clock,
remoteClockMonitor: serverCtx.RemoteClocks,
clusterID: &serverCtx.ClusterID,
nodeID: &serverCtx.NodeID,
version: serverCtx.version,
})

Expand Down
22 changes: 22 additions & 0 deletions pkg/rpc/heartbeat.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/roachpb"
"github.com/cockroachdb/cockroach/pkg/settings/cluster"
"github.com/cockroachdb/cockroach/pkg/util/hlc"
"github.com/cockroachdb/cockroach/pkg/util/log"
"github.com/cockroachdb/cockroach/pkg/util/timeutil"
"github.com/cockroachdb/cockroach/pkg/util/uuid"
"github.com/pkg/errors"
Expand All @@ -48,6 +49,7 @@ type HeartbeatService struct {
// shared by rpc clients, to keep track of remote clock measurements.
remoteClockMonitor *RemoteClockMonitor
clusterID *base.ClusterIDContainer
nodeID *base.NodeIDContainer
version *cluster.ExposedClusterVersion
}

Expand Down Expand Up @@ -79,13 +81,33 @@ func checkVersion(
// The requester should also estimate its offset from this server along
// with the requester's address.
func (hs *HeartbeatService) Ping(ctx context.Context, args *PingRequest) (*PingResponse, error) {
if log.V(2) {
log.Infof(ctx, "received heartbeat: %+v vs local cluster %+v node %+v", args, hs.clusterID, hs.nodeID)
}
// Check that cluster IDs match.
clusterID := hs.clusterID.Get()
if args.ClusterID != nil && *args.ClusterID != uuid.Nil && clusterID != uuid.Nil &&
*args.ClusterID != clusterID {
return nil, errors.Errorf(
"client cluster ID %q doesn't match server cluster ID %q", args.ClusterID, clusterID)
}
// Check that node IDs match.
var nodeID roachpb.NodeID
if hs.nodeID != nil {
nodeID = hs.nodeID.Get()
}
if args.NodeID != 0 && args.NodeID != nodeID {
// If nodeID != 0, the situation is clear (we are checking that
// the other side is talking to the right node).
//
// If nodeID == 0 this means that this node (serving the
// heartbeat) doesn't have a node ID yet. Then we can't serve
// connections for other nodes that want a specific node ID,
// however we can still serve connections that don't need a node
// ID, e.g. during initial gossip.
return nil, errors.Errorf(
"client requested node ID %d doesn't match server node ID %d", args.NodeID, nodeID)
}

// Check version compatibility.
if err := checkVersion(hs.version, args.ServerVersion); err != nil {
Expand Down
Loading

0 comments on commit 33afd7f

Please sign in to comment.