From 944c359c53afdb614b6910ea8332fa190f77b086 Mon Sep 17 00:00:00 2001 From: Michael MacDonald Date: Mon, 20 May 2024 21:03:50 +0000 Subject: [PATCH] DAOS-15874 control: Add optional credential cache to agent On heavily-loaded client nodes where many processes are being launched by the same user or users, the admin may optionally enable the credential cache in the agent in order to lower agent overhead caused by generating identical credentials for each process owned by a user. The agent-generated credential is presented by the client process during pool/container connection and is used to evaluate ACL permissions for that connection. Example config: credential_config: cache_lifetime: 1m Features: control Required-githooks: true Change-Id: I6ae2a8be1dd97ef14e0ccef0283d65bc1fabc4ed Signed-off-by: Michael MacDonald --- src/control/cmd/daos_agent/config_test.go | 2 + src/control/cmd/daos_agent/infocache.go | 75 ++++++++-- src/control/cmd/daos_agent/infocache_test.go | 36 +++-- src/control/cmd/daos_agent/mgmt_rpc.go | 11 +- src/control/cmd/daos_agent/mgmt_rpc_test.go | 3 +- src/control/cmd/daos_agent/security_rpc.go | 111 +++++++++++++-- .../cmd/daos_agent/security_rpc_test.go | 128 ++++++++++++++++-- src/control/lib/cache/cache.go | 117 ++++++++++------ src/control/lib/cache/cache_test.go | 20 ++- src/control/security/auth/auth_sys.go | 5 +- src/control/security/auth/auth_sys_test.go | 4 +- src/control/security/config.go | 1 + utils/config/daos_agent.yml | 16 ++- 13 files changed, 424 insertions(+), 105 deletions(-) diff --git a/src/control/cmd/daos_agent/config_test.go b/src/control/cmd/daos_agent/config_test.go index 162b60bbae0..6f0b7562660 100644 --- a/src/control/cmd/daos_agent/config_test.go +++ b/src/control/cmd/daos_agent/config_test.go @@ -47,6 +47,7 @@ disable_caching: true cache_expiration: 30 disable_auto_evict: true credential_config: + cache_lifetime: 10m client_user_map: 1000: user: frodo @@ -140,6 +141,7 @@ transport_config: CacheExpiration: refreshMinutes(30 * time.Minute), DisableAutoEvict: true, CredentialConfig: &security.CredentialConfig{ + CacheLifetime: time.Minute * 10, ClientUserMap: map[uint32]*security.MappedClientUser{ 1000: { User: "frodo", diff --git a/src/control/cmd/daos_agent/infocache.go b/src/control/cmd/daos_agent/infocache.go index cb777396ff1..9fc942f9290 100644 --- a/src/control/cmd/daos_agent/infocache.go +++ b/src/control/cmd/daos_agent/infocache.go @@ -80,11 +80,13 @@ func getFabricScanFn(log logging.Logger, cfg *Config, scanner *hardware.FabricSc } type cacheItem struct { - sync.Mutex + sync.RWMutex lastCached time.Time refreshInterval time.Duration } +// isStale returns true if the cache item is stale. +// NB: Should be run under a lock to protect lastCached. func (ci *cacheItem) isStale() bool { if ci.refreshInterval == 0 { return false @@ -92,6 +94,8 @@ func (ci *cacheItem) isStale() bool { return ci.lastCached.Add(ci.refreshInterval).Before(time.Now()) } +// isCached returns true if the cache item is cached. +// NB: Should be run under at least a read lock to protect lastCached. func (ci *cacheItem) isCached() bool { return !ci.lastCached.Equal(time.Time{}) } @@ -130,16 +134,30 @@ func (ci *cachedAttachInfo) Key() string { return sysAttachInfoKey(ci.system) } -// NeedsRefresh checks whether the cached data needs to be refreshed. -func (ci *cachedAttachInfo) NeedsRefresh() bool { +// needsRefresh checks whether the cached data needs to be refreshed. +func (ci *cachedAttachInfo) needsRefresh() bool { if ci == nil { return false } return !ci.isCached() || ci.isStale() } -// Refresh contacts the remote management server and refreshes the GetAttachInfo cache. -func (ci *cachedAttachInfo) Refresh(ctx context.Context) error { +// RefreshIfNeeded refreshes the cached data if it needs to be refreshed. +func (ci *cachedAttachInfo) RefreshIfNeeded(ctx context.Context) (func(), bool, error) { + if ci == nil { + return cache.NoopRelease, false, errors.New("cachedAttachInfo is nil") + } + + ci.Lock() + if ci.needsRefresh() { + return ci.Unlock, true, ci.refresh(ctx) + } + return ci.Unlock, false, nil +} + +// refresh implements the actual refresh logic. +// NB: Should be run under a lock. +func (ci *cachedAttachInfo) refresh(ctx context.Context) error { if ci == nil { return errors.New("cachedAttachInfo is nil") } @@ -155,6 +173,16 @@ func (ci *cachedAttachInfo) Refresh(ctx context.Context) error { return nil } +// Refresh contacts the remote management server and refreshes the GetAttachInfo cache. +func (ci *cachedAttachInfo) Refresh(ctx context.Context) (func(), error) { + if ci == nil { + return cache.NoopRelease, errors.New("cachedAttachInfo is nil") + } + + ci.Lock() + return ci.Unlock, ci.refresh(ctx) +} + type cachedFabricInfo struct { cacheItem fetch fabricScanFn @@ -172,17 +200,31 @@ func (cfi *cachedFabricInfo) Key() string { return fabricKey } -// NeedsRefresh indicates that the fabric information does not need to be refreshed unless it has +// needsRefresh indicates that the fabric information does not need to be refreshed unless it has // never been populated. -func (cfi *cachedFabricInfo) NeedsRefresh() bool { +func (cfi *cachedFabricInfo) needsRefresh() bool { if cfi == nil { return false } return !cfi.isCached() } -// Refresh scans the hardware for information about the fabric devices and caches the result. -func (cfi *cachedFabricInfo) Refresh(ctx context.Context) error { +// RefreshIfNeeded refreshes the cached fabric information if it needs to be refreshed. +func (cfi *cachedFabricInfo) RefreshIfNeeded(ctx context.Context) (func(), bool, error) { + if cfi == nil { + return cache.NoopRelease, false, errors.New("cachedFabricInfo is nil") + } + + cfi.Lock() + if cfi.needsRefresh() { + return cfi.Unlock, true, cfi.refresh(ctx) + } + return cfi.Unlock, false, nil +} + +// refresh implements the actual refresh logic. +// NB: Should be run under a lock. +func (cfi *cachedFabricInfo) refresh(ctx context.Context) error { if cfi == nil { return errors.New("cachedFabricInfo is nil") } @@ -197,6 +239,16 @@ func (cfi *cachedFabricInfo) Refresh(ctx context.Context) error { return nil } +// Refresh scans the hardware for information about the fabric devices and caches the result. +func (cfi *cachedFabricInfo) Refresh(ctx context.Context) (func(), error) { + if cfi == nil { + return cache.NoopRelease, errors.New("cachedFabricInfo is nil") + } + + cfi.Lock() + return cfi.Unlock, cfi.refresh(ctx) +} + // InfoCache is a cache for the results of expensive operations needed by the agent. type InfoCache struct { log logging.Logger @@ -350,15 +402,14 @@ func (c *InfoCache) GetAttachInfo(ctx context.Context, sys string) (*control.Get } createItem := func() (cache.Item, error) { c.log.Debugf("cache miss for %s", sysAttachInfoKey(sys)) - cai := newCachedAttachInfo(c.attachInfoRefresh, sys, c.client, c.getAttachInfo) - return cai, nil + return newCachedAttachInfo(c.attachInfoRefresh, sys, c.client, c.getAttachInfo), nil } item, release, err := c.cache.GetOrCreate(ctx, sysAttachInfoKey(sys), createItem) - defer release() if err != nil { return nil, errors.Wrap(err, "getting attach info from cache") } + defer release() cai, ok := item.(*cachedAttachInfo) if !ok { diff --git a/src/control/cmd/daos_agent/infocache_test.go b/src/control/cmd/daos_agent/infocache_test.go index e86c44bfc0c..36954849468 100644 --- a/src/control/cmd/daos_agent/infocache_test.go +++ b/src/control/cmd/daos_agent/infocache_test.go @@ -149,14 +149,22 @@ func TestAgent_cachedAttachInfo_Key(t *testing.T) { } } -func TestAgent_cachedAttachInfo_NeedsRefresh(t *testing.T) { +func TestAgent_cachedAttachInfo_RefreshIfNeeded(t *testing.T) { + log, buf := logging.NewTestLogger(t.Name()) + defer test.ShowBufferOnFailure(t, buf) + mockClient := control.NewMockInvoker(log, &control.MockInvokerConfig{}) + + noopGetAttachInfo := func(_ context.Context, _ control.UnaryInvoker, _ *control.GetAttachInfoReq) (*control.GetAttachInfoResp, error) { + return nil, nil + } + for name, tc := range map[string]struct { ai *cachedAttachInfo expResult bool }{ "nil": {}, "never cached": { - ai: newCachedAttachInfo(0, "test", nil, nil), + ai: newCachedAttachInfo(0, "test", mockClient, noopGetAttachInfo), expResult: true, }, "no refresh": { @@ -164,6 +172,8 @@ func TestAgent_cachedAttachInfo_NeedsRefresh(t *testing.T) { cacheItem: cacheItem{ lastCached: time.Now().Add(-time.Minute), }, + rpcClient: mockClient, + fetch: noopGetAttachInfo, lastResponse: &control.GetAttachInfoResp{}, }, }, @@ -173,6 +183,8 @@ func TestAgent_cachedAttachInfo_NeedsRefresh(t *testing.T) { lastCached: time.Now().Add(-time.Minute), refreshInterval: time.Second, }, + rpcClient: mockClient, + fetch: noopGetAttachInfo, lastResponse: &control.GetAttachInfoResp{}, }, expResult: true, @@ -183,12 +195,15 @@ func TestAgent_cachedAttachInfo_NeedsRefresh(t *testing.T) { lastCached: time.Now().Add(-time.Second), refreshInterval: time.Minute, }, + rpcClient: mockClient, + fetch: noopGetAttachInfo, lastResponse: &control.GetAttachInfoResp{}, }, }, } { t.Run(name, func(t *testing.T) { - test.AssertEqual(t, tc.expResult, tc.ai.NeedsRefresh(), "") + _, refreshed, _ := tc.ai.RefreshIfNeeded(test.Context(t)) + test.AssertEqual(t, tc.expResult, refreshed, "") }) } } @@ -271,7 +286,8 @@ func TestAgent_cachedAttachInfo_Refresh(t *testing.T) { } } - err := ai.Refresh(test.Context(t)) + release, err := ai.Refresh(test.Context(t)) + release() test.CmpErr(t, tc.expErr, err) @@ -320,7 +336,7 @@ func TestAgent_cachedFabricInfo_Key(t *testing.T) { } } -func TestAgent_cachedFabricInfo_NeedsRefresh(t *testing.T) { +func TestAgent_cachedFabricInfo_RefreshIfNeeded(t *testing.T) { for name, tc := range map[string]struct { nilCache bool cacheTime time.Time @@ -342,11 +358,14 @@ func TestAgent_cachedFabricInfo_NeedsRefresh(t *testing.T) { var cfi *cachedFabricInfo if !tc.nilCache { - cfi = newCachedFabricInfo(log, nil) + cfi = newCachedFabricInfo(log, func(_ context.Context, _ ...string) (*NUMAFabric, error) { + return nil, nil + }) cfi.cacheItem.lastCached = tc.cacheTime } - test.AssertEqual(t, tc.expResult, cfi.NeedsRefresh(), "") + _, refreshed, _ := cfi.RefreshIfNeeded(test.Context(t)) + test.AssertEqual(t, tc.expResult, refreshed, "") }) } } @@ -416,7 +435,8 @@ func TestAgent_cachedFabricInfo_Refresh(t *testing.T) { } } - err := cfi.Refresh(test.Context(t)) + release, err := cfi.Refresh(test.Context(t)) + release() test.CmpErr(t, tc.expErr, err) diff --git a/src/control/cmd/daos_agent/mgmt_rpc.go b/src/control/cmd/daos_agent/mgmt_rpc.go index 75dc337e313..51cb80984c8 100644 --- a/src/control/cmd/daos_agent/mgmt_rpc.go +++ b/src/control/cmd/daos_agent/mgmt_rpc.go @@ -8,7 +8,6 @@ package main import ( "net" - "sync" "github.com/pkg/errors" "golang.org/x/net/context" @@ -22,6 +21,7 @@ import ( "github.com/daos-stack/daos/src/control/drpc" "github.com/daos-stack/daos/src/control/fault" "github.com/daos-stack/daos/src/control/fault/code" + "github.com/daos-stack/daos/src/control/lib/atm" "github.com/daos-stack/daos/src/control/lib/control" "github.com/daos-stack/daos/src/control/lib/daos" "github.com/daos-stack/daos/src/control/lib/hardware" @@ -34,16 +34,13 @@ import ( // Management Service proxy, handling dRPCs sent by libdaos by forwarding them // to MS. type mgmtModule struct { - attachInfoMutex sync.RWMutex - fabricMutex sync.RWMutex - log logging.Logger sys string ctlInvoker control.Invoker cache *InfoCache monitor *procMon cliMetricsSrc *promexp.ClientSource - useDefaultNUMA bool + useDefaultNUMA atm.Bool numaGetter hardware.ProcessNUMAProvider } @@ -161,14 +158,14 @@ func (mod *mgmtModule) handleGetAttachInfo(ctx context.Context, reqb []byte, pid } func (mod *mgmtModule) getNUMANode(ctx context.Context, pid int32) (uint, error) { - if mod.useDefaultNUMA { + if mod.useDefaultNUMA.IsTrue() { return 0, nil } numaNode, err := mod.numaGetter.GetNUMANodeIDForPID(ctx, pid) if errors.Is(err, hardware.ErrNoNUMANodes) { mod.log.Debug("system is not NUMA-aware") - mod.useDefaultNUMA = true + mod.useDefaultNUMA.SetTrue() return 0, nil } else if err != nil { return 0, errors.Wrapf(err, "failed to get NUMA node ID for pid %d", pid) diff --git a/src/control/cmd/daos_agent/mgmt_rpc_test.go b/src/control/cmd/daos_agent/mgmt_rpc_test.go index 59fcb507a81..965f3dc47c9 100644 --- a/src/control/cmd/daos_agent/mgmt_rpc_test.go +++ b/src/control/cmd/daos_agent/mgmt_rpc_test.go @@ -27,6 +27,7 @@ import ( "github.com/daos-stack/daos/src/control/drpc" "github.com/daos-stack/daos/src/control/fault" "github.com/daos-stack/daos/src/control/fault/code" + "github.com/daos-stack/daos/src/control/lib/atm" "github.com/daos-stack/daos/src/control/lib/control" "github.com/daos-stack/daos/src/control/lib/daos" "github.com/daos-stack/daos/src/control/lib/hardware" @@ -334,7 +335,7 @@ func TestAgent_mgmtModule_getNUMANode(t *testing.T) { mod := &mgmtModule{ log: log, - useDefaultNUMA: tc.useDefaultNUMA, + useDefaultNUMA: atm.NewBool(tc.useDefaultNUMA), numaGetter: tc.numaGetter, } diff --git a/src/control/cmd/daos_agent/security_rpc.go b/src/control/cmd/daos_agent/security_rpc.go index 8843f18ed6f..7fc6b942cd9 100644 --- a/src/control/cmd/daos_agent/security_rpc.go +++ b/src/control/cmd/daos_agent/security_rpc.go @@ -8,12 +8,15 @@ package main import ( "context" + "fmt" "net" "os/user" + "time" "github.com/pkg/errors" "github.com/daos-stack/daos/src/control/drpc" + "github.com/daos-stack/daos/src/control/lib/cache" "github.com/daos-stack/daos/src/control/lib/daos" "github.com/daos-stack/daos/src/control/logging" "github.com/daos-stack/daos/src/control/security" @@ -21,7 +24,24 @@ import ( ) type ( - credSignerFn func(*auth.CredentialRequest) (*auth.Credential, error) + // credSignerFn defines the function signature for signing credentials. + credSignerFn func(context.Context, *auth.CredentialRequest) (*auth.Credential, error) + + // credentialCache implements a cache for signed credentials. + credentialCache struct { + log logging.Logger + cache *cache.ItemCache + credLifetime time.Duration + cacheMissFn credSignerFn + } + + // cachedCredential wraps a cached credential and implements the cache.ExpirableItem interface. + cachedCredential struct { + cacheItem + key string + expiredAt time.Time + cred *auth.Credential + } // securityConfig defines configuration parameters for SecurityModule. securityConfig struct { @@ -33,32 +53,106 @@ type ( SecurityModule struct { log logging.Logger signCredential credSignerFn + credCache *credentialCache config *securityConfig } ) -// NewSecurityModule creates a new module with the given initialized TransportConfig +// NewSecurityModule creates a new module with the given initialized TransportConfig. func NewSecurityModule(log logging.Logger, cfg *securityConfig) *SecurityModule { + var credCache *credentialCache + credSigner := auth.GetSignedCredential + if cfg.credentials.CacheLifetime > 0 { + cache := &credentialCache{ + log: log, + cache: cache.NewItemCache(log), + credLifetime: cfg.credentials.CacheLifetime, + cacheMissFn: auth.GetSignedCredential, + } + credSigner = cache.getSignedCredential + log.Noticef("credential cache enabled (entry lifetime: %s)", cfg.credentials.CacheLifetime) + } + return &SecurityModule{ log: log, - signCredential: auth.GetSignedCredential, + signCredential: credSigner, + credCache: credCache, config: cfg, } } +func credReqKey(req *auth.CredentialRequest) string { + return fmt.Sprintf("%d:%d:%s", req.DomainInfo.Uid(), req.DomainInfo.Gid(), req.DomainInfo.Ctx()) +} + +func (cred *cachedCredential) Key() string { + if cred == nil { + return "" + } + + return cred.key +} + +func (cred *cachedCredential) IsExpired() bool { + if cred == nil || cred.cred == nil || cred.expiredAt.IsZero() { + return true + } + + return time.Now().After(cred.expiredAt) +} + +func (cc *credentialCache) getSignedCredential(ctx context.Context, req *auth.CredentialRequest) (*auth.Credential, error) { + key := credReqKey(req) + + createItem := func() (cache.Item, error) { + cc.log.Tracef("cache miss for %s", key) + cred, err := cc.cacheMissFn(ctx, req) + if err != nil { + return nil, err + } + cc.log.Tracef("getting credential for %s", key) + return newCachedCredential(key, cred, cc.credLifetime) + } + + item, release, err := cc.cache.GetOrCreate(ctx, key, createItem) + if err != nil { + return nil, errors.Wrap(err, "getting cached credential from cache") + } + defer release() + + cachedCred, ok := item.(*cachedCredential) + if !ok { + return nil, errors.New("invalid cached credential") + } + + return cachedCred.cred, nil +} + +func newCachedCredential(key string, cred *auth.Credential, lifetime time.Duration) (*cachedCredential, error) { + if cred == nil { + return nil, errors.New("credential is nil") + } + + return &cachedCredential{ + key: key, + cred: cred, + expiredAt: time.Now().Add(lifetime), + }, nil +} + // HandleCall is the handler for calls to the SecurityModule -func (m *SecurityModule) HandleCall(_ context.Context, session *drpc.Session, method drpc.Method, body []byte) ([]byte, error) { +func (m *SecurityModule) HandleCall(ctx context.Context, session *drpc.Session, method drpc.Method, body []byte) ([]byte, error) { if method != drpc.MethodRequestCredentials { return nil, drpc.UnknownMethodFailure() } - return m.getCredential(session) + return m.getCredential(ctx, session) } // getCredentials generates a signed user credential based on the data attached to // the Unix Domain Socket. -func (m *SecurityModule) getCredential(session *drpc.Session) ([]byte, error) { +func (m *SecurityModule) getCredential(ctx context.Context, session *drpc.Session) ([]byte, error) { if session == nil { return nil, drpc.NewFailureWithMessage("session is nil") } @@ -82,7 +176,7 @@ func (m *SecurityModule) getCredential(session *drpc.Session) ([]byte, error) { } req := auth.NewCredentialRequest(info, signingKey) - cred, err := m.signCredential(req) + cred, err := m.signCredential(ctx, req) if err != nil { if err := func() error { if !errors.Is(err, user.UnknownUserIdError(info.Uid())) { @@ -95,7 +189,7 @@ func (m *SecurityModule) getCredential(session *drpc.Session) ([]byte, error) { } req.WithUserAndGroup(mu.User, mu.Group, mu.Groups...) - cred, err = m.signCredential(req) + cred, err = m.signCredential(ctx, req) if err != nil { return err } @@ -107,7 +201,6 @@ func (m *SecurityModule) getCredential(session *drpc.Session) ([]byte, error) { } } - m.log.Tracef("%s: successfully signed credential", info) resp := &auth.GetCredResp{Cred: cred} return drpc.Marshal(resp) } diff --git a/src/control/cmd/daos_agent/security_rpc_test.go b/src/control/cmd/daos_agent/security_rpc_test.go index 08ee4755b27..52e298f0e88 100644 --- a/src/control/cmd/daos_agent/security_rpc_test.go +++ b/src/control/cmd/daos_agent/security_rpc_test.go @@ -7,17 +7,22 @@ package main import ( + "context" "errors" "net" "os/user" + "syscall" "testing" + "time" "github.com/google/go-cmp/cmp" "golang.org/x/sys/unix" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/testing/protocmp" "github.com/daos-stack/daos/src/control/common/test" "github.com/daos-stack/daos/src/control/drpc" + "github.com/daos-stack/daos/src/control/lib/cache" "github.com/daos-stack/daos/src/control/lib/daos" "github.com/daos-stack/daos/src/control/logging" "github.com/daos-stack/daos/src/control/security" @@ -28,7 +33,7 @@ func TestAgentSecurityModule_ID(t *testing.T) { log, buf := logging.NewTestLogger(t.Name()) defer test.ShowBufferOnFailure(t, buf) - mod := NewSecurityModule(log, nil) + mod := NewSecurityModule(log, defaultTestSecurityConfig()) test.AssertEqual(t, mod.ID(), drpc.ModuleSecurityAgent, "wrong drpc module") } @@ -49,7 +54,7 @@ func TestAgentSecurityModule_BadMethod(t *testing.T) { log, buf := logging.NewTestLogger(t.Name()) defer test.ShowBufferOnFailure(t, buf) - mod := NewSecurityModule(log, nil) + mod := NewSecurityModule(log, defaultTestSecurityConfig()) method, err := mod.ID().GetMethod(-1) if method != nil { t.Errorf("Expected no method, got %+v", method) @@ -166,7 +171,8 @@ func TestAgentSecurityModule_RequestCreds_BadConfig(t *testing.T) { // Empty TransportConfig is incomplete mod := NewSecurityModule(log, &securityConfig{ - transport: &security.TransportConfig{}, + transport: &security.TransportConfig{}, + credentials: &security.CredentialConfig{}, }) respBytes, err := callRequestCreds(mod, t, log, conn) @@ -186,7 +192,7 @@ func TestAgentSecurityModule_RequestCreds_BadUid(t *testing.T) { defer cleanup() mod := NewSecurityModule(log, defaultTestSecurityConfig()) - mod.signCredential = func(_ *auth.CredentialRequest) (*auth.Credential, error) { + mod.signCredential = func(_ context.Context, _ *auth.CredentialRequest) (*auth.Credential, error) { return nil, errors.New("LookupUserID") } respBytes, err := callRequestCreds(mod, t, log, conn) @@ -198,11 +204,12 @@ func TestAgentSecurityModule_RequestCreds_BadUid(t *testing.T) { expectCredResp(t, respBytes, int32(daos.MiscError), false) } +type signCredentialResp struct { + cred *auth.Credential + err error +} + func TestAgent_SecurityRPC_getCredential(t *testing.T) { - type response struct { - cred *auth.Credential - err error - } testCred := &auth.Credential{ Token: &auth.Token{Flavor: auth.Flavor_AUTH_SYS, Data: []byte("test-token")}, Origin: "test-origin", @@ -227,13 +234,13 @@ func TestAgent_SecurityRPC_getCredential(t *testing.T) { for name, tc := range map[string]struct { secCfg *securityConfig - responses []response + responses []signCredentialResp expBytes []byte expErr error }{ "lookup miss": { secCfg: defaultTestSecurityConfig(), - responses: []response{ + responses: []signCredentialResp{ { cred: nil, err: user.UnknownUserIdError(unix.Getuid()), @@ -251,7 +258,7 @@ func TestAgent_SecurityRPC_getCredential(t *testing.T) { } return cfg }(), - responses: []response{ + responses: []signCredentialResp{ { cred: nil, err: user.UnknownUserIdError(unix.Getuid()), @@ -273,7 +280,7 @@ func TestAgent_SecurityRPC_getCredential(t *testing.T) { } return cfg }(), - responses: []response{ + responses: []signCredentialResp{ { cred: nil, err: user.UnknownUserIdError(unix.Getuid()), @@ -295,9 +302,9 @@ func TestAgent_SecurityRPC_getCredential(t *testing.T) { defer cleanup() mod := NewSecurityModule(log, tc.secCfg) - mod.signCredential = func() func(req *auth.CredentialRequest) (*auth.Credential, error) { + mod.signCredential = func() func(_ context.Context, req *auth.CredentialRequest) (*auth.Credential, error) { var idx int - return func(req *auth.CredentialRequest) (*auth.Credential, error) { + return func(_ context.Context, req *auth.CredentialRequest) (*auth.Credential, error) { defer func() { if idx < len(tc.responses)-1 { idx++ @@ -319,3 +326,96 @@ func TestAgent_SecurityRPC_getCredential(t *testing.T) { }) } } + +func TestAgent_SecurityCachedCredentials(t *testing.T) { + cred0 := &auth.Credential{ + Token: &auth.Token{Flavor: auth.Flavor_AUTH_SYS, Data: []byte("user,group1,group2")}, + Origin: "test-origin", + } + cred1 := &auth.Credential{ + Token: &auth.Token{Flavor: auth.Flavor_AUTH_SYS, Data: []byte("user,group1,group3")}, + Origin: "test-origin", + } + + for name, tc := range map[string]struct { + lifetime time.Duration + req *auth.CredentialRequest + responses []signCredentialResp + exp *auth.Credential + }{ + "cache hit": { + lifetime: time.Second, + req: &auth.CredentialRequest{ + DomainInfo: security.InitDomainInfo(&syscall.Ucred{Uid: 1234, Gid: 5678}, ""), + }, + responses: []signCredentialResp{ + { + cred: cred0, + err: nil, + }, + { + cred: cred1, + err: nil, + }, + }, + exp: cred0, + }, + "expired entry": { + lifetime: time.Nanosecond, + req: &auth.CredentialRequest{ + DomainInfo: security.InitDomainInfo(&syscall.Ucred{Uid: 1234, Gid: 5678}, ""), + }, + responses: []signCredentialResp{ + { + cred: cred0, + err: nil, + }, + { + cred: cred1, + err: nil, + }, + }, + exp: cred1, + }, + } { + t.Run(name, func(t *testing.T) { + log, buf := logging.NewTestLogger(t.Name()) + defer test.ShowBufferOnFailure(t, buf) + + cache := &credentialCache{ + log: log, + cache: cache.NewItemCache(log), + credLifetime: tc.lifetime, + cacheMissFn: func() func(_ context.Context, req *auth.CredentialRequest) (*auth.Credential, error) { + var idx int + return func(_ context.Context, req *auth.CredentialRequest) (*auth.Credential, error) { + defer func() { + if idx < len(tc.responses)-1 { + idx++ + } + }() + t.Logf("returning response %d: %+v", idx, tc.responses[idx]) + return tc.responses[idx].cred, tc.responses[idx].err + } + }(), + } + + // Prime the cache with a single entry. + _, err := cache.getSignedCredential(test.Context(t), tc.req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + // Request a second time with the same credentials. + cred, err := cache.getSignedCredential(test.Context(t), tc.req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + cmpOpts := cmp.Options{ + protocmp.Transform(), + } + if diff := cmp.Diff(tc.exp, cred, cmpOpts...); diff != "" { + t.Errorf("unexpected credential (-want +got):\n%s", diff) + } + }) + } +} diff --git a/src/control/lib/cache/cache.go b/src/control/lib/cache/cache.go index c31d82b0a61..9ecc123240c 100644 --- a/src/control/lib/cache/cache.go +++ b/src/control/lib/cache/cache.go @@ -1,5 +1,5 @@ // -// (C) Copyright 2023 Intel Corporation. +// (C) Copyright 2023-2024 Intel Corporation. // // SPDX-License-Identifier: BSD-2-Clause-Patent // @@ -18,20 +18,33 @@ import ( "github.com/daos-stack/daos/src/control/logging" ) -type Item interface { - Lock() - Unlock() - Key() string - Refresh(ctx context.Context) error - NeedsRefresh() bool -} +type ( + // Item defines an interface for a cached item. + Item interface { + sync.Locker + Key() string + } -// ItemCache is a mechanism for caching Items to keys. -type ItemCache struct { - log logging.Logger - mutex sync.RWMutex - items map[string]Item -} + // ExpirableItem is an Item that defines its own expiration criteria. + ExpirableItem interface { + Item + IsExpired() bool + } + + // RefreshableItem is an Item that defines its own refresh criteria and method. + RefreshableItem interface { + Item + Refresh(ctx context.Context) (func(), error) + RefreshIfNeeded(ctx context.Context) (func(), bool, error) + } + + // ItemCache is a mechanism for caching Items to keys. + ItemCache struct { + log logging.Logger + mutex sync.RWMutex + items map[string]Item + } +) // NewItemCache creates a new ItemCache. func NewItemCache(log logging.Logger) *ItemCache { @@ -119,22 +132,24 @@ func (e *errKeyNotFound) Error() string { return fmt.Sprintf("key %q not found", e.key) } -func noopRelease() {} +// NoopRelease is a no-op function that does nothing, but can +// be safely returned in lieu of a real lock release function. +func NoopRelease() {} // GetOrCreate returns an item from the cache if it exists, otherwise it creates // the item using the given function and caches it. The item must be released // by the caller when it is safe to be modified. func (ic *ItemCache) GetOrCreate(ctx context.Context, key string, missFn ItemCreateFunc) (Item, func(), error) { if ic == nil { - return nil, noopRelease, errors.New("nil ItemCache") + return nil, NoopRelease, errors.New("nil ItemCache") } if key == "" { - return nil, noopRelease, errors.Errorf("empty string is an invalid key") + return nil, NoopRelease, errors.Errorf("empty string is an invalid key") } if missFn == nil { - return nil, noopRelease, errors.Errorf("item create function is required") + return nil, NoopRelease, errors.Errorf("item create function is required") } ic.mutex.Lock() @@ -145,33 +160,39 @@ func (ic *ItemCache) GetOrCreate(ctx context.Context, key string, missFn ItemCre ic.log.Debugf("failed to get item for key %q: %s", key, err.Error()) item, err = missFn() if err != nil { - return nil, noopRelease, errors.Wrapf(err, "create item for %q", key) + return nil, NoopRelease, errors.Wrapf(err, "create item for %q", key) } ic.log.Debugf("created item for key %q", key) ic.set(item) } - item.Lock() - if item.NeedsRefresh() { - if err := item.Refresh(ctx); err != nil { - item.Unlock() - return nil, noopRelease, errors.Wrapf(err, "fetch data for %q", key) + var release func() + if ri, ok := item.(RefreshableItem); ok { + var refreshed bool + release, refreshed, err = ri.RefreshIfNeeded(ctx) + if err != nil { + return nil, NoopRelease, errors.Wrapf(err, "fetch data for %q", key) } - ic.log.Debugf("refreshed item %q", key) + if refreshed { + ic.log.Debugf("refreshed item %q", key) + } + } else { + item.Lock() + release = item.Unlock } - return item, item.Unlock, nil + return item, release, nil } // Get returns an item from the cache if it exists, otherwise it returns an // error. The item must be released by the caller when it is safe to be modified. func (ic *ItemCache) Get(ctx context.Context, key string) (Item, func(), error) { if ic == nil { - return nil, noopRelease, errors.New("nil ItemCache") + return nil, NoopRelease, errors.New("nil ItemCache") } if key == "" { - return nil, noopRelease, errors.Errorf("empty string is an invalid key") + return nil, NoopRelease, errors.Errorf("empty string is an invalid key") } ic.mutex.Lock() @@ -179,25 +200,35 @@ func (ic *ItemCache) Get(ctx context.Context, key string) (Item, func(), error) item, err := ic.get(key) if err != nil { - return nil, noopRelease, err + return nil, NoopRelease, err } - item.Lock() - if item.NeedsRefresh() { - if err := item.Refresh(ctx); err != nil { - item.Unlock() - return nil, noopRelease, errors.Wrapf(err, "fetch data for %q", key) + var release func() + if ri, ok := item.(RefreshableItem); ok { + var refreshed bool + release, refreshed, err = ri.RefreshIfNeeded(ctx) + if err != nil { + return nil, NoopRelease, errors.Wrapf(err, "fetch data for %q", key) } - ic.log.Debugf("refreshed item %q", key) + if refreshed { + ic.log.Debugf("refreshed item %q", key) + } + } else { + item.Lock() + release = item.Unlock } - return item, item.Unlock, nil + return item, release, nil } func (ic *ItemCache) get(key string) (Item, error) { - val, ok := ic.items[key] + item, ok := ic.items[key] if ok { - return val, nil + if ei, ok := item.(ExpirableItem); ok && ei.IsExpired() { + delete(ic.items, key) + } else { + return item, nil + } } return nil, &errKeyNotFound{key: key} } @@ -229,10 +260,12 @@ func (ic *ItemCache) refreshItem(ctx context.Context, key string) error { return err } - item.Lock() - defer item.Unlock() - if err := item.Refresh(ctx); err != nil { - return errors.Wrapf(err, "failed to refresh cached item %q", item.Key()) + if ri, ok := item.(RefreshableItem); ok { + release, err := ri.Refresh(ctx) + if err != nil { + return errors.Wrapf(err, "failed to refresh cached item %q", item.Key()) + } + release() } return nil diff --git a/src/control/lib/cache/cache_test.go b/src/control/lib/cache/cache_test.go index 6d1aeb465ea..0a34e8bf606 100644 --- a/src/control/lib/cache/cache_test.go +++ b/src/control/lib/cache/cache_test.go @@ -1,5 +1,5 @@ // -// (C) Copyright 2023 Intel Corporation. +// (C) Copyright 2023-2024 Intel Corporation. // // SPDX-License-Identifier: BSD-2-Clause-Patent // @@ -10,11 +10,12 @@ import ( "context" "testing" - "github.com/daos-stack/daos/src/control/common/test" - "github.com/daos-stack/daos/src/control/logging" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/pkg/errors" + + "github.com/daos-stack/daos/src/control/common/test" + "github.com/daos-stack/daos/src/control/logging" ) func TestCache_NewItemCache(t *testing.T) { @@ -51,11 +52,18 @@ func (m *mockItem) Key() string { return m.ItemKey } -func (m *mockItem) Refresh(ctx context.Context) error { - return m.RefreshErr +func (m *mockItem) Refresh(ctx context.Context) (func(), error) { + return NoopRelease, m.RefreshErr +} + +func (m *mockItem) RefreshIfNeeded(ctx context.Context) (func(), bool, error) { + if m.needsRefresh() { + return NoopRelease, true, m.RefreshErr + } + return NoopRelease, false, nil } -func (m *mockItem) NeedsRefresh() bool { +func (m *mockItem) needsRefresh() bool { return m.NeedsRefreshResult } diff --git a/src/control/security/auth/auth_sys.go b/src/control/security/auth/auth_sys.go index e9255f407d3..fcedcb95bd5 100644 --- a/src/control/security/auth/auth_sys.go +++ b/src/control/security/auth/auth_sys.go @@ -8,6 +8,7 @@ package auth import ( "bytes" + "context" "crypto" "os" "os/user" @@ -17,6 +18,7 @@ import ( "github.com/pkg/errors" "google.golang.org/protobuf/proto" + "github.com/daos-stack/daos/src/control/logging" "github.com/daos-stack/daos/src/control/security" ) @@ -213,7 +215,7 @@ func (r *CredentialRequest) WithUserAndGroup(userStr, groupStr string, groupStrs // GetSignedCredential returns a credential based on the provided domain info and // signing key. -func GetSignedCredential(req *CredentialRequest) (*Credential, error) { +func GetSignedCredential(ctx context.Context, req *CredentialRequest) (*Credential, error) { if req == nil { return nil, errors.Errorf("%T is nil", req) } @@ -274,6 +276,7 @@ func GetSignedCredential(req *CredentialRequest) (*Credential, error) { Verifier: &verifierToken, Origin: "agent"} + logging.FromContext(ctx).Tracef("%s: successfully signed credential", req.DomainInfo) return &credential, nil } diff --git a/src/control/security/auth/auth_sys_test.go b/src/control/security/auth/auth_sys_test.go index bbca64fe233..a421d1b067c 100644 --- a/src/control/security/auth/auth_sys_test.go +++ b/src/control/security/auth/auth_sys_test.go @@ -261,7 +261,7 @@ func TestAuth_GetSignedCred(t *testing.T) { }, } { t.Run(name, func(t *testing.T) { - cred, gotErr := GetSignedCredential(tc.req) + cred, gotErr := GetSignedCredential(test.Context(t), tc.req) test.CmpErr(t, tc.expErr, gotErr) if tc.expErr != nil { return @@ -277,7 +277,7 @@ func TestAuth_CredentialRequestOverrides(t *testing.T) { req.getHostnameFn = testHostnameFn(nil, "test-host") req.WithUserAndGroup("test-user", "test-group", "test-secondary") - cred, err := GetSignedCredential(req) + cred, err := GetSignedCredential(test.Context(t), req) if err != nil { t.Fatalf("Failed to get credential: %s", err) } diff --git a/src/control/security/config.go b/src/control/security/config.go index 2bda4bf4a00..1913fd89fcd 100644 --- a/src/control/security/config.go +++ b/src/control/security/config.go @@ -94,6 +94,7 @@ func (cm ClientUserMap) Lookup(uid uint32) *MappedClientUser { // CredentialConfig contains configuration details for managing user // credentials. type CredentialConfig struct { + CacheLifetime time.Duration `yaml:"cache_lifetime,omitempty"` ClientUserMap ClientUserMap `yaml:"client_user_map,omitempty"` } diff --git a/utils/config/daos_agent.yml b/utils/config/daos_agent.yml index 3e7e8f1ab95..cab0bd88976 100644 --- a/utils/config/daos_agent.yml +++ b/utils/config/daos_agent.yml @@ -48,7 +48,6 @@ #telemetry_retain: 1m ## Configuration for user credential management. -# #credential_config: # # If the agent should be able to resolve unknown client uids and gids # # (e.g. when running in a container) into ACL principal names, then a @@ -61,10 +60,20 @@ # 1000: # user: ralph # group: stanley - +# +# # Optionally cache generated credentials with the specified cache +# # lifetime. By default, a credential is generated for every client +# # process that connects to a pool. If the credential cache is +# # enabled, then local client processes connecting with stable +# # uid:gid associations may take advantage of the cached credential +# # and reduce some agent overhead. For heavily-loaded client nodes +# # with many frequent (e.g. hundreds per minute) client connections, +# # a lifetime of 1-5 minutes may be a reasonable tradeoff between +# # performance and responsiveness to user/group database updates. +# cred_cache_lifetime: 1m +# ## Configuration for SSL certificates used to secure management traffic # and authenticate/authorize management components. -# #transport_config: # # In order to disable transport security, uncomment and set allow_insecure # # to true. Not recommended for production configurations. @@ -77,6 +86,7 @@ # # Key portion of Agent Certificate # key: /etc/daos/certs/agent.key # + # Use the given directory for creating unix domain sockets # # NOTE: Do not change this when running under systemd control. If it needs to