diff --git a/lib/auth/accesspoint/accesspoint.go b/lib/auth/accesspoint/accesspoint.go
index 5c0bc9b957ee0..ff60731708600 100644
--- a/lib/auth/accesspoint/accesspoint.go
+++ b/lib/auth/accesspoint/accesspoint.go
@@ -103,6 +103,7 @@ type Config struct {
Users services.UsersService
WebSession types.WebSessionInterface
WebToken types.WebTokenInterface
+ WorkloadIdentity cache.WorkloadIdentityReader
WindowsDesktops services.WindowsDesktops
AutoUpdateService services.AutoUpdateServiceGetter
}
@@ -198,6 +199,7 @@ func NewCache(cfg Config) (*cache.Cache, error) {
Users: cfg.Users,
WebSession: cfg.WebSession,
WebToken: cfg.WebToken,
+ WorkloadIdentity: cfg.WorkloadIdentity,
WindowsDesktops: cfg.WindowsDesktops,
}
diff --git a/lib/auth/auth.go b/lib/auth/auth.go
index ec02d335ff141..2abd052512ec5 100644
--- a/lib/auth/auth.go
+++ b/lib/auth/auth.go
@@ -359,6 +359,13 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) {
return nil, trace.Wrap(err, "creating SPIFFEFederation service")
}
}
+ if cfg.WorkloadIdentity == nil {
+ workloadIdentity, err := local.NewWorkloadIdentityService(cfg.Backend)
+ if err != nil {
+ return nil, trace.Wrap(err, "creating WorkloadIdentity service")
+ }
+ cfg.WorkloadIdentity = workloadIdentity
+ }
limiter, err := limiter.NewConnectionsLimiter(limiter.Config{
MaxConnections: defaults.LimiterMaxConcurrentSignatures,
@@ -455,6 +462,7 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) {
BotInstance: cfg.BotInstance,
SPIFFEFederations: cfg.SPIFFEFederations,
StaticHostUser: cfg.StaticHostUsers,
+ WorkloadIdentities: cfg.WorkloadIdentity,
}
as := Server{
@@ -668,6 +676,7 @@ type Services struct {
services.BotInstance
services.StaticHostUser
services.AutoUpdateService
+ services.WorkloadIdentities
}
// GetWebSession returns existing web session described by req.
diff --git a/lib/auth/authclient/api.go b/lib/auth/authclient/api.go
index 710427d8615e3..d6b51a154c1f5 100644
--- a/lib/auth/authclient/api.go
+++ b/lib/auth/authclient/api.go
@@ -37,6 +37,7 @@ import (
userprovisioningpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/userprovisioning/v2"
userspb "github.com/gravitational/teleport/api/gen/proto/go/teleport/users/v1"
usertasksv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/usertasks/v1"
+ workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/types/accesslist"
"github.com/gravitational/teleport/api/types/discoveryconfig"
@@ -1211,6 +1212,12 @@ type Cache interface {
// GetAccessGraphSettings returns the access graph settings.
GetAccessGraphSettings(context.Context) (*clusterconfigpb.AccessGraphSettings, error)
+ // GetWorkloadIdentity gets a WorkloadIdentity by name.
+ GetWorkloadIdentity(ctx context.Context, name string) (*workloadidentityv1pb.WorkloadIdentity, error)
+ // ListWorkloadIdentities lists all SPIFFE Federations using Google style
+ // pagination.
+ ListWorkloadIdentities(ctx context.Context, pageSize int, lastToken string) ([]*workloadidentityv1pb.WorkloadIdentity, string, error)
+
// ListStaticHostUsers lists static host users.
ListStaticHostUsers(ctx context.Context, pageSize int, startKey string) ([]*userprovisioningpb.StaticHostUser, string, error)
// GetStaticHostUser returns a static host user by name.
diff --git a/lib/auth/helpers.go b/lib/auth/helpers.go
index 22bc3b221e1fd..d4242442ab85f 100644
--- a/lib/auth/helpers.go
+++ b/lib/auth/helpers.go
@@ -347,6 +347,7 @@ func NewTestAuthServer(cfg TestAuthServerConfig) (*TestAuthServer, error) {
SecReports: svces.SecReports,
SnowflakeSession: svces.Identity,
SPIFFEFederations: svces.SPIFFEFederations,
+ WorkloadIdentity: svces.WorkloadIdentities,
StaticHostUsers: svces.StaticHostUser,
Trust: svces.TrustInternal,
UserGroups: svces.UserGroups,
diff --git a/lib/auth/init.go b/lib/auth/init.go
index 2203b353a2857..8c92b8da41f4a 100644
--- a/lib/auth/init.go
+++ b/lib/auth/init.go
@@ -316,6 +316,10 @@ type InitConfig struct {
// SPIFFEFederations is a service that manages storing SPIFFE federations.
SPIFFEFederations services.SPIFFEFederations
+ // WorkloadIdentity is the service for storing and retrieving
+ // WorkloadIdentity resources.
+ WorkloadIdentity services.WorkloadIdentities
+
// StaticHostUsers is a service that manages host users that should be
// created on SSH nodes.
StaticHostUsers services.StaticHostUser
diff --git a/lib/cache/cache.go b/lib/cache/cache.go
index 6540919dbd22b..1f44c0b07e182 100644
--- a/lib/cache/cache.go
+++ b/lib/cache/cache.go
@@ -190,6 +190,7 @@ func ForAuth(cfg Config) Config {
{Kind: types.KindUserTask},
{Kind: types.KindAutoUpdateVersion},
{Kind: types.KindAutoUpdateConfig},
+ {Kind: types.KindWorkloadIdentity},
}
cfg.QueueSize = defaults.AuthQueueSize
// We don't want to enable partial health for auth cache because auth uses an event stream
@@ -536,6 +537,7 @@ type Cache struct {
accessMontoringRuleCache services.AccessMonitoringRules
spiffeFederationCache spiffeFederationCacher
staticHostUsersCache *local.StaticHostUserService
+ workloadIdentityCache workloadIdentityCacher
// closed indicates that the cache has been closed
closed atomic.Bool
@@ -716,6 +718,9 @@ type Config struct {
SPIFFEFederations SPIFFEFederationReader
// StaticHostUsers is the static host user service.
StaticHostUsers services.StaticHostUser
+ // WorkloadIdentity is the upstream Workload Identities service that we're
+ // caching
+ WorkloadIdentity WorkloadIdentityReader
// Backend is a backend for local cache
Backend backend.Backend
// MaxRetryPeriod is the maximum period between cache retries on failures
@@ -969,6 +974,12 @@ func New(config Config) (*Cache, error) {
return nil, trace.Wrap(err)
}
+ workloadIdentityCache, err := local.NewWorkloadIdentityService(config.Backend)
+ if err != nil {
+ cancel()
+ return nil, trace.Wrap(err)
+ }
+
staticHostUserCache, err := local.NewStaticHostUserService(config.Backend)
if err != nil {
cancel()
@@ -1019,6 +1030,7 @@ func New(config Config) (*Cache, error) {
kubeWaitingContsCache: kubeWaitingContsCache,
spiffeFederationCache: spiffeFederationCache,
staticHostUsersCache: staticHostUserCache,
+ workloadIdentityCache: workloadIdentityCache,
Logger: log.WithFields(log.Fields{
teleport.ComponentKey: config.Component,
}),
diff --git a/lib/cache/cache_test.go b/lib/cache/cache_test.go
index 19e625561dbc1..cccb8a6ad3a24 100644
--- a/lib/cache/cache_test.go
+++ b/lib/cache/cache_test.go
@@ -138,6 +138,7 @@ type testPack struct {
spiffeFederations *local.SPIFFEFederationService
staticHostUsers services.StaticHostUser
autoUpdateService services.AutoUpdateService
+ workloadIdentity *local.WorkloadIdentityService
}
// testFuncs are functions to support testing an object in a cache.
@@ -350,6 +351,12 @@ func newPackWithoutCache(dir string, opts ...packOption) (*testPack, error) {
}
p.spiffeFederations = spiffeFederationsSvc
+ workloadIdentitySvc, err := local.NewWorkloadIdentityService(p.backend)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ p.workloadIdentity = workloadIdentitySvc
+
databaseObjectsSvc, err := local.NewDatabaseObjectService(p.backend)
if err != nil {
return nil, trace.Wrap(err)
@@ -428,6 +435,7 @@ func newPack(dir string, setupConfig func(c Config) Config, opts ...packOption)
DatabaseObjects: p.databaseObjects,
StaticHostUsers: p.staticHostUsers,
AutoUpdateService: p.autoUpdateService,
+ WorkloadIdentity: p.workloadIdentity,
MaxRetryPeriod: 200 * time.Millisecond,
EventsC: p.eventsC,
}))
@@ -836,6 +844,7 @@ func TestCompletenessInit(t *testing.T) {
SPIFFEFederations: p.spiffeFederations,
StaticHostUsers: p.staticHostUsers,
AutoUpdateService: p.autoUpdateService,
+ WorkloadIdentity: p.workloadIdentity,
MaxRetryPeriod: 200 * time.Millisecond,
EventsC: p.eventsC,
}))
@@ -917,6 +926,7 @@ func TestCompletenessReset(t *testing.T) {
SPIFFEFederations: p.spiffeFederations,
StaticHostUsers: p.staticHostUsers,
AutoUpdateService: p.autoUpdateService,
+ WorkloadIdentity: p.workloadIdentity,
MaxRetryPeriod: 200 * time.Millisecond,
EventsC: p.eventsC,
}))
@@ -1124,6 +1134,7 @@ func TestListResources_NodesTTLVariant(t *testing.T) {
SPIFFEFederations: p.spiffeFederations,
StaticHostUsers: p.staticHostUsers,
AutoUpdateService: p.autoUpdateService,
+ WorkloadIdentity: p.workloadIdentity,
MaxRetryPeriod: 200 * time.Millisecond,
EventsC: p.eventsC,
neverOK: true, // ensure reads are never healthy
@@ -1216,6 +1227,7 @@ func initStrategy(t *testing.T) {
SPIFFEFederations: p.spiffeFederations,
StaticHostUsers: p.staticHostUsers,
AutoUpdateService: p.autoUpdateService,
+ WorkloadIdentity: p.workloadIdentity,
MaxRetryPeriod: 200 * time.Millisecond,
EventsC: p.eventsC,
}))
@@ -3454,6 +3466,7 @@ func TestCacheWatchKindExistsInEvents(t *testing.T) {
types.KindUserTask: types.Resource153ToLegacy(newUserTasks(t)),
types.KindAutoUpdateConfig: types.Resource153ToLegacy(newAutoUpdateConfig(t)),
types.KindAutoUpdateVersion: types.Resource153ToLegacy(newAutoUpdateVersion(t)),
+ types.KindWorkloadIdentity: types.Resource153ToLegacy(newWorkloadIdentity("some_identifier")),
}
for name, cfg := range cases {
diff --git a/lib/cache/collections.go b/lib/cache/collections.go
index 84a02648881cd..8dcb5723ea4a8 100644
--- a/lib/cache/collections.go
+++ b/lib/cache/collections.go
@@ -39,6 +39,7 @@ import (
userprovisioningpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/userprovisioning/v2"
userspb "github.com/gravitational/teleport/api/gen/proto/go/teleport/users/v1"
usertasksv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/usertasks/v1"
+ workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/types/accesslist"
"github.com/gravitational/teleport/api/types/discoveryconfig"
@@ -265,6 +266,7 @@ type cacheCollections struct {
spiffeFederations collectionReader[SPIFFEFederationReader]
autoUpdateConfigs collectionReader[autoUpdateConfigGetter]
autoUpdateVersions collectionReader[autoUpdateVersionGetter]
+ workloadIdentity collectionReader[WorkloadIdentityReader]
}
// setupCollections returns a registry of collections.
@@ -784,6 +786,15 @@ func setupCollections(c *Cache, watches []types.WatchKind) (*cacheCollections, e
watch: watch,
}
collections.byKind[resourceKind] = collections.accessGraphSettings
+ case types.KindWorkloadIdentity:
+ if c.Config.WorkloadIdentity == nil {
+ return nil, trace.BadParameter("missing parameter WorkloadIdentity")
+ }
+ collections.workloadIdentity = &genericCollection[*workloadidentityv1pb.WorkloadIdentity, WorkloadIdentityReader, workloadIdentityExecutor]{
+ cache: c,
+ watch: watch,
+ }
+ collections.byKind[resourceKind] = collections.workloadIdentity
case types.KindAutoUpdateConfig:
if c.AutoUpdateService == nil {
return nil, trace.BadParameter("missing parameter AutoUpdateService")
diff --git a/lib/cache/resource_workload_identity.go b/lib/cache/resource_workload_identity.go
new file mode 100644
index 0000000000000..75efb50fedbd5
--- /dev/null
+++ b/lib/cache/resource_workload_identity.go
@@ -0,0 +1,119 @@
+// Teleport
+// Copyright (C) 2024 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+//nolint:unused // Because the executors generate a large amount of false positives.
+package cache
+
+import (
+ "context"
+
+ "github.com/gravitational/trace"
+
+ workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
+ "github.com/gravitational/teleport/api/types"
+)
+
+// WorkloadIdentityReader is an interface that defines the methods for getting
+// WorkloadIdentity. This is returned as the reader for the WorkloadIdentity
+// collection but is also used by the executor to read the full list of
+// WorkloadIdentity on initialization.
+type WorkloadIdentityReader interface {
+ ListWorkloadIdentities(ctx context.Context, pageSize int, nextToken string) ([]*workloadidentityv1pb.WorkloadIdentity, string, error)
+ GetWorkloadIdentity(ctx context.Context, name string) (*workloadidentityv1pb.WorkloadIdentity, error)
+}
+
+// workloadIdentityCacher is used for storing and retrieving WorkloadIdentity
+// from the cache's local backend.
+type workloadIdentityCacher interface {
+ WorkloadIdentityReader
+ UpsertWorkloadIdentity(ctx context.Context, resource *workloadidentityv1pb.WorkloadIdentity) (*workloadidentityv1pb.WorkloadIdentity, error)
+ DeleteWorkloadIdentity(ctx context.Context, name string) error
+ DeleteAllWorkloadIdentities(ctx context.Context) error
+}
+
+type workloadIdentityExecutor struct{}
+
+var _ executor[*workloadidentityv1pb.WorkloadIdentity, WorkloadIdentityReader] = workloadIdentityExecutor{}
+
+func (workloadIdentityExecutor) getAll(ctx context.Context, cache *Cache, loadSecrets bool) ([]*workloadidentityv1pb.WorkloadIdentity, error) {
+ var out []*workloadidentityv1pb.WorkloadIdentity
+ var nextToken string
+ for {
+ var page []*workloadidentityv1pb.WorkloadIdentity
+ var err error
+
+ const defaultPageSize = 0
+ page, nextToken, err = cache.Config.WorkloadIdentity.ListWorkloadIdentities(ctx, defaultPageSize, nextToken)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ out = append(out, page...)
+ if nextToken == "" {
+ break
+ }
+ }
+ return out, nil
+}
+
+func (workloadIdentityExecutor) upsert(ctx context.Context, cache *Cache, resource *workloadidentityv1pb.WorkloadIdentity) error {
+ _, err := cache.workloadIdentityCache.UpsertWorkloadIdentity(ctx, resource)
+ return trace.Wrap(err)
+}
+
+func (workloadIdentityExecutor) deleteAll(ctx context.Context, cache *Cache) error {
+ return trace.Wrap(cache.workloadIdentityCache.DeleteAllWorkloadIdentities(ctx))
+}
+
+func (workloadIdentityExecutor) delete(ctx context.Context, cache *Cache, resource types.Resource) error {
+ return trace.Wrap(cache.workloadIdentityCache.DeleteWorkloadIdentity(ctx, resource.GetName()))
+}
+
+func (workloadIdentityExecutor) isSingleton() bool { return false }
+
+func (workloadIdentityExecutor) getReader(cache *Cache, cacheOK bool) WorkloadIdentityReader {
+ if cacheOK {
+ return cache.workloadIdentityCache
+ }
+ return cache.Config.WorkloadIdentity
+}
+
+// ListWorkloadIdentities returns a paginated list of WorkloadIdentity resources.
+func (c *Cache) ListWorkloadIdentities(ctx context.Context, pageSize int, nextToken string) ([]*workloadidentityv1pb.WorkloadIdentity, string, error) {
+ ctx, span := c.Tracer.Start(ctx, "cache/ListWorkloadIdentities")
+ defer span.End()
+
+ rg, err := readCollectionCache(c, c.collections.workloadIdentity)
+ if err != nil {
+ return nil, "", trace.Wrap(err)
+ }
+ defer rg.Release()
+ out, nextKey, err := rg.reader.ListWorkloadIdentities(ctx, pageSize, nextToken)
+ return out, nextKey, trace.Wrap(err)
+}
+
+// GetWorkloadIdentity returns a single WorkloadIdentity by name
+func (c *Cache) GetWorkloadIdentity(ctx context.Context, name string) (*workloadidentityv1pb.WorkloadIdentity, error) {
+ ctx, span := c.Tracer.Start(ctx, "cache/GetWorkloadIdentity")
+ defer span.End()
+
+ rg, err := readCollectionCache(c, c.collections.workloadIdentity)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ defer rg.Release()
+ out, err := rg.reader.GetWorkloadIdentity(ctx, name)
+ return out, trace.Wrap(err)
+}
diff --git a/lib/cache/resource_workload_identity_test.go b/lib/cache/resource_workload_identity_test.go
new file mode 100644
index 0000000000000..da82d64fec27c
--- /dev/null
+++ b/lib/cache/resource_workload_identity_test.go
@@ -0,0 +1,74 @@
+// Teleport
+// Copyright (C) 2024 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package cache
+
+import (
+ "context"
+ "testing"
+
+ "github.com/gravitational/trace"
+
+ headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1"
+ workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
+ "github.com/gravitational/teleport/api/types"
+)
+
+func newWorkloadIdentity(name string) *workloadidentityv1pb.WorkloadIdentity {
+ return &workloadidentityv1pb.WorkloadIdentity{
+ Kind: types.KindWorkloadIdentity,
+ Version: types.V1,
+ Metadata: &headerv1.Metadata{
+ Name: name,
+ },
+ Spec: &workloadidentityv1pb.WorkloadIdentitySpec{
+ Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{
+ Id: "/example",
+ },
+ },
+ }
+}
+
+func TestWorkloadIdentity(t *testing.T) {
+ t.Parallel()
+
+ p := newTestPack(t, ForAuth)
+ t.Cleanup(p.Close)
+
+ testResources153(t, p, testFuncs153[*workloadidentityv1pb.WorkloadIdentity]{
+ newResource: func(s string) (*workloadidentityv1pb.WorkloadIdentity, error) {
+ return newWorkloadIdentity(s), nil
+ },
+
+ create: func(ctx context.Context, item *workloadidentityv1pb.WorkloadIdentity) error {
+ _, err := p.workloadIdentity.CreateWorkloadIdentity(ctx, item)
+ return trace.Wrap(err)
+ },
+ list: func(ctx context.Context) ([]*workloadidentityv1pb.WorkloadIdentity, error) {
+ items, _, err := p.workloadIdentity.ListWorkloadIdentities(ctx, 0, "")
+ return items, trace.Wrap(err)
+ },
+ deleteAll: func(ctx context.Context) error {
+ return p.workloadIdentity.DeleteAllWorkloadIdentities(ctx)
+ },
+
+ cacheList: func(ctx context.Context) ([]*workloadidentityv1pb.WorkloadIdentity, error) {
+ items, _, err := p.cache.ListWorkloadIdentities(ctx, 0, "")
+ return items, trace.Wrap(err)
+ },
+ cacheGet: p.cache.GetWorkloadIdentity,
+ })
+}
diff --git a/lib/service/service.go b/lib/service/service.go
index e2a35c8df2f07..86363450dd1a4 100644
--- a/lib/service/service.go
+++ b/lib/service/service.go
@@ -2429,6 +2429,7 @@ func (process *TeleportProcess) newAccessCacheForServices(cfg accesspoint.Config
cfg.WebSession = services.Identity.WebSessions()
cfg.WebToken = services.Identity.WebTokens()
cfg.WindowsDesktops = services.WindowsDesktops
+ cfg.WorkloadIdentity = services.WorkloadIdentities
cfg.AutoUpdateService = services.AutoUpdateService
return accesspoint.NewCache(cfg)
diff --git a/lib/services/local/events.go b/lib/services/local/events.go
index 994f54c18cd4e..0c19086766a0c 100644
--- a/lib/services/local/events.go
+++ b/lib/services/local/events.go
@@ -234,6 +234,8 @@ func (e *EventsService) NewWatcher(ctx context.Context, watch types.Watch) (type
parser = newAccessGraphSettingsParser()
case types.KindStaticHostUser:
parser = newStaticHostUserParser()
+ case types.KindWorkloadIdentity:
+ parser = newWorkloadIdentityParser()
default:
if watch.AllowPartialSuccess {
continue
@@ -2607,3 +2609,31 @@ func (p *accessGraphSettingsParser) parse(event backend.Event) (types.Resource,
return nil, trace.BadParameter("event %v is not supported", event.Type)
}
}
+
+func newWorkloadIdentityParser() *workloadIdentityParser {
+ return &workloadIdentityParser{
+ baseParser: newBaseParser(backend.NewKey(workloadIdentityPrefix)),
+ }
+}
+
+type workloadIdentityParser struct {
+ baseParser
+}
+
+func (p *workloadIdentityParser) parse(event backend.Event) (types.Resource, error) {
+ switch event.Type {
+ case types.OpDelete:
+ return resourceHeader(event, types.KindAccessGraphSettings, types.V1, 0)
+ case types.OpPut:
+ resource, err := services.UnmarshalWorkloadIdentity(
+ event.Item.Value,
+ services.WithExpires(event.Item.Expires),
+ services.WithRevision(event.Item.Revision))
+ if err != nil {
+ return nil, trace.Wrap(err, "unmarshalling resource from event")
+ }
+ return types.Resource153ToLegacy(resource), nil
+ default:
+ return nil, trace.BadParameter("event %v is not supported", event.Type)
+ }
+}
diff --git a/lib/services/local/workload_identity.go b/lib/services/local/workload_identity.go
new file mode 100644
index 0000000000000..e0504e989cbe8
--- /dev/null
+++ b/lib/services/local/workload_identity.go
@@ -0,0 +1,118 @@
+// Teleport
+// Copyright (C) 2024 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package local
+
+import (
+ "context"
+
+ "github.com/gravitational/trace"
+
+ workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
+ "github.com/gravitational/teleport/api/types"
+ "github.com/gravitational/teleport/lib/backend"
+ "github.com/gravitational/teleport/lib/services"
+ "github.com/gravitational/teleport/lib/services/local/generic"
+)
+
+const (
+ workloadIdentityPrefix = "workload_identity"
+)
+
+// WorkloadIdentityService exposes backend functionality for storing
+// WorkloadIdentity resources
+type WorkloadIdentityService struct {
+ service *generic.ServiceWrapper[*workloadidentityv1pb.WorkloadIdentity]
+}
+
+// NewWorkloadIdentityService creates a new WorkloadIdentityService
+func NewWorkloadIdentityService(b backend.Backend) (*WorkloadIdentityService, error) {
+ service, err := generic.NewServiceWrapper(
+ generic.ServiceWrapperConfig[*workloadidentityv1pb.WorkloadIdentity]{
+ Backend: b,
+ ResourceKind: types.KindWorkloadIdentity,
+ BackendPrefix: backend.NewKey(workloadIdentityPrefix),
+ MarshalFunc: services.MarshalWorkloadIdentity,
+ UnmarshalFunc: services.UnmarshalWorkloadIdentity,
+ ValidateFunc: services.ValidateWorkloadIdentity,
+ })
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ return &WorkloadIdentityService{
+ service: service,
+ }, nil
+}
+
+// CreateWorkloadIdentity inserts a new WorkloadIdentity into the backend.
+func (b *WorkloadIdentityService) CreateWorkloadIdentity(
+ ctx context.Context, resource *workloadidentityv1pb.WorkloadIdentity,
+) (*workloadidentityv1pb.WorkloadIdentity, error) {
+ created, err := b.service.CreateResource(ctx, resource)
+ return created, trace.Wrap(err)
+}
+
+// GetWorkloadIdentity retrieves a specific WorkloadIdentity given a name
+func (b *WorkloadIdentityService) GetWorkloadIdentity(
+ ctx context.Context, name string,
+) (*workloadidentityv1pb.WorkloadIdentity, error) {
+ resource, err := b.service.GetResource(ctx, name)
+ return resource, trace.Wrap(err)
+}
+
+// ListWorkloadIdentities lists all WorkloadIdentities using a given page size
+// and last key.
+func (b *WorkloadIdentityService) ListWorkloadIdentities(
+ ctx context.Context, pageSize int, currentToken string,
+) ([]*workloadidentityv1pb.WorkloadIdentity, string, error) {
+ r, nextToken, err := b.service.ListResources(ctx, pageSize, currentToken)
+ return r, nextToken, trace.Wrap(err)
+}
+
+// DeleteWorkloadIdentity deletes a specific WorkloadIdentity.
+func (b *WorkloadIdentityService) DeleteWorkloadIdentity(
+ ctx context.Context, name string,
+) error {
+ return trace.Wrap(b.service.DeleteResource(ctx, name))
+}
+
+// DeleteAllWorkloadIdentities deletes all SPIFFE resources, this is typically
+// only meant to be used by the cache.
+func (b *WorkloadIdentityService) DeleteAllWorkloadIdentities(
+ ctx context.Context,
+) error {
+ return trace.Wrap(b.service.DeleteAllResources(ctx))
+}
+
+// UpsertWorkloadIdentity upserts a WorkloadIdentitys. Prefer using
+// CreateWorkloadIdentity. This is only designed for usage by the cache.
+func (b *WorkloadIdentityService) UpsertWorkloadIdentity(
+ ctx context.Context, resource *workloadidentityv1pb.WorkloadIdentity,
+) (*workloadidentityv1pb.WorkloadIdentity, error) {
+ upserted, err := b.service.UpsertResource(ctx, resource)
+ return upserted, trace.Wrap(err)
+}
+
+// UpdateWorkloadIdentity updates a specific WorkloadIdentity. The resource must
+// already exist, and, condition update semantics are used - e.g the submitted
+// resource must have a revision matching the revision of the resource in the
+// backend.
+func (b *WorkloadIdentityService) UpdateWorkloadIdentity(
+ ctx context.Context, resource *workloadidentityv1pb.WorkloadIdentity,
+) (*workloadidentityv1pb.WorkloadIdentity, error) {
+ updated, err := b.service.ConditionalUpdateResource(ctx, resource)
+ return updated, trace.Wrap(err)
+}
diff --git a/lib/services/local/workload_identity_test.go b/lib/services/local/workload_identity_test.go
new file mode 100644
index 0000000000000..acba05d9c8e4a
--- /dev/null
+++ b/lib/services/local/workload_identity_test.go
@@ -0,0 +1,355 @@
+// Teleport
+// Copyright (C) 2024 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package local
+
+import (
+ "context"
+ "fmt"
+ "slices"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "github.com/gravitational/trace"
+ "github.com/jonboulle/clockwork"
+ "github.com/stretchr/testify/require"
+ "google.golang.org/protobuf/proto"
+ "google.golang.org/protobuf/testing/protocmp"
+
+ headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1"
+ workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
+ "github.com/gravitational/teleport/api/types"
+ "github.com/gravitational/teleport/lib/backend"
+ "github.com/gravitational/teleport/lib/backend/memory"
+)
+
+func setupWorkloadIdentityServiceTest(
+ t *testing.T,
+) (context.Context, *WorkloadIdentityService) {
+ t.Parallel()
+ ctx := context.Background()
+ clock := clockwork.NewFakeClock()
+ mem, err := memory.New(memory.Config{
+ Context: ctx,
+ Clock: clock,
+ })
+ require.NoError(t, err)
+ service, err := NewWorkloadIdentityService(backend.NewSanitizer(mem))
+ require.NoError(t, err)
+ return ctx, service
+}
+
+func newValidWorkloadIdentity(name string) *workloadidentityv1pb.WorkloadIdentity {
+ return &workloadidentityv1pb.WorkloadIdentity{
+ Kind: types.KindWorkloadIdentity,
+ Version: types.V1,
+ Metadata: &headerv1.Metadata{
+ Name: name,
+ },
+ Spec: &workloadidentityv1pb.WorkloadIdentitySpec{
+ Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{
+ Id: "/test",
+ },
+ },
+ }
+}
+
+func TestWorkloadIdentityService_CreateWorkloadIdentity(t *testing.T) {
+ ctx, service := setupWorkloadIdentityServiceTest(t)
+
+ t.Run("ok", func(t *testing.T) {
+ want := newValidWorkloadIdentity("example")
+ got, err := service.CreateWorkloadIdentity(
+ ctx,
+ // Clone to avoid Marshaling modifying want
+ proto.Clone(want).(*workloadidentityv1pb.WorkloadIdentity),
+ )
+ require.NoError(t, err)
+ require.NotEmpty(t, got.Metadata.Revision)
+ require.Empty(t, cmp.Diff(
+ want,
+ got,
+ protocmp.Transform(),
+ protocmp.IgnoreFields(&headerv1.Metadata{}, "revision"),
+ ))
+ })
+ t.Run("validation occurs", func(t *testing.T) {
+ out, err := service.CreateWorkloadIdentity(ctx, newValidWorkloadIdentity(""))
+ require.ErrorContains(t, err, "metadata.name: is required")
+ require.Nil(t, out)
+ })
+ t.Run("no upsert", func(t *testing.T) {
+ res := newValidWorkloadIdentity("duplicate")
+ _, err := service.CreateWorkloadIdentity(
+ ctx,
+ // Clone to avoid Marshaling modifying want
+ proto.Clone(res).(*workloadidentityv1pb.WorkloadIdentity),
+ )
+ require.NoError(t, err)
+ _, err = service.CreateWorkloadIdentity(
+ ctx,
+ // Clone to avoid Marshaling modifying want
+ proto.Clone(res).(*workloadidentityv1pb.WorkloadIdentity),
+ )
+ require.Error(t, err)
+ require.True(t, trace.IsAlreadyExists(err))
+ })
+}
+
+func TestWorkloadIdentityService_UpsertWorkloadIdentity(t *testing.T) {
+ ctx, service := setupWorkloadIdentityServiceTest(t)
+
+ t.Run("ok", func(t *testing.T) {
+ want := newValidWorkloadIdentity("example")
+ got, err := service.UpsertWorkloadIdentity(
+ ctx,
+ // Clone to avoid Marshaling modifying want
+ proto.Clone(want).(*workloadidentityv1pb.WorkloadIdentity),
+ )
+ require.NoError(t, err)
+ require.NotEmpty(t, got.Metadata.Revision)
+ require.Empty(t, cmp.Diff(
+ want,
+ got,
+ protocmp.Transform(),
+ protocmp.IgnoreFields(&headerv1.Metadata{}, "revision"),
+ ))
+
+ // Ensure we can upsert over an existing resource
+ _, err = service.UpsertWorkloadIdentity(
+ ctx,
+ // Clone to avoid Marshaling modifying want
+ proto.Clone(want).(*workloadidentityv1pb.WorkloadIdentity),
+ )
+ require.NoError(t, err)
+ })
+ t.Run("validation occurs", func(t *testing.T) {
+ out, err := service.UpdateWorkloadIdentity(ctx, newValidWorkloadIdentity(""))
+ require.ErrorContains(t, err, "metadata.name: is required")
+ require.Nil(t, out)
+ })
+}
+
+func TestWorkloadIdentityService_ListWorkloadIdentities(t *testing.T) {
+ ctx, service := setupWorkloadIdentityServiceTest(t)
+ // Create entities to list
+ createdObjects := []*workloadidentityv1pb.WorkloadIdentity{}
+ // Create 49 entities to test an incomplete page at the end.
+ for i := 0; i < 49; i++ {
+ created, err := service.CreateWorkloadIdentity(
+ ctx,
+ newValidWorkloadIdentity(fmt.Sprintf("%d", i)),
+ )
+ require.NoError(t, err)
+ createdObjects = append(createdObjects, created)
+ }
+ t.Run("default page size", func(t *testing.T) {
+ page, nextToken, err := service.ListWorkloadIdentities(ctx, 0, "")
+ require.NoError(t, err)
+ require.Len(t, page, 49)
+ require.Empty(t, nextToken)
+
+ // Expect that we get all the things we have created
+ for _, created := range createdObjects {
+ slices.ContainsFunc(page, func(resource *workloadidentityv1pb.WorkloadIdentity) bool {
+ return proto.Equal(created, resource)
+ })
+ }
+ })
+ t.Run("pagination", func(t *testing.T) {
+ fetched := []*workloadidentityv1pb.WorkloadIdentity{}
+ token := ""
+ iterations := 0
+ for {
+ iterations++
+ page, nextToken, err := service.ListWorkloadIdentities(ctx, 10, token)
+ require.NoError(t, err)
+ fetched = append(fetched, page...)
+ if nextToken == "" {
+ break
+ }
+ token = nextToken
+ }
+ require.Equal(t, 5, iterations)
+
+ require.Len(t, fetched, 49)
+ // Expect that we get all the things we have created
+ for _, created := range createdObjects {
+ slices.ContainsFunc(fetched, func(resource *workloadidentityv1pb.WorkloadIdentity) bool {
+ return proto.Equal(created, resource)
+ })
+ }
+ })
+}
+
+func TestWorkloadIdentityService_GetWorkloadIdentity(t *testing.T) {
+ ctx, service := setupWorkloadIdentityServiceTest(t)
+
+ t.Run("ok", func(t *testing.T) {
+ want := newValidWorkloadIdentity("example")
+ _, err := service.CreateWorkloadIdentity(
+ ctx,
+ // Clone to avoid Marshaling modifying want
+ proto.Clone(want).(*workloadidentityv1pb.WorkloadIdentity),
+ )
+ require.NoError(t, err)
+ got, err := service.GetWorkloadIdentity(ctx, "example")
+ require.NoError(t, err)
+ require.NotEmpty(t, got.Metadata.Revision)
+ require.Empty(t, cmp.Diff(
+ want,
+ got,
+ protocmp.Transform(),
+ protocmp.IgnoreFields(&headerv1.Metadata{}, "revision"),
+ ))
+ })
+ t.Run("not found", func(t *testing.T) {
+ _, err := service.GetWorkloadIdentity(ctx, "not-found")
+ require.Error(t, err)
+ require.True(t, trace.IsNotFound(err))
+ })
+}
+
+func TestWorkloadIdentityService_DeleteWorkloadIdentity(t *testing.T) {
+ ctx, service := setupWorkloadIdentityServiceTest(t)
+
+ t.Run("ok", func(t *testing.T) {
+ _, err := service.CreateWorkloadIdentity(
+ ctx,
+ newValidWorkloadIdentity("example"),
+ )
+ require.NoError(t, err)
+
+ _, err = service.GetWorkloadIdentity(ctx, "example")
+ require.NoError(t, err)
+
+ err = service.DeleteWorkloadIdentity(ctx, "example")
+ require.NoError(t, err)
+
+ _, err = service.GetWorkloadIdentity(ctx, "example")
+ require.Error(t, err)
+ require.True(t, trace.IsNotFound(err))
+ })
+ t.Run("not found", func(t *testing.T) {
+ err := service.DeleteWorkloadIdentity(ctx, "foo.example.com")
+ require.Error(t, err)
+ require.True(t, trace.IsNotFound(err))
+ })
+}
+
+func TestWorkloadIdentityService_DeleteAllWorkloadIdentities(t *testing.T) {
+ ctx, service := setupWorkloadIdentityServiceTest(t)
+ _, err := service.CreateWorkloadIdentity(
+ ctx,
+ newValidWorkloadIdentity("1"),
+ )
+ require.NoError(t, err)
+ _, err = service.CreateWorkloadIdentity(
+ ctx,
+ newValidWorkloadIdentity("2"),
+ )
+ require.NoError(t, err)
+
+ page, _, err := service.ListWorkloadIdentities(ctx, 0, "")
+ require.NoError(t, err)
+ require.Len(t, page, 2)
+
+ err = service.DeleteAllWorkloadIdentities(ctx)
+ require.NoError(t, err)
+
+ page, _, err = service.ListWorkloadIdentities(ctx, 0, "")
+ require.NoError(t, err)
+ require.Empty(t, page)
+}
+
+func TestWorkloadIdentityService_UpdateWorkloadIdentity(t *testing.T) {
+ ctx, service := setupWorkloadIdentityServiceTest(t)
+
+ t.Run("ok", func(t *testing.T) {
+ // Create first to support updating
+ toCreate := newValidWorkloadIdentity("example")
+ got, err := service.CreateWorkloadIdentity(
+ ctx,
+ // Clone to avoid Marshaling modifying want
+ proto.Clone(toCreate).(*workloadidentityv1pb.WorkloadIdentity),
+ )
+ require.NoError(t, err)
+ require.NotEmpty(t, got.Metadata.Revision)
+ got.Spec.Spiffe.Id = "/changed"
+ got2, err := service.UpdateWorkloadIdentity(
+ ctx,
+ // Clone to avoid Marshaling modifying want
+ proto.Clone(got).(*workloadidentityv1pb.WorkloadIdentity),
+ )
+ require.NoError(t, err)
+ require.NotEmpty(t, got2.Metadata.Revision)
+ require.Empty(t, cmp.Diff(
+ got,
+ got2,
+ protocmp.Transform(),
+ protocmp.IgnoreFields(&headerv1.Metadata{}, "revision"),
+ ))
+ })
+ t.Run("validation occurs", func(t *testing.T) {
+ // Create first to support updating
+ toCreate := newValidWorkloadIdentity("example2")
+ got, err := service.CreateWorkloadIdentity(
+ ctx,
+ // Clone to avoid Marshaling modifying want
+ proto.Clone(toCreate).(*workloadidentityv1pb.WorkloadIdentity),
+ )
+ require.NoError(t, err)
+ require.NotEmpty(t, got.Metadata.Revision)
+ got.Spec.Spiffe.Id = ""
+ got2, err := service.UpdateWorkloadIdentity(
+ ctx,
+ // Clone to avoid Marshaling modifying want
+ proto.Clone(got).(*workloadidentityv1pb.WorkloadIdentity),
+ )
+ require.ErrorContains(t, err, "spec.spiffe.id: is required")
+ require.Nil(t, got2)
+ })
+ t.Run("cond update blocks", func(t *testing.T) {
+ toCreate := newValidWorkloadIdentity("example4")
+ got, err := service.CreateWorkloadIdentity(
+ ctx,
+ // Clone to avoid Marshaling modifying want
+ proto.Clone(toCreate).(*workloadidentityv1pb.WorkloadIdentity),
+ )
+ require.NoError(t, err)
+ // We'll now update it twice, but on the second update, we will use the
+ // revision from the creation not the second update.
+ _, err = service.UpdateWorkloadIdentity(
+ ctx,
+ proto.Clone(got).(*workloadidentityv1pb.WorkloadIdentity),
+ )
+ require.NoError(t, err)
+ _, err = service.UpdateWorkloadIdentity(
+ ctx,
+ proto.Clone(got).(*workloadidentityv1pb.WorkloadIdentity),
+ )
+ require.ErrorIs(t, err, backend.ErrIncorrectRevision)
+ })
+ t.Run("no upsert", func(t *testing.T) {
+ toUpdate := newValidWorkloadIdentity("example3")
+ _, err := service.UpdateWorkloadIdentity(
+ ctx,
+ // Clone to avoid Marshaling modifying want
+ proto.Clone(toUpdate).(*workloadidentityv1pb.WorkloadIdentity),
+ )
+ require.Error(t, err)
+ })
+}
diff --git a/lib/services/workload_identity.go b/lib/services/workload_identity.go
new file mode 100644
index 0000000000000..89b87ba0d2473
--- /dev/null
+++ b/lib/services/workload_identity.go
@@ -0,0 +1,122 @@
+// Teleport
+// Copyright (C) 2024 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package services
+
+import (
+ "context"
+ "strings"
+
+ "github.com/gravitational/trace"
+
+ workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
+ "github.com/gravitational/teleport/api/types"
+)
+
+// WorkloadIdentities is an interface over the WorkloadIdentities service. This
+// interface may also be implemented by a client to allow remote and local
+// consumers to access the resource in a similar way.
+type WorkloadIdentities interface {
+ // GetWorkloadIdentity gets a SPIFFE Federation by name.
+ GetWorkloadIdentity(
+ ctx context.Context, name string,
+ ) (*workloadidentityv1pb.WorkloadIdentity, error)
+ // ListWorkloadIdentities lists all WorkloadIdentities using Google style
+ // pagination.
+ ListWorkloadIdentities(
+ ctx context.Context, pageSize int, lastToken string,
+ ) ([]*workloadidentityv1pb.WorkloadIdentity, string, error)
+ // CreateWorkloadIdentity creates a new WorkloadIdentity.
+ CreateWorkloadIdentity(
+ ctx context.Context, workloadIdentity *workloadidentityv1pb.WorkloadIdentity,
+ ) (*workloadidentityv1pb.WorkloadIdentity, error)
+ // DeleteWorkloadIdentity deletes a SPIFFE Federation by name.
+ DeleteWorkloadIdentity(ctx context.Context, name string) error
+ // UpdateWorkloadIdentity updates a specific WorkloadIdentity. The resource must
+ // already exist, and, condition update semantics are used - e.g the submitted
+ // resource must have a revision matching the revision of the resource in the
+ // backend.
+ UpdateWorkloadIdentity(
+ ctx context.Context, workloadIdentity *workloadidentityv1pb.WorkloadIdentity,
+ ) (*workloadidentityv1pb.WorkloadIdentity, error)
+ // UpsertWorkloadIdentity creates or updates a WorkloadIdentity.
+ UpsertWorkloadIdentity(
+ ctx context.Context, workloadIdentity *workloadidentityv1pb.WorkloadIdentity,
+ ) (*workloadidentityv1pb.WorkloadIdentity, error)
+}
+
+// MarshalWorkloadIdentity marshals the WorkloadIdentity object into a JSON byte
+// array.
+func MarshalWorkloadIdentity(
+ object *workloadidentityv1pb.WorkloadIdentity, opts ...MarshalOption,
+) ([]byte, error) {
+ return MarshalProtoResource(object, opts...)
+}
+
+// UnmarshalWorkloadIdentity unmarshals the WorkloadIdentity object from a
+// JSON byte array.
+func UnmarshalWorkloadIdentity(
+ data []byte, opts ...MarshalOption,
+) (*workloadidentityv1pb.WorkloadIdentity, error) {
+ return UnmarshalProtoResource[*workloadidentityv1pb.WorkloadIdentity](data, opts...)
+}
+
+// ValidateWorkloadIdentity validates the WorkloadIdentity object. This is
+// performed prior to writing to the backend.
+func ValidateWorkloadIdentity(s *workloadidentityv1pb.WorkloadIdentity) error {
+ switch {
+ case s == nil:
+ return trace.BadParameter("object cannot be nil")
+ case s.Version != types.V1:
+ return trace.BadParameter("version: only %q is supported", types.V1)
+ case s.Kind != types.KindWorkloadIdentity:
+ return trace.BadParameter("kind: must be %q", types.KindWorkloadIdentity)
+ case s.Metadata == nil:
+ return trace.BadParameter("metadata: is required")
+ case s.Metadata.Name == "":
+ return trace.BadParameter("metadata.name: is required")
+ case s.Spec == nil:
+ return trace.BadParameter("spec: is required")
+ case s.Spec.Spiffe.Id == "":
+ return trace.BadParameter("spec.spiffe.id: is required")
+ case !strings.HasPrefix(s.Spec.Spiffe.Id, "/"):
+ return trace.BadParameter("spec.spiffe.id: must start with a /")
+ }
+
+ for i, rule := range s.GetSpec().GetRules().GetAllow() {
+ if len(rule.Conditions) == 0 {
+ return trace.BadParameter("spec.rules.allow[%d].conditions: must be non-empty", i)
+ }
+ for j, condition := range rule.Conditions {
+ if condition.Attribute == "" {
+ return trace.BadParameter("spec.rules.allow[%d].conditions[%d].attribute: must be non-empty", i, j)
+ }
+ // Ensure exactly one operator is set.
+ operatorsSet := 0
+ if condition.Equals != "" {
+ operatorsSet++
+ }
+ if operatorsSet == 0 || operatorsSet > 1 {
+ return trace.BadParameter(
+ "spec.rules.allow[%d].conditions[%d]: exactly one operator must be specified, found %d",
+ i, j, operatorsSet,
+ )
+ }
+ }
+ }
+
+ return nil
+}
diff --git a/lib/services/workload_identity_test.go b/lib/services/workload_identity_test.go
new file mode 100644
index 0000000000000..429612ed48555
--- /dev/null
+++ b/lib/services/workload_identity_test.go
@@ -0,0 +1,231 @@
+// Teleport
+// Copyright (C) 2024 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package services
+
+import (
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "github.com/stretchr/testify/require"
+ "google.golang.org/protobuf/testing/protocmp"
+
+ headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1"
+ workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
+ "github.com/gravitational/teleport/api/types"
+)
+
+func TestWorkloadIdentityMarshaling(t *testing.T) {
+ t.Parallel()
+
+ testCases := []struct {
+ name string
+ in *workloadidentityv1pb.WorkloadIdentity
+ }{
+ {
+ name: "normal",
+ in: &workloadidentityv1pb.WorkloadIdentity{
+ Kind: types.KindWorkloadIdentity,
+ Version: types.V1,
+ Metadata: &headerv1.Metadata{
+ Name: "example",
+ },
+ Spec: &workloadidentityv1pb.WorkloadIdentitySpec{
+ Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{
+ Id: "/example",
+ },
+ },
+ },
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ gotBytes, err := MarshalWorkloadIdentity(tc.in)
+ require.NoError(t, err)
+ // Test that unmarshaling gives us the same object
+ got, err := UnmarshalWorkloadIdentity(gotBytes)
+ require.NoError(t, err)
+ require.Empty(t, cmp.Diff(tc.in, got, protocmp.Transform()))
+ })
+ }
+}
+
+func TestValidateWorkloadIdentity(t *testing.T) {
+ t.Parallel()
+
+ var errContains = func(contains string) require.ErrorAssertionFunc {
+ return func(t require.TestingT, err error, msgAndArgs ...interface{}) {
+ require.ErrorContains(t, err, contains, msgAndArgs...)
+ }
+ }
+
+ testCases := []struct {
+ name string
+ in *workloadidentityv1pb.WorkloadIdentity
+ requireErr require.ErrorAssertionFunc
+ }{
+ {
+ name: "success - full",
+ in: &workloadidentityv1pb.WorkloadIdentity{
+ Kind: types.KindWorkloadIdentity,
+ Version: types.V1,
+ Metadata: &headerv1.Metadata{
+ Name: "example",
+ },
+ Spec: &workloadidentityv1pb.WorkloadIdentitySpec{
+ Rules: &workloadidentityv1pb.WorkloadIdentityRules{
+ Allow: []*workloadidentityv1pb.WorkloadIdentityRule{
+ {
+ Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{
+ {
+ Attribute: "example",
+ Equals: "foo",
+ },
+ },
+ },
+ },
+ },
+ Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{
+ Id: "/example",
+ },
+ },
+ },
+ requireErr: require.NoError,
+ },
+ {
+ name: "success - minimal",
+ in: &workloadidentityv1pb.WorkloadIdentity{
+ Kind: types.KindWorkloadIdentity,
+ Version: types.V1,
+ Metadata: &headerv1.Metadata{
+ Name: "example",
+ },
+ Spec: &workloadidentityv1pb.WorkloadIdentitySpec{
+ Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{
+ Id: "/example",
+ },
+ },
+ },
+ requireErr: require.NoError,
+ },
+ {
+ name: "missing name",
+ in: &workloadidentityv1pb.WorkloadIdentity{
+ Kind: types.KindWorkloadIdentity,
+ Version: types.V1,
+ Metadata: &headerv1.Metadata{},
+ Spec: &workloadidentityv1pb.WorkloadIdentitySpec{
+ Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{
+ Id: "/example",
+ },
+ },
+ },
+ requireErr: errContains("metadata.name: is required"),
+ },
+ {
+ name: "missing spiffe id",
+ in: &workloadidentityv1pb.WorkloadIdentity{
+ Kind: types.KindWorkloadIdentity,
+ Version: types.V1,
+ Metadata: &headerv1.Metadata{
+ Name: "example",
+ },
+ Spec: &workloadidentityv1pb.WorkloadIdentitySpec{
+ Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{},
+ },
+ },
+ requireErr: errContains("spec.spiffe.id: is required"),
+ },
+ {
+ name: "spiffe id must have leading /",
+ in: &workloadidentityv1pb.WorkloadIdentity{
+ Kind: types.KindWorkloadIdentity,
+ Version: types.V1,
+ Metadata: &headerv1.Metadata{
+ Name: "example",
+ },
+ Spec: &workloadidentityv1pb.WorkloadIdentitySpec{
+ Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{
+ Id: "example",
+ },
+ },
+ },
+ requireErr: errContains("spec.spiffe.id: must start with a /"),
+ },
+ {
+ name: "missing attribute",
+ in: &workloadidentityv1pb.WorkloadIdentity{
+ Kind: types.KindWorkloadIdentity,
+ Version: types.V1,
+ Metadata: &headerv1.Metadata{
+ Name: "example",
+ },
+ Spec: &workloadidentityv1pb.WorkloadIdentitySpec{
+ Rules: &workloadidentityv1pb.WorkloadIdentityRules{
+ Allow: []*workloadidentityv1pb.WorkloadIdentityRule{
+ {
+ Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{
+ {
+ Attribute: "",
+ Equals: "foo",
+ },
+ },
+ },
+ },
+ },
+ Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{
+ Id: "/example",
+ },
+ },
+ },
+ requireErr: errContains("spec.rules.allow[0].conditions[0].attribute: must be non-empty"),
+ },
+ {
+ name: "missing operator",
+ in: &workloadidentityv1pb.WorkloadIdentity{
+ Kind: types.KindWorkloadIdentity,
+ Version: types.V1,
+ Metadata: &headerv1.Metadata{
+ Name: "example",
+ },
+ Spec: &workloadidentityv1pb.WorkloadIdentitySpec{
+ Rules: &workloadidentityv1pb.WorkloadIdentityRules{
+ Allow: []*workloadidentityv1pb.WorkloadIdentityRule{
+ {
+ Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{
+ {
+ Attribute: "example",
+ },
+ },
+ },
+ },
+ },
+ Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{
+ Id: "/example",
+ },
+ },
+ },
+ requireErr: errContains("spec.rules.allow[0].conditions[0]: exactly one operator must be specified, found 0"),
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ err := ValidateWorkloadIdentity(tc.in)
+ tc.requireErr(t, err)
+ })
+ }
+}