From cf889dde02ef1d81be09cf6570323953251648f9 Mon Sep 17 00:00:00 2001 From: Michael MacDonald Date: Tue, 11 Jun 2024 08:47:00 -0400 Subject: [PATCH] DAOS-15849 control: Add client uid map to agent config (#14381) Allow daos_agent to optionally handle unresolvable client uids via custom mapping. In deployments where the agent may not have access to the same user namespace as client applications (e.g. in containerized deployments), the client_user_map can provide a fallback mechanism for resolving the client uids to known usernames for the purpose of applying ACL permissions tests. Example agent config: credential_config: client_user_map: default: user: nobody group: nobody 1000: user: joe group: blow Signed-off-by: Michael MacDonald --- src/control/cmd/daos_agent/config.go | 48 +-- src/control/cmd/daos_agent/config_test.go | 30 +- src/control/cmd/daos_agent/security_rpc.go | 71 ++++- .../cmd/daos_agent/security_rpc_test.go | 155 ++++++++- src/control/cmd/daos_agent/start.go | 6 +- src/control/lib/control/pool.go | 21 +- src/control/lib/control/pool_test.go | 5 - src/control/security/auth/auth_sys.go | 273 +++++++++------- src/control/security/auth/auth_sys_test.go | 294 ++++++++++-------- src/control/security/auth/mocks.go | 91 ------ src/control/security/config.go | 65 ++++ src/control/security/config_test.go | 104 ++++++- src/control/security/domain_info.go | 7 +- utils/config/daos_agent.yml | 20 +- 14 files changed, 768 insertions(+), 422 deletions(-) delete mode 100644 src/control/security/auth/mocks.go diff --git a/src/control/cmd/daos_agent/config.go b/src/control/cmd/daos_agent/config.go index cf86df85ac8..1c5ea0d3f5e 100644 --- a/src/control/cmd/daos_agent/config.go +++ b/src/control/cmd/daos_agent/config.go @@ -42,23 +42,24 @@ func (rm refreshMinutes) Duration() time.Duration { // Config defines the agent configuration. type Config struct { - SystemName string `yaml:"name"` - AccessPoints []string `yaml:"access_points"` - ControlPort int `yaml:"port"` - RuntimeDir string `yaml:"runtime_dir"` - LogFile string `yaml:"log_file"` - LogLevel common.ControlLogLevel `yaml:"control_log_mask,omitempty"` - TransportConfig *security.TransportConfig `yaml:"transport_config"` - DisableCache bool `yaml:"disable_caching,omitempty"` - CacheExpiration refreshMinutes `yaml:"cache_expiration,omitempty"` - DisableAutoEvict bool `yaml:"disable_auto_evict,omitempty"` - EvictOnStart bool `yaml:"enable_evict_on_start,omitempty"` - ExcludeFabricIfaces common.StringSet `yaml:"exclude_fabric_ifaces,omitempty"` - FabricInterfaces []*NUMAFabricConfig `yaml:"fabric_ifaces,omitempty"` - ProviderIdx uint // TODO SRS-31: Enable with multiprovider functionality - TelemetryPort int `yaml:"telemetry_port,omitempty"` - TelemetryEnabled bool `yaml:"telemetry_enabled,omitempty"` - TelemetryRetain time.Duration `yaml:"telemetry_retain,omitempty"` + SystemName string `yaml:"name"` + AccessPoints []string `yaml:"access_points"` + ControlPort int `yaml:"port"` + RuntimeDir string `yaml:"runtime_dir"` + LogFile string `yaml:"log_file"` + LogLevel common.ControlLogLevel `yaml:"control_log_mask,omitempty"` + CredentialConfig *security.CredentialConfig `yaml:"credential_config"` + TransportConfig *security.TransportConfig `yaml:"transport_config"` + DisableCache bool `yaml:"disable_caching,omitempty"` + CacheExpiration refreshMinutes `yaml:"cache_expiration,omitempty"` + DisableAutoEvict bool `yaml:"disable_auto_evict,omitempty"` + EvictOnStart bool `yaml:"enable_evict_on_start,omitempty"` + ExcludeFabricIfaces common.StringSet `yaml:"exclude_fabric_ifaces,omitempty"` + FabricInterfaces []*NUMAFabricConfig `yaml:"fabric_ifaces,omitempty"` + ProviderIdx uint // TODO SRS-31: Enable with multiprovider functionality + TelemetryPort int `yaml:"telemetry_port,omitempty"` + TelemetryEnabled bool `yaml:"telemetry_enabled,omitempty"` + TelemetryRetain time.Duration `yaml:"telemetry_retain,omitempty"` } // TelemetryExportEnabled returns true if client telemetry export is enabled. @@ -113,11 +114,12 @@ func LoadConfig(cfgPath string) (*Config, error) { func DefaultConfig() *Config { localServer := fmt.Sprintf("localhost:%d", build.DefaultControlPort) return &Config{ - SystemName: build.DefaultSystemName, - ControlPort: build.DefaultControlPort, - AccessPoints: []string{localServer}, - RuntimeDir: defaultRuntimeDir, - LogLevel: common.DefaultControlLogLevel, - TransportConfig: security.DefaultAgentTransportConfig(), + SystemName: build.DefaultSystemName, + ControlPort: build.DefaultControlPort, + AccessPoints: []string{localServer}, + RuntimeDir: defaultRuntimeDir, + LogLevel: common.DefaultControlLogLevel, + TransportConfig: security.DefaultAgentTransportConfig(), + CredentialConfig: &security.CredentialConfig{}, } } diff --git a/src/control/cmd/daos_agent/config_test.go b/src/control/cmd/daos_agent/config_test.go index 3a58e2c5616..7e8802be99f 100644 --- a/src/control/cmd/daos_agent/config_test.go +++ b/src/control/cmd/daos_agent/config_test.go @@ -1,5 +1,5 @@ // -// (C) Copyright 2021-2023 Intel Corporation. +// (C) Copyright 2021-2024 Intel Corporation. // // SPDX-License-Identifier: BSD-2-Clause-Patent // @@ -46,6 +46,12 @@ control_log_mask: debug disable_caching: true cache_expiration: 30 disable_auto_evict: true +credential_config: + client_user_map: + 1000: + user: frodo + group: baggins + groups: ["ringbearers"] transport_config: allow_insecure: true exclude_fabric_ifaces: ["ib3"] @@ -104,12 +110,13 @@ transport_config: "without optional items": { path: withoutOptCfg, expResult: &Config{ - SystemName: "shire", - AccessPoints: []string{"one:10001", "two:10001"}, - ControlPort: 4242, - RuntimeDir: "/tmp/runtime", - LogFile: "/home/frodo/logfile", - LogLevel: common.DefaultControlLogLevel, + SystemName: "shire", + AccessPoints: []string{"one:10001", "two:10001"}, + ControlPort: 4242, + RuntimeDir: "/tmp/runtime", + LogFile: "/home/frodo/logfile", + LogLevel: common.DefaultControlLogLevel, + CredentialConfig: &security.CredentialConfig{}, TransportConfig: &security.TransportConfig{ AllowInsecure: true, CertificateConfig: DefaultConfig().TransportConfig.CertificateConfig, @@ -132,6 +139,15 @@ transport_config: DisableCache: true, CacheExpiration: refreshMinutes(30 * time.Minute), DisableAutoEvict: true, + CredentialConfig: &security.CredentialConfig{ + ClientUserMap: map[uint32]*security.MappedClientUser{ + 1000: { + User: "frodo", + Group: "baggins", + Groups: []string{"ringbearers"}, + }, + }, + }, TransportConfig: &security.TransportConfig{ AllowInsecure: true, CertificateConfig: DefaultConfig().TransportConfig.CertificateConfig, diff --git a/src/control/cmd/daos_agent/security_rpc.go b/src/control/cmd/daos_agent/security_rpc.go index 906fe53ad8b..8843f18ed6f 100644 --- a/src/control/cmd/daos_agent/security_rpc.go +++ b/src/control/cmd/daos_agent/security_rpc.go @@ -1,5 +1,5 @@ // -// (C) Copyright 2018-2022 Intel Corporation. +// (C) Copyright 2018-2024 Intel Corporation. // // SPDX-License-Identifier: BSD-2-Clause-Patent // @@ -9,6 +9,9 @@ package main import ( "context" "net" + "os/user" + + "github.com/pkg/errors" "github.com/daos-stack/daos/src/control/drpc" "github.com/daos-stack/daos/src/control/lib/daos" @@ -17,21 +20,31 @@ import ( "github.com/daos-stack/daos/src/control/security/auth" ) -// SecurityModule is the security drpc module struct -type SecurityModule struct { - log logging.Logger - ext auth.UserExt - config *security.TransportConfig -} +type ( + credSignerFn func(*auth.CredentialRequest) (*auth.Credential, error) + + // securityConfig defines configuration parameters for SecurityModule. + securityConfig struct { + credentials *security.CredentialConfig + transport *security.TransportConfig + } + + // SecurityModule is the security drpc module struct + SecurityModule struct { + log logging.Logger + signCredential credSignerFn + + config *securityConfig + } +) // NewSecurityModule creates a new module with the given initialized TransportConfig -func NewSecurityModule(log logging.Logger, tc *security.TransportConfig) *SecurityModule { - mod := SecurityModule{ - log: log, - config: tc, +func NewSecurityModule(log logging.Logger, cfg *securityConfig) *SecurityModule { + return &SecurityModule{ + log: log, + signCredential: auth.GetSignedCredential, + config: cfg, } - mod.ext = &auth.External{} - return &mod } // HandleCall is the handler for calls to the SecurityModule @@ -46,6 +59,10 @@ func (m *SecurityModule) HandleCall(_ context.Context, session *drpc.Session, me // getCredentials generates a signed user credential based on the data attached to // the Unix Domain Socket. func (m *SecurityModule) getCredential(session *drpc.Session) ([]byte, error) { + if session == nil { + return nil, drpc.NewFailureWithMessage("session is nil") + } + uConn, ok := session.Conn.(*net.UnixConn) if !ok { return nil, drpc.NewFailureWithMessage("connection is not a unix socket") @@ -57,17 +74,37 @@ func (m *SecurityModule) getCredential(session *drpc.Session) ([]byte, error) { return m.credRespWithStatus(daos.MiscError) } - signingKey, err := m.config.PrivateKey() + signingKey, err := m.config.transport.PrivateKey() if err != nil { m.log.Errorf("%s: failed to get signing key: %s", info, err) // something is wrong with the cert config return m.credRespWithStatus(daos.BadCert) } - cred, err := auth.AuthSysRequestFromCreds(m.ext, info, signingKey) + req := auth.NewCredentialRequest(info, signingKey) + cred, err := m.signCredential(req) if err != nil { - m.log.Errorf("%s: failed to get AuthSys struct: %s", info, err) - return m.credRespWithStatus(daos.MiscError) + if err := func() error { + if !errors.Is(err, user.UnknownUserIdError(info.Uid())) { + return err + } + + mu := m.config.credentials.ClientUserMap.Lookup(info.Uid()) + if mu == nil { + return user.UnknownUserIdError(info.Uid()) + } + + req.WithUserAndGroup(mu.User, mu.Group, mu.Groups...) + cred, err = m.signCredential(req) + if err != nil { + return err + } + + return nil + }(); err != nil { + m.log.Errorf("%s: failed to get user credential: %s", info, err) + return m.credRespWithStatus(daos.MiscError) + } } m.log.Tracef("%s: successfully signed credential", info) diff --git a/src/control/cmd/daos_agent/security_rpc_test.go b/src/control/cmd/daos_agent/security_rpc_test.go index 89fb83b8852..048506afa04 100644 --- a/src/control/cmd/daos_agent/security_rpc_test.go +++ b/src/control/cmd/daos_agent/security_rpc_test.go @@ -1,5 +1,5 @@ // -// (C) Copyright 2019-2022 Intel Corporation. +// (C) Copyright 2019-2024 Intel Corporation. // // SPDX-License-Identifier: BSD-2-Clause-Patent // @@ -9,8 +9,11 @@ package main import ( "errors" "net" + "os/user" "testing" + "github.com/google/go-cmp/cmp" + "golang.org/x/sys/unix" "google.golang.org/protobuf/proto" "github.com/daos-stack/daos/src/control/common/test" @@ -35,8 +38,11 @@ func newTestSession(t *testing.T, log logging.Logger, conn net.Conn) *drpc.Sessi return drpc.NewSession(conn, svc) } -func defaultTestTransportConfig() *security.TransportConfig { - return &security.TransportConfig{AllowInsecure: true} +func defaultTestSecurityConfig() *securityConfig { + return &securityConfig{ + transport: &security.TransportConfig{AllowInsecure: true}, + credentials: &security.CredentialConfig{}, + } } func TestAgentSecurityModule_BadMethod(t *testing.T) { @@ -74,6 +80,7 @@ func setupTestUnixConn(t *testing.T) (*net.UnixConn, func()) { } func getClientConn(t *testing.T, path string) drpc.DomainSocketClient { + t.Helper() client := drpc.NewClientConnection(path) if err := client.Connect(test.Context(t)); err != nil { t.Fatalf("Failed to connect: %v", err) @@ -82,6 +89,8 @@ func getClientConn(t *testing.T, path string) drpc.DomainSocketClient { } func expectCredResp(t *testing.T, respBytes []byte, expStatus int32, expCred bool) { + t.Helper() + if respBytes == nil { t.Error("Expected non-nil response") } @@ -104,8 +113,7 @@ func TestAgentSecurityModule_RequestCreds_OK(t *testing.T) { conn, cleanup := setupTestUnixConn(t) defer cleanup() - mod := NewSecurityModule(log, defaultTestTransportConfig()) - mod.ext = auth.NewMockExtWithUser("agent-test", 0, 0) + mod := NewSecurityModule(log, defaultTestSecurityConfig()) respBytes, err := callRequestCreds(mod, t, log, conn) if err != nil { @@ -119,7 +127,7 @@ func TestAgentSecurityModule_RequestCreds_NotUnixConn(t *testing.T) { log, buf := logging.NewTestLogger(t.Name()) defer test.ShowBufferOnFailure(t, buf) - mod := NewSecurityModule(log, defaultTestTransportConfig()) + mod := NewSecurityModule(log, defaultTestSecurityConfig()) respBytes, err := callRequestCreds(mod, t, log, &net.TCPConn{}) test.CmpErr(t, drpc.NewFailureWithMessage("connection is not a unix socket"), err) @@ -138,7 +146,7 @@ func TestAgentSecurityModule_RequestCreds_NotConnected(t *testing.T) { defer cleanup() conn.Close() // can't get uid/gid from a closed connection - mod := NewSecurityModule(log, defaultTestTransportConfig()) + mod := NewSecurityModule(log, defaultTestSecurityConfig()) respBytes, err := callRequestCreds(mod, t, log, conn) if err != nil { @@ -157,7 +165,9 @@ func TestAgentSecurityModule_RequestCreds_BadConfig(t *testing.T) { defer cleanup() // Empty TransportConfig is incomplete - mod := NewSecurityModule(log, &security.TransportConfig{}) + mod := NewSecurityModule(log, &securityConfig{ + transport: &security.TransportConfig{}, + }) respBytes, err := callRequestCreds(mod, t, log, conn) if err != nil { @@ -175,10 +185,9 @@ func TestAgentSecurityModule_RequestCreds_BadUid(t *testing.T) { conn, cleanup := setupTestUnixConn(t) defer cleanup() - mod := NewSecurityModule(log, defaultTestTransportConfig()) - mod.ext = &auth.MockExt{ - LookupUserIDErr: errors.New("LookupUserID"), - LookupGroupIDErr: errors.New("LookupGroupID"), + mod := NewSecurityModule(log, defaultTestSecurityConfig()) + mod.signCredential = func(_ *auth.CredentialRequest) (*auth.Credential, error) { + return nil, errors.New("LookupUserID") } respBytes, err := callRequestCreds(mod, t, log, conn) @@ -188,3 +197,125 @@ func TestAgentSecurityModule_RequestCreds_BadUid(t *testing.T) { expectCredResp(t, respBytes, int32(daos.MiscError), false) } + +func TestAgent_SecurityRPC_getCredential(t *testing.T) { + type response struct { + cred *auth.Credential + err error + } + testCred := &auth.Credential{ + Token: &auth.Token{Flavor: auth.Flavor_AUTH_SYS, Data: []byte("test-token")}, + Origin: "test-origin", + } + miscErrBytes, err := proto.Marshal( + &auth.GetCredResp{ + Status: int32(daos.MiscError), + }, + ) + if err != nil { + t.Fatalf("Couldn't marshal misc error: %v", err) + } + successBytes, err := proto.Marshal( + &auth.GetCredResp{ + Status: 0, + Cred: testCred, + }, + ) + if err != nil { + t.Fatalf("Couldn't marshal success: %v", err) + } + + for name, tc := range map[string]struct { + secCfg *securityConfig + responses []response + expBytes []byte + expErr error + }{ + "lookup miss": { + secCfg: defaultTestSecurityConfig(), + responses: []response{ + { + cred: nil, + err: user.UnknownUserIdError(unix.Getuid()), + }, + }, + expBytes: miscErrBytes, + }, + "lookup OK": { + secCfg: func() *securityConfig { + cfg := defaultTestSecurityConfig() + cfg.credentials.ClientUserMap = security.ClientUserMap{ + uint32(unix.Getuid()): &security.MappedClientUser{ + User: "test-user", + }, + } + return cfg + }(), + responses: []response{ + { + cred: nil, + err: user.UnknownUserIdError(unix.Getuid()), + }, + { + cred: testCred, + err: nil, + }, + }, + expBytes: successBytes, + }, + "lookup OK, but retried request fails": { + secCfg: func() *securityConfig { + cfg := defaultTestSecurityConfig() + cfg.credentials.ClientUserMap = security.ClientUserMap{ + uint32(unix.Getuid()): &security.MappedClientUser{ + User: "test-user", + }, + } + return cfg + }(), + responses: []response{ + { + cred: nil, + err: user.UnknownUserIdError(unix.Getuid()), + }, + { + cred: nil, + err: errors.New("oops"), + }, + }, + expBytes: miscErrBytes, + }, + } { + t.Run(name, func(t *testing.T) { + log, buf := logging.NewTestLogger(t.Name()) + defer test.ShowBufferOnFailure(t, buf) + + // Set up a real unix socket so we can make a real connection + conn, cleanup := setupTestUnixConn(t) + defer cleanup() + + mod := NewSecurityModule(log, tc.secCfg) + mod.signCredential = func() credSignerFn { + var idx int + return func(req *auth.CredentialRequest) (*auth.Credential, error) { + defer func() { + if idx < len(tc.responses)-1 { + idx++ + } + }() + t.Logf("returning response %d: %+v", idx, tc.responses[idx]) + return tc.responses[idx].cred, tc.responses[idx].err + } + }() + + respBytes, gotErr := callRequestCreds(mod, t, log, conn) + test.CmpErr(t, tc.expErr, gotErr) + if tc.expErr != nil { + return + } + if diff := cmp.Diff(tc.expBytes, respBytes); diff != "" { + t.Errorf("unexpected response (-want +got):\n%s", diff) + } + }) + } +} diff --git a/src/control/cmd/daos_agent/start.go b/src/control/cmd/daos_agent/start.go index b8b7ead40d4..b47f914ce1e 100644 --- a/src/control/cmd/daos_agent/start.go +++ b/src/control/cmd/daos_agent/start.go @@ -106,7 +106,11 @@ func (cmd *startCmd) Execute(_ []string) error { } drpcRegStart := time.Now() - drpcServer.RegisterRPCModule(NewSecurityModule(cmd.Logger, cmd.cfg.TransportConfig)) + secCfg := &securityConfig{ + transport: cmd.cfg.TransportConfig, + credentials: cmd.cfg.CredentialConfig, + } + drpcServer.RegisterRPCModule(NewSecurityModule(cmd.Logger, secCfg)) mgmtMod := &mgmtModule{ log: cmd.Logger, sys: cmd.cfg.SystemName, diff --git a/src/control/lib/control/pool.go b/src/control/lib/control/pool.go index 3c3cb8a146d..2b5bfc89efd 100644 --- a/src/control/lib/control/pool.go +++ b/src/control/lib/control/pool.go @@ -11,6 +11,7 @@ import ( "encoding/json" "fmt" "math" + "os/user" "sort" "strings" "time" @@ -28,7 +29,6 @@ import ( "github.com/daos-stack/daos/src/control/fault/code" "github.com/daos-stack/daos/src/control/lib/daos" "github.com/daos-stack/daos/src/control/lib/ranklist" - "github.com/daos-stack/daos/src/control/security/auth" "github.com/daos-stack/daos/src/control/server/storage" "github.com/daos-stack/daos/src/control/system" ) @@ -50,22 +50,19 @@ func checkUUID(uuidStr string) error { // formatNameGroup converts system names to principals, If user or group is not // provided, the effective user and/or effective group will be used. -func formatNameGroup(ext auth.UserExt, usr string, grp string) (string, string, error) { +func formatNameGroup(usr string, grp string) (string, string, error) { if usr == "" || grp == "" { - eUsr, err := ext.Current() + eUsr, err := user.Current() if err != nil { return "", "", err } if usr == "" { - usr = eUsr.Username() + usr = eUsr.Username } if grp == "" { - gid, err := eUsr.Gid() - if err != nil { - return "", "", err - } - eGrp, err := ext.LookupGroupID(gid) + gid := eUsr.Gid + eGrp, err := user.LookupGroupId(gid) if err != nil { return "", "", err } @@ -202,7 +199,6 @@ type ( // PoolCreateReq contains the parameters for a pool create request. PoolCreateReq struct { poolRequest - userExt auth.UserExt User string UserGroup string ACL *AccessControlList `json:"-"` @@ -287,11 +283,8 @@ func poolCreateReqChkSizes(log debugLogger, getMaxPoolSz maxPoolSizeGetter, req } func poolCreateGenPBReq(ctx context.Context, rpcClient UnaryInvoker, in *PoolCreateReq) (out *mgmtpb.PoolCreateReq, err error) { - if in.userExt == nil { - in.userExt = &auth.External{} - } // ensure pool ownership is set up correctly - in.User, in.UserGroup, err = formatNameGroup(in.userExt, in.User, in.UserGroup) + in.User, in.UserGroup, err = formatNameGroup(in.User, in.UserGroup) if err != nil { return } diff --git a/src/control/lib/control/pool_test.go b/src/control/lib/control/pool_test.go index e145e9eaa8f..ca3a7e68a9d 100644 --- a/src/control/lib/control/pool_test.go +++ b/src/control/lib/control/pool_test.go @@ -28,7 +28,6 @@ import ( "github.com/daos-stack/daos/src/control/lib/daos" "github.com/daos-stack/daos/src/control/lib/ranklist" "github.com/daos-stack/daos/src/control/logging" - "github.com/daos-stack/daos/src/control/security/auth" "github.com/daos-stack/daos/src/control/server/storage" "github.com/daos-stack/daos/src/control/system" ) @@ -511,7 +510,6 @@ func TestControl_poolCreateReqChkSizes(t *testing.T) { } func TestControl_PoolCreate(t *testing.T) { - mockExt := auth.NewMockExtWithUser("poolTest", 0, 0) mockTierRatios := []float64{0.06, 0.94} mockTierBytes := []uint64{humanize.GiByte * 6, humanize.GiByte * 94} validReq := &PoolCreateReq{ @@ -704,9 +702,6 @@ func TestControl_PoolCreate(t *testing.T) { ctx := test.Context(t) mi := NewMockInvoker(log, mic) - if tc.req.userExt == nil { - tc.req.userExt = mockExt - } gotResp, gotErr := PoolCreate(ctx, mi, tc.req) test.CmpErr(t, tc.expErr, gotErr) if tc.expErr != nil { diff --git a/src/control/security/auth/auth_sys.go b/src/control/security/auth/auth_sys.go index a04c0627a0d..a96a2c84662 100644 --- a/src/control/security/auth/auth_sys.go +++ b/src/control/security/auth/auth_sys.go @@ -20,84 +20,6 @@ import ( "github.com/daos-stack/daos/src/control/security" ) -// User is an interface wrapping a representation of a specific system user. -type User interface { - Username() string - GroupIDs() ([]uint32, error) - Gid() (uint32, error) -} - -// UserExt is an interface that wraps system user-related external functions. -type UserExt interface { - Current() (User, error) - LookupUserID(uid uint32) (User, error) - LookupGroupID(gid uint32) (*user.Group, error) -} - -// UserInfo is an exported implementation of the security.User interface. -type UserInfo struct { - Info *user.User -} - -// Username is a wrapper for user.Username. -func (u *UserInfo) Username() string { - return u.Info.Username -} - -// GroupIDs is a wrapper for user.GroupIds. -func (u *UserInfo) GroupIDs() ([]uint32, error) { - gidStrs, err := u.Info.GroupIds() - if err != nil { - return nil, err - } - - gids := []uint32{} - for _, gstr := range gidStrs { - gid, err := strconv.Atoi(gstr) - if err != nil { - continue - } - gids = append(gids, uint32(gid)) - } - - return gids, nil -} - -// Gid is a wrapper for user.Gid. -func (u *UserInfo) Gid() (uint32, error) { - gid, err := strconv.Atoi(u.Info.Gid) - - return uint32(gid), errors.Wrap(err, "user gid") -} - -// External is an exported implementation of the UserExt interface. -type External struct{} - -// LookupUserId is a wrapper for user.LookupId. -func (e *External) LookupUserID(uid uint32) (User, error) { - uidStr := strconv.FormatUint(uint64(uid), 10) - info, err := user.LookupId(uidStr) - if err != nil { - return nil, err - } - return &UserInfo{Info: info}, nil -} - -// LookupGroupId is a wrapper for user.LookupGroupId. -func (e *External) LookupGroupID(gid uint32) (*user.Group, error) { - gidStr := strconv.FormatUint(uint64(gid), 10) - return user.LookupGroupId(gidStr) -} - -// Current is a wrapper for user.Current. -func (e *External) Current() (User, error) { - info, err := user.Current() - if err != nil { - return nil, err - } - return &UserInfo{Info: info}, nil -} - // VerifierFromToken will return a SHA512 hash of the token data. If a signing key // is passed in it will additionally sign the hash of the token. func VerifierFromToken(key crypto.PublicKey, token *Token) ([]byte, error) { @@ -146,6 +68,10 @@ func sysNameToPrincipalName(name string) string { return name + "@" } +func stripHostName(name string) string { + return strings.Split(name, ".")[0] +} + // GetMachineName returns the "short" hostname by stripping the domain from the FQDN. func GetMachineName() (string, error) { name, err := os.Hostname() @@ -153,60 +79,185 @@ func GetMachineName() (string, error) { return "", err } - return strings.Split(name, ".")[0], nil + return stripHostName(name), nil } -// AuthSysRequestFromCreds takes the domain info credentials gathered -// during the dRPC request and creates an AuthSys security request to obtain -// a handle from the management service. -func AuthSysRequestFromCreds(ext UserExt, creds *security.DomainInfo, signing crypto.PrivateKey) (*Credential, error) { - if creds == nil { - return nil, errors.New("No credentials supplied") +type ( + getHostnameFn func() (string, error) + getUserFn func(string) (*user.User, error) + getGroupFn func(string) (*user.Group, error) + getGroupIdsFn func(*CredentialRequest) ([]string, error) + getGroupNamesFn func(*CredentialRequest) ([]string, error) + + // CredentialRequest defines the request parameters for GetSignedCredential. + CredentialRequest struct { + DomainInfo *security.DomainInfo + SigningKey crypto.PrivateKey + getHostname getHostnameFn + getUser getUserFn + getGroup getGroupFn + getGroupIds getGroupIdsFn + getGroupNames getGroupNamesFn } +) - userInfo, err := ext.LookupUserID(creds.Uid()) +func getGroupIds(req *CredentialRequest) ([]string, error) { + u, err := req.user() if err != nil { - return nil, errors.Wrapf(err, "Failed to lookup uid %v", - creds.Uid()) + return nil, err } + return u.GroupIds() +} - groupInfo, err := ext.LookupGroupID(creds.Gid()) +func getGroupNames(req *CredentialRequest) ([]string, error) { + groupIds, err := req.getGroupIds(req) if err != nil { - return nil, errors.Wrapf(err, "Failed to lookup gid %v", - creds.Gid()) + return nil, err + } + + groupNames := make([]string, len(groupIds)) + for i, gID := range groupIds { + g, err := req.getGroup(gID) + if err != nil { + return nil, err + } + groupNames[i] = g.Name + } + + return groupNames, nil +} + +// NewCredentialRequest returns a properly initialized CredentialRequest. +func NewCredentialRequest(info *security.DomainInfo, key crypto.PrivateKey) *CredentialRequest { + return &CredentialRequest{ + DomainInfo: info, + SigningKey: key, + getHostname: GetMachineName, + getUser: user.LookupId, + getGroup: user.LookupGroupId, + getGroupIds: getGroupIds, + getGroupNames: getGroupNames, } +} - groups, err := userInfo.GroupIDs() +func (r *CredentialRequest) hostname() (string, error) { + if r.getHostname == nil { + return "", errors.New("hostname lookup function not set") + } + + hostname, err := r.getHostname() if err != nil { - return nil, errors.Wrapf(err, "Failed to get group IDs for user %v", - userInfo.Username()) + return "", errors.Wrap(err, "failed to get hostname") + } + return stripHostName(hostname), nil +} + +func (r *CredentialRequest) user() (*user.User, error) { + if r.getUser == nil { + return nil, errors.New("user lookup function not set") } + return r.getUser(strconv.Itoa(int(r.DomainInfo.Uid()))) +} - host, err := GetMachineName() +func (r *CredentialRequest) userPrincipal() (string, error) { + u, err := r.user() if err != nil { - host = "unavailable" + return "", err } + return sysNameToPrincipalName(u.Username), nil +} - var groupList = []string{} +func (r *CredentialRequest) group() (*user.Group, error) { + if r.getGroup == nil { + return nil, errors.New("group lookup function not set") + } + return r.getGroup(strconv.Itoa(int(r.DomainInfo.Gid()))) +} - // Convert groups to gids - for _, gid := range groups { - gInfo, err := ext.LookupGroupID(gid) - if err != nil { - // Skip this group - continue - } - groupList = append(groupList, sysNameToPrincipalName(gInfo.Name)) +func (r *CredentialRequest) groupPrincipal() (string, error) { + g, err := r.group() + if err != nil { + return "", err + } + return sysNameToPrincipalName(g.Name), nil +} + +func (r *CredentialRequest) groupPrincipals() ([]string, error) { + if r.getGroupNames == nil { + return nil, errors.New("groupNames function not set") + } + + groupNames, err := r.getGroupNames(r) + if err != nil { + return nil, errors.Wrap(err, "failed to get group names") + } + + for i, g := range groupNames { + groupNames[i] = sysNameToPrincipalName(g) + } + return groupNames, nil +} + +// WithUserAndGroup provides an override to set the user, group, and optional list +// of group names to be used for the request. +func (r *CredentialRequest) WithUserAndGroup(userStr, groupStr string, groupStrs ...string) { + r.getUser = func(id string) (*user.User, error) { + return &user.User{ + Uid: id, + Gid: id, + Username: userStr, + }, nil + } + r.getGroup = func(id string) (*user.Group, error) { + return &user.Group{ + Gid: id, + Name: groupStr, + }, nil + } + r.getGroupNames = func(*CredentialRequest) ([]string, error) { + return groupStrs, nil + } +} + +// GetSignedCredential returns a credential based on the provided domain info and +// signing key. +func GetSignedCredential(req *CredentialRequest) (*Credential, error) { + if req == nil { + return nil, errors.Errorf("%T is nil", req) + } + + if req.DomainInfo == nil { + return nil, errors.New("No domain info supplied") + } + + hostname, err := req.hostname() + if err != nil { + return nil, err + } + + userPrinc, err := req.userPrincipal() + if err != nil { + return nil, err + } + + groupPrinc, err := req.groupPrincipal() + if err != nil { + return nil, err + } + + groupPrincs, err := req.groupPrincipals() + if err != nil { + return nil, err } // Craft AuthToken sys := Sys{ Stamp: 0, - Machinename: host, - User: sysNameToPrincipalName(userInfo.Username()), - Group: sysNameToPrincipalName(groupInfo.Name), - Groups: groupList, - Secctx: creds.Ctx()} + Machinename: hostname, + User: userPrinc, + Group: groupPrinc, + Groups: groupPrincs, + Secctx: req.DomainInfo.Ctx()} // Marshal our AuthSys token into a byte array tokenBytes, err := proto.Marshal(&sys) @@ -217,7 +268,7 @@ func AuthSysRequestFromCreds(ext UserExt, creds *security.DomainInfo, signing cr Flavor: Flavor_AUTH_SYS, Data: tokenBytes} - verifier, err := VerifierFromToken(signing, &token) + verifier, err := VerifierFromToken(req.SigningKey, &token) if err != nil { return nil, errors.WithMessage(err, "Unable to generate verifier") } diff --git a/src/control/security/auth/auth_sys_test.go b/src/control/security/auth/auth_sys_test.go index fee5c69931b..b55a293af05 100644 --- a/src/control/security/auth/auth_sys_test.go +++ b/src/control/security/auth/auth_sys_test.go @@ -1,5 +1,5 @@ // -// (C) Copyright 2018-2022 Intel Corporation. +// (C) Copyright 2018-2024 Intel Corporation. // // SPDX-License-Identifier: BSD-2-Clause-Patent // @@ -8,14 +8,13 @@ package auth import ( "errors" - "fmt" "os/user" "syscall" "testing" "google.golang.org/protobuf/proto" - . "github.com/daos-stack/daos/src/control/common/test" + "github.com/daos-stack/daos/src/control/common/test" "github.com/daos-stack/daos/src/control/security" ) @@ -29,7 +28,7 @@ func expectAuthSysErrorForToken(t *testing.T, badToken *Token, expectedErrorMess t.Error("Expected a nil AuthSys") } - CmpErr(t, errors.New(expectedErrorMessage), err) + test.CmpErr(t, errors.New(expectedErrorMessage), err) } // AuthSysFromAuthToken tests @@ -82,30 +81,70 @@ func TestAuthSysFromAuthToken_SucceedsWithGoodToken(t *testing.T) { t.Fatal("Got a nil AuthSys") } - AssertEqual(t, authSys.GetStamp(), originalAuthSys.GetStamp(), + test.AssertEqual(t, authSys.GetStamp(), originalAuthSys.GetStamp(), "Stamps don't match") - AssertEqual(t, authSys.GetMachinename(), originalAuthSys.GetMachinename(), + test.AssertEqual(t, authSys.GetMachinename(), originalAuthSys.GetMachinename(), "Machinenames don't match") - AssertEqual(t, authSys.GetUser(), originalAuthSys.GetUser(), + test.AssertEqual(t, authSys.GetUser(), originalAuthSys.GetUser(), "Owners don't match") - AssertEqual(t, authSys.GetGroup(), originalAuthSys.GetGroup(), + test.AssertEqual(t, authSys.GetGroup(), originalAuthSys.GetGroup(), "Groups don't match") - AssertEqual(t, len(authSys.GetGroups()), len(originalAuthSys.GetGroups()), + test.AssertEqual(t, len(authSys.GetGroups()), len(originalAuthSys.GetGroups()), "Group lists aren't the same length") - AssertEqual(t, authSys.GetSecctx(), originalAuthSys.GetSecctx(), + test.AssertEqual(t, authSys.GetSecctx(), originalAuthSys.GetSecctx(), "Secctx don't match") } -// AuthSysRequestFromCreds tests +func testHostnameFn(expErr error, hostname string) getHostnameFn { + return func() (string, error) { + if expErr != nil { + return "", expErr + } + return hostname, nil + } +} + +func testUserFn(expErr error, userName string) getUserFn { + return func(uid string) (*user.User, error) { + if expErr != nil { + return nil, expErr + } + return &user.User{ + Uid: uid, + Gid: uid, + Username: userName, + }, nil + } +} -func TestAuthSysRequestFromCreds_failsIfDomainInfoNil(t *testing.T) { - result, err := AuthSysRequestFromCreds(&MockExt{}, nil, nil) +func testGroupFn(expErr error, groupName string) getGroupFn { + return func(gid string) (*user.Group, error) { + if expErr != nil { + return nil, expErr + } + return &user.Group{ + Gid: gid, + Name: groupName, + }, nil + } +} - if result != nil { - t.Error("Expected a nil request") +func testGroupIdsFn(expErr error, groupNames ...string) getGroupIdsFn { + return func(*CredentialRequest) ([]string, error) { + if expErr != nil { + return nil, expErr + } + return groupNames, nil } +} - ExpectError(t, err, "No credentials supplied", "") +func testGroupNamesFn(expErr error, groupNames ...string) getGroupNamesFn { + return func(*CredentialRequest) ([]string, error) { + if expErr != nil { + return nil, expErr + } + return groupNames, nil + } } func getTestCreds(uid uint32, gid uint32) *security.DomainInfo { @@ -116,44 +155,10 @@ func getTestCreds(uid uint32, gid uint32) *security.DomainInfo { return security.InitDomainInfo(creds, "test") } -func TestAuthSysRequestFromCreds_returnsAuthSys(t *testing.T) { - ext := &MockExt{} - uid := uint32(15) - gid := uint32(2001) - gids := []uint32{1, 2, 3} - expectedUser := "myuser" - expectedGroup := "mygroup" - expectedGroupList := []string{"group1", "group2", "group3"} - creds := getTestCreds(uid, gid) - - ext.LookupUserIDResult = &MockUser{ - username: expectedUser, - groupIDs: gids, - } - ext.LookupGroupIDResults = []*user.Group{ - { - Name: expectedGroup, - }, - } - - for _, grp := range expectedGroupList { - ext.LookupGroupIDResults = append(ext.LookupGroupIDResults, - &user.Group{ - Name: grp, - }) - } - - result, err := AuthSysRequestFromCreds(ext, creds, nil) - - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - if result == nil { - t.Fatal("Credential was nil") - } +func verifyCredential(t *testing.T, cred *Credential, expHostname, expUserPrinc, expGroupPrinc string, expGroupPrincs ...string) { + t.Helper() - token := result.GetToken() + token := cred.GetToken() if token == nil { t.Fatal("Token was nil") } @@ -163,111 +168,128 @@ func TestAuthSysRequestFromCreds_returnsAuthSys(t *testing.T) { } authsys := &Sys{} - err = proto.Unmarshal(token.GetData(), authsys) + err := proto.Unmarshal(token.GetData(), authsys) if err != nil { t.Fatal("Failed to unmarshal token data") } - if authsys.GetUser() != expectedUser+"@" { + if authsys.GetMachinename() != expHostname { + t.Errorf("AuthSys had bad hostname: %v", authsys.GetMachinename()) + } + + if authsys.GetUser() != expUserPrinc { t.Errorf("AuthSys had bad username: %v", authsys.GetUser()) } - if authsys.GetGroup() != expectedGroup+"@" { + if authsys.GetGroup() != expGroupPrinc { t.Errorf("AuthSys had bad group name: %v", authsys.GetGroup()) } for i, group := range authsys.GetGroups() { - if group != expectedGroupList[i]+"@" { + if group != expGroupPrincs[i] { t.Errorf("AuthSys had bad group in list (idx %v): %v", i, group) } } } -func TestAuthSysRequestFromCreds_UidLookupFails(t *testing.T) { - ext := &MockExt{} - uid := uint32(15) - creds := getTestCreds(uid, 500) - - ext.LookupUserIDErr = errors.New("LookupUserID test error") - expectedErr := fmt.Errorf("Failed to lookup uid %v: %v", uid, - ext.LookupUserIDErr) - - result, err := AuthSysRequestFromCreds(ext, creds, nil) - - if result != nil { - t.Error("Expected a nil result") - } - - if err == nil { - t.Fatal("Expected an error") - } - - if err.Error() != expectedErr.Error() { - t.Errorf("Expected error '%v', got '%v'", expectedErr, err) - } -} - -func TestAuthSysRequestFromCreds_GidLookupFails(t *testing.T) { - ext := &MockExt{} - gid := uint32(205) - creds := getTestCreds(12, gid) - - ext.LookupUserIDResult = &MockUser{ - username: "user@", - groupIDs: []uint32{1, 2}, - } - - ext.LookupGroupIDErr = errors.New("LookupGroupID test error") - expectedErr := fmt.Errorf("Failed to lookup gid %v: %v", gid, - ext.LookupGroupIDErr) - - result, err := AuthSysRequestFromCreds(ext, creds, nil) - - if result != nil { - t.Error("Expected a nil result") - } - - if err == nil { - t.Fatal("Expected an error") - } - - if err.Error() != expectedErr.Error() { - t.Errorf("Expected error '%v', got '%v'", expectedErr, err) - } -} - -func TestAuthSysRequestFromCreds_GroupIDListFails(t *testing.T) { - ext := &MockExt{} - creds := getTestCreds(12, 15) - testUser := &MockUser{ - username: "user@", - groupIDs: []uint32{1, 2}, +func TestAuth_GetSignedCred(t *testing.T) { + testHostname := "test-host.domain.foo" + testUsername := "test-user" + testGroup := "test-group" + testGroupList := []string{"group1", "group2", "group3"} + + expectedHostname := "test-host" + expectedUser := testUsername + "@" + expectedGroup := testGroup + "@" + expectedGroupList := make([]string, len(testGroupList)) + for i, group := range testGroupList { + expectedGroupList[i] = group + "@" } - ext.LookupUserIDResult = testUser - - ext.LookupGroupIDResults = []*user.Group{ - { - Name: "group@", + for name, tc := range map[string]struct { + req *CredentialRequest + expErr error + }{ + "nil request": { + req: nil, + expErr: errors.New("is nil"), + }, + "nil DomainInfo": { + req: &CredentialRequest{}, + expErr: errors.New("No domain info supplied"), + }, + "bad hostname": { + req: func() *CredentialRequest { + req := NewCredentialRequest(getTestCreds(1, 2), nil) + req.getHostname = testHostnameFn(errors.New("bad hostname"), "") + return req + }(), + expErr: errors.New("bad hostname"), + }, + "bad uid": { + req: func() *CredentialRequest { + req := NewCredentialRequest(getTestCreds(1, 2), nil) + req.getUser = testUserFn(errors.New("bad uid"), "") + return req + }(), + expErr: errors.New("bad uid"), + }, + "bad gid": { + req: func() *CredentialRequest { + req := NewCredentialRequest(getTestCreds(1, 2), nil) + req.getGroup = testGroupFn(errors.New("bad gid"), "") + return req + }(), + expErr: errors.New("bad gid"), + }, + "bad group IDs": { + req: func() *CredentialRequest { + req := NewCredentialRequest(getTestCreds(1, 2), nil) + req.getGroupIds = testGroupIdsFn(errors.New("bad group IDs")) + return req + }(), + expErr: errors.New("bad group IDs"), }, + "bad group names": { + req: func() *CredentialRequest { + req := NewCredentialRequest(getTestCreds(1, 2), nil) + req.getGroupNames = testGroupNamesFn(errors.New("bad group names")) + return req + }(), + expErr: errors.New("bad group names"), + }, + "valid": { + req: func() *CredentialRequest { + req := NewCredentialRequest(getTestCreds(1, 2), nil) + req.getHostname = testHostnameFn(nil, testHostname) + req.getUser = testUserFn(nil, testUsername) + req.getGroup = testGroupFn(nil, testGroup) + req.getGroupNames = testGroupNamesFn(nil, testGroupList...) + return req + }(), + }, + } { + t.Run(name, func(t *testing.T) { + cred, gotErr := GetSignedCredential(tc.req) + test.CmpErr(t, tc.expErr, gotErr) + if tc.expErr != nil { + return + } + + verifyCredential(t, cred, expectedHostname, expectedUser, expectedGroup, expectedGroupList...) + }) } +} - testUser.groupIDErr = errors.New("GroupIDs test error") - expectedErr := fmt.Errorf("Failed to get group IDs for user %v: %v", - testUser.username, - testUser.groupIDErr) - - result, err := AuthSysRequestFromCreds(ext, creds, nil) - - if result != nil { - t.Error("Expected a nil result") - } +func TestAuth_CredentialRequestOverrides(t *testing.T) { + req := NewCredentialRequest(getTestCreds(1, 2), nil) + req.getHostname = testHostnameFn(nil, "test-host") + req.WithUserAndGroup("test-user", "test-group", "test-secondary") - if err == nil { - t.Fatal("Expected an error") + cred, err := GetSignedCredential(req) + if err != nil { + t.Fatalf("Failed to get credential: %s", err) } - if err.Error() != expectedErr.Error() { - t.Errorf("Expected error '%v', got '%v'", expectedErr, err) - } + verifyCredential(t, cred, "test-host", "test-user@", "test-group@", "test-secondary@") } diff --git a/src/control/security/auth/mocks.go b/src/control/security/auth/mocks.go deleted file mode 100644 index 1897d7b9a1d..00000000000 --- a/src/control/security/auth/mocks.go +++ /dev/null @@ -1,91 +0,0 @@ -// -// (C) Copyright 2020-2021 Intel Corporation. -// -// SPDX-License-Identifier: BSD-2-Clause-Patent -// - -package auth - -import ( - "os/user" - "strconv" - - "github.com/pkg/errors" -) - -// Mocks - -type MockUser struct { - username string - uid uint32 - groupIDs []uint32 - groupIDErr error -} - -func (u *MockUser) Username() string { - return u.username -} - -func (u *MockUser) GroupIDs() ([]uint32, error) { - return u.groupIDs, u.groupIDErr -} - -func (u *MockUser) Gid() (uint32, error) { - if len(u.groupIDs) == 0 { - return 0, errors.New("no mock gids to return") - } - return u.groupIDs[0], nil -} - -func NewMockExtWithUser(name string, uid uint32, gids ...uint32) *MockExt { - me := &MockExt{ - LookupUserIDResult: &MockUser{ - uid: uid, - username: name, - groupIDs: gids, - }, - } - - if len(gids) > 0 { - for _, gid := range gids { - me.LookupGroupIDResults = append(me.LookupGroupIDResults, &user.Group{ - Gid: strconv.Itoa(int(gid)), - }) - } - } - - return me -} - -type MockExt struct { - LookupUserIDUid uint32 - LookupUserIDResult User - LookupUserIDErr error - LookupGroupIDGid uint32 - LookupGroupIDResults []*user.Group - LookupGroupIDCallCount uint32 - LookupGroupIDErr error -} - -func (e *MockExt) Current() (User, error) { - return e.LookupUserIDResult, e.LookupUserIDErr -} - -func (e *MockExt) LookupUserID(uid uint32) (User, error) { - e.LookupUserIDUid = uid - return e.LookupUserIDResult, e.LookupUserIDErr -} - -func (e *MockExt) LookupGroupID(gid uint32) (*user.Group, error) { - e.LookupGroupIDGid = gid - var result *user.Group - if len(e.LookupGroupIDResults) > 0 { - resultIdx := int(e.LookupGroupIDCallCount) - if len(e.LookupGroupIDResults) <= resultIdx { - resultIdx = len(e.LookupGroupIDResults) - 1 - } - result = e.LookupGroupIDResults[resultIdx] - } - e.LookupGroupIDCallCount++ - return result, e.LookupGroupIDErr -} diff --git a/src/control/security/config.go b/src/control/security/config.go index 2f8cda7c25e..3558a7735d1 100644 --- a/src/control/security/config.go +++ b/src/control/security/config.go @@ -13,6 +13,7 @@ import ( "fmt" "io/fs" "os" + "strconv" "time" "github.com/pkg/errors" @@ -32,6 +33,70 @@ const ( defaultInsecure = false ) +// MappedClientUser represents a client user that is mapped to a uid. +type MappedClientUser struct { + User string `yaml:"user"` + Group string `yaml:"group"` + Groups []string `yaml:"groups"` +} + +const ( + defaultMapUser = "default" + defaultMapKey = ^uint32(0) +) + +// ClientUserMap is a map of uids to mapped client users. +type ClientUserMap map[uint32]*MappedClientUser + +func (cm *ClientUserMap) UnmarshalYAML(unmarshal func(interface{}) error) error { + strKeyMap := make(map[string]*MappedClientUser) + if err := unmarshal(&strKeyMap); err != nil { + return err + } + + tmp := make(ClientUserMap) + for strKey, value := range strKeyMap { + var key uint32 + switch strKey { + case defaultMapUser: + key = defaultMapKey + default: + parsedKey, err := strconv.ParseUint(strKey, 10, 32) + if err != nil { + return errors.Wrapf(err, "invalid uid %s", strKey) + } + + switch parsedKey { + case uint64(defaultMapKey): + return errors.Errorf("uid %d is reserved", parsedKey) + default: + key = uint32(parsedKey) + } + } + + tmp[key] = value + } + *cm = tmp + + return nil +} + +// Lookup attempts to resolve the supplied uid to a mapped +// client user. If the uid is not in the map, the default map key +// is returned. If the default map key is not found, nil is returned. +func (cm ClientUserMap) Lookup(uid uint32) *MappedClientUser { + if mu, found := cm[uid]; found { + return mu + } + return cm[defaultMapKey] +} + +// CredentialConfig contains configuration details for managing user +// credentials. +type CredentialConfig struct { + ClientUserMap ClientUserMap `yaml:"client_user_map,omitempty"` +} + // TransportConfig contains all the information on whether or not to use // certificates and their location if their use is specified. type TransportConfig struct { diff --git a/src/control/security/config_test.go b/src/control/security/config_test.go index 72fb1345bf5..4e33f3769b8 100644 --- a/src/control/security/config_test.go +++ b/src/control/security/config_test.go @@ -11,6 +11,7 @@ import ( "crypto" "crypto/x509" "encoding/pem" + "fmt" "io/ioutil" "os" "path/filepath" @@ -18,10 +19,12 @@ import ( "testing" "time" - "github.com/daos-stack/daos/src/control/common/test" - "github.com/daos-stack/daos/src/control/fault" "github.com/google/go-cmp/cmp" "github.com/pkg/errors" + "gopkg.in/yaml.v2" + + "github.com/daos-stack/daos/src/control/common/test" + "github.com/daos-stack/daos/src/control/fault" ) func InsecureTC() *TransportConfig { @@ -409,3 +412,100 @@ func TestSecurity_DefaultTransportConfigs(t *testing.T) { }) } } + +func TestSecurity_ClientUserMap(t *testing.T) { + for name, tc := range map[string]struct { + cfgYaml string + expMap ClientUserMap + expErr error + }{ + "empty": {}, + "defaultKey": { + cfgYaml: fmt.Sprintf(` +%d: + user: whoops +`, defaultMapKey), + expErr: errors.New("reserved"), + }, + "invalid uid (negative)": { + cfgYaml: ` +-1: + user: whoops +`, + expErr: errors.New("invalid uid"), + }, + "invalid uid (words)": { + cfgYaml: ` +blah: + user: whoops +`, + expErr: errors.New("invalid uid"), + }, + "invalid mapped user": { + cfgYaml: ` +1234: +user: whoops +`, + expErr: errors.New("unmarshal error"), + }, + "good": { + cfgYaml: ` +default: + user: banana + group: rama + groups: [ding, dong] +1234: + user: abc + group: def + groups: [yabba, dabba, doo] +5678: + user: ghi + group: jkl + groups: [mno, pqr, stu] +`, + expMap: ClientUserMap{ + defaultMapKey: { + User: "banana", + Group: "rama", + Groups: []string{"ding", "dong"}, + }, + 1234: { + User: "abc", + Group: "def", + Groups: []string{"yabba", "dabba", "doo"}, + }, + 5678: { + User: "ghi", + Group: "jkl", + Groups: []string{"mno", "pqr", "stu"}, + }, + }, + }, + } { + t.Run(name, func(t *testing.T) { + var result ClientUserMap + err := yaml.Unmarshal([]byte(tc.cfgYaml), &result) + test.CmpErr(t, tc.expErr, err) + if tc.expErr != nil { + return + } + if diff := cmp.Diff(tc.expMap, result); diff != "" { + t.Fatalf("unexpected ClientUserMap (-want, +got)\n %s", diff) + } + + for uid, exp := range tc.expMap { + gotUser := result.Lookup(uid) + if diff := cmp.Diff(exp.User, gotUser.User); diff != "" { + t.Fatalf("unexpected User (-want, +got)\n %s", diff) + } + } + + if expDefUser, found := tc.expMap[defaultMapKey]; found { + gotDefUser := result.Lookup(1234567) + if diff := cmp.Diff(expDefUser, gotDefUser); diff != "" { + t.Fatalf("unexpected DefaultUser (-want, +got)\n %s", diff) + } + } + }) + } +} diff --git a/src/control/security/domain_info.go b/src/control/security/domain_info.go index 15bfb06dd77..f0812c17f37 100644 --- a/src/control/security/domain_info.go +++ b/src/control/security/domain_info.go @@ -1,5 +1,5 @@ // -// (C) Copyright 2018-2021 Intel Corporation. +// (C) Copyright 2018-2024 Intel Corporation. // // SPDX-License-Identifier: BSD-2-Clause-Patent // @@ -63,6 +63,11 @@ func (d *DomainInfo) String() string { return outStr } +// Pid returns the PID obtained from the domain socket +func (d *DomainInfo) Pid() int32 { + return d.creds.Pid +} + // Uid returns the UID obtained from the domain socket func (d *DomainInfo) Uid() uint32 { return d.creds.Uid diff --git a/utils/config/daos_agent.yml b/utils/config/daos_agent.yml index 24af15f955e..6268cf98f76 100644 --- a/utils/config/daos_agent.yml +++ b/utils/config/daos_agent.yml @@ -47,7 +47,23 @@ ## default 0 (do not retain telemetry after client exit) #telemetry_retain: 1m -## Transport Credentials Specifying certificates to secure communications +## Configuration for user credential management. +# +#credential_config: +# # If the agent should be able to resolve unknown client uids and gids +# # (e.g. when running in a container) into ACL principal names, then a +# # client user map may be defined. The optional "default" uid is a special +# # case and applies if no other matches are found. +# client_user_map: +# default: +# user: nobody +# group: nobody +# 1000: +# user: ralph +# group: stanley + +## Configuration for SSL certificates used to secure management traffic +# and authenticate/authorize management components. # #transport_config: # # In order to disable transport security, uncomment and set allow_insecure @@ -60,7 +76,7 @@ # cert: /etc/daos/certs/agent.crt # # Key portion of Agent Certificate # key: /etc/daos/certs/agent.key - +# # Use the given directory for creating unix domain sockets # # NOTE: Do not change this when running under systemd control. If it needs to