diff --git a/pkg/connectorbuilder/connectorbuilder.go b/pkg/connectorbuilder/connectorbuilder.go index 29e02e8b..a245b13b 100644 --- a/pkg/connectorbuilder/connectorbuilder.go +++ b/pkg/connectorbuilder/connectorbuilder.go @@ -837,7 +837,7 @@ func (b *builderImpl) RotateCredential(ctx context.Context, request *v2.RotateCr return nil, fmt.Errorf("error: rotate credentials on resource failed: %w", err) } - pkem, err := crypto.NewEncryptionManager(request.GetCredentialOptions(), request.GetEncryptionConfigs()) + pkem, err := crypto.NewEncryptionManager(ctx, request.GetCredentialOptions(), request.GetEncryptionConfigs()) if err != nil { l.Error("error: creating encryption manager failed", zap.Error(err)) b.m.RecordTaskFailure(ctx, tt, b.nowFunc().Sub(start)) @@ -889,7 +889,7 @@ func (b *builderImpl) CreateAccount(ctx context.Context, request *v2.CreateAccou return nil, fmt.Errorf("error: create account failed: %w", err) } - pkem, err := crypto.NewEncryptionManager(request.GetCredentialOptions(), request.GetEncryptionConfigs()) + pkem, err := crypto.NewEncryptionManager(ctx, request.GetCredentialOptions(), request.GetEncryptionConfigs()) if err != nil { l.Error("error: creating encryption manager failed", zap.Error(err)) b.m.RecordTaskFailure(ctx, tt, b.nowFunc().Sub(start)) diff --git a/pkg/crypto/crypto.go b/pkg/crypto/crypto.go index 6d3d8e14..c9764049 100644 --- a/pkg/crypto/crypto.go +++ b/pkg/crypto/crypto.go @@ -15,36 +15,47 @@ type PlaintextCredential struct { } type EncryptionManager struct { - opts *v2.CredentialOptions - configs []*v2.EncryptionConfig + opts *v2.CredentialOptions + providerConfigsMap map[string]([]*v2.EncryptionConfig) } // FIXME(morgabra) Be tolerant of failures here and return the encryptions that succeeded. We've likely already // done things to generate the credentials we want to encrypt, so we should still return the created objects // even if your encryption provider is misconfigured. func (pkem *EncryptionManager) Encrypt(ctx context.Context, cred *v2.PlaintextData) ([]*v2.EncryptedData, error) { - encryptedDatas := make([]*v2.EncryptedData, 0, len(pkem.configs)) + encryptedDatas := make([]*v2.EncryptedData, 0) - for _, config := range pkem.configs { - provider, err := providers.GetEncryptionProviderForConfig(ctx, config) + for providerName, configs := range pkem.providerConfigsMap { + provider, err := providers.GetEncryptionProvider(providerName) if err != nil { return nil, err } - encryptedData, err := provider.Encrypt(ctx, config, cred) + encryptedData, err := provider.Encrypt(ctx, configs, cred) if err != nil { return nil, err } - encryptedDatas = append(encryptedDatas, encryptedData) + encryptedDatas = append(encryptedDatas, encryptedData...) } return encryptedDatas, nil } -func NewEncryptionManager(co *v2.CredentialOptions, ec []*v2.EncryptionConfig) (*EncryptionManager, error) { +// MJP creating the providerMap means parsing the configs and failing early instead of in Encrypt. +func NewEncryptionManager(ctx context.Context, co *v2.CredentialOptions, ec []*v2.EncryptionConfig) (*EncryptionManager, error) { + // Group the encryption configs by provider + providerMap := make(map[string]([]*v2.EncryptionConfig)) + for _, config := range ec { + providerName, err := providers.GetEncryptionProviderName(ctx, config) + if err != nil { + return nil, err + } + providerMap[providerName] = append(providerMap[providerName], config) + } + em := &EncryptionManager{ - opts: co, - configs: ec, + opts: co, + providerConfigsMap: providerMap, } return em, nil } diff --git a/pkg/crypto/crypto_test.go b/pkg/crypto/crypto_test.go index 73df9ef5..7da76ec8 100644 --- a/pkg/crypto/crypto_test.go +++ b/pkg/crypto/crypto_test.go @@ -37,6 +37,76 @@ func marshalJWK(t *testing.T, privKey interface{}) (*v2.EncryptionConfig, []byte return config, privJWKBytes } +func TestMultiRecipientEncrypton(t *testing.T) { + ctx := context.Background() + provider, err := providers.GetEncryptionProvider(jwk.EncryptionProviderJwk) + require.NoError(t, err) + + config, key1, err := provider.GenerateKey(ctx) + require.NoError(t, err) + config2, key2, err := provider.GenerateKey(ctx) + require.NoError(t, err) + config3, key3, err := provider.GenerateKey(ctx) + require.NoError(t, err) + + // try with an RSA key as well + privKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + config4, key4 := marshalJWK(t, privKey) + + plainText := &v2.PlaintextData{ + Name: "password", + Description: "this is the password", + Schema: "", + Bytes: []byte("hunter2"), + } + + cipherTexts, err := provider.Encrypt(ctx, []*v2.EncryptionConfig{config, config2, config3, config4}, plainText) + require.NoError(t, err) + require.Len(t, cipherTexts, 1) + + cipherText := cipherTexts[0] + require.Equal(t, plainText.Name, cipherText.Name) + require.Equal(t, plainText.Description, cipherText.Description) + require.Equal(t, plainText.Schema, cipherText.Schema) + require.NotEqual(t, plainText.Bytes, cipherText.EncryptedBytes) + require.Greater(t, len(cipherText.EncryptedBytes), len(plainText.Bytes)) + + decryptedText, err := provider.Decrypt(ctx, cipherText, key1) + require.NoError(t, err) + require.Equal(t, plainText.Name, decryptedText.Name) + require.Equal(t, plainText.Description, decryptedText.Description) + require.Equal(t, plainText.Schema, decryptedText.Schema) + require.Equal(t, plainText.Bytes, decryptedText.Bytes) + + decryptedText, err = provider.Decrypt(ctx, cipherText, key2) + require.NoError(t, err) + require.Equal(t, plainText.Name, decryptedText.Name) + require.Equal(t, plainText.Description, decryptedText.Description) + require.Equal(t, plainText.Schema, decryptedText.Schema) + require.Equal(t, plainText.Bytes, decryptedText.Bytes) + + decryptedText, err = provider.Decrypt(ctx, cipherText, key3) + require.NoError(t, err) + require.Equal(t, plainText.Name, decryptedText.Name) + require.Equal(t, plainText.Description, decryptedText.Description) + require.Equal(t, plainText.Schema, decryptedText.Schema) + require.Equal(t, plainText.Bytes, decryptedText.Bytes) + + decryptedText, err = provider.Decrypt(ctx, cipherText, key4) + require.NoError(t, err) + require.Equal(t, plainText.Name, decryptedText.Name) + require.Equal(t, plainText.Description, decryptedText.Description) + require.Equal(t, plainText.Schema, decryptedText.Schema) + require.Equal(t, plainText.Bytes, decryptedText.Bytes) + + // but some random new key shouldn't work + _, badKey, err := provider.GenerateKey(ctx) + require.NoError(t, err) + _, err = provider.Decrypt(ctx, cipherText, badKey) + require.Error(t, err) +} + func testEncryptionProvider(t *testing.T, ctx context.Context, config *v2.EncryptionConfig, privKey []byte) { provider, err := providers.GetEncryptionProvider(jwk.EncryptionProviderJwk) require.NoError(t, err) @@ -47,8 +117,10 @@ func testEncryptionProvider(t *testing.T, ctx context.Context, config *v2.Encryp Schema: "", Bytes: []byte("hunter2"), } - cipherText, err := provider.Encrypt(ctx, config, plainText) + cipherTexts, err := provider.Encrypt(ctx, []*v2.EncryptionConfig{config}, plainText) require.NoError(t, err) + require.Len(t, cipherTexts, 1) + cipherText := cipherTexts[0] require.Equal(t, plainText.Name, cipherText.Name) require.Equal(t, plainText.Description, cipherText.Description) @@ -138,7 +210,7 @@ func TestEncryptionProviderJWKSymmetric(t *testing.T) { Schema: "", Bytes: []byte("hunter2"), } - cipherText, err := provider.Encrypt(ctx, config, plainText) + cipherText, err := provider.Encrypt(ctx, []*v2.EncryptionConfig{config}, plainText) require.ErrorIs(t, err, jwk.JWKUnsupportedKeyTypeError) require.Nil(t, cipherText) } diff --git a/pkg/crypto/providers/jwk/ed25519.go b/pkg/crypto/providers/jwk/ed25519.go index e06f38da..5dfff24e 100644 --- a/pkg/crypto/providers/jwk/ed25519.go +++ b/pkg/crypto/providers/jwk/ed25519.go @@ -4,11 +4,12 @@ import ( "crypto/ed25519" "fmt" + "filippo.io/age" "filippo.io/age/agessh" "golang.org/x/crypto/ssh" ) -func EncryptED25519(pubKey ed25519.PublicKey, plaintext []byte) ([]byte, error) { +func CreateED25519Recipient(pubKey ed25519.PublicKey) (*agessh.Ed25519Recipient, error) { sshPubKey, err := ssh.NewPublicKey(pubKey) if err != nil { return nil, fmt.Errorf("jwk-ed25519: failed to convert public key to ssh format: %w", err) @@ -18,8 +19,16 @@ func EncryptED25519(pubKey ed25519.PublicKey, plaintext []byte) ([]byte, error) if err != nil { return nil, fmt.Errorf("jwk-ed25519: failed to create recipient: %w", err) } + return recipient, nil +} + +func EncryptED25519(pubKey ed25519.PublicKey, plaintext []byte) ([]byte, error) { + recipient, err := CreateED25519Recipient(pubKey) + if err != nil { + return nil, err + } - ciphertext, err := ageEncrypt(recipient, plaintext) + ciphertext, err := ageEncrypt([]age.Recipient{recipient}, plaintext) if err != nil { return nil, fmt.Errorf("jwk-ed25519: %w", err) } diff --git a/pkg/crypto/providers/jwk/jwk.go b/pkg/crypto/providers/jwk/jwk.go index a8b8b95b..7158e213 100644 --- a/pkg/crypto/providers/jwk/jwk.go +++ b/pkg/crypto/providers/jwk/jwk.go @@ -72,49 +72,83 @@ func (j *JWKEncryptionProvider) GenerateKey(ctx context.Context) (*v2.Encryption }, privKeyJWKBytes, nil } -func (j *JWKEncryptionProvider) Encrypt(ctx context.Context, conf *v2.EncryptionConfig, plainText *v2.PlaintextData) (*v2.EncryptedData, error) { - jwk, err := unmarshalJWK(conf.GetJwkPublicKeyConfig().GetPubKey()) - if err != nil { - return nil, err - } +func (j *JWKEncryptionProvider) Encrypt(ctx context.Context, configs []*v2.EncryptionConfig, plainText *v2.PlaintextData) ([]*v2.EncryptedData, error) { + recipients := make([]age.Recipient, 0, len(configs)) + recipientThumbs := make([]string, 0, len(configs)) + encrypted := make([]*v2.EncryptedData, 0, len(configs)) - var ciphertext []byte - switch pubKey := jwk.Public().Key.(type) { - case ed25519.PublicKey: - ciphertext, err = EncryptED25519(pubKey, plainText.Bytes) - if err != nil { - return nil, err - } - case *ecdsa.PublicKey: - ciphertext, err = EncryptECDSA(pubKey, plainText.Bytes) + for _, config := range configs { + jwk, err := unmarshalJWK(config.GetJwkPublicKeyConfig().GetPubKey()) if err != nil { return nil, err } - case *rsa.PublicKey: - ciphertext, err = EncryptRSA(pubKey, plainText.Bytes) - if err != nil { - return nil, err - } - default: - return nil, JWKUnsupportedKeyTypeError - } - tp, err := thumbprint(jwk) - if err != nil { - return nil, err + switch pubKey := jwk.Public().Key.(type) { + case ed25519.PublicKey: + tp, err := thumbprint(jwk) + if err != nil { + return nil, err + } + recipientThumbs = append(recipientThumbs, tp) + recipient, err := CreateED25519Recipient(pubKey) + if err != nil { + return nil, err + } + recipients = append(recipients, recipient) + case *rsa.PublicKey: + tp, err := thumbprint(jwk) + if err != nil { + return nil, err + } + recipientThumbs = append(recipientThumbs, tp) + recipient, err := CreateRSARecipient(pubKey) + if err != nil { + return nil, err + } + recipients = append(recipients, recipient) + case *ecdsa.PublicKey: + tp, err := thumbprint(jwk) + if err != nil { + return nil, err + } + ciphertext, err := EncryptECDSA(pubKey, plainText.Bytes) + if err != nil { + return nil, err + } + encCipherText := base64.StdEncoding.EncodeToString(ciphertext) + encrypted = append(encrypted, &v2.EncryptedData{ + Provider: EncryptionProviderJwk, + KeyId: tp, // MJP remove me once we've depricated fully + Name: plainText.Name, + Description: plainText.Description, + Schema: plainText.Schema, + EncryptedBytes: []byte(encCipherText), + KeyIds: []string{tp}, + }) + + default: + return nil, JWKUnsupportedKeyTypeError + } } - encCipherText := base64.StdEncoding.EncodeToString(ciphertext) - - return &v2.EncryptedData{ - Provider: EncryptionProviderJwk, - KeyId: tp, - Name: plainText.Name, - Description: plainText.Description, - Schema: plainText.Schema, - EncryptedBytes: []byte(encCipherText), - KeyIds: []string{tp}, - }, nil + if len(recipients) > 0 { + ciphertext, err := ageEncrypt(recipients, plainText.Bytes) + if err != nil { + return nil, fmt.Errorf("jwk: %w", err) + } + encCipherText := base64.StdEncoding.EncodeToString(ciphertext) + + encrypted = append(encrypted, &v2.EncryptedData{ + Provider: EncryptionProviderJwk, + KeyId: recipientThumbs[0], // MJP remove me once we've depricated fully + Name: plainText.Name, + Description: plainText.Description, + Schema: plainText.Schema, + EncryptedBytes: []byte(encCipherText), + KeyIds: recipientThumbs, + }) + } + return encrypted, nil } func (j *JWKEncryptionProvider) Decrypt(ctx context.Context, cipherText *v2.EncryptedData, privateKey []byte) (*v2.PlaintextData, error) { @@ -169,9 +203,9 @@ func thumbprint(jwk *jose.JSONWebKey) (string, error) { return hex.EncodeToString(tp), nil } -func ageEncrypt(r age.Recipient, plaintext []byte) ([]byte, error) { +func ageEncrypt(r []age.Recipient, plaintext []byte) ([]byte, error) { ciphertext := &bytes.Buffer{} - w, err := age.Encrypt(ciphertext, r) + w, err := age.Encrypt(ciphertext, r...) if err != nil { return nil, fmt.Errorf("age: failed to encrypt: %w", err) } diff --git a/pkg/crypto/providers/jwk/rsa.go b/pkg/crypto/providers/jwk/rsa.go index f5256e05..92711200 100644 --- a/pkg/crypto/providers/jwk/rsa.go +++ b/pkg/crypto/providers/jwk/rsa.go @@ -4,11 +4,12 @@ import ( "crypto/rsa" "fmt" + "filippo.io/age" "filippo.io/age/agessh" "golang.org/x/crypto/ssh" ) -func EncryptRSA(pubKey *rsa.PublicKey, plaintext []byte) ([]byte, error) { +func CreateRSARecipient(pubKey *rsa.PublicKey) (*agessh.RSARecipient, error) { sshPubKey, err := ssh.NewPublicKey(pubKey) if err != nil { return nil, fmt.Errorf("jwk-rsa: failed to convert public key to ssh format: %w", err) @@ -19,7 +20,16 @@ func EncryptRSA(pubKey *rsa.PublicKey, plaintext []byte) ([]byte, error) { return nil, fmt.Errorf("jwk-rsa: failed to create recipient: %w", err) } - ciphertext, err := ageEncrypt(recipient, plaintext) + return recipient, nil +} + +func EncryptRSA(pubKey *rsa.PublicKey, plaintext []byte) ([]byte, error) { + recipient, err := CreateRSARecipient(pubKey) + if err != nil { + return nil, err + } + + ciphertext, err := ageEncrypt([]age.Recipient{recipient}, plaintext) if err != nil { return nil, fmt.Errorf("jwk-rsa: %w", err) } diff --git a/pkg/crypto/providers/registry.go b/pkg/crypto/providers/registry.go index 399be670..9a0268e0 100644 --- a/pkg/crypto/providers/registry.go +++ b/pkg/crypto/providers/registry.go @@ -12,7 +12,7 @@ import ( var EncryptionProviderNotRegisteredError = fmt.Errorf("crypto/providers: encryption provider not registered") type EncryptionProvider interface { - Encrypt(ctx context.Context, conf *v2.EncryptionConfig, plainText *v2.PlaintextData) (*v2.EncryptedData, error) + Encrypt(ctx context.Context, conf []*v2.EncryptionConfig, plainText *v2.PlaintextData) ([]*v2.EncryptedData, error) Decrypt(ctx context.Context, cipherText *v2.EncryptedData, privateKey []byte) (*v2.PlaintextData, error) GenerateKey(ctx context.Context) (*v2.EncryptionConfig, []byte, error) @@ -34,11 +34,7 @@ func GetEncryptionProvider(name string) (EncryptionProvider, error) { return provider, nil } -// GetEncryptionProviderForConfig returns the encryption provider for the given config. -// If the config specifies a provider, we will fetch it directly by name and return an error if it's not found. -// If the config contains a non-nil well-known configuration (like JWKPublicKeyConfig), we will return the provider for that by name. -// If we can't find a provider, we return an EncryptionProviderNotRegisteredError. -func GetEncryptionProviderForConfig(ctx context.Context, conf *v2.EncryptionConfig) (EncryptionProvider, error) { +func GetEncryptionProviderName(ctx context.Context, conf *v2.EncryptionConfig) (string, error) { providerName := normalizeProviderName(conf.GetProvider()) // We weren't given an explicit provider, so we can try to infer one based on the config. @@ -49,9 +45,17 @@ func GetEncryptionProviderForConfig(ctx context.Context, conf *v2.EncryptionConf } } - // If we don't have a provider by now, bail. - if providerName == "" { - return nil, EncryptionProviderNotRegisteredError + return providerName, nil +} + +// GetEncryptionProviderForConfig returns the encryption provider for the given config. +// If the config specifies a provider, we will fetch it directly by name and return an error if it's not found. +// If the config contains a non-nil well-known configuration (like JWKPublicKeyConfig), we will return the provider for that by name. +// If we can't find a provider, we return an EncryptionProviderNotRegisteredError. +func GetEncryptionProviderForConfig(ctx context.Context, conf *v2.EncryptionConfig) (EncryptionProvider, error) { + providerName, err := GetEncryptionProviderName(ctx, conf) + if err != nil { + return nil, err } return GetEncryptionProvider(providerName)