Skip to content

Commit

Permalink
Unit tests for move tables and copy progress
Browse files Browse the repository at this point in the history
Signed-off-by: Rohit Nayak <[email protected]>
  • Loading branch information
rohit-nayak-ps committed Dec 26, 2020
1 parent 6d8f21c commit 3fce881
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 42 deletions.
10 changes: 10 additions & 0 deletions go/vt/wrangler/fake_dbclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ package wrangler
import (
"fmt"
"regexp"
"strings"
"testing"

"github.com/stretchr/testify/assert"
"vitess.io/vitess/go/vt/log"

"vitess.io/vitess/go/sqltypes"
)
Expand All @@ -46,6 +48,7 @@ type dbResult struct {

func (dbrs *dbResults) next(query string) (*sqltypes.Result, error) {
if dbrs.exhausted() {
log.Infof(fmt.Sprintf("Unexpected query >%s<", query))
return nil, fmt.Errorf("code executed this query, but the test did not expect it: %s", query)
}
i := dbrs.index
Expand Down Expand Up @@ -143,6 +146,13 @@ func (dc *fakeDBClient) ExecuteFetch(query string, maxrows int) (qr *sqltypes.Re
if result := dc.invariants[query]; result != nil {
return result, nil
}
for q, result := range dc.invariants { //supports allowing just a prefix of an expected query
if strings.Contains(query, q) {
return result, nil
}
}

log.Infof("Missing query: >%s<" + query)
return nil, fmt.Errorf("unexpected query: %s", query)
}

Expand Down
1 change: 0 additions & 1 deletion go/vt/wrangler/traffic_switcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,6 @@ func (wr *Wrangler) getCellsWithTableReadsSwitched(ctx context.Context, targetKe
return nil, nil, err
}
rules := srvVSchema.RoutingRules.Rules
log.Infof("Rules for srvVSchema for cell %s are %+v", cell, rules)
found := false
switched := false
for _, rule := range rules {
Expand Down
3 changes: 3 additions & 0 deletions go/vt/wrangler/traffic_switcher_env_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"testing"
"time"

"vitess.io/vitess/go/vt/log"

"vitess.io/vitess/go/mysql/fakesqldb"

"golang.org/x/net/context"
Expand Down Expand Up @@ -361,6 +363,7 @@ func (tme *testMigraterEnv) createDBClients(ctx context.Context, t *testing.T) {
master.TM.VREngine.Open(ctx)
}
for _, master := range tme.targetMasters {
log.Infof("Adding as targetMaster %s", master.Tablet.Alias)
dbclient := newFakeDBClient()
tme.dbTargetClients = append(tme.dbTargetClients, dbclient)
dbClientFactory := func() binlogplayer.DBClient { return dbclient }
Expand Down
51 changes: 25 additions & 26 deletions go/vt/wrangler/workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package wrangler
import (
"context"
"fmt"
"sort"
"strings"
"time"

Expand All @@ -14,17 +15,6 @@ import (
"vitess.io/vitess/go/vt/log"
)

/*
TODO
* expand e2e for testing all possible transitions
(Switch/Reverse Replica/Rdonly)
* Unit Tests (run coverage first and identify)
(CurrentState())
* dry run
*/

// VReplicationWorkflowType specifies whether workflow is MoveTables or Reshard
type VReplicationWorkflowType int

Expand Down Expand Up @@ -168,8 +158,11 @@ func (vrw *VReplicationWorkflow) stateAsString(ws *workflowState) string {
// Start initiates a workflow
func (vrw *VReplicationWorkflow) Start() error {
var err error
if vrw.Exists() {
return fmt.Errorf("workflow has already been started")
if !vrw.Exists() {
return fmt.Errorf("workflow now found")
}
if vrw.CachedState() != WorkflowStateNotStarted {
return fmt.Errorf("workflow has already been started, state is %s", vrw.CachedState())
}
switch vrw.workflowType {
case MoveTablesWorkflow:
Expand Down Expand Up @@ -375,29 +368,34 @@ func (vrw *VReplicationWorkflow) GetCopyProgress() (*CopyProgress, error) {
}
qr := sqltypes.Proto3ToResult(p3qr)
for i := 0; i < len(p3qr.Rows); i++ {
tables[qr.Rows[0][0].ToString()] = true
tables[qr.Rows[i][0].ToString()] = true
}
sourcesi, err := vrw.wr.ts.GetShard(ctx, bls.Keyspace, bls.Shard)
if err != nil {
return nil, err
}
sourceMasters[sourcesi.MasterAlias] = true
found := false
for existingSource := range sourceMasters {
if existingSource.Uid == sourcesi.MasterAlias.Uid {
found = true
}
}
if !found {
sourceMasters[sourcesi.MasterAlias] = true
}
}
}
if len(tables) == 0 {
return nil, nil
}
tableList := ""
var tableList []string
targetRowCounts := make(map[string]int64)
sourceRowCounts := make(map[string]int64)
targetTableSizes := make(map[string]int64)
sourceTableSizes := make(map[string]int64)

for table := range tables {
if tableList != "" {
tableList += ","
}
tableList += encodeString(table)
tableList = append(tableList, encodeString(table))
targetRowCounts[table] = 0
sourceRowCounts[table] = 0
targetTableSizes[table] = 0
Expand All @@ -411,12 +409,12 @@ func (vrw *VReplicationWorkflow) GetCopyProgress() (*CopyProgress, error) {
}
qr := sqltypes.Proto3ToResult(p3qr)
for i := 0; i < len(qr.Rows); i++ {
table := qr.Rows[0][0].ToString()
rowCount, err := evalengine.ToInt64(qr.Rows[0][1])
table := qr.Rows[i][0].ToString()
rowCount, err := evalengine.ToInt64(qr.Rows[i][1])
if err != nil {
return err
}
tableSize, err := evalengine.ToInt64(qr.Rows[0][2])
tableSize, err := evalengine.ToInt64(qr.Rows[i][2])
if err != nil {
return err
}
Expand All @@ -441,16 +439,17 @@ func (vrw *VReplicationWorkflow) GetCopyProgress() (*CopyProgress, error) {
if sourceDbName == "" || targetDbName == "" {
return nil, fmt.Errorf("workflow %s.%s is incorrectly configured", vrw.ws.TargetKeyspace, vrw.ws.Workflow)
}

query := fmt.Sprintf(getRowCountQuery, encodeString(targetDbName), tableList)
sort.Strings(tableList) // sort list for repeatability for mocking in tests
tablesStr := strings.Join(tableList, ",")
query := fmt.Sprintf(getRowCountQuery, encodeString(targetDbName), tablesStr)
for _, target := range vrw.ts.targets {
tablet := target.master.Tablet
if err := getTableMetrics(tablet, query, &targetRowCounts, &targetTableSizes); err != nil {
return nil, err
}
}

query = fmt.Sprintf(getRowCountQuery, encodeString(sourceDbName), tableList)
query = fmt.Sprintf(getRowCountQuery, encodeString(sourceDbName), tablesStr)
for source := range sourceMasters {
ti, err := vrw.wr.ts.GetTablet(ctx, source)
tablet := ti.Tablet
Expand Down
164 changes: 149 additions & 15 deletions go/vt/wrangler/workflow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,33 +16,32 @@ limitations under the License.

package wrangler

//FIXME: update test for recent changes
/*
import (
"testing"

"github.com/stretchr/testify/require"
"golang.org/x/net/context"
"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/vt/log"
"vitess.io/vitess/go/vt/proto/topodata"
)

func getMoveTablesWorkflow(t *testing.T, cells, tabletTypes string) *VReplicationWorkflow {
mtp := &VReplicationWorkflowParams{
p := &VReplicationWorkflowParams{
Workflow: "wf1",
SourceKeyspace: "sourceks",
TargetKeyspace: "targetks",
Tables: "customer,corder",
Cells: cells,
TabletTypes: tabletTypes,
}
wf, _ := newWorkflow("wf1", "MoveTables")
mtwf := &VReplicationWorkflow{
ctx: context.Background(),
wf: wf,
wr: nil,
params: mtp,
ts: nil,
ws: nil,
workflowType: MoveTablesWorkflow,
ctx: context.Background(),
wr: nil,
params: p,
ts: nil,
ws: nil,
}
return mtwf
}
Expand All @@ -52,9 +51,9 @@ func TestReshardingWorkflowErrorsAndMisc(t *testing.T) {
require.False(t, mtwf.Exists())
mtwf.ws = &workflowState{}
require.True(t, mtwf.Exists())
require.Errorf(t, mtwf.Complete(), errWorkflowNotFullySwitched)
require.Errorf(t, mtwf.Complete(), ErrWorkflowNotFullySwitched)
mtwf.ws.WritesSwitched = true
require.Errorf(t, mtwf.Abort(), errWorkflowPartiallySwitched)
require.Errorf(t, mtwf.Abort(), ErrWorkflowPartiallySwitched)

require.ElementsMatch(t, mtwf.getCellsAsArray(), []string{"cell1", "cell2"})
require.ElementsMatch(t, mtwf.getTabletTypes(), []topodata.TabletType{topodata.TabletType_REPLICA, topodata.TabletType_RDONLY})
Expand All @@ -65,7 +64,8 @@ func TestReshardingWorkflowErrorsAndMisc(t *testing.T) {
require.False(t, hasMaster)

mtwf.params.TabletTypes = "replica,rdonly,master"
require.ElementsMatch(t, mtwf.getTabletTypes(), []topodata.TabletType{topodata.TabletType_REPLICA, topodata.TabletType_RDONLY, topodata.TabletType_MASTER})
require.ElementsMatch(t, mtwf.getTabletTypes(),
[]topodata.TabletType{topodata.TabletType_REPLICA, topodata.TabletType_RDONLY, topodata.TabletType_MASTER})

hasReplica, hasRdonly, hasMaster, err = mtwf.parseTabletTypes()
require.NoError(t, err)
Expand All @@ -74,6 +74,140 @@ func TestReshardingWorkflowErrorsAndMisc(t *testing.T) {
require.True(t, hasMaster)
}

func TestReshardingWorkflowCurrentState(t *testing.T) {
func TestCopyProgress(t *testing.T) {
var err error
var wf *VReplicationWorkflow
ctx := context.Background()
p := &VReplicationWorkflowParams{
Workflow: "test",
SourceKeyspace: "ks1",
TargetKeyspace: "ks2",
Tables: "t1,t2",
Cells: "cell1,cell2",
TabletTypes: "replica,rdonly,master",
Timeout: DefaultActionTimeout,
}
tme := newTestTableMigrater(ctx, t)
defer tme.stopTablets(t)
wf, err = tme.wr.NewVReplicationWorkflow(ctx, MoveTablesWorkflow, p)
require.NoError(t, err)
require.NotNil(t, wf)
require.Equal(t, WorkflowStateNotSwitched, wf.CurrentState())

expectCopyProgressQueries(t, tme)

cp, err2 := wf.GetCopyProgress()
require.NoError(t, err2)
log.Infof("CopyProgress is %+v,%+v", (*cp)["t1"], (*cp)["t2"])

require.Equal(t, int64(800), (*cp)["t1"].SourceRowCount)
require.Equal(t, int64(200), (*cp)["t1"].TargetRowCount)
require.Equal(t, int64(4000), (*cp)["t1"].SourceTableSize)
require.Equal(t, int64(2000), (*cp)["t1"].TargetTableSize)

require.Equal(t, int64(2000), (*cp)["t2"].SourceRowCount)
require.Equal(t, int64(400), (*cp)["t2"].TargetRowCount)
require.Equal(t, int64(4000), (*cp)["t2"].SourceTableSize)
require.Equal(t, int64(1000), (*cp)["t2"].TargetTableSize)
}

func expectCopyProgressQueries(t *testing.T, tme *testMigraterEnv) {
db := tme.tmeDB
query := "select table_name from _vt.copy_state cs, _vt.vreplication vr where vr.id = cs.vrepl_id and vr.id = 1"
rows := []string{"t1", "t2"}
result := sqltypes.MakeTestResult(sqltypes.MakeTestFields(
"table_name",
"varchar"),
rows...)
db.AddQuery(query, result)
query = "select table_name from _vt.copy_state cs, _vt.vreplication vr where vr.id = cs.vrepl_id and vr.id = 2"
db.AddQuery(query, result)

query = "select table_name, table_rows, data_length from information_schema.tables where table_schema = 'vt_ks2' and table_name in ('t1','t2')"
result = sqltypes.MakeTestResult(sqltypes.MakeTestFields(
"table_name|table_rows|data_length",
"varchar|int64|int64"),
"t1|100|1000",
"t2|200|500")
db.AddQuery(query, result)

query = "select table_name, table_rows, data_length from information_schema.tables where table_schema = 'vt_ks1' and table_name in ('t1','t2')"
result = sqltypes.MakeTestResult(sqltypes.MakeTestFields(
"table_name|table_rows|data_length",
"varchar|int64|int64"),
"t1|400|2000",
"t2|1000|2000")
db.AddQuery(query, result)

}
func TestMoveTablesV2(t *testing.T) {
ctx := context.Background()
p := &VReplicationWorkflowParams{
Workflow: "test",
SourceKeyspace: "ks1",
TargetKeyspace: "ks2",
Tables: "t1,t2",
Cells: "cell1,cell2",
TabletTypes: "replica,rdonly,master",
Timeout: DefaultActionTimeout,
}
tme := newTestTableMigrater(ctx, t)
defer tme.stopTablets(t)
wf, err := tme.wr.NewVReplicationWorkflow(ctx, MoveTablesWorkflow, p)
require.NoError(t, err)
require.NotNil(t, wf)
require.Equal(t, WorkflowStateNotSwitched, wf.CurrentState())
tme.expectNoPreviousJournals()
expectMoveTablesQueries(t, tme)
tme.expectNoPreviousJournals()
require.NoError(t, wf.SwitchTraffic(DirectionForward))
require.Equal(t, WorkflowStateAllSwitched, wf.CurrentState())
require.NoError(t, wf.Complete())
}

func expectMoveTablesQueries(t *testing.T, tme *testMigraterEnv) {
var query string
//var result *sqltypes.Result
noResult := &sqltypes.Result{}
for _, dbclient := range tme.dbTargetClients {
query = "update _vt.vreplication set state = 'Running', message = '' where id in (1)"
dbclient.addInvariant(query, noResult)
dbclient.addInvariant("select id from _vt.vreplication where db_name = 'vt_ks2' and workflow = 'test'", resultid1)
dbclient.addInvariant("select * from _vt.vreplication where id = 1", runningResult(1))
dbclient.addInvariant("select * from _vt.vreplication where id = 2", runningResult(2))
query = "update _vt.vreplication set message='Picked source tablet: cell:\"cell1\" uid:10 ' where id=1"
dbclient.addInvariant(query, noResult)
dbclient.addInvariant("select id from _vt.vreplication where id = 1", resultid1)
dbclient.addInvariant("select id from _vt.vreplication where id = 2", resultid2)
dbclient.addInvariant("update _vt.vreplication set state = 'Stopped', message = 'stopped for cutover' where id in (1)", noResult)
dbclient.addInvariant("update _vt.vreplication set state = 'Stopped', message = 'stopped for cutover' where id in (2)", noResult)
dbclient.addInvariant("insert into _vt.vreplication (workflow, source, pos, max_tps, max_replication_lag, time_updated, transaction_timestamp, state, db_name)", &sqltypes.Result{InsertID: uint64(1)})
dbclient.addInvariant("update _vt.vreplication set message = 'FROZEN'", noResult)
dbclient.addInvariant("select 1 from _vt.vreplication where db_name='vt_ks2' and workflow='test' and message!='FROZEN'", noResult)
dbclient.addInvariant("delete from _vt.vreplication where id in (1)", noResult)
dbclient.addInvariant("delete from _vt.copy_state where vrepl_id in (1)", noResult)

//
}

for _, dbclient := range tme.dbSourceClients {
dbclient.addInvariant("select id from _vt.vreplication where db_name = 'vt_ks1' and workflow = 'test_reverse'", resultid1)
dbclient.addInvariant("delete from _vt.vreplication where id in (1)", noResult)
dbclient.addInvariant("delete from _vt.copy_state where vrepl_id in (1)", noResult)
dbclient.addInvariant("insert into _vt.vreplication (workflow, source, pos, max_tps, max_replication_lag, time_updated, transaction_timestamp, state, db_name)", &sqltypes.Result{InsertID: uint64(1)})
dbclient.addInvariant("select * from _vt.vreplication where id = 1", runningResult(1))
dbclient.addInvariant("select * from _vt.vreplication where id = 2", runningResult(2))
dbclient.addInvariant("insert into _vt.resharding_journal", noResult)
}
state := sqltypes.MakeTestResult(sqltypes.MakeTestFields(
"pos|state|message",
"varchar|varchar|varchar"),
"MariaDB/5-456-892|Running",
)
tme.dbTargetClients[0].addQuery("select pos, state, message from _vt.vreplication where id=1", state, nil)
tme.dbTargetClients[0].addQuery("select pos, state, message from _vt.vreplication where id=2", state, nil)
tme.dbTargetClients[1].addQuery("select pos, state, message from _vt.vreplication where id=1", state, nil)
tme.dbTargetClients[1].addQuery("select pos, state, message from _vt.vreplication where id=2", state, nil)
tme.tmeDB.AddQueryPattern("drop table vt_ks1.t1", &sqltypes.Result{})
tme.tmeDB.AddQueryPattern("drop table vt_ks1.t2", &sqltypes.Result{})
}
*/

0 comments on commit 3fce881

Please sign in to comment.