Skip to content

Commit

Permalink
executor: support semi join (pingcap#57658)
Browse files Browse the repository at this point in the history
  • Loading branch information
xzhangxian1008 authored Dec 4, 2024
1 parent 9812d85 commit 8eebb2d
Show file tree
Hide file tree
Showing 16 changed files with 1,180 additions and 120 deletions.
39 changes: 35 additions & 4 deletions pkg/executor/internal/testutil/testutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ type MockDataSourceParameters struct {
// and he can save provided test data at here.
Datums [][]any

Nulls [][]bool

Rows int
HasSel bool
}
Expand Down Expand Up @@ -80,6 +82,10 @@ func (mds *MockDataSource) GenColDatums(col int) (results []any) {
}
results = make([]any, 0, rows)

// ndv > 0: generate n rows with random value with `nvd` distinct value
// ndv == 0: generate n rows with random value
// ndv == -1: generate n rows with value provided by user and with `nvd` distinct value
// ndv == -2: use rows provided by user
if ndv == 0 {
if mds.P.GenDataFunc == nil {
for i := 0; i < rows; i++ {
Expand All @@ -90,11 +96,19 @@ func (mds *MockDataSource) GenColDatums(col int) (results []any) {
results = append(results, mds.P.GenDataFunc(i, typ))
}
}
} else if ndv == -2 {
// Use data provided by user
if mds.P.Datums[col] == nil {
panic("need to provide data")
}

results = mds.P.Datums[col]
} else {
// Use nvd base data provided by user
datums := make([]any, 0, max(ndv, 0))
if ndv == -1 {
if mds.P.Datums[col] == nil {
panic("need to provid data")
panic("need to provide data")
}

datums = mds.P.Datums[col]
Expand Down Expand Up @@ -257,8 +271,9 @@ func BuildMockDataSource(opt MockDataSourceParameters) *MockDataSource {
Chunks: nil,
}
rTypes := exec.RetTypes(m)
colData := make([][]any, len(rTypes))
for i := 0; i < len(rTypes); i++ {
colNum := len(rTypes)
colData := make([][]any, colNum)
for i := 0; i < colNum; i++ {
colData[i] = m.GenColDatums(i)
}

Expand All @@ -267,10 +282,26 @@ func BuildMockDataSource(opt MockDataSourceParameters) *MockDataSource {
m.GenData[i] = chunk.NewChunkWithCapacity(exec.RetTypes(m), m.MaxChunkSize())
}

nulls := opt.Nulls
if nulls == nil {
nulls = make([][]bool, colNum)
for i := range colNum {
nulls[i] = make([]bool, m.P.Rows)
for j := range m.P.Rows {
nulls[i][j] = false
}
}
}

for i := 0; i < m.P.Rows; i++ {
idx := i / m.MaxChunkSize()
retTypes := exec.RetTypes(m)
for colIdx := 0; colIdx < len(rTypes); colIdx++ {
for colIdx := 0; colIdx < colNum; colIdx++ {
if nulls[colIdx][i] {
m.GenData[idx].AppendNull(colIdx)
continue
}

switch retTypes[colIdx].GetType() {
case mysql.TypeLong, mysql.TypeLonglong:
m.GenData[idx].AppendInt64(colIdx, colData[colIdx][i].(int64))
Expand Down
3 changes: 3 additions & 0 deletions pkg/executor/join/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ go_library(
name = "join",
srcs = [
"base_join_probe.go",
"base_semi_join.go",
"concurrent_map.go",
"hash_join_base.go",
"hash_join_spill.go",
Expand All @@ -24,6 +25,7 @@ go_library(
"merge_join.go",
"outer_join_probe.go",
"row_table_builder.go",
"semi_join_probe.go",
"tagged_ptr.go",
],
importpath = "github.com/pingcap/tidb/pkg/executor/join",
Expand Down Expand Up @@ -93,6 +95,7 @@ go_test(
"outer_join_spill_test.go",
"right_outer_join_probe_test.go",
"row_table_builder_test.go",
"semi_join_probe_test.go",
"tagged_ptr_test.go",
],
embed = [":join"],
Expand Down
26 changes: 25 additions & 1 deletion pkg/executor/join/base_join_probe.go
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ func (j *baseJoinProbe) prepareForProbe(chk *chunk.Chunk) (joinedChk *chunk.Chun
j.nextCachedBuildRowIndex = 0
j.matchedRowsForCurrentProbeRow = 0
joinedChk = chk
if j.ctx.OtherCondition != nil {
if j.ctx.hasOtherCondition() {
j.tmpChk.Reset()
j.rowIndexInfos = j.rowIndexInfos[:0]
j.selected = j.selected[:0]
Expand Down Expand Up @@ -693,6 +693,22 @@ func isKeyMatched(keyMode keyMode, serializedKey []byte, rowStart unsafe.Pointer
}
}

func commonInitForScanRowTable(base *baseJoinProbe) *rowIter {
totalRowCount := base.ctx.hashTableContext.hashTable.totalRowCount()
concurrency := base.ctx.Concurrency
workID := uint64(base.workID)
avgRowPerWorker := totalRowCount / uint64(concurrency)
startIndex := workID * avgRowPerWorker
endIndex := (workID + 1) * avgRowPerWorker
if workID == uint64(concurrency-1) {
endIndex = totalRowCount
}
if endIndex > totalRowCount {
endIndex = totalRowCount
}
return base.ctx.hashTableContext.hashTable.createRowIter(startIndex, endIndex)
}

// NewJoinProbe create a join probe used for hash join v2
func NewJoinProbe(ctx *HashJoinCtxV2, workID uint, joinType logicalop.JoinType, keyIndex []int, joinedColumnTypes, probeKeyTypes []*types.FieldType, rightAsBuildSide bool) ProbeV2 {
base := baseJoinProbe{
Expand Down Expand Up @@ -747,7 +763,15 @@ func NewJoinProbe(ctx *HashJoinCtxV2, workID uint, joinType logicalop.JoinType,
return newOuterJoinProbe(base, !rightAsBuildSide, rightAsBuildSide)
case logicalop.RightOuterJoin:
return newOuterJoinProbe(base, rightAsBuildSide, rightAsBuildSide)
case logicalop.SemiJoin:
if len(base.rUsed) != 0 {
panic("len(base.rUsed) != 0 for semi join")
}
return newSemiJoinProbe(base, !rightAsBuildSide)
case logicalop.LeftOuterSemiJoin:
if len(base.rUsed) != 0 {
panic("len(base.rUsed) != 0 for left outer semi join")
}
if rightAsBuildSide {
return newLeftOuterSemiJoinProbe(base)
}
Expand Down
131 changes: 131 additions & 0 deletions pkg/executor/join/base_semi_join.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
// Copyright 2024 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package join

import (
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/queue"
"github.com/pingcap/tidb/pkg/util/sqlkiller"
)

// The following described case has other condition.
// During the probe, when a probe matches one build row, we need to put the probe and build rows
// together and generate a new row. If one probe row could match n build row, then we will get
// n new rows. If n is very big, there will generate too much rows. In order to avoid this case
// we need to limit the max generated row number. This variable describe this max number.
// NOTE: Suppose probe chunk has n rows and n*maxMatchedRowNum << chunkRemainingCapacity.
// We will keep on join probe rows that have been matched before with build rows, though
// probe row with idx i may have produced `maxMatchedRowNum` number rows before. So that
// we can process as many rows as possible.
var maxMatchedRowNum = 4

type baseSemiJoin struct {
baseJoinProbe
isLeftSideBuild bool

// isMatchedRows marks whether the left side row is matched
// It's used only when right side is build side.
isMatchedRows []bool

isNulls []bool

// used when left side is build side
rowIter *rowIter

// used in other condition to record which rows need to be processed
unFinishedProbeRowIdxQueue *queue.Queue[int]
}

func newBaseSemiJoin(base baseJoinProbe, isLeftSideBuild bool) *baseSemiJoin {
ret := &baseSemiJoin{
baseJoinProbe: base,
isLeftSideBuild: isLeftSideBuild,
isNulls: make([]bool, 0),
}

return ret
}

func (b *baseSemiJoin) resetProbeState() {
if !b.isLeftSideBuild {
b.isMatchedRows = b.isMatchedRows[:0]
for i := 0; i < b.chunkRows; i++ {
b.isMatchedRows = append(b.isMatchedRows, false)
}
}

if b.ctx.hasOtherCondition() {
if b.unFinishedProbeRowIdxQueue == nil {
b.unFinishedProbeRowIdxQueue = queue.NewQueue[int](b.chunkRows)
} else {
b.unFinishedProbeRowIdxQueue.ClearAndExpandIfNeed(b.chunkRows)
}

for i := 0; i < b.chunkRows; i++ {
if b.matchedRowsHeaders[i] != 0 {
b.unFinishedProbeRowIdxQueue.Push(i)
}
}
}
}

func (b *baseSemiJoin) matchMultiBuildRows(joinedChk *chunk.Chunk, joinedChkRemainCap *int, isRightSideBuild bool) {
tagHelper := b.ctx.hashTableContext.tagHelper
meta := b.ctx.hashTableMeta
for b.matchedRowsHeaders[b.currentProbeRow] != 0 && *joinedChkRemainCap > 0 && b.matchedRowsForCurrentProbeRow < maxMatchedRowNum {
candidateRow := tagHelper.toUnsafePointer(b.matchedRowsHeaders[b.currentProbeRow])
if isRightSideBuild || !meta.isCurrentRowUsedWithAtomic(candidateRow) {
if isKeyMatched(meta.keyMode, b.serializedKeys[b.currentProbeRow], candidateRow, meta) {
b.appendBuildRowToCachedBuildRowsV1(b.currentProbeRow, candidateRow, joinedChk, 0, true)
b.matchedRowsForCurrentProbeRow++
*joinedChkRemainCap--
} else {
b.probeCollision++
}
}

b.matchedRowsHeaders[b.currentProbeRow] = getNextRowAddress(candidateRow, tagHelper, b.matchedRowsHashValue[b.currentProbeRow])
}

b.finishLookupCurrentProbeRow()
}

func (b *baseSemiJoin) concatenateProbeAndBuildRows(joinedChk *chunk.Chunk, sqlKiller *sqlkiller.SQLKiller, isRightSideBuild bool) error {
joinedChkRemainCap := joinedChk.Capacity()

for joinedChkRemainCap > 0 && !b.unFinishedProbeRowIdxQueue.IsEmpty() {
probeRowIdx := b.unFinishedProbeRowIdxQueue.Pop()
if isRightSideBuild && b.isMatchedRows[probeRowIdx] {
continue
}

b.currentProbeRow = probeRowIdx
b.matchMultiBuildRows(joinedChk, &joinedChkRemainCap, isRightSideBuild)

if b.matchedRowsHeaders[probeRowIdx] == 0 {
continue
}

b.unFinishedProbeRowIdxQueue.Push(probeRowIdx)
}

err := checkSQLKiller(sqlKiller, "killedDuringProbe")
if err != nil {
return err
}

b.finishCurrentLookupLoop(joinedChk)
return nil
}
6 changes: 5 additions & 1 deletion pkg/executor/join/inner_join_probe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,11 @@ func testJoinProbe(t *testing.T, withSel bool, leftKeyIndex []int, rightKeyIndex
checkChunksEqual(t, expectedChunks, resultChunks, resultTypes)
case logicalop.LeftOuterSemiJoin:
expectedChunks := genLeftOuterSemiJoinResult(t, hashJoinCtx.SessCtx, leftFilter, leftChunks, rightChunks, leftKeyIndex, rightKeyIndex, leftTypes,
rightTypes, leftKeyTypes, rightKeyTypes, leftUsed, rightUsed, otherCondition, resultTypes)
rightTypes, leftKeyTypes, rightKeyTypes, leftUsed, otherCondition, resultTypes)
checkChunksEqual(t, expectedChunks, resultChunks, resultTypes)
case logicalop.SemiJoin:
expectedChunks := genSemiJoinResult(t, hashJoinCtx.SessCtx, leftFilter, leftChunks, rightChunks, leftKeyIndex, rightKeyIndex, leftTypes,
rightTypes, leftKeyTypes, rightKeyTypes, leftUsed, otherCondition, resultTypes)
checkChunksEqual(t, expectedChunks, resultChunks, resultTypes)
default:
require.NoError(t, errors.New("not supported join type"))
Expand Down
6 changes: 6 additions & 0 deletions pkg/executor/join/join_table_meta.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ func (*joinTableMeta) isCurrentRowUsed(rowStart unsafe.Pointer) bool {
return (*(*uint32)(unsafe.Add(rowStart, sizeOfNextPtr)) & usedFlagMask) == usedFlagMask
}

func (*joinTableMeta) isCurrentRowUsedWithAtomic(rowStart unsafe.Pointer) bool {
addr := (*uint32)(unsafe.Add(rowStart, sizeOfNextPtr))
value := atomic.LoadUint32(addr)
return (value & usedFlagMask) == usedFlagMask
}

type keyProp struct {
canBeInlined bool
keyLength int
Expand Down
Loading

0 comments on commit 8eebb2d

Please sign in to comment.