diff --git a/internal/controlplane/handlers_oauth_test.go b/internal/controlplane/handlers_oauth_test.go index 2602677d81..94b7e30610 100644 --- a/internal/controlplane/handlers_oauth_test.go +++ b/internal/controlplane/handlers_oauth_test.go @@ -33,6 +33,7 @@ import ( "github.com/google/go-cmp/cmp/cmpopts" "github.com/google/uuid" "github.com/lestrrat-go/jwx/v2/jwt/openid" + "github.com/sqlc-dev/pqtype" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" @@ -430,6 +431,8 @@ func TestProviderCallback(t *testing.T) { Valid: true, String: encryptedUrlString.EncodedData, } + serialized, err := encryptedUrlString.Serialize() + require.NoError(t, err) tx := sql.Tx{} store.EXPECT().BeginTransaction().Return(&tx, nil) @@ -441,7 +444,11 @@ func TestProviderCallback(t *testing.T) { db.GetProjectIDBySessionStateRow{ ProjectID: projectID, RedirectUrl: encryptedUrl, - RemoteUser: tc.remoteUser, + EncryptedRedirect: pqtype.NullRawMessage{ + RawMessage: serialized, + Valid: true, + }, + RemoteUser: tc.remoteUser, }, nil) if tc.existingProvider { diff --git a/internal/controlplane/handlers_providers_test.go b/internal/controlplane/handlers_providers_test.go index 6734af515d..302417df63 100644 --- a/internal/controlplane/handlers_providers_test.go +++ b/internal/controlplane/handlers_providers_test.go @@ -22,6 +22,7 @@ import ( "github.com/google/uuid" "github.com/lestrrat-go/jwx/v2/jwt/openid" + "github.com/sqlc-dev/pqtype" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" @@ -31,6 +32,8 @@ import ( "github.com/stacklok/minder/internal/auth" "github.com/stacklok/minder/internal/authz/mock" serverconfig "github.com/stacklok/minder/internal/config/server" + "github.com/stacklok/minder/internal/crypto" + "github.com/stacklok/minder/internal/crypto/algorithms" mockcrypto "github.com/stacklok/minder/internal/crypto/mock" "github.com/stacklok/minder/internal/db" "github.com/stacklok/minder/internal/engine" @@ -87,7 +90,7 @@ func TestDeleteProvider(t *testing.T) { mockStore.EXPECT(). GetAccessTokenByProjectID(gomock.Any(), gomock.Any()). Return(db.ProviderAccessToken{ - EncryptedToken: "encryptedToken", + EncryptedAccessToken: generateSecret(t), }, nil).AnyTimes() mockStore.EXPECT().DeleteProvider(gomock.Any(), db.DeleteProviderParams{ ID: providerID, @@ -206,7 +209,7 @@ func TestDeleteProviderByID(t *testing.T) { mockStore.EXPECT(). GetAccessTokenByProjectID(gomock.Any(), gomock.Any()). Return(db.ProviderAccessToken{ - EncryptedToken: "encryptedToken", + EncryptedAccessToken: generateSecret(t), }, nil).AnyTimes() mockStore.EXPECT(). GetProviderWebhooks(gomock.Any(), gomock.Eq(providerID)). @@ -267,3 +270,22 @@ func TestDeleteProviderByID(t *testing.T) { assert.NoError(t, err) assert.Equal(t, providerID.String(), resp.Id) } + +func generateSecret(t *testing.T) pqtype.NullRawMessage { + t.Helper() + + data := crypto.EncryptedData{ + Algorithm: algorithms.Aes256Cfb, + // randomly generated + EncodedData: "dnS6VFiMYrfnbeP6eixmBw==", + KeyVersion: "", + } + + serialized, err := data.Serialize() + require.NoError(t, err) + + return pqtype.NullRawMessage{ + RawMessage: serialized, + Valid: true, + } +} diff --git a/internal/controlplane/handlers_repositories_test.go b/internal/controlplane/handlers_repositories_test.go index f12ac59eae..b054e1ddd1 100644 --- a/internal/controlplane/handlers_repositories_test.go +++ b/internal/controlplane/handlers_repositories_test.go @@ -22,6 +22,7 @@ import ( "testing" "github.com/google/uuid" + "github.com/sqlc-dev/pqtype" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" @@ -549,7 +550,10 @@ func createServer( store.EXPECT(). GetAccessTokenByProjectID(gomock.Any(), gomock.Any()). Return(db.ProviderAccessToken{ - EncryptedToken: "encryptedToken", + EncryptedAccessToken: pqtype.NullRawMessage{ + Valid: true, + RawMessage: make(json.RawMessage, 16), + }, }, nil).AnyTimes() store.EXPECT(). ListRepositoriesByProjectID(gomock.Any(), gomock.Any()). diff --git a/internal/crypto/algorithms/aes256cfb.go b/internal/crypto/algorithms/aes256cfb.go index 8075ebaee1..8a026e9fde 100644 --- a/internal/crypto/algorithms/aes256cfb.go +++ b/internal/crypto/algorithms/aes256cfb.go @@ -29,32 +29,36 @@ import ( // AES256CFBAlgorithm implements the AES-256-CFB algorithm type AES256CFBAlgorithm struct{} +// Our current implementation of AES-256-CFB uses a fixed salt. +// Since we are planning to move to AES-256-GCM, leave this hardcoded here. +var legacySalt = []byte("somesalt") + // Encrypt encrypts a row of data. -func (a *AES256CFBAlgorithm) Encrypt(data []byte, key []byte, salt []byte) ([]byte, error) { - if len(data) > maxSize { +func (a *AES256CFBAlgorithm) Encrypt(plaintext []byte, key []byte) ([]byte, error) { + if len(plaintext) > maxSize { return nil, status.Errorf(codes.InvalidArgument, "data is too large (>32MB)") } - block, err := aes.NewCipher(a.deriveKey(key, salt)) + block, err := aes.NewCipher(a.deriveKey(key)) if err != nil { return nil, status.Errorf(codes.Unknown, "failed to create cipher: %s", err) } // The IV needs to be unique, but not secure. Therefore, it's common to include it at the beginning of the ciphertext. - ciphertext := make([]byte, aes.BlockSize+len(data)) + ciphertext := make([]byte, aes.BlockSize+len(plaintext)) iv := ciphertext[:aes.BlockSize] if _, err := io.ReadFull(rand.Reader, iv); err != nil { return nil, status.Errorf(codes.Unknown, "failed to read random bytes: %s", err) } stream := cipher.NewCFBEncrypter(block, iv) - stream.XORKeyStream(ciphertext[aes.BlockSize:], data) + stream.XORKeyStream(ciphertext[aes.BlockSize:], plaintext) return ciphertext, nil } // Decrypt decrypts a row of data. -func (a *AES256CFBAlgorithm) Decrypt(ciphertext []byte, key []byte, salt []byte) ([]byte, error) { - block, err := aes.NewCipher(a.deriveKey(key, salt)) +func (a *AES256CFBAlgorithm) Decrypt(ciphertext []byte, key []byte) ([]byte, error) { + block, err := aes.NewCipher(a.deriveKey(key)) if err != nil { return nil, status.Errorf(codes.Unknown, "failed to create cipher: %s", err) } @@ -74,6 +78,6 @@ func (a *AES256CFBAlgorithm) Decrypt(ciphertext []byte, key []byte, salt []byte) } // Function to derive a key from a passphrase using Argon2 -func (_ *AES256CFBAlgorithm) deriveKey(key []byte, salt []byte) []byte { - return argon2.IDKey(key, salt, 1, 64*1024, 4, 32) +func (_ *AES256CFBAlgorithm) deriveKey(key []byte) []byte { + return argon2.IDKey(key, legacySalt, 1, 64*1024, 4, 32) } diff --git a/internal/crypto/algorithms/algorithm.go b/internal/crypto/algorithms/algorithm.go index f376807de7..127f748121 100644 --- a/internal/crypto/algorithms/algorithm.go +++ b/internal/crypto/algorithms/algorithm.go @@ -23,8 +23,8 @@ import ( // EncryptionAlgorithm represents a crypto algorithm used by the Engine type EncryptionAlgorithm interface { - Encrypt(data []byte, key []byte, salt []byte) ([]byte, error) - Decrypt(data []byte, key []byte, salt []byte) ([]byte, error) + Encrypt(plaintext []byte, key []byte) ([]byte, error) + Decrypt(ciphertext []byte, key []byte) ([]byte, error) } // Type is an enum of supported encryption algorithms diff --git a/internal/crypto/engine.go b/internal/crypto/engine.go index 721b64fe92..7240b7b142 100644 --- a/internal/crypto/engine.go +++ b/internal/crypto/engine.go @@ -45,8 +45,6 @@ type Engine interface { } var ( - // TODO: get rid of this when we allow per-secret salting. - legacySalt = []byte("somesalt") // ErrDecrypt is returned when we cannot decrypt a secret. ErrDecrypt = errors.New("unable to decrypt") // ErrEncrypt is returned when we cannot encrypt a secret. @@ -170,7 +168,7 @@ func (e *engine) encrypt(data []byte) (EncryptedData, error) { return EncryptedData{}, fmt.Errorf("unable to find preferred key with ID: %s", e.defaultKeyID) } - encrypted, err := algorithm.Encrypt(data, key, legacySalt) + encrypted, err := algorithm.Encrypt(data, key) if err != nil { return EncryptedData{}, errors.Join(ErrEncrypt, err) } @@ -180,7 +178,6 @@ func (e *engine) encrypt(data []byte) (EncryptedData, error) { return EncryptedData{ Algorithm: e.defaultAlgorithm, EncodedData: encoded, - Salt: legacySalt, KeyVersion: e.defaultKeyID, }, nil } @@ -190,10 +187,6 @@ func (e *engine) decrypt(data EncryptedData) ([]byte, error) { return nil, errors.New("cannot decrypt empty data") } - if len(data.Salt) == 0 { - return nil, errors.New("cannot decrypt data with empty salt") - } - algorithm, ok := e.supportedAlgorithms[data.Algorithm] if !ok { return nil, fmt.Errorf("%w: %s", algorithms.ErrUnknownAlgorithm, e.defaultAlgorithm) @@ -218,7 +211,7 @@ func (e *engine) decrypt(data EncryptedData) ([]byte, error) { } // decrypt the data - result, err := algorithm.Decrypt(encrypted, key, data.Salt) + result, err := algorithm.Decrypt(encrypted, key) if err != nil { return nil, errors.Join(ErrDecrypt, err) } diff --git a/internal/crypto/engine_test.go b/internal/crypto/engine_test.go index 316c57c7ca..dc665d9c07 100644 --- a/internal/crypto/engine_test.go +++ b/internal/crypto/engine_test.go @@ -167,20 +167,6 @@ func TestDecryptEmpty(t *testing.T) { require.ErrorContains(t, err, "cannot decrypt empty data") } -func TestDecryptEmptySalt(t *testing.T) { - t.Parallel() - - engine, err := NewEngineFromConfig(config) - require.NoError(t, err) - encryptedToken := EncryptedData{ - EncodedData: "abc", - Salt: nil, - } - - _, err = engine.DecryptString(encryptedToken) - require.ErrorContains(t, err, "cannot decrypt data with empty salt") -} - func TestDecryptBadAlgorithm(t *testing.T) { t.Parallel() @@ -189,7 +175,6 @@ func TestDecryptBadAlgorithm(t *testing.T) { encryptedToken := EncryptedData{ Algorithm: "I'm a little teapot", EncodedData: "abc", - Salt: legacySalt, KeyVersion: "", } require.NoError(t, err) @@ -207,7 +192,6 @@ func TestDecryptBadEncoding(t *testing.T) { Algorithm: algorithms.Aes256Cfb, // Unicode snowman is _not_ a valid base64 character EncodedData: "☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃", - Salt: legacySalt, KeyVersion: "", } require.NoError(t, err) @@ -225,7 +209,6 @@ func TestDecryptFailedDecryption(t *testing.T) { Algorithm: algorithms.Aes256Cfb, // too small of a value - will trigger the ciphertext length check EncodedData: "abcdef0123456789", - Salt: legacySalt, KeyVersion: "", } require.NoError(t, err) diff --git a/internal/crypto/models.go b/internal/crypto/models.go index f54dfb82b5..38468b151c 100644 --- a/internal/crypto/models.go +++ b/internal/crypto/models.go @@ -28,8 +28,6 @@ type EncryptedData struct { Algorithm algorithms.Type // The encrypted data represented as a base64 encoded string. EncodedData string - // The salt used in the encryption. - Salt []byte // An identifier which specifies the key used. // Used to handle multiple keys during key rotation. KeyVersion string @@ -48,7 +46,6 @@ func NewBackwardsCompatibleEncryptedData(encryptedData string) EncryptedData { return EncryptedData{ Algorithm: algorithms.Aes256Cfb, EncodedData: encryptedData, - Salt: legacySalt, KeyVersion: "", } } diff --git a/internal/db/provider_access_tokens_test.go b/internal/db/provider_access_tokens_test.go index bee6a29a33..6d670b9fc5 100644 --- a/internal/db/provider_access_tokens_test.go +++ b/internal/db/provider_access_tokens_test.go @@ -20,7 +20,11 @@ import ( "database/sql" "testing" + "github.com/sqlc-dev/pqtype" "github.com/stretchr/testify/require" + + "github.com/stacklok/minder/internal/crypto" + "github.com/stacklok/minder/internal/crypto/algorithms" ) func TestUpsertProviderAccessToken(t *testing.T) { @@ -30,11 +34,15 @@ func TestUpsertProviderAccessToken(t *testing.T) { project := createRandomProject(t, org.ID) prov := createRandomProvider(t, project.ID) + secret := createSecret(t, "abc") + serialized := serializeSecret(t, secret) + tok, err := testQueries.UpsertAccessToken(context.Background(), UpsertAccessTokenParams{ - ProjectID: project.ID, - Provider: prov.Name, - EncryptedToken: "abc", - OwnerFilter: sql.NullString{}, + ProjectID: project.ID, + Provider: prov.Name, + EncryptedToken: "abc", + EncryptedAccessToken: serialized, + OwnerFilter: sql.NullString{}, }) require.NoError(t, err) @@ -45,6 +53,7 @@ func TestUpsertProviderAccessToken(t *testing.T) { require.Equal(t, project.ID, tok.ProjectID) require.Equal(t, prov.Name, tok.Provider) require.Equal(t, "abc", tok.EncryptedToken) + require.Equal(t, secret, deserializeSecret(t, tok.EncryptedAccessToken)) require.Equal(t, sql.NullString{}, tok.OwnerFilter) tokUpdate, err := testQueries.UpsertAccessToken(context.Background(), UpsertAccessTokenParams{ @@ -63,3 +72,32 @@ func TestUpsertProviderAccessToken(t *testing.T) { require.Equal(t, tok.CreatedAt, tokUpdate.CreatedAt) require.NotEqual(t, tok.UpdatedAt, tokUpdate.UpdatedAt) } + +func createSecret(t *testing.T, encryptedData string) crypto.EncryptedData { + t.Helper() + + return crypto.EncryptedData{ + Algorithm: algorithms.Aes256Cfb, + EncodedData: encryptedData, + KeyVersion: "12345", + } +} + +func serializeSecret(t *testing.T, data crypto.EncryptedData) pqtype.NullRawMessage { + t.Helper() + + serialized, err := data.Serialize() + require.NoError(t, err) + return pqtype.NullRawMessage{ + RawMessage: serialized, + Valid: true, + } +} + +func deserializeSecret(t *testing.T, data pqtype.NullRawMessage) crypto.EncryptedData { + t.Helper() + + result, err := crypto.DeserializeEncryptedData(data.RawMessage) + require.NoError(t, err) + return result +}