diff --git a/pkg/rpc/context.go b/pkg/rpc/context.go index b18352040374..839a0058bb93 100644 --- a/pkg/rpc/context.go +++ b/pkg/rpc/context.go @@ -32,9 +32,11 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/peer" "github.com/cockroachdb/cockroach/pkg/base" "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/security" "github.com/cockroachdb/cockroach/pkg/settings/cluster" "github.com/cockroachdb/cockroach/pkg/util/envutil" "github.com/cockroachdb/cockroach/pkg/util/grpcutil" @@ -93,6 +95,33 @@ func spanInclusionFunc( return parentSpanCtx != nil && !tracing.IsNoopContext(parentSpanCtx) } +func requireSuperUser(ctx context.Context) error { + // TODO(marc): grpc's authentication model (which gives credential access in + // the request handler) doesn't really fit with the current design of the + // security package (which assumes that TLS state is only given at connection + // time) - that should be fixed. + if grpcutil.IsLocalRequestContext(ctx) { + // This is a in-process request. Bypass authentication check. + } else if peer, ok := peer.FromContext(ctx); ok { + if tlsInfo, ok := peer.AuthInfo.(credentials.TLSInfo); ok { + certUser, err := security.GetCertificateUser(&tlsInfo.State) + if err != nil { + return err + } + // TODO(benesch): the vast majority of RPCs should be limited to just + // NodeUser. This is not a security concern, as RootUser has access to + // read and write all data, merely good hygiene. For example, there is + // no reason to permit the root user to send raw Raft RPCs. + if certUser != security.NodeUser && certUser != security.RootUser { + return errors.Errorf("user %s is not allowed to perform this RPC", certUser) + } + } + } else { + return errors.New("internal authentication error: TLSInfo is not available in request context") + } + return nil +} + // NewServer is a thin wrapper around grpc.NewServer that registers a heartbeat // service. func NewServer(ctx *Context) *grpc.Server { @@ -186,6 +215,33 @@ func NewServerWithInterceptor( } } + if !ctx.Insecure { + prevUnaryInterceptor := unaryInterceptor + unaryInterceptor = func( + ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler, + ) (interface{}, error) { + if err := requireSuperUser(ctx); err != nil { + return nil, err + } + if prevUnaryInterceptor != nil { + return prevUnaryInterceptor(ctx, req, info, handler) + } + return handler(ctx, req) + } + prevStreamInterceptor := streamInterceptor + streamInterceptor = func( + srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler, + ) error { + if err := requireSuperUser(stream.Context()); err != nil { + return err + } + if prevStreamInterceptor != nil { + return prevStreamInterceptor(srv, stream, info, handler) + } + return handler(srv, stream) + } + } + if unaryInterceptor != nil { opts = append(opts, grpc.UnaryInterceptor(unaryInterceptor)) } diff --git a/pkg/server/authentication_test.go b/pkg/server/authentication_test.go index 7cfdb1ff7bb1..70769048137a 100644 --- a/pkg/server/authentication_test.go +++ b/pkg/server/authentication_test.go @@ -28,13 +28,22 @@ import ( "testing" "time" + "github.com/gogo/protobuf/jsonpb" + "github.com/lib/pq" + "github.com/pkg/errors" "golang.org/x/crypto/bcrypt" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/gossip" "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/security" "github.com/cockroachdb/cockroach/pkg/server/debug" "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/sql/distsqlrun" + "github.com/cockroachdb/cockroach/pkg/storage" + "github.com/cockroachdb/cockroach/pkg/storage/closedts/ctpb" "github.com/cockroachdb/cockroach/pkg/testutils" "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" "github.com/cockroachdb/cockroach/pkg/ts" @@ -42,9 +51,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/util" "github.com/cockroachdb/cockroach/pkg/util/httputil" "github.com/cockroachdb/cockroach/pkg/util/leaktest" - "github.com/gogo/protobuf/jsonpb" - "github.com/lib/pq" - "github.com/pkg/errors" ) type ctxI interface { @@ -647,3 +653,117 @@ func TestAuthenticationMux(t *testing.T) { }) } } + +func TestGRPCAuthentication(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + s, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer s.Stopper().Stop(ctx) + + // For each subsystem we pick a representative RPC. The idea is not to + // exhaustively test each RPC but to prevent server startup from being + // refactored in such a way that an entire subsystem becomes inadvertently + // exempt from authentication checks. + subsystems := []struct { + name string + sendRPC func(context.Context, *grpc.ClientConn) error + }{ + {"gossip", func(ctx context.Context, conn *grpc.ClientConn) error { + stream, err := gossip.NewGossipClient(conn).Gossip(ctx) + if err != nil { + return err + } + _ = stream.Send(&gossip.Request{}) + _, err = stream.Recv() + return err + }}, + {"internal", func(ctx context.Context, conn *grpc.ClientConn) error { + _, err := roachpb.NewInternalClient(conn).Batch(ctx, &roachpb.BatchRequest{}) + return err + }}, + {"perReplica", func(ctx context.Context, conn *grpc.ClientConn) error { + _, err := storage.NewPerReplicaClient(conn).CollectChecksum(ctx, &storage.CollectChecksumRequest{}) + return err + }}, + {"raft", func(ctx context.Context, conn *grpc.ClientConn) error { + stream, err := storage.NewMultiRaftClient(conn).RaftMessageBatch(ctx) + if err != nil { + return err + } + _ = stream.Send(&storage.RaftMessageRequestBatch{}) + _, err = stream.Recv() + return err + }}, + {"closedTimestamp", func(ctx context.Context, conn *grpc.ClientConn) error { + stream, err := ctpb.NewClosedTimestampClient(conn).Get(ctx) + if err != nil { + return err + } + _ = stream.Send(&ctpb.Reaction{}) + _, err = stream.Recv() + return err + }}, + {"distSQL", func(ctx context.Context, conn *grpc.ClientConn) error { + stream, err := distsqlrun.NewDistSQLClient(conn).RunSyncFlow(ctx) + if err != nil { + return err + } + _ = stream.Send(&distsqlrun.ConsumerSignal{}) + _, err = stream.Recv() + return err + }}, + {"init", func(ctx context.Context, conn *grpc.ClientConn) error { + _, err := serverpb.NewInitClient(conn).Bootstrap(ctx, &serverpb.BootstrapRequest{}) + return err + }}, + {"admin", func(ctx context.Context, conn *grpc.ClientConn) error { + _, err := serverpb.NewAdminClient(conn).Databases(ctx, &serverpb.DatabasesRequest{}) + return err + }}, + {"status", func(ctx context.Context, conn *grpc.ClientConn) error { + _, err := serverpb.NewStatusClient(conn).ListSessions(ctx, &serverpb.ListSessionsRequest{}) + return err + }}, + } + + conn, err := grpc.DialContext(ctx, s.Addr(), + grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ + InsecureSkipVerify: true, + }))) + if err != nil { + t.Fatal(err) + } + defer func(conn *grpc.ClientConn) { _ = conn.Close() }(conn) + for _, subsystem := range subsystems { + t.Run(fmt.Sprintf("no-cert/%s", subsystem.name), func(t *testing.T) { + err := subsystem.sendRPC(ctx, conn) + if exp := "no client certificates in request"; !testutils.IsError(err, exp) { + t.Errorf("expected %q error, but got %v", exp, err) + } + }) + } + + certManager, err := s.RPCContext().GetCertificateManager() + if err != nil { + t.Fatal(err) + } + tlsConfig, err := certManager.GetClientTLSConfig("testuser") + if err != nil { + t.Fatal(err) + } + conn, err = grpc.DialContext(ctx, s.Addr(), + grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) + if err != nil { + t.Fatal(err) + } + defer func(conn *grpc.ClientConn) { _ = conn.Close() }(conn) + for _, subsystem := range subsystems { + t.Run(fmt.Sprintf("bad-user/%s", subsystem.name), func(t *testing.T) { + err := subsystem.sendRPC(ctx, conn) + if exp := "user testuser is not allowed to perform this RPC"; !testutils.IsError(err, exp) { + t.Errorf("expected %q error, but got %v", exp, err) + } + }) + } +} diff --git a/pkg/server/node.go b/pkg/server/node.go index 4fdca31b4f40..885f7cc22766 100644 --- a/pkg/server/node.go +++ b/pkg/server/node.go @@ -21,9 +21,6 @@ import ( "net" "time" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/peer" - opentracing "github.com/opentracing/opentracing-go" "github.com/pkg/errors" @@ -34,7 +31,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/keys" "github.com/cockroachdb/cockroach/pkg/kv" "github.com/cockroachdb/cockroach/pkg/roachpb" - "github.com/cockroachdb/cockroach/pkg/security" "github.com/cockroachdb/cockroach/pkg/server/status" "github.com/cockroachdb/cockroach/pkg/settings" "github.com/cockroachdb/cockroach/pkg/settings/cluster" @@ -972,31 +968,12 @@ func (n *Node) batchInternal( return &br, nil } - isLocalRequest := grpcutil.IsLocalRequestContext(ctx) - // TODO(marc): grpc's authentication model (which gives credential access in - // the request handler) doesn't really fit with the current design of the - // security package (which assumes that TLS state is only given at connection - // time) - that should be fixed. - if isLocalRequest { - // this is a in-process request, bypass checks. - } else if peer, ok := peer.FromContext(ctx); ok { - if tlsInfo, ok := peer.AuthInfo.(credentials.TLSInfo); ok { - certUser, err := security.GetCertificateUser(&tlsInfo.State) - if err != nil { - return nil, err - } - if certUser != security.NodeUser { - return nil, errors.Errorf("user %s is not allowed", certUser) - } - } - } - var br *roachpb.BatchResponse if err := n.stopper.RunTaskWithErr(ctx, "node.Node: batch", func(ctx context.Context) error { var finishSpan func(*roachpb.BatchResponse) // Shadow ctx from the outer function. Written like this to pass the linter. - ctx, finishSpan = n.setupSpanForIncomingRPC(ctx, isLocalRequest) + ctx, finishSpan = n.setupSpanForIncomingRPC(ctx, grpcutil.IsLocalRequestContext(ctx)) defer func(br **roachpb.BatchResponse) { finishSpan(*br) }(&br)