Skip to content

Commit

Permalink
Use schema for the information_schema views (vitessio#12171)
Browse files Browse the repository at this point in the history
* feat: fail queries from information_schema which aren't querying valid table names

Signed-off-by: Manan Gupta <[email protected]>
Signed-off-by: Andres Taylor <[email protected]>

* more columns that can be treated as table and schema names

Signed-off-by: Andres Taylor <[email protected]>

* renamed test case files

Signed-off-by: Andres Taylor <[email protected]>

* refactor: clean up code

Signed-off-by: Andres Taylor <[email protected]>

* clean up types

Signed-off-by: Andres Taylor <[email protected]>

* test: add query with missing info_schema table

Signed-off-by: Andres Taylor <[email protected]>

* refactor: clean up method

Signed-off-by: Andres Taylor <[email protected]>

---------

Signed-off-by: Manan Gupta <[email protected]>
Signed-off-by: Andres Taylor <[email protected]>
Co-authored-by: Manan Gupta <[email protected]>
  • Loading branch information
systay and GuptaManan100 authored Feb 1, 2023
1 parent 4f5ab22 commit cf82274
Show file tree
Hide file tree
Showing 11 changed files with 248 additions and 150 deletions.
5 changes: 3 additions & 2 deletions go/vt/vtgate/executor_select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -459,13 +459,14 @@ func TestGen4SelectDBA(t *testing.T) {
executor.normalize = true
executor.pv = querypb.ExecuteOptions_Gen4

query := "select * from INFORMATION_SCHEMA.foo"
query := "select * from INFORMATION_SCHEMA.TABLE_CONSTRAINTS"
_, err := executor.Execute(context.Background(), "TestSelectDBA",
NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}),
query, map[string]*querypb.BindVariable{},
)
require.NoError(t, err)
wantQueries := []*querypb.BoundQuery{{Sql: query, BindVariables: map[string]*querypb.BindVariable{}}}
expected := "select CONSTRAINT_CATALOG, CONSTRAINT_SCHEMA, CONSTRAINT_NAME, TABLE_SCHEMA, TABLE_NAME, CONSTRAINT_TYPE, `ENFORCED` from INFORMATION_SCHEMA.TABLE_CONSTRAINTS"
wantQueries := []*querypb.BoundQuery{{Sql: expected, BindVariables: map[string]*querypb.BindVariable{}}}
utils.MustMatch(t, wantQueries, sbc1.Queries)

sbc1.Queries = nil
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/route.go
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@ func createRoute(
ctx *plancontext.PlanningContext,
queryTable *QueryTable,
solves semantics.TableSet,
) (*Route, error) {
) (ops.Operator, error) {
if queryTable.IsInfSchema {
return createInfSchemaRoute(ctx, queryTable)
}
Expand Down
21 changes: 10 additions & 11 deletions go/vt/vtgate/planbuilder/operators/route_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ func seedOperatorList(ctx *plancontext.PlanningContext, qg *QueryGraph) ([]ops.O
return nil, err
}
if qg.NoDeps != nil {
plan.Source, err = plan.Source.AddPredicate(ctx, qg.NoDeps)
plan, err = plan.AddPredicate(ctx, qg.NoDeps)
if err != nil {
return nil, err
}
Expand All @@ -279,22 +279,21 @@ func seedOperatorList(ctx *plancontext.PlanningContext, qg *QueryGraph) ([]ops.O
return plans, nil
}

func createInfSchemaRoute(ctx *plancontext.PlanningContext, table *QueryTable) (*Route, error) {
func createInfSchemaRoute(ctx *plancontext.PlanningContext, table *QueryTable) (ops.Operator, error) {
ks, err := ctx.VSchema.AnyKeyspace()
if err != nil {
return nil, err
}
var src ops.Operator = &Table{
QTable: table,
VTable: &vindexes.Table{
Name: table.Table.Name,
Keyspace: ks,
},
}
r := &Route{
RouteOpCode: engine.DBA,
Source: src,
Keyspace: ks,
Source: &Table{
QTable: table,
VTable: &vindexes.Table{
Name: table.Table.Name,
Keyspace: ks,
},
},
Keyspace: ks,
}
for _, pred := range table.Predicates {
isTableSchema, bvName, out, err := extractInfoSchemaRoutingPredicate(pred, ctx.ReservedVars)
Expand Down
91 changes: 50 additions & 41 deletions go/vt/vtgate/planbuilder/operators/system_tables.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,60 +37,69 @@ func (r *Route) findSysInfoRoutingPredicatesGen4(predicates []sqlparser.Expr, re
continue
}

if r.SysTableTableName == nil {
r.SysTableTableName = map[string]evalengine.Expr{}
}

if isTableSchema {
r.SysTableTableSchema = append(r.SysTableTableSchema, out)
} else {
if r.SysTableTableName == nil {
r.SysTableTableName = map[string]evalengine.Expr{}
}
r.SysTableTableName[bvName] = out
}
}
return nil
}

func extractInfoSchemaRoutingPredicate(in sqlparser.Expr, reservedVars *sqlparser.ReservedVars) (bool, string, evalengine.Expr, error) {
switch cmp := in.(type) {
case *sqlparser.ComparisonExpr:
if cmp.Operator == sqlparser.EqualOp {
isSchemaName, col, other, replaceOther := findOtherComparator(cmp)
if col != nil && shouldRewrite(other) {
evalExpr, err := evalengine.Translate(other, &notImplementedSchemaInfoConverter{})
if err != nil {
if strings.Contains(err.Error(), evalengine.ErrTranslateExprNotSupported) {
// This just means we can't rewrite this particular expression,
// not that we have to exit altogether
return false, "", nil, nil
}
return false, "", nil, err
}
var name string
if isSchemaName {
name = sqltypes.BvSchemaName
} else {
name = reservedVars.ReserveColName(col.(*sqlparser.ColName))
}
replaceOther(sqlparser.NewArgument(name))
return isSchemaName, name, evalExpr, nil
}
func extractInfoSchemaRoutingPredicate(
in sqlparser.Expr,
reservedVars *sqlparser.ReservedVars,
) (isSchemaName bool, name string, evalExpr evalengine.Expr, err error) {
cmp, ok := in.(*sqlparser.ComparisonExpr)
if !ok || cmp.Operator != sqlparser.EqualOp {
return
}

isSchemaName, col := isTableOrSchemaRouteable(cmp)
if col == nil || !shouldRewrite(cmp.Right) {
return
}

evalExpr, err = evalengine.Translate(cmp.Right, &notImplementedSchemaInfoConverter{})
if err != nil {
if strings.Contains(err.Error(), evalengine.ErrTranslateExprNotSupported) {
// This just means we can't rewrite this particular expression,
// not that we have to exit altogether
err = nil
return
}
return
}
if isSchemaName {
name = sqltypes.BvSchemaName
} else {
name = reservedVars.ReserveColName(col)
}
return false, "", nil, nil
cmp.Right = sqlparser.NewArgument(name)
return isSchemaName, name, evalExpr, nil
}

func findOtherComparator(cmp *sqlparser.ComparisonExpr) (bool, sqlparser.Expr, sqlparser.Expr, func(arg sqlparser.Argument)) {
if schema, table := isTableSchemaOrName(cmp.Left); schema || table {
return schema, cmp.Left, cmp.Right, func(arg sqlparser.Argument) {
cmp.Right = arg
}
// isTableOrSchemaRouteable searches for a comparison where one side is a table or schema name column.
// if it finds the correct column name being used,
// it also makes sure that the LHS of the comparison contains the column, and the RHS the value sought after
func isTableOrSchemaRouteable(cmp *sqlparser.ComparisonExpr) (
isSchema bool, // tells if we are dealing with a table or a schema name comparator
col *sqlparser.ColName, // which is the colName we are comparing against
) {
if col, schema, table := isTableSchemaOrName(cmp.Left); schema || table {
return schema, col
}
if schema, table := isTableSchemaOrName(cmp.Right); schema || table {
return schema, cmp.Right, cmp.Left, func(arg sqlparser.Argument) {
cmp.Left = arg
}
if col, schema, table := isTableSchemaOrName(cmp.Right); schema || table {
// to make the rest of the code easier, we shuffle these around so the ColName is always on the LHS
cmp.Right, cmp.Left = cmp.Left, cmp.Right
return schema, col
}

return false, nil, nil, nil
return false, nil
}

func shouldRewrite(e sqlparser.Expr) bool {
Expand All @@ -102,12 +111,12 @@ func shouldRewrite(e sqlparser.Expr) bool {
return true
}

func isTableSchemaOrName(e sqlparser.Expr) (isTableSchema bool, isTableName bool) {
func isTableSchemaOrName(e sqlparser.Expr) (col *sqlparser.ColName, isTableSchema bool, isTableName bool) {
col, ok := e.(*sqlparser.ColName)
if !ok {
return false, false
return nil, false, false
}
return isDbNameCol(col), isTableNameCol(col)
return col, isDbNameCol(col), isTableNameCol(col)
}

func isDbNameCol(col *sqlparser.ColName) bool {
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ func TestPlan(t *testing.T) {
testFile(t, "flush_cases_no_default_keyspace.json", testOutputTempDir, vschemaWrapper, false)
testFile(t, "show_cases_no_default_keyspace.json", testOutputTempDir, vschemaWrapper, false)
testFile(t, "stream_cases.json", testOutputTempDir, vschemaWrapper, false)
testFile(t, "systemtables_cases80.json", testOutputTempDir, vschemaWrapper, false)
testFile(t, "info_schema80_cases.json", testOutputTempDir, vschemaWrapper, false)
testFile(t, "reference_cases.json", testOutputTempDir, vschemaWrapper, false)
testFile(t, "vexplain_cases.json", testOutputTempDir, vschemaWrapper, false)
}
Expand All @@ -261,7 +261,7 @@ func TestSystemTables57(t *testing.T) {
defer servenv.SetMySQLServerVersionForTest("")
vschemaWrapper := &vschemaWrapper{v: loadSchema(t, "vschemas/schema.json", true)}
testOutputTempDir := makeTestOutput(t)
testFile(t, "systemtables_cases57.json", testOutputTempDir, vschemaWrapper, false)
testFile(t, "info_schema57_cases.json", testOutputTempDir, vschemaWrapper, false)
}

func TestSysVarSetDisabled(t *testing.T) {
Expand Down
73 changes: 46 additions & 27 deletions go/vt/vtgate/planbuilder/system_tables.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,41 +87,60 @@ func isTableSchemaOrName(e sqlparser.Expr) (isTableSchema bool, isTableName bool
return isDbNameCol(col), isTableNameCol(col)
}

var schemaColumns = map[string]any{
"table_schema": nil,
"constraint_schema": nil,
"schema_name": nil,
"routine_schema": nil,
"specific_schema": nil,
"event_schema": nil,
"referenced_table_schema": nil,
"index_schema": nil,
"trigger_schema": nil,
"event_object_schema": nil,
}

func isDbNameCol(col *sqlparser.ColName) bool {
return col.Name.EqualString("table_schema") || col.Name.EqualString("constraint_schema") || col.Name.EqualString("schema_name") || col.Name.EqualString("routine_schema")
_, found := schemaColumns[col.Name.Lowered()]
return found
}

func isTableNameCol(col *sqlparser.ColName) bool {
return col.Name.EqualString("table_name")
return col.Name.EqualString("table_name") || col.Name.EqualString("referenced_table_name")
}

func extractInfoSchemaRoutingPredicate(in sqlparser.Expr, reservedVars *sqlparser.ReservedVars) (bool, string, evalengine.Expr, error) {
switch cmp := in.(type) {
case *sqlparser.ComparisonExpr:
if cmp.Operator == sqlparser.EqualOp {
isSchemaName, col, other, replaceOther := findOtherComparator(cmp)
if col != nil && shouldRewrite(other) {
evalExpr, err := evalengine.Translate(other, &notImplementedSchemaInfoConverter{})
if err != nil {
if strings.Contains(err.Error(), evalengine.ErrTranslateExprNotSupported) {
// This just means we can't rewrite this particular expression,
// not that we have to exit altogether
return false, "", nil, nil
}
return false, "", nil, err
}
var name string
if isSchemaName {
name = sqltypes.BvSchemaName
} else {
name = reservedVars.ReserveColName(col.(*sqlparser.ColName))
}
replaceOther(sqlparser.NewArgument(name))
return isSchemaName, name, evalExpr, nil
}
func extractInfoSchemaRoutingPredicate(
in sqlparser.Expr,
reservedVars *sqlparser.ReservedVars,
) (isSchemaName bool, name string, evalExpr evalengine.Expr, err error) {
cmp, ok := in.(*sqlparser.ComparisonExpr)
if !ok || cmp.Operator != sqlparser.EqualOp {
return
}

isSchemaName, col, other, replaceOther := findOtherComparator(cmp)
if col == nil || !shouldRewrite(other) {
return
}

evalExpr, err = evalengine.Translate(other, &notImplementedSchemaInfoConverter{})
if err != nil {
if strings.Contains(err.Error(), evalengine.ErrTranslateExprNotSupported) {
// This just means we can't rewrite this particular expression,
// not that we have to exit altogether
err = nil
return
}
return false, "", nil, err
}

if isSchemaName {
name = sqltypes.BvSchemaName
} else {
name = reservedVars.ReserveColName(col.(*sqlparser.ColName))
}
return false, "", nil, nil
replaceOther(sqlparser.NewArgument(name))
return isSchemaName, name, evalExpr, nil
}

func shouldRewrite(e sqlparser.Expr) bool {
Expand Down
Loading

0 comments on commit cf82274

Please sign in to comment.