Skip to content

Commit

Permalink
feat: nversion support for azkv
Browse files Browse the repository at this point in the history
Signed-off-by: Josh Duffney <[email protected]>
  • Loading branch information
duffney committed Nov 27, 2024
1 parent aae1aa6 commit 7fbc3e2
Show file tree
Hide file tree
Showing 3 changed files with 407 additions and 128 deletions.
206 changes: 153 additions & 53 deletions pkg/keymanagementprovider/azurekeyvault/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"fmt"
"io"
"net/http"
"sort"
"strconv"
"strings"
"time"
Expand All @@ -43,16 +44,18 @@ import (
"golang.org/x/crypto/pkcs12"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/Azure/azure-sdk-for-go/sdk/keyvault/azcertificates"
"github.com/Azure/azure-sdk-for-go/sdk/keyvault/azkeys"
"github.com/Azure/azure-sdk-for-go/sdk/keyvault/azsecrets"
)

const (
ProviderName string = "azurekeyvault"
PKCS12ContentType string = "application/x-pkcs12"
PEMContentType string = "application/x-pem-file"
ProviderName string = "azurekeyvault"
PKCS12ContentType string = "application/x-pkcs12"
PEMContentType string = "application/x-pem-file"
versionHistoryLimitDefault int = 1
)

var logOpt = logger.Option{
Expand Down Expand Up @@ -88,6 +91,8 @@ type akvKMProviderFactory struct{}
type keyKVClient interface {
// GetKey retrieves a key from the keyvault
GetKey(ctx context.Context, keyName string, keyVersion string) (azkeys.GetKeyResponse, error)
// NewListKeyVersionsPager retrieves a pager for listing key versions
NewListKeyVersionsPager(name string, options *azkeys.ListKeyVersionsOptions) *runtime.Pager[azkeys.ListKeyVersionsResponse]
}
type secretKVClient interface {
// GetSecret retrieves a secret from the keyvault
Expand All @@ -96,6 +101,8 @@ type secretKVClient interface {
type certificateKVClient interface {
// GetCertificate retrieves a certificate from the keyvault
GetCertificate(ctx context.Context, certificateName string, certificateVersion string) (azcertificates.GetCertificateResponse, error)
// NewListCertificateVersionsPager creates a new instance of the ListCertificateVersionsPager
NewListCertificateVersionsPager(certificateName string, options *azcertificates.ListCertificateVersionsOptions) *runtime.Pager[azcertificates.ListCertificateVersionsResponse]
}

type keyKVClientImpl struct {
Expand All @@ -113,11 +120,20 @@ func (c *certificateKVClientImpl) GetCertificate(ctx context.Context, certificat
return c.Client.GetCertificate(ctx, certificateName, certificateVersion, nil)
}

// NewListCertificateVersionsPager retrieves a pager for listing certificate versions
func (c *certificateKVClientImpl) NewListCertificateVersionsPager(certificateName string, options *azcertificates.ListCertificateVersionsOptions) *runtime.Pager[azcertificates.ListCertificateVersionsResponse] {
return c.Client.NewListCertificateVersionsPager(certificateName, options)
}

// GetKey retrieves a key from the keyvault
func (c *keyKVClientImpl) GetKey(ctx context.Context, keyName string, keyVersion string) (azkeys.GetKeyResponse, error) {
return c.Client.GetKey(ctx, keyName, keyVersion, nil)
}

func (c *keyKVClientImpl) NewListKeyVersionsPager(name string, options *azkeys.ListKeyVersionsOptions) *runtime.Pager[azkeys.ListKeyVersionsResponse] {
return c.Client.NewListKeyVersionsPager(name, options)
}

// GetSecret retrieves a secret from the keyvault
func (c *secretKVClientImpl) GetSecret(ctx context.Context, secretName string, secretVersion string) (azsecrets.GetSecretResponse, error) {
return c.Client.GetSecret(ctx, secretName, secretVersion, nil)
Expand Down Expand Up @@ -186,39 +202,68 @@ func (s *akvKMProvider) GetCertificates(ctx context.Context) (map[keymanagementp
logger.GetLogger(ctx, logOpt).Debugf("fetching secret from key vault, certName %v, certVersion %v, vaultURI: %v", keyVaultCert.Name, keyVaultCert.Version, s.vaultURI)

startTime := time.Now()
secretResponse, err := s.secretKVClient.GetSecret(ctx, keyVaultCert.Name, keyVaultCert.Version)
if err != nil {
if isSecretDisabledError(err) {
// if secret is disabled, get the version of the certificate for status
certResponse, err := s.certificateKVClient.GetCertificate(ctx, keyVaultCert.Name, keyVaultCert.Version)
if err != nil {
return nil, nil, fmt.Errorf("failed to get certificate objectName:%s, objectVersion:%s, error: %w", keyVaultCert.Name, keyVaultCert.Version, err)
}
certBundle := certResponse.CertificateBundle
keyVaultCert.Version = getObjectVersion(*certBundle.KID)
isEnabled := *certBundle.Attributes.Enabled
lastRefreshed := startTime.Format(time.RFC3339)
certProperty := getStatusProperty(keyVaultCert.Name, keyVaultCert.Version, lastRefreshed, isEnabled)
certsStatus = append(certsStatus, certProperty)
mapKey := keymanagementprovider.KMPMapKey{Name: keyVaultCert.Name, Version: keyVaultCert.Version, Enabled: isEnabled}
keymanagementprovider.DeleteCertificateFromMap(s.resource, mapKey)
continue
// if versionHistoryLimit is not set, set it to default value 1
if keyVaultCert.VersionHistoryLimit == 0 {
keyVaultCert.VersionHistoryLimit = versionHistoryLimitDefault
}

var versionHistory []struct {
Version string
Created time.Time
}

certVersionPager := s.certificateKVClient.NewListCertificateVersionsPager(keyVaultCert.Name, nil)
for certVersionPager.More() {
pager, err := certVersionPager.NextPage(ctx)
if err != nil {
return nil, nil, fmt.Errorf("failed to get certificate versions for objectName:%s, error: %w", keyVaultCert.Name, err)
}
for _, cert := range pager.Value {
versionHistory = append(versionHistory, struct {
Version string
Created time.Time
}{Version: cert.ID.Version(), Created: *cert.Attributes.Created})
}
return nil, nil, fmt.Errorf("failed to get secret objectName:%s, objectVersion:%s, error: %w", keyVaultCert.Name, keyVaultCert.Version, err)
}

secretBundle := secretResponse.SecretBundle
isEnabled := *secretBundle.Attributes.Enabled
// sort the version history by created time
sortVersionHistory(versionHistory)
sortedVersionHistory := GetSortedVersions(versionHistory)

certResult, certProperty, err := getCertsFromSecretBundle(ctx, secretBundle, keyVaultCert.Name, isEnabled)
if err != nil {
return nil, nil, fmt.Errorf("failed to get certificates from secret bundle:%w", err)
// if versionHistoryLimit is greater than the number of versions, set it to the number of versions
if keyVaultCert.VersionHistoryLimit > len(sortedVersionHistory) {
keyVaultCert.VersionHistoryLimit = len(sortedVersionHistory)

Check warning on line 235 in pkg/keymanagementprovider/azurekeyvault/provider.go

View check run for this annotation

Codecov / codecov/patch

pkg/keymanagementprovider/azurekeyvault/provider.go#L235

Added line #L235 was not covered by tests
}

metrics.ReportAKVCertificateDuration(ctx, time.Since(startTime).Milliseconds(), keyVaultCert.Name)
certsStatus = append(certsStatus, certProperty...)
certMapKey := keymanagementprovider.KMPMapKey{Name: keyVaultCert.Name, Version: keyVaultCert.Version, Enabled: isEnabled}
certsMap[certMapKey] = certResult
// get the latest version of the certificate up to the limit
for _, version := range sortedVersionHistory[len(sortedVersionHistory)-keyVaultCert.VersionHistoryLimit:] {
secretReponse, err := s.secretKVClient.GetSecret(ctx, keyVaultCert.Name, version)
if err != nil {
if isSecretDisabledError(err) {
isEnabled := false
lastRefreshed := startTime.Format(time.RFC3339)
certProperty := getStatusProperty(keyVaultCert.Name, version, lastRefreshed, isEnabled)
certsStatus = append(certsStatus, certProperty)
mapKey := keymanagementprovider.KMPMapKey{Name: keyVaultCert.Name, Version: version, Enabled: isEnabled}
keymanagementprovider.DeleteCertificateFromMap(s.resource, mapKey)
continue
}
return nil, nil, fmt.Errorf("failed to get secret objectName:%s, objectVersion:%s, error: %w", keyVaultCert.Name, version, err)
}

secretBundle := secretReponse.SecretBundle
isEnabled := *secretBundle.Attributes.Enabled

certResult, certProperty, err := getCertsFromSecretBundle(ctx, secretBundle, keyVaultCert.Name, isEnabled)
if err != nil {
return nil, nil, fmt.Errorf("failed to get certificates from secret bundle:%w", err)
}

metrics.ReportAKVCertificateDuration(ctx, time.Since(startTime).Milliseconds(), keyVaultCert.Name)
certsStatus = append(certsStatus, certProperty...)
certMapKey := keymanagementprovider.KMPMapKey{Name: keyVaultCert.Name, Version: version, Enabled: isEnabled}
certsMap[certMapKey] = certResult
}
}
return certsMap, getStatusMap(certsStatus, types.CertificatesStatus), nil
}
Expand All @@ -233,33 +278,66 @@ func (s *akvKMProvider) GetKeys(ctx context.Context) (map[keymanagementprovider.

// fetch the key object from Key Vault
startTime := time.Now()
keyResponse, err := s.keyKVClient.GetKey(ctx, keyVaultKey.Name, keyVaultKey.Version)
if err != nil {
return nil, nil, fmt.Errorf("failed to get key objectName:%s, objectVersion:%s, error: %w", keyVaultKey.Name, keyVaultKey.Version, err)
// if versionHistoryLimit is not set, set it to default value 1
if keyVaultKey.VersionHistoryLimit == 0 {
keyVaultKey.VersionHistoryLimit = versionHistoryLimitDefault
}
keyBundle := keyResponse.KeyBundle
isEnabled := *keyBundle.Attributes.Enabled
// if version is set as "" in the config, use the version from the key bundle
keyVaultKey.Version = getObjectVersion(string(*keyBundle.Key.KID))

if !isEnabled {
startTime := time.Now()
lastRefreshed := startTime.Format(time.RFC3339)
properties := getStatusProperty(keyVaultKey.Name, keyVaultKey.Version, lastRefreshed, isEnabled)
keysStatus = append(keysStatus, properties)
mapKey := keymanagementprovider.KMPMapKey{Name: keyVaultKey.Name, Version: keyVaultKey.Version, Enabled: isEnabled}
keymanagementprovider.DeleteKeyFromMap(s.resource, mapKey)
continue

var versionHistory []struct {
Version string
Created time.Time
}

publicKey, err := getKeyFromKeyBundle(keyBundle)
if err != nil {
return nil, nil, fmt.Errorf("failed to get key from key bundle:%w", err)
keyVersionPager := s.keyKVClient.NewListKeyVersionsPager(keyVaultKey.Name, nil)
for keyVersionPager.More() {
pager, err := keyVersionPager.NextPage(ctx)
if err != nil {
return nil, nil, fmt.Errorf("failed to get key versions for objectName:%s, error: %w", keyVaultKey.Name, err)
}
for _, key := range pager.Value {
versionHistory = append(versionHistory, struct {
Version string
Created time.Time
}{Version: key.KID.Version(), Created: *key.Attributes.Created})
}
}

// sort the version history by created time
sortVersionHistory(versionHistory)
sortedVersionHistory := GetSortedVersions(versionHistory)

// if versionHistoryLimit is greater than the number of versions, set it to the number of versions
if keyVaultKey.VersionHistoryLimit > len(sortedVersionHistory) {
keyVaultKey.VersionHistoryLimit = len(sortedVersionHistory)
}

Check warning on line 312 in pkg/keymanagementprovider/azurekeyvault/provider.go

View check run for this annotation

Codecov / codecov/patch

pkg/keymanagementprovider/azurekeyvault/provider.go#L311-L312

Added lines #L311 - L312 were not covered by tests

// get the latest version of the certificate up to the limit
for _, version := range sortedVersionHistory[len(sortedVersionHistory)-keyVaultKey.VersionHistoryLimit:] {
keyResponse, err := s.keyKVClient.GetKey(ctx, keyVaultKey.Name, version)
if err != nil {
return nil, nil, fmt.Errorf("failed to get key objectName:%s, objectVersion:%s, error: %w", keyVaultKey.Name, version, err)
}
keyBundle := keyResponse.KeyBundle
isEnabled := *keyBundle.Attributes.Enabled

if !isEnabled {
lastRefresh := time.Now().Format(time.RFC3339)
keyProperties := getStatusProperty(keyVaultKey.Name, version, lastRefresh, isEnabled)
keysStatus = append(keysStatus, keyProperties)
mapKey := keymanagementprovider.KMPMapKey{Name: keyVaultKey.Name, Version: version, Enabled: isEnabled}
keymanagementprovider.DeleteKeyFromMap(s.resource, mapKey)
continue
}

publicKey, err := getKeyFromKeyBundle(keyBundle)
if err != nil {
return nil, nil, fmt.Errorf("failed to get key from key bundle:%w", err)
}
keysMap[keymanagementprovider.KMPMapKey{Name: keyVaultKey.Name, Version: version, Enabled: isEnabled}] = publicKey
metrics.ReportAKVCertificateDuration(ctx, time.Since(startTime).Milliseconds(), keyVaultKey.Name)
keyProperties := getStatusProperty(keyVaultKey.Name, version, time.Now().Format(time.RFC3339), isEnabled)
keysStatus = append(keysStatus, keyProperties)
}
keysMap[keymanagementprovider.KMPMapKey{Name: keyVaultKey.Name, Version: keyVaultKey.Version, Enabled: isEnabled}] = publicKey
metrics.ReportAKVCertificateDuration(ctx, time.Since(startTime).Milliseconds(), keyVaultKey.Name)
properties := getStatusProperty(keyVaultKey.Name, keyVaultKey.Version, time.Now().Format(time.RFC3339), isEnabled)
keysStatus = append(keysStatus, properties)
}

return keysMap, getStatusMap(keysStatus, types.KeysStatus), nil
Expand Down Expand Up @@ -478,6 +556,28 @@ func isSecretDisabledError(err error) bool {
return false
}

// sortVersionHistory sorts the version history by created time
func sortVersionHistory(versionHistory []struct {
Version string
Created time.Time
}) {
sort.Slice(versionHistory, func(i, j int) bool {
return versionHistory[i].Created.Before(versionHistory[j].Created)
})
}

// GetSortedVersions returns the sorted versions of the object
func GetSortedVersions(versionHistory []struct {
Version string
Created time.Time
}) []string {
sortedVersions := make([]string, 0, len(versionHistory))
for _, version := range versionHistory {
sortedVersions = append(sortedVersions, version.Version)
}
return sortedVersions
}

// validate checks vaultURI, tenantID, clientID are set and all certificates/keys have a name
func (s *akvKMProvider) validate() error {
if s.vaultURI == "" {
Expand Down
Loading

0 comments on commit 7fbc3e2

Please sign in to comment.