diff --git a/changelog/18.0/18.0.0/summary.md b/changelog/18.0/18.0.0/summary.md index bc3ca5e6004..96f3cdf8541 100644 --- a/changelog/18.0/18.0.0/summary.md +++ b/changelog/18.0/18.0.0/summary.md @@ -83,10 +83,20 @@ Throttler related `vttablet` flags: - `--throttle_check_as_check_self` is deprecated and will be removed in `v19.0` - `--throttler-config-via-topo` is deprecated after assumed `true` in `v17.0`. It will be removed in a future version. +Cache related `vttablet` flags: + +- `--queryserver-config-query-cache-lfu` is deprecated and will be removed in `v19.0`. The query cache always uses a LFU implementation now. +- `--queryserver-config-query-cache-size` is deprecated and will be removed in `v19.0`. This option only applied to LRU caches, which are now unsupported. + Buffering related `vtgate` flags: - `--buffer_implementation` is deprecated and will be removed in `v19.0` +Cache related `vtgate` flags: + +- `--gate_query_cache_lfu` is deprecated and will be removed in `v19.0`. The query cache always uses a LFU implementation now. +- `--gate_query_cache_size` is deprecated and will be removed in `v19.0`. This option only applied to LRU caches, which are now unsupported. + VTGate flag: - `--schema_change_signal_user` is deprecated and will be removed in `v19.0` diff --git a/go.mod b/go.mod index 9b2692ed3b7..639a22edc6b 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,6 @@ require ( github.com/Azure/azure-storage-blob-go v0.15.0 github.com/DataDog/datadog-go v4.8.3+incompatible github.com/HdrHistogram/hdrhistogram-go v0.9.0 // indirect - github.com/PuerkitoBio/goquery v1.5.1 github.com/aquarapid/vaultlib v0.5.1 github.com/armon/go-metrics v0.4.1 // indirect github.com/aws/aws-sdk-go v1.44.258 @@ -76,7 +75,7 @@ require ( golang.org/x/mod v0.12.0 // indirect golang.org/x/net v0.14.0 golang.org/x/oauth2 v0.7.0 - golang.org/x/sys v0.11.0 // indirect + golang.org/x/sys v0.11.0 golang.org/x/term v0.11.0 golang.org/x/text v0.12.0 golang.org/x/time v0.3.0 @@ -97,6 +96,7 @@ require ( require ( github.com/Shopify/toxiproxy/v2 v2.5.0 github.com/bndr/gotabulate v1.1.2 + github.com/gammazero/deque v0.2.1 github.com/google/safehtml v0.1.0 github.com/hashicorp/go-version v1.6.0 github.com/kr/pretty v0.3.1 @@ -124,7 +124,6 @@ require ( github.com/DataDog/go-tuf v0.3.0--fix-localmeta-fork // indirect github.com/DataDog/sketches-go v1.4.1 // indirect github.com/Microsoft/go-winio v0.6.0 // indirect - github.com/andybalholm/cascadia v1.1.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/coreos/go-semver v0.3.1 // indirect github.com/coreos/go-systemd/v22 v22.5.0 // indirect diff --git a/go.sum b/go.sum index b5d7eb888c7..ffadd6498b9 100644 --- a/go.sum +++ b/go.sum @@ -96,16 +96,12 @@ github.com/Microsoft/go-winio v0.5.0/go.mod h1:JPGBdM1cNvN/6ISo+n8V5iA4v8pBzdOpz github.com/Microsoft/go-winio v0.5.1/go.mod h1:JPGBdM1cNvN/6ISo+n8V5iA4v8pBzdOpzfwIujj1a84= github.com/Microsoft/go-winio v0.6.0 h1:slsWYD/zyx7lCXoZVlvQrj0hPTM1HI4+v1sIda2yDvg= github.com/Microsoft/go-winio v0.6.0/go.mod h1:cTAf44im0RAYeL23bpB+fzCyDH2MJiz2BO69KH/soAE= -github.com/PuerkitoBio/goquery v1.5.1 h1:PSPBGne8NIUWw+/7vFBV+kG2J/5MOjbzc7154OaKCSE= -github.com/PuerkitoBio/goquery v1.5.1/go.mod h1:GsLWisAFVj4WgDibEWF4pvYnkVQBpKBKeU+7zCJoLcc= github.com/Shopify/toxiproxy/v2 v2.5.0 h1:i4LPT+qrSlKNtQf5QliVjdP08GyAH8+BUIc9gT0eahc= github.com/Shopify/toxiproxy/v2 v2.5.0/go.mod h1:yhM2epWtAmel9CB8r2+L+PCmhH6yH2pITaPAo7jxJl0= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= -github.com/andybalholm/cascadia v1.1.0 h1:BuuO6sSfQNFRu1LppgbD25Hr2vLYW25JvxHs5zzsLTo= -github.com/andybalholm/cascadia v1.1.0/go.mod h1:GsXiBklL0woXo1j/WYWtSYYC4ouU9PqHO0sqidkEA4Y= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/aquarapid/vaultlib v0.5.1 h1:vuLWR6bZzLHybjJBSUYPgZlIp6KZ+SXeHLRRYTuk6d4= github.com/aquarapid/vaultlib v0.5.1/go.mod h1:yT7AlEXtuabkxylOc/+Ulyp18tff1+QjgNLTnFWTlOs= @@ -195,6 +191,8 @@ github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMo github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= +github.com/gammazero/deque v0.2.1 h1:qSdsbG6pgp6nL7A0+K/B7s12mcCY/5l5SIUpMOl+dC0= +github.com/gammazero/deque v0.2.1/go.mod h1:LFroj8x4cMYCukHJDbxFCkT+r9AndaJnFMuZDV34tuU= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= @@ -707,7 +705,6 @@ golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91 golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/net v0.0.0-20180218175443-cbe0f9307d01/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= diff --git a/go/cache/cache.go b/go/cache/cache.go deleted file mode 100644 index a801d075fde..00000000000 --- a/go/cache/cache.go +++ /dev/null @@ -1,91 +0,0 @@ -/* -Copyright 2021 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package cache - -// Cache is a generic interface type for a data structure that keeps recently used -// objects in memory and evicts them when it becomes full. -type Cache interface { - Get(key string) (any, bool) - Set(key string, val any) bool - ForEach(callback func(any) bool) - - Delete(key string) - Clear() - - // Wait waits for all pending operations on the cache to settle. Since cache writes - // are asynchronous, a write may not be immediately accessible unless the user - // manually calls Wait. - Wait() - - Len() int - Evictions() int64 - Hits() int64 - Misses() int64 - UsedCapacity() int64 - MaxCapacity() int64 - SetCapacity(int64) - - // Close shuts down this cache and stops any background goroutines. - Close() -} - -type cachedObject interface { - CachedSize(alloc bool) int64 -} - -// NewDefaultCacheImpl returns the default cache implementation for Vitess. The options in the -// Config struct control the memory and entry limits for the cache, and the underlying cache -// implementation. -func NewDefaultCacheImpl(cfg *Config) Cache { - switch { - case cfg == nil: - return &nullCache{} - - case cfg.LFU: - if cfg.MaxEntries == 0 || cfg.MaxMemoryUsage == 0 { - return &nullCache{} - } - return NewRistrettoCache(cfg.MaxEntries, cfg.MaxMemoryUsage, func(val any) int64 { - return val.(cachedObject).CachedSize(true) - }) - - default: - if cfg.MaxEntries == 0 { - return &nullCache{} - } - return NewLRUCache(cfg.MaxEntries, func(_ any) int64 { - return 1 - }) - } -} - -// Config is the configuration options for a cache instance -type Config struct { - // MaxEntries is the estimated amount of entries that the cache will hold at capacity - MaxEntries int64 - // MaxMemoryUsage is the maximum amount of memory the cache can handle - MaxMemoryUsage int64 - // LFU toggles whether to use a new cache implementation with a TinyLFU admission policy - LFU bool -} - -// DefaultConfig is the default configuration for a cache instance in Vitess -var DefaultConfig = &Config{ - MaxEntries: 5000, - MaxMemoryUsage: 32 * 1024 * 1024, - LFU: true, -} diff --git a/go/cache/cache_test.go b/go/cache/cache_test.go deleted file mode 100644 index 911a3bb207b..00000000000 --- a/go/cache/cache_test.go +++ /dev/null @@ -1,47 +0,0 @@ -package cache - -import ( - "fmt" - "testing" - - "github.com/stretchr/testify/require" - - "vitess.io/vitess/go/cache/ristretto" -) - -func TestNewDefaultCacheImpl(t *testing.T) { - assertNullCache := func(t *testing.T, cache Cache) { - _, ok := cache.(*nullCache) - require.True(t, ok) - } - - assertLFUCache := func(t *testing.T, cache Cache) { - _, ok := cache.(*ristretto.Cache) - require.True(t, ok) - } - - assertLRUCache := func(t *testing.T, cache Cache) { - _, ok := cache.(*LRUCache) - require.True(t, ok) - } - - tests := []struct { - cfg *Config - verify func(t *testing.T, cache Cache) - }{ - {&Config{MaxEntries: 0, MaxMemoryUsage: 0, LFU: false}, assertNullCache}, - {&Config{MaxEntries: 0, MaxMemoryUsage: 0, LFU: true}, assertNullCache}, - {&Config{MaxEntries: 100, MaxMemoryUsage: 0, LFU: false}, assertLRUCache}, - {&Config{MaxEntries: 0, MaxMemoryUsage: 1000, LFU: false}, assertNullCache}, - {&Config{MaxEntries: 100, MaxMemoryUsage: 1000, LFU: false}, assertLRUCache}, - {&Config{MaxEntries: 100, MaxMemoryUsage: 0, LFU: true}, assertNullCache}, - {&Config{MaxEntries: 100, MaxMemoryUsage: 1000, LFU: true}, assertLFUCache}, - {&Config{MaxEntries: 0, MaxMemoryUsage: 1000, LFU: true}, assertNullCache}, - } - for _, tt := range tests { - t.Run(fmt.Sprintf("%d.%d.%v", tt.cfg.MaxEntries, tt.cfg.MaxMemoryUsage, tt.cfg.LFU), func(t *testing.T) { - cache := NewDefaultCacheImpl(tt.cfg) - tt.verify(t, cache) - }) - } -} diff --git a/go/cache/lru_cache.go b/go/cache/lru_cache.go index 8cc89ac55a4..d845265b77b 100644 --- a/go/cache/lru_cache.go +++ b/go/cache/lru_cache.go @@ -29,8 +29,6 @@ import ( "time" ) -var _ Cache = &LRUCache{} - // LRUCache is a typical LRU cache implementation. If the cache // reaches the capacity, the least recently used item is deleted from // the cache. Note the capacity is not the number of items, but the diff --git a/go/cache/null.go b/go/cache/null.go deleted file mode 100644 index 2e1eeeb0d2d..00000000000 --- a/go/cache/null.go +++ /dev/null @@ -1,75 +0,0 @@ -/* -Copyright 2021 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package cache - -// nullCache is a no-op cache that does not store items -type nullCache struct{} - -// Get never returns anything on the nullCache -func (n *nullCache) Get(_ string) (any, bool) { - return nil, false -} - -// Set is a no-op in the nullCache -func (n *nullCache) Set(_ string, _ any) bool { - return false -} - -// ForEach iterates the nullCache, which is always empty -func (n *nullCache) ForEach(_ func(any) bool) {} - -// Delete is a no-op in the nullCache -func (n *nullCache) Delete(_ string) {} - -// Clear is a no-op in the nullCache -func (n *nullCache) Clear() {} - -// Wait is a no-op in the nullcache -func (n *nullCache) Wait() {} - -func (n *nullCache) Len() int { - return 0 -} - -// Hits returns number of cache hits since creation -func (n *nullCache) Hits() int64 { - return 0 -} - -// Hits returns number of cache misses since creation -func (n *nullCache) Misses() int64 { - return 0 -} - -// Capacity returns the capacity of the nullCache, which is always 0 -func (n *nullCache) UsedCapacity() int64 { - return 0 -} - -// Capacity returns the capacity of the nullCache, which is always 0 -func (n *nullCache) MaxCapacity() int64 { - return 0 -} - -// SetCapacity sets the capacity of the null cache, which is a no-op -func (n *nullCache) SetCapacity(_ int64) {} - -func (n *nullCache) Evictions() int64 { - return 0 -} - -func (n *nullCache) Close() {} diff --git a/go/cache/ristretto.go b/go/cache/ristretto.go deleted file mode 100644 index 6d6f596a5b9..00000000000 --- a/go/cache/ristretto.go +++ /dev/null @@ -1,28 +0,0 @@ -package cache - -import ( - "vitess.io/vitess/go/cache/ristretto" -) - -var _ Cache = &ristretto.Cache{} - -// NewRistrettoCache returns a Cache implementation based on Ristretto -func NewRistrettoCache(maxEntries, maxCost int64, cost func(any) int64) *ristretto.Cache { - // The TinyLFU paper recommends to allocate 10x times the max entries amount as counters - // for the admission policy; since our caches are small and we're very interested on admission - // accuracy, we're a bit more greedy than 10x - const CounterRatio = 12 - - config := ristretto.Config{ - NumCounters: maxEntries * CounterRatio, - MaxCost: maxCost, - BufferItems: 64, - Metrics: true, - Cost: cost, - } - cache, err := ristretto.NewCache(&config) - if err != nil { - panic(err) - } - return cache -} diff --git a/go/cache/ristretto/bloom/bbloom.go b/go/cache/ristretto/bloom/bbloom.go deleted file mode 100644 index 9d6b1080a2e..00000000000 --- a/go/cache/ristretto/bloom/bbloom.go +++ /dev/null @@ -1,149 +0,0 @@ -// The MIT License (MIT) -// Copyright (c) 2014 Andreas Briese, eduToolbox@Bri-C GmbH, Sarstedt - -// Permission is hereby granted, free of charge, to any person obtaining a copy of -// this software and associated documentation files (the "Software"), to deal in -// the Software without restriction, including without limitation the rights to -// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of -// the Software, and to permit persons to whom the Software is furnished to do so, -// subject to the following conditions: - -// The above copyright notice and this permission notice shall be included in all -// copies or substantial portions of the Software. - -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS -// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR -// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER -// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN -// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -package bloom - -import ( - "math" - "unsafe" -) - -// helper -var mask = []uint8{1, 2, 4, 8, 16, 32, 64, 128} - -func getSize(ui64 uint64) (size uint64, exponent uint64) { - if ui64 < uint64(512) { - ui64 = uint64(512) - } - size = uint64(1) - for size < ui64 { - size <<= 1 - exponent++ - } - return size, exponent -} - -// NewBloomFilterWithErrorRate returns a new bloomfilter with optimal size for the given -// error rate -func NewBloomFilterWithErrorRate(numEntries uint64, wrongs float64) *Bloom { - size := -1 * float64(numEntries) * math.Log(wrongs) / math.Pow(0.69314718056, 2) - locs := math.Ceil(0.69314718056 * size / float64(numEntries)) - return NewBloomFilter(uint64(size), uint64(locs)) -} - -// NewBloomFilter returns a new bloomfilter. -func NewBloomFilter(entries, locs uint64) (bloomfilter *Bloom) { - size, exponent := getSize(entries) - bloomfilter = &Bloom{ - sizeExp: exponent, - size: size - 1, - setLocs: locs, - shift: 64 - exponent, - } - bloomfilter.Size(size) - return bloomfilter -} - -// Bloom filter -type Bloom struct { - bitset []uint64 - ElemNum uint64 - sizeExp uint64 - size uint64 - setLocs uint64 - shift uint64 -} - -// <--- http://www.cse.yorku.ca/~oz/hash.html -// modified Berkeley DB Hash (32bit) -// hash is casted to l, h = 16bit fragments -// func (bl Bloom) absdbm(b *[]byte) (l, h uint64) { -// hash := uint64(len(*b)) -// for _, c := range *b { -// hash = uint64(c) + (hash << 6) + (hash << bl.sizeExp) - hash -// } -// h = hash >> bl.shift -// l = hash << bl.shift >> bl.shift -// return l, h -// } - -// Add adds hash of a key to the bloomfilter. -func (bl *Bloom) Add(hash uint64) { - h := hash >> bl.shift - l := hash << bl.shift >> bl.shift - for i := uint64(0); i < bl.setLocs; i++ { - bl.Set((h + i*l) & bl.size) - bl.ElemNum++ - } -} - -// Has checks if bit(s) for entry hash is/are set, -// returns true if the hash was added to the Bloom Filter. -func (bl Bloom) Has(hash uint64) bool { - h := hash >> bl.shift - l := hash << bl.shift >> bl.shift - for i := uint64(0); i < bl.setLocs; i++ { - if !bl.IsSet((h + i*l) & bl.size) { - return false - } - } - return true -} - -// AddIfNotHas only Adds hash, if it's not present in the bloomfilter. -// Returns true if hash was added. -// Returns false if hash was already registered in the bloomfilter. -func (bl *Bloom) AddIfNotHas(hash uint64) bool { - if bl.Has(hash) { - return false - } - bl.Add(hash) - return true -} - -// TotalSize returns the total size of the bloom filter. -func (bl *Bloom) TotalSize() int { - // The bl struct has 5 members and each one is 8 byte. The bitset is a - // uint64 byte slice. - return len(bl.bitset)*8 + 5*8 -} - -// Size makes Bloom filter with as bitset of size sz. -func (bl *Bloom) Size(sz uint64) { - bl.bitset = make([]uint64, sz>>6) -} - -// Clear resets the Bloom filter. -func (bl *Bloom) Clear() { - clear(bl.bitset) -} - -// Set sets the bit[idx] of bitset. -func (bl *Bloom) Set(idx uint64) { - ptr := unsafe.Pointer(uintptr(unsafe.Pointer(&bl.bitset[idx>>6])) + uintptr((idx%64)>>3)) - *(*uint8)(ptr) |= mask[idx%8] -} - -// IsSet checks if bit[idx] of bitset is set, returns true/false. -func (bl *Bloom) IsSet(idx uint64) bool { - ptr := unsafe.Pointer(uintptr(unsafe.Pointer(&bl.bitset[idx>>6])) + uintptr((idx%64)>>3)) - r := ((*(*uint8)(ptr)) >> (idx % 8)) & 1 - return r == 1 -} diff --git a/go/cache/ristretto/bloom/bbloom_test.go b/go/cache/ristretto/bloom/bbloom_test.go deleted file mode 100644 index 7d280988bae..00000000000 --- a/go/cache/ristretto/bloom/bbloom_test.go +++ /dev/null @@ -1,86 +0,0 @@ -package bloom - -import ( - "crypto/rand" - "os" - "testing" - - _flag "vitess.io/vitess/go/internal/flag" - "vitess.io/vitess/go/vt/log" - - "vitess.io/vitess/go/hack" -) - -var ( - wordlist1 [][]byte - n = uint64(1 << 16) - bf *Bloom -) - -func TestMain(m *testing.M) { - // hack to get rid of an "ERROR: logging before flag.Parse" - _flag.TrickGlog() - wordlist1 = make([][]byte, n) - for i := range wordlist1 { - b := make([]byte, 32) - _, _ = rand.Read(b) - wordlist1[i] = b - } - log.Info("Benchmarks relate to 2**16 OP. --> output/65536 op/ns") - - os.Exit(m.Run()) -} - -func TestM_NumberOfWrongs(t *testing.T) { - bf = NewBloomFilter(n*10, 7) - - cnt := 0 - for i := range wordlist1 { - hash := hack.RuntimeMemhash(wordlist1[i], 0) - if !bf.AddIfNotHas(hash) { - cnt++ - } - } - log.Infof("Bloomfilter New(7* 2**16, 7) (-> size=%v bit): \n Check for 'false positives': %v wrong positive 'Has' results on 2**16 entries => %v %%", len(bf.bitset)<<6, cnt, float64(cnt)/float64(n)) - -} - -func BenchmarkM_New(b *testing.B) { - for r := 0; r < b.N; r++ { - _ = NewBloomFilter(n*10, 7) - } -} - -func BenchmarkM_Clear(b *testing.B) { - bf = NewBloomFilter(n*10, 7) - for i := range wordlist1 { - hash := hack.RuntimeMemhash(wordlist1[i], 0) - bf.Add(hash) - } - b.ResetTimer() - for r := 0; r < b.N; r++ { - bf.Clear() - } -} - -func BenchmarkM_Add(b *testing.B) { - bf = NewBloomFilter(n*10, 7) - b.ResetTimer() - for r := 0; r < b.N; r++ { - for i := range wordlist1 { - hash := hack.RuntimeMemhash(wordlist1[i], 0) - bf.Add(hash) - } - } - -} - -func BenchmarkM_Has(b *testing.B) { - b.ResetTimer() - for r := 0; r < b.N; r++ { - for i := range wordlist1 { - hash := hack.RuntimeMemhash(wordlist1[i], 0) - bf.Has(hash) - } - } -} diff --git a/go/cache/ristretto/cache.go b/go/cache/ristretto/cache.go deleted file mode 100644 index aa6aa2c2870..00000000000 --- a/go/cache/ristretto/cache.go +++ /dev/null @@ -1,701 +0,0 @@ -/* - * Copyright 2019 Dgraph Labs, Inc. and Contributors - * Copyright 2021 The Vitess Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// Package ristretto is a fast, fixed size, in-memory cache with a dual focus on -// throughput and hit ratio performance. You can easily add Ristretto to an -// existing system and keep the most valuable data where you need it. -package ristretto - -import ( - "bytes" - "errors" - "fmt" - "sync" - "sync/atomic" - "time" - "unsafe" - - "vitess.io/vitess/go/hack" -) - -var ( - // TODO: find the optimal value for this or make it configurable - setBufSize = 32 * 1024 -) - -func defaultStringHash(key string) (uint64, uint64) { - const Seed1 = uint64(0x1122334455667788) - const Seed2 = uint64(0x8877665544332211) - return hack.RuntimeStrhash(key, Seed1), hack.RuntimeStrhash(key, Seed2) -} - -type itemCallback func(*Item) - -// CacheItemSize is the overhead in bytes for every stored cache item -var CacheItemSize = hack.RuntimeAllocSize(int64(unsafe.Sizeof(storeItem{}))) - -// Cache is a thread-safe implementation of a hashmap with a TinyLFU admission -// policy and a Sampled LFU eviction policy. You can use the same Cache instance -// from as many goroutines as you want. -type Cache struct { - // store is the central concurrent hashmap where key-value items are stored. - store store - // policy determines what gets let in to the cache and what gets kicked out. - policy policy - // getBuf is a custom ring buffer implementation that gets pushed to when - // keys are read. - getBuf *ringBuffer - // setBuf is a buffer allowing us to batch/drop Sets during times of high - // contention. - setBuf chan *Item - // onEvict is called for item evictions. - onEvict itemCallback - // onReject is called when an item is rejected via admission policy. - onReject itemCallback - // onExit is called whenever a value goes out of scope from the cache. - onExit func(any) - // KeyToHash function is used to customize the key hashing algorithm. - // Each key will be hashed using the provided function. If keyToHash value - // is not set, the default keyToHash function is used. - keyToHash func(string) (uint64, uint64) - // stop is used to stop the processItems goroutine. - stop chan struct{} - // indicates whether cache is closed. - isClosed atomic.Bool - // cost calculates cost from a value. - cost func(value any) int64 - // ignoreInternalCost dictates whether to ignore the cost of internally storing - // the item in the cost calculation. - ignoreInternalCost bool - // Metrics contains a running log of important statistics like hits, misses, - // and dropped items. - Metrics *Metrics -} - -// Config is passed to NewCache for creating new Cache instances. -type Config struct { - // NumCounters determines the number of counters (keys) to keep that hold - // access frequency information. It's generally a good idea to have more - // counters than the max cache capacity, as this will improve eviction - // accuracy and subsequent hit ratios. - // - // For example, if you expect your cache to hold 1,000,000 items when full, - // NumCounters should be 10,000,000 (10x). Each counter takes up 4 bits, so - // keeping 10,000,000 counters would require 5MB of memory. - NumCounters int64 - // MaxCost can be considered as the cache capacity, in whatever units you - // choose to use. - // - // For example, if you want the cache to have a max capacity of 100MB, you - // would set MaxCost to 100,000,000 and pass an item's number of bytes as - // the `cost` parameter for calls to Set. If new items are accepted, the - // eviction process will take care of making room for the new item and not - // overflowing the MaxCost value. - MaxCost int64 - // BufferItems determines the size of Get buffers. - // - // Unless you have a rare use case, using `64` as the BufferItems value - // results in good performance. - BufferItems int64 - // Metrics determines whether cache statistics are kept during the cache's - // lifetime. There *is* some overhead to keeping statistics, so you should - // only set this flag to true when testing or throughput performance isn't a - // major factor. - Metrics bool - // OnEvict is called for every eviction and passes the hashed key, value, - // and cost to the function. - OnEvict func(item *Item) - // OnReject is called for every rejection done via the policy. - OnReject func(item *Item) - // OnExit is called whenever a value is removed from cache. This can be - // used to do manual memory deallocation. Would also be called on eviction - // and rejection of the value. - OnExit func(val any) - // KeyToHash function is used to customize the key hashing algorithm. - // Each key will be hashed using the provided function. If keyToHash value - // is not set, the default keyToHash function is used. - KeyToHash func(string) (uint64, uint64) - // Cost evaluates a value and outputs a corresponding cost. This function - // is ran after Set is called for a new item or an item update with a cost - // param of 0. - Cost func(value any) int64 - // IgnoreInternalCost set to true indicates to the cache that the cost of - // internally storing the value should be ignored. This is useful when the - // cost passed to set is not using bytes as units. Keep in mind that setting - // this to true will increase the memory usage. - IgnoreInternalCost bool -} - -type itemFlag byte - -const ( - itemNew itemFlag = iota - itemDelete - itemUpdate -) - -// Item is passed to setBuf so items can eventually be added to the cache. -type Item struct { - flag itemFlag - Key uint64 - Conflict uint64 - Value any - Cost int64 - wg *sync.WaitGroup -} - -// NewCache returns a new Cache instance and any configuration errors, if any. -func NewCache(config *Config) (*Cache, error) { - switch { - case config.NumCounters == 0: - return nil, errors.New("NumCounters can't be zero") - case config.MaxCost == 0: - return nil, errors.New("Capacity can't be zero") - case config.BufferItems == 0: - return nil, errors.New("BufferItems can't be zero") - } - policy := newPolicy(config.NumCounters, config.MaxCost) - cache := &Cache{ - store: newStore(), - policy: policy, - getBuf: newRingBuffer(policy, config.BufferItems), - setBuf: make(chan *Item, setBufSize), - keyToHash: config.KeyToHash, - stop: make(chan struct{}), - cost: config.Cost, - ignoreInternalCost: config.IgnoreInternalCost, - } - cache.onExit = func(val any) { - if config.OnExit != nil && val != nil { - config.OnExit(val) - } - } - cache.onEvict = func(item *Item) { - if config.OnEvict != nil { - config.OnEvict(item) - } - cache.onExit(item.Value) - } - cache.onReject = func(item *Item) { - if config.OnReject != nil { - config.OnReject(item) - } - cache.onExit(item.Value) - } - if cache.keyToHash == nil { - cache.keyToHash = defaultStringHash - } - if config.Metrics { - cache.collectMetrics() - } - // NOTE: benchmarks seem to show that performance decreases the more - // goroutines we have running cache.processItems(), so 1 should - // usually be sufficient - go cache.processItems() - return cache, nil -} - -// Wait blocks until all the current cache operations have been processed in the background -func (c *Cache) Wait() { - if c == nil || c.isClosed.Load() { - return - } - wg := &sync.WaitGroup{} - wg.Add(1) - c.setBuf <- &Item{wg: wg} - wg.Wait() -} - -// Get returns the value (if any) and a boolean representing whether the -// value was found or not. The value can be nil and the boolean can be true at -// the same time. -func (c *Cache) Get(key string) (any, bool) { - if c == nil || c.isClosed.Load() { - return nil, false - } - keyHash, conflictHash := c.keyToHash(key) - c.getBuf.Push(keyHash) - value, ok := c.store.Get(keyHash, conflictHash) - if ok { - c.Metrics.add(hit, keyHash, 1) - } else { - c.Metrics.add(miss, keyHash, 1) - } - return value, ok -} - -// Set attempts to add the key-value item to the cache. If it returns false, -// then the Set was dropped and the key-value item isn't added to the cache. If -// it returns true, there's still a chance it could be dropped by the policy if -// its determined that the key-value item isn't worth keeping, but otherwise the -// item will be added and other items will be evicted in order to make room. -// -// The cost of the entry will be evaluated lazily by the cache's Cost function. -func (c *Cache) Set(key string, value any) bool { - return c.SetWithCost(key, value, 0) -} - -// SetWithCost works like Set but adds a key-value pair to the cache with a specific -// cost. The built-in Cost function will not be called to evaluate the object's cost -// and instead the given value will be used. -func (c *Cache) SetWithCost(key string, value any, cost int64) bool { - if c == nil || c.isClosed.Load() { - return false - } - - keyHash, conflictHash := c.keyToHash(key) - i := &Item{ - flag: itemNew, - Key: keyHash, - Conflict: conflictHash, - Value: value, - Cost: cost, - } - // cost is eventually updated. The expiration must also be immediately updated - // to prevent items from being prematurely removed from the map. - if prev, ok := c.store.Update(i); ok { - c.onExit(prev) - i.flag = itemUpdate - } - // Attempt to send item to policy. - select { - case c.setBuf <- i: - return true - default: - if i.flag == itemUpdate { - // Return true if this was an update operation since we've already - // updated the store. For all the other operations (set/delete), we - // return false which means the item was not inserted. - return true - } - c.Metrics.add(dropSets, keyHash, 1) - return false - } -} - -// Delete deletes the key-value item from the cache if it exists. -func (c *Cache) Delete(key string) { - if c == nil || c.isClosed.Load() { - return - } - keyHash, conflictHash := c.keyToHash(key) - // Delete immediately. - _, prev := c.store.Del(keyHash, conflictHash) - c.onExit(prev) - // If we've set an item, it would be applied slightly later. - // So we must push the same item to `setBuf` with the deletion flag. - // This ensures that if a set is followed by a delete, it will be - // applied in the correct order. - c.setBuf <- &Item{ - flag: itemDelete, - Key: keyHash, - Conflict: conflictHash, - } -} - -// Close stops all goroutines and closes all channels. -func (c *Cache) Close() { - if c == nil { - return - } - wasClosed := c.isClosed.Swap(true) - if wasClosed { - return - } - c.Clear() - - // Block until processItems goroutine is returned. - c.stop <- struct{}{} - close(c.stop) - close(c.setBuf) - c.policy.Close() - c.isClosed.Store(true) -} - -// Clear empties the hashmap and zeroes all policy counters. Note that this is -// not an atomic operation (but that shouldn't be a problem as it's assumed that -// Set/Get calls won't be occurring until after this). -func (c *Cache) Clear() { - if c == nil || c.isClosed.Load() { - return - } - // Block until processItems goroutine is returned. - c.stop <- struct{}{} - - // Clear out the setBuf channel. -loop: - for { - select { - case i := <-c.setBuf: - if i.wg != nil { - i.wg.Done() - continue - } - if i.flag != itemUpdate { - // In itemUpdate, the value is already set in the store. So, no need to call - // onEvict here. - c.onEvict(i) - } - default: - break loop - } - } - - // Clear value hashmap and policy data. - c.policy.Clear() - c.store.Clear(c.onEvict) - // Only reset metrics if they're enabled. - if c.Metrics != nil { - c.Metrics.Clear() - } - // Restart processItems goroutine. - go c.processItems() -} - -// Len returns the size of the cache (in entries) -func (c *Cache) Len() int { - if c == nil { - return 0 - } - return c.store.Len() -} - -// UsedCapacity returns the size of the cache (in bytes) -func (c *Cache) UsedCapacity() int64 { - if c == nil { - return 0 - } - return c.policy.Used() -} - -// MaxCapacity returns the max cost of the cache (in bytes) -func (c *Cache) MaxCapacity() int64 { - if c == nil { - return 0 - } - return c.policy.MaxCost() -} - -// SetCapacity updates the maxCost of an existing cache. -func (c *Cache) SetCapacity(maxCost int64) { - if c == nil { - return - } - c.policy.UpdateMaxCost(maxCost) -} - -// Evictions returns the number of evictions -func (c *Cache) Evictions() int64 { - // TODO - if c == nil || c.Metrics == nil { - return 0 - } - return int64(c.Metrics.KeysEvicted()) -} - -// Hits returns the number of cache hits -func (c *Cache) Hits() int64 { - if c == nil || c.Metrics == nil { - return 0 - } - return int64(c.Metrics.Hits()) -} - -// Misses returns the number of cache misses -func (c *Cache) Misses() int64 { - if c == nil || c.Metrics == nil { - return 0 - } - return int64(c.Metrics.Misses()) -} - -// ForEach yields all the values currently stored in the cache to the given callback. -// The callback may return `false` to stop the iteration early. -func (c *Cache) ForEach(forEach func(any) bool) { - if c == nil { - return - } - c.store.ForEach(forEach) -} - -// processItems is ran by goroutines processing the Set buffer. -func (c *Cache) processItems() { - startTs := make(map[uint64]time.Time) - numToKeep := 100000 // TODO: Make this configurable via options. - - trackAdmission := func(key uint64) { - if c.Metrics == nil { - return - } - startTs[key] = time.Now() - if len(startTs) > numToKeep { - for k := range startTs { - if len(startTs) <= numToKeep { - break - } - delete(startTs, k) - } - } - } - onEvict := func(i *Item) { - delete(startTs, i.Key) - if c.onEvict != nil { - c.onEvict(i) - } - } - - for { - select { - case i := <-c.setBuf: - if i.wg != nil { - i.wg.Done() - continue - } - // Calculate item cost value if new or update. - if i.Cost == 0 && c.cost != nil && i.flag != itemDelete { - i.Cost = c.cost(i.Value) - } - if !c.ignoreInternalCost { - // Add the cost of internally storing the object. - i.Cost += CacheItemSize - } - - switch i.flag { - case itemNew: - victims, added := c.policy.Add(i.Key, i.Cost) - if added { - c.store.Set(i) - c.Metrics.add(keyAdd, i.Key, 1) - trackAdmission(i.Key) - } else { - c.onReject(i) - } - for _, victim := range victims { - victim.Conflict, victim.Value = c.store.Del(victim.Key, 0) - onEvict(victim) - } - - case itemUpdate: - c.policy.Update(i.Key, i.Cost) - - case itemDelete: - c.policy.Del(i.Key) // Deals with metrics updates. - _, val := c.store.Del(i.Key, i.Conflict) - c.onExit(val) - } - case <-c.stop: - return - } - } -} - -// collectMetrics just creates a new *Metrics instance and adds the pointers -// to the cache and policy instances. -func (c *Cache) collectMetrics() { - c.Metrics = newMetrics() - c.policy.CollectMetrics(c.Metrics) -} - -type metricType int - -const ( - // The following 2 keep track of hits and misses. - hit = iota - miss - // The following 3 keep track of number of keys added, updated and evicted. - keyAdd - keyUpdate - keyEvict - // The following 2 keep track of cost of keys added and evicted. - costAdd - costEvict - // The following keep track of how many sets were dropped or rejected later. - dropSets - rejectSets - // The following 2 keep track of how many gets were kept and dropped on the - // floor. - dropGets - keepGets - // This should be the final enum. Other enums should be set before this. - doNotUse -) - -func stringFor(t metricType) string { - switch t { - case hit: - return "hit" - case miss: - return "miss" - case keyAdd: - return "keys-added" - case keyUpdate: - return "keys-updated" - case keyEvict: - return "keys-evicted" - case costAdd: - return "cost-added" - case costEvict: - return "cost-evicted" - case dropSets: - return "sets-dropped" - case rejectSets: - return "sets-rejected" // by policy. - case dropGets: - return "gets-dropped" - case keepGets: - return "gets-kept" - default: - return "unidentified" - } -} - -// Metrics is a snapshot of performance statistics for the lifetime of a cache instance. -type Metrics struct { - all [doNotUse][]*uint64 -} - -func newMetrics() *Metrics { - s := &Metrics{} - for i := 0; i < doNotUse; i++ { - s.all[i] = make([]*uint64, 256) - slice := s.all[i] - for j := range slice { - slice[j] = new(uint64) - } - } - return s -} - -func (p *Metrics) add(t metricType, hash, delta uint64) { - if p == nil { - return - } - valp := p.all[t] - // Avoid false sharing by padding at least 64 bytes of space between two - // atomic counters which would be incremented. - idx := (hash % 25) * 10 - atomic.AddUint64(valp[idx], delta) -} - -func (p *Metrics) get(t metricType) uint64 { - if p == nil { - return 0 - } - valp := p.all[t] - var total uint64 - for i := range valp { - total += atomic.LoadUint64(valp[i]) - } - return total -} - -// Hits is the number of Get calls where a value was found for the corresponding key. -func (p *Metrics) Hits() uint64 { - return p.get(hit) -} - -// Misses is the number of Get calls where a value was not found for the corresponding key. -func (p *Metrics) Misses() uint64 { - return p.get(miss) -} - -// KeysAdded is the total number of Set calls where a new key-value item was added. -func (p *Metrics) KeysAdded() uint64 { - return p.get(keyAdd) -} - -// KeysUpdated is the total number of Set calls where the value was updated. -func (p *Metrics) KeysUpdated() uint64 { - return p.get(keyUpdate) -} - -// KeysEvicted is the total number of keys evicted. -func (p *Metrics) KeysEvicted() uint64 { - return p.get(keyEvict) -} - -// CostAdded is the sum of costs that have been added (successful Set calls). -func (p *Metrics) CostAdded() uint64 { - return p.get(costAdd) -} - -// CostEvicted is the sum of all costs that have been evicted. -func (p *Metrics) CostEvicted() uint64 { - return p.get(costEvict) -} - -// SetsDropped is the number of Set calls that don't make it into internal -// buffers (due to contention or some other reason). -func (p *Metrics) SetsDropped() uint64 { - return p.get(dropSets) -} - -// SetsRejected is the number of Set calls rejected by the policy (TinyLFU). -func (p *Metrics) SetsRejected() uint64 { - return p.get(rejectSets) -} - -// GetsDropped is the number of Get counter increments that are dropped -// internally. -func (p *Metrics) GetsDropped() uint64 { - return p.get(dropGets) -} - -// GetsKept is the number of Get counter increments that are kept. -func (p *Metrics) GetsKept() uint64 { - return p.get(keepGets) -} - -// Ratio is the number of Hits over all accesses (Hits + Misses). This is the -// percentage of successful Get calls. -func (p *Metrics) Ratio() float64 { - if p == nil { - return 0.0 - } - hits, misses := p.get(hit), p.get(miss) - if hits == 0 && misses == 0 { - return 0.0 - } - return float64(hits) / float64(hits+misses) -} - -// Clear resets all the metrics. -func (p *Metrics) Clear() { - if p == nil { - return - } - for i := 0; i < doNotUse; i++ { - for j := range p.all[i] { - atomic.StoreUint64(p.all[i][j], 0) - } - } -} - -// String returns a string representation of the metrics. -func (p *Metrics) String() string { - if p == nil { - return "" - } - var buf bytes.Buffer - for i := 0; i < doNotUse; i++ { - t := metricType(i) - fmt.Fprintf(&buf, "%s: %d ", stringFor(t), p.get(t)) - } - fmt.Fprintf(&buf, "gets-total: %d ", p.get(hit)+p.get(miss)) - fmt.Fprintf(&buf, "hit-ratio: %.2f", p.Ratio()) - return buf.String() -} diff --git a/go/cache/ristretto/cache_test.go b/go/cache/ristretto/cache_test.go deleted file mode 100644 index eda9f9109f3..00000000000 --- a/go/cache/ristretto/cache_test.go +++ /dev/null @@ -1,690 +0,0 @@ -/* - * Copyright 2019 Dgraph Labs, Inc. and Contributors - * Copyright 2021 The Vitess Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package ristretto - -import ( - "fmt" - "math/rand" - "strconv" - "strings" - "sync" - "testing" - "time" - - "vitess.io/vitess/go/vt/log" - - "github.com/stretchr/testify/require" -) - -var wait = time.Millisecond * 10 - -func TestCacheKeyToHash(t *testing.T) { - keyToHashCount := 0 - c, err := NewCache(&Config{ - NumCounters: 10, - MaxCost: 1000, - BufferItems: 64, - IgnoreInternalCost: true, - KeyToHash: func(key string) (uint64, uint64) { - keyToHashCount++ - return defaultStringHash(key) - }, - }) - require.NoError(t, err) - if c.SetWithCost("1", 1, 1) { - time.Sleep(wait) - val, ok := c.Get("1") - require.True(t, ok) - require.NotNil(t, val) - c.Delete("1") - } - require.Equal(t, 3, keyToHashCount) -} - -func TestCacheMaxCost(t *testing.T) { - charset := "abcdefghijklmnopqrstuvwxyz0123456789" - key := func() string { - k := make([]byte, 2) - for i := range k { - k[i] = charset[rand.Intn(len(charset))] - } - return string(k) - } - c, err := NewCache(&Config{ - NumCounters: 12960, // 36^2 * 10 - MaxCost: 1e6, // 1mb - BufferItems: 64, - Metrics: true, - }) - require.NoError(t, err) - stop := make(chan struct{}, 8) - for i := 0; i < 8; i++ { - go func() { - for { - select { - case <-stop: - return - default: - time.Sleep(time.Millisecond) - - k := key() - if _, ok := c.Get(k); !ok { - val := "" - if rand.Intn(100) < 10 { - val = "test" - } else { - val = strings.Repeat("a", 1000) - } - c.SetWithCost(key(), val, int64(2+len(val))) - } - } - } - }() - } - for i := 0; i < 20; i++ { - time.Sleep(time.Second) - cacheCost := c.Metrics.CostAdded() - c.Metrics.CostEvicted() - log.Infof("total cache cost: %d", cacheCost) - require.True(t, float64(cacheCost) <= float64(1e6*1.05)) - } - for i := 0; i < 8; i++ { - stop <- struct{}{} - } -} - -func TestUpdateMaxCost(t *testing.T) { - c, err := NewCache(&Config{ - NumCounters: 10, - MaxCost: 10, - BufferItems: 64, - }) - require.NoError(t, err) - require.Equal(t, int64(10), c.MaxCapacity()) - require.True(t, c.SetWithCost("1", 1, 1)) - time.Sleep(wait) - _, ok := c.Get("1") - // Set is rejected because the cost of the entry is too high - // when accounting for the internal cost of storing the entry. - require.False(t, ok) - - // Update the max cost of the cache and retry. - c.SetCapacity(1000) - require.Equal(t, int64(1000), c.MaxCapacity()) - require.True(t, c.SetWithCost("1", 1, 1)) - time.Sleep(wait) - val, ok := c.Get("1") - require.True(t, ok) - require.NotNil(t, val) - c.Delete("1") -} - -func TestNewCache(t *testing.T) { - _, err := NewCache(&Config{ - NumCounters: 0, - }) - require.Error(t, err) - - _, err = NewCache(&Config{ - NumCounters: 100, - MaxCost: 0, - }) - require.Error(t, err) - - _, err = NewCache(&Config{ - NumCounters: 100, - MaxCost: 10, - BufferItems: 0, - }) - require.Error(t, err) - - c, err := NewCache(&Config{ - NumCounters: 100, - MaxCost: 10, - BufferItems: 64, - Metrics: true, - }) - require.NoError(t, err) - require.NotNil(t, c) -} - -func TestNilCache(t *testing.T) { - var c *Cache - val, ok := c.Get("1") - require.False(t, ok) - require.Nil(t, val) - - require.False(t, c.SetWithCost("1", 1, 1)) - c.Delete("1") - c.Clear() - c.Close() -} - -func TestMultipleClose(t *testing.T) { - var c *Cache - c.Close() - - var err error - c, err = NewCache(&Config{ - NumCounters: 100, - MaxCost: 10, - BufferItems: 64, - Metrics: true, - }) - require.NoError(t, err) - c.Close() - c.Close() -} - -func TestSetAfterClose(t *testing.T) { - c, err := newTestCache() - require.NoError(t, err) - require.NotNil(t, c) - - c.Close() - require.False(t, c.SetWithCost("1", 1, 1)) -} - -func TestClearAfterClose(t *testing.T) { - c, err := newTestCache() - require.NoError(t, err) - require.NotNil(t, c) - - c.Close() - c.Clear() -} - -func TestGetAfterClose(t *testing.T) { - c, err := newTestCache() - require.NoError(t, err) - require.NotNil(t, c) - - require.True(t, c.SetWithCost("1", 1, 1)) - c.Close() - - _, ok := c.Get("2") - require.False(t, ok) -} - -func TestDelAfterClose(t *testing.T) { - c, err := newTestCache() - require.NoError(t, err) - require.NotNil(t, c) - - require.True(t, c.SetWithCost("1", 1, 1)) - c.Close() - - c.Delete("1") -} - -func TestCacheProcessItems(t *testing.T) { - m := &sync.Mutex{} - evicted := make(map[uint64]struct{}) - c, err := NewCache(&Config{ - NumCounters: 100, - MaxCost: 10, - BufferItems: 64, - IgnoreInternalCost: true, - Cost: func(value any) int64 { - return int64(value.(int)) - }, - OnEvict: func(item *Item) { - m.Lock() - defer m.Unlock() - evicted[item.Key] = struct{}{} - }, - }) - require.NoError(t, err) - - var key uint64 - var conflict uint64 - - key, conflict = defaultStringHash("1") - c.setBuf <- &Item{ - flag: itemNew, - Key: key, - Conflict: conflict, - Value: 1, - Cost: 0, - } - time.Sleep(wait) - require.True(t, c.policy.Has(key)) - require.Equal(t, int64(1), c.policy.Cost(key)) - - key, conflict = defaultStringHash("1") - c.setBuf <- &Item{ - flag: itemUpdate, - Key: key, - Conflict: conflict, - Value: 2, - Cost: 0, - } - time.Sleep(wait) - require.Equal(t, int64(2), c.policy.Cost(key)) - - key, conflict = defaultStringHash("1") - c.setBuf <- &Item{ - flag: itemDelete, - Key: key, - Conflict: conflict, - } - time.Sleep(wait) - key, conflict = defaultStringHash("1") - val, ok := c.store.Get(key, conflict) - require.False(t, ok) - require.Nil(t, val) - require.False(t, c.policy.Has(1)) - - key, conflict = defaultStringHash("2") - c.setBuf <- &Item{ - flag: itemNew, - Key: key, - Conflict: conflict, - Value: 2, - Cost: 3, - } - key, conflict = defaultStringHash("3") - c.setBuf <- &Item{ - flag: itemNew, - Key: key, - Conflict: conflict, - Value: 3, - Cost: 3, - } - key, conflict = defaultStringHash("4") - c.setBuf <- &Item{ - flag: itemNew, - Key: key, - Conflict: conflict, - Value: 3, - Cost: 3, - } - key, conflict = defaultStringHash("5") - c.setBuf <- &Item{ - flag: itemNew, - Key: key, - Conflict: conflict, - Value: 3, - Cost: 5, - } - time.Sleep(wait) - m.Lock() - require.NotEqual(t, 0, len(evicted)) - m.Unlock() - - defer func() { - require.NotNil(t, recover()) - }() - c.Close() - c.setBuf <- &Item{flag: itemNew} -} - -func TestCacheGet(t *testing.T) { - c, err := NewCache(&Config{ - NumCounters: 100, - MaxCost: 10, - BufferItems: 64, - IgnoreInternalCost: true, - Metrics: true, - }) - require.NoError(t, err) - - key, conflict := defaultStringHash("1") - i := Item{ - Key: key, - Conflict: conflict, - Value: 1, - } - c.store.Set(&i) - val, ok := c.Get("1") - require.True(t, ok) - require.NotNil(t, val) - - val, ok = c.Get("2") - require.False(t, ok) - require.Nil(t, val) - - // 0.5 and not 1.0 because we tried Getting each item twice - require.Equal(t, 0.5, c.Metrics.Ratio()) - - c = nil - val, ok = c.Get("0") - require.False(t, ok) - require.Nil(t, val) -} - -// retrySet calls SetWithCost until the item is accepted by the cache. -func retrySet(t *testing.T, c *Cache, key string, value int, cost int64) { - for { - if set := c.SetWithCost(key, value, cost); !set { - time.Sleep(wait) - continue - } - - time.Sleep(wait) - val, ok := c.Get(key) - require.True(t, ok) - require.NotNil(t, val) - require.Equal(t, value, val.(int)) - return - } -} - -func TestCacheSet(t *testing.T) { - c, err := NewCache(&Config{ - NumCounters: 100, - MaxCost: 10, - IgnoreInternalCost: true, - BufferItems: 64, - Metrics: true, - }) - require.NoError(t, err) - - retrySet(t, c, "1", 1, 1) - - c.SetWithCost("1", 2, 2) - val, ok := c.store.Get(defaultStringHash("1")) - require.True(t, ok) - require.Equal(t, 2, val.(int)) - - c.stop <- struct{}{} - for i := 0; i < setBufSize; i++ { - key, conflict := defaultStringHash("1") - c.setBuf <- &Item{ - flag: itemUpdate, - Key: key, - Conflict: conflict, - Value: 1, - Cost: 1, - } - } - require.False(t, c.SetWithCost("2", 2, 1)) - require.Equal(t, uint64(1), c.Metrics.SetsDropped()) - close(c.setBuf) - close(c.stop) - - c = nil - require.False(t, c.SetWithCost("1", 1, 1)) -} - -func TestCacheInternalCost(t *testing.T) { - c, err := NewCache(&Config{ - NumCounters: 100, - MaxCost: 10, - BufferItems: 64, - Metrics: true, - }) - require.NoError(t, err) - - // Get should return false because the cache's cost is too small to store the item - // when accounting for the internal cost. - c.SetWithCost("1", 1, 1) - time.Sleep(wait) - _, ok := c.Get("1") - require.False(t, ok) -} - -func TestCacheDel(t *testing.T) { - c, err := NewCache(&Config{ - NumCounters: 100, - MaxCost: 10, - BufferItems: 64, - }) - require.NoError(t, err) - - c.SetWithCost("1", 1, 1) - c.Delete("1") - // The deletes and sets are pushed through the setbuf. It might be possible - // that the delete is not processed before the following get is called. So - // wait for a millisecond for things to be processed. - time.Sleep(time.Millisecond) - val, ok := c.Get("1") - require.False(t, ok) - require.Nil(t, val) - - c = nil - defer func() { - require.Nil(t, recover()) - }() - c.Delete("1") -} - -func TestCacheClear(t *testing.T) { - c, err := NewCache(&Config{ - NumCounters: 100, - MaxCost: 10, - IgnoreInternalCost: true, - BufferItems: 64, - Metrics: true, - }) - require.NoError(t, err) - - for i := 0; i < 10; i++ { - c.SetWithCost(strconv.Itoa(i), i, 1) - } - time.Sleep(wait) - require.Equal(t, uint64(10), c.Metrics.KeysAdded()) - - c.Clear() - require.Equal(t, uint64(0), c.Metrics.KeysAdded()) - - for i := 0; i < 10; i++ { - val, ok := c.Get(strconv.Itoa(i)) - require.False(t, ok) - require.Nil(t, val) - } -} - -func TestCacheMetrics(t *testing.T) { - c, err := NewCache(&Config{ - NumCounters: 100, - MaxCost: 10, - IgnoreInternalCost: true, - BufferItems: 64, - Metrics: true, - }) - require.NoError(t, err) - - for i := 0; i < 10; i++ { - c.SetWithCost(strconv.Itoa(i), i, 1) - } - time.Sleep(wait) - m := c.Metrics - require.Equal(t, uint64(10), m.KeysAdded()) -} - -func TestMetrics(t *testing.T) { - newMetrics() -} - -func TestNilMetrics(t *testing.T) { - var m *Metrics - for _, f := range []func() uint64{ - m.Hits, - m.Misses, - m.KeysAdded, - m.KeysEvicted, - m.CostEvicted, - m.SetsDropped, - m.SetsRejected, - m.GetsDropped, - m.GetsKept, - } { - require.Equal(t, uint64(0), f()) - } -} - -func TestMetricsAddGet(t *testing.T) { - m := newMetrics() - m.add(hit, 1, 1) - m.add(hit, 2, 2) - m.add(hit, 3, 3) - require.Equal(t, uint64(6), m.Hits()) - - m = nil - m.add(hit, 1, 1) - require.Equal(t, uint64(0), m.Hits()) -} - -func TestMetricsRatio(t *testing.T) { - m := newMetrics() - require.Equal(t, float64(0), m.Ratio()) - - m.add(hit, 1, 1) - m.add(hit, 2, 2) - m.add(miss, 1, 1) - m.add(miss, 2, 2) - require.Equal(t, 0.5, m.Ratio()) - - m = nil - require.Equal(t, float64(0), m.Ratio()) -} - -func TestMetricsString(t *testing.T) { - m := newMetrics() - m.add(hit, 1, 1) - m.add(miss, 1, 1) - m.add(keyAdd, 1, 1) - m.add(keyUpdate, 1, 1) - m.add(keyEvict, 1, 1) - m.add(costAdd, 1, 1) - m.add(costEvict, 1, 1) - m.add(dropSets, 1, 1) - m.add(rejectSets, 1, 1) - m.add(dropGets, 1, 1) - m.add(keepGets, 1, 1) - require.Equal(t, uint64(1), m.Hits()) - require.Equal(t, uint64(1), m.Misses()) - require.Equal(t, 0.5, m.Ratio()) - require.Equal(t, uint64(1), m.KeysAdded()) - require.Equal(t, uint64(1), m.KeysUpdated()) - require.Equal(t, uint64(1), m.KeysEvicted()) - require.Equal(t, uint64(1), m.CostAdded()) - require.Equal(t, uint64(1), m.CostEvicted()) - require.Equal(t, uint64(1), m.SetsDropped()) - require.Equal(t, uint64(1), m.SetsRejected()) - require.Equal(t, uint64(1), m.GetsDropped()) - require.Equal(t, uint64(1), m.GetsKept()) - - require.NotEqual(t, 0, len(m.String())) - - m = nil - require.Equal(t, 0, len(m.String())) - - require.Equal(t, "unidentified", stringFor(doNotUse)) -} - -func TestCacheMetricsClear(t *testing.T) { - c, err := NewCache(&Config{ - NumCounters: 100, - MaxCost: 10, - BufferItems: 64, - Metrics: true, - }) - require.NoError(t, err) - - c.SetWithCost("1", 1, 1) - stop := make(chan struct{}) - go func() { - for { - select { - case <-stop: - return - default: - c.Get("1") - } - } - }() - time.Sleep(wait) - c.Clear() - stop <- struct{}{} - c.Metrics = nil - c.Metrics.Clear() -} - -// Regression test for bug https://github.com/dgraph-io/ristretto/issues/167 -func TestDropUpdates(t *testing.T) { - originalSetBugSize := setBufSize - defer func() { setBufSize = originalSetBugSize }() - - test := func() { - // dropppedMap stores the items dropped from the cache. - droppedMap := make(map[int]struct{}) - lastEvictedSet := int64(-1) - - var err error - handler := func(_ any, value any) { - v := value.(string) - lastEvictedSet, err = strconv.ParseInt(string(v), 10, 32) - require.NoError(t, err) - - _, ok := droppedMap[int(lastEvictedSet)] - if ok { - panic(fmt.Sprintf("val = %+v was dropped but it got evicted. Dropped items: %+v\n", - lastEvictedSet, droppedMap)) - } - } - - // This is important. The race condition shows up only when the setBuf - // is full and that's why we reduce the buf size here. The test will - // try to fill up the setbuf to it's capacity and then perform an - // update on a key. - setBufSize = 10 - - c, err := NewCache(&Config{ - NumCounters: 100, - MaxCost: 10, - BufferItems: 64, - Metrics: true, - OnEvict: func(item *Item) { - if item.Value != nil { - handler(nil, item.Value) - } - }, - }) - require.NoError(t, err) - - for i := 0; i < 5*setBufSize; i++ { - v := fmt.Sprintf("%0100d", i) - // We're updating the same key. - if !c.SetWithCost("0", v, 1) { - // The race condition doesn't show up without this sleep. - time.Sleep(time.Microsecond) - droppedMap[i] = struct{}{} - } - } - // Wait for all the items to be processed. - c.Wait() - // This will cause eviction from the cache. - require.True(t, c.SetWithCost("1", nil, 10)) - c.Close() - } - - // Run the test 100 times since it's not reliable. - for i := 0; i < 100; i++ { - test() - } -} - -func newTestCache() (*Cache, error) { - return NewCache(&Config{ - NumCounters: 100, - MaxCost: 10, - BufferItems: 64, - Metrics: true, - }) -} diff --git a/go/cache/ristretto/policy.go b/go/cache/ristretto/policy.go deleted file mode 100644 index 84cc008cb99..00000000000 --- a/go/cache/ristretto/policy.go +++ /dev/null @@ -1,423 +0,0 @@ -/* - * Copyright 2020 Dgraph Labs, Inc. and Contributors - * Copyright 2021 The Vitess Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package ristretto - -import ( - "math" - "sync" - "sync/atomic" - - "vitess.io/vitess/go/cache/ristretto/bloom" -) - -const ( - // lfuSample is the number of items to sample when looking at eviction - // candidates. 5 seems to be the most optimal number [citation needed]. - lfuSample = 5 -) - -// policy is the interface encapsulating eviction/admission behavior. -// -// TODO: remove this interface and just rename defaultPolicy to policy, as we -// -// are probably only going to use/implement/maintain one policy. -type policy interface { - ringConsumer - // Add attempts to Add the key-cost pair to the Policy. It returns a slice - // of evicted keys and a bool denoting whether or not the key-cost pair - // was added. If it returns true, the key should be stored in cache. - Add(uint64, int64) ([]*Item, bool) - // Has returns true if the key exists in the Policy. - Has(uint64) bool - // Del deletes the key from the Policy. - Del(uint64) - // Cap returns the amount of used capacity. - Used() int64 - // Close stops all goroutines and closes all channels. - Close() - // Update updates the cost value for the key. - Update(uint64, int64) - // Cost returns the cost value of a key or -1 if missing. - Cost(uint64) int64 - // Optionally, set stats object to track how policy is performing. - CollectMetrics(*Metrics) - // Clear zeroes out all counters and clears hashmaps. - Clear() - // MaxCost returns the current max cost of the cache policy. - MaxCost() int64 - // UpdateMaxCost updates the max cost of the cache policy. - UpdateMaxCost(int64) -} - -func newPolicy(numCounters, maxCost int64) policy { - return newDefaultPolicy(numCounters, maxCost) -} - -type defaultPolicy struct { - sync.Mutex - admit *tinyLFU - evict *sampledLFU - itemsCh chan []uint64 - stop chan struct{} - isClosed bool - metrics *Metrics - numCounters int64 - maxCost int64 -} - -func newDefaultPolicy(numCounters, maxCost int64) *defaultPolicy { - p := &defaultPolicy{ - admit: newTinyLFU(numCounters), - evict: newSampledLFU(maxCost), - itemsCh: make(chan []uint64, 3), - stop: make(chan struct{}), - numCounters: numCounters, - maxCost: maxCost, - } - go p.processItems() - return p -} - -func (p *defaultPolicy) CollectMetrics(metrics *Metrics) { - p.metrics = metrics - p.evict.metrics = metrics -} - -type policyPair struct { - key uint64 - cost int64 -} - -func (p *defaultPolicy) processItems() { - for { - select { - case items := <-p.itemsCh: - p.Lock() - p.admit.Push(items) - p.Unlock() - case <-p.stop: - return - } - } -} - -func (p *defaultPolicy) Push(keys []uint64) bool { - if p.isClosed { - return false - } - - if len(keys) == 0 { - return true - } - - select { - case p.itemsCh <- keys: - p.metrics.add(keepGets, keys[0], uint64(len(keys))) - return true - default: - p.metrics.add(dropGets, keys[0], uint64(len(keys))) - return false - } -} - -// Add decides whether the item with the given key and cost should be accepted by -// the policy. It returns the list of victims that have been evicted and a boolean -// indicating whether the incoming item should be accepted. -func (p *defaultPolicy) Add(key uint64, cost int64) ([]*Item, bool) { - p.Lock() - defer p.Unlock() - - // Cannot add an item bigger than entire cache. - if cost > p.evict.getMaxCost() { - return nil, false - } - - // No need to go any further if the item is already in the cache. - if has := p.evict.updateIfHas(key, cost); has { - // An update does not count as an addition, so return false. - return nil, false - } - - // If the execution reaches this point, the key doesn't exist in the cache. - // Calculate the remaining room in the cache (usually bytes). - room := p.evict.roomLeft(cost) - if room >= 0 { - // There's enough room in the cache to store the new item without - // overflowing. Do that now and stop here. - p.evict.add(key, cost) - p.metrics.add(costAdd, key, uint64(cost)) - return nil, true - } - - // incHits is the hit count for the incoming item. - incHits := p.admit.Estimate(key) - // sample is the eviction candidate pool to be filled via random sampling. - // TODO: perhaps we should use a min heap here. Right now our time - // complexity is N for finding the min. Min heap should bring it down to - // O(lg N). - sample := make([]*policyPair, 0, lfuSample) - // As items are evicted they will be appended to victims. - victims := make([]*Item, 0) - - // Delete victims until there's enough space or a minKey is found that has - // more hits than incoming item. - for ; room < 0; room = p.evict.roomLeft(cost) { - // Fill up empty slots in sample. - sample = p.evict.fillSample(sample) - - // Find minimally used item in sample. - minKey, minHits, minID, minCost := uint64(0), int64(math.MaxInt64), 0, int64(0) - for i, pair := range sample { - // Look up hit count for sample key. - if hits := p.admit.Estimate(pair.key); hits < minHits { - minKey, minHits, minID, minCost = pair.key, hits, i, pair.cost - } - } - - // If the incoming item isn't worth keeping in the policy, reject. - if incHits < minHits { - p.metrics.add(rejectSets, key, 1) - return victims, false - } - - // Delete the victim from metadata. - p.evict.del(minKey) - - // Delete the victim from sample. - sample[minID] = sample[len(sample)-1] - sample = sample[:len(sample)-1] - // Store victim in evicted victims slice. - victims = append(victims, &Item{ - Key: minKey, - Conflict: 0, - Cost: minCost, - }) - } - - p.evict.add(key, cost) - p.metrics.add(costAdd, key, uint64(cost)) - return victims, true -} - -func (p *defaultPolicy) Has(key uint64) bool { - p.Lock() - _, exists := p.evict.keyCosts[key] - p.Unlock() - return exists -} - -func (p *defaultPolicy) Del(key uint64) { - p.Lock() - p.evict.del(key) - p.Unlock() -} - -func (p *defaultPolicy) Used() int64 { - p.Lock() - used := p.evict.used - p.Unlock() - return used -} - -func (p *defaultPolicy) Update(key uint64, cost int64) { - p.Lock() - p.evict.updateIfHas(key, cost) - p.Unlock() -} - -func (p *defaultPolicy) Cost(key uint64) int64 { - p.Lock() - if cost, found := p.evict.keyCosts[key]; found { - p.Unlock() - return cost - } - p.Unlock() - return -1 -} - -func (p *defaultPolicy) Clear() { - p.Lock() - p.admit = newTinyLFU(p.numCounters) - p.evict = newSampledLFU(p.maxCost) - p.Unlock() -} - -func (p *defaultPolicy) Close() { - if p.isClosed { - return - } - - // Block until the p.processItems goroutine returns. - p.stop <- struct{}{} - close(p.stop) - close(p.itemsCh) - p.isClosed = true -} - -func (p *defaultPolicy) MaxCost() int64 { - if p == nil || p.evict == nil { - return 0 - } - return p.evict.getMaxCost() -} - -func (p *defaultPolicy) UpdateMaxCost(maxCost int64) { - if p == nil || p.evict == nil { - return - } - p.evict.updateMaxCost(maxCost) -} - -// sampledLFU is an eviction helper storing key-cost pairs. -type sampledLFU struct { - keyCosts map[uint64]int64 - maxCost int64 - used int64 - metrics *Metrics -} - -func newSampledLFU(maxCost int64) *sampledLFU { - return &sampledLFU{ - keyCosts: make(map[uint64]int64), - maxCost: maxCost, - } -} - -func (p *sampledLFU) getMaxCost() int64 { - return atomic.LoadInt64(&p.maxCost) -} - -func (p *sampledLFU) updateMaxCost(maxCost int64) { - atomic.StoreInt64(&p.maxCost, maxCost) -} - -func (p *sampledLFU) roomLeft(cost int64) int64 { - return p.getMaxCost() - (p.used + cost) -} - -func (p *sampledLFU) fillSample(in []*policyPair) []*policyPair { - if len(in) >= lfuSample { - return in - } - for key, cost := range p.keyCosts { - in = append(in, &policyPair{key, cost}) - if len(in) >= lfuSample { - return in - } - } - return in -} - -func (p *sampledLFU) del(key uint64) { - cost, ok := p.keyCosts[key] - if !ok { - return - } - p.used -= cost - delete(p.keyCosts, key) - p.metrics.add(costEvict, key, uint64(cost)) - p.metrics.add(keyEvict, key, 1) -} - -func (p *sampledLFU) add(key uint64, cost int64) { - p.keyCosts[key] = cost - p.used += cost -} - -func (p *sampledLFU) updateIfHas(key uint64, cost int64) bool { - if prev, found := p.keyCosts[key]; found { - // Update the cost of an existing key, but don't worry about evicting. - // Evictions will be handled the next time a new item is added. - p.metrics.add(keyUpdate, key, 1) - if prev > cost { - diff := prev - cost - p.metrics.add(costAdd, key, ^uint64(uint64(diff)-1)) - } else if cost > prev { - diff := cost - prev - p.metrics.add(costAdd, key, uint64(diff)) - } - p.used += cost - prev - p.keyCosts[key] = cost - return true - } - return false -} - -func (p *sampledLFU) clear() { - p.used = 0 - p.keyCosts = make(map[uint64]int64) -} - -// tinyLFU is an admission helper that keeps track of access frequency using -// tiny (4-bit) counters in the form of a count-min sketch. -// tinyLFU is NOT thread safe. -type tinyLFU struct { - freq *cmSketch - door *bloom.Bloom - incrs int64 - resetAt int64 -} - -func newTinyLFU(numCounters int64) *tinyLFU { - return &tinyLFU{ - freq: newCmSketch(numCounters), - door: bloom.NewBloomFilterWithErrorRate(uint64(numCounters), 0.01), - resetAt: numCounters, - } -} - -func (p *tinyLFU) Push(keys []uint64) { - for _, key := range keys { - p.Increment(key) - } -} - -func (p *tinyLFU) Estimate(key uint64) int64 { - hits := p.freq.Estimate(key) - if p.door.Has(key) { - hits++ - } - return hits -} - -func (p *tinyLFU) Increment(key uint64) { - // Flip doorkeeper bit if not already done. - if added := p.door.AddIfNotHas(key); !added { - // Increment count-min counter if doorkeeper bit is already set. - p.freq.Increment(key) - } - p.incrs++ - if p.incrs >= p.resetAt { - p.reset() - } -} - -func (p *tinyLFU) reset() { - // Zero out incrs. - p.incrs = 0 - // clears doorkeeper bits - p.door.Clear() - // halves count-min counters - p.freq.Reset() -} - -func (p *tinyLFU) clear() { - p.incrs = 0 - p.freq.Clear() - p.door.Clear() -} diff --git a/go/cache/ristretto/policy_test.go b/go/cache/ristretto/policy_test.go deleted file mode 100644 index c864b6c74d0..00000000000 --- a/go/cache/ristretto/policy_test.go +++ /dev/null @@ -1,276 +0,0 @@ -/* - * Copyright 2020 Dgraph Labs, Inc. and Contributors - * Copyright 2021 The Vitess Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package ristretto - -import ( - "testing" - "time" - - "github.com/stretchr/testify/require" -) - -func TestPolicy(t *testing.T) { - defer func() { - require.Nil(t, recover()) - }() - newPolicy(100, 10) -} - -func TestPolicyMetrics(t *testing.T) { - p := newDefaultPolicy(100, 10) - p.CollectMetrics(newMetrics()) - require.NotNil(t, p.metrics) - require.NotNil(t, p.evict.metrics) -} - -func TestPolicyProcessItems(t *testing.T) { - p := newDefaultPolicy(100, 10) - p.itemsCh <- []uint64{1, 2, 2} - time.Sleep(wait) - p.Lock() - require.Equal(t, int64(2), p.admit.Estimate(2)) - require.Equal(t, int64(1), p.admit.Estimate(1)) - p.Unlock() - - p.stop <- struct{}{} - p.itemsCh <- []uint64{3, 3, 3} - time.Sleep(wait) - p.Lock() - require.Equal(t, int64(0), p.admit.Estimate(3)) - p.Unlock() -} - -func TestPolicyPush(t *testing.T) { - p := newDefaultPolicy(100, 10) - require.True(t, p.Push([]uint64{})) - - keepCount := 0 - for i := 0; i < 10; i++ { - if p.Push([]uint64{1, 2, 3, 4, 5}) { - keepCount++ - } - } - require.NotEqual(t, 0, keepCount) -} - -func TestPolicyAdd(t *testing.T) { - p := newDefaultPolicy(1000, 100) - if victims, added := p.Add(1, 101); victims != nil || added { - t.Fatal("can't add an item bigger than entire cache") - } - p.Lock() - p.evict.add(1, 1) - p.admit.Increment(1) - p.admit.Increment(2) - p.admit.Increment(3) - p.Unlock() - - victims, added := p.Add(1, 1) - require.Nil(t, victims) - require.False(t, added) - - victims, added = p.Add(2, 20) - require.Nil(t, victims) - require.True(t, added) - - victims, added = p.Add(3, 90) - require.NotNil(t, victims) - require.True(t, added) - - victims, added = p.Add(4, 20) - require.NotNil(t, victims) - require.False(t, added) -} - -func TestPolicyHas(t *testing.T) { - p := newDefaultPolicy(100, 10) - p.Add(1, 1) - require.True(t, p.Has(1)) - require.False(t, p.Has(2)) -} - -func TestPolicyDel(t *testing.T) { - p := newDefaultPolicy(100, 10) - p.Add(1, 1) - p.Del(1) - p.Del(2) - require.False(t, p.Has(1)) - require.False(t, p.Has(2)) -} - -func TestPolicyCap(t *testing.T) { - p := newDefaultPolicy(100, 10) - p.Add(1, 1) - require.Equal(t, int64(9), p.MaxCost()-p.Used()) -} - -func TestPolicyUpdate(t *testing.T) { - p := newDefaultPolicy(100, 10) - p.Add(1, 1) - p.Update(1, 2) - p.Lock() - require.Equal(t, int64(2), p.evict.keyCosts[1]) - p.Unlock() -} - -func TestPolicyCost(t *testing.T) { - p := newDefaultPolicy(100, 10) - p.Add(1, 2) - require.Equal(t, int64(2), p.Cost(1)) - require.Equal(t, int64(-1), p.Cost(2)) -} - -func TestPolicyClear(t *testing.T) { - p := newDefaultPolicy(100, 10) - p.Add(1, 1) - p.Add(2, 2) - p.Add(3, 3) - p.Clear() - require.Equal(t, int64(10), p.MaxCost()-p.Used()) - require.False(t, p.Has(1)) - require.False(t, p.Has(2)) - require.False(t, p.Has(3)) -} - -func TestPolicyClose(t *testing.T) { - defer func() { - require.NotNil(t, recover()) - }() - - p := newDefaultPolicy(100, 10) - p.Add(1, 1) - p.Close() - p.itemsCh <- []uint64{1} -} - -func TestPushAfterClose(t *testing.T) { - p := newDefaultPolicy(100, 10) - p.Close() - require.False(t, p.Push([]uint64{1, 2})) -} - -func TestAddAfterClose(t *testing.T) { - p := newDefaultPolicy(100, 10) - p.Close() - p.Add(1, 1) -} - -func TestSampledLFUAdd(t *testing.T) { - e := newSampledLFU(4) - e.add(1, 1) - e.add(2, 2) - e.add(3, 1) - require.Equal(t, int64(4), e.used) - require.Equal(t, int64(2), e.keyCosts[2]) -} - -func TestSampledLFUDel(t *testing.T) { - e := newSampledLFU(4) - e.add(1, 1) - e.add(2, 2) - e.del(2) - require.Equal(t, int64(1), e.used) - _, ok := e.keyCosts[2] - require.False(t, ok) - e.del(4) -} - -func TestSampledLFUUpdate(t *testing.T) { - e := newSampledLFU(4) - e.add(1, 1) - require.True(t, e.updateIfHas(1, 2)) - require.Equal(t, int64(2), e.used) - require.False(t, e.updateIfHas(2, 2)) -} - -func TestSampledLFUClear(t *testing.T) { - e := newSampledLFU(4) - e.add(1, 1) - e.add(2, 2) - e.add(3, 1) - e.clear() - require.Equal(t, 0, len(e.keyCosts)) - require.Equal(t, int64(0), e.used) -} - -func TestSampledLFURoom(t *testing.T) { - e := newSampledLFU(16) - e.add(1, 1) - e.add(2, 2) - e.add(3, 3) - require.Equal(t, int64(6), e.roomLeft(4)) -} - -func TestSampledLFUSample(t *testing.T) { - e := newSampledLFU(16) - e.add(4, 4) - e.add(5, 5) - sample := e.fillSample([]*policyPair{ - {1, 1}, - {2, 2}, - {3, 3}, - }) - k := sample[len(sample)-1].key - require.Equal(t, 5, len(sample)) - require.NotEqual(t, 1, k) - require.NotEqual(t, 2, k) - require.NotEqual(t, 3, k) - require.Equal(t, len(sample), len(e.fillSample(sample))) - e.del(5) - sample = e.fillSample(sample[:len(sample)-2]) - require.Equal(t, 4, len(sample)) -} - -func TestTinyLFUIncrement(t *testing.T) { - a := newTinyLFU(4) - a.Increment(1) - a.Increment(1) - a.Increment(1) - require.True(t, a.door.Has(1)) - require.Equal(t, int64(2), a.freq.Estimate(1)) - - a.Increment(1) - require.False(t, a.door.Has(1)) - require.Equal(t, int64(1), a.freq.Estimate(1)) -} - -func TestTinyLFUEstimate(t *testing.T) { - a := newTinyLFU(8) - a.Increment(1) - a.Increment(1) - a.Increment(1) - require.Equal(t, int64(3), a.Estimate(1)) - require.Equal(t, int64(0), a.Estimate(2)) -} - -func TestTinyLFUPush(t *testing.T) { - a := newTinyLFU(16) - a.Push([]uint64{1, 2, 2, 3, 3, 3}) - require.Equal(t, int64(1), a.Estimate(1)) - require.Equal(t, int64(2), a.Estimate(2)) - require.Equal(t, int64(3), a.Estimate(3)) - require.Equal(t, int64(6), a.incrs) -} - -func TestTinyLFUClear(t *testing.T) { - a := newTinyLFU(16) - a.Push([]uint64{1, 3, 3, 3}) - a.clear() - require.Equal(t, int64(0), a.incrs) - require.Equal(t, int64(0), a.Estimate(3)) -} diff --git a/go/cache/ristretto/ring.go b/go/cache/ristretto/ring.go deleted file mode 100644 index 84d8689ee37..00000000000 --- a/go/cache/ristretto/ring.go +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Copyright 2019 Dgraph Labs, Inc. and Contributors - * Copyright 2021 The Vitess Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package ristretto - -import ( - "sync" -) - -// ringConsumer is the user-defined object responsible for receiving and -// processing items in batches when buffers are drained. -type ringConsumer interface { - Push([]uint64) bool -} - -// ringStripe is a singular ring buffer that is not concurrent safe. -type ringStripe struct { - cons ringConsumer - data []uint64 - capa int -} - -func newRingStripe(cons ringConsumer, capa int64) *ringStripe { - return &ringStripe{ - cons: cons, - data: make([]uint64, 0, capa), - capa: int(capa), - } -} - -// Push appends an item in the ring buffer and drains (copies items and -// sends to Consumer) if full. -func (s *ringStripe) Push(item uint64) { - s.data = append(s.data, item) - // Decide if the ring buffer should be drained. - if len(s.data) >= s.capa { - // Send elements to consumer and create a new ring stripe. - if s.cons.Push(s.data) { - s.data = make([]uint64, 0, s.capa) - } else { - s.data = s.data[:0] - } - } -} - -// ringBuffer stores multiple buffers (stripes) and distributes Pushed items -// between them to lower contention. -// -// This implements the "batching" process described in the BP-Wrapper paper -// (section III part A). -type ringBuffer struct { - pool *sync.Pool -} - -// newRingBuffer returns a striped ring buffer. The Consumer in ringConfig will -// be called when individual stripes are full and need to drain their elements. -func newRingBuffer(cons ringConsumer, capa int64) *ringBuffer { - // LOSSY buffers use a very simple sync.Pool for concurrently reusing - // stripes. We do lose some stripes due to GC (unheld items in sync.Pool - // are cleared), but the performance gains generally outweigh the small - // percentage of elements lost. The performance primarily comes from - // low-level runtime functions used in the standard library that aren't - // available to us (such as runtime_procPin()). - return &ringBuffer{ - pool: &sync.Pool{ - New: func() any { return newRingStripe(cons, capa) }, - }, - } -} - -// Push adds an element to one of the internal stripes and possibly drains if -// the stripe becomes full. -func (b *ringBuffer) Push(item uint64) { - // Reuse or create a new stripe. - stripe := b.pool.Get().(*ringStripe) - stripe.Push(item) - b.pool.Put(stripe) -} diff --git a/go/cache/ristretto/ring_test.go b/go/cache/ristretto/ring_test.go deleted file mode 100644 index 0dbe962ccc6..00000000000 --- a/go/cache/ristretto/ring_test.go +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright 2020 Dgraph Labs, Inc. and Contributors - * Copyright 2021 The Vitess Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package ristretto - -import ( - "sync" - "testing" - - "github.com/stretchr/testify/require" -) - -type testConsumer struct { - push func([]uint64) - save bool -} - -func (c *testConsumer) Push(items []uint64) bool { - if c.save { - c.push(items) - return true - } - return false -} - -func TestRingDrain(t *testing.T) { - drains := 0 - r := newRingBuffer(&testConsumer{ - push: func(items []uint64) { - drains++ - }, - save: true, - }, 1) - for i := 0; i < 100; i++ { - r.Push(uint64(i)) - } - require.Equal(t, 100, drains, "buffers shouldn't be dropped with BufferItems == 1") -} - -func TestRingReset(t *testing.T) { - drains := 0 - r := newRingBuffer(&testConsumer{ - push: func(items []uint64) { - drains++ - }, - save: false, - }, 4) - for i := 0; i < 100; i++ { - r.Push(uint64(i)) - } - require.Equal(t, 0, drains, "testConsumer shouldn't be draining") -} - -func TestRingConsumer(t *testing.T) { - mu := &sync.Mutex{} - drainItems := make(map[uint64]struct{}) - r := newRingBuffer(&testConsumer{ - push: func(items []uint64) { - mu.Lock() - defer mu.Unlock() - for i := range items { - drainItems[items[i]] = struct{}{} - } - }, - save: true, - }, 4) - for i := 0; i < 100; i++ { - r.Push(uint64(i)) - } - l := len(drainItems) - require.NotEqual(t, 0, l) - require.True(t, l <= 100) -} diff --git a/go/cache/ristretto/sketch.go b/go/cache/ristretto/sketch.go deleted file mode 100644 index c8ad31e8494..00000000000 --- a/go/cache/ristretto/sketch.go +++ /dev/null @@ -1,154 +0,0 @@ -/* - * Copyright 2019 Dgraph Labs, Inc. and Contributors - * Copyright 2021 The Vitess Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// Package ristretto includes multiple probabalistic data structures needed for -// admission/eviction metadata. Most are Counting Bloom Filter variations, but -// a caching-specific feature that is also required is a "freshness" mechanism, -// which basically serves as a "lifetime" process. This freshness mechanism -// was described in the original TinyLFU paper [1], but other mechanisms may -// be better suited for certain data distributions. -// -// [1]: https://arxiv.org/abs/1512.00727 -package ristretto - -import ( - "fmt" - "math/rand" - "time" -) - -// cmSketch is a Count-Min sketch implementation with 4-bit counters, heavily -// based on Damian Gryski's CM4 [1]. -// -// [1]: https://github.com/dgryski/go-tinylfu/blob/master/cm4.go -type cmSketch struct { - rows [cmDepth]cmRow - seed [cmDepth]uint64 - mask uint64 -} - -const ( - // cmDepth is the number of counter copies to store (think of it as rows). - cmDepth = 4 -) - -func newCmSketch(numCounters int64) *cmSketch { - if numCounters == 0 { - panic("cmSketch: bad numCounters") - } - // Get the next power of 2 for better cache performance. - numCounters = next2Power(numCounters) - sketch := &cmSketch{mask: uint64(numCounters - 1)} - // Initialize rows of counters and seeds. - source := rand.New(rand.NewSource(time.Now().UnixNano())) - for i := 0; i < cmDepth; i++ { - sketch.seed[i] = source.Uint64() - sketch.rows[i] = newCmRow(numCounters) - } - return sketch -} - -// Increment increments the count(ers) for the specified key. -func (s *cmSketch) Increment(hashed uint64) { - for i := range s.rows { - s.rows[i].increment((hashed ^ s.seed[i]) & s.mask) - } -} - -// Estimate returns the value of the specified key. -func (s *cmSketch) Estimate(hashed uint64) int64 { - min := byte(255) - for i := range s.rows { - val := s.rows[i].get((hashed ^ s.seed[i]) & s.mask) - if val < min { - min = val - } - } - return int64(min) -} - -// Reset halves all counter values. -func (s *cmSketch) Reset() { - for _, r := range s.rows { - r.reset() - } -} - -// Clear zeroes all counters. -func (s *cmSketch) Clear() { - for _, r := range s.rows { - r.clear() - } -} - -// cmRow is a row of bytes, with each byte holding two counters. -type cmRow []byte - -func newCmRow(numCounters int64) cmRow { - return make(cmRow, numCounters/2) -} - -func (r cmRow) get(n uint64) byte { - return byte(r[n/2]>>((n&1)*4)) & 0x0f -} - -func (r cmRow) increment(n uint64) { - // Index of the counter. - i := n / 2 - // Shift distance (even 0, odd 4). - s := (n & 1) * 4 - // Counter value. - v := (r[i] >> s) & 0x0f - // Only increment if not max value (overflow wrap is bad for LFU). - if v < 15 { - r[i] += 1 << s - } -} - -func (r cmRow) reset() { - // Halve each counter. - for i := range r { - r[i] = (r[i] >> 1) & 0x77 - } -} - -func (r cmRow) clear() { - // Zero each counter. - clear(r) -} - -func (r cmRow) string() string { - s := "" - for i := uint64(0); i < uint64(len(r)*2); i++ { - s += fmt.Sprintf("%02d ", (r[(i/2)]>>((i&1)*4))&0x0f) - } - s = s[:len(s)-1] - return s -} - -// next2Power rounds x up to the next power of 2, if it's not already one. -func next2Power(x int64) int64 { - x-- - x |= x >> 1 - x |= x >> 2 - x |= x >> 4 - x |= x >> 8 - x |= x >> 16 - x |= x >> 32 - x++ - return x -} diff --git a/go/cache/ristretto/sketch_test.go b/go/cache/ristretto/sketch_test.go deleted file mode 100644 index 03804a6d599..00000000000 --- a/go/cache/ristretto/sketch_test.go +++ /dev/null @@ -1,104 +0,0 @@ -/* - * Copyright 2020 Dgraph Labs, Inc. and Contributors - * Copyright 2021 The Vitess Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package ristretto - -import ( - "testing" - - "vitess.io/vitess/go/vt/log" - - "github.com/stretchr/testify/require" -) - -func TestSketch(t *testing.T) { - defer func() { - require.NotNil(t, recover()) - }() - - s := newCmSketch(5) - require.Equal(t, uint64(7), s.mask) - newCmSketch(0) -} - -func TestSketchIncrement(t *testing.T) { - s := newCmSketch(16) - s.Increment(1) - s.Increment(5) - s.Increment(9) - for i := 0; i < cmDepth; i++ { - if s.rows[i].string() != s.rows[0].string() { - break - } - require.False(t, i == cmDepth-1, "identical rows, bad seeding") - } -} - -func TestSketchEstimate(t *testing.T) { - s := newCmSketch(16) - s.Increment(1) - s.Increment(1) - require.Equal(t, int64(2), s.Estimate(1)) - require.Equal(t, int64(0), s.Estimate(0)) -} - -func TestSketchReset(t *testing.T) { - s := newCmSketch(16) - s.Increment(1) - s.Increment(1) - s.Increment(1) - s.Increment(1) - s.Reset() - require.Equal(t, int64(2), s.Estimate(1)) -} - -func TestSketchClear(t *testing.T) { - s := newCmSketch(16) - for i := 0; i < 16; i++ { - s.Increment(uint64(i)) - } - s.Clear() - for i := 0; i < 16; i++ { - require.Equal(t, int64(0), s.Estimate(uint64(i))) - } -} - -func TestNext2Power(t *testing.T) { - sz := 12 << 30 - szf := float64(sz) * 0.01 - val := int64(szf) - log.Infof("szf = %.2f val = %d\n", szf, val) - pow := next2Power(val) - log.Infof("pow = %d. mult 4 = %d\n", pow, pow*4) -} - -func BenchmarkSketchIncrement(b *testing.B) { - s := newCmSketch(16) - b.SetBytes(1) - for n := 0; n < b.N; n++ { - s.Increment(1) - } -} - -func BenchmarkSketchEstimate(b *testing.B) { - s := newCmSketch(16) - s.Increment(1) - b.SetBytes(1) - for n := 0; n < b.N; n++ { - s.Estimate(1) - } -} diff --git a/go/cache/ristretto/store.go b/go/cache/ristretto/store.go deleted file mode 100644 index 0e455e7052f..00000000000 --- a/go/cache/ristretto/store.go +++ /dev/null @@ -1,240 +0,0 @@ -/* - * Copyright 2019 Dgraph Labs, Inc. and Contributors - * Copyright 2021 The Vitess Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package ristretto - -import ( - "sync" -) - -// TODO: Do we need this to be a separate struct from Item? -type storeItem struct { - key uint64 - conflict uint64 - value any -} - -// store is the interface fulfilled by all hash map implementations in this -// file. Some hash map implementations are better suited for certain data -// distributions than others, so this allows us to abstract that out for use -// in Ristretto. -// -// Every store is safe for concurrent usage. -type store interface { - // Get returns the value associated with the key parameter. - Get(uint64, uint64) (any, bool) - // Set adds the key-value pair to the Map or updates the value if it's - // already present. The key-value pair is passed as a pointer to an - // item object. - Set(*Item) - // Del deletes the key-value pair from the Map. - Del(uint64, uint64) (uint64, any) - // Update attempts to update the key with a new value and returns true if - // successful. - Update(*Item) (any, bool) - // Clear clears all contents of the store. - Clear(onEvict itemCallback) - // ForEach yields all the values in the store - ForEach(forEach func(any) bool) - // Len returns the number of entries in the store - Len() int -} - -// newStore returns the default store implementation. -func newStore() store { - return newShardedMap() -} - -const numShards uint64 = 256 - -type shardedMap struct { - shards []*lockedMap -} - -func newShardedMap() *shardedMap { - sm := &shardedMap{ - shards: make([]*lockedMap, int(numShards)), - } - for i := range sm.shards { - sm.shards[i] = newLockedMap() - } - return sm -} - -func (sm *shardedMap) Get(key, conflict uint64) (any, bool) { - return sm.shards[key%numShards].get(key, conflict) -} - -func (sm *shardedMap) Set(i *Item) { - if i == nil { - // If item is nil make this Set a no-op. - return - } - - sm.shards[i.Key%numShards].Set(i) -} - -func (sm *shardedMap) Del(key, conflict uint64) (uint64, any) { - return sm.shards[key%numShards].Del(key, conflict) -} - -func (sm *shardedMap) Update(newItem *Item) (any, bool) { - return sm.shards[newItem.Key%numShards].Update(newItem) -} - -func (sm *shardedMap) ForEach(forEach func(any) bool) { - for _, shard := range sm.shards { - if !shard.foreach(forEach) { - break - } - } -} - -func (sm *shardedMap) Len() int { - l := 0 - for _, shard := range sm.shards { - l += shard.Len() - } - return l -} - -func (sm *shardedMap) Clear(onEvict itemCallback) { - for i := uint64(0); i < numShards; i++ { - sm.shards[i].Clear(onEvict) - } -} - -type lockedMap struct { - sync.RWMutex - data map[uint64]storeItem -} - -func newLockedMap() *lockedMap { - return &lockedMap{ - data: make(map[uint64]storeItem), - } -} - -func (m *lockedMap) get(key, conflict uint64) (any, bool) { - m.RLock() - item, ok := m.data[key] - m.RUnlock() - if !ok { - return nil, false - } - if conflict != 0 && (conflict != item.conflict) { - return nil, false - } - return item.value, true -} - -func (m *lockedMap) Set(i *Item) { - if i == nil { - // If the item is nil make this Set a no-op. - return - } - - m.Lock() - defer m.Unlock() - item, ok := m.data[i.Key] - - if ok { - // The item existed already. We need to check the conflict key and reject the - // update if they do not match. Only after that the expiration map is updated. - if i.Conflict != 0 && (i.Conflict != item.conflict) { - return - } - } - - m.data[i.Key] = storeItem{ - key: i.Key, - conflict: i.Conflict, - value: i.Value, - } -} - -func (m *lockedMap) Del(key, conflict uint64) (uint64, any) { - m.Lock() - item, ok := m.data[key] - if !ok { - m.Unlock() - return 0, nil - } - if conflict != 0 && (conflict != item.conflict) { - m.Unlock() - return 0, nil - } - - delete(m.data, key) - m.Unlock() - return item.conflict, item.value -} - -func (m *lockedMap) Update(newItem *Item) (any, bool) { - m.Lock() - item, ok := m.data[newItem.Key] - if !ok { - m.Unlock() - return nil, false - } - if newItem.Conflict != 0 && (newItem.Conflict != item.conflict) { - m.Unlock() - return nil, false - } - - m.data[newItem.Key] = storeItem{ - key: newItem.Key, - conflict: newItem.Conflict, - value: newItem.Value, - } - - m.Unlock() - return item.value, true -} - -func (m *lockedMap) Len() int { - m.RLock() - l := len(m.data) - m.RUnlock() - return l -} - -func (m *lockedMap) Clear(onEvict itemCallback) { - m.Lock() - i := &Item{} - if onEvict != nil { - for _, si := range m.data { - i.Key = si.key - i.Conflict = si.conflict - i.Value = si.value - onEvict(i) - } - } - m.data = make(map[uint64]storeItem) - m.Unlock() -} - -func (m *lockedMap) foreach(forEach func(any) bool) bool { - m.RLock() - defer m.RUnlock() - for _, si := range m.data { - if !forEach(si.value) { - return false - } - } - return true -} diff --git a/go/cache/ristretto/store_test.go b/go/cache/ristretto/store_test.go deleted file mode 100644 index 54634736a72..00000000000 --- a/go/cache/ristretto/store_test.go +++ /dev/null @@ -1,224 +0,0 @@ -/* - * Copyright 2020 Dgraph Labs, Inc. and Contributors - * Copyright 2021 The Vitess Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package ristretto - -import ( - "strconv" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestStoreSetGet(t *testing.T) { - s := newStore() - key, conflict := defaultStringHash("1") - i := Item{ - Key: key, - Conflict: conflict, - Value: 2, - } - s.Set(&i) - val, ok := s.Get(key, conflict) - require.True(t, ok) - require.Equal(t, 2, val.(int)) - - i.Value = 3 - s.Set(&i) - val, ok = s.Get(key, conflict) - require.True(t, ok) - require.Equal(t, 3, val.(int)) - - key, conflict = defaultStringHash("2") - i = Item{ - Key: key, - Conflict: conflict, - Value: 2, - } - s.Set(&i) - val, ok = s.Get(key, conflict) - require.True(t, ok) - require.Equal(t, 2, val.(int)) -} - -func TestStoreDel(t *testing.T) { - s := newStore() - key, conflict := defaultStringHash("1") - i := Item{ - Key: key, - Conflict: conflict, - Value: 1, - } - s.Set(&i) - s.Del(key, conflict) - val, ok := s.Get(key, conflict) - require.False(t, ok) - require.Nil(t, val) - - s.Del(2, 0) -} - -func TestStoreClear(t *testing.T) { - s := newStore() - for i := 0; i < 1000; i++ { - key, conflict := defaultStringHash(strconv.Itoa(i)) - it := Item{ - Key: key, - Conflict: conflict, - Value: i, - } - s.Set(&it) - } - s.Clear(nil) - for i := 0; i < 1000; i++ { - key, conflict := defaultStringHash(strconv.Itoa(i)) - val, ok := s.Get(key, conflict) - require.False(t, ok) - require.Nil(t, val) - } -} - -func TestStoreUpdate(t *testing.T) { - s := newStore() - key, conflict := defaultStringHash("1") - i := Item{ - Key: key, - Conflict: conflict, - Value: 1, - } - s.Set(&i) - i.Value = 2 - _, ok := s.Update(&i) - require.True(t, ok) - - val, ok := s.Get(key, conflict) - require.True(t, ok) - require.NotNil(t, val) - - val, ok = s.Get(key, conflict) - require.True(t, ok) - require.Equal(t, 2, val.(int)) - - i.Value = 3 - _, ok = s.Update(&i) - require.True(t, ok) - - val, ok = s.Get(key, conflict) - require.True(t, ok) - require.Equal(t, 3, val.(int)) - - key, conflict = defaultStringHash("2") - i = Item{ - Key: key, - Conflict: conflict, - Value: 2, - } - _, ok = s.Update(&i) - require.False(t, ok) - val, ok = s.Get(key, conflict) - require.False(t, ok) - require.Nil(t, val) -} - -func TestStoreCollision(t *testing.T) { - s := newShardedMap() - s.shards[1].Lock() - s.shards[1].data[1] = storeItem{ - key: 1, - conflict: 0, - value: 1, - } - s.shards[1].Unlock() - val, ok := s.Get(1, 1) - require.False(t, ok) - require.Nil(t, val) - - i := Item{ - Key: 1, - Conflict: 1, - Value: 2, - } - s.Set(&i) - val, ok = s.Get(1, 0) - require.True(t, ok) - require.NotEqual(t, 2, val.(int)) - - _, ok = s.Update(&i) - require.False(t, ok) - val, ok = s.Get(1, 0) - require.True(t, ok) - require.NotEqual(t, 2, val.(int)) - - s.Del(1, 1) - val, ok = s.Get(1, 0) - require.True(t, ok) - require.NotNil(t, val) -} - -func BenchmarkStoreGet(b *testing.B) { - s := newStore() - key, conflict := defaultStringHash("1") - i := Item{ - Key: key, - Conflict: conflict, - Value: 1, - } - s.Set(&i) - b.SetBytes(1) - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - s.Get(key, conflict) - } - }) -} - -func BenchmarkStoreSet(b *testing.B) { - s := newStore() - key, conflict := defaultStringHash("1") - b.SetBytes(1) - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - i := Item{ - Key: key, - Conflict: conflict, - Value: 1, - } - s.Set(&i) - } - }) -} - -func BenchmarkStoreUpdate(b *testing.B) { - s := newStore() - key, conflict := defaultStringHash("1") - i := Item{ - Key: key, - Conflict: conflict, - Value: 1, - } - s.Set(&i) - b.SetBytes(1) - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - s.Update(&Item{ - Key: key, - Conflict: conflict, - Value: 2, - }) - } - }) -} diff --git a/go/cache/theine/LICENSE b/go/cache/theine/LICENSE new file mode 100644 index 00000000000..0161260b7b6 --- /dev/null +++ b/go/cache/theine/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Yiling-J + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/go/cache/theine/bf/bf.go b/go/cache/theine/bf/bf.go new file mode 100644 index 00000000000..f68e34d81e3 --- /dev/null +++ b/go/cache/theine/bf/bf.go @@ -0,0 +1,116 @@ +package bf + +import ( + "math" +) + +// doorkeeper is a small bloom-filter-based cache admission policy +type Bloomfilter struct { + Filter bitvector // our filter bit vector + M uint32 // size of bit vector in bits + K uint32 // distinct hash functions needed + FalsePositiveRate float64 + Capacity int +} + +func New(falsePositiveRate float64) *Bloomfilter { + d := &Bloomfilter{FalsePositiveRate: falsePositiveRate} + d.EnsureCapacity(320) + return d +} + +// create new bloomfilter with given size in bytes +func NewWithSize(size uint32) *Bloomfilter { + d := &Bloomfilter{} + bits := size * 8 + m := nextPowerOfTwo(uint32(bits)) + d.M = m + d.Filter = newbv(m) + return d +} + +func (d *Bloomfilter) EnsureCapacity(capacity int) { + if capacity <= d.Capacity { + return + } + capacity = int(nextPowerOfTwo(uint32(capacity))) + bits := float64(capacity) * -math.Log(d.FalsePositiveRate) / (math.Log(2.0) * math.Log(2.0)) // in bits + m := nextPowerOfTwo(uint32(bits)) + + if m < 1024 { + m = 1024 + } + + k := uint32(0.7 * float64(m) / float64(capacity)) + if k < 2 { + k = 2 + } + d.Capacity = capacity + d.M = m + d.Filter = newbv(m) + d.K = k +} + +func (d *Bloomfilter) Exist(h uint64) bool { + h1, h2 := uint32(h), uint32(h>>32) + var o uint = 1 + for i := uint32(0); i < d.K; i++ { + o &= d.Filter.get((h1 + (i * h2)) & (d.M - 1)) + } + return o == 1 +} + +// insert inserts the byte array b into the bloom filter. Returns true if the value +// was already considered to be in the bloom filter. +func (d *Bloomfilter) Insert(h uint64) bool { + h1, h2 := uint32(h), uint32(h>>32) + var o uint = 1 + for i := uint32(0); i < d.K; i++ { + o &= d.Filter.getset((h1 + (i * h2)) & (d.M - 1)) + } + return o == 1 +} + +// Reset clears the bloom filter +func (d *Bloomfilter) Reset() { + for i := range d.Filter { + d.Filter[i] = 0 + } +} + +// Internal routines for the bit vector +type bitvector []uint64 + +func newbv(size uint32) bitvector { + return make([]uint64, uint(size+63)/64) +} + +func (b bitvector) get(bit uint32) uint { + shift := bit % 64 + idx := bit / 64 + bb := b[idx] + m := uint64(1) << shift + return uint((bb & m) >> shift) +} + +// set bit 'bit' in the bitvector d and return previous value +func (b bitvector) getset(bit uint32) uint { + shift := bit % 64 + idx := bit / 64 + bb := b[idx] + m := uint64(1) << shift + b[idx] |= m + return uint((bb & m) >> shift) +} + +// return the integer >= i which is a power of two +func nextPowerOfTwo(i uint32) uint32 { + n := i - 1 + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + n++ + return n +} diff --git a/go/cache/theine/bf/bf_test.go b/go/cache/theine/bf/bf_test.go new file mode 100644 index 00000000000..f0e505766e7 --- /dev/null +++ b/go/cache/theine/bf/bf_test.go @@ -0,0 +1,24 @@ +package bf + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBloom(t *testing.T) { + bf := NewWithSize(5) + bf.FalsePositiveRate = 0.1 + bf.EnsureCapacity(5) + bf.EnsureCapacity(500) + bf.EnsureCapacity(200) + + exist := bf.Insert(123) + require.False(t, exist) + + exist = bf.Exist(123) + require.True(t, exist) + + exist = bf.Exist(456) + require.False(t, exist) +} diff --git a/go/cache/theine/entry.go b/go/cache/theine/entry.go new file mode 100644 index 00000000000..48e3bd5a09a --- /dev/null +++ b/go/cache/theine/entry.go @@ -0,0 +1,93 @@ +/* +Copyright 2023 The Vitess Authors. +Copyright 2023 Yiling-J + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package theine + +import "sync/atomic" + +const ( + NEW int8 = iota + REMOVE + UPDATE +) + +type ReadBufItem[K cachekey, V any] struct { + entry *Entry[K, V] + hash uint64 +} +type WriteBufItem[K cachekey, V any] struct { + entry *Entry[K, V] + costChange int64 + code int8 +} + +type MetaData[K cachekey, V any] struct { + prev *Entry[K, V] + next *Entry[K, V] +} + +type Entry[K cachekey, V any] struct { + key K + value V + meta MetaData[K, V] + cost atomic.Int64 + frequency atomic.Int32 + epoch atomic.Uint32 + removed bool + deque bool + root bool + list uint8 // used in slru, probation or protected +} + +func NewEntry[K cachekey, V any](key K, value V, cost int64) *Entry[K, V] { + entry := &Entry[K, V]{ + key: key, + value: value, + } + entry.cost.Store(cost) + return entry +} + +func (e *Entry[K, V]) Next() *Entry[K, V] { + if p := e.meta.next; !p.root { + return e.meta.next + } + return nil +} + +func (e *Entry[K, V]) Prev() *Entry[K, V] { + if p := e.meta.prev; !p.root { + return e.meta.prev + } + return nil +} + +func (e *Entry[K, V]) prev() *Entry[K, V] { + return e.meta.prev +} + +func (e *Entry[K, V]) next() *Entry[K, V] { + return e.meta.next +} + +func (e *Entry[K, V]) setPrev(entry *Entry[K, V]) { + e.meta.prev = entry +} + +func (e *Entry[K, V]) setNext(entry *Entry[K, V]) { + e.meta.next = entry +} diff --git a/go/cache/theine/list.go b/go/cache/theine/list.go new file mode 100644 index 00000000000..19854190cba --- /dev/null +++ b/go/cache/theine/list.go @@ -0,0 +1,205 @@ +/* +Copyright 2023 The Vitess Authors. +Copyright 2023 Yiling-J + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package theine + +import ( + "fmt" + "strings" +) + +const ( + LIST_PROBATION uint8 = 1 + LIST_PROTECTED uint8 = 2 +) + +// List represents a doubly linked list. +// The zero value for List is an empty list ready to use. +type List[K cachekey, V any] struct { + root Entry[K, V] // sentinel list element, only &root, root.prev, and root.next are used + len int // current list length(sum of costs) excluding (this) sentinel element + count int // count of entries in list + capacity uint + bounded bool + listType uint8 // 1 tinylfu list, 2 timerwheel list +} + +// New returns an initialized list. +func NewList[K cachekey, V any](size uint, listType uint8) *List[K, V] { + l := &List[K, V]{listType: listType, capacity: size, root: Entry[K, V]{}} + l.root.root = true + l.root.setNext(&l.root) + l.root.setPrev(&l.root) + l.len = 0 + l.capacity = size + if size > 0 { + l.bounded = true + } + return l +} + +func (l *List[K, V]) Reset() { + l.root.setNext(&l.root) + l.root.setPrev(&l.root) + l.len = 0 +} + +// Len returns the number of elements of list l. +// The complexity is O(1). +func (l *List[K, V]) Len() int { return l.len } + +func (l *List[K, V]) display() string { + var s []string + for e := l.Front(); e != nil; e = e.Next() { + s = append(s, fmt.Sprintf("%v", e.key)) + } + return strings.Join(s, "/") +} + +func (l *List[K, V]) displayReverse() string { + var s []string + for e := l.Back(); e != nil; e = e.Prev() { + s = append(s, fmt.Sprintf("%v", e.key)) + } + return strings.Join(s, "/") +} + +// Front returns the first element of list l or nil if the list is empty. +func (l *List[K, V]) Front() *Entry[K, V] { + e := l.root.next() + if e != &l.root { + return e + } + return nil +} + +// Back returns the last element of list l or nil if the list is empty. +func (l *List[K, V]) Back() *Entry[K, V] { + e := l.root.prev() + if e != &l.root { + return e + } + return nil +} + +// insert inserts e after at, increments l.len, and evicted entry if capacity exceed +func (l *List[K, V]) insert(e, at *Entry[K, V]) *Entry[K, V] { + var evicted *Entry[K, V] + if l.bounded && l.len >= int(l.capacity) { + evicted = l.PopTail() + } + e.list = l.listType + e.setPrev(at) + e.setNext(at.next()) + e.prev().setNext(e) + e.next().setPrev(e) + if l.bounded { + l.len += int(e.cost.Load()) + l.count += 1 + } + return evicted +} + +// PushFront push entry to list head +func (l *List[K, V]) PushFront(e *Entry[K, V]) *Entry[K, V] { + return l.insert(e, &l.root) +} + +// Push push entry to the back of list +func (l *List[K, V]) PushBack(e *Entry[K, V]) *Entry[K, V] { + return l.insert(e, l.root.prev()) +} + +// remove removes e from its list, decrements l.len +func (l *List[K, V]) remove(e *Entry[K, V]) { + e.prev().setNext(e.next()) + e.next().setPrev(e.prev()) + e.setNext(nil) + e.setPrev(nil) + e.list = 0 + if l.bounded { + l.len -= int(e.cost.Load()) + l.count -= 1 + } +} + +// move moves e to next to at. +func (l *List[K, V]) move(e, at *Entry[K, V]) { + if e == at { + return + } + e.prev().setNext(e.next()) + e.next().setPrev(e.prev()) + + e.setPrev(at) + e.setNext(at.next()) + e.prev().setNext(e) + e.next().setPrev(e) +} + +// Remove removes e from l if e is an element of list l. +// It returns the element value e.Value. +// The element must not be nil. +func (l *List[K, V]) Remove(e *Entry[K, V]) { + l.remove(e) +} + +// MoveToFront moves element e to the front of list l. +// If e is not an element of l, the list is not modified. +// The element must not be nil. +func (l *List[K, V]) MoveToFront(e *Entry[K, V]) { + l.move(e, &l.root) +} + +// MoveToBack moves element e to the back of list l. +// If e is not an element of l, the list is not modified. +// The element must not be nil. +func (l *List[K, V]) MoveToBack(e *Entry[K, V]) { + l.move(e, l.root.prev()) +} + +// MoveBefore moves element e to its new position before mark. +// If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. +func (l *List[K, V]) MoveBefore(e, mark *Entry[K, V]) { + l.move(e, mark.prev()) +} + +// MoveAfter moves element e to its new position after mark. +// If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. +func (l *List[K, V]) MoveAfter(e, mark *Entry[K, V]) { + l.move(e, mark) +} + +func (l *List[K, V]) PopTail() *Entry[K, V] { + entry := l.root.prev() + if entry != nil && entry != &l.root { + l.remove(entry) + return entry + } + return nil +} + +func (l *List[K, V]) Contains(entry *Entry[K, V]) bool { + for e := l.Front(); e != nil; e = e.Next() { + if e == entry { + return true + } + } + return false +} diff --git a/go/cache/theine/list_test.go b/go/cache/theine/list_test.go new file mode 100644 index 00000000000..aad68f5c142 --- /dev/null +++ b/go/cache/theine/list_test.go @@ -0,0 +1,91 @@ +/* +Copyright 2023 The Vitess Authors. +Copyright 2023 Yiling-J + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package theine + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestList(t *testing.T) { + l := NewList[StringKey, string](5, LIST_PROBATION) + require.Equal(t, uint(5), l.capacity) + require.Equal(t, LIST_PROBATION, l.listType) + for i := 0; i < 5; i++ { + evicted := l.PushFront(NewEntry(StringKey(fmt.Sprintf("%d", i)), "", 1)) + require.Nil(t, evicted) + } + require.Equal(t, 5, l.len) + require.Equal(t, "4/3/2/1/0", l.display()) + require.Equal(t, "0/1/2/3/4", l.displayReverse()) + + evicted := l.PushFront(NewEntry(StringKey("5"), "", 1)) + require.Equal(t, StringKey("0"), evicted.key) + require.Equal(t, 5, l.len) + require.Equal(t, "5/4/3/2/1", l.display()) + require.Equal(t, "1/2/3/4/5", l.displayReverse()) + + for i := 0; i < 5; i++ { + entry := l.PopTail() + require.Equal(t, StringKey(fmt.Sprintf("%d", i+1)), entry.key) + } + entry := l.PopTail() + require.Nil(t, entry) + + var entries []*Entry[StringKey, string] + for i := 0; i < 5; i++ { + new := NewEntry(StringKey(fmt.Sprintf("%d", i)), "", 1) + evicted := l.PushFront(new) + entries = append(entries, new) + require.Nil(t, evicted) + } + require.Equal(t, "4/3/2/1/0", l.display()) + l.MoveToBack(entries[2]) + require.Equal(t, "4/3/1/0/2", l.display()) + require.Equal(t, "2/0/1/3/4", l.displayReverse()) + l.MoveBefore(entries[1], entries[3]) + require.Equal(t, "4/1/3/0/2", l.display()) + require.Equal(t, "2/0/3/1/4", l.displayReverse()) + l.MoveAfter(entries[2], entries[4]) + require.Equal(t, "4/2/1/3/0", l.display()) + require.Equal(t, "0/3/1/2/4", l.displayReverse()) + l.Remove(entries[1]) + require.Equal(t, "4/2/3/0", l.display()) + require.Equal(t, "0/3/2/4", l.displayReverse()) + +} + +func TestListCountCost(t *testing.T) { + l := NewList[StringKey, string](100, LIST_PROBATION) + require.Equal(t, uint(100), l.capacity) + require.Equal(t, LIST_PROBATION, l.listType) + for i := 0; i < 5; i++ { + evicted := l.PushFront(NewEntry(StringKey(fmt.Sprintf("%d", i)), "", 20)) + require.Nil(t, evicted) + } + require.Equal(t, 100, l.len) + require.Equal(t, 5, l.count) + for i := 0; i < 3; i++ { + entry := l.PopTail() + require.NotNil(t, entry) + } + require.Equal(t, 40, l.len) + require.Equal(t, 2, l.count) +} diff --git a/go/cache/theine/mpsc.go b/go/cache/theine/mpsc.go new file mode 100644 index 00000000000..c00e2ce5a26 --- /dev/null +++ b/go/cache/theine/mpsc.go @@ -0,0 +1,86 @@ +/* +Copyright 2023 The Vitess Authors. +Copyright 2023 Yiling-J + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package theine + +// This implementation is based on http://www.1024cores.net/home/lock-free-algorithms/queues/non-intrusive-mpsc-node-based-queue + +import ( + "sync" + "sync/atomic" +) + +type node[V any] struct { + next atomic.Pointer[node[V]] + val V +} + +type Queue[V any] struct { + head, tail atomic.Pointer[node[V]] + nodePool sync.Pool +} + +func NewQueue[V any]() *Queue[V] { + q := &Queue[V]{nodePool: sync.Pool{New: func() any { + return new(node[V]) + }}} + stub := &node[V]{} + q.head.Store(stub) + q.tail.Store(stub) + return q +} + +// Push adds x to the back of the queue. +// +// Push can be safely called from multiple goroutines +func (q *Queue[V]) Push(x V) { + n := q.nodePool.Get().(*node[V]) + n.val = x + + // current producer acquires head node + prev := q.head.Swap(n) + + // release node to consumer + prev.next.Store(n) +} + +// Pop removes the item from the front of the queue or nil if the queue is empty +// +// Pop must be called from a single, consumer goroutine +func (q *Queue[V]) Pop() (V, bool) { + tail := q.tail.Load() + next := tail.next.Load() + if next != nil { + var null V + q.tail.Store(next) + v := next.val + next.val = null + tail.next.Store(nil) + q.nodePool.Put(tail) + return v, true + } + var null V + return null, false +} + +// Empty returns true if the queue is empty +// +// Empty must be called from a single, consumer goroutine +func (q *Queue[V]) Empty() bool { + tail := q.tail.Load() + return tail.next.Load() == nil +} diff --git a/go/cache/theine/mpsc_test.go b/go/cache/theine/mpsc_test.go new file mode 100644 index 00000000000..eca50efed3e --- /dev/null +++ b/go/cache/theine/mpsc_test.go @@ -0,0 +1,46 @@ +/* +Copyright 2023 The Vitess Authors. +Copyright 2023 Yiling-J + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package theine + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestQueue_PushPop(t *testing.T) { + q := NewQueue[int]() + + q.Push(1) + q.Push(2) + v, ok := q.Pop() + assert.True(t, ok) + assert.Equal(t, 1, v) + v, ok = q.Pop() + assert.True(t, ok) + assert.Equal(t, 2, v) + _, ok = q.Pop() + assert.False(t, ok) +} + +func TestQueue_Empty(t *testing.T) { + q := NewQueue[int]() + assert.True(t, q.Empty()) + q.Push(1) + assert.False(t, q.Empty()) +} diff --git a/go/cache/theine/singleflight.go b/go/cache/theine/singleflight.go new file mode 100644 index 00000000000..fde56670514 --- /dev/null +++ b/go/cache/theine/singleflight.go @@ -0,0 +1,196 @@ +/* +Copyright 2023 The Vitess Authors. +Copyright 2023 Yiling-J +Copyright 2013 The Go Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package theine + +import ( + "bytes" + "errors" + "fmt" + "runtime" + "runtime/debug" + "sync" + "sync/atomic" +) + +// errGoexit indicates the runtime.Goexit was called in +// the user given function. +var errGoexit = errors.New("runtime.Goexit was called") + +// A panicError is an arbitrary value recovered from a panic +// with the stack trace during the execution of given function. +type panicError struct { + value interface{} + stack []byte +} + +// Error implements error interface. +func (p *panicError) Error() string { + return fmt.Sprintf("%v\n\n%s", p.value, p.stack) +} + +func newPanicError(v interface{}) error { + stack := debug.Stack() + + // The first line of the stack trace is of the form "goroutine N [status]:" + // but by the time the panic reaches Do the goroutine may no longer exist + // and its status will have changed. Trim out the misleading line. + if line := bytes.IndexByte(stack[:], '\n'); line >= 0 { + stack = stack[line+1:] + } + return &panicError{value: v, stack: stack} +} + +// call is an in-flight or completed singleflight.Do call +type call[V any] struct { + + // These fields are written once before the WaitGroup is done + // and are only read after the WaitGroup is done. + val V + err error + + wg sync.WaitGroup + + // These fields are read and written with the singleflight + // mutex held before the WaitGroup is done, and are read but + // not written after the WaitGroup is done. + dups atomic.Int32 +} + +// Group represents a class of work and forms a namespace in +// which units of work can be executed with duplicate suppression. +type Group[K comparable, V any] struct { + m map[K]*call[V] // lazily initialized + mu sync.Mutex // protects m + callPool sync.Pool +} + +func NewGroup[K comparable, V any]() *Group[K, V] { + return &Group[K, V]{ + callPool: sync.Pool{New: func() any { + return new(call[V]) + }}, + } +} + +// Result holds the results of Do, so they can be passed +// on a channel. +type Result struct { + Val interface{} + Err error + Shared bool +} + +// Do executes and returns the results of the given function, making +// sure that only one execution is in-flight for a given key at a +// time. If a duplicate comes in, the duplicate caller waits for the +// original to complete and receives the same results. +// The return value shared indicates whether v was given to multiple callers. +func (g *Group[K, V]) Do(key K, fn func() (V, error)) (v V, err error, shared bool) { + g.mu.Lock() + if g.m == nil { + g.m = make(map[K]*call[V]) + } + if c, ok := g.m[key]; ok { + _ = c.dups.Add(1) + g.mu.Unlock() + c.wg.Wait() + + if e, ok := c.err.(*panicError); ok { + panic(e) + } else if c.err == errGoexit { + runtime.Goexit() + } + // assign value/err before put back to pool to avoid race + v = c.val + err = c.err + n := c.dups.Add(-1) + if n == 0 { + g.callPool.Put(c) + } + return v, err, true + } + c := g.callPool.Get().(*call[V]) + defer func() { + n := c.dups.Add(-1) + if n == 0 { + g.callPool.Put(c) + } + }() + _ = c.dups.Add(1) + c.wg.Add(1) + g.m[key] = c + g.mu.Unlock() + + g.doCall(c, key, fn) + return c.val, c.err, true +} + +// doCall handles the single call for a key. +func (g *Group[K, V]) doCall(c *call[V], key K, fn func() (V, error)) { + normalReturn := false + recovered := false + + // use double-defer to distinguish panic from runtime.Goexit, + // more details see https://golang.org/cl/134395 + defer func() { + // the given function invoked runtime.Goexit + if !normalReturn && !recovered { + c.err = errGoexit + } + + g.mu.Lock() + defer g.mu.Unlock() + c.wg.Done() + if g.m[key] == c { + delete(g.m, key) + } + + if e, ok := c.err.(*panicError); ok { + panic(e) + } + }() + + func() { + defer func() { + if !normalReturn { + // Ideally, we would wait to take a stack trace until we've determined + // whether this is a panic or a runtime.Goexit. + // + // Unfortunately, the only way we can distinguish the two is to see + // whether the recover stopped the goroutine from terminating, and by + // the time we know that, the part of the stack trace relevant to the + // panic has been discarded. + if r := recover(); r != nil { + c.err = newPanicError(r) + } + } + }() + + c.val, c.err = fn() + normalReturn = true + }() + + if !normalReturn { + recovered = true + } +} diff --git a/go/cache/theine/singleflight_test.go b/go/cache/theine/singleflight_test.go new file mode 100644 index 00000000000..60b28e69b4e --- /dev/null +++ b/go/cache/theine/singleflight_test.go @@ -0,0 +1,211 @@ +/* +Copyright 2023 The Vitess Authors. +Copyright 2023 Yiling-J +Copyright 2013 The Go Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package theine + +import ( + "crypto/rand" + "errors" + "fmt" + "io" + "runtime" + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestDo(t *testing.T) { + g := NewGroup[string, string]() + v, err, _ := g.Do("key", func() (string, error) { + return "bar", nil + }) + if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want { + t.Errorf("Do = %v; want %v", got, want) + } + if err != nil { + t.Errorf("Do error = %v", err) + } +} + +func TestDoErr(t *testing.T) { + g := NewGroup[string, string]() + someErr := errors.New("Some error") + v, err, _ := g.Do("key", func() (string, error) { + return "", someErr + }) + if err != someErr { + t.Errorf("Do error = %v; want someErr %v", err, someErr) + } + if v != "" { + t.Errorf("unexpected non-nil value %#v", v) + } +} + +func TestDoDupSuppress(t *testing.T) { + g := NewGroup[string, string]() + var wg1, wg2 sync.WaitGroup + c := make(chan string, 1) + var calls int32 + fn := func() (string, error) { + if atomic.AddInt32(&calls, 1) == 1 { + // First invocation. + wg1.Done() + } + v := <-c + c <- v // pump; make available for any future calls + + time.Sleep(10 * time.Millisecond) // let more goroutines enter Do + + return v, nil + } + + const n = 10 + wg1.Add(1) + for i := 0; i < n; i++ { + wg1.Add(1) + wg2.Add(1) + go func() { + defer wg2.Done() + wg1.Done() + v, err, _ := g.Do("key", fn) + if err != nil { + t.Errorf("Do error: %v", err) + return + } + if s := v; s != "bar" { + t.Errorf("Do = %T %v; want %q", v, v, "bar") + } + }() + } + wg1.Wait() + // At least one goroutine is in fn now and all of them have at + // least reached the line before the Do. + c <- "bar" + wg2.Wait() + if got := atomic.LoadInt32(&calls); got <= 0 || got >= n { + t.Errorf("number of calls = %d; want over 0 and less than %d", got, n) + } +} + +// Test singleflight behaves correctly after Do panic. +// See https://github.com/golang/go/issues/41133 +func TestPanicDo(t *testing.T) { + g := NewGroup[string, string]() + fn := func() (string, error) { + panic("invalid memory address or nil pointer dereference") + } + + const n = 5 + waited := int32(n) + panicCount := int32(0) + done := make(chan struct{}) + for i := 0; i < n; i++ { + go func() { + defer func() { + if err := recover(); err != nil { + atomic.AddInt32(&panicCount, 1) + } + + if atomic.AddInt32(&waited, -1) == 0 { + close(done) + } + }() + + _, _, _ = g.Do("key", fn) + }() + } + + select { + case <-done: + if panicCount != n { + t.Errorf("Expect %d panic, but got %d", n, panicCount) + } + case <-time.After(time.Second): + t.Fatalf("Do hangs") + } +} + +func TestGoexitDo(t *testing.T) { + g := NewGroup[string, int]() + fn := func() (int, error) { + runtime.Goexit() + return 0, nil + } + + const n = 5 + waited := int32(n) + done := make(chan struct{}) + for i := 0; i < n; i++ { + go func() { + var err error + defer func() { + if err != nil { + t.Errorf("Error should be nil, but got: %v", err) + } + if atomic.AddInt32(&waited, -1) == 0 { + close(done) + } + }() + _, err, _ = g.Do("key", fn) + }() + } + + select { + case <-done: + case <-time.After(time.Second): + t.Fatalf("Do hangs") + } +} + +func BenchmarkDo(b *testing.B) { + keys := randKeys(b, 10240, 10) + benchDo(b, NewGroup[string, int](), keys) + +} + +func benchDo(b *testing.B, g *Group[string, int], keys []string) { + keyc := len(keys) + b.ReportAllocs() + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + for i := 0; pb.Next(); i++ { + _, _, _ = g.Do(keys[i%keyc], func() (int, error) { + return 0, nil + }) + } + }) +} + +func randKeys(b *testing.B, count, length uint) []string { + keys := make([]string, 0, count) + key := make([]byte, length) + + for i := uint(0); i < count; i++ { + if _, err := io.ReadFull(rand.Reader, key); err != nil { + b.Fatalf("Failed to generate random key %d of %d of length %d: %s", i+1, count, length, err) + } + keys = append(keys, string(key)) + } + return keys +} diff --git a/go/cache/theine/sketch.go b/go/cache/theine/sketch.go new file mode 100644 index 00000000000..7d241d94fc8 --- /dev/null +++ b/go/cache/theine/sketch.go @@ -0,0 +1,137 @@ +/* +Copyright 2023 The Vitess Authors. +Copyright 2023 Yiling-J + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package theine + +type CountMinSketch struct { + Table []uint64 + Additions uint + SampleSize uint + BlockMask uint +} + +func NewCountMinSketch() *CountMinSketch { + new := &CountMinSketch{} + new.EnsureCapacity(16) + return new +} + +// indexOf return table index and counter index together +func (s *CountMinSketch) indexOf(h uint64, block uint64, offset uint8) (uint, uint) { + counterHash := h + uint64(1+offset)*(h>>32) + // max block + 7(8 * 8 bytes), fit 64 bytes cache line + index := block + counterHash&1 + uint64(offset<<1) + return uint(index), uint((counterHash & 0xF) << 2) +} + +func (s *CountMinSketch) inc(index uint, offset uint) bool { + mask := uint64(0xF << offset) + if s.Table[index]&mask != mask { + s.Table[index] += 1 << offset + return true + } + return false +} + +func (s *CountMinSketch) Add(h uint64) bool { + hn := spread(h) + block := (hn & uint64(s.BlockMask)) << 3 + hc := rehash(h) + index0, offset0 := s.indexOf(hc, block, 0) + index1, offset1 := s.indexOf(hc, block, 1) + index2, offset2 := s.indexOf(hc, block, 2) + index3, offset3 := s.indexOf(hc, block, 3) + + added := s.inc(index0, offset0) + added = s.inc(index1, offset1) || added + added = s.inc(index2, offset2) || added + added = s.inc(index3, offset3) || added + + if added { + s.Additions += 1 + if s.Additions == s.SampleSize { + s.reset() + return true + } + } + return false +} + +func (s *CountMinSketch) reset() { + for i := range s.Table { + s.Table[i] = s.Table[i] >> 1 + } + s.Additions = s.Additions >> 1 +} + +func (s *CountMinSketch) count(h uint64, block uint64, offset uint8) uint { + index, off := s.indexOf(h, block, offset) + count := (s.Table[index] >> off) & 0xF + return uint(count) +} + +func (s *CountMinSketch) Estimate(h uint64) uint { + hn := spread(h) + block := (hn & uint64(s.BlockMask)) << 3 + hc := rehash(h) + m := min(s.count(hc, block, 0), 100) + m = min(s.count(hc, block, 1), m) + m = min(s.count(hc, block, 2), m) + m = min(s.count(hc, block, 3), m) + return m +} + +func next2Power(x uint) uint { + x-- + x |= x >> 1 + x |= x >> 2 + x |= x >> 4 + x |= x >> 8 + x |= x >> 16 + x |= x >> 32 + x++ + return x +} + +func (s *CountMinSketch) EnsureCapacity(size uint) { + if len(s.Table) >= int(size) { + return + } + if size < 16 { + size = 16 + } + newSize := next2Power(size) + s.Table = make([]uint64, newSize) + s.SampleSize = 10 * size + s.BlockMask = uint((len(s.Table) >> 3) - 1) + s.Additions = 0 +} + +func spread(h uint64) uint64 { + h ^= h >> 17 + h *= 0xed5ad4bb + h ^= h >> 11 + h *= 0xac4c1b51 + h ^= h >> 15 + return h +} + +func rehash(h uint64) uint64 { + h *= 0x31848bab + h ^= h >> 14 + return h +} diff --git a/go/cache/theine/sketch_test.go b/go/cache/theine/sketch_test.go new file mode 100644 index 00000000000..3437f0cac3c --- /dev/null +++ b/go/cache/theine/sketch_test.go @@ -0,0 +1,54 @@ +package theine + +import ( + "fmt" + "testing" + + "github.com/cespare/xxhash/v2" + "github.com/stretchr/testify/require" +) + +func TestEnsureCapacity(t *testing.T) { + sketch := NewCountMinSketch() + sketch.EnsureCapacity(1) + require.Equal(t, 16, len(sketch.Table)) +} + +func TestSketch(t *testing.T) { + sketch := NewCountMinSketch() + sketch.EnsureCapacity(100) + require.Equal(t, 128, len(sketch.Table)) + require.Equal(t, uint(1000), sketch.SampleSize) + // override sampleSize so test won't reset + sketch.SampleSize = 5120 + + failed := 0 + for i := 0; i < 500; i++ { + key := fmt.Sprintf("key:%d", i) + keyh := xxhash.Sum64String(key) + sketch.Add(keyh) + sketch.Add(keyh) + sketch.Add(keyh) + sketch.Add(keyh) + sketch.Add(keyh) + key = fmt.Sprintf("key:%d:b", i) + keyh2 := xxhash.Sum64String(key) + sketch.Add(keyh2) + sketch.Add(keyh2) + sketch.Add(keyh2) + + es1 := sketch.Estimate(keyh) + es2 := sketch.Estimate(keyh2) + if es2 > es1 { + failed++ + } + require.True(t, es1 >= 5) + require.True(t, es2 >= 3) + + } + require.True(t, float32(failed)/4000 < 0.1) + require.True(t, sketch.Additions > 3500) + a := sketch.Additions + sketch.reset() + require.Equal(t, a>>1, sketch.Additions) +} diff --git a/go/cache/theine/slru.go b/go/cache/theine/slru.go new file mode 100644 index 00000000000..e3bcb2532b1 --- /dev/null +++ b/go/cache/theine/slru.go @@ -0,0 +1,79 @@ +/* +Copyright 2023 The Vitess Authors. +Copyright 2023 Yiling-J + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package theine + +type Slru[K cachekey, V any] struct { + probation *List[K, V] + protected *List[K, V] + maxsize uint +} + +func NewSlru[K cachekey, V any](size uint) *Slru[K, V] { + return &Slru[K, V]{ + maxsize: size, + probation: NewList[K, V](size, LIST_PROBATION), + protected: NewList[K, V](uint(float32(size)*0.8), LIST_PROTECTED), + } +} + +func (s *Slru[K, V]) insert(entry *Entry[K, V]) *Entry[K, V] { + var evicted *Entry[K, V] + if s.probation.Len()+s.protected.Len() >= int(s.maxsize) { + evicted = s.probation.PopTail() + } + s.probation.PushFront(entry) + return evicted +} + +func (s *Slru[K, V]) victim() *Entry[K, V] { + if s.probation.Len()+s.protected.Len() < int(s.maxsize) { + return nil + } + return s.probation.Back() +} + +func (s *Slru[K, V]) access(entry *Entry[K, V]) { + switch entry.list { + case LIST_PROBATION: + s.probation.remove(entry) + evicted := s.protected.PushFront(entry) + if evicted != nil { + s.probation.PushFront(evicted) + } + case LIST_PROTECTED: + s.protected.MoveToFront(entry) + } +} + +func (s *Slru[K, V]) remove(entry *Entry[K, V]) { + switch entry.list { + case LIST_PROBATION: + s.probation.remove(entry) + case LIST_PROTECTED: + s.protected.remove(entry) + } +} + +func (s *Slru[K, V]) updateCost(entry *Entry[K, V], delta int64) { + switch entry.list { + case LIST_PROBATION: + s.probation.len += int(delta) + case LIST_PROTECTED: + s.protected.len += int(delta) + } +} diff --git a/go/cache/theine/store.go b/go/cache/theine/store.go new file mode 100644 index 00000000000..3d86e549867 --- /dev/null +++ b/go/cache/theine/store.go @@ -0,0 +1,615 @@ +/* +Copyright 2023 The Vitess Authors. +Copyright 2023 Yiling-J + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package theine + +import ( + "runtime" + "sync" + "sync/atomic" + "time" + + "github.com/gammazero/deque" + + "vitess.io/vitess/go/cache/theine/bf" + "vitess.io/vitess/go/hack" +) + +const ( + MaxReadBuffSize = 64 + MinWriteBuffSize = 4 + MaxWriteBuffSize = 1024 +) + +type RemoveReason uint8 + +const ( + REMOVED RemoveReason = iota + EVICTED + EXPIRED +) + +type Shard[K cachekey, V any] struct { + hashmap map[K]*Entry[K, V] + dookeeper *bf.Bloomfilter + deque *deque.Deque[*Entry[K, V]] + group *Group[K, V] + qsize uint + qlen int + counter uint + mu sync.RWMutex +} + +func NewShard[K cachekey, V any](size uint, qsize uint, doorkeeper bool) *Shard[K, V] { + s := &Shard[K, V]{ + hashmap: make(map[K]*Entry[K, V]), + qsize: qsize, + deque: deque.New[*Entry[K, V]](), + group: NewGroup[K, V](), + } + if doorkeeper { + s.dookeeper = bf.New(0.01) + } + return s +} + +func (s *Shard[K, V]) set(key K, entry *Entry[K, V]) { + s.hashmap[key] = entry + if s.dookeeper != nil { + ds := 20 * len(s.hashmap) + if ds > s.dookeeper.Capacity { + s.dookeeper.EnsureCapacity(ds) + } + } +} + +func (s *Shard[K, V]) get(key K) (entry *Entry[K, V], ok bool) { + entry, ok = s.hashmap[key] + return +} + +func (s *Shard[K, V]) delete(entry *Entry[K, V]) bool { + var deleted bool + exist, ok := s.hashmap[entry.key] + if ok && exist == entry { + delete(s.hashmap, exist.key) + deleted = true + } + return deleted +} + +func (s *Shard[K, V]) len() int { + return len(s.hashmap) +} + +type Metrics struct { + evicted atomic.Int64 + hits atomic.Int64 + misses atomic.Int64 +} + +func (m *Metrics) Evicted() int64 { + return m.evicted.Load() +} + +func (m *Metrics) Hits() int64 { + return m.hits.Load() +} + +func (m *Metrics) Misses() int64 { + return m.misses.Load() +} + +func (m *Metrics) Accesses() int64 { + return m.Hits() + m.Misses() +} + +type cachekey interface { + comparable + Hash() uint64 + Hash2() (uint64, uint64) +} + +type HashKey256 [32]byte + +func (h HashKey256) Hash() uint64 { + return uint64(h[0]) | uint64(h[1])<<8 | uint64(h[2])<<16 | uint64(h[3])<<24 | + uint64(h[4])<<32 | uint64(h[5])<<40 | uint64(h[6])<<48 | uint64(h[7])<<56 +} + +func (h HashKey256) Hash2() (uint64, uint64) { + h0 := h.Hash() + h1 := uint64(h[8]) | uint64(h[9])<<8 | uint64(h[10])<<16 | uint64(h[11])<<24 | + uint64(h[12])<<32 | uint64(h[13])<<40 | uint64(h[14])<<48 | uint64(h[15])<<56 + return h0, h1 +} + +type StringKey string + +func (h StringKey) Hash() uint64 { + return hack.RuntimeStrhash(string(h), 13850135847636357301) +} + +func (h StringKey) Hash2() (uint64, uint64) { + h0 := h.Hash() + h1 := ((h0 >> 16) ^ h0) * 0x45d9f3b + h1 = ((h1 >> 16) ^ h1) * 0x45d9f3b + h1 = (h1 >> 16) ^ h1 + return h0, h1 +} + +type cacheval interface { + CachedSize(alloc bool) int64 +} + +type Store[K cachekey, V cacheval] struct { + Metrics Metrics + OnRemoval func(K, V, RemoveReason) + + entryPool sync.Pool + writebuf chan WriteBufItem[K, V] + policy *TinyLfu[K, V] + readbuf *Queue[ReadBufItem[K, V]] + shards []*Shard[K, V] + cap uint + shardCount uint + writebufsize int64 + tailUpdate bool + doorkeeper bool + + mlock sync.Mutex + readCounter atomic.Uint32 + open atomic.Bool +} + +func NewStore[K cachekey, V cacheval](maxsize int64, doorkeeper bool) *Store[K, V] { + writeBufSize := maxsize / 100 + if writeBufSize < MinWriteBuffSize { + writeBufSize = MinWriteBuffSize + } + if writeBufSize > MaxWriteBuffSize { + writeBufSize = MaxWriteBuffSize + } + shardCount := 1 + for shardCount < runtime.GOMAXPROCS(0)*2 { + shardCount *= 2 + } + if shardCount < 16 { + shardCount = 16 + } + if shardCount > 128 { + shardCount = 128 + } + dequeSize := int(maxsize) / 100 / shardCount + shardSize := int(maxsize) / shardCount + if shardSize < 50 { + shardSize = 50 + } + policySize := int(maxsize) - (dequeSize * shardCount) + + s := &Store[K, V]{ + cap: uint(maxsize), + policy: NewTinyLfu[K, V](uint(policySize)), + readbuf: NewQueue[ReadBufItem[K, V]](), + writebuf: make(chan WriteBufItem[K, V], writeBufSize), + entryPool: sync.Pool{New: func() any { return &Entry[K, V]{} }}, + shardCount: uint(shardCount), + doorkeeper: doorkeeper, + writebufsize: writeBufSize, + } + s.shards = make([]*Shard[K, V], 0, s.shardCount) + for i := 0; i < int(s.shardCount); i++ { + s.shards = append(s.shards, NewShard[K, V](uint(shardSize), uint(dequeSize), doorkeeper)) + } + + go s.maintenance() + s.open.Store(true) + return s +} + +func (s *Store[K, V]) EnsureOpen() { + if s.open.Swap(true) { + return + } + s.writebuf = make(chan WriteBufItem[K, V], s.writebufsize) + go s.maintenance() +} + +func (s *Store[K, V]) getFromShard(key K, hash uint64, shard *Shard[K, V], epoch uint32) (V, bool) { + new := s.readCounter.Add(1) + shard.mu.RLock() + entry, ok := shard.get(key) + var value V + if ok { + if entry.epoch.Load() < epoch { + s.Metrics.misses.Add(1) + ok = false + } else { + s.Metrics.hits.Add(1) + s.policy.hit.Add(1) + value = entry.value + } + } else { + s.Metrics.misses.Add(1) + } + shard.mu.RUnlock() + switch { + case new < MaxReadBuffSize: + var send ReadBufItem[K, V] + send.hash = hash + if ok { + send.entry = entry + } + s.readbuf.Push(send) + case new == MaxReadBuffSize: + var send ReadBufItem[K, V] + send.hash = hash + if ok { + send.entry = entry + } + s.readbuf.Push(send) + s.drainRead() + } + return value, ok +} + +func (s *Store[K, V]) Get(key K, epoch uint32) (V, bool) { + h, index := s.index(key) + shard := s.shards[index] + return s.getFromShard(key, h, shard, epoch) +} + +func (s *Store[K, V]) GetOrLoad(key K, epoch uint32, load func() (V, error)) (V, bool, error) { + h, index := s.index(key) + shard := s.shards[index] + v, ok := s.getFromShard(key, h, shard, epoch) + if !ok { + loaded, err, _ := shard.group.Do(key, func() (V, error) { + loaded, err := load() + if err == nil { + s.Set(key, loaded, 0, epoch) + } + return loaded, err + }) + return loaded, false, err + } + return v, true, nil +} + +func (s *Store[K, V]) setEntry(shard *Shard[K, V], cost int64, epoch uint32, entry *Entry[K, V]) { + shard.set(entry.key, entry) + // cost larger than deque size, send to policy directly + if cost > int64(shard.qsize) { + shard.mu.Unlock() + s.writebuf <- WriteBufItem[K, V]{entry: entry, code: NEW} + return + } + entry.deque = true + shard.deque.PushFront(entry) + shard.qlen += int(cost) + s.processDeque(shard, epoch) +} + +func (s *Store[K, V]) setInternal(key K, value V, cost int64, epoch uint32) (*Shard[K, V], *Entry[K, V], bool) { + h, index := s.index(key) + shard := s.shards[index] + shard.mu.Lock() + exist, ok := shard.get(key) + if ok { + var costChange int64 + exist.value = value + oldCost := exist.cost.Swap(cost) + if oldCost != cost { + costChange = cost - oldCost + if exist.deque { + shard.qlen += int(costChange) + } + } + shard.mu.Unlock() + exist.epoch.Store(epoch) + if costChange != 0 { + s.writebuf <- WriteBufItem[K, V]{ + entry: exist, code: UPDATE, costChange: costChange, + } + } + return shard, exist, true + } + if s.doorkeeper { + if shard.counter > uint(shard.dookeeper.Capacity) { + shard.dookeeper.Reset() + shard.counter = 0 + } + hit := shard.dookeeper.Insert(h) + if !hit { + shard.counter += 1 + shard.mu.Unlock() + return shard, nil, false + } + } + entry := s.entryPool.Get().(*Entry[K, V]) + entry.frequency.Store(-1) + entry.key = key + entry.value = value + entry.cost.Store(cost) + entry.epoch.Store(epoch) + s.setEntry(shard, cost, epoch, entry) + return shard, entry, true + +} + +func (s *Store[K, V]) Set(key K, value V, cost int64, epoch uint32) bool { + if cost == 0 { + cost = value.CachedSize(true) + } + if cost > int64(s.cap) { + return false + } + _, _, ok := s.setInternal(key, value, cost, epoch) + return ok +} + +type dequeKV[K cachekey, V cacheval] struct { + k K + v V +} + +func (s *Store[K, V]) processDeque(shard *Shard[K, V], epoch uint32) { + if shard.qlen <= int(shard.qsize) { + shard.mu.Unlock() + return + } + var evictedkv []dequeKV[K, V] + var expiredkv []dequeKV[K, V] + + // send to slru + send := make([]*Entry[K, V], 0, 2) + for shard.qlen > int(shard.qsize) { + evicted := shard.deque.PopBack() + evicted.deque = false + shard.qlen -= int(evicted.cost.Load()) + + if evicted.epoch.Load() < epoch { + deleted := shard.delete(evicted) + if deleted { + if s.OnRemoval != nil { + evictedkv = append(evictedkv, dequeKV[K, V]{evicted.key, evicted.value}) + } + s.postDelete(evicted) + s.Metrics.evicted.Add(1) + } + } else { + count := evicted.frequency.Load() + threshold := s.policy.threshold.Load() + if count == -1 { + send = append(send, evicted) + } else { + if int32(count) >= threshold { + send = append(send, evicted) + } else { + deleted := shard.delete(evicted) + // double check because entry maybe removed already by Delete API + if deleted { + if s.OnRemoval != nil { + evictedkv = append(evictedkv, dequeKV[K, V]{evicted.key, evicted.value}) + } + s.postDelete(evicted) + s.Metrics.evicted.Add(1) + } + } + } + } + } + + shard.mu.Unlock() + for _, entry := range send { + s.writebuf <- WriteBufItem[K, V]{entry: entry, code: NEW} + } + if s.OnRemoval != nil { + for _, kv := range evictedkv { + s.OnRemoval(kv.k, kv.v, EVICTED) + } + for _, kv := range expiredkv { + s.OnRemoval(kv.k, kv.v, EXPIRED) + } + } +} + +func (s *Store[K, V]) Delete(key K) { + _, index := s.index(key) + shard := s.shards[index] + shard.mu.Lock() + entry, ok := shard.get(key) + if ok { + shard.delete(entry) + } + shard.mu.Unlock() + if ok { + s.writebuf <- WriteBufItem[K, V]{entry: entry, code: REMOVE} + } +} + +func (s *Store[K, V]) Len() int { + total := 0 + for _, s := range s.shards { + s.mu.RLock() + total += s.len() + s.mu.RUnlock() + } + return total +} + +func (s *Store[K, V]) UsedCapacity() int { + total := 0 + for _, s := range s.shards { + s.mu.RLock() + total += s.qlen + s.mu.RUnlock() + } + return total +} + +func (s *Store[K, V]) MaxCapacity() int { + return int(s.cap) +} + +// spread hash before get index +func (s *Store[K, V]) index(key K) (uint64, int) { + h0, h1 := key.Hash2() + return h0, int(h1 & uint64(s.shardCount-1)) +} + +func (s *Store[K, V]) postDelete(entry *Entry[K, V]) { + var zero V + entry.value = zero + s.entryPool.Put(entry) +} + +// remove entry from cache/policy/timingwheel and add back to pool +func (s *Store[K, V]) removeEntry(entry *Entry[K, V], reason RemoveReason) { + if prev := entry.meta.prev; prev != nil { + s.policy.Remove(entry) + } + switch reason { + case EVICTED, EXPIRED: + _, index := s.index(entry.key) + shard := s.shards[index] + shard.mu.Lock() + deleted := shard.delete(entry) + shard.mu.Unlock() + if deleted { + if s.OnRemoval != nil { + s.OnRemoval(entry.key, entry.value, reason) + } + s.postDelete(entry) + s.Metrics.evicted.Add(1) + } + case REMOVED: + // already removed from shard map + if s.OnRemoval != nil { + s.OnRemoval(entry.key, entry.value, reason) + } + } +} + +func (s *Store[K, V]) drainRead() { + s.policy.total.Add(MaxReadBuffSize) + s.mlock.Lock() + for { + v, ok := s.readbuf.Pop() + if !ok { + break + } + s.policy.Access(v) + } + s.mlock.Unlock() + s.readCounter.Store(0) +} + +func (s *Store[K, V]) maintenanceItem(item WriteBufItem[K, V]) { + s.mlock.Lock() + defer s.mlock.Unlock() + + entry := item.entry + if entry == nil { + return + } + + // lock free because store API never read/modify entry metadata + switch item.code { + case NEW: + if entry.removed { + return + } + evicted := s.policy.Set(entry) + if evicted != nil { + s.removeEntry(evicted, EVICTED) + s.tailUpdate = true + } + removed := s.policy.EvictEntries() + for _, e := range removed { + s.tailUpdate = true + s.removeEntry(e, EVICTED) + } + case REMOVE: + entry.removed = true + s.removeEntry(entry, REMOVED) + s.policy.threshold.Store(-1) + case UPDATE: + if item.costChange != 0 { + s.policy.UpdateCost(entry, item.costChange) + removed := s.policy.EvictEntries() + for _, e := range removed { + s.tailUpdate = true + s.removeEntry(e, EVICTED) + } + } + } + item.entry = nil + if s.tailUpdate { + s.policy.UpdateThreshold() + s.tailUpdate = false + } +} + +func (s *Store[K, V]) maintenance() { + tick := time.NewTicker(500 * time.Millisecond) + defer tick.Stop() + + for { + select { + case <-tick.C: + s.mlock.Lock() + s.policy.UpdateThreshold() + s.mlock.Unlock() + + case item, ok := <-s.writebuf: + if !ok { + return + } + s.maintenanceItem(item) + } + } +} + +func (s *Store[K, V]) Range(epoch uint32, f func(key K, value V) bool) { + for _, shard := range s.shards { + shard.mu.RLock() + for _, entry := range shard.hashmap { + if entry.epoch.Load() < epoch { + continue + } + if !f(entry.key, entry.value) { + shard.mu.RUnlock() + return + } + } + shard.mu.RUnlock() + } +} + +func (s *Store[K, V]) Close() { + if !s.open.Swap(false) { + panic("theine.Store: double close") + } + + for _, s := range s.shards { + s.mu.Lock() + clear(s.hashmap) + s.mu.Unlock() + } + close(s.writebuf) +} diff --git a/go/cache/theine/store_test.go b/go/cache/theine/store_test.go new file mode 100644 index 00000000000..880acf30193 --- /dev/null +++ b/go/cache/theine/store_test.go @@ -0,0 +1,82 @@ +/* +Copyright 2023 The Vitess Authors. +Copyright 2023 Yiling-J + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package theine + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +type cachedint int + +func (ci cachedint) CachedSize(bool) int64 { + return 1 +} + +type keyint int + +func (k keyint) Hash() uint64 { + return uint64(k) +} + +func (k keyint) Hash2() (uint64, uint64) { + return uint64(k), uint64(k) * 333 +} + +func TestProcessDeque(t *testing.T) { + store := NewStore[keyint, cachedint](20000, false) + + evicted := map[keyint]cachedint{} + store.OnRemoval = func(key keyint, value cachedint, reason RemoveReason) { + if reason == EVICTED { + evicted[key] = value + } + } + _, index := store.index(123) + shard := store.shards[index] + shard.qsize = 10 + + for i := keyint(0); i < 5; i++ { + entry := &Entry[keyint, cachedint]{key: i} + entry.cost.Store(1) + store.shards[index].deque.PushFront(entry) + store.shards[index].qlen += 1 + store.shards[index].hashmap[i] = entry + } + + // move 0,1,2 entries to slru + store.Set(123, 123, 8, 0) + require.Equal(t, store.shards[index].deque.Len(), 3) + var keys []keyint + for store.shards[index].deque.Len() != 0 { + e := store.shards[index].deque.PopBack() + keys = append(keys, e.key) + } + require.Equal(t, []keyint{3, 4, 123}, keys) +} + +func TestDoorKeeperDynamicSize(t *testing.T) { + store := NewStore[keyint, cachedint](200000, true) + shard := store.shards[0] + require.True(t, shard.dookeeper.Capacity == 512) + for i := keyint(0); i < 5000; i++ { + shard.set(i, &Entry[keyint, cachedint]{}) + } + require.True(t, shard.dookeeper.Capacity > 100000) +} diff --git a/go/cache/theine/tlfu.go b/go/cache/theine/tlfu.go new file mode 100644 index 00000000000..f7a4f8dec51 --- /dev/null +++ b/go/cache/theine/tlfu.go @@ -0,0 +1,197 @@ +/* +Copyright 2023 The Vitess Authors. +Copyright 2023 Yiling-J + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package theine + +import ( + "sync/atomic" +) + +type TinyLfu[K cachekey, V any] struct { + slru *Slru[K, V] + sketch *CountMinSketch + size uint + counter uint + total atomic.Uint32 + hit atomic.Uint32 + hr float32 + threshold atomic.Int32 + lruFactor uint8 + step int8 +} + +func NewTinyLfu[K cachekey, V any](size uint) *TinyLfu[K, V] { + tlfu := &TinyLfu[K, V]{ + size: size, + slru: NewSlru[K, V](size), + sketch: NewCountMinSketch(), + step: 1, + } + // default threshold to -1 so all entries are admitted until cache is full + tlfu.threshold.Store(-1) + return tlfu +} + +func (t *TinyLfu[K, V]) climb() { + total := t.total.Load() + hit := t.hit.Load() + current := float32(hit) / float32(total) + delta := current - t.hr + var diff int8 + if delta > 0.0 { + if t.step < 0 { + t.step -= 1 + } else { + t.step += 1 + } + if t.step < -13 { + t.step = -13 + } else if t.step > 13 { + t.step = 13 + } + newFactor := int8(t.lruFactor) + t.step + if newFactor < 0 { + newFactor = 0 + } else if newFactor > 16 { + newFactor = 16 + } + diff = newFactor - int8(t.lruFactor) + t.lruFactor = uint8(newFactor) + } else if delta < 0.0 { + // reset + if t.step > 0 { + t.step = -1 + } else { + t.step = 1 + } + newFactor := int8(t.lruFactor) + t.step + if newFactor < 0 { + newFactor = 0 + } else if newFactor > 16 { + newFactor = 16 + } + diff = newFactor - int8(t.lruFactor) + t.lruFactor = uint8(newFactor) + } + t.threshold.Add(-int32(diff)) + t.hr = current + t.hit.Store(0) + t.total.Store(0) +} + +func (t *TinyLfu[K, V]) Set(entry *Entry[K, V]) *Entry[K, V] { + t.counter++ + if t.counter > 10*t.size { + t.climb() + t.counter = 0 + } + if entry.meta.prev == nil { + if victim := t.slru.victim(); victim != nil { + freq := int(entry.frequency.Load()) + if freq == -1 { + freq = int(t.sketch.Estimate(entry.key.Hash())) + } + evictedCount := uint(freq) + uint(t.lruFactor) + victimCount := t.sketch.Estimate(victim.key.Hash()) + if evictedCount <= uint(victimCount) { + return entry + } + } else { + count := t.slru.probation.count + t.slru.protected.count + t.sketch.EnsureCapacity(uint(count + count/100)) + } + evicted := t.slru.insert(entry) + return evicted + } + + return nil +} + +func (t *TinyLfu[K, V]) Access(item ReadBufItem[K, V]) { + t.counter++ + if t.counter > 10*t.size { + t.climb() + t.counter = 0 + } + if entry := item.entry; entry != nil { + reset := t.sketch.Add(item.hash) + if reset { + t.threshold.Store(t.threshold.Load() / 2) + } + if entry.meta.prev != nil { + var tail bool + if entry == t.slru.victim() { + tail = true + } + t.slru.access(entry) + if tail { + t.UpdateThreshold() + } + } else { + entry.frequency.Store(int32(t.sketch.Estimate(item.hash))) + } + } else { + reset := t.sketch.Add(item.hash) + if reset { + t.threshold.Store(t.threshold.Load() / 2) + } + } +} + +func (t *TinyLfu[K, V]) Remove(entry *Entry[K, V]) { + t.slru.remove(entry) +} + +func (t *TinyLfu[K, V]) UpdateCost(entry *Entry[K, V], delta int64) { + t.slru.updateCost(entry, delta) +} + +func (t *TinyLfu[K, V]) EvictEntries() []*Entry[K, V] { + removed := []*Entry[K, V]{} + + for t.slru.probation.Len()+t.slru.protected.Len() > int(t.slru.maxsize) { + entry := t.slru.probation.PopTail() + if entry == nil { + break + } + removed = append(removed, entry) + } + for t.slru.probation.Len()+t.slru.protected.Len() > int(t.slru.maxsize) { + entry := t.slru.protected.PopTail() + if entry == nil { + break + } + removed = append(removed, entry) + } + return removed +} + +func (t *TinyLfu[K, V]) UpdateThreshold() { + if t.slru.probation.Len()+t.slru.protected.Len() < int(t.slru.maxsize) { + t.threshold.Store(-1) + } else { + tail := t.slru.victim() + if tail != nil { + t.threshold.Store( + int32(t.sketch.Estimate(tail.key.Hash()) - uint(t.lruFactor)), + ) + } else { + // cache is not full + t.threshold.Store(-1) + } + } +} diff --git a/go/cache/theine/tlfu_test.go b/go/cache/theine/tlfu_test.go new file mode 100644 index 00000000000..ac6ddaabdb6 --- /dev/null +++ b/go/cache/theine/tlfu_test.go @@ -0,0 +1,156 @@ +/* +Copyright 2023 The Vitess Authors. +Copyright 2023 Yiling-J + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package theine + +import ( + "fmt" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestTlfu(t *testing.T) { + tlfu := NewTinyLfu[StringKey, string](1000) + require.Equal(t, uint(1000), tlfu.slru.probation.capacity) + require.Equal(t, uint(800), tlfu.slru.protected.capacity) + require.Equal(t, 0, tlfu.slru.probation.len) + require.Equal(t, 0, tlfu.slru.protected.len) + + var entries []*Entry[StringKey, string] + for i := 0; i < 200; i++ { + e := NewEntry(StringKey(fmt.Sprintf("%d", i)), "", 1) + evicted := tlfu.Set(e) + entries = append(entries, e) + require.Nil(t, evicted) + } + + require.Equal(t, 200, tlfu.slru.probation.len) + require.Equal(t, 0, tlfu.slru.protected.len) + + // probation -> protected + tlfu.Access(ReadBufItem[StringKey, string]{entry: entries[11]}) + require.Equal(t, 199, tlfu.slru.probation.len) + require.Equal(t, 1, tlfu.slru.protected.len) + tlfu.Access(ReadBufItem[StringKey, string]{entry: entries[11]}) + require.Equal(t, 199, tlfu.slru.probation.len) + require.Equal(t, 1, tlfu.slru.protected.len) + + for i := 200; i < 1000; i++ { + e := NewEntry(StringKey(fmt.Sprintf("%d", i)), "", 1) + entries = append(entries, e) + evicted := tlfu.Set(e) + require.Nil(t, evicted) + } + // access protected + tlfu.Access(ReadBufItem[StringKey, string]{entry: entries[11]}) + require.Equal(t, 999, tlfu.slru.probation.len) + require.Equal(t, 1, tlfu.slru.protected.len) + + evicted := tlfu.Set(NewEntry(StringKey("0a"), "", 1)) + require.Equal(t, StringKey("0a"), evicted.key) + require.Equal(t, 999, tlfu.slru.probation.len) + require.Equal(t, 1, tlfu.slru.protected.len) + + victim := tlfu.slru.victim() + require.Equal(t, StringKey("0"), victim.key) + tlfu.Access(ReadBufItem[StringKey, string]{entry: entries[991]}) + tlfu.Access(ReadBufItem[StringKey, string]{entry: entries[991]}) + tlfu.Access(ReadBufItem[StringKey, string]{entry: entries[991]}) + tlfu.Access(ReadBufItem[StringKey, string]{entry: entries[991]}) + evicted = tlfu.Set(NewEntry(StringKey("1a"), "", 1)) + require.Equal(t, StringKey("1a"), evicted.key) + require.Equal(t, 998, tlfu.slru.probation.len) + + var entries2 []*Entry[StringKey, string] + for i := 0; i < 1000; i++ { + e := NewEntry(StringKey(fmt.Sprintf("%d*", i)), "", 1) + tlfu.Set(e) + entries2 = append(entries2, e) + } + require.Equal(t, 998, tlfu.slru.probation.len) + require.Equal(t, 2, tlfu.slru.protected.len) + + for _, i := range []int{997, 998, 999} { + tlfu.Remove(entries2[i]) + tlfu.slru.probation.display() + tlfu.slru.probation.displayReverse() + tlfu.slru.protected.display() + tlfu.slru.protected.displayReverse() + } + +} + +func TestEvictEntries(t *testing.T) { + tlfu := NewTinyLfu[StringKey, string](500) + require.Equal(t, uint(500), tlfu.slru.probation.capacity) + require.Equal(t, uint(400), tlfu.slru.protected.capacity) + require.Equal(t, 0, tlfu.slru.probation.len) + require.Equal(t, 0, tlfu.slru.protected.len) + + for i := 0; i < 500; i++ { + tlfu.Set(NewEntry(StringKey(fmt.Sprintf("%d:1", i)), "", 1)) + } + require.Equal(t, 500, tlfu.slru.probation.len) + require.Equal(t, 0, tlfu.slru.protected.len) + new := NewEntry(StringKey("l:10"), "", 10) + new.frequency.Store(10) + tlfu.Set(new) + require.Equal(t, 509, tlfu.slru.probation.len) + require.Equal(t, 0, tlfu.slru.protected.len) + // 2. probation length is 509, so remove 9 entries from probation + removed := tlfu.EvictEntries() + for _, rm := range removed { + require.True(t, strings.HasSuffix(string(rm.key), ":1")) + } + require.Equal(t, 9, len(removed)) + require.Equal(t, 500, tlfu.slru.probation.len) + require.Equal(t, 0, tlfu.slru.protected.len) + + // put l:450 to probation, this will remove 1 entry, probation len is 949 now + // remove 449 entries from probation + new = NewEntry(StringKey("l:450"), "", 450) + new.frequency.Store(10) + tlfu.Set(new) + removed = tlfu.EvictEntries() + require.Equal(t, 449, len(removed)) + require.Equal(t, 500, tlfu.slru.probation.len) + require.Equal(t, 0, tlfu.slru.protected.len) + + // put l:460 to probation, this will remove 1 entry, probation len is 959 now + // remove all entries except the new l:460 one + new = NewEntry(StringKey("l:460"), "", 460) + new.frequency.Store(10) + tlfu.Set(new) + removed = tlfu.EvictEntries() + require.Equal(t, 41, len(removed)) + require.Equal(t, 460, tlfu.slru.probation.len) + require.Equal(t, 0, tlfu.slru.protected.len) + + // access + tlfu.Access(ReadBufItem[StringKey, string]{entry: new}) + require.Equal(t, 0, tlfu.slru.probation.len) + require.Equal(t, 460, tlfu.slru.protected.len) + new.cost.Store(600) + tlfu.UpdateCost(new, 140) + removed = tlfu.EvictEntries() + require.Equal(t, 1, len(removed)) + require.Equal(t, 0, tlfu.slru.probation.len) + require.Equal(t, 0, tlfu.slru.protected.len) + +} diff --git a/go/flags/endtoend/vtgate.txt b/go/flags/endtoend/vtgate.txt index 3f2752be084..1d2f665177b 100644 --- a/go/flags/endtoend/vtgate.txt +++ b/go/flags/endtoend/vtgate.txt @@ -35,9 +35,7 @@ Usage of vtgate: --enable_set_var This will enable the use of MySQL's SET_VAR query hint for certain system variables instead of using reserved connections (default true) --enable_system_settings This will enable the system settings to be changed per session at the database connection level (default true) --foreign_key_mode string This is to provide how to handle foreign key constraint in create/alter table. Valid values are: allow, disallow (default "allow") - --gate_query_cache_lfu gate server cache algorithm. when set to true, a new cache algorithm based on a TinyLFU admission policy will be used to improve cache behavior and prevent pollution from sparse queries (default true) --gate_query_cache_memory int gate server query cache size in bytes, maximum amount of memory to be cached. vtgate analyzes every incoming query and generate a query plan, these plans are being cached in a lru cache. This config controls the capacity of the lru cache. (default 33554432) - --gate_query_cache_size int gate server query cache size, maximum number of queries to be cached. vtgate analyzes every incoming query and generate a query plan, these plans are being cached in a cache. This config controls the expected amount of unique entries in the cache. (default 5000) --gateway_initial_tablet_timeout duration At startup, the tabletGateway will wait up to this duration to get at least one tablet per keyspace/shard/tablet type (default 30s) --grpc-use-effective-groups If set, and SSL is not used, will set the immediate caller's security groups from the effective caller id's groups. --grpc-use-static-authentication-callerid If set, will set the immediate caller id to the username authenticated by the static auth plugin. diff --git a/go/flags/endtoend/vttablet.txt b/go/flags/endtoend/vttablet.txt index 9b42ebb5644..b787ffbfe57 100644 --- a/go/flags/endtoend/vttablet.txt +++ b/go/flags/endtoend/vttablet.txt @@ -225,9 +225,7 @@ Usage of vttablet: --queryserver-config-passthrough-dmls query server pass through all dml statements without rewriting --queryserver-config-pool-conn-max-lifetime duration query server connection max lifetime (in seconds), vttablet manages various mysql connection pools. This config means if a connection has lived at least this long, it connection will be removed from pool upon the next time it is returned to the pool. (default 0s) --queryserver-config-pool-size int query server read pool size, connection pool is used by regular queries (non streaming, not in a transaction) (default 16) - --queryserver-config-query-cache-lfu query server cache algorithm. when set to true, a new cache algorithm based on a TinyLFU admission policy will be used to improve cache behavior and prevent pollution from sparse queries (default true) --queryserver-config-query-cache-memory int query server query cache size in bytes, maximum amount of memory to be used for caching. vttablet analyzes every incoming query and generate a query plan, these plans are being cached in a lru cache. This config controls the capacity of the lru cache. (default 33554432) - --queryserver-config-query-cache-size int query server query cache size, maximum number of queries to be cached. vttablet analyzes every incoming query and generate a query plan, these plans are being cached in a lru cache. This config controls the capacity of the lru cache. (default 5000) --queryserver-config-query-pool-timeout duration query server query pool timeout (in seconds), it is how long vttablet waits for a connection from the query pool. If set to 0 (default) then the overall query timeout is used instead. (default 0s) --queryserver-config-query-pool-waiter-cap int query server query pool waiter limit, this is the maximum number of queries that can be queued waiting to get a connection (default 5000) --queryserver-config-query-timeout duration query server query timeout (in seconds), this is the query timeout in vttablet side. If a query takes more than this timeout, it will be killed. (default 30s) diff --git a/go/test/endtoend/cluster/cluster_process.go b/go/test/endtoend/cluster/cluster_process.go index 5db2078d666..e111e4325f3 100644 --- a/go/test/endtoend/cluster/cluster_process.go +++ b/go/test/endtoend/cluster/cluster_process.go @@ -61,6 +61,7 @@ import ( const ( DefaultCell = "zone1" DefaultStartPort = 6700 + DefaultVttestEnv = "VTTEST=endtoend" ) var ( diff --git a/go/test/endtoend/cluster/mysqlctl_process.go b/go/test/endtoend/cluster/mysqlctl_process.go index c3ab377320a..b5e7cfb5a32 100644 --- a/go/test/endtoend/cluster/mysqlctl_process.go +++ b/go/test/endtoend/cluster/mysqlctl_process.go @@ -148,6 +148,8 @@ ssl_key={{.ServerKey}} } } tmpProcess.Args = append(tmpProcess.Args, "start") + tmpProcess.Env = append(tmpProcess.Env, os.Environ()...) + tmpProcess.Env = append(tmpProcess.Env, DefaultVttestEnv) log.Infof("Starting mysqlctl with command: %v", tmpProcess.Args) return tmpProcess, tmpProcess.Start() } diff --git a/go/test/endtoend/cluster/mysqlctld_process.go b/go/test/endtoend/cluster/mysqlctld_process.go index 4f51e8ca888..9a0f36e3918 100644 --- a/go/test/endtoend/cluster/mysqlctld_process.go +++ b/go/test/endtoend/cluster/mysqlctld_process.go @@ -95,6 +95,7 @@ func (mysqlctld *MysqlctldProcess) Start() error { tempProcess.Stderr = errFile tempProcess.Env = append(tempProcess.Env, os.Environ()...) + tempProcess.Env = append(tempProcess.Env, DefaultVttestEnv) tempProcess.Stdout = os.Stdout tempProcess.Stderr = os.Stderr diff --git a/go/test/endtoend/cluster/topo_process.go b/go/test/endtoend/cluster/topo_process.go index 8fd4bd1c74c..6a9ba1ec438 100644 --- a/go/test/endtoend/cluster/topo_process.go +++ b/go/test/endtoend/cluster/topo_process.go @@ -96,6 +96,7 @@ func (topo *TopoProcess) SetupEtcd() (err error) { topo.proc.Stderr = errFile topo.proc.Env = append(topo.proc.Env, os.Environ()...) + topo.proc.Env = append(topo.proc.Env, DefaultVttestEnv) log.Infof("Starting etcd with command: %v", strings.Join(topo.proc.Args, " ")) diff --git a/go/test/endtoend/cluster/vtbackup_process.go b/go/test/endtoend/cluster/vtbackup_process.go index be75026bf0d..ba508e8d593 100644 --- a/go/test/endtoend/cluster/vtbackup_process.go +++ b/go/test/endtoend/cluster/vtbackup_process.go @@ -84,6 +84,7 @@ func (vtbackup *VtbackupProcess) Setup() (err error) { vtbackup.proc.Stdout = os.Stdout vtbackup.proc.Env = append(vtbackup.proc.Env, os.Environ()...) + vtbackup.proc.Env = append(vtbackup.proc.Env, DefaultVttestEnv) log.Infof("Running vtbackup with args: %v", strings.Join(vtbackup.proc.Args, " ")) err = vtbackup.proc.Run() diff --git a/go/test/endtoend/cluster/vtctld_process.go b/go/test/endtoend/cluster/vtctld_process.go index 5e85f172ce1..b8f6cf240fc 100644 --- a/go/test/endtoend/cluster/vtctld_process.go +++ b/go/test/endtoend/cluster/vtctld_process.go @@ -74,6 +74,7 @@ func (vtctld *VtctldProcess) Setup(cell string, extraArgs ...string) (err error) vtctld.proc.Stderr = errFile vtctld.proc.Env = append(vtctld.proc.Env, os.Environ()...) + vtctld.proc.Env = append(vtctld.proc.Env, DefaultVttestEnv) log.Infof("Starting vtctld with command: %v", strings.Join(vtctld.proc.Args, " ")) diff --git a/go/test/endtoend/cluster/vtgate_process.go b/go/test/endtoend/cluster/vtgate_process.go index 3bf29c8af9b..48aecab7c1e 100644 --- a/go/test/endtoend/cluster/vtgate_process.go +++ b/go/test/endtoend/cluster/vtgate_process.go @@ -127,6 +127,7 @@ func (vtgate *VtgateProcess) Setup() (err error) { vtgate.proc.Stderr = errFile vtgate.proc.Env = append(vtgate.proc.Env, os.Environ()...) + vtgate.proc.Env = append(vtgate.proc.Env, DefaultVttestEnv) log.Infof("Running vtgate with command: %v", strings.Join(vtgate.proc.Args, " ")) diff --git a/go/test/endtoend/cluster/vtorc_process.go b/go/test/endtoend/cluster/vtorc_process.go index 34c4f3295ab..28c6355a175 100644 --- a/go/test/endtoend/cluster/vtorc_process.go +++ b/go/test/endtoend/cluster/vtorc_process.go @@ -133,6 +133,7 @@ func (orc *VTOrcProcess) Setup() (err error) { orc.proc.Stderr = errFile orc.proc.Env = append(orc.proc.Env, os.Environ()...) + orc.proc.Env = append(orc.proc.Env, DefaultVttestEnv) log.Infof("Running vtorc with command: %v", strings.Join(orc.proc.Args, " ")) diff --git a/go/test/endtoend/cluster/vttablet_process.go b/go/test/endtoend/cluster/vttablet_process.go index eb44a96f346..9f5e513c7a0 100644 --- a/go/test/endtoend/cluster/vttablet_process.go +++ b/go/test/endtoend/cluster/vttablet_process.go @@ -132,6 +132,7 @@ func (vttablet *VttabletProcess) Setup() (err error) { vttablet.proc.Stderr = errFile vttablet.proc.Env = append(vttablet.proc.Env, os.Environ()...) + vttablet.proc.Env = append(vttablet.proc.Env, DefaultVttestEnv) log.Infof("Running vttablet with command: %v", strings.Join(vttablet.proc.Args, " ")) diff --git a/go/test/endtoend/vreplication/helper_test.go b/go/test/endtoend/vreplication/helper_test.go index ad8fd5a96c7..6d032122430 100644 --- a/go/test/endtoend/vreplication/helper_test.go +++ b/go/test/endtoend/vreplication/helper_test.go @@ -20,19 +20,18 @@ import ( "context" "crypto/rand" "encoding/hex" + "encoding/json" "fmt" "io" "net/http" "os/exec" "regexp" "sort" - "strconv" "strings" "sync/atomic" "testing" "time" - "github.com/PuerkitoBio/goquery" "github.com/buger/jsonparser" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -231,12 +230,28 @@ func waitForRowCountInTablet(t *testing.T, vttablet *cluster.VttabletProcess, da } } -func validateThatQueryExecutesOnTablet(t *testing.T, conn *mysql.Conn, tablet *cluster.VttabletProcess, ksName string, query string, matchQuery string) bool { - count := getQueryCount(tablet.QueryzURL, matchQuery) +func executeOnTablet(t *testing.T, conn *mysql.Conn, tablet *cluster.VttabletProcess, ksName string, query string, matchQuery string) (int, []byte, int, []byte) { + queryStatsURL := fmt.Sprintf("http://%s:%d/debug/query_stats", tablet.TabletHostname, tablet.Port) + + count0, body0 := getQueryCount(t, queryStatsURL, matchQuery) + qr := execVtgateQuery(t, conn, ksName, query) require.NotNil(t, qr) - newCount := getQueryCount(tablet.QueryzURL, matchQuery) - return newCount == count+1 + + count1, body1 := getQueryCount(t, queryStatsURL, matchQuery) + return count0, body0, count1, body1 +} + +func assertQueryExecutesOnTablet(t *testing.T, conn *mysql.Conn, tablet *cluster.VttabletProcess, ksName string, query string, matchQuery string) { + t.Helper() + count0, body0, count1, body1 := executeOnTablet(t, conn, tablet, ksName, query, matchQuery) + assert.Equalf(t, count0+1, count1, "query %q did not execute in target;\ntried to match %q\nbefore:\n%s\n\nafter:\n%s\n\n", query, matchQuery, body0, body1) +} + +func assertQueryDoesNotExecutesOnTablet(t *testing.T, conn *mysql.Conn, tablet *cluster.VttabletProcess, ksName string, query string, matchQuery string) { + t.Helper() + count0, body0, count1, body1 := executeOnTablet(t, conn, tablet, ksName, query, matchQuery) + assert.Equalf(t, count0, count1, "query %q executed in target;\ntried to match %q\nbefore:\n%s\n\nafter:\n%s\n\n", query, matchQuery, body0, body1) } // waitForWorkflowState waits for all of the given workflow's @@ -351,77 +366,36 @@ func confirmTablesHaveSecondaryKeys(t *testing.T, tablets []*cluster.VttabletPro } } -func getHTTPBody(url string) string { +func getHTTPBody(t *testing.T, url string) []byte { resp, err := http.Get(url) - if err != nil { - log.Infof("http Get returns %+v", err) - return "" - } - if resp.StatusCode != 200 { - log.Infof("http Get returns status %d", resp.StatusCode) - return "" - } - respByte, _ := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, 200, resp.StatusCode) + defer resp.Body.Close() - body := string(respByte) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) return body } -func getQueryCount(url string, query string) int { - var headings, row []string - var rows [][]string - body := getHTTPBody(url) - doc, err := goquery.NewDocumentFromReader(strings.NewReader(body)) - if err != nil { - log.Infof("goquery parsing returns %+v\n", err) - return 0 +func getQueryCount(t *testing.T, url string, query string) (int, []byte) { + body := getHTTPBody(t, url) + + var queryStats []struct { + Query string + QueryCount uint64 } - var queryIndex, countIndex, count int - queryIndex = -1 - countIndex = -1 + err := json.Unmarshal(body, &queryStats) + require.NoError(t, err) - doc.Find("table").Each(func(index int, tablehtml *goquery.Selection) { - tablehtml.Find("tr").Each(func(indextr int, rowhtml *goquery.Selection) { - rowhtml.Find("th").Each(func(indexth int, tableheading *goquery.Selection) { - heading := tableheading.Text() - if heading == "Query" { - queryIndex = indexth - } - if heading == "Count" { - countIndex = indexth - } - headings = append(headings, heading) - }) - rowhtml.Find("td").Each(func(indexth int, tablecell *goquery.Selection) { - row = append(row, tablecell.Text()) - }) - rows = append(rows, row) - row = nil - }) - }) - if queryIndex == -1 || countIndex == -1 { - log.Infof("Queryz response is incorrect") - return 0 - } - for _, row := range rows { - if len(row) != len(headings) { - continue - } - filterChars := []string{"_", "`"} - //Queries seem to include non-printable characters at times and hence equality fails unless these are removed - re := regexp.MustCompile("[[:^ascii:]]") - foundQuery := re.ReplaceAllLiteralString(row[queryIndex], "") - cleanQuery := re.ReplaceAllLiteralString(query, "") - for _, filterChar := range filterChars { - foundQuery = strings.ReplaceAll(foundQuery, filterChar, "") - cleanQuery = strings.ReplaceAll(cleanQuery, filterChar, "") - } - if foundQuery == cleanQuery || strings.Contains(foundQuery, cleanQuery) { - count, _ = strconv.Atoi(row[countIndex]) + for _, q := range queryStats { + if strings.Contains(q.Query, query) { + return int(q.QueryCount), body } } - return count + + return 0, body } func validateDryRunResults(t *testing.T, output string, want []string) { @@ -530,8 +504,8 @@ func getDebugVar(t *testing.T, port int, varPath []string) (string, error) { var err error url := fmt.Sprintf("http://localhost:%d/debug/vars", port) log.Infof("url: %s, varPath: %s", url, strings.Join(varPath, ":")) - body := getHTTPBody(url) - val, _, _, err = jsonparser.Get([]byte(body), varPath...) + body := getHTTPBody(t, url) + val, _, _, err = jsonparser.Get(body, varPath...) require.NoError(t, err) return string(val), nil } diff --git a/go/test/endtoend/vreplication/resharding_workflows_v2_test.go b/go/test/endtoend/vreplication/resharding_workflows_v2_test.go index ec20e1d92ca..993da344905 100644 --- a/go/test/endtoend/vreplication/resharding_workflows_v2_test.go +++ b/go/test/endtoend/vreplication/resharding_workflows_v2_test.go @@ -218,7 +218,7 @@ func validateReadsRoute(t *testing.T, tabletTypes string, tablet *cluster.Vttabl for _, tt := range []string{"replica", "rdonly"} { destination := fmt.Sprintf("%s:%s@%s", tablet.Keyspace, tablet.Shard, tt) if strings.Contains(tabletTypes, tt) { - require.True(t, validateThatQueryExecutesOnTablet(t, vtgateConn, tablet, destination, readQuery, readQuery)) + assertQueryExecutesOnTablet(t, vtgateConn, tablet, destination, readQuery, readQuery) } } } @@ -233,17 +233,17 @@ func validateReadsRouteToTarget(t *testing.T, tabletTypes string) { func validateWritesRouteToSource(t *testing.T) { insertQuery := "insert into customer(name, cid) values('tempCustomer2', 200)" - matchInsertQuery := "insert into customer(name, cid) values" - require.True(t, validateThatQueryExecutesOnTablet(t, vtgateConn, sourceTab, "customer", insertQuery, matchInsertQuery)) + matchInsertQuery := "insert into customer(`name`, cid) values" + assertQueryExecutesOnTablet(t, vtgateConn, sourceTab, "customer", insertQuery, matchInsertQuery) execVtgateQuery(t, vtgateConn, "customer", "delete from customer where cid > 100") } func validateWritesRouteToTarget(t *testing.T) { insertQuery := "insert into customer(name, cid) values('tempCustomer3', 101)" - matchInsertQuery := "insert into customer(name, cid) values" - require.True(t, validateThatQueryExecutesOnTablet(t, vtgateConn, targetTab2, "customer", insertQuery, matchInsertQuery)) + matchInsertQuery := "insert into customer(`name`, cid) values" + assertQueryExecutesOnTablet(t, vtgateConn, targetTab2, "customer", insertQuery, matchInsertQuery) insertQuery = "insert into customer(name, cid) values('tempCustomer3', 102)" - require.True(t, validateThatQueryExecutesOnTablet(t, vtgateConn, targetTab1, "customer", insertQuery, matchInsertQuery)) + assertQueryExecutesOnTablet(t, vtgateConn, targetTab1, "customer", insertQuery, matchInsertQuery) execVtgateQuery(t, vtgateConn, "customer", "delete from customer where cid > 100") } diff --git a/go/test/endtoend/vreplication/vdiff_helper_test.go b/go/test/endtoend/vreplication/vdiff_helper_test.go index 9f5ce973c40..982ea04c957 100644 --- a/go/test/endtoend/vreplication/vdiff_helper_test.go +++ b/go/test/endtoend/vreplication/vdiff_helper_test.go @@ -70,7 +70,7 @@ func doVDiff1(t *testing.T, ksWorkflow, cells string) { diffReports := make(map[string]*wrangler.DiffReport) t.Logf("vdiff1 output: %s", output) err = json.Unmarshal([]byte(output), &diffReports) - require.NoError(t, err) + require.NoErrorf(t, err, "full output: %s", output) if len(diffReports) < 1 { t.Fatal("VDiff did not return a valid json response " + output + "\n") } diff --git a/go/test/endtoend/vreplication/vreplication_test.go b/go/test/endtoend/vreplication/vreplication_test.go index 12f8ee7dad6..38c7aa8faa3 100644 --- a/go/test/endtoend/vreplication/vreplication_test.go +++ b/go/test/endtoend/vreplication/vreplication_test.go @@ -787,10 +787,10 @@ func shardCustomer(t *testing.T, testReverse bool, cells []*Cell, sourceCellOrAl }) query := "select cid from customer" - require.True(t, validateThatQueryExecutesOnTablet(t, vtgateConn, productTab, "product", query, query)) + assertQueryExecutesOnTablet(t, vtgateConn, productTab, "product", query, query) insertQuery1 := "insert into customer(cid, name) values(1001, 'tempCustomer1')" matchInsertQuery1 := "insert into customer(cid, `name`) values (:vtg1 /* INT64 */, :vtg2 /* VARCHAR */)" - require.True(t, validateThatQueryExecutesOnTablet(t, vtgateConn, productTab, "product", insertQuery1, matchInsertQuery1)) + assertQueryExecutesOnTablet(t, vtgateConn, productTab, "product", insertQuery1, matchInsertQuery1) // FIXME for some reason, these inserts fails on mac, need to investigate, some // vreplication bug because of case insensitiveness of table names on mac? @@ -809,7 +809,7 @@ func shardCustomer(t *testing.T, testReverse bool, cells []*Cell, sourceCellOrAl vdiff1(t, ksWorkflow, "") switchReadsDryRun(t, workflowType, allCellNames, ksWorkflow, dryRunResultsReadCustomerShard) switchReads(t, workflowType, allCellNames, ksWorkflow, false) - require.True(t, validateThatQueryExecutesOnTablet(t, vtgateConn, productTab, "customer", query, query)) + assertQueryExecutesOnTablet(t, vtgateConn, productTab, "customer", query, query) var commit func(t *testing.T) if withOpenTx { @@ -840,14 +840,14 @@ func shardCustomer(t *testing.T, testReverse bool, cells []*Cell, sourceCellOrAl ksShards := []string{"product/0", "customer/-80", "customer/80-"} printShardPositions(vc, ksShards) insertQuery2 := "insert into customer(name, cid) values('tempCustomer2', 100)" - matchInsertQuery2 := "insert into customer(`name`, cid) values (:vtg1 /* VARCHAR */, :_cid0)" - require.False(t, validateThatQueryExecutesOnTablet(t, vtgateConn, productTab, "customer", insertQuery2, matchInsertQuery2)) + matchInsertQuery2 := "insert into customer(`name`, cid) values (:vtg1 /* VARCHAR */, :_cid_0)" + assertQueryDoesNotExecutesOnTablet(t, vtgateConn, productTab, "customer", insertQuery2, matchInsertQuery2) insertQuery2 = "insert into customer(name, cid) values('tempCustomer3', 101)" // ID 101, hence due to reverse_bits in shard 80- - require.True(t, validateThatQueryExecutesOnTablet(t, vtgateConn, customerTab2, "customer", insertQuery2, matchInsertQuery2)) + assertQueryExecutesOnTablet(t, vtgateConn, customerTab2, "customer", insertQuery2, matchInsertQuery2) insertQuery2 = "insert into customer(name, cid) values('tempCustomer4', 102)" // ID 102, hence due to reverse_bits in shard -80 - require.True(t, validateThatQueryExecutesOnTablet(t, vtgateConn, customerTab1, "customer", insertQuery2, matchInsertQuery2)) + assertQueryExecutesOnTablet(t, vtgateConn, customerTab1, "customer", insertQuery2, matchInsertQuery2) execVtgateQuery(t, vtgateConn, "customer", "update customer set meta = convert(x'7b7d' using utf8mb4) where cid = 1") if testReverse { @@ -862,12 +862,12 @@ func shardCustomer(t *testing.T, testReverse bool, cells []*Cell, sourceCellOrAl require.Contains(t, output, "'customer.bmd5'") insertQuery1 = "insert into customer(cid, name) values(1002, 'tempCustomer5')" - require.True(t, validateThatQueryExecutesOnTablet(t, vtgateConn, productTab, "product", insertQuery1, matchInsertQuery1)) + assertQueryExecutesOnTablet(t, vtgateConn, productTab, "product", insertQuery1, matchInsertQuery1) // both inserts go into 80-, this tests the edge-case where a stream (-80) has no relevant new events after the previous switch insertQuery1 = "insert into customer(cid, name) values(1003, 'tempCustomer6')" - require.False(t, validateThatQueryExecutesOnTablet(t, vtgateConn, customerTab1, "customer", insertQuery1, matchInsertQuery1)) + assertQueryDoesNotExecutesOnTablet(t, vtgateConn, customerTab1, "customer", insertQuery1, matchInsertQuery1) insertQuery1 = "insert into customer(cid, name) values(1004, 'tempCustomer7')" - require.False(t, validateThatQueryExecutesOnTablet(t, vtgateConn, customerTab2, "customer", insertQuery1, matchInsertQuery1)) + assertQueryDoesNotExecutesOnTablet(t, vtgateConn, customerTab2, "customer", insertQuery1, matchInsertQuery1) waitForNoWorkflowLag(t, vc, targetKs, workflow) @@ -902,11 +902,11 @@ func shardCustomer(t *testing.T, testReverse bool, cells []*Cell, sourceCellOrAl require.True(t, found) insertQuery2 = "insert into customer(name, cid) values('tempCustomer8', 103)" // ID 103, hence due to reverse_bits in shard 80- - require.False(t, validateThatQueryExecutesOnTablet(t, vtgateConn, productTab, "customer", insertQuery2, matchInsertQuery2)) + assertQueryDoesNotExecutesOnTablet(t, vtgateConn, productTab, "customer", insertQuery2, matchInsertQuery2) insertQuery2 = "insert into customer(name, cid) values('tempCustomer10', 104)" // ID 105, hence due to reverse_bits in shard -80 - require.True(t, validateThatQueryExecutesOnTablet(t, vtgateConn, customerTab1, "customer", insertQuery2, matchInsertQuery2)) + assertQueryExecutesOnTablet(t, vtgateConn, customerTab1, "customer", insertQuery2, matchInsertQuery2) insertQuery2 = "insert into customer(name, cid) values('tempCustomer9', 105)" // ID 104, hence due to reverse_bits in shard 80- - require.True(t, validateThatQueryExecutesOnTablet(t, vtgateConn, customerTab2, "customer", insertQuery2, matchInsertQuery2)) + assertQueryExecutesOnTablet(t, vtgateConn, customerTab2, "customer", insertQuery2, matchInsertQuery2) execVtgateQuery(t, vtgateConn, "customer", "delete from customer where name like 'tempCustomer%'") waitForRowCountInTablet(t, customerTab1, "customer", "customer", 1) diff --git a/go/test/endtoend/vtgate/queries/normalize/normalize_test.go b/go/test/endtoend/vtgate/queries/normalize/normalize_test.go index 4fa4313e76c..52e30accf03 100644 --- a/go/test/endtoend/vtgate/queries/normalize/normalize_test.go +++ b/go/test/endtoend/vtgate/queries/normalize/normalize_test.go @@ -25,11 +25,11 @@ import ( "testing" "time" - "vitess.io/vitess/go/test/endtoend/utils" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "vitess.io/vitess/go/test/endtoend/utils" + "vitess.io/vitess/go/mysql" ) @@ -46,38 +46,24 @@ func TestNormalizeAllFields(t *testing.T) { assert.Equal(t, 1, len(qr.Rows), "wrong number of table rows, expected 1 but had %d. Results: %v", len(qr.Rows), qr.Rows) // Now need to figure out the best way to check the normalized query in the planner cache... - results, err := getPlanCache(fmt.Sprintf("%s:%d", vtParams.Host, clusterInstance.VtgateProcess.Port)) - require.Nil(t, err) - found := false - for _, record := range results { - key := record["Key"].(string) - if key == normalizedInsertQuery { - found = true - break - } - } - assert.Truef(t, found, "correctly normalized record not found in planner cache %v", results) + results := getPlanCache(t, fmt.Sprintf("%s:%d", vtParams.Host, clusterInstance.VtgateProcess.Port)) + assert.Contains(t, results, normalizedInsertQuery) } -func getPlanCache(vtgateHostPort string) ([]map[string]any, error) { - var results []map[string]any +func getPlanCache(t *testing.T, vtgateHostPort string) map[string]any { + var results map[string]any client := http.Client{ Timeout: 10 * time.Second, } resp, err := client.Get(fmt.Sprintf("http://%s/debug/query_plans", vtgateHostPort)) - if err != nil { - return results, err - } + require.NoError(t, err) defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) - if err != nil { - return results, err - } + require.NoError(t, err) err = json.Unmarshal(body, &results) - if err != nil { - return results, err - } + require.NoErrorf(t, err, "failed to unmarshal results. contents:\n%s\n\n", body) - return results, nil + return results } diff --git a/go/vt/servenv/servenv.go b/go/vt/servenv/servenv.go index 06942d4f4b9..0c1deddeb10 100644 --- a/go/vt/servenv/servenv.go +++ b/go/vt/servenv/servenv.go @@ -516,3 +516,10 @@ func RegisterFlagsForTopoBinaries(registerFlags func(fs *pflag.FlagSet)) { OnParseFor(cmd, registerFlags) } } + +// TestingEndtoend is true when this Vitess binary is being ran as part of an endtoend test suite +var TestingEndtoend = false + +func init() { + TestingEndtoend = os.Getenv("VTTEST") == "endtoend" +} diff --git a/go/vt/vtexplain/vtexplain_vtgate.go b/go/vt/vtexplain/vtexplain_vtgate.go index a98e6c3a724..bbeb99e0e36 100644 --- a/go/vt/vtexplain/vtexplain_vtgate.go +++ b/go/vt/vtexplain/vtexplain_vtgate.go @@ -25,10 +25,10 @@ import ( "sort" "strings" + "vitess.io/vitess/go/cache/theine" "vitess.io/vitess/go/vt/vtgate/logstats" "vitess.io/vitess/go/vt/vtgate/vindexes" - "vitess.io/vitess/go/cache" "vitess.io/vitess/go/vt/topo" "vitess.io/vitess/go/vt/topo/memorytopo" @@ -74,8 +74,9 @@ func (vte *VTExplain) initVtgateExecutor(ctx context.Context, vSchemaStr, ksShar streamSize := 10 var schemaTracker vtgate.SchemaInfo // no schema tracker for these tests queryLogBufferSize := 10 - plans := cache.NewDefaultCacheImpl(cache.DefaultConfig) - vte.vtgateExecutor = vtgate.NewExecutor(ctx, vte.explainTopo, vtexplainCell, resolver, opts.Normalize, false, streamSize, plans, schemaTracker, false, opts.PlannerVersion, streamlog.New[*logstats.LogStats]("VTGate", queryLogBufferSize)) + plans := theine.NewStore[vtgate.PlanCacheKey, *engine.Plan](4*1024*1024, false) + vte.vtgateExecutor = vtgate.NewExecutor(ctx, vte.explainTopo, vtexplainCell, resolver, opts.Normalize, false, streamSize, plans, schemaTracker, false, opts.PlannerVersion) + vte.vtgateExecutor.SetQueryLogger(streamlog.New[*logstats.LogStats]("VTGate", queryLogBufferSize)) return nil } @@ -207,29 +208,27 @@ func (vte *VTExplain) vtgateExecute(sql string) ([]*engine.Plan, map[string]*Tab // This will ensure that the commit/rollback order is predictable. vte.sortShardSession() - // use the plan cache to get the set of plans used for this query, then - // clear afterwards for the next run - planCache := vte.vtgateExecutor.Plans() - _, err := vte.vtgateExecutor.Execute(context.Background(), nil, "VtexplainExecute", vtgate.NewSafeSession(vte.vtgateSession), sql, nil) if err != nil { for _, tc := range vte.explainTopo.TabletConns { tc.tabletQueries = nil tc.mysqlQueries = nil } - planCache.Clear() - + vte.vtgateExecutor.ClearPlans() return nil, nil, vterrors.Wrapf(err, "vtexplain execute error in '%s'", sql) } var plans []*engine.Plan - planCache.ForEach(func(value any) bool { - plan := value.(*engine.Plan) + + // use the plan cache to get the set of plans used for this query, then + // clear afterwards for the next run + vte.vtgateExecutor.ForEachPlan(func(plan *engine.Plan) bool { plan.ExecTime = 0 plans = append(plans, plan) return true }) - planCache.Clear() + + vte.vtgateExecutor.ClearPlans() tabletActions := make(map[string]*TabletActions) for shard, tc := range vte.explainTopo.TabletConns { diff --git a/go/vt/vtgate/executor.go b/go/vt/vtgate/executor.go index 07b1ba79f42..5b94183a950 100644 --- a/go/vt/vtgate/executor.go +++ b/go/vt/vtgate/executor.go @@ -17,11 +17,8 @@ limitations under the License. package vtgate import ( - "bufio" "bytes" "context" - "crypto/sha256" - "encoding/hex" "encoding/json" "fmt" "io" @@ -33,10 +30,11 @@ import ( "github.com/spf13/pflag" + "vitess.io/vitess/go/cache/theine" "vitess.io/vitess/go/streamlog" + "vitess.io/vitess/go/vt/vthash" "vitess.io/vitess/go/acl" - "vitess.io/vitess/go/cache" "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/stats" @@ -105,9 +103,11 @@ type Executor struct { mu sync.Mutex vschema *vindexes.VSchema streamSize int - plans cache.Cache vschemaStats *VSchemaStats + plans *PlanCache + epoch atomic.Uint32 + normalize bool warnShardedOnly bool @@ -127,6 +127,15 @@ const pathQueryPlans = "/debug/query_plans" const pathScatterStats = "/debug/scatter_stats" const pathVSchema = "/debug/vschema" +type PlanCacheKey = theine.HashKey256 +type PlanCache = theine.Store[PlanCacheKey, *engine.Plan] + +func DefaultPlanCache() *PlanCache { + // when being endtoend tested, disable the doorkeeper to ensure reproducible results + doorkeeper := !servenv.TestingEndtoend + return theine.NewStore[PlanCacheKey, *engine.Plan](queryPlanCacheMemory, doorkeeper) +} + // NewExecutor creates a new Executor. func NewExecutor( ctx context.Context, @@ -135,11 +144,10 @@ func NewExecutor( resolver *Resolver, normalize, warnOnShardedOnly bool, streamSize int, - plans cache.Cache, + plans *PlanCache, schemaTracker SchemaInfo, noScatter bool, pv plancontext.PlannerVersion, - queryLogger *streamlog.StreamLogger[*logstats.LogStats], ) *Executor { e := &Executor{ serv: serv, @@ -147,14 +155,13 @@ func NewExecutor( resolver: resolver, scatterConn: resolver.scatterConn, txConn: resolver.scatterConn.txConn, - plans: plans, normalize: normalize, warnShardedOnly: warnOnShardedOnly, streamSize: streamSize, schemaTracker: schemaTracker, allowScatter: !noScatter, pv: pv, - queryLogger: queryLogger, + plans: plans, } vschemaacl.Init() @@ -172,19 +179,19 @@ func NewExecutor( return int64(e.plans.Len()) }) stats.NewGaugeFunc("QueryPlanCacheSize", "Query plan cache size", func() int64 { - return e.plans.UsedCapacity() + return int64(e.plans.UsedCapacity()) }) stats.NewGaugeFunc("QueryPlanCacheCapacity", "Query plan cache capacity", func() int64 { - return e.plans.MaxCapacity() + return int64(e.plans.MaxCapacity()) }) stats.NewCounterFunc("QueryPlanCacheEvictions", "Query plan cache evictions", func() int64 { - return e.plans.Evictions() + return e.plans.Metrics.Evicted() }) stats.NewCounterFunc("QueryPlanCacheHits", "Query plan cache hits", func() int64 { - return e.plans.Hits() + return e.plans.Metrics.Hits() }) stats.NewCounterFunc("QueryPlanCacheMisses", "Query plan cache misses", func() int64 { - return e.plans.Misses() + return e.plans.Metrics.Hits() }) servenv.HTTPHandle(pathQueryPlans, e) servenv.HTTPHandle(pathScatterStats, e) @@ -977,7 +984,7 @@ func (e *Executor) SaveVSchema(vschema *vindexes.VSchema, stats *VSchemaStats) { e.vschema = vschema } e.vschemaStats = stats - e.plans.Clear() + e.ClearPlans() if vschemaCounters != nil { vschemaCounters.Add("Reload", 1) @@ -1070,37 +1077,23 @@ func (e *Executor) getPlan( return e.cacheAndBuildStatement(ctx, vcursor, query, stmt, reservedVars, bindVarNeeds, logStats) } -func (e *Executor) hashPlan(ctx context.Context, vcursor *vcursorImpl, query string) string { - planHash := sha256.New() - - { - // use a bufio.Writer to accumulate writes instead of writing directly to the hasher - buf := bufio.NewWriter(planHash) - vcursor.keyForPlan(ctx, query, buf) - buf.Flush() - } +func (e *Executor) hashPlan(ctx context.Context, vcursor *vcursorImpl, query string) PlanCacheKey { + hasher := vthash.New256() + vcursor.keyForPlan(ctx, query, hasher) - return hex.EncodeToString(planHash.Sum(nil)) + var planKey PlanCacheKey + hasher.Sum(planKey[:0]) + return planKey } -func (e *Executor) cacheAndBuildStatement( +func (e *Executor) buildStatement( ctx context.Context, vcursor *vcursorImpl, query string, stmt sqlparser.Statement, reservedVars *sqlparser.ReservedVars, bindVarNeeds *sqlparser.BindVarNeeds, - logStats *logstats.LogStats, ) (*engine.Plan, error) { - planKey := e.hashPlan(ctx, vcursor, query) - planCachable := sqlparser.CachePlan(stmt) && vcursor.safeSession.cachePlan() - if planCachable { - if plan, ok := e.plans.Get(planKey); ok { - logStats.CachedPlan = true - return plan.(*engine.Plan), nil - } - } - plan, err := planbuilder.BuildFromStmt(ctx, query, stmt, reservedVars, vcursor, bindVarNeeds, enableOnlineDDL, enableDirectDDL) if err != nil { return nil, err @@ -1110,13 +1103,32 @@ func (e *Executor) cacheAndBuildStatement( vcursor.warnings = nil err = e.checkThatPlanIsValid(stmt, plan) - // Only cache the plan if it is valid (i.e. does not scatter) - if err == nil && planCachable { - e.plans.Set(planKey, plan) - } return plan, err } +func (e *Executor) cacheAndBuildStatement( + ctx context.Context, + vcursor *vcursorImpl, + query string, + stmt sqlparser.Statement, + reservedVars *sqlparser.ReservedVars, + bindVarNeeds *sqlparser.BindVarNeeds, + logStats *logstats.LogStats, +) (*engine.Plan, error) { + planCachable := sqlparser.CachePlan(stmt) && vcursor.safeSession.cachePlan() + if planCachable { + planKey := e.hashPlan(ctx, vcursor, query) + + var plan *engine.Plan + var err error + plan, logStats.CachedPlan, err = e.plans.GetOrLoad(planKey, e.epoch.Load(), func() (*engine.Plan, error) { + return e.buildStatement(ctx, vcursor, query, stmt, reservedVars, bindVarNeeds) + }) + return plan, err + } + return e.buildStatement(ctx, vcursor, query, stmt, reservedVars, bindVarNeeds) +} + func (e *Executor) canNormalizeStatement(stmt sqlparser.Statement, setVarComment string) bool { return sqlparser.CanNormalize(stmt) || setVarComment != "" } @@ -1149,18 +1161,10 @@ func prepareSetVarComment(vcursor *vcursorImpl, stmt sqlparser.Statement) (strin return strings.TrimSpace(res.String()), nil } -type cacheItem struct { - Key string - Value *engine.Plan -} - -func (e *Executor) debugCacheEntries() (items []cacheItem) { - e.plans.ForEach(func(value any) bool { - plan := value.(*engine.Plan) - items = append(items, cacheItem{ - Key: plan.Original, - Value: plan, - }) +func (e *Executor) debugCacheEntries() (items map[string]*engine.Plan) { + items = make(map[string]*engine.Plan) + e.ForEachPlan(func(plan *engine.Plan) bool { + items[plan.Original] = plan return true }) return @@ -1198,10 +1202,20 @@ func returnAsJSON(response http.ResponseWriter, stuff any) { } // Plans returns the LRU plan cache -func (e *Executor) Plans() cache.Cache { +func (e *Executor) Plans() *PlanCache { return e.plans } +func (e *Executor) ForEachPlan(each func(plan *engine.Plan) bool) { + e.plans.Range(e.epoch.Load(), func(_ PlanCacheKey, value *engine.Plan) bool { + return each(value) + }) +} + +func (e *Executor) ClearPlans() { + e.epoch.Add(1) +} + func (e *Executor) updateQueryCounts(planType, keyspace, tableName string, shardQueries int64) { queriesProcessed.Add(planType, 1) queriesRouted.Add(planType, shardQueries) diff --git a/go/vt/vtgate/executor_framework_test.go b/go/vt/vtgate/executor_framework_test.go index a1e1e91d53b..107215d6f4d 100644 --- a/go/vt/vtgate/executor_framework_test.go +++ b/go/vt/vtgate/executor_framework_test.go @@ -28,9 +28,10 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "vitess.io/vitess/go/cache/theine" "vitess.io/vitess/go/test/utils" + "vitess.io/vitess/go/vt/vtgate/engine" - "vitess.io/vitess/go/cache" "vitess.io/vitess/go/constants/sidecar" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/streamlog" @@ -177,8 +178,15 @@ func createExecutorEnv(t testing.TB) (executor *Executor, sbc1, sbc2, sbclookup _ = hc.AddTestTablet(cell, "2", 3, KsTestUnsharded, "0", topodatapb.TabletType_REPLICA, true, 1, nil) queryLogger := streamlog.New[*logstats.LogStats]("VTGate", queryLogBufferSize) - plans := cache.NewDefaultCacheImpl(cache.DefaultConfig) - executor = NewExecutor(ctx, serv, cell, resolver, false, false, testBufferSize, plans, nil, false, querypb.ExecuteOptions_Gen4, queryLogger) + + // All these vtgate tests expect plans to be immediately cached after first use; + // this is not the actual behavior of the system in a production context because we use a doorkeeper + // that sometimes can cause a plan to not be cached the very first time it's seen, to prevent + // one-off queries from thrashing the cache. Disable the doorkeeper in the tests to prevent flakiness. + plans := theine.NewStore[PlanCacheKey, *engine.Plan](queryPlanCacheMemory, false) + + executor = NewExecutor(ctx, serv, cell, resolver, false, false, testBufferSize, plans, nil, false, querypb.ExecuteOptions_Gen4) + executor.SetQueryLogger(queryLogger) key.AnyShardPicker = DestinationAnyShardPickerFirstShard{} @@ -209,8 +217,9 @@ func createCustomExecutor(t testing.TB, vschema string) (executor *Executor, sbc sbclookup = hc.AddTestTablet(cell, "0", 1, KsTestUnsharded, "0", topodatapb.TabletType_PRIMARY, true, 1, nil) queryLogger := streamlog.New[*logstats.LogStats]("VTGate", queryLogBufferSize) - plans := cache.NewDefaultCacheImpl(cache.DefaultConfig) - executor = NewExecutor(ctx, serv, cell, resolver, false, false, testBufferSize, plans, nil, false, querypb.ExecuteOptions_Gen4, queryLogger) + plans := DefaultPlanCache() + executor = NewExecutor(ctx, serv, cell, resolver, false, false, testBufferSize, plans, nil, false, querypb.ExecuteOptions_Gen4) + executor.SetQueryLogger(queryLogger) t.Cleanup(func() { defer utils.EnsureNoLeaks(t) @@ -246,8 +255,9 @@ func createCustomExecutorSetValues(t testing.TB, vschema string, values []*sqlty sbclookup = hc.AddTestTablet(cell, "0", 1, KsTestUnsharded, "0", topodatapb.TabletType_PRIMARY, true, 1, nil) queryLogger := streamlog.New[*logstats.LogStats]("VTGate", queryLogBufferSize) - plans := cache.NewDefaultCacheImpl(cache.DefaultConfig) - executor = NewExecutor(ctx, serv, cell, resolver, false, false, testBufferSize, plans, nil, false, querypb.ExecuteOptions_Gen4, queryLogger) + plans := DefaultPlanCache() + executor = NewExecutor(ctx, serv, cell, resolver, false, false, testBufferSize, plans, nil, false, querypb.ExecuteOptions_Gen4) + executor.SetQueryLogger(queryLogger) t.Cleanup(func() { defer utils.EnsureNoLeaks(t) diff --git a/go/vt/vtgate/executor_scatter_stats.go b/go/vt/vtgate/executor_scatter_stats.go index cfe0b7435f2..beaa60d7012 100644 --- a/go/vt/vtgate/executor_scatter_stats.go +++ b/go/vt/vtgate/executor_scatter_stats.go @@ -62,8 +62,7 @@ func (e *Executor) gatherScatterStats() (statsResults, error) { plans := make([]*engine.Plan, 0) routes := make([]*engine.Route, 0) // First we go over all plans and collect statistics and all query plans for scatter queries - e.plans.ForEach(func(value any) bool { - plan := value.(*engine.Plan) + e.ForEachPlan(func(plan *engine.Plan) bool { scatter := engine.Find(findScatter, plan.Instructions) readOnly := !engine.Exists(isUpdating, plan.Instructions) isScatter := scatter != nil diff --git a/go/vt/vtgate/executor_scatter_stats_test.go b/go/vt/vtgate/executor_scatter_stats_test.go index 9a9e54517e7..84dd2744e8b 100644 --- a/go/vt/vtgate/executor_scatter_stats_test.go +++ b/go/vt/vtgate/executor_scatter_stats_test.go @@ -19,6 +19,7 @@ package vtgate import ( "net/http/httptest" "testing" + "time" "github.com/stretchr/testify/require" @@ -67,7 +68,7 @@ func TestScatterStatsHttpWriting(t *testing.T) { _, err = executor.Execute(ctx, nil, "TestExecutorResultsExceeded", session, query4, nil) require.NoError(t, err) - executor.plans.Wait() + time.Sleep(500 * time.Millisecond) recorder := httptest.NewRecorder() executor.WriteScatterStats(recorder) diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index b8d2625ec04..4b99eef6c33 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -37,7 +37,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "vitess.io/vitess/go/cache" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/test/utils" "vitess.io/vitess/go/vt/discovery" @@ -1561,8 +1560,10 @@ func TestStreamSelectIN(t *testing.T) { func createExecutor(ctx context.Context, serv *sandboxTopo, cell string, resolver *Resolver) *Executor { queryLogger := streamlog.New[*logstats.LogStats]("VTGate", queryLogBufferSize) - plans := cache.NewDefaultCacheImpl(cache.DefaultConfig) - return NewExecutor(ctx, serv, cell, resolver, false, false, testBufferSize, plans, nil, false, querypb.ExecuteOptions_Gen4, queryLogger) + plans := DefaultPlanCache() + ex := NewExecutor(ctx, serv, cell, resolver, false, false, testBufferSize, plans, nil, false, querypb.ExecuteOptions_Gen4) + ex.SetQueryLogger(queryLogger) + return ex } func TestSelectScatter(t *testing.T) { @@ -3186,8 +3187,9 @@ func TestStreamOrderByLimitWithMultipleResults(t *testing.T) { } queryLogger := streamlog.New[*logstats.LogStats]("VTGate", queryLogBufferSize) - plans := cache.NewDefaultCacheImpl(cache.DefaultConfig) - executor := NewExecutor(ctx, serv, cell, resolver, true, false, testBufferSize, plans, nil, false, querypb.ExecuteOptions_Gen4, queryLogger) + plans := DefaultPlanCache() + executor := NewExecutor(ctx, serv, cell, resolver, true, false, testBufferSize, plans, nil, false, querypb.ExecuteOptions_Gen4) + executor.SetQueryLogger(queryLogger) defer executor.Close() // some sleep for all goroutines to start time.Sleep(100 * time.Millisecond) @@ -4094,12 +4096,12 @@ func TestSelectCFC(t *testing.T) { for { select { case <-timeout: - t.Fatal("not able to cache a plan withing 10 seconds.") + t.Fatal("not able to cache a plan within 30 seconds.") case <-time.After(5 * time.Millisecond): // should be able to find cache entry before the timeout. cacheItems := executor.debugCacheEntries() for _, item := range cacheItems { - if strings.Contains(item.Key, "c2 from tbl_cfc where c1 like") { + if strings.Contains(item.Original, "c2 from tbl_cfc where c1 like") { return } } diff --git a/go/vt/vtgate/executor_stream_test.go b/go/vt/vtgate/executor_stream_test.go index 6f6dcd9f6b4..e5c730eb157 100644 --- a/go/vt/vtgate/executor_stream_test.go +++ b/go/vt/vtgate/executor_stream_test.go @@ -27,7 +27,6 @@ import ( vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" "vitess.io/vitess/go/vt/vtgate/logstats" - "vitess.io/vitess/go/cache" "vitess.io/vitess/go/vt/discovery" topodatapb "vitess.io/vitess/go/vt/proto/topodata" @@ -67,8 +66,11 @@ func TestStreamSQLSharded(t *testing.T) { _ = hc.AddTestTablet(cell, shard, 1, "TestExecutor", shard, topodatapb.TabletType_PRIMARY, true, 1, nil) } queryLogger := streamlog.New[*logstats.LogStats]("VTGate", queryLogBufferSize) - plans := cache.NewDefaultCacheImpl(cache.DefaultConfig) - executor := NewExecutor(ctx, serv, cell, resolver, false, false, testBufferSize, plans, nil, false, querypb.ExecuteOptions_Gen4, queryLogger) + plans := DefaultPlanCache() + + executor := NewExecutor(ctx, serv, cell, resolver, false, false, testBufferSize, plans, nil, false, querypb.ExecuteOptions_Gen4) + executor.SetQueryLogger(queryLogger) + defer executor.Close() sql := "stream * from sharded_user_msgs" diff --git a/go/vt/vtgate/executor_test.go b/go/vt/vtgate/executor_test.go index 17d73997125..2ab45f1ef42 100644 --- a/go/vt/vtgate/executor_test.go +++ b/go/vt/vtgate/executor_test.go @@ -28,6 +28,7 @@ import ( "sort" "strings" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/google/safehtml/template" @@ -35,7 +36,6 @@ import ( "github.com/stretchr/testify/require" "google.golang.org/protobuf/proto" - "vitess.io/vitess/go/cache" "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/mysql/sqlerror" "vitess.io/vitess/go/sqltypes" @@ -1665,13 +1665,9 @@ func TestGetPlanUnnormalized(t *testing.T) { } } -func assertCacheSize(t *testing.T, c cache.Cache, expected int) { +func assertCacheSize(t *testing.T, c *PlanCache, expected int) { t.Helper() - var size int - c.ForEach(func(_ any) bool { - size++ - return true - }) + size := c.Len() if size != expected { t.Errorf("getPlan() expected cache to have size %d, but got: %d", expected, size) } @@ -1682,8 +1678,7 @@ func assertCacheContains(t *testing.T, e *Executor, vc *vcursorImpl, sql string) var plan *engine.Plan if vc == nil { - e.plans.ForEach(func(x any) bool { - p := x.(*engine.Plan) + e.ForEachPlan(func(p *engine.Plan) bool { if p.Original == sql { plan = p } @@ -1691,9 +1686,7 @@ func assertCacheContains(t *testing.T, e *Executor, vc *vcursorImpl, sql string) }) } else { h := e.hashPlan(context.Background(), vc, sql) - if p, ok := e.plans.Get(h); ok { - plan = p.(*engine.Plan) - } + plan, _ = e.plans.Get(h, e.epoch.Load()) } require.Truef(t, plan != nil, "plan not found for query: %s", sql) return plan @@ -1712,7 +1705,7 @@ func getPlanCached(t *testing.T, ctx context.Context, e *Executor, vcursor *vcur require.NoError(t, err) // Wait for cache to settle - e.plans.Wait() + time.Sleep(100 * time.Millisecond) return plan, logStats } @@ -2147,7 +2140,7 @@ func TestServingKeyspaces(t *testing.T) { hc.BroadcastAll() // Clear plan cache, to force re-planning of the query. - executor.plans.Clear() + executor.ClearPlans() require.ElementsMatch(t, []string{"TestUnsharded"}, gw.GetServingKeyspaces()) result, err = executor.Execute(ctx, nil, "TestServingKeyspaces", NewSafeSession(&vtgatepb.Session{}), "select keyspace_name from dual", nil) require.NoError(t, err) diff --git a/go/vt/vtgate/querylog.go b/go/vt/vtgate/querylog.go index 94c1b23263b..7425f2feba9 100644 --- a/go/vt/vtgate/querylog.go +++ b/go/vt/vtgate/querylog.go @@ -19,7 +19,6 @@ package vtgate import ( "net/http" - "vitess.io/vitess/go/cache" "vitess.io/vitess/go/streamlog" "vitess.io/vitess/go/vt/servenv" "vitess.io/vitess/go/vt/vtgate/logstats" @@ -36,7 +35,7 @@ var ( QueryzHandler = "/debug/queryz" ) -func initQueryLogger(plans cache.Cache) (*streamlog.StreamLogger[*logstats.LogStats], error) { +func (e *Executor) defaultQueryLogger() error { queryLogger := streamlog.New[*logstats.LogStats]("VTGate", queryLogBufferSize) queryLogger.ServeLogs(QueryLogHandler, streamlog.GetFormatter(queryLogger)) @@ -47,15 +46,20 @@ func initQueryLogger(plans cache.Cache) (*streamlog.StreamLogger[*logstats.LogSt }) servenv.HTTPHandleFunc(QueryzHandler, func(w http.ResponseWriter, r *http.Request) { - queryzHandler(plans, w, r) + queryzHandler(e, w, r) }) if queryLogToFile != "" { _, err := queryLogger.LogToFile(queryLogToFile, streamlog.GetFormatter(queryLogger)) if err != nil { - return nil, err + return err } } - return queryLogger, nil + e.queryLogger = queryLogger + return nil +} + +func (e *Executor) SetQueryLogger(ql *streamlog.StreamLogger[*logstats.LogStats]) { + e.queryLogger = ql } diff --git a/go/vt/vtgate/queryz.go b/go/vt/vtgate/queryz.go index 6c5d2b89dee..e546fc68c6f 100644 --- a/go/vt/vtgate/queryz.go +++ b/go/vt/vtgate/queryz.go @@ -24,8 +24,6 @@ import ( "github.com/google/safehtml/template" - "vitess.io/vitess/go/cache" - "vitess.io/vitess/go/acl" "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/vt/logz" @@ -129,7 +127,7 @@ func (s *queryzSorter) Len() int { return len(s.rows) } func (s *queryzSorter) Swap(i, j int) { s.rows[i], s.rows[j] = s.rows[j], s.rows[i] } func (s *queryzSorter) Less(i, j int) bool { return s.less(s.rows[i], s.rows[j]) } -func queryzHandler(plans cache.Cache, w http.ResponseWriter, r *http.Request) { +func queryzHandler(e *Executor, w http.ResponseWriter, r *http.Request) { if err := acl.CheckAccessHTTP(r, acl.DEBUGGING); err != nil { acl.SendError(w, err) return @@ -145,8 +143,7 @@ func queryzHandler(plans cache.Cache, w http.ResponseWriter, r *http.Request) { }, } - plans.ForEach(func(value any) bool { - plan := value.(*engine.Plan) + e.ForEachPlan(func(plan *engine.Plan) bool { Value := &queryzRow{ Query: logz.Wrappable(sqlparser.TruncateForUI(plan.Original)), } diff --git a/go/vt/vtgate/queryz_test.go b/go/vt/vtgate/queryz_test.go index 83fd064df7d..826cb8641d8 100644 --- a/go/vt/vtgate/queryz_test.go +++ b/go/vt/vtgate/queryz_test.go @@ -46,7 +46,7 @@ func TestQueryzHandler(t *testing.T) { sql := "select id from user where id = 1" _, err := executorExec(ctx, executor, session, sql, nil) require.NoError(t, err) - executor.plans.Wait() + time.Sleep(100 * time.Millisecond) plan1 := assertCacheContains(t, executor, nil, "select id from `user` where id = 1") plan1.ExecTime = uint64(1 * time.Millisecond) @@ -54,7 +54,7 @@ func TestQueryzHandler(t *testing.T) { sql = "select id from user" _, err = executorExec(ctx, executor, session, sql, nil) require.NoError(t, err) - executor.plans.Wait() + time.Sleep(100 * time.Millisecond) plan2 := assertCacheContains(t, executor, nil, "select id from `user`") plan2.ExecTime = uint64(1 * time.Second) @@ -64,7 +64,7 @@ func TestQueryzHandler(t *testing.T) { "name": sqltypes.BytesBindVariable([]byte("myname")), }) require.NoError(t, err) - executor.plans.Wait() + time.Sleep(100 * time.Millisecond) plan3 := assertCacheContains(t, executor, nil, "insert into `user`(id, `name`) values (:id, :name)") // vindex insert from above execution @@ -82,7 +82,7 @@ func TestQueryzHandler(t *testing.T) { plan3.ExecTime = uint64(100 * time.Millisecond) plan4.ExecTime = uint64(200 * time.Millisecond) - queryzHandler(executor.plans, resp, req) + queryzHandler(executor, resp, req) body, _ := io.ReadAll(resp.Body) planPattern1 := []string{ `