Skip to content

Commit

Permalink
Merge pull request #8394 from planetscale/orderby-gen4
Browse files Browse the repository at this point in the history
Gen4: order by and enhanced scoping information for table exprs
  • Loading branch information
systay authored Jul 6, 2021
2 parents a383e84 + acdb794 commit 95ece2f
Show file tree
Hide file tree
Showing 24 changed files with 1,580 additions and 493 deletions.
1 change: 1 addition & 0 deletions go/vt/sqlparser/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ type (
SetLimit(*Limit)
SetLock(lock Lock)
MakeDistinct()
GetColumnCount() int
}

// DDLStatement represents any DDL Statement
Expand Down
15 changes: 15 additions & 0 deletions go/vt/sqlparser/ast_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,11 @@ func (node *Select) MakeDistinct() {
node.Distinct = true
}

// GetColumnCount return SelectExprs count.
func (node *Select) GetColumnCount() int {
return len(node.SelectExprs)
}

// AddWhere adds the boolean expression to the
// WHERE clause as an AND condition.
func (node *Select) AddWhere(expr Expr) {
Expand Down Expand Up @@ -796,6 +801,11 @@ func (node *ParenSelect) MakeDistinct() {
node.Select.MakeDistinct()
}

// GetColumnCount implements the SelectStatement interface
func (node *ParenSelect) GetColumnCount() int {
return node.Select.GetColumnCount()
}

// AddWhere adds the boolean expression to the
// WHERE clause as an AND condition.
func (node *Update) AddWhere(expr Expr) {
Expand Down Expand Up @@ -832,6 +842,11 @@ func (node *Union) MakeDistinct() {
node.UnionSelects[len(node.UnionSelects)-1].Distinct = true
}

// GetColumnCount implements the SelectStatement interface
func (node *Union) GetColumnCount() int {
return node.FirstStatement.GetColumnCount()
}

//Unionize returns a UNION, either creating one or adding SELECT to an existing one
func Unionize(lhs, rhs SelectStatement, distinct bool, by OrderBy, limit *Limit, lock Lock) *Union {
union, isUnion := lhs.(*Union)
Expand Down
113 changes: 113 additions & 0 deletions go/vt/vtgate/planbuilder/expand_star.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
Copyright 2021 The Vitess Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package planbuilder

import (
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vterrors"
"vitess.io/vitess/go/vt/vtgate/semantics"
)

type (
expandStarInfo struct {
proceed bool
tblColMap map[*sqlparser.AliasedTableExpr]sqlparser.SelectExprs
}
starRewriter struct {
err error
semTable *semantics.SemTable
}
)

func (sr *starRewriter) starRewrite(cursor *sqlparser.Cursor) bool {
switch node := cursor.Node().(type) {
case *sqlparser.Select:
tables := sr.semTable.GetSelectTables(node)
var selExprs sqlparser.SelectExprs
for _, selectExpr := range node.SelectExprs {
starExpr, isStarExpr := selectExpr.(*sqlparser.StarExpr)
if !isStarExpr {
selExprs = append(selExprs, selectExpr)
continue
}
colNames, expStar, err := expandTableColumns(tables, starExpr)
if err != nil {
sr.err = err
return false
}
if !expStar.proceed {
selExprs = append(selExprs, selectExpr)
continue
}
selExprs = append(selExprs, colNames...)
for tbl, cols := range expStar.tblColMap {
sr.semTable.AddExprs(tbl, cols)
}
}
node.SelectExprs = selExprs
}
return true
}

func expandTableColumns(tables []semantics.TableInfo, starExpr *sqlparser.StarExpr) (sqlparser.SelectExprs, *expandStarInfo, error) {
unknownTbl := true
var colNames sqlparser.SelectExprs
expStar := &expandStarInfo{
tblColMap: map[*sqlparser.AliasedTableExpr]sqlparser.SelectExprs{},
}

for _, tbl := range tables {
if !starExpr.TableName.IsEmpty() && !tbl.Matches(starExpr.TableName) {
continue
}
unknownTbl = false
if !tbl.Authoritative() {
expStar.proceed = false
break
}
expStar.proceed = true
tblName, err := tbl.Name()
if err != nil {
return nil, nil, err
}
for _, col := range tbl.GetColumns() {
colNames = append(colNames, &sqlparser.AliasedExpr{
Expr: sqlparser.NewColNameWithQualifier(col.Name, tblName),
As: sqlparser.NewColIdent(col.Name),
})
}
expStar.tblColMap[tbl.GetExpr()] = colNames
}

if unknownTbl {
// This will only happen for case when starExpr has qualifier.
return nil, nil, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.BadDb, "Unknown table '%s'", sqlparser.String(starExpr.TableName))
}
return colNames, expStar, nil
}

func expandStar(sel *sqlparser.Select, semTable *semantics.SemTable) (*sqlparser.Select, error) {
// TODO we could store in semTable whether there are any * in the query that needs expanding or not
sr := &starRewriter{semTable: semTable}

_ = sqlparser.Rewrite(sel, sr.starRewrite, nil)
if sr.err != nil {
return nil, sr.err
}
return sel, nil
}
180 changes: 180 additions & 0 deletions go/vt/vtgate/planbuilder/expand_star_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
/*
Copyright 2021 The Vitess Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package planbuilder

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vtgate/semantics"
"vitess.io/vitess/go/vt/vtgate/vindexes"
)

func TestExpandStar(t *testing.T) {
schemaInfo := &semantics.FakeSI{
Tables: map[string]*vindexes.Table{
"t1": {
Name: sqlparser.NewTableIdent("t1"),
Columns: []vindexes.Column{{
Name: sqlparser.NewColIdent("a"),
Type: sqltypes.VarChar,
}, {
Name: sqlparser.NewColIdent("b"),
Type: sqltypes.VarChar,
}, {
Name: sqlparser.NewColIdent("c"),
Type: sqltypes.VarChar,
}},
ColumnListAuthoritative: true,
},
"t2": {
Name: sqlparser.NewTableIdent("t2"),
Columns: []vindexes.Column{{
Name: sqlparser.NewColIdent("c1"),
Type: sqltypes.VarChar,
}, {
Name: sqlparser.NewColIdent("c2"),
Type: sqltypes.VarChar,
}},
ColumnListAuthoritative: true,
},
"t3": { // non authoritative table.
Name: sqlparser.NewTableIdent("t3"),
Columns: []vindexes.Column{{
Name: sqlparser.NewColIdent("col"),
Type: sqltypes.VarChar,
}},
ColumnListAuthoritative: false,
},
},
}
cDB := "db"
tcases := []struct {
sql string
expSQL string
expErr string
}{{
sql: "select * from t1",
expSQL: "select t1.a as a, t1.b as b, t1.c as c from t1",
}, {
sql: "select t1.* from t1",
expSQL: "select t1.a as a, t1.b as b, t1.c as c from t1",
}, {
sql: "select *, 42, t1.* from t1",
expSQL: "select t1.a as a, t1.b as b, t1.c as c, 42, t1.a as a, t1.b as b, t1.c as c from t1",
}, {
sql: "select 42, t1.* from t1",
expSQL: "select 42, t1.a as a, t1.b as b, t1.c as c from t1",
}, {
sql: "select * from t1, t2",
expSQL: "select t1.a as a, t1.b as b, t1.c as c, t2.c1 as c1, t2.c2 as c2 from t1, t2",
}, {
sql: "select t1.* from t1, t2",
expSQL: "select t1.a as a, t1.b as b, t1.c as c from t1, t2",
}, {
sql: "select *, t1.* from t1, t2",
expSQL: "select t1.a as a, t1.b as b, t1.c as c, t2.c1 as c1, t2.c2 as c2, t1.a as a, t1.b as b, t1.c as c from t1, t2",
}, { // aliased table
sql: "select * from t1 a, t2 b",
expSQL: "select a.a as a, a.b as b, a.c as c, b.c1 as c1, b.c2 as c2 from t1 as a, t2 as b",
}, { // t3 is non-authoritative table
sql: "select * from t3",
expSQL: "select * from t3",
}, { // t3 is non-authoritative table
sql: "select * from t1, t2, t3",
expSQL: "select * from t1, t2, t3",
}, { // t3 is non-authoritative table
sql: "select t1.*, t2.*, t3.* from t1, t2, t3",
expSQL: "select t1.a as a, t1.b as b, t1.c as c, t2.c1 as c1, t2.c2 as c2, t3.* from t1, t2, t3",
}, {
sql: "select foo.* from t1, t2",
expErr: "Unknown table 'foo'",
}}
for _, tcase := range tcases {
t.Run(tcase.sql, func(t *testing.T) {
ast, err := sqlparser.Parse(tcase.sql)
require.NoError(t, err)
semTable, err := semantics.Analyze(ast, cDB, schemaInfo)
require.NoError(t, err)
expandedSelect, err := expandStar(ast.(*sqlparser.Select), semTable)
if tcase.expErr == "" {
require.NoError(t, err)
assert.Equal(t, tcase.expSQL, sqlparser.String(expandedSelect))
} else {
require.EqualError(t, err, tcase.expErr)
}
})
}
}

func TestSemTableDependenciesAfterExpandStar(t *testing.T) {
schemaInfo := &semantics.FakeSI{Tables: map[string]*vindexes.Table{
"t1": {
Name: sqlparser.NewTableIdent("t1"),
Columns: []vindexes.Column{{
Name: sqlparser.NewColIdent("a"),
Type: sqltypes.VarChar,
}},
ColumnListAuthoritative: true,
}}}
tcases := []struct {
sql string
expSQL string
sameTbl int
otherTbl int
expandedCol int
}{{
sql: "select a, * from t1",
expSQL: "select a, t1.a as a from t1",
otherTbl: -1, sameTbl: 0, expandedCol: 1,
}, {
sql: "select t2.a, t1.a, t1.* from t1, t2",
expSQL: "select t2.a, t1.a, t1.a as a from t1, t2",
otherTbl: 0, sameTbl: 1, expandedCol: 2,
}, {
sql: "select t2.a, t.a, t.* from t1 t, t2",
expSQL: "select t2.a, t.a, t.a as a from t1 as t, t2",
otherTbl: 0, sameTbl: 1, expandedCol: 2,
}}
for _, tcase := range tcases {
t.Run(tcase.sql, func(t *testing.T) {
ast, err := sqlparser.Parse(tcase.sql)
require.NoError(t, err)
semTable, err := semantics.Analyze(ast, "", schemaInfo)
require.NoError(t, err)
expandedSelect, err := expandStar(ast.(*sqlparser.Select), semTable)
require.NoError(t, err)
assert.Equal(t, tcase.expSQL, sqlparser.String(expandedSelect))
if tcase.otherTbl != -1 {
assert.NotEqual(t,
semTable.Dependencies(expandedSelect.SelectExprs[tcase.otherTbl].(*sqlparser.AliasedExpr).Expr),
semTable.Dependencies(expandedSelect.SelectExprs[tcase.expandedCol].(*sqlparser.AliasedExpr).Expr),
)
}
if tcase.sameTbl != -1 {
assert.Equal(t,
semTable.Dependencies(expandedSelect.SelectExprs[tcase.sameTbl].(*sqlparser.AliasedExpr).Expr),
semTable.Dependencies(expandedSelect.SelectExprs[tcase.expandedCol].(*sqlparser.AliasedExpr).Expr),
)
}
})
}
}
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/grouping.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func planGroupBy(pb *primitiveBuilder, input logicalPlan, groupBy sqlparser.Grou
return nil, vterrors.New(vtrpcpb.Code_UNIMPLEMENTED, "unsupported: in scatter query: group by column must reference column in SELECT list")
}
case *sqlparser.Literal:
num, err := ResultFromNumber(node.resultColumns, e)
num, err := ResultFromNumber(node.resultColumns, e, "group statement")
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/memory_sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func newMemorySort(plan logicalPlan, orderBy sqlparser.OrderBy) (*memorySort, er
switch expr := order.Expr.(type) {
case *sqlparser.Literal:
var err error
if colNumber, err = ResultFromNumber(ms.ResultColumns(), expr); err != nil {
if colNumber, err = ResultFromNumber(ms.ResultColumns(), expr, "order clause"); err != nil {
return nil, err
}
case *sqlparser.ColName:
Expand Down
Loading

0 comments on commit 95ece2f

Please sign in to comment.