diff --git a/changelog/25219.txt b/changelog/25219.txt new file mode 100644 index 000000000000..bf6ee22794c3 --- /dev/null +++ b/changelog/25219.txt @@ -0,0 +1,3 @@ +```release-note:feature +**Plugin Workload Identity**: Vault can generate identity tokens for plugins to use in workload identity federation auth flows. +``` \ No newline at end of file diff --git a/vault/dynamic_system_view.go b/vault/dynamic_system_view.go index 2feae0a9a400..ee460c0dd517 100644 --- a/vault/dynamic_system_view.go +++ b/vault/dynamic_system_view.go @@ -458,10 +458,19 @@ func (d dynamicSystemView) ClusterID(ctx context.Context) (string, error) { return clusterInfo.ID, nil } -func (d dynamicSystemView) GenerateIdentityToken(_ context.Context, _ *pluginutil.IdentityTokenRequest) (*pluginutil.IdentityTokenResponse, error) { - // TODO: implement plugin identity token generation using identity store +func (d dynamicSystemView) GenerateIdentityToken(ctx context.Context, req *pluginutil.IdentityTokenRequest) (*pluginutil.IdentityTokenResponse, error) { + storage := d.core.router.MatchingStorageByAPIPath(ctx, mountPathIdentity) + if storage == nil { + return nil, fmt.Errorf("failed to find storage entry for identity mount") + } + + token, ttl, err := d.core.IdentityStore().generatePluginIdentityToken(ctx, storage, d.mountEntry, req.Audience, req.TTL) + if err != nil { + return nil, fmt.Errorf("failed to generate plugin identity token: %w", err) + } + return &pluginutil.IdentityTokenResponse{ - Token: "unimplemented", - TTL: time.Duration(0), + Token: pluginutil.IdentityToken(token), + TTL: ttl, }, nil } diff --git a/vault/identity_store.go b/vault/identity_store.go index c57bc9ed2dd0..d936e86c721d 100644 --- a/vault/identity_store.go +++ b/vault/identity_store.go @@ -64,6 +64,7 @@ func NewIdentityStore(ctx context.Context, core *Core, config *logical.BackendCo groupUpdater: core, tokenStorer: core, entityCreator: core, + mountLister: core, mfaBackend: core.loginMFABackend, } diff --git a/vault/identity_store_oidc.go b/vault/identity_store_oidc.go index f650b26b6edf..f58c3a7b0dd7 100644 --- a/vault/identity_store_oidc.go +++ b/vault/identity_store_oidc.go @@ -34,6 +34,7 @@ import ( "github.com/hashicorp/vault/sdk/logical" "github.com/patrickmn/go-cache" "golang.org/x/crypto/ed25519" + "golang.org/x/exp/maps" ) type oidcConfig struct { @@ -126,23 +127,12 @@ type oidcCache struct { c *cache.Cache } -var errNilNamespace = errors.New("nil namespace in oidc cache request") - -const ( - issuerPath = "identity/oidc" - oidcTokensPrefix = "oidc_tokens/" - namedKeyCachePrefix = "namedKeys/" - oidcConfigStorageKey = oidcTokensPrefix + "config/" - namedKeyConfigPath = oidcTokensPrefix + "named_keys/" - publicKeysConfigPath = oidcTokensPrefix + "public_keys/" - roleConfigPath = oidcTokensPrefix + "roles/" +var ( + errNilNamespace = errors.New("nil namespace in oidc cache request") - // Identity tokens have a base issuer and plugin issuer - baseIdentityTokenIssuer = "" - pluginIdentityTokenIssuer = "plugins" -) + // pseudo-namespace for cache items that don't belong to any real namespace. + noNamespace = &namespace.Namespace{ID: "__NO_NAMESPACE"} -var ( reservedClaims = []string{ "iat", "aud", "exp", "iss", "sub", "namespace", "nonce", @@ -159,8 +149,24 @@ var ( } ) -// pseudo-namespace for cache items that don't belong to any real namespace. -var noNamespace = &namespace.Namespace{ID: "__NO_NAMESPACE"} +const ( + issuerPath = "identity/oidc" + oidcTokensPrefix = "oidc_tokens/" + namedKeyCachePrefix = "namedKeys/" + oidcConfigStorageKey = oidcTokensPrefix + "config/" + namedKeyConfigPath = oidcTokensPrefix + "named_keys/" + publicKeysConfigPath = oidcTokensPrefix + "public_keys/" + roleConfigPath = oidcTokensPrefix + "roles/" + + // Identity tokens have a base issuer and plugin issuer + baseIdentityTokenIssuer = "" + pluginIdentityTokenIssuer = "plugins" + + pluginTokenSubjectPrefix = "plugin-identity" + pluginTokenPrivateClaimKey = "vaultproject.io" + secretTableValue = "secret" + deleteKeyErrorFmt = "unable to delete key %q because it is currently referenced by these %s: %s" +) // optionalChildIssuerRegex is a regex for optionally accepting a field in an // API request as a single path segment. Adapted from framework.OptionalParamRegex @@ -784,6 +790,56 @@ func (i *IdentityStore) roleNamesReferencingTargetKeyName(ctx context.Context, r return names, nil } +// listMounts returns all mount entries in the namespace. +// Returns an error if the namespace is nil. +func (i *IdentityStore) listMounts(ns *namespace.Namespace) ([]*MountEntry, error) { + if ns == nil { + return nil, errors.New("namespace must not be nil") + } + + secretMounts, err := i.mountLister.ListMounts() + if err != nil { + return nil, err + } + authMounts, err := i.mountLister.ListAuths() + if err != nil { + return nil, err + } + + var allMounts []*MountEntry + for _, mount := range append(authMounts, secretMounts...) { + if mount.NamespaceID == ns.ID { + allMounts = append(allMounts, mount) + } + } + + return allMounts, nil +} + +// mountsReferencingKey returns a sorted list of all mount entry paths referencing +// the key in the namespace. Returns an error if the namespace is nil. +func (i *IdentityStore) mountsReferencingKey(ns *namespace.Namespace, key string) ([]string, error) { + if ns == nil { + return nil, errors.New("namespace must not be nil") + } + + allMounts, err := i.listMounts(ns) + if err != nil { + return nil, err + } + + pathsWithKey := make(map[string]struct{}) + for _, mount := range allMounts { + if mount.Config.IdentityTokenKey == key { + pathsWithKey[mount.Path] = struct{}{} + } + } + + paths := maps.Keys(pathsWithKey) + sort.Strings(paths) + return paths, nil +} + // handleOIDCDeleteKey is used to delete a key func (i *IdentityStore) pathOIDCDeleteKey(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { ns, err := namespace.FromContext(ctx) @@ -807,8 +863,8 @@ func (i *IdentityStore) pathOIDCDeleteKey(ctx context.Context, req *logical.Requ } if len(roleNames) > 0 { - errorMessage := fmt.Sprintf("unable to delete key %q because it is currently referenced by these roles: %s", - targetKeyName, strings.Join(roleNames, ", ")) + errorMessage := fmt.Sprintf(deleteKeyErrorFmt, + targetKeyName, "roles", strings.Join(roleNames, ", ")) i.oidcLock.Unlock() return logical.ErrorResponse(errorMessage), logical.ErrInvalidRequest } @@ -820,8 +876,20 @@ func (i *IdentityStore) pathOIDCDeleteKey(ctx context.Context, req *logical.Requ } if len(clientNames) > 0 { - errorMessage := fmt.Sprintf("unable to delete key %q because it is currently referenced by these clients: %s", - targetKeyName, strings.Join(clientNames, ", ")) + errorMessage := fmt.Sprintf(deleteKeyErrorFmt, + targetKeyName, "clients", strings.Join(clientNames, ", ")) + i.oidcLock.Unlock() + return logical.ErrorResponse(errorMessage), logical.ErrInvalidRequest + } + + mounts, err := i.mountsReferencingKey(ns, targetKeyName) + if err != nil { + i.oidcLock.Unlock() + return nil, err + } + if len(mounts) > 0 { + errorMessage := fmt.Sprintf(deleteKeyErrorFmt, + targetKeyName, "mounts", strings.Join(mounts, ", ")) i.oidcLock.Unlock() return logical.ErrorResponse(errorMessage), logical.ErrInvalidRequest } @@ -1028,6 +1096,99 @@ func (i *IdentityStore) pathOIDCGenerateToken(ctx context.Context, req *logical. return retResp, nil } +func (i *IdentityStore) generatePluginIdentityToken(ctx context.Context, storage logical.Storage, me *MountEntry, audience string, ttl time.Duration) (string, time.Duration, error) { + ns, err := namespace.FromContext(ctx) + if err != nil { + return "", 0, err + } + + if me == nil { + i.Logger().Error("unexpected nil mount entry when generating plugin identity token") + return "", 0, errors.New("mount entry must not be nil") + } + + key := defaultKeyName + if me.Config.IdentityTokenKey != "" { + key = me.Config.IdentityTokenKey + } + if ttl == 0 { + ttl = time.Hour + } + namedKey, err := i.getNamedKey(ctx, storage, key) + if err != nil { + return "", 0, err + } + if namedKey == nil { + return "", 0, fmt.Errorf("key %q not found", key) + } + + // Validate that the role is allowed to sign with its key (the key could have been updated) + if !strutil.StrListContains(namedKey.AllowedClientIDs, "*") && !strutil.StrListContains(namedKey.AllowedClientIDs, audience) { + return "", 0, fmt.Errorf("the key %q does not list %q as an allowed audience", key, audience) + } + + config, err := i.getOIDCConfig(ctx, storage) + if err != nil { + return "", 0, err + } + + // Cap the TTL to the key's verification TTL. This is the maximum amount of + // time the key will remain in the JWKS after it's been rotated. + if ttl > namedKey.VerificationTTL { + ttl = namedKey.VerificationTTL + } + + // Tokens for plugins have a distinct issuer from Vault's identity token issuer + issuer, err := config.fullIssuer(pluginIdentityTokenIssuer) + if err != nil { + return "", 0, err + } + + // The subject uniquely identifies the plugin + subject := fmt.Sprintf("%s:%s:%s:%s", pluginTokenSubjectPrefix, ns.ID, + translateTableClaim(me.Table), me.Accessor) + + now := time.Now() + claims := map[string]any{ + "iss": issuer, + "sub": subject, + "aud": []string{audience}, + "nbf": now.Unix(), + "iat": now.Unix(), + "exp": now.Add(ttl).Unix(), + pluginTokenPrivateClaimKey: map[string]any{ + "namespace_id": ns.ID, + "namespace_path": ns.Path, + "class": translateTableClaim(me.Table), + "plugin": me.Type, + "version": me.RunningVersion, + "path": me.Path, + "accessor": me.Accessor, + "local": me.Local, + }, + } + payload, err := json.Marshal(claims) + if err != nil { + return "", 0, err + } + + signedToken, err := namedKey.signPayload(payload) + if err != nil { + return "", 0, fmt.Errorf("error signing plugin identity token: %w", err) + } + + return signedToken, ttl, nil +} + +func translateTableClaim(table string) string { + switch table { + case mountTableType: + return secretTableValue + default: + return table + } +} + func (i *IdentityStore) getNamedKey(ctx context.Context, s logical.Storage, name string) (*namedKey, error) { ns, err := namespace.FromContext(ctx) if err != nil { @@ -1804,14 +1965,16 @@ func (i *IdentityStore) generatePublicJWKS(ctx context.Context, s logical.Storag return nil, err } - // only return keys that are associated with a role + // Only return keys that are associated with a role or plugin mount + // by collecting and de-duplicating keys and key IDs for each + keyNames := make(map[string]struct{}) + keyIDs := make(map[string]struct{}) + + // First collect the set of unique key names roleNames, err := s.List(ctx, roleConfigPath) if err != nil { return nil, err } - - // collect and deduplicate the key IDs for all roles - keyIDs := make(map[string]struct{}) for _, roleName := range roleNames { role, err := i.getOIDCRole(ctx, s, roleName) if err != nil { @@ -1821,13 +1984,30 @@ func (i *IdentityStore) generatePublicJWKS(ctx context.Context, s logical.Storag continue } - roleKeyIDs, err := i.keyIDsByName(ctx, s, role.Key) + keyNames[role.Key] = struct{}{} + } + mounts, err := i.listMounts(ns) + if err != nil { + return nil, err + } + for _, me := range mounts { + key := defaultKeyName + if me.Config.IdentityTokenKey != "" { + key = me.Config.IdentityTokenKey + } + + keyNames[key] = struct{}{} + } + + // Second collect the set of unique key IDs for each key name + for name := range keyNames { + ids, err := i.keyIDsByName(ctx, s, name) if err != nil { return nil, err } - for _, keyID := range roleKeyIDs { - keyIDs[keyID] = struct{}{} + for _, id := range ids { + keyIDs[id] = struct{}{} } } diff --git a/vault/identity_store_oidc_provider.go b/vault/identity_store_oidc_provider.go index ae2b41c561ce..3882ab99cad4 100644 --- a/vault/identity_store_oidc_provider.go +++ b/vault/identity_store_oidc_provider.go @@ -2629,10 +2629,6 @@ func (i *IdentityStore) lazyGenerateDefaultKey(ctx context.Context, storage logi return err } - if err := i.oidcCache.Delete(ns, namedKeyCachePrefix+defaultKeyName); err != nil { - return err - } - entry, err := logical.StorageEntryJSON(namedKeyConfigPath+defaultKeyName, defaultKey) if err != nil { return err @@ -2640,6 +2636,10 @@ func (i *IdentityStore) lazyGenerateDefaultKey(ctx context.Context, storage logi if err := storage.Put(ctx, entry); err != nil { return err } + + if err := i.oidcCache.Flush(ns); err != nil { + return err + } } return nil diff --git a/vault/identity_store_oidc_provider_test.go b/vault/identity_store_oidc_provider_test.go index 7fcc8fa48a16..2296c8aa7cc5 100644 --- a/vault/identity_store_oidc_provider_test.go +++ b/vault/identity_store_oidc_provider_test.go @@ -1176,7 +1176,8 @@ func setupOIDCCommon(t *testing.T, c *Core, s logical.Storage) (string, string, ctx := namespace.RootContext(nil) // Create a key - resp, err := c.identityStore.HandleRequest(ctx, testKeyReq(s, []string{"*"}, "RS256")) + resp, err := c.identityStore.HandleRequest(ctx, testKeyReq(s, "test-key", + []string{"*"}, "RS256")) expectSuccess(t, resp, err) // Create an entity @@ -1359,10 +1360,10 @@ func testEntityReq(s logical.Storage) *logical.Request { } } -func testKeyReq(s logical.Storage, allowedClientIDs []string, alg string) *logical.Request { +func testKeyReq(s logical.Storage, name string, allowedClientIDs []string, alg string) *logical.Request { return &logical.Request{ Storage: s, - Path: "oidc/key/test-key", + Path: fmt.Sprintf("oidc/key/%s", name), Operation: logical.CreateOperation, Data: map[string]interface{}{ "allowed_client_ids": allowedClientIDs, diff --git a/vault/identity_store_oidc_test.go b/vault/identity_store_oidc_test.go index 457d1f1c9910..834c9f83bc58 100644 --- a/vault/identity_store_oidc_test.go +++ b/vault/identity_store_oidc_test.go @@ -5,6 +5,7 @@ package vault import ( "context" + "crypto" "encoding/json" "fmt" "regexp" @@ -16,7 +17,9 @@ import ( "github.com/go-jose/go-jose/v3" "github.com/go-jose/go-jose/v3/jwt" "github.com/go-test/deep" + capjwt "github.com/hashicorp/cap/jwt" "github.com/hashicorp/go-hclog" + credUserpass "github.com/hashicorp/vault/builtin/credential/userpass" "github.com/hashicorp/vault/helper/identity" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/sdk/framework" @@ -390,8 +393,60 @@ func TestOIDC_Path_OIDCRole(t *testing.T) { expectStrings(t, respListRoleAfterDelete.Data["keys"].([]string), expectedStrings) } -// TestOIDC_Path_OIDCKeyKey tests CRUD operations for keys -func TestOIDC_Path_OIDCKeyKey(t *testing.T) { +// TestOIDC_DeleteKeyWithMountReference ensures that keys cannot be deleted +// if they're referenced by mounts for plugin identity tokens. +func TestOIDC_DeleteKeyWithMountReference(t *testing.T) { + ctx := namespace.RootContext(nil) + core, _, _ := TestCoreUnsealed(t) + core.credentialBackends["userpass"] = credUserpass.Factory + idStorage := core.router.MatchingStorageByAPIPath(ctx, mountPathIdentity) + require.NotNil(t, idStorage) + + tests := []struct { + name string + mountPrefix string + mountType string + keyName string + }{ + { + name: "delete key referenced by auth mount does not succeed", + mountPrefix: "auth/", + mountType: "userpass/", + keyName: "test-key-1", + }, + { + name: "delete key referenced by secret mount does not succeed", + mountPrefix: "mounts/", + mountType: "kv/", + keyName: "test-key-2", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resp, err := core.identityStore.HandleRequest(ctx, testKeyReq(idStorage, tt.keyName, + []string{"*"}, "RS256")) + expectSuccess(t, resp, err) + + createMountEntryWithKey(t, ctx, core.systemBackend, tt.mountPrefix, tt.mountType, tt.keyName) + require.NoError(t, err) + require.Nil(t, resp) + + // Deleting the key must not succeed + resp, err = core.identityStore.HandleRequest(ctx, &logical.Request{ + Path: fmt.Sprintf("oidc/key/%s", tt.keyName), + Operation: logical.DeleteOperation, + Storage: idStorage, + }) + expectError(t, resp, err) + require.Equal(t, fmt.Sprintf(deleteKeyErrorFmt, tt.keyName, "mounts", tt.mountType), + resp.Error().Error()) + }) + } +} + +// TestOIDC_Path_CRUDKey tests CRUD operations for keys +func TestOIDC_Path_CRUDKey(t *testing.T) { c, _, _ := TestCoreUnsealed(t) ctx := namespace.RootContext(nil) storage := &logical.InmemStorage{} @@ -461,7 +516,6 @@ func TestOIDC_Path_OIDCKeyKey(t *testing.T) { Storage: storage, }) expectSuccess(t, resp, err) - // fmt.Printf("resp is:\n%#v", resp) // Delete test-key -- should fail because test-role depends on test-key resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{ @@ -558,8 +612,8 @@ func TestOIDC_Path_OIDCKey_InvalidTokenTTL(t *testing.T) { expectError(t, resp, err) } -// TestOIDC_Path_OIDCKey tests the List operation for keys -func TestOIDC_Path_OIDCKey(t *testing.T) { +// TestOIDC_Path_ListKey tests the List operation for keys +func TestOIDC_Path_ListKey(t *testing.T) { c, _, _ := TestCoreUnsealed(t) ctx := namespace.RootContext(nil) storage := &logical.InmemStorage{} @@ -1844,3 +1898,177 @@ func Test_optionalChildIssuerRegex(t *testing.T) { }) } } + +// TestIdentityStore_generatePluginIdentityToken tests generation of plugin identity +// tokens by verifying signatures and validating claims. +func TestIdentityStore_generatePluginIdentityToken(t *testing.T) { + core, _, _ := TestCoreUnsealed(t) + core.credentialBackends["userpass"] = credUserpass.Factory + identityStore := core.IdentityStore() + identityStore.redirectAddr = "http://localhost:8200" + ctx := namespace.RootContext(nil) + storage := core.router.MatchingStorageByAPIPath(ctx, mountPathIdentity) + require.NotNil(t, storage) + + // Create a key + testKey := "test-key" + testAudience := "allowed-audience" + resp, err := core.identityStore.HandleRequest(ctx, testKeyReq(storage, testKey, + []string{testAudience}, "RS256")) + expectSuccess(t, resp, err) + + // Enable a secret mount using the test key + createMountEntryWithKey(t, ctx, core.systemBackend, "mounts/", "kv/", testKey) + expectSuccess(t, resp, err) + secretMountEntry := core.router.MatchingMountEntry(ctx, "kv/") + require.NotNil(t, secretMountEntry) + + // Enable an auth mount using the default key + createMountEntryWithKey(t, ctx, core.systemBackend, "auth/", "userpass/", defaultKeyName) + expectSuccess(t, resp, err) + authMountEntry := core.router.MatchingMountEntry(ctx, "auth/userpass/") + require.NotNil(t, authMountEntry) + + tests := []struct { + name string + ctx context.Context + mountEntry *MountEntry + audience string + ttl time.Duration + wantErr bool + }{ + { + name: "expect error with nil context", + ctx: nil, + wantErr: true, + }, + { + name: "expect error with nil mount entry", + ctx: ctx, + mountEntry: nil, + wantErr: true, + }, + { + name: "expect error with key that doesn't exist", + ctx: ctx, + mountEntry: &MountEntry{ + Config: MountConfig{ + IdentityTokenKey: "does-not-exist", + }, + }, + wantErr: true, + }, + { + name: "expect error with audience that's not allowed by the key", + ctx: ctx, + mountEntry: secretMountEntry, + audience: "not-allowed-audience", + wantErr: true, + }, + { + name: "expect valid identity token with secret mount using test key", + ctx: ctx, + mountEntry: secretMountEntry, + audience: testAudience, + }, + { + name: "expect valid identity token with auth mount using default key", + ctx: ctx, + mountEntry: authMountEntry, + audience: testAudience, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + token, _, err := identityStore.generatePluginIdentityToken(tt.ctx, storage, tt.mountEntry, + tt.audience, tt.ttl) + if tt.wantErr { + require.Error(t, err) + require.Empty(t, token) + return + } + + require.NoError(t, err) + require.NotEmpty(t, token) + + // Verify the signature and claims of the token + key, err := identityStore.getNamedKey(ctx, storage, tt.mountEntry.Config.IdentityTokenKey) + require.NoError(t, err) + keySet, err := capjwt.NewStaticKeySet([]crypto.PublicKey{key.SigningKey.Public()}) + require.NoError(t, err) + + validator, err := capjwt.NewValidator(keySet) + require.NoError(t, err) + expected := capjwt.Expected{ + Issuer: fmt.Sprintf("%s/v1/identity/oidc/plugins", identityStore.redirectAddr), + Subject: fmt.Sprintf("%s:%s:%s:%s", pluginTokenSubjectPrefix, namespace.RootNamespace.ID, + translateTableClaim(tt.mountEntry.Table), tt.mountEntry.Accessor), + Audiences: []string{tt.audience}, + SigningAlgorithms: []capjwt.Alg{capjwt.RS256}, + } + + claims, err := validator.Validate(ctx, token, expected) + require.NoError(t, err) + require.Contains(t, claims, pluginTokenPrivateClaimKey) + require.IsType(t, map[string]interface{}{}, claims[pluginTokenPrivateClaimKey]) + + vaultSubClaims := claims[pluginTokenPrivateClaimKey].(map[string]interface{}) + require.Equal(t, namespace.RootNamespace.ID, vaultSubClaims["namespace_id"]) + require.Equal(t, namespace.RootNamespace.Path, vaultSubClaims["namespace_path"]) + require.Equal(t, translateTableClaim(tt.mountEntry.Table), vaultSubClaims["class"]) + require.Equal(t, tt.mountEntry.Type, vaultSubClaims["plugin"]) + require.Equal(t, tt.mountEntry.RunningVersion, vaultSubClaims["version"]) + require.Equal(t, tt.mountEntry.Path, vaultSubClaims["path"]) + require.Equal(t, tt.mountEntry.Accessor, vaultSubClaims["accessor"]) + require.Equal(t, tt.mountEntry.Local, vaultSubClaims["local"]) + }) + } +} + +func createMountEntryWithKey(t *testing.T, ctx context.Context, sys *SystemBackend, mountPrefix, mountType, key string) { + t.Helper() + + resp, err := sys.HandleRequest(ctx, &logical.Request{ + Path: mountPrefix + mountType, + Operation: logical.UpdateOperation, + Storage: new(logical.InmemStorage), + Data: map[string]interface{}{ + "type": strings.TrimSuffix(mountType, "/"), + "config": map[string]interface{}{ + "identity_token_key": key, + }, + }, + }) + expectSuccess(t, resp, err) +} + +// Test_translateTableClaim tests that we convert mount entry table +// values to expected claim values. +func Test_translateTableClaim(t *testing.T) { + tests := []struct { + name string + table string + want string + }{ + { + name: "given mounts table returns secret", + table: mountTableType, + want: secretTableValue, + }, + { + name: "given auth table returns auth", + table: "auth", + want: "auth", + }, + { + name: "given any value returns itself", + table: "other", + want: "other", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, translateTableClaim(tt.table), "translateTableClaim(%v)", tt.table) + }) + } +} diff --git a/vault/identity_store_structs.go b/vault/identity_store_structs.go index bba2060c01c2..d60427b9f6c5 100644 --- a/vault/identity_store_structs.go +++ b/vault/identity_store_structs.go @@ -102,6 +102,7 @@ type IdentityStore struct { groupUpdater GroupUpdater tokenStorer TokenStorer entityCreator EntityCreator + mountLister MountLister mfaBackend *LoginMFABackend } @@ -153,3 +154,10 @@ type EntityCreator interface { } var _ EntityCreator = &Core{} + +type MountLister interface { + ListMounts() ([]*MountEntry, error) + ListAuths() ([]*MountEntry, error) +} + +var _ MountLister = &Core{} diff --git a/vault/logical_system.go b/vault/logical_system.go index 889825ae54c0..7812861227e4 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -1619,15 +1619,16 @@ func (b *SystemBackend) handleMount(ctx context.Context, req *logical.Request, d config.DelegatedAuthAccessors = apiConfig.DelegatedAuthAccessors } - if apiConfig.IdentityTokenKey != "" { - storage := b.Core.router.MatchingStorageByAPIPath(ctx, mountPathIdentity) - if storage == nil { - return nil, errors.New("failed to find identity storage") - } + storage := b.Core.router.MatchingStorageByAPIPath(ctx, mountPathIdentity) + if storage == nil { + return nil, errors.New("failed to find identity storage") + } - identityStore := b.Core.IdentityStore() - identityStore.oidcLock.RLock() - defer identityStore.oidcLock.RUnlock() + // Ensure that the mount's identity token key exists + identityStore := b.Core.IdentityStore() + identityStore.oidcLock.Lock() + defer identityStore.oidcLock.Unlock() + if apiConfig.IdentityTokenKey != "" { k, err := identityStore.getNamedKey(ctx, storage, apiConfig.IdentityTokenKey) if err != nil { return nil, fmt.Errorf("failed getting key %q: %w", apiConfig.IdentityTokenKey, err) @@ -1639,6 +1640,15 @@ func (b *SystemBackend) handleMount(ctx context.Context, req *logical.Request, d config.IdentityTokenKey = apiConfig.IdentityTokenKey } + // Don't lazily generate the default OIDC key for KV mounts. A default KV mount + // is enabled in dev and test servers. We don't want to pay the cost of key + // generation for that KV mount in all tests. + if config.usingOIDCDefaultKey() && logicalType != mountTypeKV { + if err := identityStore.lazyGenerateDefaultKey(ctx, storage); err != nil { + return nil, fmt.Errorf("failed to generate default key: %w", err) + } + } + // Create the mount entry me := &MountEntry{ Table: mountTableType, @@ -2431,15 +2441,16 @@ func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string, if rawVal, ok := data.GetOk("identity_token_key"); ok { identityTokenKey := rawVal.(string) - if identityTokenKey != "" { - storage := b.Core.router.MatchingStorageByAPIPath(ctx, mountPathIdentity) - if storage == nil { - return nil, errors.New("failed to find identity storage") - } + storage := b.Core.router.MatchingStorageByAPIPath(ctx, mountPathIdentity) + if storage == nil { + return nil, errors.New("failed to find identity storage") + } - identityStore := b.Core.IdentityStore() - identityStore.oidcLock.RLock() - defer identityStore.oidcLock.RUnlock() + // Ensure that the mount's identity token key exists + identityStore := b.Core.IdentityStore() + identityStore.oidcLock.Lock() + defer identityStore.oidcLock.Unlock() + if identityTokenKey != "" { k, err := identityStore.getNamedKey(ctx, storage, identityTokenKey) if err != nil { return nil, fmt.Errorf("failed getting key %q: %w", identityTokenKey, err) @@ -2452,6 +2463,13 @@ func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string, oldVal := mountEntry.Config.IdentityTokenKey mountEntry.Config.IdentityTokenKey = identityTokenKey + if mountEntry.Config.usingOIDCDefaultKey() { + if err := identityStore.lazyGenerateDefaultKey(ctx, storage); err != nil { + mountEntry.Config.IdentityTokenKey = oldVal + return nil, fmt.Errorf("failed to generate default key: %w", err) + } + } + // Update the mount table var err error switch { @@ -3227,15 +3245,16 @@ func (b *SystemBackend) handleEnableAuth(ctx context.Context, req *logical.Reque config.AllowedManagedKeys = apiConfig.AllowedManagedKeys } - if apiConfig.IdentityTokenKey != "" { - storage := b.Core.router.MatchingStorageByAPIPath(ctx, mountPathIdentity) - if storage == nil { - return nil, errors.New("failed to find identity storage") - } + storage := b.Core.router.MatchingStorageByAPIPath(ctx, mountPathIdentity) + if storage == nil { + return nil, errors.New("failed to find identity storage") + } - identityStore := b.Core.IdentityStore() - identityStore.oidcLock.RLock() - defer identityStore.oidcLock.RUnlock() + // Ensure that the mount's identity token key exists + identityStore := b.Core.IdentityStore() + identityStore.oidcLock.Lock() + defer identityStore.oidcLock.Unlock() + if apiConfig.IdentityTokenKey != "" { k, err := identityStore.getNamedKey(ctx, storage, apiConfig.IdentityTokenKey) if err != nil { return nil, fmt.Errorf("failed getting key %q: %w", apiConfig.IdentityTokenKey, err) @@ -3246,6 +3265,11 @@ func (b *SystemBackend) handleEnableAuth(ctx context.Context, req *logical.Reque config.IdentityTokenKey = apiConfig.IdentityTokenKey } + if config.usingOIDCDefaultKey() { + if err := identityStore.lazyGenerateDefaultKey(ctx, storage); err != nil { + return nil, fmt.Errorf("failed to generate default key: %w", err) + } + } // Create the mount entry me := &MountEntry{ diff --git a/vault/mount.go b/vault/mount.go index f9caf4f46139..bc6193692969 100644 --- a/vault/mount.go +++ b/vault/mount.go @@ -375,6 +375,10 @@ type MountConfig struct { PluginName string `json:"plugin_name,omitempty" structs:"plugin_name,omitempty" mapstructure:"plugin_name"` } +func (c *MountConfig) usingOIDCDefaultKey() bool { + return c.IdentityTokenKey == "" || c.IdentityTokenKey == defaultKeyName +} + type UserLockoutConfig struct { LockoutThreshold uint64 `json:"lockout_threshold,omitempty" structs:"lockout_threshold" mapstructure:"lockout_threshold"` LockoutDuration time.Duration `json:"lockout_duration,omitempty" structs:"lockout_duration" mapstructure:"lockout_duration"`