Skip to content

Commit

Permalink
Fix bug with Vault CA provider where updating
Browse files Browse the repository at this point in the history
RootPKIPath but not IntermediatePKIPath would
not update leaf signing certs with the new root.
  • Loading branch information
Chris S. Kim committed Jul 12, 2023
1 parent f51a9d2 commit 89ceafc
Show file tree
Hide file tree
Showing 10 changed files with 209 additions and 118 deletions.
48 changes: 24 additions & 24 deletions agent/connect/ca/mock_Provider.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion agent/connect/ca/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ type PrimaryProvider interface {
// Depending on the provider and its configuration, GenerateCAChain may return
// a single root certificate or a chain of certs. The provider should return an
// existing CA chain if one exists or generate a new one and return it.
GenerateCAChain() (CAChainResult, error)
GenerateCAChain() (string, error)

// SignIntermediate will validate the CSR to ensure the trust domain in the
// URI SAN matches the local one and that basic constraints for a CA
Expand Down
10 changes: 5 additions & 5 deletions agent/connect/ca/provider_aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,19 +140,19 @@ func (a *AWSProvider) State() (map[string]string, error) {
}

// GenerateCAChain implements Provider
func (a *AWSProvider) GenerateCAChain() (CAChainResult, error) {
func (a *AWSProvider) GenerateCAChain() (string, error) {
if !a.isPrimary {
return CAChainResult{}, fmt.Errorf("provider is not the root certificate authority")
return "", fmt.Errorf("provider is not the root certificate authority")
}

if err := a.ensureCA(); err != nil {
return CAChainResult{}, err
return "", err
}

if a.rootPEM == "" {
return CAChainResult{}, fmt.Errorf("AWS CA provider not fully Initialized")
return "", fmt.Errorf("AWS CA provider not fully Initialized")
}
return CAChainResult{PEM: a.rootPEM}, nil
return a.rootPEM, nil
}

// ensureCA loads the CA resource to check it exists if configured by User or in
Expand Down
18 changes: 9 additions & 9 deletions agent/connect/ca/provider_consul.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,25 +155,25 @@ func (c *ConsulProvider) State() (map[string]string, error) {
}

// GenerateCAChain initializes a new root certificate and private key if needed.
func (c *ConsulProvider) GenerateCAChain() (CAChainResult, error) {
func (c *ConsulProvider) GenerateCAChain() (string, error) {
providerState, err := c.getState()
if err != nil {
return CAChainResult{}, err
return "", err
}

if !c.isPrimary {
return CAChainResult{}, fmt.Errorf("provider is not the root certificate authority")
return "", fmt.Errorf("provider is not the root certificate authority")
}
if providerState.RootCert != "" {
return CAChainResult{PEM: providerState.RootCert}, nil
return providerState.RootCert, nil
}

// Generate a private key if needed
newState := *providerState
if c.config.PrivateKey == "" {
_, pk, err := connect.GeneratePrivateKeyWithConfig(c.config.PrivateKeyType, c.config.PrivateKeyBits)
if err != nil {
return CAChainResult{}, err
return "", err
}
newState.PrivateKey = pk
} else {
Expand All @@ -184,12 +184,12 @@ func (c *ConsulProvider) GenerateCAChain() (CAChainResult, error) {
if c.config.RootCert == "" {
nextSerial, err := c.incrementAndGetNextSerialNumber()
if err != nil {
return CAChainResult{}, fmt.Errorf("error computing next serial number: %v", err)
return "", fmt.Errorf("error computing next serial number: %v", err)
}

ca, err := c.generateCA(newState.PrivateKey, nextSerial, c.config.RootCertTTL)
if err != nil {
return CAChainResult{}, fmt.Errorf("error generating CA: %v", err)
return "", fmt.Errorf("error generating CA: %v", err)
}
newState.RootCert = ca
} else {
Expand All @@ -202,10 +202,10 @@ func (c *ConsulProvider) GenerateCAChain() (CAChainResult, error) {
ProviderState: &newState,
}
if _, err := c.Delegate.ApplyCARequest(args); err != nil {
return CAChainResult{}, err
return "", err
}

return CAChainResult{PEM: newState.RootCert}, nil
return newState.RootCert, nil
}

// GenerateIntermediateCSR creates a private key and generates a CSR
Expand Down
31 changes: 10 additions & 21 deletions agent/connect/ca/provider_vault.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,9 @@ func (v *VaultProvider) State() (map[string]string, error) {
}

// GenerateCAChain mounts and initializes a new root PKI backend if needed.
func (v *VaultProvider) GenerateCAChain() (CAChainResult, error) {
func (v *VaultProvider) GenerateCAChain() (string, error) {
if !v.isPrimary {
return CAChainResult{}, fmt.Errorf("provider is not the root certificate authority")
return "", fmt.Errorf("provider is not the root certificate authority")
}

// Set up the root PKI backend if necessary.
Expand All @@ -302,15 +302,15 @@ func (v *VaultProvider) GenerateCAChain() (CAChainResult, error) {
},
})
if err != nil {
return CAChainResult{}, fmt.Errorf("failed to mount root CA backend: %w", err)
return "", fmt.Errorf("failed to mount root CA backend: %w", err)
}

// We want to initialize afterwards
fallthrough
case ErrBackendNotInitialized:
uid, err := connect.CompactUID()
if err != nil {
return CAChainResult{}, err
return "", err
}
resp, err := v.writeNamespaced(v.config.RootPKINamespace, v.config.RootPKIPath+"root/generate/internal", map[string]interface{}{
"common_name": connect.CACN("vault", uid, v.clusterID, v.isPrimary),
Expand All @@ -319,23 +319,23 @@ func (v *VaultProvider) GenerateCAChain() (CAChainResult, error) {
"key_bits": v.config.PrivateKeyBits,
})
if err != nil {
return CAChainResult{}, fmt.Errorf("failed to initialize root CA: %w", err)
return "", fmt.Errorf("failed to initialize root CA: %w", err)
}
var ok bool
rootPEM, ok = resp.Data["certificate"].(string)
if !ok {
return CAChainResult{}, fmt.Errorf("unexpected response from Vault: %v", resp.Data["certificate"])
return "", fmt.Errorf("unexpected response from Vault: %v", resp.Data["certificate"])
}

default:
if err != nil {
return CAChainResult{}, fmt.Errorf("unexpected error while setting root PKI backend: %w", err)
return "", fmt.Errorf("unexpected error while setting root PKI backend: %w", err)
}
}

rootChain, err := v.getCAChain(v.config.RootPKINamespace, v.config.RootPKIPath)
if err != nil {
return CAChainResult{}, err
return "", err
}

// Workaround for a bug in the Vault PKI API.
Expand All @@ -344,18 +344,7 @@ func (v *VaultProvider) GenerateCAChain() (CAChainResult, error) {
rootChain = rootPEM
}

intermediate, err := v.ActiveLeafSigningCert()
if err != nil {
return CAChainResult{}, fmt.Errorf("error fetching active intermediate: %w", err)
}
if intermediate == "" {
intermediate, err = v.GenerateLeafSigningCert()
if err != nil {
return CAChainResult{}, fmt.Errorf("error generating intermediate: %w", err)
}
}

return CAChainResult{PEM: rootChain, IntermediatePEM: intermediate}, nil
return rootChain, nil
}

// GenerateIntermediateCSR creates a private key and generates a CSR
Expand Down Expand Up @@ -582,7 +571,7 @@ func (v *VaultProvider) getCAChain(namespace, path string) (string, error) {
return root, nil
}

// GenerateIntermediate mounts the configured intermediate PKI backend if
// GenerateLeafSigningCert mounts the configured intermediate PKI backend if
// necessary, then generates and signs a new CA CSR using the root PKI backend
// and updates the intermediate backend to use that new certificate.
func (v *VaultProvider) GenerateLeafSigningCert() (string, error) {
Expand Down
25 changes: 11 additions & 14 deletions agent/connect/ca/provider_vault_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ func TestVaultCAProvider_Bootstrap(t *testing.T) {
},
certFunc: func(provider *VaultProvider) (string, error) {
root, err := provider.GenerateCAChain()
return root.PEM, err
return root, err
},
backendPath: "pki-root/",
rootCaCreation: true,
Expand Down Expand Up @@ -497,9 +497,8 @@ func TestVaultCAProvider_SignLeaf(t *testing.T) {
Service: "foo",
}

root, err := provider.GenerateCAChain()
rootPEM, err := provider.GenerateCAChain()
require.NoError(t, err)
rootPEM := root.PEM
assertCorrectKeyType(t, tc.KeyType, rootPEM)

intPEM, err := provider.ActiveLeafSigningCert()
Expand Down Expand Up @@ -600,9 +599,9 @@ func TestVaultCAProvider_CrossSignCA(t *testing.T) {
})

testutil.RunStep(t, "init", func(t *testing.T) {
root, err := provider1.GenerateCAChain()
rootPEM, err := provider1.GenerateCAChain()
require.NoError(t, err)
assertCorrectKeyType(t, tc.SigningKeyType, root.PEM)
assertCorrectKeyType(t, tc.SigningKeyType, rootPEM)

intPEM, err := provider1.ActiveLeafSigningCert()
require.NoError(t, err)
Expand All @@ -628,9 +627,9 @@ func TestVaultCAProvider_CrossSignCA(t *testing.T) {
})

testutil.RunStep(t, "swap", func(t *testing.T) {
root, err := provider2.GenerateCAChain()
rootPEM, err := provider2.GenerateCAChain()
require.NoError(t, err)
assertCorrectKeyType(t, tc.CSRKeyType, root.PEM)
assertCorrectKeyType(t, tc.CSRKeyType, rootPEM)

intPEM, err := provider2.ActiveLeafSigningCert()
require.NoError(t, err)
Expand Down Expand Up @@ -1147,14 +1146,14 @@ func TestVaultCAProvider_GenerateIntermediate(t *testing.T) {
// This test was created to ensure that our calls to Vault
// returns a new Intermediate certificate and further calls
// to ActiveLeafSigningCert return the same new cert.
new, err := provider.GenerateLeafSigningCert()
newLeaf, err := provider.GenerateLeafSigningCert()
require.NoError(t, err)

newActive, err := provider.ActiveLeafSigningCert()
require.NoError(t, err)

require.Equal(t, new, newActive)
require.NotEqual(t, orig, new)
require.Equal(t, newLeaf, newActive)
require.NotEqual(t, orig, newLeaf)
}

func TestVaultCAProvider_AutoTidyExpiredIssuers(t *testing.T) {
Expand Down Expand Up @@ -1240,9 +1239,8 @@ func TestVaultCAProvider_GenerateIntermediate_inSecondary(t *testing.T) {
// Sign the CSR with primaryProvider.
intermediatePEM, err := primaryProvider.SignIntermediate(csr)
require.NoError(t, err)
root, err := primaryProvider.GenerateCAChain()
rootPEM, err := primaryProvider.GenerateCAChain()
require.NoError(t, err)
rootPEM := root.PEM

// Give the new intermediate to provider to use.
require.NoError(t, provider.SetIntermediate(intermediatePEM, rootPEM, issuerID))
Expand All @@ -1261,9 +1259,8 @@ func TestVaultCAProvider_GenerateIntermediate_inSecondary(t *testing.T) {
// Sign the CSR with primaryProvider.
intermediatePEM, err := primaryProvider.SignIntermediate(csr)
require.NoError(t, err)
root, err := primaryProvider.GenerateCAChain()
rootPEM, err := primaryProvider.GenerateCAChain()
require.NoError(t, err)
rootPEM := root.PEM

// Give the new intermediate to provider to use.
require.NoError(t, provider.SetIntermediate(intermediatePEM, rootPEM, issuerID))
Expand Down
Loading

0 comments on commit 89ceafc

Please sign in to comment.