Skip to content

Commit

Permalink
update rotation stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
hooksie1 committed Aug 7, 2024
1 parent a1f39f2 commit 11e4c8f
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 62 deletions.
17 changes: 10 additions & 7 deletions service/rotate.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,19 +55,17 @@ func (a *AppContext) Rotate(currentKey string) ([]byte, error) {

updated, err := a.rotateKey(kvs)
if err != nil {
errs := a.rollbackKey(updated)
fmt.Println(errs)
return nil, nil
return nil, a.rollbackKey(updated)
}

databaseKey = newKey

return []byte(newKey), nil
}

func (a *AppContext) rollbackKey(kvs []rotatedKV) []error {
func (a *AppContext) rollbackKey(kvs []rotatedKV) error {
var failedKeys []string
logger := a.logger.WithContext(map[string]string{"rotation_step": "rollback"})
var errs []error
for _, v := range kvs {
if v.rotated == true {
logger.Infof("rolling back secret: %s", v.subject)
Expand All @@ -76,20 +74,25 @@ func (a *AppContext) rollbackKey(kvs []rotatedKV) []error {
data, err := a.getRecord(record, v.newKey)
if err != nil {
logger.Errorf("error in getting secret %s: %v", v.subject, err)
failedKeys = append(failedKeys, v.subject)
continue
}

record.SetValue(string(data))

if err := a.addRecord(record); err != nil {
errs = append(errs, err)
failedKeys = append(failedKeys, v.subject)
logger.Errorf("error rolling back encryption key on secret %s: %v", v.subject, err)
continue
}
}
}

return errs
if len(failedKeys) > 0 {
return fmt.Errorf("error rolling back keys: %v", failedKeys)
}

return nil
}

func (a *AppContext) rotateKey(kvs []rotatedKV) ([]rotatedKV, error) {
Expand Down
119 changes: 64 additions & 55 deletions service/rotate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func shutdownJSServerAndRemoveStorage(t *testing.T, s *server.Server) {
s.WaitForShutdown()
}

func setupEncryptedVals(t *testing.T, server *server.Server) ([]byte, AppContext) {
func setupEncryptedVals(t *testing.T, server *server.Server, vals map[string]string) ([]byte, AppContext) {

// nats connection
nc, err := nats.Connect(server.ClientURL())
Expand Down Expand Up @@ -80,7 +80,7 @@ func setupEncryptedVals(t *testing.T, server *server.Server) ([]byte, AppContext
t.Fatal(err)
}

for k, v := range testVals {
for k, v := range vals {
record := NewJSRecord().SetEncryptionKey(key).SetBucket(piggyBucket).SetKey(k).SetValue(v)
if err := app.addRecord(record); err != nil {
t.Error(err)
Expand All @@ -90,60 +90,69 @@ func setupEncryptedVals(t *testing.T, server *server.Server) ([]byte, AppContext
return key, app
}

func TestRotate(t *testing.T) {
func TestRotation(t *testing.T) {
// reset key
databaseKey = nil
server := NewServer(t)
defer shutdownJSServerAndRemoveStorage(t, server)

key, app := setupEncryptedVals(t, server)

_, err := app.Rotate(toBase64(key))
if err != nil {
t.Fatal(err)
tt := []struct {
name string
vals map[string]string
expected map[string]string
rollback bool
err bool
}{
{
name: "normal rotation",
rollback: false,
vals: testVals,
expected: testVals,
err: false,
},
{
name: "rollback with error",
rollback: true,
vals: testVals,
expected: map[string]string{
"piggybank.secrets.secret1": "thesecret",
"piggybank.secrets.secret2": "other secret",
"piggybank.secrets.secret3": "",
},
err: true,
},
}

for _, v := range tt {
t.Run(v.name, func(t *testing.T) {
databaseKey = nil
server := NewServer(t)
defer shutdownJSServerAndRemoveStorage(t, server)

key, app := setupEncryptedVals(t, server, v.vals)

// Change one key with bad data to cause rollback
if v.rollback {
record := NewJSRecord().SetEncryptionKey(generateKey()).SetBucket(piggyBucket).SetKey("piggybank.secrets.secret3").SetValue(string("other secret"))
if err := app.addRecord(record); err != nil {
t.Error(err)
}
}

_, err := app.Rotate(toBase64(key))
if err != nil && v.err != false {
t.Fatal(err)
}

for sub, val := range v.vals {
record := NewJSRecord().SetBucket(piggyBucket).SetKey(sub)
decrypted, err := app.getRecord(record, databaseKey)
if err != nil && v.err != true {
t.Error(err)
}

if string(decrypted) != v.expected[sub] {
t.Errorf("expected %s but got %s", val, string(decrypted))
}
}

})
}

for k, v := range testVals {
record := NewJSRecord().SetBucket(piggyBucket).SetKey(k)
decrypted, err := app.getRecord(record, databaseKey)
if err != nil {
t.Error(err)
}

if string(decrypted) != v {
t.Errorf("expected %s but got %s", v, string(decrypted))
}
}
}

func TestRollback(t *testing.T) {
// reset key
databaseKey = nil
server := NewServer(t)
defer shutdownJSServerAndRemoveStorage(t, server)

key, app := setupEncryptedVals(t, server)

// Change one key with bad data to cause rollback
record := NewJSRecord().SetEncryptionKey(generateKey()).SetBucket(piggyBucket).SetKey("piggybank.secrets.secret3").SetValue(string("other secret"))
if err := app.addRecord(record); err != nil {
t.Error(err)
}

_, err := app.Rotate(toBase64(key))
if err != nil {
t.Fatal(err)
}

for k, v := range testVals {
record := NewJSRecord().SetBucket(piggyBucket).SetKey(k)
decrypted, err := app.getRecord(record, databaseKey)
if err != nil {
t.Error(err)
}

if string(decrypted) != v {
t.Errorf("expected %s but got %s", v, string(decrypted))
}
}
}

0 comments on commit 11e4c8f

Please sign in to comment.