diff --git a/pkg/sql/instrumentation.go b/pkg/sql/instrumentation.go index 14ddffdc608e..4f30c6047e71 100644 --- a/pkg/sql/instrumentation.go +++ b/pkg/sql/instrumentation.go @@ -363,9 +363,8 @@ func (ih *instrumentationHelper) Finish( p.SessionData(), ) phaseTimes := statsCollector.PhaseTimes() - if ih.stmtDiagnosticsRecorder.IsExecLatencyConditionMet( - ih.diagRequestID, ih.diagRequest, phaseTimes.GetServiceLatencyNoOverhead(), - ) { + execLatency := phaseTimes.GetServiceLatencyNoOverhead() + if ih.stmtDiagnosticsRecorder.IsConditionSatisfied(ih.diagRequest, execLatency) { placeholders := p.extendedEvalCtx.Placeholders ob := ih.emitExplainAnalyzePlanToOutputBuilder( explain.Flags{Verbose: true, ShowTypes: true}, @@ -377,9 +376,9 @@ func (ih *instrumentationHelper) Finish( ctx, cfg.DB, ie.(*InternalExecutor), &p.curPlan, ob.BuildString(), trace, placeholders, ) bundle.insert(ctx, ih.fingerprint, ast, cfg.StmtDiagnosticsRecorder, ih.diagRequestID, ih.diagRequest) - ih.stmtDiagnosticsRecorder.RemoveOngoing(ih.diagRequestID, ih.diagRequest) telemetry.Inc(sqltelemetry.StatementDiagnosticsCollectedCounter) } + ih.stmtDiagnosticsRecorder.MaybeRemoveRequest(ih.diagRequestID, ih.diagRequest, execLatency) } // If there was a communication error already, no point in setting any diff --git a/pkg/sql/stmtdiagnostics/statement_diagnostics.go b/pkg/sql/stmtdiagnostics/statement_diagnostics.go index 77e8ecb96eac..c8a07e73947d 100644 --- a/pkg/sql/stmtdiagnostics/statement_diagnostics.go +++ b/pkg/sql/stmtdiagnostics/statement_diagnostics.go @@ -128,6 +128,14 @@ func (r *Request) isConditional() bool { return r.minExecutionLatency != 0 } +// continueCollecting returns true if we want to continue collecting bundles for +// this request. +func (r *Request) continueCollecting(st *cluster.Settings) bool { + return collectUntilExpiration.Get(&st.SV) && // continuous collection must be enabled + r.samplingProbability != 0 && !r.expiresAt.IsZero() && // conditions for continuous collection must be set + !r.isExpired(timeutil.Now()) // the request must not have expired yet +} + // NewRegistry constructs a new Registry. func NewRegistry(ie sqlutil.InternalExecutor, db *kv.DB, st *cluster.Settings) *Registry { r := &Registry{ @@ -410,35 +418,28 @@ func (r *Registry) CancelRequest(ctx context.Context, requestID int64) error { return nil } -// IsExecLatencyConditionMet returns true if the completed request's execution -// latency satisfies the request's condition. If false is returned, it inlines -// the logic of RemoveOngoing. -func (r *Registry) IsExecLatencyConditionMet( - requestID RequestID, req Request, execLatency time.Duration, -) bool { - if req.minExecutionLatency <= execLatency { - return true - } - // This is a conditional request and the condition is not satisfied, so we - // only need to remove the request if it has expired. - if req.isExpired(timeutil.Now()) { - r.mu.Lock() - defer r.mu.Unlock() - delete(r.mu.requestFingerprints, requestID) - } - return false +// IsConditionSatisfied returns whether the completed request satisfies its +// condition. +func (r *Registry) IsConditionSatisfied(req Request, execLatency time.Duration) bool { + return req.minExecutionLatency <= execLatency } -// RemoveOngoing removes the given request from the list of ongoing queries. -func (r *Registry) RemoveOngoing(requestID RequestID, req Request) { - r.mu.Lock() - defer r.mu.Unlock() - if req.isConditional() { - if req.isExpired(timeutil.Now()) { +// MaybeRemoveRequest checks whether the request needs to be removed from the +// local Registry and removes it if so. Note that the registries on other nodes +// will learn about it via polling of the system table. +func (r *Registry) MaybeRemoveRequest(requestID RequestID, req Request, execLatency time.Duration) { + // We should remove the request from the registry if its condition is + // satisfied unless we want to continue collecting bundles for this request. + shouldRemove := r.IsConditionSatisfied(req, execLatency) && !req.continueCollecting(r.st) + // Always remove the expired requests. + if shouldRemove || req.isExpired(timeutil.Now()) { + r.mu.Lock() + defer r.mu.Unlock() + if req.isConditional() { delete(r.mu.requestFingerprints, requestID) + } else { + delete(r.mu.unconditionalOngoing, requestID) } - } else { - delete(r.mu.unconditionalOngoing, requestID) } } @@ -448,8 +449,7 @@ func (r *Registry) RemoveOngoing(requestID RequestID, req Request) { // case ShouldCollectDiagnostics will return true again on this node for the // same diagnostics request only for conditional requests. // -// If shouldCollect is true, RemoveOngoing needs to be called (which is inlined -// by IsExecLatencyConditionMet when that returns false). +// If shouldCollect is true, MaybeRemoveRequest needs to be called. func (r *Registry) ShouldCollectDiagnostics( ctx context.Context, fingerprint string, ) (shouldCollect bool, reqID RequestID, req Request) { diff --git a/pkg/sql/stmtdiagnostics/statement_diagnostics_helpers_test.go b/pkg/sql/stmtdiagnostics/statement_diagnostics_helpers_test.go index d8ec19eba234..541b3414c94b 100644 --- a/pkg/sql/stmtdiagnostics/statement_diagnostics_helpers_test.go +++ b/pkg/sql/stmtdiagnostics/statement_diagnostics_helpers_test.go @@ -16,10 +16,10 @@ import ( ) // TestingFindRequest exports findRequest for testing purposes. -func (r *Registry) TestingFindRequest(requestID RequestID) bool { +func (r *Registry) TestingFindRequest(requestID int64) bool { r.mu.Lock() defer r.mu.Unlock() - return r.findRequestLocked(requestID) + return r.findRequestLocked(RequestID(requestID)) } // InsertRequestInternal exposes the form of insert which returns the request ID diff --git a/pkg/sql/stmtdiagnostics/statement_diagnostics_test.go b/pkg/sql/stmtdiagnostics/statement_diagnostics_test.go index a9cbbd92d0f4..81d71d5d9a15 100644 --- a/pkg/sql/stmtdiagnostics/statement_diagnostics_test.go +++ b/pkg/sql/stmtdiagnostics/statement_diagnostics_test.go @@ -45,13 +45,24 @@ func TestDiagnosticsRequest(t *testing.T) { s, db, _ := serverutils.StartServer(t, base.TestServerArgs{}) ctx := context.Background() defer s.Stopper().Stop(ctx) + registry := s.ExecutorConfig().(sql.ExecutorConfig).StmtDiagnosticsRecorder _, err := db.Exec("CREATE TABLE test (x int PRIMARY KEY)") require.NoError(t, err) - completedQuery := "SELECT completed, statement_diagnostics_id FROM system.statement_diagnostics_requests WHERE ID = $1" + var collectUntilExpirationEnabled bool isCompleted := func(reqID int64) (completed bool, diagnosticsID gosql.NullInt64) { + completedQuery := "SELECT completed, statement_diagnostics_id FROM system.statement_diagnostics_requests WHERE ID = $1" reqRow := db.QueryRow(completedQuery, reqID) require.NoError(t, reqRow.Scan(&completed, &diagnosticsID)) + if completed && !collectUntilExpirationEnabled { + // Ensure that if the request was completed and the continuous + // collection is not enabled, the local registry no longer has the + // request. + require.False( + t, registry.TestingFindRequest(reqID), "request was "+ + "completed and should have been removed from the registry", + ) + } return completed, diagnosticsID } checkNotCompleted := func(reqID int64) { @@ -76,6 +87,7 @@ func TestDiagnosticsRequest(t *testing.T) { return nil } setCollectUntilExpiration := func(v bool) { + collectUntilExpirationEnabled = v _, err := db.Exec( fmt.Sprintf("SET CLUSTER SETTING sql.stmt_diagnostics.collect_continuously.enabled = %t", v)) require.NoError(t, err) @@ -86,7 +98,6 @@ func TestDiagnosticsRequest(t *testing.T) { require.NoError(t, err) } - registry := s.ExecutorConfig().(sql.ExecutorConfig).StmtDiagnosticsRecorder var minExecutionLatency, expiresAfter time.Duration var samplingProbability float64 @@ -445,7 +456,7 @@ func TestDiagnosticsRequest(t *testing.T) { // We should not find the request and a subsequent executions should not // capture anything. testutils.SucceedsSoon(t, func() error { - if found := registry.TestingFindRequest(stmtdiagnostics.RequestID(reqID)); found { + if found := registry.TestingFindRequest(reqID); found { return errors.New("expected expired request to no longer be tracked") } return nil