diff --git a/pkg/sql/backfill.go b/pkg/sql/backfill.go index f833c61229d1..38994cd92d5e 100644 --- a/pkg/sql/backfill.go +++ b/pkg/sql/backfill.go @@ -248,31 +248,39 @@ func (sc *SchemaChanger) AddConstraints( fksByBackrefTable[c.ForeignKey.Table] = append(fksByBackrefTable[c.ForeignKey.Table], c) } } + tableIDsToUpdate := make([]sqlbase.ID, 0, len(fksByBackrefTable)+1) + tableIDsToUpdate = append(tableIDsToUpdate, sc.tableID) + for id := range fksByBackrefTable { + tableIDsToUpdate = append(tableIDsToUpdate, id) + } - // Create map of update closures for the table and all other tables with backreferences - updates := make(map[sqlbase.ID]func(descriptor *sqlbase.MutableTableDescriptor) error) - updates[sc.tableID] = func(desc *sqlbase.MutableTableDescriptor) error { + // Create update closure for the table and all other tables with backreferences + update := func(descs map[sqlbase.ID]*sqlbase.MutableTableDescriptor) error { + scTable, ok := descs[sc.tableID] + if !ok { + return errors.AssertionFailedf("required table with ID %d not provided to update closure", sc.tableID) + } for i := range constraints { - added := &constraints[i] - switch added.ConstraintType { + constraint := &constraints[i] + switch constraint.ConstraintType { case sqlbase.ConstraintToUpdate_CHECK, sqlbase.ConstraintToUpdate_NOT_NULL: found := false - for _, c := range desc.Checks { - if c.Name == added.Name { + for _, c := range scTable.Checks { + if c.Name == constraint.Name { log.VEventf( ctx, 2, "backfiller tried to add constraint %+v but found existing constraint %+v, presumably due to a retry", - added, c, + constraint, c, ) found = true break } } if !found { - desc.Checks = append(desc.Checks, &constraints[i].Check) + scTable.Checks = append(scTable.Checks, &constraints[i].Check) } case sqlbase.ConstraintToUpdate_FOREIGN_KEY: - idx, err := desc.FindIndexByID(added.ForeignKeyIndex) + idx, err := scTable.FindIndexByID(constraint.ForeignKeyIndex) if err != nil { return err } @@ -281,39 +289,29 @@ func (sc *SchemaChanger) AddConstraints( log.VEventf( ctx, 2, "backfiller tried to add constraint %+v but found existing constraint %+v, presumably due to a retry", - added, idx.ForeignKey, + constraint, idx.ForeignKey, ) } } else { - idx.ForeignKey = added.ForeignKey - // If there are any backreferences to be added to the same table, add them here - if added.ForeignKey.Table == sc.tableID { - backref := sqlbase.ForeignKeyReference{Table: sc.tableID, Index: added.ForeignKeyIndex} - idx, err := desc.FindIndexByID(added.ForeignKey.Index) - if err != nil { - return err - } - idx.ReferencedBy = append(idx.ReferencedBy, backref) + idx.ForeignKey = constraint.ForeignKey + // Add backreference on the referenced table (which could be the same table) + backref := sqlbase.ForeignKeyReference{Table: sc.tableID, Index: constraint.ForeignKeyIndex} + backrefTable, ok := descs[constraint.ForeignKey.Table] + if !ok { + return errors.AssertionFailedf("required table with ID %d not provided to update closure", sc.tableID) + } + backrefIdx, err := backrefTable.FindIndexByID(constraint.ForeignKey.Index) + if err != nil { + return err } + backrefIdx.ReferencedBy = append(backrefIdx.ReferencedBy, backref) } } } return nil } - for id := range fksByBackrefTable { - updates[id] = func(desc *sqlbase.MutableTableDescriptor) error { - for _, c := range fksByBackrefTable[id] { - backref := sqlbase.ForeignKeyReference{Table: sc.tableID, Index: c.ForeignKeyIndex} - idx, err := desc.FindIndexByID(c.ForeignKey.Index) - if err != nil { - return err - } - idx.ReferencedBy = append(idx.ReferencedBy, backref) - } - return nil - } - } - if _, err := sc.leaseMgr.PublishMultiple(ctx, updates, nil); err != nil { + + if _, err := sc.leaseMgr.PublishMultiple(ctx, tableIDsToUpdate, update, nil); err != nil { return err } if err := sc.waitToUpdateLeases(ctx, sc.tableID); err != nil { @@ -342,7 +340,7 @@ func (sc *SchemaChanger) validateConstraints( return err } - if fn := sc.testingKnobs.RunBeforeChecksValidation; fn != nil { + if fn := sc.testingKnobs.RunBeforeConstraintValidation; fn != nil { if err := fn(); err != nil { return err } diff --git a/pkg/sql/lease.go b/pkg/sql/lease.go index 15bf0f6dfb0d..20e8226e56c1 100644 --- a/pkg/sql/lease.go +++ b/pkg/sql/lease.go @@ -339,16 +339,18 @@ var errDidntUpdateDescriptor = errors.New("didn't update the table descriptor") // time by first waiting for all nodes to be on the current (pre-update) version // of the table desc. // -// The update closure for each table ID is called after the wait, and it -// provides the new version of the descriptor to be written. +// The update closure for all tables is called after the wait. The argument to +// the closure is a map of the table descriptors with the IDs given in tableIDs, +// and the closure mutates those descriptors. // -// The closures may be called multiple times if retries occur; make sure they do +// The closure may be called multiple times if retries occur; make sure it does // not have side effects. // // Returns the updated versions of the descriptors. func (s LeaseStore) PublishMultiple( ctx context.Context, - updates map[sqlbase.ID]func(descriptor *sqlbase.MutableTableDescriptor) error, + tableIDs []sqlbase.ID, + update func(map[sqlbase.ID]*sqlbase.MutableTableDescriptor) error, logEvent func(*client.Txn) error, ) (map[sqlbase.ID]*sqlbase.ImmutableTableDescriptor, error) { errLeaseVersionChanged := errors.New("lease version changed") @@ -357,7 +359,7 @@ func (s LeaseStore) PublishMultiple( // Wait until there are no unexpired leases on the previous versions // of the tables. expectedVersions := make(map[sqlbase.ID]sqlbase.DescriptorVersion) - for id := range updates { + for _, id := range tableIDs { expected, err := s.WaitForOneVersion(ctx, id, base.DefaultRetryOptions()) if err != nil { return nil, err @@ -369,45 +371,50 @@ func (s LeaseStore) PublishMultiple( // There should be only one version of the descriptor, but it's // a race now to update to the next version. err := s.db.Txn(ctx, func(ctx context.Context, txn *client.Txn) error { - for id, update := range updates { + versions := make(map[sqlbase.ID]sqlbase.DescriptorVersion) + descsToUpdate := make(map[sqlbase.ID]*sqlbase.MutableTableDescriptor) + for _, id := range tableIDs { // Re-read the current versions of the table descriptor, this time // transactionally. var err error - tableDesc, err := sqlbase.GetMutableTableDescFromID(ctx, txn, id) + descsToUpdate[id], err = sqlbase.GetMutableTableDescFromID(ctx, txn, id) if err != nil { return err } - if expectedVersions[id] != tableDesc.Version { + if expectedVersions[id] != descsToUpdate[id].Version { // The version changed out from under us. Someone else must be // performing a schema change operation. if log.V(3) { - log.Infof(ctx, "publish (version changed): %d != %d", expectedVersions[id], tableDesc.Version) + log.Infof(ctx, "publish (version changed): %d != %d", expectedVersions[id], descsToUpdate[id].Version) } return errLeaseVersionChanged } - // Run the update closure. - version := tableDesc.Version - if err := update(tableDesc); err != nil { - return err - } - if version != tableDesc.Version { + versions[id] = descsToUpdate[id].Version + } + + // Run the update closure. + if err := update(descsToUpdate); err != nil { + return err + } + for _, id := range tableIDs { + if versions[id] != descsToUpdate[id].Version { return errors.Errorf("updated version to: %d, expected: %d", - tableDesc.Version, version) + descsToUpdate[id].Version, versions[id]) } - if err := tableDesc.MaybeIncrementVersion(ctx, txn); err != nil { + if err := descsToUpdate[id].MaybeIncrementVersion(ctx, txn); err != nil { return err } - if err := tableDesc.ValidateTable(s.settings); err != nil { + if err := descsToUpdate[id].ValidateTable(s.settings); err != nil { return err } - tableDescs[id] = tableDesc + tableDescs[id] = descsToUpdate[id] } - // Write the updated descriptor. + // Write the updated descriptors. if err := txn.SetSystemConfigTrigger(); err != nil { return err } @@ -470,10 +477,16 @@ func (s LeaseStore) Publish( update func(*sqlbase.MutableTableDescriptor) error, logEvent func(*client.Txn) error, ) (*sqlbase.ImmutableTableDescriptor, error) { - updates := make(map[sqlbase.ID]func(descriptor *sqlbase.MutableTableDescriptor) error) - updates[tableID] = update + tableIDs := []sqlbase.ID{tableID} + updates := func(descs map[sqlbase.ID]*sqlbase.MutableTableDescriptor) error { + desc, ok := descs[tableID] + if !ok { + return errors.AssertionFailedf("required table with ID %d not provided to update closure", tableID) + } + return update(desc) + } - results, err := s.PublishMultiple(ctx, updates, logEvent) + results, err := s.PublishMultiple(ctx, tableIDs, updates, logEvent) if err != nil { return nil, err } diff --git a/pkg/sql/schema_changer.go b/pkg/sql/schema_changer.go index 999a734fa91d..869263681257 100644 --- a/pkg/sql/schema_changer.go +++ b/pkg/sql/schema_changer.go @@ -1277,21 +1277,25 @@ func (sc *SchemaChanger) reverseMutations(ctx context.Context, causingError erro if err != nil { return err } + tableIDsToUpdate := make([]sqlbase.ID, 0, len(fksByBackrefTable)+1) + tableIDsToUpdate = append(tableIDsToUpdate, sc.tableID) + for id := range fksByBackrefTable { + tableIDsToUpdate = append(tableIDsToUpdate, id) + } - // Create map of update closures for the table and all other tables with backreferences - updates := make(map[sqlbase.ID]func(descriptor *sqlbase.MutableTableDescriptor) error) - // All the mutations dropped by the reversal of the schema change. - // This is created by traversing the mutations list like a graph - // where the indexes refer columns. Whenever a column schema change - // is reversed, any index mutation referencing it is also reversed. + // Create update closure for the table and all other tables with backreferences var droppedMutations map[sqlbase.MutationID]struct{} - updates[sc.tableID] = func(desc *sqlbase.MutableTableDescriptor) error { + update := func(descs map[sqlbase.ID]*sqlbase.MutableTableDescriptor) error { + scDesc, ok := descs[sc.tableID] + if !ok { + return errors.AssertionFailedf("required table with ID %d not provided to update closure", sc.tableID) + } // Keep track of the column mutations being reversed so that indexes // referencing them can be dropped. columns := make(map[string]struct{}) droppedMutations = nil - for i, mutation := range desc.Mutations { + for i, mutation := range scDesc.Mutations { if mutation.MutationID != sc.mutationID { // Only reverse the first set of mutations if they have the // mutation ID we're looking for. @@ -1308,35 +1312,36 @@ func (sc *SchemaChanger) reverseMutations(ctx context.Context, causingError erro } log.Warningf(ctx, "reverse schema change mutation: %+v", mutation) - desc.Mutations[i], columns = sc.reverseMutation(mutation, false /*notStarted*/, columns) + scDesc.Mutations[i], columns = sc.reverseMutation(mutation, false /*notStarted*/, columns) // If the mutation is for validating a constraint that is being added, // drop the constraint because validation has failed if constraint := mutation.GetConstraint(); constraint != nil && mutation.Direction == sqlbase.DescriptorMutation_ADD { log.Warningf(ctx, "dropping constraint %+v", constraint) - if err := sc.maybeDropValidatingConstraint(ctx, desc, constraint); err != nil { + if err := sc.maybeDropValidatingConstraint(ctx, scDesc, constraint); err != nil { return err } // Get the foreign key backreferences to remove, and remove them immediately if they're on the same table if constraint.ConstraintType == sqlbase.ConstraintToUpdate_FOREIGN_KEY { fk := &constraint.ForeignKey - if fk.Table == desc.ID { - if err := removeFKBackReferenceFromTable(desc, fk.Index, desc.ID, constraint.ForeignKeyIndex); err != nil { - return err - } + backrefTable, ok := descs[fk.Table] + if !ok { + return errors.AssertionFailedf("required table with ID %d not provided to update closure", sc.tableID) + } + if err := removeFKBackReferenceFromTable(backrefTable, fk.Index, sc.tableID, constraint.ForeignKeyIndex); err != nil { + return err } } } - - desc.Mutations[i].Rollback = true + scDesc.Mutations[i].Rollback = true } // Delete all mutations that reference any of the reversed columns // by running a graph traversal of the mutations. if len(columns) > 0 { var err error - droppedMutations, err = sc.deleteIndexMutationsWithReversedColumns(ctx, desc, columns) + droppedMutations, err = sc.deleteIndexMutationsWithReversedColumns(ctx, scDesc, columns) if err != nil { return err } @@ -1345,18 +1350,8 @@ func (sc *SchemaChanger) reverseMutations(ctx context.Context, causingError erro // PublishMultiple() will increment the version. return nil } - for id := range fksByBackrefTable { - updates[id] = func(desc *sqlbase.MutableTableDescriptor) error { - for _, c := range fksByBackrefTable[id] { - if err := removeFKBackReferenceFromTable(desc, c.ForeignKey.Index, sc.tableID, c.ForeignKeyIndex); err != nil { - return err - } - } - return nil - } - } - _, err = sc.leaseMgr.PublishMultiple(ctx, updates, func(txn *client.Txn) error { + _, err = sc.leaseMgr.PublishMultiple(ctx, tableIDsToUpdate, update, func(txn *client.Txn) error { // Read the table descriptor from the store. The Version of the // descriptor has already been incremented in the transaction and // this descriptor can be modified without incrementing the version. @@ -1671,9 +1666,9 @@ type SchemaChangerTestingKnobs struct { // after setting the job status to validating. RunBeforeIndexValidation func() error - // RunBeforeChecksValidation is called just before starting the checks validation, + // RunBeforeConstraintValidation is called just before starting the checks validation, // after setting the job status to validating. - RunBeforeChecksValidation func() error + RunBeforeConstraintValidation func() error // OldNamesDrainedNotification is called during a schema change, // after all leases on the version of the descriptor with the old diff --git a/pkg/sql/schema_changer_test.go b/pkg/sql/schema_changer_test.go index 8cca477c2bf2..3a37ad0ff0f2 100644 --- a/pkg/sql/schema_changer_test.go +++ b/pkg/sql/schema_changer_test.go @@ -4826,11 +4826,11 @@ CREATE TABLE t.test (k INT PRIMARY KEY, v INT); func TestSchemaChangeJobRunningStatusValidation(t *testing.T) { defer leaktest.AfterTest(t)() params, _ := tests.CreateTestServerParams() - var runBeforeChecksValidation func() error + var runBeforeConstraintValidation func() error params.Knobs = base.TestingKnobs{ SQLSchemaChanger: &sql.SchemaChangerTestingKnobs{ - RunBeforeChecksValidation: func() error { - return runBeforeChecksValidation() + RunBeforeConstraintValidation: func() error { + return runBeforeConstraintValidation() }, }, // Disable backfill migrations, we still need the jobs table migration. @@ -4851,7 +4851,7 @@ INSERT INTO t.test (k, v) VALUES (1, 99), (2, 100); tableDesc := sqlbase.GetTableDescriptor(kvDB, "t", "test") sqlRun := sqlutils.MakeSQLRunner(sqlDB) - runBeforeChecksValidation = func() error { + runBeforeConstraintValidation = func() error { return jobutils.VerifyRunningSystemJob(t, sqlRun, 0, jobspb.TypeSchemaChange, sql.RunningStatusValidation, jobs.Record{ Username: security.RootUser, Description: "ALTER TABLE t.public.test ADD COLUMN a INT8 AS (v - 1) STORED, ADD CHECK ((a < v) AND (a IS NOT NULL))", @@ -4867,3 +4867,59 @@ INSERT INTO t.test (k, v) VALUES (1, 99), (2, 100); t.Fatal(err) } } + +// TestFKReferencesAddedOnlyOnceOnRetry verifies that if ALTER TABLE ADD FOREIGN +// KEY is retried, both the FK reference and backreference (on another table) +// are only added once. This is addressed by #38377. +func TestFKReferencesAddedOnlyOnceOnRetry(t *testing.T) { + defer leaktest.AfterTest(t)() + params, _ := tests.CreateTestServerParams() + var runBeforeConstraintValidation func() error + errorReturned := false + params.Knobs = base.TestingKnobs{ + SQLSchemaChanger: &sql.SchemaChangerTestingKnobs{ + RunBeforeConstraintValidation: func() error { + return runBeforeConstraintValidation() + }, + }, + // Disable backfill migrations, we still need the jobs table migration. + SQLMigrationManager: &sqlmigrations.MigrationManagerTestingKnobs{ + DisableBackfillMigrations: true, + }, + } + s, sqlDB, _ := serverutils.StartServer(t, params) + defer s.Stopper().Stop(context.TODO()) + if _, err := sqlDB.Exec(` +CREATE DATABASE t; +CREATE TABLE t.test (k INT PRIMARY KEY, v INT); +CREATE TABLE t.test2 (k INT, INDEX (k)); +`); err != nil { + t.Fatal(err) + } + + // After FK forward references and backreferences are installed, and before + // the validation query is run, return an error so that the schema change + // has to be retried. The error is only returned on the first try. + runBeforeConstraintValidation = func() error { + if !errorReturned { + errorReturned = true + return context.DeadlineExceeded + + } + return nil + } + if _, err := sqlDB.Exec(` +ALTER TABLE t.test2 ADD FOREIGN KEY (k) REFERENCES t.test; +`); err != nil { + t.Fatal(err) + } + + // Table descriptor validation failures, resulting from, e.g., broken or + // duplicated backreferences, are returned by SHOW CONSTRAINTS. + if _, err := sqlDB.Query(`SHOW CONSTRAINTS FROM t.test`); err != nil { + t.Fatal(err) + } + if _, err := sqlDB.Query(`SHOW CONSTRAINTS FROM t.test2`); err != nil { + t.Fatal(err) + } +}