From 229e3d5a84cf68cd79972e2fd957ebc0fd058433 Mon Sep 17 00:00:00 2001 From: gak Date: Tue, 20 Aug 2024 08:11:34 +1000 Subject: [PATCH] refactor: encryption refactor (#2429) Fixes #2346 --- backend/controller/controller.go | 3 +- .../cronjobs/cronjobs_integration_test.go | 4 +- backend/controller/cronjobs/cronjobs_test.go | 3 +- backend/controller/dal/async_calls_test.go | 6 +- backend/controller/dal/dal.go | 7 +- backend/controller/dal/dal_test.go | 7 +- backend/controller/dal/encryption.go | 50 ++++--------- backend/controller/dal/fsm_test.go | 4 +- backend/controller/dal/lease_test.go | 5 +- backend/controller/sql/sqltest/testing.go | 2 +- cmd/ftl-controller/main.go | 6 +- internal/encryption/encryption.go | 73 +++++++++++++++---- internal/encryption/encryption_test.go | 7 +- internal/encryption/integration_test.go | 5 +- 14 files changed, 110 insertions(+), 72 deletions(-) diff --git a/backend/controller/controller.go b/backend/controller/controller.go index f2d5e94fd8..ddd8f39ccf 100644 --- a/backend/controller/controller.go +++ b/backend/controller/controller.go @@ -55,6 +55,7 @@ import ( "github.com/TBD54566975/ftl/frontend" cf "github.com/TBD54566975/ftl/internal/configuration" "github.com/TBD54566975/ftl/internal/cors" + "github.com/TBD54566975/ftl/internal/encryption" ftlhttp "github.com/TBD54566975/ftl/internal/http" "github.com/TBD54566975/ftl/internal/log" ftlmaps "github.com/TBD54566975/ftl/internal/maps" @@ -229,7 +230,7 @@ func New(ctx context.Context, conn *sql.DB, config Config, runnerScaling scaling config.ControllerTimeout = time.Second * 5 } - db, err := dal.New(ctx, conn, optional.Ptr[string](config.KMSURI)) + db, err := dal.New(ctx, conn, encryption.NewBuilder().WithKMSURI(optional.Ptr(config.KMSURI))) if err != nil { return nil, fmt.Errorf("failed to create DAL: %w", err) } diff --git a/backend/controller/cronjobs/cronjobs_integration_test.go b/backend/controller/cronjobs/cronjobs_integration_test.go index 33deb53f53..34ed96fc75 100644 --- a/backend/controller/cronjobs/cronjobs_integration_test.go +++ b/backend/controller/cronjobs/cronjobs_integration_test.go @@ -10,12 +10,12 @@ import ( "time" "github.com/alecthomas/assert/v2" - "github.com/alecthomas/types/optional" "github.com/benbjohnson/clock" db "github.com/TBD54566975/ftl/backend/controller/cronjobs/dal" parentdb "github.com/TBD54566975/ftl/backend/controller/dal" "github.com/TBD54566975/ftl/backend/controller/sql/sqltest" + "github.com/TBD54566975/ftl/internal/encryption" in "github.com/TBD54566975/ftl/internal/integration" "github.com/TBD54566975/ftl/internal/log" ) @@ -28,7 +28,7 @@ func TestServiceWithRealDal(t *testing.T) { conn := sqltest.OpenForTesting(ctx, t) dal := db.New(conn) - parentDAL, err := parentdb.New(ctx, conn, optional.None[string]()) + parentDAL, err := parentdb.New(ctx, conn, encryption.NewBuilder()) assert.NoError(t, err) // Using a real clock because real db queries use db clock diff --git a/backend/controller/cronjobs/cronjobs_test.go b/backend/controller/cronjobs/cronjobs_test.go index 715476c932..531617779d 100644 --- a/backend/controller/cronjobs/cronjobs_test.go +++ b/backend/controller/cronjobs/cronjobs_test.go @@ -16,6 +16,7 @@ import ( "github.com/TBD54566975/ftl/backend/controller/sql/sqltest" ftlv1 "github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1" "github.com/TBD54566975/ftl/backend/schema" + "github.com/TBD54566975/ftl/internal/encryption" "github.com/TBD54566975/ftl/internal/log" "github.com/TBD54566975/ftl/internal/model" "github.com/TBD54566975/ftl/internal/slices" @@ -37,7 +38,7 @@ func TestServiceWithMockDal(t *testing.T) { attemptCountMap: map[string]int{}, } conn := sqltest.OpenForTesting(ctx, t) - parentDAL, err := db.New(ctx, conn, optional.None[string]()) + parentDAL, err := db.New(ctx, conn, encryption.NewBuilder()) assert.NoError(t, err) testServiceWithDal(ctx, t, mockDal, parentDAL, clk) diff --git a/backend/controller/dal/async_calls_test.go b/backend/controller/dal/async_calls_test.go index 965a3d6a24..c1ea520cc8 100644 --- a/backend/controller/dal/async_calls_test.go +++ b/backend/controller/dal/async_calls_test.go @@ -4,18 +4,18 @@ import ( "context" "testing" - "github.com/alecthomas/types/optional" + "github.com/alecthomas/assert/v2" "github.com/TBD54566975/ftl/backend/controller/sql/sqltest" dalerrs "github.com/TBD54566975/ftl/backend/dal" + "github.com/TBD54566975/ftl/internal/encryption" "github.com/TBD54566975/ftl/internal/log" - "github.com/alecthomas/assert/v2" ) func TestNoCallToAcquire(t *testing.T) { ctx := log.ContextWithNewDefaultLogger(context.Background()) conn := sqltest.OpenForTesting(ctx, t) - dal, err := New(ctx, conn, optional.None[string]()) + dal, err := New(ctx, conn, encryption.NewBuilder()) assert.NoError(t, err) _, err = dal.AcquireAsyncCall(ctx) diff --git a/backend/controller/dal/dal.go b/backend/controller/dal/dal.go index 0b47b57c4a..2101d0ba7c 100644 --- a/backend/controller/dal/dal.go +++ b/backend/controller/dal/dal.go @@ -210,16 +210,17 @@ func WithReservation(ctx context.Context, reservation Reservation, fn func() err return reservation.Commit(ctx) } -func New(ctx context.Context, conn *stdsql.DB, kmsURL optional.Option[string]) (*DAL, error) { +func New(ctx context.Context, conn *stdsql.DB, encryptionBuilder encryption.Builder) (*DAL, error) { d := &DAL{ db: sql.NewDB(conn), DeploymentChanges: pubsub.New[DeploymentNotification](), - kmsURL: kmsURL, } - if err := d.setupEncryptor(ctx); err != nil { + encryptor, err := encryptionBuilder.Build(ctx, d) + if err != nil { return nil, fmt.Errorf("failed to setup encryptor: %w", err) } + d.encryptor = encryptor return d, nil } diff --git a/backend/controller/dal/dal_test.go b/backend/controller/dal/dal_test.go index 1954dff3bc..10def9f50b 100644 --- a/backend/controller/dal/dal_test.go +++ b/backend/controller/dal/dal_test.go @@ -17,6 +17,7 @@ import ( dalerrs "github.com/TBD54566975/ftl/backend/dal" ftlv1 "github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1" "github.com/TBD54566975/ftl/backend/schema" + "github.com/TBD54566975/ftl/internal/encryption" "github.com/TBD54566975/ftl/internal/log" "github.com/TBD54566975/ftl/internal/model" "github.com/TBD54566975/ftl/internal/sha256" @@ -26,7 +27,7 @@ import ( func TestDAL(t *testing.T) { ctx := log.ContextWithNewDefaultLogger(context.Background()) conn := sqltest.OpenForTesting(ctx, t) - dal, err := New(ctx, conn, optional.None[string]()) + dal, err := New(ctx, conn, encryption.NewBuilder()) assert.NoError(t, err) assert.NotZero(t, dal) var testContent = bytes.Repeat([]byte("sometestcontentthatislongerthanthereadbuffer"), 100) @@ -373,7 +374,7 @@ func TestDAL(t *testing.T) { func TestCreateArtefactConflict(t *testing.T) { ctx := log.ContextWithNewDefaultLogger(context.Background()) conn := sqltest.OpenForTesting(ctx, t) - dal, err := New(ctx, conn, optional.None[string]()) + dal, err := New(ctx, conn, encryption.NewBuilder()) assert.NoError(t, err) idch := make(chan sha256.SHA256, 2) @@ -450,7 +451,7 @@ func assertEventsEqual(t *testing.T, expected, actual []TimelineEvent) { func TestDeleteOldEvents(t *testing.T) { ctx := log.ContextWithNewDefaultLogger(context.Background()) conn := sqltest.OpenForTesting(ctx, t) - dal, err := New(ctx, conn, optional.None[string]()) + dal, err := New(ctx, conn, encryption.NewBuilder()) assert.NoError(t, err) var testContent = bytes.Repeat([]byte("sometestcontentthatislongerthanthereadbuffer"), 100) diff --git a/backend/controller/dal/encryption.go b/backend/controller/dal/encryption.go index af0fd59f89..7106aebbd9 100644 --- a/backend/controller/dal/encryption.go +++ b/backend/controller/dal/encryption.go @@ -58,49 +58,31 @@ func (d *DAL) decryptJSON(encrypted encryption.Encrypted, v any) error { //nolin return nil } -// setupEncryptor sets up the encryptor for the DAL. -// It will either create a key or load the existing one. -// If the KMS URL is not set, it will use a NoOpEncryptor which does not encrypt anything. -func (d *DAL) setupEncryptor(ctx context.Context) (err error) { +func (d *DAL) EnsureKey(ctx context.Context, generateKey func() ([]byte, error)) (encryptedKey []byte, err error) { logger := log.FromContext(ctx) tx, err := d.Begin(ctx) if err != nil { - return fmt.Errorf("failed to begin transaction: %w", err) + return nil, fmt.Errorf("failed to begin transaction: %w", err) } defer tx.CommitOrRollback(ctx, &err) - url, ok := d.kmsURL.Get() - if !ok { - logger.Infof("KMS URL not set, encryption not enabled") - d.encryptor = encryption.NewNoOpEncryptor() - return nil - } + encryptedKey, err = tx.db.GetOnlyEncryptionKey(ctx) + if err != nil && dal.IsNotFound(err) { + logger.Debugf("No encryption key found, generating a new one") + key, err := generateKey() + if err != nil { + return nil, fmt.Errorf("failed to generate key: %w", err) + } - encryptedKey, err := tx.db.GetOnlyEncryptionKey(ctx) - if err != nil { - if dal.IsNotFound(err) { - logger.Infof("No encryption key found, generating a new one") - encryptor, err := encryption.NewKMSEncryptorGenerateKey(url, nil) - if err != nil { - return fmt.Errorf("failed to create encryptor for generation: %w", err) - } - d.encryptor = encryptor - - if err = tx.db.CreateOnlyEncryptionKey(ctx, encryptor.GetEncryptedKeyset()); err != nil { - return fmt.Errorf("failed to create only encryption key: %w", err) - } - - return nil + if err = tx.db.CreateOnlyEncryptionKey(ctx, key); err != nil { + return nil, fmt.Errorf("failed to save the encryption key: %w", err) } - return fmt.Errorf("failed to get only encryption key: %w", err) - } - logger.Debugf("Encryption key found, using it") - encryptor, err := encryption.NewKMSEncryptorWithKMS(url, nil, encryptedKey) - if err != nil { - return fmt.Errorf("failed to create encryptor with encrypted key: %w", err) + return key, nil + } else if err != nil { + return nil, fmt.Errorf("failed to load the encryption key from the db: %w", err) } - d.encryptor = encryptor - return nil + logger.Debugf("Encryption key found, using it") + return encryptedKey, nil } diff --git a/backend/controller/dal/fsm_test.go b/backend/controller/dal/fsm_test.go index abdb0166cc..957a3c3f4e 100644 --- a/backend/controller/dal/fsm_test.go +++ b/backend/controller/dal/fsm_test.go @@ -2,7 +2,6 @@ package dal import ( "context" - "github.com/alecthomas/types/optional" "testing" "time" @@ -12,13 +11,14 @@ import ( "github.com/TBD54566975/ftl/backend/controller/sql/sqltest" dalerrs "github.com/TBD54566975/ftl/backend/dal" "github.com/TBD54566975/ftl/backend/schema" + "github.com/TBD54566975/ftl/internal/encryption" "github.com/TBD54566975/ftl/internal/log" ) func TestSendFSMEvent(t *testing.T) { ctx := log.ContextWithNewDefaultLogger(context.Background()) conn := sqltest.OpenForTesting(ctx, t) - dal, err := New(ctx, conn, optional.None[string]()) + dal, err := New(ctx, conn, encryption.NewBuilder()) assert.NoError(t, err) _, err = dal.AcquireAsyncCall(ctx) diff --git a/backend/controller/dal/lease_test.go b/backend/controller/dal/lease_test.go index 9e2370d72f..34b58e58a3 100644 --- a/backend/controller/dal/lease_test.go +++ b/backend/controller/dal/lease_test.go @@ -14,6 +14,7 @@ import ( "github.com/TBD54566975/ftl/backend/controller/sql" "github.com/TBD54566975/ftl/backend/controller/sql/sqltest" dalerrs "github.com/TBD54566975/ftl/backend/dal" + "github.com/TBD54566975/ftl/internal/encryption" "github.com/TBD54566975/ftl/internal/log" ) @@ -36,7 +37,7 @@ func TestLease(t *testing.T) { } ctx := log.ContextWithNewDefaultLogger(context.Background()) conn := sqltest.OpenForTesting(ctx, t) - dal, err := New(ctx, conn, optional.None[string]()) + dal, err := New(ctx, conn, encryption.NewBuilder()) assert.NoError(t, err) // TTL is too short, expect an error @@ -71,7 +72,7 @@ func TestExpireLeases(t *testing.T) { } ctx := log.ContextWithNewDefaultLogger(context.Background()) conn := sqltest.OpenForTesting(ctx, t) - dal, err := New(ctx, conn, optional.None[string]()) + dal, err := New(ctx, conn, encryption.NewBuilder()) assert.NoError(t, err) leasei, _, err := dal.AcquireLease(ctx, leases.SystemKey("test"), time.Second*5, optional.None[any]()) diff --git a/backend/controller/sql/sqltest/testing.go b/backend/controller/sql/sqltest/testing.go index a2affb00fa..e2ba5f449b 100644 --- a/backend/controller/sql/sqltest/testing.go +++ b/backend/controller/sql/sqltest/testing.go @@ -20,7 +20,7 @@ func OpenForTesting(ctx context.Context, t testing.TB) *sql.DB { t.Helper() // Acquire lock for this DB. lockPath := filepath.Join(os.TempDir(), "ftl-db-test.lock") - release, err := flock.Acquire(ctx, lockPath, 20*time.Second) + release, err := flock.Acquire(ctx, lockPath, 30*time.Second) assert.NoError(t, err) t.Cleanup(func() { _ = release() }) //nolint:errcheck diff --git a/cmd/ftl-controller/main.go b/cmd/ftl-controller/main.go index 2faf92ffb4..4ab85995fb 100644 --- a/cmd/ftl-controller/main.go +++ b/cmd/ftl-controller/main.go @@ -20,6 +20,7 @@ import ( _ "github.com/TBD54566975/ftl/internal/automaxprocs" // Set GOMAXPROCS to match Linux container CPU quota. cf "github.com/TBD54566975/ftl/internal/configuration" cfdal "github.com/TBD54566975/ftl/internal/configuration/dal" + "github.com/TBD54566975/ftl/internal/encryption" "github.com/TBD54566975/ftl/internal/log" "github.com/TBD54566975/ftl/internal/observability" ) @@ -55,7 +56,10 @@ func main() { // The FTL controller currently only supports DB as a configuration provider/resolver. conn, err := sql.Open("pgx", cli.ControllerConfig.DSN) kctx.FatalIfErrorf(err) - dal, err := dal.New(ctx, conn, optional.Some[string](*cli.ControllerConfig.KMSURI)) + + encryptionBuilder := encryption.NewBuilder().WithKMSURI(optional.Ptr(cli.ControllerConfig.KMSURI)) + kctx.FatalIfErrorf(err) + dal, err := dal.New(ctx, conn, encryptionBuilder) kctx.FatalIfErrorf(err) configDal, err := cfdal.New(ctx, conn) diff --git a/internal/encryption/encryption.go b/internal/encryption/encryption.go index a9b17c224f..5f0dd11c3b 100644 --- a/internal/encryption/encryption.go +++ b/internal/encryption/encryption.go @@ -2,9 +2,11 @@ package encryption import ( "bytes" + "context" "fmt" "strings" + "github.com/alecthomas/types/optional" awsv1kms "github.com/aws/aws-sdk-go/service/kms" "github.com/tink-crypto/tink-go-awskms/integration/awskms" "github.com/tink-crypto/tink-go/v2/aead" @@ -23,26 +25,72 @@ type Encrypted interface { Set(data []byte) } +type KeyStoreProvider interface { + // EnsureKey asks a provider to check for an encrypted key. + // If not available, call the generateKey function to create a new key. + // The provider should handle transactions around checking and setting the key, to prevent race conditions. + EnsureKey(ctx context.Context, generateKey func() ([]byte, error)) ([]byte, error) +} + +// Builder constructs a DataEncryptor when used with a provider. +// Use a chain of With* methods to configure the builder. +type Builder struct { + kmsURI optional.Option[string] +} + +func NewBuilder() Builder { + return Builder{ + kmsURI: optional.None[string](), + } +} + +// WithKMSURI sets the URI for the KMS key to use. Omitting this call or using None will create a NoOpEncryptor. +func (b Builder) WithKMSURI(kmsURI optional.Option[string]) Builder { + b.kmsURI = kmsURI + return b +} + +func (b Builder) Build(ctx context.Context, provider KeyStoreProvider) (DataEncryptor, error) { + kmsURI, ok := b.kmsURI.Get() + if !ok { + return NewNoOpEncryptor(), nil + } + + key, err := provider.EnsureKey(ctx, func() ([]byte, error) { + return newKey(kmsURI, nil) + }) + if err != nil { + return nil, fmt.Errorf("failed to ensure key from provider: %w", err) + } + + encryptor, err := NewKMSEncryptorWithKMS(kmsURI, nil, key) + if err != nil { + return nil, fmt.Errorf("failed to create KMS encryptor: %w", err) + } + + return encryptor, nil +} + type DataEncryptor interface { Encrypt(cleartext []byte, dest Encrypted) error Decrypt(encrypted Encrypted) ([]byte, error) } -// NoOpEncryptorNext does not encrypt and just passes the input as is. -type NoOpEncryptorNext struct{} +// NoOpEncryptor does not encrypt and just passes the input as is. +type NoOpEncryptor struct{} -func NewNoOpEncryptor() NoOpEncryptorNext { - return NoOpEncryptorNext{} +func NewNoOpEncryptor() NoOpEncryptor { + return NoOpEncryptor{} } -var _ DataEncryptor = NoOpEncryptorNext{} +var _ DataEncryptor = NoOpEncryptor{} -func (n NoOpEncryptorNext) Encrypt(cleartext []byte, dest Encrypted) error { +func (n NoOpEncryptor) Encrypt(cleartext []byte, dest Encrypted) error { dest.Set(cleartext) return nil } -func (n NoOpEncryptorNext) Decrypt(encrypted Encrypted) ([]byte, error) { +func (n NoOpEncryptor) Decrypt(encrypted Encrypted) ([]byte, error) { return encrypted.Bytes(), nil } @@ -87,7 +135,7 @@ func newClientWithAEAD(uri string, kms *awsv1kms.KMS) (tink.AEAD, error) { return kekAEAD, nil } -func NewKMSEncryptorGenerateKey(uri string, v1client *awsv1kms.KMS) (*KMSEncryptor, error) { +func newKey(uri string, v1client *awsv1kms.KMS) ([]byte, error) { kekAEAD, err := newClientWithAEAD(uri, v1client) if err != nil { return nil, fmt.Errorf("failed to create KMS client: %w", err) @@ -116,14 +164,7 @@ func NewKMSEncryptorGenerateKey(uri string, v1client *awsv1kms.KMS) (*KMSEncrypt if err != nil { return nil, fmt.Errorf("failed to encrypt DEK: %w", err) } - encryptedKeyset := buf.Bytes() - - return &KMSEncryptor{ - root: *handle, - kekAEAD: kekAEAD, - encryptedKeyset: encryptedKeyset, - cachedDerived: make(map[SubKey]tink.AEAD), - }, nil + return buf.Bytes(), nil } func NewKMSEncryptorWithKMS(uri string, v1client *awsv1kms.KMS, encryptedKeyset []byte) (*KMSEncryptor, error) { diff --git a/internal/encryption/encryption_test.go b/internal/encryption/encryption_test.go index d8b1fd1b8f..23eb8b094b 100644 --- a/internal/encryption/encryption_test.go +++ b/internal/encryption/encryption_test.go @@ -7,7 +7,7 @@ import ( ) func TestNoOpEncryptor(t *testing.T) { - encryptor := NoOpEncryptorNext{} + encryptor := NoOpEncryptor{} var encrypted EncryptedTimelineColumn err := encryptor.Encrypt([]byte("hunter2"), &encrypted) @@ -23,7 +23,10 @@ func TestNoOpEncryptor(t *testing.T) { func TestKMSEncryptorFakeKMS(t *testing.T) { uri := "fake-kms://CKbvh_ILElQKSAowdHlwZS5nb29nbGVhcGlzLmNvbS9nb29nbGUuY3J5cHRvLnRpbmsuQWVzR2NtS2V5EhIaEE6tD2yE5AWYOirhmkY-r3sYARABGKbvh_ILIAE" - encryptor, err := NewKMSEncryptorGenerateKey(uri, nil) + key, err := newKey(uri, nil) + assert.NoError(t, err) + + encryptor, err := NewKMSEncryptorWithKMS(uri, nil, key) assert.NoError(t, err) var encrypted EncryptedTimelineColumn diff --git a/internal/encryption/integration_test.go b/internal/encryption/integration_test.go index 29e7c037c7..5c19af3665 100644 --- a/internal/encryption/integration_test.go +++ b/internal/encryption/integration_test.go @@ -146,7 +146,10 @@ func TestKMSEncryptorLocalstack(t *testing.T) { Region: awsv1.String("us-west-2"), }) - encryptor, err := NewKMSEncryptorGenerateKey(uri, v1client) + key, err := newKey(uri, v1client) + assert.NoError(t, err) + + encryptor, err := NewKMSEncryptorWithKMS(uri, v1client, key) assert.NoError(t, err) var encrypted EncryptedTimelineColumn