Skip to content
This repository has been archived by the owner on Jul 24, 2024. It is now read-only.

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
restore: remove/add pd scheduler in restore
Browse files Browse the repository at this point in the history
3pointer committed Dec 19, 2019

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent cd6a176 commit 96cdc05
Showing 5 changed files with 299 additions and 318 deletions.
23 changes: 11 additions & 12 deletions cmd/backup.go
Original file line number Diff line number Diff line change
@@ -21,8 +21,6 @@ const (
flagBackupRateLimit = "ratelimit"
flagBackupConcurrency = "concurrency"
flagBackupChecksum = "checksum"
flagBackupDB = "db"
flagBackupTable = "table"
)

func defineBackupFlags(flagSet *pflag.FlagSet) {
@@ -91,6 +89,8 @@ func runBackup(flagSet *pflag.FlagSet, cmdName, db, table string) error {
return err
}

defer summary.Summary(cmdName)

ranges, backupSchemas, err := backup.BuildBackupRangeAndSchema(
mgr.GetDomain(), mgr.GetTiKV(), backupTS, db, table)
if err != nil {
@@ -151,7 +151,6 @@ func runBackup(flagSet *pflag.FlagSet, cmdName, db, table string) error {
if err != nil {
return err
}
summary.Summary(cmdName)
return nil
}

@@ -205,7 +204,7 @@ func newDbBackupCommand() *cobra.Command {
Use: "db",
Short: "backup a database",
RunE: func(command *cobra.Command, _ []string) error {
db, err := command.Flags().GetString(flagBackupDB)
db, err := command.Flags().GetString(flagDatabase)
if err != nil {
return err
}
@@ -215,8 +214,8 @@ func newDbBackupCommand() *cobra.Command {
return runBackup(command.Flags(), "Database backup", db, "")
},
}
command.Flags().StringP(flagBackupDB, "", "", "backup a table in the specific db")
_ = command.MarkFlagRequired(flagBackupDB)
command.Flags().StringP(flagDatabase, "", "", "backup a table in the specific db")
_ = command.MarkFlagRequired(flagDatabase)

return command
}
@@ -227,14 +226,14 @@ func newTableBackupCommand() *cobra.Command {
Use: "table",
Short: "backup a table",
RunE: func(command *cobra.Command, _ []string) error {
db, err := command.Flags().GetString(flagBackupDB)
db, err := command.Flags().GetString(flagDatabase)
if err != nil {
return err
}
if len(db) == 0 {
return errors.Errorf("empty database name is not allowed")
}
table, err := command.Flags().GetString(flagBackupTable)
table, err := command.Flags().GetString(flagTable)
if err != nil {
return err
}
@@ -244,9 +243,9 @@ func newTableBackupCommand() *cobra.Command {
return runBackup(command.Flags(), "Table backup", db, table)
},
}
command.Flags().StringP(flagBackupDB, "", "", "backup a table in the specific db")
command.Flags().StringP(flagBackupTable, "t", "", "backup the specific table")
_ = command.MarkFlagRequired(flagBackupDB)
_ = command.MarkFlagRequired(flagBackupTable)
command.Flags().StringP(flagDatabase, "", "", "backup a table in the specific db")
command.Flags().StringP(flagTable, "t", "", "backup the specific table")
_ = command.MarkFlagRequired(flagDatabase)
_ = command.MarkFlagRequired(flagTable)
return command
}
4 changes: 4 additions & 0 deletions cmd/cmd.go
Original file line number Diff line number Diff line change
@@ -50,6 +50,10 @@ const (
FlagStatusAddr = "status-addr"
// FlagSlowLogFile is the name of slow-log-file flag.
FlagSlowLogFile = "slow-log-file"

flagDatabase = "db"

flagTable = "table"
)

// AddFlags adds flags to the given cmd.
494 changes: 201 additions & 293 deletions cmd/restore.go
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@ package cmd

import (
"context"
"fmt"
"strings"

"github.com/gogo/protobuf/proto"
@@ -13,12 +14,23 @@ import (
flag "github.com/spf13/pflag"
"go.uber.org/zap"

"github.com/pingcap/br/pkg/conn"
"github.com/pingcap/br/pkg/restore"
"github.com/pingcap/br/pkg/storage"
"github.com/pingcap/br/pkg/summary"
"github.com/pingcap/br/pkg/utils"
)

var schedulers = map[string]struct{}{
"balance-leader-scheduler": {},
"balance-hot-region-scheduler": {},
"balance-region-scheduler": {},

"shuffle-leader-scheduler": {},
"shuffle-region-scheduler": {},
"shuffle-hot-region-scheduler": {},
}

// NewRestoreCommand returns a restore subcommand
func NewRestoreCommand() *cobra.Command {
command := &cobra.Command{
@@ -56,108 +68,141 @@ func NewRestoreCommand() *cobra.Command {
return command
}

func newFullRestoreCommand() *cobra.Command {
command := &cobra.Command{
Use: "full",
Short: "restore all tables",
RunE: func(cmd *cobra.Command, _ []string) error {
pdAddr, err := cmd.Flags().GetString(FlagPD)
if err != nil {
return errors.Trace(err)
}
ctx, cancel := context.WithCancel(GetDefaultContext())
defer cancel()
func runRestore(flagSet *flag.FlagSet, cmdName, dbName, tableName string) error {
pdAddr, err := flagSet.GetString(FlagPD)
if err != nil {
return errors.Trace(err)
}
ctx, cancel := context.WithCancel(GetDefaultContext())
defer cancel()

mgr, err := GetDefaultMgr()
if err != nil {
return err
}
defer mgr.Close()
mgr, err := GetDefaultMgr()
if err != nil {
return err
}
defer mgr.Close()

client, err := restore.NewRestoreClient(
ctx, mgr.GetPDClient(), mgr.GetTiKV())
if err != nil {
return errors.Trace(err)
}
defer client.Close()
err = initRestoreClient(ctx, client, cmd.Flags())
if err != nil {
return errors.Trace(err)
}
client, err := restore.NewRestoreClient(
ctx, mgr.GetPDClient(), mgr.GetTiKV())
if err != nil {
return errors.Trace(err)
}
defer client.Close()
err = initRestoreClient(ctx, client, flagSet)
if err != nil {
return errors.Trace(err)
}

files := make([]*backup.File, 0)
tables := make([]*utils.Table, 0)
for _, db := range client.GetDatabases() {
err = client.CreateDatabase(db.Schema)
if err != nil {
return errors.Trace(err)
}
for _, table := range db.Tables {
files = append(files, table.Files...)
}
tables = append(tables, db.Tables...)
}
files := make([]*backup.File, 0)
tables := make([]*utils.Table, 0)

defer summary.Summary("Restore full")
defer summary.Summary(cmdName)

summary.CollectInt("restore files", len(files))
rewriteRules, newTables, err := client.CreateTables(mgr.GetDomain(), tables)
switch {
case len(dbName) == 0 && len(tableName) == 0:
// full restore
for _, db := range client.GetDatabases() {
err = client.CreateDatabase(db.Schema)
if err != nil {
return errors.Trace(err)
}
ranges, err := restore.ValidateFileRanges(files, rewriteRules)
if err != nil {
return err
for _, table := range db.Tables {
files = append(files, table.Files...)
}
summary.CollectInt("restore ranges", len(ranges))
tables = append(tables, db.Tables...)
}
case len(dbName) != 0 && len(tableName) == 0:
// database restore
db := client.GetDatabase(dbName)
err = client.CreateDatabase(db.Schema)
if err != nil {
return errors.Trace(err)
}
for _, table := range db.Tables {
files = append(files, table.Files...)
}
tables = db.Tables
case len(dbName) != 0 && len(tableName) != 0:
// table restore
db := client.GetDatabase(dbName)
err = client.CreateDatabase(db.Schema)
if err != nil {
return errors.Trace(err)
}
table := db.GetTable(tableName)
files = table.Files
tables = utils.Tables{table}
default:
return errors.New("must set db when table was set")
}

// Redirect to log if there is no log file to avoid unreadable output.
updateCh := utils.StartProgress(
ctx,
"Full Restore",
// Split/Scatter + Download/Ingest
int64(len(ranges)+len(files)),
!HasLogFile())
summary.CollectInt("restore files", len(files))
rewriteRules, newTables, err := client.CreateTables(mgr.GetDomain(), tables)
if err != nil {
return errors.Trace(err)
}
ranges, err := restore.ValidateFileRanges(files, rewriteRules)
if err != nil {
return err
}
summary.CollectInt("restore ranges", len(ranges))

err = restore.SplitRanges(ctx, client, ranges, rewriteRules, updateCh)
if err != nil {
log.Error("split regions failed", zap.Error(err))
return errors.Trace(err)
}
pdAddrs := strings.Split(pdAddr, ",")
err = client.ResetTS(pdAddrs)
if err != nil {
log.Error("reset pd TS failed", zap.Error(err))
return errors.Trace(err)
}
// Redirect to log if there is no log file to avoid unreadable output.
updateCh := utils.StartProgress(
ctx,
cmdName,
// Split/Scatter + Download/Ingest
int64(len(ranges)+len(files)),
!HasLogFile())

err = client.SwitchToImportModeIfOffline(ctx)
if err != nil {
return errors.Trace(err)
}
err = restore.SplitRanges(ctx, client, ranges, rewriteRules, updateCh)
if err != nil {
log.Error("split regions failed", zap.Error(err))
return errors.Trace(err)
}
pdAddrs := strings.Split(pdAddr, ",")
err = client.ResetTS(pdAddrs)
if err != nil {
log.Error("reset pd TS failed", zap.Error(err))
return errors.Trace(err)
}

err = client.RestoreAll(rewriteRules, updateCh)
if err != nil {
return errors.Trace(err)
}
removedSchedulers, err := RestorePrepareWork(ctx, client, mgr)
if err != nil {
return errors.Trace(err)
}

err = client.SwitchToNormalModeIfOffline(ctx)
if err != nil {
return errors.Trace(err)
}
// Restore has finished.
close(updateCh)

// Checksum
updateCh = utils.StartProgress(
ctx, "Checksum", int64(len(newTables)), !HasLogFile())
err = client.ValidateChecksum(
ctx, mgr.GetTiKV().GetClient(), tables, newTables, updateCh)
if err != nil {
return err
}
close(updateCh)
return nil
err = client.RestoreAll(rewriteRules, updateCh)
if err != nil {
return errors.Trace(err)
}

err = RestorePostWork(ctx, client, mgr, removedSchedulers)
if err != nil {
return errors.Trace(err)
}
// Restore has finished.
close(updateCh)

// Checksum
updateCh = utils.StartProgress(
ctx, "Checksum", int64(len(newTables)), !HasLogFile())
err = client.ValidateChecksum(
ctx, mgr.GetTiKV().GetClient(), tables, newTables, updateCh)
if err != nil {
return err
}
close(updateCh)

return nil
}

func newFullRestoreCommand() *cobra.Command {
command := &cobra.Command{
Use: "full",
Short: "restore all tables",
RunE: func(cmd *cobra.Command, _ []string) error {
return runRestore(cmd.Flags(), "Full Restore", "", "")
},
}
return command
@@ -168,113 +213,18 @@ func newDbRestoreCommand() *cobra.Command {
Use: "db",
Short: "restore tables in a database",
RunE: func(cmd *cobra.Command, _ []string) error {
pdAddr, err := cmd.Flags().GetString(FlagPD)
if err != nil {
return errors.Trace(err)
}
ctx, cancel := context.WithCancel(GetDefaultContext())
defer cancel()

mgr, err := GetDefaultMgr()
db, err := cmd.Flags().GetString(flagDatabase)
if err != nil {
return err
}
defer mgr.Close()

client, err := restore.NewRestoreClient(
ctx, mgr.GetPDClient(), mgr.GetTiKV())
if err != nil {
return errors.Trace(err)
if len(db) == 0 {
return errors.Errorf("empty database name is not allowed")
}
defer client.Close()
err = initRestoreClient(ctx, client, cmd.Flags())
if err != nil {
return errors.Trace(err)
}

dbName, err := cmd.Flags().GetString("db")
if err != nil {
return errors.Trace(err)
}
db := client.GetDatabase(dbName)
if db == nil {
return errors.New("not exists database")
}
err = client.CreateDatabase(db.Schema)
if err != nil {
return errors.Trace(err)
}

rewriteRules, newTables, err := client.CreateTables(mgr.GetDomain(), db.Tables)
if err != nil {
return errors.Trace(err)
}
files := make([]*backup.File, 0)
for _, table := range db.Tables {
files = append(files, table.Files...)
}

defer summary.Summary("Restore database")

summary.CollectInt("restore files", len(files))
ranges, err := restore.ValidateFileRanges(files, rewriteRules)
if err != nil {
return err
}
summary.CollectInt("restore ranges", len(ranges))
// Redirect to log if there is no log file to avoid unreadable output.
updateCh := utils.StartProgress(
ctx,
"Database Restore",
// Split/Scatter + Download/Ingest
int64(len(ranges)+len(files)),
!HasLogFile())

err = restore.SplitRanges(ctx, client, ranges, rewriteRules, updateCh)
if err != nil {
log.Error("split regions failed", zap.Error(err))
return errors.Trace(err)
}
pdAddrs := strings.Split(pdAddr, ",")
err = client.ResetTS(pdAddrs)
if err != nil {
log.Error("reset pd TS failed", zap.Error(err))
return errors.Trace(err)
}

err = client.SwitchToImportModeIfOffline(ctx)
if err != nil {
return errors.Trace(err)
}

err = client.RestoreDatabase(
db, rewriteRules, updateCh)
if err != nil {
return errors.Trace(err)
}

err = client.SwitchToNormalModeIfOffline(ctx)
if err != nil {
return errors.Trace(err)
}
// Checksum
updateCh = utils.StartProgress(
ctx, "Checksum", int64(len(newTables)), !HasLogFile())
err = client.ValidateChecksum(
ctx, mgr.GetTiKV().GetClient(), db.Tables, newTables, updateCh)
if err != nil {
return err
}
close(updateCh)
return nil
return runRestore(cmd.Flags(), "Database Restore", db, "")
},
}
command.Flags().String("db", "", "database name")

if err := command.MarkFlagRequired("db"); err != nil {
panic(err)
}

command.Flags().String(flagDatabase, "", "database name")
_ = command.MarkFlagRequired(flagDatabase)
return command
}

@@ -283,122 +233,29 @@ func newTableRestoreCommand() *cobra.Command {
Use: "table",
Short: "restore a table",
RunE: func(cmd *cobra.Command, _ []string) error {
pdAddr, err := cmd.Flags().GetString(FlagPD)
if err != nil {
return errors.Trace(err)
}
ctx, cancel := context.WithCancel(GetDefaultContext())
defer cancel()

mgr, err := GetDefaultMgr()
db, err := cmd.Flags().GetString(flagDatabase)
if err != nil {
return err
}
defer mgr.Close()

client, err := restore.NewRestoreClient(
ctx, mgr.GetPDClient(), mgr.GetTiKV())
if err != nil {
return errors.Trace(err)
}
defer client.Close()
err = initRestoreClient(ctx, client, cmd.Flags())
if err != nil {
return errors.Trace(err)
}

dbName, err := cmd.Flags().GetString("db")
if err != nil {
return errors.Trace(err)
}
db := client.GetDatabase(dbName)
if db == nil {
return errors.New("not exists database")
if len(db) == 0 {
return errors.Errorf("empty database name is not allowed")
}
err = client.CreateDatabase(db.Schema)
if err != nil {
return errors.Trace(err)
}

tableName, err := cmd.Flags().GetString("table")
if err != nil {
return errors.Trace(err)
}
table := db.GetTable(tableName)
if table == nil {
return errors.New("not exists table")
}
// The rules here is raw key.
rewriteRules, newTables, err := client.CreateTables(mgr.GetDomain(), []*utils.Table{table})
if err != nil {
return errors.Trace(err)
}

defer summary.Summary("Restore table")

summary.CollectInt("restore files", len(table.Files))
ranges, err := restore.ValidateFileRanges(table.Files, rewriteRules)
table, err := cmd.Flags().GetString(flagTable)
if err != nil {
return err
}
summary.CollectInt("restore ranges", len(ranges))
// Redirect to log if there is no log file to avoid unreadable output.
updateCh := utils.StartProgress(
ctx,
"Table Restore",
// Split/Scatter + Download/Ingest
int64(len(ranges)+len(table.Files)),
!HasLogFile())

err = restore.SplitRanges(ctx, client, ranges, rewriteRules, updateCh)
if err != nil {
log.Error("split regions failed", zap.Error(err))
return errors.Trace(err)
}
pdAddrs := strings.Split(pdAddr, ",")
err = client.ResetTS(pdAddrs)
if err != nil {
log.Error("reset pd TS failed", zap.Error(err))
return errors.Trace(err)
}
err = client.SwitchToImportModeIfOffline(ctx)
if err != nil {
return errors.Trace(err)
}
err = client.RestoreTable(table, rewriteRules, updateCh)
if err != nil {
return errors.Trace(err)
}
err = client.SwitchToNormalModeIfOffline(ctx)
if err != nil {
return errors.Trace(err)
}
// Restore has finished.
close(updateCh)

// Checksum
updateCh = utils.StartProgress(
ctx, "Checksum", int64(len(newTables)), !HasLogFile())
err = client.ValidateChecksum(
ctx, mgr.GetTiKV().GetClient(), []*utils.Table{table}, newTables, updateCh)
if err != nil {
return err
if len(table) == 0 {
return errors.Errorf("empty table name is not allowed")
}
close(updateCh)
return nil
return runRestore(cmd.Flags(), "Table Restore", db, table)
},
}

command.Flags().String("db", "", "database name")
command.Flags().String("table", "", "table name")

if err := command.MarkFlagRequired("db"); err != nil {
panic(err)
}
if err := command.MarkFlagRequired("table"); err != nil {
panic(err)
}
command.Flags().String(flagDatabase, "", "database name")
command.Flags().String(flagTable, "", "table name")

_ = command.MarkFlagRequired(flagDatabase)
_ = command.MarkFlagRequired(flagTable)
return command
}

@@ -446,3 +303,54 @@ func initRestoreClient(ctx context.Context, client *restore.Client, flagSet *fla

return nil
}

// RestorePrepareWork execute some prepare work before restore
func RestorePrepareWork(ctx context.Context, client *restore.Client, mgr *conn.Mgr) ([]string, error) {
err := client.SwitchToImportModeIfOffline(ctx)
if err != nil {
return nil, errors.Trace(err)
}
existSchedulers, err := mgr.ListScheduler(ctx)
if err != nil {
return nil, errors.Trace(err)
}
needRemoveScheduler := make([]string, 0, len(existSchedulers))
for _, s := range existSchedulers {
if _, ok := schedulers[s]; ok {
needRemoveScheduler = append(needRemoveScheduler, s)
}
}
fmt.Println("need:", needRemoveScheduler)
return removePDLeaderScheduler(ctx, mgr, needRemoveScheduler)
}

func removePDLeaderScheduler(ctx context.Context, mgr *conn.Mgr, existSchedulers []string) ([]string, error) {
removedSchedulers := make([]string, 0, len(existSchedulers))
for _, scheduler := range existSchedulers {
err := mgr.RemoveScheduler(ctx, scheduler)
if err != nil {
return nil, err
}
removedSchedulers = append(removedSchedulers, scheduler)
}
return removedSchedulers, nil
}

// RestorePostWork execute some post work after restore
func RestorePostWork(ctx context.Context, client *restore.Client, mgr *conn.Mgr, removedSchedulers []string) error {
err := client.SwitchToNormalModeIfOffline(ctx)
if err != nil {
return errors.Trace(err)
}
return addPDLeaderScheduler(ctx, mgr, removedSchedulers)
}

func addPDLeaderScheduler(ctx context.Context, mgr *conn.Mgr, removedSchedulers []string) error {
for _, scheduler := range removedSchedulers {
_, err := mgr.AddScheduler(ctx, scheduler)
if err != nil {
return err
}
}
return nil
}
91 changes: 80 additions & 11 deletions pkg/conn/conn.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package conn

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
@@ -30,6 +32,7 @@ const (
dialTimeout = 5 * time.Second
clusterVersionPrefix = "pd/api/v1/config/cluster-version"
regionCountPrefix = "pd/api/v1/stats/region"
schdulerPrefix = "pd/api/v1/schedulers"
)

// Mgr manages connections to a TiDB cluster.
@@ -47,9 +50,12 @@ type Mgr struct {
}
}

type pdHTTPGet func(context.Context, string, string, *http.Client) ([]byte, error)
type pdHTTPRequest func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error)

func pdGet(ctx context.Context, addr string, prefix string, cli *http.Client) ([]byte, error) {
func pdRequest(
ctx context.Context,
addr string, prefix string,
cli *http.Client, method string, body io.Reader) ([]byte, error) {
if addr != "" && !strings.HasPrefix("http", addr) {
addr = "http://" + addr
}
@@ -58,11 +64,15 @@ func pdGet(ctx context.Context, addr string, prefix string, cli *http.Client) ([
return nil, errors.Trace(err)
}
url := fmt.Sprintf("%s/%s", u, prefix)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
var (
req *http.Request
resp *http.Response
)
req, err = http.NewRequestWithContext(ctx, method, url, body)
if err != nil {
return nil, errors.Trace(err)
}
resp, err := cli.Do(req)
resp, err = cli.Do(req)
if err != nil {
return nil, errors.Trace(err)
}
@@ -86,7 +96,7 @@ func NewMgr(ctx context.Context, pdAddrs string, storage tikv.Storage) (*Mgr, er
failure := errors.Errorf("pd address (%s) has wrong format", pdAddrs)
cli := &http.Client{Timeout: 30 * time.Second}
for _, addr := range addrs {
_, failure = pdGet(ctx, addr, clusterVersionPrefix, cli)
_, failure = pdRequest(ctx, addr, clusterVersionPrefix, cli, http.MethodGet, nil)
// TODO need check cluster version >= 3.1 when br release
if failure == nil {
break
@@ -152,13 +162,13 @@ func (mgr *Mgr) SetPDClient(pdClient pd.Client) {

// GetClusterVersion returns the current cluster version.
func (mgr *Mgr) GetClusterVersion(ctx context.Context) (string, error) {
return mgr.getClusterVersionWith(ctx, pdGet)
return mgr.getClusterVersionWith(ctx, pdRequest)
}

func (mgr *Mgr) getClusterVersionWith(ctx context.Context, get pdHTTPGet) (string, error) {
func (mgr *Mgr) getClusterVersionWith(ctx context.Context, get pdHTTPRequest) (string, error) {
var err error
for _, addr := range mgr.pdHTTP.addrs {
v, e := get(ctx, addr, clusterVersionPrefix, mgr.pdHTTP.cli)
v, e := get(ctx, addr, clusterVersionPrefix, mgr.pdHTTP.cli, http.MethodGet, nil)
if e != nil {
err = e
continue
@@ -171,19 +181,19 @@ func (mgr *Mgr) getClusterVersionWith(ctx context.Context, get pdHTTPGet) (strin

// GetRegionCount returns the region count in the specified range.
func (mgr *Mgr) GetRegionCount(ctx context.Context, startKey, endKey []byte) (int, error) {
return mgr.getRegionCountWith(ctx, pdGet, startKey, endKey)
return mgr.getRegionCountWith(ctx, pdRequest, startKey, endKey)
}

func (mgr *Mgr) getRegionCountWith(
ctx context.Context, get pdHTTPGet, startKey, endKey []byte,
ctx context.Context, get pdHTTPRequest, startKey, endKey []byte,
) (int, error) {
var err error
for _, addr := range mgr.pdHTTP.addrs {
query := fmt.Sprintf(
"%s?start_key=%s&end_key=%s",
regionCountPrefix,
url.QueryEscape(string(startKey)), url.QueryEscape(string(endKey)))
v, e := get(ctx, addr, query, mgr.pdHTTP.cli)
v, e := get(ctx, addr, query, mgr.pdHTTP.cli, http.MethodGet, nil)
if e != nil {
err = e
continue
@@ -266,6 +276,65 @@ func (mgr *Mgr) GetDomain() *domain.Domain {
return mgr.dom
}

// RemoveScheduler remove pd scheduler
func (mgr *Mgr) RemoveScheduler(ctx context.Context, scheduler string) error {
return mgr.removeSchedulerWith(ctx, scheduler, pdRequest)
}

func (mgr *Mgr) removeSchedulerWith(ctx context.Context, scheduler string, delete pdHTTPRequest) (err error) {
for _, addr := range mgr.pdHTTP.addrs {
prefix := fmt.Sprintf("%s/%s", schdulerPrefix, scheduler)
_, err = delete(ctx, addr, prefix, mgr.pdHTTP.cli, http.MethodDelete, nil)
if err != nil {
continue
}
return nil
}
return err
}

// AddScheduler add pd scheduler
func (mgr *Mgr) AddScheduler(ctx context.Context, scheduler string) (string, error) {
return mgr.addSchedulerWith(ctx, scheduler, pdRequest)
}

func (mgr *Mgr) addSchedulerWith(ctx context.Context, scheduler string, post pdHTTPRequest) (string, error) {
var err error
for _, addr := range mgr.pdHTTP.addrs {
body := bytes.NewBuffer([]byte(`{"name":"` + scheduler + `"}`))
v, e := post(ctx, addr, schdulerPrefix, mgr.pdHTTP.cli, http.MethodPost, body)
if e != nil {
err = e
continue
}
return string(v), nil
}
return "", err
}

// ListScheduler list all pd scheduler
func (mgr *Mgr) ListScheduler(ctx context.Context) ([]string, error) {
return mgr.listSchedulerWith(ctx, pdRequest)
}

func (mgr *Mgr) listSchedulerWith(ctx context.Context, get pdHTTPRequest) ([]string, error) {
var err error
for _, addr := range mgr.pdHTTP.addrs {
v, e := get(ctx, addr, schdulerPrefix, mgr.pdHTTP.cli, http.MethodGet, nil)
if e != nil {
err = e
continue
}
d := make([]string, 0)
err = json.Unmarshal(v, &d)
if err != nil {
return nil, err
}
return d, nil
}
return nil, err
}

// Close closes all client in Mgr.
func (mgr *Mgr) Close() {
mgr.grpcClis.mu.Lock()
5 changes: 3 additions & 2 deletions pkg/conn/conn_test.go
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@ package conn
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"

@@ -34,7 +35,7 @@ func (s *testClientSuite) TearDownSuite(c *C) {

func (s *testClientSuite) TestPDHTTP(c *C) {
ctx := context.Background()
mock := func(context.Context, string, string, *http.Client) ([]byte, error) {
mock := func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error) {
stats := statistics.RegionStats{Count: 6}
ret, err := json.Marshal(stats)
c.Assert(err, IsNil)
@@ -45,7 +46,7 @@ func (s *testClientSuite) TestPDHTTP(c *C) {
c.Assert(err, IsNil)
c.Assert(resp, Equals, 6)

mock = func(context.Context, string, string, *http.Client) ([]byte, error) {
mock = func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error) {
return []byte(`test`), nil
}
respString, err := s.mgr.getClusterVersionWith(ctx, mock)

0 comments on commit 96cdc05

Please sign in to comment.