diff --git a/br/pkg/lightning/backend/backend.go b/br/pkg/lightning/backend/backend.go index f8b3e79132aa9..a08ba27146eb0 100644 --- a/br/pkg/lightning/backend/backend.go +++ b/br/pkg/lightning/backend/backend.go @@ -125,16 +125,43 @@ type CheckCtx struct { DBMetas []*mydump.MDDatabaseMeta } +// TargetInfoGetter defines the interfaces to get target information. +type TargetInfoGetter interface { + // FetchRemoteTableModels obtains the models of all tables given the schema + // name. The returned table info does not need to be precise if the encoder, + // is not requiring them, but must at least fill in the following fields for + // TablesFromMeta to succeed: + // - Name + // - State (must be model.StatePublic) + // - ID + // - Columns + // * Name + // * State (must be model.StatePublic) + // * Offset (must be 0, 1, 2, ...) + // - PKIsHandle (true = do not generate _tidb_rowid) + FetchRemoteTableModels(ctx context.Context, schemaName string) ([]*model.TableInfo, error) + + // CheckRequirements performs the check whether the backend satisfies the version requirements + CheckRequirements(ctx context.Context, checkCtx *CheckCtx) error +} + +// EncodingBuilder consists of operations to handle encoding backend row data formats from source. +type EncodingBuilder interface { + // NewEncoder creates an encoder of a TiDB table. + NewEncoder(ctx context.Context, tbl table.Table, options *kv.SessionOptions) (kv.Encoder, error) + // MakeEmptyRows creates an empty collection of encoded rows. + MakeEmptyRows() kv.Rows +} + // AbstractBackend is the abstract interface behind Backend. // Implementations of this interface must be goroutine safe: you can share an // instance and execute any method anywhere. type AbstractBackend interface { + EncodingBuilder + TargetInfoGetter // Close the connection to the backend. Close() - // MakeEmptyRows creates an empty collection of encoded rows. - MakeEmptyRows() kv.Rows - // RetryImportDelay returns the duration to sleep when retrying an import RetryImportDelay() time.Duration @@ -142,9 +169,6 @@ type AbstractBackend interface { // performed for this backend. Post-processing includes checksum and analyze. ShouldPostProcess() bool - // NewEncoder creates an encoder of a TiDB table. - NewEncoder(ctx context.Context, tbl table.Table, options *kv.SessionOptions) (kv.Encoder, error) - OpenEngine(ctx context.Context, config *EngineConfig, engineUUID uuid.UUID) error CloseEngine(ctx context.Context, config *EngineConfig, engineUUID uuid.UUID) error @@ -156,24 +180,6 @@ type AbstractBackend interface { CleanupEngine(ctx context.Context, engineUUID uuid.UUID) error - // CheckRequirements performs the check whether the backend satisfies the - // version requirements - CheckRequirements(ctx context.Context, checkCtx *CheckCtx) error - - // FetchRemoteTableModels obtains the models of all tables given the schema - // name. The returned table info does not need to be precise if the encoder, - // is not requiring them, but must at least fill in the following fields for - // TablesFromMeta to succeed: - // - Name - // - State (must be model.StatePublic) - // - ID - // - Columns - // * Name - // * State (must be model.StatePublic) - // * Offset (must be 0, 1, 2, ...) - // - PKIsHandle (true = do not generate _tidb_rowid) - FetchRemoteTableModels(ctx context.Context, schemaName string) ([]*model.TableInfo, error) - // FlushEngine ensures all KV pairs written to an open engine has been // synchronized, such that kill-9'ing Lightning afterwards and resuming from // checkpoint can recover the exact same content. diff --git a/br/pkg/lightning/backend/local/local.go b/br/pkg/lightning/backend/local/local.go index d5d9e9bea7b5c..56f4cd0987c62 100644 --- a/br/pkg/lightning/backend/local/local.go +++ b/br/pkg/lightning/backend/local/local.go @@ -201,6 +201,139 @@ type Range struct { end []byte } +type encodingBuilder struct { + metrics *metric.Metrics +} + +// NewEncodingBuilder creates an KVEncodingBuilder with local backend implementation. +func NewEncodingBuilder(ctx context.Context) backend.EncodingBuilder { + result := new(encodingBuilder) + if m, ok := metric.FromContext(ctx); ok { + result.metrics = m + } + return result +} + +// NewEncoder creates a KV encoder. +// It implements the `backend.EncodingBuilder` interface. +func (b *encodingBuilder) NewEncoder(ctx context.Context, tbl table.Table, options *kv.SessionOptions) (kv.Encoder, error) { + return kv.NewTableKVEncoder(tbl, options, b.metrics, log.FromContext(ctx)) +} + +// MakeEmptyRows creates an empty KV rows. +// It implements the `backend.EncodingBuilder` interface. +func (b *encodingBuilder) MakeEmptyRows() kv.Rows { + return kv.MakeRowsFromKvPairs(nil) +} + +type targetInfoGetter struct { + tls *common.TLS + targetDBGlue glue.Glue + pdAddr string +} + +// NewTargetInfoGetter creates an TargetInfoGetter with local backend implementation. +func NewTargetInfoGetter(tls *common.TLS, g glue.Glue, pdAddr string) backend.TargetInfoGetter { + return &targetInfoGetter{ + tls: tls, + targetDBGlue: g, + pdAddr: pdAddr, + } +} + +// FetchRemoteTableModels obtains the models of all tables given the schema name. +// It implements the `TargetInfoGetter` interface. +func (g *targetInfoGetter) FetchRemoteTableModels(ctx context.Context, schemaName string) ([]*model.TableInfo, error) { + return tikv.FetchRemoteTableModelsFromTLS(ctx, g.tls, schemaName) +} + +// CheckRequirements performs the check whether the backend satisfies the version requirements. +// It implements the `TargetInfoGetter` interface. +func (g *targetInfoGetter) CheckRequirements(ctx context.Context, checkCtx *backend.CheckCtx) error { + // TODO: support lightning via SQL + db, _ := g.targetDBGlue.GetDB() + versionStr, err := version.FetchVersion(ctx, db) + if err != nil { + return errors.Trace(err) + } + if err := checkTiDBVersion(ctx, versionStr, localMinTiDBVersion, localMaxTiDBVersion); err != nil { + return err + } + if err := tikv.CheckPDVersion(ctx, g.tls, g.pdAddr, localMinPDVersion, localMaxPDVersion); err != nil { + return err + } + if err := tikv.CheckTiKVVersion(ctx, g.tls, g.pdAddr, localMinTiKVVersion, localMaxTiKVVersion); err != nil { + return err + } + + serverInfo := version.ParseServerInfo(versionStr) + return checkTiFlashVersion(ctx, g.targetDBGlue, checkCtx, *serverInfo.ServerVersion) +} + +func checkTiDBVersion(_ context.Context, versionStr string, requiredMinVersion, requiredMaxVersion semver.Version) error { + return version.CheckTiDBVersion(versionStr, requiredMinVersion, requiredMaxVersion) +} + +var tiFlashReplicaQuery = "SELECT TABLE_SCHEMA, TABLE_NAME FROM information_schema.TIFLASH_REPLICA WHERE REPLICA_COUNT > 0;" + +type tblName struct { + schema string + name string +} + +type tblNames []tblName + +func (t tblNames) String() string { + var b strings.Builder + b.WriteByte('[') + for i, n := range t { + if i > 0 { + b.WriteString(", ") + } + b.WriteString(common.UniqueTable(n.schema, n.name)) + } + b.WriteByte(']') + return b.String() +} + +// check TiFlash replicas. +// local backend doesn't support TiFlash before tidb v4.0.5 +func checkTiFlashVersion(ctx context.Context, g glue.Glue, checkCtx *backend.CheckCtx, tidbVersion semver.Version) error { + if tidbVersion.Compare(tiFlashMinVersion) >= 0 { + return nil + } + + res, err := g.GetSQLExecutor().QueryStringsWithLog(ctx, tiFlashReplicaQuery, "fetch tiflash replica info", log.FromContext(ctx)) + if err != nil { + return errors.Annotate(err, "fetch tiflash replica info failed") + } + + tiFlashTablesMap := make(map[tblName]struct{}, len(res)) + for _, tblInfo := range res { + name := tblName{schema: tblInfo[0], name: tblInfo[1]} + tiFlashTablesMap[name] = struct{}{} + } + + tiFlashTables := make(tblNames, 0) + for _, dbMeta := range checkCtx.DBMetas { + for _, tblMeta := range dbMeta.Tables { + if len(tblMeta.DataFiles) == 0 { + continue + } + name := tblName{schema: tblMeta.DB, name: tblMeta.Name} + if _, ok := tiFlashTablesMap[name]; ok { + tiFlashTables = append(tiFlashTables, name) + } + } + } + + if len(tiFlashTables) > 0 { + helpInfo := "Please either upgrade TiDB to version >= 4.0.5 or add TiFlash replica after load data." + return errors.Errorf("lightning local backend doesn't support TiFlash in this TiDB version. conflict tables: %s. "+helpInfo, tiFlashTables) + } + return nil +} + type local struct { engines sync.Map // sync version of map[uuid.UUID]*Engine @@ -236,6 +369,9 @@ type local struct { metrics *metric.Metrics writeLimiter StoreWriteLimiter logger log.Logger + + encBuilder backend.EncodingBuilder + targetInfoGetter backend.TargetInfoGetter } func openDuplicateDB(storeDir string) (*pebble.DB, error) { @@ -344,6 +480,8 @@ func NewLocalBackend( bufferPool: membuf.NewPool(membuf.WithAllocator(manual.Allocator{})), writeLimiter: writeLimiter, logger: log.FromContext(ctx), + encBuilder: NewEncodingBuilder(ctx), + targetInfoGetter: NewTargetInfoGetter(tls, g, cfg.TiDB.PdAddr), } if m, ok := metric.FromContext(ctx); ok { local.metrics = m @@ -1652,100 +1790,19 @@ func (local *local) CleanupEngine(ctx context.Context, engineUUID uuid.UUID) err } func (local *local) CheckRequirements(ctx context.Context, checkCtx *backend.CheckCtx) error { - // TODO: support lightning via SQL - db, _ := local.g.GetDB() - versionStr, err := version.FetchVersion(ctx, db) - if err != nil { - return errors.Trace(err) - } - if err := checkTiDBVersion(ctx, versionStr, localMinTiDBVersion, localMaxTiDBVersion); err != nil { - return err - } - if err := tikv.CheckPDVersion(ctx, local.tls, local.pdAddr, localMinPDVersion, localMaxPDVersion); err != nil { - return err - } - if err := tikv.CheckTiKVVersion(ctx, local.tls, local.pdAddr, localMinTiKVVersion, localMaxTiKVVersion); err != nil { - return err - } - - serverInfo := version.ParseServerInfo(versionStr) - return checkTiFlashVersion(ctx, local.g, checkCtx, *serverInfo.ServerVersion) -} - -func checkTiDBVersion(_ context.Context, versionStr string, requiredMinVersion, requiredMaxVersion semver.Version) error { - return version.CheckTiDBVersion(versionStr, requiredMinVersion, requiredMaxVersion) -} - -var tiFlashReplicaQuery = "SELECT TABLE_SCHEMA, TABLE_NAME FROM information_schema.TIFLASH_REPLICA WHERE REPLICA_COUNT > 0;" - -type tblName struct { - schema string - name string -} - -type tblNames []tblName - -func (t tblNames) String() string { - var b strings.Builder - b.WriteByte('[') - for i, n := range t { - if i > 0 { - b.WriteString(", ") - } - b.WriteString(common.UniqueTable(n.schema, n.name)) - } - b.WriteByte(']') - return b.String() -} - -// check TiFlash replicas. -// local backend doesn't support TiFlash before tidb v4.0.5 -func checkTiFlashVersion(ctx context.Context, g glue.Glue, checkCtx *backend.CheckCtx, tidbVersion semver.Version) error { - if tidbVersion.Compare(tiFlashMinVersion) >= 0 { - return nil - } - - res, err := g.GetSQLExecutor().QueryStringsWithLog(ctx, tiFlashReplicaQuery, "fetch tiflash replica info", log.FromContext(ctx)) - if err != nil { - return errors.Annotate(err, "fetch tiflash replica info failed") - } - - tiFlashTablesMap := make(map[tblName]struct{}, len(res)) - for _, tblInfo := range res { - name := tblName{schema: tblInfo[0], name: tblInfo[1]} - tiFlashTablesMap[name] = struct{}{} - } - - tiFlashTables := make(tblNames, 0) - for _, dbMeta := range checkCtx.DBMetas { - for _, tblMeta := range dbMeta.Tables { - if len(tblMeta.DataFiles) == 0 { - continue - } - name := tblName{schema: tblMeta.DB, name: tblMeta.Name} - if _, ok := tiFlashTablesMap[name]; ok { - tiFlashTables = append(tiFlashTables, name) - } - } - } - - if len(tiFlashTables) > 0 { - helpInfo := "Please either upgrade TiDB to version >= 4.0.5 or add TiFlash replica after load data." - return errors.Errorf("lightning local backend doesn't support TiFlash in this TiDB version. conflict tables: %s. "+helpInfo, tiFlashTables) - } - return nil + return local.targetInfoGetter.CheckRequirements(ctx, checkCtx) } func (local *local) FetchRemoteTableModels(ctx context.Context, schemaName string) ([]*model.TableInfo, error) { - return tikv.FetchRemoteTableModelsFromTLS(ctx, local.tls, schemaName) + return local.targetInfoGetter.FetchRemoteTableModels(ctx, schemaName) } func (local *local) MakeEmptyRows() kv.Rows { - return kv.MakeRowsFromKvPairs(nil) + return local.encBuilder.MakeEmptyRows() } func (local *local) NewEncoder(ctx context.Context, tbl table.Table, options *kv.SessionOptions) (kv.Encoder, error) { - return kv.NewTableKVEncoder(tbl, options, local.metrics, log.FromContext(ctx)) + return local.encBuilder.NewEncoder(ctx, tbl, options) } func engineSSTDir(storeDir string, engineUUID uuid.UUID) string { diff --git a/br/pkg/lightning/backend/tidb/tidb.go b/br/pkg/lightning/backend/tidb/tidb.go index 1a9d100d39bd5..826a14bfeb4a9 100644 --- a/br/pkg/lightning/backend/tidb/tidb.go +++ b/br/pkg/lightning/backend/tidb/tidb.go @@ -90,10 +90,163 @@ type tidbEncoder struct { columnCnt int } +type encodingBuilder struct{} + +// NewEncodingBuilder creates an EncodingBuilder with TiDB backend implementation. +func NewEncodingBuilder() backend.EncodingBuilder { + return new(encodingBuilder) +} + +// NewEncoder creates a KV encoder. +// It implements the `backend.EncodingBuilder` interface. +func (b *encodingBuilder) NewEncoder(ctx context.Context, tbl table.Table, options *kv.SessionOptions) (kv.Encoder, error) { + se := kv.NewSession(options, log.FromContext(ctx)) + if options.SQLMode.HasStrictMode() { + se.GetSessionVars().SkipUTF8Check = false + se.GetSessionVars().SkipASCIICheck = false + } + + return &tidbEncoder{mode: options.SQLMode, tbl: tbl, se: se}, nil +} + +// MakeEmptyRows creates an empty KV rows. +// It implements the `backend.EncodingBuilder` interface. +func (b *encodingBuilder) MakeEmptyRows() kv.Rows { + return tidbRows(nil) +} + +type targetInfoGetter struct { + db *sql.DB +} + +// NewTargetInfoGetter creates an TargetInfoGetter with TiDB backend implementation. +func NewTargetInfoGetter(db *sql.DB) backend.TargetInfoGetter { + return &targetInfoGetter{ + db: db, + } +} + +// FetchRemoteTableModels obtains the models of all tables given the schema name. +// It implements the `backend.TargetInfoGetter` interface. +// TODO: refactor +func (b *targetInfoGetter) FetchRemoteTableModels(ctx context.Context, schemaName string) ([]*model.TableInfo, error) { + var err error + tables := []*model.TableInfo{} + s := common.SQLWithRetry{ + DB: b.db, + Logger: log.FromContext(ctx), + } + + err = s.Transact(ctx, "fetch table columns", func(c context.Context, tx *sql.Tx) error { + var versionStr string + if versionStr, err = version.FetchVersion(ctx, tx); err != nil { + return err + } + serverInfo := version.ParseServerInfo(versionStr) + + rows, e := tx.Query(` + SELECT table_name, column_name, column_type, generation_expression, extra + FROM information_schema.columns + WHERE table_schema = ? + ORDER BY table_name, ordinal_position; + `, schemaName) + if e != nil { + return e + } + defer rows.Close() + + var ( + curTableName string + curColOffset int + curTable *model.TableInfo + ) + for rows.Next() { + var tableName, columnName, columnType, generationExpr, columnExtra string + if e := rows.Scan(&tableName, &columnName, &columnType, &generationExpr, &columnExtra); e != nil { + return e + } + if tableName != curTableName { + curTable = &model.TableInfo{ + Name: model.NewCIStr(tableName), + State: model.StatePublic, + PKIsHandle: true, + } + tables = append(tables, curTable) + curTableName = tableName + curColOffset = 0 + } + + // see: https://github.com/pingcap/parser/blob/3b2fb4b41d73710bc6c4e1f4e8679d8be6a4863e/types/field_type.go#L185-L191 + var flag uint + if strings.HasSuffix(columnType, "unsigned") { + flag |= mysql.UnsignedFlag + } + if strings.Contains(columnExtra, "auto_increment") { + flag |= mysql.AutoIncrementFlag + } + + ft := types.FieldType{} + ft.SetFlag(flag) + curTable.Columns = append(curTable.Columns, &model.ColumnInfo{ + Name: model.NewCIStr(columnName), + Offset: curColOffset, + State: model.StatePublic, + FieldType: ft, + GeneratedExprString: generationExpr, + }) + curColOffset++ + } + if err := rows.Err(); err != nil { + return err + } + // shard_row_id/auto random is only available after tidb v4.0.0 + // `show table next_row_id` is also not available before tidb v4.0.0 + if serverInfo.ServerType != version.ServerTypeTiDB || serverInfo.ServerVersion.Major < 4 { + return nil + } + + // init auto id column for each table + for _, tbl := range tables { + tblName := common.UniqueTable(schemaName, tbl.Name.O) + autoIDInfos, err := FetchTableAutoIDInfos(ctx, tx, tblName) + if err != nil { + return errors.Trace(err) + } + for _, info := range autoIDInfos { + for _, col := range tbl.Columns { + if col.Name.O == info.Column { + switch info.Type { + case "AUTO_INCREMENT": + col.AddFlag(mysql.AutoIncrementFlag) + case "AUTO_RANDOM": + col.AddFlag(mysql.PriKeyFlag) + tbl.PKIsHandle = true + // set a stub here, since we don't really need the real value + tbl.AutoRandomBits = 1 + } + } + } + } + + } + return nil + }) + return tables, err +} + +// CheckRequirements performs the check whether the backend satisfies the version requirements. +// It implements the `backend.TargetInfoGetter` interface. +func (b *targetInfoGetter) CheckRequirements(ctx context.Context, _ *backend.CheckCtx) error { + log.FromContext(ctx).Info("skipping check requirements for tidb backend") + return nil +} + type tidbBackend struct { - db *sql.DB - onDuplicate string - errorMgr *errormanager.ErrorManager + db *sql.DB + onDuplicate string + errorMgr *errormanager.ErrorManager + encBuilder backend.EncodingBuilder + targetInfoGetter backend.TargetInfoGetter } // NewTiDBBackend creates a new TiDB backend using the given database. @@ -107,7 +260,13 @@ func NewTiDBBackend(ctx context.Context, db *sql.DB, onDuplicate string, errorMg log.FromContext(ctx).Warn("unsupported action on duplicate, overwrite with `replace`") onDuplicate = config.ReplaceOnDup } - return backend.MakeBackend(&tidbBackend{db: db, onDuplicate: onDuplicate, errorMgr: errorMgr}) + return backend.MakeBackend(&tidbBackend{ + db: db, + onDuplicate: onDuplicate, + errorMgr: errorMgr, + encBuilder: NewEncodingBuilder(), + targetInfoGetter: NewTargetInfoGetter(db), + }) } func (row tidbRow) Size() uint64 { @@ -375,7 +534,7 @@ func (be *tidbBackend) Close() { } func (be *tidbBackend) MakeEmptyRows() kv.Rows { - return tidbRows(nil) + return be.encBuilder.MakeEmptyRows() } func (be *tidbBackend) RetryImportDelay() time.Duration { @@ -394,18 +553,11 @@ func (be *tidbBackend) ShouldPostProcess() bool { } func (be *tidbBackend) CheckRequirements(ctx context.Context, _ *backend.CheckCtx) error { - log.FromContext(ctx).Info("skipping check requirements for tidb backend") - return nil + return be.targetInfoGetter.CheckRequirements(ctx, nil) } func (be *tidbBackend) NewEncoder(ctx context.Context, tbl table.Table, options *kv.SessionOptions) (kv.Encoder, error) { - se := kv.NewSession(options, log.FromContext(ctx)) - if options.SQLMode.HasStrictMode() { - se.GetSessionVars().SkipUTF8Check = false - se.GetSessionVars().SkipASCIICheck = false - } - - return &tidbEncoder{mode: options.SQLMode, tbl: tbl, se: se}, nil + return be.encBuilder.NewEncoder(ctx, tbl, options) } func (be *tidbBackend) OpenEngine(context.Context, *backend.EngineConfig, uuid.UUID) error { @@ -583,108 +735,8 @@ func (be *tidbBackend) execStmts(ctx context.Context, stmtTasks []stmtTask, tabl return nil } -//nolint:nakedret // TODO: refactor -func (be *tidbBackend) FetchRemoteTableModels(ctx context.Context, schemaName string) (tables []*model.TableInfo, err error) { - s := common.SQLWithRetry{ - DB: be.db, - Logger: log.FromContext(ctx), - } - - err = s.Transact(ctx, "fetch table columns", func(c context.Context, tx *sql.Tx) error { - var versionStr string - if versionStr, err = version.FetchVersion(ctx, tx); err != nil { - return err - } - serverInfo := version.ParseServerInfo(versionStr) - - rows, e := tx.Query(` - SELECT table_name, column_name, column_type, generation_expression, extra - FROM information_schema.columns - WHERE table_schema = ? - ORDER BY table_name, ordinal_position; - `, schemaName) - if e != nil { - return e - } - defer rows.Close() - - var ( - curTableName string - curColOffset int - curTable *model.TableInfo - ) - for rows.Next() { - var tableName, columnName, columnType, generationExpr, columnExtra string - if e := rows.Scan(&tableName, &columnName, &columnType, &generationExpr, &columnExtra); e != nil { - return e - } - if tableName != curTableName { - curTable = &model.TableInfo{ - Name: model.NewCIStr(tableName), - State: model.StatePublic, - PKIsHandle: true, - } - tables = append(tables, curTable) - curTableName = tableName - curColOffset = 0 - } - - // see: https://github.com/pingcap/parser/blob/3b2fb4b41d73710bc6c4e1f4e8679d8be6a4863e/types/field_type.go#L185-L191 - var flag uint - if strings.HasSuffix(columnType, "unsigned") { - flag |= mysql.UnsignedFlag - } - if strings.Contains(columnExtra, "auto_increment") { - flag |= mysql.AutoIncrementFlag - } - - ft := types.FieldType{} - ft.SetFlag(flag) - curTable.Columns = append(curTable.Columns, &model.ColumnInfo{ - Name: model.NewCIStr(columnName), - Offset: curColOffset, - State: model.StatePublic, - FieldType: ft, - GeneratedExprString: generationExpr, - }) - curColOffset++ - } - if err := rows.Err(); err != nil { - return err - } - // shard_row_id/auto random is only available after tidb v4.0.0 - // `show table next_row_id` is also not available before tidb v4.0.0 - if serverInfo.ServerType != version.ServerTypeTiDB || serverInfo.ServerVersion.Major < 4 { - return nil - } - - // init auto id column for each table - for _, tbl := range tables { - tblName := common.UniqueTable(schemaName, tbl.Name.O) - autoIDInfos, err := FetchTableAutoIDInfos(ctx, tx, tblName) - if err != nil { - return errors.Trace(err) - } - for _, info := range autoIDInfos { - for _, col := range tbl.Columns { - if col.Name.O == info.Column { - switch info.Type { - case "AUTO_INCREMENT": - col.AddFlag(mysql.AutoIncrementFlag) - case "AUTO_RANDOM": - col.AddFlag(mysql.PriKeyFlag) - tbl.PKIsHandle = true - // set a stub here, since we don't really need the real value - tbl.AutoRandomBits = 1 - } - } - } - } - - } - return nil - }) - return +func (be *tidbBackend) FetchRemoteTableModels(ctx context.Context, schemaName string) ([]*model.TableInfo, error) { + return be.targetInfoGetter.FetchRemoteTableModels(ctx, schemaName) } func (be *tidbBackend) EngineFileSizes() []backend.EngineFileSize { diff --git a/br/pkg/lightning/mydump/loader.go b/br/pkg/lightning/mydump/loader.go index 40091c61b2d03..a16ad88de76c2 100644 --- a/br/pkg/lightning/mydump/loader.go +++ b/br/pkg/lightning/mydump/loader.go @@ -38,6 +38,13 @@ type MDDatabaseMeta struct { charSet string } +// NewMDDatabaseMeta creates an Mydumper database meta with specified character set. +func NewMDDatabaseMeta(charSet string) *MDDatabaseMeta { + return &MDDatabaseMeta{ + charSet: charSet, + } +} + func (m *MDDatabaseMeta) GetSchema(ctx context.Context, store storage.ExternalStorage) string { if m.SchemaFile.FileMeta.Path != "" { schema, err := ExportStatement(ctx, store, m.SchemaFile, m.charSet) @@ -73,6 +80,13 @@ type SourceFileMeta struct { FileSize int64 } +// NewMDTableMeta creates an Mydumper table meta with specified character set. +func NewMDTableMeta(charSet string) *MDTableMeta { + return &MDTableMeta{ + charSet: charSet, + } +} + func (m *MDTableMeta) GetSchema(ctx context.Context, store storage.ExternalStorage) (string, error) { schema, err := ExportStatement(ctx, store, m.SchemaFile, m.charSet) if err != nil { diff --git a/br/pkg/lightning/restore/check_info.go b/br/pkg/lightning/restore/check_info.go index 2be105a157fac..6927a48d09488 100644 --- a/br/pkg/lightning/restore/check_info.go +++ b/br/pkg/lightning/restore/check_info.go @@ -15,11 +15,8 @@ package restore import ( - "bytes" "context" - "database/sql" "fmt" - "io" "path/filepath" "reflect" "strconv" @@ -28,24 +25,18 @@ import ( "github.com/docker/go-units" "github.com/pingcap/errors" - "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/metapb" - "github.com/pingcap/tidb/br/pkg/lightning/backend" "github.com/pingcap/tidb/br/pkg/lightning/backend/kv" "github.com/pingcap/tidb/br/pkg/lightning/checkpoints" "github.com/pingcap/tidb/br/pkg/lightning/common" "github.com/pingcap/tidb/br/pkg/lightning/config" - "github.com/pingcap/tidb/br/pkg/lightning/errormanager" "github.com/pingcap/tidb/br/pkg/lightning/log" "github.com/pingcap/tidb/br/pkg/lightning/mydump" - "github.com/pingcap/tidb/br/pkg/lightning/verification" "github.com/pingcap/tidb/br/pkg/storage" - "github.com/pingcap/tidb/br/pkg/utils" "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/store/pdtypes" "github.com/pingcap/tidb/table" - "github.com/pingcap/tidb/table/tables" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/mathutil" "go.uber.org/zap" @@ -76,21 +67,20 @@ func (rc *Controller) isSourceInLocal() bool { } func (rc *Controller) getReplicaCount(ctx context.Context) (uint64, error) { - result := &pdtypes.ReplicationConfig{} - err := rc.tls.WithHost(rc.cfg.TiDB.PdAddr).GetJSON(ctx, pdReplicate, &result) + replConfig, err := rc.preInfoGetter.GetReplicationConfig(ctx) if err != nil { return 0, errors.Trace(err) } - return result.MaxReplicas, nil + return replConfig.MaxReplicas, nil } func (rc *Controller) getClusterAvail(ctx context.Context) (uint64, error) { - result := &pdtypes.StoresInfo{} - if err := rc.tls.WithHost(rc.cfg.TiDB.PdAddr).GetJSON(ctx, pdStores, result); err != nil { + storeInfo, err := rc.preInfoGetter.GetStorageInfo(ctx) + if err != nil { return 0, errors.Trace(err) } clusterAvail := uint64(0) - for _, store := range result.Stores { + for _, store := range storeInfo.Stores { clusterAvail += uint64(store.Status.Available) } return clusterAvail, nil @@ -171,10 +161,8 @@ func (rc *Controller) ClusterIsAvailable(ctx context.Context) { defer func() { rc.checkTemplate.Collect(Critical, passed, message) }() - checkCtx := &backend.CheckCtx{ - DBMetas: rc.dbMetas, - } - if err := rc.backend.CheckRequirements(ctx, checkCtx); err != nil { + checkCtx := WithPreInfoGetterDBMetas(ctx, rc.dbMetas) + if err := rc.preInfoGetter.CheckVersionRequirements(checkCtx); err != nil { err = common.NormalizeError(err) passed = false message = fmt.Sprintf("cluster available check failed: %s", err.Error()) @@ -196,22 +184,24 @@ func (rc *Controller) checkEmptyRegion(ctx context.Context) error { defer func() { rc.checkTemplate.Collect(Critical, passed, message) }() - storeInfo := &pdtypes.StoresInfo{} - err := rc.tls.WithHost(rc.cfg.TiDB.PdAddr).GetJSON(ctx, pdStores, storeInfo) + dbInfos, err := rc.preInfoGetter.GetAllTableStructures(ctx) + if err != nil { + return errors.Trace(err) + } + storeInfo, err := rc.preInfoGetter.GetStorageInfo(ctx) if err != nil { return errors.Trace(err) } if len(storeInfo.Stores) <= 1 { return nil } - - var result pdtypes.RegionsInfo - if err := rc.tls.WithHost(rc.cfg.TiDB.PdAddr).GetJSON(ctx, pdEmptyRegions, &result); err != nil { + emptyRegionsInfo, err := rc.preInfoGetter.GetEmptyRegionsInfo(ctx) + if err != nil { return errors.Trace(err) } regions := make(map[uint64]int) stores := make(map[uint64]*pdtypes.StoreInfo) - for _, region := range result.Regions { + for _, region := range emptyRegionsInfo.Regions { for _, peer := range region.Peers { regions[peer.StoreId]++ } @@ -221,7 +211,7 @@ func (rc *Controller) checkEmptyRegion(ctx context.Context) error { } tableCount := 0 for _, db := range rc.dbMetas { - info, ok := rc.dbInfos[db.Name] + info, ok := dbInfos[db.Name] if !ok { continue } @@ -273,13 +263,12 @@ func (rc *Controller) checkRegionDistribution(ctx context.Context) error { rc.checkTemplate.Collect(Critical, passed, message) }() - result := &pdtypes.StoresInfo{} - err := rc.tls.WithHost(rc.cfg.TiDB.PdAddr).GetJSON(ctx, pdStores, result) + storesInfo, err := rc.preInfoGetter.GetStorageInfo(ctx) if err != nil { return errors.Trace(err) } - stores := make([]*pdtypes.StoreInfo, 0, len(result.Stores)) - for _, store := range result.Stores { + stores := make([]*pdtypes.StoreInfo, 0, len(storesInfo.Stores)) + for _, store := range storesInfo.Stores { if metapb.StoreState(metapb.StoreState_value[store.Store.StateName]) != metapb.StoreState_Up { continue } @@ -296,9 +285,14 @@ func (rc *Controller) checkRegionDistribution(ctx context.Context) error { }) minStore := stores[0] maxStore := stores[len(stores)-1] + + dbInfos, err := rc.preInfoGetter.GetAllTableStructures(ctx) + if err != nil { + return errors.Trace(err) + } tableCount := 0 for _, db := range rc.dbMetas { - info, ok := rc.dbInfos[db.Name] + info, ok := dbInfos[db.Name] if !ok { continue } @@ -396,62 +390,12 @@ func (rc *Controller) HasLargeCSV(dbMetas []*mydump.MDDatabaseMeta) { } } -func (rc *Controller) estimateSourceData(ctx context.Context) (int64, error) { - sourceSize := int64(0) - originSource := int64(0) - bigTableCount := 0 - tableCount := 0 - unSortedTableCount := 0 - errMgr := errormanager.New(nil, rc.cfg, log.FromContext(ctx)) - for _, db := range rc.dbMetas { - info, ok := rc.dbInfos[db.Name] - if !ok { - continue - } - for _, tbl := range db.Tables { - originSource += tbl.TotalSize - tableInfo, ok := info.Tables[tbl.Name] - if ok { - // Do not sample small table because there may a large number of small table and it will take a long - // time to sample data for all of them. - if rc.cfg.TikvImporter.Backend == config.BackendTiDB || tbl.TotalSize < int64(config.SplitRegionSize) { - sourceSize += tbl.TotalSize - tbl.IndexRatio = 1.0 - tbl.IsRowOrdered = false - } else { - if err := rc.sampleDataFromTable(ctx, db.Name, tbl, tableInfo.Core, errMgr); err != nil { - return sourceSize, errors.Trace(err) - } - - if tbl.IndexRatio > 0 { - sourceSize += int64(float64(tbl.TotalSize) * tbl.IndexRatio) - } else { - // if sample data failed due to max-error, fallback to use source size - sourceSize += tbl.TotalSize - } - - if tbl.TotalSize > int64(config.DefaultBatchSize)*2 { - bigTableCount += 1 - if !tbl.IsRowOrdered { - unSortedTableCount += 1 - } - } - } - tableCount += 1 - } - } - } - - if rc.status != nil { - rc.status.TotalFileSize.Store(originSource) - } - // Do not import with too large concurrency because these data may be all unsorted. - if bigTableCount > 0 && unSortedTableCount > 0 { - if rc.cfg.App.TableConcurrency > rc.cfg.App.IndexConcurrency { - rc.cfg.App.TableConcurrency = rc.cfg.App.IndexConcurrency - } +func (rc *Controller) estimateSourceData(ctx context.Context) (int64, int64, bool, error) { + result, err := rc.preInfoGetter.EstimateSourceDataSize(ctx) + if err != nil { + return 0, 0, false, errors.Trace(err) } - return sourceSize, nil + return result.SizeWithIndex, result.SizeWithoutIndex, result.HasUnsortedBigTables, nil } // localResource checks the local node has enough resources for this import when local backend enabled; @@ -540,7 +484,8 @@ func (rc *Controller) CheckpointIsValid(ctx context.Context, tableInfo *mydump.M return msgs, false } - dbInfo, ok := rc.dbInfos[tableInfo.DB] + dbInfos, _ := rc.preInfoGetter.GetAllTableStructures(ctx) + dbInfo, ok := dbInfos[tableInfo.DB] if ok { t, ok := dbInfo.Tables[tableInfo.Name] if ok { @@ -573,7 +518,7 @@ func (rc *Controller) CheckpointIsValid(ctx context.Context, tableInfo *mydump.M log.FromContext(ctx).Debug("no valid checkpoint detected", zap.String("table", uniqueName)) return nil, false } - info := rc.dbInfos[tableInfo.DB].Tables[tableInfo.Name] + info := dbInfos[tableInfo.DB].Tables[tableInfo.Name] if info != nil { permFromTiDB, err := parseColumnPermutations(info.Core, columns, nil, log.FromContext(ctx)) if err != nil { @@ -595,48 +540,15 @@ func hasDefault(col *model.ColumnInfo) bool { } func (rc *Controller) readFirstRow(ctx context.Context, dataFileMeta mydump.SourceFileMeta) (cols []string, row []types.Datum, err error) { - var reader storage.ReadSeekCloser - if dataFileMeta.Type == mydump.SourceTypeParquet { - reader, err = mydump.OpenParquetReader(ctx, rc.store, dataFileMeta.Path, dataFileMeta.FileSize) - } else { - reader, err = rc.store.Open(ctx, dataFileMeta.Path) - } + cols, rows, err := rc.preInfoGetter.ReadFirstNRowsByFileMeta(ctx, dataFileMeta, 1) if err != nil { - return nil, nil, errors.Trace(err) + return nil, nil, err } - - var parser mydump.Parser - blockBufSize := int64(rc.cfg.Mydumper.ReadBlockSize) - switch dataFileMeta.Type { - case mydump.SourceTypeCSV: - hasHeader := rc.cfg.Mydumper.CSV.Header - // Create a utf8mb4 convertor to encode and decode data with the charset of CSV files. - charsetConvertor, err := mydump.NewCharsetConvertor(rc.cfg.Mydumper.DataCharacterSet, rc.cfg.Mydumper.DataInvalidCharReplace) - if err != nil { - return nil, nil, errors.Trace(err) - } - parser, err = mydump.NewCSVParser(ctx, &rc.cfg.Mydumper.CSV, reader, blockBufSize, rc.ioWorkers, hasHeader, charsetConvertor) - if err != nil { - return nil, nil, errors.Trace(err) - } - case mydump.SourceTypeSQL: - parser = mydump.NewChunkParser(ctx, rc.cfg.TiDB.SQLMode, reader, blockBufSize, rc.ioWorkers) - case mydump.SourceTypeParquet: - parser, err = mydump.NewParquetParser(ctx, rc.store, reader, dataFileMeta.Path) - if err != nil { - return nil, nil, errors.Trace(err) - } - default: - panic(fmt.Sprintf("unknown file type '%s'", dataFileMeta.Type)) - } - //nolint: errcheck - defer parser.Close() - - err = parser.ReadRow() - if err != nil && errors.Cause(err) != io.EOF { - return nil, nil, errors.Trace(err) + if len(rows) > 0 { + return cols, rows[0], nil + } else { + return cols, []types.Datum{}, nil } - return parser.Columns(), parser.LastRow().Row, nil } // SchemaIsValid checks the import file and cluster schema is match. @@ -647,7 +559,11 @@ func (rc *Controller) SchemaIsValid(ctx context.Context, tableInfo *mydump.MDTab } msgs := make([]string, 0) - info, ok := rc.dbInfos[tableInfo.DB].Tables[tableInfo.Name] + dbInfos, err := rc.preInfoGetter.GetAllTableStructures(ctx) + if err != nil { + return nil, errors.Trace(err) + } + info, ok := dbInfos[tableInfo.DB].Tables[tableInfo.Name] if !ok { msgs = append(msgs, fmt.Sprintf("TiDB schema `%s`.`%s` doesn't exists,"+ "please give a schema file in source dir or create table manually", tableInfo.DB, tableInfo.Name)) @@ -770,6 +686,10 @@ func (rc *Controller) checkCSVHeader(ctx context.Context, dbMetas []*mydump.MDDa csvCount int hasUniqueIdx bool ) + dbInfos, err := rc.preInfoGetter.GetAllTableStructures(ctx) + if err != nil { + return errors.Trace(err) + } // only check one table source files for better performance. The checked table is chosen based on following two factor: // 1. contains at least 1 csv source file, 2 is preferable // 2. table schema contains primary key or unique key @@ -794,7 +714,7 @@ outer: continue } - info := rc.dbInfos[tblMeta.DB].Tables[tblMeta.Name] + info := dbInfos[tblMeta.DB].Tables[tblMeta.Name] for _, idx := range info.Core.Indices { if idx.Primary || idx.Unique { tableHasUniqueIdx = true @@ -855,7 +775,7 @@ outer: // check if some fields are unique and not ignored // if at least one field appears in a unique key, we can sure there is something wrong, // they should be either the header line or the data is duplicated. - tableInfo := rc.dbInfos[tableMeta.DB].Tables[tableMeta.Name] + tableInfo := dbInfos[tableMeta.DB].Tables[tableMeta.Name] tableFields := make(map[string]struct{}) uniqueIdxFields := make(map[string]struct{}) ignoreColumns, err := rc.cfg.Mydumper.IgnoreColumns.GetIgnoreColumns(tableMeta.DB, tableMeta.Name, rc.cfg.Mydumper.CaseSensitive) @@ -936,166 +856,10 @@ func checkFieldCompatibility( return true } -func (rc *Controller) sampleDataFromTable( - ctx context.Context, - dbName string, - tableMeta *mydump.MDTableMeta, - tableInfo *model.TableInfo, - errMgr *errormanager.ErrorManager, -) error { - if len(tableMeta.DataFiles) == 0 { - return nil - } - sampleFile := tableMeta.DataFiles[0].FileMeta - var reader storage.ReadSeekCloser - var err error - if sampleFile.Type == mydump.SourceTypeParquet { - reader, err = mydump.OpenParquetReader(ctx, rc.store, sampleFile.Path, sampleFile.FileSize) - } else { - reader, err = rc.store.Open(ctx, sampleFile.Path) - } - if err != nil { - return errors.Trace(err) - } - idAlloc := kv.NewPanickingAllocators(0) - tbl, err := tables.TableFromMeta(idAlloc, tableInfo) - if err != nil { - return errors.Trace(err) - } - kvEncoder, err := rc.backend.NewEncoder(ctx, tbl, &kv.SessionOptions{ - SQLMode: rc.cfg.TiDB.SQLMode, - Timestamp: 0, - SysVars: rc.sysVars, - AutoRandomSeed: 0, - }) - if err != nil { - return errors.Trace(err) - } - blockBufSize := int64(rc.cfg.Mydumper.ReadBlockSize) - - var parser mydump.Parser - switch tableMeta.DataFiles[0].FileMeta.Type { - case mydump.SourceTypeCSV: - hasHeader := rc.cfg.Mydumper.CSV.Header - // Create a utf8mb4 convertor to encode and decode data with the charset of CSV files. - charsetConvertor, err := mydump.NewCharsetConvertor(rc.cfg.Mydumper.DataCharacterSet, rc.cfg.Mydumper.DataInvalidCharReplace) - if err != nil { - return errors.Trace(err) - } - parser, err = mydump.NewCSVParser(ctx, &rc.cfg.Mydumper.CSV, reader, blockBufSize, rc.ioWorkers, hasHeader, charsetConvertor) - if err != nil { - return errors.Trace(err) - } - case mydump.SourceTypeSQL: - parser = mydump.NewChunkParser(ctx, rc.cfg.TiDB.SQLMode, reader, blockBufSize, rc.ioWorkers) - case mydump.SourceTypeParquet: - parser, err = mydump.NewParquetParser(ctx, rc.store, reader, sampleFile.Path) - if err != nil { - return errors.Trace(err) - } - default: - panic(fmt.Sprintf("file '%s' with unknown source type '%s'", sampleFile.Path, sampleFile.Type.String())) - } - //nolint: errcheck - defer parser.Close() - logTask := log.FromContext(ctx).With(zap.String("table", tableMeta.Name)).Begin(zap.InfoLevel, "sample file") - igCols, err := rc.cfg.Mydumper.IgnoreColumns.GetIgnoreColumns(dbName, tableMeta.Name, rc.cfg.Mydumper.CaseSensitive) - if err != nil { - return errors.Trace(err) - } - - initializedColumns := false - var columnPermutation []int - var kvSize uint64 = 0 - var rowSize uint64 = 0 - rowCount := 0 - dataKVs := rc.backend.MakeEmptyRows() - indexKVs := rc.backend.MakeEmptyRows() - lastKey := make([]byte, 0) - tableMeta.IsRowOrdered = true - tableMeta.IndexRatio = 1.0 -outloop: - for { - offset, _ := parser.Pos() - err = parser.ReadRow() - columnNames := parser.Columns() - - switch errors.Cause(err) { - case nil: - if !initializedColumns { - if len(columnPermutation) == 0 { - columnPermutation, err = createColumnPermutation( - columnNames, - igCols.ColumnsMap(), - tableInfo, - log.FromContext(ctx)) - if err != nil { - return errors.Trace(err) - } - } - initializedColumns = true - } - case io.EOF: - break outloop - default: - err = errors.Annotatef(err, "in file offset %d", offset) - return errors.Trace(err) - } - lastRow := parser.LastRow() - rowCount += 1 - - var dataChecksum, indexChecksum verification.KVChecksum - kvs, encodeErr := kvEncoder.Encode(logTask.Logger, lastRow.Row, lastRow.RowID, columnPermutation, sampleFile.Path, offset) - if encodeErr != nil { - encodeErr = errMgr.RecordTypeError(ctx, log.FromContext(ctx), tableInfo.Name.O, sampleFile.Path, offset, - "" /* use a empty string here because we don't actually record */, encodeErr) - if encodeErr != nil { - return errors.Annotatef(encodeErr, "in file at offset %d", offset) - } - if rowCount < maxSampleRowCount { - continue - } else { - break - } - } - if tableMeta.IsRowOrdered { - kvs.ClassifyAndAppend(&dataKVs, &dataChecksum, &indexKVs, &indexChecksum) - for _, kv := range kv.KvPairsFromRows(dataKVs) { - if len(lastKey) == 0 { - lastKey = kv.Key - } else if bytes.Compare(lastKey, kv.Key) > 0 { - tableMeta.IsRowOrdered = false - break - } - } - dataKVs = dataKVs.Clear() - indexKVs = indexKVs.Clear() - } - kvSize += kvs.Size() - rowSize += uint64(lastRow.Length) - parser.RecycleRow(lastRow) - - failpoint.Inject("mock-kv-size", func(val failpoint.Value) { - kvSize += uint64(val.(int)) - }) - if rowSize > maxSampleDataSize || rowCount > maxSampleRowCount { - break - } - } - - if rowSize > 0 && kvSize > rowSize { - tableMeta.IndexRatio = float64(kvSize) / float64(rowSize) - } - log.FromContext(ctx).Info("Sample source data", zap.String("table", tableMeta.Name), zap.Float64("IndexRatio", tableMeta.IndexRatio), zap.Bool("IsSourceOrder", tableMeta.IsRowOrdered)) - return nil -} - func (rc *Controller) checkTableEmpty(ctx context.Context) error { if rc.cfg.TikvImporter.Backend == config.BackendTiDB || rc.cfg.TikvImporter.IncrementalImport { return nil } - db, _ := rc.tidbGlue.GetDB() - tableCount := 0 for _, db := range rc.dbMetas { tableCount += len(db.Tables) @@ -1104,15 +868,20 @@ func (rc *Controller) checkTableEmpty(ctx context.Context) error { var lock sync.Mutex tableNames := make([]string, 0) concurrency := mathutil.Min(tableCount, rc.cfg.App.RegionConcurrency) - ch := make(chan string, concurrency) + type tableNameComponents struct { + DBName string + TableName string + } + ch := make(chan tableNameComponents, concurrency) eg, gCtx := errgroup.WithContext(ctx) for i := 0; i < concurrency; i++ { eg.Go(func() error { - for tblName := range ch { + for tblNameComp := range ch { + fullTableName := common.UniqueTable(tblNameComp.DBName, tblNameComp.TableName) // skip tables that have checkpoint if rc.cfg.Checkpoint.Enable { - _, err := rc.checkpointsDB.Get(gCtx, tblName) + _, err := rc.checkpointsDB.Get(gCtx, fullTableName) switch { case err == nil: continue @@ -1122,13 +891,13 @@ func (rc *Controller) checkTableEmpty(ctx context.Context) error { } } - hasData, err1 := tableContainsData(gCtx, db, tblName) + isEmptyPtr, err1 := rc.preInfoGetter.IsTableEmpty(gCtx, tblNameComp.DBName, tblNameComp.TableName) if err1 != nil { return err1 } - if hasData { + if !(*isEmptyPtr) { lock.Lock() - tableNames = append(tableNames, tblName) + tableNames = append(tableNames, fullTableName) lock.Unlock() } } @@ -1139,7 +908,7 @@ loop: for _, db := range rc.dbMetas { for _, tbl := range db.Tables { select { - case ch <- common.UniqueTable(tbl.DB, tbl.Name): + case ch <- tableNameComponents{tbl.DB, tbl.Name}: case <-gCtx.Done(): break loop } @@ -1162,25 +931,3 @@ loop: } return nil } - -func tableContainsData(ctx context.Context, db utils.DBExecutor, tableName string) (bool, error) { - failpoint.Inject("CheckTableEmptyFailed", func() { - failpoint.Return(false, errors.New("mock error")) - }) - query := "select 1 from " + tableName + " limit 1" - exec := common.SQLWithRetry{ - DB: db, - Logger: log.FromContext(ctx), - } - var dump int - err := exec.QueryRow(ctx, "check table empty", query, &dump) - - switch { - case errors.ErrorEqual(err, sql.ErrNoRows): - return false, nil - case err != nil: - return false, errors.Trace(err) - default: - return true, nil - } -} diff --git a/br/pkg/lightning/restore/check_info_test.go b/br/pkg/lightning/restore/check_info_test.go index abdfcf232f0a9..2da71e4c84b90 100644 --- a/br/pkg/lightning/restore/check_info_test.go +++ b/br/pkg/lightning/restore/check_info_test.go @@ -336,10 +336,18 @@ func TestCheckCSVHeader(t *testing.T) { }, }, } + + ioWorkers := worker.NewPool(context.Background(), 1, "io") + preInfoGetter := &PreRestoreInfoGetterImpl{ + cfg: cfg, + srcStorage: mockStore, + ioWorkers: ioWorkers, + } rc := &Controller{ - cfg: cfg, - store: mockStore, - ioWorkers: worker.NewPool(context.Background(), 1, "io"), + cfg: cfg, + store: mockStore, + ioWorkers: ioWorkers, + preInfoGetter: preInfoGetter, } p := parser.New() @@ -398,7 +406,7 @@ func TestCheckCSVHeader(t *testing.T) { }) } - err := rc.checkCSVHeader(ctx, dbMetas) + err := rc.checkCSVHeader(WithPreInfoGetterTableStructuresCache(ctx, rc.dbInfos), dbMetas) require.NoError(t, err) if ca.level != passed { require.Equal(t, 1, rc.checkTemplate.FailedCount(ca.level)) @@ -436,10 +444,20 @@ func TestCheckTableEmpty(t *testing.T) { }, } + targetInfoGetter := &TargetInfoGetterImpl{ + cfg: cfg, + } + preInfoGetter := &PreRestoreInfoGetterImpl{ + cfg: cfg, + dbMetas: dbMetas, + targetInfoGetter: targetInfoGetter, + } + rc := &Controller{ cfg: cfg, dbMetas: dbMetas, checkpointsDB: checkpoints.NewNullCheckpointsDB(), + preInfoGetter: preInfoGetter, } ctx := context.Background() @@ -459,12 +477,12 @@ func TestCheckTableEmpty(t *testing.T) { db, mock, err := sqlmock.New() require.NoError(t, err) mock.MatchExpectationsInOrder(false) - rc.tidbGlue = glue.NewExternalTiDBGlue(db, mysql.ModeNone) - mock.ExpectQuery("select 1 from `test1`.`tbl1` limit 1"). + targetInfoGetter.targetDBGlue = glue.NewExternalTiDBGlue(db, mysql.ModeNone) + mock.ExpectQuery("SELECT 1 FROM `test1`.`tbl1` LIMIT 1"). WillReturnRows(sqlmock.NewRows([]string{""}).RowError(0, sql.ErrNoRows)) - mock.ExpectQuery("select 1 from `test1`.`tbl2` limit 1"). + mock.ExpectQuery("SELECT 1 FROM `test1`.`tbl2` LIMIT 1"). WillReturnRows(sqlmock.NewRows([]string{""}).RowError(0, sql.ErrNoRows)) - mock.ExpectQuery("select 1 from `test2`.`tbl1` limit 1"). + mock.ExpectQuery("SELECT 1 FROM `test2`.`tbl1` LIMIT 1"). WillReturnRows(sqlmock.NewRows([]string{""}).RowError(0, sql.ErrNoRows)) // not error, need not to init check template err = rc.checkTableEmpty(ctx) @@ -474,16 +492,16 @@ func TestCheckTableEmpty(t *testing.T) { // single table contains data db, mock, err = sqlmock.New() require.NoError(t, err) - rc.tidbGlue = glue.NewExternalTiDBGlue(db, mysql.ModeNone) + targetInfoGetter.targetDBGlue = glue.NewExternalTiDBGlue(db, mysql.ModeNone) mock.MatchExpectationsInOrder(false) // test auto retry retryable error - mock.ExpectQuery("select 1 from `test1`.`tbl1` limit 1"). + mock.ExpectQuery("SELECT 1 FROM `test1`.`tbl1` LIMIT 1"). WillReturnError(&gmysql.MySQLError{Number: errno.ErrPDServerTimeout}) - mock.ExpectQuery("select 1 from `test1`.`tbl1` limit 1"). + mock.ExpectQuery("SELECT 1 FROM `test1`.`tbl1` LIMIT 1"). WillReturnRows(sqlmock.NewRows([]string{""}).RowError(0, sql.ErrNoRows)) - mock.ExpectQuery("select 1 from `test1`.`tbl2` limit 1"). + mock.ExpectQuery("SELECT 1 FROM `test1`.`tbl2` LIMIT 1"). WillReturnRows(sqlmock.NewRows([]string{""}).RowError(0, sql.ErrNoRows)) - mock.ExpectQuery("select 1 from `test2`.`tbl1` limit 1"). + mock.ExpectQuery("SELECT 1 FROM `test2`.`tbl1` LIMIT 1"). WillReturnRows(sqlmock.NewRows([]string{""}).AddRow(1)) rc.checkTemplate = NewSimpleTemplate() err = rc.checkTableEmpty(ctx) @@ -497,13 +515,13 @@ func TestCheckTableEmpty(t *testing.T) { // multi tables contains data db, mock, err = sqlmock.New() require.NoError(t, err) - rc.tidbGlue = glue.NewExternalTiDBGlue(db, mysql.ModeNone) + targetInfoGetter.targetDBGlue = glue.NewExternalTiDBGlue(db, mysql.ModeNone) mock.MatchExpectationsInOrder(false) - mock.ExpectQuery("select 1 from `test1`.`tbl1` limit 1"). + mock.ExpectQuery("SELECT 1 FROM `test1`.`tbl1` LIMIT 1"). WillReturnRows(sqlmock.NewRows([]string{""}).AddRow(1)) - mock.ExpectQuery("select 1 from `test1`.`tbl2` limit 1"). + mock.ExpectQuery("SELECT 1 FROM `test1`.`tbl2` LIMIT 1"). WillReturnRows(sqlmock.NewRows([]string{""}).RowError(0, sql.ErrNoRows)) - mock.ExpectQuery("select 1 from `test2`.`tbl1` limit 1"). + mock.ExpectQuery("SELECT 1 FROM `test2`.`tbl1` LIMIT 1"). WillReturnRows(sqlmock.NewRows([]string{""}).AddRow(1)) rc.checkTemplate = NewSimpleTemplate() err = rc.checkTableEmpty(ctx) @@ -540,9 +558,9 @@ func TestCheckTableEmpty(t *testing.T) { require.NoError(t, err) db, mock, err = sqlmock.New() require.NoError(t, err) - rc.tidbGlue = glue.NewExternalTiDBGlue(db, mysql.ModeNone) + targetInfoGetter.targetDBGlue = glue.NewExternalTiDBGlue(db, mysql.ModeNone) // only need to check the one that is not in checkpoint - mock.ExpectQuery("select 1 from `test1`.`tbl2` limit 1"). + mock.ExpectQuery("SELECT 1 FROM `test1`.`tbl2` LIMIT 1"). WillReturnRows(sqlmock.NewRows([]string{""}).RowError(0, sql.ErrNoRows)) err = rc.checkTableEmpty(ctx) require.NoError(t, err) diff --git a/br/pkg/lightning/restore/get_pre_info.go b/br/pkg/lightning/restore/get_pre_info.go new file mode 100644 index 0000000000000..f76792942caa9 --- /dev/null +++ b/br/pkg/lightning/restore/get_pre_info.go @@ -0,0 +1,807 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package restore + +import ( + "bytes" + "context" + "database/sql" + "fmt" + "io" + "strings" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/br/pkg/lightning/backend" + "github.com/pingcap/tidb/br/pkg/lightning/backend/kv" + "github.com/pingcap/tidb/br/pkg/lightning/backend/local" + "github.com/pingcap/tidb/br/pkg/lightning/backend/tidb" + "github.com/pingcap/tidb/br/pkg/lightning/checkpoints" + "github.com/pingcap/tidb/br/pkg/lightning/common" + "github.com/pingcap/tidb/br/pkg/lightning/config" + "github.com/pingcap/tidb/br/pkg/lightning/errormanager" + "github.com/pingcap/tidb/br/pkg/lightning/glue" + "github.com/pingcap/tidb/br/pkg/lightning/log" + "github.com/pingcap/tidb/br/pkg/lightning/mydump" + "github.com/pingcap/tidb/br/pkg/lightning/verification" + "github.com/pingcap/tidb/br/pkg/lightning/worker" + "github.com/pingcap/tidb/br/pkg/storage" + "github.com/pingcap/tidb/ddl" + "github.com/pingcap/tidb/meta/autoid" + "github.com/pingcap/tidb/parser" + "github.com/pingcap/tidb/parser/ast" + "github.com/pingcap/tidb/parser/model" + "github.com/pingcap/tidb/parser/mysql" + _ "github.com/pingcap/tidb/planner/core" // to setup expression.EvalAstExpr. Otherwise we cannot parse the default value + "github.com/pingcap/tidb/store/pdtypes" + "github.com/pingcap/tidb/table/tables" + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/mock" + "go.uber.org/zap" + "golang.org/x/exp/maps" +) + +// EstimateSourceDataSizeResult is the object for estimated data size result. +type EstimateSourceDataSizeResult struct { + // SizeWithIndex is the size with the index. + SizeWithIndex int64 + // SizeWithoutIndex is the size without the index. + SizeWithoutIndex int64 + // HasUnsortedBigTables indicates whether the source data has unsorted big tables or not. + HasUnsortedBigTables bool +} + +// PreRestoreInfoGetter defines the operations to get information from sources and target. +// These information are used in the preparation of the import ( like precheck ). +type PreRestoreInfoGetter interface { + TargetInfoGetter + // GetAllTableStructures gets all the table structures with the information from both the source and the target. + GetAllTableStructures(ctx context.Context) (map[string]*checkpoints.TidbDBInfo, error) + // ReadFirstNRowsByTableName reads the first N rows of data of an importing source table. + ReadFirstNRowsByTableName(ctx context.Context, schemaName string, tableName string, n int) (cols []string, rows [][]types.Datum, err error) + // ReadFirstNRowsByFileMeta reads the first N rows of an data file. + ReadFirstNRowsByFileMeta(ctx context.Context, dataFileMeta mydump.SourceFileMeta, n int) (cols []string, rows [][]types.Datum, err error) + // EstimateSourceDataSize estimates the datasize to generate during the import as well as some other sub-informaiton. + // It will return: + // * the estimated data size to generate during the import, + // which might include some extra index data to generate besides the source file data + // * the total data size of all the source files, + // * whether there are some unsorted big tables + EstimateSourceDataSize(ctx context.Context) (*EstimateSourceDataSizeResult, error) +} + +// TargetInfoGetter defines the operations to get information from target. +type TargetInfoGetter interface { + // FetchRemoteTableModels fetches the table structures from the remote target. + FetchRemoteTableModels(ctx context.Context, schemaName string) ([]*model.TableInfo, error) + // CheckVersionRequirements performs the check whether the target satisfies the version requirements. + CheckVersionRequirements(ctx context.Context) error + // IsTableEmpty checks whether the specified table on the target DB contains data or not. + IsTableEmpty(ctx context.Context, schemaName string, tableName string) (*bool, error) + // GetTargetSysVariablesForImport gets some important systam variables for importing on the target. + GetTargetSysVariablesForImport(ctx context.Context) map[string]string + // GetReplicationConfig gets the replication config on the target. + GetReplicationConfig(ctx context.Context) (*pdtypes.ReplicationConfig, error) + // GetStorageInfo gets the storage information on the target. + GetStorageInfo(ctx context.Context) (*pdtypes.StoresInfo, error) + // GetEmptyRegionsInfo gets the region information of all the empty regions on the target. + GetEmptyRegionsInfo(ctx context.Context) (*pdtypes.RegionsInfo, error) +} + +type preInfoGetterKey string + +const ( + preInfoGetterKeyDBMetas preInfoGetterKey = "PRE_INFO_GETTER/DB_METAS" + preInfoGetterKeyTableStructsCache preInfoGetterKey = "PRE_INFO_GETTER/TABLE_STRUCTS_CACHE" + preInfoGetterKeySysVarsCache preInfoGetterKey = "PRE_INFO_GETTER/SYS_VARS_CACHE" + preInfoGetterKeyEstimatedSourceSizeCache preInfoGetterKey = "PRE_INFO_GETTER/ESTIMATED_SOURCE_SIZE_CACHE" +) + +func WithPreInfoGetterDBMetas(ctx context.Context, dbMetas []*mydump.MDDatabaseMeta) context.Context { + return context.WithValue(ctx, preInfoGetterKeyDBMetas, dbMetas) +} + +func WithPreInfoGetterTableStructuresCache(ctx context.Context, dbInfos map[string]*checkpoints.TidbDBInfo) context.Context { + return context.WithValue(ctx, preInfoGetterKeyTableStructsCache, dbInfos) +} + +func WithPreInfoGetterSysVarsCache(ctx context.Context, sysVars map[string]string) context.Context { + return context.WithValue(ctx, preInfoGetterKeySysVarsCache, sysVars) +} + +func WithPreInfoGetterEstimatedSrcSizeCache(ctx context.Context, sizeResult *EstimateSourceDataSizeResult) context.Context { + return context.WithValue(ctx, preInfoGetterKeyEstimatedSourceSizeCache, sizeResult) +} + +// TargetInfoGetterImpl implements the operations to get information from the target. +type TargetInfoGetterImpl struct { + cfg *config.Config + targetDBGlue glue.Glue + tls *common.TLS + backend backend.TargetInfoGetter +} + +// NewTargetInfoGetterImpl creates a TargetInfoGetterImpl object. +func NewTargetInfoGetterImpl( + cfg *config.Config, + targetDB *sql.DB, +) (*TargetInfoGetterImpl, error) { + targetDBGlue := glue.NewExternalTiDBGlue(targetDB, cfg.TiDB.SQLMode) + tls, err := cfg.ToTLS() + if err != nil { + return nil, errors.Trace(err) + } + var backendTargetInfoGetter backend.TargetInfoGetter + switch cfg.TikvImporter.Backend { + case config.BackendTiDB: + backendTargetInfoGetter = tidb.NewTargetInfoGetter(targetDB) + case config.BackendLocal: + backendTargetInfoGetter = local.NewTargetInfoGetter(tls, targetDBGlue, cfg.TiDB.PdAddr) + default: + return nil, common.ErrUnknownBackend.GenWithStackByArgs(cfg.TikvImporter.Backend) + } + return &TargetInfoGetterImpl{ + targetDBGlue: targetDBGlue, + tls: tls, + backend: backendTargetInfoGetter, + }, nil +} + +// FetchRemoteTableModels fetches the table structures from the remote target. +// It implements the TargetInfoGetter interface. +func (g *TargetInfoGetterImpl) FetchRemoteTableModels(ctx context.Context, schemaName string) ([]*model.TableInfo, error) { + if !g.targetDBGlue.OwnsSQLExecutor() { + return g.targetDBGlue.GetTables(ctx, schemaName) + } + return g.backend.FetchRemoteTableModels(ctx, schemaName) +} + +// CheckVersionRequirements performs the check whether the target satisfies the version requirements. +// It implements the TargetInfoGetter interface. +// Mydump database metas are retrieved from the context. +func (g *TargetInfoGetterImpl) CheckVersionRequirements(ctx context.Context) error { + var dbMetas []*mydump.MDDatabaseMeta + dbmetasVal := ctx.Value(preInfoGetterKeyDBMetas) + if dbmetasVal != nil { + if m, ok := dbmetasVal.([]*mydump.MDDatabaseMeta); ok { + dbMetas = m + } + } + return g.backend.CheckRequirements(ctx, &backend.CheckCtx{ + DBMetas: dbMetas, + }) +} + +// IsTableEmpty checks whether the specified table on the target DB contains data or not. +// It implements the TargetInfoGetter interface. +// It tries to select the row count from the target DB. +func (g *TargetInfoGetterImpl) IsTableEmpty(ctx context.Context, schemaName string, tableName string) (*bool, error) { + var result bool + failpoint.Inject("CheckTableEmptyFailed", func() { + failpoint.Return(nil, errors.New("mock error")) + }) + db, err := g.targetDBGlue.GetDB() + if err != nil { + return nil, errors.Trace(err) + } + exec := common.SQLWithRetry{ + DB: db, + Logger: log.FromContext(ctx), + } + var dump int + err = exec.QueryRow(ctx, "check table empty", + fmt.Sprintf("SELECT 1 FROM %s LIMIT 1", common.UniqueTable(schemaName, tableName)), + &dump, + ) + + switch { + case errors.ErrorEqual(err, sql.ErrNoRows): + result = true + case err != nil: + return nil, errors.Trace(err) + default: + result = false + } + return &result, nil +} + +// GetTargetSysVariablesForImport gets some important system variables for importing on the target. +// It implements the TargetInfoGetter interface. +// It uses the SQL to fetch sys variables from the target. +func (g *TargetInfoGetterImpl) GetTargetSysVariablesForImport(ctx context.Context) map[string]string { + sysVars := ObtainImportantVariables(ctx, g.targetDBGlue.GetSQLExecutor(), !isTiDBBackend(g.cfg)) + // override by manually set vars + maps.Copy(sysVars, g.cfg.TiDB.Vars) + return sysVars +} + +// GetReplicationConfig gets the replication config on the target. +// It implements the TargetInfoGetter interface. +// It uses the PD interface through TLS to get the information. +func (g *TargetInfoGetterImpl) GetReplicationConfig(ctx context.Context) (*pdtypes.ReplicationConfig, error) { + result := new(pdtypes.ReplicationConfig) + if err := g.tls.WithHost(g.cfg.TiDB.PdAddr).GetJSON(ctx, pdReplicate, &result); err != nil { + return nil, errors.Trace(err) + } + return result, nil +} + +// GetStorageInfo gets the storage information on the target. +// It implements the TargetInfoGetter interface. +// It uses the PD interface through TLS to get the information. +func (g *TargetInfoGetterImpl) GetStorageInfo(ctx context.Context) (*pdtypes.StoresInfo, error) { + result := new(pdtypes.StoresInfo) + if err := g.tls.WithHost(g.cfg.TiDB.PdAddr).GetJSON(ctx, pdStores, result); err != nil { + return nil, errors.Trace(err) + } + return result, nil +} + +// GetEmptyRegionsInfo gets the region information of all the empty regions on the target. +// It implements the TargetInfoGetter interface. +// It uses the PD interface through TLS to get the information. +func (g *TargetInfoGetterImpl) GetEmptyRegionsInfo(ctx context.Context) (*pdtypes.RegionsInfo, error) { + result := new(pdtypes.RegionsInfo) + if err := g.tls.WithHost(g.cfg.TiDB.PdAddr).GetJSON(ctx, pdEmptyRegions, &result); err != nil { + return nil, errors.Trace(err) + } + return result, nil +} + +// PreRestoreInfoGetterImpl implements the operations to get information used in importing preparation. +type PreRestoreInfoGetterImpl struct { + cfg *config.Config + srcStorage storage.ExternalStorage + ioWorkers *worker.Pool + encBuilder backend.EncodingBuilder + targetInfoGetter TargetInfoGetter + + dbMetas []*mydump.MDDatabaseMeta + mdDBMetaMap map[string]*mydump.MDDatabaseMeta + mdDBTableMetaMap map[string]map[string]*mydump.MDTableMeta +} + +// NewPreRestoreInfoGetter creates a PreRestoreInfoGetterImpl object. +func NewPreRestoreInfoGetter( + cfg *config.Config, + dbMetas []*mydump.MDDatabaseMeta, + srcStorage storage.ExternalStorage, + targetInfoGetter TargetInfoGetter, + ioWorkers *worker.Pool, + encBuilder backend.EncodingBuilder, +) (*PreRestoreInfoGetterImpl, error) { + if ioWorkers == nil { + ioWorkers = worker.NewPool(context.Background(), cfg.App.IOConcurrency, "pre_info_getter_io") + } + if encBuilder == nil { + switch cfg.TikvImporter.Backend { + case config.BackendTiDB: + encBuilder = tidb.NewEncodingBuilder() + case config.BackendLocal: + encBuilder = local.NewEncodingBuilder(context.Background()) + default: + return nil, common.ErrUnknownBackend.GenWithStackByArgs(cfg.TikvImporter.Backend) + } + } + + result := &PreRestoreInfoGetterImpl{ + cfg: cfg, + dbMetas: dbMetas, + srcStorage: srcStorage, + ioWorkers: ioWorkers, + encBuilder: encBuilder, + targetInfoGetter: targetInfoGetter, + } + result.Init() + return result, nil +} + +// Init initializes some internal data and states for PreRestoreInfoGetterImpl. +func (p *PreRestoreInfoGetterImpl) Init() { + mdDBMetaMap := make(map[string]*mydump.MDDatabaseMeta) + mdDBTableMetaMap := make(map[string]map[string]*mydump.MDTableMeta) + for _, dbMeta := range p.dbMetas { + dbName := dbMeta.Name + mdDBMetaMap[dbName] = dbMeta + mdTableMetaMap, ok := mdDBTableMetaMap[dbName] + if !ok { + mdTableMetaMap = make(map[string]*mydump.MDTableMeta) + mdDBTableMetaMap[dbName] = mdTableMetaMap + } + for _, tblMeta := range dbMeta.Tables { + tblName := tblMeta.Name + mdTableMetaMap[tblName] = tblMeta + } + } + p.mdDBMetaMap = mdDBMetaMap + p.mdDBTableMetaMap = mdDBTableMetaMap +} + +// GetAllTableStructures gets all the table structures with the information from both the source and the target. +// It implements the PreRestoreInfoGetter interface. +// It has a caching mechanism: the table structures will be obtained from the source only once. +func (p *PreRestoreInfoGetterImpl) GetAllTableStructures(ctx context.Context) (map[string]*checkpoints.TidbDBInfo, error) { + var ( + dbInfos map[string]*checkpoints.TidbDBInfo + err error + ) + dbInfosVal := ctx.Value(preInfoGetterKeyTableStructsCache) + if dbInfosVal != nil { + if v, ok := dbInfosVal.(map[string]*checkpoints.TidbDBInfo); ok { + dbInfos = v + } + } + if dbInfos != nil { + return dbInfos, nil + } + dbInfos, err = LoadSchemaInfo(ctx, p.dbMetas, func(ctx context.Context, dbName string) ([]*model.TableInfo, error) { + return p.getTableStructuresByFileMeta(ctx, p.mdDBMetaMap[dbName]) + }) + if err != nil { + return nil, errors.Trace(err) + } + return dbInfos, nil +} + +func (p *PreRestoreInfoGetterImpl) getTableStructuresByFileMeta(ctx context.Context, dbSrcFileMeta *mydump.MDDatabaseMeta) ([]*model.TableInfo, error) { + dbName := dbSrcFileMeta.Name + currentTableInfosFromDB, err := p.targetInfoGetter.FetchRemoteTableModels(ctx, dbName) + if err != nil { + return nil, errors.Trace(err) + } + currentTableInfosMap := make(map[string]*model.TableInfo) + for _, tblInfo := range currentTableInfosFromDB { + currentTableInfosMap[tblInfo.Name.L] = tblInfo + } + resultInfos := make([]*model.TableInfo, len(dbSrcFileMeta.Tables)) + for i, tableFileMeta := range dbSrcFileMeta.Tables { + if curTblInfo, ok := currentTableInfosMap[strings.ToLower(tableFileMeta.Name)]; ok { + resultInfos[i] = curTblInfo + continue + } + createTblSQL, err := tableFileMeta.GetSchema(ctx, p.srcStorage) + if err != nil { + return nil, errors.Annotatef(err, "get create table statement from schema file error: %s", tableFileMeta.Name) + } + theTableInfo, err := newTableInfo(createTblSQL, 0) + if err != nil { + errMsg := "generate table info from SQL error" + log.L().Error(errMsg, zap.Error(err), zap.String("sql", createTblSQL), zap.String("table_name", tableFileMeta.Name)) + return nil, errors.Annotatef(err, "%s: %s", errMsg, tableFileMeta.Name) + } + resultInfos[i] = theTableInfo + } + return resultInfos, nil +} + +func newTableInfo(createTblSQL string, tableID int64) (*model.TableInfo, error) { + parser := parser.New() + astNode, err := parser.ParseOneStmt(createTblSQL, "", "") + if err != nil { + errMsg := "parse sql statement error" + log.L().Error(errMsg, zap.Error(err), zap.String("sql", createTblSQL)) + return nil, errors.Trace(err) + } + sctx := mock.NewContext() + createTableStmt, ok := astNode.(*ast.CreateTableStmt) + if !ok { + return nil, errors.New("cannot transfer the parsed SQL as an CREATE TABLE statement") + } + info, err := ddl.MockTableInfo(sctx, createTableStmt, tableID) + if err != nil { + return nil, errors.Trace(err) + } + // set a auto_random bit if AUTO_RANDOM is set + setAutoRandomBits(info, createTableStmt.Cols) + info.State = model.StatePublic + return info, nil +} + +func setAutoRandomBits(tblInfo *model.TableInfo, colDefs []*ast.ColumnDef) { + if !tblInfo.PKIsHandle { + return + } + pkColName := tblInfo.GetPkName() + for _, colDef := range colDefs { + if colDef.Name.Name.L != pkColName.L || colDef.Tp.GetType() != mysql.TypeLonglong { + continue + } + // potential AUTO_RANDOM candidate column, examine the options + hasAutoRandom := false + canSetAutoRandom := true + var autoRandomBits int + for _, option := range colDef.Options { + if option.Tp == ast.ColumnOptionAutoRandom { + hasAutoRandom = true + autoRandomBits = option.AutoRandomBitLength + switch { + case autoRandomBits == types.UnspecifiedLength: + autoRandomBits = autoid.DefaultAutoRandomBits + case autoRandomBits <= 0 || autoRandomBits > autoid.MaxAutoRandomBits: + canSetAutoRandom = false + } + } + if option.Tp == ast.ColumnOptionAutoIncrement { + canSetAutoRandom = false + } + if option.Tp == ast.ColumnOptionDefaultValue { + canSetAutoRandom = false + } + } + if hasAutoRandom && canSetAutoRandom { + tblInfo.AutoRandomBits = uint64(autoRandomBits) + } + } +} + +// ReadFirstNRowsByTableName reads the first N rows of data of an importing source table. +// It implements the PreRestoreInfoGetter interface. +func (p *PreRestoreInfoGetterImpl) ReadFirstNRowsByTableName(ctx context.Context, schemaName string, tableName string, n int) ([]string, [][]types.Datum, error) { + mdTableMetaMap, ok := p.mdDBTableMetaMap[schemaName] + if !ok { + return nil, nil, errors.Errorf("cannot find the schema: %s", schemaName) + } + mdTableMeta, ok := mdTableMetaMap[tableName] + if !ok { + return nil, nil, errors.Errorf("cannot find the table: %s.%s", schemaName, tableName) + } + if len(mdTableMeta.DataFiles) <= 0 { + return nil, [][]types.Datum{}, nil + } + return p.ReadFirstNRowsByFileMeta(ctx, mdTableMeta.DataFiles[0].FileMeta, n) +} + +// ReadFirstNRowsByFileMeta reads the first N rows of an data file. +// It implements the PreRestoreInfoGetter interface. +func (p *PreRestoreInfoGetterImpl) ReadFirstNRowsByFileMeta(ctx context.Context, dataFileMeta mydump.SourceFileMeta, n int) ([]string, [][]types.Datum, error) { + var ( + reader storage.ReadSeekCloser + err error + ) + if dataFileMeta.Type == mydump.SourceTypeParquet { + reader, err = mydump.OpenParquetReader(ctx, p.srcStorage, dataFileMeta.Path, dataFileMeta.FileSize) + } else { + reader, err = p.srcStorage.Open(ctx, dataFileMeta.Path) + } + if err != nil { + return nil, nil, errors.Trace(err) + } + + var parser mydump.Parser + blockBufSize := int64(p.cfg.Mydumper.ReadBlockSize) + switch dataFileMeta.Type { + case mydump.SourceTypeCSV: + hasHeader := p.cfg.Mydumper.CSV.Header + // Create a utf8mb4 convertor to encode and decode data with the charset of CSV files. + charsetConvertor, err := mydump.NewCharsetConvertor(p.cfg.Mydumper.DataCharacterSet, p.cfg.Mydumper.DataInvalidCharReplace) + if err != nil { + return nil, nil, errors.Trace(err) + } + parser, err = mydump.NewCSVParser(ctx, &p.cfg.Mydumper.CSV, reader, blockBufSize, p.ioWorkers, hasHeader, charsetConvertor) + if err != nil { + return nil, nil, errors.Trace(err) + } + case mydump.SourceTypeSQL: + parser = mydump.NewChunkParser(ctx, p.cfg.TiDB.SQLMode, reader, blockBufSize, p.ioWorkers) + case mydump.SourceTypeParquet: + parser, err = mydump.NewParquetParser(ctx, p.srcStorage, reader, dataFileMeta.Path) + if err != nil { + return nil, nil, errors.Trace(err) + } + default: + panic(fmt.Sprintf("unknown file type '%s'", dataFileMeta.Type)) + } + //nolint: errcheck + defer parser.Close() + + rows := [][]types.Datum{} + for i := 0; i < n; i++ { + err := parser.ReadRow() + if err != nil { + if errors.Cause(err) != io.EOF { + return nil, nil, errors.Trace(err) + } else { + break + } + } + rows = append(rows, parser.LastRow().Row) + } + return parser.Columns(), rows, nil + +} + +// EstimateSourceDataSize estimates the datasize to generate during the import as well as some other sub-informaiton. +// It implements the PreRestoreInfoGetter interface. +func (p *PreRestoreInfoGetterImpl) EstimateSourceDataSize(ctx context.Context) (*EstimateSourceDataSizeResult, error) { + var result *EstimateSourceDataSizeResult + resultVal := ctx.Value(preInfoGetterKeyEstimatedSourceSizeCache) + if resultVal != nil { + if v, ok := resultVal.(*EstimateSourceDataSizeResult); ok { + result = v + } + } + if result != nil { + return result, nil + } + sizeWithIndex := int64(0) + sourceTotalSize := int64(0) + tableCount := 0 + unSortedBigTableCount := 0 + errMgr := errormanager.New(nil, p.cfg, log.FromContext(ctx)) + dbInfos, err := p.GetAllTableStructures(ctx) + if err != nil { + return nil, errors.Trace(err) + } + sysVars := p.GetTargetSysVariablesForImport(ctx) + for _, db := range p.dbMetas { + info, ok := dbInfos[db.Name] + if !ok { + continue + } + for _, tbl := range db.Tables { + sourceTotalSize += tbl.TotalSize + tableInfo, ok := info.Tables[tbl.Name] + if ok { + // Do not sample small table because there may a large number of small table and it will take a long + // time to sample data for all of them. + if isTiDBBackend(p.cfg) || tbl.TotalSize < int64(config.SplitRegionSize) { + sizeWithIndex += tbl.TotalSize + tbl.IndexRatio = 1.0 + tbl.IsRowOrdered = false + } else { + sampledIndexRatio, isRowOrderedFromSample, err := p.sampleDataFromTable(ctx, db.Name, tbl, tableInfo.Core, errMgr, sysVars) + if err != nil { + return nil, errors.Trace(err) + } + tbl.IndexRatio = sampledIndexRatio + tbl.IsRowOrdered = isRowOrderedFromSample + + if tbl.IndexRatio > 0 { + sizeWithIndex += int64(float64(tbl.TotalSize) * tbl.IndexRatio) + } else { + // if sample data failed due to max-error, fallback to use source size + sizeWithIndex += tbl.TotalSize + } + + if tbl.TotalSize > int64(config.DefaultBatchSize)*2 && !tbl.IsRowOrdered { + unSortedBigTableCount++ + } + } + tableCount += 1 + } + } + } + + result = &EstimateSourceDataSizeResult{ + SizeWithIndex: sizeWithIndex, + SizeWithoutIndex: sourceTotalSize, + HasUnsortedBigTables: (unSortedBigTableCount > 0), + } + return result, nil + +} + +// sampleDataFromTable samples the source data file to get the extra data ratio for the index +// It returns: +// * the extra data ratio with index size accounted +// * is the sample data ordered by row +func (p *PreRestoreInfoGetterImpl) sampleDataFromTable( + ctx context.Context, + dbName string, + tableMeta *mydump.MDTableMeta, + tableInfo *model.TableInfo, + errMgr *errormanager.ErrorManager, + sysVars map[string]string, +) (float64, bool, error) { + resultIndexRatio := 1.0 + isRowOrdered := false + if len(tableMeta.DataFiles) == 0 { + return resultIndexRatio, isRowOrdered, nil + } + sampleFile := tableMeta.DataFiles[0].FileMeta + var reader storage.ReadSeekCloser + var err error + if sampleFile.Type == mydump.SourceTypeParquet { + reader, err = mydump.OpenParquetReader(ctx, p.srcStorage, sampleFile.Path, sampleFile.FileSize) + } else { + reader, err = p.srcStorage.Open(ctx, sampleFile.Path) + } + if err != nil { + return 0.0, false, errors.Trace(err) + } + idAlloc := kv.NewPanickingAllocators(0) + tbl, err := tables.TableFromMeta(idAlloc, tableInfo) + if err != nil { + return 0.0, false, errors.Trace(err) + } + kvEncoder, err := p.encBuilder.NewEncoder(ctx, tbl, &kv.SessionOptions{ + SQLMode: p.cfg.TiDB.SQLMode, + Timestamp: 0, + SysVars: sysVars, + AutoRandomSeed: 0, + }) + if err != nil { + return 0.0, false, errors.Trace(err) + } + blockBufSize := int64(p.cfg.Mydumper.ReadBlockSize) + + var parser mydump.Parser + switch tableMeta.DataFiles[0].FileMeta.Type { + case mydump.SourceTypeCSV: + hasHeader := p.cfg.Mydumper.CSV.Header + // Create a utf8mb4 convertor to encode and decode data with the charset of CSV files. + charsetConvertor, err := mydump.NewCharsetConvertor(p.cfg.Mydumper.DataCharacterSet, p.cfg.Mydumper.DataInvalidCharReplace) + if err != nil { + return 0.0, false, errors.Trace(err) + } + parser, err = mydump.NewCSVParser(ctx, &p.cfg.Mydumper.CSV, reader, blockBufSize, p.ioWorkers, hasHeader, charsetConvertor) + if err != nil { + return 0.0, false, errors.Trace(err) + } + case mydump.SourceTypeSQL: + parser = mydump.NewChunkParser(ctx, p.cfg.TiDB.SQLMode, reader, blockBufSize, p.ioWorkers) + case mydump.SourceTypeParquet: + parser, err = mydump.NewParquetParser(ctx, p.srcStorage, reader, sampleFile.Path) + if err != nil { + return 0.0, false, errors.Trace(err) + } + default: + panic(fmt.Sprintf("file '%s' with unknown source type '%s'", sampleFile.Path, sampleFile.Type.String())) + } + //nolint: errcheck + defer parser.Close() + logTask := log.FromContext(ctx).With(zap.String("table", tableMeta.Name)).Begin(zap.InfoLevel, "sample file") + igCols, err := p.cfg.Mydumper.IgnoreColumns.GetIgnoreColumns(dbName, tableMeta.Name, p.cfg.Mydumper.CaseSensitive) + if err != nil { + return 0.0, false, errors.Trace(err) + } + + initializedColumns := false + var columnPermutation []int + var kvSize uint64 = 0 + var rowSize uint64 = 0 + rowCount := 0 + dataKVs := p.encBuilder.MakeEmptyRows() + indexKVs := p.encBuilder.MakeEmptyRows() + lastKey := make([]byte, 0) + isRowOrdered = true +outloop: + for { + offset, _ := parser.Pos() + err = parser.ReadRow() + columnNames := parser.Columns() + + switch errors.Cause(err) { + case nil: + if !initializedColumns { + if len(columnPermutation) == 0 { + columnPermutation, err = createColumnPermutation( + columnNames, + igCols.ColumnsMap(), + tableInfo, + log.FromContext(ctx)) + if err != nil { + return 0.0, false, errors.Trace(err) + } + } + initializedColumns = true + } + case io.EOF: + break outloop + default: + err = errors.Annotatef(err, "in file offset %d", offset) + return 0.0, false, errors.Trace(err) + } + lastRow := parser.LastRow() + rowCount += 1 + + var dataChecksum, indexChecksum verification.KVChecksum + kvs, encodeErr := kvEncoder.Encode(logTask.Logger, lastRow.Row, lastRow.RowID, columnPermutation, sampleFile.Path, offset) + if encodeErr != nil { + encodeErr = errMgr.RecordTypeError(ctx, log.FromContext(ctx), tableInfo.Name.O, sampleFile.Path, offset, + "" /* use a empty string here because we don't actually record */, encodeErr) + if encodeErr != nil { + return 0.0, false, errors.Annotatef(encodeErr, "in file at offset %d", offset) + } + if rowCount < maxSampleRowCount { + continue + } else { + break + } + } + if isRowOrdered { + kvs.ClassifyAndAppend(&dataKVs, &dataChecksum, &indexKVs, &indexChecksum) + for _, kv := range kv.KvPairsFromRows(dataKVs) { + if len(lastKey) == 0 { + lastKey = kv.Key + } else if bytes.Compare(lastKey, kv.Key) > 0 { + isRowOrdered = false + break + } + } + dataKVs = dataKVs.Clear() + indexKVs = indexKVs.Clear() + } + kvSize += kvs.Size() + rowSize += uint64(lastRow.Length) + parser.RecycleRow(lastRow) + + failpoint.Inject("mock-kv-size", func(val failpoint.Value) { + kvSize += uint64(val.(int)) + }) + if rowSize > maxSampleDataSize || rowCount > maxSampleRowCount { + break + } + } + + if rowSize > 0 && kvSize > rowSize { + resultIndexRatio = float64(kvSize) / float64(rowSize) + } + log.FromContext(ctx).Info("Sample source data", zap.String("table", tableMeta.Name), zap.Float64("IndexRatio", tableMeta.IndexRatio), zap.Bool("IsSourceOrder", tableMeta.IsRowOrdered)) + return resultIndexRatio, isRowOrdered, nil +} + +// GetReplicationConfig gets the replication config on the target. +// It implements the PreRestoreInfoGetter interface. +func (p *PreRestoreInfoGetterImpl) GetReplicationConfig(ctx context.Context) (*pdtypes.ReplicationConfig, error) { + return p.targetInfoGetter.GetReplicationConfig(ctx) +} + +// GetStorageInfo gets the storage information on the target. +// It implements the PreRestoreInfoGetter interface. +func (p *PreRestoreInfoGetterImpl) GetStorageInfo(ctx context.Context) (*pdtypes.StoresInfo, error) { + return p.targetInfoGetter.GetStorageInfo(ctx) +} + +// GetEmptyRegionsInfo gets the region information of all the empty regions on the target. +// It implements the PreRestoreInfoGetter interface. +func (p *PreRestoreInfoGetterImpl) GetEmptyRegionsInfo(ctx context.Context) (*pdtypes.RegionsInfo, error) { + return p.targetInfoGetter.GetEmptyRegionsInfo(ctx) +} + +// IsTableEmpty checks whether the specified table on the target DB contains data or not. +// It implements the PreRestoreInfoGetter interface. +func (p *PreRestoreInfoGetterImpl) IsTableEmpty(ctx context.Context, schemaName string, tableName string) (*bool, error) { + return p.targetInfoGetter.IsTableEmpty(ctx, schemaName, tableName) +} + +// FetchRemoteTableModels fetches the table structures from the remote target. +// It implements the PreRestoreInfoGetter interface. +func (p *PreRestoreInfoGetterImpl) FetchRemoteTableModels(ctx context.Context, schemaName string) ([]*model.TableInfo, error) { + return p.targetInfoGetter.FetchRemoteTableModels(ctx, schemaName) +} + +// CheckVersionRequirements performs the check whether the target satisfies the version requirements. +// It implements the PreRestoreInfoGetter interface. +// Mydump database metas are retrieved from the context. +func (g *PreRestoreInfoGetterImpl) CheckVersionRequirements(ctx context.Context) error { + return g.targetInfoGetter.CheckVersionRequirements(ctx) +} + +// GetTargetSysVariablesForImport gets some important systam variables for importing on the target. +// It implements the PreRestoreInfoGetter interface. +// It has caching mechanism. +func (p *PreRestoreInfoGetterImpl) GetTargetSysVariablesForImport(ctx context.Context) map[string]string { + var sysVars map[string]string + sysVarsVal := ctx.Value(preInfoGetterKeySysVarsCache) + if sysVarsVal != nil { + if v, ok := sysVarsVal.(map[string]string); ok { + sysVars = v + } + } + if sysVars != nil { + return sysVars + } + return p.targetInfoGetter.GetTargetSysVariablesForImport(ctx) +} diff --git a/br/pkg/lightning/restore/get_pre_info_test.go b/br/pkg/lightning/restore/get_pre_info_test.go new file mode 100644 index 0000000000000..94bc4ec58c41e --- /dev/null +++ b/br/pkg/lightning/restore/get_pre_info_test.go @@ -0,0 +1,497 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package restore + +import ( + "context" + "fmt" + "strings" + "testing" + + "github.com/pingcap/tidb/br/pkg/lightning/config" + "github.com/pingcap/tidb/br/pkg/lightning/restore/mock" + "github.com/pingcap/tidb/parser/model" + "github.com/pingcap/tidb/types" + "github.com/stretchr/testify/require" +) + +type colDef struct { + ColName string + Def string + TypeStr string +} + +type tableDef []*colDef + +func tableDefsToMockDataMap(dbTableDefs map[string]map[string]tableDef) map[string]*mock.MockDBSourceData { + dbMockDataMap := make(map[string]*mock.MockDBSourceData) + for dbName, tblDefMap := range dbTableDefs { + tblMockDataMap := make(map[string]*mock.MockTableSourceData) + for tblName, colDefs := range tblDefMap { + colDefStrs := make([]string, len(colDefs)) + for i, colDef := range colDefs { + colDefStrs[i] = fmt.Sprintf("%s %s", colDef.ColName, colDef.Def) + } + createSQL := fmt.Sprintf("CREATE TABLE %s.%s (%s);", dbName, tblName, strings.Join(colDefStrs, ", ")) + tblMockDataMap[tblName] = &mock.MockTableSourceData{ + DBName: dbName, + TableName: tblName, + SchemaFile: &mock.MockSourceFile{ + FileName: fmt.Sprintf("/%s/%s/%s.schema.sql", dbName, tblName, tblName), + Data: []byte(createSQL), + }, + } + } + dbMockDataMap[dbName] = &mock.MockDBSourceData{ + Name: dbName, + Tables: tblMockDataMap, + } + } + return dbMockDataMap +} + +func TestGetPreInfoGenerateTableInfo(t *testing.T) { + schemaName := "db1" + tblName := "tbl1" + createTblSQL := fmt.Sprintf("create table `%s`.`%s` (a varchar(16) not null, b varchar(8) default 'DEFA')", schemaName, tblName) + tblInfo, err := newTableInfo(createTblSQL, 1) + require.Nil(t, err) + t.Logf("%+v", tblInfo) + require.Equal(t, model.NewCIStr(tblName), tblInfo.Name) + require.Equal(t, len(tblInfo.Columns), 2) + require.Equal(t, model.NewCIStr("a"), tblInfo.Columns[0].Name) + require.Nil(t, tblInfo.Columns[0].DefaultValue) + require.False(t, hasDefault(tblInfo.Columns[0])) + require.Equal(t, model.NewCIStr("b"), tblInfo.Columns[1].Name) + require.NotNil(t, tblInfo.Columns[1].DefaultValue) + + createTblSQL = fmt.Sprintf("create table `%s`.`%s` (a varchar(16), b varchar(8) default 'DEFAULT_BBBBB')", schemaName, tblName) // default value exceeds the length + tblInfo, err = newTableInfo(createTblSQL, 2) + require.NotNil(t, err) +} + +func TestGetPreInfoHasDefault(t *testing.T) { + subCases := []struct { + ColDef string + ExpectHasDefault bool + }{ + { + ColDef: "varchar(16)", + ExpectHasDefault: true, + }, + { + ColDef: "varchar(16) NOT NULL", + ExpectHasDefault: false, + }, + { + ColDef: "INTEGER PRIMARY KEY", + ExpectHasDefault: false, + }, + { + ColDef: "INTEGER AUTO_INCREMENT", + ExpectHasDefault: true, + }, + { + ColDef: "INTEGER PRIMARY KEY AUTO_INCREMENT", + ExpectHasDefault: true, + }, + { + ColDef: "BIGINT PRIMARY KEY AUTO_RANDOM", + ExpectHasDefault: false, + }, + } + for _, subCase := range subCases { + createTblSQL := fmt.Sprintf("create table `db1`.`tbl1` (a %s)", subCase.ColDef) + tblInfo, err := newTableInfo(createTblSQL, 1) + require.Nil(t, err) + require.Equal(t, subCase.ExpectHasDefault, hasDefault(tblInfo.Columns[0]), subCase.ColDef) + } +} + +func TestGetPreInfoAutoRandomBits(t *testing.T) { + subCases := []struct { + ColDef string + ExpectAutoRandomBits uint64 + }{ + { + ColDef: "varchar(16)", + ExpectAutoRandomBits: 0, + }, + { + ColDef: "varchar(16) AUTO_RANDOM", + ExpectAutoRandomBits: 0, + }, + { + ColDef: "INTEGER PRIMARY KEY AUTO_RANDOM", + ExpectAutoRandomBits: 0, + }, + { + ColDef: "BIGINT PRIMARY KEY AUTO_RANDOM AUTO_INCREMENT", + ExpectAutoRandomBits: 0, + }, + { + ColDef: "BIGINT PRIMARY KEY AUTO_RANDOM(3)", + ExpectAutoRandomBits: 3, + }, + { + ColDef: "BIGINT PRIMARY KEY AUTO_RANDOM", + ExpectAutoRandomBits: 5, + }, + { + ColDef: "BIGINT PRIMARY KEY AUTO_RANDOM(20)", + ExpectAutoRandomBits: 0, + }, + { + ColDef: "BIGINT PRIMARY KEY AUTO_RANDOM(0)", + ExpectAutoRandomBits: 0, + }, + { + ColDef: "BIGINT AUTO_RANDOM", + ExpectAutoRandomBits: 0, + }, + } + for _, subCase := range subCases { + createTblSQL := fmt.Sprintf("create table `db1`.`tbl1` (a %s)", subCase.ColDef) + tblInfo, err := newTableInfo(createTblSQL, 1) + require.Nil(t, err) + require.Equal(t, subCase.ExpectAutoRandomBits, tblInfo.AutoRandomBits, subCase.ColDef) + } +} + +func TestGetPreInfoGetAllTableStructures(t *testing.T) { + dbTableDefs := map[string]map[string]tableDef{ + "db01": { + "tbl01": { + &colDef{ + ColName: "id", + Def: "INTEGER PRIMARY KEY AUTO_INCREMENT", + TypeStr: "int", + }, + &colDef{ + ColName: "strval", + Def: "VARCHAR(64)", + TypeStr: "varchar", + }, + }, + "tbl02": { + &colDef{ + ColName: "id", + Def: "INTEGER PRIMARY KEY AUTO_INCREMENT", + TypeStr: "int", + }, + &colDef{ + ColName: "val", + Def: "VARCHAR(64)", + TypeStr: "varchar", + }, + }, + }, + "db02": { + "tbl01": { + &colDef{ + ColName: "id", + Def: "INTEGER PRIMARY KEY AUTO_INCREMENT", + TypeStr: "int", + }, + &colDef{ + ColName: "strval", + Def: "VARCHAR(64)", + TypeStr: "varchar", + }, + }, + }, + } + testMockDataMap := tableDefsToMockDataMap(dbTableDefs) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mockSrc, err := mock.NewMockImportSource(testMockDataMap) + require.Nil(t, err) + + mockTarget := mock.NewMockTargetInfo() + + cfg := config.NewConfig() + cfg.TikvImporter.Backend = config.BackendLocal + ig, err := NewPreRestoreInfoGetter(cfg, mockSrc.GetAllDBFileMetas(), mockSrc.GetStorage(), mockTarget, nil, nil) + require.NoError(t, err) + tblStructMap, err := ig.GetAllTableStructures(ctx) + require.Nil(t, err) + require.Equal(t, len(dbTableDefs), len(tblStructMap), "compare db count") + for dbName, dbInfo := range tblStructMap { + tblDefMap, ok := dbTableDefs[dbName] + require.Truef(t, ok, "check db exists in db definitions: %s", dbName) + require.Equalf(t, len(tblDefMap), len(dbInfo.Tables), "compare table count: %s", dbName) + for tblName, tblStruct := range dbInfo.Tables { + tblDef, ok := tblDefMap[tblName] + require.Truef(t, ok, "check table exists in table definitions: %s.%s", dbName, tblName) + require.Equalf(t, len(tblDef), len(tblStruct.Core.Columns), "compare columns count: %s.%s", dbName, tblName) + for i, colDef := range tblStruct.Core.Columns { + expectColDef := tblDef[i] + require.Equalf(t, strings.ToLower(expectColDef.ColName), colDef.Name.L, "check column name: %s.%s", dbName, tblName) + require.Truef(t, strings.Contains(colDef.FieldType.String(), strings.ToLower(expectColDef.TypeStr)), "check column type: %s.%s", dbName, tblName) + } + } + } +} + +func TestGetPreInfoReadFirstRow(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + const testCSVData01 string = `ival,sval +111,"aaa" +222,"bbb" +` + const testSQLData01 string = `INSERT INTO db01.tbl01 (ival, sval) VALUES (333, 'ccc'); +INSERT INTO db01.tbl01 (ival, sval) VALUES (444, 'ddd');` + testDataInfos := []struct { + FileName string + Data string + FirstN int + CSVConfig *config.CSVConfig + ExpectFirstRowDatums [][]types.Datum + ExpectColumns []string + }{ + { + FileName: "/db01/tbl01/data.001.csv", + Data: testCSVData01, + FirstN: 1, + ExpectFirstRowDatums: [][]types.Datum{ + { + types.NewStringDatum("111"), + types.NewStringDatum("aaa"), + }, + }, + ExpectColumns: []string{"ival", "sval"}, + }, + { + FileName: "/db01/tbl01/data.002.csv", + Data: testCSVData01, + FirstN: 2, + ExpectFirstRowDatums: [][]types.Datum{ + { + types.NewStringDatum("111"), + types.NewStringDatum("aaa"), + }, + { + types.NewStringDatum("222"), + types.NewStringDatum("bbb"), + }, + }, + ExpectColumns: []string{"ival", "sval"}, + }, + { + FileName: "/db01/tbl01/data.001.sql", + Data: testSQLData01, + FirstN: 1, + ExpectFirstRowDatums: [][]types.Datum{ + { + types.NewUintDatum(333), + types.NewStringDatum("ccc"), + }, + }, + ExpectColumns: []string{"ival", "sval"}, + }, + { + FileName: "/db01/tbl01/data.003.csv", + Data: "", + FirstN: 1, + ExpectFirstRowDatums: [][]types.Datum{}, + ExpectColumns: nil, + }, + { + FileName: "/db01/tbl01/data.004.csv", + Data: "ival,sval", + FirstN: 1, + ExpectFirstRowDatums: [][]types.Datum{}, + ExpectColumns: []string{"ival", "sval"}, + }, + } + tblMockSourceData := &mock.MockTableSourceData{ + DBName: "db01", + TableName: "tbl01", + SchemaFile: &mock.MockSourceFile{ + FileName: "/db01/tbl01/tbl01.schema.sql", + Data: []byte("CREATE TABLE db01.tbl01(id INTEGER PRIMARY KEY AUTO_INCREMENT, ival INTEGER, sval VARCHAR(64));"), + }, + DataFiles: []*mock.MockSourceFile{}, + } + for _, testInfo := range testDataInfos { + tblMockSourceData.DataFiles = append(tblMockSourceData.DataFiles, &mock.MockSourceFile{ + FileName: testInfo.FileName, + Data: []byte(testInfo.Data), + }) + } + mockDataMap := map[string]*mock.MockDBSourceData{ + "db01": { + Name: "db01", + Tables: map[string]*mock.MockTableSourceData{ + "tbl01": tblMockSourceData, + }, + }, + } + mockSrc, err := mock.NewMockImportSource(mockDataMap) + require.Nil(t, err) + mockTarget := mock.NewMockTargetInfo() + cfg := config.NewConfig() + cfg.TikvImporter.Backend = config.BackendLocal + ig, err := NewPreRestoreInfoGetter(cfg, mockSrc.GetAllDBFileMetas(), mockSrc.GetStorage(), mockTarget, nil, nil) + require.NoError(t, err) + + cfg.Mydumper.CSV.Header = true + tblMeta := mockSrc.GetDBMetaMap()["db01"].Tables[0] + for i, dataFile := range tblMeta.DataFiles { + theDataInfo := testDataInfos[i] + cols, rowDatums, err := ig.ReadFirstNRowsByFileMeta(ctx, dataFile.FileMeta, theDataInfo.FirstN) + require.Nil(t, err) + t.Logf("%v, %v", cols, rowDatums) + require.Equal(t, theDataInfo.ExpectColumns, cols) + require.Equal(t, theDataInfo.ExpectFirstRowDatums, rowDatums) + } + + theDataInfo := testDataInfos[0] + cols, rowDatums, err := ig.ReadFirstNRowsByTableName(ctx, "db01", "tbl01", theDataInfo.FirstN) + require.NoError(t, err) + require.Equal(t, theDataInfo.ExpectColumns, cols) + require.Equal(t, theDataInfo.ExpectFirstRowDatums, rowDatums) +} + +func TestGetPreInfoSampleSource(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + dataFileName := "/db01/tbl01/tbl01.data.001.csv" + mockDataMap := map[string]*mock.MockDBSourceData{ + "db01": { + Name: "db01", + Tables: map[string]*mock.MockTableSourceData{ + "tbl01": { + DBName: "db01", + TableName: "tbl01", + SchemaFile: &mock.MockSourceFile{ + FileName: "/db01/tbl01/tbl01.schema.sql", + Data: []byte("CREATE TABLE db01.tbl01 (id INTEGER PRIMARY KEY AUTO_INCREMENT, ival INTEGER, sval VARCHAR(64));"), + }, + DataFiles: []*mock.MockSourceFile{ + { + FileName: dataFileName, + Data: []byte(nil), + }, + }, + }, + }, + }, + } + mockSrc, err := mock.NewMockImportSource(mockDataMap) + require.Nil(t, err) + mockTarget := mock.NewMockTargetInfo() + cfg := config.NewConfig() + cfg.TikvImporter.Backend = config.BackendLocal + ig, err := NewPreRestoreInfoGetter(cfg, mockSrc.GetAllDBFileMetas(), mockSrc.GetStorage(), mockTarget, nil, nil) + require.NoError(t, err) + + mdDBMeta := mockSrc.GetAllDBFileMetas()[0] + mdTblMeta := mdDBMeta.Tables[0] + dbInfos, err := ig.GetAllTableStructures(ctx) + require.NoError(t, err) + + subTests := []struct { + Data []byte + ExpectIsOrdered bool + }{ + { + Data: []byte(`id,ival,sval +1,111,"aaa" +2,222,"bbb" +`, + ), + ExpectIsOrdered: true, + }, + { + Data: []byte(`sval,ival,id +"aaa",111,1 +"bbb",222,2 +`, + ), + ExpectIsOrdered: true, + }, + { + Data: []byte(`id,ival,sval +2,222,"bbb" +1,111,"aaa" +`, + ), + ExpectIsOrdered: false, + }, + { + Data: []byte(`sval,ival,id +"aaa",111,2 +"bbb",222,1 +`, + ), + ExpectIsOrdered: false, + }, + } + for _, subTest := range subTests { + require.NoError(t, mockSrc.GetStorage().WriteFile(ctx, dataFileName, subTest.Data)) + sampledIndexRatio, isRowOrderedFromSample, err := ig.sampleDataFromTable(ctx, "db01", mdTblMeta, dbInfos["db01"].Tables["tbl01"].Core, nil, defaultImportantVariables) + require.NoError(t, err) + t.Logf("%v, %v", sampledIndexRatio, isRowOrderedFromSample) + require.Greater(t, sampledIndexRatio, 1.0) + require.Equal(t, subTest.ExpectIsOrdered, isRowOrderedFromSample) + } +} + +func TestGetPreInfoEstimateSourceSize(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + dataFileName := "/db01/tbl01/tbl01.data.001.csv" + testData := []byte(`id,ival,sval +1,111,"aaa" +2,222,"bbb" +`, + ) + mockDataMap := map[string]*mock.MockDBSourceData{ + "db01": { + Name: "db01", + Tables: map[string]*mock.MockTableSourceData{ + "tbl01": { + DBName: "db01", + TableName: "tbl01", + SchemaFile: &mock.MockSourceFile{ + FileName: "/db01/tbl01/tbl01.schema.sql", + Data: []byte("CREATE TABLE db01.tbl01 (id INTEGER PRIMARY KEY AUTO_INCREMENT, ival INTEGER, sval VARCHAR(64));"), + }, + DataFiles: []*mock.MockSourceFile{ + { + FileName: dataFileName, + Data: testData, + }, + }, + }, + }, + }, + } + mockSrc, err := mock.NewMockImportSource(mockDataMap) + require.Nil(t, err) + mockTarget := mock.NewMockTargetInfo() + cfg := config.NewConfig() + cfg.TikvImporter.Backend = config.BackendLocal + ig, err := NewPreRestoreInfoGetter(cfg, mockSrc.GetAllDBFileMetas(), mockSrc.GetStorage(), mockTarget, nil, nil) + require.NoError(t, err) + + sizeResult, err := ig.EstimateSourceDataSize(ctx) + require.NoError(t, err) + t.Logf("estimate size: %v, file size: %v, has unsorted table: %v\n", sizeResult.SizeWithIndex, sizeResult.SizeWithoutIndex, sizeResult.HasUnsortedBigTables) + require.GreaterOrEqual(t, sizeResult.SizeWithIndex, sizeResult.SizeWithoutIndex) + require.Equal(t, int64(len(testData)), sizeResult.SizeWithoutIndex) + require.False(t, sizeResult.HasUnsortedBigTables) +} diff --git a/br/pkg/lightning/restore/mock/mock.go b/br/pkg/lightning/restore/mock/mock.go new file mode 100644 index 0000000000000..100372b594620 --- /dev/null +++ b/br/pkg/lightning/restore/mock/mock.go @@ -0,0 +1,283 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + "strings" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/br/pkg/lightning/mydump" + "github.com/pingcap/tidb/br/pkg/storage" + "github.com/pingcap/tidb/parser/model" + "github.com/pingcap/tidb/store/pdtypes" + "github.com/pingcap/tidb/util/filter" +) + +// MockSourceFile defines a mock source file. +type MockSourceFile struct { + FileName string + Data []byte + TotalSize int +} + +// MockTableSourceData defines a mock source information for a table. +type MockTableSourceData struct { + DBName string + TableName string + SchemaFile *MockSourceFile + DataFiles []*MockSourceFile +} + +// MockDBSourceData defines a mock source information for a database. +type MockDBSourceData struct { + Name string + Tables map[string]*MockTableSourceData +} + +// MockImportSource defines a mock import source +type MockImportSource struct { + dbSrcDataMap map[string]*MockDBSourceData + dbFileMetaMap map[string]*mydump.MDDatabaseMeta + srcStorage storage.ExternalStorage +} + +// NewMockImportSource creates a MockImportSource object. +func NewMockImportSource(dbSrcDataMap map[string]*MockDBSourceData) (*MockImportSource, error) { + ctx := context.Background() + dbFileMetaMap := make(map[string]*mydump.MDDatabaseMeta) + mapStore := storage.NewMemStorage() + for dbName, dbData := range dbSrcDataMap { + dbFileInfo := mydump.FileInfo{ + TableName: filter.Table{ + Schema: dbName, + }, + FileMeta: mydump.SourceFileMeta{Type: mydump.SourceTypeSchemaSchema}, + } + dbMeta := mydump.NewMDDatabaseMeta("binary") + dbMeta.Name = dbName + dbMeta.SchemaFile = dbFileInfo + dbMeta.Tables = []*mydump.MDTableMeta{} + for tblName, tblData := range dbData.Tables { + tblMeta := mydump.NewMDTableMeta("binary") + tblMeta.DB = dbName + tblMeta.Name = tblName + tblMeta.SchemaFile = mydump.FileInfo{ + TableName: filter.Table{ + Schema: dbName, + Name: tblName, + }, + FileMeta: mydump.SourceFileMeta{ + Path: tblData.SchemaFile.FileName, + Type: mydump.SourceTypeTableSchema, + }, + } + tblMeta.DataFiles = []mydump.FileInfo{} + if err := mapStore.WriteFile(ctx, tblData.SchemaFile.FileName, tblData.SchemaFile.Data); err != nil { + return nil, errors.Trace(err) + } + totalFileSize := 0 + for _, tblDataFile := range tblData.DataFiles { + fileSize := tblDataFile.TotalSize + if fileSize == 0 { + fileSize = len(tblDataFile.Data) + } + totalFileSize += fileSize + fileInfo := mydump.FileInfo{ + TableName: filter.Table{ + Schema: dbName, + Name: tblName, + }, + FileMeta: mydump.SourceFileMeta{ + Path: tblDataFile.FileName, + FileSize: int64(fileSize), + }, + } + switch { + case strings.HasSuffix(tblDataFile.FileName, ".csv"): + fileInfo.FileMeta.Type = mydump.SourceTypeCSV + case strings.HasSuffix(tblDataFile.FileName, ".sql"): + fileInfo.FileMeta.Type = mydump.SourceTypeSQL + default: + return nil, errors.Errorf("unsupported file type: %s", tblDataFile.FileName) + } + tblMeta.DataFiles = append(tblMeta.DataFiles, fileInfo) + if err := mapStore.WriteFile(ctx, tblDataFile.FileName, tblDataFile.Data); err != nil { + return nil, errors.Trace(err) + } + } + tblMeta.TotalSize = int64(totalFileSize) + dbMeta.Tables = append(dbMeta.Tables, tblMeta) + } + dbFileMetaMap[dbName] = dbMeta + } + return &MockImportSource{ + dbSrcDataMap: dbSrcDataMap, + dbFileMetaMap: dbFileMetaMap, + srcStorage: mapStore, + }, nil +} + +// GetStorage gets the External Storage object on the mock source. +func (m *MockImportSource) GetStorage() storage.ExternalStorage { + return m.srcStorage +} + +// GetDBMetaMap gets the Mydumper database metadata map on the mock source. +func (m *MockImportSource) GetDBMetaMap() map[string]*mydump.MDDatabaseMeta { + return m.dbFileMetaMap +} + +// GetAllDBFileMetas gets all the Mydumper database metadatas on the mock source. +func (m *MockImportSource) GetAllDBFileMetas() []*mydump.MDDatabaseMeta { + result := make([]*mydump.MDDatabaseMeta, len(m.dbFileMetaMap)) + i := 0 + for _, dbMeta := range m.dbFileMetaMap { + result[i] = dbMeta + i++ + } + return result +} + +// StorageInfo defines the storage information for a mock target. +type StorageInfo struct { + TotalSize uint64 + UsedSize uint64 + AvailableSize uint64 +} + +// MockTableInfo defines a mock table structure information for a mock target. +type MockTableInfo struct { + RowCount int + TableModel *model.TableInfo +} + +// MockTableInfo defines a mock target information. +type MockTargetInfo struct { + MaxReplicasPerRegion int + EmptyRegionCount int + StorageInfos []StorageInfo + sysVarMap map[string]string + dbTblInfoMap map[string]map[string]*MockTableInfo +} + +// NewMockTargetInfo creates a MockTargetInfo object. +func NewMockTargetInfo() *MockTargetInfo { + return &MockTargetInfo{ + StorageInfos: []StorageInfo{}, + sysVarMap: make(map[string]string), + dbTblInfoMap: make(map[string]map[string]*MockTableInfo), + } +} + +// SetSysVar sets the system variables of the mock target. +func (t *MockTargetInfo) SetSysVar(key string, value string) { + t.sysVarMap[key] = value +} + +// SetTableInfo sets the table structure information of the mock target. +func (t *MockTargetInfo) SetTableInfo(schemaName string, tableName string, tblInfo *MockTableInfo) { + if _, ok := t.dbTblInfoMap[schemaName]; !ok { + t.dbTblInfoMap[schemaName] = make(map[string]*MockTableInfo) + } + t.dbTblInfoMap[schemaName][tableName] = tblInfo +} + +// FetchRemoteTableModels fetches the table structures from the remote target. +// It implements the TargetInfoGetter interface. +func (t *MockTargetInfo) FetchRemoteTableModels(ctx context.Context, schemaName string) ([]*model.TableInfo, error) { + resultInfos := []*model.TableInfo{} + tblMap, ok := t.dbTblInfoMap[schemaName] + if !ok { + return resultInfos, nil + } + for _, tblInfo := range tblMap { + resultInfos = append(resultInfos, tblInfo.TableModel) + } + return resultInfos, nil +} + +// GetTargetSysVariablesForImport gets some important systam variables for importing on the target. +// It implements the TargetInfoGetter interface. +func (t *MockTargetInfo) GetTargetSysVariablesForImport(ctx context.Context) map[string]string { + result := make(map[string]string) + for k, v := range t.sysVarMap { + result[k] = v + } + return result +} + +// GetReplicationConfig gets the replication config on the target. +// It implements the TargetInfoGetter interface. +func (t *MockTargetInfo) GetReplicationConfig(ctx context.Context) (*pdtypes.ReplicationConfig, error) { + return &pdtypes.ReplicationConfig{ + MaxReplicas: uint64(t.MaxReplicasPerRegion), + }, nil +} + +// GetStorageInfo gets the storage information on the target. +// It implements the TargetInfoGetter interface. +func (t *MockTargetInfo) GetStorageInfo(ctx context.Context) (*pdtypes.StoresInfo, error) { + resultStoreInfos := make([]*pdtypes.StoreInfo, len(t.StorageInfos)) + for i, storeInfo := range t.StorageInfos { + resultStoreInfos[i] = &pdtypes.StoreInfo{ + Status: &pdtypes.StoreStatus{ + Capacity: pdtypes.ByteSize(storeInfo.TotalSize), + Available: pdtypes.ByteSize(storeInfo.AvailableSize), + UsedSize: pdtypes.ByteSize(storeInfo.UsedSize), + }, + } + } + return &pdtypes.StoresInfo{ + Count: len(resultStoreInfos), + Stores: resultStoreInfos, + }, nil +} + +// GetEmptyRegionsInfo gets the region information of all the empty regions on the target. +// It implements the TargetInfoGetter interface. +func (t *MockTargetInfo) GetEmptyRegionsInfo(ctx context.Context) (*pdtypes.RegionsInfo, error) { + regions := make([]pdtypes.RegionInfo, t.EmptyRegionCount) + for i := 0; i < t.EmptyRegionCount; i++ { + regions[i] = pdtypes.RegionInfo{} + } + return &pdtypes.RegionsInfo{ + Count: t.EmptyRegionCount, + Regions: regions, + }, nil +} + +// IsTableEmpty checks whether the specified table on the target DB contains data or not. +// It implements the TargetInfoGetter interface. +func (t *MockTargetInfo) IsTableEmpty(ctx context.Context, schemaName string, tableName string) (*bool, error) { + var result bool + tblInfoMap, ok := t.dbTblInfoMap[schemaName] + if !ok { + result = true + return &result, nil + } + tblInfo, ok := tblInfoMap[tableName] + if !ok { + result = true + return &result, nil + } + result = (tblInfo.RowCount == 0) + return &result, nil +} + +// CheckVersionRequirements performs the check whether the target satisfies the version requirements. +// It implements the TargetInfoGetter interface. +func (t *MockTargetInfo) CheckVersionRequirements(ctx context.Context) error { + return nil +} diff --git a/br/pkg/lightning/restore/mock/mock_test.go b/br/pkg/lightning/restore/mock/mock_test.go new file mode 100644 index 0000000000000..9c5b2b0cad6e0 --- /dev/null +++ b/br/pkg/lightning/restore/mock/mock_test.go @@ -0,0 +1,200 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "bytes" + "context" + "testing" + + "github.com/pingcap/tidb/br/pkg/lightning/restore" + "github.com/pingcap/tidb/parser/model" + "github.com/stretchr/testify/require" +) + +func TestMockImportSourceBasic(t *testing.T) { + mockDataMap := map[string]*MockDBSourceData{ + "db01": { + Name: "db01", + Tables: map[string]*MockTableSourceData{ + "tbl01": { + DBName: "db01", + TableName: "tbl01", + SchemaFile: &MockSourceFile{ + FileName: "/db01/tbl01/tbl01.schema.sql", + Data: []byte("CREATE TABLE db01.tbl01(id INTEGER PRIMARY KEY AUTO_INCREMENT, strval VARCHAR(64))"), + }, + }, + "tbl02": { + DBName: "db01", + TableName: "tbl02", + SchemaFile: &MockSourceFile{ + FileName: "/db01/tbl02/tbl02.schema.sql", + Data: []byte("CREATE TABLE db01.tbl02(id INTEGER PRIMARY KEY AUTO_INCREMENT, val VARCHAR(64))"), + }, + DataFiles: []*MockSourceFile{ + { + FileName: "/db01/tbl02/tbl02.data.csv", + Data: []byte("val\naaa\nbbb"), + }, + { + FileName: "/db01/tbl02/tbl02.data.sql", + Data: []byte("INSERT INTO db01.tbl02 (val) VALUES ('ccc');"), + }, + }, + }, + }, + }, + "db02": { + Name: "db02", + Tables: map[string]*MockTableSourceData{ + "tbl01": { + DBName: "db02", + TableName: "tbl01", + SchemaFile: &MockSourceFile{ + FileName: "/db02/tbl01/tbl01.schema.sql", + Data: []byte("CREATE TABLE db02.tbl01(id INTEGER PRIMARY KEY AUTO_INCREMENT, strval VARCHAR(64))"), + }, + }, + }, + }, + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mockEnv, err := NewMockImportSource(mockDataMap) + require.Nil(t, err) + dbFileMetas := mockEnv.GetAllDBFileMetas() + require.Equal(t, len(mockDataMap), len(dbFileMetas), "compare db count") + for _, dbFileMeta := range dbFileMetas { + dbMockData, ok := mockDataMap[dbFileMeta.Name] + require.Truef(t, ok, "get mock data by DB: %s", dbFileMeta.Name) + require.Equalf(t, len(dbMockData.Tables), len(dbFileMeta.Tables), "compare table count: %s", dbFileMeta.Name) + for _, tblFileMeta := range dbFileMeta.Tables { + tblMockData, ok := dbMockData.Tables[tblFileMeta.Name] + require.Truef(t, ok, "get mock data by Table: %s.%s", dbFileMeta.Name, tblFileMeta.Name) + schemaFileMeta := tblFileMeta.SchemaFile + mockSchemaFile := tblMockData.SchemaFile + fileData, err := mockEnv.srcStorage.ReadFile(ctx, schemaFileMeta.FileMeta.Path) + require.Nilf(t, err, "read schema file: %s.%s", dbFileMeta.Name, tblFileMeta.Name) + require.Truef(t, bytes.Equal(mockSchemaFile.Data, fileData), "compare schema file: %s.%s", dbFileMeta.Name, tblFileMeta.Name) + require.Equalf(t, len(tblMockData.DataFiles), len(tblFileMeta.DataFiles), "compare data file count: %s.%s", dbFileMeta.Name, tblFileMeta.Name) + for i, dataFileMeta := range tblFileMeta.DataFiles { + mockDataFile := tblMockData.DataFiles[i] + fileData, err := mockEnv.srcStorage.ReadFile(ctx, dataFileMeta.FileMeta.Path) + require.Nilf(t, err, "read data file: %s.%s: %s", dbFileMeta.Name, tblFileMeta.Name, dataFileMeta.FileMeta.Path) + require.Truef(t, bytes.Equal(mockDataFile.Data, fileData), "compare data file: %s.%s: %s", dbFileMeta.Name, tblFileMeta.Name, dataFileMeta.FileMeta.Path) + } + } + } +} + +func TestMockTargetInfoBasic(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + ti := NewMockTargetInfo() + var _ restore.TargetInfoGetter = ti + const replicaCount = 3 + const emptyRegionCount = 5 + const s01TotalSize uint64 = 10 << 30 + const s01UsedSize uint64 = 7<<30 + 500<<20 + const s02TotalSize uint64 = 50 << 30 + const s02UsedSize uint64 = 35<<30 + 700<<20 + + ti.SetSysVar("aaa", "111") + ti.SetSysVar("bbb", "222") + sysVars := ti.GetTargetSysVariablesForImport(ctx) + v, ok := sysVars["aaa"] + require.True(t, ok) + require.Equal(t, "111", v) + v, ok = sysVars["bbb"] + require.True(t, ok) + require.Equal(t, "222", v) + + ti.MaxReplicasPerRegion = replicaCount + rcfg, err := ti.GetReplicationConfig(ctx) + require.NoError(t, err) + require.Equal(t, uint64(replicaCount), rcfg.MaxReplicas) + + ti.StorageInfos = append(ti.StorageInfos, + StorageInfo{ + TotalSize: s01TotalSize, + UsedSize: s01UsedSize, + AvailableSize: s01TotalSize - s01UsedSize, + }, + StorageInfo{ + TotalSize: s02TotalSize, + UsedSize: s02UsedSize, + AvailableSize: s02TotalSize - s02UsedSize, + }, + ) + si, err := ti.GetStorageInfo(ctx) + require.NoError(t, err) + require.Equal(t, 2, si.Count) + store := si.Stores[0] + require.Equal(t, s01TotalSize, uint64(store.Status.Capacity)) + require.Equal(t, s01UsedSize, uint64(store.Status.UsedSize)) + store = si.Stores[1] + require.Equal(t, s02TotalSize, uint64(store.Status.Capacity)) + require.Equal(t, s02UsedSize, uint64(store.Status.UsedSize)) + + ti.EmptyRegionCount = emptyRegionCount + ri, err := ti.GetEmptyRegionsInfo(ctx) + require.NoError(t, err) + require.Equal(t, emptyRegionCount, ri.Count) + require.Equal(t, emptyRegionCount, len(ri.Regions)) + + ti.SetTableInfo("testdb", "testtbl1", + &MockTableInfo{ + TableModel: &model.TableInfo{ + ID: 1, + Name: model.NewCIStr("testtbl1"), + Columns: []*model.ColumnInfo{ + { + ID: 1, + Name: model.NewCIStr("c_1"), + Offset: 0, + }, + { + ID: 2, + Name: model.NewCIStr("c_2"), + Offset: 1, + }, + }, + }, + }, + ) + ti.SetTableInfo("testdb", "testtbl2", + &MockTableInfo{ + RowCount: 100, + }, + ) + tblInfos, err := ti.FetchRemoteTableModels(ctx, "testdb") + require.NoError(t, err) + require.Equal(t, 2, len(tblInfos)) + for _, tblInfo := range tblInfos { + if tblInfo == nil { + continue + } + require.Equal(t, 2, len(tblInfo.Columns)) + } + + isEmptyPtr, err := ti.IsTableEmpty(ctx, "testdb", "testtbl1") + require.NoError(t, err) + require.NotNil(t, isEmptyPtr) + require.True(t, *isEmptyPtr) + isEmptyPtr, err = ti.IsTableEmpty(ctx, "testdb", "testtbl2") + require.NoError(t, err) + require.NotNil(t, isEmptyPtr) + require.False(t, *isEmptyPtr) +} diff --git a/br/pkg/lightning/restore/restore.go b/br/pkg/lightning/restore/restore.go index a570b57ce0abf..0de2061e4a4f6 100644 --- a/br/pkg/lightning/restore/restore.go +++ b/br/pkg/lightning/restore/restore.go @@ -63,7 +63,6 @@ import ( "go.uber.org/atomic" "go.uber.org/multierr" "go.uber.org/zap" - "golang.org/x/exp/maps" ) const ( @@ -224,6 +223,8 @@ type Controller struct { diskQuotaState atomic.Int32 compactState atomic.Int32 status *LightningStatus + + preInfoGetter PreRestoreInfoGetter } type LightningStatus struct { @@ -365,6 +366,24 @@ func NewRestoreControllerWithPauser( default: metaBuilder = noopMetaMgrBuilder{} } + ioWorkers := worker.NewPool(ctx, cfg.App.IOConcurrency, "io") + targetInfoGetter := &TargetInfoGetterImpl{ + cfg: cfg, + targetDBGlue: p.Glue, + tls: tls, + backend: backend, + } + preInfoGetter, err := NewPreRestoreInfoGetter( + cfg, + p.DBMetas, + p.DumpFileStorage, + targetInfoGetter, + ioWorkers, + backend, + ) + if err != nil { + return nil, errors.Trace(err) + } rc := &Controller{ taskCtx: ctx, @@ -373,7 +392,7 @@ func NewRestoreControllerWithPauser( tableWorkers: nil, indexWorkers: nil, regionWorkers: worker.NewPool(ctx, cfg.App.RegionConcurrency, "region"), - ioWorkers: worker.NewPool(ctx, cfg.App.IOConcurrency, "io"), + ioWorkers: ioWorkers, checksumWorks: worker.NewPool(ctx, cfg.TiDB.ChecksumTableConcurrency, "checksum"), pauser: p.Pauser, backend: backend, @@ -393,6 +412,8 @@ func NewRestoreControllerWithPauser( errorMgr: errorMgr, status: p.Status, taskMgr: nil, + + preInfoGetter: preInfoGetter, } return rc, nil @@ -713,22 +734,18 @@ func (rc *Controller) restoreSchema(ctx context.Context) error { for i := 0; i < concurrency; i++ { go worker.doJob() } - getTableFunc := rc.backend.FetchRemoteTableModels - if !rc.tidbGlue.OwnsSQLExecutor() { - getTableFunc = rc.tidbGlue.GetTables - } - err := worker.makeJobs(rc.dbMetas, getTableFunc) + err := worker.makeJobs(rc.dbMetas, rc.preInfoGetter.FetchRemoteTableModels) logTask.End(zap.ErrorLevel, err) if err != nil { return err } - dbInfos, err := LoadSchemaInfo(ctx, rc.dbMetas, getTableFunc) + dbInfos, err := rc.preInfoGetter.GetAllTableStructures(ctx) if err != nil { return errors.Trace(err) } // For local backend, we need DBInfo.ID to operate the global autoid allocator. - if rc.isLocalBackend() { + if isLocalBackend(rc.cfg) { dbs, err := tikv.FetchRemoteDBModelsFromTLS(ctx, rc.tls) if err != nil { return errors.Trace(err) @@ -742,11 +759,7 @@ func (rc *Controller) restoreSchema(ctx context.Context) error { } } rc.dbInfos = dbInfos - - sysVars := ObtainImportantVariables(ctx, rc.tidbGlue.GetSQLExecutor(), !rc.isTiDBBackend()) - // override by manually set vars - maps.Copy(sysVars, rc.cfg.TiDB.Vars) - rc.sysVars = sysVars + rc.sysVars = rc.preInfoGetter.GetTargetSysVariablesForImport(ctx) return nil } @@ -1398,7 +1411,7 @@ func (rc *Controller) restoreTables(ctx context.Context) (finalErr error) { postProgress := func() error { return nil } var kvStore tidbkv.Storage - if rc.isLocalBackend() { + if isLocalBackend(rc.cfg) { var ( restoreFn pdutil.UndoFunc err error @@ -1654,7 +1667,7 @@ func (tr *TableRestore) restoreTable( versionInfo := version.ParseServerInfo(versionStr) // "show table next_row_id" is only available after tidb v4.0.0 - if versionInfo.ServerVersion.Major >= 4 && rc.isLocalBackend() { + if versionInfo.ServerVersion.Major >= 4 && isLocalBackend(rc.cfg) { // first, insert a new-line into meta table if err = metaMgr.InitTableMeta(ctx); err != nil { return false, err @@ -1764,7 +1777,7 @@ func (rc *Controller) switchToNormalMode(ctx context.Context) { func (rc *Controller) switchTiKVMode(ctx context.Context, mode sstpb.SwitchMode) { // // tidb backend don't need to switch tikv to import mode - if rc.isTiDBBackend() { + if isTiDBBackend(rc.cfg) { return } @@ -1883,7 +1896,7 @@ func (rc *Controller) enforceDiskQuota(ctx context.Context) { func (rc *Controller) setGlobalVariables(ctx context.Context) error { // skip for tidb backend to be compatible with MySQL - if rc.isTiDBBackend() { + if isTiDBBackend(rc.cfg) { return nil } // set new collation flag base on tidb config @@ -1932,12 +1945,12 @@ func (rc *Controller) cleanCheckpoints(ctx context.Context) error { return nil } -func (rc *Controller) isLocalBackend() bool { - return rc.cfg.TikvImporter.Backend == config.BackendLocal +func isLocalBackend(cfg *config.Config) bool { + return cfg.TikvImporter.Backend == config.BackendLocal } -func (rc *Controller) isTiDBBackend() bool { - return rc.cfg.TikvImporter.Backend == config.BackendTiDB +func isTiDBBackend(cfg *config.Config) bool { + return cfg.TikvImporter.Backend == config.BackendTiDB } // preCheckRequirements checks @@ -1947,6 +1960,8 @@ func (rc *Controller) isTiDBBackend() bool { // 4. Lightning configuration // before restore tables start. func (rc *Controller) preCheckRequirements(ctx context.Context) error { + ctx = WithPreInfoGetterSysVarsCache(ctx, rc.sysVars) + ctx = WithPreInfoGetterTableStructuresCache(ctx, rc.dbInfos) if err := rc.DataCheck(ctx); err != nil { return errors.Trace(err) } @@ -1968,11 +1983,22 @@ func (rc *Controller) preCheckRequirements(ctx context.Context) error { // We still need to sample source data even if this task has existed, because we need to judge whether the // source is in order as row key to decide how to sort local data. - source, err := rc.estimateSourceData(ctx) + estimatedSizeResult, err := rc.preInfoGetter.EstimateSourceDataSize(ctx) if err != nil { return common.ErrCheckDataSource.Wrap(err).GenWithStackByArgs() } - if rc.isLocalBackend() { + estimatedDataSizeWithIndex := estimatedSizeResult.SizeWithIndex + + // Do not import with too large concurrency because these data may be all unsorted. + if estimatedSizeResult.HasUnsortedBigTables { + if rc.cfg.App.TableConcurrency > rc.cfg.App.IndexConcurrency { + rc.cfg.App.TableConcurrency = rc.cfg.App.IndexConcurrency + } + } + if rc.status != nil { + rc.status.TotalFileSize.Store(estimatedSizeResult.SizeWithoutIndex) + } + if isLocalBackend(rc.cfg) { pdController, err := pdutil.NewPdController(ctx, rc.cfg.TiDB.PdAddr, rc.tls.TLSConfig(), rc.tls.ToPDSecurityOption()) if err != nil { @@ -1986,7 +2012,7 @@ func (rc *Controller) preCheckRequirements(ctx context.Context) error { return common.ErrMetaMgrUnknown.Wrap(err).GenWithStackByArgs() } if !taskExist { - if err = rc.taskMgr.InitTask(ctx, source); err != nil { + if err = rc.taskMgr.InitTask(ctx, estimatedDataSizeWithIndex); err != nil { return common.ErrMetaMgrUnknown.Wrap(err).GenWithStackByArgs() } } @@ -2002,11 +2028,11 @@ func (rc *Controller) preCheckRequirements(ctx context.Context) error { needCheck = taskCheckpoints == nil } if needCheck { - err = rc.localResource(ctx, source) + err = rc.localResource(ctx, estimatedDataSizeWithIndex) if err != nil { return common.ErrCheckLocalResource.Wrap(err).GenWithStackByArgs() } - if err := rc.clusterResource(ctx, source); err != nil { + if err := rc.clusterResource(ctx, estimatedDataSizeWithIndex); err != nil { if err1 := rc.taskMgr.CleanupTask(ctx); err1 != nil { log.FromContext(ctx).Warn("cleanup task failed", zap.Error(err1)) return common.ErrMetaMgrUnknown.Wrap(err).GenWithStackByArgs() @@ -2335,7 +2361,7 @@ func (cr *chunkRestore) deliverLoop( // can safely update current checkpoint. failpoint.Inject("LocalBackendSaveCheckpoint", func() { - if !rc.isLocalBackend() && (dataChecksum.SumKVS() != 0 || indexChecksum.SumKVS() != 0) { + if !isLocalBackend(rc.cfg) && (dataChecksum.SumKVS() != 0 || indexChecksum.SumKVS() != 0) { // No need to save checkpoint if nothing was delivered. saveCheckpoint(rc, t, engineID, cr.chunk) } diff --git a/br/pkg/lightning/restore/restore_schema_test.go b/br/pkg/lightning/restore/restore_schema_test.go index d7a585f0c10a3..a0f962abec6d2 100644 --- a/br/pkg/lightning/restore/restore_schema_test.go +++ b/br/pkg/lightning/restore/restore_schema_test.go @@ -42,10 +42,12 @@ import ( type restoreSchemaSuite struct { suite.Suite - ctx context.Context - rc *Controller - controller *gomock.Controller - tableInfos []*model.TableInfo + ctx context.Context + rc *Controller + controller *gomock.Controller + tableInfos []*model.TableInfo + infoGetter *PreRestoreInfoGetterImpl + targetInfoGetter *TargetInfoGetterImpl } func TestRestoreSchemaSuite(t *testing.T) { @@ -103,14 +105,29 @@ func (s *restoreSchemaSuite) SetupSuite() { config.Mydumper.CharacterSet = "utf8mb4" config.App.RegionConcurrency = 8 mydumpLoader, err := mydump.NewMyDumpLoaderWithStore(ctx, config, store) - require.NoError(s.T(), err) + s.Require().NoError(err) + + dbMetas := mydumpLoader.GetDatabases() + targetInfoGetter := &TargetInfoGetterImpl{ + cfg: config, + } + preInfoGetter := &PreRestoreInfoGetterImpl{ + cfg: config, + srcStorage: store, + targetInfoGetter: targetInfoGetter, + dbMetas: dbMetas, + } + preInfoGetter.Init() s.rc = &Controller{ checkTemplate: NewSimpleTemplate(), cfg: config, store: store, - dbMetas: mydumpLoader.GetDatabases(), + dbMetas: dbMetas, checkpointsDB: &checkpoints.NullCheckpointsDB{}, + preInfoGetter: preInfoGetter, } + s.infoGetter = preInfoGetter + s.targetInfoGetter = targetInfoGetter } //nolint:interfacer // change test case signature might cause Check failed to find this test case? @@ -122,7 +139,9 @@ func (s *restoreSchemaSuite) SetupTest() { AnyTimes(). Return(s.tableInfos, nil) mockBackend.EXPECT().Close() - s.rc.backend = backend.MakeBackend(mockBackend) + theBackend := backend.MakeBackend(mockBackend) + s.rc.backend = theBackend + s.targetInfoGetter.backend = theBackend mockDB, sqlMock, err := sqlmock.New() require.NoError(s.T(), err) @@ -140,6 +159,7 @@ func (s *restoreSchemaSuite) SetupTest() { GetParser(). AnyTimes(). Return(parser) + s.targetInfoGetter.targetDBGlue = mockTiDBGlue s.rc.tidbGlue = mockTiDBGlue } @@ -152,6 +172,7 @@ func (s *restoreSchemaSuite) TearDownTest() { AnyTimes(). Return(exec) s.rc.tidbGlue = mockTiDBGlue + s.targetInfoGetter.targetDBGlue = mockTiDBGlue s.rc.Close() s.controller.Finish() @@ -213,6 +234,7 @@ func (s *restoreSchemaSuite) TestRestoreSchemaFailed() { AnyTimes(). Return(parser) s.rc.tidbGlue = mockTiDBGlue + s.targetInfoGetter.targetDBGlue = mockTiDBGlue err = s.rc.restoreSchema(s.ctx) require.Error(s.T(), err) require.True(s.T(), errors.ErrorEqual(err, injectErr)) @@ -268,6 +290,7 @@ func (s *restoreSchemaSuite) TestRestoreSchemaContextCancel() { AnyTimes(). Return(parser) s.rc.tidbGlue = mockTiDBGlue + s.targetInfoGetter.targetDBGlue = mockTiDBGlue err = s.rc.restoreSchema(childCtx) cancel() require.Error(s.T(), err) diff --git a/br/pkg/lightning/restore/restore_test.go b/br/pkg/lightning/restore/restore_test.go index e4d4420ead983..6c210b052d1dd 100644 --- a/br/pkg/lightning/restore/restore_test.go +++ b/br/pkg/lightning/restore/restore_test.go @@ -29,6 +29,7 @@ import ( "github.com/pingcap/tidb/br/pkg/lightning/errormanager" "github.com/pingcap/tidb/br/pkg/lightning/glue" "github.com/pingcap/tidb/br/pkg/lightning/log" + "github.com/pingcap/tidb/br/pkg/lightning/mydump" "github.com/pingcap/tidb/br/pkg/version/build" "github.com/pingcap/tidb/ddl" "github.com/pingcap/tidb/parser" @@ -213,6 +214,15 @@ func TestPreCheckFailed(t *testing.T) { require.NoError(t, err) g := glue.NewExternalTiDBGlue(db, mysql.ModeNone) + targetInfoGetter := &TargetInfoGetterImpl{ + cfg: cfg, + targetDBGlue: g, + } + preInfoGetter := &PreRestoreInfoGetterImpl{ + cfg: cfg, + targetInfoGetter: targetInfoGetter, + dbMetas: make([]*mydump.MDDatabaseMeta, 0), + } ctl := &Controller{ cfg: cfg, saveCpCh: make(chan saveCp), @@ -221,6 +231,7 @@ func TestPreCheckFailed(t *testing.T) { checkTemplate: NewSimpleTemplate(), tidbGlue: g, errorMgr: errormanager.New(nil, cfg, log.L()), + preInfoGetter: preInfoGetter, } mock.ExpectBegin() diff --git a/br/pkg/lightning/restore/table_restore.go b/br/pkg/lightning/restore/table_restore.go index adad30e5eaa42..b32c5e82b7345 100644 --- a/br/pkg/lightning/restore/table_restore.go +++ b/br/pkg/lightning/restore/table_restore.go @@ -619,14 +619,14 @@ func (tr *TableRestore) restoreEngine( // in local mode, this check-point make no sense, because we don't do flush now, // so there may be data lose if exit at here. So we don't write this checkpoint // here like other mode. - if !rc.isLocalBackend() { + if !isLocalBackend(rc.cfg) { if saveCpErr := rc.saveStatusCheckpoint(ctx, tr.tableName, engineID, err, checkpoints.CheckpointStatusAllWritten); saveCpErr != nil { return nil, errors.Trace(firstErr(err, saveCpErr)) } } if err != nil { // if process is canceled, we should flush all chunk checkpoints for local backend - if rc.isLocalBackend() && common.IsContextCanceledError(err) { + if isLocalBackend(rc.cfg) && common.IsContextCanceledError(err) { // ctx is canceled, so to avoid Close engine failed, we use `context.Background()` here if _, err2 := dataEngine.Close(context.Background(), dataEngineCfg); err2 != nil { log.FromContext(ctx).Warn("flush all chunk checkpoints failed before manually exits", zap.Error(err2)) @@ -642,7 +642,7 @@ func (tr *TableRestore) restoreEngine( closedDataEngine, err := dataEngine.Close(ctx, dataEngineCfg) // For local backend, if checkpoint is enabled, we must flush index engine to avoid data loss. // this flush action impact up to 10% of the performance, so we only do it if necessary. - if err == nil && rc.cfg.Checkpoint.Enable && rc.isLocalBackend() { + if err == nil && rc.cfg.Checkpoint.Enable && isLocalBackend(rc.cfg) { if err = indexEngine.Flush(ctx); err != nil { return nil, errors.Trace(err) } diff --git a/br/pkg/lightning/restore/table_restore_test.go b/br/pkg/lightning/restore/table_restore_test.go index fb3a82adaf69a..de58ea4a39d43 100644 --- a/br/pkg/lightning/restore/table_restore_test.go +++ b/br/pkg/lightning/restore/table_restore_test.go @@ -47,6 +47,7 @@ import ( "github.com/pingcap/tidb/br/pkg/lightning/log" "github.com/pingcap/tidb/br/pkg/lightning/metric" "github.com/pingcap/tidb/br/pkg/lightning/mydump" + restoremock "github.com/pingcap/tidb/br/pkg/lightning/restore/mock" "github.com/pingcap/tidb/br/pkg/lightning/verification" "github.com/pingcap/tidb/br/pkg/lightning/web" "github.com/pingcap/tidb/br/pkg/lightning/worker" @@ -909,19 +910,34 @@ func (s *tableRestoreSuite) TestTableRestoreMetrics() { cpDB := checkpoints.NewNullCheckpointsDB() g := mock.NewMockGlue(controller) - rc := &Controller{ - cfg: cfg, - dbMetas: []*mydump.MDDatabaseMeta{ - { - Name: s.tableInfo.DB, - Tables: []*mydump.MDTableMeta{s.tableMeta}, - }, - }, - dbInfos: map[string]*checkpoints.TidbDBInfo{ - s.tableInfo.DB: s.dbInfo, + dbMetas := []*mydump.MDDatabaseMeta{ + { + Name: s.tableInfo.DB, + Tables: []*mydump.MDTableMeta{s.tableMeta}, }, + } + ioWorkers := worker.NewPool(ctx, 5, "io") + targetInfoGetter := &TargetInfoGetterImpl{ + cfg: cfg, + targetDBGlue: g, + } + preInfoGetter := &PreRestoreInfoGetterImpl{ + cfg: cfg, + dbMetas: dbMetas, + targetInfoGetter: targetInfoGetter, + srcStorage: s.store, + ioWorkers: ioWorkers, + } + preInfoGetter.Init() + dbInfos := map[string]*checkpoints.TidbDBInfo{ + s.tableInfo.DB: s.dbInfo, + } + rc := &Controller{ + cfg: cfg, + dbMetas: dbMetas, + dbInfos: dbInfos, tableWorkers: worker.NewPool(ctx, 6, "table"), - ioWorkers: worker.NewPool(ctx, 5, "io"), + ioWorkers: ioWorkers, indexWorkers: worker.NewPool(ctx, 2, "index"), regionWorkers: worker.NewPool(ctx, 10, "region"), checksumWorks: worker.NewPool(ctx, 2, "region"), @@ -937,6 +953,7 @@ func (s *tableRestoreSuite) TestTableRestoreMetrics() { metaMgrBuilder: noopMetaMgrBuilder{}, errorMgr: errormanager.New(nil, cfg, log.L()), taskMgr: noopTaskMetaMgr{}, + preInfoGetter: preInfoGetter, } go func() { for scp := range chptCh { @@ -1097,7 +1114,22 @@ func (s *tableRestoreSuite) TestCheckClusterResource() { url := strings.TrimPrefix(server.URL, "https://") cfg := &config.Config{TiDB: config.DBStore{PdAddr: url}} - rc := &Controller{cfg: cfg, tls: tls, store: mockStore, checkTemplate: template} + targetInfoGetter := &TargetInfoGetterImpl{ + cfg: cfg, + tls: tls, + } + preInfoGetter := &PreRestoreInfoGetterImpl{ + cfg: cfg, + targetInfoGetter: targetInfoGetter, + srcStorage: mockStore, + } + rc := &Controller{ + cfg: cfg, + tls: tls, + store: mockStore, + checkTemplate: template, + preInfoGetter: preInfoGetter, + } var sourceSize int64 err = rc.store.WalkDir(ctx, &storage.WalkOption{}, func(path string, size int64) error { sourceSize += size @@ -1224,9 +1256,27 @@ func (s *tableRestoreSuite) TestCheckClusterRegion() { url := strings.TrimPrefix(server.URL, "https://") cfg := &config.Config{TiDB: config.DBStore{PdAddr: url}} - rc := &Controller{cfg: cfg, tls: tls, taskMgr: mockTaskMetaMgr{}, checkTemplate: template} - err := rc.checkClusterRegion(context.Background()) + targetInfoGetter := &TargetInfoGetterImpl{ + cfg: cfg, + tls: tls, + } + preInfoGetter := &PreRestoreInfoGetterImpl{ + cfg: cfg, + targetInfoGetter: targetInfoGetter, + dbMetas: []*mydump.MDDatabaseMeta{}, + } + rc := &Controller{ + cfg: cfg, + tls: tls, + taskMgr: mockTaskMetaMgr{}, + checkTemplate: template, + preInfoGetter: preInfoGetter, + dbInfos: make(map[string]*checkpoints.TidbDBInfo), + } + + ctx := WithPreInfoGetterTableStructuresCache(context.Background(), rc.dbInfos) + err := rc.checkClusterRegion(ctx) require.NoError(s.T(), err) require.Equal(s.T(), ca.expectErrorCnt, template.FailedCount(Critical)) require.Equal(s.T(), ca.expectResult, template.Success()) @@ -1333,32 +1383,48 @@ func (s *tableRestoreSuite) TestEstimate() { s.cfg.TikvImporter.Backend = config.BackendLocal template := NewSimpleTemplate() + dbMetas := []*mydump.MDDatabaseMeta{ + { + Name: "db1", + Tables: []*mydump.MDTableMeta{s.tableMeta}, + }, + } + dbInfos := map[string]*checkpoints.TidbDBInfo{ + "db1": s.dbInfo, + } + ioWorkers := worker.NewPool(context.Background(), 1, "io") + mockTarget := restoremock.NewMockTargetInfo() + + preInfoGetter := &PreRestoreInfoGetterImpl{ + cfg: s.cfg, + srcStorage: s.store, + encBuilder: importer, + ioWorkers: ioWorkers, + dbMetas: dbMetas, + targetInfoGetter: mockTarget, + } + preInfoGetter.Init() rc := &Controller{ cfg: s.cfg, checkTemplate: template, store: s.store, backend: importer, - dbMetas: []*mydump.MDDatabaseMeta{ - { - Name: "db1", - Tables: []*mydump.MDTableMeta{s.tableMeta}, - }, - }, - dbInfos: map[string]*checkpoints.TidbDBInfo{ - "db1": s.dbInfo, - }, - ioWorkers: worker.NewPool(context.Background(), 1, "io"), + dbMetas: dbMetas, + dbInfos: dbInfos, + ioWorkers: ioWorkers, + preInfoGetter: preInfoGetter, } - source, err := rc.estimateSourceData(ctx) + ctx = WithPreInfoGetterTableStructuresCache(ctx, dbInfos) + source, _, _, err := rc.estimateSourceData(ctx) // Because this file is small than region split size so we does not sample it. require.NoError(s.T(), err) require.Equal(s.T(), s.tableMeta.TotalSize, source) s.tableMeta.TotalSize = int64(config.SplitRegionSize) - source, err = rc.estimateSourceData(ctx) + source, _, _, err = rc.estimateSourceData(ctx) require.NoError(s.T(), err) require.Greater(s.T(), source, s.tableMeta.TotalSize) rc.cfg.TikvImporter.Backend = config.BackendTiDB - source, err = rc.estimateSourceData(ctx) + source, _, _, err = rc.estimateSourceData(ctx) require.NoError(s.T(), err) require.Equal(s.T(), s.tableMeta.TotalSize, source) } @@ -1758,13 +1824,21 @@ func (s *tableRestoreSuite) TestSchemaIsValid() { IgnoreColumns: ca.ignoreColumns, }, } + ioWorkers := worker.NewPool(context.Background(), 1, "io") + preInfoGetter := &PreRestoreInfoGetterImpl{ + cfg: cfg, + srcStorage: mockStore, + ioWorkers: ioWorkers, + } rc := &Controller{ cfg: cfg, checkTemplate: template, store: mockStore, dbInfos: ca.dbInfos, - ioWorkers: worker.NewPool(context.Background(), 1, "io"), + ioWorkers: ioWorkers, + preInfoGetter: preInfoGetter, } + ctx = WithPreInfoGetterTableStructuresCache(ctx, ca.dbInfos) msgs, err := rc.SchemaIsValid(ctx, ca.tableMeta) require.NoError(s.T(), err) require.Len(s.T(), msgs, ca.MsgNum) @@ -1804,36 +1878,45 @@ func (s *tableRestoreSuite) TestGBKEncodedSchemaIsValid() { err = mockStore.WriteFile(ctx, csvFile, []byte(csvContent)) require.NoError(s.T(), err) - rc := &Controller{ - cfg: cfg, - checkTemplate: NewSimpleTemplate(), - store: mockStore, - dbInfos: map[string]*checkpoints.TidbDBInfo{ - "db1": { - Name: "db1", - Tables: map[string]*checkpoints.TidbTableInfo{ - "gbk_table": { - ID: 1, - DB: "db1", - Name: "gbk_table", - Core: &model.TableInfo{ - Columns: []*model.ColumnInfo{ - { - Name: model.NewCIStr("colA"), - FieldType: types.NewFieldTypeBuilder().SetType(0).SetFlag(1).Build(), - }, - { - Name: model.NewCIStr("colB"), - FieldType: types.NewFieldTypeBuilder().SetType(0).SetFlag(1).Build(), - }, + dbInfos := map[string]*checkpoints.TidbDBInfo{ + "db1": { + Name: "db1", + Tables: map[string]*checkpoints.TidbTableInfo{ + "gbk_table": { + ID: 1, + DB: "db1", + Name: "gbk_table", + Core: &model.TableInfo{ + Columns: []*model.ColumnInfo{ + { + Name: model.NewCIStr("colA"), + FieldType: types.NewFieldTypeBuilder().SetType(0).SetFlag(1).Build(), + }, + { + Name: model.NewCIStr("colB"), + FieldType: types.NewFieldTypeBuilder().SetType(0).SetFlag(1).Build(), }, }, }, }, }, }, - ioWorkers: worker.NewPool(ctx, 1, "io"), } + ioWorkers := worker.NewPool(ctx, 1, "io") + preInfoGetter := &PreRestoreInfoGetterImpl{ + cfg: cfg, + srcStorage: mockStore, + ioWorkers: ioWorkers, + } + rc := &Controller{ + cfg: cfg, + checkTemplate: NewSimpleTemplate(), + store: mockStore, + dbInfos: dbInfos, + ioWorkers: ioWorkers, + preInfoGetter: preInfoGetter, + } + ctx = WithPreInfoGetterTableStructuresCache(ctx, dbInfos) msgs, err := rc.SchemaIsValid(ctx, &mydump.MDTableMeta{ DB: "db1", Name: "gbk_table",