diff --git a/.travis.yml b/.travis.yml index 5f29dde..2e53593 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,6 @@ language: go go: - - 1.12.6 - 1.12 - 1.11 git: diff --git a/cache.go b/cache.go index ede6abf..e52e902 100644 --- a/cache.go +++ b/cache.go @@ -13,7 +13,7 @@ type expireCallback func(key string, value interface{}) // Cache is a synchronized map of items that can auto-expire once stale type Cache struct { - mutex sync.RWMutex + mutex sync.Mutex ttl time.Duration items map[string]*item expireCallback expireCallback @@ -25,13 +25,10 @@ type Cache struct { skipTTLExtension bool } -func (cache *Cache) getItem(key string) (*item, bool) { - cache.mutex.RLock() - +func (cache *Cache) getItem(key string) (*item, bool, bool) { item, exists := cache.items[key] if !exists || item.expired() { - cache.mutex.RUnlock() - return nil, false + return nil, false, false } if item.ttl >= 0 && (item.ttl > 0 || cache.ttl > 0) { @@ -45,12 +42,15 @@ func (cache *Cache) getItem(key string) (*item, bool) { cache.priorityQueue.update(item) } - cache.mutex.RUnlock() - cache.expirationNotificationTrigger(item) - return item, exists + expirationNotification := false + if cache.expirationTime.After(time.Now().Add(item.ttl)) { + expirationNotification = true + } + return item, exists, expirationNotification } func (cache *Cache) startExpirationProcessing() { + timer := time.NewTimer(time.Hour) for { var sleepTime time.Duration cache.mutex.Lock() @@ -74,7 +74,7 @@ func (cache *Cache) startExpirationProcessing() { cache.expirationTime = time.Now().Add(sleepTime) cache.mutex.Unlock() - timer := time.NewTimer(sleepTime) + timer.Reset(sleepTime) select { case <-timer.C: timer.Stop() @@ -119,16 +119,6 @@ func (cache *Cache) startExpirationProcessing() { } } -func (cache *Cache) expirationNotificationTrigger(item *item) { - cache.mutex.Lock() - if cache.expirationTime.After(time.Now().Add(item.ttl)) { - cache.mutex.Unlock() - cache.expirationNotification <- true - } else { - cache.mutex.Unlock() - } -} - // Set is a thread-safe way to add new items to the map func (cache *Cache) Set(key string, data interface{}) { cache.SetWithTTL(key, data, ItemExpireWithGlobalTTL) @@ -136,8 +126,8 @@ func (cache *Cache) Set(key string, data interface{}) { // SetWithTTL is a thread-safe way to add new items to the map with individual ttl func (cache *Cache) SetWithTTL(key string, data interface{}, ttl time.Duration) { - item, exists := cache.getItem(key) cache.mutex.Lock() + item, exists, _ := cache.getItem(key) if exists { item.data = data @@ -170,13 +160,18 @@ func (cache *Cache) SetWithTTL(key string, data interface{}, ttl time.Duration) // Get is a thread-safe way to lookup items // Every lookup, also touches the item, hence extending it's life func (cache *Cache) Get(key string) (interface{}, bool) { - item, exists := cache.getItem(key) + cache.mutex.Lock() + item, exists, triggerExpirationNotification := cache.getItem(key) + + var dataToReturn interface{} if exists { - cache.mutex.RLock() - defer cache.mutex.RUnlock() - return item.data, true + dataToReturn = item.data } - return nil, false + cache.mutex.Unlock() + if triggerExpirationNotification { + cache.expirationNotification <- true + } + return dataToReturn, exists } func (cache *Cache) Remove(key string) bool { @@ -195,9 +190,9 @@ func (cache *Cache) Remove(key string) bool { // Count returns the number of items in the cache func (cache *Cache) Count() int { - cache.mutex.RLock() + cache.mutex.Lock() length := len(cache.items) - cache.mutex.RUnlock() + cache.mutex.Unlock() return length } diff --git a/cache_test.go b/cache_test.go index 711e4c8..272aeb5 100644 --- a/cache_test.go +++ b/cache_test.go @@ -1,6 +1,7 @@ package ttlcache import ( + "math/rand" "testing" "time" @@ -34,6 +35,50 @@ func TestCache_SkipTtlExtensionOnHit(t *testing.T) { } } +func TestCache_ForRacesAcrossGoroutines(t *testing.T) { + cache := NewCache() + cache.SetTTL(time.Minute * 1) + cache.SkipTtlExtensionOnHit(false) + + var wgSet sync.WaitGroup + var wgGet sync.WaitGroup + + n := 500 + wgSet.Add(1) + go func() { + for i := 0; i < n; i++ { + wgSet.Add(1) + + go func(i int) { + time.Sleep(time.Nanosecond * time.Duration(rand.Int63n(1000000))) + if i%2 == 0 { + cache.Set(fmt.Sprintf("test%d", i /10), false) + } else { + cache.SetWithTTL(fmt.Sprintf("test%d", i /10), false, time.Second*59) + } + wgSet.Done() + }(i) + } + wgSet.Done() + }() + wgGet.Add(1) + go func() { + for i := 0; i < n; i++ { + wgGet.Add(1) + + go func(i int) { + time.Sleep(time.Nanosecond * time.Duration(rand.Int63n(1000000))) + cache.Get(fmt.Sprintf("test%d", i /10)) + wgGet.Done() + }(i) + } + wgGet.Done() + }() + + wgGet.Wait() + wgSet.Wait() +} + func TestCache_SkipTtlExtensionOnHit_ForRacesAcrossGoroutines(t *testing.T) { cache := NewCache() cache.SetTTL(time.Minute * 1) @@ -48,10 +93,15 @@ func TestCache_SkipTtlExtensionOnHit_ForRacesAcrossGoroutines(t *testing.T) { for i := 0; i < n; i++ { wgSet.Add(1) - go func() { - cache.Set("test", false) + go func(i int) { + time.Sleep(time.Nanosecond * time.Duration(rand.Int63n(1000000))) + if i%2 == 0 { + cache.Set(fmt.Sprintf("test%d", i /10), false) + } else { + cache.SetWithTTL(fmt.Sprintf("test%d", i /10), false, time.Second*59) + } wgSet.Done() - }() + }(i) } wgSet.Done() }() @@ -60,10 +110,11 @@ func TestCache_SkipTtlExtensionOnHit_ForRacesAcrossGoroutines(t *testing.T) { for i := 0; i < n; i++ { wgGet.Add(1) - go func() { - cache.Get("test") + go func(i int) { + time.Sleep(time.Nanosecond * time.Duration(rand.Int63n(1000000))) + cache.Get(fmt.Sprintf("test%d", i /10)) wgGet.Done() - }() + }(i) } wgGet.Done() }()