diff --git a/go/vt/schemadiff/errors.go b/go/vt/schemadiff/errors.go index 42dd304e75a..e63c0e4c3d0 100644 --- a/go/vt/schemadiff/errors.go +++ b/go/vt/schemadiff/errors.go @@ -334,3 +334,31 @@ type ViewDependencyUnresolvedError struct { func (e *ViewDependencyUnresolvedError) Error() string { return fmt.Sprintf("view %s has unresolved/loop dependencies", sqlescape.EscapeID(e.View)) } + +type InvalidColumnReferencedInViewError struct { + View string + Table string + Column string + NonUnique bool +} + +func (e *InvalidColumnReferencedInViewError) Error() string { + switch { + case e.Column == "": + return fmt.Sprintf("view %s references non-existent table %s", sqlescape.EscapeID(e.View), sqlescape.EscapeID(e.Table)) + case e.Table != "": + return fmt.Sprintf("view %s references non-existent column %s.%s", sqlescape.EscapeID(e.View), sqlescape.EscapeID(e.Table), sqlescape.EscapeID(e.Column)) + case e.NonUnique: + return fmt.Sprintf("view %s references unqualified but non unique column %s", sqlescape.EscapeID(e.View), sqlescape.EscapeID(e.Column)) + default: + return fmt.Sprintf("view %s references unqualified but non-existent column %s", sqlescape.EscapeID(e.View), sqlescape.EscapeID(e.Column)) + } +} + +type EntityNotFoundError struct { + Name string +} + +func (e *EntityNotFoundError) Error() string { + return fmt.Sprintf("entity %s not found", sqlescape.EscapeID(e.Name)) +} diff --git a/go/vt/schemadiff/schema.go b/go/vt/schemadiff/schema.go index 0e9ae4c4df1..3c491c13b81 100644 --- a/go/vt/schemadiff/schema.go +++ b/go/vt/schemadiff/schema.go @@ -24,9 +24,13 @@ import ( "sort" "strings" + "vitess.io/vitess/go/vt/concurrency" "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vterrors" ) +type tablesColumnsMap map[string]map[string]struct{} + // Schema represents a database schema, which may contain entities such as tables and views. // Schema is not in itself an Entity, since it is more of a collection of entities. type Schema struct { @@ -318,6 +322,11 @@ func (s *Schema) normalize() error { } } + // Validate views' referenced columns: do these columns actually exist in referenced tables/views? + if err := s.ValidateViewReferences(); err != nil { + return err + } + // Validate table definitions for _, t := range s.tables { if err := t.validate(); err != nil { @@ -761,3 +770,257 @@ func (s *Schema) Apply(diffs []EntityDiff) (*Schema, error) { } return dup, nil } + +func (s *Schema) ValidateViewReferences() error { + errs := &concurrency.AllErrorRecorder{} + availableColumns := tablesColumnsMap{} + + for _, e := range s.Entities() { + entityColumns, err := s.getEntityColumnNames(e.Name(), availableColumns) + if err != nil { + errs.RecordError(err) + continue + } + availableColumns[e.Name()] = map[string]struct{}{} + for _, col := range entityColumns { + availableColumns[e.Name()][col.Lowered()] = struct{}{} + } + } + + // Add dual table with no explicit columns for dual style expressions in views. + availableColumns["dual"] = map[string]struct{}{} + + for _, view := range s.Views() { + // First gather all referenced tables and table aliases + tableAliases := map[string]string{} + tableReferences := map[string]struct{}{} + err := gatherTableInformationForView(view, availableColumns, tableReferences, tableAliases) + errs.RecordError(err) + + // Now we can walk the view again and check each column expression + // to see if there's an existing column referenced. + err = gatherColumnReferenceInformationForView(view, availableColumns, tableReferences, tableAliases) + errs.RecordError(err) + } + return errs.AggrError(vterrors.Aggregate) +} + +func gatherTableInformationForView(view *CreateViewEntity, availableColumns tablesColumnsMap, tableReferences map[string]struct{}, tableAliases map[string]string) error { + errs := &concurrency.AllErrorRecorder{} + tableErrors := make(map[string]struct{}) + err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { + switch node := node.(type) { + case *sqlparser.AliasedTableExpr: + aliased := sqlparser.GetTableName(node.Expr).String() + if aliased == "" { + return true, nil + } + + if _, ok := availableColumns[aliased]; !ok { + if _, ok := tableErrors[aliased]; ok { + // Only show a missing table reference once per view. + return true, nil + } + err := &InvalidColumnReferencedInViewError{ + View: view.Name(), + Table: aliased, + } + errs.RecordError(err) + tableErrors[aliased] = struct{}{} + return true, nil + } + tableReferences[aliased] = struct{}{} + if node.As.String() != "" { + tableAliases[node.As.String()] = aliased + } + } + return true, nil + }, view.Select) + if err != nil { + // parsing error. Forget about any view dependency issues we may have found. This is way more important + return err + } + return errs.AggrError(vterrors.Aggregate) +} + +func gatherColumnReferenceInformationForView(view *CreateViewEntity, availableColumns tablesColumnsMap, tableReferences map[string]struct{}, tableAliases map[string]string) error { + errs := &concurrency.AllErrorRecorder{} + qualifiedColumnErrors := make(map[string]map[string]struct{}) + unqualifiedColumnErrors := make(map[string]struct{}) + + err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { + switch node := node.(type) { + case *sqlparser.ColName: + if node.Qualifier.IsEmpty() { + err := verifyUnqualifiedColumn(view, availableColumns, tableReferences, node.Name, unqualifiedColumnErrors) + errs.RecordError(err) + } else { + err := verifyQualifiedColumn(view, availableColumns, tableAliases, node, qualifiedColumnErrors) + errs.RecordError(err) + } + } + return true, nil + }, view.Select) + if err != nil { + // parsing error. Forget about any view dependency issues we may have found. This is way more important + return err + } + return errs.AggrError(vterrors.Aggregate) +} + +func verifyUnqualifiedColumn(view *CreateViewEntity, availableColumns tablesColumnsMap, tableReferences map[string]struct{}, nodeName sqlparser.IdentifierCI, unqualifiedColumnErrors map[string]struct{}) error { + // In case we have a non-qualified column reference, it needs + // to be unique across all referenced tables if it is supposed + // to work. + columnFound := false + for table := range tableReferences { + cols, ok := availableColumns[table] + if !ok { + // We already dealt with an error for a missing table reference + // earlier, so we can ignore it at this point here. + return nil + } + _, columnInTable := cols[nodeName.Lowered()] + if !columnInTable { + continue + } + if columnFound { + // We already have seen the column before in another table, so + // if we see it again here, that's an error case. + if _, ok := unqualifiedColumnErrors[nodeName.Lowered()]; ok { + return nil + } + unqualifiedColumnErrors[nodeName.Lowered()] = struct{}{} + return &InvalidColumnReferencedInViewError{ + View: view.Name(), + Column: nodeName.String(), + NonUnique: true, + } + } + columnFound = true + } + + // If we've seen the desired column here once, we're all good + if columnFound { + return nil + } + + if _, ok := unqualifiedColumnErrors[nodeName.Lowered()]; ok { + return nil + } + unqualifiedColumnErrors[nodeName.Lowered()] = struct{}{} + return &InvalidColumnReferencedInViewError{ + View: view.Name(), + Column: nodeName.String(), + } +} + +func verifyQualifiedColumn( + view *CreateViewEntity, + availableColumns tablesColumnsMap, + tableAliases map[string]string, node *sqlparser.ColName, + columnErrors map[string]map[string]struct{}, +) error { + tableName := node.Qualifier.Name.String() + if aliased, ok := tableAliases[tableName]; ok { + tableName = aliased + } + cols, ok := availableColumns[tableName] + if !ok { + // Already dealt with missing tables earlier on, we don't have + // any error to add here. + return nil + } + _, ok = cols[node.Name.Lowered()] + if ok { + // Found the column in the table, all good. + return nil + } + + if _, ok := columnErrors[tableName]; !ok { + columnErrors[tableName] = make(map[string]struct{}) + } + + if _, ok := columnErrors[tableName][node.Name.Lowered()]; ok { + return nil + } + columnErrors[tableName][node.Name.Lowered()] = struct{}{} + return &InvalidColumnReferencedInViewError{ + View: view.Name(), + Table: tableName, + Column: node.Name.String(), + } +} + +// getTableColumnNames returns the names of columns in given table. +func (s *Schema) getEntityColumnNames(entityName string, availableColumns tablesColumnsMap) ( + columnNames []*sqlparser.IdentifierCI, + err error, +) { + entity := s.Entity(entityName) + if entity == nil { + if strings.ToLower(entityName) == "dual" { + // this is fine. DUAL does not exist but is allowed + return nil, nil + } + return nil, &EntityNotFoundError{Name: entityName} + } + // The entity is either a table or a view + switch entity := entity.(type) { + case *CreateTableEntity: + return s.getTableColumnNames(entity), nil + case *CreateViewEntity: + return s.getViewColumnNames(entity, availableColumns) + } + return nil, &EntityNotFoundError{Name: entityName} +} + +// getTableColumnNames returns the names of columns in given table. +func (s *Schema) getTableColumnNames(t *CreateTableEntity) (columnNames []*sqlparser.IdentifierCI) { + for _, c := range t.TableSpec.Columns { + columnNames = append(columnNames, &c.Name) + } + return columnNames +} + +// getViewColumnNames returns the names of aliased columns returned by a given view. +func (s *Schema) getViewColumnNames(v *CreateViewEntity, availableColumns tablesColumnsMap) ( + columnNames []*sqlparser.IdentifierCI, + err error, +) { + err = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { + switch node := node.(type) { + case *sqlparser.StarExpr: + if tableName := node.TableName.Name.String(); tableName != "" { + for colName := range availableColumns[tableName] { + name := sqlparser.NewIdentifierCI(colName) + columnNames = append(columnNames, &name) + } + } else { + dependentNames, err := getViewDependentTableNames(v.CreateView) + if err != nil { + return false, err + } + // add all columns from all referenced tables and views + for _, entityName := range dependentNames { + for colName := range availableColumns[entityName] { + name := sqlparser.NewIdentifierCI(colName) + columnNames = append(columnNames, &name) + } + } + } + case *sqlparser.AliasedExpr: + if node.As.String() != "" { + columnNames = append(columnNames, &node.As) + } else { + name := sqlparser.NewIdentifierCI(sqlparser.String(node.Expr)) + columnNames = append(columnNames, &name) + } + } + return true, nil + }, v.Select.GetColumns()) + if err != nil { + return nil, err + } + return columnNames, nil +} diff --git a/go/vt/schemadiff/schema_test.go b/go/vt/schemadiff/schema_test.go index 1a24b862b1c..25931d90e6f 100644 --- a/go/vt/schemadiff/schema_test.go +++ b/go/vt/schemadiff/schema_test.go @@ -17,6 +17,7 @@ limitations under the License. package schemadiff import ( + "sort" "strings" "testing" @@ -36,7 +37,7 @@ var createQueries = []string{ "create table t5(id int)", "create view v2 as select * from v3, t2", "create view v1 as select * from v3", - "create view v3 as select * from t3 as t3", + "create view v3 as select *, id+1 as id_plus, id+2 from t3 as t3", "create view v0 as select 1 from DUAL", "create view v9 as select 1", } @@ -74,12 +75,12 @@ var expectSortedViewNames = []string{ "v6", // level 3 } -var toSQL = "CREATE TABLE `t1` (\n\t`id` int\n);\nCREATE TABLE `t2` (\n\t`id` int\n);\nCREATE TABLE `t3` (\n\t`id` int,\n\t`type` enum('foo', 'bar') NOT NULL DEFAULT 'foo'\n);\nCREATE TABLE `t5` (\n\t`id` int\n);\nCREATE VIEW `v0` AS SELECT 1 FROM `dual`;\nCREATE VIEW `v3` AS SELECT * FROM `t3` AS `t3`;\nCREATE VIEW `v9` AS SELECT 1 FROM `dual`;\nCREATE VIEW `v1` AS SELECT * FROM `v3`;\nCREATE VIEW `v2` AS SELECT * FROM `v3`, `t2`;\nCREATE VIEW `v4` AS SELECT * FROM `t2` AS `something_else`, `v3`;\nCREATE VIEW `v5` AS SELECT * FROM `t1`, (SELECT * FROM `v3`) AS `some_alias`;\nCREATE VIEW `v6` AS SELECT * FROM `v4`;\n" +var toSQL = "CREATE TABLE `t1` (\n\t`id` int\n);\nCREATE TABLE `t2` (\n\t`id` int\n);\nCREATE TABLE `t3` (\n\t`id` int,\n\t`type` enum('foo', 'bar') NOT NULL DEFAULT 'foo'\n);\nCREATE TABLE `t5` (\n\t`id` int\n);\nCREATE VIEW `v0` AS SELECT 1 FROM `dual`;\nCREATE VIEW `v3` AS SELECT *, `id` + 1 AS `id_plus`, `id` + 2 FROM `t3` AS `t3`;\nCREATE VIEW `v9` AS SELECT 1 FROM `dual`;\nCREATE VIEW `v1` AS SELECT * FROM `v3`;\nCREATE VIEW `v2` AS SELECT * FROM `v3`, `t2`;\nCREATE VIEW `v4` AS SELECT * FROM `t2` AS `something_else`, `v3`;\nCREATE VIEW `v5` AS SELECT * FROM `t1`, (SELECT * FROM `v3`) AS `some_alias`;\nCREATE VIEW `v6` AS SELECT * FROM `v4`;\n" func TestNewSchemaFromQueries(t *testing.T) { schema, err := NewSchemaFromQueries(createQueries) assert.NoError(t, err) - assert.NotNil(t, schema) + require.NotNil(t, schema) assert.Equal(t, expectSortedNames, schema.EntityNames()) assert.Equal(t, expectSortedTableNames, schema.TableNames()) @@ -89,7 +90,7 @@ func TestNewSchemaFromQueries(t *testing.T) { func TestNewSchemaFromSQL(t *testing.T) { schema, err := NewSchemaFromSQL(strings.Join(createQueries, ";")) assert.NoError(t, err) - assert.NotNil(t, schema) + require.NotNil(t, schema) assert.Equal(t, expectSortedNames, schema.EntityNames()) assert.Equal(t, expectSortedTableNames, schema.TableNames()) @@ -158,7 +159,7 @@ func TestNewSchemaFromQueriesLoop(t *testing.T) { func TestToSQL(t *testing.T) { schema, err := NewSchemaFromQueries(createQueries) assert.NoError(t, err) - assert.NotNil(t, schema) + require.NotNil(t, schema) sql := schema.ToSQL() assert.Equal(t, toSQL, sql) @@ -167,7 +168,7 @@ func TestToSQL(t *testing.T) { func TestCopy(t *testing.T) { schema, err := NewSchemaFromQueries(createQueries) assert.NoError(t, err) - assert.NotNil(t, schema) + require.NotNil(t, schema) schemaClone := schema.copy() assert.Equal(t, schema, schemaClone) @@ -398,3 +399,286 @@ func TestInvalidTableForeignKeyReference(t *testing.T) { assert.EqualError(t, err, (&ForeignKeyDependencyUnresolvedError{Table: "t11"}).Error()) } } + +func TestGetEntityColumnNames(t *testing.T) { + var queries = []string{ + "create table t1(id int, state int, some char(5))", + "create table t2(id int primary key, c char(5))", + "create view v1 as select id as id from t1", + "create view v2 as select 3+1 as `id`, state as state, some as thing from t1", + "create view v3 as select `id` as `id`, state as state, thing as another from v2", + "create view v4 as select 1 as `ok` from dual", + "create view v5 as select 1 as `ok` from DUAL", + "create view v6 as select ok as `ok` from v5", + "create view v7 as select * from t1", + "create view v8 as select * from v7", + "create view v9 as select * from v8, v6", + "create view va as select * from v6, v8", + "create view vb as select *, now() from v8", + } + + schema, err := NewSchemaFromQueries(queries) + require.NoError(t, err) + require.NotNil(t, schema) + + expectedColNames := map[string]([]string){ + "t1": []string{"id", "state", "some"}, + "t2": []string{"id", "c"}, + "v1": []string{"id"}, + "v2": []string{"id", "state", "thing"}, + "v3": []string{"id", "state", "another"}, + "v4": []string{"ok"}, + "v5": []string{"ok"}, + "v6": []string{"ok"}, + "v7": []string{"id", "state", "some"}, + "v8": []string{"id", "state", "some"}, + "v9": []string{"id", "state", "some", "ok"}, + "va": []string{"ok", "id", "state", "some"}, + "vb": []string{"id", "state", "some", "now()"}, + } + entities := schema.Entities() + require.Equal(t, len(entities), len(expectedColNames)) + + tcmap := tablesColumnsMap{} + // we test by order of dependency: + for _, e := range entities { + tbl := e.Name() + t.Run(tbl, func(t *testing.T) { + identifiers, err := schema.getEntityColumnNames(tbl, tcmap) + assert.NoError(t, err) + names := []string{} + for _, ident := range identifiers { + names = append(names, ident.String()) + } + // compare columns. We disregard order. + expectNames := expectedColNames[tbl][:] + sort.Strings(names) + sort.Strings(expectNames) + assert.Equal(t, expectNames, names) + // emulate the logic that fills known columns for known entities: + tcmap[tbl] = map[string]struct{}{} + for _, name := range names { + tcmap[tbl][name] = struct{}{} + } + }) + } +} + +func TestViewReferences(t *testing.T) { + tt := []struct { + name string + queries []string + expectErr error + }{ + { + name: "valid", + queries: []string{ + "create table t1(id int, state int, some char(5))", + "create table t2(id int primary key, c char(5))", + "create view v1 as select id as id from t1", + "create view v2 as select 3+1 as `id`, state as state, some as thing from t1", + "create view v3 as select `id` as `id`, state as state, thing as another from v2", + "create view v4 as select 1 as `ok` from dual", + "create view v5 as select 1 as `ok` from DUAL", + "create view v6 as select ok as `ok` from v5", + }, + }, + { + name: "valid WHERE", + queries: []string{ + "create table t1(id int primary key, c char(5))", + "create table t2(id int primary key, c char(5))", + "create view v1 as select c from t1 where id=3", + }, + }, + { + name: "invalid unqualified referenced column", + queries: []string{ + "create table t1(id int primary key, c char(5))", + "create table t2(id int primary key, c char(5))", + "create view v1 as select unexpected from t1", + }, + expectErr: &InvalidColumnReferencedInViewError{View: "v1", Column: "unexpected"}, + }, + { + name: "invalid unqualified referenced column in where clause", + queries: []string{ + "create table t1(id int primary key, c char(5))", + "create table t2(id int primary key, c char(5))", + "create view v1 as select 1 from t1 where unexpected=3", + }, + expectErr: &InvalidColumnReferencedInViewError{View: "v1", Column: "unexpected"}, + }, + { + name: "valid qualified", + queries: []string{ + "create table t1(id int primary key, c char(5))", + "create table t2(id int primary key, c char(5))", + "create view v1 as select t1.c from t1 where t1.id=3", + }, + }, + { + name: "valid qualified, multiple tables", + queries: []string{ + "create table t1(id int primary key, c char(5))", + "create table t2(id int primary key, c char(5))", + "create view v1 as select t1.c from t1, t2 where t2.id=3", + }, + }, + { + name: "invalid unqualified, multiple tables", + queries: []string{ + "create table t1(id int primary key, c char(5))", + "create table t2(id int primary key, c char(5))", + "create view v1 as select c from t1, t2 where t2.id=3", + }, + expectErr: &InvalidColumnReferencedInViewError{View: "v1", Column: "c", NonUnique: true}, + }, + { + name: "invalid unqualified in WHERE clause, multiple tables", + queries: []string{ + "create table t1(id int primary key, c char(5))", + "create table t2(id int primary key, c char(5))", + "create view v1 as select t2.c from t1, t2 where id=3", + }, + expectErr: &InvalidColumnReferencedInViewError{View: "v1", Column: "id", NonUnique: true}, + }, + { + name: "valid unqualified, multiple tables", + queries: []string{ + "create table t1(id int primary key, c char(5))", + "create table t2(id int primary key, c char(5), only_in_t2 int)", + "create view v1 as select only_in_t2 from t1, t2 where t1.id=3", + }, + }, + { + name: "valid unqualified in WHERE clause, multiple tables", + queries: []string{ + "create table t1(id int primary key, c char(5))", + "create table t2(id int primary key, c char(5), only_in_t2 int)", + "create view v1 as select t1.id from t1, t2 where only_in_t2=3", + }, + }, + { + name: "valid cascaded views", + queries: []string{ + "create table t1(id int primary key, c char(5))", + "create view v1 as select id, c from t1 where id > 0", + "create view v2 as select * from v1 where id > 0", + }, + }, + { + name: "valid cascaded views, column aliases", + queries: []string{ + "create table t1(id int primary key, c char(5))", + "create view v1 as select id, c as ch from t1 where id > 0", + "create view v2 as select ch from v1 where id > 0", + }, + }, + { + name: "valid cascaded views, column aliases in WHERE", + queries: []string{ + "create table t1(id int primary key, c char(5))", + "create view v1 as select id as counter, c as ch from t1 where id > 0", + "create view v2 as select ch from v1 where counter > 0", + }, + }, + { + name: "valid cascaded views, aliased expression", + queries: []string{ + "create table t1(id int primary key, c char(5))", + "create view v1 as select id+1 as counter, c as ch from t1 where id > 0", + "create view v2 as select ch from v1 where counter > 0", + }, + }, + { + name: "valid cascaded views, non aliased expression", + queries: []string{ + "create table t1(id int primary key, c char(5))", + "create view v1 as select id+1, c as ch from t1 where id > 0", + "create view v2 as select ch from v1 where `id + 1` > 0", + }, + }, + { + name: "cascaded views, invalid column aliases", + queries: []string{ + "create table t1(id int primary key, c char(5))", + "create view v1 as select id, c as ch from t1 where id > 0", + "create view v2 as select c from v1 where id > 0", + }, + expectErr: &InvalidColumnReferencedInViewError{View: "v2", Column: "c"}, + }, + { + name: "cascaded views, column not in view", + queries: []string{ + "create table t1(id int primary key, c char(5))", + "create view v1 as select id from t1 where c='x'", + "create view v2 as select c from v1 where id > 0", + }, + expectErr: &InvalidColumnReferencedInViewError{View: "v2", Column: "c"}, + }, + { + name: "complex cascade", + queries: []string{ + "create table t1(id int primary key, c char(5))", + "create table t2(id int primary key, n int, info int)", + "create view v1 as select id, c as ch from t1 where id > 0", + "create view v2 as select n as num, info from t2", + "create view v3 as select num, v1.id, ch from v1 join v2 using (id) where info > 5", + }, + }, + { + name: "valid star", + queries: []string{ + "create table t1(id int primary key, c char(5))", + "create view v1 as select * from t1 where id > 0", + }, + }, + { + name: "valid star, cascaded", + queries: []string{ + "create table t1(id int primary key, c char(5))", + "create view v1 as select t1.* from t1 where id > 0", + "create view v2 as select * from v1 where id > 0", + }, + }, + { + name: "valid star, two tables, cascaded", + queries: []string{ + "create table t1(id int primary key, c char(5))", + "create table t2(id int primary key, ts timestamp)", + "create view v1 as select t1.* from t1, t2 where t1.id > 0", + "create view v2 as select * from v1 where c > 0", + }, + }, + { + name: "valid two star, two tables, cascaded", + queries: []string{ + "create table t1(id int primary key, c char(5))", + "create table t2(id int primary key, ts timestamp)", + "create view v1 as select t1.*, t2.* from t1, t2 where t1.id > 0", + "create view v2 as select * from v1 where c > 0 and ts is not null", + }, + }, + { + name: "valid unqualified star, cascaded", + queries: []string{ + "create table t1(id int primary key, c char(5))", + "create table t2(id int primary key, ts timestamp)", + "create view v1 as select * from t1, t2 where t1.id > 0", + "create view v2 as select * from v1 where c > 0 and ts is not null", + }, + }, + } + for _, ts := range tt { + t.Run(ts.name, func(t *testing.T) { + schema, err := NewSchemaFromQueries(ts.queries) + if ts.expectErr == nil { + require.NoError(t, err) + require.NotNil(t, schema) + } else { + require.Equal(t, ts.expectErr, err, "received error: %v", err) + } + }) + } +}