Skip to content

Commit

Permalink
chore: move common code to helper.go
Browse files Browse the repository at this point in the history
Signed-off-by: Shahram Kalantari <[email protected]>
  • Loading branch information
shahramk64 committed Oct 14, 2024
1 parent b2494ec commit a40441b
Showing 6 changed files with 196 additions and 91 deletions.
34 changes: 17 additions & 17 deletions pkg/common/oras/authprovider/azure/azureidentity.go
Original file line number Diff line number Diff line change
@@ -32,20 +32,34 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry"
)

type azureManagedIdentityProviderFactory struct{}

// ManagedIdentityTokenGetter defines an interface for getting a managed identity token.
type ManagedIdentityTokenGetter interface {
GetManagedIdentityToken(ctx context.Context, clientID string) (azcore.AccessToken, error)
}

// DefaultManagedIdentityTokenGetterImpl is the default implementation of AADAccessTokenGetter.
// DefaultManagedIdentityTokenGetterImpl is the default implementation of getManagedIdentityToken.
type DefaultManagedIdentityTokenGetterImpl struct{}

func (g *DefaultManagedIdentityTokenGetterImpl) GetManagedIdentityToken(ctx context.Context, clientID string) (azcore.AccessToken, error) {
return getManagedIdentityToken(ctx, clientID)
}

func getManagedIdentityToken(ctx context.Context, clientID string) (azcore.AccessToken, error) {
id := azidentity.ClientID(clientID)
opts := azidentity.ManagedIdentityCredentialOptions{ID: id}
cred, err := azidentity.NewManagedIdentityCredential(&opts)
if err != nil {
return azcore.AccessToken{}, err
}
scopes := []string{AADResource}
if cred != nil {
return cred.GetToken(ctx, policy.TokenRequestOptions{Scopes: scopes})
}
return azcore.AccessToken{}, re.ErrorCodeConfigInvalid.WithComponentType(re.AuthProvider).WithDetail("config is nil pointer for GetServicePrincipalToken")
}

type azureManagedIdentityProviderFactory struct{}

type MIAuthProvider struct {
identityToken azcore.AccessToken
clientID string
@@ -185,17 +199,3 @@ func (d *MIAuthProvider) Provide(ctx context.Context, artifact string) (provider

return authConfig, nil
}

func getManagedIdentityToken(ctx context.Context, clientID string) (azcore.AccessToken, error) {
id := azidentity.ClientID(clientID)
opts := azidentity.ManagedIdentityCredentialOptions{ID: id}
cred, err := azidentity.NewManagedIdentityCredential(&opts)
if err != nil {
return azcore.AccessToken{}, err
}
scopes := []string{AADResource}
if cred != nil {
return cred.GetToken(ctx, policy.TokenRequestOptions{Scopes: scopes})
}
return azcore.AccessToken{}, re.ErrorCodeConfigInvalid.WithComponentType(re.AuthProvider).WithDetail("config is nil pointer for GetServicePrincipalToken")
}
4 changes: 2 additions & 2 deletions pkg/common/oras/authprovider/azure/azureidentity_test.go
Original file line number Diff line number Diff line change
@@ -146,7 +146,7 @@ func TestMIAuthProvider_Provide_TokenRefreshSuccess(t *testing.T) {

// Define token values
expiredToken := azcore.AccessToken{Token: "expired_token", ExpiresOn: time.Now().Add(-10 * time.Minute)}
newTokenString := "refreshed_token"
newTokenString := "refreshed"
newAADToken := azcore.AccessToken{Token: "new_token", ExpiresOn: time.Now().Add(10 * time.Minute)}
refreshToken := azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse{
ACRRefreshToken: azcontainerregistry.ACRRefreshToken{RefreshToken: &newTokenString},
@@ -174,7 +174,7 @@ func TestMIAuthProvider_Provide_TokenRefreshSuccess(t *testing.T) {

// Validate success and token refresh
assert.NoError(t, err)
assert.Equal(t, "refreshed_token", authConfig.Password)
assert.Equal(t, "refreshed", authConfig.Password)
}

// Test failed token refresh
46 changes: 15 additions & 31 deletions pkg/common/oras/authprovider/azure/azureworkloadidentity.go
Original file line number Diff line number Diff line change
@@ -25,61 +25,45 @@ import (
re "github.com/ratify-project/ratify/errors"
"github.com/ratify-project/ratify/internal/logger"
provider "github.com/ratify-project/ratify/pkg/common/oras/authprovider"
"github.com/ratify-project/ratify/pkg/utils/azureauth"

"github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential"
)

// AuthClientFactory defines an interface for creating an authentication client.
type AuthClientFactory interface {
CreateAuthClient(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error)
}

// RegistryHostGetter defines an interface for getting the registry host.
type RegistryHostGetter interface {
GetRegistryHost(artifact string) (string, error)
}

// AADAccessTokenGetter defines an interface for getting an AAD access token.
type AADAccessTokenGetter interface {
GetAADAccessToken(ctx context.Context, tenantID, clientID, resource string) (confidential.AuthResult, error)
}

// MetricsReporter defines an interface for reporting metrics.
type MetricsReporter interface {
ReportMetrics(ctx context.Context, duration int64, artifactHostName string)
}

// DefaultAuthClientFactoryImpl is the default implementation of AuthClientFactory.
type DefaultAuthClientFactoryImpl struct{}

func (f *DefaultAuthClientFactoryImpl) CreateAuthClient(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) {
return DefaultAuthClientFactory(serverURL, options)
}

// DefaultRegistryHostGetterImpl is the default implementation of RegistryHostGetter.
type DefaultRegistryHostGetterImpl struct{}

func (g *DefaultRegistryHostGetterImpl) GetRegistryHost(artifact string) (string, error) {
// Implement the logic to get the registry host
return provider.GetRegistryHostName(artifact)
// return artifactHost, nil // Replace with actual logic
}

// DefaultAADAccessTokenGetterImpl is the default implementation of AADAccessTokenGetter.
type DefaultAADAccessTokenGetterImpl struct{}

func (g *DefaultAADAccessTokenGetterImpl) GetAADAccessToken(ctx context.Context, tenantID, clientID, resource string) (confidential.AuthResult, error) {
return DefaultGetAADAccessToken(ctx, tenantID, clientID, resource)
}

func DefaultGetAADAccessToken(ctx context.Context, tenantID, clientID, resource string) (confidential.AuthResult, error) {
return azureauth.GetAADAccessToken(ctx, tenantID, clientID, resource)
}

// MetricsReporter defines an interface for reporting metrics.
type MetricsReporter interface {
ReportMetrics(ctx context.Context, duration int64, artifactHostName string)
}

// DefaultMetricsReporterImpl is the default implementation of MetricsReporter.
type DefaultMetricsReporterImpl struct{}

func (r *DefaultMetricsReporterImpl) ReportMetrics(ctx context.Context, duration int64, artifactHostName string) {
DefaultReportMetrics(ctx, duration, artifactHostName)
}

func DefaultReportMetrics(ctx context.Context, duration int64, artifactHostName string) {
logger.GetLogger(ctx, logOpt).Infof("Metrics Report: Duration=%dms, Host=%s", duration, artifactHostName)
}

type AzureWIProviderFactory struct{} //nolint:revive // ignore linter to have unique type name

type WIAuthProvider struct {
aadToken confidential.AuthResult
tenantID string
30 changes: 0 additions & 30 deletions pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go
Original file line number Diff line number Diff line change
@@ -31,26 +31,6 @@ import (
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential"
)

// MockAuthClientFactory for creating AuthClient
type MockAuthClientFactory struct {
mock.Mock
}

func (m *MockAuthClientFactory) CreateAuthClient(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) {
args := m.Called(serverURL, options)
return args.Get(0).(AuthClient), args.Error(1)
}

// MockRegistryHostGetter for retrieving registry host
type MockRegistryHostGetter struct {
mock.Mock
}

func (m *MockRegistryHostGetter) GetRegistryHost(artifact string) (string, error) {
args := m.Called(artifact)
return args.String(0), args.Error(1)
}

// MockAADAccessTokenGetter for retrieving AAD access token
type MockAADAccessTokenGetter struct {
mock.Mock
@@ -70,16 +50,6 @@ func (m *MockMetricsReporter) ReportMetrics(ctx context.Context, duration int64,
m.Called(ctx, duration, artifactHostName)
}

// MockAuthClient for the Azure auth client
type MockAuthClient struct {
mock.Mock
}

func (m *MockAuthClient) ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) {
args := m.Called(ctx, grantType, service, options)
return args.Get(0).(azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse), args.Error(1)
}

// Test for successful Provide function
func TestWIAuthProvider_Provide_Success(t *testing.T) {
// Mock all dependencies
37 changes: 26 additions & 11 deletions pkg/common/oras/authprovider/azure/helper.go
Original file line number Diff line number Diff line change
@@ -19,11 +19,21 @@ import (
"context"

"github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential"
"github.com/ratify-project/ratify/internal/logger"
"github.com/ratify-project/ratify/pkg/utils/azureauth"
provider "github.com/ratify-project/ratify/pkg/common/oras/authprovider"
)

// AuthClientFactory defines an interface for creating an authentication client.
type AuthClientFactory interface {
CreateAuthClient(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error)
}

// DefaultAuthClientFactoryImpl is the default implementation of AuthClientFactory.
type DefaultAuthClientFactoryImpl struct{}

func (f *DefaultAuthClientFactoryImpl) CreateAuthClient(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) {
return DefaultAuthClientFactory(serverURL, options)
}

func DefaultAuthClientFactory(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) {
client, err := azcontainerregistry.NewAuthenticationClient(serverURL, options)
if err != nil {
@@ -32,14 +42,6 @@ func DefaultAuthClientFactory(serverURL string, options *azcontainerregistry.Aut
return &AuthenticationClientWrapper{client: client}, nil
}

func DefaultGetAADAccessToken(ctx context.Context, tenantID, clientID, resource string) (confidential.AuthResult, error) {
return azureauth.GetAADAccessToken(ctx, tenantID, clientID, resource)
}

func DefaultReportMetrics(ctx context.Context, duration int64, artifactHostName string) {
logger.GetLogger(ctx, logOpt).Infof("Metrics Report: Duration=%dms, Host=%s", duration, artifactHostName)
}

type AuthenticationClientWrapper struct {
client *azcontainerregistry.AuthenticationClient
}
@@ -51,3 +53,16 @@ func (w *AuthenticationClientWrapper) ExchangeAADAccessTokenForACRRefreshToken(c
type AuthClient interface {
ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error)
}

// RegistryHostGetter defines an interface for getting the registry host.
type RegistryHostGetter interface {
GetRegistryHost(artifact string) (string, error)
}

// DefaultRegistryHostGetterImpl is the default implementation of RegistryHostGetter.
type DefaultRegistryHostGetterImpl struct{}

func (g *DefaultRegistryHostGetterImpl) GetRegistryHost(artifact string) (string, error) {
// Implement the logic to get the registry host
return provider.GetRegistryHostName(artifact)
}
136 changes: 136 additions & 0 deletions pkg/common/oras/authprovider/azure/helper_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
Copyright The Ratify Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package azure

import (
"context"

"github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry"
"github.com/stretchr/testify/mock"
)

// MockAuthClient is a mock implementation of AuthClient.
type MockAuthClient struct {
mock.Mock
}

// Mock method for ExchangeAADAccessTokenForACRRefreshToken
func (m *MockAuthClient) ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) {
args := m.Called(ctx, grantType, service, options)
return args.Get(0).(azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse), args.Error(1)
}

// MockAuthClientFactory is a mock implementation of AuthClientFactory.
type MockAuthClientFactory struct {
mock.Mock
}

// Mock method for CreateAuthClient
func (m *MockAuthClientFactory) CreateAuthClient(serverURL string, options *azcontainerregistry.AuthenticationClientOptions) (AuthClient, error) {
args := m.Called(serverURL, options)
return args.Get(0).(AuthClient), args.Error(1)
}

// MockRegistryHostGetter is a mock implementation of RegistryHostGetter.
type MockRegistryHostGetter struct {
mock.Mock
}

// Mock method for GetRegistryHost
func (m *MockRegistryHostGetter) GetRegistryHost(artifact string) (string, error) {
args := m.Called(artifact)
return args.String(0), args.Error(1)
}

// // TestDefaultAuthClientFactoryImpl tests the default factory implementation.
// func TestDefaultAuthClientFactoryImpl(t *testing.T) {
// mockFactory := new(MockAuthClientFactory)
// mockAuthClient := new(MockAuthClient)

// serverURL := "https://example.azurecr.io"
// options := &azcontainerregistry.AuthenticationClientOptions{}

// // Set up expectations
// mockFactory.On("CreateAuthClient", serverURL, options).Return(mockAuthClient, nil)

// factory := &DefaultAuthClientFactoryImpl{}
// client, err := factory.CreateAuthClient(serverURL, options)

// // Verify expectations
// mockFactory.AssertCalled(t, "CreateAuthClient", serverURL, options)
// assert.NoError(t, err)
// assert.NotNil(t, client)
// }

// // TestDefaultAuthClientFactory_Error tests error handling during client creation.
// func TestDefaultAuthClientFactory_Error(t *testing.T) {
// mockFactory := new(MockAuthClientFactory)

// serverURL := "https://example.azurecr.io"
// options := &azcontainerregistry.AuthenticationClientOptions{}
// expectedError := errors.New("failed to create client")

// // Set up expectations
// mockFactory.On("CreateAuthClient", serverURL, options).Return(nil, expectedError)

// factory := &DefaultAuthClientFactoryImpl{}
// client, err := factory.CreateAuthClient(serverURL, options)

// // Verify expectations
// mockFactory.AssertCalled(t, "CreateAuthClient", serverURL, options)
// assert.Error(t, err)
// assert.Nil(t, client)
// assert.Equal(t, expectedError, err)
// }

// // TestGetRegistryHost tests the GetRegistryHost function.
// func TestGetRegistryHost(t *testing.T) {
// mockGetter := new(MockRegistryHostGetter)

// artifact := "test/artifact"
// expectedHost := "example.azurecr.io"

// // Set up expectations
// mockGetter.On("GetRegistryHost", artifact).Return(expectedHost, nil)

// getter := &DefaultRegistryHostGetterImpl{}
// host, err := getter.GetRegistryHost(artifact)

// // Verify expectations
// mockGetter.AssertCalled(t, "GetRegistryHost", artifact)
// assert.NoError(t, err)
// assert.Equal(t, expectedHost, host)
// }

// // TestGetRegistryHost_Error tests error handling in GetRegistryHost.
// func TestGetRegistryHost_Error(t *testing.T) {
// mockGetter := new(MockRegistryHostGetter)

// artifact := "test/artifact"
// expectedError := errors.New("failed to get registry host")

// // Set up expectations
// mockGetter.On("GetRegistryHost", artifact).Return("", expectedError)

// getter := &DefaultRegistryHostGetterImpl{}
// host, err := getter.GetRegistryHost(artifact)

// // Verify expectations
// mockGetter.AssertCalled(t, "GetRegistryHost", artifact)
// assert.Error(t, err)
// assert.Empty(t, host)
// assert.Equal(t, expectedError, err)
// }

0 comments on commit a40441b

Please sign in to comment.