From 80384f32b25abf03b3f99329795099929b6c4f24 Mon Sep 17 00:00:00 2001 From: "Gerasimos (Makis) Maropoulos" Date: Fri, 19 Jun 2020 05:54:21 +0300 Subject: [PATCH] fix #1539 --- .../sessions/overview/example/example.go | 8 +++- mvc/controller_test.go | 15 +++---- sessions/sessiondb/badger/database.go | 42 ++++++++++++------- sessions/sessiondb/redis/driver_radix.go | 7 +++- sessions/sessiondb/redis/driver_redigo.go | 13 +++++- sessions/sessions.go | 7 ++++ 6 files changed, 61 insertions(+), 31 deletions(-) diff --git a/_examples/sessions/overview/example/example.go b/_examples/sessions/overview/example/example.go index f7312c4a9..64509b513 100644 --- a/_examples/sessions/overview/example/example.go +++ b/_examples/sessions/overview/example/example.go @@ -36,9 +36,11 @@ func NewApp(sess *sessions.Sessions) *iris.Application { // set session values. app.Get("/set", func(ctx iris.Context) { session := sessions.Get(ctx) + isNew := session.IsNew() + session.Set("name", "iris") - ctx.Writef("All ok session set to: %s", session.GetString("name")) + ctx.Writef("All ok session set to: %s [isNew=%t]", session.GetString("name"), isNew) }) app.Get("/get", func(ctx iris.Context) { @@ -68,9 +70,11 @@ func NewApp(sess *sessions.Sessions) *iris.Application { key := ctx.Params().Get("key") value := ctx.Params().Get("value") + isNew := session.IsNew() + session.Set(key, value) - ctx.Writef("All ok session value of the '%s' is: %s", key, session.GetString(key)) + ctx.Writef("All ok session value of the '%s' is: %s [isNew=%t]", key, session.GetString(key), isNew) }) app.Get("/get/{key}", func(ctx iris.Context) { diff --git a/mvc/controller_test.go b/mvc/controller_test.go index 4e70832f3..6ff0a7c73 100644 --- a/mvc/controller_test.go +++ b/mvc/controller_test.go @@ -663,27 +663,22 @@ func TestApplicationDependency(t *testing.T) { // Authenticated type. type Authenticated int64 -// BasePublicPrivateController base controller between public and private controllers. -type BasePublicPrivateController struct { +// BasePrivateController base controller for private controllers. +type BasePrivateController struct { CurrentUserID Authenticated - Ctx iris.Context + Ctx iris.Context // not-used. } type publicController struct { - Ctx iris.Context + Ctx iris.Context // not-used. } -// Get desc -// Route / [GET] func (c *publicController) Get() iris.Map { return iris.Map{"data": "things"} } -// privateController serves the "public-private" Customer API. -type privateController struct{ BasePublicPrivateController } +type privateController struct{ BasePrivateController } -// Get desc -// Route / [GET] func (c *privateController) Get() iris.Map { return iris.Map{"id": c.CurrentUserID} } diff --git a/sessions/sessiondb/badger/database.go b/sessions/sessiondb/badger/database.go index bd14189f7..c7accb7cd 100644 --- a/sessions/sessiondb/badger/database.go +++ b/sessions/sessiondb/badger/database.go @@ -156,6 +156,12 @@ func (db *Database) Get(sid string, key string) (value interface{}) { return } +// validSessionItem reports whether the current iterator's item key +// is a value of the session id "prefix". +func validSessionItem(key, prefix []byte) bool { + return len(key) > len(prefix) && bytes.Equal(key[0:len(prefix)], prefix) +} + // Visit loops through all session keys and values. func (db *Database) Visit(sid string, cb func(key string, value interface{})) { prefix := makePrefix(sid) @@ -166,30 +172,28 @@ func (db *Database) Visit(sid string, cb func(key string, value interface{})) { iter := txn.NewIterator(badger.DefaultIteratorOptions) defer iter.Close() - for iter.Rewind(); iter.ValidForPrefix(prefix); iter.Next() { - item := iter.Item() - var value interface{} + for iter.Rewind(); ; iter.Next() { + if !iter.Valid() { + break + } - // err := item.Value(func(valueBytes []byte) { - // if err := sessions.DefaultTranscoder.Unmarshal(valueBytes, &value); err != nil { - // golog.Error(err) - // } - // }) + item := iter.Item() + key := item.Key() + if !validSessionItem(key, prefix) { + continue + } - // if err != nil { - // golog.Error(err) - // continue - // } + var value interface{} err := item.Value(func(valueBytes []byte) error { return sessions.DefaultTranscoder.Unmarshal(valueBytes, &value) }) if err != nil { - golog.Error(err) + golog.Errorf("[sessionsdb.badger.Visit] %v", err) continue } - cb(string(bytes.TrimPrefix(item.Key(), prefix)), value) + cb(string(bytes.TrimPrefix(key, prefix)), value) } } @@ -207,8 +211,14 @@ func (db *Database) Len(sid string) (n int) { txn := db.Service.NewTransaction(false) iter := txn.NewIterator(iterOptionsNoValues) - for iter.Rewind(); iter.ValidForPrefix(prefix); iter.Next() { - n++ + for iter.Rewind(); ; iter.Next() { + if !iter.Valid() { + break + } + + if validSessionItem(iter.Item().Key(), prefix) { + n++ + } } iter.Close() diff --git a/sessions/sessiondb/redis/driver_radix.go b/sessions/sessiondb/redis/driver_radix.go index 1ec53c25f..c06cfd489 100644 --- a/sessions/sessiondb/redis/driver_radix.go +++ b/sessions/sessiondb/redis/driver_radix.go @@ -263,7 +263,12 @@ func (r *RadixDriver) getKeys(cursor, prefix string) ([]string, error) { return nil, err } - keys := res.keys[0:] + if len(res.keys) <= 1 { + return nil, nil + } + + keys := res.keys[1:] // the first one is always the session id itself. + if res.cur != "0" { moreKeys, err := r.getKeys(res.cur, prefix) if err != nil { diff --git a/sessions/sessiondb/redis/driver_redigo.go b/sessions/sessiondb/redis/driver_redigo.go index 81ee00448..d34f6247f 100644 --- a/sessions/sessiondb/redis/driver_redigo.go +++ b/sessions/sessiondb/redis/driver_redigo.go @@ -205,10 +205,19 @@ func (r *RedigoDriver) getKeysConn(c redis.Conn, cursor interface{}, prefix stri if len(replies) == 2 { // take the second, it must contain the slice of keys. if keysSliceAsBytes, ok := replies[1].([]interface{}); ok { - keys := make([]string, len(keysSliceAsBytes)) + n := len(keysSliceAsBytes) - 1 // scan match returns the session id key too. + if n <= 0 { + return nil, nil + } + + keys := make([]string, n) for i, k := range keysSliceAsBytes { - keys[i] = fmt.Sprintf("%s", k)[len(r.Config.Prefix):] + key := fmt.Sprintf("%s", k)[len(r.Config.Prefix):] + if key == prefix { + continue // it's the session id itself. + } + keys[i] = key } if cur := fmt.Sprintf("%s", replies[0]); cur != "0" { diff --git a/sessions/sessions.go b/sessions/sessions.go index c049d5f1b..9bed639c3 100644 --- a/sessions/sessions.go +++ b/sessions/sessions.go @@ -83,6 +83,13 @@ func (s *Sessions) Start(ctx context.Context, cookieOptions ...context.CookieOpt sid := s.config.SessionIDGenerator(ctx) sess := s.provider.Init(s, sid, s.config.Expires) + // n := s.provider.db.Len(sid) + // fmt.Printf("db.Len(%s) = %d\n", sid, n) + // if n > 0 { + // s.provider.db.Visit(sid, func(key string, value interface{}) { + // fmt.Printf("%s=%s\n", key, value) + // }) + // } sess.isNew = s.provider.db.Len(sid) == 0 s.updateCookie(ctx, sid, s.config.Expires, cookieOptions...)