From c678b0e3e6319d1e260969a60eea18bfbd04646d Mon Sep 17 00:00:00 2001 From: Andrew Mason Date: Wed, 3 Mar 2021 22:47:08 -0500 Subject: [PATCH] Add testutil helper to manage the complexity of recursively calling WithTestServer Signed-off-by: Andrew Mason --- go/vt/vtadmin/api_test.go | 113 +++++++++--------- go/vt/vtctl/grpcvtctldserver/testutil/util.go | 37 ++++++ 2 files changed, 94 insertions(+), 56 deletions(-) diff --git a/go/vt/vtadmin/api_test.go b/go/vt/vtadmin/api_test.go index f92554f6777..ef1d935dc8e 100644 --- a/go/vt/vtadmin/api_test.go +++ b/go/vt/vtadmin/api_test.go @@ -367,7 +367,12 @@ func TestGetKeyspaces(t *testing.T) { }, } + ctx := context.Background() + for _, tt := range tests { + // Note that these test cases were written prior to the existence of + // WithTestServers, so they are all written with the assumption that + // there are exactly 2 clusters. topos := []*topo.Server{ memorytopo.NewServer("c0_cell1"), memorytopo.NewServer("c1_cell1"), @@ -375,32 +380,31 @@ func TestGetKeyspaces(t *testing.T) { for cdx, cks := range tt.clusterKeyspaces { for _, ks := range cks { - testutil.AddKeyspace(context.Background(), t, topos[cdx], ks) + testutil.AddKeyspace(ctx, t, topos[cdx], ks) } } for cdx, css := range tt.clusterShards { - testutil.AddShards(context.Background(), t, topos[cdx], css...) + testutil.AddShards(ctx, t, topos[cdx], css...) } - // Setting up WithTestServer in a generic, recursive way is... unpleasant, - // so all tests are set-up and run in the context of these two clusters. - testutil.WithTestServer(t, grpcvtctldserver.NewVtctldServer(topos[0]), func(t *testing.T, cluster0Client vtctldclient.VtctldClient) { - testutil.WithTestServer(t, grpcvtctldserver.NewVtctldServer(topos[1]), func(t *testing.T, cluster1Client vtctldclient.VtctldClient) { - clusterClients := []vtctldclient.VtctldClient{cluster0Client, cluster1Client} + servers := []vtctlservicepb.VtctldServer{ + grpcvtctldserver.NewVtctldServer(topos[0]), + grpcvtctldserver.NewVtctldServer(topos[1]), + } - clusters := []*cluster.Cluster{ - vtadmintestutil.BuildCluster(0, clusterClients[0], nil, nil), - vtadmintestutil.BuildCluster(1, clusterClients[1], nil, nil), - } + testutil.WithTestServers(t, func(t *testing.T, clients ...vtctldclient.VtctldClient) { + clusters := []*cluster.Cluster{ + vtadmintestutil.BuildCluster(0, clients[0], nil, nil), + vtadmintestutil.BuildCluster(1, clients[1], nil, nil), + } - api := NewAPI(clusters, grpcserver.Options{}, http.Options{}) - resp, err := api.GetKeyspaces(context.Background(), tt.req) - require.NoError(t, err) + api := NewAPI(clusters, grpcserver.Options{}, http.Options{}) + resp, err := api.GetKeyspaces(ctx, tt.req) + require.NoError(t, err) - vtadmintestutil.AssertKeyspaceSlicesEqual(t, tt.expected.Keyspaces, resp.Keyspaces) - }) - }) + vtadmintestutil.AssertKeyspaceSlicesEqual(t, tt.expected.Keyspaces, resp.Keyspaces) + }, servers...) } } @@ -831,7 +835,12 @@ func TestGetSchemas(t *testing.T) { }, } + ctx := context.Background() + for _, tt := range tests { + // Note that these test cases were written prior to the existence of + // WithTestServers, so they are all written with the assumption that + // there are exactly 2 clusters. tt := tt t.Run(tt.name, func(t *testing.T) { @@ -858,51 +867,43 @@ func TestGetSchemas(t *testing.T) { }), } - // Setting up WithTestServer in a generic, recursive way is... unpleasant, - // so all tests are set-up and run in the context of these two clusters. - testutil.WithTestServer(t, vtctlds[0], func(t *testing.T, cluster0Client vtctldclient.VtctldClient) { - testutil.WithTestServer(t, vtctlds[1], func(t *testing.T, cluster1Client vtctldclient.VtctldClient) { - // Put 'em in a slice so we can look them up by index - clusterClients := []vtctldclient.VtctldClient{cluster0Client, cluster1Client} - - // Build the clusters - clusters := make([]*cluster.Cluster, len(topos)) - for cdx, toposerver := range topos { - // Handle when a test doesn't define any tablets for a given cluster. - var cts []*vtadminpb.Tablet - if cdx < len(tt.clusterTablets) { - cts = tt.clusterTablets[cdx] - } + testutil.WithTestServers(t, func(t *testing.T, clients ...vtctldclient.VtctldClient) { + clusters := make([]*cluster.Cluster, len(topos)) + for cdx, toposerver := range topos { + // Handle when a test doesn't define any tablets for a given cluster. + var cts []*vtadminpb.Tablet + if cdx < len(tt.clusterTablets) { + cts = tt.clusterTablets[cdx] + } - for _, tablet := range cts { - // AddTablet also adds the keyspace + shard for us. - testutil.AddTablet(context.Background(), t, toposerver, tablet.Tablet, nil) - - // Adds each SchemaDefinition to the fake TabletManagerClient, or nil - // if there are no schemas for that tablet. (All tablet aliases must - // exist in the map. Otherwise, TabletManagerClient will return an error when - // looking up the schema with tablet alias that doesn't exist.) - alias := topoproto.TabletAliasString(tablet.Tablet.Alias) - tmc.GetSchemaResults[alias] = struct { - Schema *tabletmanagerdatapb.SchemaDefinition - Error error - }{ - Schema: tt.tabletSchemas[alias], - Error: nil, - } + for _, tablet := range cts { + // AddTablet also adds the keyspace + shard for us. + testutil.AddTablet(ctx, t, toposerver, tablet.Tablet, nil) + + // Adds each SchemaDefinition to the fake TabletManagerClient, or nil + // if there are no schemas for that tablet. (All tablet aliases must + // exist in the map. Otherwise, TabletManagerClient will return an error when + // looking up the schema with tablet alias that doesn't exist.) + alias := topoproto.TabletAliasString(tablet.Tablet.Alias) + tmc.GetSchemaResults[alias] = struct { + Schema *tabletmanagerdatapb.SchemaDefinition + Error error + }{ + Schema: tt.tabletSchemas[alias], + Error: nil, } - - clusters[cdx] = vtadmintestutil.BuildCluster(cdx, clusterClients[cdx], cts, nil) } - api := NewAPI(clusters, grpcserver.Options{}, http.Options{}) + clusters[cdx] = vtadmintestutil.BuildCluster(cdx, clients[cdx], cts, nil) + } + + api := NewAPI(clusters, grpcserver.Options{}, http.Options{}) - resp, err := api.GetSchemas(context.Background(), tt.req) - require.NoError(t, err) + resp, err := api.GetSchemas(ctx, tt.req) + require.NoError(t, err) - vtadmintestutil.AssertSchemaSlicesEqual(t, tt.expected.Schemas, resp.Schemas, tt.name) - }) - }) + vtadmintestutil.AssertSchemaSlicesEqual(t, tt.expected.Schemas, resp.Schemas, tt.name) + }, vtctlds...) }) } } diff --git a/go/vt/vtctl/grpcvtctldserver/testutil/util.go b/go/vt/vtctl/grpcvtctldserver/testutil/util.go index ba1b2d09f67..795a7547c43 100644 --- a/go/vt/vtctl/grpcvtctldserver/testutil/util.go +++ b/go/vt/vtctl/grpcvtctldserver/testutil/util.go @@ -62,6 +62,43 @@ func WithTestServer( test(t, client) } +// WithTestServers creates N gRPC servers listening locally with the given RPC +// implementations, and then runs the test func with N clients created, where +// clients[i] points at servers[i]. +func WithTestServers( + t *testing.T, + test func(t *testing.T, clients ...vtctldclient.VtctldClient), + servers ...vtctlservicepb.VtctldServer, +) { + // Declare our recursive helper function so it can refer to itself. + var withTestServers func(t *testing.T, servers ...vtctlservicepb.VtctldServer) + + // Preallocate a slice of clients we're eventually going to call the test + // function with. + clients := make([]vtctldclient.VtctldClient, 0, len(servers)) + + withTestServers = func(t *testing.T, servers ...vtctlservicepb.VtctldServer) { + if len(servers) == 0 { + // We've started up all the test servers and accumulated clients for + // each of them (or there were no test servers to start, and we've + // accumulated no clients), so finally we run the test and stop + // recursing. + test(t, clients...) + + return + } + + // Start up a test server for the head of our server slice, accumulate + // the resulting client, and recurse on the tail of our server slice. + WithTestServer(t, servers[0], func(t *testing.T, client vtctldclient.VtctldClient) { + clients = append(clients, client) + withTestServers(t, servers[1:]...) + }) + } + + withTestServers(t, servers...) +} + // AddKeyspace adds a keyspace to a topology, failing a test if that keyspace // could not be added. It shallow copies the proto struct to prevent XXX_ fields // from changing in the marshalling.