Skip to content

Commit

Permalink
refactor: rename things to make the next merge easier
Browse files Browse the repository at this point in the history
  • Loading branch information
gak committed Aug 19, 2024
1 parent 8233d17 commit c1c8333
Show file tree
Hide file tree
Showing 10 changed files with 53 additions and 62 deletions.
6 changes: 3 additions & 3 deletions backend/controller/dal/async_calls.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func (d *DAL) AcquireAsyncCall(ctx context.Context) (call *AsyncCall, err error)
return nil, fmt.Errorf("failed to parse origin key %q: %w", row.Origin, err)
}

decryptedRequest, err := d.decrypt(encryption.AsyncSubkey{}, row.Request)
decryptedRequest, err := d.decrypt(encryption.AsyncSubKey, row.Request)
if err != nil {
return nil, fmt.Errorf("failed to decrypt async call request: %w", err)
}
Expand Down Expand Up @@ -158,7 +158,7 @@ func (d *DAL) CompleteAsyncCall(ctx context.Context,
didScheduleAnotherCall = false
switch result := result.(type) {
case either.Left[[]byte, string]: // Successful response.
encryptedResult, err := d.encrypt(encryption.AsyncSubkey{}, result.Get())
encryptedResult, err := d.encrypt(encryption.AsyncSubKey, result.Get())
if err != nil {
return false, fmt.Errorf("failed to encrypt async call result: %w", err)
}
Expand Down Expand Up @@ -227,7 +227,7 @@ func (d *DAL) LoadAsyncCall(ctx context.Context, id int64) (*AsyncCall, error) {
if err != nil {
return nil, fmt.Errorf("failed to parse origin key %q: %w", row.Origin, err)
}
request, err := d.decrypt(encryption.AsyncSubkey{}, row.Request)
request, err := d.decrypt(encryption.AsyncSubKey, row.Request)
if err != nil {
return nil, fmt.Errorf("failed to decrypt async call request: %w", err)
}
Expand Down
8 changes: 4 additions & 4 deletions backend/controller/dal/dal.go
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ func (d *DAL) SetDeploymentReplicas(ctx context.Context, key model.DeploymentKey
return dalerrs.TranslatePGError(err)
}
}
payload, err := d.encryptJSON(encryption.TimelineSubkey{}, map[string]interface{}{
payload, err := d.encryptJSON(encryption.TimelineSubKey, map[string]interface{}{
"prev_min_replicas": deployment.MinReplicas,
"min_replicas": minReplicas,
})
Expand Down Expand Up @@ -783,7 +783,7 @@ func (d *DAL) ReplaceDeployment(ctx context.Context, newDeploymentKey model.Depl
}
}

payload, err := d.encryptJSON(encryption.TimelineSubkey{}, map[string]any{
payload, err := d.encryptJSON(encryption.TimelineSubKey, map[string]any{
"min_replicas": int32(minReplicas),
"replaced": replacedDeploymentKey,
})
Expand Down Expand Up @@ -1058,7 +1058,7 @@ func (d *DAL) InsertLogEvent(ctx context.Context, log *LogEvent) error {
"error": log.Error,
"stack": log.Stack,
}
encryptedPayload, err := d.encryptJSON(encryption.TimelineSubkey{}, payload)
encryptedPayload, err := d.encryptJSON(encryption.TimelineSubKey, payload)
if err != nil {
return fmt.Errorf("failed to encrypt log payload: %w", err)
}
Expand Down Expand Up @@ -1138,7 +1138,7 @@ func (d *DAL) InsertCallEvent(ctx context.Context, call *CallEvent) error {
if pr, ok := call.ParentRequestKey.Get(); ok {
parentRequestKey = optional.Some(pr.String())
}
payload, err := d.encryptJSON(encryption.TimelineSubkey{}, map[string]any{
payload, err := d.encryptJSON(encryption.TimelineSubKey, map[string]any{
"duration_ms": call.Duration.Milliseconds(),
"request": call.Request,
"response": call.Response,
Expand Down
22 changes: 11 additions & 11 deletions backend/controller/dal/encryption.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,45 +10,45 @@ import (
"github.com/TBD54566975/ftl/internal/log"
)

func (d *DAL) encrypt(subkey encryption.Subkey, cleartext []byte) ([]byte, error) {
func (d *DAL) encrypt(subKey encryption.SubKey, cleartext []byte) ([]byte, error) {
if d.encryptor == nil {
return nil, fmt.Errorf("encryptor not set")
}

v, err := d.encryptor.Encrypt(subkey, cleartext)
v, err := d.encryptor.Encrypt(subKey, cleartext)
if err != nil {
return nil, fmt.Errorf("failed to encrypt binary with subkey %s: %w", subkey, err)
return nil, fmt.Errorf("failed to encrypt binary with subkey %s: %w", subKey, err)
}

return v, nil
}

func (d *DAL) decrypt(subkey encryption.Subkey, encrypted []byte) ([]byte, error) {
func (d *DAL) decrypt(subKey encryption.SubKey, encrypted []byte) ([]byte, error) {
if d.encryptor == nil {
return nil, fmt.Errorf("encryptor not set")
}

v, err := d.encryptor.Decrypt(subkey, encrypted)
v, err := d.encryptor.Decrypt(subKey, encrypted)
if err != nil {
return nil, fmt.Errorf("failed to decrypt binary with subkey %s: %w", subkey, err)
return nil, fmt.Errorf("failed to decrypt binary with subkey %s: %w", subKey, err)
}

return v, nil
}

func (d *DAL) encryptJSON(subkey encryption.Subkey, v any) ([]byte, error) {
func (d *DAL) encryptJSON(subKey encryption.SubKey, v any) ([]byte, error) {
serialized, err := json.Marshal(v)
if err != nil {
return nil, fmt.Errorf("failed to marshal JSON: %w", err)
}

return d.encrypt(subkey, serialized)
return d.encrypt(subKey, serialized)
}

func (d *DAL) decryptJSON(subkey encryption.Subkey, encrypted []byte, v any) error { //nolint:unparam
decrypted, err := d.decrypt(subkey, encrypted)
func (d *DAL) decryptJSON(subKey encryption.SubKey, encrypted []byte, v any) error { //nolint:unparam
decrypted, err := d.decrypt(subKey, encrypted)
if err != nil {
return fmt.Errorf("failed to decrypt json with subkey %s: %w", subkey, err)
return fmt.Errorf("failed to decrypt json with subkey %s: %w", subKey, err)
}

if err = json.Unmarshal(decrypted, v); err != nil {
Expand Down
8 changes: 4 additions & 4 deletions backend/controller/dal/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ func (d *DAL) transformRowsToTimelineEvents(deploymentKeys map[int64]model.Deplo
switch row.Type {
case sql.EventTypeLog:
var jsonPayload eventLogJSON
if err := d.decryptJSON(encryption.TimelineSubkey{}, row.Payload, &jsonPayload); err != nil {
if err := d.decryptJSON(encryption.TimelineSubKey, row.Payload, &jsonPayload); err != nil {
return nil, fmt.Errorf("failed to decrypt log event: %w", err)
}

Expand All @@ -371,7 +371,7 @@ func (d *DAL) transformRowsToTimelineEvents(deploymentKeys map[int64]model.Deplo

case sql.EventTypeCall:
var jsonPayload eventCallJSON
if err := d.decryptJSON(encryption.TimelineSubkey{}, row.Payload, &jsonPayload); err != nil {
if err := d.decryptJSON(encryption.TimelineSubKey, row.Payload, &jsonPayload); err != nil {
return nil, fmt.Errorf("failed to decrypt call event: %w", err)
}
var sourceVerb optional.Option[schema.Ref]
Expand All @@ -396,7 +396,7 @@ func (d *DAL) transformRowsToTimelineEvents(deploymentKeys map[int64]model.Deplo

case sql.EventTypeDeploymentCreated:
var jsonPayload eventDeploymentCreatedJSON
if err := d.decryptJSON(encryption.TimelineSubkey{}, row.Payload, &jsonPayload); err != nil {
if err := d.decryptJSON(encryption.TimelineSubKey, row.Payload, &jsonPayload); err != nil {
return nil, fmt.Errorf("failed to decrypt call event: %w", err)
}
out = append(out, &DeploymentCreatedEvent{
Expand All @@ -411,7 +411,7 @@ func (d *DAL) transformRowsToTimelineEvents(deploymentKeys map[int64]model.Deplo

case sql.EventTypeDeploymentUpdated:
var jsonPayload eventDeploymentUpdatedJSON
if err := d.decryptJSON(encryption.TimelineSubkey{}, row.Payload, &jsonPayload); err != nil {
if err := d.decryptJSON(encryption.TimelineSubKey, row.Payload, &jsonPayload); err != nil {
return nil, fmt.Errorf("failed to decrypt call event: %w", err)
}
out = append(out, &DeploymentUpdatedEvent{
Expand Down
4 changes: 2 additions & 2 deletions backend/controller/dal/fsm.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func (d *DAL) StartFSMTransition(ctx context.Context, fsm schema.RefKey, instanc
if encrypted {
encryptedRequest = request
} else {
encryptedRequest, err = d.encrypt(encryption.AsyncSubkey{}, request)
encryptedRequest, err = d.encrypt(encryption.AsyncSubKey, request)
if err != nil {
return fmt.Errorf("failed to encrypt FSM request: %w", err)
}
Expand Down Expand Up @@ -146,7 +146,7 @@ func (d *DAL) PopNextFSMEvent(ctx context.Context, fsm schema.RefKey, instanceKe
}

func (d *DAL) SetNextFSMEvent(ctx context.Context, fsm schema.RefKey, instanceKey string, nextState schema.RefKey, request json.RawMessage, requestType schema.Type) error {
encryptedRequest, err := d.encryptJSON(encryption.AsyncSubkey{}, request)
encryptedRequest, err := d.encryptJSON(encryption.AsyncSubKey, request)
if err != nil {
return fmt.Errorf("failed to encrypt FSM request: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion backend/controller/dal/pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (
)

func (d *DAL) PublishEventForTopic(ctx context.Context, module, topic, caller string, payload []byte) error {
encryptedPayload, err := d.encrypt(encryption.AsyncSubkey{}, payload)
encryptedPayload, err := d.encrypt(encryption.AsyncSubKey, payload)
if err != nil {
return fmt.Errorf("failed to encrypt payload: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion cmd/ftl-controller/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ import (
"context"
"database/sql"
"fmt"
"github.com/alecthomas/types/optional"
"os"
"strconv"
"time"

"github.com/alecthomas/kong"
"github.com/alecthomas/types/optional"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/secretsmanager"

Expand Down
47 changes: 19 additions & 28 deletions internal/encryption/encryption.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,12 @@ import (
"github.com/tink-crypto/tink-go/v2/tink"
)

type Subkey interface {
Salt() string
}

type TimelineSubkey struct{}

func (t TimelineSubkey) Salt() string {
return "timeline"
}
type SubKey string

type AsyncSubkey struct{}

func (a AsyncSubkey) Salt() string {
return "async"
}
const (
TimelineSubKey SubKey = "timeline"
AsyncSubKey SubKey = "async"
)

type KeyStoreProvider interface {
// EnsureKey asks a provider to check for an encrypted key.
Expand Down Expand Up @@ -81,8 +72,8 @@ func (b Builder) Build(ctx context.Context, provider KeyStoreProvider) (DataEncr
}

type DataEncryptor interface {
Encrypt(subkey Subkey, cleartext []byte) ([]byte, error)
Decrypt(subkey Subkey, encrypted []byte) ([]byte, error)
Encrypt(subKey SubKey, cleartext []byte) ([]byte, error)
Decrypt(subKey SubKey, encrypted []byte) ([]byte, error)
}

// NoOpEncryptor does not encrypt and just passes the input as is.
Expand All @@ -92,11 +83,11 @@ func NewNoOpEncryptor() NoOpEncryptor {
return NoOpEncryptor{}
}

func (n NoOpEncryptor) Encrypt(_ Subkey, cleartext []byte) ([]byte, error) {
func (n NoOpEncryptor) Encrypt(_ SubKey, cleartext []byte) ([]byte, error) {
return cleartext, nil
}

func (n NoOpEncryptor) Decrypt(_ Subkey, encrypted []byte) ([]byte, error) {
func (n NoOpEncryptor) Decrypt(_ SubKey, encrypted []byte) ([]byte, error) {
return encrypted, nil
}

Expand All @@ -105,7 +96,7 @@ type KMSEncryptor struct {
root keyset.Handle
kekAEAD tink.AEAD
encryptedKeyset []byte
cachedDerived map[Subkey]tink.AEAD
cachedDerived map[SubKey]tink.AEAD
}

func newClientWithAEAD(uri string, kms *awsv1kms.KMS) (tink.AEAD, error) {
Expand Down Expand Up @@ -187,7 +178,7 @@ func NewKMSEncryptorWithKMS(uri string, v1client *awsv1kms.KMS, encryptedKeyset
root: *handle,
kekAEAD: kekAEAD,
encryptedKeyset: encryptedKeyset,
cachedDerived: make(map[Subkey]tink.AEAD),
cachedDerived: make(map[SubKey]tink.AEAD),
}, nil
}

Expand All @@ -209,12 +200,12 @@ func deriveKeyset(root keyset.Handle, salt []byte) (*keyset.Handle, error) {
return derived, nil
}

func (k *KMSEncryptor) getDerivedPrimitive(subkey Subkey) (tink.AEAD, error) {
if primitive, ok := k.cachedDerived[subkey]; ok {
func (k *KMSEncryptor) getDerivedPrimitive(subKey SubKey) (tink.AEAD, error) {
if primitive, ok := k.cachedDerived[subKey]; ok {
return primitive, nil
}

derived, err := deriveKeyset(k.root, []byte(subkey.Salt()))
derived, err := deriveKeyset(k.root, []byte(subKey))
if err != nil {
return nil, fmt.Errorf("failed to derive keyset: %w", err)
}
Expand All @@ -224,12 +215,12 @@ func (k *KMSEncryptor) getDerivedPrimitive(subkey Subkey) (tink.AEAD, error) {
return nil, fmt.Errorf("failed to create primitive: %w", err)
}

k.cachedDerived[subkey] = primitive
k.cachedDerived[subKey] = primitive
return primitive, nil
}

func (k *KMSEncryptor) Encrypt(subkey Subkey, cleartext []byte) ([]byte, error) {
primitive, err := k.getDerivedPrimitive(subkey)
func (k *KMSEncryptor) Encrypt(subKey SubKey, cleartext []byte) ([]byte, error) {
primitive, err := k.getDerivedPrimitive(subKey)
if err != nil {
return nil, fmt.Errorf("failed to get derived primitive: %w", err)
}
Expand All @@ -242,8 +233,8 @@ func (k *KMSEncryptor) Encrypt(subkey Subkey, cleartext []byte) ([]byte, error)
return encrypted, nil
}

func (k *KMSEncryptor) Decrypt(subkey Subkey, encrypted []byte) ([]byte, error) {
primitive, err := k.getDerivedPrimitive(subkey)
func (k *KMSEncryptor) Decrypt(subKey SubKey, encrypted []byte) ([]byte, error) {
primitive, err := k.getDerivedPrimitive(subKey)
if err != nil {
return nil, fmt.Errorf("failed to get derived primitive: %w", err)
}
Expand Down
10 changes: 5 additions & 5 deletions internal/encryption/encryption_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ import (
func TestNoOpEncryptor(t *testing.T) {
encryptor := NoOpEncryptor{}

encrypted, err := encryptor.Encrypt(TimelineSubkey{}, []byte("hunter2"))
encrypted, err := encryptor.Encrypt(TimelineSubKey, []byte("hunter2"))
assert.NoError(t, err)

decrypted, err := encryptor.Decrypt(TimelineSubkey{}, encrypted)
decrypted, err := encryptor.Decrypt(TimelineSubKey, encrypted)
assert.NoError(t, err)

assert.Equal(t, "hunter2", string(decrypted))
Expand All @@ -28,14 +28,14 @@ func TestKMSEncryptorFakeKMS(t *testing.T) {
encryptor, err := NewKMSEncryptorWithKMS(uri, nil, key)
assert.NoError(t, err)

encrypted, err := encryptor.Encrypt(TimelineSubkey{}, []byte("hunter2"))
encrypted, err := encryptor.Encrypt(TimelineSubKey, []byte("hunter2"))
assert.NoError(t, err)

decrypted, err := encryptor.Decrypt(TimelineSubkey{}, encrypted)
decrypted, err := encryptor.Decrypt(TimelineSubKey, encrypted)
assert.NoError(t, err)
assert.Equal(t, "hunter2", string(decrypted))

// Should fail to decrypt with the wrong subkey
_, err = encryptor.Decrypt(AsyncSubkey{}, encrypted)
_, err = encryptor.Decrypt(AsyncSubKey, encrypted)
assert.Error(t, err)
}
6 changes: 3 additions & 3 deletions internal/encryption/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,14 @@ func TestKMSEncryptorLocalstack(t *testing.T) {
encryptor, err := NewKMSEncryptorWithKMS(uri, v1client, key)
assert.NoError(t, err)

encrypted, err := encryptor.Encrypt(TimelineSubkey{}, []byte("hunter2"))
encrypted, err := encryptor.Encrypt(TimelineSubKey, []byte("hunter2"))
assert.NoError(t, err)

decrypted, err := encryptor.Decrypt(TimelineSubkey{}, encrypted)
decrypted, err := encryptor.Decrypt(TimelineSubKey, encrypted)
assert.NoError(t, err)
assert.Equal(t, "hunter2", string(decrypted))

// Should fail to decrypt with the wrong subkey
_, err = encryptor.Decrypt(AsyncSubkey{}, encrypted)
_, err = encryptor.Decrypt(AsyncSubKey, encrypted)
assert.Error(t, err)
}

0 comments on commit c1c8333

Please sign in to comment.