Skip to content

Commit

Permalink
fix: listing sessions query (ory#2958)
Browse files Browse the repository at this point in the history
Closes ory#2930
  • Loading branch information
Ajay Kelkar authored Dec 20, 2022
1 parent dd64052 commit 67346e6
Show file tree
Hide file tree
Showing 4 changed files with 287 additions and 334 deletions.
25 changes: 18 additions & 7 deletions persistence/sql/persister_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ func (p *Persister) GetSession(ctx context.Context, sid uuid.UUID, expandables s
s.Identity = i
}

s.Active = s.IsActive()
return &s, nil
}

Expand All @@ -77,7 +78,11 @@ func (p *Persister) ListSessions(ctx context.Context, active *bool, paginatorOpt
if err := p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
q := c.Where("nid = ?", nid)
if active != nil {
q = q.Where("active = ?", *active)
if *active {
q.Where("active = ? AND expires_at >= ?", *active, time.Now().UTC())
} else {
q.Where("(active = ? OR expires_at < ?)", *active, time.Now().UTC())
}
}

// if len(expandables) > 0 {
Expand Down Expand Up @@ -106,6 +111,7 @@ func (p *Persister) ListSessions(ctx context.Context, active *bool, paginatorOpt
return err
}

sess.Active = sess.IsActive()
sess.Identity = i
}
}
Expand Down Expand Up @@ -133,7 +139,11 @@ func (p *Persister) ListSessionsByIdentity(ctx context.Context, iID uuid.UUID, a
q = q.Where("id != ?", except)
}
if active != nil {
q = q.Where("active = ?", *active)
if *active {
q.Where("active = ? AND expires_at >= ?", *active, time.Now().UTC())
} else {
q.Where("(active = ? OR expires_at < ?)", *active, time.Now().UTC())
}
}
if len(expandables) > 0 {
q = q.Eager(expandables.ToEager()...)
Expand All @@ -152,12 +162,13 @@ func (p *Persister) ListSessionsByIdentity(ctx context.Context, iID uuid.UUID, a
}

if expandables.Has(session.ExpandSessionIdentity) {
for _, s := range s {
i, err := p.GetIdentity(ctx, s.IdentityID)
if err != nil {
return err
}
i, err := p.GetIdentity(ctx, iID)
if err != nil {
return sqlcon.HandleError(err)
}

for _, s := range s {
s.Active = s.IsActive()
s.Identity = i
}
}
Expand Down
173 changes: 64 additions & 109 deletions session/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,48 @@ func TestHandlerAdminSessionManagement(t *testing.T) {
assert.False(t, session.Active)
})

t.Run("case=session status should be false when session expiry is past", func(t *testing.T) {
client := testhelpers.NewClientWithCookies(t)

s.ExpiresAt = time.Now().Add(-time.Hour * 1)
require.NoError(t, reg.SessionPersister().UpsertSession(ctx, s))

assert.NotEqual(t, uuid.Nil, s.ID)
assert.NotEqual(t, uuid.Nil, s.Identity.ID)

req, _ := http.NewRequest("GET", ts.URL+"/admin/sessions/"+s.ID.String(), nil)
res, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)

body, err := io.ReadAll(res.Body)
require.NoError(t, err)
assert.Equal(t, "false", gjson.GetBytes(body, "active").String(), "%s", body)
})

t.Run("case=session status should be false for inactive identity", func(t *testing.T) {
client := testhelpers.NewClientWithCookies(t)
var s1 *Session
require.NoError(t, faker.FakeData(&s1))
s1.Active = true
s1.Identity.State = identity.StateInactive
require.NoError(t, reg.Persister().CreateIdentity(ctx, s1.Identity))

assert.Equal(t, uuid.Nil, s1.ID)
require.NoError(t, reg.SessionPersister().UpsertSession(ctx, s1))
assert.NotEqual(t, uuid.Nil, s1.ID)
assert.NotEqual(t, uuid.Nil, s1.Identity.ID)

req, _ := http.NewRequest("GET", ts.URL+"/admin/sessions/"+s1.ID.String()+"?expand=Identity", nil)
res, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)

body, err := io.ReadAll(res.Body)
require.NoError(t, err)
assert.Equal(t, "false", gjson.GetBytes(body, "active").String(), "%s", body)
})

req, _ := http.NewRequest("DELETE", ts.URL+"/admin/identities/"+s.Identity.ID.String()+"/sessions", nil)
res, err := client.Do(req)
require.NoError(t, err)
Expand All @@ -649,52 +691,6 @@ func TestHandlerAdminSessionManagement(t *testing.T) {
})
})

t.Run("case=session status should be false for inactive identity", func(t *testing.T) {
client := testhelpers.NewClientWithCookies(t)
var s *Session
require.NoError(t, faker.FakeData(&s))
s.Active = true
s.Identity.State = identity.StateInactive
require.NoError(t, reg.Persister().CreateIdentity(ctx, s.Identity))

assert.Equal(t, uuid.Nil, s.ID)
require.NoError(t, reg.SessionPersister().UpsertSession(ctx, s))
assert.NotEqual(t, uuid.Nil, s.ID)
assert.NotEqual(t, uuid.Nil, s.Identity.ID)

req, _ := http.NewRequest("GET", ts.URL+"/admin/sessions/"+s.ID.String()+"?expand=Identity", nil)
res, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)

body, err := io.ReadAll(res.Body)
require.NoError(t, err)
assert.Equal(t, "false", gjson.GetBytes(body, "active").String(), "%s", body)
})

t.Run("case=session status should be false when session expiry is past", func(t *testing.T) {
client := testhelpers.NewClientWithCookies(t)
var s *Session
require.NoError(t, faker.FakeData(&s))
s.Active = true
s.ExpiresAt = time.Now().Add(-time.Hour * 1)
require.NoError(t, reg.Persister().CreateIdentity(ctx, s.Identity))

assert.Equal(t, uuid.Nil, s.ID)
require.NoError(t, reg.SessionPersister().UpsertSession(ctx, s))
assert.NotEqual(t, uuid.Nil, s.ID)
assert.NotEqual(t, uuid.Nil, s.Identity.ID)

req, _ := http.NewRequest("GET", ts.URL+"/admin/sessions/"+s.ID.String(), nil)
res, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)

body, err := io.ReadAll(res.Body)
require.NoError(t, err)
assert.Equal(t, "false", gjson.GetBytes(body, "active").String(), "%s", body)
})

t.Run("case=should return 400 when bad UUID is sent", func(t *testing.T) {
client := testhelpers.NewClientWithCookies(t)

Expand All @@ -719,8 +715,9 @@ func TestHandlerAdminSessionManagement(t *testing.T) {

t.Run("case=should return pagination headers on list response", func(t *testing.T) {
client := testhelpers.NewClientWithCookies(t)
i := identity.NewIdentity("")
require.NoError(t, reg.IdentityManager().Create(ctx, i))
var i *identity.Identity
require.NoError(t, faker.FakeData(&i))
require.NoError(t, reg.Persister().CreateIdentity(ctx, i))

numSessions := 5
numSessionsActive := 2
Expand All @@ -731,78 +728,35 @@ func TestHandlerAdminSessionManagement(t *testing.T) {
sess[j].Identity = i
if j < numSessionsActive {
sess[j].Active = true
sess[j].ExpiresAt = time.Now().Add(time.Hour)
} else {
sess[j].Active = false
sess[j].ExpiresAt = time.Now().Add(-time.Hour)
}
require.NoError(t, reg.SessionPersister().UpsertSession(ctx, &sess[j]))
}

for _, tc := range []struct {
activeOnly string
expectedTotalCount int
expectedSessionIds []uuid.UUID
}{
{
activeOnly: "true",
expectedTotalCount: numSessionsActive,
expectedSessionIds: []uuid.UUID{sess[0].ID, sess[1].ID},
},
{
activeOnly: "false",
expectedTotalCount: numSessions - numSessionsActive,
expectedSessionIds: []uuid.UUID{sess[2].ID, sess[3].ID, sess[4].ID},
},
{
activeOnly: "",
expectedTotalCount: numSessions,
expectedSessionIds: []uuid.UUID{sess[0].ID, sess[1].ID, sess[2].ID, sess[3].ID, sess[4].ID},
},
} {
t.Run(fmt.Sprintf("active=%#v", tc.activeOnly), func(t *testing.T) {
reqURL := ts.URL + "/admin/identities/" + i.ID.String() + "/sessions"
if tc.activeOnly != "" {
reqURL += "?active=" + tc.activeOnly
}
req, _ := http.NewRequest("GET", reqURL, nil)
res, err := client.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, res.StatusCode)

totalCount, err := strconv.Atoi(res.Header.Get("X-Total-Count"))
require.NoError(t, err)
require.Equal(t, tc.expectedTotalCount, totalCount)
require.NotEqual(t, "", res.Header.Get("Link"))
})
}
})
sessions, _, _ := reg.SessionPersister().ListSessionsByIdentity(ctx, i.ID, nil, 1, 10, uuid.Nil, ExpandEverything)
require.Equal(t, 5, len(sessions))

t.Run("case=should respect active on list", func(t *testing.T) {
client := testhelpers.NewClientWithCookies(t)
i := identity.NewIdentity("")
require.NoError(t, reg.IdentityManager().Create(ctx, i))

sess := make([]Session, 2)
for j := range sess {
require.NoError(t, faker.FakeData(&sess[j]))
sess[j].Identity = i
sess[j].Active = j%2 == 0
require.NoError(t, reg.SessionPersister().UpsertSession(ctx, &sess[j]))
}

for _, tc := range []struct {
activeOnly string
expectedIDs []uuid.UUID
}{
{
activeOnly: "true",
expectedIDs: []uuid.UUID{sess[0].ID},
},
{
activeOnly: "false",
expectedIDs: []uuid.UUID{sess[1].ID},
},
{
activeOnly: "",
expectedIDs: []uuid.UUID{sess[0].ID, sess[1].ID},
},
} {
t.Run(fmt.Sprintf("active=%#v", tc.activeOnly), func(t *testing.T) {
reqURL := ts.URL + "/admin/identities/" + i.ID.String() + "/sessions"
if tc.activeOnly != "" {
reqURL += "?active=" + tc.activeOnly
Expand All @@ -812,17 +766,18 @@ func TestHandlerAdminSessionManagement(t *testing.T) {
require.NoError(t, err)
require.Equal(t, http.StatusOK, res.StatusCode)

var sessions []Session
require.NoError(t, json.NewDecoder(res.Body).Decode(&sessions))
require.Equal(t, len(sessions), len(tc.expectedIDs))

for _, id := range tc.expectedIDs {
found := false
for _, s := range sessions {
found = found || s.ID == id
}
assert.True(t, found)
var actualSessions []Session
require.NoError(t, json.NewDecoder(res.Body).Decode(&actualSessions))
actualSessionIds := make([]uuid.UUID, 0)
for _, s := range actualSessions {
actualSessionIds = append(actualSessionIds, s.ID)
}

totalCount, err := strconv.Atoi(res.Header.Get("X-Total-Count"))
require.NoError(t, err)
assert.Equal(t, len(tc.expectedSessionIds), totalCount)
assert.NotEqual(t, "", res.Header.Get("Link"))
assert.ElementsMatch(t, tc.expectedSessionIds, actualSessionIds)
})
}
})
Expand Down
11 changes: 0 additions & 11 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,17 +153,6 @@ func (s Session) TableName(ctx context.Context) string {
return "sessions"
}

func (s Session) MarshalJSON() ([]byte, error) {
type sl Session
s.Active = s.IsActive()

result, err := json.Marshal(sl(s))
if err != nil {
return nil, err
}
return result, nil
}

func (s *Session) CompletedLoginFor(method identity.CredentialsType, aal identity.AuthenticatorAssuranceLevel) {
s.AMR = append(s.AMR, AuthenticationMethod{Method: method, AAL: aal, CompletedAt: time.Now().UTC()})
}
Expand Down
Loading

0 comments on commit 67346e6

Please sign in to comment.