Skip to content

Commit

Permalink
fix: handle nil pointers
Browse files Browse the repository at this point in the history
  • Loading branch information
duffney committed Dec 3, 2024
1 parent a3828c9 commit fc8a17c
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 11 deletions.
40 changes: 29 additions & 11 deletions pkg/keymanagementprovider/azurekeyvault/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ func (c *keyKVClientImpl) GetKey(ctx context.Context, keyName string, keyVersion
return c.Client.GetKey(ctx, keyName, keyVersion, nil)
}

func (c *keyKVClientImpl) NewListKeyVersionsPager(name string, options *azkeys.ListKeyVersionsOptions) *runtime.Pager[azkeys.ListKeyVersionsResponse] {
return c.Client.NewListKeyVersionsPager(name, options)
func (c *keyKVClientImpl) NewListKeyVersionsPager(keyName string, options *azkeys.ListKeyVersionsOptions) *runtime.Pager[azkeys.ListKeyVersionsResponse] {
return c.Client.NewListKeyVersionsPager(keyName, options)
}

// GetSecret retrieves a secret from the keyvault
Expand Down Expand Up @@ -216,6 +216,10 @@ func (s *akvKMProvider) GetCertificates(ctx context.Context) (map[keymanagementp
return nil, nil, fmt.Errorf("failed to get certificate versions for objectName:%s, error: %w", keyVaultCert.Name, err)
}
for _, cert := range pager.Value {
if cert.ID == nil || cert.Attributes == nil || cert.Attributes.Created == nil {
logger.GetLogger(ctx, logOpt).Warnf("certificate %s, version %s, found invalid certificate object, id, attributes or created time must not be nil", keyVaultCert.Name, cert.ID.Version())
continue
}
versionInfo := VersionInfo{
Version: cert.ID.Version(),
Created: *cert.Attributes.Created,
Expand All @@ -224,9 +228,10 @@ func (s *akvKMProvider) GetCertificates(ctx context.Context) (map[keymanagementp
}
}

// sort the version history by created time
// Pager results are not sorted by created time, so we sort them here
// in ascending order (oldest to newest)
sortVersionHistory(versionHistory)
sortedVersionHistory := GetSortedVersions(versionHistory)
sortedVersionHistory := GetVersions(versionHistory)

if len(sortedVersionHistory) == 0 {
logger.GetLogger(ctx, logOpt).Infof("no versions found for certificate %s", keyVaultCert.Name)
Expand Down Expand Up @@ -260,6 +265,10 @@ func (s *akvKMProvider) GetCertificates(ctx context.Context) (map[keymanagementp
}

secretBundle := secretReponse.SecretBundle
if secretBundle.Attributes == nil || secretBundle.Attributes.Enabled == nil {
logger.GetLogger(ctx, logOpt).Warnf("certificate %s, version %s, found invalid secret bundle, attributes or attribute.enabled not be nil", keyVaultCert.Name, version)
continue
}
isEnabled := *secretBundle.Attributes.Enabled

certResult, certProperty, err := getCertsFromSecretBundle(ctx, secretBundle, keyVaultCert.Name, isEnabled)
Expand Down Expand Up @@ -295,6 +304,10 @@ func (s *akvKMProvider) GetKeys(ctx context.Context) (map[keymanagementprovider.
return nil, nil, fmt.Errorf("failed to get key versions for objectName:%s, error: %w", keyVaultKey.Name, err)
}
for _, key := range pager.Value {
if key.KID == nil || key.Attributes == nil || key.Attributes.Created == nil {
logger.GetLogger(ctx, logOpt).Warnf("key %s, version %s, found invalid key object, id, attributes or created time must not be nil", keyVaultKey.Name, key.KID.Version())
continue
}
versionInfo := VersionInfo{
Version: key.KID.Version(),
Created: *key.Attributes.Created,
Expand All @@ -303,9 +316,10 @@ func (s *akvKMProvider) GetKeys(ctx context.Context) (map[keymanagementprovider.
}
}

// sort the version history by created time
// Pager results are not sorted by created time, so we sort them here
// in ascending order (oldest to newest)
sortVersionHistory(versionHistory)
sortedVersionHistory := GetSortedVersions(versionHistory)
sortedVersionHistory := GetVersions(versionHistory)

// if no versions found, log and continue
if len(sortedVersionHistory) == 0 {
Expand All @@ -330,6 +344,10 @@ func (s *akvKMProvider) GetKeys(ctx context.Context) (map[keymanagementprovider.
return nil, nil, fmt.Errorf("failed to get key objectName:%s, objectVersion:%s, error: %w", keyVaultKey.Name, version, err)
}
keyBundle := keyResponse.KeyBundle
if keyBundle.Attributes == nil || keyBundle.Attributes.Enabled == nil {
logger.GetLogger(ctx, logOpt).Warnf("key %s, version %s, found invalid key bundle, attributes or attribute.enabled not be nil", keyVaultKey.Name, version)
continue
}
isEnabled := *keyBundle.Attributes.Enabled

if !isEnabled {
Expand Down Expand Up @@ -569,14 +587,14 @@ func isSecretDisabledError(err error) bool {
}

// sortVersionHistory sorts the version history by created time
func sortVersionHistory(versionHistory []VersionInfo) {
sort.Slice(versionHistory, func(i, j int) bool {
return versionHistory[i].Created.Before(versionHistory[j].Created)
func sortVersionHistory(sortedVersionHistory []VersionInfo) {
sort.Slice(sortedVersionHistory, func(i, j int) bool {
return sortedVersionHistory[i].Created.Before(sortedVersionHistory[j].Created)
})
}

// GetSortedVersions returns the sorted versions of the object
func GetSortedVersions(versionHistory []VersionInfo) []string {
// GetVersions returns the sorted versions of the object
func GetVersions(versionHistory []VersionInfo) []string {
sortedVersions := make([]string, 0, len(versionHistory))
for _, version := range versionHistory {
sortedVersions = append(sortedVersions, version.Version)
Expand Down
165 changes: 165 additions & 0 deletions pkg/keymanagementprovider/azurekeyvault/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,104 @@ func TestGetCertificates(t *testing.T) {
mockCertificateKVClient *MockCertificateKVClient
expectedErr bool
}{
{
name: "versionHistoryLimit is 0",
versionHistoryLimit: 0,
mockCertificateKVClient: &MockCertificateKVClient{},
mockSecretKVClient: &MockSecretKVClient{
GetSecretFunc: func(_ context.Context, _ string, _ string) (azsecrets.GetSecretResponse, error) {
return azsecrets.GetSecretResponse{
SecretBundle: azsecrets.SecretBundle{
ID: &secretID,
Kid: stringPtr("https://testkv.vault.azure.net/keys/key1"),
ContentType: stringPtr("application/x-pem-file"),
Attributes: &azsecrets.SecretAttributes{
Enabled: boolPtr(true),
},
},
}, nil
},
},
expectedErr: true,
},
{
name: "Certificate enabled with multiple versions nil attributes",
versionHistoryLimit: 2,
mockCertificateKVClient: &MockCertificateKVClient{
GetCertificateFunc: func(_ context.Context, _ string, _ string) (azcertificates.GetCertificateResponse, error) {
return azcertificates.GetCertificateResponse{
CertificateBundle: azcertificates.CertificateBundle{
ID: &certID,
KID: stringPtr("https://testkv.vault.azure.net/keys/key1"),
},
}, nil
},
NewListCertificateVersionsPagerFunc: func(_ string, _ *azcertificates.ListCertificateVersionsOptions) *runtime.Pager[azcertificates.ListCertificateVersionsResponse] {
pageCounter := 0
return runtime.NewPager(runtime.PagingHandler[azcertificates.ListCertificateVersionsResponse]{
More: func(resp azcertificates.ListCertificateVersionsResponse) bool {
return resp.CertificateListResult.NextLink != nil
},
Fetcher: func(_ context.Context, _ *azcertificates.ListCertificateVersionsResponse) (azcertificates.ListCertificateVersionsResponse, error) {
var resp azcertificates.ListCertificateVersionsResponse

if pageCounter == 0 {
resp = azcertificates.ListCertificateVersionsResponse{
CertificateListResult: azcertificates.CertificateListResult{
NextLink: stringPtr("https://testkv.vault.azure.net/certificates/cert1/versions?api-version=7.2"),
Value: []*azcertificates.CertificateItem{
{
ID: &certID,
Attributes: &azcertificates.CertificateAttributes{
Enabled: boolPtr(true),
Created: nil,
},
},
},
},
}
}

if pageCounter == 1 {
resp = azcertificates.ListCertificateVersionsResponse{
CertificateListResult: azcertificates.CertificateListResult{
NextLink: nil,
Value: []*azcertificates.CertificateItem{
{
ID: &certIDLatest,
Attributes: &azcertificates.CertificateAttributes{
Enabled: boolPtr(true),
Created: &certIDLatestCreated,
},
},
},
},
}
}

pageCounter++
return resp, nil
},
})
},
},
mockSecretKVClient: &MockSecretKVClient{
GetSecretFunc: func(_ context.Context, _ string, _ string) (azsecrets.GetSecretResponse, error) {
return azsecrets.GetSecretResponse{
SecretBundle: azsecrets.SecretBundle{
ID: &secretID,
Kid: stringPtr("https://testkv.vault.azure.net/keys/key1"),
ContentType: stringPtr("application/x-pem-file"),
Attributes: &azsecrets.SecretAttributes{
Enabled: boolPtr(true),
},
Value: stringPtr("-----BEGIN CERTIFICATE-----\nMIIC8TCCAdmgAwIBAgIUaNrwbhs/I1ecqUYdzD2xuAVNdmowDQYJKoZIhvcNAQEL\nBQAwKjEPMA0GA1UECgwGUmF0aWZ5MRcwFQYDVQQDDA5SYXRpZnkgUm9vdCBDQTAe\nFw0yMzA2MjEwMTIyMzdaFw0yNDA2MjAwMTIyMzdaMBkxFzAVBgNVBAMMDnJhdGlm\neS5kZWZhdWx0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAtskG1BUt\n4Fw2lbm53KbwZb1hnLmWdwRotZyznhhk/yrUDcq3uF6klwpk/E2IKfUKIo6doHSk\nXaEZXR68UtXygvA4wdg7xZ6kKpXy0gu+RxGE6CGtDHTyDDzITu+NBjo21ZSsyGpQ\nJeIKftUCHdwdygKf0CdJx8A29GBRpHGCmJadmt7tTzOnYjmbuPVLeqJo/Ex9qXcG\nZbxoxnxr5NCocFeKx+EbLo+k/KjdFB2PKnhgzxAaMMMP6eXPr8l5AlzkC83EmPvN\ntveuaBbamdlFkD+53TZeZlxt3GIdq93Iw/UpbQ/pvhbrztMT+UVEkm15sShfX8Xn\nL2st5A4n0V+66QIDAQABoyAwHjAMBgNVHRMBAf8EAjAAMA4GA1UdDwEB/wQEAwIH\ngDANBgkqhkiG9w0BAQsFAAOCAQEAGpOqozyfDSBjoTepsRroxxcZ4sq65gw45Bme\nm36BS6FG0WHIg3cMy6KIIBefTDSKrPkKNTtuF25AeGn9jM+26cnfDM78ZH0+Lnn7\n7hs0MA64WMPQaWs9/+89aM9NADV9vp2zdG4xMi6B7DruvKWyhJaNoRqK/qP6LdSQ\nw8M+21sAHvXgrRkQtJlVOzVhgwt36NOb1hzRlQiZB+nhv2Wbw7fbtAaADk3JAumf\nvM+YdPS1KfAFaYefm4yFd+9/C0KOkHico3LTbELO5hG0Mo/EYvtjM+Fljb42EweF\n3nAx1GSPe5Tn8p3h6RyJW5HIKozEKyfDuLS0ccB/nqT3oNjcTw==\n-----END CERTIFICATE-----\n-----BEGIN CERTIFICATE-----\nMIIDRTCCAi2gAwIBAgIUcC33VfaMhOnsl7avNTRVQozoVtUwDQYJKoZIhvcNAQEL\nBQAwKjEPMA0GA1UECgwGUmF0aWZ5MRcwFQYDVQQDDA5SYXRpZnkgUm9vdCBDQTAe\nFw0yMzA2MjEwMTIyMzZaFw0yMzA2MjIwMTIyMzZaMCoxDzANBgNVBAoMBlJhdGlm\neTEXMBUGA1UEAwwOUmF0aWZ5IFJvb3QgQ0EwggEiMA0GCSqGSIb3DQEBAQUAA4IB\nDwAwggEKAoIBAQDDFhDnyPrVDZaeRu6Tbg1a/iTwus+IuX+h8aKhKS1yHz4EF/Lz\nxCy7lNSQ9srGMMVumWuNom/ydIphff6PejZM1jFKPU6OQR/0JX5epcVIjbKa562T\nDguUxJ+h5V3EIyM4RqOWQ2g/xZo86x5TzyNJXiVdHHRvmDvUNwPpMeDjr/EHVAni\n5YQObxkJRiiZ7XOa5zz3YztVm8sSZAwPWroY1HIfvtP+KHpiNDIKSymmuJkH4SEr\nJn++iqN8na18a9DFBPTTrLPe3CxATGrMfosCMZ6LP3iFLLc/FaSpwcnugWdewsUK\nYs+sUY7jFWR7x7/1nyFWyRrQviM4f4TY+K7NAgMBAAGjYzBhMB0GA1UdDgQWBBQH\nYePW7QPP2p1utr3r6gqzEkKs+DAfBgNVHSMEGDAWgBQHYePW7QPP2p1utr3r6gqz\nEkKs+DAPBgNVHRMBAf8EBTADAQH/MA4GA1UdDwEB/wQEAwICBDANBgkqhkiG9w0B\nAQsFAAOCAQEAjKp4vx3bFaKVhAbQeTsDjWJgmXLK2vLgt74MiUwSF6t0wehlfszE\nIcJagGJsvs5wKFf91bnwiqwPjmpse/thPNBAxh1uEoh81tOklv0BN790vsVpq3t+\ncnUvWPiCZdRlAiGGFtRmKk3Keq4sM6UdiUki9s+wnxypHVb4wIpVxu5R271Lnp5I\n+rb2EQ48iblt4XZPczf/5QJdTgbItjBNbuO8WVPOqUIhCiFuAQziLtNUq3p81dHO\nQ2BPgmaitCpIUYHVYighLauBGCH8xOFzj4a4KbOxKdxyJTd0La/vRCKaUtJX67Lc\nfQYVR9HXQZ0YlmwPcmIG5v7wBfcW34NUvA==\n-----END CERTIFICATE-----\n"),
},
}, nil
},
},
expectedErr: false,
},
{
name: "Certificate enabled with multiple versions",
versionHistoryLimit: 3,
Expand Down Expand Up @@ -400,6 +498,7 @@ func TestGetCertificates(t *testing.T) {
},
{
name: "GetSecret error",
versionHistoryLimit: 1,
mockCertificateKVClient: &MockCertificateKVClient{},
mockSecretKVClient: &MockSecretKVClient{
GetSecretFunc: func(_ context.Context, _ string, _ string) (azsecrets.GetSecretResponse, error) {
Expand Down Expand Up @@ -572,6 +671,72 @@ func TestGetKeys(t *testing.T) {
},
expectedErr: false,
},
{
name: "Key enabled with multiple versions with nil attributes",
versionHistoryLimit: 2,
mockKeyKVClient: &MockKeyKVClient{
GetKeyFunc: func(_ context.Context, _ string, _ string) (azkeys.GetKeyResponse, error) {
return azkeys.GetKeyResponse{
KeyBundle: azkeys.KeyBundle{
Key: &azkeys.JSONWebKey{
KID: &keyID,
Kty: &keyTY,
N: []byte("n"),
E: []byte("e"),
},
Attributes: &azkeys.KeyAttributes{
Enabled: boolPtr(true),
},
},
}, nil
},
NewListKeyVersionsPagerFunc: func(_ string, _ *azkeys.ListKeyVersionsOptions) *runtime.Pager[azkeys.ListKeyVersionsResponse] {
pageCounter := 0
return runtime.NewPager(runtime.PagingHandler[azkeys.ListKeyVersionsResponse]{
More: func(resp azkeys.ListKeyVersionsResponse) bool {
return resp.KeyListResult.NextLink != nil
},
Fetcher: func(_ context.Context, _ *azkeys.ListKeyVersionsResponse) (azkeys.ListKeyVersionsResponse, error) {
var resp azkeys.ListKeyVersionsResponse
if pageCounter == 0 {
resp = azkeys.ListKeyVersionsResponse{
KeyListResult: azkeys.KeyListResult{
NextLink: stringPtr("https://testkv.vault.azure.net/keys/key1/versions?api-version=7.2"),
Value: []*azkeys.KeyItem{
{
KID: &keyID,
Attributes: &azkeys.KeyAttributes{
Created: nil,
},
},
},
},
}
}

if pageCounter == 1 {
resp = azkeys.ListKeyVersionsResponse{
KeyListResult: azkeys.KeyListResult{
Value: []*azkeys.KeyItem{
{
KID: &keyIDLatest,
Attributes: &azkeys.KeyAttributes{
Created: &keyCreatedLatest,
},
},
},
},
}
}

pageCounter++
return resp, nil
},
})
},
},
expectedErr: false,
},
{
name: "No versions returned by pager",
mockKeyKVClient: &MockKeyKVClient{
Expand Down

0 comments on commit fc8a17c

Please sign in to comment.