diff --git a/pkg/sql/colflow/vectorized_flow.go b/pkg/sql/colflow/vectorized_flow.go index 82c57e4777ea..fec35e23c785 100644 --- a/pkg/sql/colflow/vectorized_flow.go +++ b/pkg/sql/colflow/vectorized_flow.go @@ -369,6 +369,10 @@ func (f *vectorizedFlow) MemUsage() int64 { // Cleanup is part of the flowinfra.Flow interface. func (f *vectorizedFlow) Cleanup(ctx context.Context) { + startCleanup, endCleanup := f.FlowBase.GetOnCleanupFns() + startCleanup() + defer endCleanup() + // This cleans up all the memory and disk monitoring of the vectorized flow // as well as closes all the closers. f.creator.cleanup(ctx) diff --git a/pkg/sql/distsql/server.go b/pkg/sql/distsql/server.go index ac7d23891318..a13dc5329452 100644 --- a/pkg/sql/distsql/server.go +++ b/pkg/sql/distsql/server.go @@ -216,7 +216,7 @@ func (ds *ServerImpl) setupFlow( ) (retCtx context.Context, _ flowinfra.Flow, _ execopnode.OpChains, retErr error) { var sp *tracing.Span // will be Finish()ed by Flow.Cleanup() var monitor *mon.BytesMonitor // will be closed in Flow.Cleanup() - var onFlowCleanup func() + var onFlowCleanupEnd func() // will be called at the very end of Flow.Cleanup() // Make sure that we clean up all resources (which in the happy case are // cleaned up in Flow.Cleanup()) if an error is encountered. defer func() { @@ -224,8 +224,8 @@ func (ds *ServerImpl) setupFlow( if monitor != nil { monitor.Stop(ctx) } - if onFlowCleanup != nil { - onFlowCleanup() + if onFlowCleanupEnd != nil { + onFlowCleanupEnd() } else { reserved.Close(ctx) } @@ -307,7 +307,7 @@ func (ds *ServerImpl) setupFlow( // the whole evalContext, but that isn't free, so we choose to restore // the original state in order to avoid performance regressions. origTxn := evalCtx.Txn - onFlowCleanup = func() { + onFlowCleanupEnd = func() { evalCtx.Txn = origTxn reserved.Close(ctx) } @@ -322,7 +322,7 @@ func (ds *ServerImpl) setupFlow( evalCtx.Txn = leafTxn } } else { - onFlowCleanup = func() { + onFlowCleanupEnd = func() { reserved.Close(ctx) } if localState.IsLocal { @@ -388,7 +388,7 @@ func (ds *ServerImpl) setupFlow( isVectorized := req.EvalContext.SessionData.VectorizeMode != sessiondatapb.VectorizeOff f := newFlow( flowCtx, sp, ds.flowRegistry, rowSyncFlowConsumer, batchSyncFlowConsumer, - localState.LocalProcs, isVectorized, onFlowCleanup, req.StatementSQL, + localState.LocalProcs, isVectorized, onFlowCleanupEnd, req.StatementSQL, ) opt := flowinfra.FuseNormally if !localState.MustUseLeafTxn() { @@ -521,10 +521,10 @@ func newFlow( batchSyncFlowConsumer execinfra.BatchReceiver, localProcessors []execinfra.LocalProcessor, isVectorized bool, - onFlowCleanup func(), + onFlowCleanupEnd func(), statementSQL string, ) flowinfra.Flow { - base := flowinfra.NewFlowBase(flowCtx, sp, flowReg, rowSyncFlowConsumer, batchSyncFlowConsumer, localProcessors, onFlowCleanup, statementSQL) + base := flowinfra.NewFlowBase(flowCtx, sp, flowReg, rowSyncFlowConsumer, batchSyncFlowConsumer, localProcessors, onFlowCleanupEnd, statementSQL) if isVectorized { return colflow.NewVectorizedFlow(base) } diff --git a/pkg/sql/distsql_running.go b/pkg/sql/distsql_running.go index bfb1a582594b..520300805010 100644 --- a/pkg/sql/distsql_running.go +++ b/pkg/sql/distsql_running.go @@ -43,7 +43,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/sessiondatapb" "github.com/cockroachdb/cockroach/pkg/sql/sqltelemetry" "github.com/cockroachdb/cockroach/pkg/sql/types" - "github.com/cockroachdb/cockroach/pkg/util/buildutil" "github.com/cockroachdb/cockroach/pkg/util/contextutil" "github.com/cockroachdb/cockroach/pkg/util/errorutil/unimplemented" "github.com/cockroachdb/cockroach/pkg/util/hlc" @@ -100,23 +99,25 @@ type runnerResult struct { // run executes the request. An error, if encountered, is both sent on the // result channel and returned. func (req runnerRequest) run() error { - defer physicalplan.ReleaseFlowSpec(&req.flowReq.Flow) res := runnerResult{nodeID: req.sqlInstanceID} + defer func() { + req.resultChan <- res + physicalplan.ReleaseFlowSpec(&req.flowReq.Flow) + }() conn, err := req.podNodeDialer.Dial(req.ctx, roachpb.NodeID(req.sqlInstanceID), rpc.DefaultClass) + if err != nil { + res.err = err + return err + } + client := execinfrapb.NewDistSQLClient(conn) + // TODO(radu): do we want a timeout here? + resp, err := client.SetupFlow(req.ctx, req.flowReq) if err != nil { res.err = err } else { - client := execinfrapb.NewDistSQLClient(conn) - // TODO(radu): do we want a timeout here? - resp, err := client.SetupFlow(req.ctx, req.flowReq) - if err != nil { - res.err = err - } else { - res.err = resp.Error.ErrorDetail(req.ctx) - } + res.err = resp.Error.ErrorDetail(req.ctx) } - req.resultChan <- res return res.err } @@ -464,12 +465,13 @@ func (dsp *DistSQLPlanner) setupFlows( // Start all the remote flows. // - // numAsyncRequests tracks the number of the SetupFlow RPCs that were - // delegated to the DistSQL runner goroutines. - var numAsyncRequests int - // numSerialRequests tracks the number of the SetupFlow RPCs that were - // issued by the current goroutine on its own. - var numSerialRequests int + // usedWorker indicates whether we used at least one DistSQL worker + // goroutine to issue the SetupFlow RPC. + var usedWorker bool + // numIssuedRequests tracks the number of the SetupFlow RPCs that were + // issued (either by the current goroutine directly or delegated to the + // DistSQL workers). + var numIssuedRequests int if sp := tracing.SpanFromContext(origCtx); sp != nil && !sp.IsNoop() { setupReq.TraceInfo = sp.Meta().ToProto() } @@ -515,7 +517,7 @@ func (dsp *DistSQLPlanner) setupFlows( // // Note that even in case of an error in runnerRequest.run we still // send on the result channel. - for i := 0; i < numAsyncRequests+numSerialRequests; i++ { + for i := 0; i < numIssuedRequests; i++ { <-resultChan } // At this point, we know that all concurrent requests (if there @@ -541,11 +543,11 @@ func (dsp *DistSQLPlanner) setupFlows( // Send out a request to the workers; if no worker is available, run // directly. + numIssuedRequests++ select { case dsp.runnerCoordinator.runnerChan <- runReq: - numAsyncRequests++ + usedWorker = true default: - numSerialRequests++ // Use the context of the local flow since we're executing this // SetupFlow RPC synchronously. runReq.ctx = ctx @@ -554,16 +556,8 @@ func (dsp *DistSQLPlanner) setupFlows( } } } - if buildutil.CrdbTestBuild { - if numAsyncRequests+numSerialRequests != len(flows)-1 { - return ctx, flow, errors.AssertionFailedf( - "expected %d requests, found only %d async and %d serial", - len(flows)-1, numAsyncRequests, numSerialRequests, - ) - } - } - if numAsyncRequests == 0 { + if !usedWorker { // We executed all SetupFlow RPCs in the current goroutine, and all RPCs // succeeded. return ctx, flow, nil @@ -586,7 +580,7 @@ func (dsp *DistSQLPlanner) setupFlows( syncutil.Mutex called bool }{} - flow.AddOnCleanup(func() { + flow.AddOnCleanupStart(func() { cleanupCalledMu.Lock() defer cleanupCalledMu.Unlock() cleanupCalledMu.called = true @@ -605,24 +599,29 @@ func (dsp *DistSQLPlanner) setupFlows( for i := 0; i < len(flows)-1; i++ { res := <-resultChan if res.err != nil && !seenError { - seenError = true // The setup of at least one remote flow failed. - cleanupCalledMu.Lock() - skipCancel := cleanupCalledMu.called - cleanupCalledMu.Unlock() - if skipCancel { - continue - } - // First, we update the DistSQL receiver with the error to be - // returned to the client eventually. - // - // In order to not protect DistSQLReceiver.status with a mutex, - // we do not update the status here and, instead, rely on the - // DistSQLReceiver detecting the error the next time an object - // is pushed into it. - recv.setErrorWithoutStatusUpdate(res.err, true /* willDeferStatusUpdate */) - // Now explicitly cancel the local flow. - flow.Cancel() + seenError = true + func() { + cleanupCalledMu.Lock() + // Flow.Cancel cannot be called after or concurrently with + // Flow.Cleanup. + defer cleanupCalledMu.Unlock() + if cleanupCalledMu.called { + // Cleanup of the local flow has already been performed, + // so there is nothing to do. + return + } + // First, we update the DistSQL receiver with the error to + // be returned to the client eventually. + // + // In order to not protect DistSQLReceiver.status with a + // mutex, we do not update the status here and, instead, + // rely on the DistSQLReceiver detecting the error the next + // time an object is pushed into it. + recv.setErrorWithoutStatusUpdate(res.err, true /* willDeferStatusUpdate */) + // Now explicitly cancel the local flow. + flow.Cancel() + }() } } }) diff --git a/pkg/sql/flowinfra/flow.go b/pkg/sql/flowinfra/flow.go index af55c0b3de35..19c29a9295f8 100644 --- a/pkg/sql/flowinfra/flow.go +++ b/pkg/sql/flowinfra/flow.go @@ -140,10 +140,14 @@ type Flow interface { // Cleanup. Cancel() - // AddOnCleanup adds a callback to be executed at the very end of Cleanup. - // Callbacks are put on the stack meaning that AddOnCleanup is called - // multiple times, then the "later" callbacks are executed first. - AddOnCleanup(fn func()) + // AddOnCleanupStart adds a callback to be executed at the very beginning of + // Cleanup. + AddOnCleanupStart(fn func()) + + // GetOnCleanupFns returns a couple of functions that should be called at + // the very beginning and the very end of Cleanup, respectively. Both will + // be non-nil. + GetOnCleanupFns() (startCleanup, endCleanup func()) // Cleanup must be called whenever the flow is done (meaning it either // completes gracefully after all processors and mailboxes exited or an @@ -200,7 +204,10 @@ type FlowBase struct { // - outboxes waitGroup sync.WaitGroup - onFlowCleanup func() + // onCleanupStart and onCleanupEnd will be called in the very beginning and + // the very end of Cleanup(), respectively. + onCleanupStart func() + onCleanupEnd func() statementSQL string @@ -276,7 +283,7 @@ func NewFlowBase( rowSyncFlowConsumer execinfra.RowReceiver, batchSyncFlowConsumer execinfra.BatchReceiver, localProcessors []execinfra.LocalProcessor, - onFlowCleanup func(), + onFlowCleanupEnd func(), statementSQL string, ) *FlowBase { // We are either in a single tenant cluster, or a SQL node in a multi-tenant @@ -300,7 +307,7 @@ func NewFlowBase( batchSyncFlowConsumer: batchSyncFlowConsumer, localProcessors: localProcessors, admissionInfo: admissionInfo, - onFlowCleanup: onFlowCleanup, + onCleanupEnd: onFlowCleanupEnd, status: flowNotStarted, statementSQL: statementSQL, } @@ -527,17 +534,31 @@ func (f *FlowBase) Cancel() { f.ctxCancel() } -// AddOnCleanup is part of the Flow interface. -func (f *FlowBase) AddOnCleanup(fn func()) { - if f.onFlowCleanup != nil { - oldOnFlowCleanup := f.onFlowCleanup - f.onFlowCleanup = func() { +// AddOnCleanupStart is part of the Flow interface. +func (f *FlowBase) AddOnCleanupStart(fn func()) { + if f.onCleanupStart != nil { + oldOnCleanupStart := f.onCleanupStart + f.onCleanupStart = func() { fn() - oldOnFlowCleanup() + oldOnCleanupStart() } } else { - f.onFlowCleanup = fn + f.onCleanupStart = fn + } +} + +var noopFn = func() {} + +// GetOnCleanupFns is part of the Flow interface. +func (f *FlowBase) GetOnCleanupFns() (startCleanup, endCleanup func()) { + onCleanupStart, onCleanupEnd := f.onCleanupStart, f.onCleanupEnd + if onCleanupStart == nil { + onCleanupStart = noopFn } + if onCleanupEnd == nil { + onCleanupEnd = noopFn + } + return onCleanupStart, onCleanupEnd } // Cleanup is part of the Flow interface. @@ -594,9 +615,6 @@ func (f *FlowBase) Cleanup(ctx context.Context) { } f.status = flowFinished f.ctxCancel() - if f.onFlowCleanup != nil { - f.onFlowCleanup() - } } // cancel cancels all unconnected streams of this flow. This function is called diff --git a/pkg/sql/rowflow/row_based_flow.go b/pkg/sql/rowflow/row_based_flow.go index 489d2e348447..3e604504c246 100644 --- a/pkg/sql/rowflow/row_based_flow.go +++ b/pkg/sql/rowflow/row_based_flow.go @@ -439,6 +439,9 @@ func (f *rowBasedFlow) Release() { // Cleanup is part of the flowinfra.Flow interface. func (f *rowBasedFlow) Cleanup(ctx context.Context) { + startCleanup, endCleanup := f.FlowBase.GetOnCleanupFns() + startCleanup() + defer endCleanup() f.FlowBase.Cleanup(ctx) f.Release() }