diff --git a/go/vt/wrangler/traffic_switcher_env_test.go b/go/vt/wrangler/traffic_switcher_env_test.go index b71dab625a1..2c8531b4e9a 100644 --- a/go/vt/wrangler/traffic_switcher_env_test.go +++ b/go/vt/wrangler/traffic_switcher_env_test.go @@ -242,10 +242,11 @@ func newTestShardMigrater(ctx context.Context, t *testing.T, sourceShards, targe tme.wr = New(logutil.NewConsoleLogger(), tme.ts, tmclient.NewTabletManagerClient()) tme.sourceShards = sourceShards tme.targetShards = targetShards + tme.tmeDB = fakesqldb.New(t) tabletID := 10 for _, shard := range sourceShards { - tme.sourceMasters = append(tme.sourceMasters, newFakeTablet(t, tme.wr, "cell1", uint32(tabletID), topodatapb.TabletType_MASTER, nil, TabletKeyspaceShard(t, "ks", shard))) + tme.sourceMasters = append(tme.sourceMasters, newFakeTablet(t, tme.wr, "cell1", uint32(tabletID), topodatapb.TabletType_MASTER, tme.tmeDB, TabletKeyspaceShard(t, "ks", shard))) tabletID += 10 _, sourceKeyRange, err := topo.ValidateShardName(shard) @@ -261,7 +262,7 @@ func newTestShardMigrater(ctx context.Context, t *testing.T, sourceShards, targe } for _, shard := range targetShards { - tme.targetMasters = append(tme.targetMasters, newFakeTablet(t, tme.wr, "cell1", uint32(tabletID), topodatapb.TabletType_MASTER, nil, TabletKeyspaceShard(t, "ks", shard))) + tme.targetMasters = append(tme.targetMasters, newFakeTablet(t, tme.wr, "cell1", uint32(tabletID), topodatapb.TabletType_MASTER, tme.tmeDB, TabletKeyspaceShard(t, "ks", shard))) tabletID += 10 _, targetKeyRange, err := topo.ValidateShardName(shard) diff --git a/go/vt/wrangler/workflow_test.go b/go/vt/wrangler/workflow_test.go index 5c507dac41a..c8084cddb5b 100644 --- a/go/vt/wrangler/workflow_test.go +++ b/go/vt/wrangler/workflow_test.go @@ -415,7 +415,6 @@ func expectReshardQueries(t *testing.T, tme *testShardMigraterEnv) { dbclient.addInvariant("select * from _vt.vreplication where id = 1", runningResult(1)) dbclient.addInvariant("select * from _vt.vreplication where id = 2", runningResult(2)) dbclient.addInvariant("insert into _vt.resharding_journal", noResult) - } targetQueries := []string{ @@ -442,8 +441,10 @@ func expectReshardQueries(t *testing.T, tme *testShardMigraterEnv) { dbclient.addInvariant("update _vt.vreplication set message = 'FROZEN'", noResult) dbclient.addInvariant("delete from _vt.vreplication where id in (1)", noResult) dbclient.addInvariant("delete from _vt.copy_state where vrepl_id in (1)", noResult) - } + tme.tmeDB.AddQuery("select 1 from _vt.copy_state cs, _vt.vreplication vr where vr.id = cs.vrepl_id and vr.id = 1", noResult) + tme.tmeDB.AddQuery("select 1 from _vt.copy_state cs, _vt.vreplication vr where vr.id = cs.vrepl_id and vr.id = 2", noResult) + } func expectMoveTablesQueries(t *testing.T, tme *testMigraterEnv) { @@ -474,7 +475,6 @@ func expectMoveTablesQueries(t *testing.T, tme *testMigraterEnv) { "int64|varchar|varchar|varchar|varchar"), ""), ) - //select pos, state, message from _vt.vreplication where id=1 } for _, dbclient := range tme.dbSourceClients { @@ -517,4 +517,7 @@ func expectMoveTablesQueries(t *testing.T, tme *testMigraterEnv) { tme.tmeDB.AddQuery("drop table vt_ks2.t1", noResult) tme.tmeDB.AddQuery("drop table vt_ks2.t2", noResult) tme.tmeDB.AddQuery("update _vt.vreplication set message='Picked source tablet: cell:\"cell1\" uid:10 ' where id=1", noResult) + tme.tmeDB.AddQuery("select 1 from _vt.copy_state cs, _vt.vreplication vr where vr.id = cs.vrepl_id and vr.id = 1", noResult) + tme.tmeDB.AddQuery("select 1 from _vt.copy_state cs, _vt.vreplication vr where vr.id = cs.vrepl_id and vr.id = 2", noResult) + }