From d62ccb8009e347c7683a53acdbe8bc8d2766fd35 Mon Sep 17 00:00:00 2001 From: Timofey Date: Fri, 2 Oct 2020 11:57:15 +0300 Subject: [PATCH] Allow configuring caching directly via Pool struct --- redis/pool.go | 38 ++++++++++++++++++-------------------- redis/pool_test.go | 7 ++----- 2 files changed, 20 insertions(+), 25 deletions(-) diff --git a/redis/pool.go b/redis/pool.go index b62d29e2..6f1fbba2 100644 --- a/redis/pool.go +++ b/redis/pool.go @@ -176,7 +176,7 @@ type Pool struct { // Controls how Pool handles server-assisted caching // introduced in Redis 6.0. This is disabled by default - cacheStrategy cacheStrategy + CacheStrategy cacheStrategy chInitialized uint32 // set to 1 when field ch is initialized @@ -218,6 +218,17 @@ func (p *Pool) Get() Conn { // If the function completes without error, then the application must close the // returned connection. func (p *Pool) GetContext(ctx context.Context) (Conn, error) { + // Start cache invalidator on first call to the pool + p.mu.Lock() + if p.CacheStrategy == CacheStrategyInvalidate && p.invalidatorID == 0 { + id, err := p.cacheInvalidator() + if err != nil { + return errorConn{err}, err + } + p.invalidatorID = id + } + p.mu.Unlock() + // Wait until there is a vacant connection in the pool. waited, err := p.waitVacantConn(ctx) if err != nil { @@ -289,7 +300,7 @@ func (p *Pool) GetContext(ctx context.Context) (Conn, error) { invalidatorID := p.invalidatorID p.mu.Unlock() - if p.cacheStrategy == CacheStrategyInvalidate && invalidatorID != 0 { + if p.CacheStrategy == CacheStrategyInvalidate && invalidatorID != 0 { // Ask server to watch the keys requested by this connection and // send invalidation notifications to our housekeeper process _, err = c.Do("CLIENT", "TRACKING", "on", "REDIRECT", invalidatorID) @@ -308,15 +319,6 @@ func (p *Pool) GetContext(ctx context.Context) (Conn, error) { return &activeConn{p: p, pc: &poolConn{c: c, created: nowFunc()}}, nil } -// EnableCaching method -func (p *Pool) EnableCaching(strategy cacheStrategy) error { - p.cacheStrategy = strategy - if strategy == CacheStrategyInvalidate { - return p.cacheInvalidator() - } - return nil -} - // PoolStats contains pool statistics. type PoolStats struct { // ActiveCount is the number of connections in the pool. The count includes @@ -498,22 +500,18 @@ func (p *Pool) put(pc *poolConn, forceClose bool) error { } // cacheInvalidator start housekeeping goroutine handling cache invalidation -func (p *Pool) cacheInvalidator() error { +func (p *Pool) cacheInvalidator() (int, error) { // Connection used for invalidation conn, err := p.Dial() if err != nil { - return err + return 0, err } id, err := Int(conn.Do("CLIENT", "ID")) if err != nil { - return err + return 0, err } - p.mu.Lock() - p.invalidatorID = id - p.mu.Unlock() - // Subscribe to revocation channel conn.Send("SUBSCRIBE", "__redis__:invalidate") conn.Flush() @@ -549,7 +547,7 @@ func (p *Pool) cacheInvalidator() error { } }() - return nil + return id, nil } type activeConn struct { @@ -653,7 +651,7 @@ func (ac *activeConn) DoWithTimeout(timeout time.Duration, commandName string, a // Loader-style middleware to cache requests. For now it only caches GET commands func (ac *activeConn) cacheLoader(commandName string, args []interface{}, f func() (interface{}, error)) (interface{}, error) { - if ac.p.cacheStrategy == CacheStrategyInvalidate && commandName == "GET" { + if ac.p.CacheStrategy == CacheStrategyInvalidate && commandName == "GET" { key := args[0].(string) // Happy path diff --git a/redis/pool_test.go b/redis/pool_test.go index ce0ad3b3..ad069f79 100644 --- a/redis/pool_test.go +++ b/redis/pool_test.go @@ -938,11 +938,8 @@ func TestWaitPoolGetCanceledContext(t *testing.T) { func TestPoolServerAssistedCaching(t *testing.T) { d := poolDialer{t: t} p := &redis.Pool{ - Dial: d.dial, - } - - if err := p.EnableCaching(redis.CacheStrategyInvalidate); err != nil { - d.t.Fatal("Unexpected error: ", err) + Dial: d.dial, + CacheStrategy: redis.CacheStrategyInvalidate, } defer p.Close()