diff --git a/cmd/coordinator/run.go b/cmd/coordinator/run.go index dd169d5c..de1ce843 100644 --- a/cmd/coordinator/run.go +++ b/cmd/coordinator/run.go @@ -7,6 +7,7 @@ package main import ( + "context" "fmt" "os" "strings" @@ -106,7 +107,7 @@ func run(validator quote.Validator, issuer quote.Issuer, sealDir string, sealer if err != nil { log.Fatal("Cannot read startup manifest", zap.Error(err)) } - if _, err := clientServer.SetManifest(content); err != nil { + if _, err := clientServer.SetManifest(context.Background(), content); err != nil { log.Fatal("Cannot set startup manifest", zap.Error(err)) } } diff --git a/coordinator/clientapi/clientapi.go b/coordinator/clientapi/clientapi.go index 936d812b..28c2c849 100644 --- a/coordinator/clientapi/clientapi.go +++ b/coordinator/clientapi/clientapi.go @@ -8,6 +8,7 @@ package clientapi import ( + "context" "crypto/ecdsa" "crypto/rand" "crypto/sha256" @@ -35,32 +36,21 @@ import ( type core interface { Unlock() - RequireState(...state.State) error - AdvanceState(state.State, store.Transaction) error - GetState() (state.State, string, error) + RequireState(context.Context, ...state.State) error + AdvanceState(state.State, interface { + PutState(state.State) error + GetState() (state.State, error) + }) error + GetState(context.Context) (state.State, string, error) GenerateSecrets( - map[string]manifest.Secret, uuid.UUID, *x509.Certificate, *ecdsa.PrivateKey, + map[string]manifest.Secret, uuid.UUID, *x509.Certificate, *ecdsa.PrivateKey, *ecdsa.PrivateKey, ) (map[string]manifest.Secret, error) GetQuote() []byte GenerateQuote([]byte) error } -type storeWrapper interface { - GetIterator(prefix string) (wrapper.Iterator, error) - GetCertificate(string) (*x509.Certificate, error) - GetPrivateKey(string) (*ecdsa.PrivateKey, error) - GetSecret(string) (manifest.Secret, error) - GetUser(userName string) (*user.User, error) - GetManifest() (manifest.Manifest, error) - GetRawManifest() ([]byte, error) - GetManifestSignature() ([]byte, error) - GetUpdateLog() (string, error) - GetSecretMap() (map[string]manifest.Secret, error) - GetPackage(string) (quote.PackageProperties, error) -} - type transactionHandle interface { - BeginTransaction() (store.Transaction, error) + BeginTransaction(context.Context) (store.Transaction, error) SetEncryptionKey([]byte) error SetRecoveryData([]byte) LoadState() ([]byte, error) @@ -76,7 +66,6 @@ type updateLog interface { type ClientAPI struct { core core recovery recovery.Recovery - data storeWrapper txHandle transactionHandle updateLog updateLog @@ -84,7 +73,7 @@ type ClientAPI struct { } // New returns an initialized instance of the ClientAPI. -func New(store store.Store, recovery recovery.Recovery, core core, log *zap.Logger, +func New(txHandle transactionHandle, recovery recovery.Recovery, core core, log *zap.Logger, ) (*ClientAPI, error) { updateLog, err := updatelog.New() if err != nil { @@ -94,8 +83,7 @@ func New(store store.Store, recovery recovery.Recovery, core core, log *zap.Logg return &ClientAPI{ core: core, recovery: recovery, - data: wrapper.New(store), - txHandle: store, + txHandle: txHandle, updateLog: updateLog, log: log, }, nil @@ -105,10 +93,10 @@ func New(store store.Store, recovery recovery.Recovery, core core, log *zap.Logg // // Returns the remote attestation quote of its own certificate alongside this certificate, // which allows to verify the Coordinator's integrity and authentication for use of the ClientAPI. -func (a *ClientAPI) GetCertQuote() (cert string, certQuote []byte, err error) { +func (a *ClientAPI) GetCertQuote(ctx context.Context) (cert string, certQuote []byte, err error) { a.log.Info("GetCertQuote called") defer a.core.Unlock() - if err := a.core.RequireState(state.AcceptingManifest, state.AcceptingMarbles, state.Recovery); err != nil { + if err := a.core.RequireState(ctx, state.AcceptingManifest, state.AcceptingMarbles, state.Recovery); err != nil { a.log.Error("GetCertQuote: Coordinator not in correct state", zap.Error(err)) return "", nil, err } @@ -118,7 +106,13 @@ func (a *ClientAPI) GetCertQuote() (cert string, certQuote []byte, err error) { } }() - rootCert, err := a.data.GetCertificate(constants.SKCoordinatorRootCert) + txdata, rollback, _, err := wrapper.WrapTransaction(ctx, a.txHandle) + if err != nil { + return "", nil, err + } + defer rollback() + + rootCert, err := txdata.GetCertificate(constants.SKCoordinatorRootCert) if err != nil { return "", nil, fmt.Errorf("loading root certificate from store: %w", err) } @@ -126,7 +120,7 @@ func (a *ClientAPI) GetCertQuote() (cert string, certQuote []byte, err error) { return "", nil, errors.New("loaded nil root certificate from store") } - intermediateCert, err := a.data.GetCertificate(constants.SKCoordinatorIntermediateCert) + intermediateCert, err := txdata.GetCertificate(constants.SKCoordinatorIntermediateCert) if err != nil { return "", nil, fmt.Errorf("loading intermediate certificate from store: %w", err) } @@ -152,16 +146,23 @@ func (a *ClientAPI) GetCertQuote() (cert string, certQuote []byte, err error) { // GetManifestSignature returns the hash of the manifest. // // Returns ECDSA signature, SHA256 hash and byte encoded representation of the active manifest. -func (a *ClientAPI) GetManifestSignature() (manifestSignatureRootECDSA, manifestSignature, manifest []byte) { +func (a *ClientAPI) GetManifestSignature(ctx context.Context) (manifestSignatureRootECDSA, manifestSignature, manifest []byte) { a.log.Info("GetManifestSignature called") - rawManifest, err := a.data.GetRawManifest() + txdata, rollback, _, err := wrapper.WrapTransaction(ctx, a.txHandle) + if err != nil { + a.log.Error("GetManifestSignature failed: initializing store transaction", zap.Error(err)) + return nil, nil, nil + } + defer rollback() + + rawManifest, err := txdata.GetRawManifest() if err != nil { a.log.Error("GetManifestSignature failed: loading manifest from store", zap.Error(err)) return nil, nil, nil } hash := sha256.Sum256(rawManifest) - signature, err := a.data.GetManifestSignature() + signature, err := txdata.GetManifestSignature() if err != nil { a.log.Error("GetManifestSignature failed: loading manifest signature from store", zap.Error(err)) return nil, nil, nil @@ -172,11 +173,11 @@ func (a *ClientAPI) GetManifestSignature() (manifestSignatureRootECDSA, manifest } // GetSecrets allows a user to retrieve secrets from the Coordinator. -func (a *ClientAPI) GetSecrets(requestedSecrets []string, client *user.User) (map[string]manifest.Secret, error) { +func (a *ClientAPI) GetSecrets(ctx context.Context, requestedSecrets []string, client *user.User) (map[string]manifest.Secret, error) { a.log.Info("GetSecrets called", zap.Strings("secrets", requestedSecrets), zap.String("user", client.Name())) defer a.core.Unlock() // we can only return secrets if a manifest has already been set - if err := a.core.RequireState(state.AcceptingMarbles); err != nil { + if err := a.core.RequireState(ctx, state.AcceptingMarbles); err != nil { a.log.Error("GetSecrets: Coordinator not in correct state", zap.Error(err)) return nil, err } @@ -190,9 +191,15 @@ func (a *ClientAPI) GetSecrets(requestedSecrets []string, client *user.User) (ma return nil, fmt.Errorf("user %s is not allowed to read one or more secrets of: %v", client.Name(), requestedSecrets) } + txdata, rollback, _, err := wrapper.WrapTransaction(ctx, a.txHandle) + if err != nil { + return nil, err + } + defer rollback() + secrets := make(map[string]manifest.Secret) for _, requestedSecret := range requestedSecrets { - returnedSecret, err := a.data.GetSecret(requestedSecret) + returnedSecret, err := txdata.GetSecret(requestedSecret) if err != nil { a.log.Error("GetSecrets failed: loading secret from store", zap.String("secret", requestedSecret), zap.Error(err)) return nil, fmt.Errorf("loading secret %s from store: %w", requestedSecret, err) @@ -205,21 +212,27 @@ func (a *ClientAPI) GetSecrets(requestedSecrets []string, client *user.User) (ma } // GetStatus returns status information about the state of the Coordinator. -func (a *ClientAPI) GetStatus() (state.State, string, error) { +func (a *ClientAPI) GetStatus(ctx context.Context) (state.State, string, error) { a.log.Info("GetStatus called") - return a.core.GetState() + return a.core.GetState(ctx) } // GetUpdateLog returns the update history of the Coordinator. -func (a *ClientAPI) GetUpdateLog() (string, error) { +func (a *ClientAPI) GetUpdateLog(ctx context.Context) (string, error) { a.log.Info("GetUpdateLog called") defer a.core.Unlock() - if err := a.core.RequireState(state.AcceptingMarbles); err != nil { + if err := a.core.RequireState(ctx, state.AcceptingMarbles); err != nil { a.log.Error("GetUpdateLog: Coordinator not in correct state", zap.Error(err)) return "", err } - updateLog, err := a.data.GetUpdateLog() + txdata, rollback, _, err := wrapper.WrapTransaction(ctx, a.txHandle) + if err != nil { + return "", err + } + defer rollback() + + updateLog, err := txdata.GetUpdateLog() if err != nil { a.log.Error("GetUpdateLog failed: loading update log from store", zap.Error(err)) return "", fmt.Errorf("loading update log from store: %w", err) @@ -230,10 +243,10 @@ func (a *ClientAPI) GetUpdateLog() (string, error) { } // Recover sets an encryption key (ideally decrypted from the recovery data) and tries to unseal and load a saved state of the Coordinator. -func (a *ClientAPI) Recover(encryptionKey []byte) (keysLeft int, err error) { +func (a *ClientAPI) Recover(ctx context.Context, encryptionKey []byte) (keysLeft int, err error) { a.log.Info("Recover called") defer a.core.Unlock() - if err := a.core.RequireState(state.Recovery); err != nil { + if err := a.core.RequireState(ctx, state.Recovery); err != nil { a.log.Error("Recover: Coordinator not in correct state", zap.Error(err)) return -1, err } @@ -269,7 +282,13 @@ func (a *ClientAPI) Recover(encryptionKey []byte) (keysLeft int, err error) { a.log.Error("Could not retrieve recovery data from state. Recovery will be unavailable", zap.Error(err)) } - rootCert, err := a.data.GetCertificate(constants.SKCoordinatorRootCert) + txdata, rollback, _, err := wrapper.WrapTransaction(ctx, a.txHandle) + if err != nil { + return -1, err + } + defer rollback() + + rootCert, err := txdata.GetCertificate(constants.SKCoordinatorRootCert) if err != nil { return -1, fmt.Errorf("loading root certificate from store: %w", err) } @@ -286,10 +305,10 @@ func (a *ClientAPI) Recover(encryptionKey []byte) (keysLeft int, err error) { // // rawManifest is the manifest of type Manifest in JSON format. // recoverySecretMap is a map of recovery secrets that can be used to recover the Coordinator. -func (a *ClientAPI) SetManifest(rawManifest []byte) (recoverySecretMap map[string][]byte, err error) { +func (a *ClientAPI) SetManifest(ctx context.Context, rawManifest []byte) (recoverySecretMap map[string][]byte, err error) { a.log.Info("SetManifest called") defer a.core.Unlock() - if err := a.core.RequireState(state.AcceptingManifest, state.Recovery); err != nil { + if err := a.core.RequireState(ctx, state.AcceptingManifest, state.Recovery); err != nil { a.log.Error("SetManifest: Coordinator not in correct state", zap.Error(err)) return nil, err } @@ -307,27 +326,33 @@ func (a *ClientAPI) SetManifest(rawManifest []byte) (recoverySecretMap map[strin return nil, fmt.Errorf("checking manifest: %w", err) } - marbleRootCert, err := a.data.GetCertificate(constants.SKMarbleRootCert) + txdata, rollback, commit, err := wrapper.WrapTransaction(ctx, a.txHandle) + if err != nil { + return nil, err + } + defer rollback() + + marbleRootCert, err := txdata.GetCertificate(constants.SKMarbleRootCert) if err != nil { return nil, fmt.Errorf("loading root certificate from store: %w", err) } - rootPrivK, err := a.data.GetPrivateKey(constants.SKCoordinatorRootKey) + rootPrivK, err := txdata.GetPrivateKey(constants.SKCoordinatorRootKey) if err != nil { return nil, fmt.Errorf("loading root private key from store: %w", err) } - intermediatePrivK, err := a.data.GetPrivateKey(constants.SKCoordinatorIntermediateKey) + intermediatePrivK, err := txdata.GetPrivateKey(constants.SKCoordinatorIntermediateKey) if err != nil { return nil, fmt.Errorf("loading intermediate private key from store: %w", err) } // Generate shared secrets specified in manifest - secrets, err := a.core.GenerateSecrets(mnf.Secrets, uuid.Nil, marbleRootCert, intermediatePrivK) + secrets, err := a.core.GenerateSecrets(mnf.Secrets, uuid.Nil, marbleRootCert, intermediatePrivK, rootPrivK) if err != nil { a.log.Error("Could not generate specified secrets for the given manifest.", zap.Error(err)) return nil, fmt.Errorf("generating secrets from manifest: %w", err) } // generate placeholders for private secrets specified in manifest - privSecrets, err := a.core.GenerateSecrets(mnf.Secrets, uuid.New(), marbleRootCert, intermediatePrivK) + privSecrets, err := a.core.GenerateSecrets(mnf.Secrets, uuid.New(), marbleRootCert, intermediatePrivK, rootPrivK) if err != nil { a.log.Error("Could not generate specified secrets for the given manifest.", zap.Error(err)) return nil, fmt.Errorf("generating placeholder secrets from manifest: %w", err) @@ -364,13 +389,6 @@ func (a *ClientAPI) SetManifest(rawManifest []byte) (recoverySecretMap map[strin return nil, fmt.Errorf("signing manifest: %w", err) } - tx, err := a.txHandle.BeginTransaction() - if err != nil { - return nil, fmt.Errorf("initializing store transaction: %w", err) - } - defer tx.Rollback() - txdata := wrapper.New(tx) - for secretName, secret := range privSecrets { secrets[secretName] = secret } @@ -434,11 +452,11 @@ func (a *ClientAPI) SetManifest(rawManifest []byte) (recoverySecretMap map[strin return nil, fmt.Errorf("saving update log to store: %w", err) } - if err := a.core.AdvanceState(state.AcceptingMarbles, tx); err != nil { + if err := a.core.AdvanceState(state.AcceptingMarbles, txdata); err != nil { return nil, fmt.Errorf("advancing state: %w", err) } a.txHandle.SetRecoveryData(recoveryData) - if err := tx.Commit(); err != nil { + if err := commit(ctx); err != nil { a.log.Error("sealing of state failed", zap.Error(err)) } @@ -447,11 +465,11 @@ func (a *ClientAPI) SetManifest(rawManifest []byte) (recoverySecretMap map[strin } // UpdateManifest allows to update certain package parameters of the original manifest, supplied via a JSON manifest. -func (a *ClientAPI) UpdateManifest(rawUpdateManifest []byte, updater *user.User) (err error) { +func (a *ClientAPI) UpdateManifest(ctx context.Context, rawUpdateManifest []byte, updater *user.User) (err error) { a.log.Info("UpdateManifest called") defer a.core.Unlock() // Only accept update manifest if we already have a manifest - if err := a.core.RequireState(state.AcceptingMarbles); err != nil { + if err := a.core.RequireState(ctx, state.AcceptingMarbles); err != nil { a.log.Error("UpdateManifest: Coordinator not in correct state", zap.Error(err)) return err } @@ -476,9 +494,15 @@ func (a *ClientAPI) UpdateManifest(rawUpdateManifest []byte, updater *user.User) return fmt.Errorf("user %s is not allowed to update one or more packages of %v", updater.Name(), wantedPackages) } + txdata, rollback, commit, err := wrapper.WrapTransaction(ctx, a.txHandle) + if err != nil { + return err + } + defer rollback() + currentPackages := make(map[string]quote.PackageProperties) for pkgName := range updateManifest.Packages { - pkg, err := a.data.GetPackage(pkgName) + pkg, err := txdata.GetPackage(pkgName) if err != nil { return fmt.Errorf("loading current package %q from store: %w", pkgName, err) } @@ -505,11 +529,11 @@ func (a *ClientAPI) UpdateManifest(rawUpdateManifest []byte, updater *user.User) } } - rootCert, err := a.data.GetCertificate(constants.SKCoordinatorRootCert) + rootCert, err := txdata.GetCertificate(constants.SKCoordinatorRootCert) if err != nil { return fmt.Errorf("loading root certificate from store: %w", err) } - rootPrivK, err := a.data.GetPrivateKey(constants.SKCoordinatorRootKey) + rootPrivK, err := txdata.GetPrivateKey(constants.SKCoordinatorRootKey) if err != nil { return fmt.Errorf("loading root private key from store: %w", err) } @@ -527,7 +551,7 @@ func (a *ClientAPI) UpdateManifest(rawUpdateManifest []byte, updater *user.User) // Gather all shared certificate secrets we need to regenerate secretsToRegenerate := make(map[string]manifest.Secret) - secrets, err := a.data.GetSecretMap() + secrets, err := txdata.GetSecretMap() if err != nil { return fmt.Errorf("loading existing shared secrets from store: %w", err) } @@ -538,7 +562,7 @@ func (a *ClientAPI) UpdateManifest(rawUpdateManifest []byte, updater *user.User) } // Regenerate shared secrets specified in manifest - regeneratedSecrets, err := a.core.GenerateSecrets(secretsToRegenerate, uuid.Nil, marbleRootCert, intermediatePrivK) + regeneratedSecrets, err := a.core.GenerateSecrets(secretsToRegenerate, uuid.Nil, marbleRootCert, intermediatePrivK, rootPrivK) if err != nil { a.log.Error("Could not generate specified secrets for the given manifest.", zap.Error(err)) return fmt.Errorf("regenerating shared secrets for updated manifest: %w", err) @@ -556,13 +580,6 @@ func (a *ClientAPI) UpdateManifest(rawUpdateManifest []byte, updater *user.User) a.updateLog.Info("SecurityVersion increased", zap.String("user", updater.Name()), zap.String("package", pkgName), zap.Uint("new version", *pkg.SecurityVersion)) } - tx, err := a.txHandle.BeginTransaction() - if err != nil { - return fmt.Errorf("initializing store transaction: %w", err) - } - defer tx.Rollback() - txdata := wrapper.New(tx) - if err := txdata.PutCertificate(constants.SKCoordinatorIntermediateCert, intermediateCert); err != nil { return fmt.Errorf("saving new intermediate certificate to store: %w", err) } @@ -593,7 +610,7 @@ func (a *ClientAPI) UpdateManifest(rawUpdateManifest []byte, updater *user.User) a.log.Info("Please restart your Marbles to enforce the update.") a.txHandle.SetRecoveryData(currentRecoveryData) - if err := tx.Commit(); err != nil { + if err := commit(ctx); err != nil { return fmt.Errorf("updating manifest failed: committing store transaction: %w", err) } @@ -602,8 +619,14 @@ func (a *ClientAPI) UpdateManifest(rawUpdateManifest []byte, updater *user.User) } // VerifyUser checks if a given client certificate matches the admin certificates specified in the manifest. -func (a *ClientAPI) VerifyUser(clientCerts []*x509.Certificate) (*user.User, error) { - userIter, err := a.data.GetIterator(request.User) +func (a *ClientAPI) VerifyUser(ctx context.Context, clientCerts []*x509.Certificate) (*user.User, error) { + txdata, rollback, _, err := wrapper.WrapTransaction(ctx, a.txHandle) + if err != nil { + return nil, err + } + defer rollback() + + userIter, err := txdata.GetIterator(request.User) if err != nil { return nil, fmt.Errorf("getting user iterator: %w", err) } @@ -615,7 +638,7 @@ func (a *ClientAPI) VerifyUser(clientCerts []*x509.Certificate) (*user.User, err if err != nil { return nil, fmt.Errorf("getting next user: %w", err) } - user, err := a.data.GetUser(name) + user, err := txdata.GetUser(name) if err != nil { return nil, fmt.Errorf("getting user %q: %w", name, err) } @@ -629,11 +652,11 @@ func (a *ClientAPI) VerifyUser(clientCerts []*x509.Certificate) (*user.User, err } // WriteSecrets allows a user to set certain user-defined secrets for the Coordinator. -func (a *ClientAPI) WriteSecrets(rawSecretManifest []byte, updater *user.User) (err error) { +func (a *ClientAPI) WriteSecrets(ctx context.Context, rawSecretManifest []byte, updater *user.User) (err error) { a.log.Info("WriteSecrets called", zap.String("user", updater.Name())) defer a.core.Unlock() // Only accept secrets if we already have a manifest - if err := a.core.RequireState(state.AcceptingMarbles); err != nil { + if err := a.core.RequireState(ctx, state.AcceptingMarbles); err != nil { a.log.Error("WriteSecrets: Coordinator not in correct state", zap.Error(err)) return err } @@ -649,8 +672,14 @@ func (a *ClientAPI) WriteSecrets(rawSecretManifest []byte, updater *user.User) ( return fmt.Errorf("unmarshaling secret manifest: %w", err) } + txdata, rollback, commit, err := wrapper.WrapTransaction(ctx, a.txHandle) + if err != nil { + return err + } + defer rollback() + // validate and parse new secrets - secretMeta, err := a.data.GetSecretMap() + secretMeta, err := txdata.GetSecretMap() if err != nil { return fmt.Errorf("loading existing secrets: %w", err) } @@ -674,7 +703,7 @@ func (a *ClientAPI) WriteSecrets(rawSecretManifest []byte, updater *user.User) ( for k, v := range newSecrets { secretMeta[k] = v } - mnf, err := a.data.GetManifest() + mnf, err := txdata.GetManifest() if err != nil { return fmt.Errorf("loading manifest: %w", err) } @@ -683,13 +712,6 @@ func (a *ClientAPI) WriteSecrets(rawSecretManifest []byte, updater *user.User) ( return fmt.Errorf("running manifest template dry run: %w", err) } - tx, err := a.txHandle.BeginTransaction() - if err != nil { - return fmt.Errorf("initializing store transaction: %w", err) - } - defer tx.Rollback() - txdata := wrapper.New(tx) - a.updateLog.Reset() for secretName, secret := range newSecrets { // verify user is allowed to set the secret @@ -705,5 +727,5 @@ func (a *ClientAPI) WriteSecrets(rawSecretManifest []byte, updater *user.User) ( return fmt.Errorf("saving update log to store: %w", err) } - return tx.Commit() + return commit(ctx) } diff --git a/coordinator/clientapi/clientapi_test.go b/coordinator/clientapi/clientapi_test.go index 949f5070..3007588d 100644 --- a/coordinator/clientapi/clientapi_test.go +++ b/coordinator/clientapi/clientapi_test.go @@ -8,6 +8,7 @@ package clientapi import ( "bytes" + "context" "crypto/ecdsa" "crypto/sha256" "crypto/x509" @@ -18,9 +19,15 @@ import ( "github.com/edgelesssys/marblerun/coordinator/constants" "github.com/edgelesssys/marblerun/coordinator/crypto" "github.com/edgelesssys/marblerun/coordinator/manifest" + "github.com/edgelesssys/marblerun/coordinator/seal" "github.com/edgelesssys/marblerun/coordinator/state" "github.com/edgelesssys/marblerun/coordinator/store" + "github.com/edgelesssys/marblerun/coordinator/store/request" + "github.com/edgelesssys/marblerun/coordinator/store/stdstore" "github.com/edgelesssys/marblerun/coordinator/store/wrapper" + "github.com/edgelesssys/marblerun/coordinator/store/wrapper/testutil" + "github.com/edgelesssys/marblerun/coordinator/user" + "github.com/edgelesssys/marblerun/test" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -34,55 +41,45 @@ func TestMain(m *testing.M) { func TestGetCertQuote(t *testing.T) { someErr := errors.New("failed") + // these are not actually root and intermediate certs + // but we don't care for this test + rootCert, intermediateCert := test.MustSetupTestCerts(test.RecoveryPrivateKey) + + prepareDefaultStore := func() store.Store { + s := stdstore.New(&seal.MockSealer{}) + require.NoError(t, wrapper.New(s).PutCertificate(constants.SKCoordinatorRootCert, rootCert)) + require.NoError(t, wrapper.New(s).PutCertificate(constants.SKCoordinatorIntermediateCert, intermediateCert)) + return s + } testCases := map[string]struct { - storeWrapper *stubStoreWrapper - core *fakeCore - wantErr bool + store store.Store + core *fakeCore + wantErr bool }{ "success state accepting Marbles": { - storeWrapper: &stubStoreWrapper{ - getCertificateList: map[string]*x509.Certificate{ - constants.SKCoordinatorRootCert: {Raw: []byte("root")}, - constants.SKCoordinatorIntermediateCert: {Raw: []byte("intermediate")}, - }, - }, + store: prepareDefaultStore(), core: &fakeCore{ state: state.AcceptingMarbles, quote: []byte("quote"), }, }, "success state accepting manifest": { - storeWrapper: &stubStoreWrapper{ - getCertificateList: map[string]*x509.Certificate{ - constants.SKCoordinatorRootCert: {Raw: []byte("root")}, - constants.SKCoordinatorIntermediateCert: {Raw: []byte("intermediate")}, - }, - }, + store: prepareDefaultStore(), core: &fakeCore{ state: state.AcceptingManifest, quote: []byte("quote"), }, }, "success state recovery": { - storeWrapper: &stubStoreWrapper{ - getCertificateList: map[string]*x509.Certificate{ - constants.SKCoordinatorRootCert: {Raw: []byte("root")}, - constants.SKCoordinatorIntermediateCert: {Raw: []byte("intermediate")}, - }, - }, + store: prepareDefaultStore(), core: &fakeCore{ state: state.Recovery, quote: []byte("quote"), }, }, "unsupported state": { - storeWrapper: &stubStoreWrapper{ - getCertificateList: map[string]*x509.Certificate{ - constants.SKCoordinatorRootCert: {Raw: []byte("root")}, - constants.SKCoordinatorIntermediateCert: {Raw: []byte("intermediate")}, - }, - }, + store: prepareDefaultStore(), core: &fakeCore{ state: state.Uninitialized, quote: []byte("quote"), @@ -90,50 +87,19 @@ func TestGetCertQuote(t *testing.T) { wantErr: true, }, "error getting state": { - storeWrapper: &stubStoreWrapper{ - getCertificateList: map[string]*x509.Certificate{ - constants.SKCoordinatorRootCert: {Raw: []byte("root")}, - constants.SKCoordinatorIntermediateCert: {Raw: []byte("intermediate")}, - }, - }, + store: prepareDefaultStore(), core: &fakeCore{ requireStateErr: someErr, quote: []byte("quote"), }, wantErr: true, }, - "empty root cert": { - storeWrapper: &stubStoreWrapper{ - getCertificateList: map[string]*x509.Certificate{ - constants.SKCoordinatorRootCert: nil, - constants.SKCoordinatorIntermediateCert: {Raw: []byte("intermediate")}, - }, - }, - core: &fakeCore{ - state: state.AcceptingMarbles, - quote: []byte("quote"), - }, - wantErr: true, - }, - "empty intermediate cert": { - storeWrapper: &stubStoreWrapper{ - getCertificateList: map[string]*x509.Certificate{ - constants.SKCoordinatorRootCert: {Raw: []byte("root")}, - constants.SKCoordinatorIntermediateCert: nil, - }, - }, - core: &fakeCore{ - state: state.AcceptingMarbles, - quote: []byte("quote"), - }, - wantErr: true, - }, "root certificate not set": { - storeWrapper: &stubStoreWrapper{ - getCertificateList: map[string]*x509.Certificate{ - constants.SKCoordinatorIntermediateCert: {Raw: []byte("intermediate")}, - }, - }, + store: func() store.Store { + s := stdstore.New(&seal.MockSealer{}) + require.NoError(t, wrapper.New(s).PutCertificate(constants.SKCoordinatorIntermediateCert, intermediateCert)) + return s + }(), core: &fakeCore{ state: state.AcceptingMarbles, quote: []byte("quote"), @@ -141,11 +107,11 @@ func TestGetCertQuote(t *testing.T) { wantErr: true, }, "intermediate certificate not set": { - storeWrapper: &stubStoreWrapper{ - getCertificateList: map[string]*x509.Certificate{ - constants.SKCoordinatorRootCert: {Raw: []byte("root")}, - }, - }, + store: func() store.Store { + s := stdstore.New(&seal.MockSealer{}) + require.NoError(t, wrapper.New(s).PutCertificate(constants.SKCoordinatorRootCert, rootCert)) + return s + }(), core: &fakeCore{ state: state.AcceptingMarbles, quote: []byte("quote"), @@ -164,12 +130,18 @@ func TestGetCertQuote(t *testing.T) { defer log.Sync() api := &ClientAPI{ - core: tc.core, - data: tc.storeWrapper, - log: log, + core: tc.core, + txHandle: tc.store, + log: log, + } + + var intermediateCert, rootCert *x509.Certificate + if !tc.wantErr { + intermediateCert = testutil.GetCertificate(t, tc.store, constants.SKCoordinatorIntermediateCert) + rootCert = testutil.GetCertificate(t, tc.store, constants.SKCoordinatorRootCert) } - cert, quote, err := api.GetCertQuote() + cert, quote, err := api.GetCertQuote(context.Background()) if tc.wantErr { assert.Error(err) return @@ -177,38 +149,38 @@ func TestGetCertQuote(t *testing.T) { require.NoError(err) assert.Equal(tc.core.quote, quote) - intermediateCert := tc.storeWrapper.getCertificateList[constants.SKCoordinatorIntermediateCert] - rootCert := tc.storeWrapper.getCertificateList[constants.SKCoordinatorRootCert] assert.Equal(mustEncodeToPem(t, intermediateCert)+mustEncodeToPem(t, rootCert), cert) }) } } func TestGetManifestSignature(t *testing.T) { - someErr := errors.New("failed") - testCases := map[string]struct { - data *stubStoreWrapper + store store.Store wantErr bool }{ "success": { - data: &stubStoreWrapper{ - rawManifest: []byte("manifest"), - manifestSignature: []byte("signature"), - }, + store: func() store.Store { + s := stdstore.New(&seal.MockSealer{}) + require.NoError(t, s.Put(request.Manifest, []byte("manifest"))) + require.NoError(t, s.Put(request.ManifestSignature, []byte("signature"))) + return s + }(), }, "GetRawManifest fails": { - data: &stubStoreWrapper{ - getRawManifestErr: someErr, - manifestSignature: []byte("signature"), - }, + store: func() store.Store { + s := stdstore.New(&seal.MockSealer{}) + require.NoError(t, s.Put(request.ManifestSignature, []byte("signature"))) + return s + }(), wantErr: true, }, "GetManifestSignature fails": { - data: &stubStoreWrapper{ - rawManifest: []byte("manifest"), - getManifestSignatureErr: someErr, - }, + store: func() store.Store { + s := stdstore.New(&seal.MockSealer{}) + require.NoError(t, s.Put(request.Manifest, []byte("manifest"))) + return s + }(), wantErr: true, }, } @@ -223,31 +195,197 @@ func TestGetManifestSignature(t *testing.T) { defer log.Sync() api := &ClientAPI{ - data: tc.data, - log: log, + txHandle: tc.store, + log: log, } - signature, hash, manifest := api.GetManifestSignature() + var rawManifest, manifestSignature, manifestHash []byte + if !tc.wantErr { + rawManifest = testutil.GetRawManifest(t, tc.store) + manifestSignature = testutil.GetManifestSignature(t, tc.store) + h := sha256.Sum256(rawManifest) + manifestHash = h[:] + } + + signature, hash, manifest := api.GetManifestSignature(context.Background()) if tc.wantErr { assert.Nil(signature) assert.Nil(hash) assert.Nil(manifest) return } - assert.Equal(tc.data.rawManifest, manifest) - expectedHash := sha256.Sum256(tc.data.rawManifest) - assert.Equal(expectedHash[:], hash) - assert.Equal(tc.data.manifestSignature, signature) + assert.Equal(rawManifest, manifest) + assert.Equal(manifestHash, hash) + assert.Equal(manifestSignature, signature) }) } } func TestGetSecrets(t *testing.T) { - t.Log("WARNING: Missing unit Test for GetSecrets") + newUserWithPermissions := func(name string, secretNames ...string) *user.User { + u := user.NewUser(name, nil) + u.Assign(user.NewPermission(user.PermissionReadSecret, secretNames)) + return u + } + + testCases := map[string]struct { + store store.Store + core *fakeCore + request []string + user *user.User + wantErr bool + }{ + "success": { + store: func() store.Store { + s := stdstore.New(&seal.MockSealer{}) + require.NoError(t, wrapper.New(s).PutSecret("secret1", manifest.Secret{ + Type: manifest.SecretTypePlain, + Private: []byte("secret"), + })) + require.NoError(t, wrapper.New(s).PutSecret("secret2", manifest.Secret{ + Type: manifest.SecretTypePlain, + Private: []byte("secret"), + })) + return s + }(), + core: &fakeCore{state: state.AcceptingMarbles}, + request: []string{ + "secret1", + "secret2", + }, + user: newUserWithPermissions("test", "secret1", "secret2"), + }, + "wrong state": { + store: func() store.Store { + s := stdstore.New(&seal.MockSealer{}) + require.NoError(t, wrapper.New(s).PutSecret("secret1", manifest.Secret{ + Type: manifest.SecretTypePlain, + Private: []byte("secret"), + })) + require.NoError(t, wrapper.New(s).PutSecret("secret2", manifest.Secret{ + Type: manifest.SecretTypePlain, + Private: []byte("secret"), + })) + return s + }(), + core: &fakeCore{state: state.AcceptingManifest}, + request: []string{ + "secret1", + "secret2", + }, + user: newUserWithPermissions("test", "secret1", "secret2"), + wantErr: true, + }, + "user is missing permissions": { + store: func() store.Store { + s := stdstore.New(&seal.MockSealer{}) + require.NoError(t, wrapper.New(s).PutSecret("secret1", manifest.Secret{ + Type: manifest.SecretTypePlain, + Private: []byte("secret"), + })) + require.NoError(t, wrapper.New(s).PutSecret("secret2", manifest.Secret{ + Type: manifest.SecretTypePlain, + Private: []byte("secret"), + })) + return s + }(), + core: &fakeCore{state: state.AcceptingMarbles}, + request: []string{ + "secret1", + "secret2", + }, + user: newUserWithPermissions("test", "secret2"), // only permission for secret2 + wantErr: true, + }, + "secret does not exist": { + store: func() store.Store { + s := stdstore.New(&seal.MockSealer{}) + require.NoError(t, wrapper.New(s).PutSecret("secret1", manifest.Secret{ + Type: manifest.SecretTypePlain, + Private: []byte("secret"), + })) + return s + }(), + core: &fakeCore{state: state.AcceptingMarbles}, + request: []string{ + "secret1", + "secret2", + }, + user: newUserWithPermissions("test", "secret1", "secret2"), + wantErr: true, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + log, err := zap.NewDevelopment() + require.NoError(err) + defer log.Sync() + + api := &ClientAPI{ + txHandle: tc.store, + core: tc.core, + log: log, + } + + storedSecrets := testutil.GetSecretMap(t, tc.store) + + secrets, err := api.GetSecrets(context.Background(), tc.request, tc.user) + if tc.wantErr { + assert.Error(err) + return + } + require.NoError(err) + for name, secret := range secrets { + assert.Equal(storedSecrets[name], secret) + } + }) + } } func TestGetStatus(t *testing.T) { - t.Log("WARNING: Missing unit Test for GetStatus") + testCases := map[string]struct { + core *fakeCore + wantErr bool + }{ + "success": { + core: &fakeCore{state: state.AcceptingManifest}, + }, + "error": { + core: &fakeCore{ + state: state.AcceptingManifest, + getStateErr: errors.New("failed"), + }, + wantErr: true, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + log, err := zap.NewDevelopment() + require.NoError(err) + defer log.Sync() + + api := &ClientAPI{ + core: tc.core, + log: log, + } + + status, _, err := api.GetStatus(context.Background()) + if tc.wantErr { + assert.Error(err) + return + } + require.NoError(err) + assert.Equal(tc.core.state, status) + }) + } } func TestGetUpdateLog(t *testing.T) { @@ -256,33 +394,32 @@ func TestGetUpdateLog(t *testing.T) { func TestRecover(t *testing.T) { someErr := errors.New("failed") + _, rootCert := test.MustSetupTestCerts(test.RecoveryPrivateKey) + defaultStore := func() store.Store { + s := stdstore.New(&seal.MockSealer{}) + require.NoError(t, wrapper.New(s).PutCertificate(constants.SKCoordinatorRootCert, rootCert)) + return s + } testCases := map[string]struct { - data *stubStoreWrapper - store *stubStore + store *fakeStore recovery *stubRecovery core *fakeCore wantErr bool }{ "success": { - data: &stubStoreWrapper{ - getCertificateList: map[string]*x509.Certificate{ - constants.SKCoordinatorRootCert: {Raw: []byte("root cert")}, - }, + store: &fakeStore{ + store: defaultStore(), }, - store: &stubStore{}, recovery: &stubRecovery{}, core: &fakeCore{ state: state.Recovery, }, }, "more than one key required": { - data: &stubStoreWrapper{ - getCertificateList: map[string]*x509.Certificate{ - constants.SKCoordinatorRootCert: {Raw: []byte("root cert")}, - }, + store: &fakeStore{ + store: defaultStore(), }, - store: &stubStore{}, recovery: &stubRecovery{ recoverKeysLeft: 1, }, @@ -291,12 +428,9 @@ func TestRecover(t *testing.T) { }, }, "SetRecoveryData fails does not result in error": { - data: &stubStoreWrapper{ - getCertificateList: map[string]*x509.Certificate{ - constants.SKCoordinatorRootCert: {Raw: []byte("root cert")}, - }, + store: &fakeStore{ + store: defaultStore(), }, - store: &stubStore{}, recovery: &stubRecovery{ setRecoveryDataErr: someErr, }, @@ -305,12 +439,9 @@ func TestRecover(t *testing.T) { }, }, "Coordinator not in recovery state": { - data: &stubStoreWrapper{ - getCertificateList: map[string]*x509.Certificate{ - constants.SKCoordinatorRootCert: {Raw: []byte("root cert")}, - }, + store: &fakeStore{ + store: defaultStore(), }, - store: &stubStore{}, recovery: &stubRecovery{}, core: &fakeCore{ state: state.AcceptingManifest, @@ -318,12 +449,9 @@ func TestRecover(t *testing.T) { wantErr: true, }, "RecoverKey fails": { - data: &stubStoreWrapper{ - getCertificateList: map[string]*x509.Certificate{ - constants.SKCoordinatorRootCert: {Raw: []byte("root cert")}, - }, + store: &fakeStore{ + store: defaultStore(), }, - store: &stubStore{}, recovery: &stubRecovery{ recoverKeyErr: someErr, }, @@ -333,12 +461,8 @@ func TestRecover(t *testing.T) { wantErr: true, }, "LoadState fails": { - data: &stubStoreWrapper{ - getCertificateList: map[string]*x509.Certificate{ - constants.SKCoordinatorRootCert: {Raw: []byte("root cert")}, - }, - }, - store: &stubStore{ + store: &fakeStore{ + store: defaultStore(), loadStateErr: someErr, }, recovery: &stubRecovery{}, @@ -348,12 +472,8 @@ func TestRecover(t *testing.T) { wantErr: true, }, "SetEncryptionKey fails": { - data: &stubStoreWrapper{ - getCertificateList: map[string]*x509.Certificate{ - constants.SKCoordinatorRootCert: {Raw: []byte("root cert")}, - }, - }, - store: &stubStore{ + store: &fakeStore{ + store: defaultStore(), setEncryptionKeyErr: someErr, }, recovery: &stubRecovery{}, @@ -363,10 +483,9 @@ func TestRecover(t *testing.T) { wantErr: true, }, "GetCertificate fails": { - data: &stubStoreWrapper{ - getCertificateErr: someErr, + store: &fakeStore{ + store: stdstore.New(&seal.MockSealer{}), }, - store: &stubStore{}, recovery: &stubRecovery{}, core: &fakeCore{ state: state.Recovery, @@ -374,12 +493,9 @@ func TestRecover(t *testing.T) { wantErr: true, }, "GenerateQuote fails": { - data: &stubStoreWrapper{ - getCertificateList: map[string]*x509.Certificate{ - constants.SKCoordinatorRootCert: {Raw: []byte("root cert")}, - }, + store: &fakeStore{ + store: defaultStore(), }, - store: &stubStore{}, recovery: &stubRecovery{}, core: &fakeCore{ state: state.Recovery, @@ -399,14 +515,13 @@ func TestRecover(t *testing.T) { defer log.Sync() api := &ClientAPI{ - data: tc.data, txHandle: tc.store, recovery: tc.recovery, core: tc.core, log: log, } - keysLeft, err := api.Recover([]byte("recoveryKey")) + keysLeft, err := api.Recover(context.Background(), []byte("recoveryKey")) if tc.wantErr { assert.Error(err) return @@ -459,7 +574,7 @@ func (c *fakeCore) Unlock() { c.unlockCalled = true } -func (c *fakeCore) RequireState(states ...state.State) error { +func (c *fakeCore) RequireState(_ context.Context, states ...state.State) error { if c.requireStateErr != nil { return c.requireStateErr } @@ -472,7 +587,11 @@ func (c *fakeCore) RequireState(states ...state.State) error { return errors.New("core is not in expected state") } -func (c *fakeCore) AdvanceState(newState state.State, _ store.Transaction) error { +func (c *fakeCore) AdvanceState(newState state.State, _ interface { + PutState(state.State) error + GetState() (state.State, error) +}, +) error { if c.advanceStateErr != nil { return c.advanceStateErr } @@ -484,11 +603,12 @@ func (c *fakeCore) AdvanceState(newState state.State, _ store.Transaction) error return nil } -func (c *fakeCore) GetState() (state.State, string, error) { +func (c *fakeCore) GetState(_ context.Context) (state.State, string, error) { return c.state, c.getStateMsg, c.getStateErr } -func (c *fakeCore) GenerateSecrets(newSecrets map[string]manifest.Secret, _ uuid.UUID, rootCert *x509.Certificate, privK *ecdsa.PrivateKey, +func (c *fakeCore) GenerateSecrets( + newSecrets map[string]manifest.Secret, _ uuid.UUID, rootCert *x509.Certificate, privK *ecdsa.PrivateKey, _ *ecdsa.PrivateKey, ) (map[string]manifest.Secret, error) { if c.generateSecretsErr != nil || c.generatedSecrets != nil { return c.generatedSecrets, c.generateSecretsErr @@ -534,53 +654,9 @@ func (c *fakeCore) GenerateQuote(quoteData []byte) error { return nil } -type stubStoreWrapper struct { - getCertificateList map[string]*x509.Certificate - getCertificateErr error - getPrivateKeyList map[string]*ecdsa.PrivateKey - getPrivateKeyErr error - rawManifest []byte - getRawManifestErr error - manifestSignature []byte - getManifestSignatureErr error - wrapper.Wrapper -} - -func (s *stubStoreWrapper) GetCertificate(certName string) (*x509.Certificate, error) { - if s.getCertificateErr != nil { - return nil, s.getCertificateErr - } - - cert, ok := s.getCertificateList[certName] - if !ok { - return nil, errors.New("certificate not found") - } - - return cert, nil -} - -func (s *stubStoreWrapper) GetPrivateKey(keyName string) (*ecdsa.PrivateKey, error) { - if s.getPrivateKeyErr != nil { - return nil, s.getPrivateKeyErr - } - - key, ok := s.getPrivateKeyList[keyName] - if !ok { - return nil, errors.New("private key not found") - } - - return key, nil -} - -func (s *stubStoreWrapper) GetRawManifest() ([]byte, error) { - return s.rawManifest, s.getRawManifestErr -} - -func (s *stubStoreWrapper) GetManifestSignature() ([]byte, error) { - return s.manifestSignature, s.getManifestSignatureErr -} - -type stubStore struct { +type fakeStore struct { + store store.Store + beginTransactionErr error recoveryData []byte encryptionKey []byte setEncryptionKeyErr error @@ -589,11 +665,14 @@ type stubStore struct { loadCalled bool } -func (s *stubStore) BeginTransaction() (store.Transaction, error) { - return nil, nil +func (s *fakeStore) BeginTransaction(ctx context.Context) (store.Transaction, error) { + if s.beginTransactionErr != nil { + return nil, s.beginTransactionErr + } + return s.store.BeginTransaction(ctx) } -func (s *stubStore) SetEncryptionKey(key []byte) error { +func (s *fakeStore) SetEncryptionKey(key []byte) error { if s.setEncryptionKeyErr != nil { return s.setEncryptionKeyErr } @@ -601,11 +680,11 @@ func (s *stubStore) SetEncryptionKey(key []byte) error { return nil } -func (s *stubStore) SetRecoveryData(recoveryData []byte) { +func (s *fakeStore) SetRecoveryData(recoveryData []byte) { s.recoveryData = recoveryData } -func (s *stubStore) LoadState() ([]byte, error) { +func (s *fakeStore) LoadState() ([]byte, error) { s.loadCalled = true return s.loadStateRes, s.loadStateErr } @@ -623,15 +702,15 @@ type stubRecovery struct { setRecoveryDataErr error } -func (s *stubRecovery) GenerateEncryptionKey(recoveryKeys map[string]string) ([]byte, error) { +func (s *stubRecovery) GenerateEncryptionKey(_ map[string]string) ([]byte, error) { return s.generateEncryptionKeyRes, s.generateEncryptionKeyErr } -func (s *stubRecovery) GenerateRecoveryData(recoveryKeys map[string]string) (map[string][]byte, []byte, error) { +func (s *stubRecovery) GenerateRecoveryData(_ map[string]string) (map[string][]byte, []byte, error) { return s.generateRecoveryDataRes, nil, s.generateRecoveryDataErr } -func (s *stubRecovery) RecoverKey(secret []byte) (int, []byte, error) { +func (s *stubRecovery) RecoverKey(_ []byte) (int, []byte, error) { return s.recoverKeysLeft, s.recoverKeyRes, s.recoverKeyErr } @@ -639,7 +718,7 @@ func (s *stubRecovery) GetRecoveryData() ([]byte, error) { return s.getRecoveryDataRes, s.getRecoveryDataErr } -func (s *stubRecovery) SetRecoveryData(data []byte) error { +func (s *stubRecovery) SetRecoveryData(_ []byte) error { return s.setRecoveryDataErr } diff --git a/coordinator/clientapi/legacy_test.go b/coordinator/clientapi/legacy_test.go index 929a696b..578e5ebb 100644 --- a/coordinator/clientapi/legacy_test.go +++ b/coordinator/clientapi/legacy_test.go @@ -7,6 +7,7 @@ package clientapi import ( + "context" "crypto/ecdsa" "crypto/sha256" "crypto/x509" @@ -34,25 +35,26 @@ import ( func TestSetManifest_Legacy(t *testing.T) { assert := assert.New(t) require := require.New(t) + ctx := context.Background() rawManifest := []byte(test.ManifestJSON) var manifest manifest.Manifest require.NoError(json.Unmarshal(rawManifest, &manifest)) c, getter := setupAPI(t) - _, err := c.SetManifest(rawManifest) + _, err := c.SetManifest(ctx, rawManifest) assert.NoError(err, "SetManifest should succeed on first try") cManifest, err := getter.GetManifest() assert.NoError(err) assert.Equal(manifest, cManifest, "Manifest should be set correctly") - _, err = c.SetManifest(rawManifest) + _, err = c.SetManifest(ctx, rawManifest) assert.Error(err, "SetManifest should fail on the second try") cManifest, err = getter.GetManifest() assert.NoError(err) assert.Equal(manifest, cManifest, "Manifest should still be set correctly") - _, err = c.SetManifest(rawManifest[:len(rawManifest)-1]) + _, err = c.SetManifest(ctx, rawManifest[:len(rawManifest)-1]) assert.Error(err, "SetManifest should fail on broken json") cManifest, err = getter.GetManifest() assert.NoError(err) @@ -60,14 +62,14 @@ func TestSetManifest_Legacy(t *testing.T) { // use new core c, _ = setupAPI(t) - _, err = c.SetManifest(rawManifest[:len(rawManifest)-1]) + _, err = c.SetManifest(ctx, rawManifest[:len(rawManifest)-1]) assert.Error(err, "SetManifest should fail on broken json") c, getter = setupAPI(t) - _, err = c.SetManifest([]byte("")) + _, err = c.SetManifest(ctx, []byte("")) assert.Error(err, "empty string should not be accepted") - _, err = c.SetManifest(rawManifest) + _, err = c.SetManifest(ctx, rawManifest) assert.NoError(err, "SetManifest should succed after failed tries") cManifest, err = getter.GetManifest() assert.NoError(err) @@ -77,6 +79,7 @@ func TestSetManifest_Legacy(t *testing.T) { func TestSetManifestInvalid_Legacy(t *testing.T) { assert := assert.New(t) require := require.New(t) + ctx := context.Background() newTestManifest := func() *manifest.Manifest { rawManifest := []byte(test.ManifestJSON) @@ -91,7 +94,7 @@ func TestSetManifestInvalid_Legacy(t *testing.T) { modRawManifest, err := json.Marshal(manifest) require.NoError(err) - _, err = a.SetManifest(modRawManifest) + _, err = a.SetManifest(ctx, modRawManifest) assert.NoError(err) marblePackage.Debug = false @@ -109,7 +112,7 @@ func TestSetManifestInvalid_Legacy(t *testing.T) { } modRawManifest, err := json.Marshal(manifest) require.NoError(err) - _, err = a.SetManifest(modRawManifest) + _, err = a.SetManifest(ctx, modRawManifest) assert.ErrorContains(err, "manifest does not contain marble package foo") // Try setting manifest with all values unset, no debug mode (this should fail) @@ -126,7 +129,7 @@ func TestSetManifestInvalid_Legacy(t *testing.T) { manifest.Packages["backend"] = backendPackage modRawManifest, err = json.Marshal(manifest) require.NoError(err) - _, err = a.SetManifest(modRawManifest) + _, err = a.SetManifest(ctx, modRawManifest) assert.ErrorContains(err, "manifest misses value for SignerID in package backend") // Enable debug mode, should work now @@ -139,7 +142,7 @@ func TestSetManifestInvalid_Legacy(t *testing.T) { modRawManifest, err = json.Marshal(manifest) require.NoError(err) - _, err = a.SetManifest(modRawManifest) + _, err = a.SetManifest(ctx, modRawManifest) assert.ErrorContains(err, "manifest misses value for ProductID in package backend") // Enable debug mode, should work now @@ -153,7 +156,7 @@ func TestSetManifestInvalid_Legacy(t *testing.T) { modRawManifest, err = json.Marshal(manifest) require.NoError(err) - _, err = a.SetManifest(modRawManifest) + _, err = a.SetManifest(ctx, modRawManifest) assert.ErrorContains(err, "manifest misses value for SecurityVersion in package backend") // Enable debug mode, should work now @@ -167,7 +170,7 @@ func TestSetManifestInvalid_Legacy(t *testing.T) { modRawManifest, err = json.Marshal(manifest) require.NoError(err) - _, err = a.SetManifest(modRawManifest) + _, err = a.SetManifest(ctx, modRawManifest) assert.NoError(err) // Reset & enable debug mode, should also work now @@ -181,7 +184,7 @@ func TestSetManifestInvalid_Legacy(t *testing.T) { modRawManifest, err = json.Marshal(manifest) require.NoError(err) - _, err = a.SetManifest(modRawManifest) + _, err = a.SetManifest(ctx, modRawManifest) assert.ErrorContains(err, "manifest specifies both UniqueID *and* SignerID/ProductID/SecurityVersion in package backend") // Enable debug mode, should work now @@ -191,13 +194,14 @@ func TestSetManifestInvalid_Legacy(t *testing.T) { func TestGetManifestSignature_Legacy(t *testing.T) { assert := assert.New(t) require := require.New(t) + ctx := context.Background() api, data := setupAPI(t) - _, err := api.SetManifest([]byte(test.ManifestJSON)) + _, err := api.SetManifest(ctx, []byte(test.ManifestJSON)) assert.NoError(err) - sigECDSA, hash, manifest := api.GetManifestSignature() + sigECDSA, hash, manifest := api.GetManifestSignature(ctx) expectedHash := sha256.Sum256([]byte(test.ManifestJSON)) assert.Equal(expectedHash[:], hash) @@ -211,32 +215,33 @@ func TestGetSecret_Legacy(t *testing.T) { assert := assert.New(t) require := require.New(t) c, data := setupAPI(t) + ctx := context.Background() symmetricSecret := "symmetricKeyShared" certSecret := "certShared" - _, err := c.SetManifest([]byte(test.ManifestJSONWithRecoveryKey)) + _, err := c.SetManifest(ctx, []byte(test.ManifestJSONWithRecoveryKey)) require.NoError(err) secret1, err := data.GetSecret(symmetricSecret) require.NoError(err) - secret2, err := c.data.GetSecret(certSecret) + secret2, err := data.GetSecret(certSecret) require.NoError(err) admin, err := data.GetUser("admin") require.NoError(err) // requested secrets should be the same - reqSecrets, err := c.GetSecrets([]string{symmetricSecret, certSecret}, admin) + reqSecrets, err := c.GetSecrets(ctx, []string{symmetricSecret, certSecret}, admin) require.NoError(err) assert.True(len(reqSecrets) == 2) assert.Equal(secret1, reqSecrets[symmetricSecret]) assert.Equal(secret2, reqSecrets[certSecret]) // request should fail if the user lacks permissions - _, err = c.GetSecrets([]string{symmetricSecret, "restrictedSecret"}, admin) + _, err = c.GetSecrets(ctx, []string{symmetricSecret, "restrictedSecret"}, admin) assert.Error(err) // requesting a secret should return an empty secret since it was not set - sec, err := c.GetSecrets([]string{"symmetricKeyUnset"}, admin) + sec, err := c.GetSecrets(ctx, []string{"symmetricKeyUnset"}, admin) require.NoError(err) assert.Empty(sec["symmetricKeyUnset"].Public) assert.Empty(sec["symmetricKeyUnset"].Private) @@ -246,17 +251,18 @@ func TestGetStatus_Legacy(t *testing.T) { assert := assert.New(t) require := require.New(t) c, _ := setupAPI(t) + ctx := context.Background() // Server should be ready to accept a manifest after initializing a mock core - statusCode, status, err := c.GetStatus() + statusCode, status, err := c.GetStatus(ctx) assert.NoError(err, "GetStatus failed") assert.EqualValues(state.AcceptingManifest, statusCode, "We should be ready to accept a manifest now, but GetStatus tells us we don't.") assert.NotEmpty(status, "Status string was empty, but should not.") // Set a manifest, state should change - _, err = c.SetManifest([]byte(test.ManifestJSON)) + _, err = c.SetManifest(ctx, []byte(test.ManifestJSON)) require.NoError(err) - statusCode, status, err = c.GetStatus() + statusCode, status, err = c.GetStatus(ctx) assert.NoError(err, "GetStatus failed") assert.EqualValues(state.AcceptingMarbles, statusCode, "We should be ready to accept Marbles now, but GetStatus tells us we don't.") assert.NotEmpty(status, "Status string was empty, but should not.") @@ -265,13 +271,14 @@ func TestGetStatus_Legacy(t *testing.T) { func TestWriteSecrets_Legacy(t *testing.T) { assert := assert.New(t) require := require.New(t) + ctx := context.Background() symmetricSecret := "symmetricKeyUnset" certSecret := "certUnset" c, data := setupAPI(t) - _, err := c.SetManifest([]byte(test.ManifestJSONWithRecoveryKey)) + _, err := c.SetManifest(ctx, []byte(test.ManifestJSONWithRecoveryKey)) require.NoError(err) admin, err := data.GetUser("admin") @@ -289,7 +296,7 @@ func TestWriteSecrets_Legacy(t *testing.T) { assert.Empty(sec.Private) // set a secret - err = c.WriteSecrets([]byte(test.UserSecrets), admin) + err = c.WriteSecrets(ctx, []byte(test.UserSecrets), admin) require.NoError(err) secret, err := data.GetSecret(symmetricSecret) require.NoError(err) @@ -304,7 +311,7 @@ func TestWriteSecrets_Legacy(t *testing.T) { "Key": "` + base64.StdEncoding.EncodeToString([]byte("MarbleRun Unit Test")) + `" } }`) - err = c.WriteSecrets(genericSecret, admin) + err = c.WriteSecrets(ctx, genericSecret, admin) require.NoError(err) secret, err = data.GetSecret("genericSecret") require.NoError(err) @@ -316,7 +323,7 @@ func TestWriteSecrets_Legacy(t *testing.T) { "Key": "` + base64.StdEncoding.EncodeToString([]byte{0x41, 0x41, 0x00, 0x41}) + `" } }`) - err = c.WriteSecrets(genericSecret, admin) + err = c.WriteSecrets(ctx, genericSecret, admin) assert.Error(err) // try to set a secret incorrect size @@ -325,7 +332,7 @@ func TestWriteSecrets_Legacy(t *testing.T) { "Key": "` + base64.StdEncoding.EncodeToString([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) + `" } }`) - err = c.WriteSecrets(invalidSecret, admin) + err = c.WriteSecrets(ctx, invalidSecret, admin) assert.Error(err) } @@ -333,9 +340,10 @@ func TestUpdateManifest_Legacy(t *testing.T) { assert := assert.New(t) require := require.New(t) c, data := setupAPI(t) + ctx := context.Background() // Set manifest - _, err := c.SetManifest([]byte(test.ManifestJSONWithRecoveryKey)) + _, err := c.SetManifest(ctx, []byte(test.ManifestJSONWithRecoveryKey)) require.NoError(err) admin, err := data.GetUser("admin") @@ -352,7 +360,7 @@ func TestUpdateManifest_Legacy(t *testing.T) { assert.NoError(err) // Update manifest - err = c.UpdateManifest([]byte(test.UpdateManifest), admin) + err = c.UpdateManifest(ctx, []byte(test.UpdateManifest), admin) require.NoError(err) // Get new certificates @@ -405,7 +413,7 @@ func TestUpdateManifest_Legacy(t *testing.T) { assert.NoError(err) // updating the manifest should have produced an entry for "frontend" in the updatelog - updateLog, err := c.GetUpdateLog() + updateLog, err := c.GetUpdateLog(ctx) assert.NoError(err) assert.Contains(updateLog, `"package":"frontend"`) } @@ -414,10 +422,11 @@ func TestUpdateManifestInvalid_Legacy(t *testing.T) { assert := assert.New(t) require := require.New(t) c, data := setupAPI(t) + ctx := context.Background() // Good update manifests // Set manifest (frontend has SecurityVersion 3) - _, err := c.SetManifest([]byte(test.ManifestJSONWithRecoveryKey)) + _, err := c.SetManifest(ctx, []byte(test.ManifestJSONWithRecoveryKey)) require.NoError(err) cPackage, err := data.GetPackage("frontend") assert.NoError(err) @@ -425,14 +434,14 @@ func TestUpdateManifestInvalid_Legacy(t *testing.T) { // Try to update with unregistered user someUser := user.NewUser("invalid", nil) - err = c.UpdateManifest([]byte(test.UpdateManifest), someUser) + err = c.UpdateManifest(ctx, []byte(test.UpdateManifest), someUser) assert.Error(err) admin, err := data.GetUser("admin") assert.NoError(err) // Try to update manifest (frontend's SecurityVersion should rise from 3 to 5) - err = c.UpdateManifest([]byte(test.UpdateManifest), admin) + err = c.UpdateManifest(ctx, []byte(test.UpdateManifest), admin) require.NoError(err) cUpdatedPackage, err := data.GetPackage("frontend") assert.NoError(err) @@ -446,7 +455,7 @@ func TestUpdateManifestInvalid_Legacy(t *testing.T) { badUpdateManifest.Packages["nonExisting"] = badUpdateManifest.Packages["frontend"] badRawManifest, err := json.Marshal(badUpdateManifest) require.NoError(err) - err = c.UpdateManifest(badRawManifest, admin) + err = c.UpdateManifest(ctx, badRawManifest, admin) assert.Error(err) delete(badUpdateManifest.Packages, "nonExisting") @@ -457,7 +466,7 @@ func TestUpdateManifestInvalid_Legacy(t *testing.T) { badUpdateManifest.Packages["frontend"] = badModPackage badRawManifest, err = json.Marshal(badUpdateManifest) require.NoError(err) - err = c.UpdateManifest(badRawManifest, admin) + err = c.UpdateManifest(ctx, badRawManifest, admin) assert.Error(err) badModPackage.Debug = false @@ -467,7 +476,7 @@ func TestUpdateManifestInvalid_Legacy(t *testing.T) { badUpdateManifest.Packages["frontend"] = badModPackage badRawManifest, err = json.Marshal(badUpdateManifest) require.NoError(err) - err = c.UpdateManifest(badRawManifest, admin) + err = c.UpdateManifest(ctx, badRawManifest, admin) assert.Error(err) // Test if downgrading fails @@ -477,7 +486,7 @@ func TestUpdateManifestInvalid_Legacy(t *testing.T) { badUpdateManifest.Packages["frontend"] = badModPackage badRawManifest, err = json.Marshal(badUpdateManifest) require.NoError(err) - err = c.UpdateManifest(badRawManifest, admin) + err = c.UpdateManifest(ctx, badRawManifest, admin) assert.Error(err) // Test if downgrading fails @@ -487,7 +496,7 @@ func TestUpdateManifestInvalid_Legacy(t *testing.T) { badUpdateManifest.Packages["frontend"] = badModPackage badRawManifest, err = json.Marshal(badUpdateManifest) require.NoError(err) - err = c.UpdateManifest(badRawManifest, admin) + err = c.UpdateManifest(ctx, badRawManifest, admin) assert.Error(err) // Test if removing a package from a currently existing update manifest fails @@ -495,14 +504,14 @@ func TestUpdateManifestInvalid_Legacy(t *testing.T) { delete(badUpdateManifest.Packages, "frontend") badRawManifest, err = json.Marshal(badUpdateManifest) require.NoError(err) - err = c.UpdateManifest(badRawManifest, admin) + err = c.UpdateManifest(ctx, badRawManifest, admin) assert.Error(err) // Test what happens if no packages are defined at all badUpdateManifest.Packages = nil badRawManifest, err = json.Marshal(badUpdateManifest) require.NoError(err) - err = c.UpdateManifest(badRawManifest, admin) + err = c.UpdateManifest(ctx, badRawManifest, admin) assert.Error(err) } @@ -540,10 +549,11 @@ func TestUpdateDebugMarble_Legacy(t *testing.T) { }`) assert := assert.New(t) require := require.New(t) + ctx := context.Background() c, data := setupAPI(t) // Set manifest - _, err := c.SetManifest(manifest) + _, err := c.SetManifest(ctx, manifest) require.NoError(err) admin, err := data.GetUser("admin") @@ -554,7 +564,7 @@ func TestUpdateDebugMarble_Legacy(t *testing.T) { // Try to update manifest // frontend's security version, which was previously unset, should now be set to 5 - err = c.UpdateManifest([]byte(test.UpdateManifest), admin) + err = c.UpdateManifest(ctx, []byte(test.UpdateManifest), admin) require.NoError(err) updatedPackage, err := data.GetPackage("frontend") @@ -566,11 +576,12 @@ func TestVerifyUser_Legacy(t *testing.T) { assert := assert.New(t) require := require.New(t) c, _ := setupAPI(t) + ctx := context.Background() adminTestCert, otherTestCert := test.MustSetupTestCerts(test.RecoveryPrivateKey) // Set a manifest containing an admin certificate - _, err := c.SetManifest([]byte(test.ManifestJSONWithRecoveryKey)) + _, err := c.SetManifest(ctx, []byte(test.ManifestJSONWithRecoveryKey)) require.NoError(err) // Put certificates in slice, as Go's TLS library passes them in an HTTP request @@ -578,12 +589,12 @@ func TestVerifyUser_Legacy(t *testing.T) { otherTestCertSlice := []*x509.Certificate{otherTestCert} // Check if the adminTest certificate is deemed valid (stored in core), and the freshly generated one is deemed false - user, err := c.VerifyUser(adminTestCertSlice) + user, err := c.VerifyUser(ctx, adminTestCertSlice) assert.NoError(err) assert.Equal(*user.Certificate(), *adminTestCert) - _, err = c.VerifyUser(otherTestCertSlice) + _, err = c.VerifyUser(ctx, otherTestCertSlice) assert.Error(err) - _, err = c.VerifyUser(nil) + _, err = c.VerifyUser(ctx, nil) assert.Error(err) } @@ -614,7 +625,6 @@ func setupAPI(t *testing.T) (*ClientAPI, wrapper.Wrapper) { require.NoError(err) return &ClientAPI{ - data: wrapper, core: &fakeCore{ state: state.AcceptingManifest, getStateMsg: "status message", diff --git a/coordinator/core/core.go b/coordinator/core/core.go index f41efc02..59c61335 100644 --- a/coordinator/core/core.go +++ b/coordinator/core/core.go @@ -53,8 +53,7 @@ type Core struct { recovery recovery.Recovery metrics *coreMetrics - store store.Store - data wrapper.Wrapper + txHandle transactionHandle log *zap.Logger eventlog *events.Log @@ -64,9 +63,16 @@ type Core struct { // RequireState checks if the Coordinator is in one of the given states. // This function locks the Core's mutex and therefore should be paired with `defer c.mux.Unlock()`. -func (c *Core) RequireState(states ...state.State) error { +func (c *Core) RequireState(ctx context.Context, states ...state.State) error { c.mux.Lock() - curState, err := c.data.GetState() + + getter, rollback, _, err := wrapper.WrapTransaction(ctx, c.txHandle) + if err != nil { + return err + } + defer rollback() + + curState, err := getter.GetState() if err != nil { return err } @@ -79,16 +85,19 @@ func (c *Core) RequireState(states ...state.State) error { } // AdvanceState advances the state of the Coordinator. -func (c *Core) AdvanceState(newState state.State, tx store.Transaction) error { - txdata := wrapper.New(tx) - curState, err := txdata.GetState() +func (c *Core) AdvanceState(newState state.State, tx interface { + PutState(state.State) error + GetState() (state.State, error) +}, +) error { + curState, err := tx.GetState() if err != nil { return err } if !(curState < newState && newState < state.Max) { panic(fmt.Errorf("cannot advance from %d to %d", curState, newState)) } - return txdata.PutState(newState) + return tx.PutState(newState) } // Unlock the Core's mutex. @@ -97,35 +106,36 @@ func (c *Core) Unlock() { } // NewCore creates and initializes a new Core object. -func NewCore(dnsNames []string, qv quote.Validator, qi quote.Issuer, stor store.Store, recovery recovery.Recovery, zapLogger *zap.Logger, promFactory *promauto.Factory, eventlog *events.Log) (*Core, error) { +func NewCore( + dnsNames []string, qv quote.Validator, qi quote.Issuer, txHandle transactionHandle, + recovery recovery.Recovery, zapLogger *zap.Logger, promFactory *promauto.Factory, eventlog *events.Log, +) (*Core, error) { c := &Core{ qv: qv, qi: qi, recovery: recovery, - store: stor, - data: wrapper.New(stor), + txHandle: txHandle, log: zapLogger, eventlog: eventlog, } c.metrics = newCoreMetrics(promFactory, c, "coordinator") zapLogger.Info("loading state") - recoveryData, loadErr := stor.LoadState() + recoveryData, loadErr := txHandle.LoadState() if err := c.recovery.SetRecoveryData(recoveryData); err != nil { c.log.Error("Could not retrieve recovery data from state. Recovery will be unavailable", zap.Error(err)) } - tx, err := c.store.BeginTransaction() + transaction, rollback, commit, err := wrapper.WrapTransaction(context.Background(), c.txHandle) if err != nil { return nil, err } - defer tx.Rollback() - txdata := wrapper.New(tx) + defer rollback() // set core to uninitialized if no state is set - if _, err := txdata.GetState(); err != nil { + if _, err := transaction.GetState(); err != nil { if errors.Is(err, store.ErrValueUnset) { - if err := txdata.PutState(state.Uninitialized); err != nil { + if err := transaction.PutState(state.Uninitialized); err != nil { return nil, err } } else { @@ -139,34 +149,34 @@ func NewCore(dnsNames []string, qv quote.Validator, qi quote.Issuer, stor store. } // sealed state was found but couldnt be decrypted, go to recovery mode or reset manifest c.log.Error("Failed to decrypt sealed state. Processing with a new state. Use the /recover API endpoint to load an old state, or submit a new manifest to overwrite the old state. Look up the documentation for more information on how to proceed.") - if err := c.setCAData(dnsNames, tx); err != nil { + if err := c.setCAData(dnsNames, transaction); err != nil { return nil, err } - if err := c.AdvanceState(state.Recovery, tx); err != nil { + if err := c.AdvanceState(state.Recovery, transaction); err != nil { return nil, err } - } else if _, err := txdata.GetRawManifest(); errors.Is(err, store.ErrValueUnset) { + } else if _, err := transaction.GetRawManifest(); errors.Is(err, store.ErrValueUnset) { // no state was found, wait for manifest c.log.Info("No sealed state found. Proceeding with new state.") - if err := c.setCAData(dnsNames, tx); err != nil { + if err := c.setCAData(dnsNames, transaction); err != nil { return nil, err } - if err := txdata.PutState(state.AcceptingManifest); err != nil { + if err := transaction.PutState(state.AcceptingManifest); err != nil { return nil, err } } else if err != nil { return nil, err } else { // recovered from a sealed state, reload components and finish the store transaction - stor.SetRecoveryData(recoveryData) + txHandle.SetRecoveryData(recoveryData) } - if err := tx.Commit(); err != nil { + rootCert, err := transaction.GetCertificate(constants.SKCoordinatorRootCert) + if err != nil { return nil, err } - rootCert, err := c.data.GetCertificate(constants.SKCoordinatorRootCert) - if err != nil { + if err := commit(context.Background()); err != nil { return nil, err } @@ -206,8 +216,16 @@ func (c *Core) GetTLSConfig() (*tls.Config, error) { } // GetTLSRootCertificate creates a TLS certificate for the Coordinators self-signed x509 certificate. +// +// This function initializes a read transaction and should not be called from other functions with ongoing transactions. func (c *Core) GetTLSRootCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { - curState, err := c.data.GetState() + data, rollback, _, err := wrapper.WrapTransaction(clientHello.Context(), c.txHandle) + if err != nil { + return nil, err + } + defer rollback() + + curState, err := data.GetState() if err != nil { return nil, err } @@ -215,11 +233,11 @@ func (c *Core) GetTLSRootCertificate(clientHello *tls.ClientHelloInfo) (*tls.Cer return nil, errors.New("don't have a cert yet") } - rootCert, err := c.data.GetCertificate(constants.SKCoordinatorRootCert) + rootCert, err := data.GetCertificate(constants.SKCoordinatorRootCert) if err != nil { return nil, err } - rootPrivK, err := c.data.GetPrivateKey(constants.SKCoordinatorRootKey) + rootPrivK, err := data.GetPrivateKey(constants.SKCoordinatorRootKey) if err != nil { return nil, err } @@ -228,8 +246,16 @@ func (c *Core) GetTLSRootCertificate(clientHello *tls.ClientHelloInfo) (*tls.Cer } // GetTLSMarbleRootCertificate creates a TLS certificate for the Coordinator's x509 marbleRoot certificate. +// +// This function initializes a read transaction and should not be called from other functions with ongoing transactions. func (c *Core) GetTLSMarbleRootCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { - curState, err := c.data.GetState() + data, rollback, _, err := wrapper.WrapTransaction(clientHello.Context(), c.txHandle) + if err != nil { + return nil, err + } + defer rollback() + + curState, err := data.GetState() if err != nil { return nil, err } @@ -237,11 +263,11 @@ func (c *Core) GetTLSMarbleRootCertificate(clientHello *tls.ClientHelloInfo) (*t return nil, errors.New("don't have a cert yet") } - marbleRootCert, err := c.data.GetCertificate(constants.SKMarbleRootCert) + marbleRootCert, err := data.GetCertificate(constants.SKMarbleRootCert) if err != nil { return nil, err } - intermediatePrivK, err := c.data.GetPrivateKey(constants.SKCoordinatorIntermediateKey) + intermediatePrivK, err := data.GetPrivateKey(constants.SKCoordinatorIntermediateKey) if err != nil { return nil, err } @@ -290,8 +316,14 @@ func getClientTLSCert(ctx context.Context) *x509.Certificate { } // GetState returns the current state of the Coordinator. -func (c *Core) GetState() (state.State, string, error) { - curState, err := c.data.GetState() +func (c *Core) GetState(ctx context.Context) (state.State, string, error) { + data, rollback, _, err := wrapper.WrapTransaction(ctx, c.txHandle) + if err != nil { + return -1, "Cannot determine coordinator status.", fmt.Errorf("initializing read transaction: %w", err) + } + defer rollback() + + curState, err := data.GetState() if err != nil { return -1, "Cannot determine coordinator status.", err } @@ -313,15 +345,13 @@ func (c *Core) GetState() (state.State, string, error) { } // GenerateSecrets generates secrets for the given manifest and parent certificate. -func (c *Core) GenerateSecrets(secrets map[string]manifest.Secret, id uuid.UUID, parentCertificate *x509.Certificate, parentPrivKey *ecdsa.PrivateKey) (map[string]manifest.Secret, error) { +func (c *Core) GenerateSecrets( + secrets map[string]manifest.Secret, id uuid.UUID, + parentCertificate *x509.Certificate, parentPrivKey *ecdsa.PrivateKey, rootPrivK *ecdsa.PrivateKey, +) (map[string]manifest.Secret, error) { // Create a new map so we do not overwrite the entries in the manifest newSecrets := make(map[string]manifest.Secret) - rootPrivK, err := c.data.GetPrivateKey(constants.SKCoordinatorRootKey) - if err != nil { - return nil, err - } - // Generate secrets for name, secret := range secrets { // Skip user defined secrets, these will be uploaded by a user @@ -515,7 +545,11 @@ func (c *Core) generateCertificateForSecret(secret manifest.Secret, parentCertif return secret, nil } -func (c *Core) setCAData(dnsNames []string, tx store.Transaction) error { +func (c *Core) setCAData(dnsNames []string, putter interface { + PutCertificate(name string, cert *x509.Certificate) error + PutPrivateKey(name string, key *ecdsa.PrivateKey) error +}, +) error { rootCert, rootPrivK, err := corecrypto.GenerateCert(dnsNames, constants.CoordinatorName, nil, nil, nil) if err != nil { return err @@ -530,20 +564,19 @@ func (c *Core) setCAData(dnsNames []string, tx store.Transaction) error { return err } - txdata := wrapper.New(tx) - if err := txdata.PutCertificate(constants.SKCoordinatorRootCert, rootCert); err != nil { + if err := putter.PutCertificate(constants.SKCoordinatorRootCert, rootCert); err != nil { return err } - if err := txdata.PutCertificate(constants.SKCoordinatorIntermediateCert, intermediateCert); err != nil { + if err := putter.PutCertificate(constants.SKCoordinatorIntermediateCert, intermediateCert); err != nil { return err } - if err := txdata.PutCertificate(constants.SKMarbleRootCert, marbleRootCert); err != nil { + if err := putter.PutCertificate(constants.SKMarbleRootCert, marbleRootCert); err != nil { return err } - if err := txdata.PutPrivateKey(constants.SKCoordinatorRootKey, rootPrivK); err != nil { + if err := putter.PutPrivateKey(constants.SKCoordinatorRootKey, rootPrivK); err != nil { return err } - if err := txdata.PutPrivateKey(constants.SKCoordinatorIntermediateKey, intermediatePrivK); err != nil { + if err := putter.PutPrivateKey(constants.SKCoordinatorIntermediateKey, intermediatePrivK); err != nil { return err } @@ -559,3 +592,10 @@ type QuoteError struct { func (e QuoteError) Error() string { return fmt.Sprintf("failed to get quote: %v", e.err) } + +type transactionHandle interface { + BeginTransaction(context.Context) (store.Transaction, error) + SetEncryptionKey([]byte) error + SetRecoveryData([]byte) + LoadState() ([]byte, error) +} diff --git a/coordinator/core/core_test.go b/coordinator/core/core_test.go index 72c5ee2f..de19f3d1 100644 --- a/coordinator/core/core_test.go +++ b/coordinator/core/core_test.go @@ -7,8 +7,10 @@ package core import ( + "context" "crypto/ed25519" "crypto/rand" + "crypto/tls" "crypto/x509" "crypto/x509/pkix" "math/big" @@ -23,6 +25,7 @@ import ( "github.com/edgelesssys/marblerun/coordinator/seal" "github.com/edgelesssys/marblerun/coordinator/state" "github.com/edgelesssys/marblerun/coordinator/store/stdstore" + "github.com/edgelesssys/marblerun/coordinator/store/wrapper/testutil" "github.com/edgelesssys/marblerun/test" "github.com/google/uuid" "github.com/stretchr/testify/assert" @@ -39,14 +42,12 @@ func TestCore(t *testing.T) { assert := assert.New(t) c := NewCoreWithMocks() - curState, err := c.data.GetState() - assert.NoError(err) + curState := testutil.GetState(t, c.txHandle) assert.Equal(state.AcceptingManifest, curState) - rootCert, err := c.data.GetCertificate(constants.SKCoordinatorRootCert) - assert.NoError(err) + rootCert := testutil.GetCertificate(t, c.txHandle, constants.SKCoordinatorRootCert) assert.Equal(constants.CoordinatorName, rootCert.Subject.CommonName) - cert, err := c.GetTLSRootCertificate(nil) + cert, err := c.GetTLSRootCertificate(&tls.ClientHelloInfo{}) assert.NoError(err) assert.NotNil(cert) @@ -58,6 +59,7 @@ func TestCore(t *testing.T) { func TestSeal(t *testing.T) { assert := assert.New(t) require := require.New(t) + ctx := context.Background() // setup mock zaplogger which can be passed to Core zapLogger, err := zap.NewDevelopment() @@ -73,42 +75,39 @@ func TestSeal(t *testing.T) { require.NoError(err) // Set manifest. This will seal the state. - clientAPI, err := clientapi.New(c.store, c.recovery, c, zapLogger) + clientAPI, err := clientapi.New(c.txHandle, c.recovery, c, zapLogger) require.NoError(err) - _, err = clientAPI.SetManifest([]byte(test.ManifestJSON)) + _, err = clientAPI.SetManifest(ctx, []byte(test.ManifestJSON)) require.NoError(err) // Get certificate and signature. - cert, err := c.GetTLSRootCertificate(nil) + cert, err := c.GetTLSRootCertificate(&tls.ClientHelloInfo{}) assert.NoError(err) - signatureRootECDSA, signature, _ := clientAPI.GetManifestSignature() + signatureRootECDSA, signature, _ := clientAPI.GetManifestSignature(ctx) // Get secrets - cSecrets, err := c.data.GetSecretMap() - assert.NoError(err) + cSecrets := testutil.GetSecretMap(t, c.txHandle) // Check sealing with a new core initialized with the sealed state. c2, err := NewCore([]string{"localhost"}, validator, issuer, stdstore.New(sealer), recovery, zapLogger, nil, nil) require.NoError(err) - clientAPI, err = clientapi.New(c2.store, c2.recovery, c2, zapLogger) + clientAPI, err = clientapi.New(c2.txHandle, c2.recovery, c2, zapLogger) require.NoError(err) - c2State, err := c2.data.GetState() - assert.NoError(err) + c2State := testutil.GetState(t, c2.txHandle) assert.Equal(state.AcceptingMarbles, c2State) - cert2, err := c2.GetTLSRootCertificate(nil) + cert2, err := c2.GetTLSRootCertificate(&tls.ClientHelloInfo{}) assert.NoError(err) assert.Equal(cert, cert2) - _, err = clientAPI.SetManifest([]byte(test.ManifestJSON)) + _, err = clientAPI.SetManifest(ctx, []byte(test.ManifestJSON)) assert.Error(err) // Check if the secret specified in the test manifest is unsealed correctly - c2Secrets, err := c2.data.GetSecretMap() - assert.NoError(err) + c2Secrets := testutil.GetSecretMap(t, c2.txHandle) assert.Equal(cSecrets, c2Secrets) - signatureRootECDSA2, signature2, _ := clientAPI.GetManifestSignature() + signatureRootECDSA2, signature2, _ := clientAPI.GetManifestSignature(ctx) assert.Equal(signature, signature2, "manifest signature differs after restart") assert.Equal(signatureRootECDSA, signatureRootECDSA2, "manifest signature root ecdsa differs after restart") } @@ -116,6 +115,7 @@ func TestSeal(t *testing.T) { func TestRecover(t *testing.T) { assert := assert.New(t) require := require.New(t) + ctx := context.Background() // setup mock zaplogger which can be passed to Core zapLogger, err := zap.NewDevelopment() @@ -129,20 +129,20 @@ func TestRecover(t *testing.T) { c, err := NewCore([]string{"localhost"}, validator, issuer, stdstore.New(sealer), recovery, zapLogger, nil, nil) require.NoError(err) - clientAPI, err := clientapi.New(c.store, c.recovery, c, zapLogger) + clientAPI, err := clientapi.New(c.txHandle, c.recovery, c, zapLogger) require.NoError(err) // new core does not allow recover key := make([]byte, 16) - _, err = clientAPI.Recover(key) + _, err = clientAPI.Recover(ctx, key) assert.Error(err) // Set manifest. This will seal the state. - _, err = clientAPI.SetManifest([]byte(test.ManifestJSON)) + _, err = clientAPI.SetManifest(ctx, []byte(test.ManifestJSON)) require.NoError(err) // core does not allow recover after manifest has been set - _, err = clientAPI.Recover(key) + _, err = clientAPI.Recover(ctx, key) assert.Error(err) // Initialize new core and let unseal fail @@ -150,17 +150,15 @@ func TestRecover(t *testing.T) { c2, err := NewCore([]string{"localhost"}, validator, issuer, stdstore.New(sealer), recovery, zapLogger, nil, nil) sealer.UnsealError = nil require.NoError(err) - clientAPI, err = clientapi.New(c2.store, c2.recovery, c2, zapLogger) + clientAPI, err = clientapi.New(c2.txHandle, c2.recovery, c2, zapLogger) require.NoError(err) - c2State, err := c2.data.GetState() - assert.NoError(err) + c2State := testutil.GetState(t, c2.txHandle) require.Equal(state.Recovery, c2State) // recover - _, err = clientAPI.Recover(key) - assert.NoError(err) - c2State, err = c2.data.GetState() + _, err = clientAPI.Recover(ctx, key) assert.NoError(err) + c2State = testutil.GetState(t, c2.txHandle) assert.Equal(state.AcceptingMarbles, c2State) } @@ -202,13 +200,11 @@ func TestGenerateSecrets(t *testing.T) { c := NewCoreWithMocks() - rootCert, err := c.data.GetCertificate(constants.SKCoordinatorRootCert) - assert.NoError(err) - rootPrivK, err := c.data.GetPrivateKey(constants.SKCoordinatorRootKey) - assert.NoError(err) + rootCert := testutil.GetCertificate(t, c.txHandle, constants.SKCoordinatorRootCert) + rootPrivK := testutil.GetPrivateKey(t, c.txHandle, constants.SKCoordinatorRootKey) // This should return valid secrets - generatedSecrets, err := c.GenerateSecrets(secretsToGenerate, uuid.Nil, rootCert, rootPrivK) + generatedSecrets, err := c.GenerateSecrets(secretsToGenerate, uuid.Nil, rootCert, rootPrivK, rootPrivK) require.NoError(err) // Check if rawTest1 has 128 Bits/16 Bytes and rawTest2 256 Bits/8 Bytes assert.Len(generatedSecrets["rawTest1"].Public, 16) @@ -228,7 +224,7 @@ func TestGenerateSecrets(t *testing.T) { // Make sure a certificate gets a new serial number if its regenerated firstSerial := generatedSecrets["cert-rsa-test"].Cert.SerialNumber - secondGeneration, err := c.GenerateSecrets(generatedSecrets, uuid.Nil, rootCert, rootPrivK) + secondGeneration, err := c.GenerateSecrets(generatedSecrets, uuid.Nil, rootCert, rootPrivK, rootPrivK) assert.NoError(err) assert.NotEqualValues(*firstSerial, *secondGeneration["cert-rsa-test"].Cert.SerialNumber) @@ -270,31 +266,31 @@ func TestGenerateSecrets(t *testing.T) { assert.NoError(err) // Check if we get an empty secret map as output for an empty map as input - generatedSecrets, err = c.GenerateSecrets(secretsEmptyMap, uuid.Nil, rootCert, rootPrivK) + generatedSecrets, err = c.GenerateSecrets(secretsEmptyMap, uuid.Nil, rootCert, rootPrivK, rootPrivK) require.NoError(err) assert.IsType(map[string]manifest.Secret{}, generatedSecrets) assert.Len(generatedSecrets, 0) // Check if we get an empty secret map as output for nil - generatedSecrets, err = c.GenerateSecrets(nil, uuid.Nil, rootCert, rootPrivK) + generatedSecrets, err = c.GenerateSecrets(nil, uuid.Nil, rootCert, rootPrivK, rootPrivK) require.NoError(err) assert.IsType(map[string]manifest.Secret{}, generatedSecrets) assert.Len(generatedSecrets, 0) // If no size is specified, the function should fail - _, err = c.GenerateSecrets(secretsNoSize, uuid.Nil, rootCert, rootPrivK) + _, err = c.GenerateSecrets(secretsNoSize, uuid.Nil, rootCert, rootPrivK, rootPrivK) assert.Error(err) // Also, it should fail if we try to generate a secret with an unknown type - _, err = c.GenerateSecrets(secretsInvalidType, uuid.Nil, rootCert, rootPrivK) + _, err = c.GenerateSecrets(secretsInvalidType, uuid.Nil, rootCert, rootPrivK, rootPrivK) assert.Error(err) // If Ed25519 key size is specified, we should fail - _, err = c.GenerateSecrets(secretsEd25519WrongKeySize, uuid.Nil, rootCert, rootPrivK) + _, err = c.GenerateSecrets(secretsEd25519WrongKeySize, uuid.Nil, rootCert, rootPrivK, rootPrivK) assert.Error(err) // However, for ECDSA we fail as we can have multiple curves - _, err = c.GenerateSecrets(secretsECDSAWrongKeySize, uuid.Nil, rootCert, rootPrivK) + _, err = c.GenerateSecrets(secretsECDSAWrongKeySize, uuid.Nil, rootCert, rootPrivK, rootPrivK) assert.Error(err) } @@ -313,20 +309,15 @@ func TestUnsetRestart(t *testing.T) { // create a new core, this seals the state with only certificate and keys c1, err := NewCore([]string{"localhost"}, validator, issuer, stdstore.New(sealer), recovery, zapLogger, nil, nil) require.NoError(err) - c1State, err := c1.data.GetState() - assert.NoError(err) + c1State := testutil.GetState(t, c1.txHandle) assert.Equal(state.AcceptingManifest, c1State) - cCert, err := c1.data.GetCertificate(constants.SKCoordinatorRootCert) - assert.NoError(err) + cCert := testutil.GetCertificate(t, c1.txHandle, constants.SKCoordinatorRootCert) // create a second core, this should overwrite the previously sealed certificate and keys since no manifest was set c2, err := NewCore([]string{"localhost"}, validator, issuer, stdstore.New(sealer), recovery, zapLogger, nil, nil) require.NoError(err) - c2State, err := c2.data.GetState() - assert.NoError(err) + c2State := testutil.GetState(t, c2.txHandle) assert.Equal(state.AcceptingManifest, c2State) - c2Cert, err := c2.data.GetCertificate(constants.SKCoordinatorRootCert) - assert.NoError(err) - + c2Cert := testutil.GetCertificate(t, c2.txHandle, constants.SKCoordinatorRootCert) assert.NotEqual(*cCert, *c2Cert) } diff --git a/coordinator/core/marbleapi.go b/coordinator/core/marbleapi.go index f0aadd8b..676e7277 100644 --- a/coordinator/core/marbleapi.go +++ b/coordinator/core/marbleapi.go @@ -61,7 +61,7 @@ func (c *Core) Activate(ctx context.Context, req *rpc.ActivationReq) (res *rpc.A c.metrics.marbleAPI.activation.WithLabelValues(req.GetMarbleType(), req.GetUUID()).Inc() defer c.mux.Unlock() - if err := c.RequireState(state.AcceptingMarbles); err != nil { + if err := c.RequireState(ctx, state.AcceptingMarbles); err != nil { return nil, status.Error(codes.FailedPrecondition, "cannot accept marbles in current state") } @@ -77,7 +77,15 @@ func (c *Core) Activate(ctx context.Context, req *rpc.ActivationReq) (res *rpc.A c.log.Error("Couldn't get marble TLS certificate") return nil, status.Error(codes.Unauthenticated, "couldn't get marble TLS certificate") } - if err := c.verifyManifestRequirement(tlsCert, req.GetQuote(), req.GetMarbleType()); err != nil { + + txdata, rollback, commit, err := wrapper.WrapTransaction(ctx, c.txHandle) + if err != nil { + c.log.Error("Initialize store transaction failed", zap.Error(err)) + return nil, status.Errorf(codes.Internal, "initializing store transaction: %s", err) + } + defer rollback() + + if err := c.verifyManifestRequirement(txdata, tlsCert, req.GetQuote(), req.GetMarbleType()); err != nil { c.log.Error("Marble verification failed", zap.Error(err)) return nil, status.Errorf(codes.PermissionDenied, "marble verification failed: %s", err) } @@ -89,31 +97,36 @@ func (c *Core) Activate(ctx context.Context, req *rpc.ActivationReq) (res *rpc.A } // Generate marble authentication secrets - authSecrets, err := c.generateMarbleAuthSecrets(req, marbleUUID) + authSecrets, err := c.generateMarbleAuthSecrets(txdata, req, marbleUUID) if err != nil { c.log.Error("Generating marble authentication secrets failed", zap.Error(err)) return nil, status.Errorf(codes.Internal, "generating marble authentication secrets: %s", err) } - marbleRootCert, err := c.data.GetCertificate(constants.SKMarbleRootCert) + marbleRootCert, err := txdata.GetCertificate(constants.SKMarbleRootCert) if err != nil { c.log.Error("Couldn't retrieve marble root certificate", zap.Error(err)) return nil, status.Errorf(codes.Internal, "retrieving marbleRootCert certificate: %s", err) } - intermediatePrivK, err := c.data.GetPrivateKey(constants.SKCoordinatorIntermediateKey) + rootPrivK, err := txdata.GetPrivateKey(constants.SKCoordinatorRootKey) + if err != nil { + c.log.Error("Couldn't retrieve marbleRootCert private key", zap.Error(err)) + return nil, status.Errorf(codes.Internal, "retrieving marble root private key: %s", err) + } + intermediatePrivK, err := txdata.GetPrivateKey(constants.SKCoordinatorIntermediateKey) if err != nil { c.log.Error("Couldn't retrieve marbleRootCert private key", zap.Error(err)) return nil, status.Errorf(codes.Internal, "retrieving marble root private key: %s", err) } - secrets, err := c.data.GetSecretMap() + secrets, err := txdata.GetSecretMap() if err != nil { c.log.Error("Loading secrets from store failed", zap.Error(err)) return nil, status.Errorf(codes.Internal, "retrieving secrets: %s", err) } // Generate unique (= per marble) secrets - privateSecrets, err := c.GenerateSecrets(secrets, marbleUUID, marbleRootCert, intermediatePrivK) + privateSecrets, err := c.GenerateSecrets(secrets, marbleUUID, marbleRootCert, intermediatePrivK, rootPrivK) if err != nil { c.log.Error("Couldn't generate specified secrets for the given manifest", zap.Error(err)) return nil, status.Errorf(codes.Internal, "generating secrets for marble: %s", err) @@ -124,14 +137,14 @@ func (c *Core) Activate(ctx context.Context, req *rpc.ActivationReq) (res *rpc.A secrets[k] = v } - marble, err := c.data.GetMarble(req.MarbleType) + marble, err := txdata.GetMarble(req.MarbleType) if err != nil { c.log.Error("Loading marble config failed", zap.Error(err)) return nil, status.Errorf(codes.Internal, "retrieving marble config: %s", err) } // add TTLS config to Env - if err := c.setTTLSConfig(marble, authSecrets, secrets); err != nil { + if err := c.setTTLSConfig(txdata, marble, authSecrets, secrets); err != nil { c.log.Error("Couldn't create TTLS config", zap.Error(err)) return nil, status.Errorf(codes.Internal, "creating TTLS config: %s", err) } @@ -147,20 +160,16 @@ func (c *Core) Activate(ctx context.Context, req *rpc.ActivationReq) (res *rpc.A Parameters: params, } - tx, err := c.store.BeginTransaction() - if err != nil { - c.log.Error("Initialize store transaction failed", zap.Error(err)) - return nil, status.Errorf(codes.Internal, "initializing store transaction: %s", err) - } - defer tx.Rollback() - - if err := (wrapper.New(tx)).IncrementActivations(req.GetMarbleType()); err != nil { - c.log.Error("Could not increment activations", zap.Error(err)) - return nil, status.Errorf(codes.Internal, "incrementing marble activations: %s", err) - } - if err := tx.Commit(); err != nil { - c.log.Error("Committing store transaction failed", zap.Error(err)) - return nil, status.Errorf(codes.Internal, "committing store transaction: %s", err) + // We only need to commit any data to the store if we have a limit on the number of activations + if marble.MaxActivations > 0 { + if err := txdata.IncrementActivations(req.GetMarbleType()); err != nil { + c.log.Error("Could not increment activations", zap.Error(err)) + return nil, status.Errorf(codes.Internal, "incrementing marble activations: %s", err) + } + if err := commit(ctx); err != nil { + c.log.Error("Committing store transaction failed", zap.Error(err)) + return nil, status.Errorf(codes.Internal, "committing store transaction: %s", err) + } } c.metrics.marbleAPI.activationSuccess.WithLabelValues(req.GetMarbleType(), req.GetUUID()).Inc() @@ -174,8 +183,8 @@ func (c *Core) Activate(ctx context.Context, req *rpc.ActivationReq) (res *rpc.A } // verifyManifestRequirement verifies marble attempting to register with respect to manifest. -func (c *Core) verifyManifestRequirement(tlsCert *x509.Certificate, certQuote []byte, marbleType string) error { - marble, err := c.data.GetMarble(marbleType) +func (c *Core) verifyManifestRequirement(txdata storeGetter, tlsCert *x509.Certificate, certQuote []byte, marbleType string) error { + marble, err := txdata.GetMarble(marbleType) if err != nil { if errors.Is(err, store.ErrValueUnset) { return fmt.Errorf("unknown marble type requested") @@ -183,7 +192,7 @@ func (c *Core) verifyManifestRequirement(tlsCert *x509.Certificate, certQuote [] return fmt.Errorf("loading marble data: %w", err) } - pkg, err := c.data.GetPackage(marble.Package) + pkg, err := txdata.GetPackage(marble.Package) if err != nil { if errors.Is(err, store.ErrValueUnset) { return fmt.Errorf("undefined package %q", marble.Package) @@ -191,7 +200,7 @@ func (c *Core) verifyManifestRequirement(tlsCert *x509.Certificate, certQuote [] return fmt.Errorf("loading package data: %w", err) } - infraIter, err := c.data.GetIterator(request.Infrastructure) + infraIter, err := txdata.GetIterator(request.Infrastructure) if err != nil { return fmt.Errorf("getting infrastructure iterator: %w", err) } @@ -208,7 +217,7 @@ func (c *Core) verifyManifestRequirement(tlsCert *x509.Certificate, certQuote [] if err != nil { return err } - infra, err := c.data.GetInfrastructure(infraName) + infra, err := txdata.GetInfrastructure(infraName) if err != nil { return fmt.Errorf("loading infrastructure: %w", err) } @@ -224,7 +233,7 @@ func (c *Core) verifyManifestRequirement(tlsCert *x509.Certificate, certQuote [] } // check activation budget (MaxActivations == 0 means infinite budget) - activations, err := c.data.GetActivations(marbleType) + activations, err := txdata.GetActivations(marbleType) if err != nil { return fmt.Errorf("could not retrieve activations for marble type %q: %w", marbleType, err) } @@ -235,7 +244,7 @@ func (c *Core) verifyManifestRequirement(tlsCert *x509.Certificate, certQuote [] } // generateCertFromCSR signs the CSR from marble attempting to register. -func (c *Core) generateCertFromCSR(csrReq []byte, pubk ecdsa.PublicKey, marbleUUID string) ([]byte, error) { +func (c *Core) generateCertFromCSR(txdata storeGetter, csrReq []byte, pubk ecdsa.PublicKey, marbleUUID string) ([]byte, error) { // parse and verify CSR csr, err := x509.ParseCertificateRequest(csrReq) if err != nil { @@ -250,11 +259,11 @@ func (c *Core) generateCertFromCSR(csrReq []byte, pubk ecdsa.PublicKey, marbleUU return nil, fmt.Errorf("generating certificate serial number: %w", err) } - marbleRootCert, err := c.data.GetCertificate(constants.SKMarbleRootCert) + marbleRootCert, err := txdata.GetCertificate(constants.SKMarbleRootCert) if err != nil { return nil, fmt.Errorf("loading marble root certificate: %w", err) } - intermediatePrivK, err := c.data.GetPrivateKey(constants.SKCoordinatorIntermediateKey) + intermediatePrivK, err := txdata.GetPrivateKey(constants.SKCoordinatorIntermediateKey) if err != nil { return nil, fmt.Errorf("loading marble root certificate private key: %w", err) } @@ -367,7 +376,7 @@ func parseSecrets(data string, tplFunc template.FuncMap, secretsWrapped secretsW return templateResult.String(), nil } -func (c *Core) generateMarbleAuthSecrets(req *rpc.ActivationReq, marbleUUID uuid.UUID) (reservedSecrets, error) { +func (c *Core) generateMarbleAuthSecrets(txdata storeGetter, req *rpc.ActivationReq, marbleUUID uuid.UUID) (reservedSecrets, error) { // generate key-pair for marble privk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { @@ -383,7 +392,7 @@ func (c *Core) generateMarbleAuthSecrets(req *rpc.ActivationReq, marbleUUID uuid } // Generate Marble certificate - certRaw, err := c.generateCertFromCSR(req.GetCSR(), privk.PublicKey, marbleUUID.String()) + certRaw, err := c.generateCertFromCSR(txdata, req.GetCSR(), privk.PublicKey, marbleUUID.String()) if err != nil { return reservedSecrets{}, err } @@ -393,7 +402,7 @@ func (c *Core) generateMarbleAuthSecrets(req *rpc.ActivationReq, marbleUUID uuid return reservedSecrets{}, err } - marbleRootCert, err := c.data.GetCertificate(constants.SKMarbleRootCert) + marbleRootCert, err := txdata.GetCertificate(constants.SKMarbleRootCert) if err != nil { return reservedSecrets{}, err } @@ -406,7 +415,7 @@ func (c *Core) generateMarbleAuthSecrets(req *rpc.ActivationReq, marbleUUID uuid return authSecrets, nil } -func (c *Core) setTTLSConfig(marble manifest.Marble, specialSecrets reservedSecrets, userSecrets map[string]manifest.Secret) error { +func (c *Core) setTTLSConfig(txdata storeGetter, marble manifest.Marble, specialSecrets reservedSecrets, userSecrets map[string]manifest.Secret) error { if len(marble.TLS) == 0 { return nil } @@ -416,7 +425,7 @@ func (c *Core) setTTLSConfig(marble manifest.Marble, specialSecrets reservedSecr ttlsConf["tls"]["Incoming"] = make(map[string]map[string]interface{}) ttlsConf["tls"]["Outgoing"] = make(map[string]map[string]interface{}) - marbleRootCert, err := c.data.GetCertificate(constants.SKMarbleRootCert) + marbleRootCert, err := txdata.GetCertificate(constants.SKMarbleRootCert) if err != nil { return err } @@ -431,7 +440,7 @@ func (c *Core) setTTLSConfig(marble manifest.Marble, specialSecrets reservedSecr stringClientKey := string(pem.EncodeToMemory(&pemClientKey)) for _, tagName := range marble.TLS { - tag, err := c.data.GetTLS(tagName) + tag, err := txdata.GetTLS(tagName) if err != nil { return err } @@ -480,3 +489,17 @@ func (c *Core) setTTLSConfig(marble manifest.Marble, specialSecrets reservedSecr return nil } + +type storeGetter interface { + GetActivations(name string) (uint, error) + GetCertificate(name string) (*x509.Certificate, error) + GetInfrastructure(name string) (quote.InfrastructureProperties, error) + GetIterator(prefix string) (wrapper.Iterator, error) + GetManifest() (manifest.Manifest, error) + GetMarble(marble string) (manifest.Marble, error) + GetPackage(name string) (quote.PackageProperties, error) + GetPrivateKey(name string) (*ecdsa.PrivateKey, error) + GetSecretMap() (map[string]manifest.Secret, error) + GetSecret(name string) (manifest.Secret, error) + GetTLS(name string) (manifest.TLStag, error) +} diff --git a/coordinator/core/marbleapi_test.go b/coordinator/core/marbleapi_test.go index f4fb3ada..a01c9777 100644 --- a/coordinator/core/marbleapi_test.go +++ b/coordinator/core/marbleapi_test.go @@ -29,6 +29,7 @@ import ( "github.com/edgelesssys/marblerun/coordinator/seal" "github.com/edgelesssys/marblerun/coordinator/state" "github.com/edgelesssys/marblerun/coordinator/store/stdstore" + "github.com/edgelesssys/marblerun/coordinator/store/wrapper/testutil" "github.com/edgelesssys/marblerun/test" "github.com/edgelesssys/marblerun/util" "github.com/google/uuid" @@ -71,19 +72,19 @@ func TestActivate(t *testing.T) { } // try to activate first backend marble prematurely before manifest is set - spawner.newMarble("backendFirst", "Azure", false) + spawner.newMarble(t, "backendFirst", "Azure", false) // set manifest - clientAPI, err := clientapi.New(coreServer.store, coreServer.recovery, coreServer, zapLogger) + clientAPI, err := clientapi.New(coreServer.txHandle, coreServer.recovery, coreServer, zapLogger) require.NoError(err) - _, err = clientAPI.SetManifest([]byte(test.ManifestJSON)) + _, err = clientAPI.SetManifest(context.Background(), []byte(test.ManifestJSON)) require.NoError(err) // activate first backend - spawner.newMarble("backendFirst", "Azure", true) + spawner.newMarble(t, "backendFirst", "Azure", true) // try to activate another first backend - spawner.newMarble("backendFirst", "Azure", false) + spawner.newMarble(t, "backendFirst", "Azure", false) // activate 10 other backend pickInfra := func(i int) string { @@ -93,12 +94,12 @@ func TestActivate(t *testing.T) { return "Alibaba" } for i := 0; i < 10; i++ { - spawner.newMarbleAsync("backendOther", pickInfra(i), true) + spawner.newMarbleAsync(t, "backendOther", pickInfra(i), true) } // activate 10 frontend for i := 0; i < 10; i++ { - spawner.newMarbleAsync("frontend", pickInfra(i), true) + spawner.newMarbleAsync(t, "frontend", pickInfra(i), true) } spawner.wg.Wait() @@ -123,7 +124,7 @@ type marbleSpawner struct { backendOtherUniqueCert x509.Certificate } -func (ms *marbleSpawner) newMarble(marbleType string, infraName string, shouldSucceed bool) string { +func (ms *marbleSpawner) newMarble(t *testing.T, marbleType string, infraName string, shouldSucceed bool) string { cert, csr, _ := util.MustGenerateTestMarbleCredentials() // create mock quote using values from the manifest @@ -144,7 +145,7 @@ func (ms *marbleSpawner) newMarble(marbleType string, infraName string, shouldSu }, } - ctx := peer.NewContext(context.TODO(), &peer.Peer{ + ctx := peer.NewContext(context.Background(), &peer.Peer{ AuthInfo: tlsInfo, }) @@ -204,12 +205,9 @@ func (ms *marbleSpawner) newMarble(marbleType string, infraName string, shouldSu ms.assert.Equal(cert.DNSNames, newLeafCert.DNSNames) ms.assert.Equal(cert.IPAddresses, newLeafCert.IPAddresses) - rootCert, err := ms.coreServer.data.GetCertificate(constants.SKCoordinatorRootCert) - ms.assert.NoError(err) - intermediateCert, err := ms.coreServer.data.GetCertificate(constants.SKCoordinatorIntermediateCert) - ms.assert.NoError(err) - marbleRootCert, err := ms.coreServer.data.GetCertificate(constants.SKMarbleRootCert) - ms.assert.NoError(err) + rootCert := testutil.GetCertificate(t, ms.coreServer.txHandle, constants.SKCoordinatorRootCert) + intermediateCert := testutil.GetCertificate(t, ms.coreServer.txHandle, constants.SKCoordinatorIntermediateCert) + marbleRootCert := testutil.GetCertificate(t, ms.coreServer.txHandle, constants.SKMarbleRootCert) // Check Signature for both, intermediate certificate and leaf certificate ms.assert.NoError(rootCert.CheckSignature(intermediateCert.SignatureAlgorithm, intermediateCert.RawTBSCertificate, intermediateCert.Signature)) ms.assert.NoError(newMarbleRootCert.CheckSignature(newMarbleRootCert.SignatureAlgorithm, newMarbleRootCert.RawTBSCertificate, newMarbleRootCert.Signature)) @@ -294,10 +292,10 @@ func (ms *marbleSpawner) newMarble(marbleType string, infraName string, shouldSu return uuidStr } -func (ms *marbleSpawner) newMarbleAsync(marbleType string, infraName string, shouldSucceed bool) { +func (ms *marbleSpawner) newMarbleAsync(t *testing.T, marbleType string, infraName string, shouldSucceed bool) { ms.wg.Add(1) go func() { - ms.newMarble(marbleType, infraName, shouldSucceed) + ms.newMarble(t, marbleType, infraName, shouldSucceed) ms.wg.Done() }() } @@ -452,6 +450,7 @@ func TestParseSecrets(t *testing.T) { func TestSecurityLevelUpdate(t *testing.T) { assert := assert.New(t) require := require.New(t) + ctx := context.Background() // parse manifest var manifest manifest.Manifest @@ -480,40 +479,37 @@ func TestSecurityLevelUpdate(t *testing.T) { coreServer: coreServer, } // set manifest - clientAPI, err := clientapi.New(coreServer.store, coreServer.recovery, coreServer, zapLogger) + clientAPI, err := clientapi.New(coreServer.txHandle, coreServer.recovery, coreServer, zapLogger) require.NoError(err) - _, err = clientAPI.SetManifest([]byte(test.ManifestJSONWithRecoveryKey)) + _, err = clientAPI.SetManifest(ctx, []byte(test.ManifestJSONWithRecoveryKey)) require.NoError(err) - admin, err := coreServer.data.GetUser("admin") - assert.NoError(err) + admin := testutil.GetUser(t, coreServer.txHandle, "admin") // try to activate another first backend, should succeed as SecurityLevel matches the definition in the manifest - spawner.newMarble("frontend", "Azure", true) + spawner.newMarble(t, "frontend", "Azure", true) // update manifest - err = clientAPI.UpdateManifest([]byte(test.UpdateManifest), admin) + err = clientAPI.UpdateManifest(ctx, []byte(test.UpdateManifest), admin) require.NoError(err) // try to activate another first backend, should fail as required SecurityLevel is now higher after manifest update - spawner.newMarble("frontend", "Azure", false) + spawner.newMarble(t, "frontend", "Azure", false) // Use a new core and test if updated manifest persisted after restart coreServer2, err := NewCore([]string{"localhost"}, validator, issuer, stdstore.New(sealer), recovery, zapLogger, nil, nil) require.NoError(err) - coreServer2State, err := coreServer2.data.GetState() - assert.NoError(err) - coreServer2UpdatedPkg, err := coreServer2.data.GetPackage("frontend") - assert.NoError(err) + coreServer2State := testutil.GetState(t, coreServer2.txHandle) + coreServer2UpdatedPkg := testutil.GetPackage(t, coreServer2.txHandle, "frontend") assert.Equal(state.AcceptingMarbles, coreServer2State) assert.EqualValues(5, *coreServer2UpdatedPkg.SecurityVersion) // This should still fail after a restart, as the update manifest should have been reloaded from the sealed state correctly spawner.coreServer = coreServer2 - spawner.newMarble("frontend", "Azure", false) + spawner.newMarble(t, "frontend", "Azure", false) } -func (ms *marbleSpawner) shortMarbleActivation(marbleType string, infraName string) { +func (ms *marbleSpawner) shortMarbleActivation(t *testing.T, marbleType string, infraName string) { cert, csr, _ := util.MustGenerateTestMarbleCredentials() // create mock quote using values from the manifest @@ -534,7 +530,7 @@ func (ms *marbleSpawner) shortMarbleActivation(marbleType string, infraName stri }, } - ctx := peer.NewContext(context.TODO(), &peer.Peer{ + ctx := peer.NewContext(context.Background(), &peer.Peer{ AuthInfo: tlsInfo, }) @@ -551,8 +547,7 @@ func (ms *marbleSpawner) shortMarbleActivation(marbleType string, infraName stri // Validate response params := resp.GetParameters() // Get the marble from the manifest set on the coreServer since this one sets default values for empty values - coreServerManifest, err := ms.coreServer.data.GetManifest() - ms.assert.NoError(err) + coreServerManifest := testutil.GetManifest(t, ms.coreServer.txHandle) marble = coreServerManifest.Marbles[marbleType] // Validate Files for k, v := range marble.Parameters.Files { @@ -593,10 +588,10 @@ func TestActivateWithMissingParameters(t *testing.T) { coreServer: coreServer, } // set manifest - clientAPI, err := clientapi.New(coreServer.store, coreServer.recovery, coreServer, zapLogger) + clientAPI, err := clientapi.New(coreServer.txHandle, coreServer.recovery, coreServer, zapLogger) require.NoError(err) - _, err = clientAPI.SetManifest([]byte(test.ManifestJSONMissingParameters)) + _, err = clientAPI.SetManifest(context.Background(), []byte(test.ManifestJSONMissingParameters)) require.NoError(err) - spawner.shortMarbleActivation("frontend", "Azure") + spawner.shortMarbleActivation(t, "frontend", "Azure") } diff --git a/coordinator/core/metrics.go b/coordinator/core/metrics.go index ddc2e2f4..0b111ea9 100644 --- a/coordinator/core/metrics.go +++ b/coordinator/core/metrics.go @@ -7,6 +7,8 @@ package core import ( + "context" + "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" @@ -33,7 +35,7 @@ func newCoreMetrics(factory *promauto.Factory, core *Core, namespace string) *co Help: "State of the Coordinator.", }, func() float64 { - state, err := core.data.GetState() + state, _, err := core.GetState(context.Background()) if err != nil { return float64(0) } diff --git a/coordinator/core/metrics_test.go b/coordinator/core/metrics_test.go index 07e0fd6b..cf6ff5fa 100644 --- a/coordinator/core/metrics_test.go +++ b/coordinator/core/metrics_test.go @@ -7,6 +7,7 @@ package core import ( + "context" "encoding/json" "testing" @@ -17,6 +18,7 @@ import ( "github.com/edgelesssys/marblerun/coordinator/seal" "github.com/edgelesssys/marblerun/coordinator/state" "github.com/edgelesssys/marblerun/coordinator/store/stdstore" + "github.com/edgelesssys/marblerun/coordinator/store/wrapper/testutil" "github.com/edgelesssys/marblerun/test" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" @@ -29,6 +31,7 @@ import ( func TestStoreWrapperMetrics(t *testing.T) { assert := assert.New(t) require := require.New(t) + ctx := context.Background() zapLogger, err := zap.NewDevelopment() require.NoError(err) @@ -47,9 +50,9 @@ func TestStoreWrapperMetrics(t *testing.T) { assert.Equal(1, promtest.CollectAndCount(c.metrics.coordinatorState)) assert.Equal(float64(state.AcceptingManifest), promtest.ToFloat64(c.metrics.coordinatorState)) - clientAPI, err := clientapi.New(c.store, c.recovery, c, zapLogger) + clientAPI, err := clientapi.New(c.txHandle, c.recovery, c, zapLogger) require.NoError(err) - _, err = clientAPI.SetManifest([]byte(test.ManifestJSON)) + _, err = clientAPI.SetManifest(ctx, []byte(test.ManifestJSON)) require.NoError(err) assert.Equal(1, promtest.CollectAndCount(c.metrics.coordinatorState)) assert.Equal(float64(state.AcceptingMarbles), promtest.ToFloat64(c.metrics.coordinatorState)) @@ -66,14 +69,13 @@ func TestStoreWrapperMetrics(t *testing.T) { assert.Equal(1, promtest.CollectAndCount(c.metrics.coordinatorState)) assert.Equal(float64(state.Recovery), promtest.ToFloat64(c.metrics.coordinatorState)) - clientAPI, err = clientapi.New(c.store, c.recovery, c, zapLogger) + clientAPI, err = clientapi.New(c.txHandle, c.recovery, c, zapLogger) require.NoError(err) key := make([]byte, 16) - _, err = clientAPI.Recover(key) - require.NoError(err) - state, err := c.data.GetState() + _, err = clientAPI.Recover(ctx, key) require.NoError(err) + state := testutil.GetState(t, c.txHandle) assert.Equal(1, promtest.CollectAndCount(c.metrics.coordinatorState)) assert.Equal(float64(state), promtest.ToFloat64(c.metrics.coordinatorState)) } @@ -116,27 +118,27 @@ func TestMarbleAPIMetrics(t *testing.T) { } // try to activate first backend marble prematurely before manifest is set - uuid := spawner.newMarble("backendFirst", "Azure", false) + uuid := spawner.newMarble(t, "backendFirst", "Azure", false) promtest.CollectAndCount(metrics.activation) promtest.CollectAndCount(metrics.activationSuccess) assert.Equal(float64(1), promtest.ToFloat64(metrics.activation.WithLabelValues("backendFirst", uuid))) assert.Equal(float64(0), promtest.ToFloat64(metrics.activationSuccess.WithLabelValues("backendFirst", uuid))) // set manifest - clientAPI, err := clientapi.New(c.store, c.recovery, c, zapLogger) + clientAPI, err := clientapi.New(c.txHandle, c.recovery, c, zapLogger) require.NoError(err) - _, err = clientAPI.SetManifest([]byte(test.ManifestJSON)) + _, err = clientAPI.SetManifest(context.Background(), []byte(test.ManifestJSON)) require.NoError(err) // activate first backend - uuid = spawner.newMarble("backendFirst", "Azure", true) + uuid = spawner.newMarble(t, "backendFirst", "Azure", true) promtest.CollectAndCount(metrics.activation) promtest.CollectAndCount(metrics.activationSuccess) assert.Equal(float64(1), promtest.ToFloat64(metrics.activation.WithLabelValues("backendFirst", uuid))) assert.Equal(float64(1), promtest.ToFloat64(metrics.activationSuccess.WithLabelValues("backendFirst", uuid))) // try to activate another first backend - uuid = spawner.newMarble("backendFirst", "Azure", false) + uuid = spawner.newMarble(t, "backendFirst", "Azure", false) promtest.CollectAndCount(metrics.activation) promtest.CollectAndCount(metrics.activationSuccess) assert.Equal(float64(1), promtest.ToFloat64(metrics.activation.WithLabelValues("backendFirst", uuid))) diff --git a/coordinator/core/openssl_test.go b/coordinator/core/openssl_test.go index f7cf3e51..df4f1d5f 100644 --- a/coordinator/core/openssl_test.go +++ b/coordinator/core/openssl_test.go @@ -52,16 +52,16 @@ func TestOpenSSLVerify(t *testing.T) { // create core validator := quote.NewMockValidator() issuer := quote.NewMockIssuer() - store := stdstore.New(&seal.MockSealer{}) + stor := stdstore.New(&seal.MockSealer{}) recovery := recovery.NewSinglePartyRecovery() - coreServer, err := NewCore([]string{"localhost"}, validator, issuer, store, recovery, zapLogger, nil, nil) + coreServer, err := NewCore([]string{"localhost"}, validator, issuer, stor, recovery, zapLogger, nil, nil) require.NoError(err) require.NotNil(coreServer) // set manifest - clientAPI, err := clientapi.New(coreServer.store, coreServer.recovery, coreServer, zapLogger) + clientAPI, err := clientapi.New(coreServer.txHandle, coreServer.recovery, coreServer, zapLogger) require.NoError(err) - _, err = clientAPI.SetManifest([]byte(test.ManifestJSON)) + _, err = clientAPI.SetManifest(context.Background(), []byte(test.ManifestJSON)) require.NoError(err) // create marble @@ -86,7 +86,7 @@ func TestOpenSSLVerify(t *testing.T) { }, } - ctx := peer.NewContext(context.TODO(), &peer.Peer{ + ctx := peer.NewContext(context.Background(), &peer.Peer{ AuthInfo: tlsInfo, }) diff --git a/coordinator/server/client_api.go b/coordinator/server/client_api.go index d500b01f..7271e43f 100644 --- a/coordinator/server/client_api.go +++ b/coordinator/server/client_api.go @@ -86,7 +86,7 @@ type clientAPIServer struct { // 200: StatusResponse // 500: ErrorResponse func (s *clientAPIServer) statusGet(w http.ResponseWriter, r *http.Request) { - statusCode, status, err := s.api.GetStatus() + statusCode, status, err := s.api.GetStatus(r.Context()) if err != nil { writeJSONError(w, err.Error(), http.StatusInternalServerError) return @@ -130,7 +130,7 @@ func (s *clientAPIServer) statusGet(w http.ResponseWriter, r *http.Request) { // 200: ManifestResponse // 500: ErrorResponse func (s *clientAPIServer) manifestGet(w http.ResponseWriter, r *http.Request) { - signatureRootECDSA, signature, manifest := s.api.GetManifestSignature() + signatureRootECDSA, signature, manifest := s.api.GetManifestSignature(r.Context()) writeJSON(w, ManifestSignatureResp{ ManifestSignatureRootECDSA: signatureRootECDSA, ManifestSignature: hex.EncodeToString(signature), @@ -161,7 +161,7 @@ func (s *clientAPIServer) manifestPost(w http.ResponseWriter, r *http.Request) { writeJSONError(w, err.Error(), http.StatusInternalServerError) return } - recoverySecretMap, err := s.api.SetManifest(manifest) + recoverySecretMap, err := s.api.SetManifest(r.Context(), manifest) if err != nil { writeJSONError(w, err.Error(), http.StatusBadRequest) return @@ -217,7 +217,7 @@ func (s *clientAPIServer) manifestPost(w http.ResponseWriter, r *http.Request) { // 200: CertQuoteResponse // 500: ErrorResponse func (s *clientAPIServer) quoteGet(w http.ResponseWriter, r *http.Request) { - cert, quote, err := s.api.GetCertQuote() + cert, quote, err := s.api.GetCertQuote(r.Context()) if err != nil { writeJSONError(w, err.Error(), http.StatusInternalServerError) return @@ -250,7 +250,7 @@ func (s *clientAPIServer) recoverPost(w http.ResponseWriter, r *http.Request) { } // Perform recover and receive amount of remaining secrets (for multi-party recovery) - remaining, err := s.api.Recover(key) + remaining, err := s.api.Recover(r.Context(), key) if err != nil { writeJSONError(w, err.Error(), http.StatusInternalServerError) return @@ -277,7 +277,7 @@ func (s *clientAPIServer) recoverPost(w http.ResponseWriter, r *http.Request) { // 200: UpdateLogResponse // 500: ErrorResponse func (s *clientAPIServer) updateGet(w http.ResponseWriter, r *http.Request) { - updateLog, err := s.api.GetUpdateLog() + updateLog, err := s.api.GetUpdateLog(r.Context()) if err != nil { writeJSONError(w, err.Error(), http.StatusInternalServerError) } @@ -311,7 +311,7 @@ func (s *clientAPIServer) updatePost(w http.ResponseWriter, r *http.Request) { writeJSONError(w, err.Error(), http.StatusInternalServerError) return } - err = s.api.UpdateManifest(updateManifest, user) + err = s.api.UpdateManifest(r.Context(), updateManifest, user) if err != nil { writeJSONError(w, err.Error(), http.StatusBadRequest) return @@ -362,7 +362,7 @@ func (s *clientAPIServer) secretsGet(w http.ResponseWriter, r *http.Request) { return } } - response, err := s.api.GetSecrets(requestedSecrets, user) + response, err := s.api.GetSecrets(r.Context(), requestedSecrets, user) if err != nil { writeJSONError(w, err.Error(), http.StatusBadRequest) return @@ -402,7 +402,7 @@ func (s *clientAPIServer) secretsPost(w http.ResponseWriter, r *http.Request) { writeJSONError(w, err.Error(), http.StatusInternalServerError) return } - if err := s.api.WriteSecrets(secretManifest, user); err != nil { + if err := s.api.WriteSecrets(r.Context(), secretManifest, user); err != nil { writeJSONError(w, err.Error(), http.StatusBadRequest) return } @@ -433,7 +433,7 @@ func (s *clientAPIServer) verifyUser(w http.ResponseWriter, r *http.Request) *us writeJSONError(w, "no client certificate provided", http.StatusUnauthorized) return nil } - verifiedUser, err := s.api.VerifyUser(r.TLS.PeerCertificates) + verifiedUser, err := s.api.VerifyUser(r.Context(), r.TLS.PeerCertificates) if err != nil { writeJSONError(w, "unauthorized user", http.StatusUnauthorized) return nil diff --git a/coordinator/server/server.go b/coordinator/server/server.go index d7c129cb..eb4d37e0 100644 --- a/coordinator/server/server.go +++ b/coordinator/server/server.go @@ -8,6 +8,7 @@ package server import ( + "context" "crypto/tls" "crypto/x509" "net" @@ -32,16 +33,16 @@ import ( ) type clientAPI interface { - SetManifest(rawManifest []byte) (recoverySecretMap map[string][]byte, err error) - GetCertQuote() (cert string, certQuote []byte, err error) - GetManifestSignature() (manifestSignatureRootECDSA, manifestSignature, manifest []byte) - GetSecrets(requestedSecrets []string, requestUser *user.User) (map[string]manifest.Secret, error) - GetStatus() (statusCode state.State, status string, err error) - GetUpdateLog() (updateLog string, err error) - Recover(encryptionKey []byte) (int, error) - VerifyUser(clientCerts []*x509.Certificate) (*user.User, error) - UpdateManifest(rawUpdateManifest []byte, updater *user.User) error - WriteSecrets(rawSecretManifest []byte, updater *user.User) error + SetManifest(ctx context.Context, rawManifest []byte) (recoverySecretMap map[string][]byte, err error) + GetCertQuote(context.Context) (cert string, certQuote []byte, err error) + GetManifestSignature(context.Context) (manifestSignatureRootECDSA, manifestSignature, manifest []byte) + GetSecrets(ctx context.Context, requestedSecrets []string, requestUser *user.User) (map[string]manifest.Secret, error) + GetStatus(context.Context) (statusCode state.State, status string, err error) + GetUpdateLog(context.Context) (updateLog string, err error) + Recover(ctx context.Context, encryptionKey []byte) (int, error) + VerifyUser(ctx context.Context, clientCerts []*x509.Certificate) (*user.User, error) + UpdateManifest(ctx context.Context, rawUpdateManifest []byte, updater *user.User) error + WriteSecrets(ctx context.Context, rawSecretManifest []byte, updater *user.User) error } // RunMarbleServer starts a gRPC with the given Coordinator core. diff --git a/coordinator/server/server_test.go b/coordinator/server/server_test.go index c3680fc0..0fb71d68 100644 --- a/coordinator/server/server_test.go +++ b/coordinator/server/server_test.go @@ -7,6 +7,7 @@ package server import ( + "context" "crypto/tls" "crypto/x509" "encoding/base64" @@ -56,7 +57,7 @@ func TestManifest(t *testing.T) { mux.ServeHTTP(resp, req) require.Equal(http.StatusOK, resp.Code) - sigRootECDSA, sig, manifest := c.GetManifestSignature() + sigRootECDSA, sig, manifest := c.GetManifestSignature(context.Background()) assert.JSONEq(`{"status":"success","data":{"ManifestSignatureRootECDSA":"`+base64.StdEncoding.EncodeToString(sigRootECDSA)+`","ManifestSignature":"`+hex.EncodeToString(sig)+`","Manifest":"`+base64.StdEncoding.EncodeToString(manifest)+`"}}`, resp.Body.String()) // try setting manifest again, should fail @@ -103,7 +104,7 @@ func TestGetUpdateLog(t *testing.T) { // Setup mock core and set a manifest c := newTestClientAPI(t) - _, err := c.SetManifest([]byte(test.ManifestJSONWithRecoveryKey)) + _, err := c.SetManifest(context.Background(), []byte(test.ManifestJSONWithRecoveryKey)) require.NoError(err) mux := CreateServeMux(c, nil) @@ -120,7 +121,7 @@ func TestUpdate(t *testing.T) { // Setup mock core and set a manifest c := newTestClientAPI(t) - _, err := c.SetManifest([]byte(test.ManifestJSONWithRecoveryKey)) + _, err := c.SetManifest(context.Background(), []byte(test.ManifestJSONWithRecoveryKey)) require.NoError(err) mux := CreateServeMux(c, nil) @@ -137,7 +138,7 @@ func TestReadSecret(t *testing.T) { // Setup mock core and set a manifest c := newTestClientAPI(t) - _, err := c.SetManifest([]byte(test.ManifestJSONWithRecoveryKey)) + _, err := c.SetManifest(context.Background(), []byte(test.ManifestJSONWithRecoveryKey)) require.NoError(err) mux := CreateServeMux(c, nil) @@ -154,7 +155,7 @@ func TestSetSecret(t *testing.T) { // Setup mock core and set a manifest c := newTestClientAPI(t) - _, err := c.SetManifest([]byte(test.ManifestJSONWithRecoveryKey)) + _, err := c.SetManifest(context.Background(), []byte(test.ManifestJSONWithRecoveryKey)) require.NoError(err) mux := CreateServeMux(c, nil) diff --git a/coordinator/store/stdstore/stdstore.go b/coordinator/store/stdstore/stdstore.go index b7bc4e74..b73c18a3 100644 --- a/coordinator/store/stdstore/stdstore.go +++ b/coordinator/store/stdstore/stdstore.go @@ -7,6 +7,7 @@ package stdstore import ( + "context" "encoding/json" "fmt" "strings" @@ -57,7 +58,7 @@ func (s *StdStore) Put(request string, requestData []byte) error { if err := tx.Put(request, requestData); err != nil { return err } - return tx.Commit() + return tx.Commit(context.Background()) } // Delete removes a value from StdStore. @@ -70,7 +71,7 @@ func (s *StdStore) Delete(request string) error { if err := tx.Delete(request); err != nil { return err } - return tx.Commit() + return tx.Commit(context.Background()) } // Iterator returns an iterator for keys saved in StdStore with a given prefix. @@ -87,7 +88,7 @@ func (s *StdStore) Iterator(prefix string) (store.Iterator, error) { } // BeginTransaction starts a new transaction. -func (s *StdStore) BeginTransaction() (store.Transaction, error) { +func (s *StdStore) BeginTransaction(_ context.Context) (store.Transaction, error) { return s.beginTransaction() } @@ -199,7 +200,7 @@ func (t *StdTransaction) Iterator(prefix string) (store.Iterator, error) { } // Commit ends a transaction and persists the changes. -func (t *StdTransaction) Commit() error { +func (t *StdTransaction) Commit(_ context.Context) error { if err := t.store.commit(t.data); err != nil { return err } diff --git a/coordinator/store/stdstore/stdstore_test.go b/coordinator/store/stdstore/stdstore_test.go index fe9bb6fa..f3111f7a 100644 --- a/coordinator/store/stdstore/stdstore_test.go +++ b/coordinator/store/stdstore/stdstore_test.go @@ -7,6 +7,7 @@ package stdstore import ( + "context" "testing" "github.com/edgelesssys/marblerun/coordinator/seal" @@ -16,6 +17,7 @@ import ( func TestStdStore(t *testing.T) { assert := assert.New(t) + ctx := context.Background() str := New(&seal.MockSealer{}) _, err := str.LoadState() @@ -29,11 +31,11 @@ func TestStdStore(t *testing.T) { assert.Error(err) // test Put method - tx, err := str.BeginTransaction() + tx, err := str.BeginTransaction(ctx) assert.NoError(err) assert.NoError(tx.Put("test:input", testData1)) assert.NoError(tx.Put("another:input", testData2)) - assert.NoError(tx.Commit()) + assert.NoError(tx.Commit(ctx)) // make sure values have been set val, err := str.Get("test:input") @@ -123,6 +125,7 @@ func TestStdStoreSealing(t *testing.T) { func TestStdStoreRollback(t *testing.T) { assert := assert.New(t) + ctx := context.Background() store := New(&seal.MockSealer{}) _, err := store.LoadState() @@ -133,13 +136,13 @@ func TestStdStoreRollback(t *testing.T) { testData3 := []byte("and even more data") // save data to store and seal - tx, err := store.BeginTransaction() + tx, err := store.BeginTransaction(ctx) assert.NoError(err) assert.NoError(tx.Put("test:input", testData1)) - assert.NoError(tx.Commit()) + assert.NoError(tx.Commit(ctx)) // save more data to store - tx, err = store.BeginTransaction() + tx, err = store.BeginTransaction(ctx) assert.NoError(err) assert.NoError(tx.Put("another:input", testData2)) @@ -152,10 +155,10 @@ func TestStdStoreRollback(t *testing.T) { assert.Error(err) // save something new - tx, err = store.BeginTransaction() + tx, err = store.BeginTransaction(ctx) assert.NoError(err) assert.NoError(tx.Put("last:input", testData3)) - assert.NoError(tx.Commit()) + assert.NoError(tx.Commit(ctx)) // verify values val, err = store.Get("test:input") diff --git a/coordinator/store/store.go b/coordinator/store/store.go index bbe1ae7a..a4d2435d 100644 --- a/coordinator/store/store.go +++ b/coordinator/store/store.go @@ -7,19 +7,14 @@ package store import ( + "context" "errors" ) // Store is the interface for persistence. type Store interface { // BeginTransaction starts a new transaction. - BeginTransaction() (Transaction, error) - // Get returns a value from store by key. - Get(string) ([]byte, error) - // Put saves a value to store by key. - Put(string, []byte) error - // Iterator returns an Iterator for a given prefix. - Iterator(string) (Iterator, error) + BeginTransaction(context.Context) (Transaction, error) // SetEncryptionKey sets the encryption key for the store. SetEncryptionKey([]byte) error // SetRecoveryData sets recovery data for the store. @@ -39,7 +34,7 @@ type Transaction interface { // Iterator returns an Iterator for a given prefix Iterator(string) (Iterator, error) // Commit ends a transaction and persists the changes - Commit() error + Commit(context.Context) error // Rollback aborts a transaction. Noop if already committed. Rollback() } diff --git a/coordinator/store/wrapper/testutil/testutil.go b/coordinator/store/wrapper/testutil/testutil.go new file mode 100644 index 00000000..6f07428e --- /dev/null +++ b/coordinator/store/wrapper/testutil/testutil.go @@ -0,0 +1,142 @@ +// Copyright (c) Edgeless Systems GmbH. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +// Package testutil provides utility functions to access store values in unit tests. +package testutil + +import ( + "context" + "crypto/ecdsa" + "crypto/x509" + "testing" + + "github.com/edgelesssys/marblerun/coordinator/manifest" + "github.com/edgelesssys/marblerun/coordinator/quote" + "github.com/edgelesssys/marblerun/coordinator/state" + "github.com/edgelesssys/marblerun/coordinator/store" + "github.com/edgelesssys/marblerun/coordinator/store/wrapper" + "github.com/edgelesssys/marblerun/coordinator/user" + "github.com/stretchr/testify/require" +) + +type transactionHandle interface { + BeginTransaction(context.Context) (store.Transaction, error) +} + +// GetActivations returns the number of activations for a given Marble. +func GetActivations(t *testing.T, txHandle transactionHandle, name string) uint { + return get(t, txHandle, func(tx wrapper.Wrapper) (uint, error) { + return tx.GetActivations(name) + }) +} + +// GetCertificate returns the certificate with the given name. +func GetCertificate(t *testing.T, txHandle transactionHandle, name string) *x509.Certificate { + return get(t, txHandle, func(tx wrapper.Wrapper) (*x509.Certificate, error) { + return tx.GetCertificate(name) + }) +} + +// GetInfrastructure returns infrastructure information. +func GetInfrastructure(t *testing.T, txHandle transactionHandle, name string) quote.InfrastructureProperties { + return get(t, txHandle, func(tx wrapper.Wrapper) (quote.InfrastructureProperties, error) { + return tx.GetInfrastructure(name) + }) +} + +// GetMarble returns the marble with the given name. +func GetMarble(t *testing.T, txHandle transactionHandle, name string) manifest.Marble { + return get(t, txHandle, func(tx wrapper.Wrapper) (manifest.Marble, error) { + return tx.GetMarble(name) + }) +} + +// GetPackage returns the package with the given name. +func GetPackage(t *testing.T, txHandle transactionHandle, name string) quote.PackageProperties { + return get(t, txHandle, func(tx wrapper.Wrapper) (quote.PackageProperties, error) { + return tx.GetPackage(name) + }) +} + +// GetPrivateKey returns the private key with the given name. +func GetPrivateKey(t *testing.T, txHandle transactionHandle, name string) *ecdsa.PrivateKey { + return get(t, txHandle, func(tx wrapper.Wrapper) (*ecdsa.PrivateKey, error) { + return tx.GetPrivateKey(name) + }) +} + +// GetManifest returns the manifest. +func GetManifest(t *testing.T, txHandle transactionHandle) manifest.Manifest { + return get(t, txHandle, func(tx wrapper.Wrapper) (manifest.Manifest, error) { + return tx.GetManifest() + }) +} + +// GetRawManifest returns the raw manifest. +func GetRawManifest(t *testing.T, txHandle transactionHandle) []byte { + return get(t, txHandle, func(tx wrapper.Wrapper) ([]byte, error) { + return tx.GetRawManifest() + }) +} + +// GetManifestSignature returns the manifest signature. +func GetManifestSignature(t *testing.T, txHandle transactionHandle) []byte { + return get(t, txHandle, func(tx wrapper.Wrapper) ([]byte, error) { + return tx.GetManifestSignature() + }) +} + +// GetSecret returns the secret with the given name. +func GetSecret(t *testing.T, txHandle transactionHandle, name string) manifest.Secret { + return get(t, txHandle, func(tx wrapper.Wrapper) (manifest.Secret, error) { + return tx.GetSecret(name) + }) +} + +// GetSecretMap returns a map of all secrets in the store. +func GetSecretMap(t *testing.T, txHandle transactionHandle) map[string]manifest.Secret { + return get(t, txHandle, func(tx wrapper.Wrapper) (map[string]manifest.Secret, error) { + return tx.GetSecretMap() + }) +} + +// GetState returns the current state of the store. +func GetState(t *testing.T, txHandle transactionHandle) state.State { + return get(t, txHandle, func(tx wrapper.Wrapper) (state.State, error) { + return tx.GetState() + }) +} + +// GetTLS returns the TLS config with the given name. +func GetTLS(t *testing.T, txHandle transactionHandle, name string) manifest.TLStag { + return get(t, txHandle, func(tx wrapper.Wrapper) (manifest.TLStag, error) { + return tx.GetTLS(name) + }) +} + +// GetUpdateLog returns the update log. +func GetUpdateLog(t *testing.T, txHandle transactionHandle) string { + return get(t, txHandle, func(tx wrapper.Wrapper) (string, error) { + return tx.GetUpdateLog() + }) +} + +// GetUser returns the user with the given name. +func GetUser(t *testing.T, txHandle transactionHandle, name string) *user.User { + return get(t, txHandle, func(tx wrapper.Wrapper) (*user.User, error) { + return tx.GetUser(name) + }) +} + +func get[T any](t *testing.T, txHandle transactionHandle, getter func(wrapper.Wrapper) (T, error)) T { + t.Helper() + tx, rollback, _, err := wrapper.WrapTransaction(context.Background(), txHandle) + require.NoError(t, err) + defer rollback() + val, err := getter(tx) + require.NoError(t, err) + return val +} diff --git a/coordinator/store/wrapper/wrapper.go b/coordinator/store/wrapper/wrapper.go index 0a6401e0..ac2c8b21 100644 --- a/coordinator/store/wrapper/wrapper.go +++ b/coordinator/store/wrapper/wrapper.go @@ -7,6 +7,7 @@ package wrapper import ( + "context" "crypto/ecdsa" "crypto/x509" "encoding/json" @@ -22,6 +23,17 @@ import ( "github.com/edgelesssys/marblerun/coordinator/user" ) +// WrapTransaction initializes a transaction using the given handle, +// and returns a wrapper for the transaction, as well as rollback and commit functions. +func WrapTransaction(ctx context.Context, txHandle transactionHandle, +) (wrapper Wrapper, rollback func(), commit func(context.Context) error, err error) { + tx, err := txHandle.BeginTransaction(ctx) + if err != nil { + return Wrapper{}, nil, nil, err + } + return New(tx), tx.Rollback, tx.Commit, nil +} + // Wrapper wraps store functions to provide a more convenient interface, // and provides a type-safe way to access the store. type Wrapper struct { @@ -318,3 +330,7 @@ type dataStore interface { // Iterator returns an Iterator for a given prefix Iterator(string) (store.Iterator, error) } + +type transactionHandle interface { + BeginTransaction(context.Context) (store.Transaction, error) +} diff --git a/coordinator/store/wrapper/wrapper_test.go b/coordinator/store/wrapper/wrapper_test.go index 9bf03ded..6d695dec 100644 --- a/coordinator/store/wrapper/wrapper_test.go +++ b/coordinator/store/wrapper/wrapper_test.go @@ -7,6 +7,7 @@ package wrapper import ( + "context" "testing" "github.com/edgelesssys/marblerun/coordinator/constants" @@ -30,6 +31,7 @@ func TestMain(m *testing.M) { func TestStoreWrapper(t *testing.T) { assert := assert.New(t) require := require.New(t) + ctx := context.Background() store := stdstore.New(&seal.MockSealer{}) rawManifest := []byte(test.ManifestJSON) @@ -47,7 +49,7 @@ func TestStoreWrapper(t *testing.T) { // save values to store data := New(store) - tx, err := store.BeginTransaction() + tx, err := store.BeginTransaction(ctx) assert.NoError(err) txdata := New(tx) assert.NoError(txdata.PutCertificate("some-cert", someCert)) @@ -56,7 +58,7 @@ func TestStoreWrapper(t *testing.T) { assert.NoError(txdata.PutSecret("test-secret", testSecret)) assert.NoError(txdata.PutState(curState)) assert.NoError(txdata.PutUser(testUser)) - assert.NoError(tx.Commit()) + assert.NoError(tx.Commit(ctx)) // see if we can retrieve them again savedCert, err := data.GetCertificate("some-cert") @@ -82,17 +84,18 @@ func TestStoreWrapper(t *testing.T) { func TestStoreWrapperRollback(t *testing.T) { assert := assert.New(t) require := require.New(t) + ctx := context.Background() stor := stdstore.New(&seal.MockSealer{}) data := New(stor) startingState := state.AcceptingManifest - tx, err := stor.BeginTransaction() + tx, err := stor.BeginTransaction(ctx) require.NoError(err) require.NoError(New(tx).PutState(state.AcceptingManifest)) - require.NoError(tx.Commit()) + require.NoError(tx.Commit(ctx)) - tx, err = stor.BeginTransaction() + tx, err = stor.BeginTransaction(ctx) require.NoError(err) require.NoError(New(tx).PutState(state.AcceptingMarbles)) require.NoError(New(tx).PutRawManifest([]byte("manifes")))