diff --git a/src/control/cmd/daos_agent/config.go b/src/control/cmd/daos_agent/config.go index c9d08d19744..76d8a295cf3 100644 --- a/src/control/cmd/daos_agent/config.go +++ b/src/control/cmd/daos_agent/config.go @@ -43,21 +43,22 @@ func (rm refreshMinutes) Duration() time.Duration { // Config defines the agent configuration. type Config struct { - SystemName string `yaml:"name"` - AccessPoints []string `yaml:"access_points"` - ControlPort int `yaml:"port"` - RuntimeDir string `yaml:"runtime_dir"` - LogFile string `yaml:"log_file"` - LogLevel common.ControlLogLevel `yaml:"control_log_mask,omitempty"` - TransportConfig *security.TransportConfig `yaml:"transport_config"` - DisableCache bool `yaml:"disable_caching,omitempty"` - CacheExpiration refreshMinutes `yaml:"cache_expiration,omitempty"` - DisableAutoEvict bool `yaml:"disable_auto_evict,omitempty"` - ExcludeFabricIfaces common.StringSet `yaml:"exclude_fabric_ifaces,omitempty"` - FabricInterfaces []*NUMAFabricConfig `yaml:"fabric_ifaces,omitempty"` - TelemetryPort int `yaml:"telemetry_port,omitempty"` - TelemetryEnabled bool `yaml:"telemetry_enabled,omitempty"` - TelemetryRetain time.Duration `yaml:"telemetry_retain,omitempty"` + SystemName string `yaml:"name"` + AccessPoints []string `yaml:"access_points"` + ControlPort int `yaml:"port"` + RuntimeDir string `yaml:"runtime_dir"` + LogFile string `yaml:"log_file"` + LogLevel common.ControlLogLevel `yaml:"control_log_mask,omitempty"` + CredentialConfig *security.CredentialConfig `yaml:"credential_config"` + TransportConfig *security.TransportConfig `yaml:"transport_config"` + DisableCache bool `yaml:"disable_caching,omitempty"` + CacheExpiration refreshMinutes `yaml:"cache_expiration,omitempty"` + DisableAutoEvict bool `yaml:"disable_auto_evict,omitempty"` + ExcludeFabricIfaces common.StringSet `yaml:"exclude_fabric_ifaces,omitempty"` + FabricInterfaces []*NUMAFabricConfig `yaml:"fabric_ifaces,omitempty"` + TelemetryPort int `yaml:"telemetry_port,omitempty"` + TelemetryEnabled bool `yaml:"telemetry_enabled,omitempty"` + TelemetryRetain time.Duration `yaml:"telemetry_retain,omitempty"` } // TelemetryExportEnabled returns true if client telemetry export is enabled. @@ -112,12 +113,13 @@ func LoadConfig(cfgPath string) (*Config, error) { func DefaultConfig() *Config { localServer := fmt.Sprintf("localhost:%d", build.DefaultControlPort) return &Config{ - SystemName: build.DefaultSystemName, - ControlPort: build.DefaultControlPort, - AccessPoints: []string{localServer}, - RuntimeDir: defaultRuntimeDir, - LogFile: defaultLogFile, - LogLevel: common.DefaultControlLogLevel, - TransportConfig: security.DefaultAgentTransportConfig(), + SystemName: build.DefaultSystemName, + ControlPort: build.DefaultControlPort, + AccessPoints: []string{localServer}, + RuntimeDir: defaultRuntimeDir, + LogFile: defaultLogFile, + LogLevel: common.DefaultControlLogLevel, + TransportConfig: security.DefaultAgentTransportConfig(), + CredentialConfig: &security.CredentialConfig{}, } } diff --git a/src/control/cmd/daos_agent/config_test.go b/src/control/cmd/daos_agent/config_test.go index e21920e3735..6f0b7562660 100644 --- a/src/control/cmd/daos_agent/config_test.go +++ b/src/control/cmd/daos_agent/config_test.go @@ -1,5 +1,5 @@ // -// (C) Copyright 2021-2023 Intel Corporation. +// (C) Copyright 2021-2024 Intel Corporation. // // SPDX-License-Identifier: BSD-2-Clause-Patent // @@ -46,6 +46,13 @@ control_log_mask: debug disable_caching: true cache_expiration: 30 disable_auto_evict: true +credential_config: + cache_lifetime: 10m + client_user_map: + 1000: + user: frodo + group: baggins + groups: ["ringbearers"] transport_config: allow_insecure: true exclude_fabric_ifaces: ["ib3"] @@ -104,12 +111,13 @@ transport_config: "without optional items": { path: withoutOptCfg, expResult: &Config{ - SystemName: "shire", - AccessPoints: []string{"one:10001", "two:10001"}, - ControlPort: 4242, - RuntimeDir: "/tmp/runtime", - LogFile: "/home/frodo/logfile", - LogLevel: common.DefaultControlLogLevel, + SystemName: "shire", + AccessPoints: []string{"one:10001", "two:10001"}, + ControlPort: 4242, + RuntimeDir: "/tmp/runtime", + LogFile: "/home/frodo/logfile", + LogLevel: common.DefaultControlLogLevel, + CredentialConfig: &security.CredentialConfig{}, TransportConfig: &security.TransportConfig{ AllowInsecure: true, CertificateConfig: DefaultConfig().TransportConfig.CertificateConfig, @@ -132,6 +140,16 @@ transport_config: DisableCache: true, CacheExpiration: refreshMinutes(30 * time.Minute), DisableAutoEvict: true, + CredentialConfig: &security.CredentialConfig{ + CacheLifetime: time.Minute * 10, + ClientUserMap: map[uint32]*security.MappedClientUser{ + 1000: { + User: "frodo", + Group: "baggins", + Groups: []string{"ringbearers"}, + }, + }, + }, TransportConfig: &security.TransportConfig{ AllowInsecure: true, CertificateConfig: DefaultConfig().TransportConfig.CertificateConfig, diff --git a/src/control/cmd/daos_agent/infocache.go b/src/control/cmd/daos_agent/infocache.go index cb777396ff1..9fc942f9290 100644 --- a/src/control/cmd/daos_agent/infocache.go +++ b/src/control/cmd/daos_agent/infocache.go @@ -80,11 +80,13 @@ func getFabricScanFn(log logging.Logger, cfg *Config, scanner *hardware.FabricSc } type cacheItem struct { - sync.Mutex + sync.RWMutex lastCached time.Time refreshInterval time.Duration } +// isStale returns true if the cache item is stale. +// NB: Should be run under a lock to protect lastCached. func (ci *cacheItem) isStale() bool { if ci.refreshInterval == 0 { return false @@ -92,6 +94,8 @@ func (ci *cacheItem) isStale() bool { return ci.lastCached.Add(ci.refreshInterval).Before(time.Now()) } +// isCached returns true if the cache item is cached. +// NB: Should be run under at least a read lock to protect lastCached. func (ci *cacheItem) isCached() bool { return !ci.lastCached.Equal(time.Time{}) } @@ -130,16 +134,30 @@ func (ci *cachedAttachInfo) Key() string { return sysAttachInfoKey(ci.system) } -// NeedsRefresh checks whether the cached data needs to be refreshed. -func (ci *cachedAttachInfo) NeedsRefresh() bool { +// needsRefresh checks whether the cached data needs to be refreshed. +func (ci *cachedAttachInfo) needsRefresh() bool { if ci == nil { return false } return !ci.isCached() || ci.isStale() } -// Refresh contacts the remote management server and refreshes the GetAttachInfo cache. -func (ci *cachedAttachInfo) Refresh(ctx context.Context) error { +// RefreshIfNeeded refreshes the cached data if it needs to be refreshed. +func (ci *cachedAttachInfo) RefreshIfNeeded(ctx context.Context) (func(), bool, error) { + if ci == nil { + return cache.NoopRelease, false, errors.New("cachedAttachInfo is nil") + } + + ci.Lock() + if ci.needsRefresh() { + return ci.Unlock, true, ci.refresh(ctx) + } + return ci.Unlock, false, nil +} + +// refresh implements the actual refresh logic. +// NB: Should be run under a lock. +func (ci *cachedAttachInfo) refresh(ctx context.Context) error { if ci == nil { return errors.New("cachedAttachInfo is nil") } @@ -155,6 +173,16 @@ func (ci *cachedAttachInfo) Refresh(ctx context.Context) error { return nil } +// Refresh contacts the remote management server and refreshes the GetAttachInfo cache. +func (ci *cachedAttachInfo) Refresh(ctx context.Context) (func(), error) { + if ci == nil { + return cache.NoopRelease, errors.New("cachedAttachInfo is nil") + } + + ci.Lock() + return ci.Unlock, ci.refresh(ctx) +} + type cachedFabricInfo struct { cacheItem fetch fabricScanFn @@ -172,17 +200,31 @@ func (cfi *cachedFabricInfo) Key() string { return fabricKey } -// NeedsRefresh indicates that the fabric information does not need to be refreshed unless it has +// needsRefresh indicates that the fabric information does not need to be refreshed unless it has // never been populated. -func (cfi *cachedFabricInfo) NeedsRefresh() bool { +func (cfi *cachedFabricInfo) needsRefresh() bool { if cfi == nil { return false } return !cfi.isCached() } -// Refresh scans the hardware for information about the fabric devices and caches the result. -func (cfi *cachedFabricInfo) Refresh(ctx context.Context) error { +// RefreshIfNeeded refreshes the cached fabric information if it needs to be refreshed. +func (cfi *cachedFabricInfo) RefreshIfNeeded(ctx context.Context) (func(), bool, error) { + if cfi == nil { + return cache.NoopRelease, false, errors.New("cachedFabricInfo is nil") + } + + cfi.Lock() + if cfi.needsRefresh() { + return cfi.Unlock, true, cfi.refresh(ctx) + } + return cfi.Unlock, false, nil +} + +// refresh implements the actual refresh logic. +// NB: Should be run under a lock. +func (cfi *cachedFabricInfo) refresh(ctx context.Context) error { if cfi == nil { return errors.New("cachedFabricInfo is nil") } @@ -197,6 +239,16 @@ func (cfi *cachedFabricInfo) Refresh(ctx context.Context) error { return nil } +// Refresh scans the hardware for information about the fabric devices and caches the result. +func (cfi *cachedFabricInfo) Refresh(ctx context.Context) (func(), error) { + if cfi == nil { + return cache.NoopRelease, errors.New("cachedFabricInfo is nil") + } + + cfi.Lock() + return cfi.Unlock, cfi.refresh(ctx) +} + // InfoCache is a cache for the results of expensive operations needed by the agent. type InfoCache struct { log logging.Logger @@ -350,15 +402,14 @@ func (c *InfoCache) GetAttachInfo(ctx context.Context, sys string) (*control.Get } createItem := func() (cache.Item, error) { c.log.Debugf("cache miss for %s", sysAttachInfoKey(sys)) - cai := newCachedAttachInfo(c.attachInfoRefresh, sys, c.client, c.getAttachInfo) - return cai, nil + return newCachedAttachInfo(c.attachInfoRefresh, sys, c.client, c.getAttachInfo), nil } item, release, err := c.cache.GetOrCreate(ctx, sysAttachInfoKey(sys), createItem) - defer release() if err != nil { return nil, errors.Wrap(err, "getting attach info from cache") } + defer release() cai, ok := item.(*cachedAttachInfo) if !ok { diff --git a/src/control/cmd/daos_agent/infocache_test.go b/src/control/cmd/daos_agent/infocache_test.go index e86c44bfc0c..36954849468 100644 --- a/src/control/cmd/daos_agent/infocache_test.go +++ b/src/control/cmd/daos_agent/infocache_test.go @@ -149,14 +149,22 @@ func TestAgent_cachedAttachInfo_Key(t *testing.T) { } } -func TestAgent_cachedAttachInfo_NeedsRefresh(t *testing.T) { +func TestAgent_cachedAttachInfo_RefreshIfNeeded(t *testing.T) { + log, buf := logging.NewTestLogger(t.Name()) + defer test.ShowBufferOnFailure(t, buf) + mockClient := control.NewMockInvoker(log, &control.MockInvokerConfig{}) + + noopGetAttachInfo := func(_ context.Context, _ control.UnaryInvoker, _ *control.GetAttachInfoReq) (*control.GetAttachInfoResp, error) { + return nil, nil + } + for name, tc := range map[string]struct { ai *cachedAttachInfo expResult bool }{ "nil": {}, "never cached": { - ai: newCachedAttachInfo(0, "test", nil, nil), + ai: newCachedAttachInfo(0, "test", mockClient, noopGetAttachInfo), expResult: true, }, "no refresh": { @@ -164,6 +172,8 @@ func TestAgent_cachedAttachInfo_NeedsRefresh(t *testing.T) { cacheItem: cacheItem{ lastCached: time.Now().Add(-time.Minute), }, + rpcClient: mockClient, + fetch: noopGetAttachInfo, lastResponse: &control.GetAttachInfoResp{}, }, }, @@ -173,6 +183,8 @@ func TestAgent_cachedAttachInfo_NeedsRefresh(t *testing.T) { lastCached: time.Now().Add(-time.Minute), refreshInterval: time.Second, }, + rpcClient: mockClient, + fetch: noopGetAttachInfo, lastResponse: &control.GetAttachInfoResp{}, }, expResult: true, @@ -183,12 +195,15 @@ func TestAgent_cachedAttachInfo_NeedsRefresh(t *testing.T) { lastCached: time.Now().Add(-time.Second), refreshInterval: time.Minute, }, + rpcClient: mockClient, + fetch: noopGetAttachInfo, lastResponse: &control.GetAttachInfoResp{}, }, }, } { t.Run(name, func(t *testing.T) { - test.AssertEqual(t, tc.expResult, tc.ai.NeedsRefresh(), "") + _, refreshed, _ := tc.ai.RefreshIfNeeded(test.Context(t)) + test.AssertEqual(t, tc.expResult, refreshed, "") }) } } @@ -271,7 +286,8 @@ func TestAgent_cachedAttachInfo_Refresh(t *testing.T) { } } - err := ai.Refresh(test.Context(t)) + release, err := ai.Refresh(test.Context(t)) + release() test.CmpErr(t, tc.expErr, err) @@ -320,7 +336,7 @@ func TestAgent_cachedFabricInfo_Key(t *testing.T) { } } -func TestAgent_cachedFabricInfo_NeedsRefresh(t *testing.T) { +func TestAgent_cachedFabricInfo_RefreshIfNeeded(t *testing.T) { for name, tc := range map[string]struct { nilCache bool cacheTime time.Time @@ -342,11 +358,14 @@ func TestAgent_cachedFabricInfo_NeedsRefresh(t *testing.T) { var cfi *cachedFabricInfo if !tc.nilCache { - cfi = newCachedFabricInfo(log, nil) + cfi = newCachedFabricInfo(log, func(_ context.Context, _ ...string) (*NUMAFabric, error) { + return nil, nil + }) cfi.cacheItem.lastCached = tc.cacheTime } - test.AssertEqual(t, tc.expResult, cfi.NeedsRefresh(), "") + _, refreshed, _ := cfi.RefreshIfNeeded(test.Context(t)) + test.AssertEqual(t, tc.expResult, refreshed, "") }) } } @@ -416,7 +435,8 @@ func TestAgent_cachedFabricInfo_Refresh(t *testing.T) { } } - err := cfi.Refresh(test.Context(t)) + release, err := cfi.Refresh(test.Context(t)) + release() test.CmpErr(t, tc.expErr, err) diff --git a/src/control/cmd/daos_agent/mgmt_rpc.go b/src/control/cmd/daos_agent/mgmt_rpc.go index 75dc337e313..51cb80984c8 100644 --- a/src/control/cmd/daos_agent/mgmt_rpc.go +++ b/src/control/cmd/daos_agent/mgmt_rpc.go @@ -8,7 +8,6 @@ package main import ( "net" - "sync" "github.com/pkg/errors" "golang.org/x/net/context" @@ -22,6 +21,7 @@ import ( "github.com/daos-stack/daos/src/control/drpc" "github.com/daos-stack/daos/src/control/fault" "github.com/daos-stack/daos/src/control/fault/code" + "github.com/daos-stack/daos/src/control/lib/atm" "github.com/daos-stack/daos/src/control/lib/control" "github.com/daos-stack/daos/src/control/lib/daos" "github.com/daos-stack/daos/src/control/lib/hardware" @@ -34,16 +34,13 @@ import ( // Management Service proxy, handling dRPCs sent by libdaos by forwarding them // to MS. type mgmtModule struct { - attachInfoMutex sync.RWMutex - fabricMutex sync.RWMutex - log logging.Logger sys string ctlInvoker control.Invoker cache *InfoCache monitor *procMon cliMetricsSrc *promexp.ClientSource - useDefaultNUMA bool + useDefaultNUMA atm.Bool numaGetter hardware.ProcessNUMAProvider } @@ -161,14 +158,14 @@ func (mod *mgmtModule) handleGetAttachInfo(ctx context.Context, reqb []byte, pid } func (mod *mgmtModule) getNUMANode(ctx context.Context, pid int32) (uint, error) { - if mod.useDefaultNUMA { + if mod.useDefaultNUMA.IsTrue() { return 0, nil } numaNode, err := mod.numaGetter.GetNUMANodeIDForPID(ctx, pid) if errors.Is(err, hardware.ErrNoNUMANodes) { mod.log.Debug("system is not NUMA-aware") - mod.useDefaultNUMA = true + mod.useDefaultNUMA.SetTrue() return 0, nil } else if err != nil { return 0, errors.Wrapf(err, "failed to get NUMA node ID for pid %d", pid) diff --git a/src/control/cmd/daos_agent/mgmt_rpc_test.go b/src/control/cmd/daos_agent/mgmt_rpc_test.go index 59fcb507a81..965f3dc47c9 100644 --- a/src/control/cmd/daos_agent/mgmt_rpc_test.go +++ b/src/control/cmd/daos_agent/mgmt_rpc_test.go @@ -27,6 +27,7 @@ import ( "github.com/daos-stack/daos/src/control/drpc" "github.com/daos-stack/daos/src/control/fault" "github.com/daos-stack/daos/src/control/fault/code" + "github.com/daos-stack/daos/src/control/lib/atm" "github.com/daos-stack/daos/src/control/lib/control" "github.com/daos-stack/daos/src/control/lib/daos" "github.com/daos-stack/daos/src/control/lib/hardware" @@ -334,7 +335,7 @@ func TestAgent_mgmtModule_getNUMANode(t *testing.T) { mod := &mgmtModule{ log: log, - useDefaultNUMA: tc.useDefaultNUMA, + useDefaultNUMA: atm.NewBool(tc.useDefaultNUMA), numaGetter: tc.numaGetter, } diff --git a/src/control/cmd/daos_agent/security_rpc.go b/src/control/cmd/daos_agent/security_rpc.go index 906fe53ad8b..7fc6b942cd9 100644 --- a/src/control/cmd/daos_agent/security_rpc.go +++ b/src/control/cmd/daos_agent/security_rpc.go @@ -1,5 +1,5 @@ // -// (C) Copyright 2018-2022 Intel Corporation. +// (C) Copyright 2018-2024 Intel Corporation. // // SPDX-License-Identifier: BSD-2-Clause-Patent // @@ -8,44 +8,155 @@ package main import ( "context" + "fmt" "net" + "os/user" + "time" + + "github.com/pkg/errors" "github.com/daos-stack/daos/src/control/drpc" + "github.com/daos-stack/daos/src/control/lib/cache" "github.com/daos-stack/daos/src/control/lib/daos" "github.com/daos-stack/daos/src/control/logging" "github.com/daos-stack/daos/src/control/security" "github.com/daos-stack/daos/src/control/security/auth" ) -// SecurityModule is the security drpc module struct -type SecurityModule struct { - log logging.Logger - ext auth.UserExt - config *security.TransportConfig +type ( + // credSignerFn defines the function signature for signing credentials. + credSignerFn func(context.Context, *auth.CredentialRequest) (*auth.Credential, error) + + // credentialCache implements a cache for signed credentials. + credentialCache struct { + log logging.Logger + cache *cache.ItemCache + credLifetime time.Duration + cacheMissFn credSignerFn + } + + // cachedCredential wraps a cached credential and implements the cache.ExpirableItem interface. + cachedCredential struct { + cacheItem + key string + expiredAt time.Time + cred *auth.Credential + } + + // securityConfig defines configuration parameters for SecurityModule. + securityConfig struct { + credentials *security.CredentialConfig + transport *security.TransportConfig + } + + // SecurityModule is the security drpc module struct + SecurityModule struct { + log logging.Logger + signCredential credSignerFn + credCache *credentialCache + + config *securityConfig + } +) + +// NewSecurityModule creates a new module with the given initialized TransportConfig. +func NewSecurityModule(log logging.Logger, cfg *securityConfig) *SecurityModule { + var credCache *credentialCache + credSigner := auth.GetSignedCredential + if cfg.credentials.CacheLifetime > 0 { + cache := &credentialCache{ + log: log, + cache: cache.NewItemCache(log), + credLifetime: cfg.credentials.CacheLifetime, + cacheMissFn: auth.GetSignedCredential, + } + credSigner = cache.getSignedCredential + log.Noticef("credential cache enabled (entry lifetime: %s)", cfg.credentials.CacheLifetime) + } + + return &SecurityModule{ + log: log, + signCredential: credSigner, + credCache: credCache, + config: cfg, + } +} + +func credReqKey(req *auth.CredentialRequest) string { + return fmt.Sprintf("%d:%d:%s", req.DomainInfo.Uid(), req.DomainInfo.Gid(), req.DomainInfo.Ctx()) } -// NewSecurityModule creates a new module with the given initialized TransportConfig -func NewSecurityModule(log logging.Logger, tc *security.TransportConfig) *SecurityModule { - mod := SecurityModule{ - log: log, - config: tc, +func (cred *cachedCredential) Key() string { + if cred == nil { + return "" } - mod.ext = &auth.External{} - return &mod + + return cred.key +} + +func (cred *cachedCredential) IsExpired() bool { + if cred == nil || cred.cred == nil || cred.expiredAt.IsZero() { + return true + } + + return time.Now().After(cred.expiredAt) +} + +func (cc *credentialCache) getSignedCredential(ctx context.Context, req *auth.CredentialRequest) (*auth.Credential, error) { + key := credReqKey(req) + + createItem := func() (cache.Item, error) { + cc.log.Tracef("cache miss for %s", key) + cred, err := cc.cacheMissFn(ctx, req) + if err != nil { + return nil, err + } + cc.log.Tracef("getting credential for %s", key) + return newCachedCredential(key, cred, cc.credLifetime) + } + + item, release, err := cc.cache.GetOrCreate(ctx, key, createItem) + if err != nil { + return nil, errors.Wrap(err, "getting cached credential from cache") + } + defer release() + + cachedCred, ok := item.(*cachedCredential) + if !ok { + return nil, errors.New("invalid cached credential") + } + + return cachedCred.cred, nil +} + +func newCachedCredential(key string, cred *auth.Credential, lifetime time.Duration) (*cachedCredential, error) { + if cred == nil { + return nil, errors.New("credential is nil") + } + + return &cachedCredential{ + key: key, + cred: cred, + expiredAt: time.Now().Add(lifetime), + }, nil } // HandleCall is the handler for calls to the SecurityModule -func (m *SecurityModule) HandleCall(_ context.Context, session *drpc.Session, method drpc.Method, body []byte) ([]byte, error) { +func (m *SecurityModule) HandleCall(ctx context.Context, session *drpc.Session, method drpc.Method, body []byte) ([]byte, error) { if method != drpc.MethodRequestCredentials { return nil, drpc.UnknownMethodFailure() } - return m.getCredential(session) + return m.getCredential(ctx, session) } // getCredentials generates a signed user credential based on the data attached to // the Unix Domain Socket. -func (m *SecurityModule) getCredential(session *drpc.Session) ([]byte, error) { +func (m *SecurityModule) getCredential(ctx context.Context, session *drpc.Session) ([]byte, error) { + if session == nil { + return nil, drpc.NewFailureWithMessage("session is nil") + } + uConn, ok := session.Conn.(*net.UnixConn) if !ok { return nil, drpc.NewFailureWithMessage("connection is not a unix socket") @@ -57,20 +168,39 @@ func (m *SecurityModule) getCredential(session *drpc.Session) ([]byte, error) { return m.credRespWithStatus(daos.MiscError) } - signingKey, err := m.config.PrivateKey() + signingKey, err := m.config.transport.PrivateKey() if err != nil { m.log.Errorf("%s: failed to get signing key: %s", info, err) // something is wrong with the cert config return m.credRespWithStatus(daos.BadCert) } - cred, err := auth.AuthSysRequestFromCreds(m.ext, info, signingKey) + req := auth.NewCredentialRequest(info, signingKey) + cred, err := m.signCredential(ctx, req) if err != nil { - m.log.Errorf("%s: failed to get AuthSys struct: %s", info, err) - return m.credRespWithStatus(daos.MiscError) + if err := func() error { + if !errors.Is(err, user.UnknownUserIdError(info.Uid())) { + return err + } + + mu := m.config.credentials.ClientUserMap.Lookup(info.Uid()) + if mu == nil { + return user.UnknownUserIdError(info.Uid()) + } + + req.WithUserAndGroup(mu.User, mu.Group, mu.Groups...) + cred, err = m.signCredential(ctx, req) + if err != nil { + return err + } + + return nil + }(); err != nil { + m.log.Errorf("%s: failed to get user credential: %s", info, err) + return m.credRespWithStatus(daos.MiscError) + } } - m.log.Tracef("%s: successfully signed credential", info) resp := &auth.GetCredResp{Cred: cred} return drpc.Marshal(resp) } diff --git a/src/control/cmd/daos_agent/security_rpc_test.go b/src/control/cmd/daos_agent/security_rpc_test.go index 1c682aff1ca..52e298f0e88 100644 --- a/src/control/cmd/daos_agent/security_rpc_test.go +++ b/src/control/cmd/daos_agent/security_rpc_test.go @@ -1,5 +1,5 @@ // -// (C) Copyright 2019-2022 Intel Corporation. +// (C) Copyright 2019-2024 Intel Corporation. // // SPDX-License-Identifier: BSD-2-Clause-Patent // @@ -7,14 +7,22 @@ package main import ( + "context" "errors" "net" + "os/user" + "syscall" "testing" + "time" + "github.com/google/go-cmp/cmp" + "golang.org/x/sys/unix" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/testing/protocmp" "github.com/daos-stack/daos/src/control/common/test" "github.com/daos-stack/daos/src/control/drpc" + "github.com/daos-stack/daos/src/control/lib/cache" "github.com/daos-stack/daos/src/control/lib/daos" "github.com/daos-stack/daos/src/control/logging" "github.com/daos-stack/daos/src/control/security" @@ -25,7 +33,7 @@ func TestAgentSecurityModule_ID(t *testing.T) { log, buf := logging.NewTestLogger(t.Name()) defer test.ShowBufferOnFailure(t, buf) - mod := NewSecurityModule(log, nil) + mod := NewSecurityModule(log, defaultTestSecurityConfig()) test.AssertEqual(t, mod.ID(), drpc.ModuleSecurityAgent, "wrong drpc module") } @@ -35,15 +43,18 @@ func newTestSession(t *testing.T, log logging.Logger, conn net.Conn) *drpc.Sessi return drpc.NewSession(conn, svc) } -func defaultTestTransportConfig() *security.TransportConfig { - return &security.TransportConfig{AllowInsecure: true} +func defaultTestSecurityConfig() *securityConfig { + return &securityConfig{ + transport: &security.TransportConfig{AllowInsecure: true}, + credentials: &security.CredentialConfig{}, + } } func TestAgentSecurityModule_BadMethod(t *testing.T) { log, buf := logging.NewTestLogger(t.Name()) defer test.ShowBufferOnFailure(t, buf) - mod := NewSecurityModule(log, nil) + mod := NewSecurityModule(log, defaultTestSecurityConfig()) method, err := mod.ID().GetMethod(-1) if method != nil { t.Errorf("Expected no method, got %+v", method) @@ -74,6 +85,7 @@ func setupTestUnixConn(t *testing.T) (*net.UnixConn, func()) { } func getClientConn(t *testing.T, path string) drpc.DomainSocketClient { + t.Helper() client := drpc.NewClientConnection(path) if err := client.Connect(test.Context(t)); err != nil { t.Fatalf("Failed to connect: %v", err) @@ -82,6 +94,8 @@ func getClientConn(t *testing.T, path string) drpc.DomainSocketClient { } func expectCredResp(t *testing.T, respBytes []byte, expStatus int32, expCred bool) { + t.Helper() + if respBytes == nil { t.Error("Expected non-nil response") } @@ -104,7 +118,7 @@ func TestAgentSecurityModule_RequestCreds_OK(t *testing.T) { conn, cleanup := setupTestUnixConn(t) defer cleanup() - mod := NewSecurityModule(log, defaultTestTransportConfig()) + mod := NewSecurityModule(log, defaultTestSecurityConfig()) respBytes, err := callRequestCreds(mod, t, log, conn) if err != nil { @@ -118,7 +132,7 @@ func TestAgentSecurityModule_RequestCreds_NotUnixConn(t *testing.T) { log, buf := logging.NewTestLogger(t.Name()) defer test.ShowBufferOnFailure(t, buf) - mod := NewSecurityModule(log, defaultTestTransportConfig()) + mod := NewSecurityModule(log, defaultTestSecurityConfig()) respBytes, err := callRequestCreds(mod, t, log, &net.TCPConn{}) test.CmpErr(t, drpc.NewFailureWithMessage("connection is not a unix socket"), err) @@ -137,7 +151,7 @@ func TestAgentSecurityModule_RequestCreds_NotConnected(t *testing.T) { defer cleanup() conn.Close() // can't get uid/gid from a closed connection - mod := NewSecurityModule(log, defaultTestTransportConfig()) + mod := NewSecurityModule(log, defaultTestSecurityConfig()) respBytes, err := callRequestCreds(mod, t, log, conn) if err != nil { @@ -156,7 +170,10 @@ func TestAgentSecurityModule_RequestCreds_BadConfig(t *testing.T) { defer cleanup() // Empty TransportConfig is incomplete - mod := NewSecurityModule(log, &security.TransportConfig{}) + mod := NewSecurityModule(log, &securityConfig{ + transport: &security.TransportConfig{}, + credentials: &security.CredentialConfig{}, + }) respBytes, err := callRequestCreds(mod, t, log, conn) if err != nil { @@ -174,10 +191,9 @@ func TestAgentSecurityModule_RequestCreds_BadUid(t *testing.T) { conn, cleanup := setupTestUnixConn(t) defer cleanup() - mod := NewSecurityModule(log, defaultTestTransportConfig()) - mod.ext = &auth.MockExt{ - LookupUserIDErr: errors.New("LookupUserID"), - LookupGroupIDErr: errors.New("LookupGroupID"), + mod := NewSecurityModule(log, defaultTestSecurityConfig()) + mod.signCredential = func(_ context.Context, _ *auth.CredentialRequest) (*auth.Credential, error) { + return nil, errors.New("LookupUserID") } respBytes, err := callRequestCreds(mod, t, log, conn) @@ -187,3 +203,219 @@ func TestAgentSecurityModule_RequestCreds_BadUid(t *testing.T) { expectCredResp(t, respBytes, int32(daos.MiscError), false) } + +type signCredentialResp struct { + cred *auth.Credential + err error +} + +func TestAgent_SecurityRPC_getCredential(t *testing.T) { + testCred := &auth.Credential{ + Token: &auth.Token{Flavor: auth.Flavor_AUTH_SYS, Data: []byte("test-token")}, + Origin: "test-origin", + } + miscErrBytes, err := proto.Marshal( + &auth.GetCredResp{ + Status: int32(daos.MiscError), + }, + ) + if err != nil { + t.Fatalf("Couldn't marshal misc error: %v", err) + } + successBytes, err := proto.Marshal( + &auth.GetCredResp{ + Status: 0, + Cred: testCred, + }, + ) + if err != nil { + t.Fatalf("Couldn't marshal success: %v", err) + } + + for name, tc := range map[string]struct { + secCfg *securityConfig + responses []signCredentialResp + expBytes []byte + expErr error + }{ + "lookup miss": { + secCfg: defaultTestSecurityConfig(), + responses: []signCredentialResp{ + { + cred: nil, + err: user.UnknownUserIdError(unix.Getuid()), + }, + }, + expBytes: miscErrBytes, + }, + "lookup OK": { + secCfg: func() *securityConfig { + cfg := defaultTestSecurityConfig() + cfg.credentials.ClientUserMap = security.ClientUserMap{ + uint32(unix.Getuid()): &security.MappedClientUser{ + User: "test-user", + }, + } + return cfg + }(), + responses: []signCredentialResp{ + { + cred: nil, + err: user.UnknownUserIdError(unix.Getuid()), + }, + { + cred: testCred, + err: nil, + }, + }, + expBytes: successBytes, + }, + "lookup OK, but retried request fails": { + secCfg: func() *securityConfig { + cfg := defaultTestSecurityConfig() + cfg.credentials.ClientUserMap = security.ClientUserMap{ + uint32(unix.Getuid()): &security.MappedClientUser{ + User: "test-user", + }, + } + return cfg + }(), + responses: []signCredentialResp{ + { + cred: nil, + err: user.UnknownUserIdError(unix.Getuid()), + }, + { + cred: nil, + err: errors.New("oops"), + }, + }, + expBytes: miscErrBytes, + }, + } { + t.Run(name, func(t *testing.T) { + log, buf := logging.NewTestLogger(t.Name()) + defer test.ShowBufferOnFailure(t, buf) + + // Set up a real unix socket so we can make a real connection + conn, cleanup := setupTestUnixConn(t) + defer cleanup() + + mod := NewSecurityModule(log, tc.secCfg) + mod.signCredential = func() func(_ context.Context, req *auth.CredentialRequest) (*auth.Credential, error) { + var idx int + return func(_ context.Context, req *auth.CredentialRequest) (*auth.Credential, error) { + defer func() { + if idx < len(tc.responses)-1 { + idx++ + } + }() + t.Logf("returning response %d: %+v", idx, tc.responses[idx]) + return tc.responses[idx].cred, tc.responses[idx].err + } + }() + + respBytes, gotErr := callRequestCreds(mod, t, log, conn) + test.CmpErr(t, tc.expErr, gotErr) + if tc.expErr != nil { + return + } + if diff := cmp.Diff(tc.expBytes, respBytes); diff != "" { + t.Errorf("unexpected response (-want +got):\n%s", diff) + } + }) + } +} + +func TestAgent_SecurityCachedCredentials(t *testing.T) { + cred0 := &auth.Credential{ + Token: &auth.Token{Flavor: auth.Flavor_AUTH_SYS, Data: []byte("user,group1,group2")}, + Origin: "test-origin", + } + cred1 := &auth.Credential{ + Token: &auth.Token{Flavor: auth.Flavor_AUTH_SYS, Data: []byte("user,group1,group3")}, + Origin: "test-origin", + } + + for name, tc := range map[string]struct { + lifetime time.Duration + req *auth.CredentialRequest + responses []signCredentialResp + exp *auth.Credential + }{ + "cache hit": { + lifetime: time.Second, + req: &auth.CredentialRequest{ + DomainInfo: security.InitDomainInfo(&syscall.Ucred{Uid: 1234, Gid: 5678}, ""), + }, + responses: []signCredentialResp{ + { + cred: cred0, + err: nil, + }, + { + cred: cred1, + err: nil, + }, + }, + exp: cred0, + }, + "expired entry": { + lifetime: time.Nanosecond, + req: &auth.CredentialRequest{ + DomainInfo: security.InitDomainInfo(&syscall.Ucred{Uid: 1234, Gid: 5678}, ""), + }, + responses: []signCredentialResp{ + { + cred: cred0, + err: nil, + }, + { + cred: cred1, + err: nil, + }, + }, + exp: cred1, + }, + } { + t.Run(name, func(t *testing.T) { + log, buf := logging.NewTestLogger(t.Name()) + defer test.ShowBufferOnFailure(t, buf) + + cache := &credentialCache{ + log: log, + cache: cache.NewItemCache(log), + credLifetime: tc.lifetime, + cacheMissFn: func() func(_ context.Context, req *auth.CredentialRequest) (*auth.Credential, error) { + var idx int + return func(_ context.Context, req *auth.CredentialRequest) (*auth.Credential, error) { + defer func() { + if idx < len(tc.responses)-1 { + idx++ + } + }() + t.Logf("returning response %d: %+v", idx, tc.responses[idx]) + return tc.responses[idx].cred, tc.responses[idx].err + } + }(), + } + + // Prime the cache with a single entry. + _, err := cache.getSignedCredential(test.Context(t), tc.req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + // Request a second time with the same credentials. + cred, err := cache.getSignedCredential(test.Context(t), tc.req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + cmpOpts := cmp.Options{ + protocmp.Transform(), + } + if diff := cmp.Diff(tc.exp, cred, cmpOpts...); diff != "" { + t.Errorf("unexpected credential (-want +got):\n%s", diff) + } + }) + } +} diff --git a/src/control/cmd/daos_agent/start.go b/src/control/cmd/daos_agent/start.go index e5416ee874b..1ee2f4e7d76 100644 --- a/src/control/cmd/daos_agent/start.go +++ b/src/control/cmd/daos_agent/start.go @@ -114,7 +114,11 @@ func (cmd *startCmd) Execute(_ []string) error { } drpcRegStart := time.Now() - drpcServer.RegisterRPCModule(NewSecurityModule(cmd.Logger, cmd.cfg.TransportConfig)) + secCfg := &securityConfig{ + transport: cmd.cfg.TransportConfig, + credentials: cmd.cfg.CredentialConfig, + } + drpcServer.RegisterRPCModule(NewSecurityModule(cmd.Logger, secCfg)) mgmtMod := &mgmtModule{ log: cmd.Logger, sys: cmd.cfg.SystemName, diff --git a/src/control/lib/cache/cache.go b/src/control/lib/cache/cache.go index c31d82b0a61..9ecc123240c 100644 --- a/src/control/lib/cache/cache.go +++ b/src/control/lib/cache/cache.go @@ -1,5 +1,5 @@ // -// (C) Copyright 2023 Intel Corporation. +// (C) Copyright 2023-2024 Intel Corporation. // // SPDX-License-Identifier: BSD-2-Clause-Patent // @@ -18,20 +18,33 @@ import ( "github.com/daos-stack/daos/src/control/logging" ) -type Item interface { - Lock() - Unlock() - Key() string - Refresh(ctx context.Context) error - NeedsRefresh() bool -} +type ( + // Item defines an interface for a cached item. + Item interface { + sync.Locker + Key() string + } -// ItemCache is a mechanism for caching Items to keys. -type ItemCache struct { - log logging.Logger - mutex sync.RWMutex - items map[string]Item -} + // ExpirableItem is an Item that defines its own expiration criteria. + ExpirableItem interface { + Item + IsExpired() bool + } + + // RefreshableItem is an Item that defines its own refresh criteria and method. + RefreshableItem interface { + Item + Refresh(ctx context.Context) (func(), error) + RefreshIfNeeded(ctx context.Context) (func(), bool, error) + } + + // ItemCache is a mechanism for caching Items to keys. + ItemCache struct { + log logging.Logger + mutex sync.RWMutex + items map[string]Item + } +) // NewItemCache creates a new ItemCache. func NewItemCache(log logging.Logger) *ItemCache { @@ -119,22 +132,24 @@ func (e *errKeyNotFound) Error() string { return fmt.Sprintf("key %q not found", e.key) } -func noopRelease() {} +// NoopRelease is a no-op function that does nothing, but can +// be safely returned in lieu of a real lock release function. +func NoopRelease() {} // GetOrCreate returns an item from the cache if it exists, otherwise it creates // the item using the given function and caches it. The item must be released // by the caller when it is safe to be modified. func (ic *ItemCache) GetOrCreate(ctx context.Context, key string, missFn ItemCreateFunc) (Item, func(), error) { if ic == nil { - return nil, noopRelease, errors.New("nil ItemCache") + return nil, NoopRelease, errors.New("nil ItemCache") } if key == "" { - return nil, noopRelease, errors.Errorf("empty string is an invalid key") + return nil, NoopRelease, errors.Errorf("empty string is an invalid key") } if missFn == nil { - return nil, noopRelease, errors.Errorf("item create function is required") + return nil, NoopRelease, errors.Errorf("item create function is required") } ic.mutex.Lock() @@ -145,33 +160,39 @@ func (ic *ItemCache) GetOrCreate(ctx context.Context, key string, missFn ItemCre ic.log.Debugf("failed to get item for key %q: %s", key, err.Error()) item, err = missFn() if err != nil { - return nil, noopRelease, errors.Wrapf(err, "create item for %q", key) + return nil, NoopRelease, errors.Wrapf(err, "create item for %q", key) } ic.log.Debugf("created item for key %q", key) ic.set(item) } - item.Lock() - if item.NeedsRefresh() { - if err := item.Refresh(ctx); err != nil { - item.Unlock() - return nil, noopRelease, errors.Wrapf(err, "fetch data for %q", key) + var release func() + if ri, ok := item.(RefreshableItem); ok { + var refreshed bool + release, refreshed, err = ri.RefreshIfNeeded(ctx) + if err != nil { + return nil, NoopRelease, errors.Wrapf(err, "fetch data for %q", key) } - ic.log.Debugf("refreshed item %q", key) + if refreshed { + ic.log.Debugf("refreshed item %q", key) + } + } else { + item.Lock() + release = item.Unlock } - return item, item.Unlock, nil + return item, release, nil } // Get returns an item from the cache if it exists, otherwise it returns an // error. The item must be released by the caller when it is safe to be modified. func (ic *ItemCache) Get(ctx context.Context, key string) (Item, func(), error) { if ic == nil { - return nil, noopRelease, errors.New("nil ItemCache") + return nil, NoopRelease, errors.New("nil ItemCache") } if key == "" { - return nil, noopRelease, errors.Errorf("empty string is an invalid key") + return nil, NoopRelease, errors.Errorf("empty string is an invalid key") } ic.mutex.Lock() @@ -179,25 +200,35 @@ func (ic *ItemCache) Get(ctx context.Context, key string) (Item, func(), error) item, err := ic.get(key) if err != nil { - return nil, noopRelease, err + return nil, NoopRelease, err } - item.Lock() - if item.NeedsRefresh() { - if err := item.Refresh(ctx); err != nil { - item.Unlock() - return nil, noopRelease, errors.Wrapf(err, "fetch data for %q", key) + var release func() + if ri, ok := item.(RefreshableItem); ok { + var refreshed bool + release, refreshed, err = ri.RefreshIfNeeded(ctx) + if err != nil { + return nil, NoopRelease, errors.Wrapf(err, "fetch data for %q", key) } - ic.log.Debugf("refreshed item %q", key) + if refreshed { + ic.log.Debugf("refreshed item %q", key) + } + } else { + item.Lock() + release = item.Unlock } - return item, item.Unlock, nil + return item, release, nil } func (ic *ItemCache) get(key string) (Item, error) { - val, ok := ic.items[key] + item, ok := ic.items[key] if ok { - return val, nil + if ei, ok := item.(ExpirableItem); ok && ei.IsExpired() { + delete(ic.items, key) + } else { + return item, nil + } } return nil, &errKeyNotFound{key: key} } @@ -229,10 +260,12 @@ func (ic *ItemCache) refreshItem(ctx context.Context, key string) error { return err } - item.Lock() - defer item.Unlock() - if err := item.Refresh(ctx); err != nil { - return errors.Wrapf(err, "failed to refresh cached item %q", item.Key()) + if ri, ok := item.(RefreshableItem); ok { + release, err := ri.Refresh(ctx) + if err != nil { + return errors.Wrapf(err, "failed to refresh cached item %q", item.Key()) + } + release() } return nil diff --git a/src/control/lib/cache/cache_test.go b/src/control/lib/cache/cache_test.go index 6d1aeb465ea..0a34e8bf606 100644 --- a/src/control/lib/cache/cache_test.go +++ b/src/control/lib/cache/cache_test.go @@ -1,5 +1,5 @@ // -// (C) Copyright 2023 Intel Corporation. +// (C) Copyright 2023-2024 Intel Corporation. // // SPDX-License-Identifier: BSD-2-Clause-Patent // @@ -10,11 +10,12 @@ import ( "context" "testing" - "github.com/daos-stack/daos/src/control/common/test" - "github.com/daos-stack/daos/src/control/logging" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/pkg/errors" + + "github.com/daos-stack/daos/src/control/common/test" + "github.com/daos-stack/daos/src/control/logging" ) func TestCache_NewItemCache(t *testing.T) { @@ -51,11 +52,18 @@ func (m *mockItem) Key() string { return m.ItemKey } -func (m *mockItem) Refresh(ctx context.Context) error { - return m.RefreshErr +func (m *mockItem) Refresh(ctx context.Context) (func(), error) { + return NoopRelease, m.RefreshErr +} + +func (m *mockItem) RefreshIfNeeded(ctx context.Context) (func(), bool, error) { + if m.needsRefresh() { + return NoopRelease, true, m.RefreshErr + } + return NoopRelease, false, nil } -func (m *mockItem) NeedsRefresh() bool { +func (m *mockItem) needsRefresh() bool { return m.NeedsRefreshResult } diff --git a/src/control/lib/control/pool.go b/src/control/lib/control/pool.go index e94ae71a243..b188d7f2014 100644 --- a/src/control/lib/control/pool.go +++ b/src/control/lib/control/pool.go @@ -11,6 +11,7 @@ import ( "encoding/json" "fmt" "math" + "os/user" "sort" "strings" "time" @@ -29,7 +30,6 @@ import ( "github.com/daos-stack/daos/src/control/lib/daos" "github.com/daos-stack/daos/src/control/lib/ranklist" "github.com/daos-stack/daos/src/control/logging" - "github.com/daos-stack/daos/src/control/security/auth" "github.com/daos-stack/daos/src/control/server/storage" "github.com/daos-stack/daos/src/control/system" ) @@ -51,22 +51,19 @@ func checkUUID(uuidStr string) error { // formatNameGroup converts system names to principals, If user or group is not // provided, the effective user and/or effective group will be used. -func formatNameGroup(ext auth.UserExt, usr string, grp string) (string, string, error) { +func formatNameGroup(usr string, grp string) (string, string, error) { if usr == "" || grp == "" { - eUsr, err := ext.Current() + eUsr, err := user.Current() if err != nil { return "", "", err } if usr == "" { - usr = eUsr.Username() + usr = eUsr.Username } if grp == "" { - gid, err := eUsr.Gid() - if err != nil { - return "", "", err - } - eGrp, err := ext.LookupGroupID(gid) + gid := eUsr.Gid + eGrp, err := user.LookupGroupId(gid) if err != nil { return "", "", err } @@ -153,7 +150,7 @@ func (pcr *PoolCreateReq) MarshalJSON() ([]byte, error) { // request, filling in any missing fields with reasonable defaults. func genPoolCreateRequest(in *PoolCreateReq) (out *mgmtpb.PoolCreateReq, err error) { // ensure pool ownership is set up correctly - in.User, in.UserGroup, err = formatNameGroup(&auth.External{}, in.User, in.UserGroup) + in.User, in.UserGroup, err = formatNameGroup(in.User, in.UserGroup) if err != nil { return } diff --git a/src/control/security/auth/auth_sys.go b/src/control/security/auth/auth_sys.go index 74a4d006f20..fcedcb95bd5 100644 --- a/src/control/security/auth/auth_sys.go +++ b/src/control/security/auth/auth_sys.go @@ -8,6 +8,7 @@ package auth import ( "bytes" + "context" "crypto" "os" "os/user" @@ -17,87 +18,10 @@ import ( "github.com/pkg/errors" "google.golang.org/protobuf/proto" + "github.com/daos-stack/daos/src/control/logging" "github.com/daos-stack/daos/src/control/security" ) -// User is an interface wrapping a representation of a specific system user. -type User interface { - Username() string - GroupIDs() ([]uint32, error) - Gid() (uint32, error) -} - -// UserExt is an interface that wraps system user-related external functions. -type UserExt interface { - Current() (User, error) - LookupUserID(uid uint32) (User, error) - LookupGroupID(gid uint32) (*user.Group, error) -} - -// UserInfo is an exported implementation of the security.User interface. -type UserInfo struct { - Info *user.User -} - -// Username is a wrapper for user.Username. -func (u *UserInfo) Username() string { - return u.Info.Username -} - -// GroupIDs is a wrapper for user.GroupIds. -func (u *UserInfo) GroupIDs() ([]uint32, error) { - gidStrs, err := u.Info.GroupIds() - if err != nil { - return nil, err - } - - gids := []uint32{} - for _, gstr := range gidStrs { - gid, err := strconv.Atoi(gstr) - if err != nil { - continue - } - gids = append(gids, uint32(gid)) - } - - return gids, nil -} - -// Gid is a wrapper for user.Gid. -func (u *UserInfo) Gid() (uint32, error) { - gid, err := strconv.Atoi(u.Info.Gid) - - return uint32(gid), errors.Wrap(err, "user gid") -} - -// External is an exported implementation of the UserExt interface. -type External struct{} - -// LookupUserId is a wrapper for user.LookupId. -func (e *External) LookupUserID(uid uint32) (User, error) { - uidStr := strconv.FormatUint(uint64(uid), 10) - info, err := user.LookupId(uidStr) - if err != nil { - return nil, err - } - return &UserInfo{Info: info}, nil -} - -// LookupGroupId is a wrapper for user.LookupGroupId. -func (e *External) LookupGroupID(gid uint32) (*user.Group, error) { - gidStr := strconv.FormatUint(uint64(gid), 10) - return user.LookupGroupId(gidStr) -} - -// Current is a wrapper for user.Current. -func (e *External) Current() (User, error) { - info, err := user.Current() - if err != nil { - return nil, err - } - return &UserInfo{Info: info}, nil -} - // VerifierFromToken will return a SHA512 hash of the token data. If a signing key // is passed in it will additionally sign the hash of the token. func VerifierFromToken(key crypto.PublicKey, token *Token) ([]byte, error) { @@ -146,60 +70,188 @@ func sysNameToPrincipalName(name string) string { return name + "@" } -// AuthSysRequestFromCreds takes the domain info credentials gathered -// during the dRPC request and creates an AuthSys security request to obtain -// a handle from the management service. -func AuthSysRequestFromCreds(ext UserExt, creds *security.DomainInfo, signing crypto.PrivateKey) (*Credential, error) { - if creds == nil { - return nil, errors.New("No credentials supplied") +func stripHostName(name string) string { + return strings.Split(name, ".")[0] +} + +// GetMachineName returns the "short" hostname by stripping the domain from the FQDN. +func GetMachineName() (string, error) { + name, err := os.Hostname() + if err != nil { + return "", err + } + + return stripHostName(name), nil +} + +type ( + // CredentialRequest defines the request parameters for GetSignedCredential. + CredentialRequest struct { + DomainInfo *security.DomainInfo + SigningKey crypto.PrivateKey + getHostnameFn func() (string, error) + getUserFn func(string) (*user.User, error) + getGroupFn func(string) (*user.Group, error) + getGroupIdsFn func() ([]string, error) + getGroupNamesFn func() ([]string, error) + } +) + +// NewCredentialRequest returns a properly initialized CredentialRequest. +func NewCredentialRequest(info *security.DomainInfo, key crypto.PrivateKey) *CredentialRequest { + req := &CredentialRequest{ + DomainInfo: info, + SigningKey: key, + getHostnameFn: GetMachineName, + getUserFn: user.LookupId, + getGroupFn: user.LookupGroupId, + } + req.getGroupIdsFn = func() ([]string, error) { + u, err := req.user() + if err != nil { + return nil, err + } + return u.GroupIds() + } + req.getGroupNamesFn = func() ([]string, error) { + groupIds, err := req.getGroupIdsFn() + if err != nil { + return nil, err + } + + groupNames := make([]string, len(groupIds)) + for i, gID := range groupIds { + g, err := req.getGroupFn(gID) + if err != nil { + return nil, err + } + groupNames[i] = g.Name + } + + return groupNames, nil + } + + return req +} + +func (r *CredentialRequest) hostname() (string, error) { + if r.getHostnameFn == nil { + return "", errors.New("hostname lookup function not set") } - userInfo, err := ext.LookupUserID(creds.Uid()) + hostname, err := r.getHostnameFn() if err != nil { - return nil, errors.Wrapf(err, "Failed to lookup uid %v", - creds.Uid()) + return "", errors.Wrap(err, "failed to get hostname") + } + return stripHostName(hostname), nil +} + +func (r *CredentialRequest) user() (*user.User, error) { + if r.getUserFn == nil { + return nil, errors.New("user lookup function not set") } + return r.getUserFn(strconv.Itoa(int(r.DomainInfo.Uid()))) +} - groupInfo, err := ext.LookupGroupID(creds.Gid()) +func (r *CredentialRequest) userPrincipal() (string, error) { + u, err := r.user() if err != nil { - return nil, errors.Wrapf(err, "Failed to lookup gid %v", - creds.Gid()) + return "", err + } + return sysNameToPrincipalName(u.Username), nil +} + +func (r *CredentialRequest) group() (*user.Group, error) { + if r.getGroupFn == nil { + return nil, errors.New("group lookup function not set") } + return r.getGroupFn(strconv.Itoa(int(r.DomainInfo.Gid()))) +} - groups, err := userInfo.GroupIDs() +func (r *CredentialRequest) groupPrincipal() (string, error) { + g, err := r.group() if err != nil { - return nil, errors.Wrapf(err, "Failed to get group IDs for user %v", - userInfo.Username()) + return "", err } + return sysNameToPrincipalName(g.Name), nil +} - name, err := os.Hostname() +func (r *CredentialRequest) groupPrincipals() ([]string, error) { + if r.getGroupNamesFn == nil { + return nil, errors.New("groupNames function not set") + } + + groupNames, err := r.getGroupNamesFn() + if err != nil { + return nil, errors.Wrap(err, "failed to get group names") + } + + for i, g := range groupNames { + groupNames[i] = sysNameToPrincipalName(g) + } + return groupNames, nil +} + +// WithUserAndGroup provides an override to set the user, group, and optional list +// of group names to be used for the request. +func (r *CredentialRequest) WithUserAndGroup(userStr, groupStr string, groupStrs ...string) { + r.getUserFn = func(id string) (*user.User, error) { + return &user.User{ + Uid: id, + Gid: id, + Username: userStr, + }, nil + } + r.getGroupFn = func(id string) (*user.Group, error) { + return &user.Group{ + Gid: id, + Name: groupStr, + }, nil + } + r.getGroupNamesFn = func() ([]string, error) { + return groupStrs, nil + } +} + +// GetSignedCredential returns a credential based on the provided domain info and +// signing key. +func GetSignedCredential(ctx context.Context, req *CredentialRequest) (*Credential, error) { + if req == nil { + return nil, errors.Errorf("%T is nil", req) + } + + if req.DomainInfo == nil { + return nil, errors.New("No domain info supplied") + } + + hostname, err := req.hostname() if err != nil { - name = "unavailable" + return nil, err } - // Strip the domain off of the Hostname - host := strings.Split(name, ".")[0] + userPrinc, err := req.userPrincipal() + if err != nil { + return nil, err + } - var groupList = []string{} + groupPrinc, err := req.groupPrincipal() + if err != nil { + return nil, err + } - // Convert groups to gids - for _, gid := range groups { - gInfo, err := ext.LookupGroupID(gid) - if err != nil { - // Skip this group - continue - } - groupList = append(groupList, sysNameToPrincipalName(gInfo.Name)) + groupPrincs, err := req.groupPrincipals() + if err != nil { + return nil, err } // Craft AuthToken sys := Sys{ Stamp: 0, - Machinename: host, - User: sysNameToPrincipalName(userInfo.Username()), - Group: sysNameToPrincipalName(groupInfo.Name), - Groups: groupList, - Secctx: creds.Ctx()} + Machinename: hostname, + User: userPrinc, + Group: groupPrinc, + Groups: groupPrincs, + Secctx: req.DomainInfo.Ctx()} // Marshal our AuthSys token into a byte array tokenBytes, err := proto.Marshal(&sys) @@ -210,7 +262,7 @@ func AuthSysRequestFromCreds(ext UserExt, creds *security.DomainInfo, signing cr Flavor: Flavor_AUTH_SYS, Data: tokenBytes} - verifier, err := VerifierFromToken(signing, &token) + verifier, err := VerifierFromToken(req.SigningKey, &token) if err != nil { return nil, errors.WithMessage(err, "Unable to generate verifier") } @@ -224,6 +276,7 @@ func AuthSysRequestFromCreds(ext UserExt, creds *security.DomainInfo, signing cr Verifier: &verifierToken, Origin: "agent"} + logging.FromContext(ctx).Tracef("%s: successfully signed credential", req.DomainInfo) return &credential, nil } diff --git a/src/control/security/auth/auth_sys_test.go b/src/control/security/auth/auth_sys_test.go index fee5c69931b..a421d1b067c 100644 --- a/src/control/security/auth/auth_sys_test.go +++ b/src/control/security/auth/auth_sys_test.go @@ -1,5 +1,5 @@ // -// (C) Copyright 2018-2022 Intel Corporation. +// (C) Copyright 2018-2024 Intel Corporation. // // SPDX-License-Identifier: BSD-2-Clause-Patent // @@ -8,14 +8,13 @@ package auth import ( "errors" - "fmt" "os/user" "syscall" "testing" "google.golang.org/protobuf/proto" - . "github.com/daos-stack/daos/src/control/common/test" + "github.com/daos-stack/daos/src/control/common/test" "github.com/daos-stack/daos/src/control/security" ) @@ -29,7 +28,7 @@ func expectAuthSysErrorForToken(t *testing.T, badToken *Token, expectedErrorMess t.Error("Expected a nil AuthSys") } - CmpErr(t, errors.New(expectedErrorMessage), err) + test.CmpErr(t, errors.New(expectedErrorMessage), err) } // AuthSysFromAuthToken tests @@ -82,30 +81,61 @@ func TestAuthSysFromAuthToken_SucceedsWithGoodToken(t *testing.T) { t.Fatal("Got a nil AuthSys") } - AssertEqual(t, authSys.GetStamp(), originalAuthSys.GetStamp(), + test.AssertEqual(t, authSys.GetStamp(), originalAuthSys.GetStamp(), "Stamps don't match") - AssertEqual(t, authSys.GetMachinename(), originalAuthSys.GetMachinename(), + test.AssertEqual(t, authSys.GetMachinename(), originalAuthSys.GetMachinename(), "Machinenames don't match") - AssertEqual(t, authSys.GetUser(), originalAuthSys.GetUser(), + test.AssertEqual(t, authSys.GetUser(), originalAuthSys.GetUser(), "Owners don't match") - AssertEqual(t, authSys.GetGroup(), originalAuthSys.GetGroup(), + test.AssertEqual(t, authSys.GetGroup(), originalAuthSys.GetGroup(), "Groups don't match") - AssertEqual(t, len(authSys.GetGroups()), len(originalAuthSys.GetGroups()), + test.AssertEqual(t, len(authSys.GetGroups()), len(originalAuthSys.GetGroups()), "Group lists aren't the same length") - AssertEqual(t, authSys.GetSecctx(), originalAuthSys.GetSecctx(), + test.AssertEqual(t, authSys.GetSecctx(), originalAuthSys.GetSecctx(), "Secctx don't match") } -// AuthSysRequestFromCreds tests +func testHostnameFn(expErr error, hostname string) func() (string, error) { + return func() (string, error) { + if expErr != nil { + return "", expErr + } + return hostname, nil + } +} -func TestAuthSysRequestFromCreds_failsIfDomainInfoNil(t *testing.T) { - result, err := AuthSysRequestFromCreds(&MockExt{}, nil, nil) +func testUserFn(expErr error, userName string) func(string) (*user.User, error) { + return func(uid string) (*user.User, error) { + if expErr != nil { + return nil, expErr + } + return &user.User{ + Uid: uid, + Gid: uid, + Username: userName, + }, nil + } +} - if result != nil { - t.Error("Expected a nil request") +func testGroupFn(expErr error, groupName string) func(string) (*user.Group, error) { + return func(gid string) (*user.Group, error) { + if expErr != nil { + return nil, expErr + } + return &user.Group{ + Gid: gid, + Name: groupName, + }, nil } +} - ExpectError(t, err, "No credentials supplied", "") +func testGroupNamesFn(expErr error, groupNames ...string) func() ([]string, error) { + return func() ([]string, error) { + if expErr != nil { + return nil, expErr + } + return groupNames, nil + } } func getTestCreds(uid uint32, gid uint32) *security.DomainInfo { @@ -116,44 +146,10 @@ func getTestCreds(uid uint32, gid uint32) *security.DomainInfo { return security.InitDomainInfo(creds, "test") } -func TestAuthSysRequestFromCreds_returnsAuthSys(t *testing.T) { - ext := &MockExt{} - uid := uint32(15) - gid := uint32(2001) - gids := []uint32{1, 2, 3} - expectedUser := "myuser" - expectedGroup := "mygroup" - expectedGroupList := []string{"group1", "group2", "group3"} - creds := getTestCreds(uid, gid) - - ext.LookupUserIDResult = &MockUser{ - username: expectedUser, - groupIDs: gids, - } - ext.LookupGroupIDResults = []*user.Group{ - { - Name: expectedGroup, - }, - } - - for _, grp := range expectedGroupList { - ext.LookupGroupIDResults = append(ext.LookupGroupIDResults, - &user.Group{ - Name: grp, - }) - } - - result, err := AuthSysRequestFromCreds(ext, creds, nil) - - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - if result == nil { - t.Fatal("Credential was nil") - } +func verifyCredential(t *testing.T, cred *Credential, expHostname, expUserPrinc, expGroupPrinc string, expGroupPrincs ...string) { + t.Helper() - token := result.GetToken() + token := cred.GetToken() if token == nil { t.Fatal("Token was nil") } @@ -163,111 +159,128 @@ func TestAuthSysRequestFromCreds_returnsAuthSys(t *testing.T) { } authsys := &Sys{} - err = proto.Unmarshal(token.GetData(), authsys) + err := proto.Unmarshal(token.GetData(), authsys) if err != nil { t.Fatal("Failed to unmarshal token data") } - if authsys.GetUser() != expectedUser+"@" { + if authsys.GetMachinename() != expHostname { + t.Errorf("AuthSys had bad hostname: %v", authsys.GetMachinename()) + } + + if authsys.GetUser() != expUserPrinc { t.Errorf("AuthSys had bad username: %v", authsys.GetUser()) } - if authsys.GetGroup() != expectedGroup+"@" { + if authsys.GetGroup() != expGroupPrinc { t.Errorf("AuthSys had bad group name: %v", authsys.GetGroup()) } for i, group := range authsys.GetGroups() { - if group != expectedGroupList[i]+"@" { + if group != expGroupPrincs[i] { t.Errorf("AuthSys had bad group in list (idx %v): %v", i, group) } } } -func TestAuthSysRequestFromCreds_UidLookupFails(t *testing.T) { - ext := &MockExt{} - uid := uint32(15) - creds := getTestCreds(uid, 500) - - ext.LookupUserIDErr = errors.New("LookupUserID test error") - expectedErr := fmt.Errorf("Failed to lookup uid %v: %v", uid, - ext.LookupUserIDErr) - - result, err := AuthSysRequestFromCreds(ext, creds, nil) - - if result != nil { - t.Error("Expected a nil result") - } - - if err == nil { - t.Fatal("Expected an error") - } - - if err.Error() != expectedErr.Error() { - t.Errorf("Expected error '%v', got '%v'", expectedErr, err) - } -} - -func TestAuthSysRequestFromCreds_GidLookupFails(t *testing.T) { - ext := &MockExt{} - gid := uint32(205) - creds := getTestCreds(12, gid) - - ext.LookupUserIDResult = &MockUser{ - username: "user@", - groupIDs: []uint32{1, 2}, - } - - ext.LookupGroupIDErr = errors.New("LookupGroupID test error") - expectedErr := fmt.Errorf("Failed to lookup gid %v: %v", gid, - ext.LookupGroupIDErr) - - result, err := AuthSysRequestFromCreds(ext, creds, nil) - - if result != nil { - t.Error("Expected a nil result") - } - - if err == nil { - t.Fatal("Expected an error") - } - - if err.Error() != expectedErr.Error() { - t.Errorf("Expected error '%v', got '%v'", expectedErr, err) - } -} - -func TestAuthSysRequestFromCreds_GroupIDListFails(t *testing.T) { - ext := &MockExt{} - creds := getTestCreds(12, 15) - testUser := &MockUser{ - username: "user@", - groupIDs: []uint32{1, 2}, +func TestAuth_GetSignedCred(t *testing.T) { + testHostname := "test-host.domain.foo" + testUsername := "test-user" + testGroup := "test-group" + testGroupList := []string{"group1", "group2", "group3"} + + expectedHostname := "test-host" + expectedUser := testUsername + "@" + expectedGroup := testGroup + "@" + expectedGroupList := make([]string, len(testGroupList)) + for i, group := range testGroupList { + expectedGroupList[i] = group + "@" } - ext.LookupUserIDResult = testUser - - ext.LookupGroupIDResults = []*user.Group{ - { - Name: "group@", + for name, tc := range map[string]struct { + req *CredentialRequest + expErr error + }{ + "nil request": { + req: nil, + expErr: errors.New("is nil"), + }, + "nil DomainInfo": { + req: &CredentialRequest{}, + expErr: errors.New("No domain info supplied"), + }, + "bad hostname": { + req: func() *CredentialRequest { + req := NewCredentialRequest(getTestCreds(1, 2), nil) + req.getHostnameFn = testHostnameFn(errors.New("bad hostname"), "") + return req + }(), + expErr: errors.New("bad hostname"), + }, + "bad uid": { + req: func() *CredentialRequest { + req := NewCredentialRequest(getTestCreds(1, 2), nil) + req.getUserFn = testUserFn(errors.New("bad uid"), "") + return req + }(), + expErr: errors.New("bad uid"), + }, + "bad gid": { + req: func() *CredentialRequest { + req := NewCredentialRequest(getTestCreds(1, 2), nil) + req.getGroupFn = testGroupFn(errors.New("bad gid"), "") + return req + }(), + expErr: errors.New("bad gid"), + }, + "bad group IDs": { + req: func() *CredentialRequest { + req := NewCredentialRequest(getTestCreds(1, 2), nil) + req.getGroupIdsFn = testGroupNamesFn(errors.New("bad group IDs")) + return req + }(), + expErr: errors.New("bad group IDs"), + }, + "bad group names": { + req: func() *CredentialRequest { + req := NewCredentialRequest(getTestCreds(1, 2), nil) + req.getGroupNamesFn = testGroupNamesFn(errors.New("bad group names")) + return req + }(), + expErr: errors.New("bad group names"), }, + "valid": { + req: func() *CredentialRequest { + req := NewCredentialRequest(getTestCreds(1, 2), nil) + req.getHostnameFn = testHostnameFn(nil, testHostname) + req.getUserFn = testUserFn(nil, testUsername) + req.getGroupFn = testGroupFn(nil, testGroup) + req.getGroupNamesFn = testGroupNamesFn(nil, testGroupList...) + return req + }(), + }, + } { + t.Run(name, func(t *testing.T) { + cred, gotErr := GetSignedCredential(test.Context(t), tc.req) + test.CmpErr(t, tc.expErr, gotErr) + if tc.expErr != nil { + return + } + + verifyCredential(t, cred, expectedHostname, expectedUser, expectedGroup, expectedGroupList...) + }) } +} - testUser.groupIDErr = errors.New("GroupIDs test error") - expectedErr := fmt.Errorf("Failed to get group IDs for user %v: %v", - testUser.username, - testUser.groupIDErr) - - result, err := AuthSysRequestFromCreds(ext, creds, nil) +func TestAuth_CredentialRequestOverrides(t *testing.T) { + req := NewCredentialRequest(getTestCreds(1, 2), nil) + req.getHostnameFn = testHostnameFn(nil, "test-host") + req.WithUserAndGroup("test-user", "test-group", "test-secondary") - if result != nil { - t.Error("Expected a nil result") - } - - if err == nil { - t.Fatal("Expected an error") + cred, err := GetSignedCredential(test.Context(t), req) + if err != nil { + t.Fatalf("Failed to get credential: %s", err) } - if err.Error() != expectedErr.Error() { - t.Errorf("Expected error '%v', got '%v'", expectedErr, err) - } + verifyCredential(t, cred, "test-host", "test-user@", "test-group@", "test-secondary@") } diff --git a/src/control/security/auth/mocks.go b/src/control/security/auth/mocks.go deleted file mode 100644 index 3cde0f9471b..00000000000 --- a/src/control/security/auth/mocks.go +++ /dev/null @@ -1,65 +0,0 @@ -// -// (C) Copyright 2020-2021 Intel Corporation. -// -// SPDX-License-Identifier: BSD-2-Clause-Patent -// - -package auth - -import ( - "os/user" - - "github.com/pkg/errors" -) - -// Mocks - -type MockUser struct { - username string - groupIDs []uint32 - groupIDErr error -} - -func (u *MockUser) Username() string { - return u.username -} - -func (u *MockUser) GroupIDs() ([]uint32, error) { - return u.groupIDs, u.groupIDErr -} - -func (u *MockUser) Gid() (uint32, error) { - if len(u.groupIDs) == 0 { - return 0, errors.New("no mock gids to return") - } - return u.groupIDs[0], nil -} - -type MockExt struct { - LookupUserIDUid uint32 - LookupUserIDResult User - LookupUserIDErr error - LookupGroupIDGid uint32 - LookupGroupIDResults []*user.Group - LookupGroupIDCallCount uint32 - LookupGroupIDErr error -} - -func (e *MockExt) Current() (User, error) { - return e.LookupUserIDResult, e.LookupUserIDErr -} - -func (e *MockExt) LookupUserID(uid uint32) (User, error) { - e.LookupUserIDUid = uid - return e.LookupUserIDResult, e.LookupUserIDErr -} - -func (e *MockExt) LookupGroupID(gid uint32) (*user.Group, error) { - e.LookupGroupIDGid = gid - var result *user.Group - if len(e.LookupGroupIDResults) > 0 { - result = e.LookupGroupIDResults[e.LookupGroupIDCallCount] - } - e.LookupGroupIDCallCount++ - return result, e.LookupGroupIDErr -} diff --git a/src/control/security/config.go b/src/control/security/config.go index 8eef8584593..1913fd89fcd 100644 --- a/src/control/security/config.go +++ b/src/control/security/config.go @@ -1,5 +1,5 @@ // -// (C) Copyright 2019-2023 Intel Corporation. +// (C) Copyright 2019-2024 Intel Corporation. // // SPDX-License-Identifier: BSD-2-Clause-Patent // @@ -13,6 +13,7 @@ import ( "fmt" "io/fs" "os" + "strconv" "time" "github.com/pkg/errors" @@ -32,6 +33,71 @@ const ( defaultInsecure = false ) +// MappedClientUser represents a client user that is mapped to a uid. +type MappedClientUser struct { + User string `yaml:"user"` + Group string `yaml:"group"` + Groups []string `yaml:"groups"` +} + +const ( + defaultMapUser = "default" + defaultMapKey = ^uint32(0) +) + +// ClientUserMap is a map of uids to mapped client users. +type ClientUserMap map[uint32]*MappedClientUser + +func (cm *ClientUserMap) UnmarshalYAML(unmarshal func(interface{}) error) error { + strKeyMap := make(map[string]*MappedClientUser) + if err := unmarshal(&strKeyMap); err != nil { + return err + } + + tmp := make(ClientUserMap) + for strKey, value := range strKeyMap { + var key uint32 + switch strKey { + case defaultMapUser: + key = defaultMapKey + default: + parsedKey, err := strconv.ParseUint(strKey, 10, 32) + if err != nil { + return errors.Wrapf(err, "invalid uid %s", strKey) + } + + switch parsedKey { + case uint64(defaultMapKey): + return errors.Errorf("uid %d is reserved", parsedKey) + default: + key = uint32(parsedKey) + } + } + + tmp[key] = value + } + *cm = tmp + + return nil +} + +// Lookup attempts to resolve the supplied uid to a mapped +// client user. If the uid is not in the map, the default map key +// is returned. If the default map key is not found, nil is returned. +func (cm ClientUserMap) Lookup(uid uint32) *MappedClientUser { + if mu, found := cm[uid]; found { + return mu + } + return cm[defaultMapKey] +} + +// CredentialConfig contains configuration details for managing user +// credentials. +type CredentialConfig struct { + CacheLifetime time.Duration `yaml:"cache_lifetime,omitempty"` + ClientUserMap ClientUserMap `yaml:"client_user_map,omitempty"` +} + // TransportConfig contains all the information on whether or not to use // certificates and their location if their use is specified. type TransportConfig struct { diff --git a/src/control/security/config_test.go b/src/control/security/config_test.go index 8c94a99810a..450b9f6fbd2 100644 --- a/src/control/security/config_test.go +++ b/src/control/security/config_test.go @@ -1,5 +1,5 @@ // -// (C) Copyright 2019-2023 Intel Corporation. +// (C) Copyright 2019-2024 Intel Corporation. // // SPDX-License-Identifier: BSD-2-Clause-Patent // @@ -11,14 +11,17 @@ import ( "crypto" "crypto/x509" "encoding/pem" + "fmt" "io/ioutil" "os" "testing" "time" - "github.com/daos-stack/daos/src/control/common/test" "github.com/google/go-cmp/cmp" "github.com/pkg/errors" + "gopkg.in/yaml.v2" + + "github.com/daos-stack/daos/src/control/common/test" ) func InsecureTC() *TransportConfig { @@ -353,3 +356,100 @@ func TestSecurity_DefaultTransportConfigs(t *testing.T) { }) } } + +func TestSecurity_ClientUserMap(t *testing.T) { + for name, tc := range map[string]struct { + cfgYaml string + expMap ClientUserMap + expErr error + }{ + "empty": {}, + "defaultKey": { + cfgYaml: fmt.Sprintf(` +%d: + user: whoops +`, defaultMapKey), + expErr: errors.New("reserved"), + }, + "invalid uid (negative)": { + cfgYaml: ` +-1: + user: whoops +`, + expErr: errors.New("invalid uid"), + }, + "invalid uid (words)": { + cfgYaml: ` +blah: + user: whoops +`, + expErr: errors.New("invalid uid"), + }, + "invalid mapped user": { + cfgYaml: ` +1234: +user: whoops +`, + expErr: errors.New("unmarshal error"), + }, + "good": { + cfgYaml: ` +default: + user: banana + group: rama + groups: [ding, dong] +1234: + user: abc + group: def + groups: [yabba, dabba, doo] +5678: + user: ghi + group: jkl + groups: [mno, pqr, stu] +`, + expMap: ClientUserMap{ + defaultMapKey: { + User: "banana", + Group: "rama", + Groups: []string{"ding", "dong"}, + }, + 1234: { + User: "abc", + Group: "def", + Groups: []string{"yabba", "dabba", "doo"}, + }, + 5678: { + User: "ghi", + Group: "jkl", + Groups: []string{"mno", "pqr", "stu"}, + }, + }, + }, + } { + t.Run(name, func(t *testing.T) { + var result ClientUserMap + err := yaml.Unmarshal([]byte(tc.cfgYaml), &result) + test.CmpErr(t, tc.expErr, err) + if tc.expErr != nil { + return + } + if diff := cmp.Diff(tc.expMap, result); diff != "" { + t.Fatalf("unexpected ClientUserMap (-want, +got)\n %s", diff) + } + + for uid, exp := range tc.expMap { + gotUser := result.Lookup(uid) + if diff := cmp.Diff(exp.User, gotUser.User); diff != "" { + t.Fatalf("unexpected User (-want, +got)\n %s", diff) + } + } + + if expDefUser, found := tc.expMap[defaultMapKey]; found { + gotDefUser := result.Lookup(1234567) + if diff := cmp.Diff(expDefUser, gotDefUser); diff != "" { + t.Fatalf("unexpected DefaultUser (-want, +got)\n %s", diff) + } + } + }) + } +} diff --git a/src/control/security/domain_info.go b/src/control/security/domain_info.go index 15bfb06dd77..f0812c17f37 100644 --- a/src/control/security/domain_info.go +++ b/src/control/security/domain_info.go @@ -1,5 +1,5 @@ // -// (C) Copyright 2018-2021 Intel Corporation. +// (C) Copyright 2018-2024 Intel Corporation. // // SPDX-License-Identifier: BSD-2-Clause-Patent // @@ -63,6 +63,11 @@ func (d *DomainInfo) String() string { return outStr } +// Pid returns the PID obtained from the domain socket +func (d *DomainInfo) Pid() int32 { + return d.creds.Pid +} + // Uid returns the UID obtained from the domain socket func (d *DomainInfo) Uid() uint32 { return d.creds.Uid diff --git a/src/gurt/dlog.c b/src/gurt/dlog.c index 47635e21221..5c153943744 100644 --- a/src/gurt/dlog.c +++ b/src/gurt/dlog.c @@ -775,7 +775,7 @@ static int d_log_str2pri(const char *pstr, size_t len) * handle some quirks */ - if (strncasecmp(pstr, "ERR", len) == 0) + if (strncasecmp(pstr, "ERR", len) == 0 || strncasecmp(pstr, "ERROR", len) == 0) /* has trailing space in the array */ return DLOG_ERR; if (((strncasecmp(pstr, "DEBUG", len) == 0) || diff --git a/utils/config/daos_agent.yml b/utils/config/daos_agent.yml index 4a7f13b3654..cab0bd88976 100644 --- a/utils/config/daos_agent.yml +++ b/utils/config/daos_agent.yml @@ -47,8 +47,33 @@ ## default 0 (do not retain telemetry after client exit) #telemetry_retain: 1m -## Transport Credentials Specifying certificates to secure communications -# +## Configuration for user credential management. +#credential_config: +# # If the agent should be able to resolve unknown client uids and gids +# # (e.g. when running in a container) into ACL principal names, then a +# # client user map may be defined. The optional "default" uid is a special +# # case and applies if no other matches are found. +# client_user_map: +# default: +# user: nobody +# group: nobody +# 1000: +# user: ralph +# group: stanley +# +# # Optionally cache generated credentials with the specified cache +# # lifetime. By default, a credential is generated for every client +# # process that connects to a pool. If the credential cache is +# # enabled, then local client processes connecting with stable +# # uid:gid associations may take advantage of the cached credential +# # and reduce some agent overhead. For heavily-loaded client nodes +# # with many frequent (e.g. hundreds per minute) client connections, +# # a lifetime of 1-5 minutes may be a reasonable tradeoff between +# # performance and responsiveness to user/group database updates. +# cred_cache_lifetime: 1m +# +## Configuration for SSL certificates used to secure management traffic +# and authenticate/authorize management components. #transport_config: # # In order to disable transport security, uncomment and set allow_insecure # # to true. Not recommended for production configurations. @@ -60,6 +85,7 @@ # cert: /etc/daos/certs/agent.crt # # Key portion of Agent Certificate # key: /etc/daos/certs/agent.key +# # Use the given directory for creating unix domain sockets #