diff --git a/cli/cmd/common.go b/cli/cmd/common.go index 5c875d5db7..ffb914cb60 100644 --- a/cli/cmd/common.go +++ b/cli/cmd/common.go @@ -6,11 +6,18 @@ package cmd import ( "context" _ "embed" + "fmt" + "log/slog" "os" "path/filepath" "time" "github.com/edgelesssys/contrast/cli/telemetry" + "github.com/edgelesssys/contrast/internal/atls" + "github.com/edgelesssys/contrast/internal/attestation/snp" + "github.com/edgelesssys/contrast/internal/fsstore" + "github.com/edgelesssys/contrast/internal/logger" + "github.com/edgelesssys/contrast/internal/manifest" "github.com/spf13/cobra" ) @@ -72,3 +79,27 @@ func withTelemetry(runFunc func(*cobra.Command, []string) error) func(*cobra.Com return cmdErr } } + +// validatorsFromManifest returns a list of validators corresponding to the reference values in the given manifest. +func validatorsFromManifest(m *manifest.Manifest, log *slog.Logger, hostData []byte) ([]atls.Validator, error) { + kdsDir, err := cachedir("kds") + if err != nil { + return nil, fmt.Errorf("getting cache dir: %w", err) + } + log.Debug("Using KDS cache dir", "dir", kdsDir) + kdsCache := fsstore.New(kdsDir, log.WithGroup("kds-cache")) + kdsGetter := snp.NewCachedHTTPSGetter(kdsCache, snp.NeverGCTicker, log.WithGroup("kds-getter")) + + opts, err := m.SNPValidateOpts() + if err != nil { + return nil, fmt.Errorf("getting SNP validate options: %w", err) + } + + var validators []atls.Validator + for _, opt := range opts { + validators = append(validators, snp.NewValidator(opt, []manifest.HexString{manifest.NewHexString(hostData)}, kdsGetter, + logger.NewWithAttrs(logger.NewNamed(log, "validator"), map[string]string{"tee-type": "snp"}), + )) + } + return validators, nil +} diff --git a/cli/cmd/recover.go b/cli/cmd/recover.go index 46ab1a7a05..d6be38d8a2 100644 --- a/cli/cmd/recover.go +++ b/cli/cmd/recover.go @@ -11,10 +11,7 @@ import ( "path/filepath" "github.com/edgelesssys/contrast/internal/atls" - "github.com/edgelesssys/contrast/internal/attestation/snp" - "github.com/edgelesssys/contrast/internal/fsstore" "github.com/edgelesssys/contrast/internal/grpc/dialer" - "github.com/edgelesssys/contrast/internal/logger" "github.com/edgelesssys/contrast/internal/manifest" "github.com/edgelesssys/contrast/internal/userapi" "github.com/spf13/cobra" @@ -76,22 +73,12 @@ func runRecover(cmd *cobra.Command, _ []string) error { return fmt.Errorf("decrypting seed: %w", err) } - kdsDir, err := cachedir("kds") + validators, err := validatorsFromManifest(&m, log, flags.policy) if err != nil { - return fmt.Errorf("getting cache dir: %w", err) + return fmt.Errorf("getting validators: %w", err) } - log.Debug("Using KDS cache dir", "dir", kdsDir) - validateOptsGen, err := newCoordinatorValidateOptsGen(m, flags.policy) - if err != nil { - return fmt.Errorf("generating validate opts: %w", err) - } - kdsCache := fsstore.New(kdsDir, log.WithGroup("kds-cache")) - kdsGetter := snp.NewCachedHTTPSGetter(kdsCache, snp.NeverGCTicker, log.WithGroup("kds-getter")) - validator := snp.NewValidator(validateOptsGen, kdsGetter, - logger.NewWithAttrs(logger.NewNamed(log, "validator"), map[string]string{"tee-type": "snp"}), - ) - dialer := dialer.NewWithKey(atls.NoIssuer, validator, &net.Dialer{}, workloadOwnerKey) + dialer := dialer.NewWithKey(atls.NoIssuer, validators, &net.Dialer{}, workloadOwnerKey) log.Debug("Dialing coordinator", "endpoint", flags.coordinator) conn, err := dialer.Dial(cmd.Context(), flags.coordinator) diff --git a/cli/cmd/set.go b/cli/cmd/set.go index b74400f7d5..cffe98ae13 100644 --- a/cli/cmd/set.go +++ b/cli/cmd/set.go @@ -18,11 +18,8 @@ import ( "time" "github.com/edgelesssys/contrast/internal/atls" - "github.com/edgelesssys/contrast/internal/attestation/snp" - "github.com/edgelesssys/contrast/internal/fsstore" "github.com/edgelesssys/contrast/internal/grpc/dialer" grpcRetry "github.com/edgelesssys/contrast/internal/grpc/retry" - "github.com/edgelesssys/contrast/internal/logger" "github.com/edgelesssys/contrast/internal/manifest" "github.com/edgelesssys/contrast/internal/retry" "github.com/edgelesssys/contrast/internal/spinner" @@ -101,22 +98,11 @@ func runSet(cmd *cobra.Command, args []string) error { return fmt.Errorf("checking policies match manifest: %w", err) } - kdsDir, err := cachedir("kds") + validators, err := validatorsFromManifest(&m, log, flags.policy) if err != nil { - return fmt.Errorf("getting cache dir: %w", err) + return fmt.Errorf("getting validators: %w", err) } - log.Debug("Using KDS cache dir", "dir", kdsDir) - - validateOptsGen, err := newCoordinatorValidateOptsGen(m, flags.policy) - if err != nil { - return fmt.Errorf("generating validate opts: %w", err) - } - kdsCache := fsstore.New(kdsDir, log.WithGroup("kds-cache")) - kdsGetter := snp.NewCachedHTTPSGetter(kdsCache, snp.NeverGCTicker, log.WithGroup("kds-getter")) - validator := snp.NewValidator(validateOptsGen, kdsGetter, - logger.NewWithAttrs(logger.NewNamed(log, "validator"), map[string]string{"tee-type": "snp"}), - ) - dialer := dialer.NewWithKey(atls.NoIssuer, validator, &net.Dialer{}, workloadOwnerKey) + dialer := dialer.NewWithKey(atls.NoIssuer, validators, &net.Dialer{}, workloadOwnerKey) conn, err := dialer.Dial(cmd.Context(), flags.coordinator) if err != nil { diff --git a/cli/cmd/verify.go b/cli/cmd/verify.go index adb06568c5..40d2b5ac8b 100644 --- a/cli/cmd/verify.go +++ b/cli/cmd/verify.go @@ -13,10 +13,7 @@ import ( "path/filepath" "github.com/edgelesssys/contrast/internal/atls" - "github.com/edgelesssys/contrast/internal/attestation/snp" - "github.com/edgelesssys/contrast/internal/fsstore" "github.com/edgelesssys/contrast/internal/grpc/dialer" - "github.com/edgelesssys/contrast/internal/logger" "github.com/edgelesssys/contrast/internal/manifest" "github.com/edgelesssys/contrast/internal/userapi" "github.com/spf13/cobra" @@ -71,22 +68,11 @@ func runVerify(cmd *cobra.Command, _ []string) error { return fmt.Errorf("validating manifest: %w", err) } - kdsDir, err := cachedir("kds") + validators, err := validatorsFromManifest(&m, log, flags.policy) if err != nil { - return fmt.Errorf("getting cache dir: %w", err) + return fmt.Errorf("getting validators: %w", err) } - log.Debug("Using KDS cache dir", "dir", kdsDir) - - validateOptsGen, err := newCoordinatorValidateOptsGen(m, flags.policy) - if err != nil { - return fmt.Errorf("generating validate opts: %w", err) - } - kdsCache := fsstore.New(kdsDir, log.WithGroup("kds-cache")) - kdsGetter := snp.NewCachedHTTPSGetter(kdsCache, snp.NeverGCTicker, log.WithGroup("kds-getter")) - validator := snp.NewValidator(validateOptsGen, kdsGetter, - logger.NewWithAttrs(logger.NewNamed(log, "validator"), map[string]string{"tee-type": "snp"}), - ) - dialer := dialer.New(atls.NoIssuer, validator, &net.Dialer{}) + dialer := dialer.New(atls.NoIssuer, validators, &net.Dialer{}) log.Debug("Dialing coordinator", "endpoint", flags.coordinator) conn, err := dialer.Dial(cmd.Context(), flags.coordinator) @@ -174,17 +160,6 @@ func parseVerifyFlags(cmd *cobra.Command) (*verifyFlags, error) { }, nil } -func newCoordinatorValidateOptsGen(mnfst manifest.Manifest, hostData []byte) (*snp.StaticValidateOptsGenerator, error) { - validateOpts, err := mnfst.AKSValidateOpts() - if err != nil { - return nil, err - } - validateOpts.HostData = hostData - return &snp.StaticValidateOptsGenerator{ - Opts: validateOpts, - }, nil -} - func writeFilelist(dir string, filelist map[string][]byte) error { if dir != "" { if err := os.MkdirAll(dir, 0o755); err != nil { diff --git a/coordinator/internal/authority/authority.go b/coordinator/internal/authority/authority.go index 82cb45b033..82a2869239 100644 --- a/coordinator/internal/authority/authority.go +++ b/coordinator/internal/authority/authority.go @@ -15,8 +15,6 @@ import ( "github.com/edgelesssys/contrast/internal/ca" "github.com/edgelesssys/contrast/internal/manifest" "github.com/edgelesssys/contrast/internal/userapi" - "github.com/google/go-sev-guest/proto/sevsnp" - "github.com/google/go-sev-guest/validate" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" ) @@ -171,19 +169,3 @@ type State struct { latest *history.LatestTransition generation int } - -// SNPValidateOpts returns SNP validation options from reference values. -// -// It also ensures that the policy hash in the report's HOSTDATA is allowed by the current -// manifest. -// TODO(msanft): make the manifest authoritative and allow other types of reference values. -func (s *State) SNPValidateOpts(report *sevsnp.Report) (*validate.Options, error) { - mnfst := s.Manifest - - hostData := manifest.NewHexString(report.HostData) - if _, ok := mnfst.Policies[hostData]; !ok { - return nil, fmt.Errorf("hostdata %s not found in manifest", hostData) - } - - return mnfst.AKSValidateOpts() -} diff --git a/coordinator/internal/authority/authority_test.go b/coordinator/internal/authority/authority_test.go index e7cd8beb81..4a4f73eb32 100644 --- a/coordinator/internal/authority/authority_test.go +++ b/coordinator/internal/authority/authority_test.go @@ -16,7 +16,6 @@ import ( "github.com/edgelesssys/contrast/internal/manifest" "github.com/edgelesssys/contrast/internal/platforms" "github.com/edgelesssys/contrast/internal/userapi" - "github.com/google/go-sev-guest/proto/sevsnp" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/testutil" "github.com/spf13/afero" @@ -37,8 +36,6 @@ func TestSNPValidateOpts(t *testing.T) { require := require.New(t) a, _ := newAuthority(t) _, mnfstBytes, policies := newManifest(t) - policyHash := sha256.Sum256(policies[0]) - report := &sevsnp.Report{HostData: policyHash[:]} req := &userapi.SetManifestRequest{ Manifest: mnfstBytes, @@ -47,16 +44,9 @@ func TestSNPValidateOpts(t *testing.T) { _, err := a.SetManifest(context.Background(), req) require.NoError(err) - opts, err := a.state.Load().SNPValidateOpts(report) + gens, err := a.state.Load().Manifest.SNPValidateOpts() require.NoError(err) - require.NotNil(opts) - - // Change to unknown policy hash in HostData. - report.HostData[0]++ - - opts, err = a.state.Load().SNPValidateOpts(report) - require.Error(err) - require.Nil(opts) + require.NotNil(gens) } // TODO(burgerdev): test ValidateCallback and GetCertBundle diff --git a/coordinator/internal/authority/credentials.go b/coordinator/internal/authority/credentials.go index 9d1add75c3..0b2e3772fd 100644 --- a/coordinator/internal/authority/credentials.go +++ b/coordinator/internal/authority/credentials.go @@ -15,6 +15,7 @@ import ( "github.com/edgelesssys/contrast/internal/atls" "github.com/edgelesssys/contrast/internal/attestation/snp" "github.com/edgelesssys/contrast/internal/logger" + "github.com/edgelesssys/contrast/internal/manifest" "github.com/edgelesssys/contrast/internal/memstore" "github.com/google/go-sev-guest/proto/sevsnp" "github.com/prometheus/client_golang/prometheus" @@ -72,11 +73,24 @@ func (c *Credentials) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.A authInfo := AuthInfo{State: state} - validator := snp.NewValidatorWithCallbacks(state, c.kdsGetter, - logger.NewWithAttrs(logger.NewNamed(c.logger, "validator"), map[string]string{"tee-type": "snp"}), - c.attestationFailuresCounter, &authInfo) + opts, err := state.Manifest.SNPValidateOpts() + if err != nil { + return nil, nil, fmt.Errorf("generating SNP validation options: %w", err) + } - serverCfg, err := atls.CreateAttestationServerTLSConfig(c.issuer, []atls.Validator{validator}) + var allowedHostDataEntries []manifest.HexString + for entry := range state.Manifest.Policies { + allowedHostDataEntries = append(allowedHostDataEntries, entry) + } + + var validators []atls.Validator + for _, opt := range opts { + validator := snp.NewValidatorWithCallbacks(opt, allowedHostDataEntries, c.kdsGetter, + logger.NewWithAttrs(logger.NewNamed(c.logger, "validator"), map[string]string{"tee-type": "snp"}), + c.attestationFailuresCounter, &authInfo) + validators = append(validators, validator) + } + serverCfg, err := atls.CreateAttestationServerTLSConfig(c.issuer, validators) if err != nil { return nil, nil, err } diff --git a/initializer/main.go b/initializer/main.go index 5d423f7e65..cd1ec516ae 100644 --- a/initializer/main.go +++ b/initializer/main.go @@ -60,7 +60,9 @@ func run() (retErr error) { } requestCert := func() (*meshapi.NewMeshCertResponse, error) { - dial := dialer.NewWithKey(issuer, atls.NoValidator, &net.Dialer{}, privKey) + // Supply an empty list of validators, as the coordinator does not need to be + // validated by the initializer. + dial := dialer.NewWithKey(issuer, atls.NoValidators, &net.Dialer{}, privKey) conn, err := dial.Dial(ctx, net.JoinHostPort(coordinatorHostname, meshapi.Port)) if err != nil { return nil, fmt.Errorf("dialing: %w", err) diff --git a/internal/atls/atls.go b/internal/atls/atls.go index 4b9490343c..168901dbd6 100644 --- a/internal/atls/atls.go +++ b/internal/atls/atls.go @@ -28,8 +28,8 @@ import ( const attestationTimeout = 30 * time.Second var ( - // NoValidator skips validation of the server's attestation document. - NoValidator Validator + // NoValidators skips validation of the server's attestation document. + NoValidators = []Validator{} // NoIssuer skips embedding the client's attestation document. NoIssuer Issuer diff --git a/internal/attestation/snp/validator.go b/internal/attestation/snp/validator.go index 431d537f76..de1bc7d52c 100644 --- a/internal/attestation/snp/validator.go +++ b/internal/attestation/snp/validator.go @@ -10,8 +10,10 @@ import ( "encoding/hex" "fmt" "log/slog" + "slices" "github.com/edgelesssys/contrast/internal/attestation/reportdata" + "github.com/edgelesssys/contrast/internal/manifest" "github.com/edgelesssys/contrast/internal/oid" "github.com/google/go-sev-guest/abi" "github.com/google/go-sev-guest/proto/sevsnp" @@ -24,11 +26,12 @@ import ( // Validator validates attestation statements. type Validator struct { - validateOptsGen validateOptsGenerator - callbackers []validateCallbacker - kdsGetter trust.HTTPSGetter - logger *slog.Logger - metrics metrics + opts *validate.Options + allowedHostDataEntries []manifest.HexString // Allowed host data entries in the report. If any of these is present, the report is considered valid. + callbackers []validateCallbacker + kdsGetter trust.HTTPSGetter + logger *slog.Logger + metrics metrics } type metrics struct { @@ -40,38 +43,29 @@ type validateCallbacker interface { reportRaw, nonce, peerPublicKey []byte) error } -type validateOptsGenerator interface { - SNPValidateOpts(report *sevsnp.Report) (*validate.Options, error) -} - -// StaticValidateOptsGenerator returns validate.Options generator that returns -// static validation options. -type StaticValidateOptsGenerator struct { - Opts *validate.Options -} - -// SNPValidateOpts return the SNP validation options. -func (v *StaticValidateOptsGenerator) SNPValidateOpts(_ *sevsnp.Report) (*validate.Options, error) { - return v.Opts, nil -} - // NewValidator returns a new Validator. -func NewValidator(optsGen validateOptsGenerator, kdsGetter trust.HTTPSGetter, log *slog.Logger) *Validator { +func NewValidator(opts *validate.Options, allowedHostDataEntries []manifest.HexString, + kdsGetter trust.HTTPSGetter, log *slog.Logger, +) *Validator { return &Validator{ - validateOptsGen: optsGen, - kdsGetter: kdsGetter, - logger: log, + opts: opts, + allowedHostDataEntries: allowedHostDataEntries, + kdsGetter: kdsGetter, + logger: log, } } // NewValidatorWithCallbacks returns a new Validator with callbacks. -func NewValidatorWithCallbacks(optsGen validateOptsGenerator, kdsGetter trust.HTTPSGetter, log *slog.Logger, attestataionFailures prometheus.Counter, callbacks ...validateCallbacker) *Validator { +func NewValidatorWithCallbacks(opts *validate.Options, allowedHostDataEntries []manifest.HexString, kdsGetter trust.HTTPSGetter, + log *slog.Logger, attestationFailures prometheus.Counter, callbacks ...validateCallbacker, +) *Validator { return &Validator{ - validateOptsGen: optsGen, - callbackers: callbacks, - kdsGetter: kdsGetter, - logger: log, - metrics: metrics{attestationFailures: attestataionFailures}, + opts: opts, + allowedHostDataEntries: allowedHostDataEntries, + callbackers: callbacks, + kdsGetter: kdsGetter, + logger: log, + metrics: metrics{attestationFailures: attestationFailures}, } } @@ -125,16 +119,20 @@ func (v *Validator) Validate(ctx context.Context, attDocRaw []byte, nonce []byte // Validate the report data. reportDataExpected := reportdata.Construct(peerPublicKey, nonce) - validateOpts, err := v.validateOptsGen.SNPValidateOpts(attestation.Report) - if err != nil { - return fmt.Errorf("generating validation options: %w", err) - } - validateOpts.ReportData = reportDataExpected[:] - if err := validate.SnpAttestation(attestation, validateOpts); err != nil { + v.opts.ReportData = reportDataExpected[:] + if err := validate.SnpAttestation(attestation, v.opts); err != nil { return fmt.Errorf("validating report claims: %w", err) } v.logger.Info("Successfully validated report data") + // Validate the host data. + + if !slices.ContainsFunc(v.allowedHostDataEntries, func(entry manifest.HexString) bool { + return manifest.NewHexString(attestation.Report.HostData) == entry + }) { + return fmt.Errorf("host data not allowed (found: %v allowed: %v)", attestation.Report.HostData, v.allowedHostDataEntries) + } + // Run callbacks. for _, callbacker := range v.callbackers { diff --git a/internal/grpc/dialer/dialer.go b/internal/grpc/dialer/dialer.go index 1ca6e0f622..e33c8521e0 100644 --- a/internal/grpc/dialer/dialer.go +++ b/internal/grpc/dialer/dialer.go @@ -18,38 +18,34 @@ import ( // Dialer can open grpc client connections with different levels of ATLS encryption / verification. type Dialer struct { - issuer atls.Issuer - validator atls.Validator - netDialer NetDialer - privKey *ecdsa.PrivateKey + issuer atls.Issuer + validators []atls.Validator + netDialer NetDialer + privKey *ecdsa.PrivateKey } // New creates a new Dialer. -func New(issuer atls.Issuer, validator atls.Validator, netDialer NetDialer) *Dialer { +func New(issuer atls.Issuer, validators []atls.Validator, netDialer NetDialer) *Dialer { return &Dialer{ - issuer: issuer, - validator: validator, - netDialer: netDialer, + issuer: issuer, + validators: validators, + netDialer: netDialer, } } // NewWithKey creates a new Dialer with the given private key. -func NewWithKey(issuer atls.Issuer, validator atls.Validator, netDialer NetDialer, privKey *ecdsa.PrivateKey) *Dialer { +func NewWithKey(issuer atls.Issuer, validators []atls.Validator, netDialer NetDialer, privKey *ecdsa.PrivateKey) *Dialer { return &Dialer{ - issuer: issuer, - validator: validator, - netDialer: netDialer, - privKey: privKey, + issuer: issuer, + validators: validators, + netDialer: netDialer, + privKey: privKey, } } // Dial creates a new grpc client connection to the given target using the atls validator. func (d *Dialer) Dial(_ context.Context, target string) (*grpc.ClientConn, error) { - var validators []atls.Validator - if d.validator != nil { - validators = append(validators, d.validator) - } - credentials := atlscredentials.NewWithKey(d.issuer, validators, d.privKey) + credentials := atlscredentials.NewWithKey(d.issuer, d.validators, d.privKey) return grpc.NewClient(target, d.grpcWithDialer(), diff --git a/internal/manifest/constants.go b/internal/manifest/constants.go index a9478b236b..d28f807190 100644 --- a/internal/manifest/constants.go +++ b/internal/manifest/constants.go @@ -13,27 +13,13 @@ import ( // Default returns a default manifest with reference values for the given platform. func Default(platform platforms.Platform) (*Manifest, error) { embeddedRefValues := GetEmbeddedReferenceValues() + refValues, err := embeddedRefValues.ForPlatform(platform) if err != nil { return nil, fmt.Errorf("get reference values for platform %s: %w", platform, err) } - mnfst := Manifest{} - switch platform { - case platforms.AKSCloudHypervisorSNP: - return &Manifest{ - ReferenceValues: ReferenceValues{ - AKS: refValues.AKS, - }, - }, nil - case platforms.RKE2QEMUTDX, platforms.K3sQEMUTDX: - return &Manifest{ - ReferenceValues: ReferenceValues{ - BareMetalTDX: refValues.BareMetalTDX, - }, - }, nil - } - return &mnfst, nil + return &Manifest{ReferenceValues: *refValues}, nil } // GetEmbeddedReferenceValues returns the reference values embedded in the binary. diff --git a/internal/manifest/manifest.go b/internal/manifest/manifest.go index 8bd5b7214f..7fbf21821c 100644 --- a/internal/manifest/manifest.go +++ b/internal/manifest/manifest.go @@ -6,6 +6,7 @@ package manifest import ( "crypto/sha256" "encoding/base64" + "errors" "fmt" "github.com/google/go-sev-guest/abi" @@ -16,7 +17,10 @@ import ( // Manifest is the Coordinator manifest and contains the reference values of the deployment. type Manifest struct { // policyHash/HOSTDATA -> commonName - Policies map[HexString]PolicyEntry + Policies map[HexString]PolicyEntry + // ReferenceValues specifies the allowed TEE configurations in the deployment. If ANY + // of the reference values validates the attestation report of the workload, + // the workload is considered valid. ReferenceValues ReferenceValues WorkloadOwnerKeyDigests []HexString SeedshareOwnerPubKeys []HexString @@ -65,18 +69,18 @@ func (p Policy) Hash() HexString { // Validate checks the validity of all fields in the reference values. func (r ReferenceValues) Validate() error { - if r.AKS != nil { - if err := r.AKS.Validate(); err != nil { - return fmt.Errorf("validating AKS reference values: %w", err) + for _, v := range r.SNP { + if err := v.Validate(); err != nil { + return fmt.Errorf("validating SNP reference values: %w", err) } } - if r.BareMetalTDX != nil { - if err := r.BareMetalTDX.Validate(); err != nil { - return fmt.Errorf("validating bare metal TDX reference values: %w", err) + for _, v := range r.TDX { + if err := v.Validate(); err != nil { + return fmt.Errorf("validating TDX reference values: %w", err) } } - if r.BareMetalTDX == nil && r.AKS == nil { + if len(r.SNP)+len(r.TDX) == 0 { return fmt.Errorf("reference values in manifest cannot be empty. Is the chosen platform supported?") } @@ -84,14 +88,14 @@ func (r ReferenceValues) Validate() error { } // Validate checks the validity of all fields in the AKS reference values. -func (r AKSReferenceValues) Validate() error { - if r.SNP.MinimumTCB.BootloaderVersion == nil { +func (r SNPReferenceValues) Validate() error { + if r.MinimumTCB.BootloaderVersion == nil { return fmt.Errorf("field BootloaderVersion in manifest cannot be empty") - } else if r.SNP.MinimumTCB.TEEVersion == nil { + } else if r.MinimumTCB.TEEVersion == nil { return fmt.Errorf("field TEEVersion in manifest cannot be empty") - } else if r.SNP.MinimumTCB.SNPVersion == nil { + } else if r.MinimumTCB.SNPVersion == nil { return fmt.Errorf("field SNPVersion in manifest cannot be empty") - } else if r.SNP.MinimumTCB.MicrocodeVersion == nil { + } else if r.MinimumTCB.MicrocodeVersion == nil { return fmt.Errorf("field MicrocodeVersion in manifest cannot be empty") } @@ -103,7 +107,7 @@ func (r AKSReferenceValues) Validate() error { } // Validate checks the validity of all fields in the bare metal TDX reference values. -func (r BareMetalTDXReferenceValues) Validate() error { +func (r TDXReferenceValues) Validate() error { if r.TrustedMeasurement == "" { return fmt.Errorf("field TrustedMeasurement in manifest cannot be empty") } @@ -140,40 +144,56 @@ func (m *Manifest) Validate() error { return nil } -// AKSValidateOpts returns validate options populated with the manifest's -// AKS reference values and trusted measurement. -func (m *Manifest) AKSValidateOpts() (*validate.Options, error) { - if m.ReferenceValues.AKS == nil { - return nil, fmt.Errorf("no AKS reference values present in manifest") +// TODO(msanft): add generic validation interface for other attestation types. + +// SNPValidateOpts returns validate options generators populated with the manifest's +// SNP reference values and trusted measurement for the given runtime. +func (m *Manifest) SNPValidateOpts() ([]*validate.Options, error) { + if len(m.ReferenceValues.SNP) == 0 { + return nil, errors.New("reference values cannot be empty") } if err := m.Validate(); err != nil { return nil, fmt.Errorf("validating manifest: %w", err) } - trustedMeasurement, err := m.ReferenceValues.AKS.TrustedMeasurement.Bytes() - if err != nil { - return nil, fmt.Errorf("failed to convert TrustedMeasurement from manifest to byte slices: %w", err) + + var out []*validate.Options + for _, refVal := range m.ReferenceValues.SNP { + if len(refVal.TrustedMeasurement) == 0 { + return nil, errors.New("trusted measurement cannot be empty") + } + + trustedMeasurement, err := refVal.TrustedMeasurement.Bytes() + if err != nil { + return nil, fmt.Errorf("failed to convert TrustedMeasurement from manifest to byte slices: %w", err) + } + + out = append(out, &validate.Options{ + Measurement: trustedMeasurement, + GuestPolicy: abi.SnpPolicy{ + Debug: false, + SMT: true, + }, + VMPL: new(int), // VMPL0 + MinimumTCB: kds.TCBParts{ + BlSpl: refVal.MinimumTCB.BootloaderVersion.UInt8(), + TeeSpl: refVal.MinimumTCB.TEEVersion.UInt8(), + SnpSpl: refVal.MinimumTCB.SNPVersion.UInt8(), + UcodeSpl: refVal.MinimumTCB.MicrocodeVersion.UInt8(), + }, + MinimumLaunchTCB: kds.TCBParts{ + BlSpl: refVal.MinimumTCB.BootloaderVersion.UInt8(), + TeeSpl: refVal.MinimumTCB.TEEVersion.UInt8(), + SnpSpl: refVal.MinimumTCB.SNPVersion.UInt8(), + UcodeSpl: refVal.MinimumTCB.MicrocodeVersion.UInt8(), + }, + PermitProvisionalFirmware: true, + }) + } + + if len(out) == 0 { + return nil, errors.New("no SNP reference values found in manifest") } - return &validate.Options{ - Measurement: trustedMeasurement, - GuestPolicy: abi.SnpPolicy{ - Debug: false, - SMT: true, - }, - VMPL: new(int), // VMPL0 - MinimumTCB: kds.TCBParts{ - BlSpl: m.ReferenceValues.AKS.SNP.MinimumTCB.BootloaderVersion.UInt8(), - TeeSpl: m.ReferenceValues.AKS.SNP.MinimumTCB.TEEVersion.UInt8(), - SnpSpl: m.ReferenceValues.AKS.SNP.MinimumTCB.SNPVersion.UInt8(), - UcodeSpl: m.ReferenceValues.AKS.SNP.MinimumTCB.MicrocodeVersion.UInt8(), - }, - MinimumLaunchTCB: kds.TCBParts{ - BlSpl: m.ReferenceValues.AKS.SNP.MinimumTCB.BootloaderVersion.UInt8(), - TeeSpl: m.ReferenceValues.AKS.SNP.MinimumTCB.TEEVersion.UInt8(), - SnpSpl: m.ReferenceValues.AKS.SNP.MinimumTCB.SNPVersion.UInt8(), - UcodeSpl: m.ReferenceValues.AKS.SNP.MinimumTCB.MicrocodeVersion.UInt8(), - }, - PermitProvisionalFirmware: true, - }, nil + return out, nil } diff --git a/internal/manifest/manifest_test.go b/internal/manifest/manifest_test.go index a6f98589e4..7c1db77f42 100644 --- a/internal/manifest/manifest_test.go +++ b/internal/manifest/manifest_test.go @@ -93,9 +93,11 @@ func TestValidate(t *testing.T) { m: &Manifest{ Policies: map[HexString]PolicyEntry{HexString(""): {}}, ReferenceValues: ReferenceValues{ - AKS: &AKSReferenceValues{ - SNP: mnf.ReferenceValues.AKS.SNP, - TrustedMeasurement: "", + SNP: []SNPReferenceValues{ + { + MinimumTCB: mnf.ReferenceValues.SNP[0].MinimumTCB, + TrustedMeasurement: "", + }, }, }, }, @@ -109,6 +111,7 @@ func TestValidate(t *testing.T) { wantErr: true, }, } + for i, tc := range testCases { t.Run(strconv.Itoa(i), func(t *testing.T) { assert := assert.New(t) @@ -124,22 +127,25 @@ func TestValidate(t *testing.T) { func TestAKSValidateOpts(t *testing.T) { assert := assert.New(t) + require := require.New(t) m, err := Default(platforms.AKSCloudHypervisorSNP) - require.NoError(t, err) + require.NoError(err) - opts, err := m.AKSValidateOpts() - assert.NoError(err) + opts, err := m.SNPValidateOpts() + require.NoError(err) + require.Len(opts, 1) - tcb := m.ReferenceValues.AKS.SNP.MinimumTCB + tcb := m.ReferenceValues.SNP[0].MinimumTCB assert.NotNil(tcb.BootloaderVersion) assert.NotNil(tcb.TEEVersion) assert.NotNil(tcb.SNPVersion) assert.NotNil(tcb.MicrocodeVersion) - trustedMeasurement, err := m.ReferenceValues.AKS.TrustedMeasurement.Bytes() + trustedMeasurement, err := m.ReferenceValues.SNP[0].TrustedMeasurement.Bytes() assert.NoError(err) - assert.Equal(trustedMeasurement, opts.Measurement) + + assert.Equal(trustedMeasurement, opts[0].Measurement) tcbParts := kds.TCBParts{ BlSpl: tcb.BootloaderVersion.UInt8(), @@ -147,6 +153,6 @@ func TestAKSValidateOpts(t *testing.T) { SnpSpl: tcb.SNPVersion.UInt8(), UcodeSpl: tcb.MicrocodeVersion.UInt8(), } - assert.Equal(tcbParts, opts.MinimumTCB) - assert.Equal(tcbParts, opts.MinimumLaunchTCB) + assert.Equal(tcbParts, opts[0].MinimumTCB) + assert.Equal(tcbParts, opts[0].MinimumLaunchTCB) } diff --git a/internal/manifest/referencevalues.go b/internal/manifest/referencevalues.go index e63503ea6a..065fe2c504 100644 --- a/internal/manifest/referencevalues.go +++ b/internal/manifest/referencevalues.go @@ -19,35 +19,30 @@ import ( //go:embed assets/reference-values.json var EmbeddedReferenceValuesJSON []byte -// ReferenceValues contains the workload-independent reference values for each platform. +// ReferenceValues contains the workload-independent reference values for each TEE type. type ReferenceValues struct { - // AKS holds the reference values for AKS. - AKS *AKSReferenceValues `json:"aks,omitempty"` - // BareMetalTDX holds the reference values for TDX on bare metal. - BareMetalTDX *BareMetalTDXReferenceValues `json:"bareMetalTDX,omitempty"` + // SNP holds the reference values for SNP. + SNP []SNPReferenceValues `json:"snp,omitempty"` + // TDX holds the reference values for TDX. + TDX []TDXReferenceValues `json:"tdx,omitempty"` } -// EmbeddedReferenceValues is a map of runtime handler names to reference values, as -// embedded in the binary. +// EmbeddedReferenceValues is a map of runtime handler names to a list of reference values +// for the runtime handler, as embedded in the binary. type EmbeddedReferenceValues map[string]ReferenceValues -// AKSReferenceValues contains reference values for AKS. -type AKSReferenceValues struct { - SNP SNPReferenceValues +// SNPReferenceValues contains reference values for SEV-SNP. +type SNPReferenceValues struct { + MinimumTCB SNPTCB TrustedMeasurement HexString } -// BareMetalTDXReferenceValues contains reference values for BareMetalTDX. -type BareMetalTDXReferenceValues struct { +// TDXReferenceValues contains reference values for TDX. +type TDXReferenceValues struct { TrustedMeasurement HexString } -// SNPReferenceValues contains reference values for the SNP report. -type SNPReferenceValues struct { - MinimumTCB SNPTCB -} - -// SNPTCB represents a set of SNP TCB values. +// SNPTCB represents a set of SEV-SNP TCB values. type SNPTCB struct { BootloaderVersion *SVN TEEVersion *SVN diff --git a/packages/by-name/contrast/package.nix b/packages/by-name/contrast/package.nix index 5d5a11a01b..abefbdee96 100644 --- a/packages/by-name/contrast/package.nix +++ b/packages/by-name/contrast/package.nix @@ -52,32 +52,38 @@ let k3s-qemu-snp-handler = runtimeHandler "k3s-qemu-snp" kata.contrast-node-installer-image.runtimeHash; aksRefVals = { - aks = { - snp = { + snp = [ + { minimumTCB = { bootloaderVersion = 3; teeVersion = 0; snpVersion = 8; microcodeVersion = 115; }; - }; - trustedMeasurement = lib.removeSuffix "\n" (builtins.readFile microsoft.kata-igvm.launch-digest); - }; + trustedMeasurement = lib.removeSuffix "\n" (builtins.readFile microsoft.kata-igvm.launch-digest); + } + ]; }; snpRefVals = { - inherit (aksRefVals.aks) snp; - trustedMeasurement = lib.removeSuffix "\n" ( - builtins.readFile "${kata.contrast-node-installer-image.runtimeHash}" - ); + snp = [ + { + inherit (builtins.head aksRefVals.snp) minimumTCB; + trustedMeasurement = lib.removeSuffix "\n" ( + builtins.readFile "${kata.contrast-node-installer-image.runtimeHash}" + ); + } + ]; }; tdxRefVals = { - bareMetalTDX = { - trustedMeasurement = lib.removeSuffix "\n" ( - builtins.readFile "${kata.contrast-node-installer-image.runtimeHash}" - ); - }; + tdx = [ + { + trustedMeasurement = lib.removeSuffix "\n" ( + builtins.readFile "${kata.contrast-node-installer-image.runtimeHash}" + ); + } + ]; }; in builtins.toFile "reference-values.json" (