From 946bb723100d737a3c1d88d50376d4d3ebfc7a6d Mon Sep 17 00:00:00 2001 From: Kris Jacque Date: Tue, 23 Apr 2024 06:06:40 -0600 Subject: [PATCH 1/3] DAOS-15686 gurt: Accept ERROR as a log mask string (#14211) A change further up in the stack revealed that "ERROR" wasn't accepted as a log mask string at the engine level. Signed-off-by: Kris Jacque --- src/gurt/dlog.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gurt/dlog.c b/src/gurt/dlog.c index 47635e21221..5c153943744 100644 --- a/src/gurt/dlog.c +++ b/src/gurt/dlog.c @@ -775,7 +775,7 @@ static int d_log_str2pri(const char *pstr, size_t len) * handle some quirks */ - if (strncasecmp(pstr, "ERR", len) == 0) + if (strncasecmp(pstr, "ERR", len) == 0 || strncasecmp(pstr, "ERROR", len) == 0) /* has trailing space in the array */ return DLOG_ERR; if (((strncasecmp(pstr, "DEBUG", len) == 0) || From b2edc292386b4693914af557791eb4303aaabfb9 Mon Sep 17 00:00:00 2001 From: Michael MacDonald Date: Tue, 14 May 2024 02:53:56 +0000 Subject: [PATCH 2/3] DAOS-15849 control: Add client uid map to agent config Allow daos_agent to optionally handle unresolvable client uids via custom mapping. In deployments where the agent may not have access to the same user namespace as client applications (e.g. in containerized deployments), the client_user_map can provide a fallback mechanism for resolving the client uids to known usernames for the purpose of applying ACL permissions tests. Example agent config: credential_config: client_user_map: default: user: nobody group: nobody 1000: user: joe group: blow Features: control Required-githooks: true Change-Id: I72905ccc5ddee27fc2101aa4358a14e352c86253 Signed-off-by: Michael MacDonald --- src/control/cmd/daos_agent/config.go | 46 +-- src/control/cmd/daos_agent/config_test.go | 30 +- src/control/cmd/daos_agent/security_rpc.go | 71 +++-- .../cmd/daos_agent/security_rpc_test.go | 154 +++++++++- src/control/cmd/daos_agent/start.go | 6 +- src/control/lib/control/pool.go | 17 +- src/control/security/auth/auth_sys.go | 274 ++++++++++------- src/control/security/auth/auth_sys_test.go | 285 +++++++++--------- src/control/security/auth/mocks.go | 65 ---- src/control/security/config.go | 67 +++- src/control/security/config_test.go | 104 ++++++- src/control/security/domain_info.go | 7 +- utils/config/daos_agent.yml | 20 +- 13 files changed, 759 insertions(+), 387 deletions(-) delete mode 100644 src/control/security/auth/mocks.go diff --git a/src/control/cmd/daos_agent/config.go b/src/control/cmd/daos_agent/config.go index c9d08d19744..76d8a295cf3 100644 --- a/src/control/cmd/daos_agent/config.go +++ b/src/control/cmd/daos_agent/config.go @@ -43,21 +43,22 @@ func (rm refreshMinutes) Duration() time.Duration { // Config defines the agent configuration. type Config struct { - SystemName string `yaml:"name"` - AccessPoints []string `yaml:"access_points"` - ControlPort int `yaml:"port"` - RuntimeDir string `yaml:"runtime_dir"` - LogFile string `yaml:"log_file"` - LogLevel common.ControlLogLevel `yaml:"control_log_mask,omitempty"` - TransportConfig *security.TransportConfig `yaml:"transport_config"` - DisableCache bool `yaml:"disable_caching,omitempty"` - CacheExpiration refreshMinutes `yaml:"cache_expiration,omitempty"` - DisableAutoEvict bool `yaml:"disable_auto_evict,omitempty"` - ExcludeFabricIfaces common.StringSet `yaml:"exclude_fabric_ifaces,omitempty"` - FabricInterfaces []*NUMAFabricConfig `yaml:"fabric_ifaces,omitempty"` - TelemetryPort int `yaml:"telemetry_port,omitempty"` - TelemetryEnabled bool `yaml:"telemetry_enabled,omitempty"` - TelemetryRetain time.Duration `yaml:"telemetry_retain,omitempty"` + SystemName string `yaml:"name"` + AccessPoints []string `yaml:"access_points"` + ControlPort int `yaml:"port"` + RuntimeDir string `yaml:"runtime_dir"` + LogFile string `yaml:"log_file"` + LogLevel common.ControlLogLevel `yaml:"control_log_mask,omitempty"` + CredentialConfig *security.CredentialConfig `yaml:"credential_config"` + TransportConfig *security.TransportConfig `yaml:"transport_config"` + DisableCache bool `yaml:"disable_caching,omitempty"` + CacheExpiration refreshMinutes `yaml:"cache_expiration,omitempty"` + DisableAutoEvict bool `yaml:"disable_auto_evict,omitempty"` + ExcludeFabricIfaces common.StringSet `yaml:"exclude_fabric_ifaces,omitempty"` + FabricInterfaces []*NUMAFabricConfig `yaml:"fabric_ifaces,omitempty"` + TelemetryPort int `yaml:"telemetry_port,omitempty"` + TelemetryEnabled bool `yaml:"telemetry_enabled,omitempty"` + TelemetryRetain time.Duration `yaml:"telemetry_retain,omitempty"` } // TelemetryExportEnabled returns true if client telemetry export is enabled. @@ -112,12 +113,13 @@ func LoadConfig(cfgPath string) (*Config, error) { func DefaultConfig() *Config { localServer := fmt.Sprintf("localhost:%d", build.DefaultControlPort) return &Config{ - SystemName: build.DefaultSystemName, - ControlPort: build.DefaultControlPort, - AccessPoints: []string{localServer}, - RuntimeDir: defaultRuntimeDir, - LogFile: defaultLogFile, - LogLevel: common.DefaultControlLogLevel, - TransportConfig: security.DefaultAgentTransportConfig(), + SystemName: build.DefaultSystemName, + ControlPort: build.DefaultControlPort, + AccessPoints: []string{localServer}, + RuntimeDir: defaultRuntimeDir, + LogFile: defaultLogFile, + LogLevel: common.DefaultControlLogLevel, + TransportConfig: security.DefaultAgentTransportConfig(), + CredentialConfig: &security.CredentialConfig{}, } } diff --git a/src/control/cmd/daos_agent/config_test.go b/src/control/cmd/daos_agent/config_test.go index e21920e3735..162b60bbae0 100644 --- a/src/control/cmd/daos_agent/config_test.go +++ b/src/control/cmd/daos_agent/config_test.go @@ -1,5 +1,5 @@ // -// (C) Copyright 2021-2023 Intel Corporation. +// (C) Copyright 2021-2024 Intel Corporation. // // SPDX-License-Identifier: BSD-2-Clause-Patent // @@ -46,6 +46,12 @@ control_log_mask: debug disable_caching: true cache_expiration: 30 disable_auto_evict: true +credential_config: + client_user_map: + 1000: + user: frodo + group: baggins + groups: ["ringbearers"] transport_config: allow_insecure: true exclude_fabric_ifaces: ["ib3"] @@ -104,12 +110,13 @@ transport_config: "without optional items": { path: withoutOptCfg, expResult: &Config{ - SystemName: "shire", - AccessPoints: []string{"one:10001", "two:10001"}, - ControlPort: 4242, - RuntimeDir: "/tmp/runtime", - LogFile: "/home/frodo/logfile", - LogLevel: common.DefaultControlLogLevel, + SystemName: "shire", + AccessPoints: []string{"one:10001", "two:10001"}, + ControlPort: 4242, + RuntimeDir: "/tmp/runtime", + LogFile: "/home/frodo/logfile", + LogLevel: common.DefaultControlLogLevel, + CredentialConfig: &security.CredentialConfig{}, TransportConfig: &security.TransportConfig{ AllowInsecure: true, CertificateConfig: DefaultConfig().TransportConfig.CertificateConfig, @@ -132,6 +139,15 @@ transport_config: DisableCache: true, CacheExpiration: refreshMinutes(30 * time.Minute), DisableAutoEvict: true, + CredentialConfig: &security.CredentialConfig{ + ClientUserMap: map[uint32]*security.MappedClientUser{ + 1000: { + User: "frodo", + Group: "baggins", + Groups: []string{"ringbearers"}, + }, + }, + }, TransportConfig: &security.TransportConfig{ AllowInsecure: true, CertificateConfig: DefaultConfig().TransportConfig.CertificateConfig, diff --git a/src/control/cmd/daos_agent/security_rpc.go b/src/control/cmd/daos_agent/security_rpc.go index 906fe53ad8b..8843f18ed6f 100644 --- a/src/control/cmd/daos_agent/security_rpc.go +++ b/src/control/cmd/daos_agent/security_rpc.go @@ -1,5 +1,5 @@ // -// (C) Copyright 2018-2022 Intel Corporation. +// (C) Copyright 2018-2024 Intel Corporation. // // SPDX-License-Identifier: BSD-2-Clause-Patent // @@ -9,6 +9,9 @@ package main import ( "context" "net" + "os/user" + + "github.com/pkg/errors" "github.com/daos-stack/daos/src/control/drpc" "github.com/daos-stack/daos/src/control/lib/daos" @@ -17,21 +20,31 @@ import ( "github.com/daos-stack/daos/src/control/security/auth" ) -// SecurityModule is the security drpc module struct -type SecurityModule struct { - log logging.Logger - ext auth.UserExt - config *security.TransportConfig -} +type ( + credSignerFn func(*auth.CredentialRequest) (*auth.Credential, error) + + // securityConfig defines configuration parameters for SecurityModule. + securityConfig struct { + credentials *security.CredentialConfig + transport *security.TransportConfig + } + + // SecurityModule is the security drpc module struct + SecurityModule struct { + log logging.Logger + signCredential credSignerFn + + config *securityConfig + } +) // NewSecurityModule creates a new module with the given initialized TransportConfig -func NewSecurityModule(log logging.Logger, tc *security.TransportConfig) *SecurityModule { - mod := SecurityModule{ - log: log, - config: tc, +func NewSecurityModule(log logging.Logger, cfg *securityConfig) *SecurityModule { + return &SecurityModule{ + log: log, + signCredential: auth.GetSignedCredential, + config: cfg, } - mod.ext = &auth.External{} - return &mod } // HandleCall is the handler for calls to the SecurityModule @@ -46,6 +59,10 @@ func (m *SecurityModule) HandleCall(_ context.Context, session *drpc.Session, me // getCredentials generates a signed user credential based on the data attached to // the Unix Domain Socket. func (m *SecurityModule) getCredential(session *drpc.Session) ([]byte, error) { + if session == nil { + return nil, drpc.NewFailureWithMessage("session is nil") + } + uConn, ok := session.Conn.(*net.UnixConn) if !ok { return nil, drpc.NewFailureWithMessage("connection is not a unix socket") @@ -57,17 +74,37 @@ func (m *SecurityModule) getCredential(session *drpc.Session) ([]byte, error) { return m.credRespWithStatus(daos.MiscError) } - signingKey, err := m.config.PrivateKey() + signingKey, err := m.config.transport.PrivateKey() if err != nil { m.log.Errorf("%s: failed to get signing key: %s", info, err) // something is wrong with the cert config return m.credRespWithStatus(daos.BadCert) } - cred, err := auth.AuthSysRequestFromCreds(m.ext, info, signingKey) + req := auth.NewCredentialRequest(info, signingKey) + cred, err := m.signCredential(req) if err != nil { - m.log.Errorf("%s: failed to get AuthSys struct: %s", info, err) - return m.credRespWithStatus(daos.MiscError) + if err := func() error { + if !errors.Is(err, user.UnknownUserIdError(info.Uid())) { + return err + } + + mu := m.config.credentials.ClientUserMap.Lookup(info.Uid()) + if mu == nil { + return user.UnknownUserIdError(info.Uid()) + } + + req.WithUserAndGroup(mu.User, mu.Group, mu.Groups...) + cred, err = m.signCredential(req) + if err != nil { + return err + } + + return nil + }(); err != nil { + m.log.Errorf("%s: failed to get user credential: %s", info, err) + return m.credRespWithStatus(daos.MiscError) + } } m.log.Tracef("%s: successfully signed credential", info) diff --git a/src/control/cmd/daos_agent/security_rpc_test.go b/src/control/cmd/daos_agent/security_rpc_test.go index 1c682aff1ca..08ee4755b27 100644 --- a/src/control/cmd/daos_agent/security_rpc_test.go +++ b/src/control/cmd/daos_agent/security_rpc_test.go @@ -1,5 +1,5 @@ // -// (C) Copyright 2019-2022 Intel Corporation. +// (C) Copyright 2019-2024 Intel Corporation. // // SPDX-License-Identifier: BSD-2-Clause-Patent // @@ -9,8 +9,11 @@ package main import ( "errors" "net" + "os/user" "testing" + "github.com/google/go-cmp/cmp" + "golang.org/x/sys/unix" "google.golang.org/protobuf/proto" "github.com/daos-stack/daos/src/control/common/test" @@ -35,8 +38,11 @@ func newTestSession(t *testing.T, log logging.Logger, conn net.Conn) *drpc.Sessi return drpc.NewSession(conn, svc) } -func defaultTestTransportConfig() *security.TransportConfig { - return &security.TransportConfig{AllowInsecure: true} +func defaultTestSecurityConfig() *securityConfig { + return &securityConfig{ + transport: &security.TransportConfig{AllowInsecure: true}, + credentials: &security.CredentialConfig{}, + } } func TestAgentSecurityModule_BadMethod(t *testing.T) { @@ -74,6 +80,7 @@ func setupTestUnixConn(t *testing.T) (*net.UnixConn, func()) { } func getClientConn(t *testing.T, path string) drpc.DomainSocketClient { + t.Helper() client := drpc.NewClientConnection(path) if err := client.Connect(test.Context(t)); err != nil { t.Fatalf("Failed to connect: %v", err) @@ -82,6 +89,8 @@ func getClientConn(t *testing.T, path string) drpc.DomainSocketClient { } func expectCredResp(t *testing.T, respBytes []byte, expStatus int32, expCred bool) { + t.Helper() + if respBytes == nil { t.Error("Expected non-nil response") } @@ -104,7 +113,7 @@ func TestAgentSecurityModule_RequestCreds_OK(t *testing.T) { conn, cleanup := setupTestUnixConn(t) defer cleanup() - mod := NewSecurityModule(log, defaultTestTransportConfig()) + mod := NewSecurityModule(log, defaultTestSecurityConfig()) respBytes, err := callRequestCreds(mod, t, log, conn) if err != nil { @@ -118,7 +127,7 @@ func TestAgentSecurityModule_RequestCreds_NotUnixConn(t *testing.T) { log, buf := logging.NewTestLogger(t.Name()) defer test.ShowBufferOnFailure(t, buf) - mod := NewSecurityModule(log, defaultTestTransportConfig()) + mod := NewSecurityModule(log, defaultTestSecurityConfig()) respBytes, err := callRequestCreds(mod, t, log, &net.TCPConn{}) test.CmpErr(t, drpc.NewFailureWithMessage("connection is not a unix socket"), err) @@ -137,7 +146,7 @@ func TestAgentSecurityModule_RequestCreds_NotConnected(t *testing.T) { defer cleanup() conn.Close() // can't get uid/gid from a closed connection - mod := NewSecurityModule(log, defaultTestTransportConfig()) + mod := NewSecurityModule(log, defaultTestSecurityConfig()) respBytes, err := callRequestCreds(mod, t, log, conn) if err != nil { @@ -156,7 +165,9 @@ func TestAgentSecurityModule_RequestCreds_BadConfig(t *testing.T) { defer cleanup() // Empty TransportConfig is incomplete - mod := NewSecurityModule(log, &security.TransportConfig{}) + mod := NewSecurityModule(log, &securityConfig{ + transport: &security.TransportConfig{}, + }) respBytes, err := callRequestCreds(mod, t, log, conn) if err != nil { @@ -174,10 +185,9 @@ func TestAgentSecurityModule_RequestCreds_BadUid(t *testing.T) { conn, cleanup := setupTestUnixConn(t) defer cleanup() - mod := NewSecurityModule(log, defaultTestTransportConfig()) - mod.ext = &auth.MockExt{ - LookupUserIDErr: errors.New("LookupUserID"), - LookupGroupIDErr: errors.New("LookupGroupID"), + mod := NewSecurityModule(log, defaultTestSecurityConfig()) + mod.signCredential = func(_ *auth.CredentialRequest) (*auth.Credential, error) { + return nil, errors.New("LookupUserID") } respBytes, err := callRequestCreds(mod, t, log, conn) @@ -187,3 +197,125 @@ func TestAgentSecurityModule_RequestCreds_BadUid(t *testing.T) { expectCredResp(t, respBytes, int32(daos.MiscError), false) } + +func TestAgent_SecurityRPC_getCredential(t *testing.T) { + type response struct { + cred *auth.Credential + err error + } + testCred := &auth.Credential{ + Token: &auth.Token{Flavor: auth.Flavor_AUTH_SYS, Data: []byte("test-token")}, + Origin: "test-origin", + } + miscErrBytes, err := proto.Marshal( + &auth.GetCredResp{ + Status: int32(daos.MiscError), + }, + ) + if err != nil { + t.Fatalf("Couldn't marshal misc error: %v", err) + } + successBytes, err := proto.Marshal( + &auth.GetCredResp{ + Status: 0, + Cred: testCred, + }, + ) + if err != nil { + t.Fatalf("Couldn't marshal success: %v", err) + } + + for name, tc := range map[string]struct { + secCfg *securityConfig + responses []response + expBytes []byte + expErr error + }{ + "lookup miss": { + secCfg: defaultTestSecurityConfig(), + responses: []response{ + { + cred: nil, + err: user.UnknownUserIdError(unix.Getuid()), + }, + }, + expBytes: miscErrBytes, + }, + "lookup OK": { + secCfg: func() *securityConfig { + cfg := defaultTestSecurityConfig() + cfg.credentials.ClientUserMap = security.ClientUserMap{ + uint32(unix.Getuid()): &security.MappedClientUser{ + User: "test-user", + }, + } + return cfg + }(), + responses: []response{ + { + cred: nil, + err: user.UnknownUserIdError(unix.Getuid()), + }, + { + cred: testCred, + err: nil, + }, + }, + expBytes: successBytes, + }, + "lookup OK, but retried request fails": { + secCfg: func() *securityConfig { + cfg := defaultTestSecurityConfig() + cfg.credentials.ClientUserMap = security.ClientUserMap{ + uint32(unix.Getuid()): &security.MappedClientUser{ + User: "test-user", + }, + } + return cfg + }(), + responses: []response{ + { + cred: nil, + err: user.UnknownUserIdError(unix.Getuid()), + }, + { + cred: nil, + err: errors.New("oops"), + }, + }, + expBytes: miscErrBytes, + }, + } { + t.Run(name, func(t *testing.T) { + log, buf := logging.NewTestLogger(t.Name()) + defer test.ShowBufferOnFailure(t, buf) + + // Set up a real unix socket so we can make a real connection + conn, cleanup := setupTestUnixConn(t) + defer cleanup() + + mod := NewSecurityModule(log, tc.secCfg) + mod.signCredential = func() func(req *auth.CredentialRequest) (*auth.Credential, error) { + var idx int + return func(req *auth.CredentialRequest) (*auth.Credential, error) { + defer func() { + if idx < len(tc.responses)-1 { + idx++ + } + }() + t.Logf("returning response %d: %+v", idx, tc.responses[idx]) + return tc.responses[idx].cred, tc.responses[idx].err + } + }() + + respBytes, gotErr := callRequestCreds(mod, t, log, conn) + test.CmpErr(t, tc.expErr, gotErr) + if tc.expErr != nil { + return + } + if diff := cmp.Diff(tc.expBytes, respBytes); diff != "" { + t.Errorf("unexpected response (-want +got):\n%s", diff) + } + }) + } +} diff --git a/src/control/cmd/daos_agent/start.go b/src/control/cmd/daos_agent/start.go index e5416ee874b..1ee2f4e7d76 100644 --- a/src/control/cmd/daos_agent/start.go +++ b/src/control/cmd/daos_agent/start.go @@ -114,7 +114,11 @@ func (cmd *startCmd) Execute(_ []string) error { } drpcRegStart := time.Now() - drpcServer.RegisterRPCModule(NewSecurityModule(cmd.Logger, cmd.cfg.TransportConfig)) + secCfg := &securityConfig{ + transport: cmd.cfg.TransportConfig, + credentials: cmd.cfg.CredentialConfig, + } + drpcServer.RegisterRPCModule(NewSecurityModule(cmd.Logger, secCfg)) mgmtMod := &mgmtModule{ log: cmd.Logger, sys: cmd.cfg.SystemName, diff --git a/src/control/lib/control/pool.go b/src/control/lib/control/pool.go index 979660891af..46386c986e4 100644 --- a/src/control/lib/control/pool.go +++ b/src/control/lib/control/pool.go @@ -11,6 +11,7 @@ import ( "encoding/json" "fmt" "math" + "os/user" "sort" "strconv" "strings" @@ -30,7 +31,6 @@ import ( "github.com/daos-stack/daos/src/control/lib/daos" "github.com/daos-stack/daos/src/control/lib/ranklist" "github.com/daos-stack/daos/src/control/logging" - "github.com/daos-stack/daos/src/control/security/auth" "github.com/daos-stack/daos/src/control/server/storage" "github.com/daos-stack/daos/src/control/system" ) @@ -52,22 +52,19 @@ func checkUUID(uuidStr string) error { // formatNameGroup converts system names to principals, If user or group is not // provided, the effective user and/or effective group will be used. -func formatNameGroup(ext auth.UserExt, usr string, grp string) (string, string, error) { +func formatNameGroup(usr string, grp string) (string, string, error) { if usr == "" || grp == "" { - eUsr, err := ext.Current() + eUsr, err := user.Current() if err != nil { return "", "", err } if usr == "" { - usr = eUsr.Username() + usr = eUsr.Username } if grp == "" { - gid, err := eUsr.Gid() - if err != nil { - return "", "", err - } - eGrp, err := ext.LookupGroupID(gid) + gid := eUsr.Gid + eGrp, err := user.LookupGroupId(gid) if err != nil { return "", "", err } @@ -154,7 +151,7 @@ func (pcr *PoolCreateReq) MarshalJSON() ([]byte, error) { // request, filling in any missing fields with reasonable defaults. func genPoolCreateRequest(in *PoolCreateReq) (out *mgmtpb.PoolCreateReq, err error) { // ensure pool ownership is set up correctly - in.User, in.UserGroup, err = formatNameGroup(&auth.External{}, in.User, in.UserGroup) + in.User, in.UserGroup, err = formatNameGroup(in.User, in.UserGroup) if err != nil { return } diff --git a/src/control/security/auth/auth_sys.go b/src/control/security/auth/auth_sys.go index 74a4d006f20..e9255f407d3 100644 --- a/src/control/security/auth/auth_sys.go +++ b/src/control/security/auth/auth_sys.go @@ -20,84 +20,6 @@ import ( "github.com/daos-stack/daos/src/control/security" ) -// User is an interface wrapping a representation of a specific system user. -type User interface { - Username() string - GroupIDs() ([]uint32, error) - Gid() (uint32, error) -} - -// UserExt is an interface that wraps system user-related external functions. -type UserExt interface { - Current() (User, error) - LookupUserID(uid uint32) (User, error) - LookupGroupID(gid uint32) (*user.Group, error) -} - -// UserInfo is an exported implementation of the security.User interface. -type UserInfo struct { - Info *user.User -} - -// Username is a wrapper for user.Username. -func (u *UserInfo) Username() string { - return u.Info.Username -} - -// GroupIDs is a wrapper for user.GroupIds. -func (u *UserInfo) GroupIDs() ([]uint32, error) { - gidStrs, err := u.Info.GroupIds() - if err != nil { - return nil, err - } - - gids := []uint32{} - for _, gstr := range gidStrs { - gid, err := strconv.Atoi(gstr) - if err != nil { - continue - } - gids = append(gids, uint32(gid)) - } - - return gids, nil -} - -// Gid is a wrapper for user.Gid. -func (u *UserInfo) Gid() (uint32, error) { - gid, err := strconv.Atoi(u.Info.Gid) - - return uint32(gid), errors.Wrap(err, "user gid") -} - -// External is an exported implementation of the UserExt interface. -type External struct{} - -// LookupUserId is a wrapper for user.LookupId. -func (e *External) LookupUserID(uid uint32) (User, error) { - uidStr := strconv.FormatUint(uint64(uid), 10) - info, err := user.LookupId(uidStr) - if err != nil { - return nil, err - } - return &UserInfo{Info: info}, nil -} - -// LookupGroupId is a wrapper for user.LookupGroupId. -func (e *External) LookupGroupID(gid uint32) (*user.Group, error) { - gidStr := strconv.FormatUint(uint64(gid), 10) - return user.LookupGroupId(gidStr) -} - -// Current is a wrapper for user.Current. -func (e *External) Current() (User, error) { - info, err := user.Current() - if err != nil { - return nil, err - } - return &UserInfo{Info: info}, nil -} - // VerifierFromToken will return a SHA512 hash of the token data. If a signing key // is passed in it will additionally sign the hash of the token. func VerifierFromToken(key crypto.PublicKey, token *Token) ([]byte, error) { @@ -146,60 +68,188 @@ func sysNameToPrincipalName(name string) string { return name + "@" } -// AuthSysRequestFromCreds takes the domain info credentials gathered -// during the dRPC request and creates an AuthSys security request to obtain -// a handle from the management service. -func AuthSysRequestFromCreds(ext UserExt, creds *security.DomainInfo, signing crypto.PrivateKey) (*Credential, error) { - if creds == nil { - return nil, errors.New("No credentials supplied") +func stripHostName(name string) string { + return strings.Split(name, ".")[0] +} + +// GetMachineName returns the "short" hostname by stripping the domain from the FQDN. +func GetMachineName() (string, error) { + name, err := os.Hostname() + if err != nil { + return "", err + } + + return stripHostName(name), nil +} + +type ( + // CredentialRequest defines the request parameters for GetSignedCredential. + CredentialRequest struct { + DomainInfo *security.DomainInfo + SigningKey crypto.PrivateKey + getHostnameFn func() (string, error) + getUserFn func(string) (*user.User, error) + getGroupFn func(string) (*user.Group, error) + getGroupIdsFn func() ([]string, error) + getGroupNamesFn func() ([]string, error) + } +) + +// NewCredentialRequest returns a properly initialized CredentialRequest. +func NewCredentialRequest(info *security.DomainInfo, key crypto.PrivateKey) *CredentialRequest { + req := &CredentialRequest{ + DomainInfo: info, + SigningKey: key, + getHostnameFn: GetMachineName, + getUserFn: user.LookupId, + getGroupFn: user.LookupGroupId, + } + req.getGroupIdsFn = func() ([]string, error) { + u, err := req.user() + if err != nil { + return nil, err + } + return u.GroupIds() + } + req.getGroupNamesFn = func() ([]string, error) { + groupIds, err := req.getGroupIdsFn() + if err != nil { + return nil, err + } + + groupNames := make([]string, len(groupIds)) + for i, gID := range groupIds { + g, err := req.getGroupFn(gID) + if err != nil { + return nil, err + } + groupNames[i] = g.Name + } + + return groupNames, nil + } + + return req +} + +func (r *CredentialRequest) hostname() (string, error) { + if r.getHostnameFn == nil { + return "", errors.New("hostname lookup function not set") } - userInfo, err := ext.LookupUserID(creds.Uid()) + hostname, err := r.getHostnameFn() if err != nil { - return nil, errors.Wrapf(err, "Failed to lookup uid %v", - creds.Uid()) + return "", errors.Wrap(err, "failed to get hostname") + } + return stripHostName(hostname), nil +} + +func (r *CredentialRequest) user() (*user.User, error) { + if r.getUserFn == nil { + return nil, errors.New("user lookup function not set") } + return r.getUserFn(strconv.Itoa(int(r.DomainInfo.Uid()))) +} - groupInfo, err := ext.LookupGroupID(creds.Gid()) +func (r *CredentialRequest) userPrincipal() (string, error) { + u, err := r.user() if err != nil { - return nil, errors.Wrapf(err, "Failed to lookup gid %v", - creds.Gid()) + return "", err + } + return sysNameToPrincipalName(u.Username), nil +} + +func (r *CredentialRequest) group() (*user.Group, error) { + if r.getGroupFn == nil { + return nil, errors.New("group lookup function not set") } + return r.getGroupFn(strconv.Itoa(int(r.DomainInfo.Gid()))) +} - groups, err := userInfo.GroupIDs() +func (r *CredentialRequest) groupPrincipal() (string, error) { + g, err := r.group() if err != nil { - return nil, errors.Wrapf(err, "Failed to get group IDs for user %v", - userInfo.Username()) + return "", err } + return sysNameToPrincipalName(g.Name), nil +} - name, err := os.Hostname() +func (r *CredentialRequest) groupPrincipals() ([]string, error) { + if r.getGroupNamesFn == nil { + return nil, errors.New("groupNames function not set") + } + + groupNames, err := r.getGroupNamesFn() + if err != nil { + return nil, errors.Wrap(err, "failed to get group names") + } + + for i, g := range groupNames { + groupNames[i] = sysNameToPrincipalName(g) + } + return groupNames, nil +} + +// WithUserAndGroup provides an override to set the user, group, and optional list +// of group names to be used for the request. +func (r *CredentialRequest) WithUserAndGroup(userStr, groupStr string, groupStrs ...string) { + r.getUserFn = func(id string) (*user.User, error) { + return &user.User{ + Uid: id, + Gid: id, + Username: userStr, + }, nil + } + r.getGroupFn = func(id string) (*user.Group, error) { + return &user.Group{ + Gid: id, + Name: groupStr, + }, nil + } + r.getGroupNamesFn = func() ([]string, error) { + return groupStrs, nil + } +} + +// GetSignedCredential returns a credential based on the provided domain info and +// signing key. +func GetSignedCredential(req *CredentialRequest) (*Credential, error) { + if req == nil { + return nil, errors.Errorf("%T is nil", req) + } + + if req.DomainInfo == nil { + return nil, errors.New("No domain info supplied") + } + + hostname, err := req.hostname() if err != nil { - name = "unavailable" + return nil, err } - // Strip the domain off of the Hostname - host := strings.Split(name, ".")[0] + userPrinc, err := req.userPrincipal() + if err != nil { + return nil, err + } - var groupList = []string{} + groupPrinc, err := req.groupPrincipal() + if err != nil { + return nil, err + } - // Convert groups to gids - for _, gid := range groups { - gInfo, err := ext.LookupGroupID(gid) - if err != nil { - // Skip this group - continue - } - groupList = append(groupList, sysNameToPrincipalName(gInfo.Name)) + groupPrincs, err := req.groupPrincipals() + if err != nil { + return nil, err } // Craft AuthToken sys := Sys{ Stamp: 0, - Machinename: host, - User: sysNameToPrincipalName(userInfo.Username()), - Group: sysNameToPrincipalName(groupInfo.Name), - Groups: groupList, - Secctx: creds.Ctx()} + Machinename: hostname, + User: userPrinc, + Group: groupPrinc, + Groups: groupPrincs, + Secctx: req.DomainInfo.Ctx()} // Marshal our AuthSys token into a byte array tokenBytes, err := proto.Marshal(&sys) @@ -210,7 +260,7 @@ func AuthSysRequestFromCreds(ext UserExt, creds *security.DomainInfo, signing cr Flavor: Flavor_AUTH_SYS, Data: tokenBytes} - verifier, err := VerifierFromToken(signing, &token) + verifier, err := VerifierFromToken(req.SigningKey, &token) if err != nil { return nil, errors.WithMessage(err, "Unable to generate verifier") } diff --git a/src/control/security/auth/auth_sys_test.go b/src/control/security/auth/auth_sys_test.go index fee5c69931b..bbca64fe233 100644 --- a/src/control/security/auth/auth_sys_test.go +++ b/src/control/security/auth/auth_sys_test.go @@ -1,5 +1,5 @@ // -// (C) Copyright 2018-2022 Intel Corporation. +// (C) Copyright 2018-2024 Intel Corporation. // // SPDX-License-Identifier: BSD-2-Clause-Patent // @@ -8,14 +8,13 @@ package auth import ( "errors" - "fmt" "os/user" "syscall" "testing" "google.golang.org/protobuf/proto" - . "github.com/daos-stack/daos/src/control/common/test" + "github.com/daos-stack/daos/src/control/common/test" "github.com/daos-stack/daos/src/control/security" ) @@ -29,7 +28,7 @@ func expectAuthSysErrorForToken(t *testing.T, badToken *Token, expectedErrorMess t.Error("Expected a nil AuthSys") } - CmpErr(t, errors.New(expectedErrorMessage), err) + test.CmpErr(t, errors.New(expectedErrorMessage), err) } // AuthSysFromAuthToken tests @@ -82,30 +81,61 @@ func TestAuthSysFromAuthToken_SucceedsWithGoodToken(t *testing.T) { t.Fatal("Got a nil AuthSys") } - AssertEqual(t, authSys.GetStamp(), originalAuthSys.GetStamp(), + test.AssertEqual(t, authSys.GetStamp(), originalAuthSys.GetStamp(), "Stamps don't match") - AssertEqual(t, authSys.GetMachinename(), originalAuthSys.GetMachinename(), + test.AssertEqual(t, authSys.GetMachinename(), originalAuthSys.GetMachinename(), "Machinenames don't match") - AssertEqual(t, authSys.GetUser(), originalAuthSys.GetUser(), + test.AssertEqual(t, authSys.GetUser(), originalAuthSys.GetUser(), "Owners don't match") - AssertEqual(t, authSys.GetGroup(), originalAuthSys.GetGroup(), + test.AssertEqual(t, authSys.GetGroup(), originalAuthSys.GetGroup(), "Groups don't match") - AssertEqual(t, len(authSys.GetGroups()), len(originalAuthSys.GetGroups()), + test.AssertEqual(t, len(authSys.GetGroups()), len(originalAuthSys.GetGroups()), "Group lists aren't the same length") - AssertEqual(t, authSys.GetSecctx(), originalAuthSys.GetSecctx(), + test.AssertEqual(t, authSys.GetSecctx(), originalAuthSys.GetSecctx(), "Secctx don't match") } -// AuthSysRequestFromCreds tests +func testHostnameFn(expErr error, hostname string) func() (string, error) { + return func() (string, error) { + if expErr != nil { + return "", expErr + } + return hostname, nil + } +} -func TestAuthSysRequestFromCreds_failsIfDomainInfoNil(t *testing.T) { - result, err := AuthSysRequestFromCreds(&MockExt{}, nil, nil) +func testUserFn(expErr error, userName string) func(string) (*user.User, error) { + return func(uid string) (*user.User, error) { + if expErr != nil { + return nil, expErr + } + return &user.User{ + Uid: uid, + Gid: uid, + Username: userName, + }, nil + } +} - if result != nil { - t.Error("Expected a nil request") +func testGroupFn(expErr error, groupName string) func(string) (*user.Group, error) { + return func(gid string) (*user.Group, error) { + if expErr != nil { + return nil, expErr + } + return &user.Group{ + Gid: gid, + Name: groupName, + }, nil } +} - ExpectError(t, err, "No credentials supplied", "") +func testGroupNamesFn(expErr error, groupNames ...string) func() ([]string, error) { + return func() ([]string, error) { + if expErr != nil { + return nil, expErr + } + return groupNames, nil + } } func getTestCreds(uid uint32, gid uint32) *security.DomainInfo { @@ -116,44 +146,10 @@ func getTestCreds(uid uint32, gid uint32) *security.DomainInfo { return security.InitDomainInfo(creds, "test") } -func TestAuthSysRequestFromCreds_returnsAuthSys(t *testing.T) { - ext := &MockExt{} - uid := uint32(15) - gid := uint32(2001) - gids := []uint32{1, 2, 3} - expectedUser := "myuser" - expectedGroup := "mygroup" - expectedGroupList := []string{"group1", "group2", "group3"} - creds := getTestCreds(uid, gid) - - ext.LookupUserIDResult = &MockUser{ - username: expectedUser, - groupIDs: gids, - } - ext.LookupGroupIDResults = []*user.Group{ - { - Name: expectedGroup, - }, - } - - for _, grp := range expectedGroupList { - ext.LookupGroupIDResults = append(ext.LookupGroupIDResults, - &user.Group{ - Name: grp, - }) - } - - result, err := AuthSysRequestFromCreds(ext, creds, nil) - - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - if result == nil { - t.Fatal("Credential was nil") - } +func verifyCredential(t *testing.T, cred *Credential, expHostname, expUserPrinc, expGroupPrinc string, expGroupPrincs ...string) { + t.Helper() - token := result.GetToken() + token := cred.GetToken() if token == nil { t.Fatal("Token was nil") } @@ -163,111 +159,128 @@ func TestAuthSysRequestFromCreds_returnsAuthSys(t *testing.T) { } authsys := &Sys{} - err = proto.Unmarshal(token.GetData(), authsys) + err := proto.Unmarshal(token.GetData(), authsys) if err != nil { t.Fatal("Failed to unmarshal token data") } - if authsys.GetUser() != expectedUser+"@" { + if authsys.GetMachinename() != expHostname { + t.Errorf("AuthSys had bad hostname: %v", authsys.GetMachinename()) + } + + if authsys.GetUser() != expUserPrinc { t.Errorf("AuthSys had bad username: %v", authsys.GetUser()) } - if authsys.GetGroup() != expectedGroup+"@" { + if authsys.GetGroup() != expGroupPrinc { t.Errorf("AuthSys had bad group name: %v", authsys.GetGroup()) } for i, group := range authsys.GetGroups() { - if group != expectedGroupList[i]+"@" { + if group != expGroupPrincs[i] { t.Errorf("AuthSys had bad group in list (idx %v): %v", i, group) } } } -func TestAuthSysRequestFromCreds_UidLookupFails(t *testing.T) { - ext := &MockExt{} - uid := uint32(15) - creds := getTestCreds(uid, 500) - - ext.LookupUserIDErr = errors.New("LookupUserID test error") - expectedErr := fmt.Errorf("Failed to lookup uid %v: %v", uid, - ext.LookupUserIDErr) - - result, err := AuthSysRequestFromCreds(ext, creds, nil) - - if result != nil { - t.Error("Expected a nil result") - } - - if err == nil { - t.Fatal("Expected an error") - } - - if err.Error() != expectedErr.Error() { - t.Errorf("Expected error '%v', got '%v'", expectedErr, err) - } -} - -func TestAuthSysRequestFromCreds_GidLookupFails(t *testing.T) { - ext := &MockExt{} - gid := uint32(205) - creds := getTestCreds(12, gid) - - ext.LookupUserIDResult = &MockUser{ - username: "user@", - groupIDs: []uint32{1, 2}, - } - - ext.LookupGroupIDErr = errors.New("LookupGroupID test error") - expectedErr := fmt.Errorf("Failed to lookup gid %v: %v", gid, - ext.LookupGroupIDErr) - - result, err := AuthSysRequestFromCreds(ext, creds, nil) - - if result != nil { - t.Error("Expected a nil result") - } - - if err == nil { - t.Fatal("Expected an error") - } - - if err.Error() != expectedErr.Error() { - t.Errorf("Expected error '%v', got '%v'", expectedErr, err) - } -} - -func TestAuthSysRequestFromCreds_GroupIDListFails(t *testing.T) { - ext := &MockExt{} - creds := getTestCreds(12, 15) - testUser := &MockUser{ - username: "user@", - groupIDs: []uint32{1, 2}, +func TestAuth_GetSignedCred(t *testing.T) { + testHostname := "test-host.domain.foo" + testUsername := "test-user" + testGroup := "test-group" + testGroupList := []string{"group1", "group2", "group3"} + + expectedHostname := "test-host" + expectedUser := testUsername + "@" + expectedGroup := testGroup + "@" + expectedGroupList := make([]string, len(testGroupList)) + for i, group := range testGroupList { + expectedGroupList[i] = group + "@" } - ext.LookupUserIDResult = testUser - - ext.LookupGroupIDResults = []*user.Group{ - { - Name: "group@", + for name, tc := range map[string]struct { + req *CredentialRequest + expErr error + }{ + "nil request": { + req: nil, + expErr: errors.New("is nil"), + }, + "nil DomainInfo": { + req: &CredentialRequest{}, + expErr: errors.New("No domain info supplied"), + }, + "bad hostname": { + req: func() *CredentialRequest { + req := NewCredentialRequest(getTestCreds(1, 2), nil) + req.getHostnameFn = testHostnameFn(errors.New("bad hostname"), "") + return req + }(), + expErr: errors.New("bad hostname"), + }, + "bad uid": { + req: func() *CredentialRequest { + req := NewCredentialRequest(getTestCreds(1, 2), nil) + req.getUserFn = testUserFn(errors.New("bad uid"), "") + return req + }(), + expErr: errors.New("bad uid"), + }, + "bad gid": { + req: func() *CredentialRequest { + req := NewCredentialRequest(getTestCreds(1, 2), nil) + req.getGroupFn = testGroupFn(errors.New("bad gid"), "") + return req + }(), + expErr: errors.New("bad gid"), + }, + "bad group IDs": { + req: func() *CredentialRequest { + req := NewCredentialRequest(getTestCreds(1, 2), nil) + req.getGroupIdsFn = testGroupNamesFn(errors.New("bad group IDs")) + return req + }(), + expErr: errors.New("bad group IDs"), + }, + "bad group names": { + req: func() *CredentialRequest { + req := NewCredentialRequest(getTestCreds(1, 2), nil) + req.getGroupNamesFn = testGroupNamesFn(errors.New("bad group names")) + return req + }(), + expErr: errors.New("bad group names"), }, + "valid": { + req: func() *CredentialRequest { + req := NewCredentialRequest(getTestCreds(1, 2), nil) + req.getHostnameFn = testHostnameFn(nil, testHostname) + req.getUserFn = testUserFn(nil, testUsername) + req.getGroupFn = testGroupFn(nil, testGroup) + req.getGroupNamesFn = testGroupNamesFn(nil, testGroupList...) + return req + }(), + }, + } { + t.Run(name, func(t *testing.T) { + cred, gotErr := GetSignedCredential(tc.req) + test.CmpErr(t, tc.expErr, gotErr) + if tc.expErr != nil { + return + } + + verifyCredential(t, cred, expectedHostname, expectedUser, expectedGroup, expectedGroupList...) + }) } +} - testUser.groupIDErr = errors.New("GroupIDs test error") - expectedErr := fmt.Errorf("Failed to get group IDs for user %v: %v", - testUser.username, - testUser.groupIDErr) - - result, err := AuthSysRequestFromCreds(ext, creds, nil) +func TestAuth_CredentialRequestOverrides(t *testing.T) { + req := NewCredentialRequest(getTestCreds(1, 2), nil) + req.getHostnameFn = testHostnameFn(nil, "test-host") + req.WithUserAndGroup("test-user", "test-group", "test-secondary") - if result != nil { - t.Error("Expected a nil result") - } - - if err == nil { - t.Fatal("Expected an error") + cred, err := GetSignedCredential(req) + if err != nil { + t.Fatalf("Failed to get credential: %s", err) } - if err.Error() != expectedErr.Error() { - t.Errorf("Expected error '%v', got '%v'", expectedErr, err) - } + verifyCredential(t, cred, "test-host", "test-user@", "test-group@", "test-secondary@") } diff --git a/src/control/security/auth/mocks.go b/src/control/security/auth/mocks.go deleted file mode 100644 index 3cde0f9471b..00000000000 --- a/src/control/security/auth/mocks.go +++ /dev/null @@ -1,65 +0,0 @@ -// -// (C) Copyright 2020-2021 Intel Corporation. -// -// SPDX-License-Identifier: BSD-2-Clause-Patent -// - -package auth - -import ( - "os/user" - - "github.com/pkg/errors" -) - -// Mocks - -type MockUser struct { - username string - groupIDs []uint32 - groupIDErr error -} - -func (u *MockUser) Username() string { - return u.username -} - -func (u *MockUser) GroupIDs() ([]uint32, error) { - return u.groupIDs, u.groupIDErr -} - -func (u *MockUser) Gid() (uint32, error) { - if len(u.groupIDs) == 0 { - return 0, errors.New("no mock gids to return") - } - return u.groupIDs[0], nil -} - -type MockExt struct { - LookupUserIDUid uint32 - LookupUserIDResult User - LookupUserIDErr error - LookupGroupIDGid uint32 - LookupGroupIDResults []*user.Group - LookupGroupIDCallCount uint32 - LookupGroupIDErr error -} - -func (e *MockExt) Current() (User, error) { - return e.LookupUserIDResult, e.LookupUserIDErr -} - -func (e *MockExt) LookupUserID(uid uint32) (User, error) { - e.LookupUserIDUid = uid - return e.LookupUserIDResult, e.LookupUserIDErr -} - -func (e *MockExt) LookupGroupID(gid uint32) (*user.Group, error) { - e.LookupGroupIDGid = gid - var result *user.Group - if len(e.LookupGroupIDResults) > 0 { - result = e.LookupGroupIDResults[e.LookupGroupIDCallCount] - } - e.LookupGroupIDCallCount++ - return result, e.LookupGroupIDErr -} diff --git a/src/control/security/config.go b/src/control/security/config.go index 8eef8584593..2bda4bf4a00 100644 --- a/src/control/security/config.go +++ b/src/control/security/config.go @@ -1,5 +1,5 @@ // -// (C) Copyright 2019-2023 Intel Corporation. +// (C) Copyright 2019-2024 Intel Corporation. // // SPDX-License-Identifier: BSD-2-Clause-Patent // @@ -13,6 +13,7 @@ import ( "fmt" "io/fs" "os" + "strconv" "time" "github.com/pkg/errors" @@ -32,6 +33,70 @@ const ( defaultInsecure = false ) +// MappedClientUser represents a client user that is mapped to a uid. +type MappedClientUser struct { + User string `yaml:"user"` + Group string `yaml:"group"` + Groups []string `yaml:"groups"` +} + +const ( + defaultMapUser = "default" + defaultMapKey = ^uint32(0) +) + +// ClientUserMap is a map of uids to mapped client users. +type ClientUserMap map[uint32]*MappedClientUser + +func (cm *ClientUserMap) UnmarshalYAML(unmarshal func(interface{}) error) error { + strKeyMap := make(map[string]*MappedClientUser) + if err := unmarshal(&strKeyMap); err != nil { + return err + } + + tmp := make(ClientUserMap) + for strKey, value := range strKeyMap { + var key uint32 + switch strKey { + case defaultMapUser: + key = defaultMapKey + default: + parsedKey, err := strconv.ParseUint(strKey, 10, 32) + if err != nil { + return errors.Wrapf(err, "invalid uid %s", strKey) + } + + switch parsedKey { + case uint64(defaultMapKey): + return errors.Errorf("uid %d is reserved", parsedKey) + default: + key = uint32(parsedKey) + } + } + + tmp[key] = value + } + *cm = tmp + + return nil +} + +// Lookup attempts to resolve the supplied uid to a mapped +// client user. If the uid is not in the map, the default map key +// is returned. If the default map key is not found, nil is returned. +func (cm ClientUserMap) Lookup(uid uint32) *MappedClientUser { + if mu, found := cm[uid]; found { + return mu + } + return cm[defaultMapKey] +} + +// CredentialConfig contains configuration details for managing user +// credentials. +type CredentialConfig struct { + ClientUserMap ClientUserMap `yaml:"client_user_map,omitempty"` +} + // TransportConfig contains all the information on whether or not to use // certificates and their location if their use is specified. type TransportConfig struct { diff --git a/src/control/security/config_test.go b/src/control/security/config_test.go index 8c94a99810a..450b9f6fbd2 100644 --- a/src/control/security/config_test.go +++ b/src/control/security/config_test.go @@ -1,5 +1,5 @@ // -// (C) Copyright 2019-2023 Intel Corporation. +// (C) Copyright 2019-2024 Intel Corporation. // // SPDX-License-Identifier: BSD-2-Clause-Patent // @@ -11,14 +11,17 @@ import ( "crypto" "crypto/x509" "encoding/pem" + "fmt" "io/ioutil" "os" "testing" "time" - "github.com/daos-stack/daos/src/control/common/test" "github.com/google/go-cmp/cmp" "github.com/pkg/errors" + "gopkg.in/yaml.v2" + + "github.com/daos-stack/daos/src/control/common/test" ) func InsecureTC() *TransportConfig { @@ -353,3 +356,100 @@ func TestSecurity_DefaultTransportConfigs(t *testing.T) { }) } } + +func TestSecurity_ClientUserMap(t *testing.T) { + for name, tc := range map[string]struct { + cfgYaml string + expMap ClientUserMap + expErr error + }{ + "empty": {}, + "defaultKey": { + cfgYaml: fmt.Sprintf(` +%d: + user: whoops +`, defaultMapKey), + expErr: errors.New("reserved"), + }, + "invalid uid (negative)": { + cfgYaml: ` +-1: + user: whoops +`, + expErr: errors.New("invalid uid"), + }, + "invalid uid (words)": { + cfgYaml: ` +blah: + user: whoops +`, + expErr: errors.New("invalid uid"), + }, + "invalid mapped user": { + cfgYaml: ` +1234: +user: whoops +`, + expErr: errors.New("unmarshal error"), + }, + "good": { + cfgYaml: ` +default: + user: banana + group: rama + groups: [ding, dong] +1234: + user: abc + group: def + groups: [yabba, dabba, doo] +5678: + user: ghi + group: jkl + groups: [mno, pqr, stu] +`, + expMap: ClientUserMap{ + defaultMapKey: { + User: "banana", + Group: "rama", + Groups: []string{"ding", "dong"}, + }, + 1234: { + User: "abc", + Group: "def", + Groups: []string{"yabba", "dabba", "doo"}, + }, + 5678: { + User: "ghi", + Group: "jkl", + Groups: []string{"mno", "pqr", "stu"}, + }, + }, + }, + } { + t.Run(name, func(t *testing.T) { + var result ClientUserMap + err := yaml.Unmarshal([]byte(tc.cfgYaml), &result) + test.CmpErr(t, tc.expErr, err) + if tc.expErr != nil { + return + } + if diff := cmp.Diff(tc.expMap, result); diff != "" { + t.Fatalf("unexpected ClientUserMap (-want, +got)\n %s", diff) + } + + for uid, exp := range tc.expMap { + gotUser := result.Lookup(uid) + if diff := cmp.Diff(exp.User, gotUser.User); diff != "" { + t.Fatalf("unexpected User (-want, +got)\n %s", diff) + } + } + + if expDefUser, found := tc.expMap[defaultMapKey]; found { + gotDefUser := result.Lookup(1234567) + if diff := cmp.Diff(expDefUser, gotDefUser); diff != "" { + t.Fatalf("unexpected DefaultUser (-want, +got)\n %s", diff) + } + } + }) + } +} diff --git a/src/control/security/domain_info.go b/src/control/security/domain_info.go index 15bfb06dd77..f0812c17f37 100644 --- a/src/control/security/domain_info.go +++ b/src/control/security/domain_info.go @@ -1,5 +1,5 @@ // -// (C) Copyright 2018-2021 Intel Corporation. +// (C) Copyright 2018-2024 Intel Corporation. // // SPDX-License-Identifier: BSD-2-Clause-Patent // @@ -63,6 +63,11 @@ func (d *DomainInfo) String() string { return outStr } +// Pid returns the PID obtained from the domain socket +func (d *DomainInfo) Pid() int32 { + return d.creds.Pid +} + // Uid returns the UID obtained from the domain socket func (d *DomainInfo) Uid() uint32 { return d.creds.Uid diff --git a/utils/config/daos_agent.yml b/utils/config/daos_agent.yml index 4a7f13b3654..3e7e8f1ab95 100644 --- a/utils/config/daos_agent.yml +++ b/utils/config/daos_agent.yml @@ -47,7 +47,23 @@ ## default 0 (do not retain telemetry after client exit) #telemetry_retain: 1m -## Transport Credentials Specifying certificates to secure communications +## Configuration for user credential management. +# +#credential_config: +# # If the agent should be able to resolve unknown client uids and gids +# # (e.g. when running in a container) into ACL principal names, then a +# # client user map may be defined. The optional "default" uid is a special +# # case and applies if no other matches are found. +# client_user_map: +# default: +# user: nobody +# group: nobody +# 1000: +# user: ralph +# group: stanley + +## Configuration for SSL certificates used to secure management traffic +# and authenticate/authorize management components. # #transport_config: # # In order to disable transport security, uncomment and set allow_insecure @@ -60,7 +76,7 @@ # cert: /etc/daos/certs/agent.crt # # Key portion of Agent Certificate # key: /etc/daos/certs/agent.key - +# # Use the given directory for creating unix domain sockets # # NOTE: Do not change this when running under systemd control. If it needs to From 944c359c53afdb614b6910ea8332fa190f77b086 Mon Sep 17 00:00:00 2001 From: Michael MacDonald Date: Mon, 20 May 2024 21:03:50 +0000 Subject: [PATCH 3/3] DAOS-15874 control: Add optional credential cache to agent On heavily-loaded client nodes where many processes are being launched by the same user or users, the admin may optionally enable the credential cache in the agent in order to lower agent overhead caused by generating identical credentials for each process owned by a user. The agent-generated credential is presented by the client process during pool/container connection and is used to evaluate ACL permissions for that connection. Example config: credential_config: cache_lifetime: 1m Features: control Required-githooks: true Change-Id: I6ae2a8be1dd97ef14e0ccef0283d65bc1fabc4ed Signed-off-by: Michael MacDonald --- src/control/cmd/daos_agent/config_test.go | 2 + src/control/cmd/daos_agent/infocache.go | 75 ++++++++-- src/control/cmd/daos_agent/infocache_test.go | 36 +++-- src/control/cmd/daos_agent/mgmt_rpc.go | 11 +- src/control/cmd/daos_agent/mgmt_rpc_test.go | 3 +- src/control/cmd/daos_agent/security_rpc.go | 111 +++++++++++++-- .../cmd/daos_agent/security_rpc_test.go | 128 ++++++++++++++++-- src/control/lib/cache/cache.go | 117 ++++++++++------ src/control/lib/cache/cache_test.go | 20 ++- src/control/security/auth/auth_sys.go | 5 +- src/control/security/auth/auth_sys_test.go | 4 +- src/control/security/config.go | 1 + utils/config/daos_agent.yml | 16 ++- 13 files changed, 424 insertions(+), 105 deletions(-) diff --git a/src/control/cmd/daos_agent/config_test.go b/src/control/cmd/daos_agent/config_test.go index 162b60bbae0..6f0b7562660 100644 --- a/src/control/cmd/daos_agent/config_test.go +++ b/src/control/cmd/daos_agent/config_test.go @@ -47,6 +47,7 @@ disable_caching: true cache_expiration: 30 disable_auto_evict: true credential_config: + cache_lifetime: 10m client_user_map: 1000: user: frodo @@ -140,6 +141,7 @@ transport_config: CacheExpiration: refreshMinutes(30 * time.Minute), DisableAutoEvict: true, CredentialConfig: &security.CredentialConfig{ + CacheLifetime: time.Minute * 10, ClientUserMap: map[uint32]*security.MappedClientUser{ 1000: { User: "frodo", diff --git a/src/control/cmd/daos_agent/infocache.go b/src/control/cmd/daos_agent/infocache.go index cb777396ff1..9fc942f9290 100644 --- a/src/control/cmd/daos_agent/infocache.go +++ b/src/control/cmd/daos_agent/infocache.go @@ -80,11 +80,13 @@ func getFabricScanFn(log logging.Logger, cfg *Config, scanner *hardware.FabricSc } type cacheItem struct { - sync.Mutex + sync.RWMutex lastCached time.Time refreshInterval time.Duration } +// isStale returns true if the cache item is stale. +// NB: Should be run under a lock to protect lastCached. func (ci *cacheItem) isStale() bool { if ci.refreshInterval == 0 { return false @@ -92,6 +94,8 @@ func (ci *cacheItem) isStale() bool { return ci.lastCached.Add(ci.refreshInterval).Before(time.Now()) } +// isCached returns true if the cache item is cached. +// NB: Should be run under at least a read lock to protect lastCached. func (ci *cacheItem) isCached() bool { return !ci.lastCached.Equal(time.Time{}) } @@ -130,16 +134,30 @@ func (ci *cachedAttachInfo) Key() string { return sysAttachInfoKey(ci.system) } -// NeedsRefresh checks whether the cached data needs to be refreshed. -func (ci *cachedAttachInfo) NeedsRefresh() bool { +// needsRefresh checks whether the cached data needs to be refreshed. +func (ci *cachedAttachInfo) needsRefresh() bool { if ci == nil { return false } return !ci.isCached() || ci.isStale() } -// Refresh contacts the remote management server and refreshes the GetAttachInfo cache. -func (ci *cachedAttachInfo) Refresh(ctx context.Context) error { +// RefreshIfNeeded refreshes the cached data if it needs to be refreshed. +func (ci *cachedAttachInfo) RefreshIfNeeded(ctx context.Context) (func(), bool, error) { + if ci == nil { + return cache.NoopRelease, false, errors.New("cachedAttachInfo is nil") + } + + ci.Lock() + if ci.needsRefresh() { + return ci.Unlock, true, ci.refresh(ctx) + } + return ci.Unlock, false, nil +} + +// refresh implements the actual refresh logic. +// NB: Should be run under a lock. +func (ci *cachedAttachInfo) refresh(ctx context.Context) error { if ci == nil { return errors.New("cachedAttachInfo is nil") } @@ -155,6 +173,16 @@ func (ci *cachedAttachInfo) Refresh(ctx context.Context) error { return nil } +// Refresh contacts the remote management server and refreshes the GetAttachInfo cache. +func (ci *cachedAttachInfo) Refresh(ctx context.Context) (func(), error) { + if ci == nil { + return cache.NoopRelease, errors.New("cachedAttachInfo is nil") + } + + ci.Lock() + return ci.Unlock, ci.refresh(ctx) +} + type cachedFabricInfo struct { cacheItem fetch fabricScanFn @@ -172,17 +200,31 @@ func (cfi *cachedFabricInfo) Key() string { return fabricKey } -// NeedsRefresh indicates that the fabric information does not need to be refreshed unless it has +// needsRefresh indicates that the fabric information does not need to be refreshed unless it has // never been populated. -func (cfi *cachedFabricInfo) NeedsRefresh() bool { +func (cfi *cachedFabricInfo) needsRefresh() bool { if cfi == nil { return false } return !cfi.isCached() } -// Refresh scans the hardware for information about the fabric devices and caches the result. -func (cfi *cachedFabricInfo) Refresh(ctx context.Context) error { +// RefreshIfNeeded refreshes the cached fabric information if it needs to be refreshed. +func (cfi *cachedFabricInfo) RefreshIfNeeded(ctx context.Context) (func(), bool, error) { + if cfi == nil { + return cache.NoopRelease, false, errors.New("cachedFabricInfo is nil") + } + + cfi.Lock() + if cfi.needsRefresh() { + return cfi.Unlock, true, cfi.refresh(ctx) + } + return cfi.Unlock, false, nil +} + +// refresh implements the actual refresh logic. +// NB: Should be run under a lock. +func (cfi *cachedFabricInfo) refresh(ctx context.Context) error { if cfi == nil { return errors.New("cachedFabricInfo is nil") } @@ -197,6 +239,16 @@ func (cfi *cachedFabricInfo) Refresh(ctx context.Context) error { return nil } +// Refresh scans the hardware for information about the fabric devices and caches the result. +func (cfi *cachedFabricInfo) Refresh(ctx context.Context) (func(), error) { + if cfi == nil { + return cache.NoopRelease, errors.New("cachedFabricInfo is nil") + } + + cfi.Lock() + return cfi.Unlock, cfi.refresh(ctx) +} + // InfoCache is a cache for the results of expensive operations needed by the agent. type InfoCache struct { log logging.Logger @@ -350,15 +402,14 @@ func (c *InfoCache) GetAttachInfo(ctx context.Context, sys string) (*control.Get } createItem := func() (cache.Item, error) { c.log.Debugf("cache miss for %s", sysAttachInfoKey(sys)) - cai := newCachedAttachInfo(c.attachInfoRefresh, sys, c.client, c.getAttachInfo) - return cai, nil + return newCachedAttachInfo(c.attachInfoRefresh, sys, c.client, c.getAttachInfo), nil } item, release, err := c.cache.GetOrCreate(ctx, sysAttachInfoKey(sys), createItem) - defer release() if err != nil { return nil, errors.Wrap(err, "getting attach info from cache") } + defer release() cai, ok := item.(*cachedAttachInfo) if !ok { diff --git a/src/control/cmd/daos_agent/infocache_test.go b/src/control/cmd/daos_agent/infocache_test.go index e86c44bfc0c..36954849468 100644 --- a/src/control/cmd/daos_agent/infocache_test.go +++ b/src/control/cmd/daos_agent/infocache_test.go @@ -149,14 +149,22 @@ func TestAgent_cachedAttachInfo_Key(t *testing.T) { } } -func TestAgent_cachedAttachInfo_NeedsRefresh(t *testing.T) { +func TestAgent_cachedAttachInfo_RefreshIfNeeded(t *testing.T) { + log, buf := logging.NewTestLogger(t.Name()) + defer test.ShowBufferOnFailure(t, buf) + mockClient := control.NewMockInvoker(log, &control.MockInvokerConfig{}) + + noopGetAttachInfo := func(_ context.Context, _ control.UnaryInvoker, _ *control.GetAttachInfoReq) (*control.GetAttachInfoResp, error) { + return nil, nil + } + for name, tc := range map[string]struct { ai *cachedAttachInfo expResult bool }{ "nil": {}, "never cached": { - ai: newCachedAttachInfo(0, "test", nil, nil), + ai: newCachedAttachInfo(0, "test", mockClient, noopGetAttachInfo), expResult: true, }, "no refresh": { @@ -164,6 +172,8 @@ func TestAgent_cachedAttachInfo_NeedsRefresh(t *testing.T) { cacheItem: cacheItem{ lastCached: time.Now().Add(-time.Minute), }, + rpcClient: mockClient, + fetch: noopGetAttachInfo, lastResponse: &control.GetAttachInfoResp{}, }, }, @@ -173,6 +183,8 @@ func TestAgent_cachedAttachInfo_NeedsRefresh(t *testing.T) { lastCached: time.Now().Add(-time.Minute), refreshInterval: time.Second, }, + rpcClient: mockClient, + fetch: noopGetAttachInfo, lastResponse: &control.GetAttachInfoResp{}, }, expResult: true, @@ -183,12 +195,15 @@ func TestAgent_cachedAttachInfo_NeedsRefresh(t *testing.T) { lastCached: time.Now().Add(-time.Second), refreshInterval: time.Minute, }, + rpcClient: mockClient, + fetch: noopGetAttachInfo, lastResponse: &control.GetAttachInfoResp{}, }, }, } { t.Run(name, func(t *testing.T) { - test.AssertEqual(t, tc.expResult, tc.ai.NeedsRefresh(), "") + _, refreshed, _ := tc.ai.RefreshIfNeeded(test.Context(t)) + test.AssertEqual(t, tc.expResult, refreshed, "") }) } } @@ -271,7 +286,8 @@ func TestAgent_cachedAttachInfo_Refresh(t *testing.T) { } } - err := ai.Refresh(test.Context(t)) + release, err := ai.Refresh(test.Context(t)) + release() test.CmpErr(t, tc.expErr, err) @@ -320,7 +336,7 @@ func TestAgent_cachedFabricInfo_Key(t *testing.T) { } } -func TestAgent_cachedFabricInfo_NeedsRefresh(t *testing.T) { +func TestAgent_cachedFabricInfo_RefreshIfNeeded(t *testing.T) { for name, tc := range map[string]struct { nilCache bool cacheTime time.Time @@ -342,11 +358,14 @@ func TestAgent_cachedFabricInfo_NeedsRefresh(t *testing.T) { var cfi *cachedFabricInfo if !tc.nilCache { - cfi = newCachedFabricInfo(log, nil) + cfi = newCachedFabricInfo(log, func(_ context.Context, _ ...string) (*NUMAFabric, error) { + return nil, nil + }) cfi.cacheItem.lastCached = tc.cacheTime } - test.AssertEqual(t, tc.expResult, cfi.NeedsRefresh(), "") + _, refreshed, _ := cfi.RefreshIfNeeded(test.Context(t)) + test.AssertEqual(t, tc.expResult, refreshed, "") }) } } @@ -416,7 +435,8 @@ func TestAgent_cachedFabricInfo_Refresh(t *testing.T) { } } - err := cfi.Refresh(test.Context(t)) + release, err := cfi.Refresh(test.Context(t)) + release() test.CmpErr(t, tc.expErr, err) diff --git a/src/control/cmd/daos_agent/mgmt_rpc.go b/src/control/cmd/daos_agent/mgmt_rpc.go index 75dc337e313..51cb80984c8 100644 --- a/src/control/cmd/daos_agent/mgmt_rpc.go +++ b/src/control/cmd/daos_agent/mgmt_rpc.go @@ -8,7 +8,6 @@ package main import ( "net" - "sync" "github.com/pkg/errors" "golang.org/x/net/context" @@ -22,6 +21,7 @@ import ( "github.com/daos-stack/daos/src/control/drpc" "github.com/daos-stack/daos/src/control/fault" "github.com/daos-stack/daos/src/control/fault/code" + "github.com/daos-stack/daos/src/control/lib/atm" "github.com/daos-stack/daos/src/control/lib/control" "github.com/daos-stack/daos/src/control/lib/daos" "github.com/daos-stack/daos/src/control/lib/hardware" @@ -34,16 +34,13 @@ import ( // Management Service proxy, handling dRPCs sent by libdaos by forwarding them // to MS. type mgmtModule struct { - attachInfoMutex sync.RWMutex - fabricMutex sync.RWMutex - log logging.Logger sys string ctlInvoker control.Invoker cache *InfoCache monitor *procMon cliMetricsSrc *promexp.ClientSource - useDefaultNUMA bool + useDefaultNUMA atm.Bool numaGetter hardware.ProcessNUMAProvider } @@ -161,14 +158,14 @@ func (mod *mgmtModule) handleGetAttachInfo(ctx context.Context, reqb []byte, pid } func (mod *mgmtModule) getNUMANode(ctx context.Context, pid int32) (uint, error) { - if mod.useDefaultNUMA { + if mod.useDefaultNUMA.IsTrue() { return 0, nil } numaNode, err := mod.numaGetter.GetNUMANodeIDForPID(ctx, pid) if errors.Is(err, hardware.ErrNoNUMANodes) { mod.log.Debug("system is not NUMA-aware") - mod.useDefaultNUMA = true + mod.useDefaultNUMA.SetTrue() return 0, nil } else if err != nil { return 0, errors.Wrapf(err, "failed to get NUMA node ID for pid %d", pid) diff --git a/src/control/cmd/daos_agent/mgmt_rpc_test.go b/src/control/cmd/daos_agent/mgmt_rpc_test.go index 59fcb507a81..965f3dc47c9 100644 --- a/src/control/cmd/daos_agent/mgmt_rpc_test.go +++ b/src/control/cmd/daos_agent/mgmt_rpc_test.go @@ -27,6 +27,7 @@ import ( "github.com/daos-stack/daos/src/control/drpc" "github.com/daos-stack/daos/src/control/fault" "github.com/daos-stack/daos/src/control/fault/code" + "github.com/daos-stack/daos/src/control/lib/atm" "github.com/daos-stack/daos/src/control/lib/control" "github.com/daos-stack/daos/src/control/lib/daos" "github.com/daos-stack/daos/src/control/lib/hardware" @@ -334,7 +335,7 @@ func TestAgent_mgmtModule_getNUMANode(t *testing.T) { mod := &mgmtModule{ log: log, - useDefaultNUMA: tc.useDefaultNUMA, + useDefaultNUMA: atm.NewBool(tc.useDefaultNUMA), numaGetter: tc.numaGetter, } diff --git a/src/control/cmd/daos_agent/security_rpc.go b/src/control/cmd/daos_agent/security_rpc.go index 8843f18ed6f..7fc6b942cd9 100644 --- a/src/control/cmd/daos_agent/security_rpc.go +++ b/src/control/cmd/daos_agent/security_rpc.go @@ -8,12 +8,15 @@ package main import ( "context" + "fmt" "net" "os/user" + "time" "github.com/pkg/errors" "github.com/daos-stack/daos/src/control/drpc" + "github.com/daos-stack/daos/src/control/lib/cache" "github.com/daos-stack/daos/src/control/lib/daos" "github.com/daos-stack/daos/src/control/logging" "github.com/daos-stack/daos/src/control/security" @@ -21,7 +24,24 @@ import ( ) type ( - credSignerFn func(*auth.CredentialRequest) (*auth.Credential, error) + // credSignerFn defines the function signature for signing credentials. + credSignerFn func(context.Context, *auth.CredentialRequest) (*auth.Credential, error) + + // credentialCache implements a cache for signed credentials. + credentialCache struct { + log logging.Logger + cache *cache.ItemCache + credLifetime time.Duration + cacheMissFn credSignerFn + } + + // cachedCredential wraps a cached credential and implements the cache.ExpirableItem interface. + cachedCredential struct { + cacheItem + key string + expiredAt time.Time + cred *auth.Credential + } // securityConfig defines configuration parameters for SecurityModule. securityConfig struct { @@ -33,32 +53,106 @@ type ( SecurityModule struct { log logging.Logger signCredential credSignerFn + credCache *credentialCache config *securityConfig } ) -// NewSecurityModule creates a new module with the given initialized TransportConfig +// NewSecurityModule creates a new module with the given initialized TransportConfig. func NewSecurityModule(log logging.Logger, cfg *securityConfig) *SecurityModule { + var credCache *credentialCache + credSigner := auth.GetSignedCredential + if cfg.credentials.CacheLifetime > 0 { + cache := &credentialCache{ + log: log, + cache: cache.NewItemCache(log), + credLifetime: cfg.credentials.CacheLifetime, + cacheMissFn: auth.GetSignedCredential, + } + credSigner = cache.getSignedCredential + log.Noticef("credential cache enabled (entry lifetime: %s)", cfg.credentials.CacheLifetime) + } + return &SecurityModule{ log: log, - signCredential: auth.GetSignedCredential, + signCredential: credSigner, + credCache: credCache, config: cfg, } } +func credReqKey(req *auth.CredentialRequest) string { + return fmt.Sprintf("%d:%d:%s", req.DomainInfo.Uid(), req.DomainInfo.Gid(), req.DomainInfo.Ctx()) +} + +func (cred *cachedCredential) Key() string { + if cred == nil { + return "" + } + + return cred.key +} + +func (cred *cachedCredential) IsExpired() bool { + if cred == nil || cred.cred == nil || cred.expiredAt.IsZero() { + return true + } + + return time.Now().After(cred.expiredAt) +} + +func (cc *credentialCache) getSignedCredential(ctx context.Context, req *auth.CredentialRequest) (*auth.Credential, error) { + key := credReqKey(req) + + createItem := func() (cache.Item, error) { + cc.log.Tracef("cache miss for %s", key) + cred, err := cc.cacheMissFn(ctx, req) + if err != nil { + return nil, err + } + cc.log.Tracef("getting credential for %s", key) + return newCachedCredential(key, cred, cc.credLifetime) + } + + item, release, err := cc.cache.GetOrCreate(ctx, key, createItem) + if err != nil { + return nil, errors.Wrap(err, "getting cached credential from cache") + } + defer release() + + cachedCred, ok := item.(*cachedCredential) + if !ok { + return nil, errors.New("invalid cached credential") + } + + return cachedCred.cred, nil +} + +func newCachedCredential(key string, cred *auth.Credential, lifetime time.Duration) (*cachedCredential, error) { + if cred == nil { + return nil, errors.New("credential is nil") + } + + return &cachedCredential{ + key: key, + cred: cred, + expiredAt: time.Now().Add(lifetime), + }, nil +} + // HandleCall is the handler for calls to the SecurityModule -func (m *SecurityModule) HandleCall(_ context.Context, session *drpc.Session, method drpc.Method, body []byte) ([]byte, error) { +func (m *SecurityModule) HandleCall(ctx context.Context, session *drpc.Session, method drpc.Method, body []byte) ([]byte, error) { if method != drpc.MethodRequestCredentials { return nil, drpc.UnknownMethodFailure() } - return m.getCredential(session) + return m.getCredential(ctx, session) } // getCredentials generates a signed user credential based on the data attached to // the Unix Domain Socket. -func (m *SecurityModule) getCredential(session *drpc.Session) ([]byte, error) { +func (m *SecurityModule) getCredential(ctx context.Context, session *drpc.Session) ([]byte, error) { if session == nil { return nil, drpc.NewFailureWithMessage("session is nil") } @@ -82,7 +176,7 @@ func (m *SecurityModule) getCredential(session *drpc.Session) ([]byte, error) { } req := auth.NewCredentialRequest(info, signingKey) - cred, err := m.signCredential(req) + cred, err := m.signCredential(ctx, req) if err != nil { if err := func() error { if !errors.Is(err, user.UnknownUserIdError(info.Uid())) { @@ -95,7 +189,7 @@ func (m *SecurityModule) getCredential(session *drpc.Session) ([]byte, error) { } req.WithUserAndGroup(mu.User, mu.Group, mu.Groups...) - cred, err = m.signCredential(req) + cred, err = m.signCredential(ctx, req) if err != nil { return err } @@ -107,7 +201,6 @@ func (m *SecurityModule) getCredential(session *drpc.Session) ([]byte, error) { } } - m.log.Tracef("%s: successfully signed credential", info) resp := &auth.GetCredResp{Cred: cred} return drpc.Marshal(resp) } diff --git a/src/control/cmd/daos_agent/security_rpc_test.go b/src/control/cmd/daos_agent/security_rpc_test.go index 08ee4755b27..52e298f0e88 100644 --- a/src/control/cmd/daos_agent/security_rpc_test.go +++ b/src/control/cmd/daos_agent/security_rpc_test.go @@ -7,17 +7,22 @@ package main import ( + "context" "errors" "net" "os/user" + "syscall" "testing" + "time" "github.com/google/go-cmp/cmp" "golang.org/x/sys/unix" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/testing/protocmp" "github.com/daos-stack/daos/src/control/common/test" "github.com/daos-stack/daos/src/control/drpc" + "github.com/daos-stack/daos/src/control/lib/cache" "github.com/daos-stack/daos/src/control/lib/daos" "github.com/daos-stack/daos/src/control/logging" "github.com/daos-stack/daos/src/control/security" @@ -28,7 +33,7 @@ func TestAgentSecurityModule_ID(t *testing.T) { log, buf := logging.NewTestLogger(t.Name()) defer test.ShowBufferOnFailure(t, buf) - mod := NewSecurityModule(log, nil) + mod := NewSecurityModule(log, defaultTestSecurityConfig()) test.AssertEqual(t, mod.ID(), drpc.ModuleSecurityAgent, "wrong drpc module") } @@ -49,7 +54,7 @@ func TestAgentSecurityModule_BadMethod(t *testing.T) { log, buf := logging.NewTestLogger(t.Name()) defer test.ShowBufferOnFailure(t, buf) - mod := NewSecurityModule(log, nil) + mod := NewSecurityModule(log, defaultTestSecurityConfig()) method, err := mod.ID().GetMethod(-1) if method != nil { t.Errorf("Expected no method, got %+v", method) @@ -166,7 +171,8 @@ func TestAgentSecurityModule_RequestCreds_BadConfig(t *testing.T) { // Empty TransportConfig is incomplete mod := NewSecurityModule(log, &securityConfig{ - transport: &security.TransportConfig{}, + transport: &security.TransportConfig{}, + credentials: &security.CredentialConfig{}, }) respBytes, err := callRequestCreds(mod, t, log, conn) @@ -186,7 +192,7 @@ func TestAgentSecurityModule_RequestCreds_BadUid(t *testing.T) { defer cleanup() mod := NewSecurityModule(log, defaultTestSecurityConfig()) - mod.signCredential = func(_ *auth.CredentialRequest) (*auth.Credential, error) { + mod.signCredential = func(_ context.Context, _ *auth.CredentialRequest) (*auth.Credential, error) { return nil, errors.New("LookupUserID") } respBytes, err := callRequestCreds(mod, t, log, conn) @@ -198,11 +204,12 @@ func TestAgentSecurityModule_RequestCreds_BadUid(t *testing.T) { expectCredResp(t, respBytes, int32(daos.MiscError), false) } +type signCredentialResp struct { + cred *auth.Credential + err error +} + func TestAgent_SecurityRPC_getCredential(t *testing.T) { - type response struct { - cred *auth.Credential - err error - } testCred := &auth.Credential{ Token: &auth.Token{Flavor: auth.Flavor_AUTH_SYS, Data: []byte("test-token")}, Origin: "test-origin", @@ -227,13 +234,13 @@ func TestAgent_SecurityRPC_getCredential(t *testing.T) { for name, tc := range map[string]struct { secCfg *securityConfig - responses []response + responses []signCredentialResp expBytes []byte expErr error }{ "lookup miss": { secCfg: defaultTestSecurityConfig(), - responses: []response{ + responses: []signCredentialResp{ { cred: nil, err: user.UnknownUserIdError(unix.Getuid()), @@ -251,7 +258,7 @@ func TestAgent_SecurityRPC_getCredential(t *testing.T) { } return cfg }(), - responses: []response{ + responses: []signCredentialResp{ { cred: nil, err: user.UnknownUserIdError(unix.Getuid()), @@ -273,7 +280,7 @@ func TestAgent_SecurityRPC_getCredential(t *testing.T) { } return cfg }(), - responses: []response{ + responses: []signCredentialResp{ { cred: nil, err: user.UnknownUserIdError(unix.Getuid()), @@ -295,9 +302,9 @@ func TestAgent_SecurityRPC_getCredential(t *testing.T) { defer cleanup() mod := NewSecurityModule(log, tc.secCfg) - mod.signCredential = func() func(req *auth.CredentialRequest) (*auth.Credential, error) { + mod.signCredential = func() func(_ context.Context, req *auth.CredentialRequest) (*auth.Credential, error) { var idx int - return func(req *auth.CredentialRequest) (*auth.Credential, error) { + return func(_ context.Context, req *auth.CredentialRequest) (*auth.Credential, error) { defer func() { if idx < len(tc.responses)-1 { idx++ @@ -319,3 +326,96 @@ func TestAgent_SecurityRPC_getCredential(t *testing.T) { }) } } + +func TestAgent_SecurityCachedCredentials(t *testing.T) { + cred0 := &auth.Credential{ + Token: &auth.Token{Flavor: auth.Flavor_AUTH_SYS, Data: []byte("user,group1,group2")}, + Origin: "test-origin", + } + cred1 := &auth.Credential{ + Token: &auth.Token{Flavor: auth.Flavor_AUTH_SYS, Data: []byte("user,group1,group3")}, + Origin: "test-origin", + } + + for name, tc := range map[string]struct { + lifetime time.Duration + req *auth.CredentialRequest + responses []signCredentialResp + exp *auth.Credential + }{ + "cache hit": { + lifetime: time.Second, + req: &auth.CredentialRequest{ + DomainInfo: security.InitDomainInfo(&syscall.Ucred{Uid: 1234, Gid: 5678}, ""), + }, + responses: []signCredentialResp{ + { + cred: cred0, + err: nil, + }, + { + cred: cred1, + err: nil, + }, + }, + exp: cred0, + }, + "expired entry": { + lifetime: time.Nanosecond, + req: &auth.CredentialRequest{ + DomainInfo: security.InitDomainInfo(&syscall.Ucred{Uid: 1234, Gid: 5678}, ""), + }, + responses: []signCredentialResp{ + { + cred: cred0, + err: nil, + }, + { + cred: cred1, + err: nil, + }, + }, + exp: cred1, + }, + } { + t.Run(name, func(t *testing.T) { + log, buf := logging.NewTestLogger(t.Name()) + defer test.ShowBufferOnFailure(t, buf) + + cache := &credentialCache{ + log: log, + cache: cache.NewItemCache(log), + credLifetime: tc.lifetime, + cacheMissFn: func() func(_ context.Context, req *auth.CredentialRequest) (*auth.Credential, error) { + var idx int + return func(_ context.Context, req *auth.CredentialRequest) (*auth.Credential, error) { + defer func() { + if idx < len(tc.responses)-1 { + idx++ + } + }() + t.Logf("returning response %d: %+v", idx, tc.responses[idx]) + return tc.responses[idx].cred, tc.responses[idx].err + } + }(), + } + + // Prime the cache with a single entry. + _, err := cache.getSignedCredential(test.Context(t), tc.req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + // Request a second time with the same credentials. + cred, err := cache.getSignedCredential(test.Context(t), tc.req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + cmpOpts := cmp.Options{ + protocmp.Transform(), + } + if diff := cmp.Diff(tc.exp, cred, cmpOpts...); diff != "" { + t.Errorf("unexpected credential (-want +got):\n%s", diff) + } + }) + } +} diff --git a/src/control/lib/cache/cache.go b/src/control/lib/cache/cache.go index c31d82b0a61..9ecc123240c 100644 --- a/src/control/lib/cache/cache.go +++ b/src/control/lib/cache/cache.go @@ -1,5 +1,5 @@ // -// (C) Copyright 2023 Intel Corporation. +// (C) Copyright 2023-2024 Intel Corporation. // // SPDX-License-Identifier: BSD-2-Clause-Patent // @@ -18,20 +18,33 @@ import ( "github.com/daos-stack/daos/src/control/logging" ) -type Item interface { - Lock() - Unlock() - Key() string - Refresh(ctx context.Context) error - NeedsRefresh() bool -} +type ( + // Item defines an interface for a cached item. + Item interface { + sync.Locker + Key() string + } -// ItemCache is a mechanism for caching Items to keys. -type ItemCache struct { - log logging.Logger - mutex sync.RWMutex - items map[string]Item -} + // ExpirableItem is an Item that defines its own expiration criteria. + ExpirableItem interface { + Item + IsExpired() bool + } + + // RefreshableItem is an Item that defines its own refresh criteria and method. + RefreshableItem interface { + Item + Refresh(ctx context.Context) (func(), error) + RefreshIfNeeded(ctx context.Context) (func(), bool, error) + } + + // ItemCache is a mechanism for caching Items to keys. + ItemCache struct { + log logging.Logger + mutex sync.RWMutex + items map[string]Item + } +) // NewItemCache creates a new ItemCache. func NewItemCache(log logging.Logger) *ItemCache { @@ -119,22 +132,24 @@ func (e *errKeyNotFound) Error() string { return fmt.Sprintf("key %q not found", e.key) } -func noopRelease() {} +// NoopRelease is a no-op function that does nothing, but can +// be safely returned in lieu of a real lock release function. +func NoopRelease() {} // GetOrCreate returns an item from the cache if it exists, otherwise it creates // the item using the given function and caches it. The item must be released // by the caller when it is safe to be modified. func (ic *ItemCache) GetOrCreate(ctx context.Context, key string, missFn ItemCreateFunc) (Item, func(), error) { if ic == nil { - return nil, noopRelease, errors.New("nil ItemCache") + return nil, NoopRelease, errors.New("nil ItemCache") } if key == "" { - return nil, noopRelease, errors.Errorf("empty string is an invalid key") + return nil, NoopRelease, errors.Errorf("empty string is an invalid key") } if missFn == nil { - return nil, noopRelease, errors.Errorf("item create function is required") + return nil, NoopRelease, errors.Errorf("item create function is required") } ic.mutex.Lock() @@ -145,33 +160,39 @@ func (ic *ItemCache) GetOrCreate(ctx context.Context, key string, missFn ItemCre ic.log.Debugf("failed to get item for key %q: %s", key, err.Error()) item, err = missFn() if err != nil { - return nil, noopRelease, errors.Wrapf(err, "create item for %q", key) + return nil, NoopRelease, errors.Wrapf(err, "create item for %q", key) } ic.log.Debugf("created item for key %q", key) ic.set(item) } - item.Lock() - if item.NeedsRefresh() { - if err := item.Refresh(ctx); err != nil { - item.Unlock() - return nil, noopRelease, errors.Wrapf(err, "fetch data for %q", key) + var release func() + if ri, ok := item.(RefreshableItem); ok { + var refreshed bool + release, refreshed, err = ri.RefreshIfNeeded(ctx) + if err != nil { + return nil, NoopRelease, errors.Wrapf(err, "fetch data for %q", key) } - ic.log.Debugf("refreshed item %q", key) + if refreshed { + ic.log.Debugf("refreshed item %q", key) + } + } else { + item.Lock() + release = item.Unlock } - return item, item.Unlock, nil + return item, release, nil } // Get returns an item from the cache if it exists, otherwise it returns an // error. The item must be released by the caller when it is safe to be modified. func (ic *ItemCache) Get(ctx context.Context, key string) (Item, func(), error) { if ic == nil { - return nil, noopRelease, errors.New("nil ItemCache") + return nil, NoopRelease, errors.New("nil ItemCache") } if key == "" { - return nil, noopRelease, errors.Errorf("empty string is an invalid key") + return nil, NoopRelease, errors.Errorf("empty string is an invalid key") } ic.mutex.Lock() @@ -179,25 +200,35 @@ func (ic *ItemCache) Get(ctx context.Context, key string) (Item, func(), error) item, err := ic.get(key) if err != nil { - return nil, noopRelease, err + return nil, NoopRelease, err } - item.Lock() - if item.NeedsRefresh() { - if err := item.Refresh(ctx); err != nil { - item.Unlock() - return nil, noopRelease, errors.Wrapf(err, "fetch data for %q", key) + var release func() + if ri, ok := item.(RefreshableItem); ok { + var refreshed bool + release, refreshed, err = ri.RefreshIfNeeded(ctx) + if err != nil { + return nil, NoopRelease, errors.Wrapf(err, "fetch data for %q", key) } - ic.log.Debugf("refreshed item %q", key) + if refreshed { + ic.log.Debugf("refreshed item %q", key) + } + } else { + item.Lock() + release = item.Unlock } - return item, item.Unlock, nil + return item, release, nil } func (ic *ItemCache) get(key string) (Item, error) { - val, ok := ic.items[key] + item, ok := ic.items[key] if ok { - return val, nil + if ei, ok := item.(ExpirableItem); ok && ei.IsExpired() { + delete(ic.items, key) + } else { + return item, nil + } } return nil, &errKeyNotFound{key: key} } @@ -229,10 +260,12 @@ func (ic *ItemCache) refreshItem(ctx context.Context, key string) error { return err } - item.Lock() - defer item.Unlock() - if err := item.Refresh(ctx); err != nil { - return errors.Wrapf(err, "failed to refresh cached item %q", item.Key()) + if ri, ok := item.(RefreshableItem); ok { + release, err := ri.Refresh(ctx) + if err != nil { + return errors.Wrapf(err, "failed to refresh cached item %q", item.Key()) + } + release() } return nil diff --git a/src/control/lib/cache/cache_test.go b/src/control/lib/cache/cache_test.go index 6d1aeb465ea..0a34e8bf606 100644 --- a/src/control/lib/cache/cache_test.go +++ b/src/control/lib/cache/cache_test.go @@ -1,5 +1,5 @@ // -// (C) Copyright 2023 Intel Corporation. +// (C) Copyright 2023-2024 Intel Corporation. // // SPDX-License-Identifier: BSD-2-Clause-Patent // @@ -10,11 +10,12 @@ import ( "context" "testing" - "github.com/daos-stack/daos/src/control/common/test" - "github.com/daos-stack/daos/src/control/logging" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/pkg/errors" + + "github.com/daos-stack/daos/src/control/common/test" + "github.com/daos-stack/daos/src/control/logging" ) func TestCache_NewItemCache(t *testing.T) { @@ -51,11 +52,18 @@ func (m *mockItem) Key() string { return m.ItemKey } -func (m *mockItem) Refresh(ctx context.Context) error { - return m.RefreshErr +func (m *mockItem) Refresh(ctx context.Context) (func(), error) { + return NoopRelease, m.RefreshErr +} + +func (m *mockItem) RefreshIfNeeded(ctx context.Context) (func(), bool, error) { + if m.needsRefresh() { + return NoopRelease, true, m.RefreshErr + } + return NoopRelease, false, nil } -func (m *mockItem) NeedsRefresh() bool { +func (m *mockItem) needsRefresh() bool { return m.NeedsRefreshResult } diff --git a/src/control/security/auth/auth_sys.go b/src/control/security/auth/auth_sys.go index e9255f407d3..fcedcb95bd5 100644 --- a/src/control/security/auth/auth_sys.go +++ b/src/control/security/auth/auth_sys.go @@ -8,6 +8,7 @@ package auth import ( "bytes" + "context" "crypto" "os" "os/user" @@ -17,6 +18,7 @@ import ( "github.com/pkg/errors" "google.golang.org/protobuf/proto" + "github.com/daos-stack/daos/src/control/logging" "github.com/daos-stack/daos/src/control/security" ) @@ -213,7 +215,7 @@ func (r *CredentialRequest) WithUserAndGroup(userStr, groupStr string, groupStrs // GetSignedCredential returns a credential based on the provided domain info and // signing key. -func GetSignedCredential(req *CredentialRequest) (*Credential, error) { +func GetSignedCredential(ctx context.Context, req *CredentialRequest) (*Credential, error) { if req == nil { return nil, errors.Errorf("%T is nil", req) } @@ -274,6 +276,7 @@ func GetSignedCredential(req *CredentialRequest) (*Credential, error) { Verifier: &verifierToken, Origin: "agent"} + logging.FromContext(ctx).Tracef("%s: successfully signed credential", req.DomainInfo) return &credential, nil } diff --git a/src/control/security/auth/auth_sys_test.go b/src/control/security/auth/auth_sys_test.go index bbca64fe233..a421d1b067c 100644 --- a/src/control/security/auth/auth_sys_test.go +++ b/src/control/security/auth/auth_sys_test.go @@ -261,7 +261,7 @@ func TestAuth_GetSignedCred(t *testing.T) { }, } { t.Run(name, func(t *testing.T) { - cred, gotErr := GetSignedCredential(tc.req) + cred, gotErr := GetSignedCredential(test.Context(t), tc.req) test.CmpErr(t, tc.expErr, gotErr) if tc.expErr != nil { return @@ -277,7 +277,7 @@ func TestAuth_CredentialRequestOverrides(t *testing.T) { req.getHostnameFn = testHostnameFn(nil, "test-host") req.WithUserAndGroup("test-user", "test-group", "test-secondary") - cred, err := GetSignedCredential(req) + cred, err := GetSignedCredential(test.Context(t), req) if err != nil { t.Fatalf("Failed to get credential: %s", err) } diff --git a/src/control/security/config.go b/src/control/security/config.go index 2bda4bf4a00..1913fd89fcd 100644 --- a/src/control/security/config.go +++ b/src/control/security/config.go @@ -94,6 +94,7 @@ func (cm ClientUserMap) Lookup(uid uint32) *MappedClientUser { // CredentialConfig contains configuration details for managing user // credentials. type CredentialConfig struct { + CacheLifetime time.Duration `yaml:"cache_lifetime,omitempty"` ClientUserMap ClientUserMap `yaml:"client_user_map,omitempty"` } diff --git a/utils/config/daos_agent.yml b/utils/config/daos_agent.yml index 3e7e8f1ab95..cab0bd88976 100644 --- a/utils/config/daos_agent.yml +++ b/utils/config/daos_agent.yml @@ -48,7 +48,6 @@ #telemetry_retain: 1m ## Configuration for user credential management. -# #credential_config: # # If the agent should be able to resolve unknown client uids and gids # # (e.g. when running in a container) into ACL principal names, then a @@ -61,10 +60,20 @@ # 1000: # user: ralph # group: stanley - +# +# # Optionally cache generated credentials with the specified cache +# # lifetime. By default, a credential is generated for every client +# # process that connects to a pool. If the credential cache is +# # enabled, then local client processes connecting with stable +# # uid:gid associations may take advantage of the cached credential +# # and reduce some agent overhead. For heavily-loaded client nodes +# # with many frequent (e.g. hundreds per minute) client connections, +# # a lifetime of 1-5 minutes may be a reasonable tradeoff between +# # performance and responsiveness to user/group database updates. +# cred_cache_lifetime: 1m +# ## Configuration for SSL certificates used to secure management traffic # and authenticate/authorize management components. -# #transport_config: # # In order to disable transport security, uncomment and set allow_insecure # # to true. Not recommended for production configurations. @@ -77,6 +86,7 @@ # # Key portion of Agent Certificate # key: /etc/daos/certs/agent.key # + # Use the given directory for creating unix domain sockets # # NOTE: Do not change this when running under systemd control. If it needs to