diff --git a/pkg/common/oras/authprovider/azure/azureidentity.go b/pkg/common/oras/authprovider/azure/azureidentity.go index f9f4adb67d..8f89a0f9d3 100644 --- a/pkg/common/oras/authprovider/azure/azureidentity.go +++ b/pkg/common/oras/authprovider/azure/azureidentity.go @@ -33,13 +33,26 @@ import ( ) type azureManagedIdentityProviderFactory struct{} + +// ManagedIdentityTokenGetter defines an interface for getting a managed identity token. +type ManagedIdentityTokenGetter interface { + GetManagedIdentityToken(ctx context.Context, clientID string) (azcore.AccessToken, error) +} + +// DefaultManagedIdentityTokenGetterImpl is the default implementation of AADAccessTokenGetter. +type DefaultManagedIdentityTokenGetterImpl struct{} + +func (g *DefaultManagedIdentityTokenGetterImpl) GetManagedIdentityToken(ctx context.Context, clientID string) (azcore.AccessToken, error) { + return getManagedIdentityToken(ctx, clientID) +} + type MIAuthProvider struct { identityToken azcore.AccessToken clientID string tenantID string - authClientFactory func(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) - getRegistryHost func(artifact string) (string, error) - getManagedIdentityToken func(ctx context.Context, clientID string) (azcore.AccessToken, error) + authClientFactory AuthClientFactory + getRegistryHost RegistryHostGetter + getManagedIdentityToken ManagedIdentityTokenGetter } type azureManagedIdentityAuthProviderConf struct { @@ -92,8 +105,8 @@ func (s *azureManagedIdentityProviderFactory) Create(authProviderConfig provider identityToken: token, clientID: client, tenantID: tenant, - authClientFactory: DefaultAuthClientFactory, - getManagedIdentityToken: getManagedIdentityToken, + authClientFactory: &DefaultAuthClientFactoryImpl{}, // Concrete implementation + getManagedIdentityToken: &DefaultManagedIdentityTokenGetterImpl{}, // Concrete implementation }, nil } @@ -123,14 +136,14 @@ func (d *MIAuthProvider) Provide(ctx context.Context, artifact string) (provider } // parse the artifact reference string to extract the registry host name - artifactHostName, err := d.getRegistryHost(artifact) + artifactHostName, err := d.getRegistryHost.GetRegistryHost(artifact) if err != nil { return provider.AuthConfig{}, re.ErrorCodeHostNameInvalid.WithComponentType(re.AuthProvider) } // need to refresh AAD token if it's expired if time.Now().Add(time.Minute * 5).After(d.identityToken.ExpiresOn) { - newToken, err := d.getManagedIdentityToken(ctx, d.clientID) + newToken, err := d.getManagedIdentityToken.GetManagedIdentityToken(ctx, d.clientID) if err != nil { return provider.AuthConfig{}, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureManagedIdentityLink, err, "could not refresh azure managed identity token", re.HideStackTrace) } @@ -143,7 +156,7 @@ func (d *MIAuthProvider) Provide(ctx context.Context, artifact string) (provider // TODO: Consider adding authentication client options for multicloud scenarios var options *azcontainerregistry.AuthenticationClientOptions - client, err := d.authClientFactory(serverURL, options) + client, err := d.authClientFactory.CreateAuthClient(serverURL, options) if err != nil { return provider.AuthConfig{}, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureWorkloadIdentityLink, err, "failed to create authentication client for container registry by azure managed identity token", re.HideStackTrace) } diff --git a/pkg/common/oras/authprovider/azure/azureidentity_test.go b/pkg/common/oras/authprovider/azure/azureidentity_test.go index 11fb48f5f7..8cb08b419f 100644 --- a/pkg/common/oras/authprovider/azure/azureidentity_test.go +++ b/pkg/common/oras/authprovider/azure/azureidentity_test.go @@ -23,18 +23,20 @@ import ( "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" - "github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry" + azcontainerregistry "github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry" ratifyerrors "github.com/ratify-project/ratify/errors" "github.com/ratify-project/ratify/pkg/common/oras/authprovider" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) -type MockGetManagedIdentityToken struct { +// Mock types for external dependencies +type MockManagedIdentityTokenGetter struct { mock.Mock } -func (m *MockGetManagedIdentityToken) GetManagedIdentityToken(ctx context.Context, clientID string) (azcore.AccessToken, error) { +// Mock ManagedIdentityTokenGetter.GetManagedIdentityToken +func (m *MockManagedIdentityTokenGetter) GetManagedIdentityToken(ctx context.Context, clientID string) (azcore.AccessToken, error) { args := m.Called(ctx, clientID) return args.Get(0).(azcore.AccessToken), args.Error(1) } @@ -103,104 +105,139 @@ func TestAzureMSIValidation_EnvironmentVariables_ExpectedResults(t *testing.T) { } } -func TestMIProvide_Success(t *testing.T) { - const registryHost = "myregistry.azurecr.io" - mockClient := new(MockAuthClient) - expectedRefreshToken := "mocked_refresh_token" - mockClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, "access_token", registryHost, mock.Anything). - Return(azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse{ - ACRRefreshToken: azcontainerregistry.ACRRefreshToken{RefreshToken: &expectedRefreshToken}, - }, nil) +// Test for invalid configuration when tenant ID is missing +func TestAzureManagedIdentityProviderFactory_Create_NoTenantID(t *testing.T) { + t.Setenv("AZURE_TENANT_ID", "") - provider := &MIAuthProvider{ - identityToken: azcore.AccessToken{ - Token: "mockToken", - ExpiresOn: time.Now().Add(time.Hour), - }, - tenantID: "mockTenantID", - clientID: "mockClientID", - authClientFactory: func(_ string, _ *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) { - return mockClient, nil - }, - getRegistryHost: func(_ string) (string, error) { - return registryHost, nil - }, - getManagedIdentityToken: func(_ context.Context, _ string) (azcore.AccessToken, error) { - return azcore.AccessToken{ - Token: "mockToken", - ExpiresOn: time.Now().Add(time.Hour), - }, nil - }, - } + // Initialize factory + factory := &azureManagedIdentityProviderFactory{} - authConfig, err := provider.Provide(context.Background(), "artifact") + // Attempt to create MIAuthProvider with empty configuration + _, err := factory.Create(map[string]interface{}{}) - assert.NoError(t, err) - // Assert that getManagedIdentityToken was not called - mockClient.AssertNotCalled(t, "getManagedIdentityToken", mock.Anything, mock.Anything) - // Assert that the returned refresh token matches the expected one - assert.Equal(t, expectedRefreshToken, authConfig.Password) + // Validate the error + assert.Error(t, err) + assert.Contains(t, err.Error(), "AZURE_TENANT_ID environment variable is empty") } -func TestMIProvide_RefreshAAD(t *testing.T) { - const registryHost = "myregistry.azurecr.io" - // Arrange - mockClient := new(MockAuthClient) +// Test for missing client ID +func TestAzureManagedIdentityProviderFactory_Create_NoClientID(t *testing.T) { + t.Setenv("AZURE_TENANT_ID", "tenantID") + t.Setenv("AZURE_CLIENT_ID", "") - // Create a mock function for getManagedIdentityToken - mockGetManagedIdentityToken := new(MockGetManagedIdentityToken) + // Initialize factory + factory := &azureManagedIdentityProviderFactory{} - provider := &MIAuthProvider{ - identityToken: azcore.AccessToken{ - Token: "mockToken", - ExpiresOn: time.Now(), // Expired token to force a refresh - }, - tenantID: "mockTenantID", - clientID: "mockClientID", - authClientFactory: func(_ string, _ *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) { - return mockClient, nil - }, - getRegistryHost: func(_ string) (string, error) { - return registryHost, nil - }, - getManagedIdentityToken: mockGetManagedIdentityToken.GetManagedIdentityToken, // Use the mock - } + // Attempt to create MIAuthProvider with empty client ID + _, err := factory.Create(map[string]interface{}{}) - mockClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, "access_token", registryHost, mock.Anything). - Return(azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse{ - ACRRefreshToken: azcontainerregistry.ACRRefreshToken{RefreshToken: new(string)}, - }, nil) + // Validate the error + assert.Error(t, err) + assert.Contains(t, err.Error(), "AZURE_CLIENT_ID environment variable is empty") +} - // Set up the expectation for the mocked method - mockGetManagedIdentityToken.On("GetManagedIdentityToken", mock.Anything, "mockClientID"). - Return(azcore.AccessToken{ - Token: "newMockToken", - ExpiresOn: time.Now().Add(time.Hour), - }, nil) +// Test successful token refresh +func TestMIAuthProvider_Provide_TokenRefreshSuccess(t *testing.T) { + // Mock dependencies + mockAuthClientFactory := new(MockAuthClientFactory) + mockRegistryHostGetter := new(MockRegistryHostGetter) + mockManagedIdentityTokenGetter := new(MockManagedIdentityTokenGetter) + mockAuthClient := new(MockAuthClient) + + // Define token values + expiredToken := azcore.AccessToken{Token: "expired_token", ExpiresOn: time.Now().Add(-10 * time.Minute)} + newTokenString := "refreshed_token" + newAADToken := azcore.AccessToken{Token: "new_token", ExpiresOn: time.Now().Add(10 * time.Minute)} + refreshToken := azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse{ + ACRRefreshToken: azcontainerregistry.ACRRefreshToken{RefreshToken: &newTokenString}, + } - ctx := context.TODO() - artifact := "testArtifact" + // Setup mock expectations + mockRegistryHostGetter.On("GetRegistryHost", "artifact_name").Return("example.azurecr.io", nil) + mockAuthClientFactory.On("CreateAuthClient", "https://example.azurecr.io", mock.Anything).Return(mockAuthClient, nil) + mockAuthClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, "access_token", "example.azurecr.io", mock.Anything).Return(refreshToken, nil) + mockManagedIdentityTokenGetter.On("GetManagedIdentityToken", mock.Anything, "clientID").Return(newAADToken, nil) + + // Initialize provider with expired token + provider := MIAuthProvider{ + identityToken: expiredToken, + clientID: "clientID", + tenantID: "tenantID", + authClientFactory: mockAuthClientFactory, + getRegistryHost: mockRegistryHostGetter, + getManagedIdentityToken: mockManagedIdentityTokenGetter, + } - // Act - _, err := provider.Provide(ctx, artifact) + // Call Provide method + ctx := context.Background() + authConfig, err := provider.Provide(ctx, "artifact_name") - // Assert + // Validate success and token refresh assert.NoError(t, err) - mockGetManagedIdentityToken.AssertCalled(t, "GetManagedIdentityToken", mock.Anything, "mockClientID") // Assert that getManagedIdentityToken was called + assert.Equal(t, "refreshed_token", authConfig.Password) } -func TestMIProvide_Failure_InvalidHostName(t *testing.T) { - provider := &MIAuthProvider{ - tenantID: "test_tenant", - clientID: "test_client", - identityToken: azcore.AccessToken{ - Token: "test_token", - }, - getRegistryHost: func(_ string) (string, error) { - return "", errors.New("invalid hostname") - }, +// Test failed token refresh +func TestMIAuthProvider_Provide_TokenRefreshFailure(t *testing.T) { + // Mock dependencies + mockAuthClientFactory := new(MockAuthClientFactory) + mockRegistryHostGetter := new(MockRegistryHostGetter) + mockManagedIdentityTokenGetter := new(MockManagedIdentityTokenGetter) + + // Define token values + expiredToken := azcore.AccessToken{Token: "expired_token", ExpiresOn: time.Now().Add(-10 * time.Minute)} + + // Setup mock expectations + mockRegistryHostGetter.On("GetRegistryHost", "artifact_name").Return("example.azurecr.io", nil) + mockManagedIdentityTokenGetter.On("GetManagedIdentityToken", mock.Anything, "clientID").Return(azcore.AccessToken{}, errors.New("token refresh failed")) + + // Initialize provider with expired token + provider := MIAuthProvider{ + identityToken: expiredToken, + clientID: "clientID", + tenantID: "tenantID", + authClientFactory: mockAuthClientFactory, + getRegistryHost: mockRegistryHostGetter, + getManagedIdentityToken: mockManagedIdentityTokenGetter, + } + + // Call Provide method + ctx := context.Background() + _, err := provider.Provide(ctx, "artifact_name") + + // Validate failure + assert.Error(t, err) + assert.Contains(t, err.Error(), "could not refresh azure managed identity token") +} + +// Test for invalid hostname retrieval +func TestMIAuthProvider_Provide_InvalidHostName(t *testing.T) { + // Mock dependencies + mockAuthClientFactory := new(MockAuthClientFactory) + mockRegistryHostGetter := new(MockRegistryHostGetter) + mockManagedIdentityTokenGetter := new(MockManagedIdentityTokenGetter) + + // Define valid token + validToken := azcore.AccessToken{Token: "valid_token", ExpiresOn: time.Now().Add(10 * time.Minute)} + + // Setup mock expectations for invalid hostname + mockRegistryHostGetter.On("GetRegistryHost", "artifact_name").Return("", errors.New("invalid hostname")) + + // Initialize provider with valid token + provider := MIAuthProvider{ + identityToken: validToken, + clientID: "clientID", + tenantID: "tenantID", + authClientFactory: mockAuthClientFactory, + getRegistryHost: mockRegistryHostGetter, + getManagedIdentityToken: mockManagedIdentityTokenGetter, } - _, err := provider.Provide(context.Background(), "artifact") + // Call Provide method + ctx := context.Background() + _, err := provider.Provide(ctx, "artifact_name") + + // Validate failure assert.Error(t, err) + assert.Contains(t, err.Error(), "HOST_NAME_INVALID") }