Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

schemadiff: validate views' referenced columns #12147

Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go/vt/schemadiff/diff_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,7 @@ func TestDiffSchemas(t *testing.T) {
// validate schema1 unaffected by Apply
assert.Equal(t, schema1SQL, schema1.ToSQL())

require.NotNil(t, schema2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this check useful? We already check earlier that there's no error returned when schema2 is created?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

appliedDiff, err := schema2.Diff(applied, hints)
require.NoError(t, err)
assert.Empty(t, appliedDiff)
Expand Down
28 changes: 28 additions & 0 deletions go/vt/schemadiff/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,3 +308,31 @@ type ViewDependencyUnresolvedError struct {
func (e *ViewDependencyUnresolvedError) Error() string {
return fmt.Sprintf("view %s has unresolved/loop dependencies", sqlescape.EscapeID(e.View))
}

type InvalidColumnReferencedInViewError struct {
View string
Table string
Column string
NonUnique bool
}

func (e *InvalidColumnReferencedInViewError) Error() string {
switch {
case e.Column == "":
return fmt.Sprintf("view %s references non-existing table %s", sqlescape.EscapeID(e.View), sqlescape.EscapeID(e.Table))
mattlord marked this conversation as resolved.
Show resolved Hide resolved
case e.Table != "":
return fmt.Sprintf("view %s references non existing column %s.%s", sqlescape.EscapeID(e.View), sqlescape.EscapeID(e.Table), sqlescape.EscapeID(e.Column))
mattlord marked this conversation as resolved.
Show resolved Hide resolved
case e.NonUnique:
return fmt.Sprintf("view %s references unqualified but non unique column %s", sqlescape.EscapeID(e.View), sqlescape.EscapeID(e.Column))
default:
return fmt.Sprintf("view %s references unqualified but non existing column %s", sqlescape.EscapeID(e.View), sqlescape.EscapeID(e.Column))
}
}

type EntityNotFoundError struct {
Name string
}

func (e *EntityNotFoundError) Error() string {
return fmt.Sprintf("entity %s not found", sqlescape.EscapeID(e.Name))
}
264 changes: 264 additions & 0 deletions go/vt/schemadiff/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,13 @@ import (
"sort"
"strings"

"go.uber.org/multierr"

"vitess.io/vitess/go/vt/sqlparser"
)

type tablesColumnsMap map[string]map[string]struct{}

// Schema represents a database schema, which may contain entities such as tables and views.
// Schema is not in itself an Entity, since it is more of a collection of entities.
type Schema struct {
Expand Down Expand Up @@ -309,6 +313,11 @@ func (s *Schema) normalize() error {
}
}

// Validate views' referenced columns: do these columns actually exist in referenced tables/views?
if err := s.ValidateViewReferences(); err != nil {
return err
}

// Validate table definitions
for _, t := range s.tables {
if err := t.validate(); err != nil {
Expand Down Expand Up @@ -750,3 +759,258 @@ func (s *Schema) Apply(diffs []EntityDiff) (*Schema, error) {
}
return dup, nil
}

// TODO
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this TODO reference to?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hehe, it references the fact that is is already done... removing...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed


func (s *Schema) ValidateViewReferences() error {
var allerrors error
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually use errs as a the name for this kinda thing to collect errors.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

availableColumns := tablesColumnsMap{}

for _, e := range s.Entities() {
entityColumns, err := s.getEntityColumnNames(e.Name(), availableColumns)
if err != nil {
return err
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this add to the error list instead? And then we continue through the rest of the entities? If this is a totally unexpected error, maybe good to add a comment about that and why it's safe to return immediately?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

}
availableColumns[e.Name()] = map[string]struct{}{}
for _, col := range entityColumns {
availableColumns[e.Name()][col.Lowered()] = struct{}{}
}
}

// Add dual table with no explicit columns for dual style expressions in views.
availableColumns["dual"] = map[string]struct{}{}

for _, view := range s.Views() {
// First gather all referenced tables and table aliases
tableAliases := map[string]string{}
tableReferences := map[string]struct{}{}
err := gatherTableInformationForView(view, availableColumns, tableReferences, tableAliases)
allerrors = multierr.Append(allerrors, err)

// Now we can walk the view again and check each column expression
// to see if there's an existing column referenced.
err = gatherColumnReferenceInformationForView(view, availableColumns, tableReferences, tableAliases)
allerrors = multierr.Append(allerrors, err)
}
return allerrors
}

func gatherTableInformationForView(view *CreateViewEntity, availableColumns tablesColumnsMap, tableReferences map[string]struct{}, tableAliases map[string]string) error {
var allerrors error
tableErrors := make(map[string]struct{})
err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
switch node := node.(type) {
case *sqlparser.AliasedTableExpr:
aliased := sqlparser.GetTableName(node.Expr).String()
if aliased == "" {
return true, nil
}

if _, ok := availableColumns[aliased]; !ok {
if _, ok := tableErrors[aliased]; ok {
// Only show a missing table reference once per view.
return true, nil
}
err := &InvalidColumnReferencedInViewError{
View: view.Name(),
Table: aliased,
}
allerrors = multierr.Append(allerrors, err)
tableErrors[aliased] = struct{}{}
return true, nil
}
tableReferences[aliased] = struct{}{}
if node.As.String() != "" {
tableAliases[node.As.String()] = aliased
}
}
return true, nil
}, view.Select)
if err != nil {
// parsing error. Forget about any view dependency issues we may have found. This is way more important
return err
}
return allerrors
}

func gatherColumnReferenceInformationForView(view *CreateViewEntity, availableColumns tablesColumnsMap, tableReferences map[string]struct{}, tableAliases map[string]string) error {
var allerrors error
qualifiedColumnErrors := make(map[string]map[string]struct{})
unqualifiedColumnErrors := make(map[string]struct{})

err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
switch node := node.(type) {
case *sqlparser.ColName:
if node.Qualifier.IsEmpty() {
err := verifyUnqualifiedColumn(view, availableColumns, tableReferences, node.Name, unqualifiedColumnErrors)
allerrors = multierr.Append(allerrors, err)
} else {
err := verifyQualifiedColumn(view, availableColumns, tableAliases, node, qualifiedColumnErrors)
allerrors = multierr.Append(allerrors, err)
}
}
return true, nil
}, view.Select)
if err != nil {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI - sqlparser.Walk will only return the error you return from inside the visitor function, and since you don't return any errors from that, no errors will ever make it out.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@systay ah, great! Thank you. I'll keep the check as it is, for safety, but good to know!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@systay Right, but the errors are gathered here in errs? And those are at the end returned if the walker itself doesn't error (which is not expected).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dbussink yeah, that was my point. Not really necessary to catch the returned error from sqlparser.Walk since that is not how we are dealing with the errors. When I know it will never return an error, I usually write.

	_ = sqlparser.Walk(...)

OTOH - it's probably good defensive programming to do as @shlomi-noach is doing here and catching and checking the error anyway.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@systay Right, guess as someone now knowing the details too much about how Walk is implemented, I wouldn't know that it can't return an error and would still write it defensively 😄.

// parsing error. Forget about any view dependency issues we may have found. This is way more important
return err
}
return allerrors
}

func verifyUnqualifiedColumn(view *CreateViewEntity, availableColumns tablesColumnsMap, tableReferences map[string]struct{}, nodeName sqlparser.IdentifierCI, unqualifiedColumnErrors map[string]struct{}) error {
// In case we have a non-qualified column reference, it needs
// to be unique across all referenced tables if it is supposed
// to work.
columnFound := false
for table := range tableReferences {
cols, ok := availableColumns[table]
if !ok {
// We already dealt with an error for a missing table reference
// earlier, so we can ignore it at this point here.
return nil
}
_, columnInTable := cols[nodeName.Lowered()]
if !columnInTable {
continue
}
if columnFound {
// We already have seen the column before in another table, so
// if we see it again here, that's an error case.
if _, ok := unqualifiedColumnErrors[nodeName.Lowered()]; ok {
return nil
}
unqualifiedColumnErrors[nodeName.Lowered()] = struct{}{}
return &InvalidColumnReferencedInViewError{
View: view.Name(),
Column: nodeName.String(),
NonUnique: true,
}
}
columnFound = true
}

// If we've seen the desired column here once, we're all good
if columnFound {
return nil
}

if _, ok := unqualifiedColumnErrors[nodeName.Lowered()]; ok {
return nil
}
unqualifiedColumnErrors[nodeName.Lowered()] = struct{}{}
return &InvalidColumnReferencedInViewError{
View: view.Name(),
Column: nodeName.String(),
}
}

func verifyQualifiedColumn(
view *CreateViewEntity,
availableColumns tablesColumnsMap,
tableAliases map[string]string, node *sqlparser.ColName,
columnErrors map[string]map[string]struct{},
) error {
tableName := node.Qualifier.Name.String()
if aliased, ok := tableAliases[tableName]; ok {
tableName = aliased
}
cols, ok := availableColumns[tableName]
if !ok {
// Already dealt with missing tables earlier on, we don't have
// any error to add here.
return nil
}
_, ok = cols[node.Name.Lowered()]
if ok {
// Found the column in the table, all good.
return nil
}

if _, ok := columnErrors[tableName]; !ok {
columnErrors[tableName] = make(map[string]struct{})
}

if _, ok := columnErrors[tableName][node.Name.Lowered()]; ok {
return nil
dbussink marked this conversation as resolved.
Show resolved Hide resolved
}
columnErrors[tableName][node.Name.Lowered()] = struct{}{}
return &InvalidColumnReferencedInViewError{
View: view.Name(),
Table: tableName,
Column: node.Name.String(),
}
}

// getTableColumnNames returns the names of columns in given table.
func (s *Schema) getEntityColumnNames(entityName string, availableColumns tablesColumnsMap) (
columnNames []*sqlparser.IdentifierCI,
err error,
) {
entity := s.Entity(entityName)
if entity == nil {
if strings.ToLower(entityName) == "dual" {
// this is fine. DUAL does not exist but is allowed
return nil, nil
}
return nil, &EntityNotFoundError{Name: entityName}
}
// The entity is either a table or a view
switch entity := entity.(type) {
case *CreateTableEntity:
return s.getTableColumnNames(entity), nil
case *CreateViewEntity:
return s.getViewColumnNames(entity, availableColumns)
}
return nil, &EntityNotFoundError{Name: entityName}
}

// getTableColumnNames returns the names of columns in given table.
func (s *Schema) getTableColumnNames(t *CreateTableEntity) (columnNames []*sqlparser.IdentifierCI) {
for _, c := range t.TableSpec.Columns {
columnNames = append(columnNames, &c.Name)
}
return columnNames
}

// getViewColumnNames returns the names of aliased columns returned by a given view.
func (s *Schema) getViewColumnNames(v *CreateViewEntity, availableColumns tablesColumnsMap) (
columnNames []*sqlparser.IdentifierCI,
err error,
) {
err = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
switch node := node.(type) {
case *sqlparser.StarExpr:
if tableName := node.TableName.Name.String(); tableName != "" {
for colName := range availableColumns[tableName] {
name := sqlparser.NewIdentifierCI(colName)
columnNames = append(columnNames, &name)
}
} else {
dependentNames, err := getViewDependentTableNames(v.CreateView)
if err != nil {
return false, err
}
// add all columns from all referenced tables and views
for _, entityName := range dependentNames {
for colName := range availableColumns[entityName] {
name := sqlparser.NewIdentifierCI(colName)
columnNames = append(columnNames, &name)
}
}
}
case *sqlparser.AliasedExpr:
if node.As.String() != "" {
columnNames = append(columnNames, &node.As)
} else {
name := sqlparser.NewIdentifierCI(sqlparser.String(node.Expr))
columnNames = append(columnNames, &name)
}
}
return true, nil
}, v.Select.GetColumns())
if err != nil {
return nil, err
}
return columnNames, nil
}
Loading