Skip to content

Commit

Permalink
sql: allow check constraints on adding columns
Browse files Browse the repository at this point in the history
Add check constraints to the table descriptor immediately if any columns
are currently being added. Check validation does not get turned on until
all used columns are public.

Related to cockroachdb#29639 cockroachdb#26508.

Release note: None
  • Loading branch information
Erik Trinh committed Nov 12, 2018
1 parent a05f15b commit 63ac8cc
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 32 deletions.
7 changes: 7 additions & 0 deletions pkg/sql/alter_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,13 @@ func (n *alterTableNode) startExec(params runParams) error {
panic("constraint returned by GetConstraintInfo not found")
}
ck := n.tableDesc.Checks[idx]

if active, err := ck.IsActive(&n.tableDesc.TableDescriptor); err != nil {
return err
} else if !active {
return fmt.Errorf("constraint %q not active", t.Constraint)
}

if err := params.p.validateCheckExpr(
params.ctx, ck.Expr, &n.n.Table, n.tableDesc.TableDesc(),
); err != nil {
Expand Down
6 changes: 3 additions & 3 deletions pkg/sql/create_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -1428,8 +1428,8 @@ func replaceVars(
return nil, true, expr
}

col, err := desc.FindActiveColumnByName(string(c.ColumnName))
if err != nil {
col, dropped, err := desc.FindColumnByName(c.ColumnName)
if err != nil || dropped {
return fmt.Errorf("column %q not found for constraint %q",
c.ColumnName, expr.String()), false, nil
}
Expand Down Expand Up @@ -1478,7 +1478,7 @@ func MakeCheckConstraint(
sort.Sort(sqlbase.ColumnIDs(colIDs))

sourceInfo := sqlbase.NewSourceInfoForSingleTable(
tableName, sqlbase.ResultColumnsFromColDescs(desc.Columns),
tableName, sqlbase.ResultColumnsFromColDescs(desc.AllNonDropColumns()),
)
sources := sqlbase.MultiSourceInfo{sourceInfo}

Expand Down
17 changes: 17 additions & 0 deletions pkg/sql/logictest/testdata/logic_test/schema_change_in_txn
Original file line number Diff line number Diff line change
Expand Up @@ -499,3 +499,20 @@ INSERT INTO customers (k) VALUES ('b')

statement ok
COMMIT;

subtest check_on_add_col

statement ok
BEGIN

statement ok
ALTER TABLE forlater ADD d INT

statement ok
ALTER TABLE forlater ADD CONSTRAINT d_0 CHECK (d > 0)

statement ok
COMMIT;

statement error pq: failed to satisfy CHECK constraint \(d > 0\)
INSERT INTO forlater (k, v, d) VALUES ('z', 'x', 0)
58 changes: 45 additions & 13 deletions pkg/sql/schema_changer.go
Original file line number Diff line number Diff line change
Expand Up @@ -1100,7 +1100,7 @@ func (sc *SchemaChanger) runStateMachineAndBackfill(

// reverseMutations reverses the direction of all the mutations with the
// mutationID. This is called after hitting an irrecoverable error while
// applying a schema change. If a column being added is reversed and droped,
// applying a schema change. If a column being added is reversed and dropped,
// all new indexes referencing the column will also be dropped.
func (sc *SchemaChanger) reverseMutations(ctx context.Context, causingError error) error {
// Reverse the flow of the state machine.
Expand All @@ -1113,7 +1113,7 @@ func (sc *SchemaChanger) reverseMutations(ctx context.Context, causingError erro
_, err := sc.leaseMgr.Publish(ctx, sc.tableID, func(desc *sqlbase.TableDescriptor) error {
// Keep track of the column mutations being reversed so that indexes
// referencing them can be dropped.
columns := make(map[string]struct{})
columnIDs := make(map[sqlbase.ColumnID]struct{})
droppedMutations = nil

for i, mutation := range desc.Mutations {
Expand All @@ -1133,19 +1133,22 @@ func (sc *SchemaChanger) reverseMutations(ctx context.Context, causingError erro
}

log.Warningf(ctx, "reverse schema change mutation: %+v", mutation)
desc.Mutations[i], columns = sc.reverseMutation(mutation, false /*notStarted*/, columns)
desc.Mutations[i] = sc.reverseMutation(mutation, false /*notStarted*/, columnIDs)

desc.Mutations[i].Rollback = true
}

// Delete all mutations that reference any of the reversed columns
// by running a graph traversal of the mutations.
if len(columns) > 0 {
if len(columnIDs) > 0 {
var err error
droppedMutations, err = sc.deleteIndexMutationsWithReversedColumns(ctx, desc, columns)
droppedMutations, err = sc.deleteIndexMutationsWithReversedColumns(ctx, desc, columnIDs)
if err != nil {
return err
}
if err := deleteCheckConstraints(desc, columnIDs); err != nil {
return err
}
}

// Publish() will increment the version.
Expand Down Expand Up @@ -1283,12 +1286,41 @@ func (sc *SchemaChanger) createRollbackJob(
return nil, fmt.Errorf("no job found for table %d mutation %d", sc.tableID, sc.mutationID)
}

// deleteCheckConstraints deletes check constraints which reference
// any of the reversed columns.
func deleteCheckConstraints(
desc *sqlbase.TableDescriptor, columnIDs map[sqlbase.ColumnID]struct{},
) error {
validChecks := desc.Checks[:0]
for _, check := range desc.Checks {
drop := false
for colID := range columnIDs {
used, err := check.UsesColumn(desc, colID)
if err != nil {
return err
}

if used {
drop = true
break
}
}

if !drop {
validChecks = append(validChecks, check)
}
}

desc.Checks = validChecks
return nil
}

// deleteIndexMutationsWithReversedColumns deletes mutations with a
// different mutationID than the schema changer and with an index that
// references one of the reversed columns. Execute this as a breadth
// first search graph traversal.
func (sc *SchemaChanger) deleteIndexMutationsWithReversedColumns(
ctx context.Context, desc *sqlbase.TableDescriptor, columns map[string]struct{},
ctx context.Context, desc *sqlbase.TableDescriptor, columnIDs map[sqlbase.ColumnID]struct{},
) (map[sqlbase.MutationID]struct{}, error) {
dropMutations := make(map[sqlbase.MutationID]struct{})
// Run breadth first search traversal that reverses mutations
Expand All @@ -1297,8 +1329,8 @@ func (sc *SchemaChanger) deleteIndexMutationsWithReversedColumns(
for _, mutation := range desc.Mutations {
if mutation.MutationID != sc.mutationID {
if idx := mutation.GetIndex(); idx != nil {
for _, name := range idx.ColumnNames {
if _, ok := columns[name]; ok {
for _, id := range idx.ColumnIDs {
if _, ok := columnIDs[id]; ok {
// Such an index mutation has to be with direction ADD and
// in the DELETE_ONLY state. Live indexes referencing live
// columns cannot be deleted and thus never have direction
Expand Down Expand Up @@ -1328,7 +1360,7 @@ func (sc *SchemaChanger) deleteIndexMutationsWithReversedColumns(
// Reverse mutation. Update columns to reflect additional
// columns that have been purged. This mutation doesn't need
// a rollback because it was not started.
mutation, columns = sc.reverseMutation(mutation, true /*notStarted*/, columns)
mutation = sc.reverseMutation(mutation, true /*notStarted*/, columnIDs)
// Mark as complete because this mutation needs no backfill.
if err := desc.MakeMutationComplete(mutation); err != nil {
return nil, err
Expand All @@ -1347,14 +1379,14 @@ func (sc *SchemaChanger) deleteIndexMutationsWithReversedColumns(
// notStarted is set to true only if the schema change state machine
// was not started for the mutation.
func (sc *SchemaChanger) reverseMutation(
mutation sqlbase.DescriptorMutation, notStarted bool, columns map[string]struct{},
) (sqlbase.DescriptorMutation, map[string]struct{}) {
mutation sqlbase.DescriptorMutation, notStarted bool, columnIDs map[sqlbase.ColumnID]struct{},
) sqlbase.DescriptorMutation {
switch mutation.Direction {
case sqlbase.DescriptorMutation_ADD:
mutation.Direction = sqlbase.DescriptorMutation_DROP
// A column ADD being reversed gets placed in the map.
if col := mutation.GetColumn(); col != nil {
columns[col.Name] = struct{}{}
columnIDs[col.ID] = struct{}{}
}
if notStarted && mutation.State != sqlbase.DescriptorMutation_DELETE_ONLY {
panic(fmt.Sprintf("mutation in bad state: %+v", mutation))
Expand All @@ -1366,7 +1398,7 @@ func (sc *SchemaChanger) reverseMutation(
panic(fmt.Sprintf("mutation in bad state: %+v", mutation))
}
}
return mutation, columns
return mutation
}

// TestingSchemaChangerCollection is an exported (for testing) version of
Expand Down
16 changes: 10 additions & 6 deletions pkg/sql/scrub.go
Original file line number Diff line number Diff line change
Expand Up @@ -502,12 +502,16 @@ func createConstraintCheckOperations(
for _, constraint := range constraints {
switch constraint.Kind {
case sqlbase.ConstraintTypeCheck:
results = append(results, newSQLCheckConstraintCheckOperation(
tableName,
tableDesc,
constraint.CheckConstraint,
asOf,
))
if active, err := constraint.CheckConstraint.IsActive(tableDesc); err != nil {
return nil, err
} else if active {
results = append(results, newSQLCheckConstraintCheckOperation(
tableName,
tableDesc,
constraint.CheckConstraint,
asOf,
))
}
case sqlbase.ConstraintTypeFK:
results = append(results, newSQLForeignKeyCheckOperation(
tableName,
Expand Down
12 changes: 8 additions & 4 deletions pkg/sql/sqlbase/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,15 @@ func (c *CheckHelper) Init(
*tn, ResultColumnsFromColDescs(tableDesc.Columns),
)

c.Exprs = make([]tree.TypedExpr, len(tableDesc.Checks))
exprStrings := make([]string, len(tableDesc.Checks))
for i, check := range tableDesc.Checks {
exprStrings[i] = check.Expr
var exprStrings []string
for _, check := range tableDesc.Checks {
if active, err := check.IsActive(tableDesc); err != nil {
return err
} else if active {
exprStrings = append(exprStrings, check.Expr)
}
}
c.Exprs = make([]tree.TypedExpr, len(exprStrings))
exprs, err := parser.ParseExprs(exprStrings)
if err != nil {
return err
Expand Down
29 changes: 23 additions & 6 deletions pkg/sql/sqlbase/structured.go
Original file line number Diff line number Diff line change
Expand Up @@ -421,9 +421,9 @@ func (desc *TableDescriptor) KeysPerRow(indexID IndexID) int {
return 1
}

// allNonDropColumns returns all the columns, including those being added
// AllNonDropColumns returns all the columns, including those being added
// in the mutations.
func (desc *TableDescriptor) allNonDropColumns() []ColumnDescriptor {
func (desc *TableDescriptor) AllNonDropColumns() []ColumnDescriptor {
cols := make([]ColumnDescriptor, 0, len(desc.Columns)+len(desc.Mutations))
cols = append(cols, desc.Columns...)
for _, m := range desc.Mutations {
Expand Down Expand Up @@ -1101,7 +1101,7 @@ func (desc *TableDescriptor) ValidateTable(st *cluster.Settings) error {

columnNames := make(map[string]ColumnID, len(desc.Columns))
columnIDs := make(map[ColumnID]string, len(desc.Columns))
for _, column := range desc.allNonDropColumns() {
for _, column := range desc.AllNonDropColumns() {
if err := validateName(column.Name, "column"); err != nil {
return err
}
Expand Down Expand Up @@ -1603,7 +1603,7 @@ func notIndexableError(cols []ColumnDescriptor, inverted bool) error {
func checkColumnsValidForIndex(tableDesc *TableDescriptor, indexColNames []string) error {
invalidColumns := make([]ColumnDescriptor, 0, len(indexColNames))
for _, indexCol := range indexColNames {
for _, col := range tableDesc.allNonDropColumns() {
for _, col := range tableDesc.AllNonDropColumns() {
if col.Name == indexCol {
if !columnTypeIsIndexable(col.Type) {
invalidColumns = append(invalidColumns, col)
Expand All @@ -1623,7 +1623,7 @@ func checkColumnsValidForInvertedIndex(tableDesc *TableDescriptor, indexColNames
}
invalidColumns := make([]ColumnDescriptor, 0, len(indexColNames))
for _, indexCol := range indexColNames {
for _, col := range tableDesc.allNonDropColumns() {
for _, col := range tableDesc.AllNonDropColumns() {
if col.Name == indexCol {
if !columnTypeIsInvertedIndexable(col.Type) {
invalidColumns = append(invalidColumns, col)
Expand Down Expand Up @@ -2387,6 +2387,23 @@ func (cc *TableDescriptor_CheckConstraint) UsesColumn(
return i < len(colsUsed) && colsUsed[i] == colID, nil
}

// IsActive returns whether the check constraint uses columns which are all
// active.
func (cc *TableDescriptor_CheckConstraint) IsActive(desc *TableDescriptor) (bool, error) {
cols, err := cc.ColumnsUsed(desc)
if err != nil {
return false, err
}

for _, col := range cols {
if _, err := desc.FindActiveColumnByID(col); err != nil {
return false, nil
}
}

return true, nil
}

// ForeignKeyReferenceActionValue allows the conversion between a
// tree.ReferenceAction and a ForeignKeyReference_Action.
var ForeignKeyReferenceActionValue = [...]ForeignKeyReference_Action{
Expand Down Expand Up @@ -2505,7 +2522,7 @@ func (desc *TableDescriptor) FindAllReferences() (map[ID]struct{}, error) {
return nil, err
}

for _, c := range desc.allNonDropColumns() {
for _, c := range desc.AllNonDropColumns() {
for _, id := range c.UsesSequenceIds {
refs[id] = struct{}{}
}
Expand Down

0 comments on commit 63ac8cc

Please sign in to comment.