Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove salt field from EncryptedData #3357

Merged
merged 1 commit into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion internal/controlplane/handlers_oauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/google/uuid"
"github.com/lestrrat-go/jwx/v2/jwt/openid"
"github.com/sqlc-dev/pqtype"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
Expand Down Expand Up @@ -430,6 +431,8 @@ func TestProviderCallback(t *testing.T) {
Valid: true,
String: encryptedUrlString.EncodedData,
}
serialized, err := encryptedUrlString.Serialize()
require.NoError(t, err)

tx := sql.Tx{}
store.EXPECT().BeginTransaction().Return(&tx, nil)
Expand All @@ -441,7 +444,11 @@ func TestProviderCallback(t *testing.T) {
db.GetProjectIDBySessionStateRow{
ProjectID: projectID,
RedirectUrl: encryptedUrl,
RemoteUser: tc.remoteUser,
EncryptedRedirect: pqtype.NullRawMessage{
RawMessage: serialized,
Valid: true,
},
RemoteUser: tc.remoteUser,
}, nil)

if tc.existingProvider {
Expand Down
26 changes: 24 additions & 2 deletions internal/controlplane/handlers_providers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (

"github.com/google/uuid"
"github.com/lestrrat-go/jwx/v2/jwt/openid"
"github.com/sqlc-dev/pqtype"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
Expand All @@ -31,6 +32,8 @@ import (
"github.com/stacklok/minder/internal/auth"
"github.com/stacklok/minder/internal/authz/mock"
serverconfig "github.com/stacklok/minder/internal/config/server"
"github.com/stacklok/minder/internal/crypto"
"github.com/stacklok/minder/internal/crypto/algorithms"
mockcrypto "github.com/stacklok/minder/internal/crypto/mock"
"github.com/stacklok/minder/internal/db"
"github.com/stacklok/minder/internal/engine"
Expand Down Expand Up @@ -87,7 +90,7 @@ func TestDeleteProvider(t *testing.T) {
mockStore.EXPECT().
GetAccessTokenByProjectID(gomock.Any(), gomock.Any()).
Return(db.ProviderAccessToken{
EncryptedToken: "encryptedToken",
EncryptedAccessToken: generateSecret(t),
}, nil).AnyTimes()
mockStore.EXPECT().DeleteProvider(gomock.Any(), db.DeleteProviderParams{
ID: providerID,
Expand Down Expand Up @@ -206,7 +209,7 @@ func TestDeleteProviderByID(t *testing.T) {
mockStore.EXPECT().
GetAccessTokenByProjectID(gomock.Any(), gomock.Any()).
Return(db.ProviderAccessToken{
EncryptedToken: "encryptedToken",
EncryptedAccessToken: generateSecret(t),
}, nil).AnyTimes()
mockStore.EXPECT().
GetProviderWebhooks(gomock.Any(), gomock.Eq(providerID)).
Expand Down Expand Up @@ -267,3 +270,22 @@ func TestDeleteProviderByID(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, providerID.String(), resp.Id)
}

func generateSecret(t *testing.T) pqtype.NullRawMessage {
t.Helper()

data := crypto.EncryptedData{
Algorithm: algorithms.Aes256Cfb,
// randomly generated
EncodedData: "dnS6VFiMYrfnbeP6eixmBw==",
KeyVersion: "",
}

serialized, err := data.Serialize()
require.NoError(t, err)

return pqtype.NullRawMessage{
RawMessage: serialized,
Valid: true,
}
}
6 changes: 5 additions & 1 deletion internal/controlplane/handlers_repositories_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"testing"

"github.com/google/uuid"
"github.com/sqlc-dev/pqtype"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"

Expand Down Expand Up @@ -549,7 +550,10 @@ func createServer(
store.EXPECT().
GetAccessTokenByProjectID(gomock.Any(), gomock.Any()).
Return(db.ProviderAccessToken{
EncryptedToken: "encryptedToken",
EncryptedAccessToken: pqtype.NullRawMessage{
Valid: true,
RawMessage: make(json.RawMessage, 16),
},
}, nil).AnyTimes()
store.EXPECT().
ListRepositoriesByProjectID(gomock.Any(), gomock.Any()).
Expand Down
22 changes: 13 additions & 9 deletions internal/crypto/algorithms/aes256cfb.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,32 +29,36 @@ import (
// AES256CFBAlgorithm implements the AES-256-CFB algorithm
type AES256CFBAlgorithm struct{}

// Our current implementation of AES-256-CFB uses a fixed salt.
// Since we are planning to move to AES-256-GCM, leave this hardcoded here.
var legacySalt = []byte("somesalt")

// Encrypt encrypts a row of data.
func (a *AES256CFBAlgorithm) Encrypt(data []byte, key []byte, salt []byte) ([]byte, error) {
if len(data) > maxSize {
func (a *AES256CFBAlgorithm) Encrypt(plaintext []byte, key []byte) ([]byte, error) {
if len(plaintext) > maxSize {
return nil, status.Errorf(codes.InvalidArgument, "data is too large (>32MB)")
}
block, err := aes.NewCipher(a.deriveKey(key, salt))
block, err := aes.NewCipher(a.deriveKey(key))
if err != nil {
return nil, status.Errorf(codes.Unknown, "failed to create cipher: %s", err)
}

// The IV needs to be unique, but not secure. Therefore, it's common to include it at the beginning of the ciphertext.
ciphertext := make([]byte, aes.BlockSize+len(data))
ciphertext := make([]byte, aes.BlockSize+len(plaintext))
iv := ciphertext[:aes.BlockSize]
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
return nil, status.Errorf(codes.Unknown, "failed to read random bytes: %s", err)
}

stream := cipher.NewCFBEncrypter(block, iv)
stream.XORKeyStream(ciphertext[aes.BlockSize:], data)
stream.XORKeyStream(ciphertext[aes.BlockSize:], plaintext)

return ciphertext, nil
}

// Decrypt decrypts a row of data.
func (a *AES256CFBAlgorithm) Decrypt(ciphertext []byte, key []byte, salt []byte) ([]byte, error) {
block, err := aes.NewCipher(a.deriveKey(key, salt))
func (a *AES256CFBAlgorithm) Decrypt(ciphertext []byte, key []byte) ([]byte, error) {
block, err := aes.NewCipher(a.deriveKey(key))
if err != nil {
return nil, status.Errorf(codes.Unknown, "failed to create cipher: %s", err)
}
Expand All @@ -74,6 +78,6 @@ func (a *AES256CFBAlgorithm) Decrypt(ciphertext []byte, key []byte, salt []byte)
}

// Function to derive a key from a passphrase using Argon2
func (_ *AES256CFBAlgorithm) deriveKey(key []byte, salt []byte) []byte {
return argon2.IDKey(key, salt, 1, 64*1024, 4, 32)
func (_ *AES256CFBAlgorithm) deriveKey(key []byte) []byte {
return argon2.IDKey(key, legacySalt, 1, 64*1024, 4, 32)
}
4 changes: 2 additions & 2 deletions internal/crypto/algorithms/algorithm.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ import (

// EncryptionAlgorithm represents a crypto algorithm used by the Engine
type EncryptionAlgorithm interface {
Encrypt(data []byte, key []byte, salt []byte) ([]byte, error)
Decrypt(data []byte, key []byte, salt []byte) ([]byte, error)
Encrypt(plaintext []byte, key []byte) ([]byte, error)
Decrypt(ciphertext []byte, key []byte) ([]byte, error)
}

// Type is an enum of supported encryption algorithms
Expand Down
11 changes: 2 additions & 9 deletions internal/crypto/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ type Engine interface {
}

var (
// TODO: get rid of this when we allow per-secret salting.
legacySalt = []byte("somesalt")
// ErrDecrypt is returned when we cannot decrypt a secret.
ErrDecrypt = errors.New("unable to decrypt")
// ErrEncrypt is returned when we cannot encrypt a secret.
Expand Down Expand Up @@ -170,7 +168,7 @@ func (e *engine) encrypt(data []byte) (EncryptedData, error) {
return EncryptedData{}, fmt.Errorf("unable to find preferred key with ID: %s", e.defaultKeyID)
}

encrypted, err := algorithm.Encrypt(data, key, legacySalt)
encrypted, err := algorithm.Encrypt(data, key)
if err != nil {
return EncryptedData{}, errors.Join(ErrEncrypt, err)
}
Expand All @@ -180,7 +178,6 @@ func (e *engine) encrypt(data []byte) (EncryptedData, error) {
return EncryptedData{
Algorithm: e.defaultAlgorithm,
EncodedData: encoded,
Salt: legacySalt,
KeyVersion: e.defaultKeyID,
}, nil
}
Expand All @@ -190,10 +187,6 @@ func (e *engine) decrypt(data EncryptedData) ([]byte, error) {
return nil, errors.New("cannot decrypt empty data")
}

if len(data.Salt) == 0 {
return nil, errors.New("cannot decrypt data with empty salt")
}

algorithm, ok := e.supportedAlgorithms[data.Algorithm]
if !ok {
return nil, fmt.Errorf("%w: %s", algorithms.ErrUnknownAlgorithm, e.defaultAlgorithm)
Expand All @@ -218,7 +211,7 @@ func (e *engine) decrypt(data EncryptedData) ([]byte, error) {
}

// decrypt the data
result, err := algorithm.Decrypt(encrypted, key, data.Salt)
result, err := algorithm.Decrypt(encrypted, key)
if err != nil {
return nil, errors.Join(ErrDecrypt, err)
}
Expand Down
17 changes: 0 additions & 17 deletions internal/crypto/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,20 +167,6 @@ func TestDecryptEmpty(t *testing.T) {
require.ErrorContains(t, err, "cannot decrypt empty data")
}

func TestDecryptEmptySalt(t *testing.T) {
t.Parallel()

engine, err := NewEngineFromConfig(config)
require.NoError(t, err)
encryptedToken := EncryptedData{
EncodedData: "abc",
Salt: nil,
}

_, err = engine.DecryptString(encryptedToken)
require.ErrorContains(t, err, "cannot decrypt data with empty salt")
}

func TestDecryptBadAlgorithm(t *testing.T) {
t.Parallel()

Expand All @@ -189,7 +175,6 @@ func TestDecryptBadAlgorithm(t *testing.T) {
encryptedToken := EncryptedData{
Algorithm: "I'm a little teapot",
EncodedData: "abc",
Salt: legacySalt,
KeyVersion: "",
}
require.NoError(t, err)
Expand All @@ -207,7 +192,6 @@ func TestDecryptBadEncoding(t *testing.T) {
Algorithm: algorithms.Aes256Cfb,
// Unicode snowman is _not_ a valid base64 character
EncodedData: "☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃",
Salt: legacySalt,
KeyVersion: "",
}
require.NoError(t, err)
Expand All @@ -225,7 +209,6 @@ func TestDecryptFailedDecryption(t *testing.T) {
Algorithm: algorithms.Aes256Cfb,
// too small of a value - will trigger the ciphertext length check
EncodedData: "abcdef0123456789",
Salt: legacySalt,
KeyVersion: "",
}
require.NoError(t, err)
Expand Down
3 changes: 0 additions & 3 deletions internal/crypto/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ type EncryptedData struct {
Algorithm algorithms.Type
// The encrypted data represented as a base64 encoded string.
EncodedData string
// The salt used in the encryption.
Salt []byte
// An identifier which specifies the key used.
// Used to handle multiple keys during key rotation.
KeyVersion string
Expand All @@ -48,7 +46,6 @@ func NewBackwardsCompatibleEncryptedData(encryptedData string) EncryptedData {
return EncryptedData{
Algorithm: algorithms.Aes256Cfb,
EncodedData: encryptedData,
Salt: legacySalt,
KeyVersion: "",
}
}
Expand Down
46 changes: 42 additions & 4 deletions internal/db/provider_access_tokens_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ import (
"database/sql"
"testing"

"github.com/sqlc-dev/pqtype"
"github.com/stretchr/testify/require"

"github.com/stacklok/minder/internal/crypto"
"github.com/stacklok/minder/internal/crypto/algorithms"
)

func TestUpsertProviderAccessToken(t *testing.T) {
Expand All @@ -30,11 +34,15 @@ func TestUpsertProviderAccessToken(t *testing.T) {
project := createRandomProject(t, org.ID)
prov := createRandomProvider(t, project.ID)

secret := createSecret(t, "abc")
serialized := serializeSecret(t, secret)

tok, err := testQueries.UpsertAccessToken(context.Background(), UpsertAccessTokenParams{
ProjectID: project.ID,
Provider: prov.Name,
EncryptedToken: "abc",
OwnerFilter: sql.NullString{},
ProjectID: project.ID,
Provider: prov.Name,
EncryptedToken: "abc",
EncryptedAccessToken: serialized,
OwnerFilter: sql.NullString{},
})

require.NoError(t, err)
Expand All @@ -45,6 +53,7 @@ func TestUpsertProviderAccessToken(t *testing.T) {
require.Equal(t, project.ID, tok.ProjectID)
require.Equal(t, prov.Name, tok.Provider)
require.Equal(t, "abc", tok.EncryptedToken)
require.Equal(t, secret, deserializeSecret(t, tok.EncryptedAccessToken))
require.Equal(t, sql.NullString{}, tok.OwnerFilter)

tokUpdate, err := testQueries.UpsertAccessToken(context.Background(), UpsertAccessTokenParams{
Expand All @@ -63,3 +72,32 @@ func TestUpsertProviderAccessToken(t *testing.T) {
require.Equal(t, tok.CreatedAt, tokUpdate.CreatedAt)
require.NotEqual(t, tok.UpdatedAt, tokUpdate.UpdatedAt)
}

func createSecret(t *testing.T, encryptedData string) crypto.EncryptedData {
t.Helper()

return crypto.EncryptedData{
Algorithm: algorithms.Aes256Cfb,
EncodedData: encryptedData,
KeyVersion: "12345",
}
}

func serializeSecret(t *testing.T, data crypto.EncryptedData) pqtype.NullRawMessage {
t.Helper()

serialized, err := data.Serialize()
require.NoError(t, err)
return pqtype.NullRawMessage{
RawMessage: serialized,
Valid: true,
}
}

func deserializeSecret(t *testing.T, data pqtype.NullRawMessage) crypto.EncryptedData {
t.Helper()

result, err := crypto.DeserializeEncryptedData(data.RawMessage)
require.NoError(t, err)
return result
}
Loading