From 6451acb171f88c1274f2f925bca503c0b0992472 Mon Sep 17 00:00:00 2001 From: Noah Stride Date: Tue, 10 Dec 2024 15:53:23 +0000 Subject: [PATCH] [v16] Workload ID: Add WorkloadIdentity local service and cache config (#49942) (#49990) * Workload ID: Add WorkloadIdentity local service and cache config (#49942) * Add WorkloadIdentity store and cache * Update lib/services/local/workload_identity.go Co-authored-by: Edward Dowling * Update lib/services/local/workload_identity.go Co-authored-by: Edward Dowling * Update lib/cache/resource_workload_identity.go Co-authored-by: Edoardo Spadolini --------- Co-authored-by: Edward Dowling Co-authored-by: Edoardo Spadolini * Fix for v16 * Fix dodgy backport with incorrect Kind --------- Co-authored-by: Edward Dowling Co-authored-by: Edoardo Spadolini --- lib/auth/accesspoint/accesspoint.go | 2 + lib/auth/auth.go | 9 + lib/auth/authclient/api.go | 7 + lib/auth/helpers.go | 1 + lib/auth/init.go | 4 + lib/cache/cache.go | 12 + lib/cache/cache_test.go | 13 + lib/cache/collections.go | 11 + lib/cache/resource_workload_identity.go | 119 +++++++ lib/cache/resource_workload_identity_test.go | 74 ++++ lib/service/service.go | 1 + lib/services/local/events.go | 30 ++ lib/services/local/workload_identity.go | 118 ++++++ lib/services/local/workload_identity_test.go | 355 +++++++++++++++++++ lib/services/workload_identity.go | 122 +++++++ lib/services/workload_identity_test.go | 231 ++++++++++++ 16 files changed, 1109 insertions(+) create mode 100644 lib/cache/resource_workload_identity.go create mode 100644 lib/cache/resource_workload_identity_test.go create mode 100644 lib/services/local/workload_identity.go create mode 100644 lib/services/local/workload_identity_test.go create mode 100644 lib/services/workload_identity.go create mode 100644 lib/services/workload_identity_test.go diff --git a/lib/auth/accesspoint/accesspoint.go b/lib/auth/accesspoint/accesspoint.go index 5c0bc9b957ee0..ff60731708600 100644 --- a/lib/auth/accesspoint/accesspoint.go +++ b/lib/auth/accesspoint/accesspoint.go @@ -103,6 +103,7 @@ type Config struct { Users services.UsersService WebSession types.WebSessionInterface WebToken types.WebTokenInterface + WorkloadIdentity cache.WorkloadIdentityReader WindowsDesktops services.WindowsDesktops AutoUpdateService services.AutoUpdateServiceGetter } @@ -198,6 +199,7 @@ func NewCache(cfg Config) (*cache.Cache, error) { Users: cfg.Users, WebSession: cfg.WebSession, WebToken: cfg.WebToken, + WorkloadIdentity: cfg.WorkloadIdentity, WindowsDesktops: cfg.WindowsDesktops, } diff --git a/lib/auth/auth.go b/lib/auth/auth.go index ec02d335ff141..2abd052512ec5 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -359,6 +359,13 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) { return nil, trace.Wrap(err, "creating SPIFFEFederation service") } } + if cfg.WorkloadIdentity == nil { + workloadIdentity, err := local.NewWorkloadIdentityService(cfg.Backend) + if err != nil { + return nil, trace.Wrap(err, "creating WorkloadIdentity service") + } + cfg.WorkloadIdentity = workloadIdentity + } limiter, err := limiter.NewConnectionsLimiter(limiter.Config{ MaxConnections: defaults.LimiterMaxConcurrentSignatures, @@ -455,6 +462,7 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) { BotInstance: cfg.BotInstance, SPIFFEFederations: cfg.SPIFFEFederations, StaticHostUser: cfg.StaticHostUsers, + WorkloadIdentities: cfg.WorkloadIdentity, } as := Server{ @@ -668,6 +676,7 @@ type Services struct { services.BotInstance services.StaticHostUser services.AutoUpdateService + services.WorkloadIdentities } // GetWebSession returns existing web session described by req. diff --git a/lib/auth/authclient/api.go b/lib/auth/authclient/api.go index 710427d8615e3..d6b51a154c1f5 100644 --- a/lib/auth/authclient/api.go +++ b/lib/auth/authclient/api.go @@ -37,6 +37,7 @@ import ( userprovisioningpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/userprovisioning/v2" userspb "github.com/gravitational/teleport/api/gen/proto/go/teleport/users/v1" usertasksv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/usertasks/v1" + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/accesslist" "github.com/gravitational/teleport/api/types/discoveryconfig" @@ -1211,6 +1212,12 @@ type Cache interface { // GetAccessGraphSettings returns the access graph settings. GetAccessGraphSettings(context.Context) (*clusterconfigpb.AccessGraphSettings, error) + // GetWorkloadIdentity gets a WorkloadIdentity by name. + GetWorkloadIdentity(ctx context.Context, name string) (*workloadidentityv1pb.WorkloadIdentity, error) + // ListWorkloadIdentities lists all SPIFFE Federations using Google style + // pagination. + ListWorkloadIdentities(ctx context.Context, pageSize int, lastToken string) ([]*workloadidentityv1pb.WorkloadIdentity, string, error) + // ListStaticHostUsers lists static host users. ListStaticHostUsers(ctx context.Context, pageSize int, startKey string) ([]*userprovisioningpb.StaticHostUser, string, error) // GetStaticHostUser returns a static host user by name. diff --git a/lib/auth/helpers.go b/lib/auth/helpers.go index 22bc3b221e1fd..d4242442ab85f 100644 --- a/lib/auth/helpers.go +++ b/lib/auth/helpers.go @@ -347,6 +347,7 @@ func NewTestAuthServer(cfg TestAuthServerConfig) (*TestAuthServer, error) { SecReports: svces.SecReports, SnowflakeSession: svces.Identity, SPIFFEFederations: svces.SPIFFEFederations, + WorkloadIdentity: svces.WorkloadIdentities, StaticHostUsers: svces.StaticHostUser, Trust: svces.TrustInternal, UserGroups: svces.UserGroups, diff --git a/lib/auth/init.go b/lib/auth/init.go index 2203b353a2857..8c92b8da41f4a 100644 --- a/lib/auth/init.go +++ b/lib/auth/init.go @@ -316,6 +316,10 @@ type InitConfig struct { // SPIFFEFederations is a service that manages storing SPIFFE federations. SPIFFEFederations services.SPIFFEFederations + // WorkloadIdentity is the service for storing and retrieving + // WorkloadIdentity resources. + WorkloadIdentity services.WorkloadIdentities + // StaticHostUsers is a service that manages host users that should be // created on SSH nodes. StaticHostUsers services.StaticHostUser diff --git a/lib/cache/cache.go b/lib/cache/cache.go index 6540919dbd22b..1f44c0b07e182 100644 --- a/lib/cache/cache.go +++ b/lib/cache/cache.go @@ -190,6 +190,7 @@ func ForAuth(cfg Config) Config { {Kind: types.KindUserTask}, {Kind: types.KindAutoUpdateVersion}, {Kind: types.KindAutoUpdateConfig}, + {Kind: types.KindWorkloadIdentity}, } cfg.QueueSize = defaults.AuthQueueSize // We don't want to enable partial health for auth cache because auth uses an event stream @@ -536,6 +537,7 @@ type Cache struct { accessMontoringRuleCache services.AccessMonitoringRules spiffeFederationCache spiffeFederationCacher staticHostUsersCache *local.StaticHostUserService + workloadIdentityCache workloadIdentityCacher // closed indicates that the cache has been closed closed atomic.Bool @@ -716,6 +718,9 @@ type Config struct { SPIFFEFederations SPIFFEFederationReader // StaticHostUsers is the static host user service. StaticHostUsers services.StaticHostUser + // WorkloadIdentity is the upstream Workload Identities service that we're + // caching + WorkloadIdentity WorkloadIdentityReader // Backend is a backend for local cache Backend backend.Backend // MaxRetryPeriod is the maximum period between cache retries on failures @@ -969,6 +974,12 @@ func New(config Config) (*Cache, error) { return nil, trace.Wrap(err) } + workloadIdentityCache, err := local.NewWorkloadIdentityService(config.Backend) + if err != nil { + cancel() + return nil, trace.Wrap(err) + } + staticHostUserCache, err := local.NewStaticHostUserService(config.Backend) if err != nil { cancel() @@ -1019,6 +1030,7 @@ func New(config Config) (*Cache, error) { kubeWaitingContsCache: kubeWaitingContsCache, spiffeFederationCache: spiffeFederationCache, staticHostUsersCache: staticHostUserCache, + workloadIdentityCache: workloadIdentityCache, Logger: log.WithFields(log.Fields{ teleport.ComponentKey: config.Component, }), diff --git a/lib/cache/cache_test.go b/lib/cache/cache_test.go index 19e625561dbc1..cccb8a6ad3a24 100644 --- a/lib/cache/cache_test.go +++ b/lib/cache/cache_test.go @@ -138,6 +138,7 @@ type testPack struct { spiffeFederations *local.SPIFFEFederationService staticHostUsers services.StaticHostUser autoUpdateService services.AutoUpdateService + workloadIdentity *local.WorkloadIdentityService } // testFuncs are functions to support testing an object in a cache. @@ -350,6 +351,12 @@ func newPackWithoutCache(dir string, opts ...packOption) (*testPack, error) { } p.spiffeFederations = spiffeFederationsSvc + workloadIdentitySvc, err := local.NewWorkloadIdentityService(p.backend) + if err != nil { + return nil, trace.Wrap(err) + } + p.workloadIdentity = workloadIdentitySvc + databaseObjectsSvc, err := local.NewDatabaseObjectService(p.backend) if err != nil { return nil, trace.Wrap(err) @@ -428,6 +435,7 @@ func newPack(dir string, setupConfig func(c Config) Config, opts ...packOption) DatabaseObjects: p.databaseObjects, StaticHostUsers: p.staticHostUsers, AutoUpdateService: p.autoUpdateService, + WorkloadIdentity: p.workloadIdentity, MaxRetryPeriod: 200 * time.Millisecond, EventsC: p.eventsC, })) @@ -836,6 +844,7 @@ func TestCompletenessInit(t *testing.T) { SPIFFEFederations: p.spiffeFederations, StaticHostUsers: p.staticHostUsers, AutoUpdateService: p.autoUpdateService, + WorkloadIdentity: p.workloadIdentity, MaxRetryPeriod: 200 * time.Millisecond, EventsC: p.eventsC, })) @@ -917,6 +926,7 @@ func TestCompletenessReset(t *testing.T) { SPIFFEFederations: p.spiffeFederations, StaticHostUsers: p.staticHostUsers, AutoUpdateService: p.autoUpdateService, + WorkloadIdentity: p.workloadIdentity, MaxRetryPeriod: 200 * time.Millisecond, EventsC: p.eventsC, })) @@ -1124,6 +1134,7 @@ func TestListResources_NodesTTLVariant(t *testing.T) { SPIFFEFederations: p.spiffeFederations, StaticHostUsers: p.staticHostUsers, AutoUpdateService: p.autoUpdateService, + WorkloadIdentity: p.workloadIdentity, MaxRetryPeriod: 200 * time.Millisecond, EventsC: p.eventsC, neverOK: true, // ensure reads are never healthy @@ -1216,6 +1227,7 @@ func initStrategy(t *testing.T) { SPIFFEFederations: p.spiffeFederations, StaticHostUsers: p.staticHostUsers, AutoUpdateService: p.autoUpdateService, + WorkloadIdentity: p.workloadIdentity, MaxRetryPeriod: 200 * time.Millisecond, EventsC: p.eventsC, })) @@ -3454,6 +3466,7 @@ func TestCacheWatchKindExistsInEvents(t *testing.T) { types.KindUserTask: types.Resource153ToLegacy(newUserTasks(t)), types.KindAutoUpdateConfig: types.Resource153ToLegacy(newAutoUpdateConfig(t)), types.KindAutoUpdateVersion: types.Resource153ToLegacy(newAutoUpdateVersion(t)), + types.KindWorkloadIdentity: types.Resource153ToLegacy(newWorkloadIdentity("some_identifier")), } for name, cfg := range cases { diff --git a/lib/cache/collections.go b/lib/cache/collections.go index 84a02648881cd..8dcb5723ea4a8 100644 --- a/lib/cache/collections.go +++ b/lib/cache/collections.go @@ -39,6 +39,7 @@ import ( userprovisioningpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/userprovisioning/v2" userspb "github.com/gravitational/teleport/api/gen/proto/go/teleport/users/v1" usertasksv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/usertasks/v1" + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/accesslist" "github.com/gravitational/teleport/api/types/discoveryconfig" @@ -265,6 +266,7 @@ type cacheCollections struct { spiffeFederations collectionReader[SPIFFEFederationReader] autoUpdateConfigs collectionReader[autoUpdateConfigGetter] autoUpdateVersions collectionReader[autoUpdateVersionGetter] + workloadIdentity collectionReader[WorkloadIdentityReader] } // setupCollections returns a registry of collections. @@ -784,6 +786,15 @@ func setupCollections(c *Cache, watches []types.WatchKind) (*cacheCollections, e watch: watch, } collections.byKind[resourceKind] = collections.accessGraphSettings + case types.KindWorkloadIdentity: + if c.Config.WorkloadIdentity == nil { + return nil, trace.BadParameter("missing parameter WorkloadIdentity") + } + collections.workloadIdentity = &genericCollection[*workloadidentityv1pb.WorkloadIdentity, WorkloadIdentityReader, workloadIdentityExecutor]{ + cache: c, + watch: watch, + } + collections.byKind[resourceKind] = collections.workloadIdentity case types.KindAutoUpdateConfig: if c.AutoUpdateService == nil { return nil, trace.BadParameter("missing parameter AutoUpdateService") diff --git a/lib/cache/resource_workload_identity.go b/lib/cache/resource_workload_identity.go new file mode 100644 index 0000000000000..75efb50fedbd5 --- /dev/null +++ b/lib/cache/resource_workload_identity.go @@ -0,0 +1,119 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//nolint:unused // Because the executors generate a large amount of false positives. +package cache + +import ( + "context" + + "github.com/gravitational/trace" + + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" + "github.com/gravitational/teleport/api/types" +) + +// WorkloadIdentityReader is an interface that defines the methods for getting +// WorkloadIdentity. This is returned as the reader for the WorkloadIdentity +// collection but is also used by the executor to read the full list of +// WorkloadIdentity on initialization. +type WorkloadIdentityReader interface { + ListWorkloadIdentities(ctx context.Context, pageSize int, nextToken string) ([]*workloadidentityv1pb.WorkloadIdentity, string, error) + GetWorkloadIdentity(ctx context.Context, name string) (*workloadidentityv1pb.WorkloadIdentity, error) +} + +// workloadIdentityCacher is used for storing and retrieving WorkloadIdentity +// from the cache's local backend. +type workloadIdentityCacher interface { + WorkloadIdentityReader + UpsertWorkloadIdentity(ctx context.Context, resource *workloadidentityv1pb.WorkloadIdentity) (*workloadidentityv1pb.WorkloadIdentity, error) + DeleteWorkloadIdentity(ctx context.Context, name string) error + DeleteAllWorkloadIdentities(ctx context.Context) error +} + +type workloadIdentityExecutor struct{} + +var _ executor[*workloadidentityv1pb.WorkloadIdentity, WorkloadIdentityReader] = workloadIdentityExecutor{} + +func (workloadIdentityExecutor) getAll(ctx context.Context, cache *Cache, loadSecrets bool) ([]*workloadidentityv1pb.WorkloadIdentity, error) { + var out []*workloadidentityv1pb.WorkloadIdentity + var nextToken string + for { + var page []*workloadidentityv1pb.WorkloadIdentity + var err error + + const defaultPageSize = 0 + page, nextToken, err = cache.Config.WorkloadIdentity.ListWorkloadIdentities(ctx, defaultPageSize, nextToken) + if err != nil { + return nil, trace.Wrap(err) + } + out = append(out, page...) + if nextToken == "" { + break + } + } + return out, nil +} + +func (workloadIdentityExecutor) upsert(ctx context.Context, cache *Cache, resource *workloadidentityv1pb.WorkloadIdentity) error { + _, err := cache.workloadIdentityCache.UpsertWorkloadIdentity(ctx, resource) + return trace.Wrap(err) +} + +func (workloadIdentityExecutor) deleteAll(ctx context.Context, cache *Cache) error { + return trace.Wrap(cache.workloadIdentityCache.DeleteAllWorkloadIdentities(ctx)) +} + +func (workloadIdentityExecutor) delete(ctx context.Context, cache *Cache, resource types.Resource) error { + return trace.Wrap(cache.workloadIdentityCache.DeleteWorkloadIdentity(ctx, resource.GetName())) +} + +func (workloadIdentityExecutor) isSingleton() bool { return false } + +func (workloadIdentityExecutor) getReader(cache *Cache, cacheOK bool) WorkloadIdentityReader { + if cacheOK { + return cache.workloadIdentityCache + } + return cache.Config.WorkloadIdentity +} + +// ListWorkloadIdentities returns a paginated list of WorkloadIdentity resources. +func (c *Cache) ListWorkloadIdentities(ctx context.Context, pageSize int, nextToken string) ([]*workloadidentityv1pb.WorkloadIdentity, string, error) { + ctx, span := c.Tracer.Start(ctx, "cache/ListWorkloadIdentities") + defer span.End() + + rg, err := readCollectionCache(c, c.collections.workloadIdentity) + if err != nil { + return nil, "", trace.Wrap(err) + } + defer rg.Release() + out, nextKey, err := rg.reader.ListWorkloadIdentities(ctx, pageSize, nextToken) + return out, nextKey, trace.Wrap(err) +} + +// GetWorkloadIdentity returns a single WorkloadIdentity by name +func (c *Cache) GetWorkloadIdentity(ctx context.Context, name string) (*workloadidentityv1pb.WorkloadIdentity, error) { + ctx, span := c.Tracer.Start(ctx, "cache/GetWorkloadIdentity") + defer span.End() + + rg, err := readCollectionCache(c, c.collections.workloadIdentity) + if err != nil { + return nil, trace.Wrap(err) + } + defer rg.Release() + out, err := rg.reader.GetWorkloadIdentity(ctx, name) + return out, trace.Wrap(err) +} diff --git a/lib/cache/resource_workload_identity_test.go b/lib/cache/resource_workload_identity_test.go new file mode 100644 index 0000000000000..da82d64fec27c --- /dev/null +++ b/lib/cache/resource_workload_identity_test.go @@ -0,0 +1,74 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cache + +import ( + "context" + "testing" + + "github.com/gravitational/trace" + + headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1" + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" + "github.com/gravitational/teleport/api/types" +) + +func newWorkloadIdentity(name string) *workloadidentityv1pb.WorkloadIdentity { + return &workloadidentityv1pb.WorkloadIdentity{ + Kind: types.KindWorkloadIdentity, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: name, + }, + Spec: &workloadidentityv1pb.WorkloadIdentitySpec{ + Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{ + Id: "/example", + }, + }, + } +} + +func TestWorkloadIdentity(t *testing.T) { + t.Parallel() + + p := newTestPack(t, ForAuth) + t.Cleanup(p.Close) + + testResources153(t, p, testFuncs153[*workloadidentityv1pb.WorkloadIdentity]{ + newResource: func(s string) (*workloadidentityv1pb.WorkloadIdentity, error) { + return newWorkloadIdentity(s), nil + }, + + create: func(ctx context.Context, item *workloadidentityv1pb.WorkloadIdentity) error { + _, err := p.workloadIdentity.CreateWorkloadIdentity(ctx, item) + return trace.Wrap(err) + }, + list: func(ctx context.Context) ([]*workloadidentityv1pb.WorkloadIdentity, error) { + items, _, err := p.workloadIdentity.ListWorkloadIdentities(ctx, 0, "") + return items, trace.Wrap(err) + }, + deleteAll: func(ctx context.Context) error { + return p.workloadIdentity.DeleteAllWorkloadIdentities(ctx) + }, + + cacheList: func(ctx context.Context) ([]*workloadidentityv1pb.WorkloadIdentity, error) { + items, _, err := p.cache.ListWorkloadIdentities(ctx, 0, "") + return items, trace.Wrap(err) + }, + cacheGet: p.cache.GetWorkloadIdentity, + }) +} diff --git a/lib/service/service.go b/lib/service/service.go index e2a35c8df2f07..86363450dd1a4 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -2429,6 +2429,7 @@ func (process *TeleportProcess) newAccessCacheForServices(cfg accesspoint.Config cfg.WebSession = services.Identity.WebSessions() cfg.WebToken = services.Identity.WebTokens() cfg.WindowsDesktops = services.WindowsDesktops + cfg.WorkloadIdentity = services.WorkloadIdentities cfg.AutoUpdateService = services.AutoUpdateService return accesspoint.NewCache(cfg) diff --git a/lib/services/local/events.go b/lib/services/local/events.go index 994f54c18cd4e..7847ce6ccc1a4 100644 --- a/lib/services/local/events.go +++ b/lib/services/local/events.go @@ -234,6 +234,8 @@ func (e *EventsService) NewWatcher(ctx context.Context, watch types.Watch) (type parser = newAccessGraphSettingsParser() case types.KindStaticHostUser: parser = newStaticHostUserParser() + case types.KindWorkloadIdentity: + parser = newWorkloadIdentityParser() default: if watch.AllowPartialSuccess { continue @@ -2607,3 +2609,31 @@ func (p *accessGraphSettingsParser) parse(event backend.Event) (types.Resource, return nil, trace.BadParameter("event %v is not supported", event.Type) } } + +func newWorkloadIdentityParser() *workloadIdentityParser { + return &workloadIdentityParser{ + baseParser: newBaseParser(backend.NewKey(workloadIdentityPrefix)), + } +} + +type workloadIdentityParser struct { + baseParser +} + +func (p *workloadIdentityParser) parse(event backend.Event) (types.Resource, error) { + switch event.Type { + case types.OpDelete: + return resourceHeader(event, types.KindWorkloadIdentity, types.V1, 0) + case types.OpPut: + resource, err := services.UnmarshalWorkloadIdentity( + event.Item.Value, + services.WithExpires(event.Item.Expires), + services.WithRevision(event.Item.Revision)) + if err != nil { + return nil, trace.Wrap(err, "unmarshalling resource from event") + } + return types.Resource153ToLegacy(resource), nil + default: + return nil, trace.BadParameter("event %v is not supported", event.Type) + } +} diff --git a/lib/services/local/workload_identity.go b/lib/services/local/workload_identity.go new file mode 100644 index 0000000000000..b9efd70c5d9d1 --- /dev/null +++ b/lib/services/local/workload_identity.go @@ -0,0 +1,118 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package local + +import ( + "context" + + "github.com/gravitational/trace" + + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/backend" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/services/local/generic" +) + +const ( + workloadIdentityPrefix = "workload_identity" +) + +// WorkloadIdentityService exposes backend functionality for storing +// WorkloadIdentity resources +type WorkloadIdentityService struct { + service *generic.ServiceWrapper[*workloadidentityv1pb.WorkloadIdentity] +} + +// NewWorkloadIdentityService creates a new WorkloadIdentityService +func NewWorkloadIdentityService(b backend.Backend) (*WorkloadIdentityService, error) { + service, err := generic.NewServiceWrapper( + generic.ServiceWrapperConfig[*workloadidentityv1pb.WorkloadIdentity]{ + Backend: b, + ResourceKind: types.KindWorkloadIdentity, + BackendPrefix: workloadIdentityPrefix, + MarshalFunc: services.MarshalWorkloadIdentity, + UnmarshalFunc: services.UnmarshalWorkloadIdentity, + ValidateFunc: services.ValidateWorkloadIdentity, + }) + if err != nil { + return nil, trace.Wrap(err) + } + return &WorkloadIdentityService{ + service: service, + }, nil +} + +// CreateWorkloadIdentity inserts a new WorkloadIdentity into the backend. +func (b *WorkloadIdentityService) CreateWorkloadIdentity( + ctx context.Context, resource *workloadidentityv1pb.WorkloadIdentity, +) (*workloadidentityv1pb.WorkloadIdentity, error) { + created, err := b.service.CreateResource(ctx, resource) + return created, trace.Wrap(err) +} + +// GetWorkloadIdentity retrieves a specific WorkloadIdentity given a name +func (b *WorkloadIdentityService) GetWorkloadIdentity( + ctx context.Context, name string, +) (*workloadidentityv1pb.WorkloadIdentity, error) { + resource, err := b.service.GetResource(ctx, name) + return resource, trace.Wrap(err) +} + +// ListWorkloadIdentities lists all WorkloadIdentities using a given page size +// and last key. +func (b *WorkloadIdentityService) ListWorkloadIdentities( + ctx context.Context, pageSize int, currentToken string, +) ([]*workloadidentityv1pb.WorkloadIdentity, string, error) { + r, nextToken, err := b.service.ListResources(ctx, pageSize, currentToken) + return r, nextToken, trace.Wrap(err) +} + +// DeleteWorkloadIdentity deletes a specific WorkloadIdentity. +func (b *WorkloadIdentityService) DeleteWorkloadIdentity( + ctx context.Context, name string, +) error { + return trace.Wrap(b.service.DeleteResource(ctx, name)) +} + +// DeleteAllWorkloadIdentities deletes all SPIFFE resources, this is typically +// only meant to be used by the cache. +func (b *WorkloadIdentityService) DeleteAllWorkloadIdentities( + ctx context.Context, +) error { + return trace.Wrap(b.service.DeleteAllResources(ctx)) +} + +// UpsertWorkloadIdentity upserts a WorkloadIdentitys. Prefer using +// CreateWorkloadIdentity. This is only designed for usage by the cache. +func (b *WorkloadIdentityService) UpsertWorkloadIdentity( + ctx context.Context, resource *workloadidentityv1pb.WorkloadIdentity, +) (*workloadidentityv1pb.WorkloadIdentity, error) { + upserted, err := b.service.UpsertResource(ctx, resource) + return upserted, trace.Wrap(err) +} + +// UpdateWorkloadIdentity updates a specific WorkloadIdentity. The resource must +// already exist, and, condition update semantics are used - e.g the submitted +// resource must have a revision matching the revision of the resource in the +// backend. +func (b *WorkloadIdentityService) UpdateWorkloadIdentity( + ctx context.Context, resource *workloadidentityv1pb.WorkloadIdentity, +) (*workloadidentityv1pb.WorkloadIdentity, error) { + updated, err := b.service.ConditionalUpdateResource(ctx, resource) + return updated, trace.Wrap(err) +} diff --git a/lib/services/local/workload_identity_test.go b/lib/services/local/workload_identity_test.go new file mode 100644 index 0000000000000..acba05d9c8e4a --- /dev/null +++ b/lib/services/local/workload_identity_test.go @@ -0,0 +1,355 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package local + +import ( + "context" + "fmt" + "slices" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/testing/protocmp" + + headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1" + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/backend" + "github.com/gravitational/teleport/lib/backend/memory" +) + +func setupWorkloadIdentityServiceTest( + t *testing.T, +) (context.Context, *WorkloadIdentityService) { + t.Parallel() + ctx := context.Background() + clock := clockwork.NewFakeClock() + mem, err := memory.New(memory.Config{ + Context: ctx, + Clock: clock, + }) + require.NoError(t, err) + service, err := NewWorkloadIdentityService(backend.NewSanitizer(mem)) + require.NoError(t, err) + return ctx, service +} + +func newValidWorkloadIdentity(name string) *workloadidentityv1pb.WorkloadIdentity { + return &workloadidentityv1pb.WorkloadIdentity{ + Kind: types.KindWorkloadIdentity, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: name, + }, + Spec: &workloadidentityv1pb.WorkloadIdentitySpec{ + Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{ + Id: "/test", + }, + }, + } +} + +func TestWorkloadIdentityService_CreateWorkloadIdentity(t *testing.T) { + ctx, service := setupWorkloadIdentityServiceTest(t) + + t.Run("ok", func(t *testing.T) { + want := newValidWorkloadIdentity("example") + got, err := service.CreateWorkloadIdentity( + ctx, + // Clone to avoid Marshaling modifying want + proto.Clone(want).(*workloadidentityv1pb.WorkloadIdentity), + ) + require.NoError(t, err) + require.NotEmpty(t, got.Metadata.Revision) + require.Empty(t, cmp.Diff( + want, + got, + protocmp.Transform(), + protocmp.IgnoreFields(&headerv1.Metadata{}, "revision"), + )) + }) + t.Run("validation occurs", func(t *testing.T) { + out, err := service.CreateWorkloadIdentity(ctx, newValidWorkloadIdentity("")) + require.ErrorContains(t, err, "metadata.name: is required") + require.Nil(t, out) + }) + t.Run("no upsert", func(t *testing.T) { + res := newValidWorkloadIdentity("duplicate") + _, err := service.CreateWorkloadIdentity( + ctx, + // Clone to avoid Marshaling modifying want + proto.Clone(res).(*workloadidentityv1pb.WorkloadIdentity), + ) + require.NoError(t, err) + _, err = service.CreateWorkloadIdentity( + ctx, + // Clone to avoid Marshaling modifying want + proto.Clone(res).(*workloadidentityv1pb.WorkloadIdentity), + ) + require.Error(t, err) + require.True(t, trace.IsAlreadyExists(err)) + }) +} + +func TestWorkloadIdentityService_UpsertWorkloadIdentity(t *testing.T) { + ctx, service := setupWorkloadIdentityServiceTest(t) + + t.Run("ok", func(t *testing.T) { + want := newValidWorkloadIdentity("example") + got, err := service.UpsertWorkloadIdentity( + ctx, + // Clone to avoid Marshaling modifying want + proto.Clone(want).(*workloadidentityv1pb.WorkloadIdentity), + ) + require.NoError(t, err) + require.NotEmpty(t, got.Metadata.Revision) + require.Empty(t, cmp.Diff( + want, + got, + protocmp.Transform(), + protocmp.IgnoreFields(&headerv1.Metadata{}, "revision"), + )) + + // Ensure we can upsert over an existing resource + _, err = service.UpsertWorkloadIdentity( + ctx, + // Clone to avoid Marshaling modifying want + proto.Clone(want).(*workloadidentityv1pb.WorkloadIdentity), + ) + require.NoError(t, err) + }) + t.Run("validation occurs", func(t *testing.T) { + out, err := service.UpdateWorkloadIdentity(ctx, newValidWorkloadIdentity("")) + require.ErrorContains(t, err, "metadata.name: is required") + require.Nil(t, out) + }) +} + +func TestWorkloadIdentityService_ListWorkloadIdentities(t *testing.T) { + ctx, service := setupWorkloadIdentityServiceTest(t) + // Create entities to list + createdObjects := []*workloadidentityv1pb.WorkloadIdentity{} + // Create 49 entities to test an incomplete page at the end. + for i := 0; i < 49; i++ { + created, err := service.CreateWorkloadIdentity( + ctx, + newValidWorkloadIdentity(fmt.Sprintf("%d", i)), + ) + require.NoError(t, err) + createdObjects = append(createdObjects, created) + } + t.Run("default page size", func(t *testing.T) { + page, nextToken, err := service.ListWorkloadIdentities(ctx, 0, "") + require.NoError(t, err) + require.Len(t, page, 49) + require.Empty(t, nextToken) + + // Expect that we get all the things we have created + for _, created := range createdObjects { + slices.ContainsFunc(page, func(resource *workloadidentityv1pb.WorkloadIdentity) bool { + return proto.Equal(created, resource) + }) + } + }) + t.Run("pagination", func(t *testing.T) { + fetched := []*workloadidentityv1pb.WorkloadIdentity{} + token := "" + iterations := 0 + for { + iterations++ + page, nextToken, err := service.ListWorkloadIdentities(ctx, 10, token) + require.NoError(t, err) + fetched = append(fetched, page...) + if nextToken == "" { + break + } + token = nextToken + } + require.Equal(t, 5, iterations) + + require.Len(t, fetched, 49) + // Expect that we get all the things we have created + for _, created := range createdObjects { + slices.ContainsFunc(fetched, func(resource *workloadidentityv1pb.WorkloadIdentity) bool { + return proto.Equal(created, resource) + }) + } + }) +} + +func TestWorkloadIdentityService_GetWorkloadIdentity(t *testing.T) { + ctx, service := setupWorkloadIdentityServiceTest(t) + + t.Run("ok", func(t *testing.T) { + want := newValidWorkloadIdentity("example") + _, err := service.CreateWorkloadIdentity( + ctx, + // Clone to avoid Marshaling modifying want + proto.Clone(want).(*workloadidentityv1pb.WorkloadIdentity), + ) + require.NoError(t, err) + got, err := service.GetWorkloadIdentity(ctx, "example") + require.NoError(t, err) + require.NotEmpty(t, got.Metadata.Revision) + require.Empty(t, cmp.Diff( + want, + got, + protocmp.Transform(), + protocmp.IgnoreFields(&headerv1.Metadata{}, "revision"), + )) + }) + t.Run("not found", func(t *testing.T) { + _, err := service.GetWorkloadIdentity(ctx, "not-found") + require.Error(t, err) + require.True(t, trace.IsNotFound(err)) + }) +} + +func TestWorkloadIdentityService_DeleteWorkloadIdentity(t *testing.T) { + ctx, service := setupWorkloadIdentityServiceTest(t) + + t.Run("ok", func(t *testing.T) { + _, err := service.CreateWorkloadIdentity( + ctx, + newValidWorkloadIdentity("example"), + ) + require.NoError(t, err) + + _, err = service.GetWorkloadIdentity(ctx, "example") + require.NoError(t, err) + + err = service.DeleteWorkloadIdentity(ctx, "example") + require.NoError(t, err) + + _, err = service.GetWorkloadIdentity(ctx, "example") + require.Error(t, err) + require.True(t, trace.IsNotFound(err)) + }) + t.Run("not found", func(t *testing.T) { + err := service.DeleteWorkloadIdentity(ctx, "foo.example.com") + require.Error(t, err) + require.True(t, trace.IsNotFound(err)) + }) +} + +func TestWorkloadIdentityService_DeleteAllWorkloadIdentities(t *testing.T) { + ctx, service := setupWorkloadIdentityServiceTest(t) + _, err := service.CreateWorkloadIdentity( + ctx, + newValidWorkloadIdentity("1"), + ) + require.NoError(t, err) + _, err = service.CreateWorkloadIdentity( + ctx, + newValidWorkloadIdentity("2"), + ) + require.NoError(t, err) + + page, _, err := service.ListWorkloadIdentities(ctx, 0, "") + require.NoError(t, err) + require.Len(t, page, 2) + + err = service.DeleteAllWorkloadIdentities(ctx) + require.NoError(t, err) + + page, _, err = service.ListWorkloadIdentities(ctx, 0, "") + require.NoError(t, err) + require.Empty(t, page) +} + +func TestWorkloadIdentityService_UpdateWorkloadIdentity(t *testing.T) { + ctx, service := setupWorkloadIdentityServiceTest(t) + + t.Run("ok", func(t *testing.T) { + // Create first to support updating + toCreate := newValidWorkloadIdentity("example") + got, err := service.CreateWorkloadIdentity( + ctx, + // Clone to avoid Marshaling modifying want + proto.Clone(toCreate).(*workloadidentityv1pb.WorkloadIdentity), + ) + require.NoError(t, err) + require.NotEmpty(t, got.Metadata.Revision) + got.Spec.Spiffe.Id = "/changed" + got2, err := service.UpdateWorkloadIdentity( + ctx, + // Clone to avoid Marshaling modifying want + proto.Clone(got).(*workloadidentityv1pb.WorkloadIdentity), + ) + require.NoError(t, err) + require.NotEmpty(t, got2.Metadata.Revision) + require.Empty(t, cmp.Diff( + got, + got2, + protocmp.Transform(), + protocmp.IgnoreFields(&headerv1.Metadata{}, "revision"), + )) + }) + t.Run("validation occurs", func(t *testing.T) { + // Create first to support updating + toCreate := newValidWorkloadIdentity("example2") + got, err := service.CreateWorkloadIdentity( + ctx, + // Clone to avoid Marshaling modifying want + proto.Clone(toCreate).(*workloadidentityv1pb.WorkloadIdentity), + ) + require.NoError(t, err) + require.NotEmpty(t, got.Metadata.Revision) + got.Spec.Spiffe.Id = "" + got2, err := service.UpdateWorkloadIdentity( + ctx, + // Clone to avoid Marshaling modifying want + proto.Clone(got).(*workloadidentityv1pb.WorkloadIdentity), + ) + require.ErrorContains(t, err, "spec.spiffe.id: is required") + require.Nil(t, got2) + }) + t.Run("cond update blocks", func(t *testing.T) { + toCreate := newValidWorkloadIdentity("example4") + got, err := service.CreateWorkloadIdentity( + ctx, + // Clone to avoid Marshaling modifying want + proto.Clone(toCreate).(*workloadidentityv1pb.WorkloadIdentity), + ) + require.NoError(t, err) + // We'll now update it twice, but on the second update, we will use the + // revision from the creation not the second update. + _, err = service.UpdateWorkloadIdentity( + ctx, + proto.Clone(got).(*workloadidentityv1pb.WorkloadIdentity), + ) + require.NoError(t, err) + _, err = service.UpdateWorkloadIdentity( + ctx, + proto.Clone(got).(*workloadidentityv1pb.WorkloadIdentity), + ) + require.ErrorIs(t, err, backend.ErrIncorrectRevision) + }) + t.Run("no upsert", func(t *testing.T) { + toUpdate := newValidWorkloadIdentity("example3") + _, err := service.UpdateWorkloadIdentity( + ctx, + // Clone to avoid Marshaling modifying want + proto.Clone(toUpdate).(*workloadidentityv1pb.WorkloadIdentity), + ) + require.Error(t, err) + }) +} diff --git a/lib/services/workload_identity.go b/lib/services/workload_identity.go new file mode 100644 index 0000000000000..89b87ba0d2473 --- /dev/null +++ b/lib/services/workload_identity.go @@ -0,0 +1,122 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package services + +import ( + "context" + "strings" + + "github.com/gravitational/trace" + + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" + "github.com/gravitational/teleport/api/types" +) + +// WorkloadIdentities is an interface over the WorkloadIdentities service. This +// interface may also be implemented by a client to allow remote and local +// consumers to access the resource in a similar way. +type WorkloadIdentities interface { + // GetWorkloadIdentity gets a SPIFFE Federation by name. + GetWorkloadIdentity( + ctx context.Context, name string, + ) (*workloadidentityv1pb.WorkloadIdentity, error) + // ListWorkloadIdentities lists all WorkloadIdentities using Google style + // pagination. + ListWorkloadIdentities( + ctx context.Context, pageSize int, lastToken string, + ) ([]*workloadidentityv1pb.WorkloadIdentity, string, error) + // CreateWorkloadIdentity creates a new WorkloadIdentity. + CreateWorkloadIdentity( + ctx context.Context, workloadIdentity *workloadidentityv1pb.WorkloadIdentity, + ) (*workloadidentityv1pb.WorkloadIdentity, error) + // DeleteWorkloadIdentity deletes a SPIFFE Federation by name. + DeleteWorkloadIdentity(ctx context.Context, name string) error + // UpdateWorkloadIdentity updates a specific WorkloadIdentity. The resource must + // already exist, and, condition update semantics are used - e.g the submitted + // resource must have a revision matching the revision of the resource in the + // backend. + UpdateWorkloadIdentity( + ctx context.Context, workloadIdentity *workloadidentityv1pb.WorkloadIdentity, + ) (*workloadidentityv1pb.WorkloadIdentity, error) + // UpsertWorkloadIdentity creates or updates a WorkloadIdentity. + UpsertWorkloadIdentity( + ctx context.Context, workloadIdentity *workloadidentityv1pb.WorkloadIdentity, + ) (*workloadidentityv1pb.WorkloadIdentity, error) +} + +// MarshalWorkloadIdentity marshals the WorkloadIdentity object into a JSON byte +// array. +func MarshalWorkloadIdentity( + object *workloadidentityv1pb.WorkloadIdentity, opts ...MarshalOption, +) ([]byte, error) { + return MarshalProtoResource(object, opts...) +} + +// UnmarshalWorkloadIdentity unmarshals the WorkloadIdentity object from a +// JSON byte array. +func UnmarshalWorkloadIdentity( + data []byte, opts ...MarshalOption, +) (*workloadidentityv1pb.WorkloadIdentity, error) { + return UnmarshalProtoResource[*workloadidentityv1pb.WorkloadIdentity](data, opts...) +} + +// ValidateWorkloadIdentity validates the WorkloadIdentity object. This is +// performed prior to writing to the backend. +func ValidateWorkloadIdentity(s *workloadidentityv1pb.WorkloadIdentity) error { + switch { + case s == nil: + return trace.BadParameter("object cannot be nil") + case s.Version != types.V1: + return trace.BadParameter("version: only %q is supported", types.V1) + case s.Kind != types.KindWorkloadIdentity: + return trace.BadParameter("kind: must be %q", types.KindWorkloadIdentity) + case s.Metadata == nil: + return trace.BadParameter("metadata: is required") + case s.Metadata.Name == "": + return trace.BadParameter("metadata.name: is required") + case s.Spec == nil: + return trace.BadParameter("spec: is required") + case s.Spec.Spiffe.Id == "": + return trace.BadParameter("spec.spiffe.id: is required") + case !strings.HasPrefix(s.Spec.Spiffe.Id, "/"): + return trace.BadParameter("spec.spiffe.id: must start with a /") + } + + for i, rule := range s.GetSpec().GetRules().GetAllow() { + if len(rule.Conditions) == 0 { + return trace.BadParameter("spec.rules.allow[%d].conditions: must be non-empty", i) + } + for j, condition := range rule.Conditions { + if condition.Attribute == "" { + return trace.BadParameter("spec.rules.allow[%d].conditions[%d].attribute: must be non-empty", i, j) + } + // Ensure exactly one operator is set. + operatorsSet := 0 + if condition.Equals != "" { + operatorsSet++ + } + if operatorsSet == 0 || operatorsSet > 1 { + return trace.BadParameter( + "spec.rules.allow[%d].conditions[%d]: exactly one operator must be specified, found %d", + i, j, operatorsSet, + ) + } + } + } + + return nil +} diff --git a/lib/services/workload_identity_test.go b/lib/services/workload_identity_test.go new file mode 100644 index 0000000000000..429612ed48555 --- /dev/null +++ b/lib/services/workload_identity_test.go @@ -0,0 +1,231 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package services + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/testing/protocmp" + + headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1" + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" + "github.com/gravitational/teleport/api/types" +) + +func TestWorkloadIdentityMarshaling(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + in *workloadidentityv1pb.WorkloadIdentity + }{ + { + name: "normal", + in: &workloadidentityv1pb.WorkloadIdentity{ + Kind: types.KindWorkloadIdentity, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: "example", + }, + Spec: &workloadidentityv1pb.WorkloadIdentitySpec{ + Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{ + Id: "/example", + }, + }, + }, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + gotBytes, err := MarshalWorkloadIdentity(tc.in) + require.NoError(t, err) + // Test that unmarshaling gives us the same object + got, err := UnmarshalWorkloadIdentity(gotBytes) + require.NoError(t, err) + require.Empty(t, cmp.Diff(tc.in, got, protocmp.Transform())) + }) + } +} + +func TestValidateWorkloadIdentity(t *testing.T) { + t.Parallel() + + var errContains = func(contains string) require.ErrorAssertionFunc { + return func(t require.TestingT, err error, msgAndArgs ...interface{}) { + require.ErrorContains(t, err, contains, msgAndArgs...) + } + } + + testCases := []struct { + name string + in *workloadidentityv1pb.WorkloadIdentity + requireErr require.ErrorAssertionFunc + }{ + { + name: "success - full", + in: &workloadidentityv1pb.WorkloadIdentity{ + Kind: types.KindWorkloadIdentity, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: "example", + }, + Spec: &workloadidentityv1pb.WorkloadIdentitySpec{ + Rules: &workloadidentityv1pb.WorkloadIdentityRules{ + Allow: []*workloadidentityv1pb.WorkloadIdentityRule{ + { + Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{ + { + Attribute: "example", + Equals: "foo", + }, + }, + }, + }, + }, + Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{ + Id: "/example", + }, + }, + }, + requireErr: require.NoError, + }, + { + name: "success - minimal", + in: &workloadidentityv1pb.WorkloadIdentity{ + Kind: types.KindWorkloadIdentity, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: "example", + }, + Spec: &workloadidentityv1pb.WorkloadIdentitySpec{ + Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{ + Id: "/example", + }, + }, + }, + requireErr: require.NoError, + }, + { + name: "missing name", + in: &workloadidentityv1pb.WorkloadIdentity{ + Kind: types.KindWorkloadIdentity, + Version: types.V1, + Metadata: &headerv1.Metadata{}, + Spec: &workloadidentityv1pb.WorkloadIdentitySpec{ + Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{ + Id: "/example", + }, + }, + }, + requireErr: errContains("metadata.name: is required"), + }, + { + name: "missing spiffe id", + in: &workloadidentityv1pb.WorkloadIdentity{ + Kind: types.KindWorkloadIdentity, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: "example", + }, + Spec: &workloadidentityv1pb.WorkloadIdentitySpec{ + Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{}, + }, + }, + requireErr: errContains("spec.spiffe.id: is required"), + }, + { + name: "spiffe id must have leading /", + in: &workloadidentityv1pb.WorkloadIdentity{ + Kind: types.KindWorkloadIdentity, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: "example", + }, + Spec: &workloadidentityv1pb.WorkloadIdentitySpec{ + Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{ + Id: "example", + }, + }, + }, + requireErr: errContains("spec.spiffe.id: must start with a /"), + }, + { + name: "missing attribute", + in: &workloadidentityv1pb.WorkloadIdentity{ + Kind: types.KindWorkloadIdentity, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: "example", + }, + Spec: &workloadidentityv1pb.WorkloadIdentitySpec{ + Rules: &workloadidentityv1pb.WorkloadIdentityRules{ + Allow: []*workloadidentityv1pb.WorkloadIdentityRule{ + { + Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{ + { + Attribute: "", + Equals: "foo", + }, + }, + }, + }, + }, + Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{ + Id: "/example", + }, + }, + }, + requireErr: errContains("spec.rules.allow[0].conditions[0].attribute: must be non-empty"), + }, + { + name: "missing operator", + in: &workloadidentityv1pb.WorkloadIdentity{ + Kind: types.KindWorkloadIdentity, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: "example", + }, + Spec: &workloadidentityv1pb.WorkloadIdentitySpec{ + Rules: &workloadidentityv1pb.WorkloadIdentityRules{ + Allow: []*workloadidentityv1pb.WorkloadIdentityRule{ + { + Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{ + { + Attribute: "example", + }, + }, + }, + }, + }, + Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{ + Id: "/example", + }, + }, + }, + requireErr: errContains("spec.rules.allow[0].conditions[0]: exactly one operator must be specified, found 0"), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := ValidateWorkloadIdentity(tc.in) + tc.requireErr(t, err) + }) + } +}