diff --git a/go/test/endtoend/vreplication/vstream_test.go b/go/test/endtoend/vreplication/vstream_test.go index d011b23dbec..70045e3c738 100644 --- a/go/test/endtoend/vreplication/vstream_test.go +++ b/go/test/endtoend/vreplication/vstream_test.go @@ -177,6 +177,7 @@ func testVStreamWithFailover(t *testing.T, failover bool) { const schemaUnsharded = ` create table customer_seq(id int, next_id bigint, cache bigint, primary key(id)) comment 'vitess_sequence'; +insert into customer_seq(id, next_id, cache) values(0, 1, 3); ` const vschemaUnsharded = ` { @@ -218,14 +219,18 @@ const vschemaSharded = ` func insertRow(keyspace, table string, id int) { vtgateConn.ExecuteFetch(fmt.Sprintf("use %s;", keyspace), 1000, false) vtgateConn.ExecuteFetch("begin", 1000, false) - vtgateConn.ExecuteFetch(fmt.Sprintf("insert into %s (cid, name) values (%d, '%s%d')", table, id+100, table, id), 1000, false) + _, err := vtgateConn.ExecuteFetch(fmt.Sprintf("insert into %s (name) values ('%s%d')", table, table, id), 1000, false) + if err != nil { + log.Infof("error inserting row %d: %v", id, err) + } vtgateConn.ExecuteFetch("commit", 1000, false) } type numEvents struct { - numRowEvents, numJournalEvents int64 - numLessThan80Events, numGreaterThan80Events int64 - numLessThan40Events, numGreaterThan40Events int64 + numRowEvents, numJournalEvents int64 + numLessThan80Events, numGreaterThan80Events int64 + numLessThan40Events, numGreaterThan40Events int64 + numShard0BeforeReshardEvents, numShard0AfterReshardEvents int64 } // tests the StopOnReshard flag @@ -376,6 +381,150 @@ func testVStreamStopOnReshardFlag(t *testing.T, stopOnReshard bool, baseTabletID return &ne } +// Validate that we can continue streaming from multiple keyspaces after first copying some tables and then resharding one of the keyspaces +// Ensure that there are no missing row events during the resharding process. +func testVStreamCopyMultiKeyspaceReshard(t *testing.T, baseTabletID int) numEvents { + defaultCellName := "zone1" + allCellNames = defaultCellName + allCells := []string{allCellNames} + vc = NewVitessCluster(t, "VStreamCopyMultiKeyspaceReshard", allCells, mainClusterConfig) + + require.NotNil(t, vc) + ogdr := defaultReplicas + defaultReplicas = 0 // because of CI resource constraints we can only run this test with primary tablets + defer func(dr int) { defaultReplicas = dr }(ogdr) + + defer vc.TearDown(t) + + defaultCell = vc.Cells[defaultCellName] + vc.AddKeyspace(t, []*Cell{defaultCell}, "unsharded", "0", vschemaUnsharded, schemaUnsharded, defaultReplicas, defaultRdonly, baseTabletID+100, nil) + vtgate = defaultCell.Vtgates[0] + require.NotNil(t, vtgate) + vtgate.WaitForStatusOfTabletInShard(fmt.Sprintf("%s.%s.primary", "unsharded", "0"), 1) + + vtgateConn = getConnection(t, vc.ClusterConfig.hostname, vc.ClusterConfig.vtgateMySQLPort) + defer vtgateConn.Close() + verifyClusterHealth(t, vc) + + vc.AddKeyspace(t, []*Cell{defaultCell}, "sharded", "-80,80-", vschemaSharded, schemaSharded, defaultReplicas, defaultRdonly, baseTabletID+200, nil) + + ctx := context.Background() + vstreamConn, err := vtgateconn.Dial(ctx, fmt.Sprintf("%s:%d", vc.ClusterConfig.hostname, vc.ClusterConfig.vtgateGrpcPort)) + if err != nil { + log.Fatal(err) + } + defer vstreamConn.Close() + vgtid := &binlogdatapb.VGtid{ + ShardGtids: []*binlogdatapb.ShardGtid{{ + Keyspace: "/.*", + }}} + + filter := &binlogdatapb.Filter{ + Rules: []*binlogdatapb.Rule{{ + // We want to confirm that the following two tables are streamed. + // 1. the customer_seq in the unsharded keyspace + // 2. the customer table in the sharded keyspace + Match: "/customer.*/", + }}, + } + flags := &vtgatepb.VStreamFlags{} + done := false + + id := 1000 + // First goroutine that keeps inserting rows into the table being streamed until a minute after reshard + // We should keep getting events on the new shards + go func() { + for { + if done { + return + } + id++ + time.Sleep(1 * time.Second) + insertRow("sharded", "customer", id) + } + }() + // stream events from the VStream API + var ne numEvents + reshardDone := false + go func() { + var reader vtgateconn.VStreamReader + reader, err = vstreamConn.VStream(ctx, topodatapb.TabletType_PRIMARY, vgtid, filter, flags) + require.NoError(t, err) + for { + evs, err := reader.Recv() + + switch err { + case nil: + for _, ev := range evs { + switch ev.Type { + case binlogdatapb.VEventType_ROW: + shard := ev.RowEvent.Shard + switch shard { + case "0": + if reshardDone { + ne.numShard0AfterReshardEvents++ + } else { + ne.numShard0BeforeReshardEvents++ + } + case "-80": + ne.numLessThan80Events++ + case "80-": + ne.numGreaterThan80Events++ + case "-40": + ne.numLessThan40Events++ + case "40-": + ne.numGreaterThan40Events++ + } + ne.numRowEvents++ + case binlogdatapb.VEventType_JOURNAL: + ne.numJournalEvents++ + } + } + case io.EOF: + log.Infof("Stream Ended") + done = true + default: + log.Errorf("Returned err %v", err) + done = true + } + if done { + return + } + } + }() + + ticker := time.NewTicker(1 * time.Second) + tickCount := 0 + for { + <-ticker.C + tickCount++ + switch tickCount { + case 1: + reshard(t, "sharded", "customer", "vstreamCopyMultiKeyspaceReshard", "-80,80-", "-40,40-", baseTabletID+400, nil, nil, nil, defaultCellName, 1) + reshardDone = true + case 60: + done = true + } + if done { + break + } + } + log.Infof("ne=%v", ne) + + // The number of row events streamed by the VStream API should match the number of rows inserted. + // This is important for sharded tables, where we need to ensure that no row events are missed during the resharding process. + // + // On the other hand, we don't verify the exact number of row events for the unsharded keyspace + // because the keyspace remains unsharded and the number of rows in the customer_seq table is always 1. + // We believe that checking the number of row events for the unsharded keyspace, which should always be greater than 0 before and after resharding, + // is sufficient to confirm that the resharding of one keyspace does not affect another keyspace, while keeping the test straightforward. + customerResult := execVtgateQuery(t, vtgateConn, "sharded", "select count(*) from customer") + insertedCustomerRows, err := evalengine.ToInt64(customerResult.Rows[0][0]) + require.NoError(t, err) + require.Equal(t, insertedCustomerRows, ne.numLessThan80Events+ne.numGreaterThan80Events+ne.numLessThan40Events+ne.numGreaterThan40Events) + return ne +} + func TestVStreamFailover(t *testing.T) { testVStreamWithFailover(t, true) } @@ -407,3 +556,15 @@ func TestVStreamWithKeyspacesToWatch(t *testing.T) { testVStreamWithFailover(t, false) } + +func TestVStreamCopyMultiKeyspaceReshard(t *testing.T) { + ne := testVStreamCopyMultiKeyspaceReshard(t, 3000) + require.Equal(t, int64(0), ne.numJournalEvents) + require.NotZero(t, ne.numRowEvents) + require.NotZero(t, ne.numShard0BeforeReshardEvents) + require.NotZero(t, ne.numShard0AfterReshardEvents) + require.NotZero(t, ne.numLessThan80Events) + require.NotZero(t, ne.numGreaterThan80Events) + require.NotZero(t, ne.numLessThan40Events) + require.NotZero(t, ne.numGreaterThan40Events) +} diff --git a/go/vt/vtgate/endtoend/main_test.go b/go/vt/vtgate/endtoend/main_test.go index 48872965cb9..1bf13ceadb5 100644 --- a/go/vt/vtgate/endtoend/main_test.go +++ b/go/vt/vtgate/endtoend/main_test.go @@ -50,6 +50,12 @@ create table t1_copy_basic( primary key(id1) ) Engine=InnoDB; +create table t1_copy_all( + id1 bigint, + id2 bigint, + primary key(id1) +) Engine=InnoDB; + create table t1_copy_resume( id1 bigint, id2 bigint, @@ -150,6 +156,12 @@ create table t1_sharded( Name: "hash", }}, }, + "t1_copy_all": { + ColumnVindexes: []*vschemapb.ColumnVindex{{ + Column: "id1", + Name: "hash", + }}, + }, "t1_copy_resume": { ColumnVindexes: []*vschemapb.ColumnVindex{{ Column: "id1", @@ -217,6 +229,31 @@ create table t1_sharded( }, }, } + + schema2 = ` +create table t1_copy_all_ks2( + id1 bigint, + id2 bigint, + primary key(id1) +) Engine=InnoDB; +` + + vschema2 = &vschemapb.Keyspace{ + Sharded: true, + Vindexes: map[string]*vschemapb.Vindex{ + "hash": { + Type: "hash", + }, + }, + Tables: map[string]*vschemapb.Table{ + "t1_copy_all_ks2": { + ColumnVindexes: []*vschemapb.ColumnVindex{{ + Column: "id1", + Name: "hash", + }}, + }, + }, + } ) func TestMain(m *testing.M) { @@ -225,14 +262,24 @@ func TestMain(m *testing.M) { exitCode := func() int { var cfg vttest.Config cfg.Topology = &vttestpb.VTTestTopology{ - Keyspaces: []*vttestpb.Keyspace{{ - Name: "ks", - Shards: []*vttestpb.Shard{{ - Name: "-80", - }, { - Name: "80-", - }}, - }}, + Keyspaces: []*vttestpb.Keyspace{ + { + Name: "ks", + Shards: []*vttestpb.Shard{{ + Name: "-80", + }, { + Name: "80-", + }}, + }, + { + Name: "ks2", + Shards: []*vttestpb.Shard{{ + Name: "-80", + }, { + Name: "80-", + }}, + }, + }, } if err := cfg.InitSchemas("ks", schema, vschema); err != nil { fmt.Fprintf(os.Stderr, "%v\n", err) @@ -240,6 +287,11 @@ func TestMain(m *testing.M) { return 1 } defer os.RemoveAll(cfg.SchemaDir) + if err := cfg.InitSchemas("ks2", schema2, vschema2); err != nil { + fmt.Fprintf(os.Stderr, "%v\n", err) + os.RemoveAll(cfg.SchemaDir) + return 1 + } cluster = &vttest.LocalCluster{ Config: cfg, diff --git a/go/vt/vtgate/endtoend/misc_test.go b/go/vt/vtgate/endtoend/misc_test.go index 138b68d0aa3..aeeb1c122db 100644 --- a/go/vt/vtgate/endtoend/misc_test.go +++ b/go/vt/vtgate/endtoend/misc_test.go @@ -19,6 +19,7 @@ package endtoend import ( "context" "fmt" + osExec "os/exec" "testing" "github.com/stretchr/testify/assert" @@ -55,6 +56,16 @@ func TestCreateAndDropDatabase(t *testing.T) { require.NoError(t, err) defer conn.Close() + // cleanup the keyspace from the topology. + defer func() { + // the corresponding database needs to be created in advance. + // a subsequent DeleteKeyspace command returns the error of 'node doesn't exist' without it. + _ = exec(t, conn, "create database testitest") + + _, err := osExec.Command("vtctldclient", "--server", grpcAddress, "DeleteKeyspace", "--recursive", "--force", "testitest").CombinedOutput() + require.NoError(t, err) + }() + // run it 3 times. for count := 0; count < 3; count++ { t.Run(fmt.Sprintf("exec:%d", count), func(t *testing.T) { diff --git a/go/vt/vtgate/endtoend/row_count_test.go b/go/vt/vtgate/endtoend/row_count_test.go index 9ac200b33fa..5a29f6177a9 100644 --- a/go/vt/vtgate/endtoend/row_count_test.go +++ b/go/vt/vtgate/endtoend/row_count_test.go @@ -24,6 +24,7 @@ import ( "github.com/stretchr/testify/require" "vitess.io/vitess/go/mysql" + "vitess.io/vitess/go/test/endtoend/utils" ) func TestRowCount(t *testing.T) { @@ -31,6 +32,7 @@ func TestRowCount(t *testing.T) { conn, err := mysql.Connect(ctx, &vtParams) require.NoError(t, err) defer conn.Close() + utils.Exec(t, conn, "use ks") type tc struct { query string expected int diff --git a/go/vt/vtgate/endtoend/vstream_test.go b/go/vt/vtgate/endtoend/vstream_test.go index 832799366b1..f2ba9af992b 100644 --- a/go/vt/vtgate/endtoend/vstream_test.go +++ b/go/vt/vtgate/endtoend/vstream_test.go @@ -234,12 +234,7 @@ func TestVStreamCopyBasic(t *testing.T) { printEvents(evs) // for debugging ci failures if len(evs) == numExpectedEvents { - // The arrival order of COPY_COMPLETED events with keyspace/shard is not constant. - // On the other hand, the last event should always be a fully COPY_COMPLETED event. - // That's why the sort.Slice doesn't have to handle the last element in completedEvs. - sort.Slice(completedEvs[:len(completedEvs)-1], func(i, j int) bool { - return completedEvs[i].GetShard() < completedEvs[j].GetShard() - }) + sortCopyCompletedEvents(completedEvs) for i, ev := range completedEvs { require.Regexp(t, expectedCompletedEvents[i], ev.String()) } @@ -258,6 +253,139 @@ func TestVStreamCopyBasic(t *testing.T) { } } +// TestVStreamCopyUnspecifiedShardGtid tests the case where the keyspace contains wildcards and/or the shard is not specified in the request. +// Verify that the Vstream API resolves the unspecified ShardGtid input to a list of all the matching keyspaces and all the shards in the topology. +// - If the keyspace contains wildcards and the shard is not specified, the copy operation should be performed on all shards of all matching keyspaces. +// - If the keyspace is specified and the shard is not specified, the copy operation should be performed on all shards of the specified keyspace. +func TestVStreamCopyUnspecifiedShardGtid(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + conn, err := mysql.Connect(ctx, &vtParams) + if err != nil { + require.NoError(t, err) + } + defer conn.Close() + + _, err = conn.ExecuteFetch("insert into t1_copy_all(id1,id2) values(1,1), (2,2), (3,3), (4,4), (5,5), (6,6), (7,7), (8,8)", 1, false) + if err != nil { + require.NoError(t, err) + } + + _, err = conn.ExecuteFetch("insert into t1_copy_all_ks2(id1,id2) values(10,10), (20,20)", 1, false) + if err != nil { + require.NoError(t, err) + } + + filter := &binlogdatapb.Filter{ + Rules: []*binlogdatapb.Rule{{ + Match: "/t1_copy_all.*/", + }}, + } + flags := &vtgatepb.VStreamFlags{} + + // We have 2 shards in each keyspace. We assume the rows are + // evenly split across each shard. For each INSERT statement, which + // is a transaction and gets a global transaction identifier or GTID, we + // have 1 each of the following events: + // begin, field, position, lastpk, commit (5) + // For each row created in the INSERT statement -- 8 on ks1 and + // 2 on ks2 -- we have 1 row event between the begin and commit. + // When we have copied all rows for a table in the shard, the shard + // also gets events marking the transition from the copy phase to + // the streaming phase for that table with 1 each of the following: + // begin, vgtid, commit (3) + // As the copy phase completes for all tables on the shard, the shard + // gets 1 copy phase completed event. + // Lastly the stream has 1 final event to mark the final end to all + // copy phase operations in the vstream. + expectedKs1EventNum := 2 /* num shards */ * (9 /* begin/field/vgtid:pos/4 rowevents avg/vgitd: lastpk/commit) */ + 3 /* begin/vgtid/commit for completed table */ + 1 /* copy operation completed */) + expectedKs2EventNum := 2 /* num shards */ * (6 /* begin/field/vgtid:pos/1 rowevents avg/vgitd: lastpk/commit) */ + 3 /* begin/vgtid/commit for completed table */ + 1 /* copy operation completed */) + expectedFullyCopyCompletedNum := 1 + + cases := []struct { + name string + shardGtid *binlogdatapb.ShardGtid + expectedEventNum int + expectedCompletedEvents []string + }{ + { + name: "copy from all keyspaces", + shardGtid: &binlogdatapb.ShardGtid{ + Keyspace: "/.*", + }, + expectedEventNum: expectedKs1EventNum + expectedKs2EventNum + expectedFullyCopyCompletedNum, + expectedCompletedEvents: []string{ + `type:COPY_COMPLETED keyspace:"ks" shard:"-80"`, + `type:COPY_COMPLETED keyspace:"ks" shard:"80-"`, + `type:COPY_COMPLETED keyspace:"ks2" shard:"-80"`, + `type:COPY_COMPLETED keyspace:"ks2" shard:"80-"`, + `type:COPY_COMPLETED`, + }, + }, + { + name: "copy from all shards in one keyspace", + shardGtid: &binlogdatapb.ShardGtid{ + Keyspace: "ks", + }, + expectedEventNum: expectedKs1EventNum + expectedFullyCopyCompletedNum, + expectedCompletedEvents: []string{ + `type:COPY_COMPLETED keyspace:"ks" shard:"-80"`, + `type:COPY_COMPLETED keyspace:"ks" shard:"80-"`, + `type:COPY_COMPLETED`, + }, + }, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + gconn, conn, mconn, closeConnections := initialize(ctx, t) + defer closeConnections() + + var vgtid = &binlogdatapb.VGtid{} + vgtid.ShardGtids = []*binlogdatapb.ShardGtid{c.shardGtid} + reader, err := gconn.VStream(ctx, topodatapb.TabletType_PRIMARY, vgtid, filter, flags) + _, _ = conn, mconn + if err != nil { + require.NoError(t, err) + } + require.NotNil(t, reader) + var evs []*binlogdatapb.VEvent + var completedEvs []*binlogdatapb.VEvent + for { + e, err := reader.Recv() + switch err { + case nil: + evs = append(evs, e...) + + for _, ev := range e { + if ev.Type == binlogdatapb.VEventType_COPY_COMPLETED { + completedEvs = append(completedEvs, ev) + } + } + + if len(evs) == c.expectedEventNum { + sortCopyCompletedEvents(completedEvs) + for i, ev := range completedEvs { + require.Equal(t, c.expectedCompletedEvents[i], ev.String()) + } + t.Logf("TestVStreamCopyUnspecifiedShardGtid was successful") + return + } else if c.expectedEventNum < len(evs) { + printEvents(evs) // for debugging ci failures + require.FailNow(t, "len(events)=%v are not expected\n", len(evs)) + } + case io.EOF: + log.Infof("stream ended\n") + cancel() + default: + log.Errorf("Returned err %v", err) + require.FailNow(t, "remote error: %v\n", err) + } + } + }) + } +} + func TestVStreamCopyResume(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -563,3 +691,19 @@ func (v VEventSorter) Less(i, j int) bool { } return valI < valJ } + +// The arrival order of COPY_COMPLETED events with keyspace/shard is not constant. +// On the other hand, the last event should always be a fully COPY_COMPLETED event. +// That's why the sort.Slice doesn't have to handle the last element in completedEvs. +func sortCopyCompletedEvents(completedEvs []*binlogdatapb.VEvent) { + sortVEventByKeyspaceAndShard(completedEvs[:len(completedEvs)-1]) +} + +func sortVEventByKeyspaceAndShard(evs []*binlogdatapb.VEvent) { + sort.Slice(evs, func(i, j int) bool { + if evs[i].Keyspace == evs[j].Keyspace { + return evs[i].Shard < evs[j].Shard + } + return evs[i].Keyspace < evs[j].Keyspace + }) +} diff --git a/go/vt/vtgate/vstream_manager.go b/go/vt/vtgate/vstream_manager.go index 6efe0fb5e7a..6c72d8a1126 100644 --- a/go/vt/vtgate/vstream_manager.go +++ b/go/vt/vtgate/vstream_manager.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "io" + "regexp" "strings" "sync" "time" @@ -179,31 +180,51 @@ func (vsm *vstreamManager) resolveParams(ctx context.Context, tabletType topodat return nil, nil, nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "vgtid must have at least one value with a starting position") } // To fetch from all keyspaces, the input must contain a single ShardGtid - // that has an empty keyspace, and the Gtid must be "current". In the - // future, we'll allow the Gtid to be empty which will also support - // copying of existing data. - if len(vgtid.ShardGtids) == 1 && vgtid.ShardGtids[0].Keyspace == "" { - if vgtid.ShardGtids[0].Gtid != "current" { - return nil, nil, nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "for an empty keyspace, the Gtid value must be 'current': %v", vgtid) - } - keyspaces, err := vsm.toposerv.GetSrvKeyspaceNames(ctx, vsm.cell, false) - if err != nil { - return nil, nil, nil, err - } - newvgtid := &binlogdatapb.VGtid{} - for _, keyspace := range keyspaces { - newvgtid.ShardGtids = append(newvgtid.ShardGtids, &binlogdatapb.ShardGtid{ - Keyspace: keyspace, - Gtid: "current", - }) + // that has an empty keyspace, and the Gtid must be "current". + // Or the input must contain a single ShardGtid that has keyspace wildcards. + if len(vgtid.ShardGtids) == 1 { + inputKeyspace := vgtid.ShardGtids[0].Keyspace + isEmpty := inputKeyspace == "" + isRegexp := strings.HasPrefix(inputKeyspace, "/") + if isEmpty || isRegexp { + newvgtid := &binlogdatapb.VGtid{} + keyspaces, err := vsm.toposerv.GetSrvKeyspaceNames(ctx, vsm.cell, false) + if err != nil { + return nil, nil, nil, err + } + + if isEmpty { + if vgtid.ShardGtids[0].Gtid != "current" { + return nil, nil, nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "for an empty keyspace, the Gtid value must be 'current': %v", vgtid) + } + for _, keyspace := range keyspaces { + newvgtid.ShardGtids = append(newvgtid.ShardGtids, &binlogdatapb.ShardGtid{ + Keyspace: keyspace, + Gtid: "current", + }) + } + } else { + re, err := regexp.Compile(strings.Trim(inputKeyspace, "/")) + if err != nil { + return nil, nil, nil, err + } + for _, keyspace := range keyspaces { + if re.MatchString(keyspace) { + newvgtid.ShardGtids = append(newvgtid.ShardGtids, &binlogdatapb.ShardGtid{ + Keyspace: keyspace, + Gtid: vgtid.ShardGtids[0].Gtid, + }) + } + } + } + vgtid = newvgtid } - vgtid = newvgtid } newvgtid := &binlogdatapb.VGtid{} for _, sgtid := range vgtid.ShardGtids { if sgtid.Shard == "" { - if sgtid.Gtid != "current" { - return nil, nil, nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "if shards are unspecified, the Gtid value must be 'current': %v", vgtid) + if sgtid.Gtid != "current" && sgtid.Gtid != "" { + return nil, nil, nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "if shards are unspecified, the Gtid value must be 'current' or empty; got: %v", vgtid) } // TODO(sougou): this should work with the new Migrate workflow _, _, allShards, err := vsm.resolver.GetKeyspaceShards(ctx, sgtid.Keyspace, tabletType) diff --git a/go/vt/vtgate/vstream_manager_test.go b/go/vt/vtgate/vstream_manager_test.go index 7136539510b..be94432a652 100644 --- a/go/vt/vtgate/vstream_manager_test.go +++ b/go/vt/vtgate/vstream_manager_test.go @@ -889,9 +889,44 @@ func TestResolveVStreamParams(t *testing.T) { input: &binlogdatapb.VGtid{ ShardGtids: []*binlogdatapb.ShardGtid{{ Keyspace: "TestVStream", + Gtid: "other", + }}, + }, + err: "if shards are unspecified, the Gtid value must be 'current' or empty", + }, { + // Verify that the function maps the input missing the shard to a list of all shards in the topology. + input: &binlogdatapb.VGtid{ + ShardGtids: []*binlogdatapb.ShardGtid{{ + Keyspace: "TestVStream", + }}, + }, + output: &binlogdatapb.VGtid{ + ShardGtids: []*binlogdatapb.ShardGtid{{ + Keyspace: "TestVStream", + Shard: "-20", + }, { + Keyspace: "TestVStream", + Shard: "20-40", + }, { + Keyspace: "TestVStream", + Shard: "40-60", + }, { + Keyspace: "TestVStream", + Shard: "60-80", + }, { + Keyspace: "TestVStream", + Shard: "80-a0", + }, { + Keyspace: "TestVStream", + Shard: "a0-c0", + }, { + Keyspace: "TestVStream", + Shard: "c0-e0", + }, { + Keyspace: "TestVStream", + Shard: "e0-", }}, }, - err: "if shards are unspecified, the Gtid value must be 'current'", }, { input: &binlogdatapb.VGtid{ ShardGtids: []*binlogdatapb.ShardGtid{{ @@ -983,17 +1018,49 @@ func TestResolveVStreamParams(t *testing.T) { assert.Equal(t, wantFilter, filter, tcase.input) require.False(t, flags.MinimizeSkew) } - // Special-case: empty keyspace because output is too big. - input := &binlogdatapb.VGtid{ - ShardGtids: []*binlogdatapb.ShardGtid{{ - Gtid: "current", - }}, + + // Special-case: empty keyspace or keyspace containing wildcards because output is too big. + // Verify that the function resolves input for multiple keyspaces into a list of all corresponding shards. + // Ensure that the number of shards returned is greater than the number of shards in a single keyspace named 'TestVStream.' + specialCases := []struct { + input *binlogdatapb.ShardGtid + }{ + { + input: &binlogdatapb.ShardGtid{ + Gtid: "current", + }, + }, + { + input: &binlogdatapb.ShardGtid{ + Keyspace: "/.*", + }, + }, + { + input: &binlogdatapb.ShardGtid{ + Keyspace: "/.*", + Gtid: "current", + }, + }, + { + input: &binlogdatapb.ShardGtid{ + Keyspace: "/Test.*", + }, + }, } - vgtid, _, _, err := vsm.resolveParams(context.Background(), topodatapb.TabletType_REPLICA, input, nil, nil) - require.NoError(t, err, input) - if got, want := len(vgtid.ShardGtids), 8; want >= got { - t.Errorf("len(vgtid.ShardGtids): %v, must be >%d", got, want) + for _, tcase := range specialCases { + input := &binlogdatapb.VGtid{ + ShardGtids: []*binlogdatapb.ShardGtid{tcase.input}, + } + vgtid, _, _, err := vsm.resolveParams(context.Background(), topodatapb.TabletType_REPLICA, input, nil, nil) + require.NoError(t, err, tcase.input) + if got, expectTestVStreamShardNumber := len(vgtid.ShardGtids), 8; expectTestVStreamShardNumber >= got { + t.Errorf("len(vgtid.ShardGtids): %v, must be >%d", got, expectTestVStreamShardNumber) + } + for _, s := range vgtid.ShardGtids { + require.Equal(t, tcase.input.Gtid, s.Gtid) + } } + for _, minimizeSkew := range []bool{true, false} { t.Run(fmt.Sprintf("resolveParams MinimizeSkew %t", minimizeSkew), func(t *testing.T) { flags := &vtgatepb.VStreamFlags{MinimizeSkew: minimizeSkew} diff --git a/go/vt/vttest/local_cluster.go b/go/vt/vttest/local_cluster.go index 7dcbdc67afa..0d7bfd4a24f 100644 --- a/go/vt/vttest/local_cluster.go +++ b/go/vt/vttest/local_cluster.go @@ -156,20 +156,20 @@ type Config struct { // It then sets the right value for cfg.SchemaDir. // At the end of the test, the caller should os.RemoveAll(cfg.SchemaDir). func (cfg *Config) InitSchemas(keyspace, schema string, vschema *vschemapb.Keyspace) error { - if cfg.SchemaDir != "" { - return fmt.Errorf("SchemaDir is already set to %v", cfg.SchemaDir) - } - - // Create a base temporary directory. - tempSchemaDir, err := os.MkdirTemp("", "vttest") - if err != nil { - return err + schemaDir := cfg.SchemaDir + if schemaDir == "" { + // Create a base temporary directory. + tempSchemaDir, err := os.MkdirTemp("", "vttest") + if err != nil { + return err + } + schemaDir = tempSchemaDir } // Write the schema if set. if schema != "" { - ksDir := path.Join(tempSchemaDir, keyspace) - err = os.Mkdir(ksDir, os.ModeDir|0775) + ksDir := path.Join(schemaDir, keyspace) + err := os.Mkdir(ksDir, os.ModeDir|0775) if err != nil { return err } @@ -182,7 +182,7 @@ func (cfg *Config) InitSchemas(keyspace, schema string, vschema *vschemapb.Keysp // Write in the vschema if set. if vschema != nil { - vschemaFilePath := path.Join(tempSchemaDir, keyspace, "vschema.json") + vschemaFilePath := path.Join(schemaDir, keyspace, "vschema.json") vschemaJSON, err := json.Marshal(vschema) if err != nil { return err @@ -191,7 +191,7 @@ func (cfg *Config) InitSchemas(keyspace, schema string, vschema *vschemapb.Keysp return err } } - cfg.SchemaDir = tempSchemaDir + cfg.SchemaDir = schemaDir return nil }