diff --git a/map.go b/map.go index eee9fcb..a2730c3 100644 --- a/map.go +++ b/map.go @@ -294,6 +294,34 @@ func (m *Map) LoadAndStore(key string, value interface{}) (actual interface{}, l ) } +// LoadOrTryCompute returns the existing value for the key if present. +// Otherwise, it tries to compute the value using the provided function +// and, if success, returns the computed value. The loaded result is true +// if the value was loaded, false if stored. If the compute attempt was +// cancelled, a nil will be returned. +// +// This call locks a hash table bucket while the compute function +// is executed. It means that modifications on other entries in +// the bucket will be blocked until the valueFn executes. Consider +// this when the function includes long-running operations. +func (m *Map) LoadOrTryCompute( + key string, + valueFn func() (newValue interface{}, cancel bool), +) (value interface{}, loaded bool) { + return m.doCompute( + key, + func(interface{}, bool) (interface{}, bool) { + nv, c := valueFn() + if !c { + return nv, false + } + return nil, true + }, + true, + false, + ) +} + // LoadOrCompute returns the existing value for the key if present. // Otherwise, it computes the value using the provided function and // returns the computed value. The loaded result is true if the value diff --git a/map_test.go b/map_test.go index 749f76d..8d102bc 100644 --- a/map_test.go +++ b/map_test.go @@ -353,7 +353,63 @@ func TestMapLoadOrCompute_FunctionCalledOnce(t *testing.T) { return v }) } + m.Range(func(k string, v interface{}) bool { + if vi, ok := v.(int); !ok || strconv.Itoa(vi) != k { + t.Fatalf("%sth key is not equal to value %d", k, v) + } + return true + }) +} +func TestMapLoadOrTryCompute(t *testing.T) { + const numEntries = 1000 + m := NewMap() + for i := 0; i < numEntries; i++ { + v, loaded := m.LoadOrTryCompute(strconv.Itoa(i), func() (newValue interface{}, cancel bool) { + return i, true + }) + if loaded { + t.Fatalf("value not computed for %d", i) + } + if v != nil { + t.Fatalf("values do not match for %d: %v", i, v) + } + } + if m.Size() != 0 { + t.Fatalf("zero map size expected: %d", m.Size()) + } + for i := 0; i < numEntries; i++ { + v, loaded := m.LoadOrTryCompute(strconv.Itoa(i), func() (newValue interface{}, cancel bool) { + return i, false + }) + if loaded { + t.Fatalf("value not computed for %d", i) + } + if v != i { + t.Fatalf("values do not match for %d: %v", i, v) + } + } + for i := 0; i < numEntries; i++ { + v, loaded := m.LoadOrTryCompute(strconv.Itoa(i), func() (newValue interface{}, cancel bool) { + return i, false + }) + if !loaded { + t.Fatalf("value not loaded for %d", i) + } + if v != i { + t.Fatalf("values do not match for %d: %v", i, v) + } + } +} + +func TestMapLoadOrTryCompute_FunctionCalledOnce(t *testing.T) { + m := NewMap() + for i := 0; i < 100; { + m.LoadOrTryCompute(strconv.Itoa(i), func() (v interface{}, cancel bool) { + v, i = i, i+1 + return v, false + }) + } m.Range(func(k string, v interface{}) bool { if vi, ok := v.(int); !ok || strconv.Itoa(vi) != k { t.Fatalf("%sth key is not equal to value %d", k, v) diff --git a/mapof.go b/mapof.go index 39a1aa5..cbf9cac 100644 --- a/mapof.go +++ b/mapof.go @@ -258,6 +258,34 @@ func (m *MapOf[K, V]) LoadOrCompute(key K, valueFn func() V) (actual V, loaded b ) } +// LoadOrTryCompute returns the existing value for the key if present. +// Otherwise, it tries to compute the value using the provided function +// and, if success, returns the computed value. The loaded result is true +// if the value was loaded, false if stored. If the compute attempt was +// cancelled, a zero value of type V will be returned. +// +// This call locks a hash table bucket while the compute function +// is executed. It means that modifications on other entries in +// the bucket will be blocked until the valueFn executes. Consider +// this when the function includes long-running operations. +func (m *MapOf[K, V]) LoadOrTryCompute( + key K, + valueFn func() (newValue V, cancel bool), +) (value V, loaded bool) { + return m.doCompute( + key, + func(V, bool) (V, bool) { + nv, c := valueFn() + if !c { + return nv, false + } + return nv, true // nv is ignored + }, + true, + false, + ) +} + // Compute either sets the computed new value for the key or deletes // the value for the key. When the delete result of the valueFn function // is set to true, the value will be deleted, if it exists. When delete diff --git a/mapof_test.go b/mapof_test.go index e707909..15b073a 100644 --- a/mapof_test.go +++ b/mapof_test.go @@ -373,6 +373,63 @@ func TestMapOfLoadOrCompute_FunctionCalledOnce(t *testing.T) { }) } +func TestMapOfLoadOrTryCompute(t *testing.T) { + const numEntries = 1000 + m := NewMapOf[string, int]() + for i := 0; i < numEntries; i++ { + v, loaded := m.LoadOrTryCompute(strconv.Itoa(i), func() (newValue int, cancel bool) { + return i, true + }) + if loaded { + t.Fatalf("value not computed for %d", i) + } + if v != 0 { + t.Fatalf("values do not match for %d: %v", i, v) + } + } + if m.Size() != 0 { + t.Fatalf("zero map size expected: %d", m.Size()) + } + for i := 0; i < numEntries; i++ { + v, loaded := m.LoadOrTryCompute(strconv.Itoa(i), func() (newValue int, cancel bool) { + return i, false + }) + if loaded { + t.Fatalf("value not computed for %d", i) + } + if v != i { + t.Fatalf("values do not match for %d: %v", i, v) + } + } + for i := 0; i < numEntries; i++ { + v, loaded := m.LoadOrTryCompute(strconv.Itoa(i), func() (newValue int, cancel bool) { + return i, false + }) + if !loaded { + t.Fatalf("value not loaded for %d", i) + } + if v != i { + t.Fatalf("values do not match for %d: %v", i, v) + } + } +} + +func TestMapOfLoadOrTryCompute_FunctionCalledOnce(t *testing.T) { + m := NewMapOf[int, int]() + for i := 0; i < 100; { + m.LoadOrTryCompute(i, func() (newValue int, cancel bool) { + newValue, i = i, i+1 + return newValue, false + }) + } + m.Range(func(k, v int) bool { + if k != v { + t.Fatalf("%dth key is not equal to value %d", k, v) + } + return true + }) +} + func TestMapOfCompute(t *testing.T) { m := NewMapOf[string, int]() // Store a new value. diff --git a/spscqueue.go b/spscqueue.go index d370b22..6e4f84b 100644 --- a/spscqueue.go +++ b/spscqueue.go @@ -15,17 +15,17 @@ import ( // Based on the data structure from the following article: // https://rigtorp.se/ringbuffer/ type SPSCQueue struct { - cap uint64 - p_idx uint64 + cap uint64 + pidx uint64 //lint:ignore U1000 prevents false sharing - pad0 [cacheLineSize - 8]byte - p_cached_idx uint64 + pad0 [cacheLineSize - 8]byte + pcachedIdx uint64 //lint:ignore U1000 prevents false sharing - pad1 [cacheLineSize - 8]byte - c_idx uint64 + pad1 [cacheLineSize - 8]byte + cidx uint64 //lint:ignore U1000 prevents false sharing - pad2 [cacheLineSize - 8]byte - c_cached_idx uint64 + pad2 [cacheLineSize - 8]byte + ccachedIdx uint64 //lint:ignore U1000 prevents false sharing pad3 [cacheLineSize - 8]byte items []interface{} @@ -48,21 +48,21 @@ func NewSPSCQueue(capacity int) *SPSCQueue { // full and the item was inserted. func (q *SPSCQueue) TryEnqueue(item interface{}) bool { // relaxed memory order would be enough here - idx := atomic.LoadUint64(&q.p_idx) - next_idx := idx + 1 - if next_idx == q.cap { - next_idx = 0 + idx := atomic.LoadUint64(&q.pidx) + nextIdx := idx + 1 + if nextIdx == q.cap { + nextIdx = 0 } - cached_idx := q.c_cached_idx - if next_idx == cached_idx { - cached_idx = atomic.LoadUint64(&q.c_idx) - q.c_cached_idx = cached_idx - if next_idx == cached_idx { + cachedIdx := q.ccachedIdx + if nextIdx == cachedIdx { + cachedIdx = atomic.LoadUint64(&q.cidx) + q.ccachedIdx = cachedIdx + if nextIdx == cachedIdx { return false } } q.items[idx] = item - atomic.StoreUint64(&q.p_idx, next_idx) + atomic.StoreUint64(&q.pidx, nextIdx) return true } @@ -71,22 +71,22 @@ func (q *SPSCQueue) TryEnqueue(item interface{}) bool { // indicates that the queue isn't empty and an item was retrieved. func (q *SPSCQueue) TryDequeue() (item interface{}, ok bool) { // relaxed memory order would be enough here - idx := atomic.LoadUint64(&q.c_idx) - cached_idx := q.p_cached_idx - if idx == cached_idx { - cached_idx = atomic.LoadUint64(&q.p_idx) - q.p_cached_idx = cached_idx - if idx == cached_idx { + idx := atomic.LoadUint64(&q.cidx) + cachedIdx := q.pcachedIdx + if idx == cachedIdx { + cachedIdx = atomic.LoadUint64(&q.pidx) + q.pcachedIdx = cachedIdx + if idx == cachedIdx { return } } item = q.items[idx] q.items[idx] = nil ok = true - next_idx := idx + 1 - if next_idx == q.cap { - next_idx = 0 + nextIdx := idx + 1 + if nextIdx == q.cap { + nextIdx = 0 } - atomic.StoreUint64(&q.c_idx, next_idx) + atomic.StoreUint64(&q.cidx, nextIdx) return } diff --git a/spscqueueof.go b/spscqueueof.go index cf3a13b..3ae132e 100644 --- a/spscqueueof.go +++ b/spscqueueof.go @@ -18,17 +18,17 @@ import ( // Based on the data structure from the following article: // https://rigtorp.se/ringbuffer/ type SPSCQueueOf[I any] struct { - cap uint64 - p_idx uint64 + cap uint64 + pidx uint64 //lint:ignore U1000 prevents false sharing - pad0 [cacheLineSize - 8]byte - p_cached_idx uint64 + pad0 [cacheLineSize - 8]byte + pcachedIdx uint64 //lint:ignore U1000 prevents false sharing - pad1 [cacheLineSize - 8]byte - c_idx uint64 + pad1 [cacheLineSize - 8]byte + cidx uint64 //lint:ignore U1000 prevents false sharing - pad2 [cacheLineSize - 8]byte - c_cached_idx uint64 + pad2 [cacheLineSize - 8]byte + ccachedIdx uint64 //lint:ignore U1000 prevents false sharing pad3 [cacheLineSize - 8]byte items []I @@ -51,21 +51,21 @@ func NewSPSCQueueOf[I any](capacity int) *SPSCQueueOf[I] { // full and the item was inserted. func (q *SPSCQueueOf[I]) TryEnqueue(item I) bool { // relaxed memory order would be enough here - idx := atomic.LoadUint64(&q.p_idx) + idx := atomic.LoadUint64(&q.pidx) next_idx := idx + 1 if next_idx == q.cap { next_idx = 0 } - cached_idx := q.c_cached_idx + cached_idx := q.ccachedIdx if next_idx == cached_idx { - cached_idx = atomic.LoadUint64(&q.c_idx) - q.c_cached_idx = cached_idx + cached_idx = atomic.LoadUint64(&q.cidx) + q.ccachedIdx = cached_idx if next_idx == cached_idx { return false } } q.items[idx] = item - atomic.StoreUint64(&q.p_idx, next_idx) + atomic.StoreUint64(&q.pidx, next_idx) return true } @@ -74,11 +74,11 @@ func (q *SPSCQueueOf[I]) TryEnqueue(item I) bool { // indicates that the queue isn't empty and an item was retrieved. func (q *SPSCQueueOf[I]) TryDequeue() (item I, ok bool) { // relaxed memory order would be enough here - idx := atomic.LoadUint64(&q.c_idx) - cached_idx := q.p_cached_idx + idx := atomic.LoadUint64(&q.cidx) + cached_idx := q.pcachedIdx if idx == cached_idx { - cached_idx = atomic.LoadUint64(&q.p_idx) - q.p_cached_idx = cached_idx + cached_idx = atomic.LoadUint64(&q.pidx) + q.pcachedIdx = cached_idx if idx == cached_idx { return } @@ -91,6 +91,6 @@ func (q *SPSCQueueOf[I]) TryDequeue() (item I, ok bool) { if next_idx == q.cap { next_idx = 0 } - atomic.StoreUint64(&q.c_idx, next_idx) + atomic.StoreUint64(&q.cidx, next_idx) return }