Skip to content

Commit

Permalink
chore: add more unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: Shahram Kalantari <[email protected]>
  • Loading branch information
shahramk64 committed Nov 1, 2024
1 parent 628cfc3 commit 568e2f6
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 35 deletions.
4 changes: 0 additions & 4 deletions pkg/certificateprovider/azurekeyvault/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ import (
"github.com/ratify-project/ratify/pkg/certificateprovider"
"github.com/ratify-project/ratify/pkg/certificateprovider/azurekeyvault/types"
"github.com/ratify-project/ratify/pkg/metrics"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/pkcs12"

"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
Expand Down Expand Up @@ -213,8 +212,6 @@ func parseAzureEnvironment(cloudName string) (*azure.Environment, error) {
func initializeKvClient(ctx context.Context, keyVaultEndpoint, tenantID, clientID string) (*azsecrets.Client, error) {
// Trim any trailing slash from the endpoint
kvEndpoint := strings.TrimSuffix(keyVaultEndpoint, "/")
logger.GetLogger(ctx, logOpt).Infof("kvEndpoint: '%s'", kvEndpoint)
logrus.WithContext(ctx).Infof("kvEndpoint: '%s'", kvEndpoint)

// Create the workload identity credential for authentication
credential, err := azidentity.NewWorkloadIdentityCredential(&azidentity.WorkloadIdentityCredentialOptions{
Expand All @@ -224,7 +221,6 @@ func initializeKvClient(ctx context.Context, keyVaultEndpoint, tenantID, clientI
if err != nil {
return nil, re.ErrorCodeAuthDenied.WithDetail("failed to create workload identity credential").WithRemediation(re.AKVLink).WithError(err)
}
logger.GetLogger(ctx, logOpt).Infof("credential created successfully")

// create azsecrets client
kvClientSecrets, err := azsecrets.NewClient(kvEndpoint, credential, nil)
Expand Down
97 changes: 84 additions & 13 deletions pkg/certificateprovider/azurekeyvault/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,18 @@ package azurekeyvault
// Source: https://github.com/Azure/secrets-store-csi-driver-provider-azure/tree/release-1.4/pkg/provider
import (
"context"
"errors"
"reflect"
"strings"
"testing"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/Azure/azure-sdk-for-go/sdk/keyvault/azsecrets"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/ratify-project/ratify/pkg/certificateprovider/azurekeyvault/types"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)

func TestParseAzureEnvironment(t *testing.T) {
Expand Down Expand Up @@ -92,20 +94,89 @@ func TestFormatKeyVaultCertificate(t *testing.T) {
}
}

func SkipTestInitializeKVClient(t *testing.T) {
testEnvs := []azure.Environment{
azure.PublicCloud,
azure.GermanCloud,
azure.ChinaCloud,
azure.USGovernmentCloud,
// Mock clients
type MockAzSecretsClient struct {
mock.Mock
}

type MockWorkloadIdentityCredential struct {
mock.Mock
}

// Mock functions
func (m *MockWorkloadIdentityCredential) NewWorkloadIdentityCredential(options *azidentity.WorkloadIdentityCredentialOptions) (*MockWorkloadIdentityCredential, error) {
args := m.Called(options)
return args.Get(0).(*MockWorkloadIdentityCredential), args.Error(1)
}

func (m *MockAzSecretsClient) NewClient(endpoint string, credential *azidentity.WorkloadIdentityCredential, options *azsecrets.ClientOptions) (*azsecrets.Client, error) {
args := m.Called(endpoint, credential, options)
return args.Get(0).(*azsecrets.Client), args.Error(1)
}

func TestInitializeKvClient(t *testing.T) {
mockCredential := new(MockWorkloadIdentityCredential)
mockSecretsClient := new(MockAzSecretsClient)

tests := []struct {
name string
kvEndpoint string
userAgent string
tenantID string
clientID string
mockCredentialErr error
mockSecretsErr error
expectedErr bool
}{
{
name: "Empty user agent",
kvEndpoint: "https://test.vault.azure.net",
userAgent: "",
expectedErr: true,
},
{
name: "Auth failure",
kvEndpoint: "https://test.vault.azure.net",
tenantID: "testTenantID",
clientID: "testClientID",
expectedErr: true,
},
{
name: "credential creation error",
kvEndpoint: "https://test-keyvault.vault.azure.net",
tenantID: "test-tenant-id",
clientID: "test-client-id",
mockCredentialErr: errors.New("failed to create workload identity credential"),
expectedErr: true,
},
{
name: "azsecrets client creation error",
kvEndpoint: "https://test-keyvault.vault.azure.net",
tenantID: "test-tenant-id",
clientID: "test-client-id",
mockSecretsErr: errors.New("failed to create azsecrets client"),
expectedErr: true,
},
}

for i := range testEnvs {
kvClientSecrets, err := initializeKvClient(context.TODO(), testEnvs[i].KeyVaultEndpoint, "", "")
assert.NoError(t, err)
assert.NotNil(t, kvClientSecrets)
// assert.NotNil(t, kvBaseClient.Authorizer)
// assert.Contains(t, kvClientSecrets.endpoint, testEnvs[i].KeyVaultEndpoint)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Set up mocks
mockCredential.On("NewWorkloadIdentityCredential", mock.Anything).Return(mockCredential, tt.mockCredentialErr)
mockSecretsClient.On("NewClient", tt.kvEndpoint, mockCredential, mock.Anything).Return(mockSecretsClient, tt.mockSecretsErr)

// Call function under test
secretsClient, err := initializeKvClient(context.Background(), tt.kvEndpoint, tt.tenantID, tt.clientID)

// Validate expectations
if tt.expectedErr {
assert.Error(t, err)
assert.Nil(t, secretsClient)
} else {
assert.NoError(t, err)
assert.NotNil(t, secretsClient)
}
})
}
}

Expand Down
8 changes: 1 addition & 7 deletions pkg/keymanagementprovider/azurekeyvault/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ import (
"github.com/ratify-project/ratify/pkg/keymanagementprovider/config"
"github.com/ratify-project/ratify/pkg/keymanagementprovider/factory"
"github.com/ratify-project/ratify/pkg/metrics"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/pkcs12"

"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
Expand Down Expand Up @@ -237,8 +236,6 @@ func parseAzureEnvironment(cloudName string) (*azure.Environment, error) {
func initializeKvClient(ctx context.Context, keyVaultEndpoint, tenantID, clientID string) (*azkeys.Client, *azsecrets.Client, error) {
// Trim any trailing slash from the endpoint
kvEndpoint := strings.TrimSuffix(keyVaultEndpoint, "/")
logger.GetLogger(ctx, logOpt).Infof("kvEndpoint: '%s'", kvEndpoint)
logrus.WithContext(ctx).Infof("kvEndpoint: '%s'", kvEndpoint)

// Create the workload identity credential for authentication
credential, err := azidentity.NewWorkloadIdentityCredential(&azidentity.WorkloadIdentityCredentialOptions{
Expand All @@ -248,23 +245,20 @@ func initializeKvClient(ctx context.Context, keyVaultEndpoint, tenantID, clientI
if err != nil {
return nil, nil, re.ErrorCodeAuthDenied.WithDetail("failed to create workload identity credential").WithRemediation(re.AKVLink).WithError(err)
}
logger.GetLogger(ctx, logOpt).Infof("credential created successfully")
logrus.WithContext(ctx).Infof("credential created successfully")

// create azkeys client
kvClientKeys, err := azkeys.NewClient(kvEndpoint, credential, nil)
if err != nil {
return nil, nil, re.ErrorCodeConfigInvalid.WithDetail("Failed to create Key Vault client").WithRemediation(re.AKVLink).WithError(err)
}
logger.GetLogger(ctx, logOpt).Infof("azkeys kvclient created successfully")
logrus.WithContext(ctx).Infof("azkeys kvclient created successfully")

// create azsecrets client
kvClientSecrets, err := azsecrets.NewClient(kvEndpoint, credential, nil)
if err != nil {
return nil, nil, re.ErrorCodeConfigInvalid.WithDetail("Failed to create Key Vault client").WithRemediation(re.AKVLink).WithError(err)
}
logger.GetLogger(ctx, logOpt).Infof("azsecrets kvclient created successfully")
logrus.WithContext(ctx).Infof("azsecrets kvclient created successfully")

return kvClientKeys, kvClientSecrets, nil
}
Expand Down
97 changes: 86 additions & 11 deletions pkg/keymanagementprovider/azurekeyvault/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,19 @@ package azurekeyvault
import (
"context"
"crypto"
"errors"
"strings"
"testing"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/Azure/azure-sdk-for-go/sdk/keyvault/azkeys"
"github.com/Azure/azure-sdk-for-go/sdk/keyvault/azsecrets"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/ratify-project/ratify/internal/version"
"github.com/ratify-project/ratify/pkg/keymanagementprovider/azurekeyvault/types"
"github.com/ratify-project/ratify/pkg/keymanagementprovider/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)

// TestParseAzureEnvironment tests the parseAzureEnvironment function
Expand Down Expand Up @@ -523,14 +525,50 @@ func TestValidate(t *testing.T) {
}
}

// Mock clients
type MockAzKeysClient struct {
mock.Mock
}

type MockAzSecretsClient struct {
mock.Mock
}

type MockWorkloadIdentityCredential struct {
mock.Mock
}

// Mock functions
func (m *MockWorkloadIdentityCredential) NewWorkloadIdentityCredential(options *azidentity.WorkloadIdentityCredentialOptions) (*MockWorkloadIdentityCredential, error) {
args := m.Called(options)
return args.Get(0).(*MockWorkloadIdentityCredential), args.Error(1)
}

func (m *MockAzKeysClient) NewClient(endpoint string, credential *azidentity.WorkloadIdentityCredential, options *azkeys.ClientOptions) (*azkeys.Client, error) {
args := m.Called(endpoint, credential, options)
return args.Get(0).(*azkeys.Client), args.Error(1)
}

func (m *MockAzSecretsClient) NewClient(endpoint string, credential *azidentity.WorkloadIdentityCredential, options *azsecrets.ClientOptions) (*azsecrets.Client, error) {
args := m.Called(endpoint, credential, options)
return args.Get(0).(*azsecrets.Client), args.Error(1)
}

func TestInitializeKvClient(t *testing.T) {
mockCredential := new(MockWorkloadIdentityCredential)
mockKeysClient := new(MockAzKeysClient)
mockSecretsClient := new(MockAzSecretsClient)

tests := []struct {
name string
kvEndpoint string
userAgent string
tenantID string
clientID string
expectedErr bool
name string
kvEndpoint string
userAgent string
tenantID string
clientID string
mockCredentialErr error
mockKeysErr error
mockSecretsErr error
expectedErr bool
}{
{
name: "Empty user agent",
Expand All @@ -541,18 +579,55 @@ func TestInitializeKvClient(t *testing.T) {
{
name: "Auth failure",
kvEndpoint: "https://test.vault.azure.net",
userAgent: version.UserAgent,
tenantID: "testTenantID",
clientID: "testClientID",
expectedErr: true,
},
{
name: "credential creation error",
kvEndpoint: "https://test-keyvault.vault.azure.net",
tenantID: "test-tenant-id",
clientID: "test-client-id",
mockCredentialErr: errors.New("failed to create workload identity credential"),
expectedErr: true,
},
{
name: "azkeys client creation error",
kvEndpoint: "https://test-keyvault.vault.azure.net",
tenantID: "test-tenant-id",
clientID: "test-client-id",
mockKeysErr: errors.New("failed to create azkeys client"),
expectedErr: true,
},
{
name: "azsecrets client creation error",
kvEndpoint: "https://test-keyvault.vault.azure.net",
tenantID: "test-tenant-id",
clientID: "test-client-id",
mockSecretsErr: errors.New("failed to create azsecrets client"),
expectedErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, _, err := initializeKvClient(context.Background(), tt.kvEndpoint, tt.tenantID, tt.clientID)
if tt.expectedErr != (err != nil) {
t.Fatalf("expected error: %v, got: %v", tt.expectedErr, err)
// Set up mocks
mockCredential.On("NewWorkloadIdentityCredential", mock.Anything).Return(mockCredential, tt.mockCredentialErr)
mockKeysClient.On("NewClient", tt.kvEndpoint, mockCredential, mock.Anything).Return(mockKeysClient, tt.mockKeysErr)
mockSecretsClient.On("NewClient", tt.kvEndpoint, mockCredential, mock.Anything).Return(mockSecretsClient, tt.mockSecretsErr)

// Call function under test
keysClient, secretsClient, err := initializeKvClient(context.Background(), tt.kvEndpoint, tt.tenantID, tt.clientID)

// Validate expectations
if tt.expectedErr {
assert.Error(t, err)
assert.Nil(t, keysClient)
assert.Nil(t, secretsClient)
} else {
assert.NoError(t, err)
assert.NotNil(t, keysClient)
assert.NotNil(t, secretsClient)
}
})
}
Expand Down

0 comments on commit 568e2f6

Please sign in to comment.