Skip to content

Commit

Permalink
chore: refactor azureworkloadidentity
Browse files Browse the repository at this point in the history
Signed-off-by: Shahram Kalantari <[email protected]>
  • Loading branch information
shahramk64 committed Oct 14, 2024
1 parent 811a574 commit dcc499c
Show file tree
Hide file tree
Showing 2 changed files with 382 additions and 114 deletions.
76 changes: 63 additions & 13 deletions pkg/common/oras/authprovider/azure/azureworkloadidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,69 @@ import (
re "github.com/ratify-project/ratify/errors"
"github.com/ratify-project/ratify/internal/logger"
provider "github.com/ratify-project/ratify/pkg/common/oras/authprovider"
"github.com/ratify-project/ratify/pkg/utils/azureauth"

"github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential"
)

// AuthClientFactory defines an interface for creating an authentication client.
type AuthClientFactory interface {
CreateAuthClient(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error)
}

// RegistryHostGetter defines an interface for getting the registry host.
type RegistryHostGetter interface {
GetRegistryHost(artifact string) (string, error)
}

// AADAccessTokenGetter defines an interface for getting an AAD access token.
type AADAccessTokenGetter interface {
GetAADAccessToken(ctx context.Context, tenantID, clientID, resource string) (confidential.AuthResult, error)
}

// MetricsReporter defines an interface for reporting metrics.
type MetricsReporter interface {
ReportMetrics(ctx context.Context, duration int64, artifactHostName string)
}

// DefaultAuthClientFactoryImpl is the default implementation of AuthClientFactory.
type DefaultAuthClientFactoryImpl struct{}

func (f *DefaultAuthClientFactoryImpl) CreateAuthClient(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) {
return DefaultAuthClientFactory(serverURL, options)
}

// DefaultRegistryHostGetterImpl is the default implementation of RegistryHostGetter.
type DefaultRegistryHostGetterImpl struct{}

func (g *DefaultRegistryHostGetterImpl) GetRegistryHost(artifact string) (string, error) {
// Implement the logic to get the registry host
return provider.GetRegistryHostName(artifact)
// return artifactHost, nil // Replace with actual logic
}

// DefaultAADAccessTokenGetterImpl is the default implementation of AADAccessTokenGetter.
type DefaultAADAccessTokenGetterImpl struct{}

func (g *DefaultAADAccessTokenGetterImpl) GetAADAccessToken(ctx context.Context, tenantID, clientID, resource string) (confidential.AuthResult, error) {
return DefaultGetAADAccessToken(ctx, tenantID, clientID, resource)
}

// DefaultMetricsReporterImpl is the default implementation of MetricsReporter.
type DefaultMetricsReporterImpl struct{}

func (r *DefaultMetricsReporterImpl) ReportMetrics(ctx context.Context, duration int64, artifactHostName string) {
DefaultReportMetrics(ctx, duration, artifactHostName)
}

type AzureWIProviderFactory struct{} //nolint:revive // ignore linter to have unique type name
type WIAuthProvider struct {
aadToken confidential.AuthResult
tenantID string
clientID string
authClientFactory func(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error)
getRegistryHost func(artifact string) (string, error)
getAADAccessToken func(ctx context.Context, tenantID, clientID, resource string) (confidential.AuthResult, error)
reportMetrics func(ctx context.Context, duration int64, artifactHostName string)
authClientFactory AuthClientFactory
getRegistryHost RegistryHostGetter
getAADAccessToken AADAccessTokenGetter
reportMetrics MetricsReporter
}

type azureWIAuthProviderConf struct {
Expand Down Expand Up @@ -81,7 +130,7 @@ func (s *AzureWIProviderFactory) Create(authProviderConfig provider.AuthProvider
}

// retrieve an AAD Access token
token, err := azureauth.GetAADAccessToken(context.Background(), tenant, clientID, AADResource)
token, err := DefaultGetAADAccessToken(context.Background(), tenant, clientID, AADResource)
if err != nil {
return nil, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureWorkloadIdentityLink, err, "", re.HideStackTrace)
}
Expand All @@ -90,9 +139,10 @@ func (s *AzureWIProviderFactory) Create(authProviderConfig provider.AuthProvider
aadToken: token,
tenantID: tenant,
clientID: clientID,
authClientFactory: DefaultAuthClientFactory,
getAADAccessToken: DefaultGetAADAccessToken,
reportMetrics: DefaultReportMetrics,
authClientFactory: &DefaultAuthClientFactoryImpl{}, // Concrete implementation
getRegistryHost: &DefaultRegistryHostGetterImpl{}, // Concrete implementation
getAADAccessToken: &DefaultAADAccessTokenGetterImpl{}, // Concrete implementation
reportMetrics: &DefaultMetricsReporterImpl{},
}, nil
}

Expand All @@ -118,14 +168,14 @@ func (d *WIAuthProvider) Provide(ctx context.Context, artifact string) (provider
}

// parse the artifact reference string to extract the registry host name
artifactHostName, err := d.getRegistryHost(artifact)
artifactHostName, err := d.getRegistryHost.GetRegistryHost(artifact)
if err != nil {
return provider.AuthConfig{}, re.ErrorCodeHostNameInvalid.WithComponentType(re.AuthProvider)
}

// need to refresh AAD token if it's expired
if time.Now().Add(time.Minute * 5).After(d.aadToken.ExpiresOn) {
newToken, err := d.getAADAccessToken(ctx, d.tenantID, d.clientID, AADResource)
newToken, err := d.getAADAccessToken.GetAADAccessToken(ctx, d.tenantID, d.clientID, AADResource)
if err != nil {
return provider.AuthConfig{}, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureWorkloadIdentityLink, nil, "could not refresh AAD token", re.HideStackTrace)
}
Expand All @@ -138,7 +188,7 @@ func (d *WIAuthProvider) Provide(ctx context.Context, artifact string) (provider

// TODO: Consider adding authentication client options for multicloud scenarios
var options *azcontainerregistry.AuthenticationClientOptions
client, err := d.authClientFactory(serverURL, options)
client, err := d.authClientFactory.CreateAuthClient(serverURL, options)
if err != nil {
return provider.AuthConfig{}, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureWorkloadIdentityLink, err, "failed to create authentication client for container registry", re.HideStackTrace)
}
Expand All @@ -158,7 +208,7 @@ func (d *WIAuthProvider) Provide(ctx context.Context, artifact string) (provider
}
rt := response.ACRRefreshToken

d.reportMetrics(ctx, time.Since(startTime).Milliseconds(), artifactHostName)
d.reportMetrics.ReportMetrics(ctx, time.Since(startTime).Milliseconds(), artifactHostName)

refreshTokenExpiry := getACRExpiryIfEarlier(d.aadToken.ExpiresOn)
authConfig := provider.AuthConfig{
Expand Down
Loading

0 comments on commit dcc499c

Please sign in to comment.