Skip to content

Commit

Permalink
test: add tests to akv provider (ratify-project#1729)
Browse files Browse the repository at this point in the history
  • Loading branch information
binbin-li authored Aug 21, 2024
1 parent 0b6aa67 commit 7500f96
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 10 deletions.
3 changes: 2 additions & 1 deletion pkg/certificateprovider/azurekeyvault/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (

kv "github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/ratify-project/ratify/internal/version"
"github.com/ratify-project/ratify/pkg/certificateprovider/azurekeyvault/types"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -105,7 +106,7 @@ func SkipTestInitializeKVClient(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, kvBaseClient)
assert.NotNil(t, kvBaseClient.Authorizer)
assert.Contains(t, kvBaseClient.UserAgent, "ratify")
assert.Contains(t, kvBaseClient.UserAgent, version.UserAgent)
}
}

Expand Down
11 changes: 6 additions & 5 deletions pkg/keymanagementprovider/azurekeyvault/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/go-jose/go-jose/v3"
re "github.com/ratify-project/ratify/errors"
"github.com/ratify-project/ratify/internal/logger"
"github.com/ratify-project/ratify/internal/version"
"github.com/ratify-project/ratify/pkg/keymanagementprovider"
"github.com/ratify-project/ratify/pkg/keymanagementprovider/azurekeyvault/types"
"github.com/ratify-project/ratify/pkg/keymanagementprovider/config"
Expand Down Expand Up @@ -122,7 +123,7 @@ func (f *akvKMProviderFactory) Create(_ string, keyManagementProviderConfig conf

logger.GetLogger(context.Background(), logOpt).Debugf("vaultURI %s", provider.vaultURI)

kvClient, err := initKVClient(context.Background(), provider.cloudEnv.KeyVaultEndpoint, provider.tenantID, provider.clientID)
kvClient, err := initKVClient(context.Background(), provider.cloudEnv.KeyVaultEndpoint, provider.tenantID, provider.clientID, version.UserAgent)
if err != nil {
return nil, re.ErrorCodePluginInitFailure.NewError(re.KeyManagementProvider, ProviderName, re.AKVLink, err, "failed to create keyvault client", re.HideStackTrace)
}
Expand Down Expand Up @@ -225,18 +226,18 @@ func parseAzureEnvironment(cloudName string) (*azure.Environment, error) {
return &env, err
}

func initializeKvClient(ctx context.Context, keyVaultEndpoint, tenantID, clientID string) (*kv.BaseClient, error) {
func initializeKvClient(ctx context.Context, keyVaultEndpoint, tenantID, clientID, userAgent string) (*kv.BaseClient, error) {
kvClient := kv.New()
kvEndpoint := strings.TrimSuffix(keyVaultEndpoint, "/")

err := kvClient.AddToUserAgent("ratify")
err := kvClient.AddToUserAgent(userAgent)
if err != nil {
return nil, re.ErrorCodeConfigInvalid.NewError(re.KeyManagementProvider, ProviderName, re.AKVLink, err, "failed to add user agent to keyvault client", re.HideStackTrace)
return nil, re.ErrorCodeConfigInvalid.WithDetail("Failed to add user agent to keyvault client.").WithRemediation(re.AKVLink).WithError(err)
}

kvClient.Authorizer, err = getAuthorizerForWorkloadIdentity(ctx, tenantID, clientID, kvEndpoint)
if err != nil {
return nil, re.ErrorCodeAuthDenied.NewError(re.KeyManagementProvider, ProviderName, re.AKVLink, err, "failed to get authorizer for keyvault client", re.HideStackTrace)
return nil, re.ErrorCodeAuthDenied.WithDetail("failed to get authorizer for keyvault client").WithRemediation(re.AKVLink).WithError(err)
}
return &kvClient, nil
}
Expand Down
44 changes: 40 additions & 4 deletions pkg/keymanagementprovider/azurekeyvault/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (

kv "github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/ratify-project/ratify/internal/version"
"github.com/ratify-project/ratify/pkg/keymanagementprovider/azurekeyvault/types"
"github.com/ratify-project/ratify/pkg/keymanagementprovider/config"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -62,11 +63,11 @@ func SkipTestInitializeKVClient(t *testing.T) {
}

for i := range testEnvs {
kvBaseClient, err := initializeKvClient(context.TODO(), testEnvs[i].KeyVaultEndpoint, "", "")
kvBaseClient, err := initializeKvClient(context.TODO(), testEnvs[i].KeyVaultEndpoint, "", "", version.UserAgent)
assert.NoError(t, err)
assert.NotNil(t, kvBaseClient)
assert.NotNil(t, kvBaseClient.Authorizer)
assert.Contains(t, kvBaseClient.UserAgent, "ratify")
assert.Contains(t, kvBaseClient.UserAgent, version.UserAgent)
}
}

Expand Down Expand Up @@ -173,7 +174,7 @@ func TestCreate(t *testing.T) {
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
initKVClient = func(_ context.Context, _, _, _ string) (*kv.BaseClient, error) {
initKVClient = func(_ context.Context, _, _, _, _ string) (*kv.BaseClient, error) {
return &kv.BaseClient{}, nil
}
_, err := factory.Create("v1", tc.config, "")
Expand Down Expand Up @@ -224,7 +225,7 @@ func TestGetKeys(t *testing.T) {
},
}

initKVClient = func(_ context.Context, _, _, _ string) (*kv.BaseClient, error) {
initKVClient = func(_ context.Context, _, _, _, _ string) (*kv.BaseClient, error) {
return &kv.BaseClient{}, nil
}
provider, err := factory.Create("v1", config, "")
Expand Down Expand Up @@ -506,3 +507,38 @@ func TestValidate(t *testing.T) {
})
}
}

func TestInitializeKvClient(t *testing.T) {
tests := []struct {
name string
kvEndpoint string
userAgent string
tenantID string
clientID string
expectedErr bool
}{
{
name: "Empty user agent",
kvEndpoint: "https://test.vault.azure.net",
userAgent: "",
expectedErr: true,
},
{
name: "Auth failure",
kvEndpoint: "https://test.vault.azure.net",
userAgent: version.UserAgent,
tenantID: "testTenantID",
clientID: "testClientID",
expectedErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := initializeKvClient(context.Background(), tt.kvEndpoint, tt.tenantID, tt.clientID, tt.userAgent)
if tt.expectedErr != (err != nil) {
t.Fatalf("expected error: %v, got: %v", tt.expectedErr, err)
}
})
}
}

0 comments on commit 7500f96

Please sign in to comment.