Skip to content

Commit

Permalink
Implement server-assisted client side caching
Browse files Browse the repository at this point in the history
  • Loading branch information
holykol committed Oct 2, 2020
1 parent 6f8b898 commit 105000c
Show file tree
Hide file tree
Showing 2 changed files with 206 additions and 13 deletions.
170 changes: 157 additions & 13 deletions redis/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,19 @@ var (
errConnClosed = errors.New("redigo: connection closed")
)

type cacheStrategy int

// Server-assisted cache strategy
// See https://redis.io/topics/client-side-caching
const (
// Do not cache results locally (default)
CacheStrategyNone cacheStrategy = 0
// Let server to keep track of requested keys
CacheStrategyInvalidate cacheStrategy = 1
// Subscribe to invalidation broadcasts manually (not yet implemented)
CacheStrategyBroadcast cacheStrategy = 2
)

// Pool maintains a pool of connections. The application calls the Get method
// to get a connection from the pool and the connection's Close method to
// return the connection's resources to the pool.
Expand Down Expand Up @@ -161,15 +174,21 @@ type Pool struct {
// the pool does not close connections based on age.
MaxConnLifetime time.Duration

// Controls how Pool handles server-assisted caching
// introduced in Redis 6.0. This is disabled by default
cacheStrategy cacheStrategy

chInitialized uint32 // set to 1 when field ch is initialized

mu sync.Mutex // mu protects the following fields
closed bool // set to true when the pool is closed.
active int // the number of open connections in the pool
ch chan struct{} // limits open connections when p.Wait is true
idle idleList // idle connections
waitCount int64 // total number of connections waited for.
waitDuration time.Duration // total time waited for new connections.
mu sync.Mutex // mu protects the following fields
closed bool // set to true when the pool is closed.
active int // the number of open connections in the pool
ch chan struct{} // limits open connections when p.Wait is true
idle idleList // idle connections
waitCount int64 // total number of connections waited for.
waitDuration time.Duration // total time waited for new connections.
invalidatorID int // Connection that is receiving cache revocation notifications
localCache sync.Map // Local cache
}

// NewPool creates a new pool.
Expand Down Expand Up @@ -265,9 +284,39 @@ func (p *Pool) GetContext(ctx context.Context) (Conn, error) {
p.mu.Unlock()
return errorConn{err}, err
}

p.mu.Lock()
invalidatorID := p.invalidatorID
p.mu.Unlock()

if p.cacheStrategy == CacheStrategyInvalidate && invalidatorID != 0 {
// Ask server to watch the keys requested by this connection and
// send invalidation notifications to our housekeeper process
_, err = c.Do("CLIENT", "TRACKING", "on", "REDIRECT", invalidatorID)
if err != nil {
c = nil
p.mu.Lock()
p.active--
if p.ch != nil && !p.closed {
p.ch <- struct{}{}
}
p.mu.Unlock()
return errorConn{err}, err
}
}

return &activeConn{p: p, pc: &poolConn{c: c, created: nowFunc()}}, nil
}

// EnableCaching method
func (p *Pool) EnableCaching(strategy cacheStrategy) error {
p.cacheStrategy = strategy
if strategy == CacheStrategyInvalidate {
return p.cacheInvalidator()
}
return nil
}

// PoolStats contains pool statistics.
type PoolStats struct {
// ActiveCount is the number of connections in the pool. The count includes
Expand All @@ -283,16 +332,27 @@ type PoolStats struct {
// WaitDuration is the total time blocked waiting for a new connection.
// This value is currently not guaranteed to be 100% accurate.
WaitDuration time.Duration

// Number of keys saved locally via server-assisted caching
CachedEntries int
}

// Stats returns pool's statistics.
func (p *Pool) Stats() PoolStats {
p.mu.Lock()
// Count total number of entries
cachedEntries := 0
p.localCache.Range(func(key, value interface{}) bool {
cachedEntries++
return true
})

stats := PoolStats{
ActiveCount: p.active,
IdleCount: p.idle.count,
WaitCount: p.waitCount,
WaitDuration: p.waitDuration,
ActiveCount: p.active,
IdleCount: p.idle.count,
WaitCount: p.waitCount,
WaitDuration: p.waitDuration,
CachedEntries: cachedEntries,
}
p.mu.Unlock()

Expand Down Expand Up @@ -437,6 +497,61 @@ func (p *Pool) put(pc *poolConn, forceClose bool) error {
return nil
}

// cacheInvalidator start housekeeping goroutine handling cache invalidation
func (p *Pool) cacheInvalidator() error {
// Connection used for invalidation
conn, err := p.Dial()
if err != nil {
return err
}

id, err := Int(conn.Do("CLIENT", "ID"))
if err != nil {
return err
}

p.mu.Lock()
p.invalidatorID = id
p.mu.Unlock()

// Subscribe to revocation channel
conn.Send("SUBSCRIBE", "__redis__:invalidate")
conn.Flush()

go func() {
for {
p.mu.Lock()
if p.closed {
defer p.mu.Unlock()
conn.Close()
return
}
p.mu.Unlock()

event, err := Values(conn.Receive())
if err != nil {
continue
}

etype, _ := String(event[0], nil)

if etype == "message" {
// Remove revoked keys from local cache
keys, err := Strings(event[2], nil)
if err != nil { // Sometimes Redis decides to send <nil> for some reason
continue
}

for _, key := range keys {
p.localCache.Delete(key)
}
}
}
}()

return nil
}

type activeConn struct {
p *Pool
pc *poolConn
Expand Down Expand Up @@ -513,7 +628,10 @@ func (ac *activeConn) Do(commandName string, args ...interface{}) (reply interfa
}
ci := lookupCommandInfo(commandName)
ac.state = (ac.state | ci.Set) &^ ci.Clear
return pc.c.Do(commandName, args...)

return ac.cacheLoader(commandName, args, func() (interface{}, error) {
return ac.pc.c.Do(commandName, args...)
})
}

func (ac *activeConn) DoWithTimeout(timeout time.Duration, commandName string, args ...interface{}) (reply interface{}, err error) {
Expand All @@ -527,7 +645,33 @@ func (ac *activeConn) DoWithTimeout(timeout time.Duration, commandName string, a
}
ci := lookupCommandInfo(commandName)
ac.state = (ac.state | ci.Set) &^ ci.Clear
return cwt.DoWithTimeout(timeout, commandName, args...)

return ac.cacheLoader(commandName, args, func() (interface{}, error) {
return cwt.DoWithTimeout(timeout, commandName, args...)
})
}

// Loader-style middleware to cache requests. For now it only caches GET commands
func (ac *activeConn) cacheLoader(commandName string, args []interface{}, f func() (interface{}, error)) (interface{}, error) {
if ac.p.cacheStrategy == CacheStrategyInvalidate && commandName == "GET" {
key := args[0].(string)

// Happy path
if entry, ok := ac.p.localCache.Load(key); ok {
return entry, nil
}

resp, err := f()
if err != nil {
return resp, err
}

// Cache response
ac.p.localCache.Store(key, resp)
return resp, nil
}

return f()
}

func (ac *activeConn) Send(commandName string, args ...interface{}) error {
Expand Down
49 changes: 49 additions & 0 deletions redis/pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -934,3 +934,52 @@ func TestWaitPoolGetCanceledContext(t *testing.T) {
}
})
}

func TestPoolServerAssistedCaching(t *testing.T) {
d := poolDialer{t: t}
p := &redis.Pool{
Dial: d.dial,
}

if err := p.EnableCaching(redis.CacheStrategyInvalidate); err != nil {
d.t.Fatal("Unexpected error: ", err)
}

defer p.Close()

conn := p.Get()
defer conn.Close()

_, err := conn.Do("SET", "key", "value")
if err != nil {
d.t.Error(err)
}

// Server knows we cached the results
_, err = conn.Do("GET", "key")
if err != nil {
d.t.Error(err)
}

n := p.Stats().CachedEntries
if p.Stats().CachedEntries != 1 {
d.t.Fatalf("got %v entries, should be 1", n)
}

conn2 := p.Get()
defer conn2.Close()

// Invalidate cache on the other side
_, err = conn2.Do("SET", "key", "another value")
if err != nil {
d.t.Error(err)
}

// Invalidation is happening async
time.Sleep(time.Millisecond * 50)

n = p.Stats().CachedEntries
if n != 0 {
d.t.Fatalf("got %v entries, should be 0", n)
}
}

0 comments on commit 105000c

Please sign in to comment.