Skip to content

Commit

Permalink
Merge pull request #21 from ReneKroon/issue20
Browse files Browse the repository at this point in the history
Issue20
  • Loading branch information
ReneKroon authored Jun 17, 2019
2 parents 841717f + 13264ef commit 823b876
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 35 deletions.
1 change: 0 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
language: go

go:
- 1.12.6
- 1.12
- 1.11
git:
Expand Down
51 changes: 23 additions & 28 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ type expireCallback func(key string, value interface{})

// Cache is a synchronized map of items that can auto-expire once stale
type Cache struct {
mutex sync.RWMutex
mutex sync.Mutex
ttl time.Duration
items map[string]*item
expireCallback expireCallback
Expand All @@ -25,13 +25,10 @@ type Cache struct {
skipTTLExtension bool
}

func (cache *Cache) getItem(key string) (*item, bool) {
cache.mutex.RLock()

func (cache *Cache) getItem(key string) (*item, bool, bool) {
item, exists := cache.items[key]
if !exists || item.expired() {
cache.mutex.RUnlock()
return nil, false
return nil, false, false
}

if item.ttl >= 0 && (item.ttl > 0 || cache.ttl > 0) {
Expand All @@ -45,12 +42,15 @@ func (cache *Cache) getItem(key string) (*item, bool) {
cache.priorityQueue.update(item)
}

cache.mutex.RUnlock()
cache.expirationNotificationTrigger(item)
return item, exists
expirationNotification := false
if cache.expirationTime.After(time.Now().Add(item.ttl)) {
expirationNotification = true
}
return item, exists, expirationNotification
}

func (cache *Cache) startExpirationProcessing() {
timer := time.NewTimer(time.Hour)
for {
var sleepTime time.Duration
cache.mutex.Lock()
Expand All @@ -74,7 +74,7 @@ func (cache *Cache) startExpirationProcessing() {
cache.expirationTime = time.Now().Add(sleepTime)
cache.mutex.Unlock()

timer := time.NewTimer(sleepTime)
timer.Reset(sleepTime)
select {
case <-timer.C:
timer.Stop()
Expand Down Expand Up @@ -119,25 +119,15 @@ func (cache *Cache) startExpirationProcessing() {
}
}

func (cache *Cache) expirationNotificationTrigger(item *item) {
cache.mutex.Lock()
if cache.expirationTime.After(time.Now().Add(item.ttl)) {
cache.mutex.Unlock()
cache.expirationNotification <- true
} else {
cache.mutex.Unlock()
}
}

// Set is a thread-safe way to add new items to the map
func (cache *Cache) Set(key string, data interface{}) {
cache.SetWithTTL(key, data, ItemExpireWithGlobalTTL)
}

// SetWithTTL is a thread-safe way to add new items to the map with individual ttl
func (cache *Cache) SetWithTTL(key string, data interface{}, ttl time.Duration) {
item, exists := cache.getItem(key)
cache.mutex.Lock()
item, exists, _ := cache.getItem(key)

if exists {
item.data = data
Expand Down Expand Up @@ -170,13 +160,18 @@ func (cache *Cache) SetWithTTL(key string, data interface{}, ttl time.Duration)
// Get is a thread-safe way to lookup items
// Every lookup, also touches the item, hence extending it's life
func (cache *Cache) Get(key string) (interface{}, bool) {
item, exists := cache.getItem(key)
cache.mutex.Lock()
item, exists, triggerExpirationNotification := cache.getItem(key)

var dataToReturn interface{}
if exists {
cache.mutex.RLock()
defer cache.mutex.RUnlock()
return item.data, true
dataToReturn = item.data
}
return nil, false
cache.mutex.Unlock()
if triggerExpirationNotification {
cache.expirationNotification <- true
}
return dataToReturn, exists
}

func (cache *Cache) Remove(key string) bool {
Expand All @@ -195,9 +190,9 @@ func (cache *Cache) Remove(key string) bool {

// Count returns the number of items in the cache
func (cache *Cache) Count() int {
cache.mutex.RLock()
cache.mutex.Lock()
length := len(cache.items)
cache.mutex.RUnlock()
cache.mutex.Unlock()
return length
}

Expand Down
63 changes: 57 additions & 6 deletions cache_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ttlcache

import (
"math/rand"
"testing"
"time"

Expand Down Expand Up @@ -34,6 +35,50 @@ func TestCache_SkipTtlExtensionOnHit(t *testing.T) {
}
}

func TestCache_ForRacesAcrossGoroutines(t *testing.T) {
cache := NewCache()
cache.SetTTL(time.Minute * 1)
cache.SkipTtlExtensionOnHit(false)

var wgSet sync.WaitGroup
var wgGet sync.WaitGroup

n := 500
wgSet.Add(1)
go func() {
for i := 0; i < n; i++ {
wgSet.Add(1)

go func(i int) {
time.Sleep(time.Nanosecond * time.Duration(rand.Int63n(1000000)))
if i%2 == 0 {
cache.Set(fmt.Sprintf("test%d", i /10), false)
} else {
cache.SetWithTTL(fmt.Sprintf("test%d", i /10), false, time.Second*59)
}
wgSet.Done()
}(i)
}
wgSet.Done()
}()
wgGet.Add(1)
go func() {
for i := 0; i < n; i++ {
wgGet.Add(1)

go func(i int) {
time.Sleep(time.Nanosecond * time.Duration(rand.Int63n(1000000)))
cache.Get(fmt.Sprintf("test%d", i /10))
wgGet.Done()
}(i)
}
wgGet.Done()
}()

wgGet.Wait()
wgSet.Wait()
}

func TestCache_SkipTtlExtensionOnHit_ForRacesAcrossGoroutines(t *testing.T) {
cache := NewCache()
cache.SetTTL(time.Minute * 1)
Expand All @@ -48,10 +93,15 @@ func TestCache_SkipTtlExtensionOnHit_ForRacesAcrossGoroutines(t *testing.T) {
for i := 0; i < n; i++ {
wgSet.Add(1)

go func() {
cache.Set("test", false)
go func(i int) {
time.Sleep(time.Nanosecond * time.Duration(rand.Int63n(1000000)))
if i%2 == 0 {
cache.Set(fmt.Sprintf("test%d", i /10), false)
} else {
cache.SetWithTTL(fmt.Sprintf("test%d", i /10), false, time.Second*59)
}
wgSet.Done()
}()
}(i)
}
wgSet.Done()
}()
Expand All @@ -60,10 +110,11 @@ func TestCache_SkipTtlExtensionOnHit_ForRacesAcrossGoroutines(t *testing.T) {
for i := 0; i < n; i++ {
wgGet.Add(1)

go func() {
cache.Get("test")
go func(i int) {
time.Sleep(time.Nanosecond * time.Duration(rand.Int63n(1000000)))
cache.Get(fmt.Sprintf("test%d", i /10))
wgGet.Done()
}()
}(i)
}
wgGet.Done()
}()
Expand Down

0 comments on commit 823b876

Please sign in to comment.