Skip to content

Commit

Permalink
chore: refactor azureidentity.go
Browse files Browse the repository at this point in the history
Signed-off-by: Shahram Kalantari <[email protected]>
  • Loading branch information
shahramk64 committed Oct 14, 2024
1 parent dcc499c commit b2494ec
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 93 deletions.
29 changes: 21 additions & 8 deletions pkg/common/oras/authprovider/azure/azureidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,26 @@ import (
)

type azureManagedIdentityProviderFactory struct{}

// ManagedIdentityTokenGetter defines an interface for getting a managed identity token.
type ManagedIdentityTokenGetter interface {
GetManagedIdentityToken(ctx context.Context, clientID string) (azcore.AccessToken, error)
}

// DefaultManagedIdentityTokenGetterImpl is the default implementation of AADAccessTokenGetter.
type DefaultManagedIdentityTokenGetterImpl struct{}

func (g *DefaultManagedIdentityTokenGetterImpl) GetManagedIdentityToken(ctx context.Context, clientID string) (azcore.AccessToken, error) {
return getManagedIdentityToken(ctx, clientID)
}

type MIAuthProvider struct {
identityToken azcore.AccessToken
clientID string
tenantID string
authClientFactory func(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error)
getRegistryHost func(artifact string) (string, error)
getManagedIdentityToken func(ctx context.Context, clientID string) (azcore.AccessToken, error)
authClientFactory AuthClientFactory
getRegistryHost RegistryHostGetter
getManagedIdentityToken ManagedIdentityTokenGetter
}

type azureManagedIdentityAuthProviderConf struct {
Expand Down Expand Up @@ -92,8 +105,8 @@ func (s *azureManagedIdentityProviderFactory) Create(authProviderConfig provider
identityToken: token,
clientID: client,
tenantID: tenant,
authClientFactory: DefaultAuthClientFactory,
getManagedIdentityToken: getManagedIdentityToken,
authClientFactory: &DefaultAuthClientFactoryImpl{}, // Concrete implementation
getManagedIdentityToken: &DefaultManagedIdentityTokenGetterImpl{}, // Concrete implementation
}, nil
}

Expand Down Expand Up @@ -123,14 +136,14 @@ func (d *MIAuthProvider) Provide(ctx context.Context, artifact string) (provider
}

// parse the artifact reference string to extract the registry host name
artifactHostName, err := d.getRegistryHost(artifact)
artifactHostName, err := d.getRegistryHost.GetRegistryHost(artifact)
if err != nil {
return provider.AuthConfig{}, re.ErrorCodeHostNameInvalid.WithComponentType(re.AuthProvider)
}

// need to refresh AAD token if it's expired
if time.Now().Add(time.Minute * 5).After(d.identityToken.ExpiresOn) {
newToken, err := d.getManagedIdentityToken(ctx, d.clientID)
newToken, err := d.getManagedIdentityToken.GetManagedIdentityToken(ctx, d.clientID)
if err != nil {
return provider.AuthConfig{}, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureManagedIdentityLink, err, "could not refresh azure managed identity token", re.HideStackTrace)
}
Expand All @@ -143,7 +156,7 @@ func (d *MIAuthProvider) Provide(ctx context.Context, artifact string) (provider

// TODO: Consider adding authentication client options for multicloud scenarios
var options *azcontainerregistry.AuthenticationClientOptions
client, err := d.authClientFactory(serverURL, options)
client, err := d.authClientFactory.CreateAuthClient(serverURL, options)
if err != nil {
return provider.AuthConfig{}, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureWorkloadIdentityLink, err, "failed to create authentication client for container registry by azure managed identity token", re.HideStackTrace)
}
Expand Down
207 changes: 122 additions & 85 deletions pkg/common/oras/authprovider/azure/azureidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,20 @@ import (
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry"
azcontainerregistry "github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry"
ratifyerrors "github.com/ratify-project/ratify/errors"
"github.com/ratify-project/ratify/pkg/common/oras/authprovider"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)

type MockGetManagedIdentityToken struct {
// Mock types for external dependencies
type MockManagedIdentityTokenGetter struct {
mock.Mock
}

func (m *MockGetManagedIdentityToken) GetManagedIdentityToken(ctx context.Context, clientID string) (azcore.AccessToken, error) {
// Mock ManagedIdentityTokenGetter.GetManagedIdentityToken
func (m *MockManagedIdentityTokenGetter) GetManagedIdentityToken(ctx context.Context, clientID string) (azcore.AccessToken, error) {
args := m.Called(ctx, clientID)
return args.Get(0).(azcore.AccessToken), args.Error(1)
}
Expand Down Expand Up @@ -103,104 +105,139 @@ func TestAzureMSIValidation_EnvironmentVariables_ExpectedResults(t *testing.T) {
}
}

func TestMIProvide_Success(t *testing.T) {
const registryHost = "myregistry.azurecr.io"
mockClient := new(MockAuthClient)
expectedRefreshToken := "mocked_refresh_token"
mockClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, "access_token", registryHost, mock.Anything).
Return(azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse{
ACRRefreshToken: azcontainerregistry.ACRRefreshToken{RefreshToken: &expectedRefreshToken},
}, nil)
// Test for invalid configuration when tenant ID is missing
func TestAzureManagedIdentityProviderFactory_Create_NoTenantID(t *testing.T) {
t.Setenv("AZURE_TENANT_ID", "")

provider := &MIAuthProvider{
identityToken: azcore.AccessToken{
Token: "mockToken",
ExpiresOn: time.Now().Add(time.Hour),
},
tenantID: "mockTenantID",
clientID: "mockClientID",
authClientFactory: func(_ string, _ *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) {
return mockClient, nil
},
getRegistryHost: func(_ string) (string, error) {
return registryHost, nil
},
getManagedIdentityToken: func(_ context.Context, _ string) (azcore.AccessToken, error) {
return azcore.AccessToken{
Token: "mockToken",
ExpiresOn: time.Now().Add(time.Hour),
}, nil
},
}
// Initialize factory
factory := &azureManagedIdentityProviderFactory{}

authConfig, err := provider.Provide(context.Background(), "artifact")
// Attempt to create MIAuthProvider with empty configuration
_, err := factory.Create(map[string]interface{}{})

assert.NoError(t, err)
// Assert that getManagedIdentityToken was not called
mockClient.AssertNotCalled(t, "getManagedIdentityToken", mock.Anything, mock.Anything)
// Assert that the returned refresh token matches the expected one
assert.Equal(t, expectedRefreshToken, authConfig.Password)
// Validate the error
assert.Error(t, err)
assert.Contains(t, err.Error(), "AZURE_TENANT_ID environment variable is empty")
}

func TestMIProvide_RefreshAAD(t *testing.T) {
const registryHost = "myregistry.azurecr.io"
// Arrange
mockClient := new(MockAuthClient)
// Test for missing client ID
func TestAzureManagedIdentityProviderFactory_Create_NoClientID(t *testing.T) {
t.Setenv("AZURE_TENANT_ID", "tenantID")
t.Setenv("AZURE_CLIENT_ID", "")

// Create a mock function for getManagedIdentityToken
mockGetManagedIdentityToken := new(MockGetManagedIdentityToken)
// Initialize factory
factory := &azureManagedIdentityProviderFactory{}

provider := &MIAuthProvider{
identityToken: azcore.AccessToken{
Token: "mockToken",
ExpiresOn: time.Now(), // Expired token to force a refresh
},
tenantID: "mockTenantID",
clientID: "mockClientID",
authClientFactory: func(_ string, _ *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) {
return mockClient, nil
},
getRegistryHost: func(_ string) (string, error) {
return registryHost, nil
},
getManagedIdentityToken: mockGetManagedIdentityToken.GetManagedIdentityToken, // Use the mock
}
// Attempt to create MIAuthProvider with empty client ID
_, err := factory.Create(map[string]interface{}{})

mockClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, "access_token", registryHost, mock.Anything).
Return(azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse{
ACRRefreshToken: azcontainerregistry.ACRRefreshToken{RefreshToken: new(string)},
}, nil)
// Validate the error
assert.Error(t, err)
assert.Contains(t, err.Error(), "AZURE_CLIENT_ID environment variable is empty")
}

// Set up the expectation for the mocked method
mockGetManagedIdentityToken.On("GetManagedIdentityToken", mock.Anything, "mockClientID").
Return(azcore.AccessToken{
Token: "newMockToken",
ExpiresOn: time.Now().Add(time.Hour),
}, nil)
// Test successful token refresh
func TestMIAuthProvider_Provide_TokenRefreshSuccess(t *testing.T) {
// Mock dependencies
mockAuthClientFactory := new(MockAuthClientFactory)
mockRegistryHostGetter := new(MockRegistryHostGetter)
mockManagedIdentityTokenGetter := new(MockManagedIdentityTokenGetter)
mockAuthClient := new(MockAuthClient)

// Define token values
expiredToken := azcore.AccessToken{Token: "expired_token", ExpiresOn: time.Now().Add(-10 * time.Minute)}
newTokenString := "refreshed_token"
newAADToken := azcore.AccessToken{Token: "new_token", ExpiresOn: time.Now().Add(10 * time.Minute)}
refreshToken := azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse{
ACRRefreshToken: azcontainerregistry.ACRRefreshToken{RefreshToken: &newTokenString},
}

ctx := context.TODO()
artifact := "testArtifact"
// Setup mock expectations
mockRegistryHostGetter.On("GetRegistryHost", "artifact_name").Return("example.azurecr.io", nil)
mockAuthClientFactory.On("CreateAuthClient", "https://example.azurecr.io", mock.Anything).Return(mockAuthClient, nil)
mockAuthClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, "access_token", "example.azurecr.io", mock.Anything).Return(refreshToken, nil)
mockManagedIdentityTokenGetter.On("GetManagedIdentityToken", mock.Anything, "clientID").Return(newAADToken, nil)

// Initialize provider with expired token
provider := MIAuthProvider{
identityToken: expiredToken,
clientID: "clientID",
tenantID: "tenantID",
authClientFactory: mockAuthClientFactory,
getRegistryHost: mockRegistryHostGetter,
getManagedIdentityToken: mockManagedIdentityTokenGetter,
}

// Act
_, err := provider.Provide(ctx, artifact)
// Call Provide method
ctx := context.Background()
authConfig, err := provider.Provide(ctx, "artifact_name")

// Assert
// Validate success and token refresh
assert.NoError(t, err)
mockGetManagedIdentityToken.AssertCalled(t, "GetManagedIdentityToken", mock.Anything, "mockClientID") // Assert that getManagedIdentityToken was called
assert.Equal(t, "refreshed_token", authConfig.Password)
}

func TestMIProvide_Failure_InvalidHostName(t *testing.T) {
provider := &MIAuthProvider{
tenantID: "test_tenant",
clientID: "test_client",
identityToken: azcore.AccessToken{
Token: "test_token",
},
getRegistryHost: func(_ string) (string, error) {
return "", errors.New("invalid hostname")
},
// Test failed token refresh
func TestMIAuthProvider_Provide_TokenRefreshFailure(t *testing.T) {
// Mock dependencies
mockAuthClientFactory := new(MockAuthClientFactory)
mockRegistryHostGetter := new(MockRegistryHostGetter)
mockManagedIdentityTokenGetter := new(MockManagedIdentityTokenGetter)

// Define token values
expiredToken := azcore.AccessToken{Token: "expired_token", ExpiresOn: time.Now().Add(-10 * time.Minute)}

// Setup mock expectations
mockRegistryHostGetter.On("GetRegistryHost", "artifact_name").Return("example.azurecr.io", nil)
mockManagedIdentityTokenGetter.On("GetManagedIdentityToken", mock.Anything, "clientID").Return(azcore.AccessToken{}, errors.New("token refresh failed"))

// Initialize provider with expired token
provider := MIAuthProvider{
identityToken: expiredToken,
clientID: "clientID",
tenantID: "tenantID",
authClientFactory: mockAuthClientFactory,
getRegistryHost: mockRegistryHostGetter,
getManagedIdentityToken: mockManagedIdentityTokenGetter,
}

// Call Provide method
ctx := context.Background()
_, err := provider.Provide(ctx, "artifact_name")

// Validate failure
assert.Error(t, err)
assert.Contains(t, err.Error(), "could not refresh azure managed identity token")
}

// Test for invalid hostname retrieval
func TestMIAuthProvider_Provide_InvalidHostName(t *testing.T) {
// Mock dependencies
mockAuthClientFactory := new(MockAuthClientFactory)
mockRegistryHostGetter := new(MockRegistryHostGetter)
mockManagedIdentityTokenGetter := new(MockManagedIdentityTokenGetter)

// Define valid token
validToken := azcore.AccessToken{Token: "valid_token", ExpiresOn: time.Now().Add(10 * time.Minute)}

// Setup mock expectations for invalid hostname
mockRegistryHostGetter.On("GetRegistryHost", "artifact_name").Return("", errors.New("invalid hostname"))

// Initialize provider with valid token
provider := MIAuthProvider{
identityToken: validToken,
clientID: "clientID",
tenantID: "tenantID",
authClientFactory: mockAuthClientFactory,
getRegistryHost: mockRegistryHostGetter,
getManagedIdentityToken: mockManagedIdentityTokenGetter,
}

_, err := provider.Provide(context.Background(), "artifact")
// Call Provide method
ctx := context.Background()
_, err := provider.Provide(ctx, "artifact_name")

// Validate failure
assert.Error(t, err)
assert.Contains(t, err.Error(), "HOST_NAME_INVALID")
}

0 comments on commit b2494ec

Please sign in to comment.