Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: encryption refactor #2429

Merged
merged 2 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion backend/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ import (
"github.com/TBD54566975/ftl/frontend"
cf "github.com/TBD54566975/ftl/internal/configuration"
"github.com/TBD54566975/ftl/internal/cors"
"github.com/TBD54566975/ftl/internal/encryption"
ftlhttp "github.com/TBD54566975/ftl/internal/http"
"github.com/TBD54566975/ftl/internal/log"
ftlmaps "github.com/TBD54566975/ftl/internal/maps"
Expand Down Expand Up @@ -229,7 +230,7 @@ func New(ctx context.Context, conn *sql.DB, config Config, runnerScaling scaling
config.ControllerTimeout = time.Second * 5
}

db, err := dal.New(ctx, conn, optional.Ptr[string](config.KMSURI))
db, err := dal.New(ctx, conn, encryption.NewBuilder().WithKMSURI(optional.Ptr(config.KMSURI)))
if err != nil {
return nil, fmt.Errorf("failed to create DAL: %w", err)
}
Expand Down
4 changes: 2 additions & 2 deletions backend/controller/cronjobs/cronjobs_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ import (
"time"

"github.com/alecthomas/assert/v2"
"github.com/alecthomas/types/optional"
"github.com/benbjohnson/clock"

db "github.com/TBD54566975/ftl/backend/controller/cronjobs/dal"
parentdb "github.com/TBD54566975/ftl/backend/controller/dal"
"github.com/TBD54566975/ftl/backend/controller/sql/sqltest"
"github.com/TBD54566975/ftl/internal/encryption"
in "github.com/TBD54566975/ftl/internal/integration"
"github.com/TBD54566975/ftl/internal/log"
)
Expand All @@ -28,7 +28,7 @@ func TestServiceWithRealDal(t *testing.T) {

conn := sqltest.OpenForTesting(ctx, t)
dal := db.New(conn)
parentDAL, err := parentdb.New(ctx, conn, optional.None[string]())
parentDAL, err := parentdb.New(ctx, conn, encryption.NewBuilder())
assert.NoError(t, err)

// Using a real clock because real db queries use db clock
Expand Down
3 changes: 2 additions & 1 deletion backend/controller/cronjobs/cronjobs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/TBD54566975/ftl/backend/controller/sql/sqltest"
ftlv1 "github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1"
"github.com/TBD54566975/ftl/backend/schema"
"github.com/TBD54566975/ftl/internal/encryption"
"github.com/TBD54566975/ftl/internal/log"
"github.com/TBD54566975/ftl/internal/model"
"github.com/TBD54566975/ftl/internal/slices"
Expand All @@ -37,7 +38,7 @@ func TestServiceWithMockDal(t *testing.T) {
attemptCountMap: map[string]int{},
}
conn := sqltest.OpenForTesting(ctx, t)
parentDAL, err := db.New(ctx, conn, optional.None[string]())
parentDAL, err := db.New(ctx, conn, encryption.NewBuilder())
assert.NoError(t, err)

testServiceWithDal(ctx, t, mockDal, parentDAL, clk)
Expand Down
6 changes: 3 additions & 3 deletions backend/controller/dal/async_calls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@ import (
"context"
"testing"

"github.com/alecthomas/types/optional"
"github.com/alecthomas/assert/v2"

"github.com/TBD54566975/ftl/backend/controller/sql/sqltest"
dalerrs "github.com/TBD54566975/ftl/backend/dal"
"github.com/TBD54566975/ftl/internal/encryption"
"github.com/TBD54566975/ftl/internal/log"
"github.com/alecthomas/assert/v2"
)

func TestNoCallToAcquire(t *testing.T) {
ctx := log.ContextWithNewDefaultLogger(context.Background())
conn := sqltest.OpenForTesting(ctx, t)
dal, err := New(ctx, conn, optional.None[string]())
dal, err := New(ctx, conn, encryption.NewBuilder())
assert.NoError(t, err)

_, err = dal.AcquireAsyncCall(ctx)
Expand Down
7 changes: 4 additions & 3 deletions backend/controller/dal/dal.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,16 +210,17 @@ func WithReservation(ctx context.Context, reservation Reservation, fn func() err
return reservation.Commit(ctx)
}

func New(ctx context.Context, conn *stdsql.DB, kmsURL optional.Option[string]) (*DAL, error) {
func New(ctx context.Context, conn *stdsql.DB, encryptionBuilder encryption.Builder) (*DAL, error) {
d := &DAL{
db: sql.NewDB(conn),
DeploymentChanges: pubsub.New[DeploymentNotification](),
kmsURL: kmsURL,
}

if err := d.setupEncryptor(ctx); err != nil {
encryptor, err := encryptionBuilder.Build(ctx, d)
if err != nil {
return nil, fmt.Errorf("failed to setup encryptor: %w", err)
}
d.encryptor = encryptor

return d, nil
}
Expand Down
7 changes: 4 additions & 3 deletions backend/controller/dal/dal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
dalerrs "github.com/TBD54566975/ftl/backend/dal"
ftlv1 "github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1"
"github.com/TBD54566975/ftl/backend/schema"
"github.com/TBD54566975/ftl/internal/encryption"
"github.com/TBD54566975/ftl/internal/log"
"github.com/TBD54566975/ftl/internal/model"
"github.com/TBD54566975/ftl/internal/sha256"
Expand All @@ -26,7 +27,7 @@ import (
func TestDAL(t *testing.T) {
ctx := log.ContextWithNewDefaultLogger(context.Background())
conn := sqltest.OpenForTesting(ctx, t)
dal, err := New(ctx, conn, optional.None[string]())
dal, err := New(ctx, conn, encryption.NewBuilder())
assert.NoError(t, err)
assert.NotZero(t, dal)
var testContent = bytes.Repeat([]byte("sometestcontentthatislongerthanthereadbuffer"), 100)
Expand Down Expand Up @@ -373,7 +374,7 @@ func TestDAL(t *testing.T) {
func TestCreateArtefactConflict(t *testing.T) {
ctx := log.ContextWithNewDefaultLogger(context.Background())
conn := sqltest.OpenForTesting(ctx, t)
dal, err := New(ctx, conn, optional.None[string]())
dal, err := New(ctx, conn, encryption.NewBuilder())
assert.NoError(t, err)

idch := make(chan sha256.SHA256, 2)
Expand Down Expand Up @@ -450,7 +451,7 @@ func assertEventsEqual(t *testing.T, expected, actual []TimelineEvent) {
func TestDeleteOldEvents(t *testing.T) {
ctx := log.ContextWithNewDefaultLogger(context.Background())
conn := sqltest.OpenForTesting(ctx, t)
dal, err := New(ctx, conn, optional.None[string]())
dal, err := New(ctx, conn, encryption.NewBuilder())
assert.NoError(t, err)

var testContent = bytes.Repeat([]byte("sometestcontentthatislongerthanthereadbuffer"), 100)
Expand Down
50 changes: 16 additions & 34 deletions backend/controller/dal/encryption.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,49 +58,31 @@ func (d *DAL) decryptJSON(encrypted encryption.Encrypted, v any) error { //nolin
return nil
}

// setupEncryptor sets up the encryptor for the DAL.
// It will either create a key or load the existing one.
// If the KMS URL is not set, it will use a NoOpEncryptor which does not encrypt anything.
func (d *DAL) setupEncryptor(ctx context.Context) (err error) {
func (d *DAL) EnsureKey(ctx context.Context, generateKey func() ([]byte, error)) (encryptedKey []byte, err error) {
logger := log.FromContext(ctx)
tx, err := d.Begin(ctx)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
return nil, fmt.Errorf("failed to begin transaction: %w", err)
}
defer tx.CommitOrRollback(ctx, &err)

url, ok := d.kmsURL.Get()
if !ok {
logger.Infof("KMS URL not set, encryption not enabled")
d.encryptor = encryption.NewNoOpEncryptor()
return nil
}
encryptedKey, err = tx.db.GetOnlyEncryptionKey(ctx)
if err != nil && dal.IsNotFound(err) {
logger.Debugf("No encryption key found, generating a new one")
key, err := generateKey()
if err != nil {
return nil, fmt.Errorf("failed to generate key: %w", err)
}

encryptedKey, err := tx.db.GetOnlyEncryptionKey(ctx)
if err != nil {
if dal.IsNotFound(err) {
logger.Infof("No encryption key found, generating a new one")
encryptor, err := encryption.NewKMSEncryptorGenerateKey(url, nil)
if err != nil {
return fmt.Errorf("failed to create encryptor for generation: %w", err)
}
d.encryptor = encryptor

if err = tx.db.CreateOnlyEncryptionKey(ctx, encryptor.GetEncryptedKeyset()); err != nil {
return fmt.Errorf("failed to create only encryption key: %w", err)
}

return nil
if err = tx.db.CreateOnlyEncryptionKey(ctx, key); err != nil {
return nil, fmt.Errorf("failed to save the encryption key: %w", err)
}
return fmt.Errorf("failed to get only encryption key: %w", err)
}

logger.Debugf("Encryption key found, using it")
encryptor, err := encryption.NewKMSEncryptorWithKMS(url, nil, encryptedKey)
if err != nil {
return fmt.Errorf("failed to create encryptor with encrypted key: %w", err)
return key, nil
} else if err != nil {
return nil, fmt.Errorf("failed to load the encryption key from the db: %w", err)
}
d.encryptor = encryptor

return nil
logger.Debugf("Encryption key found, using it")
return encryptedKey, nil
}
4 changes: 2 additions & 2 deletions backend/controller/dal/fsm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package dal

import (
"context"
"github.com/alecthomas/types/optional"
"testing"
"time"

Expand All @@ -12,13 +11,14 @@ import (
"github.com/TBD54566975/ftl/backend/controller/sql/sqltest"
dalerrs "github.com/TBD54566975/ftl/backend/dal"
"github.com/TBD54566975/ftl/backend/schema"
"github.com/TBD54566975/ftl/internal/encryption"
"github.com/TBD54566975/ftl/internal/log"
)

func TestSendFSMEvent(t *testing.T) {
ctx := log.ContextWithNewDefaultLogger(context.Background())
conn := sqltest.OpenForTesting(ctx, t)
dal, err := New(ctx, conn, optional.None[string]())
dal, err := New(ctx, conn, encryption.NewBuilder())
assert.NoError(t, err)

_, err = dal.AcquireAsyncCall(ctx)
Expand Down
5 changes: 3 additions & 2 deletions backend/controller/dal/lease_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/TBD54566975/ftl/backend/controller/sql"
"github.com/TBD54566975/ftl/backend/controller/sql/sqltest"
dalerrs "github.com/TBD54566975/ftl/backend/dal"
"github.com/TBD54566975/ftl/internal/encryption"
"github.com/TBD54566975/ftl/internal/log"
)

Expand All @@ -36,7 +37,7 @@ func TestLease(t *testing.T) {
}
ctx := log.ContextWithNewDefaultLogger(context.Background())
conn := sqltest.OpenForTesting(ctx, t)
dal, err := New(ctx, conn, optional.None[string]())
dal, err := New(ctx, conn, encryption.NewBuilder())
assert.NoError(t, err)

// TTL is too short, expect an error
Expand Down Expand Up @@ -71,7 +72,7 @@ func TestExpireLeases(t *testing.T) {
}
ctx := log.ContextWithNewDefaultLogger(context.Background())
conn := sqltest.OpenForTesting(ctx, t)
dal, err := New(ctx, conn, optional.None[string]())
dal, err := New(ctx, conn, encryption.NewBuilder())
assert.NoError(t, err)

leasei, _, err := dal.AcquireLease(ctx, leases.SystemKey("test"), time.Second*5, optional.None[any]())
Expand Down
2 changes: 1 addition & 1 deletion backend/controller/sql/sqltest/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func OpenForTesting(ctx context.Context, t testing.TB) *sql.DB {
t.Helper()
// Acquire lock for this DB.
lockPath := filepath.Join(os.TempDir(), "ftl-db-test.lock")
release, err := flock.Acquire(ctx, lockPath, 20*time.Second)
release, err := flock.Acquire(ctx, lockPath, 30*time.Second)
assert.NoError(t, err)
t.Cleanup(func() { _ = release() }) //nolint:errcheck

Expand Down
6 changes: 5 additions & 1 deletion cmd/ftl-controller/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
_ "github.com/TBD54566975/ftl/internal/automaxprocs" // Set GOMAXPROCS to match Linux container CPU quota.
cf "github.com/TBD54566975/ftl/internal/configuration"
cfdal "github.com/TBD54566975/ftl/internal/configuration/dal"
"github.com/TBD54566975/ftl/internal/encryption"
"github.com/TBD54566975/ftl/internal/log"
"github.com/TBD54566975/ftl/internal/observability"
)
Expand Down Expand Up @@ -55,7 +56,10 @@ func main() {
// The FTL controller currently only supports DB as a configuration provider/resolver.
conn, err := sql.Open("pgx", cli.ControllerConfig.DSN)
kctx.FatalIfErrorf(err)
dal, err := dal.New(ctx, conn, optional.Some[string](*cli.ControllerConfig.KMSURI))

encryptionBuilder := encryption.NewBuilder().WithKMSURI(optional.Ptr(cli.ControllerConfig.KMSURI))
kctx.FatalIfErrorf(err)
dal, err := dal.New(ctx, conn, encryptionBuilder)
kctx.FatalIfErrorf(err)

configDal, err := cfdal.New(ctx, conn)
Expand Down
73 changes: 57 additions & 16 deletions internal/encryption/encryption.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package encryption

import (
"bytes"
"context"
"fmt"
"strings"

"github.com/alecthomas/types/optional"
awsv1kms "github.com/aws/aws-sdk-go/service/kms"
"github.com/tink-crypto/tink-go-awskms/integration/awskms"
"github.com/tink-crypto/tink-go/v2/aead"
Expand All @@ -23,26 +25,72 @@ type Encrypted interface {
Set(data []byte)
}

type KeyStoreProvider interface {
// EnsureKey asks a provider to check for an encrypted key.
// If not available, call the generateKey function to create a new key.
// The provider should handle transactions around checking and setting the key, to prevent race conditions.
EnsureKey(ctx context.Context, generateKey func() ([]byte, error)) ([]byte, error)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice.

}

// Builder constructs a DataEncryptor when used with a provider.
// Use a chain of With* methods to configure the builder.
type Builder struct {
kmsURI optional.Option[string]
}

func NewBuilder() Builder {
return Builder{
kmsURI: optional.None[string](),
}
}

// WithKMSURI sets the URI for the KMS key to use. Omitting this call or using None will create a NoOpEncryptor.
func (b Builder) WithKMSURI(kmsURI optional.Option[string]) Builder {
b.kmsURI = kmsURI
return b
}

func (b Builder) Build(ctx context.Context, provider KeyStoreProvider) (DataEncryptor, error) {
kmsURI, ok := b.kmsURI.Get()
if !ok {
return NewNoOpEncryptor(), nil
}

key, err := provider.EnsureKey(ctx, func() ([]byte, error) {
return newKey(kmsURI, nil)
})
if err != nil {
return nil, fmt.Errorf("failed to ensure key from provider: %w", err)
}

encryptor, err := NewKMSEncryptorWithKMS(kmsURI, nil, key)
if err != nil {
return nil, fmt.Errorf("failed to create KMS encryptor: %w", err)
}

return encryptor, nil
}

type DataEncryptor interface {
Encrypt(cleartext []byte, dest Encrypted) error
Decrypt(encrypted Encrypted) ([]byte, error)
}

// NoOpEncryptorNext does not encrypt and just passes the input as is.
type NoOpEncryptorNext struct{}
// NoOpEncryptor does not encrypt and just passes the input as is.
type NoOpEncryptor struct{}

func NewNoOpEncryptor() NoOpEncryptorNext {
return NoOpEncryptorNext{}
func NewNoOpEncryptor() NoOpEncryptor {
return NoOpEncryptor{}
}

var _ DataEncryptor = NoOpEncryptorNext{}
var _ DataEncryptor = NoOpEncryptor{}

func (n NoOpEncryptorNext) Encrypt(cleartext []byte, dest Encrypted) error {
func (n NoOpEncryptor) Encrypt(cleartext []byte, dest Encrypted) error {
dest.Set(cleartext)
return nil
}

func (n NoOpEncryptorNext) Decrypt(encrypted Encrypted) ([]byte, error) {
func (n NoOpEncryptor) Decrypt(encrypted Encrypted) ([]byte, error) {
return encrypted.Bytes(), nil
}

Expand Down Expand Up @@ -87,7 +135,7 @@ func newClientWithAEAD(uri string, kms *awsv1kms.KMS) (tink.AEAD, error) {
return kekAEAD, nil
}

func NewKMSEncryptorGenerateKey(uri string, v1client *awsv1kms.KMS) (*KMSEncryptor, error) {
func newKey(uri string, v1client *awsv1kms.KMS) ([]byte, error) {
kekAEAD, err := newClientWithAEAD(uri, v1client)
if err != nil {
return nil, fmt.Errorf("failed to create KMS client: %w", err)
Expand Down Expand Up @@ -116,14 +164,7 @@ func NewKMSEncryptorGenerateKey(uri string, v1client *awsv1kms.KMS) (*KMSEncrypt
if err != nil {
return nil, fmt.Errorf("failed to encrypt DEK: %w", err)
}
encryptedKeyset := buf.Bytes()

return &KMSEncryptor{
root: *handle,
kekAEAD: kekAEAD,
encryptedKeyset: encryptedKeyset,
cachedDerived: make(map[SubKey]tink.AEAD),
}, nil
return buf.Bytes(), nil
}

func NewKMSEncryptorWithKMS(uri string, v1client *awsv1kms.KMS, encryptedKeyset []byte) (*KMSEncryptor, error) {
Expand Down
Loading
Loading