diff --git a/go/vt/sqlparser/ast.go b/go/vt/sqlparser/ast.go index e42227d52d6..eef6382227e 100644 --- a/go/vt/sqlparser/ast.go +++ b/go/vt/sqlparser/ast.go @@ -1966,14 +1966,17 @@ type ( // ExtractedSubquery is a subquery that has been extracted from the original AST // This is a struct that the parser will never produce - it's written and read by the gen4 planner + // CAUTION: you should only change argName and hasValuesArg through the setter methods ExtractedSubquery struct { Original Expr // original expression that was replaced by this ExtractedSubquery - ArgName string - HasValuesArg string - OpCode int // this should really be engine.PulloutOpCode, but we cannot depend on engine :( + OpCode int // this should really be engine.PulloutOpCode, but we cannot depend on engine :( Subquery *Subquery OtherSide Expr // represents the side of the comparison, this field will be nil if Original is not a comparison NeedsRewrite bool // tells whether we need to rewrite this subquery to Original or not + + hasValuesArg string + argName string + alternative Expr // this is what will be used to Format this struct } ) diff --git a/go/vt/sqlparser/ast_clone.go b/go/vt/sqlparser/ast_clone.go index 96060511316..0cd891306c8 100644 --- a/go/vt/sqlparser/ast_clone.go +++ b/go/vt/sqlparser/ast_clone.go @@ -864,6 +864,7 @@ func CloneRefOfExtractedSubquery(n *ExtractedSubquery) *ExtractedSubquery { out.Original = CloneExpr(n.Original) out.Subquery = CloneRefOfSubquery(n.Subquery) out.OtherSide = CloneExpr(n.OtherSide) + out.alternative = CloneExpr(n.alternative) return &out } diff --git a/go/vt/sqlparser/ast_equals.go b/go/vt/sqlparser/ast_equals.go index 20794604648..0c2d7d2b4f9 100644 --- a/go/vt/sqlparser/ast_equals.go +++ b/go/vt/sqlparser/ast_equals.go @@ -1563,13 +1563,14 @@ func EqualsRefOfExtractedSubquery(a, b *ExtractedSubquery) bool { if a == nil || b == nil { return false } - return a.ArgName == b.ArgName && - a.HasValuesArg == b.HasValuesArg && - a.OpCode == b.OpCode && + return a.OpCode == b.OpCode && a.NeedsRewrite == b.NeedsRewrite && + a.hasValuesArg == b.hasValuesArg && + a.argName == b.argName && EqualsExpr(a.Original, b.Original) && EqualsRefOfSubquery(a.Subquery, b.Subquery) && - EqualsExpr(a.OtherSide, b.OtherSide) + EqualsExpr(a.OtherSide, b.OtherSide) && + EqualsExpr(a.alternative, b.alternative) } // EqualsRefOfFlush does deep equals between the two objects. diff --git a/go/vt/sqlparser/ast_format.go b/go/vt/sqlparser/ast_format.go index ffcfd62944f..a65f65e91ef 100644 --- a/go/vt/sqlparser/ast_format.go +++ b/go/vt/sqlparser/ast_format.go @@ -1726,31 +1726,5 @@ func (node *RenameTable) Format(buf *TrackedBuffer) { // it will be formatted as if the subquery has been extracted, and instead // show up like argument comparisons func (node *ExtractedSubquery) Format(buf *TrackedBuffer) { - switch original := node.Original.(type) { - case *ExistsExpr: - buf.astPrintf(node, "%v", NewArgument(node.ArgName)) - case *ComparisonExpr: - // other_side = :__sq - cmp := &ComparisonExpr{ - Left: node.OtherSide, - Right: NewArgument(node.ArgName), - Operator: original.Operator, - } - var expr Expr = cmp - switch original.Operator { - case InOp: - // :__sq_has_values = 1 and other_side in ::__sq - cmp.Right = NewListArg(node.ArgName) - hasValue := &ComparisonExpr{Left: NewArgument(node.HasValuesArg), Right: NewIntLiteral("1"), Operator: EqualOp} - expr = AndExpressions(hasValue, cmp) - case NotInOp: - // :__sq_has_values = 0 or other_side not in ::__sq - cmp.Right = NewListArg(node.ArgName) - hasValue := &ComparisonExpr{Left: NewArgument(node.HasValuesArg), Right: NewIntLiteral("0"), Operator: EqualOp} - expr = &OrExpr{hasValue, cmp} - } - buf.astPrintf(node, "%v", expr) - case *Subquery: - buf.astPrintf(node, "%v", NewArgument(node.ArgName)) - } + node.alternative.Format(buf) } diff --git a/go/vt/sqlparser/ast_format_fast.go b/go/vt/sqlparser/ast_format_fast.go index 7753fe520cd..22cdef2f4a9 100644 --- a/go/vt/sqlparser/ast_format_fast.go +++ b/go/vt/sqlparser/ast_format_fast.go @@ -2256,31 +2256,5 @@ func (node *RenameTable) formatFast(buf *TrackedBuffer) { // it will be formatted as if the subquery has been extracted, and instead // show up like argument comparisons func (node *ExtractedSubquery) formatFast(buf *TrackedBuffer) { - switch original := node.Original.(type) { - case *ExistsExpr: - buf.printExpr(node, NewArgument(node.ArgName), true) - case *ComparisonExpr: - // other_side = :__sq - cmp := &ComparisonExpr{ - Left: node.OtherSide, - Right: NewArgument(node.ArgName), - Operator: original.Operator, - } - var expr Expr = cmp - switch original.Operator { - case InOp: - // :__sq_has_values = 1 and other_side in ::__sq - cmp.Right = NewListArg(node.ArgName) - hasValue := &ComparisonExpr{Left: NewArgument(node.HasValuesArg), Right: NewIntLiteral("1"), Operator: EqualOp} - expr = AndExpressions(hasValue, cmp) - case NotInOp: - // :__sq_has_values = 0 or other_side not in ::__sq - cmp.Right = NewListArg(node.ArgName) - hasValue := &ComparisonExpr{Left: NewArgument(node.HasValuesArg), Right: NewIntLiteral("0"), Operator: EqualOp} - expr = &OrExpr{hasValue, cmp} - } - buf.printExpr(node, expr, true) - case *Subquery: - buf.printExpr(node, NewArgument(node.ArgName), true) - } + node.alternative.Format(buf) } diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index 596012d9243..bc1f1e22eda 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -1455,3 +1455,56 @@ func GetAllSelects(selStmt SelectStatement) []*Select { } panic("[BUG]: unknown type for SelectStatement") } + +// SetArgName sets argument name. +func (es *ExtractedSubquery) SetArgName(n string) { + es.argName = n + es.updateAlternative() +} + +// SetHasValuesArg sets has_values argument. +func (es *ExtractedSubquery) SetHasValuesArg(n string) { + es.hasValuesArg = n + es.updateAlternative() +} + +// GetArgName returns argument name. +func (es *ExtractedSubquery) GetArgName() string { + return es.argName +} + +// GetHasValuesArg returns has values argument. +func (es *ExtractedSubquery) GetHasValuesArg() string { + return es.hasValuesArg + +} + +func (es *ExtractedSubquery) updateAlternative() { + switch original := es.Original.(type) { + case *ExistsExpr: + es.alternative = NewArgument(es.argName) + case *Subquery: + es.alternative = NewArgument(es.argName) + case *ComparisonExpr: + // other_side = :__sq + cmp := &ComparisonExpr{ + Left: es.OtherSide, + Right: NewArgument(es.argName), + Operator: original.Operator, + } + var expr Expr = cmp + switch original.Operator { + case InOp: + // :__sq_has_values = 1 and other_side in ::__sq + cmp.Right = NewListArg(es.argName) + hasValue := &ComparisonExpr{Left: NewArgument(es.hasValuesArg), Right: NewIntLiteral("1"), Operator: EqualOp} + expr = AndExpressions(hasValue, cmp) + case NotInOp: + // :__sq_has_values = 0 or other_side not in ::__sq + cmp.Right = NewListArg(es.argName) + hasValue := &ComparisonExpr{Left: NewArgument(es.hasValuesArg), Right: NewIntLiteral("0"), Operator: EqualOp} + expr = &OrExpr{hasValue, cmp} + } + es.alternative = expr + } +} diff --git a/go/vt/sqlparser/ast_rewrite.go b/go/vt/sqlparser/ast_rewrite.go index 747d7bfb6f2..11fdb7691c8 100644 --- a/go/vt/sqlparser/ast_rewrite.go +++ b/go/vt/sqlparser/ast_rewrite.go @@ -1916,6 +1916,11 @@ func (a *application) rewriteRefOfExtractedSubquery(parent SQLNode, node *Extrac }) { return false } + if !a.rewriteExpr(node, node.alternative, func(newNode, parent SQLNode) { + parent.(*ExtractedSubquery).alternative = newNode.(Expr) + }) { + return false + } if a.post != nil { a.cur.replacer = replacer a.cur.parent = parent diff --git a/go/vt/sqlparser/ast_visit.go b/go/vt/sqlparser/ast_visit.go index c75bc920069..8511e40c1fe 100644 --- a/go/vt/sqlparser/ast_visit.go +++ b/go/vt/sqlparser/ast_visit.go @@ -1045,6 +1045,9 @@ func VisitRefOfExtractedSubquery(in *ExtractedSubquery, f Visit) error { if err := VisitExpr(in.OtherSide, f); err != nil { return err } + if err := VisitExpr(in.alternative, f); err != nil { + return err + } return nil } func VisitRefOfFlush(in *Flush, f Visit) error { diff --git a/go/vt/sqlparser/cached_size.go b/go/vt/sqlparser/cached_size.go index 7c4f311a396..0b5b2d72268 100644 --- a/go/vt/sqlparser/cached_size.go +++ b/go/vt/sqlparser/cached_size.go @@ -934,22 +934,26 @@ func (cached *ExtractedSubquery) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(96) + size += int64(112) } // field Original vitess.io/vitess/go/vt/sqlparser.Expr if cc, ok := cached.Original.(cachedObject); ok { size += cc.CachedSize(true) } - // field ArgName string - size += hack.RuntimeAllocSize(int64(len(cached.ArgName))) - // field HasValuesArg string - size += hack.RuntimeAllocSize(int64(len(cached.HasValuesArg))) // field Subquery *vitess.io/vitess/go/vt/sqlparser.Subquery size += cached.Subquery.CachedSize(true) // field OtherSide vitess.io/vitess/go/vt/sqlparser.Expr if cc, ok := cached.OtherSide.(cachedObject); ok { size += cc.CachedSize(true) } + // field hasValuesArg string + size += hack.RuntimeAllocSize(int64(len(cached.hasValuesArg))) + // field argName string + size += hack.RuntimeAllocSize(int64(len(cached.argName))) + // field alternative vitess.io/vitess/go/vt/sqlparser.Expr + if cc, ok := cached.alternative.(cachedObject); ok { + size += cc.CachedSize(true) + } return size } func (cached *Flush) CachedSize(alloc bool) int64 { diff --git a/go/vt/sqlparser/precedence.go b/go/vt/sqlparser/precedence.go index b6e7397a6a0..e1e6cce9c42 100644 --- a/go/vt/sqlparser/precedence.go +++ b/go/vt/sqlparser/precedence.go @@ -91,7 +91,7 @@ func precedenceFor(in Expr) Precendence { case *IntervalExpr: return P1 case *ExtractedSubquery: - return precedenceFor(node.Original) + return precedenceFor(node.alternative) } return Syntactic diff --git a/go/vt/sqlparser/precedence_test.go b/go/vt/sqlparser/precedence_test.go index d6cbb8a845a..cbc481bb4d3 100644 --- a/go/vt/sqlparser/precedence_test.go +++ b/go/vt/sqlparser/precedence_test.go @@ -72,22 +72,43 @@ func TestNotInSubqueryPrecedence(t *testing.T) { not := tree.(*Select).Where.Expr.(*NotExpr) cmp := not.Expr.(*ComparisonExpr) subq := cmp.Right.(*Subquery) - fmt.Println(subq) extracted := &ExtractedSubquery{ - Original: cmp, - ArgName: "arg1", - HasValuesArg: "has_values1", - OpCode: 1, - Subquery: subq, - OtherSide: cmp.Left, + Original: cmp, + OpCode: 1, + Subquery: subq, + OtherSide: cmp.Left, } - not.Expr = extracted + extracted.SetArgName("arg1") + extracted.SetHasValuesArg("has_values1") + not.Expr = extracted output := readable(not) assert.Equal(t, "not (:has_values1 = 1 and id in ::arg1)", output) } +func TestSubqueryPrecedence(t *testing.T) { + tree, err := Parse("select * from a where id in (select 42) and false") + require.NoError(t, err) + where := tree.(*Select).Where + andExpr := where.Expr.(*AndExpr) + cmp := andExpr.Left.(*ComparisonExpr) + subq := cmp.Right.(*Subquery) + + extracted := &ExtractedSubquery{ + Original: andExpr.Left, + OpCode: 1, + Subquery: subq, + OtherSide: cmp.Left, + } + extracted.SetArgName("arg1") + extracted.SetHasValuesArg("has_values1") + + andExpr.Left = extracted + output := readable(extracted) + assert.Equal(t, ":has_values1 = 1 and id in ::arg1", output) +} + func TestPlusStarPrecedence(t *testing.T) { validSQL := []struct { input string @@ -193,7 +214,7 @@ func TestRandom(t *testing.T) { // The purpose of this test is to find discrepancies between Format and parsing. If for example our precedence rules are not consistent between the two, this test should find it. // The idea is to generate random queries, and pass them through the parser and then the unparser, and one more time. The result of the first unparse should be the same as the second result. seed := time.Now().UnixNano() - fmt.Println(fmt.Sprintf("seed is %d", seed)) //nolint + fmt.Println(fmt.Sprintf("seed is %d", seed)) // nolint g := newGenerator(seed, 5) endBy := time.Now().Add(1 * time.Second) diff --git a/go/vt/vtgate/planbuilder/abstract/concatenate.go b/go/vt/vtgate/planbuilder/abstract/concatenate.go index 0cd3b231c72..f522c4f4b0f 100644 --- a/go/vt/vtgate/planbuilder/abstract/concatenate.go +++ b/go/vt/vtgate/planbuilder/abstract/concatenate.go @@ -50,7 +50,7 @@ func (c *Concatenate) PushPredicate(sqlparser.Expr, *semantics.SemTable) error { // UnsolvedPredicates implements the Operator interface func (c *Concatenate) UnsolvedPredicates(*semantics.SemTable) []sqlparser.Expr { - panic("implement me") + return nil } // CheckValid implements the Operator interface diff --git a/go/vt/vtgate/planbuilder/abstract/operator.go b/go/vt/vtgate/planbuilder/abstract/operator.go index 001459ac039..4894b96c9b7 100644 --- a/go/vt/vtgate/planbuilder/abstract/operator.go +++ b/go/vt/vtgate/planbuilder/abstract/operator.go @@ -190,17 +190,13 @@ func createOperatorFromSelect(sel *sqlparser.Select, semTable *semantics.SemTabl if len(semTable.SubqueryMap[sel]) > 0 { resultantOp = &SubQuery{} for _, sq := range semTable.SubqueryMap[sel] { - subquerySelectStatement, isSel := sq.Subquery.Select.(*sqlparser.Select) - if !isSel { - return nil, semantics.Gen4NotSupportedF("UNION in subquery") - } - opInner, err := createOperatorFromSelect(subquerySelectStatement, semTable) + opInner, err := CreateOperatorFromAST(sq.Subquery.Select, semTable) if err != nil { return nil, err } resultantOp.Inner = append(resultantOp.Inner, &SubQueryInner{ - Inner: opInner, ExtractedSubquery: sq, + Inner: opInner, }) } } diff --git a/go/vt/vtgate/planbuilder/abstract/operator_test.go b/go/vt/vtgate/planbuilder/abstract/operator_test.go index 704808b956c..b86abfabd31 100644 --- a/go/vt/vtgate/planbuilder/abstract/operator_test.go +++ b/go/vt/vtgate/planbuilder/abstract/operator_test.go @@ -133,8 +133,8 @@ func testString(op Operator) string { var inners []string for _, sqOp := range op.Inner { subquery := fmt.Sprintf("{\n\tType: %s", engine.PulloutOpcode(sqOp.ExtractedSubquery.OpCode).String()) - if sqOp.ExtractedSubquery.ArgName != "" { - subquery += fmt.Sprintf("\n\tArgName: %s", sqOp.ExtractedSubquery.ArgName) + if sqOp.ExtractedSubquery.GetArgName() != "" { + subquery += fmt.Sprintf("\n\tArgName: %s", sqOp.ExtractedSubquery.GetArgName()) } subquery += fmt.Sprintf("\n\tQuery: %s\n}", indent(testString(sqOp.Inner))) subquery = indent(subquery) diff --git a/go/vt/vtgate/planbuilder/querytree_transformers.go b/go/vt/vtgate/planbuilder/querytree_transformers.go index fa669ee6b5c..1bf593f4cd9 100644 --- a/go/vt/vtgate/planbuilder/querytree_transformers.go +++ b/go/vt/vtgate/planbuilder/querytree_transformers.go @@ -98,8 +98,15 @@ func transformSubqueryTree(ctx *planningContext, n *subqueryTree) (logicalPlan, return nil, err } - plan := newPulloutSubquery(engine.PulloutOpcode(n.extracted.OpCode), n.extracted.ArgName, n.extracted.HasValuesArg, innerPlan) + argName := n.extracted.GetArgName() + hasValuesArg := n.extracted.GetHasValuesArg() outerPlan, err := transformToLogicalPlan(ctx, n.outer) + + merged := mergeSubQueryPlan(ctx, innerPlan, outerPlan, n) + if merged != nil { + return merged, nil + } + plan := newPulloutSubquery(engine.PulloutOpcode(n.extracted.OpCode), argName, hasValuesArg, innerPlan) if err != nil { return nil, err } @@ -107,6 +114,27 @@ func transformSubqueryTree(ctx *planningContext, n *subqueryTree) (logicalPlan, return plan, err } +func mergeSubQueryPlan(ctx *planningContext, inner, outer logicalPlan, n *subqueryTree) logicalPlan { + iroute, ok := inner.(*route) + if !ok { + return nil + } + oroute, ok := outer.(*route) + if !ok { + return nil + } + + if canMergeSubqueryPlans(ctx, iroute, oroute) { + // n.extracted is an expression that lives in oroute.Select. + // Instead of looking for it in the AST, we have a copy in the subquery tree that we can update + n.extracted.NeedsRewrite = true + replaceSubQuery(ctx, oroute.Select) + + return oroute + } + return nil +} + func transformDerivedPlan(ctx *planningContext, n *derivedTree) (logicalPlan, error) { // transforming the inner part of the derived table into a logical plan // so that we can do horizon planning on the inner. If the logical plan @@ -267,7 +295,7 @@ func mergeUnionLogicalPlans(ctx *planningContext, left logicalPlan, right logica return nil } - if canMergePlans(ctx, lroute, rroute) { + if canMergeUnionPlans(ctx, lroute, rroute) { lroute.Select = &sqlparser.Union{Left: lroute.Select, Distinct: false, Right: rroute.Select} return lroute } @@ -317,7 +345,7 @@ func transformRoutePlan(ctx *planningContext, n *routeTree) (*route, error) { if subq, isSubq := predicate.Right.(*sqlparser.Subquery); isSubq { extractedSubquery := ctx.semTable.FindSubqueryReference(subq) if extractedSubquery != nil { - extractedSubquery.ArgName = engine.ListVarName + extractedSubquery.SetArgName(engine.ListVarName) } } predicate.Right = sqlparser.ListArg(engine.ListVarName) @@ -513,7 +541,7 @@ func (sqr *subQReplacer) replacer(cursor *sqlparser.Cursor) bool { } for _, replaceByExpr := range sqr.subqueryToReplace { // we are comparing the ArgNames in case the expressions have been cloned - if ext.ArgName == replaceByExpr.ArgName { + if ext.GetArgName() == replaceByExpr.GetArgName() { cursor.Replace(ext.Original) sqr.replaced = true return false @@ -522,7 +550,7 @@ func (sqr *subQReplacer) replacer(cursor *sqlparser.Cursor) bool { return true } -func replaceSubQuery(ctx *planningContext, sel *sqlparser.Select) { +func replaceSubQuery(ctx *planningContext, sel sqlparser.SelectStatement) { extractedSubqueries := ctx.semTable.GetSubqueryNeedingRewrite() if len(extractedSubqueries) == 0 { return diff --git a/go/vt/vtgate/planbuilder/rewrite.go b/go/vt/vtgate/planbuilder/rewrite.go index 90c66333434..99d725f5e97 100644 --- a/go/vt/vtgate/planbuilder/rewrite.go +++ b/go/vt/vtgate/planbuilder/rewrite.go @@ -119,8 +119,8 @@ func rewriteInSubquery(cursor *sqlparser.Cursor, r *rewriter, node *sqlparser.Co r.inSubquery++ argName, hasValuesArg := r.reservedVars.ReserveSubQueryWithHasValues() - semTableSQ.ArgName = argName - semTableSQ.HasValuesArg = hasValuesArg + semTableSQ.SetArgName(argName) + semTableSQ.SetHasValuesArg(hasValuesArg) cursor.Replace(semTableSQ) return nil } @@ -130,12 +130,12 @@ func rewriteSubquery(cursor *sqlparser.Cursor, r *rewriter, node *sqlparser.Subq if !found { return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "BUG: came across subquery that was not in the subq map") } - if semTableSQ.ArgName != "" || engine.PulloutOpcode(semTableSQ.OpCode) != engine.PulloutValue { + if semTableSQ.GetArgName() != "" || engine.PulloutOpcode(semTableSQ.OpCode) != engine.PulloutValue { return nil } r.inSubquery++ argName := r.reservedVars.ReserveSubQuery() - semTableSQ.ArgName = argName + semTableSQ.SetArgName(argName) cursor.Replace(semTableSQ) return nil } @@ -148,7 +148,7 @@ func (r *rewriter) rewriteExistsSubquery(cursor *sqlparser.Cursor, node *sqlpars r.inSubquery++ argName := r.reservedVars.ReserveHasValuesSubQuery() - semTableSQ.ArgName = argName + semTableSQ.SetArgName(argName) cursor.Replace(semTableSQ) return nil } diff --git a/go/vt/vtgate/planbuilder/rewrite_test.go b/go/vt/vtgate/planbuilder/rewrite_test.go index afa9d680abd..f774fa3d568 100644 --- a/go/vt/vtgate/planbuilder/rewrite_test.go +++ b/go/vt/vtgate/planbuilder/rewrite_test.go @@ -41,7 +41,7 @@ func TestSubqueryRewrite(t *testing.T) { output: "select 1 from t1 where :__sq_has_values1", }, { input: "select id from t1 where id in (select 1)", - output: "select id from t1 where (:__sq_has_values1 = 1 and id in ::__sq1)", + output: "select id from t1 where :__sq_has_values1 = 1 and id in ::__sq1", }, { input: "select id from t1 where id not in (select 1)", output: "select id from t1 where :__sq_has_values1 = 0 or id not in ::__sq1", @@ -65,7 +65,7 @@ func TestSubqueryRewrite(t *testing.T) { output: "select id from t1 where not :__sq_has_values1 and :__sq_has_values2", }, { input: "select (select 1), (select 2) from t1 join t2 on t1.id = (select 1) where t1.id in (select 1)", - output: "select :__sq2, :__sq3 from t1 join t2 on t1.id = :__sq1 where (:__sq_has_values4 = 1 and t1.id in ::__sq4)", + output: "select :__sq2, :__sq3 from t1 join t2 on t1.id = :__sq1 where :__sq_has_values4 = 1 and t1.id in ::__sq4", }} for _, tcase := range tcases { t.Run(tcase.input, func(t *testing.T) { @@ -123,7 +123,7 @@ func TestHavingRewrite(t *testing.T) { output: "select count(*) as k from t1 where (x = 1 or y = 2) and a = 1 having count(*) = 1", }, { input: "select 1 from t1 where x in (select 1 from t2 having a = 1)", - output: "select 1 from t1 where (:__sq_has_values1 = 1 and x in ::__sq1)", + output: "select 1 from t1 where :__sq_has_values1 = 1 and x in ::__sq1", sqs: map[string]string{"__sq1": "select 1 from t2 where a = 1"}, }, {input: "select 1 from t1 group by a having a = 1 and count(*) > 1", output: "select 1 from t1 where a = 1 group by a having count(*) > 1", @@ -140,7 +140,7 @@ func TestHavingRewrite(t *testing.T) { assert.Equal(t, len(tcase.sqs), len(squeries), "number of subqueries not matched") } for _, sq := range squeries { - assert.Equal(t, tcase.sqs[sq.ArgName], sqlparser.String(sq.Subquery.Select)) + assert.Equal(t, tcase.sqs[sq.GetArgName()], sqlparser.String(sq.Subquery.Select)) } }) } diff --git a/go/vt/vtgate/planbuilder/route_planning.go b/go/vt/vtgate/planbuilder/route_planning.go index c1a5bd01da4..7d8ef25c746 100644 --- a/go/vt/vtgate/planbuilder/route_planning.go +++ b/go/vt/vtgate/planbuilder/route_planning.go @@ -889,7 +889,7 @@ func canMergeOnFilters(ctx *planningContext, a, b *routeTree, joinPredicates []s type mergeFunc func(a, b *routeTree) (*routeTree, error) -func canMergePlans(ctx *planningContext, a, b *route) bool { +func canMergeUnionPlans(ctx *planningContext, a, b *route) bool { // this method should be close to tryMerge below. it does the same thing, but on logicalPlans instead of queryTrees if a.eroute.Keyspace.Name != b.eroute.Keyspace.Name { return false @@ -919,6 +919,32 @@ func canMergePlans(ctx *planningContext, a, b *route) bool { } return false } +func canMergeSubqueryPlans(ctx *planningContext, a, b *route) bool { + // this method should be close to tryMerge below. it does the same thing, but on logicalPlans instead of queryTrees + if a.eroute.Keyspace.Name != b.eroute.Keyspace.Name { + return false + } + switch a.eroute.Opcode { + case engine.SelectUnsharded, engine.SelectReference: + return a.eroute.Opcode == b.eroute.Opcode + case engine.SelectDBA: + return b.eroute.Opcode == engine.SelectDBA && + len(a.eroute.SysTableTableSchema) == 0 && + len(a.eroute.SysTableTableName) == 0 && + len(b.eroute.SysTableTableSchema) == 0 && + len(b.eroute.SysTableTableName) == 0 + case engine.SelectEqualUnique: + // Check if they target the same shard. + if b.eroute.Opcode == engine.SelectEqualUnique && + a.eroute.Vindex == b.eroute.Vindex && + a.condition != nil && + b.condition != nil && + gen4ValuesEqual(ctx, []sqlparser.Expr{a.condition}, []sqlparser.Expr{b.condition}) { + return true + } + } + return false +} func tryMerge(ctx *planningContext, a, b queryTree, joinPredicates []sqlparser.Expr, merger mergeFunc) (queryTree, error) { aRoute, bRoute := queryTreesToRoutes(a.clone(), b.clone()) diff --git a/go/vt/vtgate/planbuilder/routetree.go b/go/vt/vtgate/planbuilder/routetree.go index 4e14c776707..0f085fb7414 100644 --- a/go/vt/vtgate/planbuilder/routetree.go +++ b/go/vt/vtgate/planbuilder/routetree.go @@ -452,9 +452,9 @@ func (rp *routeTree) makePlanValue(ctx *planningContext, n sqlparser.Expr) (*sql } switch engine.PulloutOpcode(extractedSubquery.OpCode) { case engine.PulloutIn, engine.PulloutNotIn: - expr = sqlparser.NewListArg(extractedSubquery.ArgName) + expr = sqlparser.NewListArg(extractedSubquery.GetArgName()) case engine.PulloutValue, engine.PulloutExists: - expr = sqlparser.NewArgument(extractedSubquery.ArgName) + expr = sqlparser.NewArgument(extractedSubquery.GetArgName()) } } pv, err := makePlanValue(expr) diff --git a/go/vt/vtgate/planbuilder/testdata/ddl_cases.txt b/go/vt/vtgate/planbuilder/testdata/ddl_cases.txt index 404bfad40e1..4e39c7bb83a 100644 --- a/go/vt/vtgate/planbuilder/testdata/ddl_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/ddl_cases.txt @@ -264,6 +264,7 @@ Gen4 plan same as above "Query": "create view view_a as select id, `name` from unsharded where id in (select id from unsharded where id = 1 union select id from unsharded where id = 3)" } } +Gen4 plan same as above # create view with subquery in unsharded keyspace with UNION clause "create view view_a as (select id from unsharded) union (select id from unsharded_auto) order by id limit 5" diff --git a/go/vt/vtgate/planbuilder/testdata/filter_cases.txt b/go/vt/vtgate/planbuilder/testdata/filter_cases.txt index a2aeec7f6ed..ae778ca4498 100644 --- a/go/vt/vtgate/planbuilder/testdata/filter_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/filter_cases.txt @@ -1738,46 +1738,7 @@ Gen4 plan same as above ] } } -{ - "QueryType": "SELECT", - "Original": "select id from user where id in (select col from user)", - "Instructions": { - "OperatorType": "Subquery", - "Variant": "PulloutIn", - "PulloutVars": [ - "__sq_has_values1", - "__sq1" - ], - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "SelectScatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select col from `user` where 1 != 1", - "Query": "select col from `user`", - "Table": "`user`" - }, - { - "OperatorType": "Route", - "Variant": "SelectIN", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select id from `user` where 1 != 1", - "Query": "select id from `user` where (:__sq_has_values1 = 1 and id in ::__vals)", - "Table": "`user`", - "Values": [ - "::__sq1" - ], - "Vindex": "user_index" - } - ] - } -} +Gen4 plan same as above # cross-shard subquery in NOT IN clause. "select id from user where id not in (select col from user)" @@ -2038,7 +1999,7 @@ Gen4 plan same as above "Sharded": true }, "FieldQuery": "select id2 from `user` where 1 != 1", - "Query": "select id2 from `user` where (:__sq_has_values2 = 1 and id2 in ::__sq2)", + "Query": "select id2 from `user` where :__sq_has_values2 = 1 and id2 in ::__sq2", "Table": "`user`" } ] @@ -2473,7 +2434,7 @@ Gen4 plan same as above "Sharded": true }, "FieldQuery": "select id from `user` where 1 != 1", - "Query": "select id from `user` where :__sq_has_values1 = 0 or id not in ::__sq1 and (:__sq_has_values2 = 1 and id in ::__vals)", + "Query": "select id from `user` where (:__sq_has_values1 = 0 or id not in ::__sq1) and (:__sq_has_values2 = 1 and id in ::__vals)", "Table": "`user`", "Values": [ "::__sq2" @@ -2611,46 +2572,7 @@ Gen4 plan same as above ] } } -{ - "QueryType": "SELECT", - "Original": "select id from user where id in (select col from unsharded where col = id)", - "Instructions": { - "OperatorType": "Subquery", - "Variant": "PulloutIn", - "PulloutVars": [ - "__sq_has_values1", - "__sq1" - ], - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "SelectUnsharded", - "Keyspace": { - "Name": "main", - "Sharded": false - }, - "FieldQuery": "select col from unsharded where 1 != 1", - "Query": "select col from unsharded where col = id", - "Table": "unsharded" - }, - { - "OperatorType": "Route", - "Variant": "SelectIN", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select id from `user` where 1 != 1", - "Query": "select id from `user` where (:__sq_has_values1 = 1 and id in ::__vals)", - "Table": "`user`", - "Values": [ - "::__sq1" - ], - "Vindex": "user_index" - } - ] - } -} +Gen4 plan same as above # correlated subquery with different keyspace tables involved "select id from user where id in (select col from unsharded where col = user.id)" @@ -3036,7 +2958,7 @@ Gen4 plan same as above "Sharded": true }, "FieldQuery": "select id from `user` where 1 != 1", - "Query": "select id from `user` where id = 5 and :__sq_has_values1 = 0 or id not in ::__sq1 and id in (select user_extra.col from user_extra where user_extra.user_id = 5)", + "Query": "select id from `user` where id = 5 and (:__sq_has_values1 = 0 or id not in ::__sq1) and id in (select user_extra.col from user_extra where user_extra.user_id = 5)", "Table": "`user`", "Values": [ 5 @@ -3393,7 +3315,7 @@ Gen4 plan same as above }, "FieldQuery": "select `user`.id, `user`.col, weight_string(`user`.id), weight_string(`user`.col) from `user` where 1 != 1", "OrderBy": "(0|2) ASC, (1|3) ASC", - "Query": "select `user`.id, `user`.col, weight_string(`user`.id), weight_string(`user`.col) from `user` where (:__sq_has_values1 = 1 and `user`.col in ::__sq1) order by `user`.id asc, `user`.col asc", + "Query": "select `user`.id, `user`.col, weight_string(`user`.id), weight_string(`user`.col) from `user` where :__sq_has_values1 = 1 and `user`.col in ::__sq1 order by `user`.id asc, `user`.col asc", "Table": "`user`" } ] diff --git a/go/vt/vtgate/planbuilder/testdata/from_cases.txt b/go/vt/vtgate/planbuilder/testdata/from_cases.txt index de0cb57ff30..827523a672f 100644 --- a/go/vt/vtgate/planbuilder/testdata/from_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/from_cases.txt @@ -2818,7 +2818,7 @@ Gen4 plan same as above "Sharded": false }, "FieldQuery": "select unsharded_a.col from unsharded_a, unsharded_b where 1 != 1", - "Query": "select unsharded_a.col from unsharded_a, unsharded_b where (:__sq_has_values1 = 1 and unsharded_a.col in ::__sq1)", + "Query": "select unsharded_a.col from unsharded_a, unsharded_b where :__sq_has_values1 = 1 and unsharded_a.col in ::__sq1", "Table": "unsharded_a, unsharded_b" } ] @@ -2882,61 +2882,7 @@ Gen4 plan same as above ] } } -{ - "QueryType": "SELECT", - "Original": "select unsharded.col from unsharded join user on user.col in (select col from user)", - "Instructions": { - "OperatorType": "Subquery", - "Variant": "PulloutIn", - "PulloutVars": [ - "__sq_has_values1", - "__sq1" - ], - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "SelectScatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select col from `user` where 1 != 1", - "Query": "select col from `user`", - "Table": "`user`" - }, - { - "OperatorType": "Join", - "Variant": "Join", - "JoinColumnIndexes": "-1", - "TableName": "unsharded_`user`", - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "SelectUnsharded", - "Keyspace": { - "Name": "main", - "Sharded": false - }, - "FieldQuery": "select unsharded.col from unsharded where 1 != 1", - "Query": "select unsharded.col from unsharded", - "Table": "unsharded" - }, - { - "OperatorType": "Route", - "Variant": "SelectScatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select 1 from `user` where 1 != 1", - "Query": "select 1 from `user` where (:__sq_has_values1 = 1 and `user`.col in ::__sq1)", - "Table": "`user`" - } - ] - } - ] - } -} +Gen4 plan same as above # subquery in ON clause, with left join primitives # The subquery is not pulled all the way out. @@ -3043,7 +2989,7 @@ Gen4 plan same as above "Sharded": true }, "FieldQuery": "select 1 from `user` where 1 != 1", - "Query": "select 1 from `user` where (:__sq_has_values1 = 1 and `user`.col in ::__sq1)", + "Query": "select 1 from `user` where :__sq_has_values1 = 1 and `user`.col in ::__sq1", "Table": "`user`" } ] @@ -3165,7 +3111,7 @@ Gen4 plan same as above "Sharded": true }, "FieldQuery": "select 1 from `user` where 1 != 1", - "Query": "select 1 from `user` where (:__sq_has_values1 = 1 and `user`.col in ::__sq1)", + "Query": "select 1 from `user` where :__sq_has_values1 = 1 and `user`.col in ::__sq1", "Table": "`user`" }, { @@ -4285,7 +4231,7 @@ Gen4 plan same as above "Sharded": true }, "FieldQuery": "select id from `user` where 1 != 1", - "Query": "select id from `user` where (:__sq_has_values1 = 1 and id in ::__vals) and col = :__sq2", + "Query": "select id from `user` where :__sq_has_values1 = 1 and id in ::__vals and col = :__sq2", "Table": "`user`", "Values": [ "::__sq1" diff --git a/go/vt/vtgate/planbuilder/testdata/postprocess_cases.txt b/go/vt/vtgate/planbuilder/testdata/postprocess_cases.txt index eea474e0c09..2bc80f59033 100644 --- a/go/vt/vtgate/planbuilder/testdata/postprocess_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/postprocess_cases.txt @@ -251,7 +251,7 @@ Gen4 error: Column 'col1' in field list is ambiguous "Sharded": true }, "FieldQuery": "select id from `user` where 1 != 1", - "Query": "select id from `user` where (:__sq_has_values1 = 1 and id in ::__vals)", + "Query": "select id from `user` where :__sq_has_values1 = 1 and id in ::__vals", "Table": "`user`", "Values": [ "::__sq1" @@ -589,44 +589,7 @@ Gen4 plan same as above ] } } -{ - "QueryType": "SELECT", - "Original": "select col from user where col in (select col2 from user) order by col", - "Instructions": { - "OperatorType": "Subquery", - "Variant": "PulloutIn", - "PulloutVars": [ - "__sq_has_values1", - "__sq1" - ], - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "SelectScatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select col2 from `user` where 1 != 1", - "Query": "select col2 from `user`", - "Table": "`user`" - }, - { - "OperatorType": "Route", - "Variant": "SelectScatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select col, weight_string(col) from `user` where 1 != 1", - "OrderBy": "(0|1) ASC", - "Query": "select col, weight_string(col) from `user` where (:__sq_has_values1 = 1 and col in ::__sq1) order by col asc", - "ResultColumns": 1, - "Table": "`user`" - } - ] - } -} +Gen4 plan same as above # ORDER BY NULL for join "select user.col1 as a, user.col2, music.col3 from user join music on user.id = music.id where user.id = 1 order by null" @@ -947,42 +910,7 @@ Gen4 plan same as above ] } } -{ - "QueryType": "SELECT", - "Original": "select col from user where col in (select col2 from user) order by null", - "Instructions": { - "OperatorType": "Subquery", - "Variant": "PulloutIn", - "PulloutVars": [ - "__sq_has_values1", - "__sq1" - ], - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "SelectScatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select col2 from `user` where 1 != 1", - "Query": "select col2 from `user`", - "Table": "`user`" - }, - { - "OperatorType": "Route", - "Variant": "SelectScatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select col from `user` where 1 != 1", - "Query": "select col from `user` where (:__sq_has_values1 = 1 and col in ::__sq1) order by null", - "Table": "`user`" - } - ] - } -} +Gen4 plan same as above # ORDER BY RAND() "select col from user order by RAND()" @@ -1134,42 +1062,7 @@ Gen4 plan same as above ] } } -{ - "QueryType": "SELECT", - "Original": "select col from user where col in (select col2 from user) order by rand()", - "Instructions": { - "OperatorType": "Subquery", - "Variant": "PulloutIn", - "PulloutVars": [ - "__sq_has_values1", - "__sq1" - ], - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "SelectScatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select col2 from `user` where 1 != 1", - "Query": "select col2 from `user`", - "Table": "`user`" - }, - { - "OperatorType": "Route", - "Variant": "SelectScatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select col from `user` where 1 != 1", - "Query": "select col from `user` where (:__sq_has_values1 = 1 and col in ::__sq1) order by rand()", - "Table": "`user`" - } - ] - } -} +Gen4 plan same as above # Order by, '*' expression "select * from user where id = 5 order by col" @@ -1694,48 +1587,7 @@ Gen4 plan same as above ] } } -{ - "QueryType": "SELECT", - "Original": "select col from user where col in (select col1 from user) limit 1", - "Instructions": { - "OperatorType": "Limit", - "Count": 1, - "Inputs": [ - { - "OperatorType": "Subquery", - "Variant": "PulloutIn", - "PulloutVars": [ - "__sq_has_values1", - "__sq1" - ], - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "SelectScatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select col1 from `user` where 1 != 1", - "Query": "select col1 from `user`", - "Table": "`user`" - }, - { - "OperatorType": "Route", - "Variant": "SelectScatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select col from `user` where 1 != 1", - "Query": "select col from `user` where (:__sq_has_values1 = 1 and col in ::__sq1) limit :__upper_limit", - "Table": "`user`" - } - ] - } - ] - } -} +Gen4 plan same as above # limit on reference table "select col from ref limit 1" diff --git a/go/vt/vtgate/planbuilder/testdata/select_cases.txt b/go/vt/vtgate/planbuilder/testdata/select_cases.txt index 21ec7e3eb2f..d160744ca94 100644 --- a/go/vt/vtgate/planbuilder/testdata/select_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/select_cases.txt @@ -1492,6 +1492,7 @@ Gen4 plan same as above "Table": "unsharded" } } +Gen4 plan same as above "(select id from unsharded) union (select id from unsharded_auto) order by id limit 5" { diff --git a/go/vt/vtgate/planbuilder/testdata/tpch_cases.txt b/go/vt/vtgate/planbuilder/testdata/tpch_cases.txt index 200efb1c916..3c93c0cd647 100644 --- a/go/vt/vtgate/planbuilder/testdata/tpch_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/tpch_cases.txt @@ -823,7 +823,7 @@ Gen4 error: unsupported: cross-shard correlated subquery }, "FieldQuery": "select o_custkey, o_orderkey, o_orderdate, o_totalprice, weight_string(o_orderkey), weight_string(o_orderdate), weight_string(o_totalprice) from orders where 1 != 1", "OrderBy": "(3|6) DESC, (2|5) ASC", - "Query": "select o_custkey, o_orderkey, o_orderdate, o_totalprice, weight_string(o_orderkey), weight_string(o_orderdate), weight_string(o_totalprice) from orders where (:__sq_has_values1 = 1 and o_orderkey in ::__vals) order by o_totalprice desc, o_orderdate asc", + "Query": "select o_custkey, o_orderkey, o_orderdate, o_totalprice, weight_string(o_orderkey), weight_string(o_orderdate), weight_string(o_totalprice) from orders where :__sq_has_values1 = 1 and o_orderkey in ::__vals order by o_totalprice desc, o_orderdate asc", "Table": "orders", "Values": [ "::__sq1" diff --git a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt index 08feb0c6598..22e32aa5d56 100644 --- a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt @@ -1,6 +1,7 @@ # union operations in subqueries (expressions) "select * from user where id in (select * from user union select * from user_extra)" "unsupported: '*' expression in cross-shard query" +Gen4 plan same as above # TODO: Implement support for select with a target destination "select * from `user[-]`.user_metadata" diff --git a/go/vt/vtgate/planbuilder/testdata/wireup_cases.txt b/go/vt/vtgate/planbuilder/testdata/wireup_cases.txt index d0e232c2184..19186f53e91 100644 --- a/go/vt/vtgate/planbuilder/testdata/wireup_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/wireup_cases.txt @@ -1192,7 +1192,7 @@ "Sharded": true }, "FieldQuery": "select 1 from `user` where 1 != 1", - "Query": "select 1 from `user` where (:__sq_has_values1 = 1 and id in ::__vals)", + "Query": "select 1 from `user` where :__sq_has_values1 = 1 and id in ::__vals", "Table": "`user`", "Values": [ "::__sq1" diff --git a/go/vt/vtgate/semantics/analyzer.go b/go/vt/vtgate/semantics/analyzer.go index 0dd65ebe7c0..625ff9f8391 100644 --- a/go/vt/vtgate/semantics/analyzer.go +++ b/go/vt/vtgate/semantics/analyzer.go @@ -17,10 +17,6 @@ limitations under the License. package semantics import ( - "fmt" - "runtime/debug" - "strings" - "vitess.io/vitess/go/vt/vtgate/vindexes" querypb "vitess.io/vitess/go/vt/proto/query" @@ -330,17 +326,6 @@ func (a *analyzer) tableSetFor(t *sqlparser.AliasedTableExpr) TableSet { return a.tables.tableSetFor(t) } -// Gen4NotSupportedF returns a common error for shortcomings in the gen4 planner -func Gen4NotSupportedF(format string, args ...interface{}) error { - message := fmt.Sprintf("gen4 does not yet support: "+format, args...) - - // add the line that this happens in so it is easy to find it - stack := string(debug.Stack()) - lines := strings.Split(stack, "\n") - message += "\n" + lines[6] - return vterrors.New(vtrpcpb.Code_UNIMPLEMENTED, message) -} - // ProjError is used to mark an error as something that should only be returned // if the planner fails to merge everything down to a single route type ProjError struct { diff --git a/go/vt/vtgate/semantics/tabletset.go b/go/vt/vtgate/semantics/tabletset.go index 23f1858b144..0bc0fc01a00 100644 --- a/go/vt/vtgate/semantics/tabletset.go +++ b/go/vt/vtgate/semantics/tabletset.go @@ -26,7 +26,7 @@ type largeTableSet struct { } func (ts *largeTableSet) overlapsSmall(small uint64) bool { - return ts.tables[0]&uint64(small) != 0 + return ts.tables[0]&small != 0 } func minlen(a, b []uint64) int { @@ -171,6 +171,7 @@ type TableSet struct { large *largeTableSet } +// Format formats the TableSet. func (ts TableSet) Format(f fmt.State, verb rune) { first := true fmt.Fprintf(f, "TableSet{") @@ -356,6 +357,7 @@ func MergeTableSets(tss ...TableSet) (result TableSet) { return } +// TableSetFromIds returns TableSet for all the id passed in argument. func TableSetFromIds(tids ...int) (ts TableSet) { for _, tid := range tids { ts.AddTable(tid)