diff --git a/lib/auth/auth.go b/lib/auth/auth.go index 451db895d04cd..37e6d9b3d3f3c 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -1089,7 +1089,7 @@ func (a *Server) generateUserCert(req certRequest) (*proto.Certs, error) { // Add the special join-only principal used for joining sessions. // All users have access to this and join RBAC rules are checked after the connection is established. - allowedLogins = append(allowedLogins, "-teleport-internal-join") + allowedLogins = append(allowedLogins, teleport.SSHSessionJoinPrincipal) requestedResourcesStr, err := types.ResourceIDsToString(req.checker.GetAllowedResourceIDs()) if err != nil { diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index 9bc50876f4858..b0a6905f0442e 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -272,8 +272,8 @@ func (a *ServerWithRoles) CreateSessionTracker(ctx context.Context, tracker type } func (a *ServerWithRoles) filterSessionTracker(ctx context.Context, joinerRoles []types.Role, tracker types.SessionTracker) bool { - evaluator := NewSessionAccessEvaluator(tracker.GetHostPolicySets(), tracker.GetSessionKind()) - modes := evaluator.CanJoin(SessionAccessContext{Roles: joinerRoles}) + evaluator := NewSessionAccessEvaluator(tracker.GetHostPolicySets(), tracker.GetSessionKind(), tracker.GetHostUser()) + modes := evaluator.CanJoin(SessionAccessContext{Username: a.context.User.GetName(), Roles: joinerRoles}) if len(modes) == 0 { return false diff --git a/lib/auth/session_access.go b/lib/auth/session_access.go index 8ad4e2348e5c3..eb42f718d8ef1 100644 --- a/lib/auth/session_access.go +++ b/lib/auth/session_access.go @@ -43,14 +43,16 @@ type SessionAccessEvaluator struct { kind types.SessionKind policySets []*types.SessionTrackerPolicySet isModerated bool + owner string } // NewSessionAccessEvaluator creates a new session access evaluator for a given session kind // and a set of roles attached to the host user. -func NewSessionAccessEvaluator(policySets []*types.SessionTrackerPolicySet, kind types.SessionKind) SessionAccessEvaluator { +func NewSessionAccessEvaluator(policySets []*types.SessionTrackerPolicySet, kind types.SessionKind, owner string) SessionAccessEvaluator { e := SessionAccessEvaluator{ kind: kind, policySets: policySets, + owner: owner, } for _, policySet := range policySets { @@ -189,6 +191,11 @@ func (e *SessionAccessEvaluator) CanJoin(user SessionAccessContext) []types.Sess return preAccessControlsModes(e.kind) } + // Session owners can always join their own sessions. + if user.Username == e.owner { + return []types.SessionParticipantMode{types.SessionPeerMode, types.SessionModeratorMode, types.SessionObserverMode} + } + var modes []types.SessionParticipantMode // Loop over every allow policy attached the participant and check it's applicability. diff --git a/lib/auth/session_access_test.go b/lib/auth/session_access_test.go index c2794c058609c..02ebdd80097a1 100644 --- a/lib/auth/session_access_test.go +++ b/lib/auth/session_access_test.go @@ -28,6 +28,7 @@ type startTestCase struct { host []types.Role sessionKinds []types.SessionKind participants []SessionAccessContext + owner string expected []bool } @@ -185,7 +186,7 @@ func TestSessionAccessStart(t *testing.T) { } for i, kind := range testCase.sessionKinds { - evaluator := NewSessionAccessEvaluator(policies, kind) + evaluator := NewSessionAccessEvaluator(policies, kind, testCase.owner) result, _, err := evaluator.FulfilledFor(testCase.participants) require.NoError(t, err) require.Equal(t, testCase.expected[i], result) @@ -199,6 +200,7 @@ type joinTestCase struct { host types.Role sessionKinds []types.SessionKind participant SessionAccessContext + owner string expected []bool } @@ -250,6 +252,25 @@ func successGlobJoinTestCase(t *testing.T) joinTestCase { } } +func successSameUserJoinTestCase(t *testing.T) joinTestCase { + hostRole, err := types.NewRole("host", types.RoleSpecV5{}) + require.NoError(t, err) + participantRole, err := types.NewRole("participant", types.RoleSpecV5{}) + require.NoError(t, err) + + return joinTestCase{ + name: "successSameUser", + host: hostRole, + sessionKinds: []types.SessionKind{types.SSHSessionKind}, + participant: SessionAccessContext{ + Username: "john", + Roles: []types.Role{participantRole}, + }, + owner: "john", + expected: []bool{true}, + } +} + func failRoleJoinTestCase(t *testing.T) joinTestCase { hostRole, err := types.NewRole("host", types.RoleSpecV5{}) require.NoError(t, err) @@ -314,6 +335,7 @@ func TestSessionAccessJoin(t *testing.T) { testCases := []joinTestCase{ successJoinTestCase(t), successGlobJoinTestCase(t), + successSameUserJoinTestCase(t), failRoleJoinTestCase(t), failKindJoinTestCase(t), versionDefaultJoinTestCase(t), @@ -323,7 +345,7 @@ func TestSessionAccessJoin(t *testing.T) { t.Run(testCase.name, func(t *testing.T) { for i, kind := range testCase.sessionKinds { policy := testCase.host.GetSessionPolicySet() - evaluator := NewSessionAccessEvaluator([]*types.SessionTrackerPolicySet{&policy}, kind) + evaluator := NewSessionAccessEvaluator([]*types.SessionTrackerPolicySet{&policy}, kind, testCase.owner) result := evaluator.CanJoin(testCase.participant) require.Equal(t, testCase.expected[i], len(result) > 0) } diff --git a/lib/kube/proxy/forwarder.go b/lib/kube/proxy/forwarder.go index ea0939c52502d..1aa3bb60d8f42 100644 --- a/lib/kube/proxy/forwarder.go +++ b/lib/kube/proxy/forwarder.go @@ -1011,7 +1011,7 @@ func (f *Forwarder) execNonInteractive(ctx *authContext, w http.ResponseWriter, policySets = append(policySets, &policySet) } - authorizer := auth.NewSessionAccessEvaluator(policySets, types.KubernetesSessionKind) + authorizer := auth.NewSessionAccessEvaluator(policySets, types.KubernetesSessionKind, ctx.User.GetName()) canStart, _, err := authorizer.FulfilledFor(nil) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/kube/proxy/sess.go b/lib/kube/proxy/sess.go index 0a405bada0f6a..ebb38f7effc4c 100644 --- a/lib/kube/proxy/sess.go +++ b/lib/kube/proxy/sess.go @@ -319,7 +319,7 @@ func newSession(ctx authContext, forwarder *Forwarder, req *http.Request, params } q := req.URL.Query() - accessEvaluator := auth.NewSessionAccessEvaluator(policySets, types.KubernetesSessionKind) + accessEvaluator := auth.NewSessionAccessEvaluator(policySets, types.KubernetesSessionKind, ctx.User.GetName()) io := srv.NewTermManager() diff --git a/lib/services/role.go b/lib/services/role.go index d64b9803b079c..251d25b9a1951 100644 --- a/lib/services/role.go +++ b/lib/services/role.go @@ -122,7 +122,7 @@ func NewImplicitRole() types.Role { // // Used in tests only. func RoleForUser(u types.User) types.Role { - role, _ := types.NewRoleV3(RoleNameForUser(u.GetName()), types.RoleSpecV5{ + role, _ := types.NewRole(RoleNameForUser(u.GetName()), types.RoleSpecV5{ Options: types.RoleOptions{ CertificateFormat: constants.CertificateFormatStandard, MaxSessionTTL: types.NewDuration(defaults.MaxCertDuration), @@ -150,6 +150,14 @@ func RoleForUser(u types.User) types.Role { types.NewRule(types.KindLock, RW()), types.NewRule(types.KindToken, RW()), }, + JoinSessions: []*types.SessionJoinPolicy{ + { + Name: "foo", + Roles: []string{"*"}, + Kinds: []string{string(types.SSHSessionKind)}, + Modes: []string{string(types.SessionPeerMode)}, + }, + }, }, }) return role diff --git a/lib/srv/forward/sshserver.go b/lib/srv/forward/sshserver.go index cd124ccbfc70b..9e53913b9d2ae 100644 --- a/lib/srv/forward/sshserver.go +++ b/lib/srv/forward/sshserver.go @@ -928,6 +928,15 @@ func (s *Server) dispatch(ctx context.Context, ch ssh.Channel, req *ssh.Request, // SSH will send them anyway but it seems fine to silently drop them. case sshutils.SubsystemRequest: return s.handleSubsystem(ctx, ch, req, scx) + case sshutils.AgentForwardRequest: + // to maintain interoperability with OpenSSH, agent forwarding requests + // should never fail, all errors should be logged and we should continue + // processing requests. + err := s.handleAgentForward(ch, req, scx) + if err != nil { + s.log.Debug(err) + } + return nil default: return trace.AccessDenied("attempted %v request in join-only mode", req.Type) } diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index dc8059baf9351..9fed71e35d208 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -1594,6 +1594,22 @@ func (s *Server) dispatch(ctx context.Context, ch ssh.Channel, req *ssh.Request, // SSH will send them anyway but it seems fine to silently drop them. case sshutils.SubsystemRequest: return s.handleSubsystem(ctx, ch, req, serverContext) + case sshutils.AgentForwardRequest: + // This happens when SSH client has agent forwarding enabled, in this case + // client sends a special request, in return SSH server opens new channel + // that uses SSH protocol for agent drafted here: + // https://tools.ietf.org/html/draft-ietf-secsh-agent-02 + // the open ssh proto spec that we implement is here: + // http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.agent + + // to maintain interoperability with OpenSSH, agent forwarding requests + // should never fail, all errors should be logged and we should continue + // processing requests. + err := s.handleAgentForwardNode(req, serverContext) + if err != nil { + log.Warn(err) + } + return nil default: return trace.AccessDenied("attempted %v request in join-only mode", req.Type) } diff --git a/lib/srv/sess.go b/lib/srv/sess.go index 8747a08f72b1f..5b166672ec038 100644 --- a/lib/srv/sess.go +++ b/lib/srv/sess.go @@ -544,7 +544,7 @@ func newSession(id rsession.ID, r *SessionRegistry, ctx *ServerContext) (*sessio if err != nil { return nil, trace.Wrap(err) } - if existing.Login != rsess.Login { + if existing.Login != rsess.Login && rsess.Login != teleport.SSHSessionJoinPrincipal { return nil, trace.AccessDenied( "can't switch users from %v to %v for session %v", rsess.Login, existing.Login, id) @@ -570,7 +570,7 @@ func newSession(id rsession.ID, r *SessionRegistry, ctx *ServerContext) (*sessio stopC: make(chan struct{}), startTime: startTime, serverCtx: ctx.srv.Context(), - access: auth.NewSessionAccessEvaluator(policySets, types.SSHSessionKind), + access: auth.NewSessionAccessEvaluator(policySets, types.SSHSessionKind, ctx.Identity.TeleportUser), scx: ctx, presenceEnabled: ctx.Identity.Certificate.Extensions[teleport.CertExtensionMFAVerified] != "", io: NewTermManager(), @@ -1544,7 +1544,7 @@ func (s *session) addParty(p *party, mode types.SessionParticipantMode) error { s.mu.Lock() defer s.mu.Unlock() - if s.login != p.login { + if s.login != p.login && p.login != teleport.SSHSessionJoinPrincipal { return trace.AccessDenied( "can't switch users from %v to %v for session %v", s.login, p.login, s.id) @@ -1637,7 +1637,8 @@ func (s *session) addParty(p *party, mode types.SessionParticipantMode) error { func (s *session) join(ch ssh.Channel, ctx *ServerContext, mode types.SessionParticipantMode) (*party, error) { if ctx.Identity.TeleportUser != s.initiator { accessContext := auth.SessionAccessContext{ - Roles: ctx.Identity.AccessChecker.Roles(), + Username: ctx.Identity.TeleportUser, + Roles: ctx.Identity.AccessChecker.Roles(), } modes := s.access.CanJoin(accessContext) @@ -1761,18 +1762,13 @@ func (p *party) closeUnderSessionLock() { // on an interval until the session tracker is closed. func (s *session) trackSession(teleportUser string, policySet []*types.SessionTrackerPolicySet) error { trackerSpec := types.SessionTrackerSpecV1{ - SessionID: s.id.String(), - Kind: string(types.SSHSessionKind), - State: types.SessionState_SessionStatePending, - Hostname: s.registry.Srv.GetInfo().GetHostname(), - Address: s.scx.srv.ID(), - ClusterName: s.scx.ClusterName, - Login: s.login, - Participants: []types.Participant{{ - ID: teleportUser, - User: teleportUser, - LastActive: s.registry.clock.Now(), - }}, + SessionID: s.id.String(), + Kind: string(types.SSHSessionKind), + State: types.SessionState_SessionStatePending, + Hostname: s.registry.Srv.GetInfo().GetHostname(), + Address: s.scx.srv.ID(), + ClusterName: s.scx.ClusterName, + Login: s.login, HostUser: teleportUser, Reason: s.scx.env[teleport.EnvSSHSessionReason], HostPolicies: policySet, diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 0cef0bf6470ac..f639e328f6021 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -1970,8 +1970,6 @@ func (h *Handler) clusterNodesGet(w http.ResponseWriter, r *http.Request, p http // // {"server_id": "uuid", "login": "admin", "term": {"h": 120, "w": 100}, "sid": "123"} // -// Session id can be empty -// // Successful response is a websocket stream that allows read write to the server // func (h *Handler) siteNodeConnect( @@ -2130,7 +2128,7 @@ func (h *Handler) siteSessionsGet(w http.ResponseWriter, r *http.Request, p http sessions := make([]session.Session, 0, len(trackers)) for _, tracker := range trackers { - if tracker.GetSessionKind() == types.SSHSessionKind { + if tracker.GetSessionKind() == types.SSHSessionKind && tracker.GetState() != types.SessionState_SessionStateTerminated { sessions = append(sessions, trackerToLegacySession(tracker, p.ByName("site"))) } } @@ -2163,7 +2161,7 @@ func (h *Handler) siteSessionGet(w http.ResponseWriter, r *http.Request, p httpr return nil, trace.Wrap(err) } - if tracker.GetSessionKind() != types.SSHSessionKind { + if tracker.GetSessionKind() != types.SSHSessionKind || tracker.GetState() == types.SessionState_SessionStateTerminated { return nil, trace.NotFound("session %v not found", sessionID) } diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 0c581a1a2c487..9c7bd44822ba1 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -3312,6 +3312,10 @@ func (mock authProviderMock) GetSessionEvents(n string, s session.ID, c int, p b return []events.EventFields{}, nil } +func (mock authProviderMock) GetSessionTracker(ctx context.Context, sessionID string) (types.SessionTracker, error) { + return nil, trace.NotFound("foo") +} + func (s *WebSuite) makeTerminal(t *testing.T, pack *authPack, opts ...session.ID) (*websocket.Conn, error) { var sessionID session.ID if len(opts) == 0 { diff --git a/lib/web/terminal.go b/lib/web/terminal.go index de11ce877f338..90319812dff67 100644 --- a/lib/web/terminal.go +++ b/lib/web/terminal.go @@ -84,6 +84,7 @@ type TerminalRequest struct { type AuthProvider interface { GetNodes(ctx context.Context, namespace string) ([]types.Server, error) GetSessionEvents(namespace string, sid session.ID, after int, includePrintEvents bool) ([]events.EventFields, error) + GetSessionTracker(ctx context.Context, sessionID string) (types.SessionTracker, error) } // NewTerminal creates a web-based terminal based on WebSockets and returns a @@ -116,6 +117,17 @@ func NewTerminal(ctx context.Context, req TerminalRequest, authProvider AuthProv return nil, trace.BadParameter("invalid server name %q: %v", req.Server, err) } + var join bool + _, err = authProvider.GetSessionTracker(ctx, string(req.SessionID)) + switch { + case trace.IsNotFound(err): + join = false + case err != nil: + return nil, trace.Wrap(err) + default: + join = true + } + return &TerminalHandler{ log: logrus.WithFields(logrus.Fields{ trace.Component: teleport.ComponentWebsocket, @@ -129,6 +141,7 @@ func NewTerminal(ctx context.Context, req TerminalRequest, authProvider AuthProv encoder: unicode.UTF8.NewEncoder(), decoder: unicode.UTF8.NewDecoder(), wsLock: &sync.Mutex{}, + join: join, }, nil } @@ -178,6 +191,9 @@ type TerminalHandler struct { closeOnce sync.Once wsLock *sync.Mutex + + // join is set if we're joining an existing session + join bool } // Serve builds a connect to the remote node and then pumps back two types of @@ -301,8 +317,13 @@ func (t *TerminalHandler) makeClient(ws *websocket.Conn, r *http.Request) (*clie // communicate over the websocket. stream := t.asTerminalStream(ws) + if t.join { + clientConfig.HostLogin = teleport.SSHSessionJoinPrincipal + } else { + clientConfig.HostLogin = t.params.Login + } + clientConfig.ForwardAgent = client.ForwardAgentLocal - clientConfig.HostLogin = t.params.Login clientConfig.Namespace = t.params.Namespace clientConfig.Stdout = stream clientConfig.Stderr = stream