Skip to content

Commit

Permalink
mod: combine versionHistory len and value checking
Browse files Browse the repository at this point in the history
  • Loading branch information
duffney committed Dec 3, 2024
1 parent 9f42567 commit a3828c9
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 17 deletions.
39 changes: 25 additions & 14 deletions pkg/keymanagementprovider/azurekeyvault/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,6 @@ func (s *akvKMProvider) GetCertificates(ctx context.Context) (map[keymanagementp
logger.GetLogger(ctx, logOpt).Debugf("fetching secret from key vault, certName %v, certVersion %v, vaultURI: %v", keyVaultCert.Name, keyVaultCert.Version, s.vaultURI)

startTime := time.Now()
// if versionHistoryLimit is not set, set it to default value 1
if keyVaultCert.VersionHistoryLimit == 0 {
keyVaultCert.VersionHistoryLimit = versionHistoryLimitDefault
}

versionHistory := []VersionInfo{}

certVersionPager := s.certificateKVClient.NewListCertificateVersionsPager(keyVaultCert.Name, nil)
Expand All @@ -233,7 +228,17 @@ func (s *akvKMProvider) GetCertificates(ctx context.Context) (map[keymanagementp
sortVersionHistory(versionHistory)
sortedVersionHistory := GetSortedVersions(versionHistory)

// if versionHistoryLimit is greater than the number of versions, set it to the number of versions
if len(sortedVersionHistory) == 0 {
logger.GetLogger(ctx, logOpt).Infof("no versions found for certificate %s", keyVaultCert.Name)
continue
}

// if versionHistoryLimit isn't set default to 1 to avoid out of bounds error
if keyVaultCert.VersionHistoryLimit == 0 {
keyVaultCert.VersionHistoryLimit = versionHistoryLimitDefault
}

// Ensure that the versionHistoryLimit is not greater than the number of versions to avoid out of bounds error
if keyVaultCert.VersionHistoryLimit > len(sortedVersionHistory) {
keyVaultCert.VersionHistoryLimit = len(sortedVersionHistory)
}
Expand All @@ -251,7 +256,7 @@ func (s *akvKMProvider) GetCertificates(ctx context.Context) (map[keymanagementp
keymanagementprovider.DeleteCertificateFromMap(s.resource, mapKey)
continue
}
return nil, nil, fmt.Errorf("failed to get secret objectName: %s, objectVersion: %s, error: %w", keyVaultCert.Name, version, err)
return nil, nil, fmt.Errorf("failed to get secret objectName:%s, objectVersion:%s, error: %w", keyVaultCert.Name, version, err)
}

secretBundle := secretReponse.SecretBundle
Expand Down Expand Up @@ -281,18 +286,13 @@ func (s *akvKMProvider) GetKeys(ctx context.Context) (map[keymanagementprovider.

// fetch the key object from Key Vault
startTime := time.Now()
// if versionHistoryLimit is not set, set it to default value 1
if keyVaultKey.VersionHistoryLimit == 0 {
keyVaultKey.VersionHistoryLimit = versionHistoryLimitDefault
}

versionHistory := []VersionInfo{}

keyVersionPager := s.keyKVClient.NewListKeyVersionsPager(keyVaultKey.Name, nil)
for keyVersionPager.More() {
pager, err := keyVersionPager.NextPage(ctx)
if err != nil {
return nil, nil, fmt.Errorf("failed to get key versions for objectName: %s, error: %w", keyVaultKey.Name, err)
return nil, nil, fmt.Errorf("failed to get key versions for objectName:%s, error: %w", keyVaultKey.Name, err)
}
for _, key := range pager.Value {
versionInfo := VersionInfo{
Expand All @@ -307,12 +307,23 @@ func (s *akvKMProvider) GetKeys(ctx context.Context) (map[keymanagementprovider.
sortVersionHistory(versionHistory)
sortedVersionHistory := GetSortedVersions(versionHistory)

// if no versions found, log and continue
if len(sortedVersionHistory) == 0 {
logger.GetLogger(ctx, logOpt).Infof("no versions found for key %s", keyVaultKey.Name)
continue
}

// if versionHistoryLimit isn't set default to 1 to avoid out of bounds error
if keyVaultKey.VersionHistoryLimit == 0 {
keyVaultKey.VersionHistoryLimit = versionHistoryLimitDefault
}

// if versionHistoryLimit is greater than the number of versions, set it to the number of versions
if keyVaultKey.VersionHistoryLimit > len(sortedVersionHistory) {
keyVaultKey.VersionHistoryLimit = len(sortedVersionHistory)
}

// get the latest version of the key up to the limit
// get the latest version of the certificate up to the limit
for _, version := range sortedVersionHistory[len(sortedVersionHistory)-keyVaultKey.VersionHistoryLimit:] {
keyResponse, err := s.keyKVClient.GetKey(ctx, keyVaultKey.Name, version)
if err != nil {
Expand Down
38 changes: 35 additions & 3 deletions pkg/keymanagementprovider/azurekeyvault/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,8 @@ func TestGetCertificates(t *testing.T) {
expectedErr bool
}{
{
name: "Certificate enabled with multiple versions",
name: "Certificate enabled with multiple versions",
versionHistoryLimit: 3,
mockCertificateKVClient: &MockCertificateKVClient{
GetCertificateFunc: func(_ context.Context, _ string, _ string) (azcertificates.GetCertificateResponse, error) {
return azcertificates.GetCertificateResponse{
Expand Down Expand Up @@ -381,9 +382,24 @@ func TestGetCertificates(t *testing.T) {
},
expectedErr: false,
},
{
name: "No versions returned by pager",
mockCertificateKVClient: &MockCertificateKVClient{
NewListCertificateVersionsPagerFunc: func(_ string, _ *azcertificates.ListCertificateVersionsOptions) *runtime.Pager[azcertificates.ListCertificateVersionsResponse] {
return runtime.NewPager(runtime.PagingHandler[azcertificates.ListCertificateVersionsResponse]{
More: func(_ azcertificates.ListCertificateVersionsResponse) bool {
return false
},
Fetcher: func(_ context.Context, _ *azcertificates.ListCertificateVersionsResponse) (azcertificates.ListCertificateVersionsResponse, error) {
return azcertificates.ListCertificateVersionsResponse{}, nil
},
})
},
},
expectedErr: false,
},
{
name: "GetSecret error",
versionHistoryLimit: 1,
mockCertificateKVClient: &MockCertificateKVClient{},
mockSecretKVClient: &MockSecretKVClient{
GetSecretFunc: func(_ context.Context, _ string, _ string) (azsecrets.GetSecretResponse, error) {
Expand Down Expand Up @@ -492,7 +508,7 @@ func TestGetKeys(t *testing.T) {
}{
{
name: "Key enabled with multiple versions",
versionHistoryLimit: 2,
versionHistoryLimit: 3,
mockKeyKVClient: &MockKeyKVClient{
GetKeyFunc: func(_ context.Context, _ string, _ string) (azkeys.GetKeyResponse, error) {
return azkeys.GetKeyResponse{
Expand Down Expand Up @@ -556,6 +572,22 @@ func TestGetKeys(t *testing.T) {
},
expectedErr: false,
},
{
name: "No versions returned by pager",
mockKeyKVClient: &MockKeyKVClient{
NewListKeyVersionsPagerFunc: func(_ string, _ *azkeys.ListKeyVersionsOptions) *runtime.Pager[azkeys.ListKeyVersionsResponse] {
return runtime.NewPager(runtime.PagingHandler[azkeys.ListKeyVersionsResponse]{
More: func(_ azkeys.ListKeyVersionsResponse) bool {
return false
},
Fetcher: func(_ context.Context, _ *azkeys.ListKeyVersionsResponse) (azkeys.ListKeyVersionsResponse, error) {
return azkeys.ListKeyVersionsResponse{}, nil
},
})
},
},
expectedErr: false,
},
{
name: "GetKey error",
mockKeyKVClient: &MockKeyKVClient{
Expand Down

0 comments on commit a3828c9

Please sign in to comment.