Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Always allow session owners to join own sessions + only list active trackers in WebUI #13660

Merged
merged 19 commits into from
Jun 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -1089,7 +1089,7 @@ func (a *Server) generateUserCert(req certRequest) (*proto.Certs, error) {

// Add the special join-only principal used for joining sessions.
// All users have access to this and join RBAC rules are checked after the connection is established.
allowedLogins = append(allowedLogins, "-teleport-internal-join")
allowedLogins = append(allowedLogins, teleport.SSHSessionJoinPrincipal)

requestedResourcesStr, err := types.ResourceIDsToString(req.checker.GetAllowedResourceIDs())
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions lib/auth/auth_with_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,8 @@ func (a *ServerWithRoles) CreateSessionTracker(ctx context.Context, tracker type
}

func (a *ServerWithRoles) filterSessionTracker(ctx context.Context, joinerRoles []types.Role, tracker types.SessionTracker) bool {
evaluator := NewSessionAccessEvaluator(tracker.GetHostPolicySets(), tracker.GetSessionKind())
modes := evaluator.CanJoin(SessionAccessContext{Roles: joinerRoles})
evaluator := NewSessionAccessEvaluator(tracker.GetHostPolicySets(), tracker.GetSessionKind(), tracker.GetHostUser())
modes := evaluator.CanJoin(SessionAccessContext{Username: a.context.User.GetName(), Roles: joinerRoles})

if len(modes) == 0 {
return false
Expand Down
9 changes: 8 additions & 1 deletion lib/auth/session_access.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,16 @@ type SessionAccessEvaluator struct {
kind types.SessionKind
policySets []*types.SessionTrackerPolicySet
isModerated bool
owner string
}

// NewSessionAccessEvaluator creates a new session access evaluator for a given session kind
// and a set of roles attached to the host user.
func NewSessionAccessEvaluator(policySets []*types.SessionTrackerPolicySet, kind types.SessionKind) SessionAccessEvaluator {
func NewSessionAccessEvaluator(policySets []*types.SessionTrackerPolicySet, kind types.SessionKind, owner string) SessionAccessEvaluator {
e := SessionAccessEvaluator{
kind: kind,
policySets: policySets,
owner: owner,
}

for _, policySet := range policySets {
Expand Down Expand Up @@ -189,6 +191,11 @@ func (e *SessionAccessEvaluator) CanJoin(user SessionAccessContext) []types.Sess
return preAccessControlsModes(e.kind)
}

// Session owners can always join their own sessions.
if user.Username == e.owner {
return []types.SessionParticipantMode{types.SessionPeerMode, types.SessionModeratorMode, types.SessionObserverMode}
}

var modes []types.SessionParticipantMode

// Loop over every allow policy attached the participant and check it's applicability.
Expand Down
26 changes: 24 additions & 2 deletions lib/auth/session_access_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type startTestCase struct {
host []types.Role
sessionKinds []types.SessionKind
participants []SessionAccessContext
owner string
expected []bool
}

Expand Down Expand Up @@ -185,7 +186,7 @@ func TestSessionAccessStart(t *testing.T) {
}

for i, kind := range testCase.sessionKinds {
evaluator := NewSessionAccessEvaluator(policies, kind)
evaluator := NewSessionAccessEvaluator(policies, kind, testCase.owner)
result, _, err := evaluator.FulfilledFor(testCase.participants)
require.NoError(t, err)
require.Equal(t, testCase.expected[i], result)
Expand All @@ -199,6 +200,7 @@ type joinTestCase struct {
host types.Role
sessionKinds []types.SessionKind
participant SessionAccessContext
owner string
expected []bool
}

Expand Down Expand Up @@ -250,6 +252,25 @@ func successGlobJoinTestCase(t *testing.T) joinTestCase {
}
}

func successSameUserJoinTestCase(t *testing.T) joinTestCase {
hostRole, err := types.NewRole("host", types.RoleSpecV5{})
require.NoError(t, err)
participantRole, err := types.NewRole("participant", types.RoleSpecV5{})
require.NoError(t, err)

return joinTestCase{
name: "successSameUser",
host: hostRole,
sessionKinds: []types.SessionKind{types.SSHSessionKind},
participant: SessionAccessContext{
Username: "john",
Roles: []types.Role{participantRole},
},
owner: "john",
expected: []bool{true},
}
}

func failRoleJoinTestCase(t *testing.T) joinTestCase {
hostRole, err := types.NewRole("host", types.RoleSpecV5{})
require.NoError(t, err)
Expand Down Expand Up @@ -314,6 +335,7 @@ func TestSessionAccessJoin(t *testing.T) {
testCases := []joinTestCase{
successJoinTestCase(t),
successGlobJoinTestCase(t),
successSameUserJoinTestCase(t),
failRoleJoinTestCase(t),
failKindJoinTestCase(t),
versionDefaultJoinTestCase(t),
Expand All @@ -323,7 +345,7 @@ func TestSessionAccessJoin(t *testing.T) {
t.Run(testCase.name, func(t *testing.T) {
for i, kind := range testCase.sessionKinds {
policy := testCase.host.GetSessionPolicySet()
evaluator := NewSessionAccessEvaluator([]*types.SessionTrackerPolicySet{&policy}, kind)
evaluator := NewSessionAccessEvaluator([]*types.SessionTrackerPolicySet{&policy}, kind, testCase.owner)
result := evaluator.CanJoin(testCase.participant)
require.Equal(t, testCase.expected[i], len(result) > 0)
}
Expand Down
2 changes: 1 addition & 1 deletion lib/kube/proxy/forwarder.go
Original file line number Diff line number Diff line change
Expand Up @@ -1011,7 +1011,7 @@ func (f *Forwarder) execNonInteractive(ctx *authContext, w http.ResponseWriter,
policySets = append(policySets, &policySet)
}

authorizer := auth.NewSessionAccessEvaluator(policySets, types.KubernetesSessionKind)
authorizer := auth.NewSessionAccessEvaluator(policySets, types.KubernetesSessionKind, ctx.User.GetName())
canStart, _, err := authorizer.FulfilledFor(nil)
if err != nil {
return nil, trace.Wrap(err)
Expand Down
2 changes: 1 addition & 1 deletion lib/kube/proxy/sess.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ func newSession(ctx authContext, forwarder *Forwarder, req *http.Request, params
}

q := req.URL.Query()
accessEvaluator := auth.NewSessionAccessEvaluator(policySets, types.KubernetesSessionKind)
accessEvaluator := auth.NewSessionAccessEvaluator(policySets, types.KubernetesSessionKind, ctx.User.GetName())

io := srv.NewTermManager()

Expand Down
10 changes: 9 additions & 1 deletion lib/services/role.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ func NewImplicitRole() types.Role {
//
// Used in tests only.
func RoleForUser(u types.User) types.Role {
role, _ := types.NewRoleV3(RoleNameForUser(u.GetName()), types.RoleSpecV5{
role, _ := types.NewRole(RoleNameForUser(u.GetName()), types.RoleSpecV5{
Options: types.RoleOptions{
CertificateFormat: constants.CertificateFormatStandard,
MaxSessionTTL: types.NewDuration(defaults.MaxCertDuration),
Expand Down Expand Up @@ -150,6 +150,14 @@ func RoleForUser(u types.User) types.Role {
types.NewRule(types.KindLock, RW()),
types.NewRule(types.KindToken, RW()),
},
JoinSessions: []*types.SessionJoinPolicy{
{
Name: "foo",
Roles: []string{"*"},
Kinds: []string{string(types.SSHSessionKind)},
Modes: []string{string(types.SessionPeerMode)},
},
},
},
})
return role
Expand Down
9 changes: 9 additions & 0 deletions lib/srv/forward/sshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,15 @@ func (s *Server) dispatch(ctx context.Context, ch ssh.Channel, req *ssh.Request,
// SSH will send them anyway but it seems fine to silently drop them.
case sshutils.SubsystemRequest:
return s.handleSubsystem(ctx, ch, req, scx)
case sshutils.AgentForwardRequest:
// to maintain interoperability with OpenSSH, agent forwarding requests
// should never fail, all errors should be logged and we should continue
// processing requests.
err := s.handleAgentForward(ch, req, scx)
if err != nil {
s.log.Debug(err)
}
return nil
default:
return trace.AccessDenied("attempted %v request in join-only mode", req.Type)
}
Expand Down
16 changes: 16 additions & 0 deletions lib/srv/regular/sshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -1594,6 +1594,22 @@ func (s *Server) dispatch(ctx context.Context, ch ssh.Channel, req *ssh.Request,
// SSH will send them anyway but it seems fine to silently drop them.
case sshutils.SubsystemRequest:
return s.handleSubsystem(ctx, ch, req, serverContext)
case sshutils.AgentForwardRequest:
// This happens when SSH client has agent forwarding enabled, in this case
// client sends a special request, in return SSH server opens new channel
// that uses SSH protocol for agent drafted here:
// https://tools.ietf.org/html/draft-ietf-secsh-agent-02
// the open ssh proto spec that we implement is here:
// http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.agent

// to maintain interoperability with OpenSSH, agent forwarding requests
// should never fail, all errors should be logged and we should continue
// processing requests.
err := s.handleAgentForwardNode(req, serverContext)
if err != nil {
log.Warn(err)
}
return nil
default:
return trace.AccessDenied("attempted %v request in join-only mode", req.Type)
}
Expand Down
28 changes: 12 additions & 16 deletions lib/srv/sess.go
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ func newSession(id rsession.ID, r *SessionRegistry, ctx *ServerContext) (*sessio
if err != nil {
return nil, trace.Wrap(err)
}
if existing.Login != rsess.Login {
if existing.Login != rsess.Login && rsess.Login != teleport.SSHSessionJoinPrincipal {
return nil, trace.AccessDenied(
"can't switch users from %v to %v for session %v",
rsess.Login, existing.Login, id)
Expand All @@ -570,7 +570,7 @@ func newSession(id rsession.ID, r *SessionRegistry, ctx *ServerContext) (*sessio
stopC: make(chan struct{}),
startTime: startTime,
serverCtx: ctx.srv.Context(),
access: auth.NewSessionAccessEvaluator(policySets, types.SSHSessionKind),
access: auth.NewSessionAccessEvaluator(policySets, types.SSHSessionKind, ctx.Identity.TeleportUser),
scx: ctx,
presenceEnabled: ctx.Identity.Certificate.Extensions[teleport.CertExtensionMFAVerified] != "",
io: NewTermManager(),
Expand Down Expand Up @@ -1544,7 +1544,7 @@ func (s *session) addParty(p *party, mode types.SessionParticipantMode) error {
s.mu.Lock()
defer s.mu.Unlock()

if s.login != p.login {
if s.login != p.login && p.login != teleport.SSHSessionJoinPrincipal {
return trace.AccessDenied(
"can't switch users from %v to %v for session %v",
s.login, p.login, s.id)
Expand Down Expand Up @@ -1637,7 +1637,8 @@ func (s *session) addParty(p *party, mode types.SessionParticipantMode) error {
func (s *session) join(ch ssh.Channel, ctx *ServerContext, mode types.SessionParticipantMode) (*party, error) {
if ctx.Identity.TeleportUser != s.initiator {
accessContext := auth.SessionAccessContext{
Roles: ctx.Identity.AccessChecker.Roles(),
Username: ctx.Identity.TeleportUser,
Roles: ctx.Identity.AccessChecker.Roles(),
}

modes := s.access.CanJoin(accessContext)
Expand Down Expand Up @@ -1761,18 +1762,13 @@ func (p *party) closeUnderSessionLock() {
// on an interval until the session tracker is closed.
func (s *session) trackSession(teleportUser string, policySet []*types.SessionTrackerPolicySet) error {
trackerSpec := types.SessionTrackerSpecV1{
SessionID: s.id.String(),
Kind: string(types.SSHSessionKind),
State: types.SessionState_SessionStatePending,
Hostname: s.registry.Srv.GetInfo().GetHostname(),
Address: s.scx.srv.ID(),
ClusterName: s.scx.ClusterName,
Login: s.login,
Participants: []types.Participant{{
ID: teleportUser,
User: teleportUser,
LastActive: s.registry.clock.Now(),
}},
SessionID: s.id.String(),
Kind: string(types.SSHSessionKind),
State: types.SessionState_SessionStatePending,
Hostname: s.registry.Srv.GetInfo().GetHostname(),
Address: s.scx.srv.ID(),
ClusterName: s.scx.ClusterName,
Login: s.login,
HostUser: teleportUser,
Reason: s.scx.env[teleport.EnvSSHSessionReason],
HostPolicies: policySet,
Expand Down
6 changes: 2 additions & 4 deletions lib/web/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -1970,8 +1970,6 @@ func (h *Handler) clusterNodesGet(w http.ResponseWriter, r *http.Request, p http
//
// {"server_id": "uuid", "login": "admin", "term": {"h": 120, "w": 100}, "sid": "123"}
//
// Session id can be empty
//
// Successful response is a websocket stream that allows read write to the server
//
func (h *Handler) siteNodeConnect(
Expand Down Expand Up @@ -2130,7 +2128,7 @@ func (h *Handler) siteSessionsGet(w http.ResponseWriter, r *http.Request, p http

sessions := make([]session.Session, 0, len(trackers))
for _, tracker := range trackers {
if tracker.GetSessionKind() == types.SSHSessionKind {
if tracker.GetSessionKind() == types.SSHSessionKind && tracker.GetState() != types.SessionState_SessionStateTerminated {
sessions = append(sessions, trackerToLegacySession(tracker, p.ByName("site")))
}
}
Expand Down Expand Up @@ -2163,7 +2161,7 @@ func (h *Handler) siteSessionGet(w http.ResponseWriter, r *http.Request, p httpr
return nil, trace.Wrap(err)
}

if tracker.GetSessionKind() != types.SSHSessionKind {
if tracker.GetSessionKind() != types.SSHSessionKind || tracker.GetState() == types.SessionState_SessionStateTerminated {
return nil, trace.NotFound("session %v not found", sessionID)
}

Expand Down
4 changes: 4 additions & 0 deletions lib/web/apiserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3312,6 +3312,10 @@ func (mock authProviderMock) GetSessionEvents(n string, s session.ID, c int, p b
return []events.EventFields{}, nil
}

func (mock authProviderMock) GetSessionTracker(ctx context.Context, sessionID string) (types.SessionTracker, error) {
return nil, trace.NotFound("foo")
}

func (s *WebSuite) makeTerminal(t *testing.T, pack *authPack, opts ...session.ID) (*websocket.Conn, error) {
var sessionID session.ID
if len(opts) == 0 {
Expand Down
23 changes: 22 additions & 1 deletion lib/web/terminal.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ type TerminalRequest struct {
type AuthProvider interface {
GetNodes(ctx context.Context, namespace string) ([]types.Server, error)
GetSessionEvents(namespace string, sid session.ID, after int, includePrintEvents bool) ([]events.EventFields, error)
GetSessionTracker(ctx context.Context, sessionID string) (types.SessionTracker, error)
}

// NewTerminal creates a web-based terminal based on WebSockets and returns a
Expand Down Expand Up @@ -116,6 +117,17 @@ func NewTerminal(ctx context.Context, req TerminalRequest, authProvider AuthProv
return nil, trace.BadParameter("invalid server name %q: %v", req.Server, err)
}

var join bool
_, err = authProvider.GetSessionTracker(ctx, string(req.SessionID))
switch {
case trace.IsNotFound(err):
join = false
case err != nil:
return nil, trace.Wrap(err)
default:
join = true
}

return &TerminalHandler{
log: logrus.WithFields(logrus.Fields{
trace.Component: teleport.ComponentWebsocket,
Expand All @@ -129,6 +141,7 @@ func NewTerminal(ctx context.Context, req TerminalRequest, authProvider AuthProv
encoder: unicode.UTF8.NewEncoder(),
decoder: unicode.UTF8.NewDecoder(),
wsLock: &sync.Mutex{},
join: join,
}, nil
}

Expand Down Expand Up @@ -178,6 +191,9 @@ type TerminalHandler struct {
closeOnce sync.Once

wsLock *sync.Mutex

// join is set if we're joining an existing session
join bool
}

// Serve builds a connect to the remote node and then pumps back two types of
Expand Down Expand Up @@ -301,8 +317,13 @@ func (t *TerminalHandler) makeClient(ws *websocket.Conn, r *http.Request) (*clie
// communicate over the websocket.
stream := t.asTerminalStream(ws)

if t.join {
clientConfig.HostLogin = teleport.SSHSessionJoinPrincipal
} else {
clientConfig.HostLogin = t.params.Login
}

clientConfig.ForwardAgent = client.ForwardAgentLocal
clientConfig.HostLogin = t.params.Login
clientConfig.Namespace = t.params.Namespace
clientConfig.Stdout = stream
clientConfig.Stderr = stream
Expand Down