diff --git a/redis/pool.go b/redis/pool.go index 4a468cf0..b62d29e2 100644 --- a/redis/pool.go +++ b/redis/pool.go @@ -44,6 +44,19 @@ var ( errConnClosed = errors.New("redigo: connection closed") ) +type cacheStrategy int + +// Server-assisted cache strategy +// See https://redis.io/topics/client-side-caching +const ( + // Do not cache results locally (default) + CacheStrategyNone cacheStrategy = 0 + // Let server to keep track of requested keys + CacheStrategyInvalidate cacheStrategy = 1 + // Subscribe to invalidation broadcasts manually (not yet implemented) + CacheStrategyBroadcast cacheStrategy = 2 +) + // Pool maintains a pool of connections. The application calls the Get method // to get a connection from the pool and the connection's Close method to // return the connection's resources to the pool. @@ -161,15 +174,21 @@ type Pool struct { // the pool does not close connections based on age. MaxConnLifetime time.Duration + // Controls how Pool handles server-assisted caching + // introduced in Redis 6.0. This is disabled by default + cacheStrategy cacheStrategy + chInitialized uint32 // set to 1 when field ch is initialized - mu sync.Mutex // mu protects the following fields - closed bool // set to true when the pool is closed. - active int // the number of open connections in the pool - ch chan struct{} // limits open connections when p.Wait is true - idle idleList // idle connections - waitCount int64 // total number of connections waited for. - waitDuration time.Duration // total time waited for new connections. + mu sync.Mutex // mu protects the following fields + closed bool // set to true when the pool is closed. + active int // the number of open connections in the pool + ch chan struct{} // limits open connections when p.Wait is true + idle idleList // idle connections + waitCount int64 // total number of connections waited for. + waitDuration time.Duration // total time waited for new connections. + invalidatorID int // Connection that is receiving cache revocation notifications + localCache sync.Map // Local cache } // NewPool creates a new pool. @@ -265,9 +284,39 @@ func (p *Pool) GetContext(ctx context.Context) (Conn, error) { p.mu.Unlock() return errorConn{err}, err } + + p.mu.Lock() + invalidatorID := p.invalidatorID + p.mu.Unlock() + + if p.cacheStrategy == CacheStrategyInvalidate && invalidatorID != 0 { + // Ask server to watch the keys requested by this connection and + // send invalidation notifications to our housekeeper process + _, err = c.Do("CLIENT", "TRACKING", "on", "REDIRECT", invalidatorID) + if err != nil { + c = nil + p.mu.Lock() + p.active-- + if p.ch != nil && !p.closed { + p.ch <- struct{}{} + } + p.mu.Unlock() + return errorConn{err}, err + } + } + return &activeConn{p: p, pc: &poolConn{c: c, created: nowFunc()}}, nil } +// EnableCaching method +func (p *Pool) EnableCaching(strategy cacheStrategy) error { + p.cacheStrategy = strategy + if strategy == CacheStrategyInvalidate { + return p.cacheInvalidator() + } + return nil +} + // PoolStats contains pool statistics. type PoolStats struct { // ActiveCount is the number of connections in the pool. The count includes @@ -283,16 +332,27 @@ type PoolStats struct { // WaitDuration is the total time blocked waiting for a new connection. // This value is currently not guaranteed to be 100% accurate. WaitDuration time.Duration + + // Number of keys saved locally via server-assisted caching + CachedEntries int } // Stats returns pool's statistics. func (p *Pool) Stats() PoolStats { p.mu.Lock() + // Count total number of entries + cachedEntries := 0 + p.localCache.Range(func(key, value interface{}) bool { + cachedEntries++ + return true + }) + stats := PoolStats{ - ActiveCount: p.active, - IdleCount: p.idle.count, - WaitCount: p.waitCount, - WaitDuration: p.waitDuration, + ActiveCount: p.active, + IdleCount: p.idle.count, + WaitCount: p.waitCount, + WaitDuration: p.waitDuration, + CachedEntries: cachedEntries, } p.mu.Unlock() @@ -437,6 +497,61 @@ func (p *Pool) put(pc *poolConn, forceClose bool) error { return nil } +// cacheInvalidator start housekeeping goroutine handling cache invalidation +func (p *Pool) cacheInvalidator() error { + // Connection used for invalidation + conn, err := p.Dial() + if err != nil { + return err + } + + id, err := Int(conn.Do("CLIENT", "ID")) + if err != nil { + return err + } + + p.mu.Lock() + p.invalidatorID = id + p.mu.Unlock() + + // Subscribe to revocation channel + conn.Send("SUBSCRIBE", "__redis__:invalidate") + conn.Flush() + + go func() { + for { + p.mu.Lock() + if p.closed { + defer p.mu.Unlock() + conn.Close() + return + } + p.mu.Unlock() + + event, err := Values(conn.Receive()) + if err != nil { + continue + } + + etype, _ := String(event[0], nil) + + if etype == "message" { + // Remove revoked keys from local cache + keys, err := Strings(event[2], nil) + if err != nil { // Sometimes Redis decides to send for some reason + continue + } + + for _, key := range keys { + p.localCache.Delete(key) + } + } + } + }() + + return nil +} + type activeConn struct { p *Pool pc *poolConn @@ -513,7 +628,10 @@ func (ac *activeConn) Do(commandName string, args ...interface{}) (reply interfa } ci := lookupCommandInfo(commandName) ac.state = (ac.state | ci.Set) &^ ci.Clear - return pc.c.Do(commandName, args...) + + return ac.cacheLoader(commandName, args, func() (interface{}, error) { + return ac.pc.c.Do(commandName, args...) + }) } func (ac *activeConn) DoWithTimeout(timeout time.Duration, commandName string, args ...interface{}) (reply interface{}, err error) { @@ -527,7 +645,33 @@ func (ac *activeConn) DoWithTimeout(timeout time.Duration, commandName string, a } ci := lookupCommandInfo(commandName) ac.state = (ac.state | ci.Set) &^ ci.Clear - return cwt.DoWithTimeout(timeout, commandName, args...) + + return ac.cacheLoader(commandName, args, func() (interface{}, error) { + return cwt.DoWithTimeout(timeout, commandName, args...) + }) +} + +// Loader-style middleware to cache requests. For now it only caches GET commands +func (ac *activeConn) cacheLoader(commandName string, args []interface{}, f func() (interface{}, error)) (interface{}, error) { + if ac.p.cacheStrategy == CacheStrategyInvalidate && commandName == "GET" { + key := args[0].(string) + + // Happy path + if entry, ok := ac.p.localCache.Load(key); ok { + return entry, nil + } + + resp, err := f() + if err != nil { + return resp, err + } + + // Cache response + ac.p.localCache.Store(key, resp) + return resp, nil + } + + return f() } func (ac *activeConn) Send(commandName string, args ...interface{}) error { diff --git a/redis/pool_test.go b/redis/pool_test.go index a8efa515..ce0ad3b3 100644 --- a/redis/pool_test.go +++ b/redis/pool_test.go @@ -934,3 +934,52 @@ func TestWaitPoolGetCanceledContext(t *testing.T) { } }) } + +func TestPoolServerAssistedCaching(t *testing.T) { + d := poolDialer{t: t} + p := &redis.Pool{ + Dial: d.dial, + } + + if err := p.EnableCaching(redis.CacheStrategyInvalidate); err != nil { + d.t.Fatal("Unexpected error: ", err) + } + + defer p.Close() + + conn := p.Get() + defer conn.Close() + + _, err := conn.Do("SET", "key", "value") + if err != nil { + d.t.Error(err) + } + + // Server knows we cached the results + _, err = conn.Do("GET", "key") + if err != nil { + d.t.Error(err) + } + + n := p.Stats().CachedEntries + if p.Stats().CachedEntries != 1 { + d.t.Fatalf("got %v entries, should be 1", n) + } + + conn2 := p.Get() + defer conn2.Close() + + // Invalidate cache on the other side + _, err = conn2.Do("SET", "key", "another value") + if err != nil { + d.t.Error(err) + } + + // Invalidation is happening async + time.Sleep(time.Millisecond * 50) + + n = p.Stats().CachedEntries + if n != 0 { + d.t.Fatalf("got %v entries, should be 0", n) + } +}