Skip to content

Commit

Permalink
chore: use dependency injection to add more unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: Shahram Kalantari <[email protected]>
  • Loading branch information
shahramk64 committed Nov 4, 2024
1 parent ece047c commit 6e5c9bd
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 15 deletions.
26 changes: 15 additions & 11 deletions pkg/keymanagementprovider/azurekeyvault/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import (
"github.com/ratify-project/ratify/pkg/metrics"
"golang.org/x/crypto/pkcs12"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/Azure/azure-sdk-for-go/sdk/keyvault/azkeys"
"github.com/Azure/azure-sdk-for-go/sdk/keyvault/azsecrets"
Expand Down Expand Up @@ -127,7 +128,7 @@ func (f *akvKMProviderFactory) Create(_ string, keyManagementProviderConfig conf

logger.GetLogger(context.Background(), logOpt).Debugf("vaultURI %s", provider.vaultURI)

kvClientKeys, kvClientSecrets, err := initKVClient(context.Background(), provider.cloudEnv.KeyVaultEndpoint, provider.tenantID, provider.clientID)
kvClientKeys, kvClientSecrets, err := initKVClient(context.Background(), provider.cloudEnv.KeyVaultEndpoint, provider.tenantID, provider.clientID, nil)
if err != nil {
return nil, re.ErrorCodePluginInitFailure.NewError(re.KeyManagementProvider, ProviderName, re.AKVLink, err, "failed to create keyvault client", re.HideStackTrace)
}
Expand Down Expand Up @@ -233,27 +234,30 @@ func parseAzureEnvironment(cloudName string) (*azure.Environment, error) {
return &env, err
}

func initializeKvClient(ctx context.Context, keyVaultEndpoint, tenantID, clientID string) (*azkeys.Client, *azsecrets.Client, error) {
func initializeKvClient(ctx context.Context, keyVaultEndpoint, tenantID, clientID string, credProvider azcore.TokenCredential) (*azkeys.Client, *azsecrets.Client, error) {
// Trim any trailing slash from the endpoint
kvEndpoint := strings.TrimSuffix(keyVaultEndpoint, "/")

// Create the workload identity credential for authentication
credential, err := azidentity.NewWorkloadIdentityCredential(&azidentity.WorkloadIdentityCredentialOptions{
ClientID: clientID,
TenantID: tenantID,
})
if err != nil {
return nil, nil, re.ErrorCodeAuthDenied.WithDetail("failed to create workload identity credential").WithRemediation(re.AKVLink).WithError(err)
// If credProvider is nil, create the default credential
if credProvider == nil {
var err error
credProvider, err = azidentity.NewWorkloadIdentityCredential(&azidentity.WorkloadIdentityCredentialOptions{
ClientID: clientID,
TenantID: tenantID,
})
if err != nil {
return nil, nil, re.ErrorCodeAuthDenied.WithDetail("failed to create workload identity credential").WithRemediation(re.AKVLink).WithError(err)
}
}

// create azkeys client
kvClientKeys, err := azkeys.NewClient(kvEndpoint, credential, nil)
kvClientKeys, err := azkeys.NewClient(kvEndpoint, credProvider, nil)
if err != nil {
return nil, nil, re.ErrorCodeConfigInvalid.WithDetail("Failed to create Key Vault client").WithRemediation(re.AKVLink).WithError(err)
}

// create azsecrets client
kvClientSecrets, err := azsecrets.NewClient(kvEndpoint, credential, nil)
kvClientSecrets, err := azsecrets.NewClient(kvEndpoint, credProvider, nil)
if err != nil {
return nil, nil, re.ErrorCodeConfigInvalid.WithDetail("Failed to create Key Vault client").WithRemediation(re.AKVLink).WithError(err)
}
Expand Down
67 changes: 63 additions & 4 deletions pkg/keymanagementprovider/azurekeyvault/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"testing"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/Azure/azure-sdk-for-go/sdk/keyvault/azkeys"
"github.com/Azure/azure-sdk-for-go/sdk/keyvault/azsecrets"
Expand Down Expand Up @@ -66,7 +67,7 @@ func SkipTestInitializeKVClient(t *testing.T) {
}

for i := range testEnvs {
kvClientkeys, kvClientSecrets, err := initializeKvClient(context.TODO(), testEnvs[i].KeyVaultEndpoint, "", "")
kvClientkeys, kvClientSecrets, err := initializeKvClient(context.TODO(), testEnvs[i].KeyVaultEndpoint, "", "", nil)
assert.NoError(t, err)
assert.NotNil(t, kvClientkeys)
assert.NotNil(t, kvClientSecrets)
Expand Down Expand Up @@ -178,7 +179,7 @@ func TestCreate(t *testing.T) {
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
initKVClient = func(_ context.Context, _, _, _ string) (*azkeys.Client, *azsecrets.Client, error) {
initKVClient = func(_ context.Context, _, _, _ string, _ azcore.TokenCredential) (*azkeys.Client, *azsecrets.Client, error) {
return &azkeys.Client{}, &azsecrets.Client{}, nil
}
_, err := factory.Create("v1", tc.config, "")
Expand Down Expand Up @@ -229,7 +230,7 @@ func TestGetKeys(t *testing.T) {
},
}

initKVClient = func(_ context.Context, _, _, _ string) (*azkeys.Client, *azsecrets.Client, error) {
initKVClient = func(_ context.Context, _, _, _ string, _ azcore.TokenCredential) (*azkeys.Client, *azsecrets.Client, error) {
return &azkeys.Client{}, &azsecrets.Client{}, nil
}
provider, err := factory.Create("v1", config, "")
Expand Down Expand Up @@ -617,7 +618,7 @@ func TestInitializeKvClient(t *testing.T) {
mockSecretsClient.On("NewClient", tt.kvEndpoint, mockCredential, mock.Anything).Return(mockSecretsClient, tt.mockSecretsErr)

// Call function under test
keysClient, secretsClient, err := initializeKvClient(context.Background(), tt.kvEndpoint, tt.tenantID, tt.clientID)
keysClient, secretsClient, err := initializeKvClient(context.Background(), tt.kvEndpoint, tt.tenantID, tt.clientID, nil)

// Validate expectations
if tt.expectedErr {
Expand Down Expand Up @@ -675,3 +676,61 @@ func TestGetKeyFromKeyBundlex(t *testing.T) {
})
}
}

func TestInitializeKvClient_Success(t *testing.T) {
// Mock the context and input parameters
ctx := context.Background()
keyVaultEndpoint := "https://myvault.vault.azure.net/"
tenantID := "tenant-id"
clientID := "client-id"

// Create a mock credential provider
mockCredential, err := azidentity.NewClientSecretCredential(tenantID, clientID, "fake-secret", nil)
if err != nil {
t.Fatalf("Failed to create mock credential: %v", err)
}

// Run the function with the mock credential
kvClientKeys, kvClientSecrets, err := initializeKvClient(ctx, keyVaultEndpoint, tenantID, clientID, mockCredential)

// Assert the function succeeds without errors and clients are created
assert.NotNil(t, kvClientKeys)
assert.NotNil(t, kvClientSecrets)
assert.NoError(t, err)
}

func TestInitializeKvClient_FailureInAzKeysClient(t *testing.T) {
// Mock the context and input parameters
ctx := context.Background()
keyVaultEndpoint := "https://invalid-vault.vault.azure.net/"
tenantID := "mock_tenant-id"
clientID := "mock_client-id"

// Run the function
kvClientKeys, kvClientSecrets, err := initializeKvClient(ctx, keyVaultEndpoint, tenantID, clientID, nil)

// Assert that an error occurred and clients were not created
assert.Nil(t, kvClientKeys)
assert.Nil(t, kvClientSecrets)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to create workload identity credential")
}

func TestInitializeKvClient_FailureInAzSecretsClient(t *testing.T) {
// Mock the context and input parameters
ctx := context.Background()
keyVaultEndpoint := "https://valid-vault.vault.azure.net/"
tenantID := "tenant-id"
clientID := "client-id"

// Modify the azsecrets.NewClient function to simulate failure
// Run the function
kvClientKeys, kvClientSecrets, err := initializeKvClient(ctx, keyVaultEndpoint, tenantID, clientID, nil)

// Assert that an error occurred and clients were not created
assert.Nil(t, kvClientKeys)
assert.Nil(t, kvClientSecrets)
assert.Error(t, err)
// assert.Contains(t, err.Error(), "Failed to create Key Vault client")
assert.Contains(t, err.Error(), "failed to create workload identity credential")
}

0 comments on commit 6e5c9bd

Please sign in to comment.