diff --git a/pkg/certificateprovider/azurekeyvault/provider_test.go b/pkg/certificateprovider/azurekeyvault/provider_test.go index 001c9b75b..f11f31eed 100644 --- a/pkg/certificateprovider/azurekeyvault/provider_test.go +++ b/pkg/certificateprovider/azurekeyvault/provider_test.go @@ -26,6 +26,7 @@ import ( kv "github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault" "github.com/Azure/go-autorest/autorest/azure" + "github.com/ratify-project/ratify/internal/version" "github.com/ratify-project/ratify/pkg/certificateprovider/azurekeyvault/types" "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" @@ -105,7 +106,7 @@ func SkipTestInitializeKVClient(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, kvBaseClient) assert.NotNil(t, kvBaseClient.Authorizer) - assert.Contains(t, kvBaseClient.UserAgent, "ratify") + assert.Contains(t, kvBaseClient.UserAgent, version.UserAgent) } } diff --git a/pkg/keymanagementprovider/azurekeyvault/provider.go b/pkg/keymanagementprovider/azurekeyvault/provider.go index 2cbe8752b..da885af90 100644 --- a/pkg/keymanagementprovider/azurekeyvault/provider.go +++ b/pkg/keymanagementprovider/azurekeyvault/provider.go @@ -31,6 +31,7 @@ import ( "github.com/go-jose/go-jose/v3" re "github.com/ratify-project/ratify/errors" "github.com/ratify-project/ratify/internal/logger" + "github.com/ratify-project/ratify/internal/version" "github.com/ratify-project/ratify/pkg/keymanagementprovider" "github.com/ratify-project/ratify/pkg/keymanagementprovider/azurekeyvault/types" "github.com/ratify-project/ratify/pkg/keymanagementprovider/config" @@ -122,7 +123,7 @@ func (f *akvKMProviderFactory) Create(_ string, keyManagementProviderConfig conf logger.GetLogger(context.Background(), logOpt).Debugf("vaultURI %s", provider.vaultURI) - kvClient, err := initKVClient(context.Background(), provider.cloudEnv.KeyVaultEndpoint, provider.tenantID, provider.clientID) + kvClient, err := initKVClient(context.Background(), provider.cloudEnv.KeyVaultEndpoint, provider.tenantID, provider.clientID, version.UserAgent) if err != nil { return nil, re.ErrorCodePluginInitFailure.NewError(re.KeyManagementProvider, ProviderName, re.AKVLink, err, "failed to create keyvault client", re.HideStackTrace) } @@ -225,18 +226,18 @@ func parseAzureEnvironment(cloudName string) (*azure.Environment, error) { return &env, err } -func initializeKvClient(ctx context.Context, keyVaultEndpoint, tenantID, clientID string) (*kv.BaseClient, error) { +func initializeKvClient(ctx context.Context, keyVaultEndpoint, tenantID, clientID, userAgent string) (*kv.BaseClient, error) { kvClient := kv.New() kvEndpoint := strings.TrimSuffix(keyVaultEndpoint, "/") - err := kvClient.AddToUserAgent("ratify") + err := kvClient.AddToUserAgent(userAgent) if err != nil { - return nil, re.ErrorCodeConfigInvalid.NewError(re.KeyManagementProvider, ProviderName, re.AKVLink, err, "failed to add user agent to keyvault client", re.HideStackTrace) + return nil, re.ErrorCodeConfigInvalid.WithDetail("Failed to add user agent to keyvault client.").WithRemediation(re.AKVLink).WithError(err) } kvClient.Authorizer, err = getAuthorizerForWorkloadIdentity(ctx, tenantID, clientID, kvEndpoint) if err != nil { - return nil, re.ErrorCodeAuthDenied.NewError(re.KeyManagementProvider, ProviderName, re.AKVLink, err, "failed to get authorizer for keyvault client", re.HideStackTrace) + return nil, re.ErrorCodeAuthDenied.WithDetail("failed to get authorizer for keyvault client").WithRemediation(re.AKVLink).WithError(err) } return &kvClient, nil } diff --git a/pkg/keymanagementprovider/azurekeyvault/provider_test.go b/pkg/keymanagementprovider/azurekeyvault/provider_test.go index ce95d24a7..676a43892 100644 --- a/pkg/keymanagementprovider/azurekeyvault/provider_test.go +++ b/pkg/keymanagementprovider/azurekeyvault/provider_test.go @@ -26,6 +26,7 @@ import ( kv "github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault" "github.com/Azure/go-autorest/autorest/azure" + "github.com/ratify-project/ratify/internal/version" "github.com/ratify-project/ratify/pkg/keymanagementprovider/azurekeyvault/types" "github.com/ratify-project/ratify/pkg/keymanagementprovider/config" "github.com/stretchr/testify/assert" @@ -62,11 +63,11 @@ func SkipTestInitializeKVClient(t *testing.T) { } for i := range testEnvs { - kvBaseClient, err := initializeKvClient(context.TODO(), testEnvs[i].KeyVaultEndpoint, "", "") + kvBaseClient, err := initializeKvClient(context.TODO(), testEnvs[i].KeyVaultEndpoint, "", "", version.UserAgent) assert.NoError(t, err) assert.NotNil(t, kvBaseClient) assert.NotNil(t, kvBaseClient.Authorizer) - assert.Contains(t, kvBaseClient.UserAgent, "ratify") + assert.Contains(t, kvBaseClient.UserAgent, version.UserAgent) } } @@ -173,7 +174,7 @@ func TestCreate(t *testing.T) { } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - initKVClient = func(_ context.Context, _, _, _ string) (*kv.BaseClient, error) { + initKVClient = func(_ context.Context, _, _, _, _ string) (*kv.BaseClient, error) { return &kv.BaseClient{}, nil } _, err := factory.Create("v1", tc.config, "") @@ -224,7 +225,7 @@ func TestGetKeys(t *testing.T) { }, } - initKVClient = func(_ context.Context, _, _, _ string) (*kv.BaseClient, error) { + initKVClient = func(_ context.Context, _, _, _, _ string) (*kv.BaseClient, error) { return &kv.BaseClient{}, nil } provider, err := factory.Create("v1", config, "") @@ -506,3 +507,38 @@ func TestValidate(t *testing.T) { }) } } + +func TestInitializeKvClient(t *testing.T) { + tests := []struct { + name string + kvEndpoint string + userAgent string + tenantID string + clientID string + expectedErr bool + }{ + { + name: "Empty user agent", + kvEndpoint: "https://test.vault.azure.net", + userAgent: "", + expectedErr: true, + }, + { + name: "Auth failure", + kvEndpoint: "https://test.vault.azure.net", + userAgent: version.UserAgent, + tenantID: "testTenantID", + clientID: "testClientID", + expectedErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := initializeKvClient(context.Background(), tt.kvEndpoint, tt.tenantID, tt.clientID, tt.userAgent) + if tt.expectedErr != (err != nil) { + t.Fatalf("expected error: %v, got: %v", tt.expectedErr, err) + } + }) + } +}