From 0afa6a4cf6bc4ca24af10e8ca8bc8f4092f2afaf Mon Sep 17 00:00:00 2001 From: Dmitry Shmulevich Date: Wed, 5 Mar 2025 19:56:55 -0800 Subject: [PATCH] update unit tests for GCP (#85) Signed-off-by: Dmitry Shmulevich --- pkg/providers/gcp/instance_topology.go | 9 +- pkg/providers/gcp/provider.go | 11 +- pkg/providers/gcp/provider_sim.go | 138 ++++++++++++++----------- pkg/providers/gcp/provider_sim_test.go | 97 ++++++++++++----- pkg/providers/providers_sim.go | 4 + 5 files changed, 166 insertions(+), 93 deletions(-) diff --git a/pkg/providers/gcp/instance_topology.go b/pkg/providers/gcp/instance_topology.go index e1a908c..776b3f0 100644 --- a/pkg/providers/gcp/instance_topology.go +++ b/pkg/providers/gcp/instance_topology.go @@ -31,7 +31,7 @@ import ( ) func (p *baseProvider) generateInstanceTopology(ctx context.Context, pageSize *int, cis []topology.ComputeInstances) (*topology.ClusterTopology, error) { - client, err := p.clientFactory() + client, err := p.clientFactory(pageSize) if err != nil { return nil, fmt.Errorf("failed to get client: %v", err) } @@ -43,9 +43,8 @@ func (p *baseProvider) generateInstanceTopology(ctx context.Context, pageSize *i topo := topology.NewClusterTopology() - maxRes := castPageSize(pageSize) for _, ci := range cis { - if err := p.generateRegionInstanceTopology(ctx, client, projectID, maxRes, topo, &ci); err != nil { + if err := p.generateRegionInstanceTopology(ctx, client, projectID, topo, &ci); err != nil { return nil, fmt.Errorf("failed to get instance topology: %v", err) } } @@ -53,13 +52,13 @@ func (p *baseProvider) generateInstanceTopology(ctx context.Context, pageSize *i return topo, nil } -func (p *baseProvider) generateRegionInstanceTopology(ctx context.Context, client Client, projectID string, maxRes *uint32, topo *topology.ClusterTopology, ci *topology.ComputeInstances) error { +func (p *baseProvider) generateRegionInstanceTopology(ctx context.Context, client Client, projectID string, topo *topology.ClusterTopology, ci *topology.ComputeInstances) error { klog.InfoS("Getting instance topology", "region", ci.Region, "project", projectID) req := computepb.ListInstancesRequest{ Project: projectID, Zone: ci.Region, - MaxResults: maxRes, + MaxResults: client.PageSize(), PageToken: nil, } diff --git a/pkg/providers/gcp/provider.go b/pkg/providers/gcp/provider.go index 20ae097..319bda1 100644 --- a/pkg/providers/gcp/provider.go +++ b/pkg/providers/gcp/provider.go @@ -37,7 +37,7 @@ type baseProvider struct { clientFactory ClientFactory } -type ClientFactory func() (Client, error) +type ClientFactory func(pageSize *int) (Client, error) type InstanceIterator interface { Next() (*computepb.Instance, error) @@ -46,10 +46,16 @@ type InstanceIterator interface { type Client interface { ProjectID(ctx context.Context) (string, error) Instances(ctx context.Context, req *computepb.ListInstancesRequest, opts ...gax.CallOption) (InstanceIterator, string) + PageSize() *uint32 } type gcpClient struct { instanceClient *compute_v1.InstancesClient + pageSize *uint32 +} + +func (c *gcpClient) PageSize() *uint32 { + return c.pageSize } func (c *gcpClient) ProjectID(ctx context.Context) (string, error) { @@ -68,7 +74,7 @@ func NamedLoader() (string, providers.Loader) { } func Loader(ctx context.Context, config providers.Config) (providers.Provider, error) { - clientFactory := func() (Client, error) { + clientFactory := func(pageSize *int) (Client, error) { instanceClient, err := compute_v1.NewInstancesRESTClient(ctx) if err != nil { return nil, fmt.Errorf("failed to get instances client: %s", err.Error()) @@ -76,6 +82,7 @@ func Loader(ctx context.Context, config providers.Config) (providers.Provider, e return &gcpClient{ instanceClient: instanceClient, + pageSize: castPageSize(pageSize), }, nil } diff --git a/pkg/providers/gcp/provider_sim.go b/pkg/providers/gcp/provider_sim.go index 165d176..b968a6e 100644 --- a/pkg/providers/gcp/provider_sim.go +++ b/pkg/providers/gcp/provider_sim.go @@ -22,6 +22,7 @@ import ( "strconv" computepb "cloud.google.com/go/compute/apiv1/computepb" + "github.com/agrea/ptr" gax "github.com/googleapis/gax-go/v2" "google.golang.org/api/iterator" @@ -32,23 +33,29 @@ import ( const ( NAME_SIM = "gcp-sim" + + errNoce = iota + errClientFactory + errProjectID + errInstances ) type simClient struct { - model *models.Model - pages []*simInstanceIter + model *models.Model + pageSize *uint32 + instanceIDs []string + apiErr int } type simInstanceIter struct { instances []*computepb.Instance indx int - next bool - err bool + err error } func (iter *simInstanceIter) Next() (*computepb.Instance, error) { - if iter.err { - return nil, fmt.Errorf("iterator error") + if iter.err != nil { + return nil, iter.err } if iter.indx >= len(iter.instances) { @@ -60,62 +67,59 @@ func (iter *simInstanceIter) Next() (*computepb.Instance, error) { return ret, nil } -func newSimClient(model *models.Model) (*simClient, error) { - // divide nodes into 2 pages - n := len(model.Nodes) - nodeNames := make([]string, 0, n) - for name := range model.Nodes { - nodeNames = append(nodeNames, name) - } - mid := n / 2 - pages := make([]*simInstanceIter, 2) - - for i, pair := range []struct{ from, to int }{ - {from: 0, to: mid}, - {from: mid + 1, to: n - 1}, - } { - if pair.from > pair.to { - pages[i] = &simInstanceIter{} - } else { - instances := make([]*computepb.Instance, 0, pair.to-pair.from+1) - for j := pair.from; j <= pair.to; j++ { - node := model.Nodes[nodeNames[j]] - physicalHost := fmt.Sprintf("/%s/%s/%s", node.NetLayers[1], node.NetLayers[0], node.Name) - instanceID, err := strconv.ParseUint(node.Name, 10, 64) - if err != nil { - return nil, fmt.Errorf("invalid instance ID %q; must be numerical", node.Name) - } - instance := &computepb.Instance{ - Id: &instanceID, - Name: &node.Name, - ResourceStatus: &computepb.ResourceStatus{ - PhysicalHost: &physicalHost, - }, - } - instances = append(instances, instance) - } - pages[i] = &simInstanceIter{instances: instances} - } - } - - pages[0].next = true - - return &simClient{ - model: model, - pages: pages, - }, nil +func (c *simClient) PageSize() *uint32 { + return c.pageSize } func (c *simClient) ProjectID(ctx context.Context) (string, error) { + if c.apiErr == errProjectID { + return "", providers.APIError + } + return "", nil } func (c *simClient) Instances(ctx context.Context, req *computepb.ListInstancesRequest, opts ...gax.CallOption) (InstanceIterator, string) { - if req.PageToken == nil { - return c.pages[0], "next" - } else { - return c.pages[1], "" + if c.apiErr == errInstances { + return &simInstanceIter{err: providers.APIError}, "" + } + + var indx int + from := getPage(req.PageToken) + iter := &simInstanceIter{instances: make([]*computepb.Instance, 0)} + + for indx = from; indx < from+int(*c.pageSize); indx++ { + node := c.model.Nodes[c.instanceIDs[indx]] + physicalHost := fmt.Sprintf("/%s/%s/%s", node.NetLayers[1], node.NetLayers[0], node.Name) + instanceID, err := strconv.ParseUint(node.Name, 10, 64) + if err != nil { + return &simInstanceIter{err: fmt.Errorf("invalid instance ID %q; must be numerical", node.Name)}, "" + } + instance := &computepb.Instance{ + Id: &instanceID, + Name: &node.Name, + ResourceStatus: &computepb.ResourceStatus{ + PhysicalHost: &physicalHost, + }, + } + iter.instances = append(iter.instances, instance) } + + var token string + if indx < len(c.instanceIDs) { + token = fmt.Sprintf("%d", indx) + } + + return iter, token +} + +func getPage(page *string) int { + if page == nil { + return 0 + } + + val, _ := strconv.ParseInt(*page, 10, 32) + return int(val) } func NamedLoaderSim() (string, providers.Loader) { @@ -133,13 +137,27 @@ func LoaderSim(ctx context.Context, cfg providers.Config) (providers.Provider, e return nil, fmt.Errorf("failed to load model file for simulation: %v", err) } - client, err := newSimClient(model) - if err != nil { - return nil, fmt.Errorf("failed to create simulation client: %v", err) + instanceIDs := make([]string, 0, len(model.Nodes)) + for _, node := range model.Nodes { + instanceIDs = append(instanceIDs, node.Name) } - clientFactory := func() (Client, error) { - return client, nil + clientFactory := func(pageSize *int) (Client, error) { + if p.APIError == errClientFactory { + return nil, providers.APIError + } + + limit := castPageSize(pageSize) + if limit == nil { + limit = ptr.Uint32(uint32(len(instanceIDs))) + } + + return &simClient{ + model: model, + pageSize: limit, + instanceIDs: instanceIDs, + apiErr: p.APIError, + }, nil } return NewSim(clientFactory), nil @@ -158,7 +176,7 @@ func NewSim(clientFactory ClientFactory) *simProvider { // Engine support func (p *simProvider) GetComputeInstances(ctx context.Context) ([]topology.ComputeInstances, error) { - client, _ := p.clientFactory() + client, _ := p.clientFactory(nil) return client.(*simClient).model.Instances, nil } diff --git a/pkg/providers/gcp/provider_sim_test.go b/pkg/providers/gcp/provider_sim_test.go index 8347bd1..407113c 100644 --- a/pkg/providers/gcp/provider_sim_test.go +++ b/pkg/providers/gcp/provider_sim_test.go @@ -24,13 +24,14 @@ import ( "github.com/NVIDIA/topograph/pkg/engines/slurm" "github.com/NVIDIA/topograph/pkg/providers" "github.com/NVIDIA/topograph/pkg/topology" + "github.com/agrea/ptr" "github.com/stretchr/testify/require" ) const ( ignoreErrMsg = "_IGNORE_" - singleCluster = ` + nodeModel = ` switches: - name: core switches: [spine] @@ -45,7 +46,7 @@ capacity_blocks: nodes: [11] ` - mediumCluster = ` + clusterModel = ` switches: - name: core switches: [spine] @@ -73,8 +74,9 @@ func TestProviderSim(t *testing.T) { testCases := []struct { name string model string + pageSize *int instances []topology.ComputeInstances - apiErr bool + apiErr int topology string err string }{ @@ -83,8 +85,49 @@ func TestProviderSim(t *testing.T) { model: `bad: model: error:`, err: ignoreErrMsg, }, + + { + name: "Case 3: no ComputeInstances", + model: clusterModel, + }, + { + name: "Case X.1: ClientFactory API error", + model: nodeModel, + instances: []topology.ComputeInstances{ + { + Region: "region", + Instances: map[string]string{"11": "node11"}, + }, + }, + apiErr: errClientFactory, + err: "failed to get client: API error", + }, + { + name: "Case X.2: ProjectID API error", + model: nodeModel, + instances: []topology.ComputeInstances{ + { + Region: "region", + Instances: map[string]string{"11": "node11"}, + }, + }, + apiErr: errProjectID, + err: "failed to get project ID: API error", + }, + { + name: "Case X.3: Instances API error", + model: nodeModel, + instances: []topology.ComputeInstances{ + { + Region: "region", + Instances: map[string]string{"11": "node11"}, + }, + }, + apiErr: errInstances, + err: "failed to get instance topology: API error", + }, { - name: "Case 2: unsupported instance ID", + name: "Case X.4: unsupported instance ID", model: ` switches: - name: core @@ -99,15 +142,17 @@ capacity_blocks: nvlink: nvl1 nodes: [n11] `, - err: `failed to create simulation client: invalid instance ID "n11"; must be numerical`, - }, - { - name: "Case 3: no ComputeInstances", - model: mediumCluster, + instances: []topology.ComputeInstances{ + { + Region: "region", + Instances: map[string]string{"11": "node11"}, + }, + }, + err: `failed to get instance topology: invalid instance ID "n11"; must be numerical`, }, { name: "Case 4: single node", - model: singleCluster, + model: nodeModel, instances: []topology.ComputeInstances{ { Region: "region", @@ -120,20 +165,23 @@ SwitchName=tor Nodes=node11 `, }, { - name: "Case 5: page iterator error", - model: singleCluster, + name: "Case 5: valid input, no pagination", + model: clusterModel, instances: []topology.ComputeInstances{ { Region: "region", - Instances: map[string]string{"11": "node11"}, + Instances: map[string]string{"11": "node11", "12": "node12", "21": "node21", "22": "node22"}, }, }, - apiErr: true, - err: "failed to get instance topology: iterator error", + topology: `SwitchName=spine Switches=tor[1-2] +SwitchName=tor1 Nodes=node[11-12] +SwitchName=tor2 Nodes=node[21-22] +`, }, { - name: "Case 6: valid input", - model: mediumCluster, + name: "Case 6: valid input, pagination", + model: clusterModel, + pageSize: ptr.Int(2), instances: []topology.ComputeInstances{ { Region: "region", @@ -160,9 +208,12 @@ SwitchName=tor2 Nodes=node[21-22] require.NoError(t, err) cfg := providers.Config{ - Params: map[string]any{"model_path": f.Name()}, + Params: map[string]any{ + "model_path": f.Name(), + "api_error": tc.apiErr, + }, } - sim, err := LoaderSim(ctx, cfg) + provider, err := LoaderSim(ctx, cfg) if err != nil { if len(tc.err) == 0 { require.NoError(t, err) @@ -171,14 +222,8 @@ SwitchName=tor2 Nodes=node[21-22] } return } - provider := sim.(*simProvider) - - if tc.apiErr { - cl, _ := provider.clientFactory() - cl.(*simClient).pages[0].err = true - } - topo, err := provider.GenerateTopologyConfig(ctx, nil, tc.instances) + topo, err := provider.GenerateTopologyConfig(ctx, tc.pageSize, tc.instances) if len(tc.err) != 0 { require.EqualError(t, err, tc.err) } else { diff --git a/pkg/providers/providers_sim.go b/pkg/providers/providers_sim.go index 8ae0479..614d449 100644 --- a/pkg/providers/providers_sim.go +++ b/pkg/providers/providers_sim.go @@ -17,13 +17,17 @@ package providers import ( + "errors" "fmt" "github.com/NVIDIA/topograph/internal/config" ) +var APIError = errors.New("API error") + type SimulationParams struct { ModelPath string `mapstructure:"model_path"` + APIError int `mapstructure:"api_error"` } func GetSimulationParams(params map[string]any) (*SimulationParams, error) {