diff --git a/api/internal/testutil/server.go b/api/internal/testutil/server.go index b6abd6dad77..39294b5472f 100644 --- a/api/internal/testutil/server.go +++ b/api/internal/testutil/server.go @@ -225,7 +225,7 @@ func NewTestServer(t testing.T, cb ServerConfigCallback) *TestServer { // Wait for the server to be ready if nomadConfig.Server.Enabled && nomadConfig.Server.BootstrapExpect != 0 { - server.waitForLeader() + server.waitForServers() } else { server.waitForAPI() } @@ -296,21 +296,29 @@ func (s *TestServer) waitForAPI() { ) } -// waitForLeader waits for the Nomad server's HTTP API to become -// available, and then waits for a known leader and an index of -// 1 or more to be observed to confirm leader election is done. -func (s *TestServer) waitForLeader() { +// waitForServers waits for the Nomad server's HTTP API to become available, +// and then waits for the keyring to be intialized. This implies a leader has +// been elected and Raft writes have occurred. +func (s *TestServer) waitForServers() { f := func() error { - // Query the API and check the status code - // Using this endpoint as it is does not have restricted access - resp, err := s.HTTPClient.Get(s.url("/v1/status/leader")) + resp, err := s.HTTPClient.Get(s.url("/.well-known/jwks.json")) if err != nil { - return fmt.Errorf("failed to get leader: %w", err) + return fmt.Errorf("failed to contact leader: %w", err) } defer func() { _ = resp.Body.Close() }() if err = s.requireOK(resp); err != nil { return fmt.Errorf("leader response is not ok: %w", err) } + + jwks := struct { + Keys []interface{} `json:"keys"` + }{} + if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil { + return fmt.Errorf("error decoding jwks response: %w", err) + } + if len(jwks.Keys) == 0 { + return fmt.Errorf("no keys found") + } return nil } must.Wait(s.t, diff --git a/client/vaultclient/vaultclient_test.go b/client/vaultclient/vaultclient_test.go index 580a8683f10..c63ef3f5327 100644 --- a/client/vaultclient/vaultclient_test.go +++ b/client/vaultclient/vaultclient_test.go @@ -32,7 +32,7 @@ const ( jwtAuthConfigTemplate = ` { "jwks_url": "<<.JWKSURL>>", - "jwt_supported_algs": ["EdDSA"], + "jwt_supported_algs": ["RS256", "EdDSA"], "default_role": "nomad-workloads" } ` diff --git a/client/widmgr/mock.go b/client/widmgr/mock.go index 73ebf7c5309..90469bb1792 100644 --- a/client/widmgr/mock.go +++ b/client/widmgr/mock.go @@ -4,7 +4,8 @@ package widmgr import ( - "crypto/ed25519" + "crypto/rand" + "crypto/rsa" "fmt" "slices" "time" @@ -22,13 +23,13 @@ type MockWIDSigner struct { // SignIdentities will use it to find expirations or reject invalid identity // names wids map[string]*structs.WorkloadIdentity - key ed25519.PrivateKey + key *rsa.PrivateKey keyID string mockNow time.Time // allows moving the clock } func NewMockWIDSigner(wids []*structs.WorkloadIdentity) *MockWIDSigner { - _, privKey, err := ed25519.GenerateKey(nil) + privKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { panic(err) } @@ -63,7 +64,7 @@ func (m *MockWIDSigner) JSONWebKeySet() *jose.JSONWebKeySet { jwk := jose.JSONWebKey{ Key: m.key.Public(), KeyID: m.keyID, - Algorithm: "EdDSA", + Algorithm: "RS256", Use: "sig", } return &jose.JSONWebKeySet{ @@ -95,7 +96,7 @@ func (m *MockWIDSigner) SignIdentities(minIndex uint64, req []*structs.WorkloadI } } opts := (&jose.SignerOptions{}).WithHeader("kid", m.keyID).WithType("JWT") - sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.EdDSA, Key: m.key}, opts) + sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: m.key}, opts) if err != nil { return nil, fmt.Errorf("error creating signer: %w", err) } diff --git a/command/agent/http_test.go b/command/agent/http_test.go index 03dd6b3a108..528c08f6d84 100644 --- a/command/agent/http_test.go +++ b/command/agent/http_test.go @@ -1560,7 +1560,7 @@ func benchmarkJsonEncoding(b *testing.B, handle *codec.JsonHandle) { func httpTest(t testing.TB, cb func(c *Config), f func(srv *TestAgent)) { s := makeHTTPServer(t, cb) defer s.Shutdown() - testutil.WaitForLeader(t, s.Agent.RPC) + testutil.WaitForKeyring(t, s.Agent.RPC, s.Config.Region) f(s) } @@ -1572,7 +1572,7 @@ func httpACLTest(t testing.TB, cb func(c *Config), f func(srv *TestAgent)) { } }) defer s.Shutdown() - testutil.WaitForLeader(t, s.Agent.RPC) + testutil.WaitForKeyring(t, s.Agent.RPC, s.Config.Region) f(s) } diff --git a/command/agent/testagent.go b/command/agent/testagent.go index 9f378b737c6..ec8dd5525e1 100644 --- a/command/agent/testagent.go +++ b/command/agent/testagent.go @@ -192,15 +192,7 @@ RETRY: failed := false if a.Config.NomadConfig.BootstrapExpect == 1 && a.Config.Server.Enabled { - testutil.WaitForResult(func() (bool, error) { - args := &structs.GenericRequest{} - var leader string - err := a.RPC("Status.Leader", args, &leader) - return leader != "", err - }, func(err error) { - a.T.Logf("failed to find leader: %v", err) - failed = true - }) + testutil.WaitForKeyring(a.T, a.RPC, a.Config.Region) } else { testutil.WaitForResult(func() (bool, error) { req, _ := http.NewRequest(http.MethodGet, "/v1/agent/self", nil) diff --git a/command/recommendation_dismiss_test.go b/command/recommendation_dismiss_test.go index 78365a4ecff..9393ddfafb3 100644 --- a/command/recommendation_dismiss_test.go +++ b/command/recommendation_dismiss_test.go @@ -96,6 +96,7 @@ func TestRecommendationDismissCommand_Run(t *testing.T) { } func TestRecommendationDismissCommand_AutocompleteArgs(t *testing.T) { + ci.Parallel(t) srv, client, url := testServer(t, false, nil) defer srv.Shutdown() @@ -113,7 +114,6 @@ func TestRecommendationDismissCommand_AutocompleteArgs(t *testing.T) { } func testRecommendationAutocompleteCommand(t *testing.T, client *api.Client, srv *agent.TestAgent, cmd *RecommendationAutocompleteCommand) { - ci.Parallel(t) require := require.New(t) // Register a test job to write a recommendation against. diff --git a/command/var_list_test.go b/command/var_list_test.go index b4525070c19..ce18d99e256 100644 --- a/command/var_list_test.go +++ b/command/var_list_test.go @@ -12,6 +12,7 @@ import ( "github.com/hashicorp/nomad/api" "github.com/hashicorp/nomad/ci" "github.com/mitchellh/cli" + "github.com/shoenig/test/must" "github.com/stretchr/testify/require" ) @@ -130,7 +131,7 @@ func TestVarListCommand_Online(t *testing.T) { nsList := []string{api.DefaultNamespace, "ns1"} pathList := []string{"a/b/c", "a/b/c/d", "z/y", "z/y/x"} - variables := setupTestVariables(client, nsList, pathList) + variables := setupTestVariables(t, client, nsList, pathList) testTmpl := `{{ range $i, $e := . }}{{if ne $i 0}}{{print "•"}}{{end}}{{printf "%v\t%v" .Namespace .Path}}{{end}}` @@ -349,14 +350,14 @@ type testSVNamespacePath struct { Path string } -func setupTestVariables(c *api.Client, nsList, pathList []string) SVMSlice { +func setupTestVariables(t *testing.T, c *api.Client, nsList, pathList []string) SVMSlice { out := make(SVMSlice, 0, len(nsList)*len(pathList)) for _, ns := range nsList { c.Namespaces().Register(&api.Namespace{Name: ns}, nil) for _, p := range pathList { - setupTestVariable(c, ns, p, &out) + must.NoError(t, setupTestVariable(c, ns, p, &out)) } } @@ -369,8 +370,12 @@ func setupTestVariable(c *api.Client, ns, p string, out *SVMSlice) error { Path: p, Items: map[string]string{"k": "v"}} v, _, err := c.Variables().Create(testVar, &api.WriteOptions{Namespace: ns}) + if err != nil { + return err + } + *out = append(*out, *v.Metadata()) - return err + return nil } type NSPather interface { diff --git a/e2e/vaultcompat/cluster_setup_test.go b/e2e/vaultcompat/cluster_setup_test.go index cb2fdd4b2d9..0f1246a3124 100644 --- a/e2e/vaultcompat/cluster_setup_test.go +++ b/e2e/vaultcompat/cluster_setup_test.go @@ -25,7 +25,7 @@ var roleLegacy = map[string]interface{}{ func authConfigJWT(jwksURL string) map[string]any { return map[string]any{ "jwks_url": jwksURL, - "jwt_supported_algs": []string{"EdDSA"}, + "jwt_supported_algs": []string{"RS256", "EdDSA"}, "default_role": "nomad-workloads", } } diff --git a/nomad/auth/auth_test.go b/nomad/auth/auth_test.go index cfa0d3c6304..22ab2c727c6 100644 --- a/nomad/auth/auth_test.go +++ b/nomad/auth/auth_test.go @@ -4,7 +4,8 @@ package auth import ( - "crypto/ed25519" + "crypto/rand" + "crypto/rsa" "crypto/x509" "crypto/x509/pkix" "errors" @@ -20,7 +21,6 @@ import ( "github.com/hashicorp/nomad/acl" "github.com/hashicorp/nomad/ci" - "github.com/hashicorp/nomad/helper/crypto" "github.com/hashicorp/nomad/helper/pointer" "github.com/hashicorp/nomad/helper/testlog" "github.com/hashicorp/nomad/helper/uuid" @@ -1217,20 +1217,23 @@ func (ctx *testContext) Certificate() *x509.Certificate { } type testEncrypter struct { - key ed25519.PrivateKey + key *rsa.PrivateKey } func newTestEncrypter() *testEncrypter { - buf, _ := crypto.Bytes(32) + k, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + panic(err) + } return &testEncrypter{ - key: ed25519.NewKeyFromSeed(buf), + key: k, } } func (te *testEncrypter) signClaim(claims *structs.IdentityClaims) (string, error) { opts := (&jose.SignerOptions{}).WithType("JWT") - sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.EdDSA, Key: te.key}, opts) + sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: te.key}, opts) if err != nil { return "", err } diff --git a/nomad/client_agent_endpoint_test.go b/nomad/client_agent_endpoint_test.go index ae0451a4f1d..d6b4595fd16 100644 --- a/nomad/client_agent_endpoint_test.go +++ b/nomad/client_agent_endpoint_test.go @@ -156,13 +156,20 @@ func TestMonitor_Monitor_RemoteServer(t *testing.T) { servers := []*Server{s1, s2} var nonLeader *Server var leader *Server - for _, s := range servers { - if !s.IsLeader() { - nonLeader = s - } else { - leader = s + testutil.WaitForResult(func() (bool, error) { + for _, s := range servers { + if !s.IsLeader() { + nonLeader = s + } else { + leader = s + } } - } + + return nonLeader != nil && leader != nil && nonLeader != leader, fmt.Errorf( + "leader=%p follower=%p", nonLeader, leader) + }, func(err error) { + t.Fatalf("timed out trying to find a leader: %v", err) + }) cases := []struct { desc string diff --git a/nomad/core_sched_test.go b/nomad/core_sched_test.go index ef82abcb446..9a151db378c 100644 --- a/nomad/core_sched_test.go +++ b/nomad/core_sched_test.go @@ -2444,7 +2444,7 @@ func TestCoreScheduler_CSIBadState_ClaimGC(t *testing.T) { } } return true - }, time.Second*5, 10*time.Millisecond, "invalid claims should be marked for GC") + }, 10*time.Second, 50*time.Millisecond, "invalid claims should be marked for GC") } @@ -2454,7 +2454,7 @@ func TestCoreScheduler_RootKeyGC(t *testing.T) { srv, cleanup := TestServer(t, nil) defer cleanup() - testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, "global") // reset the time table srv.fsm.timetable.table = make([]TimeTableEntry, 1, 10) @@ -2559,7 +2559,7 @@ func TestCoreScheduler_VariablesRekey(t *testing.T) { srv, cleanup := TestServer(t, nil) defer cleanup() - testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, "global") store := srv.fsm.State() key0, err := store.GetActiveRootKeyMeta(nil) diff --git a/nomad/encrypter.go b/nomad/encrypter.go index 8ad265d9ac2..6d397cd3619 100644 --- a/nomad/encrypter.go +++ b/nomad/encrypter.go @@ -8,6 +8,8 @@ import ( "crypto/aes" "crypto/cipher" "crypto/ed25519" + "crypto/rsa" + "crypto/x509" "encoding/json" "fmt" "io/fs" @@ -51,10 +53,16 @@ type Encrypter struct { lock sync.RWMutex } +// keyset contains the key material for variable encryption and workload +// identity signing. As keysets are rotated they are identified by the RootKey +// KeyID although the public key IDs are published with a type prefix to +// disambiguate which signing algorithm to use. type keyset struct { - rootKey *structs.RootKey - cipher cipher.AEAD - privateKey ed25519.PrivateKey + rootKey *structs.RootKey + cipher cipher.AEAD + eddsaPrivateKey ed25519.PrivateKey + rsaPrivateKey *rsa.PrivateKey + rsaPKCS1PublicKey []byte // PKCS #1 DER encoded public key for JWKS } // NewEncrypter loads or creates a new local keystore and returns an @@ -102,7 +110,7 @@ func (e *Encrypter) loadKeystore() error { key, err := e.loadKeyFromStore(path) if err != nil { - return fmt.Errorf("could not load key file %s from keystore: %v", path, err) + return fmt.Errorf("could not load key file %s from keystore: %w", path, err) } if key.Meta.KeyID != id { return fmt.Errorf("root key ID %s must match key file %s", key.Meta.KeyID, path) @@ -110,7 +118,7 @@ func (e *Encrypter) loadKeystore() error { err = e.AddKey(key) if err != nil { - return fmt.Errorf("could not add key file %s to keystore: %v", path, err) + return fmt.Errorf("could not add key file %s to keystore: %w", path, err) } return nil }) @@ -199,9 +207,20 @@ func (e *Encrypter) SignClaims(claims *structs.IdentityClaims) (string, string, } opts := (&jose.SignerOptions{}).WithHeader("kid", keyset.rootKey.Meta.KeyID).WithType("JWT") - sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.EdDSA, Key: keyset.privateKey}, opts) - if err != nil { - return "", "", err + + var sig jose.Signer + if keyset.rsaPrivateKey != nil { + // If an RSA key has been created prefer it as it is more widely compatible + sig, err = jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: keyset.rsaPrivateKey}, opts) + if err != nil { + return "", "", err + } + } else { + // No RSA key has been created, fallback to ed25519 which always exists + sig, err = jose.NewSigner(jose.SigningKey{Algorithm: jose.EdDSA, Key: keyset.eddsaPrivateKey}, opts) + if err != nil { + return "", "", err + } } raw, err := jwt.Signed(sig).Claims(claims).CompactSerialize() @@ -292,15 +311,29 @@ func (e *Encrypter) addCipher(rootKey *structs.RootKey) error { return fmt.Errorf("invalid algorithm %s", rootKey.Meta.Algorithm) } - privateKey := ed25519.NewKeyFromSeed(rootKey.Key) + ed25519Key := ed25519.NewKeyFromSeed(rootKey.Key) + + ks := keyset{ + rootKey: rootKey, + cipher: aead, + eddsaPrivateKey: ed25519Key, + } + + // Unmarshal RSAKey for Workload Identity JWT signing if one exists. Prior to + // 1.7 only the ed25519 key was used. + if len(rootKey.RSAKey) > 0 { + rsaKey, err := x509.ParsePKCS1PrivateKey(rootKey.RSAKey) + if err != nil { + return fmt.Errorf("error parsing rsa key: %w", err) + } + + ks.rsaPrivateKey = rsaKey + ks.rsaPKCS1PublicKey = x509.MarshalPKCS1PublicKey(&rsaKey.PublicKey) + } e.lock.Lock() defer e.lock.Unlock() - e.keyring[rootKey.Meta.KeyID] = &keyset{ - rootKey: rootKey, - cipher: aead, - privateKey: privateKey, - } + e.keyring[rootKey.Meta.KeyID] = &ks return nil } @@ -356,23 +389,32 @@ func (e *Encrypter) saveKeyToStore(rootKey *structs.RootKey) error { kek, err := crypto.Bytes(32) if err != nil { - return fmt.Errorf("failed to generate key wrapper key: %v", err) + return fmt.Errorf("failed to generate key wrapper key: %w", err) } wrapper, err := e.newKMSWrapper(rootKey.Meta.KeyID, kek) if err != nil { - return fmt.Errorf("failed to create encryption wrapper: %v", err) + return fmt.Errorf("failed to create encryption wrapper: %w", err) } - blob, err := wrapper.Encrypt(e.srv.shutdownCtx, rootKey.Key) + rootBlob, err := wrapper.Encrypt(e.srv.shutdownCtx, rootKey.Key) if err != nil { - return fmt.Errorf("failed to encrypt root key: %v", err) + return fmt.Errorf("failed to encrypt root key: %w", err) } kekWrapper := &structs.KeyEncryptionKeyWrapper{ Meta: rootKey.Meta, - EncryptedDataEncryptionKey: blob.Ciphertext, + EncryptedDataEncryptionKey: rootBlob.Ciphertext, KeyEncryptionKey: kek, } + // Only keysets created after 1.7.0 will contain an RSA key. + if len(rootKey.RSAKey) > 0 { + rsaBlob, err := wrapper.Encrypt(e.srv.shutdownCtx, rootKey.RSAKey) + if err != nil { + return fmt.Errorf("failed to encrypt rsa key: %w", err) + } + kekWrapper.EncryptedRSAKey = rsaBlob.Ciphertext + } + buf, err := json.Marshal(kekWrapper) if err != nil { return err @@ -408,18 +450,32 @@ func (e *Encrypter) loadKeyFromStore(path string) (*structs.RootKey, error) { // sure we wrap them with as much context as possible wrapper, err := e.newKMSWrapper(meta.KeyID, kekWrapper.KeyEncryptionKey) if err != nil { - return nil, fmt.Errorf("unable to create key wrapper cipher: %v", err) + return nil, fmt.Errorf("unable to create key wrapper cipher: %w", err) } key, err := wrapper.Decrypt(e.srv.shutdownCtx, &kms.BlobInfo{ Ciphertext: kekWrapper.EncryptedDataEncryptionKey, }) if err != nil { - return nil, fmt.Errorf("unable to decrypt wrapped root key: %v", err) + return nil, fmt.Errorf("unable to decrypt wrapped root key: %w", err) + } + + // Decrypt RSAKey for Workload Identity JWT signing if one exists. Prior to + // 1.7 an ed25519 key derived from the root key was used instead of an RSA + // key. + var rsaKey []byte + if len(kekWrapper.EncryptedRSAKey) > 0 { + rsaKey, err = wrapper.Decrypt(e.srv.shutdownCtx, &kms.BlobInfo{ + Ciphertext: kekWrapper.EncryptedRSAKey, + }) + if err != nil { + return nil, fmt.Errorf("unable to decrypt wrapped rsa key: %w", err) + } } return &structs.RootKey{ - Meta: meta, - Key: key, + Meta: meta, + Key: key, + RSAKey: rsaKey, }, nil } @@ -434,13 +490,21 @@ func (e *Encrypter) GetPublicKey(keyID string) (*structs.KeyringPublicKey, error return nil, err } - return &structs.KeyringPublicKey{ - KeyID: ks.rootKey.Meta.KeyID, - PublicKey: ks.privateKey.Public().(ed25519.PublicKey), - Algorithm: structs.PubKeyAlgEdDSA, + pubKey := &structs.KeyringPublicKey{ + KeyID: keyID, Use: structs.PubKeyUseSig, CreateTime: ks.rootKey.Meta.CreateTime, - }, nil + } + + if ks.rsaPrivateKey != nil { + pubKey.PublicKey = ks.rsaPKCS1PublicKey + pubKey.Algorithm = structs.PubKeyAlgRS256 + } else { + pubKey.PublicKey = ks.eddsaPrivateKey.Public().(ed25519.PublicKey) + pubKey.Algorithm = structs.PubKeyAlgEdDSA + } + + return pubKey, nil } // newKMSWrapper returns a go-kms-wrapping interface the caller can use to diff --git a/nomad/encrypter_test.go b/nomad/encrypter_test.go index 71fb5589740..1f18509bb08 100644 --- a/nomad/encrypter_test.go +++ b/nomad/encrypter_test.go @@ -14,11 +14,13 @@ import ( "testing" "time" + "github.com/go-jose/go-jose/v3/jwt" msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" "github.com/shoenig/test/must" "github.com/stretchr/testify/require" "github.com/hashicorp/nomad/ci" + "github.com/hashicorp/nomad/helper/uuid" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/testutil" @@ -90,7 +92,7 @@ func TestEncrypter_Restore(t *testing.T) { c.DataDir = tmpDir }) defer shutdown() - testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, "global") codec := rpcClient(t, srv) nodeID := srv.GetConfig().NodeID @@ -133,7 +135,7 @@ func TestEncrypter_Restore(t *testing.T) { c.DataDir = tmpDir }) defer shutdown2() - testutil.WaitForLeader(t, srv2.RPC) + testutil.WaitForKeyring(t, srv2.RPC, "global") codec = rpcClient(t, srv2) // Verify we've restored all the keys from the old keystore @@ -189,9 +191,9 @@ func TestEncrypter_KeyringReplication(t *testing.T) { TestJoin(t, srv1, srv2) TestJoin(t, srv1, srv3) - testutil.WaitForLeader(t, srv1.RPC) - testutil.WaitForLeader(t, srv2.RPC) - testutil.WaitForLeader(t, srv3.RPC) + testutil.WaitForKeyring(t, srv1.RPC, "global") + testutil.WaitForKeyring(t, srv2.RPC, "global") + testutil.WaitForKeyring(t, srv3.RPC, "global") servers := []*Server{srv1, srv2, srv3} var leader *Server @@ -369,7 +371,7 @@ func TestEncrypter_EncryptDecrypt(t *testing.T) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer shutdown() - testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, "global") e := srv.encrypter @@ -389,7 +391,7 @@ func TestEncrypter_SignVerify(t *testing.T) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer shutdown() - testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, "global") alloc := mock.Alloc() claims := structs.NewIdentityClaims(alloc.Job, alloc, wiHandle, alloc.LookupTask("web").Identity, time.Now()) @@ -422,7 +424,7 @@ func TestEncrypter_SignVerify_Issuer(t *testing.T) { c.OIDCIssuer = testIssuer }) defer shutdown() - testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, "global") alloc := mock.Alloc() claims := structs.NewIdentityClaims(alloc.Job, alloc, wiHandle, alloc.LookupTask("web").Identity, time.Now()) @@ -447,7 +449,7 @@ func TestEncrypter_SignVerify_AlgNone(t *testing.T) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer shutdown() - testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, "global") alloc := mock.Alloc() claims := structs.NewIdentityClaims(alloc.Job, alloc, wiHandle, alloc.LookupTask("web").Identity, time.Now()) @@ -489,3 +491,123 @@ func TestEncrypter_SignVerify_AlgNone(t *testing.T) { must.ErrorContains(t, err, "invalid signature") must.Nil(t, got) } + +// TestEncrypter_Upgrade17 simulates upgrading from 1.6 -> 1.7 does not break +// old (ed25519) or new (rsa) signing keys. +func TestEncrypter_Upgrade17(t *testing.T) { + + ci.Parallel(t) + + srv, shutdown := TestServer(t, func(c *Config) { + c.NumSchedulers = 0 + }) + t.Cleanup(shutdown) + testutil.WaitForKeyring(t, srv.RPC, "global") + codec := rpcClient(t, srv) + + // Fake life as a 1.6 server by writing only ed25519 keys + oldRootKey, err := structs.NewRootKey(structs.EncryptionAlgorithmAES256GCM) + must.NoError(t, err) + + oldRootKey.Meta.SetActive() + + // Remove RSAKey to mimic 1.6 + oldRootKey.RSAKey = nil + + // Add to keyring + must.NoError(t, srv.encrypter.AddKey(oldRootKey)) + + // Write metadata to Raft + wr := structs.WriteRequest{ + Namespace: "default", + Region: "global", + } + req := structs.KeyringUpdateRootKeyMetaRequest{ + RootKeyMeta: oldRootKey.Meta, + WriteRequest: wr, + } + _, _, err = srv.raftApply(structs.RootKeyMetaUpsertRequestType, req) + must.NoError(t, err) + + // Create a 1.6 style workload identity + claims := &structs.IdentityClaims{ + Namespace: "default", + JobID: "fakejob", + AllocationID: uuid.Generate(), + TaskName: "faketask", + } + + // Sign the claims and assert they were signed with EdDSA (the 1.6 signing + // algorithm) + oldRawJWT, oldKeyID, err := srv.encrypter.SignClaims(claims) + must.NoError(t, err) + must.Eq(t, oldRootKey.Meta.KeyID, oldKeyID) + + oldJWT, err := jwt.ParseSigned(oldRawJWT) + must.NoError(t, err) + + foundKeyID := false + foundAlg := false + for _, h := range oldJWT.Headers { + if h.KeyID != "" { + // Should only have one key id header + must.False(t, foundKeyID) + foundKeyID = true + must.Eq(t, oldKeyID, h.KeyID) + } + + if h.Algorithm != "" { + // Should only have one alg header + must.False(t, foundAlg) + foundAlg = true + must.Eq(t, structs.PubKeyAlgEdDSA, h.Algorithm) + } + } + must.True(t, foundKeyID) + must.True(t, foundAlg) + + _, err = srv.encrypter.VerifyClaim(oldRawJWT) + must.NoError(t, err) + + // !! Mimic an upgrade by rotating to get a new RSA key !! + rotateReq := &structs.KeyringRotateRootKeyRequest{ + WriteRequest: wr, + } + var rotateResp structs.KeyringRotateRootKeyResponse + must.NoError(t, msgpackrpc.CallWithCodec(codec, "Keyring.Rotate", rotateReq, &rotateResp)) + + newRawJWT, newKeyID, err := srv.encrypter.SignClaims(claims) + must.NoError(t, err) + must.NotEq(t, oldRootKey.Meta.KeyID, newKeyID) + must.Eq(t, rotateResp.Key.KeyID, newKeyID) + + newJWT, err := jwt.ParseSigned(newRawJWT) + must.NoError(t, err) + + foundKeyID = false + foundAlg = false + for _, h := range newJWT.Headers { + if h.KeyID != "" { + // Should only have one key id header + must.False(t, foundKeyID) + foundKeyID = true + must.Eq(t, newKeyID, h.KeyID) + } + + if h.Algorithm != "" { + // Should only have one alg header + must.False(t, foundAlg) + foundAlg = true + must.Eq(t, structs.PubKeyAlgRS256, h.Algorithm) + } + } + must.True(t, foundKeyID) + must.True(t, foundAlg) + + _, err = srv.encrypter.VerifyClaim(newRawJWT) + must.NoError(t, err) + + // Ensure that verifying the old JWT still works + _, err = srv.encrypter.VerifyClaim(oldRawJWT) + must.NoError(t, err) +} diff --git a/nomad/keyring_endpoint_test.go b/nomad/keyring_endpoint_test.go index 62b172f30b7..0262b4b1af1 100644 --- a/nomad/keyring_endpoint_test.go +++ b/nomad/keyring_endpoint_test.go @@ -4,7 +4,6 @@ package nomad import ( - "sync" "testing" "time" @@ -25,7 +24,7 @@ func TestKeyringEndpoint_CRUD(t *testing.T) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer shutdown() - testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, "global") codec := rpcClient(t, srv) // Upsert a new key @@ -65,12 +64,11 @@ func TestKeyringEndpoint_CRUD(t *testing.T) { // that Get queries don't need ACL tokens in the test server // because they always pass the mTLS check - var wg sync.WaitGroup - wg.Add(1) + errCh := make(chan error, 1) var listResp structs.KeyringListRootKeyMetaResponse go func() { - defer wg.Done() + defer close(errCh) codec := rpcClient(t, srv) // not safe to share across goroutines listReq := &structs.KeyringListRootKeyMetaRequest{ QueryOptions: structs.QueryOptions{ @@ -79,8 +77,7 @@ func TestKeyringEndpoint_CRUD(t *testing.T) { AuthToken: rootToken.SecretID, }, } - err = msgpackrpc.CallWithCodec(codec, "Keyring.List", listReq, &listResp) - require.NoError(t, err) + errCh <- msgpackrpc.CallWithCodec(codec, "Keyring.List", listReq, &listResp) }() updateReq.RootKey.Meta.CreateTime = time.Now().UTC().UnixNano() @@ -89,7 +86,7 @@ func TestKeyringEndpoint_CRUD(t *testing.T) { require.NotEqual(t, uint64(0), updateResp.Index) // wait for the blocking query to complete and check the response - wg.Wait() + require.NoError(t, <-errCh) require.Equal(t, listResp.Index, updateResp.Index) require.Len(t, listResp.Keys, 2) // bootstrap + new one @@ -138,7 +135,7 @@ func TestKeyringEndpoint_InvalidUpdates(t *testing.T) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer shutdown() - testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, "global") codec := rpcClient(t, srv) // Setup an existing key @@ -228,7 +225,7 @@ func TestKeyringEndpoint_Rotate(t *testing.T) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer shutdown() - testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, "global") codec := rpcClient(t, srv) // Setup an existing key @@ -312,7 +309,7 @@ func TestKeyringEndpoint_ListPublic(t *testing.T) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer shutdown() - testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, "global") codec := rpcClient(t, srv) // Assert 1 key exists and normal fields are set diff --git a/nomad/leader.go b/nomad/leader.go index f85bfc8ce5a..6946c7ffe69 100644 --- a/nomad/leader.go +++ b/nomad/leader.go @@ -2763,7 +2763,7 @@ func (s *Server) initializeKeyring(stopCh <-chan struct{}) { err = s.encrypter.AddKey(rootKey) if err != nil { - logger.Error("could not add initial key to keyring: %v", err) + logger.Error("could not add initial key to keyring", "error", err) return } @@ -2771,7 +2771,7 @@ func (s *Server) initializeKeyring(stopCh <-chan struct{}) { structs.KeyringUpdateRootKeyMetaRequest{ RootKeyMeta: rootKey.Meta, }); err != nil { - logger.Error("could not initialize keyring: %v", err) + logger.Error("could not initialize keyring", "error", err) return } diff --git a/nomad/locks_test.go b/nomad/locks_test.go index ffef5e1c560..6d149fc2cfa 100644 --- a/nomad/locks_test.go +++ b/nomad/locks_test.go @@ -54,7 +54,7 @@ func TestServer_invalidateVariableLock(t *testing.T) { testServer, testServerCleanup := TestServer(t, nil) defer testServerCleanup() - testutil.WaitForLeader(t, testServer.RPC) + testutil.WaitForKeyring(t, testServer.RPC, "global") // Generate a variable that includes a lock entry and upsert this into our // state. @@ -87,7 +87,7 @@ func TestServer_createVariableLockTimer(t *testing.T) { testServer, testServerCleanup := TestServer(t, nil) defer testServerCleanup() - testutil.WaitForLeader(t, testServer.RPC) + testutil.WaitForKeyring(t, testServer.RPC, "global") // Generate a variable that includes a lock entry and upsert this into our // state. @@ -135,7 +135,7 @@ func TestServer_createAndRenewVariableLockTimer(t *testing.T) { testServer, testServerCleanup := TestServer(t, nil) defer testServerCleanup() - testutil.WaitForLeader(t, testServer.RPC) + testutil.WaitForKeyring(t, testServer.RPC, "global") // Generate a variable that includes a lock entry and upsert this into our // state. diff --git a/nomad/node_pool_endpoint_test.go b/nomad/node_pool_endpoint_test.go index d404a584cda..f3b0efc6a88 100644 --- a/nomad/node_pool_endpoint_test.go +++ b/nomad/node_pool_endpoint_test.go @@ -1170,8 +1170,8 @@ func TestNodePoolEndpoint_DeleteNodePools_NonLocal(t *testing.T) { }) defer cleanupS2() TestJoin(t, s1, s2) - testutil.WaitForLeader(t, s1.RPC) - testutil.WaitForLeader(t, s2.RPC) + testutil.WaitForKeyring(t, s1.RPC, s1.config.Region) + testutil.WaitForKeyring(t, s2.RPC, s2.config.Region) // Write a node pool to the authoritative region np1 := mock.NodePool() diff --git a/nomad/operator_endpoint_test.go b/nomad/operator_endpoint_test.go index 02d5c932f6b..fdb829b2313 100644 --- a/nomad/operator_endpoint_test.go +++ b/nomad/operator_endpoint_test.go @@ -533,12 +533,14 @@ func TestOperator_TransferLeadershipToServerID_ACL(t *testing.T) { var tgtID raft.ServerID // Find the first non-leader server in the list. + s1.peerLock.Lock() for _, sp := range s1.localPeers { tgtID = raft.ServerID(sp.ID) if tgtID != ldrID { break } } + s1.peerLock.Unlock() // Create ACL token invalidToken := mock.CreatePolicyAndToken(t, state, 1001, "test-invalid", mock.NodePolicy(acl.PolicyWrite)) diff --git a/nomad/service_registration_endpoint_test.go b/nomad/service_registration_endpoint_test.go index 1a87d136677..175cc83932a 100644 --- a/nomad/service_registration_endpoint_test.go +++ b/nomad/service_registration_endpoint_test.go @@ -34,7 +34,7 @@ func TestServiceRegistration_Upsert(t *testing.T) { }, testFn: func(t *testing.T, s *Server, token *structs.ACLToken) { codec := rpcClient(t, s) - testutil.WaitForLeader(t, s.RPC) + testutil.WaitForKeyring(t, s.RPC, "global") // Generate and upsert some service registrations and ensure // they are in the same namespace. @@ -78,7 +78,7 @@ func TestServiceRegistration_Upsert(t *testing.T) { }, testFn: func(t *testing.T, s *Server, token *structs.ACLToken) { codec := rpcClient(t, s) - testutil.WaitForLeader(t, s.RPC) + testutil.WaitForKeyring(t, s.RPC, "global") // Generate and upsert some service registrations and ensure // they are in the same namespace. @@ -116,7 +116,7 @@ func TestServiceRegistration_Upsert(t *testing.T) { }, testFn: func(t *testing.T, s *Server, token *structs.ACLToken) { codec := rpcClient(t, s) - testutil.WaitForLeader(t, s.RPC) + testutil.WaitForKeyring(t, s.RPC, "global") // Generate and upsert some service registrations and ensure // they are in the same namespace. @@ -159,7 +159,7 @@ func TestServiceRegistration_Upsert(t *testing.T) { }, testFn: func(t *testing.T, s *Server, token *structs.ACLToken) { codec := rpcClient(t, s) - testutil.WaitForLeader(t, s.RPC) + testutil.WaitForKeyring(t, s.RPC, "global") // Generate and upsert some service registrations and ensure // they are in the same namespace. @@ -217,7 +217,7 @@ func TestServiceRegistration_DeleteByID(t *testing.T) { }, testFn: func(t *testing.T, s *Server, token *structs.ACLToken) { codec := rpcClient(t, s) - testutil.WaitForLeader(t, s.RPC) + testutil.WaitForKeyring(t, s.RPC, "global") // Attempt to delete a service registration that does not // exist. @@ -244,7 +244,7 @@ func TestServiceRegistration_DeleteByID(t *testing.T) { }, testFn: func(t *testing.T, s *Server, token *structs.ACLToken) { codec := rpcClient(t, s) - testutil.WaitForLeader(t, s.RPC) + testutil.WaitForKeyring(t, s.RPC, "global") // Generate and upsert some service registrations. services := mock.ServiceRegistrations() @@ -272,7 +272,7 @@ func TestServiceRegistration_DeleteByID(t *testing.T) { }, testFn: func(t *testing.T, s *Server, token *structs.ACLToken) { codec := rpcClient(t, s) - testutil.WaitForLeader(t, s.RPC) + testutil.WaitForKeyring(t, s.RPC, "global") // Generate and upsert some service registrations. services := mock.ServiceRegistrations() @@ -301,7 +301,7 @@ func TestServiceRegistration_DeleteByID(t *testing.T) { }, testFn: func(t *testing.T, s *Server, token *structs.ACLToken) { codec := rpcClient(t, s) - testutil.WaitForLeader(t, s.RPC) + testutil.WaitForKeyring(t, s.RPC, "global") // Generate and upsert some service registrations. services := mock.ServiceRegistrations() @@ -331,7 +331,7 @@ func TestServiceRegistration_DeleteByID(t *testing.T) { }, testFn: func(t *testing.T, s *Server, token *structs.ACLToken) { codec := rpcClient(t, s) - testutil.WaitForLeader(t, s.RPC) + testutil.WaitForKeyring(t, s.RPC, "global") // Generate and upsert some service registrations. services := mock.ServiceRegistrations() @@ -365,7 +365,7 @@ func TestServiceRegistration_DeleteByID(t *testing.T) { }, testFn: func(t *testing.T, s *Server, token *structs.ACLToken) { codec := rpcClient(t, s) - testutil.WaitForLeader(t, s.RPC) + testutil.WaitForKeyring(t, s.RPC, "global") // Generate and upsert some service registrations. services := mock.ServiceRegistrations() @@ -420,7 +420,7 @@ func TestServiceRegistration_List(t *testing.T) { }, testFn: func(t *testing.T, s *Server, token *structs.ACLToken) { codec := rpcClient(t, s) - testutil.WaitForLeader(t, s.RPC) + testutil.WaitForKeyring(t, s.RPC, "global") // Generate and upsert some service registrations. services := mock.ServiceRegistrations() @@ -465,7 +465,7 @@ func TestServiceRegistration_List(t *testing.T) { }, testFn: func(t *testing.T, s *Server, token *structs.ACLToken) { codec := rpcClient(t, s) - testutil.WaitForLeader(t, s.RPC) + testutil.WaitForKeyring(t, s.RPC, "global") // Generate and upsert some service registrations. services := mock.ServiceRegistrations() @@ -503,7 +503,7 @@ func TestServiceRegistration_List(t *testing.T) { }, testFn: func(t *testing.T, s *Server, token *structs.ACLToken) { codec := rpcClient(t, s) - testutil.WaitForLeader(t, s.RPC) + testutil.WaitForKeyring(t, s.RPC, "global") // Test a request without setting an ACL token. serviceRegReq := &structs.ServiceRegistrationListRequest{ @@ -526,7 +526,7 @@ func TestServiceRegistration_List(t *testing.T) { }, testFn: func(t *testing.T, s *Server, token *structs.ACLToken) { codec := rpcClient(t, s) - testutil.WaitForLeader(t, s.RPC) + testutil.WaitForKeyring(t, s.RPC, "global") // Generate and upsert some service registrations. services := mock.ServiceRegistrations() @@ -553,7 +553,7 @@ func TestServiceRegistration_List(t *testing.T) { }, testFn: func(t *testing.T, s *Server, token *structs.ACLToken) { codec := rpcClient(t, s) - testutil.WaitForLeader(t, s.RPC) + testutil.WaitForKeyring(t, s.RPC, "global") // Generate and upsert some service registrations. services := mock.ServiceRegistrations() @@ -580,7 +580,7 @@ func TestServiceRegistration_List(t *testing.T) { }, testFn: func(t *testing.T, s *Server, token *structs.ACLToken) { codec := rpcClient(t, s) - testutil.WaitForLeader(t, s.RPC) + testutil.WaitForKeyring(t, s.RPC, "global") // Generate and upsert some service registrations. services := mock.ServiceRegistrations() @@ -625,7 +625,7 @@ func TestServiceRegistration_List(t *testing.T) { }, testFn: func(t *testing.T, s *Server, token *structs.ACLToken) { codec := rpcClient(t, s) - testutil.WaitForLeader(t, s.RPC) + testutil.WaitForKeyring(t, s.RPC, "global") // Generate and upsert some service registrations. services := mock.ServiceRegistrations() @@ -662,7 +662,7 @@ func TestServiceRegistration_List(t *testing.T) { }, testFn: func(t *testing.T, s *Server, token *structs.ACLToken) { codec := rpcClient(t, s) - testutil.WaitForLeader(t, s.RPC) + testutil.WaitForKeyring(t, s.RPC, "global") // Create a policy and grab the token which has the read-job // capability on the platform namespace. @@ -704,7 +704,7 @@ func TestServiceRegistration_List(t *testing.T) { }, testFn: func(t *testing.T, s *Server, token *structs.ACLToken) { codec := rpcClient(t, s) - testutil.WaitForLeader(t, s.RPC) + testutil.WaitForKeyring(t, s.RPC, "global") // Create a namespace as this is needed when using an ACL like // we do in this test. @@ -757,7 +757,7 @@ func TestServiceRegistration_List(t *testing.T) { }, testFn: func(t *testing.T, s *Server, token *structs.ACLToken) { codec := rpcClient(t, s) - testutil.WaitForLeader(t, s.RPC) + testutil.WaitForKeyring(t, s.RPC, "global") // Create a namespace as this is needed when using an ACL like // we do in this test. @@ -810,7 +810,7 @@ func TestServiceRegistration_List(t *testing.T) { }, testFn: func(t *testing.T, s *Server, token *structs.ACLToken) { codec := rpcClient(t, s) - testutil.WaitForLeader(t, s.RPC) + testutil.WaitForKeyring(t, s.RPC, "global") // Create a namespace as this is needed when using an ACL like // we do in this test. @@ -868,7 +868,7 @@ func TestServiceRegistration_List(t *testing.T) { }, testFn: func(t *testing.T, s *Server, token *structs.ACLToken) { codec := rpcClient(t, s) - testutil.WaitForLeader(t, s.RPC) + testutil.WaitForKeyring(t, s.RPC, "global") // Create a namespace as this is needed when using an ACL like // we do in this test. @@ -963,7 +963,7 @@ func TestServiceRegistration_GetService(t *testing.T) { }, testFn: func(t *testing.T, s *Server, _ *structs.ACLToken) { codec := rpcClient(t, s) - testutil.WaitForLeader(t, s.RPC) + testutil.WaitForKeyring(t, s.RPC, "global") // Generate mock services then upsert them individually using different indexes. services := mock.ServiceRegistrations() @@ -1026,7 +1026,7 @@ func TestServiceRegistration_GetService(t *testing.T) { }, testFn: func(t *testing.T, s *Server, token *structs.ACLToken) { codec := rpcClient(t, s) - testutil.WaitForLeader(t, s.RPC) + testutil.WaitForKeyring(t, s.RPC, "global") // Generate mock services then upsert them individually using different indexes. services := mock.ServiceRegistrations() @@ -1107,7 +1107,7 @@ func TestServiceRegistration_GetService(t *testing.T) { }, testFn: func(t *testing.T, s *Server, token *structs.ACLToken) { codec := rpcClient(t, s) - testutil.WaitForLeader(t, s.RPC) + testutil.WaitForKeyring(t, s.RPC, "global") // Generate mock services then upsert them individually using different indexes. services := mock.ServiceRegistrations() @@ -1159,7 +1159,7 @@ func TestServiceRegistration_GetService(t *testing.T) { }, testFn: func(t *testing.T, s *Server, token *structs.ACLToken) { codec := rpcClient(t, s) - testutil.WaitForLeader(t, s.RPC) + testutil.WaitForKeyring(t, s.RPC, "global") // Generate mock services then upsert them individually using different indexes. services := mock.ServiceRegistrations() @@ -1204,7 +1204,7 @@ func TestServiceRegistration_GetService(t *testing.T) { }, testFn: func(t *testing.T, s *Server, _ *structs.ACLToken) { codec := rpcClient(t, s) - testutil.WaitForLeader(t, s.RPC) + testutil.WaitForKeyring(t, s.RPC, "global") // Generate mock services then upsert them individually using different indexes. services := mock.ServiceRegistrations() @@ -1305,7 +1305,7 @@ func TestServiceRegistration_GetService(t *testing.T) { }, testFn: func(t *testing.T, s *Server, _ *structs.ACLToken) { codec := rpcClient(t, s) - testutil.WaitForLeader(t, s.RPC) + testutil.WaitForKeyring(t, s.RPC, "global") // insert 3 instances of service s1 nodeID, jobID, allocID := "node_id", "job_id", "alloc_id" @@ -1383,7 +1383,7 @@ func TestServiceRegistration_GetService(t *testing.T) { }, testFn: func(t *testing.T, s *Server, _ *structs.ACLToken) { codec := rpcClient(t, s) - testutil.WaitForLeader(t, s.RPC) + testutil.WaitForKeyring(t, s.RPC, "global") // insert 2 instances of service s1 nodeID, jobID, allocID := "node_id", "job_id", "alloc_id" diff --git a/nomad/structs/keyring.go b/nomad/structs/keyring.go index 2a1ecfdc0fd..103068716d8 100644 --- a/nomad/structs/keyring.go +++ b/nomad/structs/keyring.go @@ -5,10 +5,14 @@ package structs import ( "crypto/ed25519" + "crypto/rand" + "crypto/rsa" + "crypto/x509" "fmt" "net/url" "time" + "github.com/go-jose/go-jose/v3" "github.com/hashicorp/nomad/helper" "github.com/hashicorp/nomad/helper/crypto" "github.com/hashicorp/nomad/helper/uuid" @@ -17,7 +21,11 @@ import ( const ( // PubKeyAlgEdDSA is the JWA (JSON Web Algorithm) for ed25519 public keys // used for signatures. - PubKeyAlgEdDSA = "EdDSA" + PubKeyAlgEdDSA = string(jose.EdDSA) + + // PubKeyAlgRS256 is the JWA for RSA public keys used for signatures. Support + // is required by AWS OIDC IAM Provider. + PubKeyAlgRS256 = string(jose.RS256) // PubKeyUseSig is the JWK (JSON Web Key) "use" parameter value for // signatures. @@ -31,6 +39,11 @@ const ( type RootKey struct { Meta *RootKeyMeta Key []byte // serialized to keystore as base64 blob + + // RSAKey is the private key used to sign workload identity JWTs with the + // RS256 algorithm. It is stored in its PKCS #1, ASN.1 DER form. See + // x509.MarshalPKCS1PrivateKey for details. + RSAKey []byte } // NewRootKey returns a new root key and its metadata. @@ -46,11 +59,19 @@ func NewRootKey(algorithm EncryptionAlgorithm) (*RootKey, error) { case EncryptionAlgorithmAES256GCM: key, err := crypto.Bytes(32) if err != nil { - return nil, fmt.Errorf("failed to generate key: %v", err) + return nil, fmt.Errorf("failed to generate root key: %w", err) } rootKey.Key = key } + // Generate RSA key for signing workload identity JWTs with RS256. + rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, fmt.Errorf("failed to generate rsa key: %w", err) + } + + rootKey.RSAKey = x509.MarshalPKCS1PrivateKey(rsaPrivateKey) + return rootKey, nil } @@ -176,6 +197,7 @@ func (rkm *RootKeyMeta) Validate() error { type KeyEncryptionKeyWrapper struct { Meta *RootKeyMeta EncryptedDataEncryptionKey []byte `json:"DEK"` + EncryptedRSAKey []byte `json:"RSAKey"` KeyEncryptionKey []byte `json:"KEK"` } @@ -303,9 +325,19 @@ type KeyringPublicKey struct { // claims) inspect pubKey's concrete type. func (pubKey *KeyringPublicKey) GetPublicKey() (any, error) { switch alg := pubKey.Algorithm; alg { + case PubKeyAlgEdDSA: // Convert public key bytes to an ed25519 public key return ed25519.PublicKey(pubKey.PublicKey), nil + + case PubKeyAlgRS256: + // PEM -> rsa.PublickKey + rsaPubKey, err := x509.ParsePKCS1PublicKey(pubKey.PublicKey) + if err != nil { + return nil, fmt.Errorf("error parsing %s public key: %w", alg, err) + } + return rsaPubKey, nil + default: return nil, fmt.Errorf("unknown algorithm: %q", alg) } @@ -345,9 +377,14 @@ func NewOIDCDiscoveryConfig(issuer string) (*OIDCDiscoveryConfig, error) { } disc := &OIDCDiscoveryConfig{ - Issuer: issuer, - JWKS: jwksURL, - IDTokenAlgs: []string{PubKeyAlgEdDSA}, + Issuer: issuer, + JWKS: jwksURL, + + // RS256 is required by the OIDC spec and some third parties such as AWS's + // IAM OIDC Identity Provider. Prior to v1.7 Nomad default to EdDSA so + // advertise support for backward compatibility. + IDTokenAlgs: []string{PubKeyAlgRS256, PubKeyAlgEdDSA}, + ResponseTypes: []string{"code"}, Subjects: []string{"public"}, } diff --git a/nomad/variables_endpoint_test.go b/nomad/variables_endpoint_test.go index 435329eb310..b3dc05c075d 100644 --- a/nomad/variables_endpoint_test.go +++ b/nomad/variables_endpoint_test.go @@ -31,7 +31,7 @@ func TestVariablesEndpoint_GetVariable_Blocking(t *testing.T) { state := s1.fsm.State() codec := rpcClient(t, s1) - testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForKeyring(t, s1.RPC, "global") // First create an unrelated variable. delay := 100 * time.Millisecond @@ -175,7 +175,7 @@ func TestVariablesEndpoint_Apply_ACL(t *testing.T) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer shutdown() - testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, "global") codec := rpcClient(t, srv) state := srv.fsm.State() @@ -455,7 +455,7 @@ func TestVariablesEndpoint_auth(t *testing.T) { }) defer shutdown() - testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, "global") const ns = "nondefault-namespace" @@ -838,7 +838,7 @@ func TestVariablesEndpoint_ListFiltering(t *testing.T) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer shutdown() - testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, "global") codec := rpcClient(t, srv) ns := "nondefault-namespace" @@ -942,7 +942,7 @@ func TestVariablesEndpoint_ComplexACLPolicies(t *testing.T) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer shutdown() - testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, "global") codec := rpcClient(t, srv) idx := uint64(1000) @@ -1062,7 +1062,7 @@ func TestVariablesEndpoint_Apply_LockAcquireUpsertAndRelease(t *testing.T) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer shutdown() - testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, "global") codec := rpcClient(t, srv) mockVar := mock.Variable() @@ -1473,7 +1473,7 @@ func TestVariablesEndpoint_Read_Lock_ACL(t *testing.T) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer shutdown() - testutil.WaitForLeader(t, srv.RPC) + testutil.WaitForKeyring(t, srv.RPC, "global") codec := rpcClient(t, srv) state := srv.fsm.State() @@ -1532,6 +1532,7 @@ func TestVariablesEndpoint_Read_Lock_ACL(t *testing.T) { readResp := new(structs.VariablesReadResponse) must.NoError(t, msgpackrpc.CallWithCodec(codec, "Variables.Read", req, &readResp)) + must.NotNil(t, readResp.Data) must.NotNil(t, readResp.Data.VariableMetadata.Lock) must.Eq(t, latest.LockID(), readResp.Data.LockID()) }) diff --git a/testutil/server.go b/testutil/server.go index 57b860ad0b6..2ff8f6738e5 100644 --- a/testutil/server.go +++ b/testutil/server.go @@ -242,7 +242,7 @@ func NewTestServer(t testing.T, cb ServerConfigCallback) *TestServer { // Wait for the server to be ready if nomadConfig.Server.Enabled && nomadConfig.Server.BootstrapExpect != 0 { - server.waitForLeader() + server.waitForServers() } else { server.waitForAPI() } @@ -338,14 +338,14 @@ func (s *TestServer) waitForAPI() { }) } -// waitForLeader waits for the Nomad server's HTTP API to become -// available, and then waits for a known leader and an index of -// 1 or more to be observed to confirm leader election is done. -func (s *TestServer) waitForLeader() { +// waitForServers waits for the Nomad server's HTTP API to become available, +// and then waits for the keyring to be intialized. This implies a leader has +// been elected and Raft writes have occurred. +func (s *TestServer) waitForServers() { WaitForResult(func() (bool, error) { // Query the API and check the status code // Using this endpoint as it is does not have restricted access - resp, err := s.HTTPClient.Get(s.url("/v1/status/leader")) + resp, err := s.HTTPClient.Get(s.url("/.well-known/jwks.json")) if err != nil { return false, err } @@ -354,6 +354,16 @@ func (s *TestServer) waitForLeader() { return false, err } + jwks := struct { + Keys []interface{} `json:"keys"` + }{} + if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil { + return false, fmt.Errorf("error decoding jwks response: %w", err) + } + if len(jwks.Keys) == 0 { + return false, fmt.Errorf("no keys found") + } + return true, nil }, func(err error) { defer s.Stop() diff --git a/testutil/wait.go b/testutil/wait.go index 3466f497b39..300acc11761 100644 --- a/testutil/wait.go +++ b/testutil/wait.go @@ -169,6 +169,24 @@ func WaitForLeaders(t testing.TB, rpcs ...rpcFn) string { return leader } +// WaitForKeyring blocks until the keyring is initialized. +func WaitForKeyring(t testing.TB, rpc rpcFn, region string) { + t.Helper() + args := structs.GenericRequest{ + QueryOptions: structs.QueryOptions{ + Namespace: "default", + Region: region, + }, + } + reply := structs.KeyringListPublicResponse{} + WaitForResult(func() (bool, error) { + err := rpc("Keyring.ListPublic", &args, &reply) + return len(reply.PublicKeys) > 0, err + }, func(err error) { + t.Fatalf("timed out waiting for keyring to initialize: %v", err) + }) +} + // WaitForClient blocks until the client can be found func WaitForClient(t testing.TB, rpc rpcFn, nodeID string, region string) { t.Helper()