Skip to content

Commit

Permalink
chore: more unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: Shahram Kalantari <[email protected]>
  • Loading branch information
shahramk64 committed Oct 3, 2024
1 parent 88639d1 commit 51d30ec
Show file tree
Hide file tree
Showing 4 changed files with 329 additions and 36 deletions.
60 changes: 40 additions & 20 deletions pkg/common/oras/authprovider/azure/azureidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,28 @@ import (
)

type azureManagedIdentityProviderFactory struct{}
type azureManagedIdentityAuthProvider struct {
identityToken azcore.AccessToken
clientID string
tenantID string
type MIAuthProvider struct {
identityToken azcore.AccessToken
clientID string
tenantID string
authClientFactory func(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error)
getRegistryHost func(artifact string) (string, error)
getManagedIdentityToken func(ctx context.Context, clientID string) (azcore.AccessToken, error)
}

// NewAzureWIAuthProvider is defined to enable mocking of some of the function in unit tests
func NewAzureMIAuthProvider() *MIAuthProvider {
return &MIAuthProvider{
authClientFactory: func(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) {
client, err := azcontainerregistry.NewAuthenticationClient(serverURL, options)
if err != nil {
return nil, err

Check warning on line 51 in pkg/common/oras/authprovider/azure/azureidentity.go

View check run for this annotation

Codecov / codecov/patch

pkg/common/oras/authprovider/azure/azureidentity.go#L49-L51

Added lines #L49 - L51 were not covered by tests
}
return &AuthenticationClientWrapper{client: client}, nil

Check warning on line 53 in pkg/common/oras/authprovider/azure/azureidentity.go

View check run for this annotation

Codecov / codecov/patch

pkg/common/oras/authprovider/azure/azureidentity.go#L53

Added line #L53 was not covered by tests
},
getRegistryHost: provider.GetRegistryHostName,
getManagedIdentityToken: getManagedIdentityToken,
}
}

type azureManagedIdentityAuthProviderConf struct {
Expand All @@ -53,7 +71,7 @@ func init() {
provider.Register(azureManagedIdentityAuthProviderName, &azureManagedIdentityProviderFactory{})
}

// Create returns an azureManagedIdentityAuthProvider
// Create returns an MIAuthProvider
func (s *azureManagedIdentityProviderFactory) Create(authProviderConfig provider.AuthProviderConfig) (provider.AuthProvider, error) {
conf := azureManagedIdentityAuthProviderConf{}
authProviderConfigBytes, err := json.Marshal(authProviderConfig)
Expand Down Expand Up @@ -85,15 +103,15 @@ func (s *azureManagedIdentityProviderFactory) Create(authProviderConfig provider
return nil, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureManagedIdentityLink, err, "", re.HideStackTrace)
}

return &azureManagedIdentityAuthProvider{
return &MIAuthProvider{

Check warning on line 106 in pkg/common/oras/authprovider/azure/azureidentity.go

View check run for this annotation

Codecov / codecov/patch

pkg/common/oras/authprovider/azure/azureidentity.go#L106

Added line #L106 was not covered by tests
identityToken: token,
clientID: client,
tenantID: tenant,
}, nil
}

// Enabled checks for non empty tenant ID and AAD access token
func (d *azureManagedIdentityAuthProvider) Enabled(_ context.Context) bool {
func (d *MIAuthProvider) Enabled(_ context.Context) bool {
if d.clientID == "" {
return false
}
Expand All @@ -112,55 +130,57 @@ func (d *azureManagedIdentityAuthProvider) Enabled(_ context.Context) bool {
// Provide returns the credentials for a specified artifact.
// Uses Managed Identity to retrieve an AAD access token which can be
// exchanged for a valid ACR refresh token for login.
func (d *azureManagedIdentityAuthProvider) Provide(ctx context.Context, artifact string) (provider.AuthConfig, error) {
func (d *MIAuthProvider) Provide(ctx context.Context, artifact string) (provider.AuthConfig, error) {
if !d.Enabled(ctx) {
return provider.AuthConfig{}, fmt.Errorf("azure managed identity provider is not properly enabled")
}

// parse the artifact reference string to extract the registry host name
artifactHostName, err := provider.GetRegistryHostName(artifact)
artifactHostName, err := d.getRegistryHost(artifact)
if err != nil {
return provider.AuthConfig{}, err
return provider.AuthConfig{}, re.ErrorCodeHostNameInvalid.WithComponentType(re.AuthProvider)

Check warning on line 141 in pkg/common/oras/authprovider/azure/azureidentity.go

View check run for this annotation

Codecov / codecov/patch

pkg/common/oras/authprovider/azure/azureidentity.go#L141

Added line #L141 was not covered by tests
}

// need to refresh AAD token if it's expired
if time.Now().Add(time.Minute * 5).After(d.identityToken.ExpiresOn) {
newToken, err := getManagedIdentityToken(ctx, d.clientID)
newToken, err := d.getManagedIdentityToken(ctx, d.clientID)
if err != nil {
return provider.AuthConfig{}, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureManagedIdentityLink, err, "could not refresh azure managed identity token", re.HideStackTrace)
}
d.identityToken = newToken
logger.GetLogger(ctx, logOpt).Info("successfully refreshed azure managed identity token")
}

// add protocol to generate complete URI
serverURL := "https://" + artifactHostName

// create registry client and exchange AAD token for registry refresh token
client, err := azcontainerregistry.NewAuthenticationClient(serverURL, nil) // &AuthenticationClientOptions{ClientOptions: options})
// TODO: Consider adding authentication client options for multicloud scenarios
var options *azcontainerregistry.AuthenticationClientOptions
client, err := d.authClientFactory(serverURL, options)
if err != nil {
return provider.AuthConfig{}, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureWorkloadIdentityLink, err, "failed to create authentication client for container registry by azure managed identity token", re.HideStackTrace)

Check warning on line 161 in pkg/common/oras/authprovider/azure/azureidentity.go

View check run for this annotation

Codecov / codecov/patch

pkg/common/oras/authprovider/azure/azureidentity.go#L161

Added line #L161 was not covered by tests
}
// refreshTokenClient := containerregistry.NewRefreshTokensClient(serverURL)
rt, err := client.ExchangeAADAccessTokenForACRRefreshToken(
context.Background(),

response, err := client.ExchangeAADAccessTokenForACRRefreshToken(
ctx,
"access_token",
artifactHostName,
&azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions{
AccessToken: &d.identityToken.Token,
Tenant: &d.tenantID,
},
)
// rt, err := refreshTokenClient.GetFromExchange(ctx, "access_token", artifactHostName, d.tenantID, "", d.identityToken.Token)
if err != nil {
return provider.AuthConfig{}, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureManagedIdentityLink, err, "failed to get refresh token for container registry by azure managed identity token", re.HideStackTrace)
}
rt := response.ACRRefreshToken

expiresOn := getACRExpiryIfEarlier(d.identityToken.ExpiresOn)

refreshTokenExpiry := getACRExpiryIfEarlier(d.identityToken.ExpiresOn)
authConfig := provider.AuthConfig{
Username: dockerTokenLoginUsernameGUID,
Password: *rt.RefreshToken,
Provider: d,
ExpiresOn: expiresOn,
ExpiresOn: refreshTokenExpiry,
}

return authConfig, nil
Expand Down
156 changes: 155 additions & 1 deletion pkg/common/oras/authprovider/azure/azureidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,28 @@ import (
"errors"
"os"
"testing"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry"
ratifyerrors "github.com/ratify-project/ratify/errors"
"github.com/ratify-project/ratify/pkg/common/oras/authprovider"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)

type MockGetManagedIdentityToken struct {
mock.Mock
}

func (m *MockGetManagedIdentityToken) GetManagedIdentityToken(ctx context.Context, clientID string) (azcore.AccessToken, error) {
args := m.Called(ctx, clientID)
return args.Get(0).(azcore.AccessToken), args.Error(1)
}

// Verifies that Enabled checks if tenantID is empty or AAD token is empty
func TestAzureMSIEnabled_ExpectedResults(t *testing.T) {
azAuthProvider := azureManagedIdentityAuthProvider{
azAuthProvider := MIAuthProvider{
tenantID: "test_tenant",
clientID: "test_client",
identityToken: azcore.AccessToken{
Expand Down Expand Up @@ -89,3 +102,144 @@ func TestAzureMSIValidation_EnvironmentVariables_ExpectedResults(t *testing.T) {
t.Fatalf("create auth provider should have failed: expected err %s, but got err %s", expectedErr, err)
}
}

func TestNewAzureMIAuthProvider_AuthenticationClientError(t *testing.T) {
// Create a new mock client factory
mockFactory := new(MockAuthClientFactory)

// Setup mock to return an error
mockFactory.On("NewAuthenticationClient", mock.Anything, mock.Anything).
Return(nil, errors.New("failed to create authentication client"))

// Create a new WIAuthProvider instance
provider := NewAzureMIAuthProvider()
provider.authClientFactory = mockFactory.NewAuthenticationClient

// Call authClientFactory to test error handling
_, err := provider.authClientFactory("https://myregistry.azurecr.io", nil)

// Assert that an error is returned
assert.Error(t, err)
assert.Equal(t, "failed to create authentication client", err.Error())

// Verify that the mock was called
mockFactory.AssertCalled(t, "NewAuthenticationClient", "https://myregistry.azurecr.io", mock.Anything)
}

func TestNewAzureMIAuthProvider_Success(t *testing.T) {
// Create a new mock client factory
mockFactory := new(MockAuthClientFactory)

// Create a mock auth client to return from the factory
mockAuthClient := new(MockAuthClient)

// Setup mock to return a successful auth client
mockFactory.On("NewAuthenticationClient", mock.Anything, mock.Anything).
Return(mockAuthClient, nil)

// Create a new WIAuthProvider instance
provider := NewAzureMIAuthProvider()

// Replace authClientFactory with the mock factory
provider.authClientFactory = mockFactory.NewAuthenticationClient

// Call authClientFactory to test successful return
client, err := provider.authClientFactory("https://myregistry.azurecr.io", nil)

// Assert that the client is returned without an error
assert.NoError(t, err)
assert.NotNil(t, client)

// Assert that the returned client is of the expected type
_, ok := client.(*MockAuthClient)
assert.True(t, ok, "expected client to be of type *MockAuthClient")

// Verify that the mock was called
mockFactory.AssertCalled(t, "NewAuthenticationClient", "https://myregistry.azurecr.io", mock.Anything)
}

func TestMIProvide_Success(t *testing.T) {
const registryHost = "myregistry.azurecr.io"
mockClient := new(MockAuthClient)
expectedRefreshToken := "mocked_refresh_token"
mockClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, "access_token", registryHost, mock.Anything).
Return(azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse{
ACRRefreshToken: azcontainerregistry.ACRRefreshToken{RefreshToken: &expectedRefreshToken},
}, nil)

provider := &MIAuthProvider{
identityToken: azcore.AccessToken{
Token: "mockToken",
ExpiresOn: time.Now().Add(time.Hour),
},
tenantID: "mockTenantID",
clientID: "mockClientID",
authClientFactory: func(_ string, _ *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) {
return mockClient, nil
},
getRegistryHost: func(_ string) (string, error) {
return registryHost, nil
},
getManagedIdentityToken: func(_ context.Context, _ string) (azcore.AccessToken, error) {
return azcore.AccessToken{
Token: "mockToken",
ExpiresOn: time.Now().Add(time.Hour),
}, nil
},
}

authConfig, err := provider.Provide(context.Background(), "artifact")

assert.NoError(t, err)
// Assert that getManagedIdentityToken was not called
mockClient.AssertNotCalled(t, "getManagedIdentityToken", mock.Anything, mock.Anything)
// Assert that the returned refresh token matches the expected one
assert.Equal(t, expectedRefreshToken, authConfig.Password)
}

func TestMIProvide_RefreshAAD(t *testing.T) {
const registryHost = "myregistry.azurecr.io"
// Arrange
mockClient := new(MockAuthClient)

// Create a mock function for getManagedIdentityToken
mockGetManagedIdentityToken := new(MockGetManagedIdentityToken)

provider := &MIAuthProvider{
identityToken: azcore.AccessToken{
Token: "mockToken",
ExpiresOn: time.Now(), // Expired token to force a refresh
},
tenantID: "mockTenantID",
clientID: "mockClientID",
authClientFactory: func(_ string, _ *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) {
return mockClient, nil
},
getRegistryHost: func(_ string) (string, error) {
return registryHost, nil
},
getManagedIdentityToken: mockGetManagedIdentityToken.GetManagedIdentityToken, // Use the mock
}

mockClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, "access_token", registryHost, mock.Anything).
Return(azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse{
ACRRefreshToken: azcontainerregistry.ACRRefreshToken{RefreshToken: new(string)},
}, nil)

// Set up the expectation for the mocked method
mockGetManagedIdentityToken.On("GetManagedIdentityToken", mock.Anything, "mockClientID").
Return(azcore.AccessToken{
Token: "newMockToken",
ExpiresOn: time.Now().Add(time.Hour),
}, nil)

ctx := context.TODO()
artifact := "testArtifact"

// Act
_, err := provider.Provide(ctx, artifact)

// Assert
assert.NoError(t, err)
mockGetManagedIdentityToken.AssertCalled(t, "GetManagedIdentityToken", mock.Anything, "mockClientID") // Assert that getManagedIdentityToken was called
}
14 changes: 7 additions & 7 deletions pkg/common/oras/authprovider/azure/azureworkloadidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,32 +36,33 @@ type WIAuthProvider struct {
aadToken confidential.AuthResult
tenantID string
clientID string
authClientFactory func(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (authClient, error)
authClientFactory func(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error)
getRegistryHost func(artifact string) (string, error)
getAADAccessToken func(ctx context.Context, tenantID, clientID, resource string) (confidential.AuthResult, error)
reportMetrics func(ctx context.Context, duration int64, artifactHostName string)
}

type authenticationClientWrapper struct {
type AuthenticationClientWrapper struct {
client *azcontainerregistry.AuthenticationClient
}

func (w *authenticationClientWrapper) ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) {
func (w *AuthenticationClientWrapper) ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) {
return w.client.ExchangeAADAccessTokenForACRRefreshToken(ctx, azcontainerregistry.PostContentSchemaGrantType(grantType), service, options)

Check warning on line 50 in pkg/common/oras/authprovider/azure/azureworkloadidentity.go

View check run for this annotation

Codecov / codecov/patch

pkg/common/oras/authprovider/azure/azureworkloadidentity.go#L49-L50

Added lines #L49 - L50 were not covered by tests
}

type authClient interface {
type AuthClient interface {
ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error)
}

// NewAzureWIAuthProvider is defined to enable mocking of some of the function in unit tests
func NewAzureWIAuthProvider() *WIAuthProvider {
return &WIAuthProvider{
authClientFactory: func(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (authClient, error) {
authClientFactory: func(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) {
client, err := azcontainerregistry.NewAuthenticationClient(serverURL, options)
if err != nil {
return nil, err

Check warning on line 63 in pkg/common/oras/authprovider/azure/azureworkloadidentity.go

View check run for this annotation

Codecov / codecov/patch

pkg/common/oras/authprovider/azure/azureworkloadidentity.go#L61-L63

Added lines #L61 - L63 were not covered by tests
}
return &authenticationClientWrapper{client: client}, nil
return &AuthenticationClientWrapper{client: client}, nil

Check warning on line 65 in pkg/common/oras/authprovider/azure/azureworkloadidentity.go

View check run for this annotation

Codecov / codecov/patch

pkg/common/oras/authprovider/azure/azureworkloadidentity.go#L65

Added line #L65 was not covered by tests
},
getRegistryHost: provider.GetRegistryHostName,
getAADAccessToken: azureauth.GetAADAccessToken,
Expand Down Expand Up @@ -161,7 +162,6 @@ func (d *WIAuthProvider) Provide(ctx context.Context, artifact string) (provider
// add protocol to generate complete URI
serverURL := "https://" + artifactHostName

// create registry client and exchange AAD token for registry refresh token
// TODO: Consider adding authentication client options for multicloud scenarios
var options *azcontainerregistry.AuthenticationClientOptions
client, err := d.authClientFactory(serverURL, options)
Expand Down
Loading

0 comments on commit 51d30ec

Please sign in to comment.