diff --git a/go.mod b/go.mod index 27d1f11c1..be3607472 100644 --- a/go.mod +++ b/go.mod @@ -10,8 +10,9 @@ retract ( require ( github.com/Azure/azure-sdk-for-go v68.0.0+incompatible - github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1 - github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.6.0 + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.14.0 + github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0 + github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry v0.2.2 github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 github.com/aws/aws-sdk-go-v2 v1.32.2 github.com/aws/aws-sdk-go-v2/config v1.27.43 @@ -118,6 +119,7 @@ require ( github.com/sigstore/timestamp-authority v1.2.2 // indirect github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 // indirect github.com/sourcegraph/conc v0.3.0 // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/tchap/go-patricia/v2 v2.3.1 // indirect github.com/thales-e-security/pool v0.0.2 // indirect github.com/tjfoc/gmsm v1.4.1 // indirect @@ -130,7 +132,7 @@ require ( ) require ( - github.com/Azure/azure-sdk-for-go/sdk/internal v1.8.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 // indirect github.com/Azure/go-autorest v14.2.0+incompatible // indirect github.com/Azure/go-autorest/autorest v0.11.29 github.com/Azure/go-autorest/autorest/adal v0.9.24 // indirect @@ -237,7 +239,7 @@ require ( golang.org/x/crypto v0.28.0 golang.org/x/exp v0.0.0-20240112132812-db7319d0e0e3 // indirect golang.org/x/mod v0.20.0 // indirect - golang.org/x/net v0.28.0 // indirect + golang.org/x/net v0.29.0 // indirect golang.org/x/oauth2 v0.23.0 // indirect golang.org/x/sys v0.26.0 // indirect golang.org/x/term v0.25.0 // indirect diff --git a/go.sum b/go.sum index d94ab83cf..ecb26633c 100644 --- a/go.sum +++ b/go.sum @@ -18,12 +18,14 @@ github.com/AliyunContainerService/ack-ram-tool/pkg/credentials/alibabacloudsdkgo github.com/AliyunContainerService/ack-ram-tool/pkg/credentials/alibabacloudsdkgo/helper v0.2.0/go.mod h1:GgeIE+1be8Ivm7Sh4RgwI42aTtC9qrcj+Y9Y6CjJhJs= github.com/Azure/azure-sdk-for-go v68.0.0+incompatible h1:fcYLmCpyNYRnvJbPerq7U0hS+6+I79yEDJBqVNcqUzU= github.com/Azure/azure-sdk-for-go v68.0.0+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc= -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1 h1:E+OJmp2tPvt1W+amx48v1eqbjDYsgN+RzP4q16yV5eM= -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1/go.mod h1:a6xsAQUZg+VsS3TJ05SRp524Hs4pZ/AeFSr5ENf0Yjo= -github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.6.0 h1:U2rTu3Ef+7w9FHKIAXM6ZyqF3UOWJZ12zIm8zECAFfg= -github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.6.0/go.mod h1:9kIvujWAA58nmPmWB1m23fyWic1kYZMxD9CxaWn4Qpg= -github.com/Azure/azure-sdk-for-go/sdk/internal v1.8.0 h1:jBQA3cKT4L2rWMpgE7Yt3Hwh2aUj8KXjIGLxjHeYNNo= -github.com/Azure/azure-sdk-for-go/sdk/internal v1.8.0/go.mod h1:4OG6tQ9EOP/MT0NMjDlRzWoVFxfu9rN9B2X+tlSVktg= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.14.0 h1:nyQWyZvwGTvunIMxi1Y9uXkcyr+I7TeNrr/foo4Kpk8= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.14.0/go.mod h1:l38EPgmsp71HHLq9j7De57JcKOWPyhrsW1Awm1JS6K0= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0 h1:tfLQ34V6F7tVSwoTf/4lH5sE0o6eCJuNDTmH09nDpbc= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0/go.mod h1:9kIvujWAA58nmPmWB1m23fyWic1kYZMxD9CxaWn4Qpg= +github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry v0.2.2 h1:wBx10efdJcl8FSewgc41kAW4AvHPgmJZmN7fpNxn8rc= +github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry v0.2.2/go.mod h1:zzmu18cpAinSbhC86oWd47nmgbb91Fl+Yac2PE8NdYk= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 h1:ywEEhmNahHBihViHepv3xPBn1663uRv2t2q/ESv9seY= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0/go.mod h1:iZDifYGJTIgIIkYRNWPENUnqx6bJ2xnSDFI2tjwZNuY= github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.1.0 h1:DRiANoJTiW6obBQe3SqZizkuV1PEgfiiGivmVocDy64= github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.1.0/go.mod h1:qLIye2hwb/ZouqhpSD9Zn3SJipvpEnz1Ywl3VUk9Y0s= github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.0.0 h1:D3occbWoio4EBLkbkevetNMAVX197GkzbUMtqjGWn80= @@ -657,6 +659,7 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= @@ -823,8 +826,8 @@ golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.11.0/go.mod h1:2L/ixqYpgIVXmeoSA/4Lu7BzTG4KIyPIryS4IsOd1oQ= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= -golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= -golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= +golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo= +golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.23.0 h1:PbgcYx2W7i4LvjJWEbf0ngHV6qJYr86PkAV3bXdLEbs= golang.org/x/oauth2 v0.23.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= diff --git a/pkg/common/oras/authprovider/azure/azureidentity.go b/pkg/common/oras/authprovider/azure/azureidentity.go index 0a5a00e5c..d0369c4dc 100644 --- a/pkg/common/oras/authprovider/azure/azureidentity.go +++ b/pkg/common/oras/authprovider/azure/azureidentity.go @@ -29,14 +29,44 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" - "github.com/Azure/azure-sdk-for-go/services/preview/containerregistry/runtime/2019-08-15-preview/containerregistry" + "github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry" ) +// ManagedIdentityTokenGetter defines an interface for getting a managed identity token. +type ManagedIdentityTokenGetter interface { + GetManagedIdentityToken(ctx context.Context, clientID string) (azcore.AccessToken, error) +} + +// defaultManagedIdentityTokenGetterImpl is the default implementation of getManagedIdentityToken. +type defaultManagedIdentityTokenGetterImpl struct{} + +func (g *defaultManagedIdentityTokenGetterImpl) GetManagedIdentityToken(ctx context.Context, clientID string) (azcore.AccessToken, error) { + return getManagedIdentityToken(ctx, clientID, azidentity.NewManagedIdentityCredential) +} + +func getManagedIdentityToken(ctx context.Context, clientID string, newCredentialFunc func(opts *azidentity.ManagedIdentityCredentialOptions) (*azidentity.ManagedIdentityCredential, error)) (azcore.AccessToken, error) { + id := azidentity.ClientID(clientID) + opts := azidentity.ManagedIdentityCredentialOptions{ID: id} + cred, err := newCredentialFunc(&opts) + if err != nil { + return azcore.AccessToken{}, err + } + scopes := []string{AADResource} + if cred != nil { + return cred.GetToken(ctx, policy.TokenRequestOptions{Scopes: scopes}) + } + return azcore.AccessToken{}, re.ErrorCodeConfigInvalid.WithComponentType(re.AuthProvider).WithDetail("config is nil pointer for GetServicePrincipalToken") +} + type azureManagedIdentityProviderFactory struct{} -type azureManagedIdentityAuthProvider struct { - identityToken azcore.AccessToken - clientID string - tenantID string + +type MIAuthProvider struct { + identityToken azcore.AccessToken + clientID string + tenantID string + authClientFactory AuthClientFactory + registryHostGetter RegistryHostGetter + getManagedIdentityToken ManagedIdentityTokenGetter } type azureManagedIdentityAuthProviderConf struct { @@ -53,7 +83,7 @@ func init() { provider.Register(azureManagedIdentityAuthProviderName, &azureManagedIdentityProviderFactory{}) } -// Create returns an azureManagedIdentityAuthProvider +// Create returns an MIAuthProvider func (s *azureManagedIdentityProviderFactory) Create(authProviderConfig provider.AuthProviderConfig) (provider.AuthProvider, error) { conf := azureManagedIdentityAuthProviderConf{} authProviderConfigBytes, err := json.Marshal(authProviderConfig) @@ -80,20 +110,22 @@ func (s *azureManagedIdentityProviderFactory) Create(authProviderConfig provider return nil, err } // retrieve an AAD Access token - token, err := getManagedIdentityToken(context.Background(), client) + token, err := getManagedIdentityToken(context.Background(), client, azidentity.NewManagedIdentityCredential) if err != nil { return nil, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureManagedIdentityLink, err, "", re.HideStackTrace) } - return &azureManagedIdentityAuthProvider{ - identityToken: token, - clientID: client, - tenantID: tenant, + return &MIAuthProvider{ + identityToken: token, + clientID: client, + tenantID: tenant, + authClientFactory: &defaultAuthClientFactoryImpl{}, // Concrete implementation + getManagedIdentityToken: &defaultManagedIdentityTokenGetterImpl{}, // Concrete implementation }, nil } // Enabled checks for non empty tenant ID and AAD access token -func (d *azureManagedIdentityAuthProvider) Enabled(_ context.Context) bool { +func (d *MIAuthProvider) Enabled(_ context.Context) bool { if d.clientID == "" { return false } @@ -112,57 +144,58 @@ func (d *azureManagedIdentityAuthProvider) Enabled(_ context.Context) bool { // Provide returns the credentials for a specified artifact. // Uses Managed Identity to retrieve an AAD access token which can be // exchanged for a valid ACR refresh token for login. -func (d *azureManagedIdentityAuthProvider) Provide(ctx context.Context, artifact string) (provider.AuthConfig, error) { +func (d *MIAuthProvider) Provide(ctx context.Context, artifact string) (provider.AuthConfig, error) { if !d.Enabled(ctx) { return provider.AuthConfig{}, fmt.Errorf("azure managed identity provider is not properly enabled") } + // parse the artifact reference string to extract the registry host name - artifactHostName, err := provider.GetRegistryHostName(artifact) + artifactHostName, err := d.registryHostGetter.GetRegistryHost(artifact) if err != nil { - return provider.AuthConfig{}, err + return provider.AuthConfig{}, re.ErrorCodeHostNameInvalid.WithComponentType(re.AuthProvider) } // need to refresh AAD token if it's expired if time.Now().Add(time.Minute * 5).After(d.identityToken.ExpiresOn) { - newToken, err := getManagedIdentityToken(ctx, d.clientID) + newToken, err := d.getManagedIdentityToken.GetManagedIdentityToken(ctx, d.clientID) if err != nil { return provider.AuthConfig{}, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureManagedIdentityLink, err, "could not refresh azure managed identity token", re.HideStackTrace) } d.identityToken = newToken logger.GetLogger(ctx, logOpt).Info("successfully refreshed azure managed identity token") } + // add protocol to generate complete URI serverURL := "https://" + artifactHostName - // create registry client and exchange AAD token for registry refresh token - refreshTokenClient := containerregistry.NewRefreshTokensClient(serverURL) - rt, err := refreshTokenClient.GetFromExchange(ctx, "access_token", artifactHostName, d.tenantID, "", d.identityToken.Token) + // TODO: Consider adding authentication client options for multicloud scenarios + var options *azcontainerregistry.AuthenticationClientOptions + client, err := d.authClientFactory.CreateAuthClient(serverURL, options) + if err != nil { + return provider.AuthConfig{}, re.ErrorCodeAuthDenied.WithError(err).WithDetail("failed to create authentication client for container registry by azure managed identity token") + } + + response, err := client.ExchangeAADAccessTokenForACRRefreshToken( + ctx, + azcontainerregistry.PostContentSchemaGrantType(GrantTypeAccessToken), + artifactHostName, + &azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions{ + AccessToken: &d.identityToken.Token, + Tenant: &d.tenantID, + }, + ) if err != nil { return provider.AuthConfig{}, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureManagedIdentityLink, err, "failed to get refresh token for container registry by azure managed identity token", re.HideStackTrace) } + rt := response.ACRRefreshToken - expiresOn := getACRExpiryIfEarlier(d.identityToken.ExpiresOn) - + refreshTokenExpiry := getACRExpiryIfEarlier(d.identityToken.ExpiresOn) authConfig := provider.AuthConfig{ Username: dockerTokenLoginUsernameGUID, Password: *rt.RefreshToken, Provider: d, - ExpiresOn: expiresOn, + ExpiresOn: refreshTokenExpiry, } return authConfig, nil } - -func getManagedIdentityToken(ctx context.Context, clientID string) (azcore.AccessToken, error) { - id := azidentity.ClientID(clientID) - opts := azidentity.ManagedIdentityCredentialOptions{ID: id} - cred, err := azidentity.NewManagedIdentityCredential(&opts) - if err != nil { - return azcore.AccessToken{}, err - } - scopes := []string{AADResource} - if cred != nil { - return cred.GetToken(ctx, policy.TokenRequestOptions{Scopes: scopes}) - } - return azcore.AccessToken{}, re.ErrorCodeConfigInvalid.WithComponentType(re.AuthProvider).WithDetail("config is nil pointer for GetServicePrincipalToken") -} diff --git a/pkg/common/oras/authprovider/azure/azureidentity_test.go b/pkg/common/oras/authprovider/azure/azureidentity_test.go index 472e704b9..8d466d3d1 100644 --- a/pkg/common/oras/authprovider/azure/azureidentity_test.go +++ b/pkg/common/oras/authprovider/azure/azureidentity_test.go @@ -20,15 +20,31 @@ import ( "errors" "os" "testing" + "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + azcontainerregistry "github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry" ratifyerrors "github.com/ratify-project/ratify/errors" "github.com/ratify-project/ratify/pkg/common/oras/authprovider" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" ) +// Mock types for external dependencies +type MockManagedIdentityTokenGetter struct { + mock.Mock +} + +// Mock ManagedIdentityTokenGetter.GetManagedIdentityToken +func (m *MockManagedIdentityTokenGetter) GetManagedIdentityToken(ctx context.Context, clientID string) (azcore.AccessToken, error) { + args := m.Called(ctx, clientID) + return args.Get(0).(azcore.AccessToken), args.Error(1) +} + // Verifies that Enabled checks if tenantID is empty or AAD token is empty func TestAzureMSIEnabled_ExpectedResults(t *testing.T) { - azAuthProvider := azureManagedIdentityAuthProvider{ + azAuthProvider := MIAuthProvider{ tenantID: "test_tenant", clientID: "test_client", identityToken: azcore.AccessToken{ @@ -89,3 +105,168 @@ func TestAzureMSIValidation_EnvironmentVariables_ExpectedResults(t *testing.T) { t.Fatalf("create auth provider should have failed: expected err %s, but got err %s", expectedErr, err) } } + +// Test for invalid configuration when tenant ID is missing +func TestAzureManagedIdentityProviderFactory_Create_NoTenantID(t *testing.T) { + t.Setenv("AZURE_TENANT_ID", "") + + // Initialize factory + factory := &azureManagedIdentityProviderFactory{} + + // Attempt to create MIAuthProvider with empty configuration + _, err := factory.Create(map[string]interface{}{}) + + // Validate the error + assert.Error(t, err) + assert.Contains(t, err.Error(), "AZURE_TENANT_ID environment variable is empty") +} + +// Test for missing client ID +func TestAzureManagedIdentityProviderFactory_Create_NoClientID(t *testing.T) { + t.Setenv("AZURE_TENANT_ID", "tenantID") + t.Setenv("AZURE_CLIENT_ID", "") + + // Initialize factory + factory := &azureManagedIdentityProviderFactory{} + + // Attempt to create MIAuthProvider with empty client ID + _, err := factory.Create(map[string]interface{}{}) + + // Validate the error + assert.Error(t, err) + assert.Contains(t, err.Error(), "AZURE_CLIENT_ID environment variable is empty") +} + +// Test successful token refresh +func TestMIAuthProvider_Provide_TokenRefreshSuccess(t *testing.T) { + // Mock dependencies + mockAuthClientFactory := new(MockAuthClientFactory) + mockRegistryHostGetter := new(MockRegistryHostGetter) + mockManagedIdentityTokenGetter := new(MockManagedIdentityTokenGetter) + mockAuthClient := new(MockAuthClient) + + // Define token values + expiredToken := azcore.AccessToken{Token: "expired_token", ExpiresOn: time.Now().Add(-10 * time.Minute)} + newTokenString := "refreshed" + newAADToken := azcore.AccessToken{Token: "new_token", ExpiresOn: time.Now().Add(10 * time.Minute)} + refreshToken := azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse{ + ACRRefreshToken: azcontainerregistry.ACRRefreshToken{RefreshToken: &newTokenString}, + } + + // Setup mock expectations + mockRegistryHostGetter.On("GetRegistryHost", "artifact_name").Return("example.azurecr.io", nil) + mockAuthClientFactory.On("CreateAuthClient", "https://example.azurecr.io", mock.Anything).Return(mockAuthClient, nil) + mockAuthClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, azcontainerregistry.PostContentSchemaGrantType(GrantTypeAccessToken), "example.azurecr.io", mock.Anything).Return(refreshToken, nil) + mockManagedIdentityTokenGetter.On("GetManagedIdentityToken", mock.Anything, "clientID").Return(newAADToken, nil) + + // Initialize provider with expired token + provider := MIAuthProvider{ + identityToken: expiredToken, + clientID: "clientID", + tenantID: "tenantID", + authClientFactory: mockAuthClientFactory, + registryHostGetter: mockRegistryHostGetter, + getManagedIdentityToken: mockManagedIdentityTokenGetter, + } + + // Call Provide method + ctx := context.Background() + authConfig, err := provider.Provide(ctx, "artifact_name") + + // Validate success and token refresh + assert.NoError(t, err) + assert.Equal(t, "refreshed", authConfig.Password) +} + +// Test failed token refresh +func TestMIAuthProvider_Provide_TokenRefreshFailure(t *testing.T) { + // Mock dependencies + mockAuthClientFactory := new(MockAuthClientFactory) + mockRegistryHostGetter := new(MockRegistryHostGetter) + mockManagedIdentityTokenGetter := new(MockManagedIdentityTokenGetter) + + // Define token values + expiredToken := azcore.AccessToken{Token: "expired_token", ExpiresOn: time.Now().Add(-10 * time.Minute)} + + // Setup mock expectations + mockRegistryHostGetter.On("GetRegistryHost", "artifact_name").Return("example.azurecr.io", nil) + mockManagedIdentityTokenGetter.On("GetManagedIdentityToken", mock.Anything, "clientID").Return(azcore.AccessToken{}, errors.New("token refresh failed")) + + // Initialize provider with expired token + provider := MIAuthProvider{ + identityToken: expiredToken, + clientID: "clientID", + tenantID: "tenantID", + authClientFactory: mockAuthClientFactory, + registryHostGetter: mockRegistryHostGetter, + getManagedIdentityToken: mockManagedIdentityTokenGetter, + } + + // Call Provide method + ctx := context.Background() + _, err := provider.Provide(ctx, "artifact_name") + + // Validate failure + assert.Error(t, err) + assert.Contains(t, err.Error(), "could not refresh azure managed identity token") +} + +// Test for invalid hostname retrieval +func TestMIAuthProvider_Provide_InvalidHostName(t *testing.T) { + // Mock dependencies + mockAuthClientFactory := new(MockAuthClientFactory) + mockRegistryHostGetter := new(MockRegistryHostGetter) + mockManagedIdentityTokenGetter := new(MockManagedIdentityTokenGetter) + + // Define valid token + validToken := azcore.AccessToken{Token: "valid_token", ExpiresOn: time.Now().Add(10 * time.Minute)} + + // Setup mock expectations for invalid hostname + mockRegistryHostGetter.On("GetRegistryHost", "artifact_name").Return("", errors.New("invalid hostname")) + + // Initialize provider with valid token + provider := MIAuthProvider{ + identityToken: validToken, + clientID: "clientID", + tenantID: "tenantID", + authClientFactory: mockAuthClientFactory, + registryHostGetter: mockRegistryHostGetter, + getManagedIdentityToken: mockManagedIdentityTokenGetter, + } + + // Call Provide method + ctx := context.Background() + _, err := provider.Provide(ctx, "artifact_name") + + // Validate failure + assert.Error(t, err) + assert.Contains(t, err.Error(), "HOST_NAME_INVALID") +} + +// Unit tests +func TestGetManagedIdentityToken(t *testing.T) { + ctx := context.Background() + clientID := "test-client-id" + expectedToken := azcore.AccessToken{Token: "test-token", ExpiresOn: time.Now().Add(time.Hour)} + + mockGetter := new(MockManagedIdentityTokenGetter) + mockGetter.On("GetManagedIdentityToken", ctx, clientID).Return(expectedToken, nil) + + token, err := mockGetter.GetManagedIdentityToken(ctx, clientID) + assert.Nil(t, err) + assert.Equal(t, expectedToken, token) +} + +func TestGetManagedIdentityToken_Error(t *testing.T) { + ctx := context.Background() + clientID := "test-client-id" + + // Mock the newCredentialFunc to return an error + mockNewCredentialFunc := func(_ *azidentity.ManagedIdentityCredentialOptions) (*azidentity.ManagedIdentityCredential, error) { + return nil, assert.AnError + } + + token, err := getManagedIdentityToken(ctx, clientID, mockNewCredentialFunc) + assert.NotNil(t, err) + assert.Equal(t, azcore.AccessToken{}, token) +} diff --git a/pkg/common/oras/authprovider/azure/azureworkloadidentity.go b/pkg/common/oras/authprovider/azure/azureworkloadidentity.go index a40ce4436..31f45127d 100644 --- a/pkg/common/oras/authprovider/azure/azureworkloadidentity.go +++ b/pkg/common/oras/authprovider/azure/azureworkloadidentity.go @@ -21,21 +21,57 @@ import ( "os" "time" + azcontainerregistry "github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry" re "github.com/ratify-project/ratify/errors" "github.com/ratify-project/ratify/internal/logger" provider "github.com/ratify-project/ratify/pkg/common/oras/authprovider" - "github.com/ratify-project/ratify/pkg/metrics" "github.com/ratify-project/ratify/pkg/utils/azureauth" - "github.com/Azure/azure-sdk-for-go/services/preview/containerregistry/runtime/2019-08-15-preview/containerregistry" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" ) +// AADAccessTokenGetter defines an interface for getting an AAD access token. +type AADAccessTokenGetter interface { + GetAADAccessToken(ctx context.Context, tenantID, clientID, resource string) (confidential.AuthResult, error) +} + +// defaultAADAccessTokenGetterImpl is the default implementation of AADAccessTokenGetter. +type defaultAADAccessTokenGetterImpl struct{} + +func (g *defaultAADAccessTokenGetterImpl) GetAADAccessToken(ctx context.Context, tenantID, clientID, resource string) (confidential.AuthResult, error) { + return defaultGetAADAccessToken(ctx, tenantID, clientID, resource) +} + +func defaultGetAADAccessToken(ctx context.Context, tenantID, clientID, resource string) (confidential.AuthResult, error) { + return azureauth.GetAADAccessToken(ctx, tenantID, clientID, resource) +} + +// MetricsReporter defines an interface for reporting metrics. +type MetricsReporter interface { + ReportMetrics(ctx context.Context, duration int64, artifactHostName string) +} + +// defaultMetricsReporterImpl is the default implementation of MetricsReporter. +type defaultMetricsReporterImpl struct{} + +func (r *defaultMetricsReporterImpl) ReportMetrics(ctx context.Context, duration int64, artifactHostName string) { + defaultReportMetrics(ctx, duration, artifactHostName) +} + +func defaultReportMetrics(ctx context.Context, duration int64, artifactHostName string) { + logger.GetLogger(ctx, logOpt).Infof("Metrics Report: Duration=%dms, Host=%s", duration, artifactHostName) +} + type AzureWIProviderFactory struct{} //nolint:revive // ignore linter to have unique type name -type azureWIAuthProvider struct { - aadToken confidential.AuthResult - tenantID string - clientID string + +type WIAuthProvider struct { + aadToken confidential.AuthResult + tenantID string + clientID string + authClientFactory AuthClientFactory + registryHostGetter RegistryHostGetter + getAADAccessToken AADAccessTokenGetter + reportMetrics MetricsReporter } type azureWIAuthProviderConf struct { @@ -78,20 +114,24 @@ func (s *AzureWIProviderFactory) Create(authProviderConfig provider.AuthProvider } // retrieve an AAD Access token - token, err := azureauth.GetAADAccessToken(context.Background(), tenant, clientID, AADResource) + token, err := defaultGetAADAccessToken(context.Background(), tenant, clientID, AADResource) if err != nil { return nil, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureWorkloadIdentityLink, err, "", re.HideStackTrace) } - return &azureWIAuthProvider{ - aadToken: token, - tenantID: tenant, - clientID: clientID, + return &WIAuthProvider{ + aadToken: token, + tenantID: tenant, + clientID: clientID, + authClientFactory: &defaultAuthClientFactoryImpl{}, // Concrete implementation + registryHostGetter: &defaultRegistryHostGetterImpl{}, // Concrete implementation + getAADAccessToken: &defaultAADAccessTokenGetterImpl{}, // Concrete implementation + reportMetrics: &defaultMetricsReporterImpl{}, }, nil } // Enabled checks for non empty tenant ID and AAD access token -func (d *azureWIAuthProvider) Enabled(_ context.Context) bool { +func (d *WIAuthProvider) Enabled(_ context.Context) bool { if d.tenantID == "" || d.clientID == "" { return false } @@ -106,19 +146,20 @@ func (d *azureWIAuthProvider) Enabled(_ context.Context) bool { // Provide returns the credentials for a specified artifact. // Uses Azure Workload Identity to retrieve an AAD access token which can be // exchanged for a valid ACR refresh token for login. -func (d *azureWIAuthProvider) Provide(ctx context.Context, artifact string) (provider.AuthConfig, error) { +func (d *WIAuthProvider) Provide(ctx context.Context, artifact string) (provider.AuthConfig, error) { if !d.Enabled(ctx) { return provider.AuthConfig{}, re.ErrorCodeConfigInvalid.WithComponentType(re.AuthProvider).WithDetail("azure workload identity auth provider is not properly enabled") } + // parse the artifact reference string to extract the registry host name - artifactHostName, err := provider.GetRegistryHostName(artifact) + artifactHostName, err := d.registryHostGetter.GetRegistryHost(artifact) if err != nil { return provider.AuthConfig{}, re.ErrorCodeHostNameInvalid.WithComponentType(re.AuthProvider) } // need to refresh AAD token if it's expired if time.Now().Add(time.Minute * 5).After(d.aadToken.ExpiresOn) { - newToken, err := azureauth.GetAADAccessToken(ctx, d.tenantID, d.clientID, AADResource) + newToken, err := d.getAADAccessToken.GetAADAccessToken(ctx, d.tenantID, d.clientID, AADResource) if err != nil { return provider.AuthConfig{}, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureWorkloadIdentityLink, nil, "could not refresh AAD token", re.HideStackTrace) } @@ -129,14 +170,29 @@ func (d *azureWIAuthProvider) Provide(ctx context.Context, artifact string) (pro // add protocol to generate complete URI serverURL := "https://" + artifactHostName - // create registry client and exchange AAD token for registry refresh token - refreshTokenClient := containerregistry.NewRefreshTokensClient(serverURL) + // TODO: Consider adding authentication client options for multicloud scenarios + var options *azcontainerregistry.AuthenticationClientOptions + client, err := d.authClientFactory.CreateAuthClient(serverURL, options) + if err != nil { + return provider.AuthConfig{}, re.ErrorCodeAuthDenied.WithError(err).WithDetail("failed to create authentication client for container registry by azure managed identity token") + } + startTime := time.Now() - rt, err := refreshTokenClient.GetFromExchange(context.Background(), "access_token", artifactHostName, d.tenantID, "", d.aadToken.AccessToken) + response, err := client.ExchangeAADAccessTokenForACRRefreshToken( + ctx, + azcontainerregistry.PostContentSchemaGrantType(GrantTypeAccessToken), + artifactHostName, + &azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions{ + AccessToken: &d.aadToken.AccessToken, + Tenant: &d.tenantID, + }, + ) if err != nil { return provider.AuthConfig{}, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureWorkloadIdentityLink, err, "failed to get refresh token for container registry", re.HideStackTrace) } - metrics.ReportACRExchangeDuration(ctx, time.Since(startTime).Milliseconds(), artifactHostName) + rt := response.ACRRefreshToken + + d.reportMetrics.ReportMetrics(ctx, time.Since(startTime).Milliseconds(), artifactHostName) refreshTokenExpiry := getACRExpiryIfEarlier(d.aadToken.ExpiresOn) authConfig := provider.AuthConfig{ diff --git a/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go b/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go index 3695ef65a..b2ffaa0cd 100644 --- a/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go +++ b/pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go @@ -22,14 +22,319 @@ import ( "testing" "time" - "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" ratifyerrors "github.com/ratify-project/ratify/errors" "github.com/ratify-project/ratify/pkg/common/oras/authprovider" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + azcontainerregistry "github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" ) +// MockAADAccessTokenGetter for retrieving AAD access token +type MockAADAccessTokenGetter struct { + mock.Mock +} + +func (m *MockAADAccessTokenGetter) GetAADAccessToken(ctx context.Context, tenantID, clientID, resource string) (confidential.AuthResult, error) { + args := m.Called(ctx, tenantID, clientID, resource) + return args.Get(0).(confidential.AuthResult), args.Error(1) +} + +// MockMetricsReporter for reporting metrics +type MockMetricsReporter struct { + mock.Mock +} + +func (m *MockMetricsReporter) ReportMetrics(ctx context.Context, duration int64, artifactHostName string) { + m.Called(ctx, duration, artifactHostName) +} + +// Test for successful Provide function +func TestWIAuthProvider_Provide_Success(t *testing.T) { + // Mock all dependencies + mockAuthClientFactory := new(MockAuthClientFactory) + mockRegistryHostGetter := new(MockRegistryHostGetter) + mockAADAccessTokenGetter := new(MockAADAccessTokenGetter) + mockMetricsReporter := new(MockMetricsReporter) + mockAuthClient := new(MockAuthClient) + + // Mock AAD token + initialToken := confidential.AuthResult{AccessToken: "initial_token", ExpiresOn: time.Now().Add(10 * time.Minute)} + refreshTokenString := "new_refresh_token" + refreshToken := azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse{ + ACRRefreshToken: azcontainerregistry.ACRRefreshToken{RefreshToken: &refreshTokenString}, + } + + // Set expectations for mocked functions + mockRegistryHostGetter.On("GetRegistryHost", "artifact_name").Return("example.azurecr.io", nil) + mockAuthClientFactory.On("CreateAuthClient", "https://example.azurecr.io", mock.Anything).Return(mockAuthClient, nil) + mockAuthClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, azcontainerregistry.PostContentSchemaGrantType(GrantTypeAccessToken), "example.azurecr.io", mock.Anything).Return(refreshToken, nil) + mockAADAccessTokenGetter.On("GetAADAccessToken", mock.Anything, "tenantID", "clientID", mock.Anything).Return(initialToken, nil) + mockMetricsReporter.On("ReportMetrics", mock.Anything, mock.Anything, "example.azurecr.io").Return() + + // Create WIAuthProvider + provider := WIAuthProvider{ + aadToken: initialToken, + tenantID: "tenantID", + clientID: "clientID", + authClientFactory: mockAuthClientFactory, + registryHostGetter: mockRegistryHostGetter, + getAADAccessToken: mockAADAccessTokenGetter, + reportMetrics: mockMetricsReporter, + } + + // Call Provide method + ctx := context.Background() + authConfig, err := provider.Provide(ctx, "artifact_name") + + // Assertions + assert.NoError(t, err) + assert.Equal(t, "new_refresh_token", authConfig.Password) +} + +// Test for AAD token refresh logic +func TestWIAuthProvider_Provide_RefreshToken(t *testing.T) { + // Mock all dependencies + mockAuthClientFactory := new(MockAuthClientFactory) + mockRegistryHostGetter := new(MockRegistryHostGetter) + mockAADAccessTokenGetter := new(MockAADAccessTokenGetter) + mockMetricsReporter := new(MockMetricsReporter) + mockAuthClient := new(MockAuthClient) + + // Mock expired AAD token, and refreshed token + expiredToken := confidential.AuthResult{AccessToken: "expired_token", ExpiresOn: time.Now().Add(-10 * time.Minute)} + newToken := confidential.AuthResult{AccessToken: "new_token", ExpiresOn: time.Now().Add(10 * time.Minute)} + refreshTokenString := "refreshed_token" + refreshToken := azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse{ + ACRRefreshToken: azcontainerregistry.ACRRefreshToken{RefreshToken: &refreshTokenString}, + } + + // Set expectations for mocked functions + mockRegistryHostGetter.On("GetRegistryHost", "artifact_name").Return("example.azurecr.io", nil) + mockAuthClientFactory.On("CreateAuthClient", "https://example.azurecr.io", mock.Anything).Return(mockAuthClient, nil) + mockAuthClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, azcontainerregistry.PostContentSchemaGrantType(GrantTypeAccessToken), "example.azurecr.io", mock.Anything).Return(refreshToken, nil) + mockAADAccessTokenGetter.On("GetAADAccessToken", mock.Anything, "tenantID", "clientID", mock.Anything).Return(newToken, nil) + mockMetricsReporter.On("ReportMetrics", mock.Anything, mock.Anything, "example.azurecr.io").Return() + + // Create WIAuthProvider with expired token + provider := WIAuthProvider{ + aadToken: expiredToken, + tenantID: "tenantID", + clientID: "clientID", + authClientFactory: mockAuthClientFactory, + registryHostGetter: mockRegistryHostGetter, + getAADAccessToken: mockAADAccessTokenGetter, + reportMetrics: mockMetricsReporter, + } + + // Call Provide method + ctx := context.Background() + authConfig, err := provider.Provide(ctx, "artifact_name") + + // Assertions + assert.NoError(t, err) + assert.Equal(t, "refreshed_token", authConfig.Password) +} + +// Test for failure when GetAADAccessToken fails +func TestWIAuthProvider_Provide_AADTokenFailure(t *testing.T) { + // Mock all dependencies + mockAuthClientFactory := new(MockAuthClientFactory) + mockRegistryHostGetter := new(MockRegistryHostGetter) + mockAADAccessTokenGetter := new(MockAADAccessTokenGetter) + mockMetricsReporter := new(MockMetricsReporter) + + // Mock expired AAD token, and failure to refresh + expiredToken := confidential.AuthResult{AccessToken: "expired_token", ExpiresOn: time.Now().Add(-10 * time.Minute)} + + // Set expectations for mocked functions + mockRegistryHostGetter.On("GetRegistryHost", "artifact_name").Return("example.azurecr.io", nil) + mockAADAccessTokenGetter.On("GetAADAccessToken", mock.Anything, "tenantID", "clientID", mock.Anything).Return(confidential.AuthResult{}, errors.New("token refresh failed")) + + // Create WIAuthProvider with expired token + provider := WIAuthProvider{ + aadToken: expiredToken, + tenantID: "tenantID", + clientID: "clientID", + authClientFactory: mockAuthClientFactory, + registryHostGetter: mockRegistryHostGetter, + getAADAccessToken: mockAADAccessTokenGetter, + reportMetrics: mockMetricsReporter, + } + + // Call Provide method + ctx := context.Background() + _, err := provider.Provide(ctx, "artifact_name") + + // Assertions + assert.Error(t, err) + assert.Contains(t, err.Error(), "could not refresh AAD token") +} + +// Test when tenant ID is missing from the environment +func TestAzureWIProviderFactory_Create_NoTenantID(t *testing.T) { + // Clear the tenant ID environment variable + t.Setenv("AZURE_TENANT_ID", "") + + // Initialize provider factory + factory := &AzureWIProviderFactory{} + + // Call Create with minimal configuration + _, err := factory.Create(map[string]interface{}{}) + + // Expect error related to missing tenant ID + assert.Error(t, err) + assert.Contains(t, err.Error(), "azure tenant id environment variable is empty") +} + +// Test when client ID is missing from the environment +func TestAzureWIProviderFactory_Create_NoClientID(t *testing.T) { + // Set tenant ID but leave client ID empty + t.Setenv("AZURE_TENANT_ID", "tenantID") + t.Setenv("AZURE_CLIENT_ID", "") + + // Initialize provider factory + factory := &AzureWIProviderFactory{} + + // Call Create with minimal configuration + _, err := factory.Create(map[string]interface{}{}) + + // Expect error related to missing client ID + assert.Error(t, err) + assert.Contains(t, err.Error(), "no client ID provided and AZURE_CLIENT_ID environment variable is empty") +} + +// Test for successful token refresh +func TestWIAuthProvider_Provide_TokenRefresh_Success(t *testing.T) { + // Mock dependencies + mockAuthClientFactory := new(MockAuthClientFactory) + mockRegistryHostGetter := new(MockRegistryHostGetter) + mockAADAccessTokenGetter := new(MockAADAccessTokenGetter) + mockMetricsReporter := new(MockMetricsReporter) + mockAuthClient := new(MockAuthClient) + + // Mock expired AAD token and refreshed token + expiredToken := confidential.AuthResult{AccessToken: "expired_token", ExpiresOn: time.Now().Add(-10 * time.Minute)} + refreshTokenString := "refreshed_token" + newToken := confidential.AuthResult{AccessToken: "new_token", ExpiresOn: time.Now().Add(10 * time.Minute)} + refreshToken := azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse{ + ACRRefreshToken: azcontainerregistry.ACRRefreshToken{RefreshToken: &refreshTokenString}, + } + + // Set expectations + mockRegistryHostGetter.On("GetRegistryHost", "artifact_name").Return("example.azurecr.io", nil) + mockAuthClientFactory.On("CreateAuthClient", "https://example.azurecr.io", mock.Anything).Return(mockAuthClient, nil) + mockAuthClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, azcontainerregistry.PostContentSchemaGrantType(GrantTypeAccessToken), "example.azurecr.io", mock.Anything).Return(refreshToken, nil) + mockAADAccessTokenGetter.On("GetAADAccessToken", mock.Anything, "tenantID", "clientID", mock.Anything).Return(newToken, nil) + mockMetricsReporter.On("ReportMetrics", mock.Anything, mock.Anything, "example.azurecr.io").Return() + + // Create WIAuthProvider with expired token + provider := WIAuthProvider{ + aadToken: expiredToken, + tenantID: "tenantID", + clientID: "clientID", + authClientFactory: mockAuthClientFactory, + registryHostGetter: mockRegistryHostGetter, + getAADAccessToken: mockAADAccessTokenGetter, + reportMetrics: mockMetricsReporter, + } + + // Call Provide method + ctx := context.Background() + authConfig, err := provider.Provide(ctx, "artifact_name") + + // Assertions + assert.NoError(t, err) + assert.Equal(t, "refreshed_token", authConfig.Password) +} + +// Test when token refresh fails +func TestWIAuthProvider_Provide_TokenRefreshFailure(t *testing.T) { + // Mock dependencies + mockAuthClientFactory := new(MockAuthClientFactory) + mockRegistryHostGetter := new(MockRegistryHostGetter) + mockAADAccessTokenGetter := new(MockAADAccessTokenGetter) + mockMetricsReporter := new(MockMetricsReporter) + + // Mock expired AAD token and failure to refresh + expiredToken := confidential.AuthResult{AccessToken: "expired_token", ExpiresOn: time.Now().Add(-10 * time.Minute)} + + // Set expectations + mockRegistryHostGetter.On("GetRegistryHost", "artifact_name").Return("example.azurecr.io", nil) + mockAADAccessTokenGetter.On("GetAADAccessToken", mock.Anything, "tenantID", "clientID", mock.Anything).Return(confidential.AuthResult{}, errors.New("token refresh failed")) + + // Create WIAuthProvider with expired token + provider := WIAuthProvider{ + aadToken: expiredToken, + tenantID: "tenantID", + clientID: "clientID", + authClientFactory: mockAuthClientFactory, + registryHostGetter: mockRegistryHostGetter, + getAADAccessToken: mockAADAccessTokenGetter, + reportMetrics: mockMetricsReporter, + } + + // Call Provide method + ctx := context.Background() + _, err := provider.Provide(ctx, "artifact_name") + + // Assertions + assert.Error(t, err) + assert.Contains(t, err.Error(), "could not refresh AAD token") +} + +// Test for handling empty AccessToken +func TestWIAuthProvider_Enabled_NoAccessToken(t *testing.T) { + // Create a provider with no AccessToken + provider := WIAuthProvider{ + tenantID: "tenantID", + clientID: "clientID", + aadToken: confidential.AuthResult{AccessToken: ""}, + } + + // Assert that provider is not enabled + enabled := provider.Enabled(context.Background()) + assert.False(t, enabled) +} + +// Test for invalid hostname retrieval +func TestWIAuthProvider_Provide_InvalidHostName(t *testing.T) { + // Mock dependencies + mockAuthClientFactory := new(MockAuthClientFactory) + mockRegistryHostGetter := new(MockRegistryHostGetter) + mockAADAccessTokenGetter := new(MockAADAccessTokenGetter) + mockMetricsReporter := new(MockMetricsReporter) + + // Mock valid AAD token + validToken := confidential.AuthResult{AccessToken: "valid_token", ExpiresOn: time.Now().Add(10 * time.Minute)} + + // Set expectations for an invalid hostname + mockRegistryHostGetter.On("GetRegistryHost", "artifact_name").Return("", errors.New("invalid hostname")) + + // Create WIAuthProvider with valid token + provider := WIAuthProvider{ + aadToken: validToken, + tenantID: "tenantID", + clientID: "clientID", + authClientFactory: mockAuthClientFactory, + registryHostGetter: mockRegistryHostGetter, + getAADAccessToken: mockAADAccessTokenGetter, + reportMetrics: mockMetricsReporter, + } + + // Call Provide method + ctx := context.Background() + _, err := provider.Provide(ctx, "artifact_name") + + // Assertions + assert.Error(t, err) + assert.Contains(t, err.Error(), "HOST_NAME_INVALID") +} + // Verifies that Enabled checks if tenantID is empty or AAD token is empty func TestAzureWIEnabled_ExpectedResults(t *testing.T) { - azAuthProvider := azureWIAuthProvider{ + azAuthProvider := WIAuthProvider{ tenantID: "test_tenant", clientID: "test_client", aadToken: confidential.AuthResult{ diff --git a/pkg/common/oras/authprovider/azure/helper.go b/pkg/common/oras/authprovider/azure/helper.go new file mode 100644 index 000000000..beafa4db0 --- /dev/null +++ b/pkg/common/oras/authprovider/azure/helper.go @@ -0,0 +1,84 @@ +/* +Copyright The Ratify Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package azure + +import ( + "context" + + "github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry" + provider "github.com/ratify-project/ratify/pkg/common/oras/authprovider" +) + +const GrantTypeAccessToken = "access_token" + +// AuthClientFactory defines an interface for creating an authentication client. +type AuthClientFactory interface { + CreateAuthClient(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) +} + +// defaultAuthClientFactoryImpl is the default implementation of AuthClientFactory. +type defaultAuthClientFactoryImpl struct{} + +// creates an AuthClient using the default factory implementation. +// Return an AuthClient and an error if the client creation fails. +func (f *defaultAuthClientFactoryImpl) CreateAuthClient(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) { + return defaultAuthClientFactory(serverURL, options) +} + +// Define a helper function that creates an instance of AuthenticationClientWrapper. +func defaultAuthClientFactory(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) { + client, err := azcontainerregistry.NewAuthenticationClient(serverURL, options) + if err != nil { + return nil, err + } + return &AuthenticationClientWrapper{client: client}, nil +} + +// Define the interface for azcontainerregistry.AuthenticationClient methods used +type AuthenticationClientInterface interface { + ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType azcontainerregistry.PostContentSchemaGrantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) +} + +// Define the wrapper for AuthenticationClientInterface +type AuthenticationClientWrapper struct { + client AuthenticationClientInterface +} + +// A wrapper method that calls the underlying AuthenticationClientInterface's method. +// Exchanges an AAD access token for an ACR refresh token. +func (w *AuthenticationClientWrapper) ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType azcontainerregistry.PostContentSchemaGrantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) { + return w.client.ExchangeAADAccessTokenForACRRefreshToken(ctx, grantType, service, options) +} + +// define the interface for authentication operations. +// It includes the method for exchanging an AAD access token for an ACR refresh token. +type AuthClient interface { + ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType azcontainerregistry.PostContentSchemaGrantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) +} + +// RegistryHostGetter defines an interface for getting the registry host. +type RegistryHostGetter interface { + GetRegistryHost(artifact string) (string, error) +} + +// defaultRegistryHostGetterImpl is the default implementation of RegistryHostGetter. +type defaultRegistryHostGetterImpl struct{} + +// Retrieves the registry host name for a given artifact. +// It utilizes the provider's GetRegistryHostName function to perform the lookup. +func (g *defaultRegistryHostGetterImpl) GetRegistryHost(artifact string) (string, error) { + return provider.GetRegistryHostName(artifact) +} diff --git a/pkg/common/oras/authprovider/azure/helper_test.go b/pkg/common/oras/authprovider/azure/helper_test.go new file mode 100644 index 000000000..49c811f0f --- /dev/null +++ b/pkg/common/oras/authprovider/azure/helper_test.go @@ -0,0 +1,100 @@ +/* +Copyright The Ratify Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package azure + +import ( + "context" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +// MockAuthClient is a mock implementation of AuthClient. +type MockAuthClient struct { + mock.Mock +} + +// Mock method for ExchangeAADAccessTokenForACRRefreshToken +func (m *MockAuthClient) ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType azcontainerregistry.PostContentSchemaGrantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) { + args := m.Called(ctx, grantType, service, options) + return args.Get(0).(azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse), args.Error(1) +} + +// MockAuthClientFactory is a mock implementation of AuthClientFactory. +type MockAuthClientFactory struct { + mock.Mock +} + +// Mock method for CreateAuthClient +func (m *MockAuthClientFactory) CreateAuthClient(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) { + args := m.Called(serverURL, options) + return args.Get(0).(AuthClient), args.Error(1) +} + +// MockRegistryHostGetter is a mock implementation of RegistryHostGetter. +type MockRegistryHostGetter struct { + mock.Mock +} + +// Mock method for GetRegistryHost +func (m *MockRegistryHostGetter) GetRegistryHost(artifact string) (string, error) { + args := m.Called(artifact) + return args.String(0), args.Error(1) +} + +func TestDefaultAuthClientFactoryImpl_CreateAuthClient(t *testing.T) { + factory := &defaultAuthClientFactoryImpl{} + serverURL := "https://example.com" + options := &azcontainerregistry.AuthenticationClientOptions{} + + client, err := factory.CreateAuthClient(serverURL, options) + assert.Nil(t, err) + assert.NotNil(t, client) +} + +func TestDefaultAuthClientFactory(t *testing.T) { + serverURL := "https://example.com" + options := &azcontainerregistry.AuthenticationClientOptions{} + + client, err := defaultAuthClientFactory(serverURL, options) + assert.Nil(t, err) + assert.NotNil(t, client) +} + +func TestDefaultRegistryHostGetterImpl_GetRegistryHost(t *testing.T) { + getter := &defaultRegistryHostGetterImpl{} + artifact := "example.azurecr.io/myArtifact" + + host, err := getter.GetRegistryHost(artifact) + assert.Nil(t, err) + assert.Equal(t, "example.azurecr.io", host) +} + +func TestAuthenticationClientWrapper_ExchangeAADAccessTokenForACRRefreshToken(t *testing.T) { + mockClient := new(MockAuthClient) + wrapper := &AuthenticationClientWrapper{client: mockClient} + ctx := context.Background() + grantType := azcontainerregistry.PostContentSchemaGrantType("grantType") + service := "service" + options := &azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions{} + + mockClient.On("ExchangeAADAccessTokenForACRRefreshToken", ctx, grantType, service, options).Return(azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse{}, nil) + + _, err := wrapper.ExchangeAADAccessTokenForACRRefreshToken(ctx, grantType, service, options) + assert.Nil(t, err) +}