Skip to content

Commit

Permalink
chore: address comments
Browse files Browse the repository at this point in the history
Signed-off-by: Shahram Kalantari <[email protected]>
  • Loading branch information
shahramk64 committed Oct 21, 2024
1 parent 69f1266 commit cbbf124
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 63 deletions.
4 changes: 2 additions & 2 deletions pkg/common/oras/authprovider/azure/azureidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ type MIAuthProvider struct {
clientID string
tenantID string
authClientFactory AuthClientFactory
getRegistryHost RegistryHostGetter
registryHostGetter RegistryHostGetter
getManagedIdentityToken ManagedIdentityTokenGetter
}

Expand Down Expand Up @@ -150,7 +150,7 @@ func (d *MIAuthProvider) Provide(ctx context.Context, artifact string) (provider
}

// parse the artifact reference string to extract the registry host name
artifactHostName, err := d.getRegistryHost.GetRegistryHost(artifact)
artifactHostName, err := d.registryHostGetter.GetRegistryHost(artifact)
if err != nil {
return provider.AuthConfig{}, re.ErrorCodeHostNameInvalid.WithComponentType(re.AuthProvider)
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/common/oras/authprovider/azure/azureidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ func TestMIAuthProvider_Provide_TokenRefreshSuccess(t *testing.T) {
clientID: "clientID",
tenantID: "tenantID",
authClientFactory: mockAuthClientFactory,
getRegistryHost: mockRegistryHostGetter,
registryHostGetter: mockRegistryHostGetter,
getManagedIdentityToken: mockManagedIdentityTokenGetter,
}

Expand Down Expand Up @@ -198,7 +198,7 @@ func TestMIAuthProvider_Provide_TokenRefreshFailure(t *testing.T) {
clientID: "clientID",
tenantID: "tenantID",
authClientFactory: mockAuthClientFactory,
getRegistryHost: mockRegistryHostGetter,
registryHostGetter: mockRegistryHostGetter,
getManagedIdentityToken: mockManagedIdentityTokenGetter,
}

Expand Down Expand Up @@ -230,7 +230,7 @@ func TestMIAuthProvider_Provide_InvalidHostName(t *testing.T) {
clientID: "clientID",
tenantID: "tenantID",
authClientFactory: mockAuthClientFactory,
getRegistryHost: mockRegistryHostGetter,
registryHostGetter: mockRegistryHostGetter,
getManagedIdentityToken: mockManagedIdentityTokenGetter,
}

Expand Down
30 changes: 15 additions & 15 deletions pkg/common/oras/authprovider/azure/azureworkloadidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,13 @@ func defaultReportMetrics(ctx context.Context, duration int64, artifactHostName
type AzureWIProviderFactory struct{} //nolint:revive // ignore linter to have unique type name

type WIAuthProvider struct {
aadToken confidential.AuthResult
tenantID string
clientID string
authClientFactory AuthClientFactory
getRegistryHost RegistryHostGetter
getAADAccessToken AADAccessTokenGetter
reportMetrics MetricsReporter
aadToken confidential.AuthResult
tenantID string
clientID string
authClientFactory AuthClientFactory
registryHostGetter RegistryHostGetter
getAADAccessToken AADAccessTokenGetter
reportMetrics MetricsReporter
}

type azureWIAuthProviderConf struct {
Expand Down Expand Up @@ -120,13 +120,13 @@ func (s *AzureWIProviderFactory) Create(authProviderConfig provider.AuthProvider
}

return &WIAuthProvider{
aadToken: token,
tenantID: tenant,
clientID: clientID,
authClientFactory: &defaultAuthClientFactoryImpl{}, // Concrete implementation
getRegistryHost: &defaultRegistryHostGetterImpl{}, // Concrete implementation
getAADAccessToken: &defaultAADAccessTokenGetterImpl{}, // Concrete implementation
reportMetrics: &defaultMetricsReporterImpl{},
aadToken: token,
tenantID: tenant,
clientID: clientID,
authClientFactory: &defaultAuthClientFactoryImpl{}, // Concrete implementation
registryHostGetter: &defaultRegistryHostGetterImpl{}, // Concrete implementation
getAADAccessToken: &defaultAADAccessTokenGetterImpl{}, // Concrete implementation
reportMetrics: &defaultMetricsReporterImpl{},
}, nil
}

Expand All @@ -152,7 +152,7 @@ func (d *WIAuthProvider) Provide(ctx context.Context, artifact string) (provider
}

// parse the artifact reference string to extract the registry host name
artifactHostName, err := d.getRegistryHost.GetRegistryHost(artifact)
artifactHostName, err := d.registryHostGetter.GetRegistryHost(artifact)
if err != nil {
return provider.AuthConfig{}, re.ErrorCodeHostNameInvalid.WithComponentType(re.AuthProvider)
}
Expand Down
84 changes: 42 additions & 42 deletions pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,13 @@ func TestWIAuthProvider_Provide_Success(t *testing.T) {

// Create WIAuthProvider
provider := WIAuthProvider{
aadToken: initialToken,
tenantID: "tenantID",
clientID: "clientID",
authClientFactory: mockAuthClientFactory,
getRegistryHost: mockRegistryHostGetter,
getAADAccessToken: mockAADAccessTokenGetter,
reportMetrics: mockMetricsReporter,
aadToken: initialToken,
tenantID: "tenantID",
clientID: "clientID",
authClientFactory: mockAuthClientFactory,
registryHostGetter: mockRegistryHostGetter,
getAADAccessToken: mockAADAccessTokenGetter,
reportMetrics: mockMetricsReporter,
}

// Call Provide method
Expand Down Expand Up @@ -119,13 +119,13 @@ func TestWIAuthProvider_Provide_RefreshToken(t *testing.T) {

// Create WIAuthProvider with expired token
provider := WIAuthProvider{
aadToken: expiredToken,
tenantID: "tenantID",
clientID: "clientID",
authClientFactory: mockAuthClientFactory,
getRegistryHost: mockRegistryHostGetter,
getAADAccessToken: mockAADAccessTokenGetter,
reportMetrics: mockMetricsReporter,
aadToken: expiredToken,
tenantID: "tenantID",
clientID: "clientID",
authClientFactory: mockAuthClientFactory,
registryHostGetter: mockRegistryHostGetter,
getAADAccessToken: mockAADAccessTokenGetter,
reportMetrics: mockMetricsReporter,
}

// Call Provide method
Expand Down Expand Up @@ -154,13 +154,13 @@ func TestWIAuthProvider_Provide_AADTokenFailure(t *testing.T) {

// Create WIAuthProvider with expired token
provider := WIAuthProvider{
aadToken: expiredToken,
tenantID: "tenantID",
clientID: "clientID",
authClientFactory: mockAuthClientFactory,
getRegistryHost: mockRegistryHostGetter,
getAADAccessToken: mockAADAccessTokenGetter,
reportMetrics: mockMetricsReporter,
aadToken: expiredToken,
tenantID: "tenantID",
clientID: "clientID",
authClientFactory: mockAuthClientFactory,
registryHostGetter: mockRegistryHostGetter,
getAADAccessToken: mockAADAccessTokenGetter,
reportMetrics: mockMetricsReporter,
}

// Call Provide method
Expand Down Expand Up @@ -231,13 +231,13 @@ func TestWIAuthProvider_Provide_TokenRefresh_Success(t *testing.T) {

// Create WIAuthProvider with expired token
provider := WIAuthProvider{
aadToken: expiredToken,
tenantID: "tenantID",
clientID: "clientID",
authClientFactory: mockAuthClientFactory,
getRegistryHost: mockRegistryHostGetter,
getAADAccessToken: mockAADAccessTokenGetter,
reportMetrics: mockMetricsReporter,
aadToken: expiredToken,
tenantID: "tenantID",
clientID: "clientID",
authClientFactory: mockAuthClientFactory,
registryHostGetter: mockRegistryHostGetter,
getAADAccessToken: mockAADAccessTokenGetter,
reportMetrics: mockMetricsReporter,
}

// Call Provide method
Expand Down Expand Up @@ -266,13 +266,13 @@ func TestWIAuthProvider_Provide_TokenRefreshFailure(t *testing.T) {

// Create WIAuthProvider with expired token
provider := WIAuthProvider{
aadToken: expiredToken,
tenantID: "tenantID",
clientID: "clientID",
authClientFactory: mockAuthClientFactory,
getRegistryHost: mockRegistryHostGetter,
getAADAccessToken: mockAADAccessTokenGetter,
reportMetrics: mockMetricsReporter,
aadToken: expiredToken,
tenantID: "tenantID",
clientID: "clientID",
authClientFactory: mockAuthClientFactory,
registryHostGetter: mockRegistryHostGetter,
getAADAccessToken: mockAADAccessTokenGetter,
reportMetrics: mockMetricsReporter,
}

// Call Provide method
Expand Down Expand Up @@ -314,13 +314,13 @@ func TestWIAuthProvider_Provide_InvalidHostName(t *testing.T) {

// Create WIAuthProvider with valid token
provider := WIAuthProvider{
aadToken: validToken,
tenantID: "tenantID",
clientID: "clientID",
authClientFactory: mockAuthClientFactory,
getRegistryHost: mockRegistryHostGetter,
getAADAccessToken: mockAADAccessTokenGetter,
reportMetrics: mockMetricsReporter,
aadToken: validToken,
tenantID: "tenantID",
clientID: "clientID",
authClientFactory: mockAuthClientFactory,
registryHostGetter: mockRegistryHostGetter,
getAADAccessToken: mockAADAccessTokenGetter,
reportMetrics: mockMetricsReporter,
}

// Call Provide method
Expand Down
11 changes: 10 additions & 1 deletion pkg/common/oras/authprovider/azure/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,13 @@ type AuthClientFactory interface {
// defaultAuthClientFactoryImpl is the default implementation of AuthClientFactory.
type defaultAuthClientFactoryImpl struct{}

// creates an AuthClient using the default factory implementation.
// Return an AuthClient and an error if the client creation fails.
func (f *defaultAuthClientFactoryImpl) CreateAuthClient(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) {
return defaultAuthClientFactory(serverURL, options)
}

// Define a helper function that creates an instance of AuthenticationClientWrapper.
func defaultAuthClientFactory(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) {
client, err := azcontainerregistry.NewAuthenticationClient(serverURL, options)
if err != nil {
Expand All @@ -49,14 +52,19 @@ type AuthenticationClientInterface interface {
ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType azcontainerregistry.PostContentSchemaGrantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error)
}

// Define the wrapper for AuthenticationClientInterface
type AuthenticationClientWrapper struct {
client AuthenticationClientInterface
}

// A wrapper method that calls the underlying AuthenticationClientInterface's method.
// Exchanges an AAD access token for an ACR refresh token.
func (w *AuthenticationClientWrapper) ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType azcontainerregistry.PostContentSchemaGrantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) {
return w.client.ExchangeAADAccessTokenForACRRefreshToken(ctx, grantType, service, options)
}

// define the interface for authentication operations.
// It includes the method for exchanging an AAD access token for an ACR refresh token.
type AuthClient interface {
ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType azcontainerregistry.PostContentSchemaGrantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error)
}
Expand All @@ -69,7 +77,8 @@ type RegistryHostGetter interface {
// defaultRegistryHostGetterImpl is the default implementation of RegistryHostGetter.
type defaultRegistryHostGetterImpl struct{}

// Retrieves the registry host name for a given artifact.
// It utilizes the provider's GetRegistryHostName function to perform the lookup.
func (g *defaultRegistryHostGetterImpl) GetRegistryHost(artifact string) (string, error) {
// Implement the logic to get the registry host
return provider.GetRegistryHostName(artifact)
}

0 comments on commit cbbf124

Please sign in to comment.