diff --git a/builtin/credential/approle/path_tidy_user_id.go b/builtin/credential/approle/path_tidy_user_id.go index 590cb7284d41..6137e0594fbe 100644 --- a/builtin/credential/approle/path_tidy_user_id.go +++ b/builtin/credential/approle/path_tidy_user_id.go @@ -26,141 +26,153 @@ func pathTidySecretID(b *backend) *framework.Path { } // tidySecretID is used to delete entries in the whitelist that are expired. -func (b *backend) tidySecretID(ctx context.Context, s logical.Storage) error { - grabbed := atomic.CompareAndSwapUint32(b.tidySecretIDCASGuard, 0, 1) - if grabbed { - defer atomic.StoreUint32(b.tidySecretIDCASGuard, 0) - } else { - return fmt.Errorf("SecretID tidy operation already running") +func (b *backend) tidySecretID(ctx context.Context, s logical.Storage) (*logical.Response, error) { + if !atomic.CompareAndSwapUint32(b.tidySecretIDCASGuard, 0, 1) { + resp := &logical.Response{} + resp.AddWarning("Tidy operation already in progress.") + return resp, nil } - var result error + go func() { + defer atomic.StoreUint32(b.tidySecretIDCASGuard, 0) - tidyFunc := func(secretIDPrefixToUse, accessorIDPrefixToUse string) error { - roleNameHMACs, err := s.List(ctx, secretIDPrefixToUse) - if err != nil { - return err - } + var result error - // List all the accessors and add them all to a map - accessorHashes, err := s.List(ctx, accessorIDPrefixToUse) - if err != nil { - return err - } - accessorMap := make(map[string]bool, len(accessorHashes)) - for _, accessorHash := range accessorHashes { - accessorMap[accessorHash] = true - } + // Don't cancel when the original client request goes away + ctx = context.Background() - secretIDCleanupFunc := func(secretIDHMAC, roleNameHMAC, secretIDPrefixToUse string) error { - lock := b.secretIDLock(secretIDHMAC) - lock.Lock() - defer lock.Unlock() + logger := b.Logger().Named("tidy") - entryIndex := fmt.Sprintf("%s%s%s", secretIDPrefixToUse, roleNameHMAC, secretIDHMAC) - secretIDEntry, err := s.Get(ctx, entryIndex) + tidyFunc := func(secretIDPrefixToUse, accessorIDPrefixToUse string) error { + roleNameHMACs, err := s.List(ctx, secretIDPrefixToUse) if err != nil { - return errwrap.Wrapf(fmt.Sprintf("error fetching SecretID %q: {{err}}", secretIDHMAC), err) + return err } - if secretIDEntry == nil { - result = multierror.Append(result, fmt.Errorf("entry for SecretID %q is nil", secretIDHMAC)) - return nil + // List all the accessors and add them all to a map + accessorHashes, err := s.List(ctx, accessorIDPrefixToUse) + if err != nil { + return err } - - if secretIDEntry.Value == nil || len(secretIDEntry.Value) == 0 { - return fmt.Errorf("found entry for SecretID %q but actual SecretID is empty", secretIDHMAC) + accessorMap := make(map[string]bool, len(accessorHashes)) + for _, accessorHash := range accessorHashes { + accessorMap[accessorHash] = true } - var result secretIDStorageEntry - if err := secretIDEntry.DecodeJSON(&result); err != nil { - return err - } + secretIDCleanupFunc := func(secretIDHMAC, roleNameHMAC, secretIDPrefixToUse string) error { + lock := b.secretIDLock(secretIDHMAC) + lock.Lock() + defer lock.Unlock() - // If a secret ID entry does not have a corresponding accessor - // entry, revoke the secret ID immediately - accessorEntry, err := b.secretIDAccessorEntry(ctx, s, result.SecretIDAccessor, secretIDPrefixToUse) - if err != nil { - return errwrap.Wrapf("failed to read secret ID accessor entry: {{err}}", err) - } - if accessorEntry == nil { - if err := s.Delete(ctx, entryIndex); err != nil { - return errwrap.Wrapf(fmt.Sprintf("error deleting secret ID %q from storage: {{err}}", secretIDHMAC), err) + entryIndex := fmt.Sprintf("%s%s%s", secretIDPrefixToUse, roleNameHMAC, secretIDHMAC) + secretIDEntry, err := s.Get(ctx, entryIndex) + if err != nil { + return errwrap.Wrapf(fmt.Sprintf("error fetching SecretID %q: {{err}}", secretIDHMAC), err) + } + + if secretIDEntry == nil { + result = multierror.Append(result, fmt.Errorf("entry for SecretID %q is nil", secretIDHMAC)) + return nil + } + + if secretIDEntry.Value == nil || len(secretIDEntry.Value) == 0 { + return fmt.Errorf("found entry for SecretID %q but actual SecretID is empty", secretIDHMAC) + } + + var result secretIDStorageEntry + if err := secretIDEntry.DecodeJSON(&result); err != nil { + return err } - return nil - } - // ExpirationTime not being set indicates non-expiring SecretIDs - if !result.ExpirationTime.IsZero() && time.Now().After(result.ExpirationTime) { - // Clean up the accessor of the secret ID first - err = b.deleteSecretIDAccessorEntry(ctx, s, result.SecretIDAccessor, secretIDPrefixToUse) + // If a secret ID entry does not have a corresponding accessor + // entry, revoke the secret ID immediately + accessorEntry, err := b.secretIDAccessorEntry(ctx, s, result.SecretIDAccessor, secretIDPrefixToUse) if err != nil { - return errwrap.Wrapf("failed to delete secret ID accessor entry: {{err}}", err) + return errwrap.Wrapf("failed to read secret ID accessor entry: {{err}}", err) + } + if accessorEntry == nil { + if err := s.Delete(ctx, entryIndex); err != nil { + return errwrap.Wrapf(fmt.Sprintf("error deleting secret ID %q from storage: {{err}}", secretIDHMAC), err) + } + return nil + } + + // ExpirationTime not being set indicates non-expiring SecretIDs + if !result.ExpirationTime.IsZero() && time.Now().After(result.ExpirationTime) { + // Clean up the accessor of the secret ID first + err = b.deleteSecretIDAccessorEntry(ctx, s, result.SecretIDAccessor, secretIDPrefixToUse) + if err != nil { + return errwrap.Wrapf("failed to delete secret ID accessor entry: {{err}}", err) + } + + if err := s.Delete(ctx, entryIndex); err != nil { + return errwrap.Wrapf(fmt.Sprintf("error deleting SecretID %q from storage: {{err}}", secretIDHMAC), err) + } + + return nil } - if err := s.Delete(ctx, entryIndex); err != nil { - return errwrap.Wrapf(fmt.Sprintf("error deleting SecretID %q from storage: {{err}}", secretIDHMAC), err) + // At this point, the secret ID is not expired and is valid. Delete + // the corresponding accessor from the accessorMap. This will leave + // only the dangling accessors in the map which can then be cleaned + // up later. + salt, err := b.Salt(ctx) + if err != nil { + return err } + delete(accessorMap, salt.SaltID(result.SecretIDAccessor)) return nil } - // At this point, the secret ID is not expired and is valid. Delete - // the corresponding accessor from the accessorMap. This will leave - // only the dangling accessors in the map which can then be cleaned - // up later. - salt, err := b.Salt(ctx) - if err != nil { - return err + for _, roleNameHMAC := range roleNameHMACs { + secretIDHMACs, err := s.List(ctx, fmt.Sprintf("%s%s", secretIDPrefixToUse, roleNameHMAC)) + if err != nil { + return err + } + for _, secretIDHMAC := range secretIDHMACs { + err = secretIDCleanupFunc(secretIDHMAC, roleNameHMAC, secretIDPrefixToUse) + if err != nil { + return err + } + } } - delete(accessorMap, salt.SaltID(result.SecretIDAccessor)) - return nil - } - - for _, roleNameHMAC := range roleNameHMACs { - secretIDHMACs, err := s.List(ctx, fmt.Sprintf("%s%s", secretIDPrefixToUse, roleNameHMAC)) - if err != nil { - return err - } - for _, secretIDHMAC := range secretIDHMACs { - err = secretIDCleanupFunc(secretIDHMAC, roleNameHMAC, secretIDPrefixToUse) + // Accessor indexes were not getting cleaned up until 0.9.3. This is a fix + // to clean up the dangling accessor entries. + for accessorHash, _ := range accessorMap { + // Ideally, locking should be performed here. But for that, accessors + // are required in plaintext, which are not available. Hence performing + // a racy cleanup. + err = s.Delete(ctx, secretIDAccessorPrefix+accessorHash) if err != nil { return err } } - } - // Accessor indexes were not getting cleaned up until 0.9.3. This is a fix - // to clean up the dangling accessor entries. - for accessorHash, _ := range accessorMap { - // Ideally, locking should be performed here. But for that, accessors - // are required in plaintext, which are not available. Hence performing - // a racy cleanup. - err = s.Delete(ctx, secretIDAccessorPrefix+accessorHash) - if err != nil { - return err - } + return nil } - return nil - } - - err := tidyFunc(secretIDPrefix, secretIDAccessorPrefix) - if err != nil { - return err - } - err = tidyFunc(secretIDLocalPrefix, secretIDAccessorLocalPrefix) - if err != nil { - return err - } + err := tidyFunc(secretIDPrefix, secretIDAccessorPrefix) + if err != nil { + logger.Error("error tidying global secret IDs", "error", err) + return + } + err = tidyFunc(secretIDLocalPrefix, secretIDAccessorLocalPrefix) + if err != nil { + logger.Error("error tidying local secret IDs", "error", err) + return + } + }() - return result + resp := &logical.Response{} + resp.AddWarning("Tidy operation successfully started. Any information from the operation will be printed to Vault's server logs.") + return resp, nil } // pathTidySecretIDUpdate is used to delete the expired SecretID entries func (b *backend) pathTidySecretIDUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - return nil, b.tidySecretID(ctx, req.Storage) + return b.tidySecretID(ctx, req.Storage) } const pathTidySecretIDSyn = "Trigger the clean-up of expired SecretID entries." diff --git a/builtin/credential/approle/path_tidy_user_id_test.go b/builtin/credential/approle/path_tidy_user_id_test.go index b52b711356f4..f4e8b8e91da6 100644 --- a/builtin/credential/approle/path_tidy_user_id_test.go +++ b/builtin/credential/approle/path_tidy_user_id_test.go @@ -3,6 +3,7 @@ package approle import ( "context" "testing" + "time" "github.com/hashicorp/vault/logical" ) @@ -64,11 +65,14 @@ func TestAppRole_TidyDanglingAccessors(t *testing.T) { t.Fatalf("bad: len(accessorHashes); expect 3, got %d", len(accessorHashes)) } - err = b.tidySecretID(context.Background(), storage) + _, err = b.tidySecretID(context.Background(), storage) if err != nil { t.Fatal(err) } + // It runs async so we give it a bit of time to run + time.Sleep(10 * time.Second) + accessorHashes, err = storage.List(context.Background(), "accessor/") if err != nil { t.Fatal(err) diff --git a/builtin/credential/aws/path_tidy_identity_whitelist.go b/builtin/credential/aws/path_tidy_identity_whitelist.go index f1abe2308614..2ff035b8d9af 100644 --- a/builtin/credential/aws/path_tidy_identity_whitelist.go +++ b/builtin/credential/aws/path_tidy_identity_whitelist.go @@ -33,53 +33,72 @@ expiration, before it is removed from the backend storage.`, } // tidyWhitelistIdentity is used to delete entries in the whitelist that are expired. -func (b *backend) tidyWhitelistIdentity(ctx context.Context, s logical.Storage, safety_buffer int) error { - grabbed := atomic.CompareAndSwapUint32(b.tidyWhitelistCASGuard, 0, 1) - if grabbed { - defer atomic.StoreUint32(b.tidyWhitelistCASGuard, 0) - } else { - return fmt.Errorf("identity whitelist tidy operation already running") +func (b *backend) tidyWhitelistIdentity(ctx context.Context, s logical.Storage, safety_buffer int) (*logical.Response, error) { + if !atomic.CompareAndSwapUint32(b.tidyWhitelistCASGuard, 0, 1) { + resp := &logical.Response{} + resp.AddWarning("Tidy operation already in progress.") + return resp, nil } - bufferDuration := time.Duration(safety_buffer) * time.Second + go func() { + defer atomic.StoreUint32(b.tidyWhitelistCASGuard, 0) - identities, err := s.List(ctx, "whitelist/identity/") - if err != nil { - return err - } + // Don't cancel when the original client request goes away + ctx = context.Background() - for _, instanceID := range identities { - identityEntry, err := s.Get(ctx, "whitelist/identity/"+instanceID) - if err != nil { - return errwrap.Wrapf(fmt.Sprintf("error fetching identity of instanceID %q: {{err}}", instanceID), err) - } + logger := b.Logger().Named("wltidy") - if identityEntry == nil { - return fmt.Errorf("identity entry for instanceID %q is nil", instanceID) - } + bufferDuration := time.Duration(safety_buffer) * time.Second - if identityEntry.Value == nil || len(identityEntry.Value) == 0 { - return fmt.Errorf("found identity entry for instanceID %q but actual identity is empty", instanceID) - } + doTidy := func() error { + identities, err := s.List(ctx, "whitelist/identity/") + if err != nil { + return err + } + + for _, instanceID := range identities { + identityEntry, err := s.Get(ctx, "whitelist/identity/"+instanceID) + if err != nil { + return errwrap.Wrapf(fmt.Sprintf("error fetching identity of instanceID %q: {{err}}", instanceID), err) + } + + if identityEntry == nil { + return fmt.Errorf("identity entry for instanceID %q is nil", instanceID) + } + + if identityEntry.Value == nil || len(identityEntry.Value) == 0 { + return fmt.Errorf("found identity entry for instanceID %q but actual identity is empty", instanceID) + } + + var result whitelistIdentity + if err := identityEntry.DecodeJSON(&result); err != nil { + return err + } + + if time.Now().After(result.ExpirationTime.Add(bufferDuration)) { + if err := s.Delete(ctx, "whitelist/identity"+instanceID); err != nil { + return errwrap.Wrapf(fmt.Sprintf("error deleting identity of instanceID %q from storage: {{err}}", instanceID), err) + } + } + } - var result whitelistIdentity - if err := identityEntry.DecodeJSON(&result); err != nil { - return err + return nil } - if time.Now().After(result.ExpirationTime.Add(bufferDuration)) { - if err := s.Delete(ctx, "whitelist/identity"+instanceID); err != nil { - return errwrap.Wrapf(fmt.Sprintf("error deleting identity of instanceID %q from storage: {{err}}", instanceID), err) - } + if err := doTidy(); err != nil { + logger.Error("error running whitelist tidy", "error", err) + return } - } + }() - return nil + resp := &logical.Response{} + resp.AddWarning("Tidy operation successfully started. Any information from the operation will be printed to Vault's server logs.") + return resp, nil } // pathTidyIdentityWhitelistUpdate is used to delete entries in the whitelist that are expired. func (b *backend) pathTidyIdentityWhitelistUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - return nil, b.tidyWhitelistIdentity(ctx, req.Storage, data.Get("safety_buffer").(int)) + return b.tidyWhitelistIdentity(ctx, req.Storage, data.Get("safety_buffer").(int)) } const pathTidyIdentityWhitelistSyn = ` diff --git a/builtin/credential/aws/path_tidy_roletag_blacklist.go b/builtin/credential/aws/path_tidy_roletag_blacklist.go index a29837110d2b..4eaafc22df4b 100644 --- a/builtin/credential/aws/path_tidy_roletag_blacklist.go +++ b/builtin/credential/aws/path_tidy_roletag_blacklist.go @@ -33,52 +33,72 @@ expiration, before it is removed from the backend storage.`, } // tidyBlacklistRoleTag is used to clean-up the entries in the role tag blacklist. -func (b *backend) tidyBlacklistRoleTag(ctx context.Context, s logical.Storage, safety_buffer int) error { - grabbed := atomic.CompareAndSwapUint32(b.tidyBlacklistCASGuard, 0, 1) - if grabbed { - defer atomic.StoreUint32(b.tidyBlacklistCASGuard, 0) - } else { - return fmt.Errorf("roletag blacklist tidy operation already running") +func (b *backend) tidyBlacklistRoleTag(ctx context.Context, s logical.Storage, safety_buffer int) (*logical.Response, error) { + if !atomic.CompareAndSwapUint32(b.tidyBlacklistCASGuard, 0, 1) { + resp := &logical.Response{} + resp.AddWarning("Tidy operation already in progress.") + return resp, nil } - bufferDuration := time.Duration(safety_buffer) * time.Second - tags, err := s.List(ctx, "blacklist/roletag/") - if err != nil { - return err - } + go func() { + defer atomic.StoreUint32(b.tidyBlacklistCASGuard, 0) - for _, tag := range tags { - tagEntry, err := s.Get(ctx, "blacklist/roletag/"+tag) - if err != nil { - return errwrap.Wrapf(fmt.Sprintf("error fetching tag %q: {{err}}", tag), err) - } + // Don't cancel when the original client request goes away + ctx = context.Background() - if tagEntry == nil { - return fmt.Errorf("tag entry for tag %q is nil", tag) - } + logger := b.Logger().Named("bltidy") - if tagEntry.Value == nil || len(tagEntry.Value) == 0 { - return fmt.Errorf("found entry for tag %q but actual tag is empty", tag) - } + bufferDuration := time.Duration(safety_buffer) * time.Second - var result roleTagBlacklistEntry - if err := tagEntry.DecodeJSON(&result); err != nil { - return err - } + doTidy := func() error { + tags, err := s.List(ctx, "blacklist/roletag/") + if err != nil { + return err + } - if time.Now().After(result.ExpirationTime.Add(bufferDuration)) { - if err := s.Delete(ctx, "blacklist/roletag"+tag); err != nil { - return errwrap.Wrapf(fmt.Sprintf("error deleting tag %q from storage: {{err}}", tag), err) + for _, tag := range tags { + tagEntry, err := s.Get(ctx, "blacklist/roletag/"+tag) + if err != nil { + return errwrap.Wrapf(fmt.Sprintf("error fetching tag %q: {{err}}", tag), err) + } + + if tagEntry == nil { + return fmt.Errorf("tag entry for tag %q is nil", tag) + } + + if tagEntry.Value == nil || len(tagEntry.Value) == 0 { + return fmt.Errorf("found entry for tag %q but actual tag is empty", tag) + } + + var result roleTagBlacklistEntry + if err := tagEntry.DecodeJSON(&result); err != nil { + return err + } + + if time.Now().After(result.ExpirationTime.Add(bufferDuration)) { + if err := s.Delete(ctx, "blacklist/roletag"+tag); err != nil { + return errwrap.Wrapf(fmt.Sprintf("error deleting tag %q from storage: {{err}}", tag), err) + } + } } + + return nil } - } - return nil + if err := doTidy(); err != nil { + logger.Error("error running blacklist tidy", "error", err) + return + } + }() + + resp := &logical.Response{} + resp.AddWarning("Tidy operation successfully started. Any information from the operation will be printed to Vault's server logs.") + return resp, nil } // pathTidyRoletagBlacklistUpdate is used to clean-up the entries in the role tag blacklist. func (b *backend) pathTidyRoletagBlacklistUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - return nil, b.tidyBlacklistRoleTag(ctx, req.Storage, data.Get("safety_buffer").(int)) + return b.tidyBlacklistRoleTag(ctx, req.Storage, data.Get("safety_buffer").(int)) } const pathTidyRoletagBlacklistSyn = ` diff --git a/builtin/logical/pki/backend.go b/builtin/logical/pki/backend.go index de8a3517c9f3..60e943acb8a3 100644 --- a/builtin/logical/pki/backend.go +++ b/builtin/logical/pki/backend.go @@ -12,7 +12,7 @@ import ( // Factory creates a new backend implementing the logical.Backend interface func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) { - b := Backend() + b := Backend(conf) if err := b.Setup(ctx, conf); err != nil { return nil, err } @@ -20,7 +20,7 @@ func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, } // Backend returns a new Backend framework struct -func Backend() *backend { +func Backend(conf *logical.BackendConfig) *backend { var b backend b.Backend = &framework.Backend{ Help: strings.TrimSpace(backendHelp), @@ -85,6 +85,8 @@ func Backend() *backend { } b.crlLifetime = time.Hour * 72 + b.tidyCASGuard = new(uint32) + b.storage = conf.StorageView return &b } @@ -92,8 +94,10 @@ func Backend() *backend { type backend struct { *framework.Backend + storage logical.Storage crlLifetime time.Duration revokeStorageLock sync.RWMutex + tidyCASGuard *uint32 } const backendHelp = ` diff --git a/builtin/logical/pki/backend_test.go b/builtin/logical/pki/backend_test.go index 3b80d5972545..9be7b55dd3a4 100644 --- a/builtin/logical/pki/backend_test.go +++ b/builtin/logical/pki/backend_test.go @@ -134,64 +134,6 @@ func TestPKI_RequireCN(t *testing.T) { } } -// Performs basic tests on CA functionality -// Uses the RSA CA key -func TestBackend_RSAKey(t *testing.T) { - initTest.Do(setCerts) - defaultLeaseTTLVal := time.Hour * 24 - maxLeaseTTLVal := time.Hour * 24 * 32 - b, err := Factory(context.Background(), &logical.BackendConfig{ - Logger: nil, - System: &logical.StaticSystemView{ - DefaultLeaseTTLVal: defaultLeaseTTLVal, - MaxLeaseTTLVal: maxLeaseTTLVal, - }, - }) - if err != nil { - t.Fatalf("Unable to create backend: %s", err) - } - - testCase := logicaltest.TestCase{ - Backend: b, - Steps: []logicaltest.TestStep{}, - } - - intdata := map[string]interface{}{} - reqdata := map[string]interface{}{} - testCase.Steps = append(testCase.Steps, generateCATestingSteps(t, rsaCACert, rsaCAKey, ecCACert, intdata, reqdata)...) - - logicaltest.Test(t, testCase) -} - -// Performs basic tests on CA functionality -// Uses the EC CA key -func TestBackend_ECKey(t *testing.T) { - initTest.Do(setCerts) - defaultLeaseTTLVal := time.Hour * 24 - maxLeaseTTLVal := time.Hour * 24 * 32 - b, err := Factory(context.Background(), &logical.BackendConfig{ - Logger: nil, - System: &logical.StaticSystemView{ - DefaultLeaseTTLVal: defaultLeaseTTLVal, - MaxLeaseTTLVal: maxLeaseTTLVal, - }, - }) - if err != nil { - t.Fatalf("Unable to create backend: %s", err) - } - - testCase := logicaltest.TestCase{ - Backend: b, - Steps: []logicaltest.TestStep{}, - } - - intdata := map[string]interface{}{} - reqdata := map[string]interface{}{} - testCase.Steps = append(testCase.Steps, generateCATestingSteps(t, ecCACert, ecCAKey, rsaCACert, intdata, reqdata)...) - - logicaltest.Test(t, testCase) -} - func TestBackend_CSRValues(t *testing.T) { initTest.Do(setCerts) defaultLeaseTTLVal := time.Hour * 24 @@ -805,685 +747,6 @@ func generateCSRSteps(t *testing.T, caCert, caKey string, intdata, reqdata map[s return ret } -// Generates steps to test out CA configuration -- certificates + CRL expiry, -// and ensure that the certificates are readable after storing them -func generateCATestingSteps(t *testing.T, caCert, caKey, otherCaCert string, intdata, reqdata map[string]interface{}) []logicaltest.TestStep { - setSerialUnderTest := func(req *logical.Request) error { - req.Path = serialUnderTest - return nil - } - - ret := []logicaltest.TestStep{ - logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "config/ca", - Data: map[string]interface{}{ - "pem_bundle": strings.Join([]string{caKey, caCert}, "\n"), - }, - }, - - logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "config/crl", - Data: map[string]interface{}{ - "expiry": "16h", - }, - }, - - // Ensure we can fetch it back via unauthenticated means, in various formats - logicaltest.TestStep{ - Operation: logical.ReadOperation, - Path: "cert/ca", - Unauthenticated: true, - Check: func(resp *logical.Response) error { - if resp.Data["certificate"].(string) != caCert { - return fmt.Errorf("CA certificate:\n%s\ndoes not match original:\n%s\n", resp.Data["certificate"].(string), caCert) - } - return nil - }, - }, - - logicaltest.TestStep{ - Operation: logical.ReadOperation, - Path: "ca/pem", - Unauthenticated: true, - Check: func(resp *logical.Response) error { - rawBytes := resp.Data["http_raw_body"].([]byte) - if !reflect.DeepEqual(rawBytes, []byte(caCert)) { - return fmt.Errorf("CA certificate:\n%#v\ndoes not match original:\n%#v\n", rawBytes, []byte(caCert)) - } - if resp.Data["http_content_type"].(string) != "application/pkix-cert" { - return fmt.Errorf("expected application/pkix-cert as content-type, but got %s", resp.Data["http_content_type"].(string)) - } - return nil - }, - }, - - logicaltest.TestStep{ - Operation: logical.ReadOperation, - Path: "ca", - Unauthenticated: true, - Check: func(resp *logical.Response) error { - rawBytes := resp.Data["http_raw_body"].([]byte) - pemBytes := strings.TrimSpace(string(pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE", - Bytes: rawBytes, - }))) - if pemBytes != caCert { - return fmt.Errorf("CA certificate:\n%s\ndoes not match original:\n%s\n", pemBytes, caCert) - } - if resp.Data["http_content_type"].(string) != "application/pkix-cert" { - return fmt.Errorf("expected application/pkix-cert as content-type, but got %s", resp.Data["http_content_type"].(string)) - } - return nil - }, - }, - - logicaltest.TestStep{ - Operation: logical.ReadOperation, - Path: "config/crl", - Check: func(resp *logical.Response) error { - if resp.Data["expiry"].(string) != "16h" { - return fmt.Errorf("CRL lifetimes do not match (got %s)", resp.Data["expiry"].(string)) - } - return nil - }, - }, - - // Ensure that both parts of the PEM bundle are required - // Here, just the cert - logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "config/ca", - Data: map[string]interface{}{ - "pem_bundle": caCert, - }, - ErrorOk: true, - }, - - // Here, just the key - logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "config/ca", - Data: map[string]interface{}{ - "pem_bundle": caKey, - }, - ErrorOk: true, - }, - - // Ensure we can fetch it back via unauthenticated means, in various formats - logicaltest.TestStep{ - Operation: logical.ReadOperation, - Path: "cert/ca", - Unauthenticated: true, - Check: func(resp *logical.Response) error { - if resp.Data["certificate"].(string) != caCert { - return fmt.Errorf("CA certificate:\n%s\ndoes not match original:\n%s\n", resp.Data["certificate"].(string), caCert) - } - return nil - }, - }, - - logicaltest.TestStep{ - Operation: logical.ReadOperation, - Path: "ca/pem", - Unauthenticated: true, - Check: func(resp *logical.Response) error { - rawBytes := resp.Data["http_raw_body"].([]byte) - if string(rawBytes) != caCert { - return fmt.Errorf("CA certificate:\n%s\ndoes not match original:\n%s\n", string(rawBytes), caCert) - } - if resp.Data["http_content_type"].(string) != "application/pkix-cert" { - return fmt.Errorf("expected application/pkix-cert as content-type, but got %s", resp.Data["http_content_type"].(string)) - } - return nil - }, - }, - - logicaltest.TestStep{ - Operation: logical.ReadOperation, - Path: "ca", - Unauthenticated: true, - Check: func(resp *logical.Response) error { - rawBytes := resp.Data["http_raw_body"].([]byte) - pemBytes := strings.TrimSpace(string(pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE", - Bytes: rawBytes, - }))) - if pemBytes != caCert { - return fmt.Errorf("CA certificate:\n%s\ndoes not match original:\n%s\n", pemBytes, caCert) - } - if resp.Data["http_content_type"].(string) != "application/pkix-cert" { - return fmt.Errorf("expected application/pkix-cert as content-type, but got %s", resp.Data["http_content_type"].(string)) - } - return nil - }, - }, - - // Test a bunch of generation stuff - logicaltest.TestStep{ - Operation: logical.DeleteOperation, - Path: "root", - }, - - logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "root/generate/exported", - Data: map[string]interface{}{ - "common_name": "Root Cert", - "ttl": "180h", - }, - Check: func(resp *logical.Response) error { - intdata["root"] = resp.Data["certificate"].(string) - intdata["rootkey"] = resp.Data["private_key"].(string) - reqdata["pem_bundle"] = strings.Join([]string{intdata["root"].(string), intdata["rootkey"].(string)}, "\n") - return nil - }, - }, - - logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "intermediate/generate/exported", - Data: map[string]interface{}{ - "common_name": "intermediate.cert.com", - }, - Check: func(resp *logical.Response) error { - intdata["intermediatecsr"] = resp.Data["csr"].(string) - intdata["intermediatekey"] = resp.Data["private_key"].(string) - return nil - }, - }, - - // Re-load the root key in so we can sign it - logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "config/ca", - Data: reqdata, - Check: func(resp *logical.Response) error { - delete(reqdata, "pem_bundle") - delete(reqdata, "ttl") - reqdata["csr"] = intdata["intermediatecsr"].(string) - reqdata["common_name"] = "intermediate.cert.com" - reqdata["ttl"] = "10s" - return nil - }, - }, - - logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "root/sign-intermediate", - Data: reqdata, - Check: func(resp *logical.Response) error { - delete(reqdata, "csr") - delete(reqdata, "common_name") - delete(reqdata, "ttl") - intdata["intermediatecert"] = resp.Data["certificate"].(string) - reqdata["serial_number"] = resp.Data["serial_number"].(string) - reqdata["rsa_int_serial_number"] = resp.Data["serial_number"].(string) - reqdata["certificate"] = resp.Data["certificate"].(string) - reqdata["pem_bundle"] = strings.Join([]string{intdata["intermediatekey"].(string), resp.Data["certificate"].(string)}, "\n") - return nil - }, - }, - - // First load in this way to populate the private key - logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "config/ca", - Data: reqdata, - Check: func(resp *logical.Response) error { - delete(reqdata, "pem_bundle") - return nil - }, - }, - - // Now test setting the intermediate, signed CA cert - logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "intermediate/set-signed", - Data: reqdata, - Check: func(resp *logical.Response) error { - delete(reqdata, "certificate") - - serialUnderTest = "cert/" + reqdata["rsa_int_serial_number"].(string) - - return nil - }, - }, - - // We expect to find a zero revocation time - logicaltest.TestStep{ - Operation: logical.ReadOperation, - PreFlight: setSerialUnderTest, - Check: func(resp *logical.Response) error { - if resp.Data["error"] != nil && resp.Data["error"].(string) != "" { - return fmt.Errorf("got an error: %s", resp.Data["error"].(string)) - } - - if resp.Data["revocation_time"].(int64) != 0 { - return fmt.Errorf("expected a zero revocation time") - } - - return nil - }, - }, - - logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "revoke", - Data: reqdata, - }, - - logicaltest.TestStep{ - Operation: logical.ReadOperation, - Path: "crl", - Data: reqdata, - Check: func(resp *logical.Response) error { - crlBytes := resp.Data["http_raw_body"].([]byte) - certList, err := x509.ParseCRL(crlBytes) - if err != nil { - t.Fatalf("err: %s", err) - } - revokedList := certList.TBSCertList.RevokedCertificates - if len(revokedList) != 1 { - t.Fatalf("length of revoked list not 1; %d", len(revokedList)) - } - revokedString := certutil.GetHexFormatted(revokedList[0].SerialNumber.Bytes(), ":") - if revokedString != reqdata["serial_number"].(string) { - t.Fatalf("got serial %s, expecting %s", revokedString, reqdata["serial_number"].(string)) - } - delete(reqdata, "serial_number") - return nil - }, - }, - - // Do it all again, with EC keys and DER format - logicaltest.TestStep{ - Operation: logical.DeleteOperation, - Path: "root", - }, - - logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "root/generate/exported", - Data: map[string]interface{}{ - "common_name": "Root Cert", - "ttl": "180h", - "key_type": "ec", - "key_bits": 384, - "format": "der", - }, - Check: func(resp *logical.Response) error { - certBytes, _ := base64.StdEncoding.DecodeString(resp.Data["certificate"].(string)) - certPem := strings.TrimSpace(string(pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE", - Bytes: certBytes, - }))) - keyBytes, _ := base64.StdEncoding.DecodeString(resp.Data["private_key"].(string)) - keyPem := strings.TrimSpace(string(pem.EncodeToMemory(&pem.Block{ - Type: "EC PRIVATE KEY", - Bytes: keyBytes, - }))) - intdata["root"] = certPem - intdata["rootkey"] = keyPem - reqdata["pem_bundle"] = strings.Join([]string{certPem, keyPem}, "\n") - return nil - }, - }, - - logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "intermediate/generate/exported", - Data: map[string]interface{}{ - "format": "der", - "key_type": "ec", - "key_bits": 384, - "common_name": "intermediate.cert.com", - }, - Check: func(resp *logical.Response) error { - csrBytes, _ := base64.StdEncoding.DecodeString(resp.Data["csr"].(string)) - csrPem := strings.TrimSpace(string(pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE REQUEST", - Bytes: csrBytes, - }))) - keyBytes, _ := base64.StdEncoding.DecodeString(resp.Data["private_key"].(string)) - keyPem := strings.TrimSpace(string(pem.EncodeToMemory(&pem.Block{ - Type: "EC PRIVATE KEY", - Bytes: keyBytes, - }))) - intdata["intermediatecsr"] = csrPem - intdata["intermediatekey"] = keyPem - return nil - }, - }, - - logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "config/ca", - Data: reqdata, - Check: func(resp *logical.Response) error { - delete(reqdata, "pem_bundle") - delete(reqdata, "ttl") - reqdata["csr"] = intdata["intermediatecsr"].(string) - reqdata["common_name"] = "intermediate.cert.com" - reqdata["ttl"] = "10s" - return nil - }, - }, - - logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "root/sign-intermediate", - Data: reqdata, - Check: func(resp *logical.Response) error { - delete(reqdata, "csr") - delete(reqdata, "common_name") - delete(reqdata, "ttl") - intdata["intermediatecert"] = resp.Data["certificate"].(string) - reqdata["serial_number"] = resp.Data["serial_number"].(string) - reqdata["ec_int_serial_number"] = resp.Data["serial_number"].(string) - reqdata["certificate"] = resp.Data["certificate"].(string) - reqdata["pem_bundle"] = strings.Join([]string{intdata["intermediatekey"].(string), resp.Data["certificate"].(string)}, "\n") - return nil - }, - }, - - // First load in this way to populate the private key - logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "config/ca", - Data: reqdata, - Check: func(resp *logical.Response) error { - delete(reqdata, "pem_bundle") - return nil - }, - }, - - // Now test setting the intermediate, signed CA cert - logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "intermediate/set-signed", - Data: reqdata, - Check: func(resp *logical.Response) error { - delete(reqdata, "certificate") - - serialUnderTest = "cert/" + reqdata["ec_int_serial_number"].(string) - - return nil - }, - }, - - // We expect to find a zero revocation time - logicaltest.TestStep{ - Operation: logical.ReadOperation, - PreFlight: setSerialUnderTest, - Check: func(resp *logical.Response) error { - if resp.Data["error"] != nil && resp.Data["error"].(string) != "" { - return fmt.Errorf("got an error: %s", resp.Data["error"].(string)) - } - - if resp.Data["revocation_time"].(int64) != 0 { - return fmt.Errorf("expected a zero revocation time") - } - - return nil - }, - }, - logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "revoke", - Data: reqdata, - }, - - logicaltest.TestStep{ - Operation: logical.ReadOperation, - Path: "crl", - Data: reqdata, - Check: func(resp *logical.Response) error { - crlBytes := resp.Data["http_raw_body"].([]byte) - certList, err := x509.ParseCRL(crlBytes) - if err != nil { - t.Fatalf("err: %s", err) - } - revokedList := certList.TBSCertList.RevokedCertificates - if len(revokedList) != 2 { - t.Fatalf("length of revoked list not 2; %d", len(revokedList)) - } - found := false - for _, revEntry := range revokedList { - revokedString := certutil.GetHexFormatted(revEntry.SerialNumber.Bytes(), ":") - if revokedString == reqdata["serial_number"].(string) { - found = true - } - } - if !found { - t.Fatalf("did not find %s in CRL", reqdata["serial_number"].(string)) - } - delete(reqdata, "serial_number") - - serialUnderTest = "cert/" + reqdata["rsa_int_serial_number"].(string) - - return nil - }, - }, - - // Make sure both serial numbers we expect to find are found - logicaltest.TestStep{ - Operation: logical.ReadOperation, - PreFlight: setSerialUnderTest, - Check: func(resp *logical.Response) error { - if resp.Data["error"] != nil && resp.Data["error"].(string) != "" { - return fmt.Errorf("got an error: %s", resp.Data["error"].(string)) - } - - if resp.Data["revocation_time"].(int64) == 0 { - return fmt.Errorf("expected a non-zero revocation time") - } - - serialUnderTest = "cert/" + reqdata["ec_int_serial_number"].(string) - - return nil - }, - }, - - logicaltest.TestStep{ - Operation: logical.ReadOperation, - PreFlight: setSerialUnderTest, - Check: func(resp *logical.Response) error { - if resp.Data["error"] != nil && resp.Data["error"].(string) != "" { - return fmt.Errorf("got an error: %s", resp.Data["error"].(string)) - } - - if resp.Data["revocation_time"].(int64) == 0 { - return fmt.Errorf("expected a non-zero revocation time") - } - - // Give time for the certificates to pass the safety buffer - t.Logf("Sleeping for 15 seconds to allow safety buffer time to pass before testing tidying") - time.Sleep(15 * time.Second) - - serialUnderTest = "cert/" + reqdata["rsa_int_serial_number"].(string) - - return nil - }, - }, - - // This shouldn't do anything since the safety buffer is too long - logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "tidy", - Data: map[string]interface{}{ - "safety_buffer": "3h", - "tidy_cert_store": true, - "tidy_revocation_list": true, - }, - }, - - // We still expect to find these - logicaltest.TestStep{ - Operation: logical.ReadOperation, - PreFlight: setSerialUnderTest, - Check: func(resp *logical.Response) error { - if resp != nil && resp.Data["error"] != nil && resp.Data["error"].(string) != "" { - return fmt.Errorf("got an error: %s", resp.Data["error"].(string)) - } - - serialUnderTest = "cert/" + reqdata["ec_int_serial_number"].(string) - - return nil - }, - }, - - logicaltest.TestStep{ - Operation: logical.ReadOperation, - PreFlight: setSerialUnderTest, - Check: func(resp *logical.Response) error { - if resp != nil && resp.Data["error"] != nil && resp.Data["error"].(string) != "" { - return fmt.Errorf("got an error: %s", resp.Data["error"].(string)) - } - - serialUnderTest = "cert/" + reqdata["rsa_int_serial_number"].(string) - - return nil - }, - }, - - // Both should appear in the CRL - logicaltest.TestStep{ - Operation: logical.ReadOperation, - Path: "crl", - Data: reqdata, - Check: func(resp *logical.Response) error { - crlBytes := resp.Data["http_raw_body"].([]byte) - certList, err := x509.ParseCRL(crlBytes) - if err != nil { - t.Fatalf("err: %s", err) - } - revokedList := certList.TBSCertList.RevokedCertificates - if len(revokedList) != 2 { - t.Fatalf("length of revoked list not 2; %d", len(revokedList)) - } - foundRsa := false - foundEc := false - for _, revEntry := range revokedList { - revokedString := certutil.GetHexFormatted(revEntry.SerialNumber.Bytes(), ":") - if revokedString == reqdata["rsa_int_serial_number"].(string) { - foundRsa = true - } - if revokedString == reqdata["ec_int_serial_number"].(string) { - foundEc = true - } - } - if !foundRsa || !foundEc { - t.Fatalf("did not find an expected entry in CRL") - } - - return nil - }, - }, - - // This shouldn't do anything since the boolean values default to false - logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "tidy", - Data: map[string]interface{}{ - "safety_buffer": "1s", - }, - }, - - // We still expect to find these - logicaltest.TestStep{ - Operation: logical.ReadOperation, - PreFlight: setSerialUnderTest, - Check: func(resp *logical.Response) error { - if resp != nil && resp.Data["error"] != nil && resp.Data["error"].(string) != "" { - return fmt.Errorf("got an error: %s", resp.Data["error"].(string)) - } - - serialUnderTest = "cert/" + reqdata["ec_int_serial_number"].(string) - - return nil - }, - }, - - logicaltest.TestStep{ - Operation: logical.ReadOperation, - PreFlight: setSerialUnderTest, - Check: func(resp *logical.Response) error { - if resp != nil && resp.Data["error"] != nil && resp.Data["error"].(string) != "" { - return fmt.Errorf("got an error: %s", resp.Data["error"].(string)) - } - - serialUnderTest = "cert/" + reqdata["rsa_int_serial_number"].(string) - - return nil - }, - }, - - // This should remove the values since the safety buffer is short - logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "tidy", - Data: map[string]interface{}{ - "safety_buffer": "1s", - "tidy_cert_store": true, - "tidy_revocation_list": true, - }, - }, - - // We do *not* expect to find these - logicaltest.TestStep{ - Operation: logical.ReadOperation, - PreFlight: setSerialUnderTest, - Check: func(resp *logical.Response) error { - if resp != nil { - return fmt.Errorf("expected no response") - } - - serialUnderTest = "cert/" + reqdata["ec_int_serial_number"].(string) - - return nil - }, - }, - - logicaltest.TestStep{ - Operation: logical.ReadOperation, - PreFlight: setSerialUnderTest, - Check: func(resp *logical.Response) error { - if resp != nil { - return fmt.Errorf("expected no response") - } - - serialUnderTest = "cert/" + reqdata["rsa_int_serial_number"].(string) - - return nil - }, - }, - - // Both should be gone from the CRL - logicaltest.TestStep{ - Operation: logical.ReadOperation, - Path: "crl", - Data: reqdata, - Check: func(resp *logical.Response) error { - crlBytes := resp.Data["http_raw_body"].([]byte) - certList, err := x509.ParseCRL(crlBytes) - if err != nil { - t.Fatalf("err: %s", err) - } - revokedList := certList.TBSCertList.RevokedCertificates - if len(revokedList) != 0 { - t.Fatalf("length of revoked list not 0; %d", len(revokedList)) - } - - return nil - }, - }, - } - - return ret -} - // Generates steps to test out various role permutations func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep { roleVals := roleEntry{ @@ -2140,7 +1403,7 @@ func TestBackend_PathFetchCertList(t *testing.T) { storage := &logical.InmemStorage{} config.StorageView = storage - b := Backend() + b := Backend(config) err := b.Setup(context.Background(), config) if err != nil { t.Fatal(err) @@ -2267,7 +1530,7 @@ func TestBackend_SignVerbatim(t *testing.T) { storage := &logical.InmemStorage{} config.StorageView = storage - b := Backend() + b := Backend(config) err := b.Setup(context.Background(), config) if err != nil { t.Fatal(err) @@ -2823,7 +2086,7 @@ func TestBackend_SignSelfIssued(t *testing.T) { storage := &logical.InmemStorage{} config.StorageView = storage - b := Backend() + b := Backend(config) err := b.Setup(context.Background(), config) if err != nil { t.Fatal(err) diff --git a/builtin/logical/pki/ca_test.go b/builtin/logical/pki/ca_test.go new file mode 100644 index 000000000000..82bac5af0c03 --- /dev/null +++ b/builtin/logical/pki/ca_test.go @@ -0,0 +1,570 @@ +package pki + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/json" + "encoding/pem" + "math/big" + mathrand "math/rand" + "strings" + "testing" + "time" + + "github.com/go-test/deep" + "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/helper/certutil" + vaulthttp "github.com/hashicorp/vault/http" + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/vault" +) + +func TestBackend_CA_Steps(t *testing.T) { + var b *backend + + factory := func(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) { + be, err := Factory(ctx, conf) + if err == nil { + b = be.(*backend) + } + return be, err + } + + coreConfig := &vault.CoreConfig{ + LogicalBackends: map[string]logical.Factory{ + "pki": factory, + }, + } + cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + }) + cluster.Start() + defer cluster.Cleanup() + + client := cluster.Cores[0].Client + + // Set RSA/EC CA certificates + var rsaCAKey, rsaCACert, ecCAKey, ecCACert string + { + cak, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + panic(err) + } + marshaledKey, err := x509.MarshalECPrivateKey(cak) + if err != nil { + panic(err) + } + keyPEMBlock := &pem.Block{ + Type: "EC PRIVATE KEY", + Bytes: marshaledKey, + } + ecCAKey = strings.TrimSpace(string(pem.EncodeToMemory(keyPEMBlock))) + if err != nil { + panic(err) + } + subjKeyID, err := certutil.GetSubjKeyID(cak) + if err != nil { + panic(err) + } + caCertTemplate := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "root.localhost", + }, + SubjectKeyId: subjKeyID, + DNSNames: []string{"root.localhost"}, + KeyUsage: x509.KeyUsage(x509.KeyUsageCertSign | x509.KeyUsageCRLSign), + SerialNumber: big.NewInt(mathrand.Int63()), + NotBefore: time.Now().Add(-30 * time.Second), + NotAfter: time.Now().Add(262980 * time.Hour), + BasicConstraintsValid: true, + IsCA: true, + } + caBytes, err := x509.CreateCertificate(rand.Reader, caCertTemplate, caCertTemplate, cak.Public(), cak) + if err != nil { + panic(err) + } + caCertPEMBlock := &pem.Block{ + Type: "CERTIFICATE", + Bytes: caBytes, + } + ecCACert = strings.TrimSpace(string(pem.EncodeToMemory(caCertPEMBlock))) + + rak, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + panic(err) + } + marshaledKey = x509.MarshalPKCS1PrivateKey(rak) + keyPEMBlock = &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: marshaledKey, + } + rsaCAKey = strings.TrimSpace(string(pem.EncodeToMemory(keyPEMBlock))) + if err != nil { + panic(err) + } + subjKeyID, err = certutil.GetSubjKeyID(rak) + if err != nil { + panic(err) + } + caBytes, err = x509.CreateCertificate(rand.Reader, caCertTemplate, caCertTemplate, rak.Public(), rak) + if err != nil { + panic(err) + } + caCertPEMBlock = &pem.Block{ + Type: "CERTIFICATE", + Bytes: caBytes, + } + rsaCACert = strings.TrimSpace(string(pem.EncodeToMemory(caCertPEMBlock))) + } + + // Setup backends + var rsaRoot, rsaInt, ecRoot, ecInt *backend + { + if err := client.Sys().Mount("rsaroot", &api.MountInput{ + Type: "pki", + Config: api.MountConfigInput{ + DefaultLeaseTTL: "16h", + MaxLeaseTTL: "60h", + }, + }); err != nil { + t.Fatal(err) + } + rsaRoot = b + + if err := client.Sys().Mount("rsaint", &api.MountInput{ + Type: "pki", + Config: api.MountConfigInput{ + DefaultLeaseTTL: "16h", + MaxLeaseTTL: "60h", + }, + }); err != nil { + t.Fatal(err) + } + rsaInt = b + + if err := client.Sys().Mount("ecroot", &api.MountInput{ + Type: "pki", + Config: api.MountConfigInput{ + DefaultLeaseTTL: "16h", + MaxLeaseTTL: "60h", + }, + }); err != nil { + t.Fatal(err) + } + ecRoot = b + + if err := client.Sys().Mount("ecint", &api.MountInput{ + Type: "pki", + Config: api.MountConfigInput{ + DefaultLeaseTTL: "16h", + MaxLeaseTTL: "60h", + }, + }); err != nil { + t.Fatal(err) + } + ecInt = b + } + + t.Run("teststeps", func(t *testing.T) { + t.Run("rsa", func(t *testing.T) { + t.Parallel() + subClient, err := client.Clone() + if err != nil { + t.Fatal(err) + } + subClient.SetToken(client.Token()) + runSteps(t, rsaRoot, rsaInt, subClient, "rsaroot/", "rsaint/", rsaCACert, rsaCAKey) + }) + t.Run("ec", func(t *testing.T) { + t.Parallel() + subClient, err := client.Clone() + if err != nil { + t.Fatal(err) + } + subClient.SetToken(client.Token()) + runSteps(t, ecRoot, ecInt, subClient, "ecroot/", "ecint/", ecCACert, ecCAKey) + }) + }) +} + +func runSteps(t *testing.T, rootB, intB *backend, client *api.Client, rootName, intName, caCert, caKey string) { + // Load CA cert/key in and ensure we can fetch it back in various formats, + // unauthenticated + { + // Attempt import but only provide one the cert + { + _, err := client.Logical().Write(rootName+"config/ca", map[string]interface{}{ + "pem_bundle": caCert, + }) + if err == nil { + t.Fatal("expected error") + } + } + + // Same but with only the key + { + _, err := client.Logical().Write(rootName+"config/ca", map[string]interface{}{ + "pem_bundle": caKey, + }) + if err == nil { + t.Fatal("expected error") + } + } + + // Import CA bundle + { + _, err := client.Logical().Write(rootName+"config/ca", map[string]interface{}{ + "pem_bundle": strings.Join([]string{caKey, caCert}, "\n"), + }) + if err != nil { + t.Fatal(err) + } + } + + prevToken := client.Token() + client.SetToken("") + + // cert/ca path + { + resp, err := client.Logical().Read(rootName + "cert/ca") + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("nil response") + } + if diff := deep.Equal(resp.Data["certificate"].(string), caCert); diff != nil { + t.Fatal(diff) + } + } + // ca/pem path (raw string) + { + req := &logical.Request{ + Path: "ca/pem", + Operation: logical.ReadOperation, + Storage: rootB.storage, + } + resp, err := rootB.HandleRequest(context.Background(), req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("nil response") + } + if diff := deep.Equal(resp.Data["http_raw_body"].([]byte), []byte(caCert)); diff != nil { + t.Fatal(diff) + } + if resp.Data["http_content_type"].(string) != "application/pkix-cert" { + t.Fatal("wrong content type") + } + } + + // ca (raw DER bytes) + { + req := &logical.Request{ + Path: "ca", + Operation: logical.ReadOperation, + Storage: rootB.storage, + } + resp, err := rootB.HandleRequest(context.Background(), req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("nil response") + } + rawBytes := resp.Data["http_raw_body"].([]byte) + pemBytes := strings.TrimSpace(string(pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: rawBytes, + }))) + if diff := deep.Equal(pemBytes, caCert); diff != nil { + t.Fatal(diff) + } + if resp.Data["http_content_type"].(string) != "application/pkix-cert" { + t.Fatal("wrong content type") + } + } + + client.SetToken(prevToken) + } + + // Configure an expiry on the CRL and verify what comes back + { + // Set CRL config + { + _, err := client.Logical().Write(rootName+"config/crl", map[string]interface{}{ + "expiry": "16h", + }) + if err != nil { + t.Fatal(err) + } + } + + // Verify it + { + resp, err := client.Logical().Read(rootName + "config/crl") + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("nil response") + } + if resp.Data["expiry"].(string) != "16h" { + t.Fatal("expected a 16 hour expiry") + } + } + } + + // Test generating a root, an intermediate, signing it, setting signed, and + // revoking it + + // We'll need this later + var intSerialNumber string + { + // First, delete the existing CA info + { + _, err := client.Logical().Delete(rootName + "root") + if err != nil { + t.Fatal(err) + } + } + + var rootPEM, rootKey, rootPEMBundle string + // Test exported root generation + { + resp, err := client.Logical().Write(rootName+"root/generate/exported", map[string]interface{}{ + "common_name": "Root Cert", + "ttl": "180h", + }) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("nil response") + } + rootPEM = resp.Data["certificate"].(string) + rootKey = resp.Data["private_key"].(string) + rootPEMBundle = strings.Join([]string{rootPEM, rootKey}, "\n") + // This is really here to keep the use checker happy + if rootPEMBundle == "" { + t.Fatal("bad root pem bundle") + } + } + + var intPEM, intCSR, intKey string + // Test exported intermediate CSR generation + { + resp, err := client.Logical().Write(intName+"intermediate/generate/exported", map[string]interface{}{ + "common_name": "intermediate.cert.com", + "ttl": "180h", + }) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("nil response") + } + intCSR = resp.Data["csr"].(string) + intKey = resp.Data["private_key"].(string) + // This is really here to keep the use checker happy + if intCSR == "" || intKey == "" { + t.Fatal("int csr or key empty") + } + } + + // Test signing + { + resp, err := client.Logical().Write(rootName+"root/sign-intermediate", map[string]interface{}{ + "common_name": "intermediate.cert.com", + "ttl": "10s", + "csr": intCSR, + }) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("nil response") + } + intPEM = resp.Data["certificate"].(string) + intSerialNumber = resp.Data["serial_number"].(string) + } + + // Test setting signed + { + resp, err := client.Logical().Write(intName+"intermediate/set-signed", map[string]interface{}{ + "certificate": intPEM, + }) + if err != nil { + t.Fatal(err) + } + if resp != nil { + t.Fatal("expected nil response") + } + } + + // Verify we can find it via the root + { + resp, err := client.Logical().Read(rootName + "cert/" + intSerialNumber) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("nil response") + } + if resp.Data["revocation_time"].(json.Number).String() != "0" { + t.Fatal("expected a zero revocation time") + } + } + + // Revoke the intermediate + { + resp, err := client.Logical().Write(rootName+"revoke", map[string]interface{}{ + "serial_number": intSerialNumber, + }) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("nil response") + } + } + } + + verifyRevocation := func(t *testing.T, serial string, shouldFind bool) { + // Verify it is now revoked + { + resp, err := client.Logical().Read(rootName + "cert/" + intSerialNumber) + if err != nil { + t.Fatal(err) + } + switch shouldFind { + case true: + if resp == nil { + t.Fatal("nil response") + } + if resp.Data["revocation_time"].(json.Number).String() == "0" { + t.Fatal("expected a non-zero revocation time") + } + default: + if resp != nil { + t.Fatalf("expected nil response, got %#v", *resp) + } + } + } + + // Fetch the CRL and make sure it shows up + { + req := &logical.Request{ + Path: "crl", + Operation: logical.ReadOperation, + Storage: rootB.storage, + } + resp, err := rootB.HandleRequest(context.Background(), req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("nil response") + } + crlBytes := resp.Data["http_raw_body"].([]byte) + certList, err := x509.ParseCRL(crlBytes) + if err != nil { + t.Fatal(err) + } + switch shouldFind { + case true: + revokedList := certList.TBSCertList.RevokedCertificates + if len(revokedList) != 1 { + t.Fatalf("bad length of revoked list: %d", len(revokedList)) + } + revokedString := certutil.GetHexFormatted(revokedList[0].SerialNumber.Bytes(), ":") + if revokedString != intSerialNumber { + t.Fatalf("bad revoked serial: %s", revokedString) + } + default: + revokedList := certList.TBSCertList.RevokedCertificates + if len(revokedList) != 0 { + t.Fatalf("bad length of revoked list: %d", len(revokedList)) + } + } + } + } + + // Validate current state of revoked certificates + verifyRevocation(t, intSerialNumber, true) + + // Give time for the safety buffer to pass before tidying + time.Sleep(10 * time.Second) + + // Test tidying + { + // Run with a high safety buffer, nothing should happen + { + resp, err := client.Logical().Write(rootName+"tidy", map[string]interface{}{ + "safety_buffer": "3h", + "tidy_cert_store": true, + "tidy_revocation_list": true, + }) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected warnings") + } + + // Wait a few seconds as it runs in a goroutine + time.Sleep(5 * time.Second) + + // Check to make sure we still find the cert and see it on the CRL + verifyRevocation(t, intSerialNumber, true) + } + + // Run with both values set false, nothing should happen + { + resp, err := client.Logical().Write(rootName+"tidy", map[string]interface{}{ + "safety_buffer": "1s", + "tidy_cert_store": false, + "tidy_revocation_list": false, + }) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected warnings") + } + + // Wait a few seconds as it runs in a goroutine + time.Sleep(5 * time.Second) + + // Check to make sure we still find the cert and see it on the CRL + verifyRevocation(t, intSerialNumber, true) + } + + // Run with a short safety buffer and both set to true, both should be cleared + { + resp, err := client.Logical().Write(rootName+"tidy", map[string]interface{}{ + "safety_buffer": "1s", + "tidy_cert_store": true, + "tidy_revocation_list": true, + }) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected warnings") + } + + // Wait a few seconds as it runs in a goroutine + time.Sleep(5 * time.Second) + + // Check to make sure we still find the cert and see it on the CRL + verifyRevocation(t, intSerialNumber, false) + } + } +} diff --git a/builtin/logical/pki/path_roles_test.go b/builtin/logical/pki/path_roles_test.go index f3927781aa23..6f915eeca742 100644 --- a/builtin/logical/pki/path_roles_test.go +++ b/builtin/logical/pki/path_roles_test.go @@ -15,7 +15,7 @@ func createBackendWithStorage(t *testing.T) (*backend, logical.Storage) { config.StorageView = &logical.InmemStorage{} var err error - b := Backend() + b := Backend(config) err = b.Setup(context.Background(), config) if err != nil { t.Fatal(err) diff --git a/builtin/logical/pki/path_tidy.go b/builtin/logical/pki/path_tidy.go index 9b0a86df9cf3..4d9ea992db00 100644 --- a/builtin/logical/pki/path_tidy.go +++ b/builtin/logical/pki/path_tidy.go @@ -4,6 +4,7 @@ import ( "context" "crypto/x509" "fmt" + "sync/atomic" "time" "github.com/hashicorp/errwrap" @@ -59,116 +60,134 @@ func (b *backend) pathTidyWrite(ctx context.Context, req *logical.Request, d *fr bufferDuration := time.Duration(safetyBuffer) * time.Second - var resp *logical.Response + if !atomic.CompareAndSwapUint32(b.tidyCASGuard, 0, 1) { + resp := &logical.Response{} + resp.AddWarning("Tidy operation already in progress.") + return resp, nil + } - if tidyCertStore { - serials, err := req.Storage.List(ctx, "certs/") - if err != nil { - return nil, errwrap.Wrapf("error fetching list of certs: {{err}}", err) - } + // Tests using framework will screw up the storage so make a locally + // scoped req to hold a reference + req = &logical.Request{ + Storage: req.Storage, + } - for _, serial := range serials { - certEntry, err := req.Storage.Get(ctx, "certs/"+serial) - if err != nil { - return nil, errwrap.Wrapf(fmt.Sprintf("error fetching certificate %q: {{err}}", serial), err) - } + go func() { + defer atomic.StoreUint32(b.tidyCASGuard, 0) - if certEntry == nil { - if resp == nil { - resp = &logical.Response{} - } - resp.AddWarning(fmt.Sprintf("Certificate entry for serial %s is nil; tidying up since it is no longer useful for any server operations", serial)) - if err := req.Storage.Delete(ctx, "certs/"+serial); err != nil { - return nil, errwrap.Wrapf(fmt.Sprintf("error deleting nil entry with serial %s: {{err}}", serial), err) - } - } + // Don't cancel when the original client request goes away + ctx = context.Background() - if certEntry.Value == nil || len(certEntry.Value) == 0 { - if resp == nil { - resp = &logical.Response{} - } - resp.AddWarning(fmt.Sprintf("Certificate entry for serial %s is nil; tidying up since it is no longer useful for any server operations", serial)) - if err := req.Storage.Delete(ctx, "certs/"+serial); err != nil { - return nil, errwrap.Wrapf(fmt.Sprintf("error deleting entry with nil value with serial %s: {{err}}", serial), err) - } - } + logger := b.Logger().Named("tidy") - cert, err := x509.ParseCertificate(certEntry.Value) - if err != nil { - return nil, errwrap.Wrapf(fmt.Sprintf("unable to parse stored certificate with serial %q: {{err}}", serial), err) - } + doTidy := func() error { + if tidyCertStore { + serials, err := req.Storage.List(ctx, "certs/") + if err != nil { + return errwrap.Wrapf("error fetching list of certs: {{err}}", err) + } - if time.Now().After(cert.NotAfter.Add(bufferDuration)) { - if err := req.Storage.Delete(ctx, "certs/"+serial); err != nil { - return nil, errwrap.Wrapf(fmt.Sprintf("error deleting serial %q from storage: {{err}}", serial), err) + for _, serial := range serials { + certEntry, err := req.Storage.Get(ctx, "certs/"+serial) + if err != nil { + return errwrap.Wrapf(fmt.Sprintf("error fetching certificate %q: {{err}}", serial), err) + } + + if certEntry == nil { + logger.Warn("certificate entry is nil; tidying up since it is no longer useful for any server operations", "serial", serial) + if err := req.Storage.Delete(ctx, "certs/"+serial); err != nil { + return errwrap.Wrapf(fmt.Sprintf("error deleting nil entry with serial %s: {{err}}", serial), err) + } + } + + if certEntry.Value == nil || len(certEntry.Value) == 0 { + logger.Warn("certificate entry has no value; tidying up since it is no longer useful for any server operations", "serial", serial) + if err := req.Storage.Delete(ctx, "certs/"+serial); err != nil { + return errwrap.Wrapf(fmt.Sprintf("error deleting entry with nil value with serial %s: {{err}}", serial), err) + } + } + + cert, err := x509.ParseCertificate(certEntry.Value) + if err != nil { + return errwrap.Wrapf(fmt.Sprintf("unable to parse stored certificate with serial %q: {{err}}", serial), err) + } + + if time.Now().After(cert.NotAfter.Add(bufferDuration)) { + if err := req.Storage.Delete(ctx, "certs/"+serial); err != nil { + return errwrap.Wrapf(fmt.Sprintf("error deleting serial %q from storage: {{err}}", serial), err) + } + } } } - } - } - if tidyRevocationList { - b.revokeStorageLock.Lock() - defer b.revokeStorageLock.Unlock() + if tidyRevocationList { + b.revokeStorageLock.Lock() + defer b.revokeStorageLock.Unlock() - tidiedRevoked := false + tidiedRevoked := false - revokedSerials, err := req.Storage.List(ctx, "revoked/") - if err != nil { - return nil, errwrap.Wrapf("error fetching list of revoked certs: {{err}}", err) - } - - var revInfo revocationInfo - for _, serial := range revokedSerials { - revokedEntry, err := req.Storage.Get(ctx, "revoked/"+serial) - if err != nil { - return nil, errwrap.Wrapf(fmt.Sprintf("unable to fetch revoked cert with serial %q: {{err}}", serial), err) - } - - if revokedEntry == nil { - if resp == nil { - resp = &logical.Response{} - } - resp.AddWarning(fmt.Sprintf("Revoked entry for serial %s is nil; tidying up since it is no longer useful for any server operations", serial)) - if err := req.Storage.Delete(ctx, "revoked/"+serial); err != nil { - return nil, errwrap.Wrapf(fmt.Sprintf("error deleting nil revoked entry with serial %s: {{err}}", serial), err) + revokedSerials, err := req.Storage.List(ctx, "revoked/") + if err != nil { + return errwrap.Wrapf("error fetching list of revoked certs: {{err}}", err) } - } - if revokedEntry.Value == nil || len(revokedEntry.Value) == 0 { - if resp == nil { - resp = &logical.Response{} - } - resp.AddWarning(fmt.Sprintf("Revoked entry for serial %s has nil value; tidying up since it is no longer useful for any server operations", serial)) - if err := req.Storage.Delete(ctx, "revoked/"+serial); err != nil { - return nil, errwrap.Wrapf(fmt.Sprintf("error deleting revoked entry with nil value with serial %s: {{err}}", serial), err) + var revInfo revocationInfo + for _, serial := range revokedSerials { + revokedEntry, err := req.Storage.Get(ctx, "revoked/"+serial) + if err != nil { + return errwrap.Wrapf(fmt.Sprintf("unable to fetch revoked cert with serial %q: {{err}}", serial), err) + } + + if revokedEntry == nil { + logger.Warn("revoked entry is nil; tidying up since it is no longer useful for any server operations", "serial", serial) + if err := req.Storage.Delete(ctx, "revoked/"+serial); err != nil { + return errwrap.Wrapf(fmt.Sprintf("error deleting nil revoked entry with serial %s: {{err}}", serial), err) + } + } + + if revokedEntry.Value == nil || len(revokedEntry.Value) == 0 { + logger.Warn("revoked entry has nil value; tidying up since it is no longer useful for any server operations", "serial", serial) + if err := req.Storage.Delete(ctx, "revoked/"+serial); err != nil { + return errwrap.Wrapf(fmt.Sprintf("error deleting revoked entry with nil value with serial %s: {{err}}", serial), err) + } + } + + err = revokedEntry.DecodeJSON(&revInfo) + if err != nil { + return errwrap.Wrapf(fmt.Sprintf("error decoding revocation entry for serial %q: {{err}}", serial), err) + } + + revokedCert, err := x509.ParseCertificate(revInfo.CertificateBytes) + if err != nil { + return errwrap.Wrapf(fmt.Sprintf("unable to parse stored revoked certificate with serial %q: {{err}}", serial), err) + } + + if time.Now().After(revokedCert.NotAfter.Add(bufferDuration)) { + if err := req.Storage.Delete(ctx, "revoked/"+serial); err != nil { + return errwrap.Wrapf(fmt.Sprintf("error deleting serial %q from revoked list: {{err}}", serial), err) + } + tidiedRevoked = true + } } - } - err = revokedEntry.DecodeJSON(&revInfo) - if err != nil { - return nil, errwrap.Wrapf(fmt.Sprintf("error decoding revocation entry for serial %q: {{err}}", serial), err) - } - - revokedCert, err := x509.ParseCertificate(revInfo.CertificateBytes) - if err != nil { - return nil, errwrap.Wrapf(fmt.Sprintf("unable to parse stored revoked certificate with serial %q: {{err}}", serial), err) - } - - if time.Now().After(revokedCert.NotAfter.Add(bufferDuration)) { - if err := req.Storage.Delete(ctx, "revoked/"+serial); err != nil { - return nil, errwrap.Wrapf(fmt.Sprintf("error deleting serial %q from revoked list: {{err}}", serial), err) + if tidiedRevoked { + if err := buildCRL(ctx, b, req); err != nil { + return err + } } - tidiedRevoked = true } + + return nil } - if tidiedRevoked { - if err := buildCRL(ctx, b, req); err != nil { - return nil, err - } + if err := doTidy(); err != nil { + logger.Error("error running tidy", "error", err) + return } - } + }() + resp := &logical.Response{} + resp.AddWarning("Tidy operation successfully started. Any information from the operation will be printed to Vault's server logs.") return resp, nil } diff --git a/command/server.go b/command/server.go index f3e2fbd22aba..98001f3e9cb6 100644 --- a/command/server.go +++ b/command/server.go @@ -935,7 +935,10 @@ CLUSTER_SYNTHESIS_COMPLETE: } server := &http.Server{ - Handler: handler, + Handler: handler, + ReadHeaderTimeout: 10 * time.Second, + ReadTimeout: 30 * time.Second, + IdleTimeout: 5 * time.Minute, } go server.Serve(ln.Listener) } diff --git a/helper/proxyutil/proxyutil.go b/helper/proxyutil/proxyutil.go index 875e74831c94..f18ec1d14905 100644 --- a/helper/proxyutil/proxyutil.go +++ b/helper/proxyutil/proxyutil.go @@ -4,6 +4,7 @@ import ( "fmt" "net" "sync" + "time" proxyproto "github.com/armon/go-proxyproto" "github.com/hashicorp/errwrap" @@ -41,12 +42,14 @@ func WrapInProxyProto(listener net.Listener, config *ProxyProtoConfig) (net.List switch config.Behavior { case "use_always": newLn = &proxyproto.Listener{ - Listener: listener, + Listener: listener, + ProxyHeaderTimeout: 10 * time.Second, } case "allow_authorized", "deny_unauthorized": newLn = &proxyproto.Listener{ - Listener: listener, + Listener: listener, + ProxyHeaderTimeout: 10 * time.Second, SourceCheck: func(addr net.Addr) (bool, error) { config.RLock() defer config.RUnlock() diff --git a/physical/dynamodb/dynamodb_test.go b/physical/dynamodb/dynamodb_test.go index 70e69802d553..a831a29150fd 100644 --- a/physical/dynamodb/dynamodb_test.go +++ b/physical/dynamodb/dynamodb_test.go @@ -6,10 +6,10 @@ import ( "math/rand" "net/http" "os" - "reflect" "testing" "time" + "github.com/go-test/deep" log "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault/helper/logging" "github.com/hashicorp/vault/physical" @@ -106,8 +106,8 @@ func TestDynamoDBBackend(t *testing.T) { if err != nil { t.Fatalf("err: %s", err) } - if !reflect.DeepEqual(inputEntry, entry) { - t.Fatalf("exp: %#v, act: %#v", inputEntry, entry) + if diff := deep.Equal(inputEntry, entry); diff != nil { + t.Fatal(diff) } }) } @@ -285,7 +285,7 @@ func prepareDynamoDBTestContainer(t *testing.T) (cleanup func(), retAddress stri t.Fatalf("Failed to connect to docker: %s", err) } - resource, err := pool.Run("deangiberson/aws-dynamodb-local", "latest", []string{}) + resource, err := pool.Run("cnadiminti/dynamodb-local", "latest", []string{}) if err != nil { t.Fatalf("Could not start local DynamoDB: %s", err) } diff --git a/vault/expiration.go b/vault/expiration.go index 5e42c3fb3f2f..7924a401704a 100644 --- a/vault/expiration.go +++ b/vault/expiration.go @@ -191,15 +191,17 @@ func (m *ExpirationManager) Tidy() error { var tidyErrors *multierror.Error + logger := m.logger.Named("tidy") + if !atomic.CompareAndSwapInt32(m.tidyLock, 0, 1) { - m.logger.Warn("tidy operation on leases is already in progress") - return fmt.Errorf("tidy operation on leases is already in progress") + logger.Warn("tidy operation on leases is already in progress") + return nil } defer atomic.CompareAndSwapInt32(m.tidyLock, 1, 0) - m.logger.Info("beginning tidy operation on leases") - defer m.logger.Info("finished tidy operation on leases") + logger.Info("beginning tidy operation on leases") + defer logger.Info("finished tidy operation on leases") // Create a cache to keep track of looked up tokens tokenCache := make(map[string]bool) @@ -208,7 +210,7 @@ func (m *ExpirationManager) Tidy() error { tidyFunc := func(leaseID string) { countLease++ if countLease%500 == 0 { - m.logger.Info("tidying leases", "progress", countLease) + logger.Info("tidying leases", "progress", countLease) } le, err := m.loadEntry(leaseID) @@ -225,7 +227,7 @@ func (m *ExpirationManager) Tidy() error { var isValid, ok bool revokeLease := false if le.ClientToken == "" { - m.logger.Debug("revoking lease which has an empty token", "lease_id", leaseID) + logger.Debug("revoking lease which has an empty token", "lease_id", leaseID) revokeLease = true deletedCountEmptyToken++ goto REVOKE_CHECK @@ -249,7 +251,7 @@ func (m *ExpirationManager) Tidy() error { } if te == nil { - m.logger.Debug("revoking lease which holds an invalid token", "lease_id", leaseID) + logger.Debug("revoking lease which holds an invalid token", "lease_id", leaseID) revokeLease = true deletedCountInvalidToken++ tokenCache[le.ClientToken] = false @@ -262,7 +264,7 @@ func (m *ExpirationManager) Tidy() error { return } - m.logger.Debug("revoking lease which contains an invalid token", "lease_id", leaseID) + logger.Debug("revoking lease which contains an invalid token", "lease_id", leaseID) revokeLease = true deletedCountInvalidToken++ goto REVOKE_CHECK @@ -285,10 +287,10 @@ func (m *ExpirationManager) Tidy() error { return err } - m.logger.Info("number of leases scanned", "count", countLease) - m.logger.Info("number of leases which had empty tokens", "count", deletedCountEmptyToken) - m.logger.Info("number of leases which had invalid tokens", "count", deletedCountInvalidToken) - m.logger.Info("number of leases successfully revoked", "count", revokedCount) + logger.Info("number of leases scanned", "count", countLease) + logger.Info("number of leases which had empty tokens", "count", deletedCountEmptyToken) + logger.Info("number of leases which had invalid tokens", "count", deletedCountInvalidToken) + logger.Info("number of leases successfully revoked", "count", revokedCount) return tidyErrors.ErrorOrNil() } diff --git a/vault/expiration_test.go b/vault/expiration_test.go index e0b93ec43bc8..b4997e6a08e9 100644 --- a/vault/expiration_test.go +++ b/vault/expiration_test.go @@ -1,6 +1,7 @@ package vault import ( + "bytes" "context" "fmt" "reflect" @@ -38,6 +39,14 @@ func TestExpiration_Tidy(t *testing.T) { var err error exp := mockExpiration(t) + + // We use this later for tidy testing where we need to check the output + logOut := new(bytes.Buffer) + logger := log.New(&log.LoggerOptions{ + Output: logOut, + }) + exp.logger = logger + if err := exp.Restore(nil); err != nil { t.Fatal(err) } @@ -212,9 +221,11 @@ func TestExpiration_Tidy(t *testing.T) { } } - if !(err1 != nil && err1.Error() == "tidy operation on leases is already in progress") && - !(err2 != nil && err2.Error() == "tidy operation on leases is already in progress") { - t.Fatalf("expected at least one of err1 or err2 to be set; err1: %#v\n err2:%#v\n", err1, err2) + if err1 != nil || err2 != nil { + t.Fatalf("got an error: err1: %v; err2: %v", err1, err2) + } + if !strings.Contains(logOut.String(), "tidy operation on leases is already in progress") { + t.Fatalf("expected to see a warning saying operation in progress, output is %s", logOut.String()) } root, err := exp.tokenStore.rootToken(context.Background()) diff --git a/vault/logical_system.go b/vault/logical_system.go index 821d534ef4c5..bd909c1fd07d 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -1182,12 +1182,17 @@ func (b *SystemBackend) handleCORSDelete(ctx context.Context, req *logical.Reque } func (b *SystemBackend) handleTidyLeases(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { - err := b.Core.expiration.Tidy() - if err != nil { - b.Backend.Logger().Error("failed to tidy leases", "error", err) - return handleErrorNoReadOnlyForward(err) - } - return nil, err + go func() { + err := b.Core.expiration.Tidy() + if err != nil { + b.Backend.Logger().Error("failed to tidy leases", "error", err) + return + } + }() + + resp := &logical.Response{} + resp.AddWarning("Tidy operation successfully started. Any information from the operation will be printed to Vault's server logs.") + return resp, nil } func (b *SystemBackend) invalidate(ctx context.Context, key string) { diff --git a/vault/request_forwarding.go b/vault/request_forwarding.go index c4c0e73cd977..3cbd22f33a63 100644 --- a/vault/request_forwarding.go +++ b/vault/request_forwarding.go @@ -76,7 +76,11 @@ func (c *Core) startForwarding(ctx context.Context) error { // duties. Doing it this way instead of listening via the server and gRPC // allows us to re-use the same port via ALPN. We can just tell the server // to serve a given conn and which handler to use. - fws := &http2.Server{} + fws := &http2.Server{ + // Our forwarding connections heartbeat regularly so anything else we + // want to go away/get cleaned up pretty rapidly + IdleTimeout: 5 * HeartbeatInterval, + } // Shutdown coordination logic shutdown := new(uint32) @@ -147,6 +151,20 @@ func (c *Core) startForwarding(ctx context.Context) error { // Type assert to TLS connection and handshake to populate the // connection state tlsConn := conn.(*tls.Conn) + + // Set a deadline for the handshake. This will cause clients + // that don't successfully auth to be kicked out quickly. + // Cluster connections should be reliable so being marginally + // aggressive here is fine. + err = tlsConn.SetDeadline(time.Now().Add(30 * time.Second)) + if err != nil { + if c.logger.IsDebug() { + c.logger.Debug("error setting deadline for cluster connection", "error", err) + } + tlsConn.Close() + continue + } + err = tlsConn.Handshake() if err != nil { if c.logger.IsDebug() { @@ -156,6 +174,16 @@ func (c *Core) startForwarding(ctx context.Context) error { continue } + // Now, set it back to unlimited + err = tlsConn.SetDeadline(time.Time{}) + if err != nil { + if c.logger.IsDebug() { + c.logger.Debug("error setting deadline for cluster connection", "error", err) + } + tlsConn.Close() + continue + } + switch tlsConn.ConnectionState().NegotiatedProtocol { case requestForwardingALPN: if !ha { diff --git a/vault/token_store.go b/vault/token_store.go index ea1c9704e04f..9526c72c0713 100644 --- a/vault/token_store.go +++ b/vault/token_store.go @@ -130,7 +130,7 @@ type TokenStore struct { saltLock sync.RWMutex salt *salt.Salt - tidyLock *int32 + tidyLock *uint32 identityPoliciesDeriverFunc func(string) (*identity.Entity, []string, error) } @@ -150,7 +150,7 @@ func NewTokenStore(ctx context.Context, logger log.Logger, c *Core, config *logi tokensPendingDeletion: &sync.Map{}, saltLock: sync.RWMutex{}, identityPoliciesDeriverFunc: c.fetchEntityAndDerivedPolicies, - tidyLock: new(int32), + tidyLock: new(uint32), } if c.policyStore != nil { @@ -1290,204 +1290,224 @@ func (ts *TokenStore) lookupBySaltedAccessor(ctx context.Context, saltedAccessor // handleTidy handles the cleaning up of leaked accessor storage entries and // cleaning up of leases that are associated to tokens that are expired. func (ts *TokenStore) handleTidy(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - var tidyErrors *multierror.Error - - if !atomic.CompareAndSwapInt32(ts.tidyLock, 0, 1) { - ts.logger.Warn("tidy operation on tokens is already in progress") - return nil, fmt.Errorf("tidy operation on tokens is already in progress") + if !atomic.CompareAndSwapUint32(ts.tidyLock, 0, 1) { + resp := &logical.Response{} + resp.AddWarning("Tidy operation already in progress.") + return resp, nil } - defer atomic.CompareAndSwapInt32(ts.tidyLock, 1, 0) + go func() { + defer atomic.StoreUint32(ts.tidyLock, 0) - ts.logger.Info("beginning tidy operation on tokens") - defer ts.logger.Info("finished tidy operation on tokens") + // Don't cancel when the original client request goes away + ctx = context.Background() - // List out all the accessors - saltedAccessorList, err := ts.view.List(ctx, accessorPrefix) - if err != nil { - return nil, errwrap.Wrapf("failed to fetch accessor index entries: {{err}}", err) - } + logger := ts.logger.Named("tidy") - // First, clean up secondary index entries that are no longer valid - parentList, err := ts.view.List(ctx, parentPrefix) - if err != nil { - return nil, errwrap.Wrapf("failed to fetch secondary index entries: {{err}}", err) - } + var tidyErrors *multierror.Error - var countParentEntries, deletedCountParentEntries, countParentList, deletedCountParentList int64 + doTidy := func() error { - // Scan through the secondary index entries; if there is an entry - // with the token's salt ID at the end, remove it - for _, parent := range parentList { - countParentEntries++ + ts.logger.Info("beginning tidy operation on tokens") + defer ts.logger.Info("finished tidy operation on tokens") - // Get the children - children, err := ts.view.List(ctx, parentPrefix+parent) - if err != nil { - tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to read secondary index: {{err}}", err)) - continue - } - - // First check if the salt ID of the parent exists, and if not mark this so - // that deletion of children later with this loop below applies to all - // children - originalChildrenCount := int64(len(children)) - exists, _ := ts.lookupSalted(ctx, strings.TrimSuffix(parent, "/"), true) - if exists == nil { - ts.logger.Debug("deleting invalid parent prefix entry", "index", parentPrefix+parent) - } - - var deletedChildrenCount int64 - for _, child := range children { - countParentList++ - if countParentList%500 == 0 { - ts.logger.Info("checking validity of tokens in secondary index list", "progress", countParentList) + // List out all the accessors + saltedAccessorList, err := ts.view.List(ctx, accessorPrefix) + if err != nil { + return errwrap.Wrapf("failed to fetch accessor index entries: {{err}}", err) } - // Look up tainted entries so we can be sure that if this isn't - // found, it doesn't exist. Doing the following without locking - // since appropriate locks cannot be held with salted token IDs. - // Also perform deletion if the parent doesn't exist any more. - te, _ := ts.lookupSalted(ctx, child, true) - // If the child entry is not nil, but the parent doesn't exist, then turn - // that child token into an orphan token. Theres no deletion in this case. - if te != nil && exists == nil { - lock := locksutil.LockForKey(ts.tokenLocks, te.ID) - lock.Lock() - - te.Parent = "" - err = ts.store(ctx, te) - if err != nil { - tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to convert child token into an orphan token: {{err}}", err)) - } - lock.Unlock() - continue + // First, clean up secondary index entries that are no longer valid + parentList, err := ts.view.List(ctx, parentPrefix) + if err != nil { + return errwrap.Wrapf("failed to fetch secondary index entries: {{err}}", err) } - // Otherwise, if the entry doesn't exist, or if the parent doesn't exist go - // on with the delete on the secondary index - if te == nil || exists == nil { - index := parentPrefix + parent + child - ts.logger.Debug("deleting invalid secondary index", "index", index) - err = ts.view.Delete(ctx, index) + + var countParentEntries, deletedCountParentEntries, countParentList, deletedCountParentList int64 + + // Scan through the secondary index entries; if there is an entry + // with the token's salt ID at the end, remove it + for _, parent := range parentList { + countParentEntries++ + + // Get the children + children, err := ts.view.List(ctx, parentPrefix+parent) if err != nil { - tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to delete secondary index: {{err}}", err)) + tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to read secondary index: {{err}}", err)) continue } - deletedChildrenCount++ - } - } - // Add current children deleted count to the total count - deletedCountParentList += deletedChildrenCount - // N.B.: We don't call delete on the parent prefix since physical.Backend.Delete - // implementations should be in charge of deleting empty prefixes. - // If we deleted all the children, then add that to our deleted parent entries count. - if originalChildrenCount == deletedChildrenCount { - deletedCountParentEntries++ - } - } - - var countAccessorList, - deletedCountAccessorEmptyToken, - deletedCountAccessorInvalidToken, - deletedCountInvalidTokenInAccessor int64 - - // For each of the accessor, see if the token ID associated with it is - // a valid one. If not, delete the leases associated with that token - // and delete the accessor as well. - for _, saltedAccessor := range saltedAccessorList { - countAccessorList++ - if countAccessorList%500 == 0 { - ts.logger.Info("checking if accessors contain valid tokens", "progress", countAccessorList) - } - accessorEntry, err := ts.lookupBySaltedAccessor(ctx, saltedAccessor, true) - if err != nil { - tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to read the accessor index: {{err}}", err)) - continue - } + // First check if the salt ID of the parent exists, and if not mark this so + // that deletion of children later with this loop below applies to all + // children + originalChildrenCount := int64(len(children)) + exists, _ := ts.lookupSalted(ctx, strings.TrimSuffix(parent, "/"), true) + if exists == nil { + ts.logger.Debug("deleting invalid parent prefix entry", "index", parentPrefix+parent) + } - // A valid accessor storage entry should always have a token ID - // in it. If not, it is an invalid accessor entry and needs to - // be deleted. - if accessorEntry.TokenID == "" { - index := accessorPrefix + saltedAccessor - // If deletion of accessor fails, move on to the next - // item since this is just a best-effort operation - err = ts.view.Delete(ctx, index) - if err != nil { - tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to delete the accessor index: {{err}}", err)) - continue + var deletedChildrenCount int64 + for _, child := range children { + countParentList++ + if countParentList%500 == 0 { + ts.logger.Info("checking validity of tokens in secondary index list", "progress", countParentList) + } + + // Look up tainted entries so we can be sure that if this isn't + // found, it doesn't exist. Doing the following without locking + // since appropriate locks cannot be held with salted token IDs. + // Also perform deletion if the parent doesn't exist any more. + te, _ := ts.lookupSalted(ctx, child, true) + // If the child entry is not nil, but the parent doesn't exist, then turn + // that child token into an orphan token. Theres no deletion in this case. + if te != nil && exists == nil { + lock := locksutil.LockForKey(ts.tokenLocks, te.ID) + lock.Lock() + + te.Parent = "" + err = ts.store(ctx, te) + if err != nil { + tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to convert child token into an orphan token: {{err}}", err)) + } + lock.Unlock() + continue + } + // Otherwise, if the entry doesn't exist, or if the parent doesn't exist go + // on with the delete on the secondary index + if te == nil || exists == nil { + index := parentPrefix + parent + child + ts.logger.Debug("deleting invalid secondary index", "index", index) + err = ts.view.Delete(ctx, index) + if err != nil { + tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to delete secondary index: {{err}}", err)) + continue + } + deletedChildrenCount++ + } + } + // Add current children deleted count to the total count + deletedCountParentList += deletedChildrenCount + // N.B.: We don't call delete on the parent prefix since physical.Backend.Delete + // implementations should be in charge of deleting empty prefixes. + // If we deleted all the children, then add that to our deleted parent entries count. + if originalChildrenCount == deletedChildrenCount { + deletedCountParentEntries++ + } } - deletedCountAccessorEmptyToken++ - } - lock := locksutil.LockForKey(ts.tokenLocks, accessorEntry.TokenID) - lock.RLock() + var countAccessorList, + deletedCountAccessorEmptyToken, + deletedCountAccessorInvalidToken, + deletedCountInvalidTokenInAccessor int64 + + // For each of the accessor, see if the token ID associated with it is + // a valid one. If not, delete the leases associated with that token + // and delete the accessor as well. + for _, saltedAccessor := range saltedAccessorList { + countAccessorList++ + if countAccessorList%500 == 0 { + ts.logger.Info("checking if accessors contain valid tokens", "progress", countAccessorList) + } - // Look up tainted variants so we only find entries that truly don't - // exist - saltedID, err := ts.SaltID(ctx, accessorEntry.TokenID) - if err != nil { - tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to read salt id: {{err}}", err)) - lock.RUnlock() - continue - } - te, err := ts.lookupSalted(ctx, saltedID, true) - if err != nil { - tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to lookup tainted ID: {{err}}", err)) - lock.RUnlock() - continue - } + accessorEntry, err := ts.lookupBySaltedAccessor(ctx, saltedAccessor, true) + if err != nil { + tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to read the accessor index: {{err}}", err)) + continue + } - lock.RUnlock() + // A valid accessor storage entry should always have a token ID + // in it. If not, it is an invalid accessor entry and needs to + // be deleted. + if accessorEntry.TokenID == "" { + index := accessorPrefix + saltedAccessor + // If deletion of accessor fails, move on to the next + // item since this is just a best-effort operation + err = ts.view.Delete(ctx, index) + if err != nil { + tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to delete the accessor index: {{err}}", err)) + continue + } + deletedCountAccessorEmptyToken++ + } - // If token entry is not found assume that the token is not valid any - // more and conclude that accessor, leases, and secondary index entries - // for this token should not exist as well. - if te == nil { - ts.logger.Info("deleting token with nil entry", "salted_token", saltedID) + lock := locksutil.LockForKey(ts.tokenLocks, accessorEntry.TokenID) + lock.RLock() - // RevokeByToken expects a '*logical.TokenEntry'. For the - // purposes of tidying, it is sufficient if the token - // entry only has ID set. - tokenEntry := &logical.TokenEntry{ - ID: accessorEntry.TokenID, - } + // Look up tainted variants so we only find entries that truly don't + // exist + saltedID, err := ts.SaltID(ctx, accessorEntry.TokenID) + if err != nil { + tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to read salt id: {{err}}", err)) + lock.RUnlock() + continue + } + te, err := ts.lookupSalted(ctx, saltedID, true) + if err != nil { + tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to lookup tainted ID: {{err}}", err)) + lock.RUnlock() + continue + } - // Attempt to revoke the token. This will also revoke - // the leases associated with the token. - err := ts.expiration.RevokeByToken(tokenEntry) - if err != nil { - tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to revoke leases of expired token: {{err}}", err)) - continue + lock.RUnlock() + + // If token entry is not found assume that the token is not valid any + // more and conclude that accessor, leases, and secondary index entries + // for this token should not exist as well. + if te == nil { + ts.logger.Info("deleting token with nil entry", "salted_token", saltedID) + + // RevokeByToken expects a '*logical.TokenEntry'. For the + // purposes of tidying, it is sufficient if the token + // entry only has ID set. + tokenEntry := &logical.TokenEntry{ + ID: accessorEntry.TokenID, + } + + // Attempt to revoke the token. This will also revoke + // the leases associated with the token. + err := ts.expiration.RevokeByToken(tokenEntry) + if err != nil { + tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to revoke leases of expired token: {{err}}", err)) + continue + } + deletedCountInvalidTokenInAccessor++ + + index := accessorPrefix + saltedAccessor + + // If deletion of accessor fails, move on to the next item since + // this is just a best-effort operation. We do this last so that on + // next run if something above failed we still have the accessor + // entry to try again. + err = ts.view.Delete(ctx, index) + if err != nil { + tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to delete accessor entry: {{err}}", err)) + continue + } + deletedCountAccessorInvalidToken++ + } } - deletedCountInvalidTokenInAccessor++ - index := accessorPrefix + saltedAccessor + ts.logger.Info("number of entries scanned in parent prefix", "count", countParentEntries) + ts.logger.Info("number of entries deleted in parent prefix", "count", deletedCountParentEntries) + ts.logger.Info("number of tokens scanned in parent index list", "count", countParentList) + ts.logger.Info("number of tokens revoked in parent index list", "count", deletedCountParentList) + ts.logger.Info("number of accessors scanned", "count", countAccessorList) + ts.logger.Info("number of deleted accessors which had empty tokens", "count", deletedCountAccessorEmptyToken) + ts.logger.Info("number of revoked tokens which were invalid but present in accessors", "count", deletedCountInvalidTokenInAccessor) + ts.logger.Info("number of deleted accessors which had invalid tokens", "count", deletedCountAccessorInvalidToken) - // If deletion of accessor fails, move on to the next item since - // this is just a best-effort operation. We do this last so that on - // next run if something above failed we still have the accessor - // entry to try again. - err = ts.view.Delete(ctx, index) - if err != nil { - tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to delete accessor entry: {{err}}", err)) - continue - } - deletedCountAccessorInvalidToken++ + return tidyErrors.ErrorOrNil() } - } - ts.logger.Info("number of entries scanned in parent prefix", "count", countParentEntries) - ts.logger.Info("number of entries deleted in parent prefix", "count", deletedCountParentEntries) - ts.logger.Info("number of tokens scanned in parent index list", "count", countParentList) - ts.logger.Info("number of tokens revoked in parent index list", "count", deletedCountParentList) - ts.logger.Info("number of accessors scanned", "count", countAccessorList) - ts.logger.Info("number of deleted accessors which had empty tokens", "count", deletedCountAccessorEmptyToken) - ts.logger.Info("number of revoked tokens which were invalid but present in accessors", "count", deletedCountInvalidTokenInAccessor) - ts.logger.Info("number of deleted accessors which had invalid tokens", "count", deletedCountAccessorInvalidToken) + if err := doTidy(); err != nil { + logger.Error("error running tidy", "error", err) + return + } + }() - return nil, tidyErrors.ErrorOrNil() + resp := &logical.Response{} + resp.AddWarning("Tidy operation successfully started. Any information from the operation will be printed to Vault's server logs.") + return resp, nil } // handleUpdateLookupAccessor handles the auth/token/lookup-accessor path for returning diff --git a/vault/token_store_test.go b/vault/token_store_test.go index c20c36b0c066..9acff94c2040 100644 --- a/vault/token_store_test.go +++ b/vault/token_store_test.go @@ -3777,6 +3777,9 @@ func TestTokenStore_HandleTidyCase1(t *testing.T) { t.Fatalf("err:%v resp:%v", err, resp) } + // Tidy runs async so give it time + time.Sleep(10 * time.Second) + // Tidy should have removed all the dangling accessor entries resp, err = ts.HandleRequest(context.Background(), accessorListReq) if err != nil || (resp != nil && resp.IsError()) { @@ -3909,6 +3912,9 @@ func TestTokenStore_HandleTidy_parentCleanup(t *testing.T) { t.Fatalf("err:%v resp:%v", err, resp) } + // Tidy runs async so give it time + time.Sleep(10 * time.Second) + // Tidy should have removed all the dangling accessor entries resp, err = ts.HandleRequest(context.Background(), accessorListReq) if err != nil || (resp != nil && resp.IsError()) {