Skip to content

Commit

Permalink
fix #1539
Browse files Browse the repository at this point in the history
Former-commit-id: f2f277cd5cbe781ce596adc7840a1b1bc3b3bfc6
  • Loading branch information
kataras committed Jun 19, 2020
1 parent 9724592 commit 3f98b39
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 31 deletions.
8 changes: 6 additions & 2 deletions _examples/sessions/overview/example/example.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@ func NewApp(sess *sessions.Sessions) *iris.Application {
// set session values.
app.Get("/set", func(ctx iris.Context) {
session := sessions.Get(ctx)
isNew := session.IsNew()

session.Set("name", "iris")

ctx.Writef("All ok session set to: %s", session.GetString("name"))
ctx.Writef("All ok session set to: %s [isNew=%t]", session.GetString("name"), isNew)
})

app.Get("/get", func(ctx iris.Context) {
Expand Down Expand Up @@ -68,9 +70,11 @@ func NewApp(sess *sessions.Sessions) *iris.Application {

key := ctx.Params().Get("key")
value := ctx.Params().Get("value")
isNew := session.IsNew()

session.Set(key, value)

ctx.Writef("All ok session value of the '%s' is: %s", key, session.GetString(key))
ctx.Writef("All ok session value of the '%s' is: %s [isNew=%t]", key, session.GetString(key), isNew)
})

app.Get("/get/{key}", func(ctx iris.Context) {
Expand Down
15 changes: 5 additions & 10 deletions mvc/controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -663,27 +663,22 @@ func TestApplicationDependency(t *testing.T) {
// Authenticated type.
type Authenticated int64

// BasePublicPrivateController base controller between public and private controllers.
type BasePublicPrivateController struct {
// BasePrivateController base controller for private controllers.
type BasePrivateController struct {
CurrentUserID Authenticated
Ctx iris.Context
Ctx iris.Context // not-used.
}

type publicController struct {
Ctx iris.Context
Ctx iris.Context // not-used.
}

// Get desc
// Route / [GET]
func (c *publicController) Get() iris.Map {
return iris.Map{"data": "things"}
}

// privateController serves the "public-private" Customer API.
type privateController struct{ BasePublicPrivateController }
type privateController struct{ BasePrivateController }

// Get desc
// Route / [GET]
func (c *privateController) Get() iris.Map {
return iris.Map{"id": c.CurrentUserID}
}
Expand Down
42 changes: 26 additions & 16 deletions sessions/sessiondb/badger/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,12 @@ func (db *Database) Get(sid string, key string) (value interface{}) {
return
}

// validSessionItem reports whether the current iterator's item key
// is a value of the session id "prefix".
func validSessionItem(key, prefix []byte) bool {
return len(key) > len(prefix) && bytes.Equal(key[0:len(prefix)], prefix)
}

// Visit loops through all session keys and values.
func (db *Database) Visit(sid string, cb func(key string, value interface{})) {
prefix := makePrefix(sid)
Expand All @@ -166,30 +172,28 @@ func (db *Database) Visit(sid string, cb func(key string, value interface{})) {
iter := txn.NewIterator(badger.DefaultIteratorOptions)
defer iter.Close()

for iter.Rewind(); iter.ValidForPrefix(prefix); iter.Next() {
item := iter.Item()
var value interface{}
for iter.Rewind(); ; iter.Next() {
if !iter.Valid() {
break
}

// err := item.Value(func(valueBytes []byte) {
// if err := sessions.DefaultTranscoder.Unmarshal(valueBytes, &value); err != nil {
// golog.Error(err)
// }
// })
item := iter.Item()
key := item.Key()
if !validSessionItem(key, prefix) {
continue
}

// if err != nil {
// golog.Error(err)
// continue
// }
var value interface{}

err := item.Value(func(valueBytes []byte) error {
return sessions.DefaultTranscoder.Unmarshal(valueBytes, &value)
})
if err != nil {
golog.Error(err)
golog.Errorf("[sessionsdb.badger.Visit] %v", err)
continue
}

cb(string(bytes.TrimPrefix(item.Key(), prefix)), value)
cb(string(bytes.TrimPrefix(key, prefix)), value)
}
}

Expand All @@ -207,8 +211,14 @@ func (db *Database) Len(sid string) (n int) {
txn := db.Service.NewTransaction(false)
iter := txn.NewIterator(iterOptionsNoValues)

for iter.Rewind(); iter.ValidForPrefix(prefix); iter.Next() {
n++
for iter.Rewind(); ; iter.Next() {
if !iter.Valid() {
break
}

if validSessionItem(iter.Item().Key(), prefix) {
n++
}
}

iter.Close()
Expand Down
7 changes: 6 additions & 1 deletion sessions/sessiondb/redis/driver_radix.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,12 @@ func (r *RadixDriver) getKeys(cursor, prefix string) ([]string, error) {
return nil, err
}

keys := res.keys[0:]
if len(res.keys) <= 1 {
return nil, nil
}

keys := res.keys[1:] // the first one is always the session id itself.

if res.cur != "0" {
moreKeys, err := r.getKeys(res.cur, prefix)
if err != nil {
Expand Down
13 changes: 11 additions & 2 deletions sessions/sessiondb/redis/driver_redigo.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,19 @@ func (r *RedigoDriver) getKeysConn(c redis.Conn, cursor interface{}, prefix stri
if len(replies) == 2 {
// take the second, it must contain the slice of keys.
if keysSliceAsBytes, ok := replies[1].([]interface{}); ok {
keys := make([]string, len(keysSliceAsBytes))
n := len(keysSliceAsBytes) - 1 // scan match returns the session id key too.
if n <= 0 {
return nil, nil
}

keys := make([]string, n)

for i, k := range keysSliceAsBytes {
keys[i] = fmt.Sprintf("%s", k)[len(r.Config.Prefix):]
key := fmt.Sprintf("%s", k)[len(r.Config.Prefix):]
if key == prefix {
continue // it's the session id itself.
}
keys[i] = key
}

if cur := fmt.Sprintf("%s", replies[0]); cur != "0" {
Expand Down
7 changes: 7 additions & 0 deletions sessions/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@ func (s *Sessions) Start(ctx context.Context, cookieOptions ...context.CookieOpt
sid := s.config.SessionIDGenerator(ctx)

sess := s.provider.Init(s, sid, s.config.Expires)
// n := s.provider.db.Len(sid)
// fmt.Printf("db.Len(%s) = %d\n", sid, n)
// if n > 0 {
// s.provider.db.Visit(sid, func(key string, value interface{}) {
// fmt.Printf("%s=%s\n", key, value)
// })
// }
sess.isNew = s.provider.db.Len(sid) == 0

s.updateCookie(ctx, sid, s.config.Expires, cookieOptions...)
Expand Down

0 comments on commit 3f98b39

Please sign in to comment.