diff --git a/kit/localcache/README.md b/kit/localcache/README.md new file mode 100644 index 0000000..1de42c3 --- /dev/null +++ b/kit/localcache/README.md @@ -0,0 +1,2 @@ +# go-localcache +in-process cache writen in go and managed by timingwheel diff --git a/kit/localcache/internal/atomic.go b/kit/localcache/internal/atomic.go new file mode 100644 index 0000000..5872c90 --- /dev/null +++ b/kit/localcache/internal/atomic.go @@ -0,0 +1,39 @@ +package internal + +import "sync/atomic" + +type AtomicBool uint32 + +func NewAtomicBool() *AtomicBool { + return new(AtomicBool) +} + +func ForAtomicBool(val bool) *AtomicBool { + ab := NewAtomicBool() + ab.Set(val) + return ab +} + +func (a *AtomicBool) CompareAndSwap(old, new bool) bool { + var oldV, newV uint32 + if old { + oldV = 1 + } + if new { + newV = 1 + } + + return atomic.CompareAndSwapUint32((*uint32)(a), oldV, newV) +} + +func (a *AtomicBool) Set(new bool) { + if new { + atomic.StoreUint32((*uint32)(a), 1) + } else { + atomic.StoreUint32((*uint32)(a), 0) + } +} + +func (a *AtomicBool) True() bool { + return atomic.LoadUint32((*uint32)(a)) == 1 +} diff --git a/kit/localcache/internal/lru.go b/kit/localcache/internal/lru.go new file mode 100644 index 0000000..e420e11 --- /dev/null +++ b/kit/localcache/internal/lru.go @@ -0,0 +1,82 @@ +package internal + +import "container/list" + +type ( + Lru interface { + Add(key string) + Remove(key string) + } + + noneLru struct{} + + keyLru struct { + limit int + evicts *list.List + elements map[string]*list.Element + onEvict func(key string) + } +) + +var ( + _ Lru = &noneLru{} + _ Lru = &keyLru{} +) + +// NewNoneLru return an empty lru implement, do not manager keys. +// when cache have a limit of count, use this to make the flow correct +func NewNoneLru() Lru { + return &noneLru{} +} + +func (l *noneLru) Add(key string) {} + +func (l *noneLru) Remove(key string) {} + +// NewLru return a Lru entry with least-recently-use algorithm +func NewLru(limit int, onEvict func(key string)) Lru { + return &keyLru{ + limit: limit, + evicts: list.New(), + elements: make(map[string]*list.Element), + onEvict: onEvict, + } +} + +func (l *keyLru) Remove(key string) { + if elem, ok := l.elements[key]; ok { + l.removeElem(elem) + } +} + +func (l *keyLru) Add(key string) { + if v, ok := l.elements[key]; ok { + // 元素存在, 移至队首 + l.evicts.MoveToFront(v) + return + } + + // 新增元素 + elem := l.evicts.PushFront(key) + l.elements[key] = elem + + // 超出列表长度, 移除队尾元素 + if l.evicts.Len() > l.limit { + l.removeOldest() + } +} + +func (l *keyLru) removeOldest() { + elem := l.evicts.Back() + l.removeElem(elem) +} + +func (l *keyLru) removeElem(e *list.Element) { + if e == nil { + return + } + l.evicts.Remove(e) + key := e.Value.(string) + delete(l.elements, key) + l.onEvict(key) +} diff --git a/kit/localcache/internal/safemap.go b/kit/localcache/internal/safemap.go new file mode 100644 index 0000000..267dfa8 --- /dev/null +++ b/kit/localcache/internal/safemap.go @@ -0,0 +1,98 @@ +package internal + +import ( + "sync" +) + +const ( + copyThreshold = 1000 + maxDeletion = 10000 +) + +// SafeMap provides a map alternative to avoid memory leak. +// This implementation is not needed until issue below fixed. +// https://github.com/golang/go/issues/20135 +type SafeMap struct { + lock sync.RWMutex + deletionOld int + deletionNew int + dirtyOld map[interface{}]interface{} + dirtyNew map[interface{}]interface{} +} + +// NewSafeMap returns a SafeMap. +func NewSafeMap() *SafeMap { + return &SafeMap{ + dirtyOld: make(map[interface{}]interface{}), + dirtyNew: make(map[interface{}]interface{}), + } +} + +// Del deletes the value with the given key from m. +func (m *SafeMap) Del(key interface{}) { + m.lock.Lock() + if _, ok := m.dirtyOld[key]; ok { + delete(m.dirtyOld, key) + m.deletionOld++ + } else if _, ok := m.dirtyNew[key]; ok { + delete(m.dirtyNew, key) + m.deletionNew++ + } + if m.deletionOld >= maxDeletion && len(m.dirtyOld) < copyThreshold { + for k, v := range m.dirtyOld { + m.dirtyNew[k] = v + } + m.dirtyOld = m.dirtyNew + m.deletionOld = m.deletionNew + m.dirtyNew = make(map[interface{}]interface{}) + m.deletionNew = 0 + } + if m.deletionNew >= maxDeletion && len(m.dirtyNew) < copyThreshold { + for k, v := range m.dirtyNew { + m.dirtyOld[k] = v + } + m.dirtyNew = make(map[interface{}]interface{}) + m.deletionNew = 0 + } + m.lock.Unlock() +} + +// Get gets the value with the given key from m. +func (m *SafeMap) Get(key interface{}) (interface{}, bool) { + m.lock.RLock() + defer m.lock.RUnlock() + + if val, ok := m.dirtyOld[key]; ok { + return val, true + } + + val, ok := m.dirtyNew[key] + return val, ok +} + +// Set sets the value into m with the given key. +func (m *SafeMap) Set(key, value interface{}) { + m.lock.Lock() + if m.deletionOld <= maxDeletion { + if _, ok := m.dirtyNew[key]; ok { + delete(m.dirtyNew, key) + m.deletionNew++ + } + m.dirtyOld[key] = value + } else { + if _, ok := m.dirtyOld[key]; ok { + delete(m.dirtyOld, key) + m.deletionOld++ + } + m.dirtyNew[key] = value + } + m.lock.Unlock() +} + +// Size returns the size of m. +func (m *SafeMap) Size() int { + m.lock.RLock() + size := len(m.dirtyOld) + len(m.dirtyNew) + m.lock.RUnlock() + return size +} diff --git a/kit/localcache/internal/stat.go b/kit/localcache/internal/stat.go new file mode 100644 index 0000000..2d5a406 --- /dev/null +++ b/kit/localcache/internal/stat.go @@ -0,0 +1,63 @@ +package internal + +import ( + "log" + "sync/atomic" + "time" +) + +//const statInterval = time.Minute +const statInterval = time.Second + +type Stat struct { + name string + hit uint64 + miss uint64 + sizeCallback func() int +} + +func NewStat(name string, sizeCallback func() int) *Stat { + st := &Stat{ + name: name, + sizeCallback: sizeCallback, + } + go st.report() + + return st +} + +// Hit hit counter++ +func (s *Stat) Hit() { + atomic.AddUint64(&s.hit, 1) +} + +// Miss missed counter++ +func (s *Stat) Miss() { + atomic.AddUint64(&s.miss, 1) +} + +// CurrentMinute return hit and missed counter in a minute +func (s *Stat) CurrentMinute() (hit, miss uint64) { + hit = atomic.LoadUint64(&s.hit) + miss = atomic.LoadUint64(&s.miss) + return hit, miss +} + +func (s *Stat) report() { + ticker := time.NewTicker(statInterval) + defer ticker.Stop() + + for range ticker.C { + hit := atomic.SwapUint64(&s.hit, 0) + miss := atomic.SwapUint64(&s.miss, 0) + total := hit + miss + if total == 0 { + log.Printf("cache(%s) - continue", s.name) + continue + } + + percent := float32(hit) / float32(total) + log.Printf("cache(%s) - qpm: %d, hit_ratio: %.1f%%, elements: %d, hit: %d, miss: %d", + s.name, total, percent*100, s.sizeCallback(), hit, miss) + } +} diff --git a/kit/localcache/internal/ticker.go b/kit/localcache/internal/ticker.go new file mode 100644 index 0000000..08a7e89 --- /dev/null +++ b/kit/localcache/internal/ticker.go @@ -0,0 +1,73 @@ +package internal + +import ( + "errors" + "time" +) + +type ( + Ticker interface { + Chan() <-chan time.Time + Stop() + } + + // FakeTicker for test + FakeTicker interface { + Ticker + Done() + Tick() + Wait(d time.Duration) error + } + + fakeTicker struct { + c chan time.Time + done chan struct{} + } + + realTicker struct { + *time.Ticker + } +) + +func NewTicker(d time.Duration) Ticker { + return &realTicker{ + Ticker: time.NewTicker(d), + } +} + +// Chan implement Ticker +func (r *realTicker) Chan() <-chan time.Time { + return r.C +} + +func NewFakeTicker() FakeTicker { + return &fakeTicker{ + c: make(chan time.Time, 1), + done: make(chan struct{}, 1), + } +} + +func (f *fakeTicker) Chan() <-chan time.Time { + return f.c +} + +func (f *fakeTicker) Stop() { + close(f.c) +} + +func (f *fakeTicker) Done() { + f.done <- struct{}{} +} + +func (f *fakeTicker) Tick() { + f.c <- time.Now() +} + +func (f *fakeTicker) Wait(d time.Duration) error { + select { + case <-time.After(d): + return errors.New("timeout") + case <-f.done: + return nil + } +} diff --git a/kit/localcache/internal/timingwheel.go b/kit/localcache/internal/timingwheel.go new file mode 100644 index 0000000..a705698 --- /dev/null +++ b/kit/localcache/internal/timingwheel.go @@ -0,0 +1,346 @@ +package internal + +import ( + "container/list" + "fmt" + "time" + + "github.com/sado0823/go-kitx/pkg/syncx" +) + +const drainWorkers = 8 + +type ( + Execute func(key, value interface{}) + + TimingWheel struct { + interval time.Duration + ticker Ticker + slots []*list.List + timers *SafeMap + tickedPos int + numSlots int + execute Execute + + setChan chan timingEntry + moveChan chan baseEntry + removeChan chan interface{} + drainChan chan func(key, value interface{}) + stopChan chan struct{} + } + + baseEntry struct { + delay time.Duration + key interface{} + } + + timingEntry struct { + baseEntry + value interface{} + circle int + diff int + removed bool + } + + positionEntry struct { + pos int + item *timingEntry + } + + timingTask struct { + key, value interface{} + } +) + +func NewTimingWheel(interval time.Duration, numSlots int, execute Execute) (*TimingWheel, error) { + if interval <= 0 || numSlots <= 0 || execute == nil { + return nil, fmt.Errorf("invalid param, interval: %v, numSlots: %d, execute: %p", interval, numSlots, execute) + } + + return newTimingWheel(interval, numSlots, execute, NewTicker(interval)) +} + +func newTimingWheel(interval time.Duration, numSlots int, execute Execute, ticker Ticker) (*TimingWheel, error) { + tw := &TimingWheel{ + interval: interval, + ticker: ticker, + slots: make([]*list.List, numSlots), + timers: NewSafeMap(), + tickedPos: numSlots - 1, + numSlots: numSlots, + execute: execute, + + setChan: make(chan timingEntry), + moveChan: make(chan baseEntry), + removeChan: make(chan interface{}), + drainChan: make(chan func(key, value interface{})), + stopChan: make(chan struct{}), + } + + tw.initSlots() + go tw.run() + + return tw, nil +} + +// Drain 执行所有任务 +func (tw *TimingWheel) Drain(fn func(key, value interface{})) { + tw.drainChan <- fn +} + +// MoveTimer 在时间轮上根据指定时间移动指定任务 +func (tw *TimingWheel) MoveTimer(key interface{}, delay time.Duration) { + if delay <= 0 || key == nil { + return + } + + tw.moveChan <- baseEntry{ + delay: delay, + key: key, + } +} + +// RemoveTimer 移除时间轮上的指定任务 +func (tw *TimingWheel) RemoveTimer(key interface{}) { + if key == nil { + return + } + + tw.removeChan <- key +} + +// SetTimer 在时间轮上新增任务 +func (tw *TimingWheel) SetTimer(key, value interface{}, delay time.Duration) { + if delay <= 0 || key == nil { + return + } + + tw.setChan <- timingEntry{ + baseEntry: baseEntry{ + delay: delay, + key: key, + }, + value: value, + } +} + +// Stop 停止时间轮任务轮询 +func (tw TimingWheel) Stop() { + close(tw.stopChan) +} + +func (tw *TimingWheel) initSlots() { + for i := range tw.slots { + tw.slots[i] = list.New() + } +} + +// run chan通信, 多事件轮询 +func (tw *TimingWheel) run() { + for { + select { + case <-tw.ticker.Chan(): + tw.onTick() + case task := <-tw.setChan: + tw.setTask(&task) + case key := <-tw.removeChan: + tw.removeTask(key) + case task := <-tw.moveChan: + tw.moveTask(task) + case fn := <-tw.drainChan: + tw.drainAll(fn) + case <-tw.stopChan: + tw.ticker.Stop() + return + } + } +} + +// 清空所有任务, 能执行优先执行 +func (tw *TimingWheel) drainAll(fn func(key, value interface{})) { + workers := make(chan struct{}, drainWorkers) + + for _, slot := range tw.slots { + for e := slot.Front(); e != nil; { + task := e.Value.(*timingEntry) + next := e.Next() + slot.Remove(e) + e = next + if !task.removed { + workers <- struct{}{} + syncx.GoSave(func() { + defer func() { + <-workers + }() + fn(task.key, task.value) + }) + } + } + } +} + +func (tw *TimingWheel) removeTask(key interface{}) { + val, ok := tw.timers.Get(key) + if !ok { + return + } + + timer := val.(*positionEntry) + timer.item.removed = true + tw.timers.Del(key) +} + +func (tw *TimingWheel) setTask(task *timingEntry) { + // 轮盘最小时间滚动刻度 + if task.delay < tw.interval { + task.delay = tw.interval + } + + if val, ok := tw.timers.Get(task.key); ok { + // 有相同任务 + entry := val.(*positionEntry) + entry.item.value = task.value + tw.moveTask(task.baseEntry) + } else { + // 全新添加任务 + pos, circle := tw.getPosAndCircle(task.delay) + task.circle = circle + tw.slots[pos].PushBack(task) + tw.setTimerPosition(pos, task) + } +} + +// 移动任务 +// 是否能够获取当前任务 +// +// 不能: 退出 +// 能: +// 1) 任务时间小于轮盘时间刻度: 执行任务 +// 2) 多层时间轮任务: 层级-1, 计算层级差值 +// 3) 标记删除旧任务, 设置新任务 +func (tw *TimingWheel) moveTask(task baseEntry) { + val, ok := tw.timers.Get(task.key) + if !ok { + return + } + + timer := val.(*positionEntry) + if task.delay < tw.interval { + syncx.GoSave(func() { + tw.execute(timer.item.key, timer.item.value) + }) + return + } + + pos, circle := tw.getPosAndCircle(task.delay) + if pos >= timer.pos { + // 新任务延后执行 + timer.item.circle = circle + timer.item.diff = pos - timer.pos + } else if circle > 0 { + // 多层级任务 + circle-- + timer.item.circle = circle + timer.item.diff = tw.numSlots + pos - timer.pos + } else { + // 标记为删除任务 + timer.item.removed = true + newItem := &timingEntry{ + baseEntry: task, + value: timer.item.value, + } + tw.slots[pos].PushBack(newItem) + tw.setTimerPosition(pos, newItem) + } + +} + +func (tw *TimingWheel) getPosAndCircle(d time.Duration) (pos, circle int) { + steps := int(d / tw.interval) + pos = (tw.tickedPos + steps) % tw.numSlots + circle = (steps - 1) / tw.numSlots + + return pos, circle +} + +func (tw *TimingWheel) onTick() { + tw.tickedPos = (tw.tickedPos + 1) % tw.numSlots + taskList := tw.slots[tw.tickedPos] + + tw.scanAndRun(taskList) +} + +func (tw *TimingWheel) scanAndRun(taskList *list.List) { + tasks := tw.scanTask(taskList) + tw.runTask(tasks) +} + +// scanTask 轮询所有任务, 执行/删除/移动... +func (tw *TimingWheel) scanTask(taskList *list.List) []timingTask { + var tasks []timingTask + + for e := taskList.Front(); e != nil; { + task := e.Value.(*timingEntry) + if task.removed { + // 标记为可移除任务 + next := e.Next() + taskList.Remove(e) + e = next + continue + } else if task.circle > 0 { + // 多层时间轮任务, 层级 -1 + task.circle-- + e = e.Next() + continue + } else if task.diff > 0 { + // 多层时间轮任务到最底层, 通过diff判断差值, 放入当前轮盘的任务链表中 + next := e.Next() + taskList.Remove(e) + + pos := (tw.tickedPos + task.diff) % tw.numSlots + tw.slots[pos].PushBack(task) + tw.setTimerPosition(pos, task) + task.diff = 0 + e = next + continue + } + + tasks = append(tasks, timingTask{ + key: task.key, + value: task.value, + }) + next := e.Next() + taskList.Remove(e) + tw.timers.Del(task.key) + e = next + } + + return tasks +} + +func (tw *TimingWheel) runTask(tasks []timingTask) { + if len(tasks) == 0 { + return + } + + go func() { + for i := range tasks { + syncx.GoSave(func() { + tw.execute(tasks[i].key, tasks[i].value) + }) + } + }() +} + +func (tw *TimingWheel) setTimerPosition(pos int, task *timingEntry) { + if v, ok := tw.timers.Get(task.key); ok { + timer := v.(*positionEntry) + timer.item = task + timer.pos = pos + } else { + tw.timers.Set(task.key, &positionEntry{ + pos: pos, + item: task, + }) + } +} diff --git a/kit/localcache/internal/timingwheel_test.go b/kit/localcache/internal/timingwheel_test.go new file mode 100644 index 0000000..150dad2 --- /dev/null +++ b/kit/localcache/internal/timingwheel_test.go @@ -0,0 +1,285 @@ +package internal + +import ( + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +const ( + testStep = time.Minute + waitTime = time.Second +) + +func Test_NewTimingWheel(t *testing.T) { + + testCases := []struct { + name string + interval time.Duration + numSlots int + execute func(key, value interface{}) + pass bool + }{ + { + // error interval + name: "error interval", interval: 0, numSlots: 10, execute: nil, pass: false, + }, + { + // error numSlots + name: "error numSlots", interval: testStep, numSlots: 0, execute: nil, pass: false, + }, + { + // err execute + name: "err execute", interval: testStep, numSlots: 10, execute: nil, pass: false, + }, + { + // correct + name: "correct", interval: testStep, numSlots: 10, execute: func(key, value interface{}) {}, pass: true, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + tw, err := NewTimingWheel(testCase.interval, testCase.numSlots, testCase.execute) + if testCase.pass { + assert.Nil(t, err) + defer tw.Stop() + } else { + assert.NotNil(t, err) + } + }) + } +} + +// 立刻执行全部任务 +func Test_TimingWheel_Drain(t *testing.T) { + ticker := NewFakeTicker() + tw, _ := newTimingWheel(testStep, 10, func(key, value interface{}) {}, ticker) + defer tw.Stop() + + tw.SetTimer("a", 1, testStep*1) + tw.SetTimer("b", 3, testStep*3) + tw.SetTimer("c", 5, testStep*5) + + var ( + keys []string + vals []int + mu sync.Mutex + wg sync.WaitGroup + ) + + wg.Add(3) + tw.Drain(func(key, value interface{}) { + mu.Lock() + defer mu.Unlock() + defer wg.Done() + keys = append(keys, key.(string)) + vals = append(vals, value.(int)) + }) + + wg.Wait() + + assert.ElementsMatch(t, keys, []string{"a", "b", "c"}) + assert.ElementsMatch(t, vals, []int{1, 3, 5}) + + var counter int + tw.Drain(func(key, value interface{}) { + counter++ + }) + time.Sleep(time.Millisecond * 100) + assert.Equal(t, 0, counter) +} + +// 任务时间小于轮盘时钟滚动时间, 会立即执行任务 +func Test_TimingWheel_SetTimerSoon(t *testing.T) { + run := NewAtomicBool() + ticker := NewFakeTicker() + + tw, _ := newTimingWheel(testStep, 10, func(key, value interface{}) { + assert.True(t, run.CompareAndSwap(false, true)) + assert.Equal(t, "any", key) + assert.Equal(t, 3, value.(int)) + ticker.Done() + }, ticker) + defer tw.Stop() + + tw.SetTimer("any", 3, testStep>>1) + ticker.Tick() + assert.Nil(t, ticker.Wait(waitTime)) + assert.True(t, run.True()) +} + +// 统一任务设置不同的执行时间, 以较长的时间为准 +func Test_TimingWheel_SetTimerTwice(t *testing.T) { + run := NewAtomicBool() + ticker := NewFakeTicker() + + tw, _ := newTimingWheel(testStep, 10, func(key, value interface{}) { + assert.True(t, run.CompareAndSwap(false, true)) + assert.Equal(t, "a", key) + assert.Equal(t, 5, value.(int)) + ticker.Done() + }, ticker) + defer tw.Stop() + + tw.SetTimer("a", 3, testStep*3) + tw.SetTimer("a", 5, testStep*5) + + for i := 0; i < 6; i++ { + ticker.Tick() + } + + assert.Nil(t, ticker.Wait(waitTime)) + assert.True(t, run.True()) +} + +func Test_TimingWheel_SetTimerWrongDelay(t *testing.T) { + ticker := NewFakeTicker() + + tw, _ := newTimingWheel(testStep, 10, func(key, value interface{}) {}, ticker) + defer tw.Stop() + + assert.NotPanics(t, func() { + tw.SetTimer("a", 3, -testStep) + }) +} + +// 移动时间轮任务 +func Test_TimingWheel_MoveTimer(t *testing.T) { + run := NewAtomicBool() + ticker := NewFakeTicker() + + tw, _ := newTimingWheel(testStep, 3, func(key, value interface{}) { + assert.True(t, run.CompareAndSwap(false, true)) + assert.Equal(t, "a", key) + assert.Equal(t, 3, value.(int)) + ticker.Done() + }, ticker) + + tw.SetTimer("a", 3, testStep*4) + tw.MoveTimer("a", testStep*7) + tw.MoveTimer("a", -testStep*7) + tw.MoveTimer("none", testStep) + + for i := 0; i < 5; i++ { + ticker.Tick() + } + assert.False(t, run.True()) + + for i := 0; i < 3; i++ { + ticker.Tick() + } + assert.Nil(t, ticker.Wait(waitTime)) + assert.True(t, run.True()) +} + +// 移动时间轮, 并且移动的时间小于时间轮刻度 +func Test_TimingWheel_MoveTimerSoon(t *testing.T) { + run := NewAtomicBool() + ticker := NewFakeTicker() + + tw, _ := newTimingWheel(testStep, 3, func(key, value interface{}) { + assert.True(t, run.CompareAndSwap(false, true)) + assert.Equal(t, "a", key) + assert.Equal(t, 3, value.(int)) + ticker.Done() + }, ticker) + defer tw.Stop() + + tw.SetTimer("a", 3, testStep*4) + tw.MoveTimer("a", testStep>>1) + + assert.Nil(t, ticker.Wait(waitTime)) + assert.True(t, run.True()) +} + +func Test_TimingWheel_MoveTimerEarlier(t *testing.T) { + run := NewAtomicBool() + ticker := NewFakeTicker() + + tw, _ := newTimingWheel(testStep, 3, func(key, value interface{}) { + assert.True(t, run.CompareAndSwap(false, true)) + assert.Equal(t, "a", key) + assert.Equal(t, 3, value.(int)) + ticker.Done() + }, ticker) + defer tw.Stop() + + tw.SetTimer("a", 3, testStep*7) + tw.MoveTimer("a", testStep*3) + + for i := 0; i < 4; i++ { + ticker.Tick() + } + + assert.Nil(t, ticker.Wait(waitTime)) + assert.True(t, run.True()) +} + +func Test_TimingWheel_RemoveTimer(t *testing.T) { + run := NewAtomicBool() + ticker := NewFakeTicker() + + tw, _ := newTimingWheel(testStep, 10, func(key, value interface{}) { + run.CompareAndSwap(false, true) + ticker.Done() + }, ticker) + tw.SetTimer("a", 3, testStep*3) + + assert.NotPanics(t, func() { + tw.RemoveTimer("a") + tw.RemoveTimer("none") + tw.RemoveTimer(nil) + }) + + for i := 0; i < 4; i++ { + ticker.Tick() + } + + tw.Stop() + assert.False(t, run.True()) +} + +func Test_TimingWheel_MoveAndRemoveTimer(t *testing.T) { + ticker := NewFakeTicker() + tick := func(counter int) { + for i := 0; i < counter; i++ { + ticker.Tick() + } + } + + var keys []int + tw, _ := newTimingWheel(testStep, 10, func(key, value interface{}) { + assert.Equal(t, "foo", key) + assert.Equal(t, 3, value.(int)) + keys = append(keys, value.(int)) + ticker.Done() + }, ticker) + defer tw.Stop() + + tw.SetTimer("foo", 3, testStep*8) + tick(6) + + tw.MoveTimer("foo", testStep*7) + tick(3) + assert.Equal(t, 0, len(keys)) + + tw.RemoveTimer("foo") + tick(30) + time.Sleep(time.Millisecond) + assert.Equal(t, 0, len(keys)) +} + +func BenchmarkTimingWheel(b *testing.B) { + b.ReportAllocs() + + tw, _ := NewTimingWheel(time.Second, 100, func(k, v interface{}) {}) + for i := 0; i < b.N; i++ { + tw.SetTimer(i, i, time.Second) + tw.SetTimer(b.N+i, b.N+i, time.Second) + tw.MoveTimer(i, time.Second*time.Duration(i)) + tw.RemoveTimer(i) + } +} diff --git a/kit/localcache/localcache.go b/kit/localcache/localcache.go new file mode 100644 index 0000000..d6200eb --- /dev/null +++ b/kit/localcache/localcache.go @@ -0,0 +1,175 @@ +package localcache + +import ( + "context" + "sync" + "time" + + "github.com/sado0823/go-kitx/kit/localcache/internal" + + "golang.org/x/sync/singleflight" +) + +const ( + defaultName = "proc" + timingWheelSlots = 300 + timingWheelInterval = time.Second +) + +type ( + Cache struct { + name string + lock sync.Mutex + data map[string]interface{} + expire time.Duration + + lru internal.Lru + timingWheel *internal.TimingWheel + sf *singleflight.Group + stat *internal.Stat + } + + Option func(cache *Cache) +) + +func WithName(name string) Option { + return func(cache *Cache) { + cache.name = name + } +} + +func WithLimit(limit int) Option { + return func(cache *Cache) { + cache.lru = internal.NewLru(limit, cache.onEvict) + } +} + +func New(expire time.Duration, opts ...Option) (cache *Cache, err error) { + cache = &Cache{ + data: make(map[string]interface{}), + expire: expire, + lru: internal.NewNoneLru(), + sf: &singleflight.Group{}, + } + + for _, opt := range opts { + opt(cache) + } + + if len(cache.name) == 0 { + cache.name = defaultName + } + + cache.stat = internal.NewStat(cache.name, cache.size) + + var tw *internal.TimingWheel + tw, err = internal.NewTimingWheel(timingWheelInterval, timingWheelSlots, func(key, value interface{}) { + v, ok := key.(string) + if !ok { + return + } + + cache.Del(context.Background(), v) + }) + if err != nil { + return nil, err + } + + cache.timingWheel = tw + return cache, nil +} + +func (c *Cache) Take(ctx context.Context, key string, fetch func(ctx context.Context) (interface{}, error)) (interface{}, error) { + if val, ok := c.doGet(ctx, key); ok { + c.stat.Hit() + return val, nil + } + + var fresh bool + val, err, _ := c.sf.Do(key, func() (interface{}, error) { + // double check + if val, ok := c.doGet(ctx, key); ok { + c.stat.Hit() + return val, nil + } + + v, err := fetch(ctx) + if err != nil { + return nil, err + } + + fresh = true + c.Set(ctx, key, v) + return v, nil + }) + if err != nil { + return nil, err + } + + if fresh { + c.stat.Miss() + return val, nil + } + + c.stat.Hit() + return val, nil +} + +func (c *Cache) Set(_ context.Context, key string, value interface{}) { + c.lock.Lock() + _, ok := c.data[key] + c.data[key] = value + c.lru.Add(key) + c.lock.Unlock() + + if ok { + c.timingWheel.MoveTimer(key, c.expire) + } else { + c.timingWheel.SetTimer(key, value, c.expire) + } +} + +func (c *Cache) Get(ctx context.Context, key string) (value interface{}, ok bool) { + value, ok = c.doGet(ctx, key) + if ok { + c.stat.Hit() + } else { + c.stat.Miss() + } + + return value, ok +} + +func (c *Cache) Del(_ context.Context, key string) { + c.lock.Lock() + delete(c.data, key) + c.lru.Remove(key) + c.lock.Unlock() + + // using chan + c.timingWheel.RemoveTimer(key) +} + +func (c *Cache) doGet(_ context.Context, key string) (value interface{}, ok bool) { + c.lock.Lock() + defer c.lock.Unlock() + + value, ok = c.data[key] + if ok { + c.lru.Add(key) + } + + return value, ok +} + +func (c *Cache) onEvict(key string) { + // already locked + delete(c.data, key) + c.timingWheel.RemoveTimer(key) +} + +func (c *Cache) size() int { + c.lock.Lock() + defer c.lock.Unlock() + return len(c.data) +} diff --git a/kit/localcache/localcache_test.go b/kit/localcache/localcache_test.go new file mode 100644 index 0000000..b0a9c4f --- /dev/null +++ b/kit/localcache/localcache_test.go @@ -0,0 +1,242 @@ +package localcache + +import ( + "context" + "errors" + "strconv" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func Test_Set(t *testing.T) { + cache, err := New(time.Second*3, WithName("foo")) + assert.Nil(t, err) + + ctx := context.Background() + + cache.Set(ctx, "foo", "bar") + cache.Set(ctx, "foo2", "bar2") + cache.Set(ctx, "foo3", "bar3") + + value, ok := cache.Get(ctx, "foo") + assert.True(t, ok) + assert.Equal(t, "bar", value) + + value, ok = cache.Get(ctx, "foo2") + assert.True(t, ok) + assert.Equal(t, "bar2", value) +} + +func Test_Del(t *testing.T) { + cache, err := New(time.Second * 3) + assert.Nil(t, err) + + ctx := context.Background() + + cache.Set(ctx, "foo", "bar") + cache.Set(ctx, "foo2", "bar2") + + cache.Del(ctx, "foo") + + value, ok := cache.Get(ctx, "foo") + assert.False(t, ok) + assert.Nil(t, value) + + value, ok = cache.Get(ctx, "foo2") + assert.True(t, ok) + assert.Equal(t, "bar2", value) +} + +func Test_Take(t *testing.T) { + cache, err := New(time.Second * 3) + assert.Nil(t, err) + + var ( + counter int32 + wg sync.WaitGroup + ctx = context.Background() + ) + + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + value, errN := cache.Take(ctx, "foo", func(ctx context.Context) (interface{}, error) { + atomic.AddInt32(&counter, 1) + time.Sleep(time.Millisecond * 100) + return "bar", nil + }) + assert.Equal(t, "bar", value) + assert.Nil(t, errN) + }() + } + + wg.Wait() + + assert.Equal(t, 1, cache.size()) + assert.Equal(t, int32(1), atomic.LoadInt32(&counter)) +} + +func Test_TakeExists(t *testing.T) { + cache, err := New(time.Second * 3) + assert.Nil(t, err) + + var ( + counter int32 + wg sync.WaitGroup + ctx = context.Background() + ) + + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + cache.Set(ctx, "foo", "bar") + value, errN := cache.Take(ctx, "foo", func(ctx context.Context) (interface{}, error) { + atomic.AddInt32(&counter, 1) + time.Sleep(time.Millisecond * 100) + return "bar", nil + }) + assert.Equal(t, "bar", value) + assert.Nil(t, errN) + }() + } + + wg.Wait() + + assert.Equal(t, 1, cache.size()) + assert.Equal(t, int32(0), atomic.LoadInt32(&counter)) +} + +func Test_TakeError(t *testing.T) { + cache, err := New(time.Second * 3) + assert.Nil(t, err) + + var ( + counter int32 + wg sync.WaitGroup + ctx = context.Background() + errNoob = errors.New("noob") + ) + + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + value, errN := cache.Take(ctx, "foo", func(ctx context.Context) (interface{}, error) { + atomic.AddInt32(&counter, 1) + time.Sleep(time.Millisecond * 100) + return nil, errNoob + }) + assert.Nil(t, value) + assert.ErrorIs(t, errN, errNoob) + }() + } + + wg.Wait() + + assert.Equal(t, 0, cache.size()) + assert.Equal(t, int32(1), atomic.LoadInt32(&counter)) +} + +func Test_WithLruEvicts(t *testing.T) { + cache, err := New(time.Second*3, WithLimit(3)) + assert.Nil(t, err) + + var ( + ctx = context.Background() + ) + + cache.Set(ctx, "foo1", "bar1") + cache.Set(ctx, "foo2", "bar2") + cache.Set(ctx, "foo3", "bar3") + cache.Set(ctx, "foo4", "bar4") + + get, ok := cache.Get(ctx, "foo1") + assert.False(t, ok) + assert.Nil(t, get) + + get, ok = cache.Get(ctx, "foo2") + assert.True(t, ok) + assert.Equal(t, "bar2", get) + + get, ok = cache.Get(ctx, "foo3") + assert.True(t, ok) + assert.Equal(t, "bar3", get) + + get, ok = cache.Get(ctx, "foo4") + assert.True(t, ok) + assert.Equal(t, "bar4", get) + +} + +func Test_WithLruEvicted(t *testing.T) { + cache, err := New(time.Second*3, WithLimit(3)) + assert.Nil(t, err) + + var ( + ctx = context.Background() + ) + + cache.Set(ctx, "foo1", "bar1") + cache.Set(ctx, "foo2", "bar2") + cache.Set(ctx, "foo3", "bar3") + cache.Set(ctx, "foo4", "bar4") + + get, ok := cache.Get(ctx, "foo1") + assert.Nil(t, get) + assert.False(t, ok) + + get, ok = cache.Get(ctx, "foo2") + assert.Equal(t, "bar2", get) + assert.True(t, ok) + + cache.Set(ctx, "foo5", "bar5") + cache.Set(ctx, "foo6", "bar6") + + get, ok = cache.Get(ctx, "foo3") + assert.Nil(t, get) + assert.False(t, ok) + + get, ok = cache.Get(ctx, "foo4") + assert.Nil(t, get) + assert.False(t, ok) + + get, ok = cache.Get(ctx, "foo2") + assert.Equal(t, "bar2", get) + assert.True(t, ok) +} + +func Benchmark_Cache(b *testing.B) { + cache, err := New(time.Second*5, WithLimit(100000)) + if err != nil { + b.Fatal(err) + } + + ctx := context.Background() + + for i := 0; i < 10000; i++ { + for j := 0; j < 10; j++ { + index := strconv.Itoa(i*10000 + j) + cache.Set(ctx, "key:"+index, "value:"+index) + } + } + + time.Sleep(time.Second * 5) + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + for i := 0; i < b.N; i++ { + index := strconv.Itoa(i % 10000) + cache.Get(ctx, "key:"+index) + if i%100 == 0 { + cache.Set(ctx, "key1:"+index, "value1:"+index) + } + } + } + }) +}