Skip to content

Commit

Permalink
fix: concurrency issue in getDerivedPrimitive (#2480)
Browse files Browse the repository at this point in the history
There is a situation where getDerivedPrimitive can be called
concurrently, and two separate keys will be created. This fixes that
issue.

Also added a type-safe Mutex type.
  • Loading branch information
alecthomas authored Aug 22, 2024
1 parent 2a797e5 commit bc77ebb
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 11 deletions.
19 changes: 8 additions & 11 deletions internal/encryption/encryption.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"context"
"fmt"
"strings"
"sync"

"github.com/alecthomas/types/optional"
awsv1kms "github.com/aws/aws-sdk-go/service/kms"
Expand All @@ -17,6 +16,8 @@ import (
"github.com/tink-crypto/tink-go/v2/prf"
"github.com/tink-crypto/tink-go/v2/testing/fakekms"
"github.com/tink-crypto/tink-go/v2/tink"

"github.com/TBD54566975/ftl/internal/mutex"
)

// Encrypted is an interface for values that contain encrypted data.
Expand Down Expand Up @@ -100,8 +101,7 @@ type KMSEncryptor struct {
root keyset.Handle
kekAEAD tink.AEAD
encryptedKeyset []byte
cachedDerivedMu sync.RWMutex
cachedDerived map[SubKey]tink.AEAD
cachedDerived *mutex.Mutex[map[SubKey]tink.AEAD]
}

var _ DataEncryptor = &KMSEncryptor{}
Expand Down Expand Up @@ -185,7 +185,7 @@ func NewKMSEncryptorWithKMS(uri string, v1client *awsv1kms.KMS, encryptedKeyset
root: *handle,
kekAEAD: kekAEAD,
encryptedKeyset: encryptedKeyset,
cachedDerived: make(map[SubKey]tink.AEAD),
cachedDerived: mutex.New(map[SubKey]tink.AEAD{}),
}, nil
}

Expand All @@ -208,9 +208,9 @@ func deriveKeyset(root keyset.Handle, salt []byte) (*keyset.Handle, error) {
}

func (k *KMSEncryptor) getDerivedPrimitive(subKey SubKey) (tink.AEAD, error) {
k.cachedDerivedMu.RLock()
primitive, ok := k.cachedDerived[subKey]
k.cachedDerivedMu.RUnlock()
cachedDerived := k.cachedDerived.Lock()
defer k.cachedDerived.Unlock()
primitive, ok := cachedDerived[subKey]
if ok {
return primitive, nil
}
Expand All @@ -225,10 +225,7 @@ func (k *KMSEncryptor) getDerivedPrimitive(subKey SubKey) (tink.AEAD, error) {
return nil, fmt.Errorf("failed to create primitive: %w", err)
}

k.cachedDerivedMu.Lock()
k.cachedDerived[subKey] = primitive
k.cachedDerivedMu.Unlock()

cachedDerived[subKey] = primitive
return primitive, nil
}

Expand Down
33 changes: 33 additions & 0 deletions internal/mutex/mutex.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package mutex

import "sync"

// Mutex is a simple mutex that can be used to protect a value.
//
// The zero value is safe to use if the zero value of T is safe to use.
//
// Example:
//
// var m mutex.Mutex[*string]
// s := m.Lock()
// defer m.Unlock()
// *s = "hello"
type Mutex[T any] struct {
m sync.Mutex
v T
}

func New[T any](v T) *Mutex[T] {
return &Mutex[T]{v: v}
}

// Lock the Mutex and return its protected value.
func (l *Mutex[T]) Lock() T {
l.m.Lock()
return l.v
}

// Unlock the Mutex. The value returned by Lock is no longer valid.
func (l *Mutex[T]) Unlock() {
l.m.Unlock()
}

0 comments on commit bc77ebb

Please sign in to comment.