Skip to content

Commit

Permalink
chore: refactor to enable mocking and add unit tests to azureworkload…
Browse files Browse the repository at this point in the history
…identity_test.go
  • Loading branch information
shahramk64 committed Sep 29, 2024
1 parent bb0107e commit 045b901
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 20 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ require (
github.com/sigstore/timestamp-authority v1.2.2 // indirect
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/tchap/go-patricia/v2 v2.3.1 // indirect
github.com/thales-e-security/pool v0.0.2 // indirect
github.com/tjfoc/gmsm v1.4.1 // indirect
Expand Down
1 change: 1 addition & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,7 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
Expand Down
65 changes: 45 additions & 20 deletions pkg/common/oras/authprovider/azure/azureworkloadidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (
"os"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry"
azcontainerregistry "github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry"
re "github.com/ratify-project/ratify/errors"
"github.com/ratify-project/ratify/internal/logger"
provider "github.com/ratify-project/ratify/pkg/common/oras/authprovider"
Expand All @@ -33,9 +33,40 @@ import (

type AzureWIProviderFactory struct{} //nolint:revive // ignore linter to have unique type name
type azureWIAuthProvider struct {
aadToken confidential.AuthResult
tenantID string
clientID string
aadToken confidential.AuthResult
tenantID string
clientID string
authClientFactory func(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (authClient, error)
getRegistryHost func(artifact string) (string, error)
getAADAccessToken func(ctx context.Context, tenantID, clientID, resource string) (confidential.AuthResult, error)
reportMetrics func(ctx context.Context, duration int64, artifactHostName string)
}

type authenticationClientWrapper struct {
client *azcontainerregistry.AuthenticationClient
}

func (w *authenticationClientWrapper) ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) {
return w.client.ExchangeAADAccessTokenForACRRefreshToken(ctx, azcontainerregistry.PostContentSchemaGrantType(grantType), service, options)
}

type authClient interface {
ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error)
}

func NewAzureWIAuthProvider() *azureWIAuthProvider {
return &azureWIAuthProvider{
authClientFactory: func(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (authClient, error) {
client, err := azcontainerregistry.NewAuthenticationClient(serverURL, options)
if err != nil {
return nil, err
}
return &authenticationClientWrapper{client: client}, nil
},
getRegistryHost: provider.GetRegistryHostName,
getAADAccessToken: azureauth.GetAADAccessToken,
reportMetrics: metrics.ReportACRExchangeDuration,
}
}

type azureWIAuthProviderConf struct {
Expand Down Expand Up @@ -103,54 +134,48 @@ func (d *azureWIAuthProvider) Enabled(_ context.Context) bool {
return true
}

// Provide returns the credentials for a specified artifact.
// Uses Azure Workload Identity to retrieve an AAD access token which can be
// exchanged for a valid ACR refresh token for login.
func (d *azureWIAuthProvider) Provide(ctx context.Context, artifact string) (provider.AuthConfig, error) {
if !d.Enabled(ctx) {
return provider.AuthConfig{}, re.ErrorCodeConfigInvalid.WithComponentType(re.AuthProvider).WithDetail("azure workload identity auth provider is not properly enabled")
}
// parse the artifact reference string to extract the registry host name
artifactHostName, err := provider.GetRegistryHostName(artifact)

artifactHostName, err := d.getRegistryHost(artifact)
if err != nil {
return provider.AuthConfig{}, re.ErrorCodeHostNameInvalid.WithComponentType(re.AuthProvider)
}

// need to refresh AAD token if it's expired
if time.Now().Add(time.Minute * 5).After(d.aadToken.ExpiresOn) {
newToken, err := azureauth.GetAADAccessToken(ctx, d.tenantID, d.clientID, AADResource)
newToken, err := d.getAADAccessToken(ctx, d.tenantID, d.clientID, AADResource)
if err != nil {
return provider.AuthConfig{}, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureWorkloadIdentityLink, nil, "could not refresh AAD token", re.HideStackTrace)
}
d.aadToken = newToken
logger.GetLogger(ctx, logOpt).Info("successfully refreshed AAD token")
}

// add protocol to generate complete URI
serverURL := "https://" + artifactHostName

// create registry client and exchange AAD token for registry refresh token
// TODO: Consider adding authentication client options for multicloud scenarios
client, err := azcontainerregistry.NewAuthenticationClient(serverURL, nil) // &AuthenticationClientOptions{ClientOptions: options})
client, err := d.authClientFactory(serverURL, nil)
if err != nil {
return provider.AuthConfig{}, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureWorkloadIdentityLink, err, "failed to create authentication client for container registry", re.HideStackTrace)
}
// refreshTokenClient := azcontainerregistry.NewRefreshTokensClient(serverURL)

startTime := time.Now()
rt, err := client.ExchangeAADAccessTokenForACRRefreshToken(
context.Background(),
response, err := client.ExchangeAADAccessTokenForACRRefreshToken(
ctx,
"access_token",
artifactHostName,
&azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions{
AccessToken: &d.aadToken.AccessToken,
Tenant: &d.tenantID,
},
)
// rt, err := refreshTokenClient.GetFromExchange(context.Background(), "access_token", artifactHostName, d.tenantID, "", d.aadToken.AccessToken)
if err != nil {
return provider.AuthConfig{}, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureWorkloadIdentityLink, err, "failed to get refresh token for container registry", re.HideStackTrace)
}
metrics.ReportACRExchangeDuration(ctx, time.Since(startTime).Milliseconds(), artifactHostName)
rt := response.ACRRefreshToken

d.reportMetrics(ctx, time.Since(startTime).Milliseconds(), artifactHostName)

refreshTokenExpiry := getACRExpiryIfEarlier(d.aadToken.ExpiresOn)
authConfig := provider.AuthConfig{
Expand Down
60 changes: 60 additions & 0 deletions pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@ import (
"testing"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential"
ratifyerrors "github.com/ratify-project/ratify/errors"
"github.com/ratify-project/ratify/pkg/common/oras/authprovider"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)

// Verifies that Enabled checks if tenantID is empty or AAD token is empty
Expand Down Expand Up @@ -131,3 +134,60 @@ func TestAzureWIValidation_EnvironmentVariables_ExpectedResults(t *testing.T) {
t.Fatalf("create auth provider should have failed: expected err %s, but got err %s", expectedErr, err)
}
}

type mockAuthClient struct {
mock.Mock
}

func (m *mockAuthClient) ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) {
args := m.Called(ctx, grantType, service, options)
return args.Get(0).(azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse), args.Error(1)
}

func TestProvide_Success(t *testing.T) {
mockClient := new(mockAuthClient)
expectedRefreshToken := "mocked_refresh_token"
mockClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, "access_token", "myregistry.azurecr.io", mock.Anything).
Return(azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse{
ACRRefreshToken: azcontainerregistry.ACRRefreshToken{RefreshToken: &expectedRefreshToken},
}, nil)

provider := &azureWIAuthProvider{
aadToken: confidential.AuthResult{
AccessToken: "mockToken",
ExpiresOn: time.Now().Add(time.Hour),
},
tenantID: "mockTenantID",
clientID: "mockClientID",
authClientFactory: func(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (authClient, error) {
return mockClient, nil
},
getRegistryHost: func(artifact string) (string, error) {
return "myregistry.azurecr.io", nil
},
getAADAccessToken: func(ctx context.Context, tenantID, clientID, resource string) (confidential.AuthResult, error) {
return confidential.AuthResult{
AccessToken: "mockToken",
ExpiresOn: time.Now().Add(time.Hour),
}, nil
},
reportMetrics: func(ctx context.Context, duration int64, artifactHostName string) {},
}

authConfig, err := provider.Provide(context.Background(), "artifact")

assert.NoError(t, err)
// Assert that the returned refresh token matches the expected one
assert.Equal(t, expectedRefreshToken, authConfig.Password)
}

func TestProvide_Failure_InvalidHostName(t *testing.T) {
provider := &azureWIAuthProvider{
getRegistryHost: func(artifact string) (string, error) {
return "", errors.New("invalid hostname")
},
}

_, err := provider.Provide(context.Background(), "artifact")
assert.Error(t, err)
}

0 comments on commit 045b901

Please sign in to comment.