From 6e5c9bd540f48b0db8d38cc703330563a3375c72 Mon Sep 17 00:00:00 2001 From: Shahram Kalantari Date: Mon, 4 Nov 2024 13:41:57 +1000 Subject: [PATCH] chore: use dependency injection to add more unit tests Signed-off-by: Shahram Kalantari --- .../azurekeyvault/provider.go | 26 ++++--- .../azurekeyvault/provider_test.go | 67 +++++++++++++++++-- 2 files changed, 78 insertions(+), 15 deletions(-) diff --git a/pkg/keymanagementprovider/azurekeyvault/provider.go b/pkg/keymanagementprovider/azurekeyvault/provider.go index e2cc41378..2e033d167 100644 --- a/pkg/keymanagementprovider/azurekeyvault/provider.go +++ b/pkg/keymanagementprovider/azurekeyvault/provider.go @@ -38,6 +38,7 @@ import ( "github.com/ratify-project/ratify/pkg/metrics" "golang.org/x/crypto/pkcs12" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/Azure/azure-sdk-for-go/sdk/keyvault/azkeys" "github.com/Azure/azure-sdk-for-go/sdk/keyvault/azsecrets" @@ -127,7 +128,7 @@ func (f *akvKMProviderFactory) Create(_ string, keyManagementProviderConfig conf logger.GetLogger(context.Background(), logOpt).Debugf("vaultURI %s", provider.vaultURI) - kvClientKeys, kvClientSecrets, err := initKVClient(context.Background(), provider.cloudEnv.KeyVaultEndpoint, provider.tenantID, provider.clientID) + kvClientKeys, kvClientSecrets, err := initKVClient(context.Background(), provider.cloudEnv.KeyVaultEndpoint, provider.tenantID, provider.clientID, nil) if err != nil { return nil, re.ErrorCodePluginInitFailure.NewError(re.KeyManagementProvider, ProviderName, re.AKVLink, err, "failed to create keyvault client", re.HideStackTrace) } @@ -233,27 +234,30 @@ func parseAzureEnvironment(cloudName string) (*azure.Environment, error) { return &env, err } -func initializeKvClient(ctx context.Context, keyVaultEndpoint, tenantID, clientID string) (*azkeys.Client, *azsecrets.Client, error) { +func initializeKvClient(ctx context.Context, keyVaultEndpoint, tenantID, clientID string, credProvider azcore.TokenCredential) (*azkeys.Client, *azsecrets.Client, error) { // Trim any trailing slash from the endpoint kvEndpoint := strings.TrimSuffix(keyVaultEndpoint, "/") - // Create the workload identity credential for authentication - credential, err := azidentity.NewWorkloadIdentityCredential(&azidentity.WorkloadIdentityCredentialOptions{ - ClientID: clientID, - TenantID: tenantID, - }) - if err != nil { - return nil, nil, re.ErrorCodeAuthDenied.WithDetail("failed to create workload identity credential").WithRemediation(re.AKVLink).WithError(err) + // If credProvider is nil, create the default credential + if credProvider == nil { + var err error + credProvider, err = azidentity.NewWorkloadIdentityCredential(&azidentity.WorkloadIdentityCredentialOptions{ + ClientID: clientID, + TenantID: tenantID, + }) + if err != nil { + return nil, nil, re.ErrorCodeAuthDenied.WithDetail("failed to create workload identity credential").WithRemediation(re.AKVLink).WithError(err) + } } // create azkeys client - kvClientKeys, err := azkeys.NewClient(kvEndpoint, credential, nil) + kvClientKeys, err := azkeys.NewClient(kvEndpoint, credProvider, nil) if err != nil { return nil, nil, re.ErrorCodeConfigInvalid.WithDetail("Failed to create Key Vault client").WithRemediation(re.AKVLink).WithError(err) } // create azsecrets client - kvClientSecrets, err := azsecrets.NewClient(kvEndpoint, credential, nil) + kvClientSecrets, err := azsecrets.NewClient(kvEndpoint, credProvider, nil) if err != nil { return nil, nil, re.ErrorCodeConfigInvalid.WithDetail("Failed to create Key Vault client").WithRemediation(re.AKVLink).WithError(err) } diff --git a/pkg/keymanagementprovider/azurekeyvault/provider_test.go b/pkg/keymanagementprovider/azurekeyvault/provider_test.go index 0fdc7976f..0e12d04ec 100644 --- a/pkg/keymanagementprovider/azurekeyvault/provider_test.go +++ b/pkg/keymanagementprovider/azurekeyvault/provider_test.go @@ -25,6 +25,7 @@ import ( "testing" "time" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/Azure/azure-sdk-for-go/sdk/keyvault/azkeys" "github.com/Azure/azure-sdk-for-go/sdk/keyvault/azsecrets" @@ -66,7 +67,7 @@ func SkipTestInitializeKVClient(t *testing.T) { } for i := range testEnvs { - kvClientkeys, kvClientSecrets, err := initializeKvClient(context.TODO(), testEnvs[i].KeyVaultEndpoint, "", "") + kvClientkeys, kvClientSecrets, err := initializeKvClient(context.TODO(), testEnvs[i].KeyVaultEndpoint, "", "", nil) assert.NoError(t, err) assert.NotNil(t, kvClientkeys) assert.NotNil(t, kvClientSecrets) @@ -178,7 +179,7 @@ func TestCreate(t *testing.T) { } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - initKVClient = func(_ context.Context, _, _, _ string) (*azkeys.Client, *azsecrets.Client, error) { + initKVClient = func(_ context.Context, _, _, _ string, _ azcore.TokenCredential) (*azkeys.Client, *azsecrets.Client, error) { return &azkeys.Client{}, &azsecrets.Client{}, nil } _, err := factory.Create("v1", tc.config, "") @@ -229,7 +230,7 @@ func TestGetKeys(t *testing.T) { }, } - initKVClient = func(_ context.Context, _, _, _ string) (*azkeys.Client, *azsecrets.Client, error) { + initKVClient = func(_ context.Context, _, _, _ string, _ azcore.TokenCredential) (*azkeys.Client, *azsecrets.Client, error) { return &azkeys.Client{}, &azsecrets.Client{}, nil } provider, err := factory.Create("v1", config, "") @@ -617,7 +618,7 @@ func TestInitializeKvClient(t *testing.T) { mockSecretsClient.On("NewClient", tt.kvEndpoint, mockCredential, mock.Anything).Return(mockSecretsClient, tt.mockSecretsErr) // Call function under test - keysClient, secretsClient, err := initializeKvClient(context.Background(), tt.kvEndpoint, tt.tenantID, tt.clientID) + keysClient, secretsClient, err := initializeKvClient(context.Background(), tt.kvEndpoint, tt.tenantID, tt.clientID, nil) // Validate expectations if tt.expectedErr { @@ -675,3 +676,61 @@ func TestGetKeyFromKeyBundlex(t *testing.T) { }) } } + +func TestInitializeKvClient_Success(t *testing.T) { + // Mock the context and input parameters + ctx := context.Background() + keyVaultEndpoint := "https://myvault.vault.azure.net/" + tenantID := "tenant-id" + clientID := "client-id" + + // Create a mock credential provider + mockCredential, err := azidentity.NewClientSecretCredential(tenantID, clientID, "fake-secret", nil) + if err != nil { + t.Fatalf("Failed to create mock credential: %v", err) + } + + // Run the function with the mock credential + kvClientKeys, kvClientSecrets, err := initializeKvClient(ctx, keyVaultEndpoint, tenantID, clientID, mockCredential) + + // Assert the function succeeds without errors and clients are created + assert.NotNil(t, kvClientKeys) + assert.NotNil(t, kvClientSecrets) + assert.NoError(t, err) +} + +func TestInitializeKvClient_FailureInAzKeysClient(t *testing.T) { + // Mock the context and input parameters + ctx := context.Background() + keyVaultEndpoint := "https://invalid-vault.vault.azure.net/" + tenantID := "mock_tenant-id" + clientID := "mock_client-id" + + // Run the function + kvClientKeys, kvClientSecrets, err := initializeKvClient(ctx, keyVaultEndpoint, tenantID, clientID, nil) + + // Assert that an error occurred and clients were not created + assert.Nil(t, kvClientKeys) + assert.Nil(t, kvClientSecrets) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to create workload identity credential") +} + +func TestInitializeKvClient_FailureInAzSecretsClient(t *testing.T) { + // Mock the context and input parameters + ctx := context.Background() + keyVaultEndpoint := "https://valid-vault.vault.azure.net/" + tenantID := "tenant-id" + clientID := "client-id" + + // Modify the azsecrets.NewClient function to simulate failure + // Run the function + kvClientKeys, kvClientSecrets, err := initializeKvClient(ctx, keyVaultEndpoint, tenantID, clientID, nil) + + // Assert that an error occurred and clients were not created + assert.Nil(t, kvClientKeys) + assert.Nil(t, kvClientSecrets) + assert.Error(t, err) + // assert.Contains(t, err.Error(), "Failed to create Key Vault client") + assert.Contains(t, err.Error(), "failed to create workload identity credential") +}