Skip to content

Commit

Permalink
Merge pull request #17 from froodian/froodian-fix-data-race
Browse files Browse the repository at this point in the history
Froodian fix data race
  • Loading branch information
ReneKroon authored Apr 26, 2019
2 parents 5a11b50 + 2e915b8 commit 4a025f8
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 8 deletions.
16 changes: 8 additions & 8 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ type expireCallback func(key string, value interface{})

// Cache is a synchronized map of items that can auto-expire once stale
type Cache struct {
mutex sync.Mutex
mutex sync.RWMutex
ttl time.Duration
items map[string]*item
expireCallback expireCallback
Expand All @@ -26,11 +26,11 @@ type Cache struct {
}

func (cache *Cache) getItem(key string) (*item, bool) {
cache.mutex.Lock()
cache.mutex.RLock()

item, exists := cache.items[key]
if !exists || item.expired() {
cache.mutex.Unlock()
cache.mutex.RUnlock()
return nil, false
}

Expand All @@ -45,7 +45,7 @@ func (cache *Cache) getItem(key string) (*item, bool) {
cache.priorityQueue.update(item)
}

cache.mutex.Unlock()
cache.mutex.RUnlock()
cache.expirationNotificationTrigger(item)
return item, exists
}
Expand Down Expand Up @@ -169,12 +169,13 @@ func (cache *Cache) SetWithTTL(key string, data interface{}, ttl time.Duration)
func (cache *Cache) Get(key string) (interface{}, bool) {
item, exists := cache.getItem(key)
if exists {
cache.mutex.RLock()
defer cache.mutex.RUnlock()
return item.data, true
}
return nil, false
}

// Remove a key from the cache, returns true if the object was in the cache
func (cache *Cache) Remove(key string) bool {
cache.mutex.Lock()
object, exists := cache.items[key]
Expand All @@ -191,13 +192,12 @@ func (cache *Cache) Remove(key string) bool {

// Count returns the number of items in the cache
func (cache *Cache) Count() int {
cache.mutex.Lock()
cache.mutex.RLock()
length := len(cache.items)
cache.mutex.Unlock()
cache.mutex.RUnlock()
return length
}

// SetTTL sets the global TTL for the cache, can be overridden on a per element basis.
func (cache *Cache) SetTTL(ttl time.Duration) {
cache.mutex.Lock()
cache.ttl = ttl
Expand Down
38 changes: 38 additions & 0 deletions cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,44 @@ func TestCache_SkipTtlExtensionOnHit(t *testing.T) {
}
}

func TestCache_SkipTtlExtensionOnHit_ForRacesAcrossGoroutines(t *testing.T) {
cache := NewCache()
cache.SetTTL(time.Minute * 1)
cache.SkipTtlExtensionOnHit(true)

var wgSet sync.WaitGroup
var wgGet sync.WaitGroup

n := 500
wgSet.Add(1)
go func() {
for i := 0; i < n; i++ {
wgSet.Add(1)

go func() {
cache.Set("test", false)
wgSet.Done()
}()
}
wgSet.Done()
}()
wgGet.Add(1)
go func() {
for i := 0; i < n; i++ {
wgGet.Add(1)

go func() {
cache.Get("test")
wgGet.Done()
}()
}
wgGet.Done()
}()

wgGet.Wait()
wgSet.Wait()
}

// test github issue #14
// Testing expiration callback would continue with the next item in list, even when it exceeds list lengths
func TestCache_SetCheckExpirationCallback(t *testing.T) {
Expand Down

0 comments on commit 4a025f8

Please sign in to comment.