Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

planner, statistics: maintain histogram for inner join #8097

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions planner/core/cbo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,31 @@ func (s *testAnalyzeSuite) TestInconsistentEstimation(c *C) {
))
}

func (s *testAnalyzeSuite) TestJoinWithHistogram(c *C) {
defer testleak.AfterTest(c)()
store, dom, err := newStoreWithBootstrap()
c.Assert(err, IsNil)
tk := testkit.NewTestKit(c, store)
defer func() {
dom.Close()
store.Close()
}()
tk.MustExec("use test")
tk.MustExec("create table t(a int primary key, b int, index idx(b))")
tk.MustExec("create table tt(a int primary key, b int, index idx(b))")
tk.MustExec("insert into t values(1, 1), (2, 1), (3, 1), (4, 2), (5, 2), (6, 2), (7, 3), (8, 4), (9, 5)")
tk.MustExec("insert into tt values(1, 1), (3, 1), (5, 1), (7, 2), (9, 3), (15, 4)")
tk.MustExec("analyze table t, tt")
tk.MustExec("set @@session.tidb_optimizer_selectivity_level=1")
tk.MustQuery("explain select * from t t1 join tt t2 where t1.a=t2.a").Check(testkit.Rows(
"IndexJoin_14 5.00 root inner join, inner:TableReader_13, outer key:t2.a, inner key:t1.a",
"├─TableReader_13 1.00 root data:TableScan_12",
"│ └─TableScan_12 1.00 cop table:t1, range: decided by [t2.a], keep order:false",
"└─TableReader_24 6.00 root data:TableScan_23",
" └─TableScan_23 6.00 cop table:t2, range:[-inf,+inf], keep order:false",
))
}

func newStoreWithBootstrap() (kv.Storage, *domain.Domain, error) {
store, err := mockstore.NewMockTikvStore()
if err != nil {
Expand Down
92 changes: 92 additions & 0 deletions planner/core/stats.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (

"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/planner/property"
"github.com/pingcap/tidb/statistics"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
)
Expand Down Expand Up @@ -278,6 +279,9 @@ func (p *LogicalJoin) deriveStats() (*property.StatsInfo, error) {
leftKeys = append(leftKeys, eqCond.GetArgs()[0].(*expression.Column))
rightKeys = append(rightKeys, eqCond.GetArgs()[1].(*expression.Column))
}
if p.JoinType == InnerJoin && p.ctx.GetSessionVars().OptimizerSelectivityLevel >= 1 {
return p.deriveInnerJoinStatsWithHist(leftKeys, rightKeys)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we pass the childStats parameter down to deriveInnerJoinStatsWithHist? so this DeriveStats function can be used by cascades planner as well.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, using childStats in new commits. But it uses the schema of join's children inside this method. Seems that i need some way to not rely on it.

}
leftKeyCardinality := getCardinality(leftKeys, p.children[0].Schema(), leftProfile)
rightKeyCardinality := getCardinality(rightKeys, p.children[1].Schema(), rightProfile)
count := leftProfile.RowCount * rightProfile.RowCount / math.Max(leftKeyCardinality, rightKeyCardinality)
Expand All @@ -299,6 +303,94 @@ func (p *LogicalJoin) deriveStats() (*property.StatsInfo, error) {
return p.stats, nil
}

func (p *LogicalJoin) deriveInnerJoinStatsWithHist(leftKeys, rightKeys []*expression.Column) (*property.StatsInfo, error) {
leftChild, rightChild := p.children[0], p.children[1]
leftProfile, rightProfile := leftChild.statsInfo(), rightChild.statsInfo()

cardinality := make([]float64, 0, p.schema.Len())
cardinality = append(cardinality, leftProfile.Cardinality...)
cardinality = append(cardinality, rightProfile.Cardinality...)

ndv, leftNdv, rightNdv := float64(1), float64(1), float64(1)
idxHistID := int64(0)
newColID2Hist := make(map[int64]*statistics.Column)
newIdxID2Hist := make(map[int64]*statistics.Index)
leftIndexReLabel := make(map[int64]int64)
rightIndexReLabel := make(map[int64]int64)
newIdx2ColumnIDs := make(map[int64][]int64)

// TODO: Support using index histogram to calculate the NDV after join and the final row count.
for i := range leftKeys {
leftHist, ok1 := leftChild.statsInfo().HistColl.Columns[leftKeys[i].UniqueID]
rightHist, ok2 := rightChild.statsInfo().HistColl.Columns[rightKeys[i].UniqueID]
lPos := leftChild.Schema().ColumnIndex(leftKeys[i])
rPos := rightChild.Schema().ColumnIndex(rightKeys[i])
leftNdv *= leftProfile.Cardinality[lPos]
rightNdv *= rightProfile.Cardinality[rPos]
if ok1 && ok2 {
eurekaka marked this conversation as resolved.
Show resolved Hide resolved
newHist := statistics.MergeHistogramForInnerJoin(&leftHist.Histogram, &rightHist.Histogram, leftKeys[i].RetType)
leftCol := &statistics.Column{Info: leftHist.Info, Histogram: *newHist}
rightCol := &statistics.Column{Info: rightHist.Info, Histogram: *newHist}
ndv *= float64(newHist.NDV)
lPosNew := p.schema.ColumnIndex(leftKeys[i])
rPosNew := p.schema.ColumnIndex(rightKeys[i])
cardinality[lPosNew] = ndv
cardinality[rPosNew] = ndv
winoros marked this conversation as resolved.
Show resolved Hide resolved
newColID2Hist[leftKeys[i].UniqueID] = leftCol
newColID2Hist[rightKeys[i].UniqueID] = rightCol
continue
}
keyNdv := math.Min(leftChild.statsInfo().Cardinality[lPos], rightChild.statsInfo().Cardinality[rPos])
ndv *= keyNdv
}
count := leftProfile.RowCount / leftNdv * rightProfile.RowCount / rightNdv * ndv

for uniqID, colHist := range leftProfile.HistColl.Columns {
_, ok := newColID2Hist[uniqID]
if ok {
continue
}
newColID2Hist[uniqID] = colHist
winoros marked this conversation as resolved.
Show resolved Hide resolved
}

for oldID, idxHist := range leftProfile.HistColl.Indices {
newIdxID2Hist[idxHistID] = idxHist
leftIndexReLabel[oldID] = idxHistID
newIdx2ColumnIDs[idxHistID] = leftProfile.HistColl.Idx2ColumnIDs[oldID]
idxHistID++
}

for oldID, idxHist := range rightProfile.HistColl.Indices {
newIdxID2Hist[idxHistID] = idxHist
leftIndexReLabel[oldID] = idxHistID
newIdx2ColumnIDs[idxHistID] = rightProfile.HistColl.Idx2ColumnIDs[oldID]
idxHistID++
}
eurekaka marked this conversation as resolved.
Show resolved Hide resolved

newColID2IdxID := make(map[int64]int64)

for colID, oldIdxID := range leftProfile.HistColl.ColID2IdxID {
newColID2IdxID[colID] = leftIndexReLabel[oldIdxID]
}
for colID, oldIdxID := range rightProfile.HistColl.ColID2IdxID {
newColID2IdxID[colID] = rightIndexReLabel[oldIdxID]
}

newHistColl := statistics.HistColl{
Count: int64(count),
Columns: newColID2Hist,
Indices: newIdxID2Hist,
Idx2ColumnIDs: newIdx2ColumnIDs,
ColID2IdxID: newColID2IdxID,
}
p.stats = &property.StatsInfo{
RowCount: count,
Cardinality: cardinality,
HistColl: newHistColl,
}
return p.stats, nil
}

func (la *LogicalApply) deriveStats() (*property.StatsInfo, error) {
leftProfile, err := la.children[0].deriveStats()
if err != nil {
Expand Down
152 changes: 152 additions & 0 deletions statistics/histogram.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"github.com/pingcap/tidb/util/sqlexec"
"github.com/pingcap/tipb/go-tipb"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"golang.org/x/net/context"
)

Expand Down Expand Up @@ -465,6 +466,10 @@ func (hg *Histogram) totalRowCount() float64 {
return float64(hg.Buckets[hg.Len()-1].Count + hg.NullCount)
}

func (hg *Histogram) notNullRowCount() float64 {
return float64(hg.Buckets[hg.Len()-1].Count)
}

// mergeBuckets is used to merge every two neighbor buckets.
func (hg *Histogram) mergeBuckets(bucketIdx int) {
curBuck := 0
Expand Down Expand Up @@ -686,6 +691,153 @@ func (hg *Histogram) outOfRange(val types.Datum) bool {
chunk.Compare(hg.Bounds.GetRow(hg.Bounds.NumRows()-1), 0, &val) < 0
}

// MergeHistogramForInnerJoin merges two histogram into one for inner join.
func MergeHistogramForInnerJoin(lSide *Histogram, rSide *Histogram, tp *types.FieldType) *Histogram {
alivxxx marked this conversation as resolved.
Show resolved Hide resolved
cmpFunc := chunk.GetCompareFunc(tp)
calcOverlapFunc := getOverlapCalculateFunc(tp)
lAvgPerVal := lSide.notNullRowCount() / float64(lSide.NDV)
rAvgPerVal := rSide.notNullRowCount() / float64(rSide.NDV)
logrus.Warnf("left avg: %v, right avg: %v", lAvgPerVal, rAvgPerVal)
winoros marked this conversation as resolved.
Show resolved Hide resolved
totCnt := int64(0)
totNdv := float64(0)
newHist := NewHistogram(0, 0, 0, 0, tp, 256, 0)
winoros marked this conversation as resolved.
Show resolved Hide resolved
for i, j := 0, 0; i < lSide.Bounds.NumRows() && j < rSide.Bounds.NumRows(); {
winoros marked this conversation as resolved.
Show resolved Hide resolved
lLow := lSide.Bounds.GetRow(i)
lHigh := lSide.Bounds.GetRow(i + 1)
rLow := rSide.Bounds.GetRow(j)
rHigh := rSide.Bounds.GetRow(j + 1)
// If [lLow, lHigh] is totally behind [rLow, rHigh], move r point.
if cmpFunc(lLow, 0, rHigh, 0) > 0 {
j += 2
continue
}
// If [rLow, rHigh] is totally behind [lLow, lHigh], move l point.
if cmpFunc(rLow, 0, lHigh, 0) > 0 {
i += 2
continue
}
var overlapLow, overLapHigh chunk.Row
if cmpFunc(lLow, 0, rLow, 0) > 0 {
overlapLow = lLow
} else {
overlapLow = rLow
}
if cmpFunc(lHigh, 0, rHigh, 0) > 0 {
overLapHigh = rHigh
} else {
overLapHigh = lHigh
}
// Calculate overlap ratio.
leftOverLap := calcOverlapFunc(lLow, lHigh, overlapLow, overLapHigh)
rightOverLap := calcOverlapFunc(rLow, rHigh, overlapLow, overLapHigh)
lCount := float64(lSide.bucketCount(i/2)) * leftOverLap
rCount := float64(rSide.bucketCount(j/2)) * rightOverLap
lNdv := lCount / lAvgPerVal
rNdv := rCount / rAvgPerVal
// bucketCount is lCount/lNdv * rCount/rNdv * finalNdv, where finalNdv is min(lNdv, rNdv).
bucketCount := lCount * rCount / math.Max(lNdv, rNdv)
winoros marked this conversation as resolved.
Show resolved Hide resolved
// Update histogram.
totCnt += int64(bucketCount)
newHist.Bounds.AppendRow(overlapLow)
newHist.Bounds.AppendRow(overLapHigh)
newHist.Buckets = append(newHist.Buckets, Bucket{Count: totCnt, Repeat: int64(bucketCount / math.Min(lNdv, rNdv))})
totNdv += math.Min(lNdv, rNdv)
// Move i and j by compare result.
switch cmpFunc(lHigh, 0, rHigh, 0) {
case -1:
i += 2
case 0:
i += 2
j += 2
case 1:
j += 2
}
}
newHist.NDV = int64(totNdv)
return newHist
}

func getOverlapCalculateFunc(tp *types.FieldType) func(l, r, lInner, rInner chunk.Row) float64 {
winoros marked this conversation as resolved.
Show resolved Hide resolved
switch tp.Tp {
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeYear:
if mysql.HasUnsignedFlag(tp.Flag) {
return calculateUintOverlap
}
return calculateIntOverlap
case mysql.TypeFloat:
return calculateFloat32Overlap
case mysql.TypeDouble:
return calculateFloat64Overlap
case mysql.TypeString, mysql.TypeVarString, mysql.TypeVarchar,
mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob:
return calculateStringOverlap
case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp:
return calculateTimeOverlap
case mysql.TypeDuration:
return calculateDurationOverlap
case mysql.TypeNewDecimal:
return calculateDecimalOverlap

// case mysql.TypeSet, mysql.TypeEnum, mysql.TypeBit, mysql.TypeJSON:
winoros marked this conversation as resolved.
Show resolved Hide resolved
}
return calculateOverlapDefault
}

func calculateOverlapDefault(_, _, _, _ chunk.Row) float64 {
return 0.5
}

func calculateIntOverlap(left, right, lInner, rInner chunk.Row) float64 {
return float64(rInner.GetInt64(0)-lInner.GetInt64(0)+1) / float64(right.GetInt64(0)-left.GetInt64(0)+1)
}

func calculateUintOverlap(left, right, lInner, rInner chunk.Row) float64 {
return float64(rInner.GetUint64(0)-lInner.GetUint64(0)+1) / float64(right.GetUint64(0)-left.GetUint64(0)+1)
}

func calculateFloat32Overlap(left, right, lInner, rInner chunk.Row) float64 {
return float64(rInner.GetFloat32(0)-lInner.GetFloat32(0)) / float64(right.GetFloat32(0)-left.GetFloat32(0))
winoros marked this conversation as resolved.
Show resolved Hide resolved
}

func calculateFloat64Overlap(left, right, lInner, rInner chunk.Row) float64 {
return rInner.GetFloat64(0) - lInner.GetFloat64(0)/right.GetFloat64(0) - left.GetFloat64(0)
}

func calculateStringOverlap(left, right, lInner, rInner chunk.Row) float64 {
common := commonPrefixLength(left.GetBytes(0), right.GetBytes(0))
return (convertBytesToScalar(right.GetBytes(0)[common:]) - convertBytesToScalar(left.GetBytes(0)[common:])) / (convertBytesToScalar(rInner.GetBytes(0)[common:]) - convertBytesToScalar(lInner.GetBytes(0)[common:]))
winoros marked this conversation as resolved.
Show resolved Hide resolved
}

func calculateTimeOverlap(left, right, lInner, rInner chunk.Row) float64 {
lTime, rTime := left.GetTime(0), right.GetTime(0)
lInnerTime, rInnerTime := lInner.GetTime(0), rInner.GetTime(0)
return (convertTimeToScalar(&rInnerTime) - convertTimeToScalar(&lInnerTime)) / (convertTimeToScalar(&rTime) - convertTimeToScalar(&lTime))
}

func calculateDurationOverlap(left, right, lInner, rInner chunk.Row) float64 {
return float64(rInner.GetDuration(0, 0).Duration-lInner.GetDuration(0, 0).Duration) / float64(right.GetDuration(0, 0).Duration-left.GetDuration(0, 0).Duration)
}

func calculateDecimalOverlap(left, right, lInner, rInner chunk.Row) float64 {
lDec, err := left.GetMyDecimal(0).ToFloat64()
if err != nil {
return 0
}
rDec, err := right.GetMyDecimal(0).ToFloat64()
if err != nil {
return 0
}
lInnerDec, err := lInner.GetMyDecimal(0).ToFloat64()
if err != nil {
return 0
}
rInnerDec, err := rInner.GetMyDecimal(0).ToFloat64()
if err != nil {
return 0
}
return (rInnerDec - lInnerDec) / (rDec - lDec)
}

// ErrorRate is the error rate of estimate row count by bucket and cm sketch.
type ErrorRate struct {
ErrorTotal float64
Expand Down
61 changes: 61 additions & 0 deletions statistics/histogram_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Copyright 2018 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package statistics

import (
. "github.com/pingcap/check"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
)

var _ = Suite(&HistogramTestSuite{})

type HistogramTestSuite struct {
}

func (s *HistogramTestSuite) TestMergeHistogramForInnerJoinIntCase(c *C) {
intTp := types.NewFieldType(mysql.TypeLonglong)
// aHist: 60 distinct value, each value repeats 2 times.
aHist := NewHistogram(0, 60, 0, 0, intTp, chunk.InitialCapacity, 0)
// [100, 200]
aHist.Bounds.AppendInt64(0, 100)
aHist.Bounds.AppendInt64(0, 200)
aHist.Buckets = append(aHist.Buckets, Bucket{Repeat: 2, Count: 100})
// [210, 230]
aHist.Bounds.AppendInt64(0, 210)
aHist.Bounds.AppendInt64(0, 230)
aHist.Buckets = append(aHist.Buckets, Bucket{Repeat: 2, Count: 120})
// bHist: 100 distinct value, each value repeats 100 times.
bHist := NewHistogram(0, 100, 0, 0, intTp, chunk.InitialCapacity, 0)
// [90, 120]
bHist.Bounds.AppendInt64(0, 90)
bHist.Bounds.AppendInt64(0, 120)
bHist.Buckets = append(bHist.Buckets, Bucket{Repeat: 100, Count: 3000})
// [130, 160]
bHist.Bounds.AppendInt64(0, 130)
bHist.Bounds.AppendInt64(0, 160)
bHist.Buckets = append(bHist.Buckets, Bucket{Repeat: 100, Count: 6000})
// [180, 220]
bHist.Bounds.AppendInt64(0, 180)
bHist.Bounds.AppendInt64(0, 220)
bHist.Buckets = append(bHist.Buckets, Bucket{Repeat: 100, Count: 10000})
finalHist := MergeHistogramForInnerJoin(aHist, bHist, intTp)

c.Assert(finalHist.ToString(0), Equals, `column:0 ndv:41 totColSize:0
num: 2079 lower_bound: 100 upper_bound: 120 repeats: 200
num: 3069 lower_bound: 130 upper_bound: 160 repeats: 200
num: 2079 lower_bound: 180 upper_bound: 200 repeats: 200
num: 1047 lower_bound: 210 upper_bound: 220 repeats: 200`)
}
Loading