diff --git a/executor/builder.go b/executor/builder.go index 037123cc1fa96..3ef6d96358e0b 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -49,6 +49,7 @@ import ( "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/admin" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/cteutil" "github.com/pingcap/tidb/util/dbterror" "github.com/pingcap/tidb/util/execdetails" "github.com/pingcap/tidb/util/logutil" @@ -83,6 +84,14 @@ type executorBuilder struct { hasLock bool } +// CTEStorages stores resTbl and iterInTbl for CTEExec. +// There will be a map[CTEStorageID]*CTEStorages in StmtCtx, +// which will store all CTEStorages to make all shared CTEs use same the CTEStorages. +type CTEStorages struct { + ResTbl cteutil.Storage + IterInTbl cteutil.Storage +} + func newExecutorBuilder(ctx sessionctx.Context, is infoschema.InfoSchema) *executorBuilder { return &executorBuilder{ ctx: ctx, @@ -235,6 +244,10 @@ func (b *executorBuilder) build(p plannercore.Plan) Executor { return b.buildAdminShowTelemetry(v) case *plannercore.AdminResetTelemetryID: return b.buildAdminResetTelemetryID(v) + case *plannercore.PhysicalCTE: + return b.buildCTE(v) + case *plannercore.PhysicalCTETable: + return b.buildCTETableReader(v) default: if mp, ok := p.(MockPhysicalPlan); ok { return mp.GetExecutor() @@ -4072,3 +4085,90 @@ func (b *executorBuilder) buildTableSample(v *plannercore.PhysicalTableSample) * } return e } + +func (b *executorBuilder) buildCTE(v *plannercore.PhysicalCTE) Executor { + // 1. Build seedPlan. + seedExec := b.build(v.SeedPlan) + if b.err != nil { + return nil + } + + // 2. Build iterInTbl. + chkSize := b.ctx.GetSessionVars().MaxChunkSize + tps := seedExec.base().retFieldTypes + iterOutTbl := cteutil.NewStorageRowContainer(tps, chkSize) + if err := iterOutTbl.OpenAndRef(); err != nil { + b.err = err + return nil + } + + var resTbl cteutil.Storage + var iterInTbl cteutil.Storage + storageMap, ok := b.ctx.GetSessionVars().StmtCtx.CTEStorageMap.(map[int]*CTEStorages) + if !ok { + b.err = errors.New("type assertion for CTEStorageMap failed") + return nil + } + storages, ok := storageMap[v.CTE.IDForStorage] + if ok { + // Storage already setup. + resTbl = storages.ResTbl + iterInTbl = storages.IterInTbl + } else { + resTbl = cteutil.NewStorageRowContainer(tps, chkSize) + if err := resTbl.OpenAndRef(); err != nil { + b.err = err + return nil + } + iterInTbl = cteutil.NewStorageRowContainer(tps, chkSize) + if err := iterInTbl.OpenAndRef(); err != nil { + b.err = err + return nil + } + storageMap[v.CTE.IDForStorage] = &CTEStorages{ResTbl: resTbl, IterInTbl: iterInTbl} + } + + // 3. Build recursive part. + recursiveExec := b.build(v.RecurPlan) + if b.err != nil { + return nil + } + + var sel []int + if v.CTE.IsDistinct { + sel = make([]int, chkSize) + for i := 0; i < chkSize; i++ { + sel[i] = i + } + } + + return &CTEExec{ + baseExecutor: newBaseExecutor(b.ctx, v.Schema(), v.ID()), + seedExec: seedExec, + recursiveExec: recursiveExec, + resTbl: resTbl, + iterInTbl: iterInTbl, + iterOutTbl: iterOutTbl, + chkIdx: 0, + isDistinct: v.CTE.IsDistinct, + sel: sel, + } +} + +func (b *executorBuilder) buildCTETableReader(v *plannercore.PhysicalCTETable) Executor { + storageMap, ok := b.ctx.GetSessionVars().StmtCtx.CTEStorageMap.(map[int]*CTEStorages) + if !ok { + b.err = errors.New("type assertion for CTEStorageMap failed") + return nil + } + storages, ok := storageMap[v.IDForStorage] + if !ok { + b.err = errors.Errorf("iterInTbl should already be set up by CTEExec(id: %d)", v.IDForStorage) + return nil + } + return &CTETableReaderExec{ + baseExecutor: newBaseExecutor(b.ctx, v.Schema(), v.ID()), + iterInTbl: storages.IterInTbl, + chkIdx: 0, + } +} diff --git a/executor/cte.go b/executor/cte.go new file mode 100644 index 0000000000000..a5e063e9dc9ee --- /dev/null +++ b/executor/cte.go @@ -0,0 +1,490 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "context" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/config" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/codec" + "github.com/pingcap/tidb/util/cteutil" + "github.com/pingcap/tidb/util/memory" +) + +var _ Executor = &CTEExec{} + +// CTEExec implements CTE. +// Following diagram describes how CTEExec works. +// +// `iterInTbl` is shared by `CTEExec` and `CTETableReaderExec`. +// `CTETableReaderExec` reads data from `iterInTbl`, +// and its output will be stored `iterOutTbl` by `CTEExec`. +// +// When an iteration ends, `CTEExec` will move all data from `iterOutTbl` into `iterInTbl`, +// which will be the input for new iteration. +// At the end of each iteration, data in `iterOutTbl` will also be added into `resTbl`. +// `resTbl` stores data of all iteration. +// +----------+ +// write |iterOutTbl| +// CTEExec ------------------->| | +// | +----+-----+ +// ------------- | write +// | | v +// other op other op +----------+ +// (seed) (recursive) | resTbl | +// ^ | | +// | +----------+ +// CTETableReaderExec +// ^ +// | read +----------+ +// +---------------+iterInTbl | +// | | +// +----------+ +type CTEExec struct { + baseExecutor + + seedExec Executor + recursiveExec Executor + + // `resTbl` and `iterInTbl` are shared by all CTEExec which reference to same the CTE. + // `iterInTbl` is also shared by CTETableReaderExec. + resTbl cteutil.Storage + iterInTbl cteutil.Storage + iterOutTbl cteutil.Storage + + hashTbl baseHashTable + + // Index of chunk to read from `resTbl`. + chkIdx int + + // UNION ALL or UNION DISTINCT. + isDistinct bool + curIter int + hCtx *hashContext + sel []int +} + +// Open implements the Executor interface. +func (e *CTEExec) Open(ctx context.Context) (err error) { + e.reset() + if err := e.baseExecutor.Open(ctx); err != nil { + return err + } + + if e.seedExec == nil { + return errors.New("seedExec for CTEExec is nil") + } + if err = e.seedExec.Open(ctx); err != nil { + return err + } + + if e.recursiveExec != nil { + if err = e.recursiveExec.Open(ctx); err != nil { + return err + } + recursiveTypes := e.recursiveExec.base().retFieldTypes + e.iterOutTbl = cteutil.NewStorageRowContainer(recursiveTypes, e.maxChunkSize) + if err = e.iterOutTbl.OpenAndRef(); err != nil { + return err + } + + setupCTEStorageTracker(e.iterOutTbl, e.ctx) + } + + if e.isDistinct { + e.hashTbl = newConcurrentMapHashTable() + e.hCtx = &hashContext{ + allTypes: e.base().retFieldTypes, + } + // We use all columns to compute hash. + e.hCtx.keyColIdx = make([]int, len(e.hCtx.allTypes)) + for i := range e.hCtx.keyColIdx { + e.hCtx.keyColIdx[i] = i + } + } + return nil +} + +// Next implements the Executor interface. +func (e *CTEExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { + req.Reset() + e.resTbl.Lock() + if !e.resTbl.Done() { + defer e.resTbl.Unlock() + resAction := setupCTEStorageTracker(e.resTbl, e.ctx) + iterInAction := setupCTEStorageTracker(e.iterInTbl, e.ctx) + + failpoint.Inject("testCTEStorageSpill", func(val failpoint.Value) { + if val.(bool) && config.GetGlobalConfig().OOMUseTmpStorage { + defer resAction.WaitForTest() + defer iterInAction.WaitForTest() + } + }) + + if err = e.computeSeedPart(ctx); err != nil { + // Don't put it in defer. + // Because it should be called only when the filling process is not completed. + if err1 := e.reopenTbls(); err1 != nil { + return err1 + } + return err + } + if err = e.computeRecursivePart(ctx); err != nil { + if err1 := e.reopenTbls(); err1 != nil { + return err1 + } + return err + } + e.resTbl.SetDone() + } else { + e.resTbl.Unlock() + } + + if e.chkIdx < e.resTbl.NumChunks() { + res, err := e.resTbl.GetChunk(e.chkIdx) + if err != nil { + return err + } + // Need to copy chunk to make sure upper operator will not change chunk in resTbl. + // Also we ignore copying rows not selected, because some operators like Projection + // doesn't support swap column if chunk.sel is no nil. + req.SwapColumns(res.CopyConstructSel()) + e.chkIdx++ + } + return nil +} + +// Close implements the Executor interface. +func (e *CTEExec) Close() (err error) { + e.reset() + if err = e.seedExec.Close(); err != nil { + return err + } + if e.recursiveExec != nil { + if err = e.recursiveExec.Close(); err != nil { + return err + } + } + + // `iterInTbl` and `resTbl` are shared by multiple operators, + // so will be closed when the SQL finishes. + if err = e.iterOutTbl.DerefAndClose(); err != nil { + return err + } + return e.baseExecutor.Close() +} + +func (e *CTEExec) computeSeedPart(ctx context.Context) (err error) { + e.curIter = 0 + e.iterInTbl.SetIter(e.curIter) + // This means iterInTbl's can be read. + defer close(e.iterInTbl.GetBegCh()) + chks := make([]*chunk.Chunk, 0, 10) + for { + chk := newFirstChunk(e.seedExec) + if err = Next(ctx, e.seedExec, chk); err != nil { + return err + } + if chk.NumRows() == 0 { + break + } + if chk, err = e.tryDedupAndAdd(chk, e.iterInTbl, e.hashTbl); err != nil { + return err + } + chks = append(chks, chk) + } + // Initial resTbl is empty, so no need to deduplicate chk using resTbl. + // Just adding is ok. + for _, chk := range chks { + if err = e.resTbl.Add(chk); err != nil { + return err + } + } + e.curIter++ + e.iterInTbl.SetIter(e.curIter) + + return nil +} + +func (e *CTEExec) computeRecursivePart(ctx context.Context) (err error) { + if e.recursiveExec == nil || e.iterInTbl.NumChunks() == 0 { + return nil + } + + if e.curIter > e.ctx.GetSessionVars().CTEMaxRecursionDepth { + return ErrCTEMaxRecursionDepth.GenWithStackByArgs(e.curIter) + } + + for { + chk := newFirstChunk(e.recursiveExec) + if err = Next(ctx, e.recursiveExec, chk); err != nil { + return err + } + if chk.NumRows() == 0 { + if err = e.setupTblsForNewIteration(); err != nil { + return err + } + if e.iterInTbl.NumChunks() == 0 { + break + } + // Next iteration begins. Need use iterOutTbl as input of next iteration. + e.curIter++ + e.iterInTbl.SetIter(e.curIter) + if e.curIter > e.ctx.GetSessionVars().CTEMaxRecursionDepth { + return ErrCTEMaxRecursionDepth.GenWithStackByArgs(e.curIter) + } + // Make sure iterInTbl is setup before Close/Open, + // because some executors will read iterInTbl in Open() (like IndexLookupJoin). + if err = e.recursiveExec.Close(); err != nil { + return err + } + if err = e.recursiveExec.Open(ctx); err != nil { + return err + } + } else { + if err = e.iterOutTbl.Add(chk); err != nil { + return err + } + } + } + return nil +} + +func (e *CTEExec) setupTblsForNewIteration() (err error) { + num := e.iterOutTbl.NumChunks() + chks := make([]*chunk.Chunk, 0, num) + // Setup resTbl's data. + for i := 0; i < num; i++ { + chk, err := e.iterOutTbl.GetChunk(i) + if err != nil { + return err + } + // Data should be copied in UNION DISTINCT. + // Because deduplicate() will change data in iterOutTbl, + // which will cause panic when spilling data into disk concurrently. + if e.isDistinct { + chk = chk.CopyConstruct() + } + chk, err = e.tryDedupAndAdd(chk, e.resTbl, e.hashTbl) + if err != nil { + return err + } + chks = append(chks, chk) + } + + // Setup new iteration data in iterInTbl. + if err = e.iterInTbl.Reopen(); err != nil { + return err + } + defer close(e.iterInTbl.GetBegCh()) + if e.isDistinct { + // Already deduplicated by resTbl, adding directly is ok. + for _, chk := range chks { + if err = e.iterInTbl.Add(chk); err != nil { + return err + } + } + } else { + if err = e.iterInTbl.SwapData(e.iterOutTbl); err != nil { + return err + } + } + + // Clear data in iterOutTbl. + return e.iterOutTbl.Reopen() +} + +func (e *CTEExec) reset() { + e.curIter = 0 + e.chkIdx = 0 + e.hashTbl = nil +} + +func (e *CTEExec) reopenTbls() (err error) { + e.hashTbl = newConcurrentMapHashTable() + if err := e.resTbl.Reopen(); err != nil { + return err + } + return e.iterInTbl.Reopen() +} + +func setupCTEStorageTracker(tbl cteutil.Storage, ctx sessionctx.Context) (actionSpill *chunk.SpillDiskAction) { + memTracker := tbl.GetMemTracker() + memTracker.SetLabel(memory.LabelForCTEStorage) + memTracker.AttachTo(ctx.GetSessionVars().StmtCtx.MemTracker) + + diskTracker := tbl.GetDiskTracker() + diskTracker.SetLabel(memory.LabelForCTEStorage) + diskTracker.AttachTo(ctx.GetSessionVars().StmtCtx.DiskTracker) + + if config.GetGlobalConfig().OOMUseTmpStorage { + actionSpill = tbl.ActionSpill() + failpoint.Inject("testCTEStorageSpill", func(val failpoint.Value) { + if val.(bool) { + actionSpill = tbl.(*cteutil.StorageRC).ActionSpillForTest() + } + }) + ctx.GetSessionVars().StmtCtx.MemTracker.FallbackOldAndSetNewAction(actionSpill) + } + return actionSpill +} + +func (e *CTEExec) tryDedupAndAdd(chk *chunk.Chunk, + storage cteutil.Storage, + hashTbl baseHashTable) (res *chunk.Chunk, err error) { + if e.isDistinct { + if chk, err = e.deduplicate(chk, storage, hashTbl); err != nil { + return nil, err + } + } + return chk, storage.Add(chk) +} + +// Compute hash values in chk and put it in hCtx.hashVals. +// Use the returned sel to choose the computed hash values. +func (e *CTEExec) computeChunkHash(chk *chunk.Chunk) (sel []int, err error) { + numRows := chk.NumRows() + e.hCtx.initHash(numRows) + // Continue to reset to make sure all hasher is new. + for i := numRows; i < len(e.hCtx.hashVals); i++ { + e.hCtx.hashVals[i].Reset() + } + sel = chk.Sel() + var hashBitMap []bool + if sel != nil { + hashBitMap = make([]bool, chk.Capacity()) + for _, val := range sel { + hashBitMap[val] = true + } + } else { + // All rows is selected, sel will be [0....numRows). + // e.sel is setup when building executor. + sel = e.sel + } + + for i := 0; i < chk.NumCols(); i++ { + if err = codec.HashChunkSelected(e.ctx.GetSessionVars().StmtCtx, e.hCtx.hashVals, + chk, e.hCtx.allTypes[i], i, e.hCtx.buf, e.hCtx.hasNull, + hashBitMap, false); err != nil { + return nil, err + } + } + return sel, nil +} + +// Use hashTbl to deduplicate rows, and unique rows will be added to hashTbl. +// Duplicated rows are only marked to be removed by sel in Chunk, instead of really deleted. +func (e *CTEExec) deduplicate(chk *chunk.Chunk, + storage cteutil.Storage, + hashTbl baseHashTable) (chkNoDup *chunk.Chunk, err error) { + numRows := chk.NumRows() + if numRows == 0 { + return chk, nil + } + + // 1. Compute hash values for chunk. + chkHashTbl := newConcurrentMapHashTable() + selOri, err := e.computeChunkHash(chk) + if err != nil { + return nil, err + } + + // 2. Filter rows duplicated in input chunk. + // This sel is for filtering rows duplicated in cur chk. + selChk := make([]int, 0, numRows) + for i := 0; i < numRows; i++ { + key := e.hCtx.hashVals[selOri[i]].Sum64() + row := chk.GetRow(i) + + hasDup, err := e.checkHasDup(key, row, chk, storage, chkHashTbl) + if err != nil { + return nil, err + } + if hasDup { + continue + } + + selChk = append(selChk, selOri[i]) + + rowPtr := chunk.RowPtr{ChkIdx: uint32(0), RowIdx: uint32(i)} + chkHashTbl.Put(key, rowPtr) + } + chk.SetSel(selChk) + chkIdx := storage.NumChunks() + + // 3. Filter rows duplicated in RowContainer. + // This sel is for filtering rows duplicated in cteutil.Storage. + selStorage := make([]int, 0, len(selChk)) + for i := 0; i < len(selChk); i++ { + key := e.hCtx.hashVals[selChk[i]].Sum64() + row := chk.GetRow(i) + + hasDup, err := e.checkHasDup(key, row, nil, storage, hashTbl) + if err != nil { + return nil, err + } + if hasDup { + continue + } + + rowIdx := len(selStorage) + selStorage = append(selStorage, selChk[i]) + + rowPtr := chunk.RowPtr{ChkIdx: uint32(chkIdx), RowIdx: uint32(rowIdx)} + hashTbl.Put(key, rowPtr) + } + + chk.SetSel(selStorage) + return chk, nil +} + +// Use the row's probe key to check if it already exists in chk or storage. +// We also need to compare the row's real encoding value to avoid hash collision. +func (e *CTEExec) checkHasDup(probeKey uint64, + row chunk.Row, + curChk *chunk.Chunk, + storage cteutil.Storage, + hashTbl baseHashTable) (hasDup bool, err error) { + ptrs := hashTbl.Get(probeKey) + + if len(ptrs) == 0 { + return false, nil + } + + for _, ptr := range ptrs { + var matchedRow chunk.Row + if curChk != nil { + matchedRow = curChk.GetRow(int(ptr.RowIdx)) + } else { + matchedRow, err = storage.GetRow(ptr) + } + if err != nil { + return false, err + } + isEqual, err := codec.EqualChunkRow(e.ctx.GetSessionVars().StmtCtx, + row, e.hCtx.allTypes, e.hCtx.keyColIdx, + matchedRow, e.hCtx.allTypes, e.hCtx.keyColIdx) + if err != nil { + return false, err + } + if isEqual { + return true, nil + } + } + return false, nil +} diff --git a/executor/cte_table_reader.go b/executor/cte_table_reader.go new file mode 100644 index 0000000000000..94fedf01fd93e --- /dev/null +++ b/executor/cte_table_reader.go @@ -0,0 +1,78 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "context" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/cteutil" +) + +// CTETableReaderExec scans data in iterInTbl, which is filled by corresponding CTEExec. +type CTETableReaderExec struct { + baseExecutor + + iterInTbl cteutil.Storage + chkIdx int + curIter int +} + +// Open implements the Executor interface. +func (e *CTETableReaderExec) Open(ctx context.Context) error { + e.reset() + return e.baseExecutor.Open(ctx) +} + +// Next implements the Executor interface. +func (e *CTETableReaderExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { + req.Reset() + + // Wait until iterInTbl can be read. This is controlled by corresponding CTEExec. + <-e.iterInTbl.GetBegCh() + + // We should read `iterInTbl` from the beginning when the next iteration starts. + // Can not directly judge whether to start the next iteration based on e.chkIdx, + // because some operators(Selection) may use forloop to read all data in `iterInTbl`. + if e.curIter != e.iterInTbl.GetIter() { + if e.curIter > e.iterInTbl.GetIter() { + return errors.Errorf("invalid iteration for CTETableReaderExec (e.curIter: %d, e.iterInTbl.GetIter(): %d)", + e.curIter, e.iterInTbl.GetIter()) + } + e.chkIdx = 0 + e.curIter = e.iterInTbl.GetIter() + } + if e.chkIdx < e.iterInTbl.NumChunks() { + res, err := e.iterInTbl.GetChunk(e.chkIdx) + if err != nil { + return err + } + // Need to copy chunk to make sure upper operators will not change chunk in iterInTbl. + req.SwapColumns(res.CopyConstructSel()) + e.chkIdx++ + } + return nil +} + +// Close implements the Executor interface. +func (e *CTETableReaderExec) Close() (err error) { + e.reset() + return e.baseExecutor.Close() +} + +func (e *CTETableReaderExec) reset() { + e.chkIdx = 0 + e.curIter = 0 +} diff --git a/executor/cte_test.go b/executor/cte_test.go new file mode 100644 index 0000000000000..e5789170627f7 --- /dev/null +++ b/executor/cte_test.go @@ -0,0 +1,244 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor_test + +import ( + "context" + "fmt" + "math/rand" + "sort" + + "github.com/pingcap/check" + + "github.com/pingcap/tidb/domain" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/session" + "github.com/pingcap/tidb/sessionctx" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/store/mockstore" + "github.com/pingcap/tidb/util/mock" + "github.com/pingcap/tidb/util/testkit" +) + +var _ = check.Suite(&CTETestSuite{}) + +type CTETestSuite struct { + store kv.Storage + dom *domain.Domain + sessionCtx sessionctx.Context + session session.Session + ctx context.Context +} + +func (test *CTETestSuite) SetUpSuite(c *check.C) { + var err error + test.store, err = mockstore.NewMockStore() + c.Assert(err, check.IsNil) + + test.dom, err = session.BootstrapSession(test.store) + c.Assert(err, check.IsNil) + + test.sessionCtx = mock.NewContext() + + test.session, err = session.CreateSession4Test(test.store) + c.Assert(err, check.IsNil) + test.session.SetConnectionID(0) + + test.ctx = context.Background() +} + +func (test *CTETestSuite) TearDownSuite(c *check.C) { + test.dom.Close() + test.store.Close() +} + +func (test *CTETestSuite) TestBasicCTE(c *check.C) { + tk := testkit.NewTestKit(c, test.store) + tk.MustExec("use test") + + rows := tk.MustQuery("with recursive cte1 as (" + + "select 1 c1 " + + "union all " + + "select c1 + 1 c1 from cte1 where c1 < 5) " + + "select * from cte1") + rows.Check(testkit.Rows("1", "2", "3", "4", "5")) + + // Two seed parts. + rows = tk.MustQuery("with recursive cte1 as (" + + "select 1 c1 " + + "union all " + + "select 2 c1 " + + "union all " + + "select c1 + 1 c1 from cte1 where c1 < 10) " + + "select * from cte1 order by c1") + rows.Check(testkit.Rows("1", "2", "2", "3", "3", "4", "4", "5", "5", "6", "6", "7", "7", "8", "8", "9", "9", "10", "10")) + + // Two recursive parts. + rows = tk.MustQuery("with recursive cte1 as (" + + "select 1 c1 " + + "union all " + + "select 2 c1 " + + "union all " + + "select c1 + 1 c1 from cte1 where c1 < 3 " + + "union all " + + "select c1 + 2 c1 from cte1 where c1 < 5) " + + "select * from cte1 order by c1") + rows.Check(testkit.Rows("1", "2", "2", "3", "3", "3", "4", "4", "5", "5", "5", "6", "6")) + + tk.MustExec("drop table if exists t1;") + tk.MustExec("create table t1(a int);") + tk.MustExec("insert into t1 values(1);") + tk.MustExec("insert into t1 values(2);") + rows = tk.MustQuery("SELECT * FROM t1 dt WHERE EXISTS(WITH RECURSIVE qn AS (SELECT a*0 AS b UNION ALL SELECT b+1 FROM qn WHERE b=0) SELECT * FROM qn WHERE b=a);") + rows.Check(testkit.Rows("1")) + rows = tk.MustQuery("SELECT * FROM t1 dt WHERE EXISTS( WITH RECURSIVE qn AS (SELECT a*0 AS b UNION ALL SELECT b+1 FROM qn WHERE b=0 or b = 1) SELECT * FROM qn WHERE b=a );") + rows.Check(testkit.Rows("1", "2")) +} + +func (test *CTETestSuite) TestSpillToDisk(c *check.C) { + tk := testkit.NewTestKit(c, test.store) + tk.MustExec("use test;") + + c.Assert(failpoint.Enable("github.com/pingcap/tidb/executor/testCTEStorageSpill", "return(true)"), check.IsNil) + defer func() { + c.Assert(failpoint.Disable("github.com/pingcap/tidb/executor/testCTEStorageSpill"), check.IsNil) + }() + + insertStr := "insert into t1 values(0, 0)" + for i := 1; i < 5000; i++ { + insertStr += fmt.Sprintf(", (%d, %d)", i, i) + } + + tk.MustExec("drop table if exists t1;") + tk.MustExec("create table t1(c1 int, c2 int);") + tk.MustExec(insertStr) + tk.MustExec("set tidb_mem_quota_query = 80000;") + rows := tk.MustQuery("with recursive cte1 as ( " + + "select c1 from t1 " + + "union " + + "select c1 + 1 c1 from cte1 where c1 < 5000) " + + "select c1 from cte1;") + + memTracker := tk.Se.GetSessionVars().StmtCtx.MemTracker + diskTracker := tk.Se.GetSessionVars().StmtCtx.DiskTracker + c.Assert(memTracker.MaxConsumed(), check.Greater, int64(0)) + c.Assert(diskTracker.MaxConsumed(), check.Greater, int64(0)) + + rowNum := 5000 + var resRows []string + for i := 0; i <= rowNum; i++ { + resRows = append(resRows, fmt.Sprintf("%d", i)) + } + rows.Check(testkit.Rows(resRows...)) + + // Use duplicated rows to test UNION DISTINCT. + tk.MustExec("set tidb_mem_quota_query = 1073741824;") + insertStr = "insert into t1 values(0, 0)" + vals := make([]int, rowNum) + vals[0] = 0 + for i := 1; i < rowNum; i++ { + v := rand.Intn(100) + vals[i] = v + insertStr += fmt.Sprintf(", (%d, %d)", v, v) + } + tk.MustExec("drop table if exists t1;") + tk.MustExec("create table t1(c1 int, c2 int);") + tk.MustExec(insertStr) + tk.MustExec("set tidb_mem_quota_query = 80000;") + tk.MustExec("set cte_max_recursion_depth = 500000;") + rows = tk.MustQuery("with recursive cte1 as ( " + + "select c1 from t1 " + + "union " + + "select c1 + 1 c1 from cte1 where c1 < 5000) " + + "select c1 from cte1 order by c1;") + + memTracker = tk.Se.GetSessionVars().StmtCtx.MemTracker + diskTracker = tk.Se.GetSessionVars().StmtCtx.DiskTracker + c.Assert(memTracker.MaxConsumed(), check.Greater, int64(0)) + c.Assert(diskTracker.MaxConsumed(), check.Greater, int64(0)) + + sort.Ints(vals) + resRows = make([]string, 0, rowNum) + for i := vals[0]; i <= rowNum; i++ { + resRows = append(resRows, fmt.Sprintf("%d", i)) + } + rows.Check(testkit.Rows(resRows...)) +} + +func (test *CTETestSuite) TestUnionDistinct(c *check.C) { + tk := testkit.NewTestKit(c, test.store) + tk.MustExec("use test;") + + // Basic test. UNION/UNION ALL intersects. + rows := tk.MustQuery("with recursive cte1(c1) as (select 1 union select 1 union select 1 union all select c1 + 1 from cte1 where c1 < 3) select * from cte1 order by c1;") + rows.Check(testkit.Rows("1", "2", "3")) + + rows = tk.MustQuery("with recursive cte1(c1) as (select 1 union all select 1 union select 1 union all select c1 + 1 from cte1 where c1 < 3) select * from cte1 order by c1;") + rows.Check(testkit.Rows("1", "2", "3")) + + tk.MustExec("drop table if exists t1;") + tk.MustExec("create table t1(c1 int, c2 int);") + tk.MustExec("insert into t1 values(1, 1), (1, 2), (2, 2);") + rows = tk.MustQuery("with recursive cte1(c1) as (select c1 from t1 union select c1 + 1 c1 from t1) select * from cte1 order by c1;") + rows.Check(testkit.Rows("1", "2", "3")) + + tk.MustExec("drop table if exists t1;") + tk.MustExec("create table t1(c1 int);") + tk.MustExec("insert into t1 values(1), (1), (1), (2), (2), (2);") + rows = tk.MustQuery("with recursive cte1(c1) as (select c1 from t1 union select c1 + 1 c1 from cte1 where c1 < 4) select * from cte1 order by c1;") + rows.Check(testkit.Rows("1", "2", "3", "4")) +} + +func (test *CTETestSuite) TestCTEMaxRecursionDepth(c *check.C) { + tk := testkit.NewTestKit(c, test.store) + tk.MustExec("use test;") + + tk.MustExec("set @@cte_max_recursion_depth = -1;") + err := tk.QueryToErr("with recursive cte1(c1) as (select 1 union select c1 + 1 c1 from cte1 where c1 < 100) select * from cte1;") + c.Assert(err, check.NotNil) + c.Assert(err.Error(), check.Equals, "[executor:3636]Recursive query aborted after 1 iterations. Try increasing @@cte_max_recursion_depth to a larger value") + // If there is no recursive part, query runs ok. + rows := tk.MustQuery("with recursive cte1(c1) as (select 1 union select 2) select * from cte1 order by c1;") + rows.Check(testkit.Rows("1", "2")) + rows = tk.MustQuery("with cte1(c1) as (select 1 union select 2) select * from cte1 order by c1;") + rows.Check(testkit.Rows("1", "2")) + + tk.MustExec("set @@cte_max_recursion_depth = 0;") + err = tk.QueryToErr("with recursive cte1(c1) as (select 1 union select c1 + 1 c1 from cte1 where c1 < 0) select * from cte1;") + c.Assert(err, check.NotNil) + c.Assert(err.Error(), check.Equals, "[executor:3636]Recursive query aborted after 1 iterations. Try increasing @@cte_max_recursion_depth to a larger value") + err = tk.QueryToErr("with recursive cte1(c1) as (select 1 union select c1 + 1 c1 from cte1 where c1 < 1) select * from cte1;") + c.Assert(err, check.NotNil) + c.Assert(err.Error(), check.Equals, "[executor:3636]Recursive query aborted after 1 iterations. Try increasing @@cte_max_recursion_depth to a larger value") + // If there is no recursive part, query runs ok. + rows = tk.MustQuery("with recursive cte1(c1) as (select 1 union select 2) select * from cte1 order by c1;") + rows.Check(testkit.Rows("1", "2")) + rows = tk.MustQuery("with cte1(c1) as (select 1 union select 2) select * from cte1 order by c1;") + rows.Check(testkit.Rows("1", "2")) + + tk.MustExec("set @@cte_max_recursion_depth = 1;") + rows = tk.MustQuery("with recursive cte1(c1) as (select 1 union select c1 + 1 c1 from cte1 where c1 < 0) select * from cte1;") + rows.Check(testkit.Rows("1")) + rows = tk.MustQuery("with recursive cte1(c1) as (select 1 union select c1 + 1 c1 from cte1 where c1 < 1) select * from cte1;") + rows.Check(testkit.Rows("1")) + err = tk.QueryToErr("with recursive cte1(c1) as (select 1 union select c1 + 1 c1 from cte1 where c1 < 2) select * from cte1;") + c.Assert(err, check.NotNil) + c.Assert(err.Error(), check.Equals, "[executor:3636]Recursive query aborted after 2 iterations. Try increasing @@cte_max_recursion_depth to a larger value") + // If there is no recursive part, query runs ok. + rows = tk.MustQuery("with recursive cte1(c1) as (select 1 union select 2) select * from cte1 order by c1;") + rows.Check(testkit.Rows("1", "2")) + rows = tk.MustQuery("with cte1(c1) as (select 1 union select 2) select * from cte1 order by c1;") + rows.Check(testkit.Rows("1", "2")) +} diff --git a/executor/executor.go b/executor/executor.go index f9cffcdfdba54..ecb9102ba2f39 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -1629,10 +1629,11 @@ func (e *UnionExec) Close() error { func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { vars := ctx.GetSessionVars() sc := &stmtctx.StatementContext{ - TimeZone: vars.Location(), - MemTracker: memory.NewTracker(memory.LabelForSQLText, vars.MemQuotaQuery), - DiskTracker: disk.NewTracker(memory.LabelForSQLText, -1), - TaskID: stmtctx.AllocateTaskID(), + TimeZone: vars.Location(), + MemTracker: memory.NewTracker(memory.LabelForSQLText, vars.MemQuotaQuery), + DiskTracker: disk.NewTracker(memory.LabelForSQLText, -1), + TaskID: stmtctx.AllocateTaskID(), + CTEStorageMap: map[int]*CTEStorages{}, } sc.MemTracker.AttachToGlobalTracker(GlobalMemoryUsageTracker) globalConfig := config.GetGlobalConfig() diff --git a/session/session.go b/session/session.go index 6e177f5cca368..0b8dce4ef9a22 100644 --- a/session/session.go +++ b/session/session.go @@ -1680,8 +1680,40 @@ type execStmtResult struct { func (rs *execStmtResult) Close() error { se := rs.se - err := rs.RecordSet.Close() - return finishStmt(context.Background(), se, err, rs.sql) + if err := resetCTEStorageMap(se); err != nil { + return finishStmt(context.Background(), se, err, rs.sql) + } + if err := rs.RecordSet.Close(); err != nil { + return finishStmt(context.Background(), se, err, rs.sql) + } + return finishStmt(context.Background(), se, nil, rs.sql) +} + +func resetCTEStorageMap(se *session) error { + tmp := se.GetSessionVars().StmtCtx.CTEStorageMap + if tmp == nil { + // Close() is already called, so no need to reset. Such as TraceExec. + return nil + } + storageMap, ok := tmp.(map[int]*executor.CTEStorages) + if !ok { + return errors.New("type assertion for CTEStorageMap failed") + } + for _, v := range storageMap { + // No need to lock IterInTbl. + v.ResTbl.Lock() + defer v.ResTbl.Unlock() + err1 := v.ResTbl.DerefAndClose() + err2 := v.IterInTbl.DerefAndClose() + if err1 != nil { + return err1 + } + if err2 != nil { + return err2 + } + } + se.GetSessionVars().StmtCtx.CTEStorageMap = nil + return nil } // rollbackOnError makes sure the next statement starts a new transaction with the latest InfoSchema. diff --git a/sessionctx/stmtctx/stmtctx.go b/sessionctx/stmtctx/stmtctx.go index ea8bd70b8c0f2..23fc0f52664f6 100644 --- a/sessionctx/stmtctx/stmtctx.go +++ b/sessionctx/stmtctx/stmtctx.go @@ -168,6 +168,9 @@ type StatementContext struct { stmtCache map[StmtCacheKey]interface{} // resourceGroupTag cache for the current statement resource group tag. resourceGroupTag atomic.Value + // Map to store all CTE storages of current SQL. + // Will clean up at the end of the execution. + CTEStorageMap interface{} } // StmtHints are SessionVars related sql hints. diff --git a/util/chunk/chunk.go b/util/chunk/chunk.go index e91cff2559d79..a15279e1a7ada 100644 --- a/util/chunk/chunk.go +++ b/util/chunk/chunk.go @@ -291,6 +291,21 @@ func (c *Chunk) CopyConstruct() *Chunk { return newChk } +// CopyConstructSel is just like CopyConstruct, +// but ignore the rows that was not selected. +func (c *Chunk) CopyConstructSel() *Chunk { + if c.sel == nil { + return c.CopyConstruct() + } + newChk := renewWithCapacity(c, c.capacity, c.requiredRows) + for colIdx, dstCol := range newChk.columns { + for _, rowIdx := range c.sel { + appendCellByCell(dstCol, c.columns[colIdx], rowIdx) + } + } + return newChk +} + // GrowAndReset resets the Chunk and doubles the capacity of the Chunk. // The doubled capacity should not be larger than maxChunkSize. // TODO: this method will be used in following PR. diff --git a/util/chunk/disk.go b/util/chunk/disk.go index 47e6b25c0b492..a9f7eecec3641 100644 --- a/util/chunk/disk.go +++ b/util/chunk/disk.go @@ -14,7 +14,6 @@ package chunk import ( - "errors" "io" "os" "strconv" @@ -132,7 +131,7 @@ func (l *ListInDisk) flush() (err error) { // Warning: do not mix Add and GetRow (always use GetRow after you have added all the chunks), and do not use Add concurrently. func (l *ListInDisk) Add(chk *Chunk) (err error) { if chk.NumRows() == 0 { - return errors.New("chunk appended to List should have at least 1 row") + return errors2.New("chunk appended to List should have at least 1 row") } if l.disk == nil { err = l.initDiskFile() diff --git a/util/cteutil/storage.go b/util/cteutil/storage.go index 3397e81fa7c05..9d42b1a11c015 100644 --- a/util/cteutil/storage.go +++ b/util/cteutil/storage.go @@ -51,6 +51,7 @@ type Storage interface { Reopen() error // Add chunk into underlying storage. + // Should return directly if chk is empty. Add(chk *chunk.Chunk) error // Get Chunk by index. @@ -84,7 +85,7 @@ type Storage interface { GetMemTracker() *memory.Tracker GetDiskTracker() *disk.Tracker - ActionSpill() memory.ActionOnExceed + ActionSpill() *chunk.SpillDiskAction } // StorageRC implements Storage interface using RowContainer. @@ -101,8 +102,8 @@ type StorageRC struct { rc *chunk.RowContainer } -// NewStorageRC create a new StorageRC. -func NewStorageRC(tp []*types.FieldType, chkSize int) *StorageRC { +// NewStorageRowContainer create a new StorageRC. +func NewStorageRowContainer(tp []*types.FieldType, chkSize int) *StorageRC { return &StorageRC{tp: tp, chkSize: chkSize} } @@ -245,7 +246,7 @@ func (s *StorageRC) GetDiskTracker() *memory.Tracker { } // ActionSpill impls Storage ActionSpill interface. -func (s *StorageRC) ActionSpill() memory.ActionOnExceed { +func (s *StorageRC) ActionSpill() *chunk.SpillDiskAction { return s.rc.ActionSpill() } diff --git a/util/cteutil/storage_test.go b/util/cteutil/storage_test.go index 89376fc8580b8..0e494978f2f84 100644 --- a/util/cteutil/storage_test.go +++ b/util/cteutil/storage_test.go @@ -35,7 +35,7 @@ type StorageRCTestSuite struct{} func (test *StorageRCTestSuite) TestStorageBasic(c *check.C) { fields := []*types.FieldType{types.NewFieldType(mysql.TypeLong)} chkSize := 1 - storage := NewStorageRC(fields, chkSize) + storage := NewStorageRowContainer(fields, chkSize) c.Assert(storage, check.NotNil) // Close before open. @@ -67,7 +67,7 @@ func (test *StorageRCTestSuite) TestStorageBasic(c *check.C) { func (test *StorageRCTestSuite) TestOpenAndClose(c *check.C) { fields := []*types.FieldType{types.NewFieldType(mysql.TypeLong)} chkSize := 1 - storage := NewStorageRC(fields, chkSize) + storage := NewStorageRowContainer(fields, chkSize) for i := 0; i < 10; i++ { err := storage.OpenAndRef() @@ -89,7 +89,7 @@ func (test *StorageRCTestSuite) TestAddAndGetChunk(c *check.C) { fields := []*types.FieldType{types.NewFieldType(mysql.TypeLong)} chkSize := 10 - storage := NewStorageRC(fields, chkSize) + storage := NewStorageRowContainer(fields, chkSize) inChk := chunk.NewChunkWithCapacity(fields, chkSize) for i := 0; i < chkSize; i++ { @@ -117,7 +117,7 @@ func (test *StorageRCTestSuite) TestAddAndGetChunk(c *check.C) { func (test *StorageRCTestSuite) TestSpillToDisk(c *check.C) { fields := []*types.FieldType{types.NewFieldType(mysql.TypeLong)} chkSize := 10 - storage := NewStorageRC(fields, chkSize) + storage := NewStorageRowContainer(fields, chkSize) var tmp interface{} = storage inChk := chunk.NewChunkWithCapacity(fields, chkSize) @@ -171,7 +171,7 @@ func (test *StorageRCTestSuite) TestSpillToDisk(c *check.C) { func (test *StorageRCTestSuite) TestReopen(c *check.C) { fields := []*types.FieldType{types.NewFieldType(mysql.TypeLong)} chkSize := 10 - storage := NewStorageRC(fields, chkSize) + storage := NewStorageRowContainer(fields, chkSize) err := storage.OpenAndRef() c.Assert(err, check.IsNil) @@ -216,7 +216,7 @@ func (test *StorageRCTestSuite) TestReopen(c *check.C) { func (test *StorageRCTestSuite) TestSwapData(c *check.C) { tp1 := []*types.FieldType{types.NewFieldType(mysql.TypeLong)} chkSize := 10 - storage1 := NewStorageRC(tp1, chkSize) + storage1 := NewStorageRowContainer(tp1, chkSize) err := storage1.OpenAndRef() c.Assert(err, check.IsNil) inChk1 := chunk.NewChunkWithCapacity(tp1, chkSize) @@ -228,7 +228,7 @@ func (test *StorageRCTestSuite) TestSwapData(c *check.C) { c.Assert(err, check.IsNil) tp2 := []*types.FieldType{types.NewFieldType(mysql.TypeVarString)} - storage2 := NewStorageRC(tp2, chkSize) + storage2 := NewStorageRowContainer(tp2, chkSize) err = storage2.OpenAndRef() c.Assert(err, check.IsNil) diff --git a/util/memory/tracker.go b/util/memory/tracker.go index 3c369724c229e..2525ee76c2e0a 100644 --- a/util/memory/tracker.go +++ b/util/memory/tracker.go @@ -491,4 +491,6 @@ const ( LabelForApplyCache int = -17 // LabelForSimpleTask represents the label of the simple task LabelForSimpleTask int = -18 + // LabelForCTEStorage represents the label of CTE storage + LabelForCTEStorage int = -19 )