diff --git a/HISTORY.md b/HISTORY.md index 3a8467a5a8..a850939b6f 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -371,6 +371,8 @@ Other Improvements: ![DBUG routes](https://iris-go.com/images/v12.2.0-dbug2.png?v=0) +- Fix [#1553](https://github.com/kataras/iris/issues/1553). + - New `DirOptions.PushTargets` and `PushTargetsRegexp` to push index' assets to the client without additional requests. Inspirated by issue [#1562](https://github.com/kataras/iris/issues/1562). Example matching all `.js, .css and .ico` files (recursively): ```go diff --git a/sessions/sessiondb/redis/database.go b/sessions/sessiondb/redis/database.go index 71f7040d5b..b3453831a0 100644 --- a/sessions/sessiondb/redis/database.go +++ b/sessions/sessiondb/redis/database.go @@ -3,6 +3,7 @@ package redis import ( "crypto/tls" "errors" + "strings" "time" "github.com/kataras/iris/v12/sessions" @@ -133,11 +134,12 @@ func (db *Database) Config() *Config { // Acquire receives a session's lifetime from the database, // if the return value is LifeTime{} then the session manager sets the life time based on the expiration duration lives in configuration. func (db *Database) Acquire(sid string, expires time.Duration) sessions.LifeTime { - seconds, hasExpiration, found := db.c.Driver.TTL(sid) + key := db.makeKey(sid, "") + seconds, hasExpiration, found := db.c.Driver.TTL(key) if !found { // fmt.Printf("db.Acquire expires: %s. Seconds: %v\n", expires, expires.Seconds()) // not found, create an entry with ttl and return an empty lifetime, session manager will do its job. - if err := db.c.Driver.Set(sid, sid, int64(expires.Seconds())); err != nil { + if err := db.c.Driver.Set(key, sid, int64(expires.Seconds())); err != nil { golog.Debug(err) } @@ -154,11 +156,14 @@ func (db *Database) Acquire(sid string, expires time.Duration) sessions.LifeTime // OnUpdateExpiration will re-set the database's session's entry ttl. // https://redis.io/commands/expire#refreshing-expires func (db *Database) OnUpdateExpiration(sid string, newExpires time.Duration) error { - return db.c.Driver.UpdateTTLMany(sid, int64(newExpires.Seconds())) + return db.c.Driver.UpdateTTLMany(db.makeKey(sid, ""), int64(newExpires.Seconds())) } func (db *Database) makeKey(sid, key string) string { - return sid + db.c.Delim + key + if key == "" { + return db.c.Prefix + sid + } + return db.c.Prefix + sid + db.c.Delim + key } // Set sets a key value of a specific session. @@ -183,20 +188,23 @@ func (db *Database) Get(sid string, key string) (value interface{}) { return } -func (db *Database) get(key string, outPtr interface{}) { +func (db *Database) get(key string, outPtr interface{}) error { data, err := db.c.Driver.Get(key) if err != nil { // not found. - return + return err } if err = sessions.DefaultTranscoder.Unmarshal(data.([]byte), outPtr); err != nil { golog.Debugf("unable to unmarshal value of key: '%s': %v", key, err) + return err } + + return nil } func (db *Database) keys(sid string) []string { - keys, err := db.c.Driver.GetKeys(sid) + keys, err := db.c.Driver.GetKeys(db.makeKey(sid, "")) if err != nil { golog.Debugf("unable to get all redis keys of session '%s': %v", sid, err) return nil @@ -211,6 +219,7 @@ func (db *Database) Visit(sid string, cb func(key string, value interface{})) { for _, key := range keys { var value interface{} // new value each time, we don't know what user will do in "cb". db.get(key, &value) + key = strings.TrimPrefix(key, db.c.Prefix+sid+db.c.Delim) cb(key, value) } } @@ -245,7 +254,7 @@ func (db *Database) Release(sid string) { // clear all $sid-$key. db.Clear(sid) // and remove the $sid. - err := db.c.Driver.Delete(sid) + err := db.c.Driver.Delete(db.c.Prefix + sid) if err != nil { golog.Debugf("Database.Release.Driver.Delete: %s: %v", sid, err) } diff --git a/sessions/sessiondb/redis/driver_radix.go b/sessions/sessiondb/redis/driver_radix.go index 87bb259719..663ad08617 100644 --- a/sessions/sessiondb/redis/driver_radix.go +++ b/sessions/sessiondb/redis/driver_radix.go @@ -149,18 +149,12 @@ func (r *RadixDriver) CloseConnection() error { // Set sets a key-value to the redis store. // The expiration is setted by the secondsLifetime. func (r *RadixDriver) Set(key string, value interface{}, secondsLifetime int64) error { - // fmt.Printf("%#+v. %T. %s\n", value, value, value) - - // if vB, ok := value.([]byte); ok && secondsLifetime <= 0 { - // return r.pool.Do(radix.Cmd(nil, "MSET", r.Config.Prefix+key, string(vB))) - // } - var cmd radix.CmdAction // if has expiration, then use the "EX" to delete the key automatically. if secondsLifetime > 0 { - cmd = radix.FlatCmd(nil, "SETEX", r.Config.Prefix+key, secondsLifetime, value) + cmd = radix.FlatCmd(nil, "SETEX", key, secondsLifetime, value) } else { - cmd = radix.FlatCmd(nil, "SET", r.Config.Prefix+key, value) // MSET same performance... + cmd = radix.FlatCmd(nil, "SET", key, value) // MSET same performance... } return r.pool.Do(cmd) @@ -168,11 +162,11 @@ func (r *RadixDriver) Set(key string, value interface{}, secondsLifetime int64) // Get returns value, err by its key // returns nil and a filled error if something bad happened. -func (r *RadixDriver) Get(key string) (interface{}, error) { +func (r *RadixDriver) Get(key string /* full key */) (interface{}, error) { var redisVal interface{} mn := radix.MaybeNil{Rcv: &redisVal} - err := r.pool.Do(radix.Cmd(&mn, "GET", r.Config.Prefix+key)) + err := r.pool.Do(radix.Cmd(&mn, "GET", key)) if err != nil { return nil, err } @@ -186,7 +180,7 @@ func (r *RadixDriver) Get(key string) (interface{}, error) { // Read more at: https://redis.io/commands/ttl func (r *RadixDriver) TTL(key string) (seconds int64, hasExpiration bool, found bool) { var redisVal interface{} - err := r.pool.Do(radix.Cmd(&redisVal, "TTL", r.Config.Prefix+key)) + err := r.pool.Do(radix.Cmd(&redisVal, "TTL", key)) if err != nil { return -2, false, false } @@ -198,9 +192,9 @@ func (r *RadixDriver) TTL(key string) (seconds int64, hasExpiration bool, found return } -func (r *RadixDriver) updateTTLConn(key string, newSecondsLifeTime int64) error { +func (r *RadixDriver) updateTTLConn(key string /* full key */, newSecondsLifeTime int64) error { var reply int - err := r.pool.Do(radix.FlatCmd(&reply, "EXPIRE", r.Config.Prefix+key, newSecondsLifeTime)) + err := r.pool.Do(radix.FlatCmd(&reply, "EXPIRE", key, newSecondsLifeTime)) if err != nil { return err } @@ -229,8 +223,8 @@ func (r *RadixDriver) UpdateTTL(key string, newSecondsLifeTime int64) error { // UpdateTTLMany like `UpdateTTL` but for all keys starting with that "prefix", // it is a bit faster operation if you need to update all sessions keys (although it can be even faster if we used hash but this will limit other features), // look the `sessions/Database#OnUpdateExpiration` for example. -func (r *RadixDriver) UpdateTTLMany(prefix string, newSecondsLifeTime int64) error { - keys, err := r.getKeys("0", prefix) +func (r *RadixDriver) UpdateTTLMany(prefix string /* prefix is the sid */, newSecondsLifeTime int64) error { + keys, err := r.getKeys("0", prefix, true) if err != nil { return err } @@ -284,21 +278,26 @@ func (s *scanResult) UnmarshalRESP(br *bufio.Reader) error { return (resp2.Any{I: &s.keys}).UnmarshalRESP(br) } -func (r *RadixDriver) getKeys(cursor, prefix string) ([]string, error) { +func (r *RadixDriver) getKeys(cursor, prefix string, includeSID bool) ([]string, error) { var res scanResult - err := r.pool.Do(radix.Cmd(&res, "SCAN", cursor, "MATCH", r.Config.Prefix+prefix+"*", "COUNT", "300000")) + + if !includeSID { + prefix += r.Config.Delim // delim can be used for fast matching of only keys. + } + pattern := prefix + "*" + + err := r.pool.Do(radix.Cmd(&res, "SCAN", cursor, "MATCH", pattern, "COUNT", "300000")) if err != nil { return nil, err } - if len(res.keys) <= 1 { + if len(res.keys) == 0 { return nil, nil } - keys := res.keys[1:] // the first one is always the session id itself. - + keys := res.keys[0:] if res.cur != "0" { - moreKeys, err := r.getKeys(res.cur, prefix) + moreKeys, err := r.getKeys(res.cur, prefix, includeSID) if err != nil { return nil, err } @@ -312,26 +311,11 @@ func (r *RadixDriver) getKeys(cursor, prefix string) ([]string, error) { // GetKeys returns all redis keys using the "SCAN" with MATCH command. // Read more at: https://redis.io/commands/scan#the-match-option. func (r *RadixDriver) GetKeys(prefix string) ([]string, error) { - return r.getKeys("0", prefix) + return r.getKeys("0", prefix, false) } -// // GetBytes returns bytes representation of a value based on given "key". -// func (r *Service) GetBytes(key string) ([]byte, error) { -// var redisVal []byte -// mn := radix.MaybeNil{Rcv: &redisVal} -// err := r.pool.Do(radix.Cmd(&mn, "GET", r.Config.Prefix+key)) -// if err != nil { -// return nil, err -// } -// if mn.Nil { -// return nil, ErrKeyNotFound.Format(key) -// } - -// return redisVal, nil -// } - // Delete removes redis entry by specific key func (r *RadixDriver) Delete(key string) error { - err := r.pool.Do(radix.Cmd(nil, "DEL", r.Config.Prefix+key)) + err := r.pool.Do(radix.Cmd(nil, "DEL", key)) return err } diff --git a/sessions/sessiondb/redis/driver_redigo.go b/sessions/sessiondb/redis/driver_redigo.go index d34f6247fc..0c3cbfdb7b 100644 --- a/sessions/sessiondb/redis/driver_redigo.go +++ b/sessions/sessiondb/redis/driver_redigo.go @@ -61,9 +61,9 @@ func (r *RedigoDriver) Set(key string, value interface{}, secondsLifetime int64) // if has expiration, then use the "EX" to delete the key automatically. if secondsLifetime > 0 { - _, err = c.Do("SETEX", r.Config.Prefix+key, secondsLifetime, value) + _, err = c.Do("SETEX", key, secondsLifetime, value) } else { - _, err = c.Do("SET", r.Config.Prefix+key, value) + _, err = c.Do("SET", key, value) } return @@ -78,7 +78,7 @@ func (r *RedigoDriver) Get(key string) (interface{}, error) { return nil, err } - redisVal, err := c.Do("GET", r.Config.Prefix+key) + redisVal, err := c.Do("GET", key) if err != nil { return nil, err } @@ -93,7 +93,7 @@ func (r *RedigoDriver) Get(key string) (interface{}, error) { func (r *RedigoDriver) TTL(key string) (seconds int64, hasExpiration bool, found bool) { c := r.pool.Get() defer c.Close() - redisVal, err := c.Do("TTL", r.Config.Prefix+key) + redisVal, err := c.Do("TTL", key) if err != nil { return -2, false, false } @@ -106,7 +106,7 @@ func (r *RedigoDriver) TTL(key string) (seconds int64, hasExpiration bool, found } func (r *RedigoDriver) updateTTLConn(c redis.Conn, key string, newSecondsLifeTime int64) error { - reply, err := c.Do("EXPIRE", r.Config.Prefix+key, newSecondsLifeTime) + reply, err := c.Do("EXPIRE", key, newSecondsLifeTime) if err != nil { return err } @@ -143,14 +143,14 @@ func (r *RedigoDriver) UpdateTTL(key string, newSecondsLifeTime int64) error { // UpdateTTLMany like `UpdateTTL` but for all keys starting with that "prefix", // it is a bit faster operation if you need to update all sessions keys (although it can be even faster if we used hash but this will limit other features), // look the `sessions/Database#OnUpdateExpiration` for example. -func (r *RedigoDriver) UpdateTTLMany(prefix string, newSecondsLifeTime int64) error { +func (r *RedigoDriver) UpdateTTLMany(prefix string /* prefix is the sid */, newSecondsLifeTime int64) error { c := r.pool.Get() defer c.Close() if err := c.Err(); err != nil { return err } - keys, err := r.getKeysConn(c, 0, prefix) + keys, err := r.getKeysConn(c, 0, prefix, true) if err != nil { return err } @@ -184,8 +184,13 @@ func (r *RedigoDriver) GetAll() (interface{}, error) { return redisVal, nil } -func (r *RedigoDriver) getKeysConn(c redis.Conn, cursor interface{}, prefix string) ([]string, error) { - if err := c.Send("SCAN", cursor, "MATCH", r.Config.Prefix+prefix+"*", "COUNT", 300000); err != nil { +func (r *RedigoDriver) getKeysConn(c redis.Conn, cursor interface{}, prefix string, includeSID bool) ([]string, error) { + if !includeSID { + prefix += r.Config.Delim // delim can be used for fast matching of only keys. + } + pattern := prefix + "*" + + if err := c.Send("SCAN", cursor, "MATCH", pattern, "COUNT", 300000); err != nil { return nil, err } @@ -205,7 +210,7 @@ func (r *RedigoDriver) getKeysConn(c redis.Conn, cursor interface{}, prefix stri if len(replies) == 2 { // take the second, it must contain the slice of keys. if keysSliceAsBytes, ok := replies[1].([]interface{}); ok { - n := len(keysSliceAsBytes) - 1 // scan match returns the session id key too. + n := len(keysSliceAsBytes) if n <= 0 { return nil, nil } @@ -213,15 +218,15 @@ func (r *RedigoDriver) getKeysConn(c redis.Conn, cursor interface{}, prefix stri keys := make([]string, n) for i, k := range keysSliceAsBytes { - key := fmt.Sprintf("%s", k)[len(r.Config.Prefix):] - if key == prefix { - continue // it's the session id itself. - } - keys[i] = key + // key := fmt.Sprintf("%s", k)[len(r.Config.Prefix):] + // if key == prefix { + // continue // it's the session id itself. + // } + keys[i] = fmt.Sprintf("%s", k) } if cur := fmt.Sprintf("%s", replies[0]); cur != "0" { - moreKeys, err := r.getKeysConn(c, cur, prefix) + moreKeys, err := r.getKeysConn(c, cur, prefix, includeSID) if err != nil { return nil, err } @@ -247,7 +252,7 @@ func (r *RedigoDriver) GetKeys(prefix string) ([]string, error) { return nil, err } - return r.getKeysConn(c, 0, prefix) + return r.getKeysConn(c, 0, prefix, false) } // GetBytes returns value, err by its key @@ -260,7 +265,7 @@ func (r *RedigoDriver) GetBytes(key string) ([]byte, error) { return nil, err } - redisVal, err := c.Do("GET", r.Config.Prefix+key) + redisVal, err := c.Do("GET", key) if err != nil { return nil, err } @@ -276,7 +281,7 @@ func (r *RedigoDriver) Delete(key string) error { c := r.pool.Get() defer c.Close() - _, err := c.Do("DEL", r.Config.Prefix+key) + _, err := c.Do("DEL", key) return err } diff --git a/sessions/sessions.go b/sessions/sessions.go index 1eb061ba4e..b95b03a256 100644 --- a/sessions/sessions.go +++ b/sessions/sessions.go @@ -20,15 +20,30 @@ type Sessions struct { config Config provider *provider - handlerCookieOpts []context.CookieOption // see `Handler`. + cookieOptions []context.CookieOption // options added on each session cookie action. } // New returns a new fast, feature-rich sessions manager // it can be adapted to an iris station func New(cfg Config) *Sessions { + var cookieOptions []context.CookieOption + if cfg.AllowReclaim { + cookieOptions = append(cookieOptions, context.CookieAllowReclaim(cfg.Cookie)) + } + if !cfg.DisableSubdomainPersistence { + cookieOptions = append(cookieOptions, context.CookieAllowSubdomains(cfg.Cookie)) + } + if cfg.CookieSecureTLS { + cookieOptions = append(cookieOptions, context.CookieSecure) + } + if cfg.Encoding != nil { + cookieOptions = append(cookieOptions, context.CookieEncoding(cfg.Encoding, cfg.Cookie)) + } + return &Sessions{ - config: cfg.Validate(), - provider: newProvider(), + cookieOptions: cookieOptions, + config: cfg.Validate(), + provider: newProvider(), } } @@ -38,9 +53,10 @@ func (s *Sessions) UseDatabase(db Database) { s.provider.RegisterDatabase(db) } -// GetCookieOptions returns any cookie options registered for the `Handler` method. +// GetCookieOptions returns the cookie options registered +// for this sessions manager based on the configuration. func (s *Sessions) GetCookieOptions() []context.CookieOption { - return s.handlerCookieOpts + return s.cookieOptions } // updateCookie gains the ability of updating the session browser cookie to any method which wants to update it @@ -65,7 +81,25 @@ func (s *Sessions) updateCookie(ctx *context.Context, sid string, expires time.D cookie.MaxAge = int(time.Until(cookie.Expires).Seconds()) } - ctx.UpsertCookie(cookie, options...) + s.upsertCookie(ctx, cookie, options) +} + +func (s *Sessions) upsertCookie(ctx *context.Context, cookie *http.Cookie, cookieOptions []context.CookieOption) { + opts := s.cookieOptions + if len(cookieOptions) > 0 { + opts = append(opts, cookieOptions...) + } + + ctx.UpsertCookie(cookie, opts...) +} + +func (s *Sessions) getCookie(ctx *context.Context, cookieOptions []context.CookieOption) string { + opts := s.cookieOptions + if len(cookieOptions) > 0 { + opts = append(opts, cookieOptions...) + } + + return ctx.GetCookie(s.config.Cookie, opts...) } // Start creates or retrieves an existing session for the particular request. @@ -77,7 +111,7 @@ func (s *Sessions) updateCookie(ctx *context.Context, sid string, expires time.D // // NOTE: Use `app.Use(sess.Handler())` instead, avoid using `Start` manually. func (s *Sessions) Start(ctx *context.Context, cookieOptions ...context.CookieOption) *Session { - cookieValue := ctx.GetCookie(s.config.Cookie, cookieOptions...) + cookieValue := s.getCookie(ctx, cookieOptions) if cookieValue == "" { // cookie doesn't exist, let's generate a session and set a cookie. sid := s.config.SessionIDGenerator(ctx) @@ -106,27 +140,9 @@ const sessionContextKey = "iris.session" // To return the request's Session call the `Get(ctx)` package-level function. // // Call `Handler()` once per sessions manager. -func (s *Sessions) Handler(cookieOptions ...context.CookieOption) context.Handler { - s.handlerCookieOpts = cookieOptions - - var requestOptions []context.CookieOption - if s.config.AllowReclaim { - requestOptions = append(requestOptions, context.CookieAllowReclaim(s.config.Cookie)) - } - if !s.config.DisableSubdomainPersistence { - requestOptions = append(requestOptions, context.CookieAllowSubdomains(s.config.Cookie)) - } - if s.config.CookieSecureTLS { - requestOptions = append(requestOptions, context.CookieSecure) - } - if s.config.Encoding != nil { - requestOptions = append(requestOptions, context.CookieEncoding(s.config.Encoding, s.config.Cookie)) - } - +func (s *Sessions) Handler(requestOptions ...context.CookieOption) context.Handler { return func(ctx *context.Context) { - ctx.AddCookieOptions(requestOptions...) // request life-cycle options. - - session := s.Start(ctx, cookieOptions...) // this cookie's end-developer's custom options. + session := s.Start(ctx, requestOptions...) // this cookie's end-developer's custom options. ctx.Values().Set(sessionContextKey, session) ctx.Next() @@ -170,7 +186,7 @@ func (s *Sessions) ShiftExpiration(ctx *context.Context, cookieOptions ...contex // It will return `ErrNotFound` when trying to update expiration on a non-existence or not valid session entry. // It will return `ErrNotImplemented` if a database is used and it does not support this feature, yet. func (s *Sessions) UpdateExpiration(ctx *context.Context, expires time.Duration, cookieOptions ...context.CookieOption) error { - cookieValue := ctx.GetCookie(s.config.Cookie) + cookieValue := s.getCookie(ctx, cookieOptions) if cookieValue == "" { return ErrNotFound } @@ -204,7 +220,7 @@ func (s *Sessions) OnDestroy(listeners ...DestroyListener) { // use `Sessions#Start` method for renewal // or use the Session's Destroy method which does keep the session entry with its values cleared. func (s *Sessions) Destroy(ctx *context.Context) { - cookieValue := ctx.GetCookie(s.config.Cookie) + cookieValue := s.getCookie(ctx, nil) if cookieValue == "" { // nothing to destroy return }