Skip to content

Commit

Permalink
server,rpc: validate node IDs in RPC heartbeats
Browse files Browse the repository at this point in the history
Prior to this patch, it was possible for a RPC client to dial a node
ID and get a connection to another node instead. This is because the
mapping of node ID -> address may be stale, and a different node could
take the address of the intended node from "under" the dialer.

(See the previous commit for a scenario.)

This happened to be "safe" in many cases where it matters because:

- RPC requests for distSQL are OK with being served on a different
  node than intended (with potential performance drop);
- RPC requests to the KV layer are OK with being served on a different
  node than intended (they would route underneath);
- RPC requests to the storage layer are rejected by the
  remote node because the store ID in the request would not match.

However this safety is largely accidental, and we should not work with
the assumption that any RPC request is safe to be mis-routed. (In
fact, we have not audited all the RPC endpoints and cannot establish
this safety exists throughout.)

This patch works to prevent these mis-routings by adding a check of
the intended node ID during RPC heartbeats (including the initial
heartbeat), when the intended node ID is known. A new API
`GRPCDialNode()` is introduced to establish such connections.

Release note (bug fix): CockroachDB now performs fewer attempts to
communicate with the wrong node, when a node is restarted with another
node's address.
  • Loading branch information
knz committed Jan 22, 2019
1 parent 321a146 commit 99b5af2
Showing 17 changed files with 293 additions and 72 deletions.
1 change: 1 addition & 0 deletions pkg/gossip/client_test.go
Original file line number Diff line number Diff line change
@@ -66,6 +66,7 @@ func startGossipAtAddr(
registry *metric.Registry,
) *Gossip {
rpcContext := newInsecureRPCContext(stopper)
rpcContext.NodeID.Set(context.TODO(), nodeID)
server := rpc.NewServer(rpcContext)
g := NewTest(nodeID, rpcContext, server, stopper, registry)
ln, err := netutil.ListenAndServeGRPC(stopper, server, addr)
3 changes: 1 addition & 2 deletions pkg/gossip/gossip_test.go
Original file line number Diff line number Diff line change
@@ -61,8 +61,7 @@ func TestGossipInfoStore(t *testing.T) {
}

// TestGossipMoveNode verifies that if a node is moved to a new address, it
// gets properly updated in gossip (including that any other node that was
// previously at that address gets removed from the cluster).
// gets properly updated in gossip.
func TestGossipMoveNode(t *testing.T) {
defer leaktest.AfterTest(t)()
stopper := stop.NewStopper()
74 changes: 57 additions & 17 deletions pkg/rpc/context.go
Original file line number Diff line number Diff line change
@@ -253,6 +253,7 @@ func NewServerWithInterceptor(
clock: ctx.LocalClock,
remoteClockMonitor: ctx.RemoteClocks,
clusterID: &ctx.ClusterID,
nodeID: &ctx.NodeID,
version: ctx.version,
})
return s
@@ -271,15 +272,17 @@ type Connection struct {
heartbeatResult atomic.Value // result of latest heartbeat
initialHeartbeatDone chan struct{} // closed after first heartbeat
stopper *stop.Stopper
remoteNodeID roachpb.NodeID // 0 when unknown, non-zero to check with remote node.

initOnce sync.Once
validatedOnce sync.Once
}

func newConnection(stopper *stop.Stopper) *Connection {
func newConnectionToNodeID(stopper *stop.Stopper, remoteNodeID roachpb.NodeID) *Connection {
c := &Connection{
initialHeartbeatDone: make(chan struct{}),
stopper: stopper,
remoteNodeID: remoteNodeID,
}
c.heartbeatResult.Store(heartbeatResult{err: ErrNotHeartbeated})
return c
@@ -346,13 +349,22 @@ type Context struct {
stats StatsHandler

ClusterID base.ClusterIDContainer
NodeID base.NodeIDContainer
version *cluster.ExposedClusterVersion

// For unittesting.
BreakerFactory func() *circuit.Breaker
testingDialOpts []grpc.DialOption
}

// connKey is used as key in the Context.conns map. Different remote
// node IDs get different *Connection objects, to ensure that we don't
// mis-route RPC requests.
type connKey struct {
targetAddr string
nodeID roachpb.NodeID
}

// NewContext creates an rpc Context with the supplied values.
func NewContext(
ambient log.AmbientContext,
@@ -396,7 +408,7 @@ func NewContext(
conn.dialErr = &roachpb.NodeUnavailableError{}
}
})
ctx.removeConn(k.(string), conn)
ctx.removeConn(k.(connKey), conn)
return true
})
})
@@ -518,15 +530,15 @@ func (ctx *Context) SetLocalInternalServer(internalServer roachpb.InternalServer
ctx.localInternalClient = internalClientAdapter{internalServer}
}

func (ctx *Context) removeConn(key string, conn *Connection) {
func (ctx *Context) removeConn(key connKey, conn *Connection) {
ctx.conns.Delete(key)
if log.V(1) {
log.Infof(ctx.masterCtx, "closing %s", key)
log.Infof(ctx.masterCtx, "closing %+v", key)
}
if grpcConn := conn.grpcConn; grpcConn != nil {
if err := grpcConn.Close(); err != nil && !grpcutil.IsClosedConnection(err) {
if log.V(1) {
log.Errorf(ctx.masterCtx, "failed to close client connection: %s", err)
log.Errorf(ctx.masterCtx, "failed to close client connection: %v", err)
}
}
}
@@ -650,10 +662,32 @@ func (ctx *Context) GRPCDialRaw(target string) (*grpc.ClientConn, <-chan struct{
}

// GRPCDial calls grpc.Dial with options appropriate for the context.
//
// It does not require validation of the node ID between client and server:
// if a connection existed already with some node ID requirement, that
// requirement will remain; if no connection existed yet,
// a new one is created without a node ID requirement.
func (ctx *Context) GRPCDial(target string) *Connection {
value, ok := ctx.conns.Load(target)
return ctx.GRPCDialNode(target, 0)
}

// GRPCDialNode calls grpc.Dial with options appropriate for the context.
//
// The remoteNodeID, if non-zero, becomes a constraint on the expected
// node ID of the remote node; this is checked during heartbeats. If
// a connection already existed with a different node ID requirement
// it is first dropped and a new one with the proper node ID
// requirement is created.
//
// This is done to ensure we always use a recent/up-to-date node ID
// requirement, for cases when e.g. a connection is initially
// established to a target address before the remote node ID is known,
// and we learn the proper required node ID afterwards.
func (ctx *Context) GRPCDialNode(target string, remoteNodeID roachpb.NodeID) *Connection {
thisConnKey := connKey{target, remoteNodeID}
value, ok := ctx.conns.Load(thisConnKey)
if !ok {
value, _ = ctx.conns.LoadOrStore(target, newConnection(ctx.Stopper))
value, _ = ctx.conns.LoadOrStore(thisConnKey, newConnectionToNodeID(ctx.Stopper, remoteNodeID))
}

conn := value.(*Connection)
@@ -668,11 +702,11 @@ func (ctx *Context) GRPCDial(target string) *Connection {
if err != nil && !grpcutil.IsClosedConnection(err) {
log.Errorf(masterCtx, "removing connection to %s due to error: %s", target, err)
}
ctx.removeConn(target, conn)
ctx.removeConn(thisConnKey, conn)
})
}); err != nil {
conn.dialErr = err
ctx.removeConn(target, conn)
ctx.removeConn(thisConnKey, conn)
}
}
})
@@ -703,6 +737,9 @@ var ErrNotHeartbeated = errors.New("not yet heartbeated")
// error will be returned. This method should therefore be used to
// prioritize among a list of candidate nodes, but not to filter out
// "unhealthy" nodes.
//
// This is used in tests only; in clusters use (*Dialer).ConnHealth()
// instead which validates the node ID.
func (ctx *Context) ConnHealth(target string) error {
if ctx.GetLocalInternalClientForAddr(target) != nil {
// The local server is always considered healthy.
@@ -716,14 +753,8 @@ func (ctx *Context) runHeartbeat(
conn *Connection, target string, redialChan <-chan struct{},
) error {
maxOffset := ctx.LocalClock.MaxOffset()
clusterID := ctx.ClusterID.Get()
maxOffsetNanos := maxOffset.Nanoseconds()

request := PingRequest{
Addr: ctx.Addr,
MaxOffsetNanos: maxOffset.Nanoseconds(),
ClusterID: &clusterID,
ServerVersion: ctx.version.ServerVersion,
}
heartbeatClient := NewHeartbeatClient(conn.grpcConn)

var heartbeatTimer timeutil.Timer
@@ -748,9 +779,18 @@ func (ctx *Context) runHeartbeat(
goCtx, cancel = context.WithTimeout(goCtx, hbTimeout)
}
sendTime := ctx.LocalClock.PhysicalTime()
// We re-mint the PingRequest to pick up any asynchronous update to clusterID.
clusterID := ctx.ClusterID.Get()
request := &PingRequest{
Addr: ctx.Addr,
MaxOffsetNanos: maxOffsetNanos,
ClusterID: &clusterID,
NodeID: conn.remoteNodeID,
ServerVersion: ctx.version.ServerVersion,
}
// NB: We want the request to fail-fast (the default), otherwise we won't
// be notified of transport failures.
response, err := heartbeatClient.Ping(goCtx, &request)
response, err := heartbeatClient.Ping(goCtx, request)
if cancel != nil {
cancel()
}
79 changes: 79 additions & 0 deletions pkg/rpc/context_test.go
Original file line number Diff line number Diff line change
@@ -93,6 +93,7 @@ func TestHeartbeatCB(t *testing.T) {
clock: clock,
remoteClockMonitor: serverCtx.RemoteClocks,
clusterID: &serverCtx.ClusterID,
nodeID: &serverCtx.NodeID,
version: serverCtx.version,
})

@@ -341,6 +342,7 @@ func TestHeartbeatHealthTransport(t *testing.T) {
clock: clock,
remoteClockMonitor: serverCtx.RemoteClocks,
clusterID: &serverCtx.ClusterID,
nodeID: &serverCtx.NodeID,
version: serverCtx.version,
})

@@ -515,6 +517,7 @@ func TestOffsetMeasurement(t *testing.T) {
clock: serverClock,
remoteClockMonitor: serverCtx.RemoteClocks,
clusterID: &serverCtx.ClusterID,
nodeID: &serverCtx.NodeID,
version: serverCtx.version,
})

@@ -682,6 +685,7 @@ func TestRemoteOffsetUnhealthy(t *testing.T) {
clock: clock,
remoteClockMonitor: nodeCtxs[i].ctx.RemoteClocks,
clusterID: &nodeCtxs[i].ctx.ClusterID,
nodeID: &nodeCtxs[i].ctx.NodeID,
version: nodeCtxs[i].ctx.version,
})
ln, err := netutil.ListenAndServeGRPC(nodeCtxs[i].ctx.Stopper, s, util.TestAddr)
@@ -829,6 +833,7 @@ func TestGRPCKeepaliveFailureFailsInflightRPCs(t *testing.T) {
clock: clock,
remoteClockMonitor: serverCtx.RemoteClocks,
clusterID: &serverCtx.ClusterID,
nodeID: &serverCtx.NodeID,
version: serverCtx.version,
},
interval: msgInterval,
@@ -1011,6 +1016,50 @@ func TestClusterIDMismatch(t *testing.T) {
wg.Wait()
}

func TestNodeIDMismatch(t *testing.T) {
defer leaktest.AfterTest(t)()

stopper := stop.NewStopper()
defer stopper.Stop(context.TODO())

clock := hlc.NewClock(timeutil.Unix(0, 20).UnixNano, time.Nanosecond)
serverCtx := newTestContext(clock, stopper)
uuid1 := uuid.MakeV4()
serverCtx.ClusterID.Set(context.TODO(), uuid1)
serverCtx.NodeID.Set(context.TODO(), 1)
s := newTestServer(t, serverCtx)
RegisterHeartbeatServer(s, &HeartbeatService{
clock: clock,
remoteClockMonitor: serverCtx.RemoteClocks,
clusterID: &serverCtx.ClusterID,
nodeID: &serverCtx.NodeID,
version: serverCtx.version,
})

ln, err := netutil.ListenAndServeGRPC(serverCtx.Stopper, s, util.TestAddr)
if err != nil {
t.Fatal(err)
}
remoteAddr := ln.Addr().String()

clientCtx := newTestContext(clock, stopper)
clientCtx.ClusterID.Set(context.TODO(), uuid1)

var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
_, err := clientCtx.GRPCDialNode(remoteAddr, 2).Connect(context.Background())
expected := "initial connection heartbeat failed.*doesn't match server node ID"
if !testutils.IsError(err, expected) {
t.Errorf("expected %s error, got %v", expected, err)
}
wg.Done()
}()
}
wg.Wait()
}

func setVersion(c *Context, v roachpb.Version) error {
settings := cluster.MakeClusterSettings(v, v)
cv := cluster.ClusterVersion{Version: v}
@@ -1048,6 +1097,7 @@ func TestVersionCheckBidirectional(t *testing.T) {
clock := hlc.NewClock(timeutil.Unix(0, 20).UnixNano, time.Nanosecond)
serverCtx := newTestContext(clock, stopper)
serverCtx.ClusterID.Set(context.TODO(), uuid.MakeV4())
serverCtx.NodeID.Set(context.TODO(), 1)
if err := setVersion(serverCtx, td.serverVersion); err != nil {
t.Fatal(err)
}
@@ -1056,6 +1106,7 @@ func TestVersionCheckBidirectional(t *testing.T) {
clock: clock,
remoteClockMonitor: serverCtx.RemoteClocks,
clusterID: &serverCtx.ClusterID,
nodeID: &serverCtx.NodeID,
version: serverCtx.version,
})

@@ -1110,3 +1161,31 @@ func BenchmarkGRPCDial(b *testing.B) {
}
})
}

func BenchmarkGRPCDialNode(b *testing.B) {
if testing.Short() {
b.Skip("TODO: fix benchmark")
}
stopper := stop.NewStopper()
defer stopper.Stop(context.TODO())

clock := hlc.NewClock(hlc.UnixNano, 250*time.Millisecond)
ctx := newTestContext(clock, stopper)
ctx.NodeID.Set(context.TODO(), 1)

s := newTestServer(b, ctx)
ln, err := netutil.ListenAndServeGRPC(ctx.Stopper, s, util.TestAddr)
if err != nil {
b.Fatal(err)
}
remoteAddr := ln.Addr().String()

b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, err := ctx.GRPCDialNode(remoteAddr, 1).Connect(context.Background())
if err != nil {
b.Fatal(err)
}
}
})
}
14 changes: 14 additions & 0 deletions pkg/rpc/heartbeat.go
Original file line number Diff line number Diff line change
@@ -23,6 +23,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/roachpb"
"github.com/cockroachdb/cockroach/pkg/settings/cluster"
"github.com/cockroachdb/cockroach/pkg/util/hlc"
"github.com/cockroachdb/cockroach/pkg/util/log"
"github.com/cockroachdb/cockroach/pkg/util/timeutil"
"github.com/cockroachdb/cockroach/pkg/util/uuid"
"github.com/pkg/errors"
@@ -48,6 +49,7 @@ type HeartbeatService struct {
// shared by rpc clients, to keep track of remote clock measurements.
remoteClockMonitor *RemoteClockMonitor
clusterID *base.ClusterIDContainer
nodeID *base.NodeIDContainer
version *cluster.ExposedClusterVersion
}

@@ -79,13 +81,25 @@ func checkVersion(
// The requester should also estimate its offset from this server along
// with the requester's address.
func (hs *HeartbeatService) Ping(ctx context.Context, args *PingRequest) (*PingResponse, error) {
if log.V(2) {
log.Infof(ctx, "received heartbeat: %+v vs local cluster %+v node %+v", args, hs.clusterID, hs.nodeID)
}
// Check that cluster IDs match.
clusterID := hs.clusterID.Get()
if args.ClusterID != nil && *args.ClusterID != uuid.Nil && clusterID != uuid.Nil &&
*args.ClusterID != clusterID {
return nil, errors.Errorf(
"client cluster ID %q doesn't match server cluster ID %q", args.ClusterID, clusterID)
}
// Check that node IDs match.
var nodeID roachpb.NodeID
if hs.nodeID != nil {
nodeID = hs.nodeID.Get()
}
if args.NodeID != 0 && nodeID != 0 && args.NodeID != nodeID {
return nil, errors.Errorf(
"client requested node ID %d doesn't match server node ID %d", args.NodeID, nodeID)
}

// Check version compatibility.
if err := checkVersion(hs.version, args.ServerVersion); err != nil {
Loading

0 comments on commit 99b5af2

Please sign in to comment.