From 4f1a90207bda191c05b03fe6d1285026ce5d7f7a Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Fri, 15 Oct 2021 15:57:14 +0530 Subject: [PATCH] fix precedence test for subquery Signed-off-by: Harshit Gangal --- go/vt/sqlparser/ast.go | 9 ++-- go/vt/sqlparser/ast_clone.go | 1 + go/vt/sqlparser/ast_equals.go | 9 ++-- go/vt/sqlparser/ast_format.go | 28 +--------- go/vt/sqlparser/ast_format_fast.go | 28 +--------- go/vt/sqlparser/ast_funcs.go | 53 +++++++++++++++++++ go/vt/sqlparser/ast_rewrite.go | 5 ++ go/vt/sqlparser/ast_visit.go | 3 ++ go/vt/sqlparser/cached_size.go | 14 +++-- go/vt/sqlparser/precedence.go | 2 +- go/vt/sqlparser/precedence_test.go | 39 ++++++++++---- .../planbuilder/abstract/operator_test.go | 4 +- .../planbuilder/querytree_transformers.go | 8 +-- go/vt/vtgate/planbuilder/rewrite.go | 10 ++-- go/vt/vtgate/planbuilder/rewrite_test.go | 2 +- go/vt/vtgate/planbuilder/routetree.go | 4 +- .../planbuilder/testdata/filter_cases.txt | 6 +-- 17 files changed, 132 insertions(+), 93 deletions(-) diff --git a/go/vt/sqlparser/ast.go b/go/vt/sqlparser/ast.go index e42227d52d6..eef6382227e 100644 --- a/go/vt/sqlparser/ast.go +++ b/go/vt/sqlparser/ast.go @@ -1966,14 +1966,17 @@ type ( // ExtractedSubquery is a subquery that has been extracted from the original AST // This is a struct that the parser will never produce - it's written and read by the gen4 planner + // CAUTION: you should only change argName and hasValuesArg through the setter methods ExtractedSubquery struct { Original Expr // original expression that was replaced by this ExtractedSubquery - ArgName string - HasValuesArg string - OpCode int // this should really be engine.PulloutOpCode, but we cannot depend on engine :( + OpCode int // this should really be engine.PulloutOpCode, but we cannot depend on engine :( Subquery *Subquery OtherSide Expr // represents the side of the comparison, this field will be nil if Original is not a comparison NeedsRewrite bool // tells whether we need to rewrite this subquery to Original or not + + hasValuesArg string + argName string + alternative Expr // this is what will be used to Format this struct } ) diff --git a/go/vt/sqlparser/ast_clone.go b/go/vt/sqlparser/ast_clone.go index 96060511316..0cd891306c8 100644 --- a/go/vt/sqlparser/ast_clone.go +++ b/go/vt/sqlparser/ast_clone.go @@ -864,6 +864,7 @@ func CloneRefOfExtractedSubquery(n *ExtractedSubquery) *ExtractedSubquery { out.Original = CloneExpr(n.Original) out.Subquery = CloneRefOfSubquery(n.Subquery) out.OtherSide = CloneExpr(n.OtherSide) + out.alternative = CloneExpr(n.alternative) return &out } diff --git a/go/vt/sqlparser/ast_equals.go b/go/vt/sqlparser/ast_equals.go index 20794604648..0c2d7d2b4f9 100644 --- a/go/vt/sqlparser/ast_equals.go +++ b/go/vt/sqlparser/ast_equals.go @@ -1563,13 +1563,14 @@ func EqualsRefOfExtractedSubquery(a, b *ExtractedSubquery) bool { if a == nil || b == nil { return false } - return a.ArgName == b.ArgName && - a.HasValuesArg == b.HasValuesArg && - a.OpCode == b.OpCode && + return a.OpCode == b.OpCode && a.NeedsRewrite == b.NeedsRewrite && + a.hasValuesArg == b.hasValuesArg && + a.argName == b.argName && EqualsExpr(a.Original, b.Original) && EqualsRefOfSubquery(a.Subquery, b.Subquery) && - EqualsExpr(a.OtherSide, b.OtherSide) + EqualsExpr(a.OtherSide, b.OtherSide) && + EqualsExpr(a.alternative, b.alternative) } // EqualsRefOfFlush does deep equals between the two objects. diff --git a/go/vt/sqlparser/ast_format.go b/go/vt/sqlparser/ast_format.go index 89c24e24002..a65f65e91ef 100644 --- a/go/vt/sqlparser/ast_format.go +++ b/go/vt/sqlparser/ast_format.go @@ -1726,31 +1726,5 @@ func (node *RenameTable) Format(buf *TrackedBuffer) { // it will be formatted as if the subquery has been extracted, and instead // show up like argument comparisons func (node *ExtractedSubquery) Format(buf *TrackedBuffer) { - switch original := node.Original.(type) { - case *ExistsExpr: - buf.formatter(NewArgument(node.ArgName)) - case *Subquery: - buf.formatter(NewArgument(node.ArgName)) - case *ComparisonExpr: - // other_side = :__sq - cmp := &ComparisonExpr{ - Left: node.OtherSide, - Right: NewArgument(node.ArgName), - Operator: original.Operator, - } - var expr Expr = cmp - switch original.Operator { - case InOp: - // :__sq_has_values = 1 and other_side in ::__sq - cmp.Right = NewListArg(node.ArgName) - hasValue := &ComparisonExpr{Left: NewArgument(node.HasValuesArg), Right: NewIntLiteral("1"), Operator: EqualOp} - expr = AndExpressions(hasValue, cmp) - case NotInOp: - // :__sq_has_values = 0 or other_side not in ::__sq - cmp.Right = NewListArg(node.ArgName) - hasValue := &ComparisonExpr{Left: NewArgument(node.HasValuesArg), Right: NewIntLiteral("0"), Operator: EqualOp} - expr = &OrExpr{hasValue, cmp} - } - buf.formatter(expr) - } + node.alternative.Format(buf) } diff --git a/go/vt/sqlparser/ast_format_fast.go b/go/vt/sqlparser/ast_format_fast.go index 494ed57fde3..22cdef2f4a9 100644 --- a/go/vt/sqlparser/ast_format_fast.go +++ b/go/vt/sqlparser/ast_format_fast.go @@ -2256,31 +2256,5 @@ func (node *RenameTable) formatFast(buf *TrackedBuffer) { // it will be formatted as if the subquery has been extracted, and instead // show up like argument comparisons func (node *ExtractedSubquery) formatFast(buf *TrackedBuffer) { - switch original := node.Original.(type) { - case *ExistsExpr: - buf.formatter(NewArgument(node.ArgName)) - case *Subquery: - buf.formatter(NewArgument(node.ArgName)) - case *ComparisonExpr: - // other_side = :__sq - cmp := &ComparisonExpr{ - Left: node.OtherSide, - Right: NewArgument(node.ArgName), - Operator: original.Operator, - } - var expr Expr = cmp - switch original.Operator { - case InOp: - // :__sq_has_values = 1 and other_side in ::__sq - cmp.Right = NewListArg(node.ArgName) - hasValue := &ComparisonExpr{Left: NewArgument(node.HasValuesArg), Right: NewIntLiteral("1"), Operator: EqualOp} - expr = AndExpressions(hasValue, cmp) - case NotInOp: - // :__sq_has_values = 0 or other_side not in ::__sq - cmp.Right = NewListArg(node.ArgName) - hasValue := &ComparisonExpr{Left: NewArgument(node.HasValuesArg), Right: NewIntLiteral("0"), Operator: EqualOp} - expr = &OrExpr{hasValue, cmp} - } - buf.formatter(expr) - } + node.alternative.Format(buf) } diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index 596012d9243..bc1f1e22eda 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -1455,3 +1455,56 @@ func GetAllSelects(selStmt SelectStatement) []*Select { } panic("[BUG]: unknown type for SelectStatement") } + +// SetArgName sets argument name. +func (es *ExtractedSubquery) SetArgName(n string) { + es.argName = n + es.updateAlternative() +} + +// SetHasValuesArg sets has_values argument. +func (es *ExtractedSubquery) SetHasValuesArg(n string) { + es.hasValuesArg = n + es.updateAlternative() +} + +// GetArgName returns argument name. +func (es *ExtractedSubquery) GetArgName() string { + return es.argName +} + +// GetHasValuesArg returns has values argument. +func (es *ExtractedSubquery) GetHasValuesArg() string { + return es.hasValuesArg + +} + +func (es *ExtractedSubquery) updateAlternative() { + switch original := es.Original.(type) { + case *ExistsExpr: + es.alternative = NewArgument(es.argName) + case *Subquery: + es.alternative = NewArgument(es.argName) + case *ComparisonExpr: + // other_side = :__sq + cmp := &ComparisonExpr{ + Left: es.OtherSide, + Right: NewArgument(es.argName), + Operator: original.Operator, + } + var expr Expr = cmp + switch original.Operator { + case InOp: + // :__sq_has_values = 1 and other_side in ::__sq + cmp.Right = NewListArg(es.argName) + hasValue := &ComparisonExpr{Left: NewArgument(es.hasValuesArg), Right: NewIntLiteral("1"), Operator: EqualOp} + expr = AndExpressions(hasValue, cmp) + case NotInOp: + // :__sq_has_values = 0 or other_side not in ::__sq + cmp.Right = NewListArg(es.argName) + hasValue := &ComparisonExpr{Left: NewArgument(es.hasValuesArg), Right: NewIntLiteral("0"), Operator: EqualOp} + expr = &OrExpr{hasValue, cmp} + } + es.alternative = expr + } +} diff --git a/go/vt/sqlparser/ast_rewrite.go b/go/vt/sqlparser/ast_rewrite.go index 747d7bfb6f2..11fdb7691c8 100644 --- a/go/vt/sqlparser/ast_rewrite.go +++ b/go/vt/sqlparser/ast_rewrite.go @@ -1916,6 +1916,11 @@ func (a *application) rewriteRefOfExtractedSubquery(parent SQLNode, node *Extrac }) { return false } + if !a.rewriteExpr(node, node.alternative, func(newNode, parent SQLNode) { + parent.(*ExtractedSubquery).alternative = newNode.(Expr) + }) { + return false + } if a.post != nil { a.cur.replacer = replacer a.cur.parent = parent diff --git a/go/vt/sqlparser/ast_visit.go b/go/vt/sqlparser/ast_visit.go index c75bc920069..8511e40c1fe 100644 --- a/go/vt/sqlparser/ast_visit.go +++ b/go/vt/sqlparser/ast_visit.go @@ -1045,6 +1045,9 @@ func VisitRefOfExtractedSubquery(in *ExtractedSubquery, f Visit) error { if err := VisitExpr(in.OtherSide, f); err != nil { return err } + if err := VisitExpr(in.alternative, f); err != nil { + return err + } return nil } func VisitRefOfFlush(in *Flush, f Visit) error { diff --git a/go/vt/sqlparser/cached_size.go b/go/vt/sqlparser/cached_size.go index 7c4f311a396..0b5b2d72268 100644 --- a/go/vt/sqlparser/cached_size.go +++ b/go/vt/sqlparser/cached_size.go @@ -934,22 +934,26 @@ func (cached *ExtractedSubquery) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(96) + size += int64(112) } // field Original vitess.io/vitess/go/vt/sqlparser.Expr if cc, ok := cached.Original.(cachedObject); ok { size += cc.CachedSize(true) } - // field ArgName string - size += hack.RuntimeAllocSize(int64(len(cached.ArgName))) - // field HasValuesArg string - size += hack.RuntimeAllocSize(int64(len(cached.HasValuesArg))) // field Subquery *vitess.io/vitess/go/vt/sqlparser.Subquery size += cached.Subquery.CachedSize(true) // field OtherSide vitess.io/vitess/go/vt/sqlparser.Expr if cc, ok := cached.OtherSide.(cachedObject); ok { size += cc.CachedSize(true) } + // field hasValuesArg string + size += hack.RuntimeAllocSize(int64(len(cached.hasValuesArg))) + // field argName string + size += hack.RuntimeAllocSize(int64(len(cached.argName))) + // field alternative vitess.io/vitess/go/vt/sqlparser.Expr + if cc, ok := cached.alternative.(cachedObject); ok { + size += cc.CachedSize(true) + } return size } func (cached *Flush) CachedSize(alloc bool) int64 { diff --git a/go/vt/sqlparser/precedence.go b/go/vt/sqlparser/precedence.go index b6e7397a6a0..e1e6cce9c42 100644 --- a/go/vt/sqlparser/precedence.go +++ b/go/vt/sqlparser/precedence.go @@ -91,7 +91,7 @@ func precedenceFor(in Expr) Precendence { case *IntervalExpr: return P1 case *ExtractedSubquery: - return precedenceFor(node.Original) + return precedenceFor(node.alternative) } return Syntactic diff --git a/go/vt/sqlparser/precedence_test.go b/go/vt/sqlparser/precedence_test.go index d6cbb8a845a..cbc481bb4d3 100644 --- a/go/vt/sqlparser/precedence_test.go +++ b/go/vt/sqlparser/precedence_test.go @@ -72,22 +72,43 @@ func TestNotInSubqueryPrecedence(t *testing.T) { not := tree.(*Select).Where.Expr.(*NotExpr) cmp := not.Expr.(*ComparisonExpr) subq := cmp.Right.(*Subquery) - fmt.Println(subq) extracted := &ExtractedSubquery{ - Original: cmp, - ArgName: "arg1", - HasValuesArg: "has_values1", - OpCode: 1, - Subquery: subq, - OtherSide: cmp.Left, + Original: cmp, + OpCode: 1, + Subquery: subq, + OtherSide: cmp.Left, } - not.Expr = extracted + extracted.SetArgName("arg1") + extracted.SetHasValuesArg("has_values1") + not.Expr = extracted output := readable(not) assert.Equal(t, "not (:has_values1 = 1 and id in ::arg1)", output) } +func TestSubqueryPrecedence(t *testing.T) { + tree, err := Parse("select * from a where id in (select 42) and false") + require.NoError(t, err) + where := tree.(*Select).Where + andExpr := where.Expr.(*AndExpr) + cmp := andExpr.Left.(*ComparisonExpr) + subq := cmp.Right.(*Subquery) + + extracted := &ExtractedSubquery{ + Original: andExpr.Left, + OpCode: 1, + Subquery: subq, + OtherSide: cmp.Left, + } + extracted.SetArgName("arg1") + extracted.SetHasValuesArg("has_values1") + + andExpr.Left = extracted + output := readable(extracted) + assert.Equal(t, ":has_values1 = 1 and id in ::arg1", output) +} + func TestPlusStarPrecedence(t *testing.T) { validSQL := []struct { input string @@ -193,7 +214,7 @@ func TestRandom(t *testing.T) { // The purpose of this test is to find discrepancies between Format and parsing. If for example our precedence rules are not consistent between the two, this test should find it. // The idea is to generate random queries, and pass them through the parser and then the unparser, and one more time. The result of the first unparse should be the same as the second result. seed := time.Now().UnixNano() - fmt.Println(fmt.Sprintf("seed is %d", seed)) //nolint + fmt.Println(fmt.Sprintf("seed is %d", seed)) // nolint g := newGenerator(seed, 5) endBy := time.Now().Add(1 * time.Second) diff --git a/go/vt/vtgate/planbuilder/abstract/operator_test.go b/go/vt/vtgate/planbuilder/abstract/operator_test.go index 704808b956c..b86abfabd31 100644 --- a/go/vt/vtgate/planbuilder/abstract/operator_test.go +++ b/go/vt/vtgate/planbuilder/abstract/operator_test.go @@ -133,8 +133,8 @@ func testString(op Operator) string { var inners []string for _, sqOp := range op.Inner { subquery := fmt.Sprintf("{\n\tType: %s", engine.PulloutOpcode(sqOp.ExtractedSubquery.OpCode).String()) - if sqOp.ExtractedSubquery.ArgName != "" { - subquery += fmt.Sprintf("\n\tArgName: %s", sqOp.ExtractedSubquery.ArgName) + if sqOp.ExtractedSubquery.GetArgName() != "" { + subquery += fmt.Sprintf("\n\tArgName: %s", sqOp.ExtractedSubquery.GetArgName()) } subquery += fmt.Sprintf("\n\tQuery: %s\n}", indent(testString(sqOp.Inner))) subquery = indent(subquery) diff --git a/go/vt/vtgate/planbuilder/querytree_transformers.go b/go/vt/vtgate/planbuilder/querytree_transformers.go index b427a769741..1bf593f4cd9 100644 --- a/go/vt/vtgate/planbuilder/querytree_transformers.go +++ b/go/vt/vtgate/planbuilder/querytree_transformers.go @@ -98,8 +98,8 @@ func transformSubqueryTree(ctx *planningContext, n *subqueryTree) (logicalPlan, return nil, err } - argName := n.extracted.ArgName - hasValuesArg := n.extracted.HasValuesArg + argName := n.extracted.GetArgName() + hasValuesArg := n.extracted.GetHasValuesArg() outerPlan, err := transformToLogicalPlan(ctx, n.outer) merged := mergeSubQueryPlan(ctx, innerPlan, outerPlan, n) @@ -345,7 +345,7 @@ func transformRoutePlan(ctx *planningContext, n *routeTree) (*route, error) { if subq, isSubq := predicate.Right.(*sqlparser.Subquery); isSubq { extractedSubquery := ctx.semTable.FindSubqueryReference(subq) if extractedSubquery != nil { - extractedSubquery.ArgName = engine.ListVarName + extractedSubquery.SetArgName(engine.ListVarName) } } predicate.Right = sqlparser.ListArg(engine.ListVarName) @@ -541,7 +541,7 @@ func (sqr *subQReplacer) replacer(cursor *sqlparser.Cursor) bool { } for _, replaceByExpr := range sqr.subqueryToReplace { // we are comparing the ArgNames in case the expressions have been cloned - if ext.ArgName == replaceByExpr.ArgName { + if ext.GetArgName() == replaceByExpr.GetArgName() { cursor.Replace(ext.Original) sqr.replaced = true return false diff --git a/go/vt/vtgate/planbuilder/rewrite.go b/go/vt/vtgate/planbuilder/rewrite.go index 90c66333434..99d725f5e97 100644 --- a/go/vt/vtgate/planbuilder/rewrite.go +++ b/go/vt/vtgate/planbuilder/rewrite.go @@ -119,8 +119,8 @@ func rewriteInSubquery(cursor *sqlparser.Cursor, r *rewriter, node *sqlparser.Co r.inSubquery++ argName, hasValuesArg := r.reservedVars.ReserveSubQueryWithHasValues() - semTableSQ.ArgName = argName - semTableSQ.HasValuesArg = hasValuesArg + semTableSQ.SetArgName(argName) + semTableSQ.SetHasValuesArg(hasValuesArg) cursor.Replace(semTableSQ) return nil } @@ -130,12 +130,12 @@ func rewriteSubquery(cursor *sqlparser.Cursor, r *rewriter, node *sqlparser.Subq if !found { return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "BUG: came across subquery that was not in the subq map") } - if semTableSQ.ArgName != "" || engine.PulloutOpcode(semTableSQ.OpCode) != engine.PulloutValue { + if semTableSQ.GetArgName() != "" || engine.PulloutOpcode(semTableSQ.OpCode) != engine.PulloutValue { return nil } r.inSubquery++ argName := r.reservedVars.ReserveSubQuery() - semTableSQ.ArgName = argName + semTableSQ.SetArgName(argName) cursor.Replace(semTableSQ) return nil } @@ -148,7 +148,7 @@ func (r *rewriter) rewriteExistsSubquery(cursor *sqlparser.Cursor, node *sqlpars r.inSubquery++ argName := r.reservedVars.ReserveHasValuesSubQuery() - semTableSQ.ArgName = argName + semTableSQ.SetArgName(argName) cursor.Replace(semTableSQ) return nil } diff --git a/go/vt/vtgate/planbuilder/rewrite_test.go b/go/vt/vtgate/planbuilder/rewrite_test.go index e035bc56662..f774fa3d568 100644 --- a/go/vt/vtgate/planbuilder/rewrite_test.go +++ b/go/vt/vtgate/planbuilder/rewrite_test.go @@ -140,7 +140,7 @@ func TestHavingRewrite(t *testing.T) { assert.Equal(t, len(tcase.sqs), len(squeries), "number of subqueries not matched") } for _, sq := range squeries { - assert.Equal(t, tcase.sqs[sq.ArgName], sqlparser.String(sq.Subquery.Select)) + assert.Equal(t, tcase.sqs[sq.GetArgName()], sqlparser.String(sq.Subquery.Select)) } }) } diff --git a/go/vt/vtgate/planbuilder/routetree.go b/go/vt/vtgate/planbuilder/routetree.go index 4e14c776707..0f085fb7414 100644 --- a/go/vt/vtgate/planbuilder/routetree.go +++ b/go/vt/vtgate/planbuilder/routetree.go @@ -452,9 +452,9 @@ func (rp *routeTree) makePlanValue(ctx *planningContext, n sqlparser.Expr) (*sql } switch engine.PulloutOpcode(extractedSubquery.OpCode) { case engine.PulloutIn, engine.PulloutNotIn: - expr = sqlparser.NewListArg(extractedSubquery.ArgName) + expr = sqlparser.NewListArg(extractedSubquery.GetArgName()) case engine.PulloutValue, engine.PulloutExists: - expr = sqlparser.NewArgument(extractedSubquery.ArgName) + expr = sqlparser.NewArgument(extractedSubquery.GetArgName()) } } pv, err := makePlanValue(expr) diff --git a/go/vt/vtgate/planbuilder/testdata/filter_cases.txt b/go/vt/vtgate/planbuilder/testdata/filter_cases.txt index 4b8f3567f0e..ae778ca4498 100644 --- a/go/vt/vtgate/planbuilder/testdata/filter_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/filter_cases.txt @@ -2434,7 +2434,7 @@ Gen4 plan same as above "Sharded": true }, "FieldQuery": "select id from `user` where 1 != 1", - "Query": "select id from `user` where :__sq_has_values1 = 0 or id not in ::__sq1 and :__sq_has_values2 = 1 and id in ::__vals", + "Query": "select id from `user` where (:__sq_has_values1 = 0 or id not in ::__sq1) and (:__sq_has_values2 = 1 and id in ::__vals)", "Table": "`user`", "Values": [ "::__sq2" @@ -2867,7 +2867,7 @@ Gen4 plan same as above "Sharded": true }, "FieldQuery": "select id from `user` where 1 != 1", - "Query": "select id from `user` where id = 5 and id not in (select user_extra.col from user_extra where user_extra.user_id = 5) and :__sq_has_values2 = 1 and id in ::__sq2", + "Query": "select id from `user` where id = 5 and id not in (select user_extra.col from user_extra where user_extra.user_id = 5) and (:__sq_has_values2 = 1 and id in ::__sq2)", "Table": "`user`", "Values": [ 5 @@ -2958,7 +2958,7 @@ Gen4 plan same as above "Sharded": true }, "FieldQuery": "select id from `user` where 1 != 1", - "Query": "select id from `user` where id = 5 and :__sq_has_values1 = 0 or id not in ::__sq1 and id in (select user_extra.col from user_extra where user_extra.user_id = 5)", + "Query": "select id from `user` where id = 5 and (:__sq_has_values1 = 0 or id not in ::__sq1) and id in (select user_extra.col from user_extra where user_extra.user_id = 5)", "Table": "`user`", "Values": [ 5