From 78f107052bc2941ea62e64feeb4182f6c67b912f Mon Sep 17 00:00:00 2001 From: JM Faircloth Date: Thu, 18 Apr 2024 15:51:17 -0500 Subject: [PATCH] update tests for assertion func changes --- azure_test.go | 58 +++++++++++++++++++++++---------------------- path_config_test.go | 2 +- path_login_test.go | 39 +++++++++++++++--------------- 3 files changed, 51 insertions(+), 48 deletions(-) diff --git a/azure_test.go b/azure_test.go index 1ba43260..f524b52d 100644 --- a/azure_test.go +++ b/azure_test.go @@ -16,7 +16,9 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/msi/armmsi" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources" "github.com/coreos/go-oidc" + "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault-plugin-auth-azure/client" + "github.com/hashicorp/vault/sdk/logical" ) // mockKeySet is used in tests to bypass signature validation and return only @@ -45,43 +47,43 @@ func newMockVerifier() client.TokenVerifier { } type mockComputeClient struct { - computeClientFunc func(vmName string) (armcompute.VirtualMachinesClientGetResponse, error) + computeClientFunc computeClientFunc } type mockVMSSClient struct { - vmssClientFunc func(vmssName string) (armcompute.VirtualMachineScaleSetsClientGetResponse, error) + vmssClientFunc vmssClientFunc } type mockMSIClient struct { - msiClientFunc func(resourceName string) (armmsi.UserAssignedIdentitiesClientGetResponse, error) - msiListFunc func(resourceGroup string) armmsi.UserAssignedIdentitiesClientListByResourceGroupResponse + msiClientFunc msiClientFunc + msiListFunc msiListFunc } type mockResourceClient struct { - resourceClientFunc func(resourceID string) (armresources.ClientGetByIDResponse, error) + resourceClientFunc resourceClientFunc } type mockProvidersClient struct { - providersClientFunc func(string) (armresources.ProvidersClientGetResponse, error) + providersClientFunc providersClientFunc } -func (c *mockComputeClient) Get(_ context.Context, _, vmName string, _ *armcompute.VirtualMachinesClientGetOptions) (armcompute.VirtualMachinesClientGetResponse, error) { +func (c *mockComputeClient) Get(ctx context.Context, _, vmName string, _ *armcompute.VirtualMachinesClientGetOptions) (armcompute.VirtualMachinesClientGetResponse, error) { if c.computeClientFunc != nil { - return c.computeClientFunc(vmName) + return c.computeClientFunc(ctx, hclog.NewNullLogger(), nil, vmName) } return armcompute.VirtualMachinesClientGetResponse{}, nil } -func (c *mockVMSSClient) Get(_ context.Context, _, vmssName string, _ *armcompute.VirtualMachineScaleSetsClientGetOptions) (armcompute.VirtualMachineScaleSetsClientGetResponse, error) { +func (c *mockVMSSClient) Get(ctx context.Context, _, vmssName string, _ *armcompute.VirtualMachineScaleSetsClientGetOptions) (armcompute.VirtualMachineScaleSetsClientGetResponse, error) { if c.vmssClientFunc != nil { - return c.vmssClientFunc(vmssName) + return c.vmssClientFunc(ctx, hclog.NewNullLogger(), nil, vmssName) } return armcompute.VirtualMachineScaleSetsClientGetResponse{}, nil } -func (c *mockMSIClient) Get(_ context.Context, _, resourceName string, _ *armmsi.UserAssignedIdentitiesClientGetOptions) (armmsi.UserAssignedIdentitiesClientGetResponse, error) { +func (c *mockMSIClient) Get(ctx context.Context, _, resourceName string, _ *armmsi.UserAssignedIdentitiesClientGetOptions) (armmsi.UserAssignedIdentitiesClientGetResponse, error) { if c.msiClientFunc != nil { - return c.msiClientFunc(resourceName) + return c.msiClientFunc(ctx, hclog.NewNullLogger(), nil, resourceName) } return armmsi.UserAssignedIdentitiesClientGetResponse{}, nil } @@ -101,33 +103,33 @@ func (c *mockMSIClient) NewListByResourceGroupPager(resourceGroup string, _ *arm return nil } -func (c *mockResourceClient) GetByID(_ context.Context, resourceID, _ string, _ *armresources.ClientGetByIDOptions) (armresources.ClientGetByIDResponse, error) { +func (c *mockResourceClient) GetByID(ctx context.Context, resourceID, _ string, _ *armresources.ClientGetByIDOptions) (armresources.ClientGetByIDResponse, error) { if c.resourceClientFunc != nil { - return c.resourceClientFunc(resourceID) + return c.resourceClientFunc(ctx, hclog.NewNullLogger(), nil, resourceID) } return armresources.ClientGetByIDResponse{}, nil } -func (c *mockProvidersClient) Get(_ context.Context, resourceID string, _ *armresources.ProvidersClientGetOptions) (armresources.ProvidersClientGetResponse, error) { +func (c *mockProvidersClient) Get(ctx context.Context, resourceID string, _ *armresources.ProvidersClientGetOptions) (armresources.ProvidersClientGetResponse, error) { if c.providersClientFunc != nil { - return c.providersClientFunc(resourceID) + return c.providersClientFunc(ctx, hclog.NewNullLogger(), nil, resourceID) } return armresources.ProvidersClientGetResponse{}, nil } -type computeClientFunc func(vmName string) (armcompute.VirtualMachinesClientGetResponse, error) +type computeClientFunc func(ctx context.Context, logger hclog.Logger, sys logical.SystemView, vmName string) (armcompute.VirtualMachinesClientGetResponse, error) -type vmssClientFunc func(vmssName string) (armcompute.VirtualMachineScaleSetsClientGetResponse, error) +type vmssClientFunc func(ctx context.Context, logger hclog.Logger, sys logical.SystemView, vmssName string) (armcompute.VirtualMachineScaleSetsClientGetResponse, error) -type msiClientFunc func(resourceName string) (armmsi.UserAssignedIdentitiesClientGetResponse, error) +type msiClientFunc func(ctx context.Context, logger hclog.Logger, sys logical.SystemView, resourceName string) (armmsi.UserAssignedIdentitiesClientGetResponse, error) type msiListFunc func(resoucename string) armmsi.UserAssignedIdentitiesClientListByResourceGroupResponse -type msGraphClientFunc func() (client.MSGraphClient, error) +type msGraphClientFunc func(ctx context.Context, logger hclog.Logger, sys logical.SystemView) (client.MSGraphClient, error) -type resourceClientFunc func(resourceID string) (armresources.ClientGetByIDResponse, error) +type resourceClientFunc func(ctx context.Context, logger hclog.Logger, sys logical.SystemView, resourceID string) (armresources.ClientGetByIDResponse, error) -type providersClientFunc func(string) (armresources.ProvidersClientGetResponse, error) +type providersClientFunc func(ctx context.Context, logger hclog.Logger, sys logical.SystemView, s string) (armresources.ProvidersClientGetResponse, error) type mockProvider struct { computeClientFunc @@ -153,36 +155,36 @@ func (*mockProvider) TokenVerifier() client.TokenVerifier { return newMockVerifier() } -func (p *mockProvider) ComputeClient(string) (client.ComputeClient, error) { +func (p *mockProvider) ComputeClient(ctx context.Context, logger hclog.Logger, sys logical.SystemView, s string) (client.ComputeClient, error) { return &mockComputeClient{ computeClientFunc: p.computeClientFunc, }, nil } -func (p *mockProvider) VMSSClient(string) (client.VMSSClient, error) { +func (p *mockProvider) VMSSClient(ctx context.Context, logger hclog.Logger, sys logical.SystemView, s string) (client.VMSSClient, error) { return &mockVMSSClient{ vmssClientFunc: p.vmssClientFunc, }, nil } -func (p *mockProvider) MSIClient(string) (client.MSIClient, error) { +func (p *mockProvider) MSIClient(ctx context.Context, logger hclog.Logger, sys logical.SystemView, s string) (client.MSIClient, error) { return &mockMSIClient{ msiClientFunc: p.msiClientFunc, msiListFunc: p.msiListFunc, }, nil } -func (p *mockProvider) MSGraphClient() (client.MSGraphClient, error) { +func (p *mockProvider) MSGraphClient(ctx context.Context, logger hclog.Logger, sys logical.SystemView) (client.MSGraphClient, error) { return nil, nil } -func (p *mockProvider) ResourceClient(string) (client.ResourceClient, error) { +func (p *mockProvider) ResourceClient(ctx context.Context, logger hclog.Logger, sys logical.SystemView, s string) (client.ResourceClient, error) { return &mockResourceClient{ resourceClientFunc: p.resourceClientFunc, }, nil } -func (p *mockProvider) ProvidersClient(string) (client.ProvidersClient, error) { +func (p *mockProvider) ProvidersClient(ctx context.Context, logger hclog.Logger, sys logical.SystemView, s string) (client.ProvidersClient, error) { return &mockProvidersClient{ providersClientFunc: p.providersClientFunc, }, nil diff --git a/path_config_test.go b/path_config_test.go index 6f1ebf9a..e98cdcb8 100644 --- a/path_config_test.go +++ b/path_config_test.go @@ -129,7 +129,7 @@ func TestConfig(t *testing.T) { "tenant_id": "foo", } - err = testConfigUpdate(t, b, s, configSubset) + _, err = testConfigUpdate(t, b, s, configSubset) if err != nil { t.Fatal(err) } diff --git a/path_login_test.go b/path_login_test.go index 8e2cb133..616e1ed0 100644 --- a/path_login_test.go +++ b/path_login_test.go @@ -18,6 +18,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/msi/armmsi" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources" "github.com/coreos/go-oidc" + "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault-plugin-auth-azure/client" "github.com/hashicorp/vault/sdk/helper/policyutil" "github.com/hashicorp/vault/sdk/logical" @@ -184,7 +185,7 @@ func TestLogin_ManagedIdentity(t *testing.T) { roleName := "test-role" // setup test response functions that mock the client GetByID response - nilIdentityRespFunc := func(_ string) (armresources.ClientGetByIDResponse, error) { + nilIdentityRespFunc := func(context.Context, hclog.Logger, logical.SystemView, string) (armresources.ClientGetByIDResponse, error) { return armresources.ClientGetByIDResponse{}, nil } userAssignedRespFunc, systemAssignedRespFunc := getResourceByIDResponses(t, principalID) @@ -195,7 +196,7 @@ func TestLogin_ManagedIdentity(t *testing.T) { claims map[string]interface{} roleData map[string]interface{} loginData map[string]interface{} - clientFunc func(resourceID string) (armresources.ClientGetByIDResponse, error) + clientFunc func(context.Context, hclog.Logger, logical.SystemView, string) (armresources.ClientGetByIDResponse, error) expectError bool }{ "login happy path user-assigned managed identity": { @@ -945,8 +946,8 @@ func TestGetAPIVersionForResource(t *testing.T) { // the azure arm resource client responses. If principalID is an empty string // then no identity data will be set in the response. func getResourceByIDResponses(t *testing.T, principalID string) ( - func(_ string) (armresources.ClientGetByIDResponse, error), - func(_ string) (armresources.ClientGetByIDResponse, error), + func(context.Context, hclog.Logger, logical.SystemView, string) (armresources.ClientGetByIDResponse, error), + func(context.Context, hclog.Logger, logical.SystemView, string) (armresources.ClientGetByIDResponse, error), ) { t.Helper() u := armresources.ClientGetByIDResponse{ @@ -972,10 +973,10 @@ func getResourceByIDResponses(t *testing.T, principalID string) ( s.GenericResource.Identity.PrincipalID = &principalID } - userAssignedRespFunc := func(_ string) (armresources.ClientGetByIDResponse, error) { + userAssignedRespFunc := func(context.Context, hclog.Logger, logical.SystemView, string) (armresources.ClientGetByIDResponse, error) { return u, nil } - systemAssignedRespFunc := func(_ string) (armresources.ClientGetByIDResponse, error) { + systemAssignedRespFunc := func(context.Context, hclog.Logger, logical.SystemView, string) (armresources.ClientGetByIDResponse, error) { return s, nil } @@ -984,7 +985,7 @@ func getResourceByIDResponses(t *testing.T, principalID string) ( // getProvidersResponse is a test helper to get the function that returns // the azure arm resource providers client response. -func getProvidersResponse(t *testing.T, resourceID string) func(_ string) (armresources.ProvidersClientGetResponse, error) { +func getProvidersResponse(t *testing.T, resourceID string) func(context.Context, hclog.Logger, logical.SystemView, string) (armresources.ProvidersClientGetResponse, error) { t.Helper() resourceType, err := arm.ParseResourceType(resourceID) @@ -1008,7 +1009,7 @@ func getProvidersResponse(t *testing.T, resourceID string) func(_ string) (armre }, }, } - providersRespFunc := func(_ string) (armresources.ProvidersClientGetResponse, error) { + providersRespFunc := func(context.Context, hclog.Logger, logical.SystemView, string) (armresources.ProvidersClientGetResponse, error) { return u, nil } return providersRespFunc @@ -1036,14 +1037,14 @@ func testJWT(t *testing.T, payload map[string]interface{}) string { } func getTestBackendFunctions(withLocation bool) ( - func(_ string) (armcompute.VirtualMachinesClientGetResponse, error), - func(_ string) (armcompute.VirtualMachineScaleSetsClientGetResponse, error), - func(_ string) (armmsi.UserAssignedIdentitiesClientGetResponse, error), + func(context.Context, hclog.Logger, logical.SystemView, string) (armcompute.VirtualMachinesClientGetResponse, error), + func(context.Context, hclog.Logger, logical.SystemView, string) (armcompute.VirtualMachineScaleSetsClientGetResponse, error), + func(context.Context, hclog.Logger, logical.SystemView, string) (armmsi.UserAssignedIdentitiesClientGetResponse, error), ) { principalID := "123e4567-e89b-12d3-a456-426655440000" if !withLocation { - c := func(_ string) (armcompute.VirtualMachinesClientGetResponse, error) { + c := func(context.Context, hclog.Logger, logical.SystemView, string) (armcompute.VirtualMachinesClientGetResponse, error) { id := armcompute.VirtualMachineIdentity{ PrincipalID: &principalID, } @@ -1053,7 +1054,7 @@ func getTestBackendFunctions(withLocation bool) ( }, }, nil } - v := func(_ string) (armcompute.VirtualMachineScaleSetsClientGetResponse, error) { + v := func(context.Context, hclog.Logger, logical.SystemView, string) (armcompute.VirtualMachineScaleSetsClientGetResponse, error) { id := armcompute.VirtualMachineScaleSetIdentity{ PrincipalID: &principalID, } @@ -1062,7 +1063,7 @@ func getTestBackendFunctions(withLocation bool) ( }}, nil } - m := func(_ string) (armmsi.UserAssignedIdentitiesClientGetResponse, error) { + m := func(context.Context, hclog.Logger, logical.SystemView, string) (armmsi.UserAssignedIdentitiesClientGetResponse, error) { userAssignedIdentityProperties := armmsi.UserAssignedIdentityProperties{ PrincipalID: &principalID, } @@ -1075,7 +1076,7 @@ func getTestBackendFunctions(withLocation bool) ( } else { location := "loc" - c := func(vmName string) (armcompute.VirtualMachinesClientGetResponse, error) { + c := func(_ context.Context, _ hclog.Logger, _ logical.SystemView, vmName string) (armcompute.VirtualMachinesClientGetResponse, error) { id := armcompute.VirtualMachineIdentity{ PrincipalID: &principalID, } @@ -1094,7 +1095,7 @@ func getTestBackendFunctions(withLocation bool) ( } return armcompute.VirtualMachinesClientGetResponse{}, nil } - v := func(vmssName string) (armcompute.VirtualMachineScaleSetsClientGetResponse, error) { + v := func(_ context.Context, _ hclog.Logger, _ logical.SystemView, vmssName string) (armcompute.VirtualMachineScaleSetsClientGetResponse, error) { id := armcompute.VirtualMachineScaleSetIdentity{ PrincipalID: &principalID, } @@ -1114,7 +1115,7 @@ func getTestBackendFunctions(withLocation bool) ( return armcompute.VirtualMachineScaleSetsClientGetResponse{}, nil } - m := func(_ string) (armmsi.UserAssignedIdentitiesClientGetResponse, error) { + m := func(context.Context, hclog.Logger, logical.SystemView, string) (armmsi.UserAssignedIdentitiesClientGetResponse, error) { userAssignedIdentityProperties := armmsi.UserAssignedIdentityProperties{ PrincipalID: &principalID, } @@ -1127,8 +1128,8 @@ func getTestBackendFunctions(withLocation bool) ( } } -func getTestMSGraphClient() func() (client.MSGraphClient, error) { - return func() (client.MSGraphClient, error) { +func getTestMSGraphClient() func(context.Context, hclog.Logger, logical.SystemView) (client.MSGraphClient, error) { + return func(context.Context, hclog.Logger, logical.SystemView) (client.MSGraphClient, error) { return nil, nil } }