Skip to content

Commit

Permalink
refactor: encryption refactor (block#2429)
Browse files Browse the repository at this point in the history
  • Loading branch information
gak authored Aug 19, 2024
1 parent f75a5f0 commit 229e3d5
Show file tree
Hide file tree
Showing 14 changed files with 110 additions and 72 deletions.
3 changes: 2 additions & 1 deletion backend/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ import (
"github.com/TBD54566975/ftl/frontend"
cf "github.com/TBD54566975/ftl/internal/configuration"
"github.com/TBD54566975/ftl/internal/cors"
"github.com/TBD54566975/ftl/internal/encryption"
ftlhttp "github.com/TBD54566975/ftl/internal/http"
"github.com/TBD54566975/ftl/internal/log"
ftlmaps "github.com/TBD54566975/ftl/internal/maps"
Expand Down Expand Up @@ -229,7 +230,7 @@ func New(ctx context.Context, conn *sql.DB, config Config, runnerScaling scaling
config.ControllerTimeout = time.Second * 5
}

db, err := dal.New(ctx, conn, optional.Ptr[string](config.KMSURI))
db, err := dal.New(ctx, conn, encryption.NewBuilder().WithKMSURI(optional.Ptr(config.KMSURI)))
if err != nil {
return nil, fmt.Errorf("failed to create DAL: %w", err)
}
Expand Down
4 changes: 2 additions & 2 deletions backend/controller/cronjobs/cronjobs_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ import (
"time"

"github.com/alecthomas/assert/v2"
"github.com/alecthomas/types/optional"
"github.com/benbjohnson/clock"

db "github.com/TBD54566975/ftl/backend/controller/cronjobs/dal"
parentdb "github.com/TBD54566975/ftl/backend/controller/dal"
"github.com/TBD54566975/ftl/backend/controller/sql/sqltest"
"github.com/TBD54566975/ftl/internal/encryption"
in "github.com/TBD54566975/ftl/internal/integration"
"github.com/TBD54566975/ftl/internal/log"
)
Expand All @@ -28,7 +28,7 @@ func TestServiceWithRealDal(t *testing.T) {

conn := sqltest.OpenForTesting(ctx, t)
dal := db.New(conn)
parentDAL, err := parentdb.New(ctx, conn, optional.None[string]())
parentDAL, err := parentdb.New(ctx, conn, encryption.NewBuilder())
assert.NoError(t, err)

// Using a real clock because real db queries use db clock
Expand Down
3 changes: 2 additions & 1 deletion backend/controller/cronjobs/cronjobs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/TBD54566975/ftl/backend/controller/sql/sqltest"
ftlv1 "github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1"
"github.com/TBD54566975/ftl/backend/schema"
"github.com/TBD54566975/ftl/internal/encryption"
"github.com/TBD54566975/ftl/internal/log"
"github.com/TBD54566975/ftl/internal/model"
"github.com/TBD54566975/ftl/internal/slices"
Expand All @@ -37,7 +38,7 @@ func TestServiceWithMockDal(t *testing.T) {
attemptCountMap: map[string]int{},
}
conn := sqltest.OpenForTesting(ctx, t)
parentDAL, err := db.New(ctx, conn, optional.None[string]())
parentDAL, err := db.New(ctx, conn, encryption.NewBuilder())
assert.NoError(t, err)

testServiceWithDal(ctx, t, mockDal, parentDAL, clk)
Expand Down
6 changes: 3 additions & 3 deletions backend/controller/dal/async_calls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@ import (
"context"
"testing"

"github.com/alecthomas/types/optional"
"github.com/alecthomas/assert/v2"

"github.com/TBD54566975/ftl/backend/controller/sql/sqltest"
dalerrs "github.com/TBD54566975/ftl/backend/dal"
"github.com/TBD54566975/ftl/internal/encryption"
"github.com/TBD54566975/ftl/internal/log"
"github.com/alecthomas/assert/v2"
)

func TestNoCallToAcquire(t *testing.T) {
ctx := log.ContextWithNewDefaultLogger(context.Background())
conn := sqltest.OpenForTesting(ctx, t)
dal, err := New(ctx, conn, optional.None[string]())
dal, err := New(ctx, conn, encryption.NewBuilder())
assert.NoError(t, err)

_, err = dal.AcquireAsyncCall(ctx)
Expand Down
7 changes: 4 additions & 3 deletions backend/controller/dal/dal.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,16 +210,17 @@ func WithReservation(ctx context.Context, reservation Reservation, fn func() err
return reservation.Commit(ctx)
}

func New(ctx context.Context, conn *stdsql.DB, kmsURL optional.Option[string]) (*DAL, error) {
func New(ctx context.Context, conn *stdsql.DB, encryptionBuilder encryption.Builder) (*DAL, error) {
d := &DAL{
db: sql.NewDB(conn),
DeploymentChanges: pubsub.New[DeploymentNotification](),
kmsURL: kmsURL,
}

if err := d.setupEncryptor(ctx); err != nil {
encryptor, err := encryptionBuilder.Build(ctx, d)
if err != nil {
return nil, fmt.Errorf("failed to setup encryptor: %w", err)
}
d.encryptor = encryptor

return d, nil
}
Expand Down
7 changes: 4 additions & 3 deletions backend/controller/dal/dal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
dalerrs "github.com/TBD54566975/ftl/backend/dal"
ftlv1 "github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1"
"github.com/TBD54566975/ftl/backend/schema"
"github.com/TBD54566975/ftl/internal/encryption"
"github.com/TBD54566975/ftl/internal/log"
"github.com/TBD54566975/ftl/internal/model"
"github.com/TBD54566975/ftl/internal/sha256"
Expand All @@ -26,7 +27,7 @@ import (
func TestDAL(t *testing.T) {
ctx := log.ContextWithNewDefaultLogger(context.Background())
conn := sqltest.OpenForTesting(ctx, t)
dal, err := New(ctx, conn, optional.None[string]())
dal, err := New(ctx, conn, encryption.NewBuilder())
assert.NoError(t, err)
assert.NotZero(t, dal)
var testContent = bytes.Repeat([]byte("sometestcontentthatislongerthanthereadbuffer"), 100)
Expand Down Expand Up @@ -373,7 +374,7 @@ func TestDAL(t *testing.T) {
func TestCreateArtefactConflict(t *testing.T) {
ctx := log.ContextWithNewDefaultLogger(context.Background())
conn := sqltest.OpenForTesting(ctx, t)
dal, err := New(ctx, conn, optional.None[string]())
dal, err := New(ctx, conn, encryption.NewBuilder())
assert.NoError(t, err)

idch := make(chan sha256.SHA256, 2)
Expand Down Expand Up @@ -450,7 +451,7 @@ func assertEventsEqual(t *testing.T, expected, actual []TimelineEvent) {
func TestDeleteOldEvents(t *testing.T) {
ctx := log.ContextWithNewDefaultLogger(context.Background())
conn := sqltest.OpenForTesting(ctx, t)
dal, err := New(ctx, conn, optional.None[string]())
dal, err := New(ctx, conn, encryption.NewBuilder())
assert.NoError(t, err)

var testContent = bytes.Repeat([]byte("sometestcontentthatislongerthanthereadbuffer"), 100)
Expand Down
50 changes: 16 additions & 34 deletions backend/controller/dal/encryption.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,49 +58,31 @@ func (d *DAL) decryptJSON(encrypted encryption.Encrypted, v any) error { //nolin
return nil
}

// setupEncryptor sets up the encryptor for the DAL.
// It will either create a key or load the existing one.
// If the KMS URL is not set, it will use a NoOpEncryptor which does not encrypt anything.
func (d *DAL) setupEncryptor(ctx context.Context) (err error) {
func (d *DAL) EnsureKey(ctx context.Context, generateKey func() ([]byte, error)) (encryptedKey []byte, err error) {
logger := log.FromContext(ctx)
tx, err := d.Begin(ctx)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
return nil, fmt.Errorf("failed to begin transaction: %w", err)
}
defer tx.CommitOrRollback(ctx, &err)

url, ok := d.kmsURL.Get()
if !ok {
logger.Infof("KMS URL not set, encryption not enabled")
d.encryptor = encryption.NewNoOpEncryptor()
return nil
}
encryptedKey, err = tx.db.GetOnlyEncryptionKey(ctx)
if err != nil && dal.IsNotFound(err) {
logger.Debugf("No encryption key found, generating a new one")
key, err := generateKey()
if err != nil {
return nil, fmt.Errorf("failed to generate key: %w", err)
}

encryptedKey, err := tx.db.GetOnlyEncryptionKey(ctx)
if err != nil {
if dal.IsNotFound(err) {
logger.Infof("No encryption key found, generating a new one")
encryptor, err := encryption.NewKMSEncryptorGenerateKey(url, nil)
if err != nil {
return fmt.Errorf("failed to create encryptor for generation: %w", err)
}
d.encryptor = encryptor

if err = tx.db.CreateOnlyEncryptionKey(ctx, encryptor.GetEncryptedKeyset()); err != nil {
return fmt.Errorf("failed to create only encryption key: %w", err)
}

return nil
if err = tx.db.CreateOnlyEncryptionKey(ctx, key); err != nil {
return nil, fmt.Errorf("failed to save the encryption key: %w", err)
}
return fmt.Errorf("failed to get only encryption key: %w", err)
}

logger.Debugf("Encryption key found, using it")
encryptor, err := encryption.NewKMSEncryptorWithKMS(url, nil, encryptedKey)
if err != nil {
return fmt.Errorf("failed to create encryptor with encrypted key: %w", err)
return key, nil
} else if err != nil {
return nil, fmt.Errorf("failed to load the encryption key from the db: %w", err)
}
d.encryptor = encryptor

return nil
logger.Debugf("Encryption key found, using it")
return encryptedKey, nil
}
4 changes: 2 additions & 2 deletions backend/controller/dal/fsm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package dal

import (
"context"
"github.com/alecthomas/types/optional"
"testing"
"time"

Expand All @@ -12,13 +11,14 @@ import (
"github.com/TBD54566975/ftl/backend/controller/sql/sqltest"
dalerrs "github.com/TBD54566975/ftl/backend/dal"
"github.com/TBD54566975/ftl/backend/schema"
"github.com/TBD54566975/ftl/internal/encryption"
"github.com/TBD54566975/ftl/internal/log"
)

func TestSendFSMEvent(t *testing.T) {
ctx := log.ContextWithNewDefaultLogger(context.Background())
conn := sqltest.OpenForTesting(ctx, t)
dal, err := New(ctx, conn, optional.None[string]())
dal, err := New(ctx, conn, encryption.NewBuilder())
assert.NoError(t, err)

_, err = dal.AcquireAsyncCall(ctx)
Expand Down
5 changes: 3 additions & 2 deletions backend/controller/dal/lease_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/TBD54566975/ftl/backend/controller/sql"
"github.com/TBD54566975/ftl/backend/controller/sql/sqltest"
dalerrs "github.com/TBD54566975/ftl/backend/dal"
"github.com/TBD54566975/ftl/internal/encryption"
"github.com/TBD54566975/ftl/internal/log"
)

Expand All @@ -36,7 +37,7 @@ func TestLease(t *testing.T) {
}
ctx := log.ContextWithNewDefaultLogger(context.Background())
conn := sqltest.OpenForTesting(ctx, t)
dal, err := New(ctx, conn, optional.None[string]())
dal, err := New(ctx, conn, encryption.NewBuilder())
assert.NoError(t, err)

// TTL is too short, expect an error
Expand Down Expand Up @@ -71,7 +72,7 @@ func TestExpireLeases(t *testing.T) {
}
ctx := log.ContextWithNewDefaultLogger(context.Background())
conn := sqltest.OpenForTesting(ctx, t)
dal, err := New(ctx, conn, optional.None[string]())
dal, err := New(ctx, conn, encryption.NewBuilder())
assert.NoError(t, err)

leasei, _, err := dal.AcquireLease(ctx, leases.SystemKey("test"), time.Second*5, optional.None[any]())
Expand Down
2 changes: 1 addition & 1 deletion backend/controller/sql/sqltest/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func OpenForTesting(ctx context.Context, t testing.TB) *sql.DB {
t.Helper()
// Acquire lock for this DB.
lockPath := filepath.Join(os.TempDir(), "ftl-db-test.lock")
release, err := flock.Acquire(ctx, lockPath, 20*time.Second)
release, err := flock.Acquire(ctx, lockPath, 30*time.Second)
assert.NoError(t, err)
t.Cleanup(func() { _ = release() }) //nolint:errcheck

Expand Down
6 changes: 5 additions & 1 deletion cmd/ftl-controller/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
_ "github.com/TBD54566975/ftl/internal/automaxprocs" // Set GOMAXPROCS to match Linux container CPU quota.
cf "github.com/TBD54566975/ftl/internal/configuration"
cfdal "github.com/TBD54566975/ftl/internal/configuration/dal"
"github.com/TBD54566975/ftl/internal/encryption"
"github.com/TBD54566975/ftl/internal/log"
"github.com/TBD54566975/ftl/internal/observability"
)
Expand Down Expand Up @@ -55,7 +56,10 @@ func main() {
// The FTL controller currently only supports DB as a configuration provider/resolver.
conn, err := sql.Open("pgx", cli.ControllerConfig.DSN)
kctx.FatalIfErrorf(err)
dal, err := dal.New(ctx, conn, optional.Some[string](*cli.ControllerConfig.KMSURI))

encryptionBuilder := encryption.NewBuilder().WithKMSURI(optional.Ptr(cli.ControllerConfig.KMSURI))
kctx.FatalIfErrorf(err)
dal, err := dal.New(ctx, conn, encryptionBuilder)
kctx.FatalIfErrorf(err)

configDal, err := cfdal.New(ctx, conn)
Expand Down
73 changes: 57 additions & 16 deletions internal/encryption/encryption.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package encryption

import (
"bytes"
"context"
"fmt"
"strings"

"github.com/alecthomas/types/optional"
awsv1kms "github.com/aws/aws-sdk-go/service/kms"
"github.com/tink-crypto/tink-go-awskms/integration/awskms"
"github.com/tink-crypto/tink-go/v2/aead"
Expand All @@ -23,26 +25,72 @@ type Encrypted interface {
Set(data []byte)
}

type KeyStoreProvider interface {
// EnsureKey asks a provider to check for an encrypted key.
// If not available, call the generateKey function to create a new key.
// The provider should handle transactions around checking and setting the key, to prevent race conditions.
EnsureKey(ctx context.Context, generateKey func() ([]byte, error)) ([]byte, error)
}

// Builder constructs a DataEncryptor when used with a provider.
// Use a chain of With* methods to configure the builder.
type Builder struct {
kmsURI optional.Option[string]
}

func NewBuilder() Builder {
return Builder{
kmsURI: optional.None[string](),
}
}

// WithKMSURI sets the URI for the KMS key to use. Omitting this call or using None will create a NoOpEncryptor.
func (b Builder) WithKMSURI(kmsURI optional.Option[string]) Builder {
b.kmsURI = kmsURI
return b
}

func (b Builder) Build(ctx context.Context, provider KeyStoreProvider) (DataEncryptor, error) {
kmsURI, ok := b.kmsURI.Get()
if !ok {
return NewNoOpEncryptor(), nil
}

key, err := provider.EnsureKey(ctx, func() ([]byte, error) {
return newKey(kmsURI, nil)
})
if err != nil {
return nil, fmt.Errorf("failed to ensure key from provider: %w", err)
}

encryptor, err := NewKMSEncryptorWithKMS(kmsURI, nil, key)
if err != nil {
return nil, fmt.Errorf("failed to create KMS encryptor: %w", err)
}

return encryptor, nil
}

type DataEncryptor interface {
Encrypt(cleartext []byte, dest Encrypted) error
Decrypt(encrypted Encrypted) ([]byte, error)
}

// NoOpEncryptorNext does not encrypt and just passes the input as is.
type NoOpEncryptorNext struct{}
// NoOpEncryptor does not encrypt and just passes the input as is.
type NoOpEncryptor struct{}

func NewNoOpEncryptor() NoOpEncryptorNext {
return NoOpEncryptorNext{}
func NewNoOpEncryptor() NoOpEncryptor {
return NoOpEncryptor{}
}

var _ DataEncryptor = NoOpEncryptorNext{}
var _ DataEncryptor = NoOpEncryptor{}

func (n NoOpEncryptorNext) Encrypt(cleartext []byte, dest Encrypted) error {
func (n NoOpEncryptor) Encrypt(cleartext []byte, dest Encrypted) error {
dest.Set(cleartext)
return nil
}

func (n NoOpEncryptorNext) Decrypt(encrypted Encrypted) ([]byte, error) {
func (n NoOpEncryptor) Decrypt(encrypted Encrypted) ([]byte, error) {
return encrypted.Bytes(), nil
}

Expand Down Expand Up @@ -87,7 +135,7 @@ func newClientWithAEAD(uri string, kms *awsv1kms.KMS) (tink.AEAD, error) {
return kekAEAD, nil
}

func NewKMSEncryptorGenerateKey(uri string, v1client *awsv1kms.KMS) (*KMSEncryptor, error) {
func newKey(uri string, v1client *awsv1kms.KMS) ([]byte, error) {
kekAEAD, err := newClientWithAEAD(uri, v1client)
if err != nil {
return nil, fmt.Errorf("failed to create KMS client: %w", err)
Expand Down Expand Up @@ -116,14 +164,7 @@ func NewKMSEncryptorGenerateKey(uri string, v1client *awsv1kms.KMS) (*KMSEncrypt
if err != nil {
return nil, fmt.Errorf("failed to encrypt DEK: %w", err)
}
encryptedKeyset := buf.Bytes()

return &KMSEncryptor{
root: *handle,
kekAEAD: kekAEAD,
encryptedKeyset: encryptedKeyset,
cachedDerived: make(map[SubKey]tink.AEAD),
}, nil
return buf.Bytes(), nil
}

func NewKMSEncryptorWithKMS(uri string, v1client *awsv1kms.KMS, encryptedKeyset []byte) (*KMSEncryptor, error) {
Expand Down
Loading

0 comments on commit 229e3d5

Please sign in to comment.