Skip to content

Commit

Permalink
sql/pgwire: implement pgwire query cancellation
Browse files Browse the repository at this point in the history
Release note (sql change): Added support for query cancellation via the
pgwire protocol. CockroachDB will now respond to a pgwire cancellation
by forwarding the request to the node that is running a particular
query. Since all cancellation requests are unauthenticated, there is a
fixed rate-limit on the number of cancellation attempts.
See https://www.postgresql.org/docs/13/protocol-flow.html#id-1.10.5.7.9

Release note (ops change): Added warning logs for unsuccessful (and
possibly malicious) attempts to cancel a query using the pgwire
cancellation protocol.
  • Loading branch information
rafiss committed Jul 13, 2021
1 parent 3b672b3 commit b227227
Show file tree
Hide file tree
Showing 9 changed files with 132 additions and 18 deletions.
23 changes: 21 additions & 2 deletions pkg/server/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -2189,11 +2189,30 @@ func (s *statusServer) CancelQuery(
}

// PGWireCancelQuery responds to a pgwire query cancellation request, and
// cancels the target query's associated context and sets a cancellation flag.
// cancels the target query's associated context and sets a cancellation flag.
func (s *statusServer) PGWireCancelQuery(
ctx context.Context, req *serverpb.PGWireCancelQueryRequest,
) (*serverpb.PGWireCancelQueryResponse, error) {
var output = &serverpb.PGWireCancelQueryResponse{}
nodeID, local, err := s.parseNodeID(req.NodeId)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, err.Error())
}
if !local {
// This request needs to be forwarded to another node.
ctx = propagateGatewayMetadata(ctx)
ctx = s.AnnotateCtx(ctx)
statusServer, err := s.dialNode(ctx, nodeID)
if err != nil {
return nil, err
}
return statusServer.PGWireCancelQuery(ctx, req)
}

output := &serverpb.PGWireCancelQueryResponse{}
output.Canceled, err = s.sessionRegistry.CancelQueryByPGWire(ctx, req.SecretID)
if err != nil {
output.Error = err.Error()
}
return output, nil
}

Expand Down
9 changes: 8 additions & 1 deletion pkg/server/tenant_status.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,14 @@ func (t *tenantStatusServer) CancelQuery(
func (t *tenantStatusServer) PGWireCancelQuery(
ctx context.Context, request *serverpb.PGWireCancelQueryRequest,
) (*serverpb.PGWireCancelQueryResponse, error) {
var output = &serverpb.PGWireCancelQueryResponse{}
var (
output = &serverpb.PGWireCancelQueryResponse{}
err error
)
output.Canceled, err = t.sessionRegistry.CancelQueryByPGWire(ctx, request.SecretID)
if err != nil {
output.Error = err.Error()
}
return output, nil
}

Expand Down
1 change: 1 addition & 0 deletions pkg/sql/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ go_library(
"//pkg/util/metric",
"//pkg/util/mon",
"//pkg/util/protoutil",
"//pkg/util/quotapool",
"//pkg/util/retry",
"//pkg/util/ring",
"//pkg/util/sequence",
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/conn_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ func (h ConnectionHandler) GetParamStatus(ctx context.Context, varName string) s
// GetPGWireCancelInfo returns the node ID and pgwire secret ID used for
// query cancellation.
func (h ConnectionHandler) GetPGWireCancelInfo() (int32, int32) {
return h.ex.sessionID.GetNodeID(), int32(h.ex.pgwireSecretID)
return int32(h.ex.server.cfg.NodeID.SQLInstanceID()), int32(h.ex.pgwireSecretID)
}

// ServeConn serves a client connection by reading commands from the stmtBuf
Expand Down
39 changes: 36 additions & 3 deletions pkg/sql/exec_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/util/log"
"github.com/cockroachdb/cockroach/pkg/util/metric"
"github.com/cockroachdb/cockroach/pkg/util/mon"
"github.com/cockroachdb/cockroach/pkg/util/quotapool"
"github.com/cockroachdb/cockroach/pkg/util/syncutil"
"github.com/cockroachdb/cockroach/pkg/util/tracing"
"github.com/cockroachdb/cockroach/pkg/util/tracing/collector"
Expand Down Expand Up @@ -515,6 +516,19 @@ var errTransactionInProgress = errors.New("there is already a transaction in pro
const sqlTxnName string = "sql txn"
const metricsSampleInterval = 10 * time.Second

// pgwireCancelSem is a semaphore that limits the number of concurrent
// calls to the pgwire query cancellation endpoint. This is needed to avoid the
// risk of a DoS attack by malicious users that attempts to cancel random
// queries by spamming the request.
//
// We hard-code a limit of 256 concurrent pgwire cancel requests (per node).
// We also add a 1-second penalty for failed cancellation requests, meaning
// that an attacker needs 1 second per guess. With an attacker randomly
// guessing the 32-bit secret, it would take 2^24 seconds to hit one query. If
// we suppose there are 256 concurrent queries actively running on a node,
// then it takes 2^16 seconds (18 hours) to hit any one of them.
var pgwireCancelSem = quotapool.NewIntPool("pgwire-cancel", 256)

// Fully-qualified names for metrics.
var (
MetaSQLExecLatency = metric.Metadata{
Expand Down Expand Up @@ -1567,16 +1581,35 @@ func (r *SessionRegistry) CancelQuery(queryIDStr string) (bool, error) {
}

// CancelQueryByPGWire looks up the associated query in the session registry and
// cancels it.
func (r *SessionRegistry) CancelQueryByPGWire(secretID uint32) (bool, error) {
// cancels it. It is rate-limited with a semaphore as per the documentation
// in sql/pgwire.
func (r *SessionRegistry) CancelQueryByPGWire(
ctx context.Context, secretID uint32,
) (canceled bool, err error) {
var alloc *quotapool.IntAlloc
alloc, err = pgwireCancelSem.TryAcquire(ctx, 1)
if err != nil {
return false, fmt.Errorf("exceeded rate limit of pgwire cancellation requests")
}
defer func() {
// If we acquired the semaphore but the cancellation request failed, then
// hold on to the semaphore for longer. This helps mitigate a DoS attack
// of random cancellation requests.
if !canceled {
time.Sleep(1 * time.Second)
}
alloc.Release()
}()

r.Lock()
defer r.Unlock()
if session, ok := r.sessionsByPGWireSecret[secretID]; ok {
if session.cancelCurrentQueries() {
return true, nil
}
return false, nil
}
return false, fmt.Errorf("query for secret ID %d not found", secretID)
return false, fmt.Errorf("session for secret ID %d not found", secretID)
}

// CancelSession looks up the specified session in the session registry and
Expand Down
1 change: 1 addition & 0 deletions pkg/sql/pgwire/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ go_library(
"//pkg/col/coldata",
"//pkg/col/coldataext",
"//pkg/security",
"//pkg/server/serverpb",
"//pkg/server/telemetry",
"//pkg/settings",
"//pkg/settings/cluster",
Expand Down
9 changes: 5 additions & 4 deletions pkg/sql/pgwire/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1550,8 +1550,7 @@ func TestSetSessionArguments(t *testing.T) {
}

// TestCancelQuery uses the pgwire-level query cancellation protocol provided
// by lib/pq to make sure that canceling a query has no effect, and makes sure
// the dummy BackendKeyData does not cause problems.
// by lib/pq to make sure that canceling a query works correctly.
func TestCancelQuery(t *testing.T) {
defer leaktest.AfterTest(t)()
defer log.Scope(t).Close(t)
Expand Down Expand Up @@ -1580,8 +1579,10 @@ func TestCancelQuery(t *testing.T) {
require.NoError(t, err)
defer db.Close()

// Cancellation has no effect on ongoing query.
if _, err := db.QueryContext(cancelCtx, "select pg_sleep(0)"); err != nil {
// Cancellation should stop the query.
if _, err := db.QueryContext(cancelCtx, "select pg_sleep(30)"); err == nil {
t.Fatal("expected error")
} else if err.Error() != "pq: query execution canceled" {
t.Fatalf("unexpected error: %s", err)
}

Expand Down
64 changes: 58 additions & 6 deletions pkg/sql/pgwire/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (

"github.com/cockroachdb/cockroach/pkg/base"
"github.com/cockroachdb/cockroach/pkg/security"
"github.com/cockroachdb/cockroach/pkg/server/serverpb"
"github.com/cockroachdb/cockroach/pkg/server/telemetry"
"github.com/cockroachdb/cockroach/pkg/settings"
"github.com/cockroachdb/cockroach/pkg/settings/cluster"
Expand Down Expand Up @@ -597,7 +598,8 @@ func (s *Server) ServeConn(ctx context.Context, conn net.Conn, socketType Socket
case versionCancel:
// The cancel message is rather peculiar: it is sent without
// authentication, always over an unencrypted channel.
return handleCancel(conn)
s.handleCancel(ctx, conn, &buf)
return nil

case versionGSSENC:
// This is a request for an unsupported feature: GSS encryption.
Expand Down Expand Up @@ -647,7 +649,8 @@ func (s *Server) ServeConn(ctx context.Context, conn net.Conn, socketType Socket
// Yet, we've found clients in the wild that send the cancel
// after the TLS handshake, for example at
// https://github.com/cockroachlabs/support/issues/600.
return handleCancel(conn)
s.handleCancel(ctx, conn, &buf)
return nil

default:
// We don't know this protocol.
Expand Down Expand Up @@ -703,12 +706,61 @@ func (s *Server) ServeConn(ctx context.Context, conn net.Conn, socketType Socket
return nil
}

func handleCancel(conn net.Conn) error {
// Since we don't support this, close the door in the client's
// face. Make a note of that use in telemetry.
// handleCancel handles a pgwire query cancellation request. Note that the
// request is unauthenticated. To mitigate the security risk (i.e., a
// malicious actor spamming this endpoint with random data to try to cancel
// a query), we rely on this part of the specification:
//
// The upshot of all this is that for reasons of both security and efficiency,
// the frontend has no direct way to tell whether a cancel request has
// succeeded. It must continue to wait for the backend to respond to the
// query. Issuing a cancel simply improves the odds that the current query
// will finish soon, and improves the odds that it will fail with an error
// message instead of succeeding.
//
// See https://www.postgresql.org/docs/13/protocol-flow.html#id-1.10.5.7.9
//
// Since the protocol is best-effort, we can rate limit the requests and
// ignore any requests that exceed the rate limit. The rate limit is implemented
// using a semaphore in sql.SessionRegistry. The most bullet-proof rate limit
// would be cluster-wide, but in practice a per-node rate limit is fine.
//
// Also, this function does not return an error, so the caller (and possible
// attacker) will not know if the cancellation attempt succeeded. Errors are
// logged so that an operator can be aware of any possibly malicious requests.
func (s *Server) handleCancel(ctx context.Context, conn net.Conn, buf *pgwirebase.ReadBuffer) {
var err error
defer func() {
if err != nil {
log.Ops.Warningf(ctx, "unexpected while handling pgwire cancellation request: %v", err)
}
}()

var nodeID, secretID uint32
nodeID, err = buf.GetUint32()
if err != nil {
return
}
secretID, err = buf.GetUint32()
if err != nil {
return
}
// The request is forwarded to the appropriate node. We implement the rate
// limit in the node to which the request is forwarded. This way, if an
// attacker spams all the nodes in the cluster with requests that all go to
// the same node, the per-node rate limit will prevent them from having
// too many guesses.
req := &serverpb.PGWireCancelQueryRequest{
NodeId: fmt.Sprintf("%d", nodeID),
SecretID: secretID,
}
var resp *serverpb.PGWireCancelQueryResponse
resp, err = s.execCfg.SQLStatusServer.PGWireCancelQuery(ctx, req)
if err == nil && len(resp.Error) > 0 {
err = fmt.Errorf(resp.Error)
}
telemetry.Inc(sqltelemetry.CancelRequestCounter)
_ = conn.Close()
return nil
}

// parseClientProvidedSessionParameters reads the incoming k/v pairs
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/sqltelemetry/pgwire.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (

// CancelRequestCounter is to be incremented every time a pgwire-level
// cancel request is received from a client.
var CancelRequestCounter = telemetry.GetCounterOnce("pgwire.unimplemented.cancel_request")
var CancelRequestCounter = telemetry.GetCounterOnce("pgwire.cancel_request")

// UnimplementedClientStatusParameterCounter is to be incremented
// every time a client attempts to configure a status parameter
Expand Down

0 comments on commit b227227

Please sign in to comment.