Skip to content

Commit

Permalink
[v10] Always allow session owners to join own sessions + only list ac…
Browse files Browse the repository at this point in the history
…tive trackers in WebUI (#13764)
  • Loading branch information
xacrimon committed Jul 4, 2022
1 parent 005712b commit 45b486d
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 38 deletions.
4 changes: 2 additions & 2 deletions lib/auth/auth_with_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,8 @@ func (a *ServerWithRoles) CreateSessionTracker(ctx context.Context, tracker type
}

func (a *ServerWithRoles) filterSessionTracker(ctx context.Context, joinerRoles []types.Role, tracker types.SessionTracker) bool {
evaluator := NewSessionAccessEvaluator(tracker.GetHostPolicySets(), tracker.GetSessionKind())
modes := evaluator.CanJoin(SessionAccessContext{Roles: joinerRoles})
evaluator := NewSessionAccessEvaluator(tracker.GetHostPolicySets(), tracker.GetSessionKind(), tracker.GetHostUser())
modes := evaluator.CanJoin(SessionAccessContext{Username: a.context.User.GetName(), Roles: joinerRoles})

if len(modes) == 0 {
return false
Expand Down
9 changes: 8 additions & 1 deletion lib/auth/session_access.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,16 @@ type SessionAccessEvaluator struct {
kind types.SessionKind
policySets []*types.SessionTrackerPolicySet
isModerated bool
owner string
}

// NewSessionAccessEvaluator creates a new session access evaluator for a given session kind
// and a set of roles attached to the host user.
func NewSessionAccessEvaluator(policySets []*types.SessionTrackerPolicySet, kind types.SessionKind) SessionAccessEvaluator {
func NewSessionAccessEvaluator(policySets []*types.SessionTrackerPolicySet, kind types.SessionKind, owner string) SessionAccessEvaluator {
e := SessionAccessEvaluator{
kind: kind,
policySets: policySets,
owner: owner,
}

for _, policySet := range policySets {
Expand Down Expand Up @@ -188,6 +190,11 @@ func (e *SessionAccessEvaluator) CanJoin(user SessionAccessContext) []types.Sess
return preAccessControlsModes(e.kind)
}

// Session owners can always join their own sessions.
if user.Username == e.owner {
return []types.SessionParticipantMode{types.SessionPeerMode, types.SessionModeratorMode, types.SessionObserverMode}
}

var modes []types.SessionParticipantMode

// Loop over every allow policy attached the participant and check it's applicability.
Expand Down
81 changes: 54 additions & 27 deletions lib/auth/session_access_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type startTestCase struct {
host types.Role
sessionKind types.SessionKind
participants []SessionAccessContext
owner string
expected bool
}

Expand Down Expand Up @@ -178,8 +179,11 @@ func TestSessionAccessStart(t *testing.T) {

for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
policy := testCase.host.GetSessionPolicySet()
evaluator := NewSessionAccessEvaluator([]*types.SessionTrackerPolicySet{&policy}, testCase.sessionKind)
var policies []*types.SessionTrackerPolicySet
policySet := testCase.host.GetSessionPolicySet()
policies = append(policies, &policySet)

evaluator := NewSessionAccessEvaluator(policies, testCase.sessionKind, testCase.owner)
result, _, err := evaluator.FulfilledFor(testCase.participants)
require.NoError(t, err)
require.Equal(t, testCase.expected, result)
Expand All @@ -188,11 +192,12 @@ func TestSessionAccessStart(t *testing.T) {
}

type joinTestCase struct {
name string
host types.Role
sessionKind types.SessionKind
participant SessionAccessContext
expected bool
name string
host types.Role
sessionKinds []types.SessionKind
participant SessionAccessContext
owner string
expected []bool
}

func successJoinTestCase(t *testing.T) joinTestCase {
Expand All @@ -208,14 +213,14 @@ func successJoinTestCase(t *testing.T) joinTestCase {
}})

return joinTestCase{
name: "success",
host: hostRole,
sessionKind: types.SSHSessionKind,
name: "success",
host: hostRole,
sessionKinds: []types.SessionKind{types.SSHSessionKind},
participant: SessionAccessContext{
Username: "participant",
Roles: []types.Role{participantRole},
},
expected: true,
expected: []bool{true},
}
}

Expand All @@ -232,14 +237,33 @@ func successGlobJoinTestCase(t *testing.T) joinTestCase {
}})

return joinTestCase{
name: "success",
host: hostRole,
sessionKind: types.SSHSessionKind,
name: "success",
host: hostRole,
sessionKinds: []types.SessionKind{types.SSHSessionKind},
participant: SessionAccessContext{
Username: "participant",
Roles: []types.Role{participantRole},
},
expected: true,
expected: []bool{true},
}
}

func successSameUserJoinTestCase(t *testing.T) joinTestCase {
hostRole, err := types.NewRole("host", types.RoleSpecV5{})
require.NoError(t, err)
participantRole, err := types.NewRole("participant", types.RoleSpecV5{})
require.NoError(t, err)

return joinTestCase{
name: "successSameUser",
host: hostRole,
sessionKinds: []types.SessionKind{types.SSHSessionKind},
participant: SessionAccessContext{
Username: "john",
Roles: []types.Role{participantRole},
},
owner: "john",
expected: []bool{true},
}
}

Expand All @@ -250,14 +274,14 @@ func failRoleJoinTestCase(t *testing.T) joinTestCase {
require.NoError(t, err)

return joinTestCase{
name: "failRole",
host: hostRole,
sessionKind: types.SSHSessionKind,
name: "failRole",
host: hostRole,
sessionKinds: []types.SessionKind{types.SSHSessionKind},
participant: SessionAccessContext{
Username: "participant",
Roles: []types.Role{participantRole},
},
expected: false,
expected: []bool{false},
}
}

Expand All @@ -274,31 +298,34 @@ func failKindJoinTestCase(t *testing.T) joinTestCase {
}})

return joinTestCase{
name: "failKind",
host: hostRole,
sessionKind: types.SSHSessionKind,
name: "failKind",
host: hostRole,
sessionKinds: []types.SessionKind{types.SSHSessionKind},
participant: SessionAccessContext{
Username: "participant",
Roles: []types.Role{participantRole},
},
expected: false,
expected: []bool{false},
}
}

func TestSessionAccessJoin(t *testing.T) {
testCases := []joinTestCase{
successJoinTestCase(t),
successGlobJoinTestCase(t),
successSameUserJoinTestCase(t),
failRoleJoinTestCase(t),
failKindJoinTestCase(t),
}

for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
policy := testCase.host.GetSessionPolicySet()
evaluator := NewSessionAccessEvaluator([]*types.SessionTrackerPolicySet{&policy}, testCase.sessionKind)
result := evaluator.CanJoin(testCase.participant)
require.Equal(t, testCase.expected, len(result) > 0)
for i, kind := range testCase.sessionKinds {
policy := testCase.host.GetSessionPolicySet()
evaluator := NewSessionAccessEvaluator([]*types.SessionTrackerPolicySet{&policy}, kind, testCase.owner)
result := evaluator.CanJoin(testCase.participant)
require.Equal(t, testCase.expected[i], len(result) > 0)
}
})
}
}
2 changes: 1 addition & 1 deletion lib/kube/proxy/forwarder.go
Original file line number Diff line number Diff line change
Expand Up @@ -954,7 +954,7 @@ func (f *Forwarder) execNonInteractive(ctx *authContext, w http.ResponseWriter,
policySets = append(policySets, &policySet)
}

authorizer := auth.NewSessionAccessEvaluator(policySets, types.KubernetesSessionKind)
authorizer := auth.NewSessionAccessEvaluator(policySets, types.KubernetesSessionKind, ctx.User.GetName())
canStart, _, err := authorizer.FulfilledFor(nil)
if err != nil {
return nil, trace.Wrap(err)
Expand Down
2 changes: 1 addition & 1 deletion lib/kube/proxy/sess.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ func newSession(ctx authContext, forwarder *Forwarder, req *http.Request, params
}

q := req.URL.Query()
accessEvaluator := auth.NewSessionAccessEvaluator(policySets, types.KubernetesSessionKind)
accessEvaluator := auth.NewSessionAccessEvaluator(policySets, types.KubernetesSessionKind, ctx.User.GetName())

io := srv.NewTermManager()

Expand Down
8 changes: 8 additions & 0 deletions lib/services/role.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,14 @@ func RoleForUser(u types.User) types.Role {
types.NewRule(types.KindLock, RW()),
types.NewRule(types.KindToken, RW()),
},
JoinSessions: []*types.SessionJoinPolicy{
{
Name: "foo",
Roles: []string{"*"},
Kinds: []string{string(types.SSHSessionKind)},
Modes: []string{string(types.SessionPeerMode)},
},
},
},
})
return role
Expand Down
5 changes: 3 additions & 2 deletions lib/srv/sess.go
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ func newSession(id rsession.ID, r *SessionRegistry, ctx *ServerContext) (*sessio
stopC: make(chan struct{}),
startTime: startTime,
serverCtx: ctx.srv.Context(),
access: auth.NewSessionAccessEvaluator(policySets, types.SSHSessionKind),
access: auth.NewSessionAccessEvaluator(policySets, types.SSHSessionKind, ctx.Identity.TeleportUser),
scx: ctx,
presenceEnabled: ctx.Identity.Certificate.Extensions[teleport.CertExtensionMFAVerified] != "",
io: NewTermManager(),
Expand Down Expand Up @@ -1493,7 +1493,8 @@ func (s *session) join(ch ssh.Channel, ctx *ServerContext, mode types.SessionPar
if ctx.Identity.TeleportUser != s.initiator {
roles := []types.Role(ctx.Identity.RoleSet)
accessContext := auth.SessionAccessContext{
Roles: roles,
Username: ctx.Identity.TeleportUser,
Roles: roles,
}

modes := s.access.CanJoin(accessContext)
Expand Down
6 changes: 2 additions & 4 deletions lib/web/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -1893,8 +1893,6 @@ func (h *Handler) clusterNodesGet(w http.ResponseWriter, r *http.Request, p http
//
// {"server_id": "uuid", "login": "admin", "term": {"h": 120, "w": 100}, "sid": "123"}
//
// Session id can be empty
//
// Successful response is a websocket stream that allows read write to the server
//
func (h *Handler) siteNodeConnect(
Expand Down Expand Up @@ -2053,7 +2051,7 @@ func (h *Handler) siteSessionsGet(w http.ResponseWriter, r *http.Request, p http

sessions := make([]session.Session, 0, len(trackers))
for _, tracker := range trackers {
if tracker.GetSessionKind() == types.SSHSessionKind {
if tracker.GetSessionKind() == types.SSHSessionKind && tracker.GetState() != types.SessionState_SessionStateTerminated {
sessions = append(sessions, trackerToLegacySession(tracker, p.ByName("site")))
}
}
Expand Down Expand Up @@ -2086,7 +2084,7 @@ func (h *Handler) siteSessionGet(w http.ResponseWriter, r *http.Request, p httpr
return nil, trace.Wrap(err)
}

if tracker.GetSessionKind() != types.SSHSessionKind {
if tracker.GetSessionKind() != types.SSHSessionKind || tracker.GetState() == types.SessionState_SessionStateTerminated {
return nil, trace.NotFound("session %v not found", sessionID)
}

Expand Down

0 comments on commit 45b486d

Please sign in to comment.