diff --git a/build/teamcity/cockroach/ci/builds/build_macos_arm64.sh b/build/teamcity/cockroach/ci/builds/build_macos_arm64.sh new file mode 100755 index 000000000000..3eee56e074b6 --- /dev/null +++ b/build/teamcity/cockroach/ci/builds/build_macos_arm64.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash + +set -euo pipefail + +dir="$(dirname $(dirname $(dirname $(dirname $(dirname "${0}")))))" + +source "$dir/teamcity-support.sh" # For $root +source "$dir/teamcity-bazel-support.sh" # For run_bazel + +tc_start_block "Run Bazel build" +run_bazel build/teamcity/cockroach/ci/builds/build_impl.sh crossmacosarm +tc_end_block "Run Bazel build" diff --git a/pkg/kv/kvserver/consistency_queue.go b/pkg/kv/kvserver/consistency_queue.go index 0a40db7aa21c..3f40424fee89 100644 --- a/pkg/kv/kvserver/consistency_queue.go +++ b/pkg/kv/kvserver/consistency_queue.go @@ -55,31 +55,10 @@ const consistencyCheckRateBurstFactor = 8 // churn on timers. const consistencyCheckRateMinWait = 100 * time.Millisecond -// consistencyCheckAsyncConcurrency is the maximum number of asynchronous -// consistency checks to run concurrently per store below Raft. The -// server.consistency_check.max_rate limit is shared among these, so running too -// many at the same time will cause them to time out. The rate is multiplied by -// 10 (permittedRangeScanSlowdown) to obtain the per-check timeout. 7 gives -// reasonable headroom, and also handles clusters with high replication factor -// and/or many nodes -- recall that each node runs a separate consistency queue -// which can schedule checks on other nodes, e.g. a 7-node cluster with a -// replication factor of 7 could run 7 concurrent checks on every node. -// -// Note that checksum calculations below Raft are not tied to the caller's -// context, and may continue to run even after the caller has given up on them, -// which may cause them to build up. Although we do best effort to cancel the -// running task on the receiving end when the incoming request is aborted. -// -// CHECK_STATS checks do not count towards this limit, as they are cheap and the -// DistSender will parallelize them across all ranges (notably when calling -// crdb_internal.check_consistency()). -const consistencyCheckAsyncConcurrency = 7 - -// consistencyCheckAsyncTimeout is a below-Raft timeout for asynchronous -// consistency check calculations. These are not tied to the caller's context, -// and thus may continue to run even if the caller has given up on them, so we -// give them an upper timeout to prevent them from running forever. -const consistencyCheckAsyncTimeout = time.Hour +// consistencyCheckSyncTimeout is the max amount of time the consistency check +// computation and the checksum collection request will wait for each other +// before giving up. +const consistencyCheckSyncTimeout = 5 * time.Second var testingAggressiveConsistencyChecks = envutil.EnvOrDefaultBool("COCKROACH_CONSISTENCY_AGGRESSIVE", false) diff --git a/pkg/kv/kvserver/replica_consistency.go b/pkg/kv/kvserver/replica_consistency.go index ebe4ec6e78f3..130caa6cf415 100644 --- a/pkg/kv/kvserver/replica_consistency.go +++ b/pkg/kv/kvserver/replica_consistency.go @@ -51,10 +51,12 @@ import ( // Up to 22.1, the consistency check initiator used to synchronously collect the // first replica's checksum before all others, so checksum collection requests // could arrive late if the first one was slow. Since 22.2, all requests are -// parallel and likely arrive quickly. +// parallel and likely arrive quickly. Thus, in 23.1 the checksum task waits a +// short amount of time until the collection request arrives, and otherwise +// doesn't start. // -// TODO(pavelkalinnikov): Consider removing GC behaviour in 23.1+, when all the -// incoming requests are from 22.2+ nodes (hence arrive timely). +// We still need the delayed GC in order to help a late arriving participant to +// learn that the other one gave up, and fail fast instead of waiting. const replicaChecksumGCInterval = time.Hour // fatalOnStatsMismatch, if true, turns stats mismatches into fatal errors. A @@ -422,7 +424,7 @@ func (r *Replica) getReplicaChecksum(id uuid.UUID, now time.Time) (*replicaCheck c := r.mu.checksums[id] if c == nil { c = &replicaChecksum{ - started: make(chan context.CancelFunc, 1), // allow an async send + started: make(chan context.CancelFunc), // require send/recv sync result: make(chan CollectChecksumResponse, 1), // allow an async send } r.mu.checksums[id] = c @@ -496,13 +498,13 @@ func (r *Replica) getChecksum(ctx context.Context, id uuid.UUID) (CollectChecksu } // checksumInitialWait returns the amount of time to wait until the checksum -// computation has started. It is set to min of 5s and 10% of the remaining time -// in the passed-in context (if it has a deadline). +// computation has started. It is set to min of consistencyCheckSyncTimeout and +// 10% of the remaining time in the passed-in context (if it has a deadline). // // If it takes longer, chances are that the replica is being restored from // snapshots, or otherwise too busy to handle this request soon. func (*Replica) checksumInitialWait(ctx context.Context) time.Duration { - wait := 5 * time.Second + wait := consistencyCheckSyncTimeout if d, ok := ctx.Deadline(); ok { if dur := time.Duration(timeutil.Until(d).Nanoseconds() / 10); dur < wait { wait = dur @@ -747,72 +749,79 @@ func (r *Replica) computeChecksumPostApply( } // Compute SHA asynchronously and store it in a map by UUID. Concurrent checks - // share the rate limit in r.store.consistencyLimiter, so we also limit the - // number of concurrent checks via r.store.consistencySem. + // share the rate limit in r.store.consistencyLimiter, so if too many run at + // the same time, chances are they will time out. // - // Don't use the proposal's context for this, as it likely to be canceled very - // soon. + // Each node's consistency queue runs a check for one range at a time, which + // it broadcasts to all replicas, so the average number of incoming in-flight + // collection requests per node is equal to the replication factor (typ. 3-7). + // Abandoned tasks are canceled eagerly within a few seconds, so there is very + // limited room for running above this figure. Thus we don't limit the number + // of concurrent tasks here. + // + // NB: CHECK_STATS checks are cheap and the DistSender will parallelize them + // across all ranges (notably when calling crdb_internal.check_consistency()). const taskName = "kvserver.Replica: computing checksum" - sem := r.store.consistencySem - if cc.Mode == roachpb.ChecksumMode_CHECK_STATS { - // Stats-only checks are cheap, and the DistSender parallelizes these across - // ranges (in particular when calling crdb_internal.check_consistency()), so - // they don't count towards the semaphore limit. - sem = nil - } stopper := r.store.Stopper() + // Don't use the proposal's context, as it is likely to be canceled very soon. taskCtx, taskCancel := stopper.WithCancelOnQuiesce(r.AnnotateCtx(context.Background())) if err := stopper.RunAsyncTaskEx(taskCtx, stop.TaskOpts{ - TaskName: taskName, - Sem: sem, - WaitForSem: false, + TaskName: taskName, }, func(ctx context.Context) { defer taskCancel() - // There is only one writer to c.started (this task), so this doesn't block. - // But if by mistake there is another writer, one of us closes the channel - // eventually, and other send/close ops will crash. This is by design. - c.started <- taskCancel - close(c.started) - - if err := contextutil.RunWithTimeout(ctx, taskName, consistencyCheckAsyncTimeout, + defer snap.Close() + defer r.gcReplicaChecksum(cc.ChecksumID, c) + // Wait until the CollectChecksum request handler joins in and learns about + // the starting computation, and then start it. + if err := contextutil.RunWithTimeout(ctx, taskName, consistencyCheckSyncTimeout, func(ctx context.Context) error { - defer snap.Close() - var snapshot *roachpb.RaftSnapshotData - if cc.SaveSnapshot { - snapshot = &roachpb.RaftSnapshotData{} + // There is only one writer to c.started (this task), buf if by mistake + // there is another writer, one of us closes the channel eventually, and + // other writes to c.started will crash. By design. + defer close(c.started) + select { + case <-ctx.Done(): + return ctx.Err() + case c.started <- taskCancel: + return nil } - - result, err := r.sha512(ctx, desc, snap, snapshot, cc.Mode, r.store.consistencyLimiter) - if err != nil { - result = nil - } - r.computeChecksumDone(c, result, snapshot) - r.gcReplicaChecksum(cc.ChecksumID, c) - return err }, ); err != nil { - log.Errorf(ctx, "checksum computation failed: %v", err) + log.Errorf(ctx, "checksum collection did not join: %v", err) + } else { + var snapshot *roachpb.RaftSnapshotData + if cc.SaveSnapshot { + snapshot = &roachpb.RaftSnapshotData{} + } + result, err := r.sha512(ctx, desc, snap, snapshot, cc.Mode, r.store.consistencyLimiter) + if err != nil { + log.Errorf(ctx, "checksum computation failed: %v", err) + result = nil + } + r.computeChecksumDone(c, result, snapshot) } var shouldFatal bool for _, rDesc := range cc.Terminate { if rDesc.StoreID == r.store.StoreID() && rDesc.ReplicaID == r.replicaID { shouldFatal = true + break } } + if !shouldFatal { + return + } - if shouldFatal { - // This node should fatal as a result of a previous consistency - // check (i.e. this round is carried out only to obtain a diff). - // If we fatal too early, the diff won't make it back to the lease- - // holder and thus won't be printed to the logs. Since we're already - // in a goroutine that's about to end, simply sleep for a few seconds - // and then terminate. - auxDir := r.store.engine.GetAuxiliaryDir() - _ = r.store.engine.MkdirAll(auxDir) - path := base.PreventedStartupFile(auxDir) + // This node should fatal as a result of a previous consistency check (i.e. + // this round is carried out only to obtain a diff). If we fatal too early, + // the diff won't make it back to the leaseholder and thus won't be printed + // to the logs. Since we're already in a goroutine that's about to end, + // simply sleep for a few seconds and then terminate. + auxDir := r.store.engine.GetAuxiliaryDir() + _ = r.store.engine.MkdirAll(auxDir) + path := base.PreventedStartupFile(auxDir) - const attentionFmt = `ATTENTION: + const attentionFmt = `ATTENTION: this node is terminating because a replica inconsistency was detected between %s and its other replicas. Please check your cluster-wide log files for more @@ -825,19 +834,17 @@ A checkpoints directory to aid (expert) debugging should be present in: A file preventing this node from restarting was placed at: %s ` - preventStartupMsg := fmt.Sprintf(attentionFmt, r, auxDir, path) - if err := fs.WriteFile(r.store.engine, path, []byte(preventStartupMsg)); err != nil { - log.Warningf(ctx, "%v", err) - } - - if p := r.store.cfg.TestingKnobs.ConsistencyTestingKnobs.OnBadChecksumFatal; p != nil { - p(*r.store.Ident) - } else { - time.Sleep(10 * time.Second) - log.Fatalf(r.AnnotateCtx(context.Background()), attentionFmt, r, auxDir, path) - } + preventStartupMsg := fmt.Sprintf(attentionFmt, r, auxDir, path) + if err := fs.WriteFile(r.store.engine, path, []byte(preventStartupMsg)); err != nil { + log.Warningf(ctx, "%v", err) } + if p := r.store.cfg.TestingKnobs.ConsistencyTestingKnobs.OnBadChecksumFatal; p != nil { + p(*r.store.Ident) + } else { + time.Sleep(10 * time.Second) + log.Fatalf(r.AnnotateCtx(context.Background()), attentionFmt, r, auxDir, path) + } }); err != nil { taskCancel() snap.Close() diff --git a/pkg/kv/kvserver/replica_consistency_test.go b/pkg/kv/kvserver/replica_consistency_test.go index 7c5ad7db42ed..98bb6be6fdf2 100644 --- a/pkg/kv/kvserver/replica_consistency_test.go +++ b/pkg/kv/kvserver/replica_consistency_test.go @@ -32,6 +32,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/cockroach/pkg/util/uuid" "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" ) func TestReplicaChecksumVersion(t *testing.T) { @@ -54,8 +55,10 @@ func TestReplicaChecksumVersion(t *testing.T) { } else { cc.Version = 1 } - taskErr := tc.repl.computeChecksumPostApply(ctx, cc) + var g errgroup.Group + g.Go(func() error { return tc.repl.computeChecksumPostApply(ctx, cc) }) rc, err := tc.repl.getChecksum(ctx, cc.ChecksumID) + taskErr := g.Wait() if !matchingVersion { require.ErrorContains(t, taskErr, "incompatible versions") require.ErrorContains(t, err, "checksum task failed to start") @@ -79,13 +82,12 @@ func TestGetChecksumNotSuccessfulExitConditions(t *testing.T) { defer stopper.Stop(ctx) tc.Start(ctx, t, stopper) - requireChecksumTaskNotStarted := func(id uuid.UUID) { - require.ErrorContains(t, - tc.repl.computeChecksumPostApply(context.Background(), kvserverpb.ComputeChecksum{ - ChecksumID: id, - Mode: roachpb.ChecksumMode_CHECK_FULL, - Version: batcheval.ReplicaChecksumVersion, - }), "checksum collection request gave up") + startChecksumTask := func(ctx context.Context, id uuid.UUID) error { + return tc.repl.computeChecksumPostApply(ctx, kvserverpb.ComputeChecksum{ + ChecksumID: id, + Mode: roachpb.ChecksumMode_CHECK_FULL, + Version: batcheval.ReplicaChecksumVersion, + }) } // Checksum computation failed to start. @@ -99,28 +101,38 @@ func TestGetChecksumNotSuccessfulExitConditions(t *testing.T) { // Checksum computation started, but failed. id = uuid.FastMakeV4() c, _ = tc.repl.getReplicaChecksum(id, timeutil.Now()) - c.started <- func() {} - close(c.started) - close(c.result) + var g errgroup.Group + g.Go(func() error { + c.started <- func() {} + close(c.started) + close(c.result) + return nil + }) rc, err = tc.repl.getChecksum(ctx, id) require.ErrorContains(t, err, "no checksum found") require.Nil(t, rc.Checksum) + require.NoError(t, g.Wait()) // The initial wait for the task start expires. This will take 10ms. id = uuid.FastMakeV4() rc, err = tc.repl.getChecksum(ctx, id) require.ErrorContains(t, err, "checksum computation did not start") require.Nil(t, rc.Checksum) - requireChecksumTaskNotStarted(id) + require.ErrorContains(t, startChecksumTask(context.Background(), id), + "checksum collection request gave up") // The computation has started, but the request context timed out. id = uuid.FastMakeV4() c, _ = tc.repl.getReplicaChecksum(id, timeutil.Now()) - c.started <- func() {} - close(c.started) + g.Go(func() error { + c.started <- func() {} + close(c.started) + return nil + }) rc, err = tc.repl.getChecksum(ctx, id) require.ErrorIs(t, err, context.DeadlineExceeded) require.Nil(t, rc.Checksum) + require.NoError(t, g.Wait()) // Context is canceled during the initial waiting. id = uuid.FastMakeV4() @@ -129,7 +141,20 @@ func TestGetChecksumNotSuccessfulExitConditions(t *testing.T) { rc, err = tc.repl.getChecksum(ctx, id) require.ErrorIs(t, err, context.Canceled) require.Nil(t, rc.Checksum) - requireChecksumTaskNotStarted(id) + require.ErrorContains(t, startChecksumTask(context.Background(), id), + "checksum collection request gave up") + + // The task failed to start because the checksum collection request did not + // join. Later, when it joins, it finds out that the task gave up. + id = uuid.FastMakeV4() + c, _ = tc.repl.getReplicaChecksum(id, timeutil.Now()) + require.NoError(t, startChecksumTask(context.Background(), id)) + // TODO(pavelkalinnikov): Avoid this long wait in the test. + time.Sleep(2 * consistencyCheckSyncTimeout) // give the task time to give up + _, ok := <-c.started + require.False(t, ok) // ensure the task gave up + rc, err = tc.repl.getChecksum(context.Background(), id) + require.ErrorContains(t, err, "checksum task failed to start") } // TestReplicaChecksumSHA512 checks that a given dataset produces the expected diff --git a/pkg/kv/kvserver/store.go b/pkg/kv/kvserver/store.go index 6ac5f567acb0..fceda5df42db 100644 --- a/pkg/kv/kvserver/store.go +++ b/pkg/kv/kvserver/store.go @@ -744,7 +744,6 @@ type Store struct { scanner *replicaScanner // Replica scanner consistencyQueue *consistencyQueue // Replica consistency check queue consistencyLimiter *quotapool.RateLimiter // Rate limits consistency checks - consistencySem *quotapool.IntPool // Limit concurrent consistency checks metrics *StoreMetrics intentResolver *intentresolver.IntentResolver recoveryMgr txnrecovery.Manager @@ -2090,9 +2089,6 @@ func (s *Store) Start(ctx context.Context, stopper *stop.Stopper) error { rate := consistencyCheckRate.Get(&s.ClusterSettings().SV) s.consistencyLimiter.UpdateLimit(quotapool.Limit(rate), rate*consistencyCheckRateBurstFactor) }) - s.consistencySem = quotapool.NewIntPool("concurrent async consistency checks", - consistencyCheckAsyncConcurrency) - s.stopper.AddCloser(s.consistencySem.Closer("stopper")) // Set the started flag (for unittests). atomic.StoreInt32(&s.started, 1) diff --git a/pkg/server/server_sql.go b/pkg/server/server_sql.go index 560777c008f0..3ca860a2d729 100644 --- a/pkg/server/server_sql.go +++ b/pkg/server/server_sql.go @@ -993,6 +993,7 @@ func newSQLServer(ctx context.Context, cfg sqlServerArgs) (*SQLServer, error) { *cfg.collectionFactory = *collectionFactory cfg.internalExecutorFactory = ieFactory execCfg.InternalExecutor = cfg.circularInternalExecutor + stmtDiagnosticsRegistry := stmtdiagnostics.NewRegistry( cfg.circularInternalExecutor, cfg.db, diff --git a/pkg/sql/explain_bundle.go b/pkg/sql/explain_bundle.go index 3b8ba0a37c09..b2447ffcdf81 100644 --- a/pkg/sql/explain_bundle.go +++ b/pkg/sql/explain_bundle.go @@ -159,11 +159,13 @@ func (bundle *diagnosticsBundle) insert( ast tree.Statement, stmtDiagRecorder *stmtdiagnostics.Registry, diagRequestID stmtdiagnostics.RequestID, + req stmtdiagnostics.Request, ) { var err error bundle.diagID, err = stmtDiagRecorder.InsertStatementDiagnostics( ctx, diagRequestID, + req, fingerprint, tree.AsString(ast), bundle.zip, diff --git a/pkg/sql/instrumentation.go b/pkg/sql/instrumentation.go index c1e8cebb0974..e4eabcdadf76 100644 --- a/pkg/sql/instrumentation.go +++ b/pkg/sql/instrumentation.go @@ -374,7 +374,7 @@ func (ih *instrumentationHelper) Finish( bundle = buildStatementBundle( ih.origCtx, cfg.DB, ie.(*InternalExecutor), &p.curPlan, ob.BuildString(), trace, placeholders, ) - bundle.insert(ctx, ih.fingerprint, ast, cfg.StmtDiagnosticsRecorder, ih.diagRequestID) + bundle.insert(ctx, ih.fingerprint, ast, cfg.StmtDiagnosticsRecorder, ih.diagRequestID, ih.diagRequest) ih.stmtDiagnosticsRecorder.RemoveOngoing(ih.diagRequestID, ih.diagRequest) telemetry.Inc(sqltelemetry.StatementDiagnosticsCollectedCounter) } diff --git a/pkg/sql/stmtdiagnostics/statement_diagnostics.go b/pkg/sql/stmtdiagnostics/statement_diagnostics.go index a6843cb15049..0bd4351f7dfa 100644 --- a/pkg/sql/stmtdiagnostics/statement_diagnostics.go +++ b/pkg/sql/stmtdiagnostics/statement_diagnostics.go @@ -53,6 +53,35 @@ var bundleChunkSize = settings.RegisterByteSizeSetting( }, ) +// collectUntilExpiration enables continuous collection of statement bundles for +// requests that declare a sampling probability and have an expiration +// timestamp. +// +// This setting should be used with some caution, enabling it would start +// accruing diagnostic bundles that meet a certain latency threshold until the +// request expires. It's worth nothing that there's no automatic GC of bundles +// today (best you can do is `cockroach statement-diag delete --all`). This +// setting also captures multiple bundles for a single diagnostic request which +// does not fit well with our current scheme of one-bundle-per-completed. What +// it does internally is refuse to mark a "continuous" request as completed +// until it has expired, accumulating bundles until that point. The UI +// integration is incomplete -- only the most recently collected bundle is shown +// once the request is marked as completed. The full set can be retrieved using +// `cockroach statement-diag download `. This setting is primarily +// intended for low-overhead trace capture during tail latency investigations, +// experiments, and escalations under supervision. +// +// TODO(irfansharif): Longer term we should rip this out in favor of keeping a +// bounded set of bundles around per-request/fingerprint. See #82896 for more +// details. +var collectUntilExpiration = settings.RegisterBoolSetting( + settings.TenantWritable, + "sql.stmt_diagnostics.collect_continuously.enabled", + "collect diagnostic bundles continuously until request expiration (to be "+ + "used with care, only has an effect if the diagnostic request has an "+ + "expiration and a sampling probability set)", + false) + // Registry maintains a view on the statement fingerprints // on which data is to be collected (i.e. system.statement_diagnostics_requests) // and provides utilities for checking a query against this list and satisfying @@ -255,15 +284,17 @@ func (r *Registry) insertRequestInternal( "sampling probability only supported after 22.2 version migrations have completed", ) } - if samplingProbability < 0 || samplingProbability > 1 { - return 0, errors.AssertionFailedf( - "malformed input: expected sampling probability in range [0.0, 1.0], got %f", - samplingProbability) - } - if samplingProbability != 0 && minExecutionLatency.Nanoseconds() == 0 { - return 0, errors.AssertionFailedf( - "malformed input: got non-zero sampling probability %f and empty min exec latency", - samplingProbability) + if samplingProbability != 0 { + if samplingProbability < 0 || samplingProbability > 1 { + return 0, errors.Newf( + "expected sampling probability in range [0.0, 1.0], got %f", + samplingProbability) + } + if minExecutionLatency == 0 { + return 0, errors.Newf( + "got non-zero sampling probability %f and empty min exec latency", + minExecutionLatency) + } } var reqID RequestID @@ -473,6 +504,7 @@ func (r *Registry) ShouldCollectDiagnostics( func (r *Registry) InsertStatementDiagnostics( ctx context.Context, requestID RequestID, + req Request, stmtFingerprint string, stmt string, bundle []byte, @@ -537,7 +569,7 @@ func (r *Registry) InsertStatementDiagnostics( collectionTime := timeutil.Now() - // Insert the trace into system.statement_diagnostics. + // Insert the collection metadata into system.statement_diagnostics. row, err := r.ie.QueryRowEx( ctx, "stmt-diag-insert", txn, sessiondata.InternalExecutorOverride{User: username.RootUserName()}, @@ -555,12 +587,28 @@ func (r *Registry) InsertStatementDiagnostics( diagID = CollectedInstanceID(*row[0].(*tree.DInt)) if requestID != 0 { - // Mark the request from system.statement_diagnostics_request as completed. + // Link the request from system.statement_diagnostics_request to the + // diagnostic ID we just collected, marking it as completed if we're + // able. + shouldMarkCompleted := true + if collectUntilExpiration.Get(&r.st.SV) { + // Two other conditions need to hold true for us to continue + // capturing future traces, i.e. not mark this request as + // completed. + // - Requests need to be of the sampling sort (also implies + // there's a latency threshold) -- a crude measure to prevent + // against unbounded collection; + // - Requests need to have an expiration set -- same reason as + // above. + if req.samplingProbability > 0 && !req.expiresAt.IsZero() { + shouldMarkCompleted = false + } + } _, err := r.ie.ExecEx(ctx, "stmt-diag-mark-completed", txn, sessiondata.InternalExecutorOverride{User: username.RootUserName()}, "UPDATE system.statement_diagnostics_requests "+ - "SET completed = true, statement_diagnostics_id = $1 WHERE id = $2", - diagID, requestID) + "SET completed = $1, statement_diagnostics_id = $2 WHERE id = $3", + shouldMarkCompleted, diagID, requestID) if err != nil { return err } @@ -652,6 +700,11 @@ func (r *Registry) pollRequests(ctx context.Context) error { if isSamplingProbabilitySupported { if prob, ok := row[4].(*tree.DFloat); ok { samplingProbability = float64(*prob) + if samplingProbability < 0 || samplingProbability > 1 { + log.Warningf(ctx, "malformed sampling probability for request %d: %f (expected in range [0, 1]), resetting to 1.0", + id, samplingProbability) + samplingProbability = 1.0 + } } } ids.Add(int(id)) diff --git a/pkg/sql/stmtdiagnostics/statement_diagnostics_helpers_test.go b/pkg/sql/stmtdiagnostics/statement_diagnostics_helpers_test.go index fe9b3cb758b2..d8ec19eba234 100644 --- a/pkg/sql/stmtdiagnostics/statement_diagnostics_helpers_test.go +++ b/pkg/sql/stmtdiagnostics/statement_diagnostics_helpers_test.go @@ -15,6 +15,13 @@ import ( "time" ) +// TestingFindRequest exports findRequest for testing purposes. +func (r *Registry) TestingFindRequest(requestID RequestID) bool { + r.mu.Lock() + defer r.mu.Unlock() + return r.findRequestLocked(requestID) +} + // InsertRequestInternal exposes the form of insert which returns the request ID // as an int64 to tests in this package. func (r *Registry) InsertRequestInternal( diff --git a/pkg/sql/stmtdiagnostics/statement_diagnostics_test.go b/pkg/sql/stmtdiagnostics/statement_diagnostics_test.go index f129c1a51411..3659f8694f64 100644 --- a/pkg/sql/stmtdiagnostics/statement_diagnostics_test.go +++ b/pkg/sql/stmtdiagnostics/statement_diagnostics_test.go @@ -39,8 +39,8 @@ import ( func TestDiagnosticsRequest(t *testing.T) { defer leaktest.AfterTest(t)() defer log.Scope(t).Close(t) - params := base.TestServerArgs{} - s, db, _ := serverutils.StartServer(t, params) + + s, db, _ := serverutils.StartServer(t, base.TestServerArgs{}) ctx := context.Background() defer s.Stopper().Stop(ctx) _, err := db.Exec("CREATE TABLE test (x int PRIMARY KEY)") @@ -73,6 +73,16 @@ func TestDiagnosticsRequest(t *testing.T) { require.True(t, diagnosticsID.Valid == expectedCompleted) return nil } + setCollectUntilExpiration := func(v bool) { + _, err := db.Exec( + fmt.Sprintf("SET CLUSTER SETTING sql.stmt_diagnostics.collect_continuously.enabled = %t", v)) + require.NoError(t, err) + } + setPollInterval := func(d time.Duration) { + _, err := db.Exec( + fmt.Sprintf("SET CLUSTER SETTING sql.stmt_diagnostics.poll_interval = '%s'", d)) + require.NoError(t, err) + } registry := s.ExecutorConfig().(sql.ExecutorConfig).StmtDiagnosticsRecorder var minExecutionLatency, expiresAfter time.Duration @@ -300,8 +310,127 @@ func TestDiagnosticsRequest(t *testing.T) { if completed { return nil } - return errors.New("expected to capture stmt bundle") + return errors.New("expected to capture diagnostic bundle") + }) + }) + + t.Run("sampling without latency threshold disallowed", func(t *testing.T) { + samplingProbability, expiresAfter := 0.5, time.Second + _, err := registry.InsertRequestInternal(ctx, "SELECT pg_sleep(_)", + samplingProbability, 0 /* minExecutionLatency */, expiresAfter) + testutils.IsError(err, "empty min exec latency") + }) + + t.Run("continuous capture disabled without sampling probability", func(t *testing.T) { + // We validate that continuous captures is disabled when a sampling + // probability of 0.0 is used. We know that it's disabled given the + // diagnostic request is marked as completed despite us not getting to + // the expiration point +1h from now (we don't mark continuous captures + // as completed until they've expired). + samplingProbability, minExecutionLatency, expiresAfter := 0.0, time.Microsecond, time.Hour + reqID, err := registry.InsertRequestInternal(ctx, "SELECT pg_sleep(_)", + samplingProbability, minExecutionLatency, expiresAfter) + require.NoError(t, err) + checkNotCompleted(reqID) + + setCollectUntilExpiration(true) + defer setCollectUntilExpiration(false) + + testutils.SucceedsSoon(t, func() error { + _, err := db.Exec("SELECT pg_sleep(0.01)") // run the query + require.NoError(t, err) + completed, _ := isCompleted(reqID) + if completed { + return nil + } + return errors.New("expected request to have been completed") + }) + }) + + t.Run("continuous capture disabled without expiration timestamp", func(t *testing.T) { + // We don't mark continuous captures as completed until they've expired, + // so we require an explicit expiration set. See previous test case for + // some commentary. + samplingProbability, minExecutionLatency, expiresAfter := 0.999, time.Microsecond, 0*time.Hour + reqID, err := registry.InsertRequestInternal(ctx, "SELECT pg_sleep(_)", + samplingProbability, minExecutionLatency, expiresAfter) + require.NoError(t, err) + checkNotCompleted(reqID) + + setCollectUntilExpiration(true) + defer setCollectUntilExpiration(false) + + testutils.SucceedsSoon(t, func() error { + _, err := db.Exec("SELECT pg_sleep(0.01)") // run the query + require.NoError(t, err) + completed, _ := isCompleted(reqID) + if completed { + return nil + } + return errors.New("expected request to have been completed") + }) + }) + + t.Run("continuous capture", func(t *testing.T) { + samplingProbability, minExecutionLatency, expiresAfter := 0.9999, time.Microsecond, time.Hour + reqID, err := registry.InsertRequestInternal(ctx, "SELECT pg_sleep(_)", + samplingProbability, minExecutionLatency, expiresAfter) + require.NoError(t, err) + checkNotCompleted(reqID) + + setCollectUntilExpiration(true) + defer setCollectUntilExpiration(false) + + var firstDiagnosticID int64 + testutils.SucceedsSoon(t, func() error { + _, err := db.Exec("SELECT pg_sleep(0.01)") // run the query + require.NoError(t, err) + completed, diagnosticID := isCompleted(reqID) + if !diagnosticID.Valid { + return errors.New("expected to capture diagnostic bundle") + } + require.False(t, completed) // should not be marked as completed + if firstDiagnosticID == 0 { + firstDiagnosticID = diagnosticID.Int64 + } + if firstDiagnosticID == diagnosticID.Int64 { + return errors.New("waiting to capture second bundle") + } + return nil }) + + require.NoError(t, registry.CancelRequest(ctx, reqID)) + }) + + t.Run("continuous capture until expiration", func(t *testing.T) { + samplingProbability, minExecutionLatency, expiresAfter := 0.9999, time.Microsecond, 100*time.Millisecond + reqID, err := registry.InsertRequestInternal( + ctx, "SELECT pg_sleep(_)", samplingProbability, minExecutionLatency, expiresAfter, + ) + require.NoError(t, err) + checkNotCompleted(reqID) + + setCollectUntilExpiration(true) + defer setCollectUntilExpiration(false) + + // Sleep until expiration (and then some), and then run the query. + time.Sleep(expiresAfter + 100*time.Millisecond) + + setPollInterval(10 * time.Millisecond) + defer setPollInterval(stmtdiagnostics.PollingInterval.Default()) + + // We should not find the request and a subsequent executions should not + // capture anything. + testutils.SucceedsSoon(t, func() error { + if found := registry.TestingFindRequest(stmtdiagnostics.RequestID(reqID)); found { + return errors.New("expected expired request to no longer be tracked") + } + return nil + }) + + _, err = db.Exec("SELECT pg_sleep(0.01)") // run the query + require.NoError(t, err) + checkNotCompleted(reqID) }) }