diff --git a/go/vt/wrangler/traffic_switcher.go b/go/vt/wrangler/traffic_switcher.go index e3f45283ee2..4ba9b119052 100644 --- a/go/vt/wrangler/traffic_switcher.go +++ b/go/vt/wrangler/traffic_switcher.go @@ -28,6 +28,7 @@ import ( "vitess.io/vitess/go/json2" "vitess.io/vitess/go/sqlescape" + "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/binlog/binlogplayer" "vitess.io/vitess/go/vt/concurrency" "vitess.io/vitess/go/vt/discovery" @@ -56,11 +57,31 @@ const ( renameTableTemplate = "_%.59s_old" // limit table name to 64 characters sqlDeleteWorkflow = "delete from _vt.vreplication where db_name = %s and workflow = %s" + + sqlGetMaxSequenceVal = "select max(%a) as maxval from %a.%a" + sqlInitSequenceTable = "insert into %a.%a (id, next_id, cache) values (0, %d, 1000) on duplicate key update next_id = if(next_id < %d, %d, next_id)" ) // accessType specifies the type of access for a shard (allow/disallow writes). type accessType int +// sequenceMetadata contains all of the relevant metadata for a sequence that +// is being used by a table involved in a vreplication workflow. +type sequenceMetadata struct { + // The name of the sequence table. + backingTableName string + // The keyspace where the backing table lives. + backingTableKeyspace string + // The dbName in use by the keyspace where the backing table lives. + backingTableDBName string + // The name of the table using the sequence. + usingTableName string + // The dbName in use by the keyspace where the using table lives. + usingTableDBName string + // The using table definition. + usingTableDefinition *vschemapb.Table +} + const ( allowWrites = accessType(iota) disallowWrites @@ -611,6 +632,12 @@ func (wr *Wrangler) SwitchWrites(ctx context.Context, targetKeyspace, workflowNa ts.Logger().Errorf("createJournals failed: %v", err) return 0, nil, err } + // Initialize any target sequences before allowing new writes. + if err := ts.initializeTargetSequenceTables(ctx); err != nil { + werr := vterrors.Wrapf(err, "initializeTargetSequenceTables failed") + ts.Logger().Error(werr) + return 0, nil, werr + } if err := sw.allowTargetWrites(ctx); err != nil { ts.Logger().Errorf("allowTargetWrites failed: %v", err) return 0, nil, err @@ -1886,7 +1913,7 @@ func (ts *trafficSwitcher) isSequenceParticipating(ctx context.Context) (bool, e if err != nil { return false, err } - if vschema == nil || vschema.Tables == nil { + if vschema == nil || vschema.Tables == nil || len(vschema.Tables) == 0 { return false, nil } sequenceFound := false @@ -1903,6 +1930,184 @@ func (ts *trafficSwitcher) isSequenceParticipating(ctx context.Context) (bool, e return sequenceFound, nil } +func (ts *trafficSwitcher) initializeTargetSequenceTables(ctx context.Context) error { + log.Error("DEBUG: initializeTargetSequenceTables") + vschema, err := ts.TopoServer().GetVSchema(ctx, ts.targetKeyspace) + if err != nil { + return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "failed to get vschema for target keyspace %s: %v", + ts.targetKeyspace, err) + } + if vschema == nil || vschema.Tables == nil || len(vschema.Tables) == 0 { // Nothing to do + return nil + } + + // We maintain two maps of the same sequence metadata so + // that we have fast lookups for both the using table and + // the backing sequence table. + sequencesByUsingTable := make(map[string]*sequenceMetadata) + sequencesByBackingTable := make(map[string]*sequenceMetadata) + for _, table := range ts.Tables() { + vs, ok := vschema.Tables[table] + if !ok || vs == nil { + continue + } + if vs.AutoIncrement != nil && vs.AutoIncrement.Sequence != "" { + sm := &sequenceMetadata{ + usingTableName: table, + usingTableDefinition: vs, + backingTableName: vs.AutoIncrement.Sequence, + // TODO: get and set this properly to deal with db_name_overrides + usingTableDBName: "vt_" + ts.targetKeyspace, + } + sequencesByUsingTable[table] = sm + sequencesByBackingTable[vs.AutoIncrement.Sequence] = sm + } + } + if len(sequencesByUsingTable) == 0 { // Nothing to do + return nil + } + + log.Errorf("DEBUG: sequences: %+v", sequencesByUsingTable) + + // Now we need to locate the backing sequence tables which will + // be in another unsharded keyspace. + keyspaces, err := ts.TopoServer().GetKeyspaces(ctx) + if err != nil { + return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "failed to get keyspaces: %v", err) + } + log.Errorf("DEBUG: keyspaces: %+v", keyspaces) + for _, keyspace := range keyspaces { + vschema, err = ts.TopoServer().GetVSchema(ctx, keyspace) + if err != nil { + return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "failed to get vschema for keyspace %s: %v", + keyspace, err) + } + if vschema == nil || vschema.Sharded || vschema.Tables == nil || len(vschema.Tables) == 0 { + continue + } + for tableName, tableDef := range vschema.Tables { + sm := sequencesByBackingTable[tableName] + if tableDef != nil && tableDef.Type == vindexes.TypeSequence && + sm != nil && tableName == sm.backingTableName { + // If the sequence backing table is being moved then we do not + // want to initialize it. + if keyspace == ts.targetKeyspace { + delete(sequencesByBackingTable, tableName) + delete(sequencesByUsingTable, tableName) + continue + } + sm.backingTableKeyspace = keyspace + // TODO: get and set this properly in order to deal with db_name_overrides + sm.backingTableDBName = "vt_" + keyspace + } + } + } + // Now we need to make sure we found all of the backing sequence tables. + for _, sm := range sequencesByUsingTable { + if sm.backingTableKeyspace == "" { + return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "failed to locate all of the backing sequence tables being used; sequence tables metadata: %+v", + sequencesByUsingTable) + } + } + log.Errorf("DEBUG: sequence backing tables: %+v", sequencesByBackingTable) + + // Now we need to initialize the backing sequence tables so that + // the next values they generate are greater than those that + // currently exist in the using table on the target keyspace. + for sequenceTableName, sequenceMetadata := range sequencesByBackingTable { + log.Errorf("DEBUG: sequence table: %v, sequenceMetadata: %+v", sequenceTableName, sequenceMetadata) + // Now we need to run this query on the target shards in order + // to get the max value and set the next id for the sequence to + // a higher value. + shardResults := make([]int64, 0, len(ts.TargetShards())) + srMu := sync.Mutex{} + err = ts.ForAllTargets(func(target *workflow.MigrationTarget) error { + query := sqlparser.BuildParsedQuery(sqlGetMaxSequenceVal, + sqlescape.EscapeID(sequenceMetadata.usingTableDefinition.AutoIncrement.Column), + sqlescape.EscapeID(sequenceMetadata.usingTableDBName), + sqlescape.EscapeID(sequenceMetadata.usingTableName), + ) + ts.Logger().Errorf("DEBUG: query: %s on shard: %s", query.Query, target.GetShard().ShardName()) + qr, err := ts.wr.ExecuteFetchAsApp(ctx, target.GetPrimary().GetAlias(), true, query.Query, 1) + if err != nil || len(qr.Rows) != 1 { + return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "failed to get max used value for target table %s in order to initialize the backing sequence table %s: %v", + sequenceMetadata.usingTableName, ts.targetKeyspace, err) + } + maxID, err := sqltypes.Proto3ToResult(qr).Rows[0][0].ToInt64() + if err != nil { + return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "failed to get max used value for target table %s in order to initialize the backing sequence table %s: %v", + sequenceMetadata.usingTableName, ts.targetKeyspace, err) + } + ts.Logger().Errorf("DEBUG: max ID seen on shard %s: %d", target.GetShard().ShardName(), maxID) + srMu.Lock() + shardResults = append(shardResults, maxID) + srMu.Unlock() + + return nil + }) + if err != nil { + return err + } + // Sort the values to find the max value across all shards. + sort.Slice(shardResults, func(i, j int) bool { + return shardResults[i] < shardResults[j] + }) + nextVal := shardResults[len(shardResults)-1] + 1 + // Now we need to update the sequence table, if needed, in order to + // ensure that that the next value it provides is > the current max. + query := sqlparser.BuildParsedQuery(sqlInitSequenceTable, + sqlescape.EscapeID(sequenceMetadata.backingTableDBName), + sqlescape.EscapeID(sequenceMetadata.backingTableName), + nextVal, + nextVal, + nextVal, + ) + log.Errorf("DEBUG: query: %s", query.Query) + // Execute this on the primary tablet of the keyspace housing + // the backing table. + sequenceShard, err := ts.wr.TopoServer().GetOnlyShard(ctx, sequenceMetadata.backingTableKeyspace) + if err != nil || sequenceShard == nil || sequenceShard.PrimaryAlias == nil { + return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "failed to get the primary tablet for keyspace %s: %v", + sequenceMetadata.backingTableKeyspace, err) + } + _, err = ts.wr.ExecuteFetchAsApp(ctx, sequenceShard.PrimaryAlias, true, query.Query, 1) + if err != nil { + return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "failed to initialize the sequence table %s.%s: %v", + sequenceMetadata.backingTableDBName, sequenceMetadata.backingTableName, err) + } + } + + // Now force the primary tablets managing the sequences to refresh their + // sequence caches for the tables we're moving. + ksDone := make(map[string]bool) + for _, sm := range sequencesByUsingTable { + if ksDone[sm.backingTableKeyspace] { + continue + } + si, err := ts.TopoServer().GetOnlyShard(ctx, sm.backingTableKeyspace) + if err != nil { + return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "failed to get shard for keyspace %s: %v", + sm.backingTableKeyspace, err) + } + ts.Logger().Infof("Resetting sequence caches for shard %s.%s on tablet %s", + si.Keyspace(), si.ShardName(), si.PrimaryAlias) + ti, err := ts.TopoServer().GetTablet(ctx, si.PrimaryAlias) + if err != nil { + return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "failed to get primary tablet for keyspace %s: %v", + sm.backingTableKeyspace, err) + } + err = ts.TabletManagerClient().ResetSequences(ctx, ti.Tablet, ts.Tables()) + if err != nil { + return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "failed to reset sequence caches for shard %s.%s on tablet %s: %v", + si.Keyspace(), si.ShardName(), si.PrimaryAlias, err) + } + ksDone[sm.backingTableKeyspace] = true + } + + // We completed the work w/o errors. + return nil +} + func (ts *trafficSwitcher) mustResetSequences(ctx context.Context) (bool, error) { switch ts.workflowType { case binlogdatapb.VReplicationWorkflowType_Migrate,