Skip to content

Commit

Permalink
chore: more unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: Shahram Kalantari <[email protected]>
  • Loading branch information
shahramk64 committed Oct 1, 2024
1 parent 88639d1 commit a96530e
Showing 1 changed file with 53 additions and 0 deletions.
53 changes: 53 additions & 0 deletions pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,15 @@ type mockAuthClient struct {
mock.Mock
}

type mockAzureAuth struct {
mock.Mock
}

func (m *mockAzureAuth) GetAADAccessToken(ctx context.Context, tenantID, clientID, resource string) (confidential.AuthResult, error) {
args := m.Called(ctx, tenantID, clientID, resource)
return args.Get(0).(confidential.AuthResult), args.Error(1)
}

func (m *mockAuthClient) ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) {
args := m.Called(ctx, grantType, service, options)
return args.Get(0).(azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse), args.Error(1)
Expand Down Expand Up @@ -177,10 +186,54 @@ func TestProvide_Success(t *testing.T) {
authConfig, err := provider.Provide(context.Background(), "artifact")

assert.NoError(t, err)
// Assert that GetAADAccessToken was not called
mockClient.AssertNotCalled(t, "GetAADAccessToken", mock.Anything, mock.Anything, mock.Anything, mock.Anything)
// Assert that the returned refresh token matches the expected one
assert.Equal(t, expectedRefreshToken, authConfig.Password)
}

func TestProvide_RefreshAAD(t *testing.T) {
// Arrange
mockAzureAuth := new(mockAzureAuth)
mockClient := new(mockAuthClient)

provider := &WIAuthProvider{
aadToken: confidential.AuthResult{
AccessToken: "mockToken",
ExpiresOn: time.Now(),
},
tenantID: "mockTenantID",
clientID: "mockClientID",
authClientFactory: func(_ string, _ *azcontainerregistry.AuthenticationClientOptions) (authClient, error) {
return mockClient, nil
},
getRegistryHost: func(_ string) (string, error) {
return "myregistry.azurecr.io", nil
},
getAADAccessToken: mockAzureAuth.GetAADAccessToken,
reportMetrics: func(_ context.Context, _ int64, _ string) {},
}

mockAzureAuth.On("GetAADAccessToken", mock.Anything, mock.Anything, mock.Anything, mock.Anything).
Return(confidential.AuthResult{AccessToken: "newAccessToken", ExpiresOn: time.Now().Add(time.Hour)}, nil)

mockClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, "access_token", "myregistry.azurecr.io", mock.Anything).
Return(azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse{
ACRRefreshToken: azcontainerregistry.ACRRefreshToken{RefreshToken: new(string)},
}, nil)

ctx := context.TODO()
artifact := "testArtifact"

// Act
_, err := provider.Provide(ctx, artifact)

assert.NoError(t, err)
// Assert that GetAADAccessToken was not called
mockAzureAuth.AssertCalled(t, "GetAADAccessToken", mock.Anything, mock.Anything, mock.Anything, mock.Anything)

}

func TestProvide_Failure_InvalidHostName(t *testing.T) {
provider := &WIAuthProvider{
getRegistryHost: func(_ string) (string, error) {
Expand Down

0 comments on commit a96530e

Please sign in to comment.