From 0deb0f5870abd48a0fb1772fe81c713b7afa7aad Mon Sep 17 00:00:00 2001 From: Mariano Kunzi Date: Mon, 5 Apr 2021 19:15:51 -0300 Subject: [PATCH] AWS KMS Server Keymanager (#2066) Signed-off-by: Mariano Kunzi --- conf/server/server_full.conf | 11 + doc/plugin_server_keymanager_aws_kms.md | 58 + doc/spire_server.md | 5 +- go.mod | 5 + go.sum | 23 + pkg/server/catalog/catalog.go | 2 + pkg/server/plugin/keymanager/awskms/awskms.go | 844 ++++++++ .../plugin/keymanager/awskms/awskms_test.go | 1830 +++++++++++++++++ pkg/server/plugin/keymanager/awskms/client.go | 37 + .../plugin/keymanager/awskms/client_fake.go | 546 +++++ .../plugin/keymanager/awskms/fetcher.go | 143 ++ test/spiretest/logs.go | 15 + 12 files changed, 3517 insertions(+), 2 deletions(-) create mode 100644 doc/plugin_server_keymanager_aws_kms.md create mode 100644 pkg/server/plugin/keymanager/awskms/awskms.go create mode 100644 pkg/server/plugin/keymanager/awskms/awskms_test.go create mode 100644 pkg/server/plugin/keymanager/awskms/client.go create mode 100644 pkg/server/plugin/keymanager/awskms/client_fake.go create mode 100644 pkg/server/plugin/keymanager/awskms/fetcher.go diff --git a/conf/server/server_full.conf b/conf/server/server_full.conf index 9799fcb8c5..c6358c1f72 100644 --- a/conf/server/server_full.conf +++ b/conf/server/server_full.conf @@ -188,6 +188,17 @@ plugins { } } + # KeyManager "aws_kms": A key manager for signing SVIDs which only generates and stores keys in AWS KMS + # KeyManager "aws_kms" { + # plugin_data { + # region: AWS Region to use. + # region = "" + # + # key_metadata_file: A file path location where information about generated keys will be persisted + # key_metadata_file = "./file_path" + # } + # } + # KeyManager "disk": A disk-based key manager for signing SVIDs. # KeyManager "disk" { # plugin_data { diff --git a/doc/plugin_server_keymanager_aws_kms.md b/doc/plugin_server_keymanager_aws_kms.md new file mode 100644 index 0000000000..7cc339b1b3 --- /dev/null +++ b/doc/plugin_server_keymanager_aws_kms.md @@ -0,0 +1,58 @@ +# Server plugin: KeyManager "aws_kms" + +The `aws_kms` key manager plugin leverages the AWS Key Management Service (KMS) to create, maintain and rotate key pairs (as [Customer Master Keys](https://docs.aws.amazon.com/kms/latest/developerguide/concepts.html#master_keys), or CMKs), and sign SVIDs as needed, with the private key never leaving KMS. + +## Configuration + +The plugin accepts the following configuration options: + +| Key | Type | Required | Description | Default | +| ------------------- | ------ | ------------------------------------- | ------------------------------------------------------- | ---------------------------------------------------- | +| access_key_id | string | see [AWS KMS Access](#aws-kms-access) | The Access Key Id used to authenticate to KMS | Value of the AWS_ACCESS_KEY_ID environment variable | +| secret_access_key | string | see [AWS KMS Access](#aws-kms-access) | The Secret Access Key used to authenticate to KMS | Value of the AWS_SECRET_ACCESS_KEY environment variable | +| region | string | yes | The region where the keys will be stored | | +| key_metadata_file | string | yes | A file path location where information about generated keys will be persisted | | + +### Alias and Key Management + +The plugin assigns [aliases](https://docs.aws.amazon.com/kms/latest/developerguide/kms-alias.html) to the Customer Master Keys that manages. The aliases are used to identify and name keys that are managed by the plugin. + +Aliases managed by the plugin have the following form: `alias/SPIRE_SERVER/{TRUST_DOMAIN}/{SERVER_ID}/{KEY_ID}`. The `{SERVER_ID}` is an auto-generated ID unique to the server and is persisted in the _Key Metadata File_ (see the `key_metadata_file` configurable). This ID allows multiple servers in the same trust domain (e.g. servers in HA deployments) to manage keys with identical `{KEY_ID}`'s without collision. + +If the _Key Metadata File_ is not found on server startup, the file is recreated, with a new auto-generated server ID. Consequently, if the file is lost, the plugin will not be able to identify keys that it has previously managed and will recreate new keys on demand. + +The plugin attempts to detect and prune stale aliases. To facilitate stale alias detection, the plugin actively updates the `LastUpdatedDate` field on all aliases every 6 hours. The plugin periodically scans aliases. Any alias encountered with a `LastUpdatedDate` older than two weeks is removed, along with its associated key. + +The plugin also attempts to detect and prune stale keys. All keys managed by the plugin are assigned a `Description` of the form `SPIRE_SERVER/{TRUST_DOMAIN}`. The plugin periodically scans the keys. Any key with a `Description` matching the proper form, that is both unassociated with any alias and has a `CreationDate` older than 48 hours, is removed. + +### AWS KMS Access + +Access to AWS KMS can be given by either setting the `access_key_id` and `secret_access_key`, or by ensuring that the plugin runs on an EC2 instance with a given IAM role that has a specific set of permissions. + +The IAM role must have an attached policy with the following permissions: + +- `kms:CreateAlias` +- `kms:CreateKey` +- `kms:DescribeKey` +- `kms:GetPublicKey` +- `kms:ListKeys` +- `kms:ListAliases` +- `kms:ScheduleKeyDeletion` +- `kms:Sign` +- `kms:UpdateAlias` +- `kms:DeleteAlias` + +## Sample Plugin Configuration + +``` +KeyManager "aws_kms" { + plugin_data { + region = "us-east-2" + key_metadata_file = "./key_metadata" + } +} +``` + +## Supported Key Types and TTL + +The plugin supports all the key types supported by SPIRE: `rsa-2048`, `rsa-4096`, `ec-p256`, and `ec-p384`. diff --git a/doc/spire_server.md b/doc/spire_server.md index 0ad66f1328..3ee9539870 100644 --- a/doc/spire_server.md +++ b/doc/spire_server.md @@ -18,8 +18,9 @@ This document is a configuration reference for SPIRE Server. It includes informa | Type | Name | Description | | ---- | ---- | ----------- | | DataStore | [sql](/doc/plugin_server_datastore_sql.md) | An sql database storage for SQLite, PostgreSQL and MySQL databases for the SPIRE datastore | -| KeyManager | [disk](/doc/plugin_server_keymanager_disk.md) | A disk-based key manager for signing SVIDs | -| KeyManager | [memory](/doc/plugin_server_keymanager_memory.md) | A key manager for signing SVIDs which only stores keys in memory and does not actually persist them anywhere | +| KeyManager | [aws_kms](/doc/plugin_server_keymanager_awskms.md) | A key manager which manages keys in AWS KMS | +| KeyManager | [disk](/doc/plugin_server_keymanager_disk.md) | A key manager which manages keys persisted on disk | +| KeyManager | [memory](/doc/plugin_server_keymanager_memory.md) | A key manager which manages unpersisted keys in memory | | NodeAttestor | [aws_iid](/doc/plugin_server_nodeattestor_aws_iid.md) | A node attestor which attests agent identity using an AWS Instance Identity Document | | NodeAttestor | [azure_msi](/doc/plugin_server_nodeattestor_azure_msi.md) | A node attestor which attests agent identity using an Azure MSI token | | NodeAttestor | [gcp_iit](/doc/plugin_server_nodeattestor_gcp_iit.md) | A node attestor which attests agent identity using a GCP Instance Identity Token | diff --git a/go.mod b/go.mod index a9b5dfef47..e8b05a043a 100644 --- a/go.mod +++ b/go.mod @@ -21,6 +21,10 @@ require ( github.com/andres-erbsen/clock v0.0.0-20160526145045-9e14626cd129 github.com/armon/go-metrics v0.3.2 github.com/aws/aws-sdk-go v1.28.9 + github.com/aws/aws-sdk-go-v2 v1.2.0 + github.com/aws/aws-sdk-go-v2/config v1.1.1 + github.com/aws/aws-sdk-go-v2/credentials v1.1.1 + github.com/aws/aws-sdk-go-v2/service/kms v1.1.1 github.com/blang/semver v3.5.1+incompatible github.com/cenkalti/backoff/v3 v3.0.0 github.com/containerd/containerd v1.3.2 // indirect @@ -64,6 +68,7 @@ require ( go.uber.org/goleak v0.10.0 golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4 + golang.org/x/sync v0.0.0-20210220032951-036812b2e83c golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4 golang.org/x/time v0.0.0-20191024005414-555d28b269f0 golang.org/x/tools v0.1.0 diff --git a/go.sum b/go.sum index 42af309c1b..e7be38c46a 100644 --- a/go.sum +++ b/go.sum @@ -122,6 +122,24 @@ github.com/asaskevich/govalidator v0.0.0-20180720115003-f9ffefc3facf/go.mod h1:l github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY= github.com/aws/aws-sdk-go v1.28.9 h1:grIuBQc+p3dTRXerh5+2OxSuWFi0iXuxbFdTSg0jaW0= github.com/aws/aws-sdk-go v1.28.9/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= +github.com/aws/aws-sdk-go-v2 v1.2.0 h1:BS+UYpbsElC82gB+2E2jiCBg36i8HlubTB/dO/moQ9c= +github.com/aws/aws-sdk-go-v2 v1.2.0/go.mod h1:zEQs02YRBw1DjK0PoJv3ygDYOFTre1ejlJWl8FwAuQo= +github.com/aws/aws-sdk-go-v2/config v1.1.1 h1:ZAoq32boMzcaTW9bcUacBswAmHTbvlvDJICgHFZuECo= +github.com/aws/aws-sdk-go-v2/config v1.1.1/go.mod h1:0XsVy9lBI/BCXm+2Tuvt39YmdHwS5unDQmxZOYe8F5Y= +github.com/aws/aws-sdk-go-v2/credentials v1.1.1 h1:NbvWIM1Mx6sNPTxowHgS2ewXCRp+NGTzUYb/96FZJbY= +github.com/aws/aws-sdk-go-v2/credentials v1.1.1/go.mod h1:mM2iIjwl7LULWtS6JCACyInboHirisUUdkBPoTHMOUo= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.0.2 h1:EtEU7WRaWliitZh2nmuxEXrN0Cb8EgPUFGIoTMeqbzI= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.0.2/go.mod h1:3hGg3PpiEjHnrkrlasTfxFqUsZ2GCk/fMUn4CbKgSkM= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.0.2 h1:4AH9fFjUlVktQMznF+YN33aWNXaR4VgDXyP28qokJC0= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.0.2/go.mod h1:45MfaXZ0cNbeuT0KQ1XJylq8A6+OpVV2E5kvY/Kq+u8= +github.com/aws/aws-sdk-go-v2/service/kms v1.1.1 h1:rK1edW1dLtSGr1551ttHqQopajK4Pv9C4ez70dVMQaI= +github.com/aws/aws-sdk-go-v2/service/kms v1.1.1/go.mod h1:6K5oOoDdnkW/h+Jv+xOA+tvgI6lwGBT9igkJGL1ypaY= +github.com/aws/aws-sdk-go-v2/service/sso v1.1.1 h1:37QubsarExl5ZuCBlnRP+7l1tNwZPBSTqpTBrPH98RU= +github.com/aws/aws-sdk-go-v2/service/sso v1.1.1/go.mod h1:SuZJxklHxLAXgLTc1iFXbEWkXs7QRTQpCLGaKIprQW0= +github.com/aws/aws-sdk-go-v2/service/sts v1.1.1 h1:TJoIfnIFubCX0ACVeJ0w46HEH5MwjwYN4iFhuYIhfIY= +github.com/aws/aws-sdk-go-v2/service/sts v1.1.1/go.mod h1:Wi0EBZwiz/K44YliU0EKxqTCJGUfYTWXrrBwkq736bM= +github.com/aws/smithy-go v1.1.0 h1:D6CSsM3gdxaGaqXnPgOBCeL6Mophqzu7KJOu7zW78sU= +github.com/aws/smithy-go v1.1.0/go.mod h1:EzMw8dbp/YJL4A5/sbhGddag+NPT7q084agLbB9LgIw= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0 h1:HWo1m869IqiPhD389kmkxeTalrjNbbJTC8LXupb+sl0= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= @@ -461,6 +479,10 @@ github.com/jinzhu/now v1.0.1 h1:HjfetcXq097iXP0uoPCdnM4Efp5/9MsM0/M+XOTeR3M= github.com/jinzhu/now v1.0.1/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af h1:pmfjZENx5imkbgOkpRUYLnmbU7UEFbjtDA2hxJ1ichM= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= +github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= +github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= +github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= +github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= github.com/json-iterator/go v1.1.6 h1:MrUvLMLTMxbqFJ9kzlvat/rYZqZnW3u4wkLzWTaFwKs= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= @@ -807,6 +829,7 @@ golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20170830134202-bb24a47a89ea/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= diff --git a/pkg/server/catalog/catalog.go b/pkg/server/catalog/catalog.go index b681db2151..d27fdb76c8 100644 --- a/pkg/server/catalog/catalog.go +++ b/pkg/server/catalog/catalog.go @@ -18,6 +18,7 @@ import ( ds_sql "github.com/spiffe/spire/pkg/server/plugin/datastore/sql" "github.com/spiffe/spire/pkg/server/plugin/hostservices" "github.com/spiffe/spire/pkg/server/plugin/keymanager" + km_awskms "github.com/spiffe/spire/pkg/server/plugin/keymanager/awskms" km_disk "github.com/spiffe/spire/pkg/server/plugin/keymanager/disk" km_memory "github.com/spiffe/spire/pkg/server/plugin/keymanager/memory" "github.com/spiffe/spire/pkg/server/plugin/nodeattestor" @@ -73,6 +74,7 @@ var ( // KeyManagers km_disk.BuiltIn(), km_memory.BuiltIn(), + km_awskms.BuiltIn(), // Notifiers no_k8sbundle.BuiltIn(), no_gcs_bundle.BuiltIn(), diff --git a/pkg/server/plugin/keymanager/awskms/awskms.go b/pkg/server/plugin/keymanager/awskms/awskms.go new file mode 100644 index 0000000000..4ba465f474 --- /dev/null +++ b/pkg/server/plugin/keymanager/awskms/awskms.go @@ -0,0 +1,844 @@ +package awskms + +import ( + "context" + "errors" + "fmt" + "io/ioutil" + "os" + "path" + "strings" + "sync" + "time" + + "github.com/andres-erbsen/clock" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/kms" + "github.com/aws/aws-sdk-go-v2/service/kms/types" + "github.com/gofrs/uuid" + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/hcl" + "github.com/spiffe/spire/pkg/common/catalog" + "github.com/spiffe/spire/proto/spire/common/plugin" + keymanagerv0 "github.com/spiffe/spire/proto/spire/plugin/server/keymanager/v0" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +const ( + pluginName = "aws_kms" + aliasPrefix = "alias/SPIRE_SERVER/" + + keyArnTag = "key_arn" + aliasNameTag = "alias_name" + reasonTag = "reason" + + refreshAliasesFrequency = time.Hour * 6 + disposeAliasesFrequency = time.Hour * 24 + aliasThreshold = time.Hour * 24 * 14 // two weeks + + disposeKeysFrequency = time.Hour * 48 + keyThreshold = time.Hour * 48 +) + +func BuiltIn() catalog.Plugin { + return builtin(New()) +} + +func builtin(p *Plugin) catalog.Plugin { + return catalog.MakePlugin(pluginName, keymanagerv0.PluginServer(p)) +} + +type keyEntry struct { + Arn string + AliasName string + PublicKey *keymanagerv0.PublicKey +} + +type pluginHooks struct { + newClient func(ctx context.Context, config *Config) (kmsClient, error) + clk clock.Clock + // just for testing + scheduleDeleteSignal chan error + refreshAliasesSignal chan error + disposeAliasesSignal chan error + disposeKeysSignal chan error +} + +// Plugin is the main representation of this keymanager plugin +type Plugin struct { + keymanagerv0.UnsafeKeyManagerServer + log hclog.Logger + mu sync.RWMutex + entries map[string]keyEntry + kmsClient kmsClient + trustDomain string + serverID string + scheduleDelete chan string + cancelTasks context.CancelFunc + hooks pluginHooks +} + +// Config provides configuration context for the plugin +type Config struct { + AccessKeyID string `hcl:"access_key_id" json:"access_key_id"` + SecretAccessKey string `hcl:"secret_access_key" json:"secret_access_key"` + Region string `hcl:"region" json:"region"` + KeyMetadataFile string `hcl:"key_metadata_file" json:"key_metadata_file"` +} + +// New returns an instantiated plugin +func New() *Plugin { + return newPlugin(newKMSClient) +} + +func newPlugin(newClient func(ctx context.Context, config *Config) (kmsClient, error)) *Plugin { + return &Plugin{ + entries: make(map[string]keyEntry), + hooks: pluginHooks{ + newClient: newClient, + clk: clock.New(), + }, + scheduleDelete: make(chan string, 120), + } +} + +// SetLogger sets a logger +func (p *Plugin) SetLogger(log hclog.Logger) { + p.log = log +} + +// Configure sets up the plugin +func (p *Plugin) Configure(ctx context.Context, req *plugin.ConfigureRequest) (*plugin.ConfigureResponse, error) { + config, err := parseAndValidateConfig(req.Configuration) + if err != nil { + return nil, err + } + + serverID, err := loadServerID(config.KeyMetadataFile) + if err != nil { + return nil, err + } + p.log.Debug("Loaded server id", "server_id", serverID) + + kc, err := p.hooks.newClient(ctx, config) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to create KMS client: %v", err) + } + + fetcher := &keyFetcher{ + log: p.log, + kmsClient: kc, + serverID: serverID, + trustDomain: req.GlobalConfig.TrustDomain, + } + p.log.Debug("Fetching key aliases from KMS") + keyEntries, err := fetcher.fetchKeyEntries(ctx) + if err != nil { + return nil, err + } + + p.mu.Lock() + defer p.mu.Unlock() + + p.setCache(keyEntries) + p.kmsClient = kc + p.trustDomain = req.GlobalConfig.TrustDomain + p.serverID = serverID + + // cancels previous tasks in case of re configure + if p.cancelTasks != nil { + p.cancelTasks() + } + + // start tasks + ctx, p.cancelTasks = context.WithCancel(context.Background()) + go p.scheduleDeleteTask(ctx) + go p.refreshAliasesTask(ctx) + go p.disposeAliasesTask(ctx) + go p.disposeKeysTask(ctx) + + return &plugin.ConfigureResponse{}, nil +} + +// GenerateKey creates a key in KMS. If a key already exists in the local storage, it is updated. +func (p *Plugin) GenerateKey(ctx context.Context, req *keymanagerv0.GenerateKeyRequest) (*keymanagerv0.GenerateKeyResponse, error) { + if req.KeyId == "" { + return nil, status.Error(codes.InvalidArgument, "key id is required") + } + if req.KeyType == keymanagerv0.KeyType_UNSPECIFIED_KEY_TYPE { + return nil, status.Error(codes.InvalidArgument, "key type is required") + } + + p.mu.Lock() + defer p.mu.Unlock() + + spireKeyID := req.KeyId + newKeyEntry, err := p.createKey(ctx, spireKeyID, req.KeyType) + if err != nil { + return nil, err + } + + err = p.assignAlias(ctx, newKeyEntry) + if err != nil { + return nil, err + } + + p.entries[spireKeyID] = *newKeyEntry + + return &keymanagerv0.GenerateKeyResponse{ + PublicKey: newKeyEntry.PublicKey, + }, nil +} + +// SignData creates a digital signature for the data to be signed +func (p *Plugin) SignData(ctx context.Context, req *keymanagerv0.SignDataRequest) (*keymanagerv0.SignDataResponse, error) { + if req.KeyId == "" { + return nil, status.Error(codes.InvalidArgument, "key id is required") + } + if req.SignerOpts == nil { + return nil, status.Error(codes.InvalidArgument, "signer opts is required") + } + + p.mu.RLock() + defer p.mu.RUnlock() + + keyEntry, hasKey := p.entries[req.KeyId] + if !hasKey { + return nil, status.Errorf(codes.NotFound, "no such key %q", req.KeyId) + } + + signingAlgo, err := signingAlgorithmForKMS(keyEntry.PublicKey.Type, req.SignerOpts) + if err != nil { + return nil, status.Error(codes.InvalidArgument, err.Error()) + } + + signResp, err := p.kmsClient.Sign(ctx, &kms.SignInput{ + KeyId: &keyEntry.Arn, + Message: req.Data, + MessageType: types.MessageTypeDigest, + SigningAlgorithm: signingAlgo, + }) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to sign: %v", err) + } + + return &keymanagerv0.SignDataResponse{Signature: signResp.Signature}, nil +} + +// GetPublicKey returns the public key for a given key +func (p *Plugin) GetPublicKey(ctx context.Context, req *keymanagerv0.GetPublicKeyRequest) (*keymanagerv0.GetPublicKeyResponse, error) { + if req.KeyId == "" { + return nil, status.Error(codes.InvalidArgument, "key id is required") + } + + p.mu.RLock() + defer p.mu.RUnlock() + + entry, ok := p.entries[req.KeyId] + if !ok { + return nil, status.Errorf(codes.NotFound, "no such key %q", req.KeyId) + } + + return &keymanagerv0.GetPublicKeyResponse{ + PublicKey: entry.PublicKey, + }, nil +} + +// GetPublicKeys return the publicKey for all the keys +func (p *Plugin) GetPublicKeys(context.Context, *keymanagerv0.GetPublicKeysRequest) (*keymanagerv0.GetPublicKeysResponse, error) { + var keys []*keymanagerv0.PublicKey + p.mu.RLock() + defer p.mu.RUnlock() + for _, key := range p.entries { + keys = append(keys, key.PublicKey) + } + + return &keymanagerv0.GetPublicKeysResponse{PublicKeys: keys}, nil +} + +// GetPluginInfo returns information about this plugin +func (p *Plugin) GetPluginInfo(context.Context, *plugin.GetPluginInfoRequest) (*plugin.GetPluginInfoResponse, error) { + return &plugin.GetPluginInfoResponse{}, nil +} + +func (p *Plugin) createKey(ctx context.Context, spireKeyID string, keyType keymanagerv0.KeyType) (*keyEntry, error) { + description := p.descriptionFromSpireKeyID(spireKeyID) + keySpec, ok := keySpecFromKeyType(keyType) + if !ok { + return nil, status.Errorf(codes.Internal, "unsupported key type: %v", keyType) + } + + createKeyInput := &kms.CreateKeyInput{ + Description: aws.String(description), + KeyUsage: types.KeyUsageTypeSignVerify, + CustomerMasterKeySpec: keySpec, + } + + key, err := p.kmsClient.CreateKey(ctx, createKeyInput) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to create key: %v", err) + } + if key == nil || key.KeyMetadata == nil || key.KeyMetadata.Arn == nil { + return nil, status.Error(codes.Internal, "malformed create key response") + } + p.log.Debug("Key created", keyArnTag, *key.KeyMetadata.Arn) + + pub, err := p.kmsClient.GetPublicKey(ctx, &kms.GetPublicKeyInput{KeyId: key.KeyMetadata.Arn}) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get public key: %v", err) + } + if pub == nil || pub.KeyId == nil || len(pub.PublicKey) == 0 { + return nil, status.Error(codes.Internal, "malformed get public key response") + } + + return &keyEntry{ + Arn: *key.KeyMetadata.Arn, + AliasName: p.aliasFromSpireKeyID(spireKeyID), + PublicKey: &keymanagerv0.PublicKey{ + Id: spireKeyID, + Type: keyType, + PkixData: pub.PublicKey, + }, + }, nil +} + +func (p *Plugin) assignAlias(ctx context.Context, entry *keyEntry) error { + oldEntry, hasOldEntry := p.entries[entry.PublicKey.Id] + + if !hasOldEntry { + // create alias + _, err := p.kmsClient.CreateAlias(ctx, &kms.CreateAliasInput{ + AliasName: aws.String(entry.AliasName), + TargetKeyId: &entry.Arn, + }) + if err != nil { + return status.Errorf(codes.Internal, "failed to create alias: %v", err) + } + p.log.Debug("Alias created", aliasNameTag, entry.AliasName, keyArnTag, entry.Arn) + } else { + // update alias + _, err := p.kmsClient.UpdateAlias(ctx, &kms.UpdateAliasInput{ + AliasName: aws.String(entry.AliasName), + TargetKeyId: &entry.Arn, + }) + if err != nil { + return status.Errorf(codes.Internal, "failed to update alias: %v", err) + } + p.log.Debug("Alias updated", aliasNameTag, entry.AliasName, keyArnTag, entry.Arn) + + select { + case p.scheduleDelete <- oldEntry.Arn: + p.log.Debug("Key enqueued for deletion", keyArnTag, oldEntry.Arn) + default: + p.log.Error("Failed to enqueue key for deletion", keyArnTag, oldEntry.Arn) + } + } + return nil +} + +func (p *Plugin) setCache(keyEntries []*keyEntry) { + // clean previous cache + p.entries = make(map[string]keyEntry) + + // add results to cache + for _, e := range keyEntries { + p.entries[e.PublicKey.Id] = *e + p.log.Debug("Key loaded", keyArnTag, e.Arn, aliasNameTag, e.AliasName) + } +} + +// scheduleDeleteTask ia a long running task that deletes keys that were rotated +func (p *Plugin) scheduleDeleteTask(ctx context.Context) { + backoffMin := 1 * time.Second + backoffMax := 60 * time.Second + backoff := backoffMin + + for { + select { + case <-ctx.Done(): + return + case keyArn := <-p.scheduleDelete: + log := p.log.With(keyArnTag, keyArn) + _, err := p.kmsClient.ScheduleKeyDeletion(ctx, &kms.ScheduleKeyDeletionInput{ + KeyId: aws.String(keyArn), + PendingWindowInDays: aws.Int32(7), + }) + + if err == nil { + log.Debug("Key deleted") + backoff = backoffMin + p.notifyDelete(nil) + continue + } + + var notFoundErr *types.NotFoundException + if errors.As(err, ¬FoundErr) { + log.Error("Failed to schedule key deletion", reasonTag, "No such key") + p.notifyDelete(err) + continue + } + + var invalidArnErr *types.InvalidArnException + if errors.As(err, &invalidArnErr) { + log.Error("Failed to schedule key deletion", reasonTag, "Invalid ARN") + p.notifyDelete(err) + continue + } + + var invalidState *types.KMSInvalidStateException + if errors.As(err, &invalidState) { + log.Error("Failed to schedule key deletion", reasonTag, "Key was on invalid state for deletion") + p.notifyDelete(err) + continue + } + + log.Error("It was not possible to schedule key for deletion", reasonTag, err) + select { + case p.scheduleDelete <- keyArn: + log.Debug("Key re-enqueued for deletion") + default: + log.Error("Failed to re-enqueue key for deletion") + } + p.notifyDelete(nil) + backoff = min(backoff*2, backoffMax) + p.hooks.clk.Sleep(backoff) + } + } +} + +// refreshAliasesTask will update the alias of all keys in the cache every 6 hours. +// Aliases will be updated to the same key they already have. +// The consequence of this is that the field LastUpdatedDate in each alias belonging to the server will be set to the current date. +// This is all with the goal of being able to detect keys that are not in use by any server. +func (p *Plugin) refreshAliasesTask(ctx context.Context) { + ticker := p.hooks.clk.Ticker(refreshAliasesFrequency) + defer ticker.Stop() + + p.notifyRefreshAliases(nil) + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + err := p.refreshAliases(ctx) + p.notifyRefreshAliases(err) + } + } +} + +func (p *Plugin) refreshAliases(ctx context.Context) error { + p.log.Debug("Refreshing aliases") + p.mu.RLock() + defer p.mu.RUnlock() + var errs []string + for _, entry := range p.entries { + _, err := p.kmsClient.UpdateAlias(ctx, &kms.UpdateAliasInput{ + AliasName: &entry.AliasName, + TargetKeyId: &entry.Arn, + }) + if err != nil { + p.log.Error("Failed to refresh alias", aliasNameTag, entry.AliasName, keyArnTag, entry.Arn, reasonTag, err) + errs = append(errs, err.Error()) + } + } + + if errs != nil { + return fmt.Errorf(strings.Join(errs, ": ")) + } + return nil +} + +// disposeAliasesTask will be run every 24hs. +// It will delete aliases that have a LastUpdatedDate value older than two weeks. +// It will also delete the keys associated with them. +// It will only delete aliases belonging to the current trust domain but not the current server. +// disposeAliasesTask relies on how aliases are built with prefixes to do all this. +// Alias example: `alias/SPIRE_SERVER/{TRUST_DOMAIN}/{SERVER_ID}/{KEY_ID}` +func (p *Plugin) disposeAliasesTask(ctx context.Context) { + ticker := p.hooks.clk.Ticker(disposeAliasesFrequency) + defer ticker.Stop() + + p.notifyDisposeAliases(nil) + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + err := p.disposeAliases(ctx) + p.notifyDisposeAliases(err) + } + } +} + +func (p *Plugin) disposeAliases(ctx context.Context) error { + p.log.Debug("Looking for aliases in trust domain to dispose") + paginator := kms.NewListAliasesPaginator(p.kmsClient, &kms.ListAliasesInput{Limit: aws.Int32(100)}) + var errs []string + + for { + aliasesResp, err := paginator.NextPage(ctx) + switch { + case err != nil: + p.log.Error("Failed to fetch aliases to dispose", reasonTag, err) + return err + case aliasesResp == nil: + p.log.Error("Failed to fetch aliases to dispose: nil response") + return err + } + + for _, alias := range aliasesResp.Aliases { + switch { + case alias.AliasName == nil || alias.LastUpdatedDate == nil || alias.AliasArn == nil: + continue + // if alias does not belong to trust domain skip + case !strings.HasPrefix(*alias.AliasName, p.aliasPrefixForTrustDomain()): + continue + // if alias belongs to current server skip + case strings.HasPrefix(*alias.AliasName, p.aliasPrefixForServer()): + continue + } + + now := p.hooks.clk.Now() + diff := now.Sub(*alias.LastUpdatedDate) + if diff < aliasThreshold { + continue + } + log := p.log.With(aliasNameTag, alias.AliasName) + log.Debug("Found alias in trust domain beyond threshold") + + describeResp, err := p.kmsClient.DescribeKey(ctx, &kms.DescribeKeyInput{KeyId: alias.AliasArn}) + switch { + case err != nil: + log.Error("Failed to clean up old KMS keys.", reasonTag, fmt.Errorf("AWS API DescribeKey failed: %v", err)) + errs = append(errs, err.Error()) + continue + case describeResp == nil || describeResp.KeyMetadata == nil || describeResp.KeyMetadata.Arn == nil: + log.Error("Failed to clean up old KMS keys", reasonTag, "Missing data in AWS API DescribeKey response") + continue + case !describeResp.KeyMetadata.Enabled: + continue + } + log = log.With(keyArnTag, *describeResp.KeyMetadata.Arn) + + _, err = p.kmsClient.DeleteAlias(ctx, &kms.DeleteAliasInput{AliasName: alias.AliasName}) + if err != nil { + log.Error("Failed to clean up old KMS keys.", reasonTag, fmt.Errorf("AWS API DeleteAlias failed: %v", err)) + errs = append(errs, err.Error()) + continue + } + + select { + case p.scheduleDelete <- *describeResp.KeyMetadata.Arn: + log.Debug("Key enqueued for deletion") + default: + log.Error("Failed to enqueue key for deletion") + } + } + + if !paginator.HasMorePages() { + break + } + } + + if errs != nil { + return fmt.Errorf(strings.Join(errs, ": ")) + } + + return nil +} + +// disposeKeysTask will be run every 48hs. +// It will delete keys that have a CreationDate value older than 48hs. +// It will only delete keys belonging to the current trust domain and without an alias. +// disposeKeysTask relies on how the keys description is built to do all this. +// Key description example: `SPIRE_SERVER/{TRUST_DOMAIN}` +// Keys belonging to a server should never be without an alias. +// The goal of this task is to remove keys that ended in this invalid state during a failure on alias assignment. +func (p *Plugin) disposeKeysTask(ctx context.Context) { + ticker := p.hooks.clk.Ticker(disposeKeysFrequency) + defer ticker.Stop() + + p.notifyDisposeKeys(nil) + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + err := p.disposeKeys(ctx) + p.notifyDisposeKeys(err) + } + } +} + +func (p *Plugin) disposeKeys(ctx context.Context) error { + p.log.Debug("Looking for keys in trust domain to dispose") + paginator := kms.NewListKeysPaginator(p.kmsClient, &kms.ListKeysInput{Limit: aws.Int32(1000)}) + var errs []string + + for { + keysResp, err := paginator.NextPage(ctx) + switch { + case err != nil: + p.log.Error("Failed to fetch keys to dispose", reasonTag, err) + return err + case keysResp == nil: + p.log.Error("Failed to fetch keys to dispose: nil response") + return err + } + + for _, key := range keysResp.Keys { + if key.KeyArn == nil { + continue + } + + log := p.log.With(keyArnTag, key.KeyArn) + + describeResp, err := p.kmsClient.DescribeKey(ctx, &kms.DescribeKeyInput{KeyId: key.KeyArn}) + switch { + case err != nil: + log.Error("Failed to describe key to dispose", reasonTag, err) + errs = append(errs, err.Error()) + continue + case describeResp == nil || + describeResp.KeyMetadata == nil || + describeResp.KeyMetadata.Description == nil || + describeResp.KeyMetadata.CreationDate == nil: + log.Error("Malformed describe key response while trying to dispose") + continue + case !describeResp.KeyMetadata.Enabled: + continue + } + + // if key does not belong to trust domain, skip it + if *describeResp.KeyMetadata.Description != p.descriptionPrefixForTrustDomain() { + continue + } + + // if key has alias, skip it + aliasesResp, err := p.kmsClient.ListAliases(ctx, &kms.ListAliasesInput{KeyId: key.KeyArn, Limit: aws.Int32(1)}) + switch { + case err != nil: + log.Error("Failed to fetch alias for key", reasonTag, err) + errs = append(errs, err.Error()) + continue + case aliasesResp == nil || len(aliasesResp.Aliases) > 0: + continue + } + + now := p.hooks.clk.Now() + diff := now.Sub(*describeResp.KeyMetadata.CreationDate) + if diff < keyThreshold { + continue + } + + log.Debug("Found key in trust domain beyond threshold") + + select { + case p.scheduleDelete <- *describeResp.KeyMetadata.Arn: + log.Debug("Key enqueued for deletion") + default: + log.Error("Failed to enqueue key for deletion") + } + } + + if !paginator.HasMorePages() { + break + } + } + if errs != nil { + return fmt.Errorf(strings.Join(errs, ": ")) + } + + return nil +} + +func (p *Plugin) aliasFromSpireKeyID(spireKeyID string) string { + return path.Join(p.aliasPrefixForServer(), spireKeyID) +} + +func (p *Plugin) descriptionFromSpireKeyID(spireKeyID string) string { + return path.Join(p.descriptionPrefixForTrustDomain(), spireKeyID) +} + +func (p *Plugin) descriptionPrefixForTrustDomain() string { + trustDomain := sanitizeTrustDomain(p.trustDomain) + return path.Join("SPIRE_SERVER_KEY/", trustDomain) +} + +func (p *Plugin) aliasPrefixForServer() string { + return path.Join(p.aliasPrefixForTrustDomain(), p.serverID) +} + +func (p *Plugin) aliasPrefixForTrustDomain() string { + trustDomain := sanitizeTrustDomain(p.trustDomain) + return path.Join(aliasPrefix, trustDomain) +} + +func (p *Plugin) notifyDelete(err error) { + if p.hooks.scheduleDeleteSignal != nil { + p.hooks.scheduleDeleteSignal <- err + } +} + +func (p *Plugin) notifyRefreshAliases(err error) { + if p.hooks.refreshAliasesSignal != nil { + p.hooks.refreshAliasesSignal <- err + } +} + +func (p *Plugin) notifyDisposeAliases(err error) { + if p.hooks.disposeAliasesSignal != nil { + p.hooks.disposeAliasesSignal <- err + } +} + +func (p *Plugin) notifyDisposeKeys(err error) { + if p.hooks.disposeKeysSignal != nil { + p.hooks.disposeKeysSignal <- err + } +} + +func sanitizeTrustDomain(trustDomain string) string { + return strings.Replace(trustDomain, ".", "_", -1) +} + +// parseAndValidateConfig returns an error if any configuration provided does not meet acceptable criteria +func parseAndValidateConfig(c string) (*Config, error) { + config := new(Config) + + if err := hcl.Decode(config, c); err != nil { + return nil, status.Errorf(codes.InvalidArgument, "unable to decode configuration: %v", err) + } + + if config.Region == "" { + return nil, status.Error(codes.InvalidArgument, "configuration is missing a region") + } + + if config.KeyMetadataFile == "" { + return nil, status.Error(codes.InvalidArgument, "configuration is missing server id file path") + } + + return config, nil +} + +func signingAlgorithmForKMS(keyType keymanagerv0.KeyType, signerOpts interface{}) (types.SigningAlgorithmSpec, error) { + var ( + hashAlgo keymanagerv0.HashAlgorithm + isPSS bool + ) + + switch opts := signerOpts.(type) { + case *keymanagerv0.SignDataRequest_HashAlgorithm: + hashAlgo = opts.HashAlgorithm + isPSS = false + case *keymanagerv0.SignDataRequest_PssOptions: + if opts.PssOptions == nil { + return "", errors.New("PSS options are required") + } + hashAlgo = opts.PssOptions.HashAlgorithm + isPSS = true + // opts.PssOptions.SaltLength is handled by KMS. The salt length matches the bits of the hashing algorithm. + default: + return "", fmt.Errorf("unsupported signer opts type %T", opts) + } + + isRSA := keyType == keymanagerv0.KeyType_RSA_2048 || keyType == keymanagerv0.KeyType_RSA_4096 + + switch { + case hashAlgo == keymanagerv0.HashAlgorithm_UNSPECIFIED_HASH_ALGORITHM: + return "", errors.New("hash algorithm is required") + case keyType == keymanagerv0.KeyType_EC_P256 && hashAlgo == keymanagerv0.HashAlgorithm_SHA256: + return types.SigningAlgorithmSpecEcdsaSha256, nil + case keyType == keymanagerv0.KeyType_EC_P384 && hashAlgo == keymanagerv0.HashAlgorithm_SHA384: + return types.SigningAlgorithmSpecEcdsaSha384, nil + case isRSA && !isPSS && hashAlgo == keymanagerv0.HashAlgorithm_SHA256: + return types.SigningAlgorithmSpecRsassaPkcs1V15Sha256, nil + case isRSA && !isPSS && hashAlgo == keymanagerv0.HashAlgorithm_SHA384: + return types.SigningAlgorithmSpecRsassaPkcs1V15Sha384, nil + case isRSA && !isPSS && hashAlgo == keymanagerv0.HashAlgorithm_SHA512: + return types.SigningAlgorithmSpecRsassaPkcs1V15Sha512, nil + case isRSA && isPSS && hashAlgo == keymanagerv0.HashAlgorithm_SHA256: + return types.SigningAlgorithmSpecRsassaPssSha256, nil + case isRSA && isPSS && hashAlgo == keymanagerv0.HashAlgorithm_SHA384: + return types.SigningAlgorithmSpecRsassaPssSha384, nil + case isRSA && isPSS && hashAlgo == keymanagerv0.HashAlgorithm_SHA512: + return types.SigningAlgorithmSpecRsassaPssSha512, nil + default: + return "", fmt.Errorf("unsupported combination of keytype: %v and hashing algorithm: %v", keyType, hashAlgo) + } +} + +func keyTypeFromKeySpec(keySpec types.CustomerMasterKeySpec) (keymanagerv0.KeyType, bool) { + switch keySpec { + case types.CustomerMasterKeySpecRsa2048: + return keymanagerv0.KeyType_RSA_2048, true + case types.CustomerMasterKeySpecRsa4096: + return keymanagerv0.KeyType_RSA_4096, true + case types.CustomerMasterKeySpecEccNistP256: + return keymanagerv0.KeyType_EC_P256, true + case types.CustomerMasterKeySpecEccNistP384: + return keymanagerv0.KeyType_EC_P384, true + default: + return keymanagerv0.KeyType_UNSPECIFIED_KEY_TYPE, false + } +} + +func keySpecFromKeyType(keyType keymanagerv0.KeyType) (types.CustomerMasterKeySpec, bool) { + switch keyType { + case keymanagerv0.KeyType_RSA_2048: + return types.CustomerMasterKeySpecRsa2048, true + case keymanagerv0.KeyType_RSA_4096: + return types.CustomerMasterKeySpecRsa4096, true + case keymanagerv0.KeyType_EC_P256: + return types.CustomerMasterKeySpecEccNistP256, true + case keymanagerv0.KeyType_EC_P384: + return types.CustomerMasterKeySpecEccNistP384, true + default: + return "", false + } +} + +func min(x, y time.Duration) time.Duration { + if x < y { + return x + } + return y +} + +func loadServerID(idPath string) (string, error) { + // get id from path + data, err := ioutil.ReadFile(idPath) + switch { + case errors.Is(err, os.ErrNotExist): + return createServerID(idPath) + case err != nil: + return "", status.Errorf(codes.Internal, "failed to read server id from path: %v", err) + } + + // validate what we got is a uuid + serverID, err := uuid.FromString(string(data)) + if err != nil { + return "", status.Errorf(codes.Internal, "failed to parse server id from path: %v", err) + } + return serverID.String(), nil +} + +func createServerID(idPath string) (string, error) { + // generate id + u, err := uuid.NewV4() + if err != nil { + return "", status.Errorf(codes.Internal, "failed to generate id for server: %v", err) + } + id := u.String() + + // persist id + err = ioutil.WriteFile(idPath, []byte(id), 0600) + if err != nil { + return "", status.Errorf(codes.Internal, "failed to persist server id on path: %v", err) + } + return id, nil +} diff --git a/pkg/server/plugin/keymanager/awskms/awskms_test.go b/pkg/server/plugin/keymanager/awskms/awskms_test.go new file mode 100644 index 0000000000..2dbdb78cb7 --- /dev/null +++ b/pkg/server/plugin/keymanager/awskms/awskms_test.go @@ -0,0 +1,1830 @@ +package awskms + +import ( + "context" + "errors" + "fmt" + "io/ioutil" + "path" + "testing" + "time" + + "github.com/andres-erbsen/clock" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/kms/types" + "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus/hooks/test" + "github.com/spiffe/spire/pkg/common/catalog" + "github.com/spiffe/spire/proto/spire/common/plugin" + keymanagerv0 "github.com/spiffe/spire/proto/spire/plugin/server/keymanager/v0" + "github.com/spiffe/spire/test/spiretest" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" +) + +const ( + // Defaults used for testing + validAccessKeyID = "AKIAIOSFODNN7EXAMPLE" + validSecretAccessKey = "secret" + validRegion = "us-west-2" + validServerIDFile = "server_id_test" + validServerID = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + keyID = "abcd-fghi" + KeyArn = "arn:aws:kms:region:1234:key/abcd-fghi" + aliasName = "alias/SPIRE_SERVER/test_example_org/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee/spireKeyID" + spireKeyID = "spireKeyID" + testTimeout = 60 * time.Second +) + +var ( + ctx = context.Background() + unixEpoch = time.Unix(0, 0) + refreshedDate = unixEpoch.Add(6 * time.Hour) +) + +type pluginTest struct { + plugin *Plugin + fakeClient *kmsClientFake + logHook *test.Hook + clockHook *clock.Mock +} + +func setupTest(t *testing.T) *pluginTest { + log, logHook := test.NewNullLogger() + log.Level = logrus.DebugLevel + + c := clock.NewMock() + fakeClient := newKMSClientFake(t, c) + kmsPlugin := newPlugin(func(ctx context.Context, c *Config) (kmsClient, error) { + return fakeClient, nil + }) + kmsCatalog := catalog.MakePlugin(pluginName, keymanagerv0.PluginServer(kmsPlugin)) + var km keymanagerv0.KeyManager + spiretest.LoadPlugin(t, kmsCatalog, &km, spiretest.Logger(log)) + + kmsPlugin.hooks.clk = c + + return &pluginTest{ + plugin: kmsPlugin, + fakeClient: fakeClient, + logHook: logHook, + clockHook: c, + } +} + +func TestConfigure(t *testing.T) { + for _, tt := range []struct { + name string + err string + code codes.Code + configureRequest *plugin.ConfigureRequest + fakeEntries []fakeKeyEntry + listAliasesErr string + describeKeyErr string + getPublicKeyErr string + }{ + + { + name: "pass with keys", + configureRequest: configureRequestWithDefaults(t), + fakeEntries: []fakeKeyEntry{ + { + AliasName: aws.String(aliasName), + KeyID: aws.String(keyID), + KeySpec: types.CustomerMasterKeySpecRsa4096, + Enabled: true, + PublicKey: []byte("foo"), + }, + { + AliasName: aws.String(aliasName + "01"), + KeyID: aws.String(keyID + "01"), + KeySpec: types.CustomerMasterKeySpecRsa2048, + Enabled: true, + PublicKey: []byte("foo"), + }, + { + AliasName: aws.String(aliasName + "02"), + KeyID: aws.String(keyID + "02"), + KeySpec: types.CustomerMasterKeySpecRsa4096, + Enabled: true, + PublicKey: []byte("foo"), + }, + { + AliasName: aws.String(aliasName + "03"), + KeyID: aws.String(keyID + "03"), + KeySpec: types.CustomerMasterKeySpecEccNistP256, + Enabled: true, + PublicKey: []byte("foo"), + }, + { + AliasName: aws.String(aliasName + "04"), + KeyID: aws.String(keyID + "04"), + KeySpec: types.CustomerMasterKeySpecEccNistP384, + Enabled: true, + PublicKey: []byte("foo"), + }, + { + AliasName: aws.String("alias/SPIRE_SERVER/wrong_prefix"), + KeyID: aws.String("foo_id"), + KeySpec: types.CustomerMasterKeySpecEccNistP384, + Enabled: true, + PublicKey: []byte("foo"), + }, + }, + }, + { + name: "pass without keys", + configureRequest: configureRequestWithDefaults(t), + }, + { + name: "missing access key id", + configureRequest: configureRequestWithVars("", "secret_access_key", "region", getKeyMetadataFile(t)), + }, + { + name: "missing secret access key", + configureRequest: configureRequestWithVars("access_key", "", "region", getKeyMetadataFile(t)), + }, + { + name: "missing region", + configureRequest: configureRequestWithVars("access_key_id", "secret_access_key", "", getKeyMetadataFile(t)), + err: "configuration is missing a region", + code: codes.InvalidArgument, + }, + { + name: "missing server id file path", + configureRequest: configureRequestWithVars("access_key_id", "secret_access_key", "region", ""), + err: "configuration is missing server id file path", + code: codes.InvalidArgument, + }, + { + name: "new server id file path", + configureRequest: configureRequestWithVars("access_key_id", "secret_access_key", "region", getEmptyKeyMetadataFile(t)), + }, + { + name: "decode error", + configureRequest: configureRequestWithString("{ malformed json }"), + err: "unable to decode configuration: 1:11: illegal char", + code: codes.InvalidArgument, + }, + { + name: "list aliases error", + err: "failed to fetch aliases: fake list aliases error", + code: codes.Internal, + configureRequest: configureRequestWithDefaults(t), + listAliasesErr: "fake list aliases error", + }, + { + name: "describe key error", + err: "failed to describe key: describe key error", + code: codes.Internal, + configureRequest: configureRequestWithDefaults(t), + fakeEntries: []fakeKeyEntry{ + { + AliasName: aws.String(aliasName), + KeyID: aws.String(keyID), + KeySpec: types.CustomerMasterKeySpecRsa2048, + Enabled: true, + PublicKey: []byte("foo"), + }, + }, + describeKeyErr: "describe key error", + }, + { + name: "unsupported key error", + err: "unsupported key spec: unsupported key spec", + code: codes.Internal, + configureRequest: configureRequestWithDefaults(t), + fakeEntries: []fakeKeyEntry{ + { + AliasName: aws.String(aliasName), + KeyID: aws.String(keyID), + KeySpec: "unsupported key spec", + Enabled: true, + PublicKey: []byte("foo"), + }, + }, + }, + { + name: "get public key error", + err: "failed to fetch aliases: failed to get public key: get public key error", + code: codes.Internal, + configureRequest: configureRequestWithDefaults(t), + fakeEntries: []fakeKeyEntry{ + { + AliasName: aws.String(aliasName), + KeyID: aws.String(keyID), + KeySpec: types.CustomerMasterKeySpecRsa4096, + Enabled: true, + PublicKey: []byte("foo"), + }, + }, + getPublicKeyErr: "get public key error", + }, + + { + name: "disabled key", + err: "failed to fetch aliases: found disabled SPIRE key: \"arn:aws:kms:region:1234:key/abcd-fghi\", alias: \"arn:aws:kms:region:1234:alias/SPIRE_SERVER/test_example_org/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee/spireKeyID\"", + code: codes.FailedPrecondition, + configureRequest: configureRequestWithDefaults(t), + fakeEntries: []fakeKeyEntry{ + { + AliasName: aws.String(aliasName), + KeyID: aws.String(keyID), + KeySpec: types.CustomerMasterKeySpecRsa4096, + Enabled: false, + PublicKey: []byte("foo"), + }, + }, + }, + } { + tt := tt + t.Run(tt.name, func(t *testing.T) { + // setup + ts := setupTest(t) + ts.fakeClient.setEntries(tt.fakeEntries) + ts.fakeClient.setListAliasesErr(tt.listAliasesErr) + ts.fakeClient.setDescribeKeyErr(tt.describeKeyErr) + ts.fakeClient.setgetPublicKeyErr(tt.getPublicKeyErr) + + // exercise + _, err := ts.plugin.Configure(ctx, tt.configureRequest) + + if tt.err != "" { + spiretest.RequireGRPCStatusContains(t, err, tt.code, tt.err) + return + } + + require.NoError(t, err) + }) + } +} + +func TestGenerateKey(t *testing.T) { + for _, tt := range []struct { + name string + err string + code codes.Code + logs []spiretest.LogEntry + waitForDelete bool + fakeEntries []fakeKeyEntry + request *keymanagerv0.GenerateKeyRequest + createKeyErr string + getPublicKeyErr string + scheduleKeyDeletionErr error + createAliasErr string + updateAliasErr string + }{ + { + name: "success: non existing key", + request: &keymanagerv0.GenerateKeyRequest{ + KeyId: spireKeyID, + KeyType: keymanagerv0.KeyType_EC_P256, + }, + }, + { + name: "success: replace old key", + request: &keymanagerv0.GenerateKeyRequest{ + KeyId: spireKeyID, + KeyType: keymanagerv0.KeyType_EC_P256, + }, + fakeEntries: []fakeKeyEntry{ + { + AliasName: aws.String(aliasName), + KeyID: aws.String(keyID), + KeySpec: types.CustomerMasterKeySpecEccNistP256, + Enabled: true, + PublicKey: []byte("foo"), + AliasLastUpdatedDate: &unixEpoch, + }, + }, + waitForDelete: true, + logs: []spiretest.LogEntry{ + { + Level: logrus.DebugLevel, + Message: "Key deleted", + Data: logrus.Fields{ + keyArnTag: KeyArn, + "subsystem_name": "built-in_plugin.aws_kms", + }, + }, + }, + }, + { + name: "success: EC 384", + request: &keymanagerv0.GenerateKeyRequest{ + KeyId: spireKeyID, + KeyType: keymanagerv0.KeyType_EC_P384, + }, + }, + { + name: "failure unsupported key spec", + err: "unsupported key type: RSA_1024", + code: codes.Internal, + request: &keymanagerv0.GenerateKeyRequest{ + KeyId: spireKeyID, + KeyType: keymanagerv0.KeyType_RSA_1024, + }, + }, + { + name: "success: RSA 2048", + request: &keymanagerv0.GenerateKeyRequest{ + KeyId: spireKeyID, + KeyType: keymanagerv0.KeyType_RSA_2048, + }, + }, + { + name: "success: RSA 4096", + request: &keymanagerv0.GenerateKeyRequest{ + KeyId: spireKeyID, + KeyType: keymanagerv0.KeyType_RSA_4096, + }, + }, + { + name: "missing key id", + request: &keymanagerv0.GenerateKeyRequest{ + KeyId: "", + KeyType: keymanagerv0.KeyType_EC_P256, + }, + err: "key id is required", + code: codes.InvalidArgument, + }, + { + name: "missing key type", + request: &keymanagerv0.GenerateKeyRequest{ + KeyId: spireKeyID, + KeyType: keymanagerv0.KeyType_UNSPECIFIED_KEY_TYPE, + }, + err: "key type is required", + code: codes.InvalidArgument, + }, + { + name: "create key error", + err: "failed to create key: something went wrong", + code: codes.Internal, + createKeyErr: "something went wrong", + request: &keymanagerv0.GenerateKeyRequest{ + KeyId: spireKeyID, + KeyType: keymanagerv0.KeyType_EC_P256, + }, + }, + { + name: "create alias error", + err: "failed to create alias: something went wrong", + code: codes.Internal, + createAliasErr: "something went wrong", + request: &keymanagerv0.GenerateKeyRequest{ + KeyId: spireKeyID, + KeyType: keymanagerv0.KeyType_EC_P256, + }, + }, + { + name: "update alias error", + err: "failed to update alias: something went wrong", + code: codes.Internal, + updateAliasErr: "something went wrong", + request: &keymanagerv0.GenerateKeyRequest{ + KeyId: spireKeyID, + KeyType: keymanagerv0.KeyType_EC_P256, + }, + fakeEntries: []fakeKeyEntry{ + { + AliasName: aws.String(aliasName), + KeyID: aws.String(keyID), + KeySpec: types.CustomerMasterKeySpecEccNistP256, + Enabled: true, + PublicKey: []byte("foo"), + }, + }, + }, + { + name: "get public key error", + err: "failed to get public key: public key error", + code: codes.Internal, + getPublicKeyErr: "public key error", + request: &keymanagerv0.GenerateKeyRequest{ + KeyId: spireKeyID, + KeyType: keymanagerv0.KeyType_EC_P256, + }, + }, + { + name: "schedule delete not found error", + request: &keymanagerv0.GenerateKeyRequest{ + KeyId: spireKeyID, + KeyType: keymanagerv0.KeyType_EC_P256, + }, + scheduleKeyDeletionErr: &types.NotFoundException{Message: aws.String("not found")}, + fakeEntries: []fakeKeyEntry{ + { + AliasName: aws.String(aliasName), + KeyID: aws.String(keyID), + KeySpec: types.CustomerMasterKeySpecEccNistP256, + Enabled: true, + PublicKey: []byte("foo"), + }, + }, + waitForDelete: true, + logs: []spiretest.LogEntry{ + { + Level: logrus.ErrorLevel, + Message: "Failed to schedule key deletion", + Data: logrus.Fields{ + reasonTag: "No such key", + keyArnTag: KeyArn, + "subsystem_name": "built-in_plugin.aws_kms", + }, + }, + }, + }, + { + name: "invalid arn error", + request: &keymanagerv0.GenerateKeyRequest{ + KeyId: spireKeyID, + KeyType: keymanagerv0.KeyType_EC_P256, + }, + scheduleKeyDeletionErr: &types.InvalidArnException{Message: aws.String("invalid arn")}, + fakeEntries: []fakeKeyEntry{ + { + AliasName: aws.String(aliasName), + KeyID: aws.String(keyID), + KeySpec: types.CustomerMasterKeySpecEccNistP256, + Enabled: true, + PublicKey: []byte("foo"), + }, + }, + waitForDelete: true, + logs: []spiretest.LogEntry{ + { + Level: logrus.ErrorLevel, + Message: "Failed to schedule key deletion", + Data: logrus.Fields{ + reasonTag: "Invalid ARN", + keyArnTag: KeyArn, + "subsystem_name": "built-in_plugin.aws_kms", + }, + }, + }, + }, + + { + name: "invalid key state error", + request: &keymanagerv0.GenerateKeyRequest{ + KeyId: spireKeyID, + KeyType: keymanagerv0.KeyType_EC_P256, + }, + scheduleKeyDeletionErr: &types.KMSInvalidStateException{Message: aws.String("invalid state")}, + fakeEntries: []fakeKeyEntry{ + { + AliasName: aws.String(aliasName), + KeyID: aws.String(keyID), + KeySpec: types.CustomerMasterKeySpecEccNistP256, + Enabled: true, + PublicKey: []byte("foo"), + }, + }, + waitForDelete: true, + logs: []spiretest.LogEntry{ + { + Level: logrus.ErrorLevel, + Message: "Failed to schedule key deletion", + Data: logrus.Fields{ + reasonTag: "Key was on invalid state for deletion", + keyArnTag: KeyArn, + "subsystem_name": "built-in_plugin.aws_kms", + }, + }, + }, + }, + + { + name: "schedule key deletion error", + scheduleKeyDeletionErr: errors.New("schedule key deletion error"), + request: &keymanagerv0.GenerateKeyRequest{ + KeyId: spireKeyID, + KeyType: keymanagerv0.KeyType_EC_P256, + }, + fakeEntries: []fakeKeyEntry{ + { + AliasName: aws.String(aliasName), + KeyID: aws.String(keyID), + KeySpec: types.CustomerMasterKeySpecEccNistP256, + Enabled: true, + PublicKey: []byte("foo"), + }, + }, + waitForDelete: true, + logs: []spiretest.LogEntry{ + { + Level: logrus.ErrorLevel, + Message: "It was not possible to schedule key for deletion", + Data: logrus.Fields{ + keyArnTag: KeyArn, + "reason": "schedule key deletion error", + "subsystem_name": "built-in_plugin.aws_kms", + }, + }, + { + Level: logrus.DebugLevel, + Message: "Key re-enqueued for deletion", + Data: logrus.Fields{ + keyArnTag: KeyArn, + "subsystem_name": "built-in_plugin.aws_kms", + }, + }, + }, + }, + } { + tt := tt + t.Run(tt.name, func(t *testing.T) { + // setup + ts := setupTest(t) + ts.fakeClient.setEntries(tt.fakeEntries) + ts.fakeClient.setCreateKeyErr(tt.createKeyErr) + ts.fakeClient.setCreateAliasesErr(tt.createAliasErr) + ts.fakeClient.setUpdateAliasErr(tt.updateAliasErr) + ts.fakeClient.setScheduleKeyDeletionErr(tt.scheduleKeyDeletionErr) + deleteSignal := make(chan error) + ts.plugin.hooks.scheduleDeleteSignal = deleteSignal + + _, err := ts.plugin.Configure(ctx, configureRequestWithDefaults(t)) + require.NoError(t, err) + + ts.fakeClient.setgetPublicKeyErr(tt.getPublicKeyErr) + + // exercise + resp, err := ts.plugin.GenerateKey(ctx, tt.request) + if tt.err != "" { + spiretest.RequireGRPCStatusContains(t, err, tt.code, tt.err) + return + } + + require.NoError(t, err) + require.NotNil(t, resp) + + if !tt.waitForDelete { + return + } + + select { + case <-deleteSignal: + spiretest.AssertLastLogs(t, ts.logHook.AllEntries(), tt.logs) + case <-time.After(testTimeout): + t.Fail() + } + }) + } +} + +func TestSignData(t *testing.T) { + for _, tt := range []struct { + name string + request *keymanagerv0.SignDataRequest + generateKeyRequest *keymanagerv0.GenerateKeyRequest + err string + code codes.Code + signDataError string + }{ + { + name: "pass EC SHA256", + request: &keymanagerv0.SignDataRequest{ + KeyId: spireKeyID, + Data: []byte("data"), + SignerOpts: &keymanagerv0.SignDataRequest_HashAlgorithm{ + HashAlgorithm: keymanagerv0.HashAlgorithm_SHA256, + }, + }, + generateKeyRequest: &keymanagerv0.GenerateKeyRequest{ + KeyId: spireKeyID, + KeyType: keymanagerv0.KeyType_EC_P256, + }, + }, + { + name: "pass EC SHA384", + request: &keymanagerv0.SignDataRequest{ + KeyId: spireKeyID, + Data: []byte("data"), + SignerOpts: &keymanagerv0.SignDataRequest_HashAlgorithm{ + HashAlgorithm: keymanagerv0.HashAlgorithm_SHA384, + }, + }, + generateKeyRequest: &keymanagerv0.GenerateKeyRequest{ + KeyId: spireKeyID, + KeyType: keymanagerv0.KeyType_EC_P384, + }, + }, + { + name: "pass RSA 2048 SHA 256", + request: &keymanagerv0.SignDataRequest{ + KeyId: spireKeyID, + Data: []byte("data"), + SignerOpts: &keymanagerv0.SignDataRequest_HashAlgorithm{ + HashAlgorithm: keymanagerv0.HashAlgorithm_SHA256, + }, + }, + generateKeyRequest: &keymanagerv0.GenerateKeyRequest{ + KeyId: spireKeyID, + KeyType: keymanagerv0.KeyType_RSA_2048, + }, + }, + { + name: "pass RSA 2048 SHA 384", + request: &keymanagerv0.SignDataRequest{ + KeyId: spireKeyID, + Data: []byte("data"), + SignerOpts: &keymanagerv0.SignDataRequest_HashAlgorithm{ + HashAlgorithm: keymanagerv0.HashAlgorithm_SHA384, + }, + }, + generateKeyRequest: &keymanagerv0.GenerateKeyRequest{ + KeyId: spireKeyID, + KeyType: keymanagerv0.KeyType_RSA_2048, + }, + }, + { + name: "pass RSA 2048 SHA 512", + request: &keymanagerv0.SignDataRequest{ + KeyId: spireKeyID, + Data: []byte("data"), + SignerOpts: &keymanagerv0.SignDataRequest_HashAlgorithm{ + HashAlgorithm: keymanagerv0.HashAlgorithm_SHA512, + }, + }, + generateKeyRequest: &keymanagerv0.GenerateKeyRequest{ + KeyId: spireKeyID, + KeyType: keymanagerv0.KeyType_RSA_2048, + }, + }, + { + name: "pass RSA PSS 2048 SHA 256", + request: &keymanagerv0.SignDataRequest{ + KeyId: spireKeyID, + Data: []byte("data"), + SignerOpts: &keymanagerv0.SignDataRequest_PssOptions{ + PssOptions: &keymanagerv0.PSSOptions{ + HashAlgorithm: keymanagerv0.HashAlgorithm_SHA256, + SaltLength: 256, + }, + }, + }, + generateKeyRequest: &keymanagerv0.GenerateKeyRequest{ + KeyId: spireKeyID, + KeyType: keymanagerv0.KeyType_RSA_2048, + }, + }, + { + name: "pass RSA PSS 2048 SHA 384", + request: &keymanagerv0.SignDataRequest{ + KeyId: spireKeyID, + Data: []byte("data"), + SignerOpts: &keymanagerv0.SignDataRequest_PssOptions{ + PssOptions: &keymanagerv0.PSSOptions{ + HashAlgorithm: keymanagerv0.HashAlgorithm_SHA384, + SaltLength: 384, + }, + }, + }, + generateKeyRequest: &keymanagerv0.GenerateKeyRequest{ + KeyId: spireKeyID, + KeyType: keymanagerv0.KeyType_RSA_2048, + }, + }, + { + name: "pass RSA PSS 2048 SHA 512", + request: &keymanagerv0.SignDataRequest{ + KeyId: spireKeyID, + Data: []byte("data"), + SignerOpts: &keymanagerv0.SignDataRequest_PssOptions{ + PssOptions: &keymanagerv0.PSSOptions{ + HashAlgorithm: keymanagerv0.HashAlgorithm_SHA512, + SaltLength: 512, + }, + }, + }, + generateKeyRequest: &keymanagerv0.GenerateKeyRequest{ + KeyId: spireKeyID, + KeyType: keymanagerv0.KeyType_RSA_2048, + }, + }, + { + name: "pass RSA 4096 SHA 256", + request: &keymanagerv0.SignDataRequest{ + KeyId: spireKeyID, + Data: []byte("data"), + SignerOpts: &keymanagerv0.SignDataRequest_HashAlgorithm{ + HashAlgorithm: keymanagerv0.HashAlgorithm_SHA256, + }, + }, + generateKeyRequest: &keymanagerv0.GenerateKeyRequest{ + KeyId: spireKeyID, + KeyType: keymanagerv0.KeyType_RSA_4096, + }, + }, + { + name: "pass RSA PSS 4096 SHA 256", + request: &keymanagerv0.SignDataRequest{ + KeyId: spireKeyID, + Data: []byte("data"), + SignerOpts: &keymanagerv0.SignDataRequest_PssOptions{ + PssOptions: &keymanagerv0.PSSOptions{ + HashAlgorithm: keymanagerv0.HashAlgorithm_SHA256, + SaltLength: 256, + }, + }, + }, + generateKeyRequest: &keymanagerv0.GenerateKeyRequest{ + KeyId: spireKeyID, + KeyType: keymanagerv0.KeyType_RSA_4096, + }, + }, + { + name: "missing key id", + request: &keymanagerv0.SignDataRequest{ + KeyId: "", + Data: []byte("data"), + SignerOpts: &keymanagerv0.SignDataRequest_HashAlgorithm{ + HashAlgorithm: keymanagerv0.HashAlgorithm_SHA256, + }, + }, + err: "key id is required", + code: codes.InvalidArgument, + }, + { + name: "missing key signer opts", + request: &keymanagerv0.SignDataRequest{ + KeyId: spireKeyID, + Data: []byte("data"), + }, + err: "signer opts is required", + code: codes.InvalidArgument, + }, + { + name: "missing hash algorithm", + request: &keymanagerv0.SignDataRequest{ + KeyId: spireKeyID, + Data: []byte("data"), + SignerOpts: &keymanagerv0.SignDataRequest_HashAlgorithm{ + HashAlgorithm: keymanagerv0.HashAlgorithm_UNSPECIFIED_HASH_ALGORITHM, + }, + }, + err: "hash algorithm is required", + code: codes.InvalidArgument, + generateKeyRequest: &keymanagerv0.GenerateKeyRequest{ + KeyId: spireKeyID, + KeyType: keymanagerv0.KeyType_EC_P256, + }, + }, + { + name: "unsupported combination", + request: &keymanagerv0.SignDataRequest{ + KeyId: spireKeyID, + Data: []byte("data"), + SignerOpts: &keymanagerv0.SignDataRequest_HashAlgorithm{ + HashAlgorithm: keymanagerv0.HashAlgorithm_SHA512, + }, + }, + err: "unsupported combination of keytype: EC_P256 and hashing algorithm: SHA512", + code: codes.InvalidArgument, + generateKeyRequest: &keymanagerv0.GenerateKeyRequest{ + KeyId: spireKeyID, + KeyType: keymanagerv0.KeyType_EC_P256, + }, + }, + { + name: "non existing key", + request: &keymanagerv0.SignDataRequest{ + KeyId: "does_not_exists", + Data: []byte("data"), + SignerOpts: &keymanagerv0.SignDataRequest_HashAlgorithm{ + HashAlgorithm: keymanagerv0.HashAlgorithm_SHA256, + }, + }, + err: "no such key \"does_not_exists\"", + code: codes.NotFound, + }, + { + name: "pss options nil", + request: &keymanagerv0.SignDataRequest{ + KeyId: spireKeyID, + Data: []byte("data"), + SignerOpts: &keymanagerv0.SignDataRequest_PssOptions{ + PssOptions: nil, + }, + }, + err: "PSS options are required", + code: codes.InvalidArgument, + generateKeyRequest: &keymanagerv0.GenerateKeyRequest{ + KeyId: spireKeyID, + KeyType: keymanagerv0.KeyType_RSA_2048, + }, + }, + { + name: "sign error", + err: "failed to sign: sign error", + code: codes.Internal, + signDataError: "sign error", + request: &keymanagerv0.SignDataRequest{ + KeyId: spireKeyID, + Data: []byte("data"), + SignerOpts: &keymanagerv0.SignDataRequest_HashAlgorithm{ + HashAlgorithm: keymanagerv0.HashAlgorithm_SHA256, + }, + }, + generateKeyRequest: &keymanagerv0.GenerateKeyRequest{ + KeyId: spireKeyID, + KeyType: keymanagerv0.KeyType_EC_P256, + }, + }, + } { + tt := tt + t.Run(tt.name, func(t *testing.T) { + // setup + ts := setupTest(t) + ts.fakeClient.setSignDataErr(tt.signDataError) + _, err := ts.plugin.Configure(ctx, configureRequestWithDefaults(t)) + require.NoError(t, err) + if tt.generateKeyRequest != nil { + _, err := ts.plugin.GenerateKey(ctx, tt.generateKeyRequest) + require.NoError(t, err) + } + + // exercise + resp, err := ts.plugin.SignData(ctx, tt.request) + if tt.err != "" { + spiretest.RequireGRPCStatusContains(t, err, tt.code, tt.err) + return + } + + require.NotNil(t, resp) + require.NoError(t, err) + }) + } +} + +func TestGetPublicKey(t *testing.T) { + for _, tt := range []struct { + name string + err string + code codes.Code + fakeEntries []fakeKeyEntry + + keyID string + }{ + { + name: "existing key", + keyID: spireKeyID, + fakeEntries: []fakeKeyEntry{ + + { + AliasName: aws.String(aliasName), + KeyID: aws.String(keyID), + KeySpec: types.CustomerMasterKeySpecRsa4096, + Enabled: true, + PublicKey: []byte("foo"), + }, + }, + }, + { + name: "non existing key", + err: "no such key \"spireKeyID\"", + code: codes.NotFound, + keyID: spireKeyID, + }, + { + name: "missing key id", + err: "key id is required", + code: codes.InvalidArgument, + }, + } { + tt := tt + t.Run(tt.name, func(t *testing.T) { + // setup + ts := setupTest(t) + ts.fakeClient.setEntries(tt.fakeEntries) + + _, err := ts.plugin.Configure(ctx, configureRequestWithDefaults(t)) + require.NoError(t, err) + + // exercise + resp, err := ts.plugin.GetPublicKey(ctx, &keymanagerv0.GetPublicKeyRequest{ + KeyId: tt.keyID, + }) + if tt.err != "" { + spiretest.RequireGRPCStatusContains(t, err, tt.code, tt.err) + return + } + require.NotNil(t, resp) + require.NoError(t, err) + }) + } +} + +func TestGetPublicKeys(t *testing.T) { + for _, tt := range []struct { + name string + err string + fakeEntries []fakeKeyEntry + }{ + { + name: "existing key", + fakeEntries: []fakeKeyEntry{ + + { + AliasName: aws.String(aliasName), + KeyID: aws.String(keyID), + KeySpec: types.CustomerMasterKeySpecRsa4096, + Enabled: true, + PublicKey: []byte("foo"), + }, + }, + }, + { + name: "non existing keys", + }, + } { + tt := tt + t.Run(tt.name, func(t *testing.T) { + // setup + ts := setupTest(t) + ts.fakeClient.setEntries(tt.fakeEntries) + _, err := ts.plugin.Configure(ctx, configureRequestWithDefaults(t)) + require.NoError(t, err) + + // exercise + resp, err := ts.plugin.GetPublicKeys(ctx, &keymanagerv0.GetPublicKeysRequest{}) + + if tt.err != "" { + require.Error(t, err) + require.Equal(t, err.Error(), tt.err) + return + } + + require.NotNil(t, resp) + require.NoError(t, err) + require.Equal(t, len(tt.fakeEntries), len(resp.PublicKeys)) + }) + } +} + +func TestGetPluginInfo(t *testing.T) { + for _, tt := range []struct { + name string + err string + + aliases []types.AliasListEntry + }{ + { + name: "pass", + aliases: []types.AliasListEntry{ + { + AliasName: aws.String(aliasName), + TargetKeyId: aws.String(keyID), + }, + }, + }, + } { + tt := tt + t.Run(tt.name, func(t *testing.T) { + //setup + ts := setupTest(t) + + //exercise + resp, err := ts.plugin.GetPluginInfo(ctx, &plugin.GetPluginInfoRequest{}) + + require.NotNil(t, resp) + require.NoError(t, err) + }) + } +} + +func TestRefreshAliases(t *testing.T) { + for _, tt := range []struct { + name string + configureRequest *plugin.ConfigureRequest + err string + fakeEntries []fakeKeyEntry + expectedEntries []fakeKeyEntry + updateAliasErr string + }{ + { + name: "refresh aliases error", + configureRequest: configureRequestWithDefaults(t), + err: "update failure", + updateAliasErr: "update failure", + fakeEntries: []fakeKeyEntry{ + { + AliasName: aws.String("alias/SPIRE_SERVER/test_example_org/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee/id_01"), + KeyID: aws.String("key_id_01"), + KeySpec: types.CustomerMasterKeySpecRsa4096, + Enabled: true, + PublicKey: []byte("foo"), + CreationDate: &unixEpoch, + AliasLastUpdatedDate: &unixEpoch, + }, + }, + }, + { + name: "refresh aliases succeeds", + configureRequest: configureRequestWithDefaults(t), + fakeEntries: []fakeKeyEntry{ + { + AliasName: aws.String("alias/SPIRE_SERVER/test_example_org/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee/id_01"), + KeyID: aws.String("key_id_01"), + KeySpec: types.CustomerMasterKeySpecRsa4096, + Enabled: true, + PublicKey: []byte("foo"), + CreationDate: &unixEpoch, + AliasLastUpdatedDate: &unixEpoch, + }, + { + AliasName: aws.String("alias/SPIRE_SERVER/test_example_org/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee/id_02"), + KeyID: aws.String("key_id_02"), + KeySpec: types.CustomerMasterKeySpecRsa2048, + Enabled: true, + PublicKey: []byte("foo"), + CreationDate: &unixEpoch, + AliasLastUpdatedDate: &unixEpoch, + }, + { + AliasName: aws.String("alias/SPIRE_SERVER/test_example_org/another_server_id/id_03"), + KeyID: aws.String("key_id_03"), + KeySpec: types.CustomerMasterKeySpecEccNistP384, + Enabled: true, + PublicKey: []byte("foo"), + CreationDate: &unixEpoch, + AliasLastUpdatedDate: &unixEpoch, + }, + { + AliasName: aws.String("alias/SPIRE_SERVER/another_td/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee/id_04"), + KeyID: aws.String("key_id_04"), + KeySpec: types.CustomerMasterKeySpecRsa4096, + Enabled: true, + PublicKey: []byte("foo"), + CreationDate: &unixEpoch, + AliasLastUpdatedDate: &unixEpoch, + }, + { + AliasName: aws.String("alias/SPIRE_SERVER/another_td/another_server_id/id_05"), + KeyID: aws.String("key_id_05"), + KeySpec: types.CustomerMasterKeySpecEccNistP384, + Enabled: true, + PublicKey: []byte("foo"), + CreationDate: &unixEpoch, + AliasLastUpdatedDate: &unixEpoch, + }, + { + AliasName: aws.String("alias/SPIRE_SERVER/unrelated"), + KeyID: aws.String("key_id_06"), + KeySpec: types.CustomerMasterKeySpecEccNistP384, + Enabled: true, + PublicKey: []byte("foo"), + CreationDate: &unixEpoch, + AliasLastUpdatedDate: &unixEpoch, + }, + { + AliasName: aws.String("alias/SPIRE_SERVER/unrelated/unrelated/id_07"), + KeyID: aws.String("key_id_07"), + KeySpec: types.CustomerMasterKeySpecEccNistP384, + Enabled: true, + PublicKey: []byte("foo"), + CreationDate: &unixEpoch, + AliasLastUpdatedDate: &unixEpoch, + }, + { + AliasName: nil, + KeyID: aws.String("key_id_08"), + KeySpec: types.CustomerMasterKeySpecRsa4096, + Enabled: true, + PublicKey: []byte("foo"), + CreationDate: &unixEpoch, + AliasLastUpdatedDate: &unixEpoch, + }, + }, + + expectedEntries: []fakeKeyEntry{ + { + AliasName: aws.String("alias/SPIRE_SERVER/test_example_org/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee/id_01"), + KeyID: aws.String("key_id_01"), + AliasLastUpdatedDate: &refreshedDate, + }, + { + AliasName: aws.String("alias/SPIRE_SERVER/test_example_org/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee/id_02"), + KeyID: aws.String("key_id_02"), + AliasLastUpdatedDate: &refreshedDate, + }, + { + AliasName: aws.String("alias/SPIRE_SERVER/test_example_org/another_server_id/id_03"), + KeyID: aws.String("key_id_03"), + AliasLastUpdatedDate: &unixEpoch, + }, + { + AliasName: aws.String("alias/SPIRE_SERVER/another_td/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee/id_04"), + KeyID: aws.String("key_id_04"), + AliasLastUpdatedDate: &unixEpoch, + }, + { + AliasName: aws.String("alias/SPIRE_SERVER/another_td/another_server_id/id_05"), + KeyID: aws.String("key_id_05"), + AliasLastUpdatedDate: &unixEpoch, + }, + { + AliasName: aws.String("alias/SPIRE_SERVER/unrelated"), + KeyID: aws.String("key_id_06"), + AliasLastUpdatedDate: &unixEpoch, + }, + { + AliasName: aws.String("alias/SPIRE_SERVER/unrelated/unrelated/id_07"), + KeyID: aws.String("key_id_07"), + AliasLastUpdatedDate: &unixEpoch, + }, + { + AliasName: nil, + KeyID: aws.String("key_id_08"), + AliasLastUpdatedDate: &unixEpoch, + }, + }, + }, + } { + tt := tt + t.Run(tt.name, func(t *testing.T) { + // setup + ts := setupTest(t) + ts.fakeClient.setEntries(tt.fakeEntries) + ts.fakeClient.setUpdateAliasErr(tt.updateAliasErr) + refreshAliasesSignal := make(chan error) + ts.plugin.hooks.refreshAliasesSignal = refreshAliasesSignal + + // exercise + _, err := ts.plugin.Configure(ctx, tt.configureRequest) + require.NoError(t, err) + + // wait for refresh alias task to be initialized + _ = waitForSignal(t, refreshAliasesSignal) + // move the clock forward so the task is run + ts.clockHook.Add(6 * time.Hour) + // wait for refresh aliases to be run + err = waitForSignal(t, refreshAliasesSignal) + + // assert + if tt.updateAliasErr != "" { + require.NotNil(t, err) + require.Equal(t, tt.err, err.Error()) + return + } + + require.NoError(t, err) + storedAliases := ts.fakeClient.store.aliases + require.Len(t, storedAliases, 7) + storedKeys := ts.fakeClient.store.keyEntries + require.Len(t, storedKeys, len(tt.expectedEntries)) + for _, expected := range tt.expectedEntries { + if expected.AliasName == nil { + continue + } + // check aliases + alias, ok := storedAliases[*expected.AliasName] + require.True(t, ok, "Expected alias was not present on end result: %q", *expected.AliasName) + require.EqualValues(t, expected.AliasLastUpdatedDate.String(), alias.KeyEntry.AliasLastUpdatedDate.String(), *expected.AliasName) + + // check keys + key, ok := storedKeys[*expected.KeyID] + require.True(t, ok, "Expected alias was not present on end result: %q", *expected.KeyID) + require.EqualValues(t, expected.AliasLastUpdatedDate.String(), key.AliasLastUpdatedDate.String(), *expected.KeyID) + } + }) + } +} + +func TestDisposeAliases(t *testing.T) { + for _, tt := range []struct { + name string + configureRequest *plugin.ConfigureRequest + err string + fakeEntries []fakeKeyEntry + expectedEntries []fakeKeyEntry + listAliasesErr string + describeKeyErr string + deleteAliasErr string + }{ + { + name: "dispose aliases succeeds", + configureRequest: configureRequestWithDefaults(t), + + fakeEntries: []fakeKeyEntry{ + { + AliasName: aws.String("alias/SPIRE_SERVER/test_example_org/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee/id_01"), + KeyID: aws.String("key_id_01"), + KeySpec: types.CustomerMasterKeySpecRsa4096, + Enabled: true, + PublicKey: []byte("foo"), + CreationDate: &unixEpoch, + AliasLastUpdatedDate: &unixEpoch, + }, + { + AliasName: aws.String("alias/SPIRE_SERVER/test_example_org/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee/id_02"), + KeyID: aws.String("key_id_02"), + KeySpec: types.CustomerMasterKeySpecRsa2048, + Enabled: true, + PublicKey: []byte("foo"), + CreationDate: &unixEpoch, + AliasLastUpdatedDate: &unixEpoch, + }, + { + AliasName: aws.String("alias/SPIRE_SERVER/test_example_org/another_server_id/id_03"), + KeyID: aws.String("key_id_03"), + KeySpec: types.CustomerMasterKeySpecEccNistP384, + Enabled: true, + PublicKey: []byte("foo"), + CreationDate: &unixEpoch, + AliasLastUpdatedDate: &unixEpoch, + }, + { + AliasName: aws.String("alias/SPIRE_SERVER/another_td/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee/id_04"), + KeyID: aws.String("key_id_04"), + KeySpec: types.CustomerMasterKeySpecRsa4096, + Enabled: true, + PublicKey: []byte("foo"), + CreationDate: &unixEpoch, + AliasLastUpdatedDate: &unixEpoch, + }, + { + AliasName: aws.String("alias/SPIRE_SERVER/another_td/another_server/id_05"), + KeyID: aws.String("key_id_05"), + KeySpec: types.CustomerMasterKeySpecEccNistP256, + Enabled: true, + PublicKey: []byte("foo"), + CreationDate: &unixEpoch, + AliasLastUpdatedDate: &unixEpoch, + }, + { + AliasName: aws.String("alias/SPIRE_SERVER/unrelated"), + KeyID: aws.String("key_id_06"), + KeySpec: types.CustomerMasterKeySpecEccNistP256, + Enabled: true, + PublicKey: []byte("foo"), + CreationDate: &unixEpoch, + AliasLastUpdatedDate: &unixEpoch, + }, + { + AliasName: aws.String("alias/SPIRE_SERVER/unrelated/unrelated/id_07"), + KeyID: aws.String("key_id_07"), + KeySpec: types.CustomerMasterKeySpecEccNistP256, + Enabled: true, + PublicKey: []byte("foo"), + CreationDate: &unixEpoch, + AliasLastUpdatedDate: &unixEpoch, + }, + { + AliasName: nil, + KeyID: aws.String("key_id_08"), + KeySpec: types.CustomerMasterKeySpecEccNistP256, + Enabled: true, + PublicKey: []byte("foo"), + CreationDate: &unixEpoch, + AliasLastUpdatedDate: &unixEpoch, + }, + { + AliasName: aws.String("alias/SPIRE_SERVER/test_example_org/another_server_id/id_09"), + KeyID: aws.String("key_id_09"), + KeySpec: types.CustomerMasterKeySpecEccNistP384, + Enabled: false, + PublicKey: []byte("foo"), + CreationDate: &unixEpoch, + AliasLastUpdatedDate: &unixEpoch, + }, + }, + + expectedEntries: []fakeKeyEntry{ + { + AliasName: aws.String("alias/SPIRE_SERVER/test_example_org/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee/id_01"), + KeyID: aws.String("key_id_01"), + }, + { + AliasName: aws.String("alias/SPIRE_SERVER/test_example_org/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee/id_02"), + KeyID: aws.String("key_id_02"), + }, + { + AliasName: aws.String("alias/SPIRE_SERVER/another_td/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee/id_04"), + KeyID: aws.String("key_id_04"), + }, + { + AliasName: aws.String("alias/SPIRE_SERVER/another_td/another_server/id_05"), + KeyID: aws.String("key_id_05"), + }, + { + AliasName: aws.String("alias/SPIRE_SERVER/unrelated"), + KeyID: aws.String("key_id_06"), + }, + { + AliasName: aws.String("alias/SPIRE_SERVER/unrelated/unrelated/id_07"), + KeyID: aws.String("key_id_07"), + }, + { + AliasName: aws.String("alias/SPIRE_SERVER/test_example_org/another_server_id/id_09"), + KeyID: aws.String("key_id_09"), + }, + }, + }, + { + name: "list aliases error", + configureRequest: configureRequestWithDefaults(t), + err: "list aliases failure", + listAliasesErr: "list aliases failure", + fakeEntries: []fakeKeyEntry{ + { + AliasName: aws.String("alias/SPIRE_SERVER/test_example_org/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee/id_01"), + KeyID: aws.String("key_id_01"), + KeySpec: types.CustomerMasterKeySpecRsa4096, + Enabled: true, + PublicKey: []byte("foo"), + CreationDate: &unixEpoch, + AliasLastUpdatedDate: &unixEpoch, + }, + }, + }, + { + name: "describe key error", + configureRequest: configureRequestWithDefaults(t), + err: "describe key failure", + describeKeyErr: "describe key failure", + fakeEntries: []fakeKeyEntry{ + { + AliasName: aws.String("alias/SPIRE_SERVER/test_example_org/another_server/id_01"), + KeyID: aws.String("key_id_01"), + KeySpec: types.CustomerMasterKeySpecRsa4096, + Enabled: true, + PublicKey: []byte("foo"), + CreationDate: &unixEpoch, + AliasLastUpdatedDate: &unixEpoch, + }, + }, + }, + { + name: "delete alias error", + configureRequest: configureRequestWithDefaults(t), + err: "delete alias failure", + deleteAliasErr: "delete alias failure", + fakeEntries: []fakeKeyEntry{ + { + AliasName: aws.String("alias/SPIRE_SERVER/test_example_org/another_server/id_01"), + KeyID: aws.String("key_id_01"), + KeySpec: types.CustomerMasterKeySpecRsa4096, + Enabled: true, + PublicKey: []byte("foo"), + CreationDate: &unixEpoch, + AliasLastUpdatedDate: &unixEpoch, + }, + }, + }, + } { + tt := tt + t.Run(tt.name, func(t *testing.T) { + // setup + ts := setupTest(t) + ts.fakeClient.setEntries(tt.fakeEntries) + // this is so dispose keys blocks on init and allows to test dispose aliases isolated + ts.plugin.hooks.disposeKeysSignal = make(chan error) + disposeAliasesSignal := make(chan error) + ts.plugin.hooks.disposeAliasesSignal = disposeAliasesSignal + deleteSignal := make(chan error) + ts.plugin.hooks.scheduleDeleteSignal = deleteSignal + + // exercise + _, err := ts.plugin.Configure(ctx, tt.configureRequest) + require.NoError(t, err) + + ts.fakeClient.setListAliasesErr(tt.listAliasesErr) + ts.fakeClient.setDescribeKeyErr(tt.describeKeyErr) + ts.fakeClient.setDeleteAliasErr(tt.deleteAliasErr) + + // wait for dispose aliases task to be initialized + _ = waitForSignal(t, disposeAliasesSignal) + // move the clock forward so the task is run + ts.clockHook.Add(aliasThreshold) + // wait for dispose aliases to be run + // first run at 24hs won't dispose keys due to threshold being two weeks + _ = waitForSignal(t, disposeAliasesSignal) + // wait for dispose aliases to be run + err = waitForSignal(t, disposeAliasesSignal) + //assert errors + if tt.err != "" { + require.NotNil(t, err) + require.Equal(t, tt.err, err.Error()) + return + } + // wait for schedule delete to be run + _ = waitForSignal(t, deleteSignal) + //assert end result + require.NoError(t, err) + storedAliases := ts.fakeClient.store.aliases + require.Len(t, storedAliases, 7) + storedKeys := ts.fakeClient.store.keyEntries + require.Len(t, storedKeys, 8) + + for _, expected := range tt.expectedEntries { + if expected.AliasName == nil { + continue + } + // check aliases + _, ok := storedAliases[*expected.AliasName] + require.True(t, ok, "Expected alias was not present on end result: %q", *expected.AliasName) + // check keys + _, ok = storedKeys[*expected.KeyID] + require.True(t, ok, "Expected alias was not present on end result: %q", *expected.KeyID) + } + }) + } +} + +func TestDisposeKeys(t *testing.T) { + for _, tt := range []struct { + name string + configureRequest *plugin.ConfigureRequest + err string + fakeEntries []fakeKeyEntry + expectedEntries []fakeKeyEntry + listKeysErr string + describeKeyErr string + listAliasesErr string + }{ + { + name: "dispose keys succeeds", + configureRequest: configureRequestWithDefaults(t), + + fakeEntries: []fakeKeyEntry{ + { + AliasName: nil, + KeyID: aws.String("key_id_01"), + Description: aws.String("SPIRE_SERVER_KEY/test_example_org"), + KeySpec: types.CustomerMasterKeySpecRsa4096, + Enabled: true, + PublicKey: []byte("foo"), + CreationDate: &unixEpoch, + AliasLastUpdatedDate: &unixEpoch, + }, + { + AliasName: aws.String("alias/SPIRE_SERVER/test_example_org/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee/id_02"), + KeyID: aws.String("key_id_02"), + Description: aws.String("SPIRE_SERVER_KEY/test_example_org"), + KeySpec: types.CustomerMasterKeySpecRsa2048, + Enabled: true, + PublicKey: []byte("foo"), + CreationDate: &unixEpoch, + AliasLastUpdatedDate: &unixEpoch, + }, + { + AliasName: aws.String("alias/SPIRE_SERVER/test_example_org/another_server_id/id_03"), + KeyID: aws.String("key_id_03"), + Description: aws.String("SPIRE_SERVER_KEY/test_example_org"), + KeySpec: types.CustomerMasterKeySpecEccNistP384, + Enabled: true, + PublicKey: []byte("foo"), + CreationDate: &unixEpoch, + AliasLastUpdatedDate: &unixEpoch, + }, + { + AliasName: aws.String("alias/SPIRE_SERVER/another_td/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee/id_04"), + KeyID: aws.String("key_id_04"), + Description: aws.String("SPIRE_SERVER_KEY/another_td"), + KeySpec: types.CustomerMasterKeySpecRsa4096, + Enabled: true, + PublicKey: []byte("foo"), + CreationDate: &unixEpoch, + AliasLastUpdatedDate: &unixEpoch, + }, + { + AliasName: aws.String("alias/SPIRE_SERVER/another_td/another_server_id/id_05"), + KeyID: aws.String("key_id_05"), + Description: aws.String("SPIRE_SERVER_KEY/another_td"), + KeySpec: types.CustomerMasterKeySpecEccNistP256, + Enabled: true, + PublicKey: []byte("foo"), + CreationDate: &unixEpoch, + AliasLastUpdatedDate: &unixEpoch, + }, + { + AliasName: aws.String("alias/SPIRE_SERVER/unrelated"), + KeyID: aws.String("key_id_06"), + Description: nil, + KeySpec: types.CustomerMasterKeySpecEccNistP256, + Enabled: true, + PublicKey: []byte("foo"), + CreationDate: &unixEpoch, + AliasLastUpdatedDate: &unixEpoch, + }, + { + AliasName: aws.String("alias/SPIRE_SERVER/unrelated/unrelated/id_07"), + KeyID: aws.String("key_id_07"), + Description: nil, + KeySpec: types.CustomerMasterKeySpecEccNistP384, + Enabled: true, + PublicKey: []byte("foo"), + CreationDate: &unixEpoch, + AliasLastUpdatedDate: &unixEpoch, + }, + { + AliasName: nil, + KeyID: aws.String("key_id_08"), + Description: nil, + KeySpec: types.CustomerMasterKeySpecRsa4096, + Enabled: true, + PublicKey: []byte("foo"), + CreationDate: &unixEpoch, + AliasLastUpdatedDate: &unixEpoch, + }, + { + AliasName: aws.String("alias/SPIRE_SERVER/test_example_org/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee/id_01"), + KeyID: aws.String("key_id_09"), + Description: aws.String("SPIRE_SERVER_KEY/test_example_org"), + KeySpec: types.CustomerMasterKeySpecRsa4096, + Enabled: true, + PublicKey: []byte("foo"), + CreationDate: &unixEpoch, + AliasLastUpdatedDate: &unixEpoch, + }, + { + AliasName: nil, + KeyID: aws.String("key_id_10"), + Description: aws.String("SPIRE_SERVER_KEY/another_td"), + KeySpec: types.CustomerMasterKeySpecRsa4096, + Enabled: true, + PublicKey: []byte("foo"), + CreationDate: &unixEpoch, + AliasLastUpdatedDate: &unixEpoch, + }, + { + AliasName: nil, + KeyID: aws.String("key_id_11"), + Description: aws.String("SPIRE_SERVER_KEY/"), + KeySpec: types.CustomerMasterKeySpecRsa4096, + Enabled: true, + PublicKey: []byte("foo"), + CreationDate: &unixEpoch, + AliasLastUpdatedDate: &unixEpoch, + }, + { + AliasName: nil, + KeyID: aws.String("key_id_12"), + Description: aws.String("SPIRE_SERVER_KEY"), + KeySpec: types.CustomerMasterKeySpecRsa4096, + Enabled: true, + PublicKey: []byte("foo"), + CreationDate: &unixEpoch, + AliasLastUpdatedDate: &unixEpoch, + }, + { + AliasName: nil, + KeyID: aws.String("key_id_13"), + Description: aws.String("test_example_org"), + KeySpec: types.CustomerMasterKeySpecRsa4096, + Enabled: true, + PublicKey: []byte("foo"), + CreationDate: &unixEpoch, + AliasLastUpdatedDate: &unixEpoch, + }, + { + AliasName: nil, + KeyID: aws.String("key_id_14"), + Description: aws.String("unrelated"), + KeySpec: types.CustomerMasterKeySpecRsa4096, + Enabled: true, + PublicKey: []byte("foo"), + CreationDate: &unixEpoch, + AliasLastUpdatedDate: &unixEpoch, + }, + { + AliasName: nil, + KeyID: aws.String("key_id_15"), + Description: aws.String("disabled key"), + KeySpec: types.CustomerMasterKeySpecRsa4096, + Enabled: false, + PublicKey: []byte("foo"), + CreationDate: &unixEpoch, + AliasLastUpdatedDate: &unixEpoch, + }, + { + AliasName: nil, + KeyID: aws.String("key_id_16"), + Description: aws.String("SPIRE_SERVER_KEY/test_example_org/extra"), + KeySpec: types.CustomerMasterKeySpecRsa4096, + Enabled: true, + PublicKey: []byte("foo"), + CreationDate: &unixEpoch, + AliasLastUpdatedDate: &unixEpoch, + }, + }, + + expectedEntries: []fakeKeyEntry{ + { + KeyID: aws.String("key_id_02"), + }, + { + KeyID: aws.String("key_id_03"), + }, + { + KeyID: aws.String("key_id_04"), + }, + { + KeyID: aws.String("key_id_05"), + }, + { + KeyID: aws.String("key_id_06"), + }, + { + KeyID: aws.String("key_id_07"), + }, + { + KeyID: aws.String("key_id_08"), + }, + { + KeyID: aws.String("key_id_09"), + }, + { + KeyID: aws.String("key_id_10"), + }, + { + KeyID: aws.String("key_id_11"), + }, + { + KeyID: aws.String("key_id_12"), + }, + { + KeyID: aws.String("key_id_13"), + }, + { + KeyID: aws.String("key_id_14"), + }, + { + KeyID: aws.String("key_id_15"), + }, + { + KeyID: aws.String("key_id_16"), + }, + }, + }, + { + name: "list keys error", + configureRequest: configureRequestWithDefaults(t), + err: "list keys failure", + listKeysErr: "list keys failure", + fakeEntries: []fakeKeyEntry{ + { + AliasName: nil, + KeyID: aws.String("key_id_01"), + Description: aws.String("SPIRE_SERVER_KEY/test_example_org"), + KeySpec: types.CustomerMasterKeySpecRsa4096, + Enabled: true, + PublicKey: []byte("foo"), + CreationDate: &unixEpoch, + AliasLastUpdatedDate: &unixEpoch, + }, + }, + }, + { + name: "list aliases error", + configureRequest: configureRequestWithDefaults(t), + err: "list aliases failure", + listAliasesErr: "list aliases failure", + fakeEntries: []fakeKeyEntry{ + { + AliasName: nil, + KeyID: aws.String("key_id_01"), + Description: aws.String("SPIRE_SERVER_KEY/test_example_org"), + KeySpec: types.CustomerMasterKeySpecRsa4096, + Enabled: true, + PublicKey: []byte("foo"), + CreationDate: &unixEpoch, + AliasLastUpdatedDate: &unixEpoch, + }, + }, + }, + { + name: "describe key error", + configureRequest: configureRequestWithDefaults(t), + err: "describe key failure", + describeKeyErr: "describe key failure", + fakeEntries: []fakeKeyEntry{ + { + AliasName: nil, + KeyID: aws.String("key_id_01"), + Description: aws.String("SPIRE_SERVER_KEY/test_example_org"), + KeySpec: types.CustomerMasterKeySpecRsa4096, + Enabled: true, + PublicKey: []byte("foo"), + CreationDate: &unixEpoch, + AliasLastUpdatedDate: &unixEpoch, + }, + }, + }, + } { + tt := tt + t.Run(tt.name, func(t *testing.T) { + // setup + ts := setupTest(t) + ts.fakeClient.setEntries(tt.fakeEntries) + + // this is so dispose aliases blocks on init and allows to test dispose keys isolated + ts.plugin.hooks.disposeAliasesSignal = make(chan error) + disposeKeysSignal := make(chan error) + ts.plugin.hooks.disposeKeysSignal = disposeKeysSignal + deleteSignal := make(chan error) + ts.plugin.hooks.scheduleDeleteSignal = deleteSignal + + // exercise + _, err := ts.plugin.Configure(ctx, tt.configureRequest) + require.NoError(t, err) + + ts.fakeClient.setListKeysErr(tt.listKeysErr) + ts.fakeClient.setDescribeKeyErr(tt.describeKeyErr) + ts.fakeClient.setListAliasesErr(tt.listAliasesErr) + + // wait for dispose kesy task to be initialized + _ = waitForSignal(t, disposeKeysSignal) + // move the clock forward so the task is run + ts.clockHook.Add(48 * time.Hour) + // wait for dispose keys to be run + err = waitForSignal(t, disposeKeysSignal) + //assert errors + if tt.err != "" { + require.NotNil(t, err) + require.Equal(t, tt.err, err.Error()) + return + } + // wait for schedule delete to be run + _ = waitForSignal(t, deleteSignal) + + // assert + storedKeys := ts.fakeClient.store.keyEntries + require.Len(t, storedKeys, len(tt.expectedEntries)) + for _, expected := range tt.expectedEntries { + _, ok := storedKeys[*expected.KeyID] + require.True(t, ok, "Expected key was not present on end result: %q", *expected.KeyID) + } + }) + } +} + +func configureRequestWithString(config string) *plugin.ConfigureRequest { + return &plugin.ConfigureRequest{ + Configuration: config, + } +} + +func configureRequestWithVars(accessKeyID, secretAccessKey, region, keyMetadataFile string) *plugin.ConfigureRequest { + return &plugin.ConfigureRequest{ + Configuration: fmt.Sprintf(`{ + "access_key_id": "%s", + "secret_access_key": "%s", + "region":"%s", + "key_metadata_file":"%s" + }`, + accessKeyID, + secretAccessKey, + region, + keyMetadataFile), + GlobalConfig: &plugin.ConfigureRequest_GlobalConfig{TrustDomain: "test.example.org"}, + } +} + +func configureRequestWithDefaults(t *testing.T) *plugin.ConfigureRequest { + return &plugin.ConfigureRequest{ + Configuration: serializedConfiguration(validAccessKeyID, validSecretAccessKey, validRegion, getKeyMetadataFile(t)), + GlobalConfig: &plugin.ConfigureRequest_GlobalConfig{TrustDomain: "test.example.org"}, + } +} + +func serializedConfiguration(accessKeyID, secretAccessKey, region string, keyMetadataFile string) string { + return fmt.Sprintf(`{ + "access_key_id": "%s", + "secret_access_key": "%s", + "region":"%s", + "key_metadata_file":"%s" + }`, + accessKeyID, + secretAccessKey, + region, + keyMetadataFile) +} + +func getKeyMetadataFile(t *testing.T) string { + tempDir := t.TempDir() + tempFilePath := path.Join(tempDir, validServerIDFile) + err := ioutil.WriteFile(tempFilePath, []byte(validServerID), 0600) + if err != nil { + t.Error(err) + } + return tempFilePath +} + +func getEmptyKeyMetadataFile(t *testing.T) string { + tempDir := t.TempDir() + return path.Join(tempDir, validServerIDFile) +} + +func waitForSignal(t *testing.T, ch chan error) error { + select { + case err := <-ch: + return err + case <-time.After(testTimeout): + t.Fail() + } + return nil +} diff --git a/pkg/server/plugin/keymanager/awskms/client.go b/pkg/server/plugin/keymanager/awskms/client.go new file mode 100644 index 0000000000..6692c5f2d4 --- /dev/null +++ b/pkg/server/plugin/keymanager/awskms/client.go @@ -0,0 +1,37 @@ +package awskms + +import ( + "context" + + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/kms" +) + +type kmsClient interface { + CreateKey(context.Context, *kms.CreateKeyInput, ...func(*kms.Options)) (*kms.CreateKeyOutput, error) + DescribeKey(context.Context, *kms.DescribeKeyInput, ...func(*kms.Options)) (*kms.DescribeKeyOutput, error) + CreateAlias(context.Context, *kms.CreateAliasInput, ...func(*kms.Options)) (*kms.CreateAliasOutput, error) + UpdateAlias(context.Context, *kms.UpdateAliasInput, ...func(*kms.Options)) (*kms.UpdateAliasOutput, error) + GetPublicKey(context.Context, *kms.GetPublicKeyInput, ...func(*kms.Options)) (*kms.GetPublicKeyOutput, error) + ListAliases(context.Context, *kms.ListAliasesInput, ...func(*kms.Options)) (*kms.ListAliasesOutput, error) + ScheduleKeyDeletion(context.Context, *kms.ScheduleKeyDeletionInput, ...func(*kms.Options)) (*kms.ScheduleKeyDeletionOutput, error) + Sign(context.Context, *kms.SignInput, ...func(*kms.Options)) (*kms.SignOutput, error) + ListKeys(context.Context, *kms.ListKeysInput, ...func(*kms.Options)) (*kms.ListKeysOutput, error) + DeleteAlias(context.Context, *kms.DeleteAliasInput, ...func(*kms.Options)) (*kms.DeleteAliasOutput, error) +} + +func newKMSClient(ctx context.Context, c *Config) (kmsClient, error) { + cfg, err := config.LoadDefaultConfig(ctx, + config.WithRegion(c.Region), + ) + if err != nil { + return nil, err + } + + if c.SecretAccessKey != "" && c.AccessKeyID != "" { + cfg.Credentials = credentials.NewStaticCredentialsProvider(c.AccessKeyID, c.SecretAccessKey, "") + } + + return kms.NewFromConfig(cfg), nil +} diff --git a/pkg/server/plugin/keymanager/awskms/client_fake.go b/pkg/server/plugin/keymanager/awskms/client_fake.go new file mode 100644 index 0000000000..f9dad7564d --- /dev/null +++ b/pkg/server/plugin/keymanager/awskms/client_fake.go @@ -0,0 +1,546 @@ +package awskms + +import ( + "context" + "crypto" + "crypto/x509" + "errors" + "fmt" + "strconv" + "strings" + "sync" + "testing" + "time" + + "github.com/andres-erbsen/clock" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/kms" + "github.com/aws/aws-sdk-go-v2/service/kms/types" + "github.com/spiffe/spire/test/testkey" +) + +type kmsClientFake struct { + t *testing.T + store fakeStore + mu sync.RWMutex + testKeys testkey.Keys + createKeyErr error + describeKeyErr error + getPublicKeyErr error + listAliasesErr error + createAliasErr error + updateAliasErr error + scheduleKeyDeletionErr error + signErr error + listKeysErr error + deleteAliasErr error +} + +func newKMSClientFake(t *testing.T, c *clock.Mock) *kmsClientFake { + return &kmsClientFake{ + t: t, + store: newFakeStore(c), + } +} + +func (k *kmsClientFake) CreateKey(ctx context.Context, input *kms.CreateKeyInput, opts ...func(*kms.Options)) (*kms.CreateKeyOutput, error) { + k.mu.RLock() + defer k.mu.RUnlock() + if k.createKeyErr != nil { + return nil, k.createKeyErr + } + + var privateKey crypto.PrivateKey + var publicKey crypto.PublicKey + + switch input.CustomerMasterKeySpec { + case types.CustomerMasterKeySpecEccNistP256: + key := k.testKeys.NewEC256(k.t) + privateKey = key + publicKey = &key.PublicKey + case types.CustomerMasterKeySpecEccNistP384: + key := k.testKeys.NewEC384(k.t) + privateKey = key + publicKey = &key.PublicKey + case types.CustomerMasterKeySpecRsa2048: + key := k.testKeys.NewRSA2048(k.t) + privateKey = key + publicKey = &key.PublicKey + case types.CustomerMasterKeySpecRsa4096: + key := k.testKeys.NewRSA4096(k.t) + privateKey = key + publicKey = &key.PublicKey + default: + return nil, fmt.Errorf("unknown key type %q", input.CustomerMasterKeySpec) + } + + pkixData, err := x509.MarshalPKIXPublicKey(publicKey) + if err != nil { + return nil, err + } + + keyEntry := &fakeKeyEntry{ + Description: input.Description, + CreationDate: aws.Time(time.Unix(0, 0)), + PublicKey: pkixData, + privateKey: privateKey, + KeySpec: input.CustomerMasterKeySpec, + Enabled: true, + } + + k.store.SaveKeyEntry(keyEntry) + + return &kms.CreateKeyOutput{ + KeyMetadata: &types.KeyMetadata{ + KeyId: keyEntry.KeyID, + Arn: keyEntry.Arn, + Description: keyEntry.Description, + CreationDate: keyEntry.CreationDate, + }, + }, nil +} + +func (k *kmsClientFake) DescribeKey(ctx context.Context, input *kms.DescribeKeyInput, opts ...func(*kms.Options)) (*kms.DescribeKeyOutput, error) { + k.mu.RLock() + defer k.mu.RUnlock() + if k.describeKeyErr != nil { + return nil, k.describeKeyErr + } + + keyEntry, err := k.store.FetchKeyEntry(*input.KeyId) + if err != nil { + return nil, err + } + + return &kms.DescribeKeyOutput{ + KeyMetadata: &types.KeyMetadata{ + KeyId: keyEntry.KeyID, + Arn: keyEntry.Arn, + CustomerMasterKeySpec: keyEntry.KeySpec, + Enabled: keyEntry.Enabled, + Description: keyEntry.Description, + CreationDate: keyEntry.CreationDate, + }, + }, nil +} + +func (k *kmsClientFake) GetPublicKey(ctx context.Context, input *kms.GetPublicKeyInput, opts ...func(*kms.Options)) (*kms.GetPublicKeyOutput, error) { + k.mu.RLock() + defer k.mu.RUnlock() + if k.getPublicKeyErr != nil { + return nil, k.getPublicKeyErr + } + + keyEntry, err := k.store.FetchKeyEntry(*input.KeyId) + if err != nil { + return nil, err + } + + return &kms.GetPublicKeyOutput{ + KeyId: keyEntry.KeyID, + PublicKey: keyEntry.PublicKey, + }, nil +} + +func (k *kmsClientFake) ListAliases(ctw context.Context, input *kms.ListAliasesInput, opts ...func(*kms.Options)) (*kms.ListAliasesOutput, error) { + k.mu.RLock() + defer k.mu.RUnlock() + if k.listAliasesErr != nil { + return nil, k.listAliasesErr + } + + if input.KeyId != nil { + keyEntry, err := k.store.FetchKeyEntry(*input.KeyId) + switch { + case err != nil: + return nil, err + case keyEntry.AliasName != nil: + aliasesResp := []types.AliasListEntry{{ + AliasName: keyEntry.AliasName, + AliasArn: aws.String(aliasArnFromAliasName(*keyEntry.AliasName)), + TargetKeyId: keyEntry.KeyID, + LastUpdatedDate: keyEntry.AliasLastUpdatedDate, + }} + return &kms.ListAliasesOutput{Aliases: aliasesResp}, nil + default: + return &kms.ListAliasesOutput{Aliases: []types.AliasListEntry{}}, nil + } + } + + var aliasesResp []types.AliasListEntry + for _, alias := range k.store.ListAliases() { + aliasesResp = append(aliasesResp, types.AliasListEntry{ + AliasName: alias.AliasName, + AliasArn: aws.String(aliasArnFromAliasName(*alias.AliasName)), + TargetKeyId: alias.KeyEntry.KeyID, + LastUpdatedDate: alias.KeyEntry.AliasLastUpdatedDate, + }) + } + + return &kms.ListAliasesOutput{Aliases: aliasesResp}, nil +} + +func (k *kmsClientFake) ScheduleKeyDeletion(ctx context.Context, input *kms.ScheduleKeyDeletionInput, opts ...func(*kms.Options)) (*kms.ScheduleKeyDeletionOutput, error) { + k.mu.RLock() + defer k.mu.RUnlock() + if k.scheduleKeyDeletionErr != nil { + return nil, k.scheduleKeyDeletionErr + } + + k.store.DeleteKeyEntry(*input.KeyId) + + return &kms.ScheduleKeyDeletionOutput{}, nil +} + +func (k *kmsClientFake) Sign(ctx context.Context, input *kms.SignInput, opts ...func(*kms.Options)) (*kms.SignOutput, error) { + k.mu.RLock() + defer k.mu.RUnlock() + + if k.signErr != nil { + return nil, k.signErr + } + + _, err := k.store.FetchKeyEntry(*input.KeyId) + if err != nil { + return nil, err + } + + //TODO: do actual signing + return &kms.SignOutput{Signature: input.Message}, nil +} + +func (k *kmsClientFake) CreateAlias(ctx context.Context, input *kms.CreateAliasInput, opts ...func(*kms.Options)) (*kms.CreateAliasOutput, error) { + k.mu.RLock() + defer k.mu.RUnlock() + if k.createAliasErr != nil { + return nil, k.createAliasErr + } + + err := k.store.SaveAlias(*input.TargetKeyId, *input.AliasName) + if err != nil { + return nil, err + } + + return &kms.CreateAliasOutput{}, nil +} + +func (k *kmsClientFake) UpdateAlias(ctw context.Context, input *kms.UpdateAliasInput, opts ...func(*kms.Options)) (*kms.UpdateAliasOutput, error) { + k.mu.RLock() + defer k.mu.RUnlock() + if k.updateAliasErr != nil { + return nil, k.updateAliasErr + } + + err := k.store.SaveAlias(*input.TargetKeyId, *input.AliasName) + if err != nil { + return nil, err + } + + return &kms.UpdateAliasOutput{}, nil +} + +func (k *kmsClientFake) ListKeys(ctw context.Context, input *kms.ListKeysInput, opts ...func(*kms.Options)) (*kms.ListKeysOutput, error) { + k.mu.RLock() + defer k.mu.RUnlock() + if k.listKeysErr != nil { + return nil, k.listKeysErr + } + + var keysResp []types.KeyListEntry + for _, keyEntry := range k.store.ListKeyEntries() { + keysResp = append(keysResp, types.KeyListEntry{ + KeyArn: keyEntry.Arn, + KeyId: keyEntry.KeyID, + }) + } + + return &kms.ListKeysOutput{Keys: keysResp}, nil +} + +func (k *kmsClientFake) DeleteAlias(ctx context.Context, params *kms.DeleteAliasInput, optFns ...func(*kms.Options)) (*kms.DeleteAliasOutput, error) { + k.mu.RLock() + defer k.mu.RUnlock() + if k.deleteAliasErr != nil { + return nil, k.deleteAliasErr + } + + k.store.DeleteAlias(*params.AliasName) + return nil, nil +} + +func (k *kmsClientFake) setEntries(entries []fakeKeyEntry) { + k.mu.Lock() + defer k.mu.Unlock() + if entries == nil { + return + } + for _, e := range entries { + if e.KeyID != nil { + newEntry := e + k.store.SaveKeyEntry(&newEntry) + } + if e.AliasName != nil { + err := k.store.SaveAlias(*e.KeyID, *e.AliasName) + if err != nil { + k.t.Error(err) + } + } + } +} + +func (k *kmsClientFake) setCreateKeyErr(fakeError string) { + k.mu.Lock() + defer k.mu.Unlock() + if fakeError != "" { + k.createKeyErr = errors.New(fakeError) + } +} +func (k *kmsClientFake) setDescribeKeyErr(fakeError string) { + k.mu.Lock() + defer k.mu.Unlock() + if fakeError != "" { + k.describeKeyErr = errors.New(fakeError) + } +} + +func (k *kmsClientFake) setgetPublicKeyErr(fakeError string) { + k.mu.Lock() + defer k.mu.Unlock() + if fakeError != "" { + k.getPublicKeyErr = errors.New(fakeError) + } +} + +func (k *kmsClientFake) setListAliasesErr(fakeError string) { + k.mu.Lock() + defer k.mu.Unlock() + if fakeError != "" { + k.listAliasesErr = errors.New(fakeError) + } +} + +func (k *kmsClientFake) setCreateAliasesErr(fakeError string) { + k.mu.Lock() + defer k.mu.Unlock() + if fakeError != "" { + k.createAliasErr = errors.New(fakeError) + } +} + +func (k *kmsClientFake) setUpdateAliasErr(fakeError string) { + k.mu.Lock() + defer k.mu.Unlock() + if fakeError != "" { + k.updateAliasErr = errors.New(fakeError) + } +} + +func (k *kmsClientFake) setScheduleKeyDeletionErr(fakeError error) { + k.mu.Lock() + defer k.mu.Unlock() + if fakeError != nil { + k.scheduleKeyDeletionErr = fakeError + } +} + +func (k *kmsClientFake) setSignDataErr(fakeError string) { + k.mu.Lock() + defer k.mu.Unlock() + if fakeError != "" { + k.signErr = errors.New(fakeError) + } +} + +func (k *kmsClientFake) setListKeysErr(fakeError string) { + k.mu.Lock() + defer k.mu.Unlock() + if fakeError != "" { + k.listKeysErr = errors.New(fakeError) + } +} + +func (k *kmsClientFake) setDeleteAliasErr(fakeError string) { + k.mu.Lock() + defer k.mu.Unlock() + if fakeError != "" { + k.deleteAliasErr = errors.New(fakeError) + } +} + +const ( + fakeKeyArnPrefix = "arn:aws:kms:region:1234:key/" + fakeAliasArnPrefix = "arn:aws:kms:region:1234:" +) + +type fakeStore struct { + keyEntries map[string]*fakeKeyEntry // don't user ara for key + aliases map[string]fakeAlias // don't user ara for key + mu sync.RWMutex + nextID int + clk *clock.Mock +} + +func newFakeStore(c *clock.Mock) fakeStore { + return fakeStore{ + keyEntries: make(map[string]*fakeKeyEntry), + aliases: make(map[string]fakeAlias), + clk: c, + } +} + +type fakeKeyEntry struct { + KeyID *string + Arn *string + Description *string + CreationDate *time.Time + AliasName *string // Only one alias per key. "Real" KMS supports many aliases per key + AliasLastUpdatedDate *time.Time + PublicKey []byte + privateKey crypto.PrivateKey + Enabled bool + KeySpec types.CustomerMasterKeySpec +} + +type fakeAlias struct { + AliasName *string + AliasArn *string + KeyEntry *fakeKeyEntry +} + +func (fs *fakeStore) SaveKeyEntry(input *fakeKeyEntry) { + if input.KeyID == nil { + input.KeyID = aws.String(strconv.Itoa(fs.nextID)) + fs.nextID++ + } + input.Arn = aws.String(arnFromKeyID(*input.KeyID)) + + fs.mu.Lock() + defer fs.mu.Unlock() + + fs.keyEntries[*input.KeyID] = input +} + +func (fs *fakeStore) DeleteKeyEntry(keyID string) { + fs.mu.Lock() + defer fs.mu.Unlock() + + delete(fs.keyEntries, keyID) + delete(fs.keyEntries, keyIDFromArn(keyID)) + + for k, v := range fs.aliases { + if *v.KeyEntry.KeyID == keyID || *v.KeyEntry.Arn == keyID { + delete(fs.aliases, k) + } + } +} + +func (fs *fakeStore) SaveAlias(targetKeyID, aliasName string) error { + fs.mu.Lock() + defer fs.mu.Unlock() + + keyEntry, err := fs.fetchKeyEntry(targetKeyID) + if err != nil { + return err + } + + keyEntry.AliasName = &aliasName + keyEntry.AliasLastUpdatedDate = aws.Time(fs.clk.Now()) + + fs.aliases[aliasName] = fakeAlias{ + AliasName: aws.String(aliasName), + AliasArn: aws.String(aliasArnFromAliasName(aliasName)), + KeyEntry: keyEntry, + } + + return nil +} + +func (fs *fakeStore) DeleteAlias(aliasName string) { + fs.mu.Lock() + defer fs.mu.Unlock() + + delete(fs.aliases, aliasName) +} + +func (fs *fakeStore) ListKeyEntries() []fakeKeyEntry { + fs.mu.RLock() + defer fs.mu.RUnlock() + + var keyEntries []fakeKeyEntry + for _, v := range fs.keyEntries { + keyEntries = append(keyEntries, *v) + } + return keyEntries +} + +func (fs *fakeStore) ListAliases() []fakeAlias { + fs.mu.RLock() + defer fs.mu.RUnlock() + + var aliases []fakeAlias + for _, v := range fs.aliases { + aliases = append(aliases, fakeAlias{ + AliasName: v.AliasName, + AliasArn: v.AliasArn, + KeyEntry: &fakeKeyEntry{ + KeyID: v.KeyEntry.KeyID, + Arn: v.KeyEntry.Arn, + Description: v.KeyEntry.Description, + CreationDate: v.KeyEntry.CreationDate, + AliasName: v.KeyEntry.AliasName, + AliasLastUpdatedDate: v.KeyEntry.AliasLastUpdatedDate, + PublicKey: v.KeyEntry.PublicKey, + privateKey: v.KeyEntry.privateKey, + Enabled: v.KeyEntry.Enabled, + KeySpec: v.KeyEntry.KeySpec, + }, + }) + } + return aliases +} + +func (fs *fakeStore) FetchKeyEntry(id string) (*fakeKeyEntry, error) { + fs.mu.RLock() + defer fs.mu.RUnlock() + return fs.fetchKeyEntry(id) +} + +func (fs *fakeStore) fetchKeyEntry(id string) (*fakeKeyEntry, error) { + keyEntry, ok := fs.keyEntries[id] + if ok { + return keyEntry, nil + } + + keyEntry, ok = fs.keyEntries[keyIDFromArn(id)] + if ok { + return keyEntry, nil + } + + aliasEntry, ok := fs.aliases[id] + if ok { + return aliasEntry.KeyEntry, nil + } + + aliasEntry, ok = fs.aliases[aliasNameFromArn(id)] + if ok { + return aliasEntry.KeyEntry, nil + } + + return &fakeKeyEntry{}, fmt.Errorf("no such key %q", id) +} + +func aliasArnFromAliasName(aliasName string) string { + return fakeAliasArnPrefix + aliasName +} + +func aliasNameFromArn(arn string) string { + return strings.TrimPrefix(arn, fakeAliasArnPrefix) +} + +func arnFromKeyID(keyID string) string { + return fakeKeyArnPrefix + keyID +} + +func keyIDFromArn(arn string) string { + return strings.TrimPrefix(arn, fakeKeyArnPrefix) +} diff --git a/pkg/server/plugin/keymanager/awskms/fetcher.go b/pkg/server/plugin/keymanager/awskms/fetcher.go new file mode 100644 index 0000000000..4dba6a6645 --- /dev/null +++ b/pkg/server/plugin/keymanager/awskms/fetcher.go @@ -0,0 +1,143 @@ +package awskms + +import ( + "context" + "path" + "strings" + "sync" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/kms" + "github.com/aws/aws-sdk-go-v2/service/kms/types" + "github.com/hashicorp/go-hclog" + keymanagerv0 "github.com/spiffe/spire/proto/spire/plugin/server/keymanager/v0" + "golang.org/x/sync/errgroup" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +type keyFetcher struct { + log hclog.Logger + kmsClient kmsClient + serverID string + trustDomain string +} + +func (kf *keyFetcher) fetchKeyEntries(ctx context.Context) ([]*keyEntry, error) { + var keyEntries []*keyEntry + var keyEntriesMutex sync.Mutex + paginator := kms.NewListAliasesPaginator(kf.kmsClient, &kms.ListAliasesInput{Limit: aws.Int32(100)}) + g, ctx := errgroup.WithContext(ctx) + + for { + aliasesResp, err := paginator.NextPage(ctx) + switch { + case err != nil: + return nil, status.Errorf(codes.Internal, "failed to fetch aliases: %v", err) + case aliasesResp == nil: + return nil, status.Errorf(codes.Internal, "failed to fetch aliases: nil response") + } + + kf.log.Debug("Found aliases", "num_aliases", len(aliasesResp.Aliases)) + + for _, alias := range aliasesResp.Aliases { + // Ensure the alias has a name. This check is purely defensive + // since aliases should always have a name. + if alias.AliasName == nil { + continue + } + + spireKeyID, ok := kf.spireKeyIDFromAlias(*alias.AliasName) + // ignore aliases/keys not belonging to this server + if !ok { + continue + } + + // The following checks are purely defensive but we want to ensure + // we don't try and handle an alias with a malformed shape. + switch { + case alias.AliasArn == nil: + return nil, status.Errorf(codes.Internal, "failed to fetch aliases: found SPIRE alias without arn: name=%q", *alias.AliasName) + case alias.TargetKeyId == nil: + // this means something external to the plugin created the alias, without associating it to a key. + // it should never happen with CMKs. + return nil, status.Errorf(codes.FailedPrecondition, "failed to fetch aliases: found SPIRE alias without key: name=%q arn=%q", *alias.AliasName, *alias.AliasArn) + } + + a := alias + // trigger a goroutine to get the details of the key + g.Go(func() error { + entry, err := kf.fetchKeyEntryDetails(ctx, a, spireKeyID) + if err != nil { + return err + } + + keyEntriesMutex.Lock() + keyEntries = append(keyEntries, entry) + keyEntriesMutex.Unlock() + return nil + }) + } + + if !paginator.HasMorePages() { + break + } + } + + // wait for all the detail gathering routines to finish + if err := g.Wait(); err != nil { + statusErr := status.Convert(err) + return nil, status.Errorf(statusErr.Code(), "failed to fetch aliases: %v", statusErr.Message()) + } + + return keyEntries, nil +} + +func (kf *keyFetcher) fetchKeyEntryDetails(ctx context.Context, alias types.AliasListEntry, spireKeyID string) (*keyEntry, error) { + describeResp, err := kf.kmsClient.DescribeKey(ctx, &kms.DescribeKeyInput{KeyId: alias.AliasArn}) + switch { + case err != nil: + return nil, status.Errorf(codes.Internal, "failed to describe key: %v", err) + case describeResp == nil || describeResp.KeyMetadata == nil: + return nil, status.Error(codes.Internal, "malformed describe key response") + case describeResp.KeyMetadata.Arn == nil: + return nil, status.Errorf(codes.Internal, "found SPIRE alias without key arn: %q", *alias.AliasArn) + case !describeResp.KeyMetadata.Enabled: + // this means something external to the plugin, deleted or disabled the key without removing the alias + // returning an error provides the opportunity or reverting this in KMS + return nil, status.Errorf(codes.FailedPrecondition, "found disabled SPIRE key: %q, alias: %q", *describeResp.KeyMetadata.Arn, *alias.AliasArn) + } + + keyType, ok := keyTypeFromKeySpec(describeResp.KeyMetadata.CustomerMasterKeySpec) + if !ok { + return nil, status.Errorf(codes.Internal, "unsupported key spec: %v", describeResp.KeyMetadata.CustomerMasterKeySpec) + } + + publicKeyResp, err := kf.kmsClient.GetPublicKey(ctx, &kms.GetPublicKeyInput{KeyId: alias.AliasArn}) + switch { + case err != nil: + return nil, status.Errorf(codes.Internal, "failed to get public key: %v", err) + case publicKeyResp == nil || publicKeyResp.PublicKey == nil || len(publicKeyResp.PublicKey) == 0: + return nil, status.Error(codes.Internal, "malformed get public key response") + } + + return &keyEntry{ + Arn: *describeResp.KeyMetadata.Arn, + AliasName: *alias.AliasName, + PublicKey: &keymanagerv0.PublicKey{ + Id: spireKeyID, + Type: keyType, + PkixData: publicKeyResp.PublicKey, + }, + }, nil +} + +func (kf *keyFetcher) spireKeyIDFromAlias(aliasName string) (string, bool) { + trustDomain := sanitizeTrustDomain(kf.trustDomain) + prefix := path.Join(aliasPrefix, trustDomain, kf.serverID) + "/" + trimmed := strings.TrimPrefix(aliasName, prefix) + if trimmed == aliasName { + return "", false + } + return trimmed, true +} diff --git a/test/spiretest/logs.go b/test/spiretest/logs.go index 96eee96464..d2a0d185a5 100644 --- a/test/spiretest/logs.go +++ b/test/spiretest/logs.go @@ -34,6 +34,21 @@ func AssertLogsAnyOrder(t *testing.T, entries []*logrus.Entry, expected []LogEnt assert.ElementsMatch(t, expected, convertLogEntries(entries), "unexpected logs") } +func AssertLastLogs(t *testing.T, entries []*logrus.Entry, expected []LogEntry) { + for _, entry := range entries { + for key, field := range entry.Data { + entry.Data[key] = fmt.Sprint(field) + } + } + + removeLen := len(entries) - len(expected) + if removeLen > 0 { + assert.Equal(t, expected, convertLogEntries(entries[removeLen:]), "unexpected logs") + return + } + assert.Equal(t, expected, convertLogEntries(entries), "unexpected logs") +} + func convertLogEntries(entries []*logrus.Entry) (out []LogEntry) { for _, entry := range entries { out = append(out, LogEntry{