diff --git a/changelog/18.0/18.0.0/summary.md b/changelog/18.0/18.0.0/summary.md index bc3ca5e6004..96f3cdf8541 100644 --- a/changelog/18.0/18.0.0/summary.md +++ b/changelog/18.0/18.0.0/summary.md @@ -83,10 +83,20 @@ Throttler related `vttablet` flags: - `--throttle_check_as_check_self` is deprecated and will be removed in `v19.0` - `--throttler-config-via-topo` is deprecated after assumed `true` in `v17.0`. It will be removed in a future version. +Cache related `vttablet` flags: + +- `--queryserver-config-query-cache-lfu` is deprecated and will be removed in `v19.0`. The query cache always uses a LFU implementation now. +- `--queryserver-config-query-cache-size` is deprecated and will be removed in `v19.0`. This option only applied to LRU caches, which are now unsupported. + Buffering related `vtgate` flags: - `--buffer_implementation` is deprecated and will be removed in `v19.0` +Cache related `vtgate` flags: + +- `--gate_query_cache_lfu` is deprecated and will be removed in `v19.0`. The query cache always uses a LFU implementation now. +- `--gate_query_cache_size` is deprecated and will be removed in `v19.0`. This option only applied to LRU caches, which are now unsupported. + VTGate flag: - `--schema_change_signal_user` is deprecated and will be removed in `v19.0` diff --git a/go.mod b/go.mod index 9b2692ed3b7..639a22edc6b 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,6 @@ require ( github.com/Azure/azure-storage-blob-go v0.15.0 github.com/DataDog/datadog-go v4.8.3+incompatible github.com/HdrHistogram/hdrhistogram-go v0.9.0 // indirect - github.com/PuerkitoBio/goquery v1.5.1 github.com/aquarapid/vaultlib v0.5.1 github.com/armon/go-metrics v0.4.1 // indirect github.com/aws/aws-sdk-go v1.44.258 @@ -76,7 +75,7 @@ require ( golang.org/x/mod v0.12.0 // indirect golang.org/x/net v0.14.0 golang.org/x/oauth2 v0.7.0 - golang.org/x/sys v0.11.0 // indirect + golang.org/x/sys v0.11.0 golang.org/x/term v0.11.0 golang.org/x/text v0.12.0 golang.org/x/time v0.3.0 @@ -97,6 +96,7 @@ require ( require ( github.com/Shopify/toxiproxy/v2 v2.5.0 github.com/bndr/gotabulate v1.1.2 + github.com/gammazero/deque v0.2.1 github.com/google/safehtml v0.1.0 github.com/hashicorp/go-version v1.6.0 github.com/kr/pretty v0.3.1 @@ -124,7 +124,6 @@ require ( github.com/DataDog/go-tuf v0.3.0--fix-localmeta-fork // indirect github.com/DataDog/sketches-go v1.4.1 // indirect github.com/Microsoft/go-winio v0.6.0 // indirect - github.com/andybalholm/cascadia v1.1.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/coreos/go-semver v0.3.1 // indirect github.com/coreos/go-systemd/v22 v22.5.0 // indirect diff --git a/go.sum b/go.sum index b5d7eb888c7..ffadd6498b9 100644 --- a/go.sum +++ b/go.sum @@ -96,16 +96,12 @@ github.com/Microsoft/go-winio v0.5.0/go.mod h1:JPGBdM1cNvN/6ISo+n8V5iA4v8pBzdOpz github.com/Microsoft/go-winio v0.5.1/go.mod h1:JPGBdM1cNvN/6ISo+n8V5iA4v8pBzdOpzfwIujj1a84= github.com/Microsoft/go-winio v0.6.0 h1:slsWYD/zyx7lCXoZVlvQrj0hPTM1HI4+v1sIda2yDvg= github.com/Microsoft/go-winio v0.6.0/go.mod h1:cTAf44im0RAYeL23bpB+fzCyDH2MJiz2BO69KH/soAE= -github.com/PuerkitoBio/goquery v1.5.1 h1:PSPBGne8NIUWw+/7vFBV+kG2J/5MOjbzc7154OaKCSE= -github.com/PuerkitoBio/goquery v1.5.1/go.mod h1:GsLWisAFVj4WgDibEWF4pvYnkVQBpKBKeU+7zCJoLcc= github.com/Shopify/toxiproxy/v2 v2.5.0 h1:i4LPT+qrSlKNtQf5QliVjdP08GyAH8+BUIc9gT0eahc= github.com/Shopify/toxiproxy/v2 v2.5.0/go.mod h1:yhM2epWtAmel9CB8r2+L+PCmhH6yH2pITaPAo7jxJl0= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= -github.com/andybalholm/cascadia v1.1.0 h1:BuuO6sSfQNFRu1LppgbD25Hr2vLYW25JvxHs5zzsLTo= -github.com/andybalholm/cascadia v1.1.0/go.mod h1:GsXiBklL0woXo1j/WYWtSYYC4ouU9PqHO0sqidkEA4Y= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/aquarapid/vaultlib v0.5.1 h1:vuLWR6bZzLHybjJBSUYPgZlIp6KZ+SXeHLRRYTuk6d4= github.com/aquarapid/vaultlib v0.5.1/go.mod h1:yT7AlEXtuabkxylOc/+Ulyp18tff1+QjgNLTnFWTlOs= @@ -195,6 +191,8 @@ github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMo github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= +github.com/gammazero/deque v0.2.1 h1:qSdsbG6pgp6nL7A0+K/B7s12mcCY/5l5SIUpMOl+dC0= +github.com/gammazero/deque v0.2.1/go.mod h1:LFroj8x4cMYCukHJDbxFCkT+r9AndaJnFMuZDV34tuU= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= @@ -707,7 +705,6 @@ golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91 golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/net v0.0.0-20180218175443-cbe0f9307d01/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= diff --git a/go/cache/cache.go b/go/cache/cache.go deleted file mode 100644 index a801d075fde..00000000000 --- a/go/cache/cache.go +++ /dev/null @@ -1,91 +0,0 @@ -/* -Copyright 2021 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package cache - -// Cache is a generic interface type for a data structure that keeps recently used -// objects in memory and evicts them when it becomes full. -type Cache interface { - Get(key string) (any, bool) - Set(key string, val any) bool - ForEach(callback func(any) bool) - - Delete(key string) - Clear() - - // Wait waits for all pending operations on the cache to settle. Since cache writes - // are asynchronous, a write may not be immediately accessible unless the user - // manually calls Wait. - Wait() - - Len() int - Evictions() int64 - Hits() int64 - Misses() int64 - UsedCapacity() int64 - MaxCapacity() int64 - SetCapacity(int64) - - // Close shuts down this cache and stops any background goroutines. - Close() -} - -type cachedObject interface { - CachedSize(alloc bool) int64 -} - -// NewDefaultCacheImpl returns the default cache implementation for Vitess. The options in the -// Config struct control the memory and entry limits for the cache, and the underlying cache -// implementation. -func NewDefaultCacheImpl(cfg *Config) Cache { - switch { - case cfg == nil: - return &nullCache{} - - case cfg.LFU: - if cfg.MaxEntries == 0 || cfg.MaxMemoryUsage == 0 { - return &nullCache{} - } - return NewRistrettoCache(cfg.MaxEntries, cfg.MaxMemoryUsage, func(val any) int64 { - return val.(cachedObject).CachedSize(true) - }) - - default: - if cfg.MaxEntries == 0 { - return &nullCache{} - } - return NewLRUCache(cfg.MaxEntries, func(_ any) int64 { - return 1 - }) - } -} - -// Config is the configuration options for a cache instance -type Config struct { - // MaxEntries is the estimated amount of entries that the cache will hold at capacity - MaxEntries int64 - // MaxMemoryUsage is the maximum amount of memory the cache can handle - MaxMemoryUsage int64 - // LFU toggles whether to use a new cache implementation with a TinyLFU admission policy - LFU bool -} - -// DefaultConfig is the default configuration for a cache instance in Vitess -var DefaultConfig = &Config{ - MaxEntries: 5000, - MaxMemoryUsage: 32 * 1024 * 1024, - LFU: true, -} diff --git a/go/cache/cache_test.go b/go/cache/cache_test.go deleted file mode 100644 index 911a3bb207b..00000000000 --- a/go/cache/cache_test.go +++ /dev/null @@ -1,47 +0,0 @@ -package cache - -import ( - "fmt" - "testing" - - "github.com/stretchr/testify/require" - - "vitess.io/vitess/go/cache/ristretto" -) - -func TestNewDefaultCacheImpl(t *testing.T) { - assertNullCache := func(t *testing.T, cache Cache) { - _, ok := cache.(*nullCache) - require.True(t, ok) - } - - assertLFUCache := func(t *testing.T, cache Cache) { - _, ok := cache.(*ristretto.Cache) - require.True(t, ok) - } - - assertLRUCache := func(t *testing.T, cache Cache) { - _, ok := cache.(*LRUCache) - require.True(t, ok) - } - - tests := []struct { - cfg *Config - verify func(t *testing.T, cache Cache) - }{ - {&Config{MaxEntries: 0, MaxMemoryUsage: 0, LFU: false}, assertNullCache}, - {&Config{MaxEntries: 0, MaxMemoryUsage: 0, LFU: true}, assertNullCache}, - {&Config{MaxEntries: 100, MaxMemoryUsage: 0, LFU: false}, assertLRUCache}, - {&Config{MaxEntries: 0, MaxMemoryUsage: 1000, LFU: false}, assertNullCache}, - {&Config{MaxEntries: 100, MaxMemoryUsage: 1000, LFU: false}, assertLRUCache}, - {&Config{MaxEntries: 100, MaxMemoryUsage: 0, LFU: true}, assertNullCache}, - {&Config{MaxEntries: 100, MaxMemoryUsage: 1000, LFU: true}, assertLFUCache}, - {&Config{MaxEntries: 0, MaxMemoryUsage: 1000, LFU: true}, assertNullCache}, - } - for _, tt := range tests { - t.Run(fmt.Sprintf("%d.%d.%v", tt.cfg.MaxEntries, tt.cfg.MaxMemoryUsage, tt.cfg.LFU), func(t *testing.T) { - cache := NewDefaultCacheImpl(tt.cfg) - tt.verify(t, cache) - }) - } -} diff --git a/go/cache/lru_cache.go b/go/cache/lru_cache.go index 8cc89ac55a4..d845265b77b 100644 --- a/go/cache/lru_cache.go +++ b/go/cache/lru_cache.go @@ -29,8 +29,6 @@ import ( "time" ) -var _ Cache = &LRUCache{} - // LRUCache is a typical LRU cache implementation. If the cache // reaches the capacity, the least recently used item is deleted from // the cache. Note the capacity is not the number of items, but the diff --git a/go/cache/null.go b/go/cache/null.go deleted file mode 100644 index 2e1eeeb0d2d..00000000000 --- a/go/cache/null.go +++ /dev/null @@ -1,75 +0,0 @@ -/* -Copyright 2021 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package cache - -// nullCache is a no-op cache that does not store items -type nullCache struct{} - -// Get never returns anything on the nullCache -func (n *nullCache) Get(_ string) (any, bool) { - return nil, false -} - -// Set is a no-op in the nullCache -func (n *nullCache) Set(_ string, _ any) bool { - return false -} - -// ForEach iterates the nullCache, which is always empty -func (n *nullCache) ForEach(_ func(any) bool) {} - -// Delete is a no-op in the nullCache -func (n *nullCache) Delete(_ string) {} - -// Clear is a no-op in the nullCache -func (n *nullCache) Clear() {} - -// Wait is a no-op in the nullcache -func (n *nullCache) Wait() {} - -func (n *nullCache) Len() int { - return 0 -} - -// Hits returns number of cache hits since creation -func (n *nullCache) Hits() int64 { - return 0 -} - -// Hits returns number of cache misses since creation -func (n *nullCache) Misses() int64 { - return 0 -} - -// Capacity returns the capacity of the nullCache, which is always 0 -func (n *nullCache) UsedCapacity() int64 { - return 0 -} - -// Capacity returns the capacity of the nullCache, which is always 0 -func (n *nullCache) MaxCapacity() int64 { - return 0 -} - -// SetCapacity sets the capacity of the null cache, which is a no-op -func (n *nullCache) SetCapacity(_ int64) {} - -func (n *nullCache) Evictions() int64 { - return 0 -} - -func (n *nullCache) Close() {} diff --git a/go/cache/ristretto.go b/go/cache/ristretto.go deleted file mode 100644 index 6d6f596a5b9..00000000000 --- a/go/cache/ristretto.go +++ /dev/null @@ -1,28 +0,0 @@ -package cache - -import ( - "vitess.io/vitess/go/cache/ristretto" -) - -var _ Cache = &ristretto.Cache{} - -// NewRistrettoCache returns a Cache implementation based on Ristretto -func NewRistrettoCache(maxEntries, maxCost int64, cost func(any) int64) *ristretto.Cache { - // The TinyLFU paper recommends to allocate 10x times the max entries amount as counters - // for the admission policy; since our caches are small and we're very interested on admission - // accuracy, we're a bit more greedy than 10x - const CounterRatio = 12 - - config := ristretto.Config{ - NumCounters: maxEntries * CounterRatio, - MaxCost: maxCost, - BufferItems: 64, - Metrics: true, - Cost: cost, - } - cache, err := ristretto.NewCache(&config) - if err != nil { - panic(err) - } - return cache -} diff --git a/go/cache/ristretto/bloom/bbloom.go b/go/cache/ristretto/bloom/bbloom.go deleted file mode 100644 index 9d6b1080a2e..00000000000 --- a/go/cache/ristretto/bloom/bbloom.go +++ /dev/null @@ -1,149 +0,0 @@ -// The MIT License (MIT) -// Copyright (c) 2014 Andreas Briese, eduToolbox@Bri-C GmbH, Sarstedt - -// Permission is hereby granted, free of charge, to any person obtaining a copy of -// this software and associated documentation files (the "Software"), to deal in -// the Software without restriction, including without limitation the rights to -// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of -// the Software, and to permit persons to whom the Software is furnished to do so, -// subject to the following conditions: - -// The above copyright notice and this permission notice shall be included in all -// copies or substantial portions of the Software. - -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS -// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR -// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER -// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN -// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -package bloom - -import ( - "math" - "unsafe" -) - -// helper -var mask = []uint8{1, 2, 4, 8, 16, 32, 64, 128} - -func getSize(ui64 uint64) (size uint64, exponent uint64) { - if ui64 < uint64(512) { - ui64 = uint64(512) - } - size = uint64(1) - for size < ui64 { - size <<= 1 - exponent++ - } - return size, exponent -} - -// NewBloomFilterWithErrorRate returns a new bloomfilter with optimal size for the given -// error rate -func NewBloomFilterWithErrorRate(numEntries uint64, wrongs float64) *Bloom { - size := -1 * float64(numEntries) * math.Log(wrongs) / math.Pow(0.69314718056, 2) - locs := math.Ceil(0.69314718056 * size / float64(numEntries)) - return NewBloomFilter(uint64(size), uint64(locs)) -} - -// NewBloomFilter returns a new bloomfilter. -func NewBloomFilter(entries, locs uint64) (bloomfilter *Bloom) { - size, exponent := getSize(entries) - bloomfilter = &Bloom{ - sizeExp: exponent, - size: size - 1, - setLocs: locs, - shift: 64 - exponent, - } - bloomfilter.Size(size) - return bloomfilter -} - -// Bloom filter -type Bloom struct { - bitset []uint64 - ElemNum uint64 - sizeExp uint64 - size uint64 - setLocs uint64 - shift uint64 -} - -// <--- http://www.cse.yorku.ca/~oz/hash.html -// modified Berkeley DB Hash (32bit) -// hash is casted to l, h = 16bit fragments -// func (bl Bloom) absdbm(b *[]byte) (l, h uint64) { -// hash := uint64(len(*b)) -// for _, c := range *b { -// hash = uint64(c) + (hash << 6) + (hash << bl.sizeExp) - hash -// } -// h = hash >> bl.shift -// l = hash << bl.shift >> bl.shift -// return l, h -// } - -// Add adds hash of a key to the bloomfilter. -func (bl *Bloom) Add(hash uint64) { - h := hash >> bl.shift - l := hash << bl.shift >> bl.shift - for i := uint64(0); i < bl.setLocs; i++ { - bl.Set((h + i*l) & bl.size) - bl.ElemNum++ - } -} - -// Has checks if bit(s) for entry hash is/are set, -// returns true if the hash was added to the Bloom Filter. -func (bl Bloom) Has(hash uint64) bool { - h := hash >> bl.shift - l := hash << bl.shift >> bl.shift - for i := uint64(0); i < bl.setLocs; i++ { - if !bl.IsSet((h + i*l) & bl.size) { - return false - } - } - return true -} - -// AddIfNotHas only Adds hash, if it's not present in the bloomfilter. -// Returns true if hash was added. -// Returns false if hash was already registered in the bloomfilter. -func (bl *Bloom) AddIfNotHas(hash uint64) bool { - if bl.Has(hash) { - return false - } - bl.Add(hash) - return true -} - -// TotalSize returns the total size of the bloom filter. -func (bl *Bloom) TotalSize() int { - // The bl struct has 5 members and each one is 8 byte. The bitset is a - // uint64 byte slice. - return len(bl.bitset)*8 + 5*8 -} - -// Size makes Bloom filter with as bitset of size sz. -func (bl *Bloom) Size(sz uint64) { - bl.bitset = make([]uint64, sz>>6) -} - -// Clear resets the Bloom filter. -func (bl *Bloom) Clear() { - clear(bl.bitset) -} - -// Set sets the bit[idx] of bitset. -func (bl *Bloom) Set(idx uint64) { - ptr := unsafe.Pointer(uintptr(unsafe.Pointer(&bl.bitset[idx>>6])) + uintptr((idx%64)>>3)) - *(*uint8)(ptr) |= mask[idx%8] -} - -// IsSet checks if bit[idx] of bitset is set, returns true/false. -func (bl *Bloom) IsSet(idx uint64) bool { - ptr := unsafe.Pointer(uintptr(unsafe.Pointer(&bl.bitset[idx>>6])) + uintptr((idx%64)>>3)) - r := ((*(*uint8)(ptr)) >> (idx % 8)) & 1 - return r == 1 -} diff --git a/go/cache/ristretto/bloom/bbloom_test.go b/go/cache/ristretto/bloom/bbloom_test.go deleted file mode 100644 index 7d280988bae..00000000000 --- a/go/cache/ristretto/bloom/bbloom_test.go +++ /dev/null @@ -1,86 +0,0 @@ -package bloom - -import ( - "crypto/rand" - "os" - "testing" - - _flag "vitess.io/vitess/go/internal/flag" - "vitess.io/vitess/go/vt/log" - - "vitess.io/vitess/go/hack" -) - -var ( - wordlist1 [][]byte - n = uint64(1 << 16) - bf *Bloom -) - -func TestMain(m *testing.M) { - // hack to get rid of an "ERROR: logging before flag.Parse" - _flag.TrickGlog() - wordlist1 = make([][]byte, n) - for i := range wordlist1 { - b := make([]byte, 32) - _, _ = rand.Read(b) - wordlist1[i] = b - } - log.Info("Benchmarks relate to 2**16 OP. --> output/65536 op/ns") - - os.Exit(m.Run()) -} - -func TestM_NumberOfWrongs(t *testing.T) { - bf = NewBloomFilter(n*10, 7) - - cnt := 0 - for i := range wordlist1 { - hash := hack.RuntimeMemhash(wordlist1[i], 0) - if !bf.AddIfNotHas(hash) { - cnt++ - } - } - log.Infof("Bloomfilter New(7* 2**16, 7) (-> size=%v bit): \n Check for 'false positives': %v wrong positive 'Has' results on 2**16 entries => %v %%", len(bf.bitset)<<6, cnt, float64(cnt)/float64(n)) - -} - -func BenchmarkM_New(b *testing.B) { - for r := 0; r < b.N; r++ { - _ = NewBloomFilter(n*10, 7) - } -} - -func BenchmarkM_Clear(b *testing.B) { - bf = NewBloomFilter(n*10, 7) - for i := range wordlist1 { - hash := hack.RuntimeMemhash(wordlist1[i], 0) - bf.Add(hash) - } - b.ResetTimer() - for r := 0; r < b.N; r++ { - bf.Clear() - } -} - -func BenchmarkM_Add(b *testing.B) { - bf = NewBloomFilter(n*10, 7) - b.ResetTimer() - for r := 0; r < b.N; r++ { - for i := range wordlist1 { - hash := hack.RuntimeMemhash(wordlist1[i], 0) - bf.Add(hash) - } - } - -} - -func BenchmarkM_Has(b *testing.B) { - b.ResetTimer() - for r := 0; r < b.N; r++ { - for i := range wordlist1 { - hash := hack.RuntimeMemhash(wordlist1[i], 0) - bf.Has(hash) - } - } -} diff --git a/go/cache/ristretto/cache.go b/go/cache/ristretto/cache.go deleted file mode 100644 index aa6aa2c2870..00000000000 --- a/go/cache/ristretto/cache.go +++ /dev/null @@ -1,701 +0,0 @@ -/* - * Copyright 2019 Dgraph Labs, Inc. and Contributors - * Copyright 2021 The Vitess Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// Package ristretto is a fast, fixed size, in-memory cache with a dual focus on -// throughput and hit ratio performance. You can easily add Ristretto to an -// existing system and keep the most valuable data where you need it. -package ristretto - -import ( - "bytes" - "errors" - "fmt" - "sync" - "sync/atomic" - "time" - "unsafe" - - "vitess.io/vitess/go/hack" -) - -var ( - // TODO: find the optimal value for this or make it configurable - setBufSize = 32 * 1024 -) - -func defaultStringHash(key string) (uint64, uint64) { - const Seed1 = uint64(0x1122334455667788) - const Seed2 = uint64(0x8877665544332211) - return hack.RuntimeStrhash(key, Seed1), hack.RuntimeStrhash(key, Seed2) -} - -type itemCallback func(*Item) - -// CacheItemSize is the overhead in bytes for every stored cache item -var CacheItemSize = hack.RuntimeAllocSize(int64(unsafe.Sizeof(storeItem{}))) - -// Cache is a thread-safe implementation of a hashmap with a TinyLFU admission -// policy and a Sampled LFU eviction policy. You can use the same Cache instance -// from as many goroutines as you want. -type Cache struct { - // store is the central concurrent hashmap where key-value items are stored. - store store - // policy determines what gets let in to the cache and what gets kicked out. - policy policy - // getBuf is a custom ring buffer implementation that gets pushed to when - // keys are read. - getBuf *ringBuffer - // setBuf is a buffer allowing us to batch/drop Sets during times of high - // contention. - setBuf chan *Item - // onEvict is called for item evictions. - onEvict itemCallback - // onReject is called when an item is rejected via admission policy. - onReject itemCallback - // onExit is called whenever a value goes out of scope from the cache. - onExit func(any) - // KeyToHash function is used to customize the key hashing algorithm. - // Each key will be hashed using the provided function. If keyToHash value - // is not set, the default keyToHash function is used. - keyToHash func(string) (uint64, uint64) - // stop is used to stop the processItems goroutine. - stop chan struct{} - // indicates whether cache is closed. - isClosed atomic.Bool - // cost calculates cost from a value. - cost func(value any) int64 - // ignoreInternalCost dictates whether to ignore the cost of internally storing - // the item in the cost calculation. - ignoreInternalCost bool - // Metrics contains a running log of important statistics like hits, misses, - // and dropped items. - Metrics *Metrics -} - -// Config is passed to NewCache for creating new Cache instances. -type Config struct { - // NumCounters determines the number of counters (keys) to keep that hold - // access frequency information. It's generally a good idea to have more - // counters than the max cache capacity, as this will improve eviction - // accuracy and subsequent hit ratios. - // - // For example, if you expect your cache to hold 1,000,000 items when full, - // NumCounters should be 10,000,000 (10x). Each counter takes up 4 bits, so - // keeping 10,000,000 counters would require 5MB of memory. - NumCounters int64 - // MaxCost can be considered as the cache capacity, in whatever units you - // choose to use. - // - // For example, if you want the cache to have a max capacity of 100MB, you - // would set MaxCost to 100,000,000 and pass an item's number of bytes as - // the `cost` parameter for calls to Set. If new items are accepted, the - // eviction process will take care of making room for the new item and not - // overflowing the MaxCost value. - MaxCost int64 - // BufferItems determines the size of Get buffers. - // - // Unless you have a rare use case, using `64` as the BufferItems value - // results in good performance. - BufferItems int64 - // Metrics determines whether cache statistics are kept during the cache's - // lifetime. There *is* some overhead to keeping statistics, so you should - // only set this flag to true when testing or throughput performance isn't a - // major factor. - Metrics bool - // OnEvict is called for every eviction and passes the hashed key, value, - // and cost to the function. - OnEvict func(item *Item) - // OnReject is called for every rejection done via the policy. - OnReject func(item *Item) - // OnExit is called whenever a value is removed from cache. This can be - // used to do manual memory deallocation. Would also be called on eviction - // and rejection of the value. - OnExit func(val any) - // KeyToHash function is used to customize the key hashing algorithm. - // Each key will be hashed using the provided function. If keyToHash value - // is not set, the default keyToHash function is used. - KeyToHash func(string) (uint64, uint64) - // Cost evaluates a value and outputs a corresponding cost. This function - // is ran after Set is called for a new item or an item update with a cost - // param of 0. - Cost func(value any) int64 - // IgnoreInternalCost set to true indicates to the cache that the cost of - // internally storing the value should be ignored. This is useful when the - // cost passed to set is not using bytes as units. Keep in mind that setting - // this to true will increase the memory usage. - IgnoreInternalCost bool -} - -type itemFlag byte - -const ( - itemNew itemFlag = iota - itemDelete - itemUpdate -) - -// Item is passed to setBuf so items can eventually be added to the cache. -type Item struct { - flag itemFlag - Key uint64 - Conflict uint64 - Value any - Cost int64 - wg *sync.WaitGroup -} - -// NewCache returns a new Cache instance and any configuration errors, if any. -func NewCache(config *Config) (*Cache, error) { - switch { - case config.NumCounters == 0: - return nil, errors.New("NumCounters can't be zero") - case config.MaxCost == 0: - return nil, errors.New("Capacity can't be zero") - case config.BufferItems == 0: - return nil, errors.New("BufferItems can't be zero") - } - policy := newPolicy(config.NumCounters, config.MaxCost) - cache := &Cache{ - store: newStore(), - policy: policy, - getBuf: newRingBuffer(policy, config.BufferItems), - setBuf: make(chan *Item, setBufSize), - keyToHash: config.KeyToHash, - stop: make(chan struct{}), - cost: config.Cost, - ignoreInternalCost: config.IgnoreInternalCost, - } - cache.onExit = func(val any) { - if config.OnExit != nil && val != nil { - config.OnExit(val) - } - } - cache.onEvict = func(item *Item) { - if config.OnEvict != nil { - config.OnEvict(item) - } - cache.onExit(item.Value) - } - cache.onReject = func(item *Item) { - if config.OnReject != nil { - config.OnReject(item) - } - cache.onExit(item.Value) - } - if cache.keyToHash == nil { - cache.keyToHash = defaultStringHash - } - if config.Metrics { - cache.collectMetrics() - } - // NOTE: benchmarks seem to show that performance decreases the more - // goroutines we have running cache.processItems(), so 1 should - // usually be sufficient - go cache.processItems() - return cache, nil -} - -// Wait blocks until all the current cache operations have been processed in the background -func (c *Cache) Wait() { - if c == nil || c.isClosed.Load() { - return - } - wg := &sync.WaitGroup{} - wg.Add(1) - c.setBuf <- &Item{wg: wg} - wg.Wait() -} - -// Get returns the value (if any) and a boolean representing whether the -// value was found or not. The value can be nil and the boolean can be true at -// the same time. -func (c *Cache) Get(key string) (any, bool) { - if c == nil || c.isClosed.Load() { - return nil, false - } - keyHash, conflictHash := c.keyToHash(key) - c.getBuf.Push(keyHash) - value, ok := c.store.Get(keyHash, conflictHash) - if ok { - c.Metrics.add(hit, keyHash, 1) - } else { - c.Metrics.add(miss, keyHash, 1) - } - return value, ok -} - -// Set attempts to add the key-value item to the cache. If it returns false, -// then the Set was dropped and the key-value item isn't added to the cache. If -// it returns true, there's still a chance it could be dropped by the policy if -// its determined that the key-value item isn't worth keeping, but otherwise the -// item will be added and other items will be evicted in order to make room. -// -// The cost of the entry will be evaluated lazily by the cache's Cost function. -func (c *Cache) Set(key string, value any) bool { - return c.SetWithCost(key, value, 0) -} - -// SetWithCost works like Set but adds a key-value pair to the cache with a specific -// cost. The built-in Cost function will not be called to evaluate the object's cost -// and instead the given value will be used. -func (c *Cache) SetWithCost(key string, value any, cost int64) bool { - if c == nil || c.isClosed.Load() { - return false - } - - keyHash, conflictHash := c.keyToHash(key) - i := &Item{ - flag: itemNew, - Key: keyHash, - Conflict: conflictHash, - Value: value, - Cost: cost, - } - // cost is eventually updated. The expiration must also be immediately updated - // to prevent items from being prematurely removed from the map. - if prev, ok := c.store.Update(i); ok { - c.onExit(prev) - i.flag = itemUpdate - } - // Attempt to send item to policy. - select { - case c.setBuf <- i: - return true - default: - if i.flag == itemUpdate { - // Return true if this was an update operation since we've already - // updated the store. For all the other operations (set/delete), we - // return false which means the item was not inserted. - return true - } - c.Metrics.add(dropSets, keyHash, 1) - return false - } -} - -// Delete deletes the key-value item from the cache if it exists. -func (c *Cache) Delete(key string) { - if c == nil || c.isClosed.Load() { - return - } - keyHash, conflictHash := c.keyToHash(key) - // Delete immediately. - _, prev := c.store.Del(keyHash, conflictHash) - c.onExit(prev) - // If we've set an item, it would be applied slightly later. - // So we must push the same item to `setBuf` with the deletion flag. - // This ensures that if a set is followed by a delete, it will be - // applied in the correct order. - c.setBuf <- &Item{ - flag: itemDelete, - Key: keyHash, - Conflict: conflictHash, - } -} - -// Close stops all goroutines and closes all channels. -func (c *Cache) Close() { - if c == nil { - return - } - wasClosed := c.isClosed.Swap(true) - if wasClosed { - return - } - c.Clear() - - // Block until processItems goroutine is returned. - c.stop <- struct{}{} - close(c.stop) - close(c.setBuf) - c.policy.Close() - c.isClosed.Store(true) -} - -// Clear empties the hashmap and zeroes all policy counters. Note that this is -// not an atomic operation (but that shouldn't be a problem as it's assumed that -// Set/Get calls won't be occurring until after this). -func (c *Cache) Clear() { - if c == nil || c.isClosed.Load() { - return - } - // Block until processItems goroutine is returned. - c.stop <- struct{}{} - - // Clear out the setBuf channel. -loop: - for { - select { - case i := <-c.setBuf: - if i.wg != nil { - i.wg.Done() - continue - } - if i.flag != itemUpdate { - // In itemUpdate, the value is already set in the store. So, no need to call - // onEvict here. - c.onEvict(i) - } - default: - break loop - } - } - - // Clear value hashmap and policy data. - c.policy.Clear() - c.store.Clear(c.onEvict) - // Only reset metrics if they're enabled. - if c.Metrics != nil { - c.Metrics.Clear() - } - // Restart processItems goroutine. - go c.processItems() -} - -// Len returns the size of the cache (in entries) -func (c *Cache) Len() int { - if c == nil { - return 0 - } - return c.store.Len() -} - -// UsedCapacity returns the size of the cache (in bytes) -func (c *Cache) UsedCapacity() int64 { - if c == nil { - return 0 - } - return c.policy.Used() -} - -// MaxCapacity returns the max cost of the cache (in bytes) -func (c *Cache) MaxCapacity() int64 { - if c == nil { - return 0 - } - return c.policy.MaxCost() -} - -// SetCapacity updates the maxCost of an existing cache. -func (c *Cache) SetCapacity(maxCost int64) { - if c == nil { - return - } - c.policy.UpdateMaxCost(maxCost) -} - -// Evictions returns the number of evictions -func (c *Cache) Evictions() int64 { - // TODO - if c == nil || c.Metrics == nil { - return 0 - } - return int64(c.Metrics.KeysEvicted()) -} - -// Hits returns the number of cache hits -func (c *Cache) Hits() int64 { - if c == nil || c.Metrics == nil { - return 0 - } - return int64(c.Metrics.Hits()) -} - -// Misses returns the number of cache misses -func (c *Cache) Misses() int64 { - if c == nil || c.Metrics == nil { - return 0 - } - return int64(c.Metrics.Misses()) -} - -// ForEach yields all the values currently stored in the cache to the given callback. -// The callback may return `false` to stop the iteration early. -func (c *Cache) ForEach(forEach func(any) bool) { - if c == nil { - return - } - c.store.ForEach(forEach) -} - -// processItems is ran by goroutines processing the Set buffer. -func (c *Cache) processItems() { - startTs := make(map[uint64]time.Time) - numToKeep := 100000 // TODO: Make this configurable via options. - - trackAdmission := func(key uint64) { - if c.Metrics == nil { - return - } - startTs[key] = time.Now() - if len(startTs) > numToKeep { - for k := range startTs { - if len(startTs) <= numToKeep { - break - } - delete(startTs, k) - } - } - } - onEvict := func(i *Item) { - delete(startTs, i.Key) - if c.onEvict != nil { - c.onEvict(i) - } - } - - for { - select { - case i := <-c.setBuf: - if i.wg != nil { - i.wg.Done() - continue - } - // Calculate item cost value if new or update. - if i.Cost == 0 && c.cost != nil && i.flag != itemDelete { - i.Cost = c.cost(i.Value) - } - if !c.ignoreInternalCost { - // Add the cost of internally storing the object. - i.Cost += CacheItemSize - } - - switch i.flag { - case itemNew: - victims, added := c.policy.Add(i.Key, i.Cost) - if added { - c.store.Set(i) - c.Metrics.add(keyAdd, i.Key, 1) - trackAdmission(i.Key) - } else { - c.onReject(i) - } - for _, victim := range victims { - victim.Conflict, victim.Value = c.store.Del(victim.Key, 0) - onEvict(victim) - } - - case itemUpdate: - c.policy.Update(i.Key, i.Cost) - - case itemDelete: - c.policy.Del(i.Key) // Deals with metrics updates. - _, val := c.store.Del(i.Key, i.Conflict) - c.onExit(val) - } - case <-c.stop: - return - } - } -} - -// collectMetrics just creates a new *Metrics instance and adds the pointers -// to the cache and policy instances. -func (c *Cache) collectMetrics() { - c.Metrics = newMetrics() - c.policy.CollectMetrics(c.Metrics) -} - -type metricType int - -const ( - // The following 2 keep track of hits and misses. - hit = iota - miss - // The following 3 keep track of number of keys added, updated and evicted. - keyAdd - keyUpdate - keyEvict - // The following 2 keep track of cost of keys added and evicted. - costAdd - costEvict - // The following keep track of how many sets were dropped or rejected later. - dropSets - rejectSets - // The following 2 keep track of how many gets were kept and dropped on the - // floor. - dropGets - keepGets - // This should be the final enum. Other enums should be set before this. - doNotUse -) - -func stringFor(t metricType) string { - switch t { - case hit: - return "hit" - case miss: - return "miss" - case keyAdd: - return "keys-added" - case keyUpdate: - return "keys-updated" - case keyEvict: - return "keys-evicted" - case costAdd: - return "cost-added" - case costEvict: - return "cost-evicted" - case dropSets: - return "sets-dropped" - case rejectSets: - return "sets-rejected" // by policy. - case dropGets: - return "gets-dropped" - case keepGets: - return "gets-kept" - default: - return "unidentified" - } -} - -// Metrics is a snapshot of performance statistics for the lifetime of a cache instance. -type Metrics struct { - all [doNotUse][]*uint64 -} - -func newMetrics() *Metrics { - s := &Metrics{} - for i := 0; i < doNotUse; i++ { - s.all[i] = make([]*uint64, 256) - slice := s.all[i] - for j := range slice { - slice[j] = new(uint64) - } - } - return s -} - -func (p *Metrics) add(t metricType, hash, delta uint64) { - if p == nil { - return - } - valp := p.all[t] - // Avoid false sharing by padding at least 64 bytes of space between two - // atomic counters which would be incremented. - idx := (hash % 25) * 10 - atomic.AddUint64(valp[idx], delta) -} - -func (p *Metrics) get(t metricType) uint64 { - if p == nil { - return 0 - } - valp := p.all[t] - var total uint64 - for i := range valp { - total += atomic.LoadUint64(valp[i]) - } - return total -} - -// Hits is the number of Get calls where a value was found for the corresponding key. -func (p *Metrics) Hits() uint64 { - return p.get(hit) -} - -// Misses is the number of Get calls where a value was not found for the corresponding key. -func (p *Metrics) Misses() uint64 { - return p.get(miss) -} - -// KeysAdded is the total number of Set calls where a new key-value item was added. -func (p *Metrics) KeysAdded() uint64 { - return p.get(keyAdd) -} - -// KeysUpdated is the total number of Set calls where the value was updated. -func (p *Metrics) KeysUpdated() uint64 { - return p.get(keyUpdate) -} - -// KeysEvicted is the total number of keys evicted. -func (p *Metrics) KeysEvicted() uint64 { - return p.get(keyEvict) -} - -// CostAdded is the sum of costs that have been added (successful Set calls). -func (p *Metrics) CostAdded() uint64 { - return p.get(costAdd) -} - -// CostEvicted is the sum of all costs that have been evicted. -func (p *Metrics) CostEvicted() uint64 { - return p.get(costEvict) -} - -// SetsDropped is the number of Set calls that don't make it into internal -// buffers (due to contention or some other reason). -func (p *Metrics) SetsDropped() uint64 { - return p.get(dropSets) -} - -// SetsRejected is the number of Set calls rejected by the policy (TinyLFU). -func (p *Metrics) SetsRejected() uint64 { - return p.get(rejectSets) -} - -// GetsDropped is the number of Get counter increments that are dropped -// internally. -func (p *Metrics) GetsDropped() uint64 { - return p.get(dropGets) -} - -// GetsKept is the number of Get counter increments that are kept. -func (p *Metrics) GetsKept() uint64 { - return p.get(keepGets) -} - -// Ratio is the number of Hits over all accesses (Hits + Misses). This is the -// percentage of successful Get calls. -func (p *Metrics) Ratio() float64 { - if p == nil { - return 0.0 - } - hits, misses := p.get(hit), p.get(miss) - if hits == 0 && misses == 0 { - return 0.0 - } - return float64(hits) / float64(hits+misses) -} - -// Clear resets all the metrics. -func (p *Metrics) Clear() { - if p == nil { - return - } - for i := 0; i < doNotUse; i++ { - for j := range p.all[i] { - atomic.StoreUint64(p.all[i][j], 0) - } - } -} - -// String returns a string representation of the metrics. -func (p *Metrics) String() string { - if p == nil { - return "" - } - var buf bytes.Buffer - for i := 0; i < doNotUse; i++ { - t := metricType(i) - fmt.Fprintf(&buf, "%s: %d ", stringFor(t), p.get(t)) - } - fmt.Fprintf(&buf, "gets-total: %d ", p.get(hit)+p.get(miss)) - fmt.Fprintf(&buf, "hit-ratio: %.2f", p.Ratio()) - return buf.String() -} diff --git a/go/cache/ristretto/cache_test.go b/go/cache/ristretto/cache_test.go deleted file mode 100644 index eda9f9109f3..00000000000 --- a/go/cache/ristretto/cache_test.go +++ /dev/null @@ -1,690 +0,0 @@ -/* - * Copyright 2019 Dgraph Labs, Inc. and Contributors - * Copyright 2021 The Vitess Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package ristretto - -import ( - "fmt" - "math/rand" - "strconv" - "strings" - "sync" - "testing" - "time" - - "vitess.io/vitess/go/vt/log" - - "github.com/stretchr/testify/require" -) - -var wait = time.Millisecond * 10 - -func TestCacheKeyToHash(t *testing.T) { - keyToHashCount := 0 - c, err := NewCache(&Config{ - NumCounters: 10, - MaxCost: 1000, - BufferItems: 64, - IgnoreInternalCost: true, - KeyToHash: func(key string) (uint64, uint64) { - keyToHashCount++ - return defaultStringHash(key) - }, - }) - require.NoError(t, err) - if c.SetWithCost("1", 1, 1) { - time.Sleep(wait) - val, ok := c.Get("1") - require.True(t, ok) - require.NotNil(t, val) - c.Delete("1") - } - require.Equal(t, 3, keyToHashCount) -} - -func TestCacheMaxCost(t *testing.T) { - charset := "abcdefghijklmnopqrstuvwxyz0123456789" - key := func() string { - k := make([]byte, 2) - for i := range k { - k[i] = charset[rand.Intn(len(charset))] - } - return string(k) - } - c, err := NewCache(&Config{ - NumCounters: 12960, // 36^2 * 10 - MaxCost: 1e6, // 1mb - BufferItems: 64, - Metrics: true, - }) - require.NoError(t, err) - stop := make(chan struct{}, 8) - for i := 0; i < 8; i++ { - go func() { - for { - select { - case <-stop: - return - default: - time.Sleep(time.Millisecond) - - k := key() - if _, ok := c.Get(k); !ok { - val := "" - if rand.Intn(100) < 10 { - val = "test" - } else { - val = strings.Repeat("a", 1000) - } - c.SetWithCost(key(), val, int64(2+len(val))) - } - } - } - }() - } - for i := 0; i < 20; i++ { - time.Sleep(time.Second) - cacheCost := c.Metrics.CostAdded() - c.Metrics.CostEvicted() - log.Infof("total cache cost: %d", cacheCost) - require.True(t, float64(cacheCost) <= float64(1e6*1.05)) - } - for i := 0; i < 8; i++ { - stop <- struct{}{} - } -} - -func TestUpdateMaxCost(t *testing.T) { - c, err := NewCache(&Config{ - NumCounters: 10, - MaxCost: 10, - BufferItems: 64, - }) - require.NoError(t, err) - require.Equal(t, int64(10), c.MaxCapacity()) - require.True(t, c.SetWithCost("1", 1, 1)) - time.Sleep(wait) - _, ok := c.Get("1") - // Set is rejected because the cost of the entry is too high - // when accounting for the internal cost of storing the entry. - require.False(t, ok) - - // Update the max cost of the cache and retry. - c.SetCapacity(1000) - require.Equal(t, int64(1000), c.MaxCapacity()) - require.True(t, c.SetWithCost("1", 1, 1)) - time.Sleep(wait) - val, ok := c.Get("1") - require.True(t, ok) - require.NotNil(t, val) - c.Delete("1") -} - -func TestNewCache(t *testing.T) { - _, err := NewCache(&Config{ - NumCounters: 0, - }) - require.Error(t, err) - - _, err = NewCache(&Config{ - NumCounters: 100, - MaxCost: 0, - }) - require.Error(t, err) - - _, err = NewCache(&Config{ - NumCounters: 100, - MaxCost: 10, - BufferItems: 0, - }) - require.Error(t, err) - - c, err := NewCache(&Config{ - NumCounters: 100, - MaxCost: 10, - BufferItems: 64, - Metrics: true, - }) - require.NoError(t, err) - require.NotNil(t, c) -} - -func TestNilCache(t *testing.T) { - var c *Cache - val, ok := c.Get("1") - require.False(t, ok) - require.Nil(t, val) - - require.False(t, c.SetWithCost("1", 1, 1)) - c.Delete("1") - c.Clear() - c.Close() -} - -func TestMultipleClose(t *testing.T) { - var c *Cache - c.Close() - - var err error - c, err = NewCache(&Config{ - NumCounters: 100, - MaxCost: 10, - BufferItems: 64, - Metrics: true, - }) - require.NoError(t, err) - c.Close() - c.Close() -} - -func TestSetAfterClose(t *testing.T) { - c, err := newTestCache() - require.NoError(t, err) - require.NotNil(t, c) - - c.Close() - require.False(t, c.SetWithCost("1", 1, 1)) -} - -func TestClearAfterClose(t *testing.T) { - c, err := newTestCache() - require.NoError(t, err) - require.NotNil(t, c) - - c.Close() - c.Clear() -} - -func TestGetAfterClose(t *testing.T) { - c, err := newTestCache() - require.NoError(t, err) - require.NotNil(t, c) - - require.True(t, c.SetWithCost("1", 1, 1)) - c.Close() - - _, ok := c.Get("2") - require.False(t, ok) -} - -func TestDelAfterClose(t *testing.T) { - c, err := newTestCache() - require.NoError(t, err) - require.NotNil(t, c) - - require.True(t, c.SetWithCost("1", 1, 1)) - c.Close() - - c.Delete("1") -} - -func TestCacheProcessItems(t *testing.T) { - m := &sync.Mutex{} - evicted := make(map[uint64]struct{}) - c, err := NewCache(&Config{ - NumCounters: 100, - MaxCost: 10, - BufferItems: 64, - IgnoreInternalCost: true, - Cost: func(value any) int64 { - return int64(value.(int)) - }, - OnEvict: func(item *Item) { - m.Lock() - defer m.Unlock() - evicted[item.Key] = struct{}{} - }, - }) - require.NoError(t, err) - - var key uint64 - var conflict uint64 - - key, conflict = defaultStringHash("1") - c.setBuf <- &Item{ - flag: itemNew, - Key: key, - Conflict: conflict, - Value: 1, - Cost: 0, - } - time.Sleep(wait) - require.True(t, c.policy.Has(key)) - require.Equal(t, int64(1), c.policy.Cost(key)) - - key, conflict = defaultStringHash("1") - c.setBuf <- &Item{ - flag: itemUpdate, - Key: key, - Conflict: conflict, - Value: 2, - Cost: 0, - } - time.Sleep(wait) - require.Equal(t, int64(2), c.policy.Cost(key)) - - key, conflict = defaultStringHash("1") - c.setBuf <- &Item{ - flag: itemDelete, - Key: key, - Conflict: conflict, - } - time.Sleep(wait) - key, conflict = defaultStringHash("1") - val, ok := c.store.Get(key, conflict) - require.False(t, ok) - require.Nil(t, val) - require.False(t, c.policy.Has(1)) - - key, conflict = defaultStringHash("2") - c.setBuf <- &Item{ - flag: itemNew, - Key: key, - Conflict: conflict, - Value: 2, - Cost: 3, - } - key, conflict = defaultStringHash("3") - c.setBuf <- &Item{ - flag: itemNew, - Key: key, - Conflict: conflict, - Value: 3, - Cost: 3, - } - key, conflict = defaultStringHash("4") - c.setBuf <- &Item{ - flag: itemNew, - Key: key, - Conflict: conflict, - Value: 3, - Cost: 3, - } - key, conflict = defaultStringHash("5") - c.setBuf <- &Item{ - flag: itemNew, - Key: key, - Conflict: conflict, - Value: 3, - Cost: 5, - } - time.Sleep(wait) - m.Lock() - require.NotEqual(t, 0, len(evicted)) - m.Unlock() - - defer func() { - require.NotNil(t, recover()) - }() - c.Close() - c.setBuf <- &Item{flag: itemNew} -} - -func TestCacheGet(t *testing.T) { - c, err := NewCache(&Config{ - NumCounters: 100, - MaxCost: 10, - BufferItems: 64, - IgnoreInternalCost: true, - Metrics: true, - }) - require.NoError(t, err) - - key, conflict := defaultStringHash("1") - i := Item{ - Key: key, - Conflict: conflict, - Value: 1, - } - c.store.Set(&i) - val, ok := c.Get("1") - require.True(t, ok) - require.NotNil(t, val) - - val, ok = c.Get("2") - require.False(t, ok) - require.Nil(t, val) - - // 0.5 and not 1.0 because we tried Getting each item twice - require.Equal(t, 0.5, c.Metrics.Ratio()) - - c = nil - val, ok = c.Get("0") - require.False(t, ok) - require.Nil(t, val) -} - -// retrySet calls SetWithCost until the item is accepted by the cache. -func retrySet(t *testing.T, c *Cache, key string, value int, cost int64) { - for { - if set := c.SetWithCost(key, value, cost); !set { - time.Sleep(wait) - continue - } - - time.Sleep(wait) - val, ok := c.Get(key) - require.True(t, ok) - require.NotNil(t, val) - require.Equal(t, value, val.(int)) - return - } -} - -func TestCacheSet(t *testing.T) { - c, err := NewCache(&Config{ - NumCounters: 100, - MaxCost: 10, - IgnoreInternalCost: true, - BufferItems: 64, - Metrics: true, - }) - require.NoError(t, err) - - retrySet(t, c, "1", 1, 1) - - c.SetWithCost("1", 2, 2) - val, ok := c.store.Get(defaultStringHash("1")) - require.True(t, ok) - require.Equal(t, 2, val.(int)) - - c.stop <- struct{}{} - for i := 0; i < setBufSize; i++ { - key, conflict := defaultStringHash("1") - c.setBuf <- &Item{ - flag: itemUpdate, - Key: key, - Conflict: conflict, - Value: 1, - Cost: 1, - } - } - require.False(t, c.SetWithCost("2", 2, 1)) - require.Equal(t, uint64(1), c.Metrics.SetsDropped()) - close(c.setBuf) - close(c.stop) - - c = nil - require.False(t, c.SetWithCost("1", 1, 1)) -} - -func TestCacheInternalCost(t *testing.T) { - c, err := NewCache(&Config{ - NumCounters: 100, - MaxCost: 10, - BufferItems: 64, - Metrics: true, - }) - require.NoError(t, err) - - // Get should return false because the cache's cost is too small to store the item - // when accounting for the internal cost. - c.SetWithCost("1", 1, 1) - time.Sleep(wait) - _, ok := c.Get("1") - require.False(t, ok) -} - -func TestCacheDel(t *testing.T) { - c, err := NewCache(&Config{ - NumCounters: 100, - MaxCost: 10, - BufferItems: 64, - }) - require.NoError(t, err) - - c.SetWithCost("1", 1, 1) - c.Delete("1") - // The deletes and sets are pushed through the setbuf. It might be possible - // that the delete is not processed before the following get is called. So - // wait for a millisecond for things to be processed. - time.Sleep(time.Millisecond) - val, ok := c.Get("1") - require.False(t, ok) - require.Nil(t, val) - - c = nil - defer func() { - require.Nil(t, recover()) - }() - c.Delete("1") -} - -func TestCacheClear(t *testing.T) { - c, err := NewCache(&Config{ - NumCounters: 100, - MaxCost: 10, - IgnoreInternalCost: true, - BufferItems: 64, - Metrics: true, - }) - require.NoError(t, err) - - for i := 0; i < 10; i++ { - c.SetWithCost(strconv.Itoa(i), i, 1) - } - time.Sleep(wait) - require.Equal(t, uint64(10), c.Metrics.KeysAdded()) - - c.Clear() - require.Equal(t, uint64(0), c.Metrics.KeysAdded()) - - for i := 0; i < 10; i++ { - val, ok := c.Get(strconv.Itoa(i)) - require.False(t, ok) - require.Nil(t, val) - } -} - -func TestCacheMetrics(t *testing.T) { - c, err := NewCache(&Config{ - NumCounters: 100, - MaxCost: 10, - IgnoreInternalCost: true, - BufferItems: 64, - Metrics: true, - }) - require.NoError(t, err) - - for i := 0; i < 10; i++ { - c.SetWithCost(strconv.Itoa(i), i, 1) - } - time.Sleep(wait) - m := c.Metrics - require.Equal(t, uint64(10), m.KeysAdded()) -} - -func TestMetrics(t *testing.T) { - newMetrics() -} - -func TestNilMetrics(t *testing.T) { - var m *Metrics - for _, f := range []func() uint64{ - m.Hits, - m.Misses, - m.KeysAdded, - m.KeysEvicted, - m.CostEvicted, - m.SetsDropped, - m.SetsRejected, - m.GetsDropped, - m.GetsKept, - } { - require.Equal(t, uint64(0), f()) - } -} - -func TestMetricsAddGet(t *testing.T) { - m := newMetrics() - m.add(hit, 1, 1) - m.add(hit, 2, 2) - m.add(hit, 3, 3) - require.Equal(t, uint64(6), m.Hits()) - - m = nil - m.add(hit, 1, 1) - require.Equal(t, uint64(0), m.Hits()) -} - -func TestMetricsRatio(t *testing.T) { - m := newMetrics() - require.Equal(t, float64(0), m.Ratio()) - - m.add(hit, 1, 1) - m.add(hit, 2, 2) - m.add(miss, 1, 1) - m.add(miss, 2, 2) - require.Equal(t, 0.5, m.Ratio()) - - m = nil - require.Equal(t, float64(0), m.Ratio()) -} - -func TestMetricsString(t *testing.T) { - m := newMetrics() - m.add(hit, 1, 1) - m.add(miss, 1, 1) - m.add(keyAdd, 1, 1) - m.add(keyUpdate, 1, 1) - m.add(keyEvict, 1, 1) - m.add(costAdd, 1, 1) - m.add(costEvict, 1, 1) - m.add(dropSets, 1, 1) - m.add(rejectSets, 1, 1) - m.add(dropGets, 1, 1) - m.add(keepGets, 1, 1) - require.Equal(t, uint64(1), m.Hits()) - require.Equal(t, uint64(1), m.Misses()) - require.Equal(t, 0.5, m.Ratio()) - require.Equal(t, uint64(1), m.KeysAdded()) - require.Equal(t, uint64(1), m.KeysUpdated()) - require.Equal(t, uint64(1), m.KeysEvicted()) - require.Equal(t, uint64(1), m.CostAdded()) - require.Equal(t, uint64(1), m.CostEvicted()) - require.Equal(t, uint64(1), m.SetsDropped()) - require.Equal(t, uint64(1), m.SetsRejected()) - require.Equal(t, uint64(1), m.GetsDropped()) - require.Equal(t, uint64(1), m.GetsKept()) - - require.NotEqual(t, 0, len(m.String())) - - m = nil - require.Equal(t, 0, len(m.String())) - - require.Equal(t, "unidentified", stringFor(doNotUse)) -} - -func TestCacheMetricsClear(t *testing.T) { - c, err := NewCache(&Config{ - NumCounters: 100, - MaxCost: 10, - BufferItems: 64, - Metrics: true, - }) - require.NoError(t, err) - - c.SetWithCost("1", 1, 1) - stop := make(chan struct{}) - go func() { - for { - select { - case <-stop: - return - default: - c.Get("1") - } - } - }() - time.Sleep(wait) - c.Clear() - stop <- struct{}{} - c.Metrics = nil - c.Metrics.Clear() -} - -// Regression test for bug https://github.com/dgraph-io/ristretto/issues/167 -func TestDropUpdates(t *testing.T) { - originalSetBugSize := setBufSize - defer func() { setBufSize = originalSetBugSize }() - - test := func() { - // dropppedMap stores the items dropped from the cache. - droppedMap := make(map[int]struct{}) - lastEvictedSet := int64(-1) - - var err error - handler := func(_ any, value any) { - v := value.(string) - lastEvictedSet, err = strconv.ParseInt(string(v), 10, 32) - require.NoError(t, err) - - _, ok := droppedMap[int(lastEvictedSet)] - if ok { - panic(fmt.Sprintf("val = %+v was dropped but it got evicted. Dropped items: %+v\n", - lastEvictedSet, droppedMap)) - } - } - - // This is important. The race condition shows up only when the setBuf - // is full and that's why we reduce the buf size here. The test will - // try to fill up the setbuf to it's capacity and then perform an - // update on a key. - setBufSize = 10 - - c, err := NewCache(&Config{ - NumCounters: 100, - MaxCost: 10, - BufferItems: 64, - Metrics: true, - OnEvict: func(item *Item) { - if item.Value != nil { - handler(nil, item.Value) - } - }, - }) - require.NoError(t, err) - - for i := 0; i < 5*setBufSize; i++ { - v := fmt.Sprintf("%0100d", i) - // We're updating the same key. - if !c.SetWithCost("0", v, 1) { - // The race condition doesn't show up without this sleep. - time.Sleep(time.Microsecond) - droppedMap[i] = struct{}{} - } - } - // Wait for all the items to be processed. - c.Wait() - // This will cause eviction from the cache. - require.True(t, c.SetWithCost("1", nil, 10)) - c.Close() - } - - // Run the test 100 times since it's not reliable. - for i := 0; i < 100; i++ { - test() - } -} - -func newTestCache() (*Cache, error) { - return NewCache(&Config{ - NumCounters: 100, - MaxCost: 10, - BufferItems: 64, - Metrics: true, - }) -} diff --git a/go/cache/ristretto/policy.go b/go/cache/ristretto/policy.go deleted file mode 100644 index 84cc008cb99..00000000000 --- a/go/cache/ristretto/policy.go +++ /dev/null @@ -1,423 +0,0 @@ -/* - * Copyright 2020 Dgraph Labs, Inc. and Contributors - * Copyright 2021 The Vitess Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package ristretto - -import ( - "math" - "sync" - "sync/atomic" - - "vitess.io/vitess/go/cache/ristretto/bloom" -) - -const ( - // lfuSample is the number of items to sample when looking at eviction - // candidates. 5 seems to be the most optimal number [citation needed]. - lfuSample = 5 -) - -// policy is the interface encapsulating eviction/admission behavior. -// -// TODO: remove this interface and just rename defaultPolicy to policy, as we -// -// are probably only going to use/implement/maintain one policy. -type policy interface { - ringConsumer - // Add attempts to Add the key-cost pair to the Policy. It returns a slice - // of evicted keys and a bool denoting whether or not the key-cost pair - // was added. If it returns true, the key should be stored in cache. - Add(uint64, int64) ([]*Item, bool) - // Has returns true if the key exists in the Policy. - Has(uint64) bool - // Del deletes the key from the Policy. - Del(uint64) - // Cap returns the amount of used capacity. - Used() int64 - // Close stops all goroutines and closes all channels. - Close() - // Update updates the cost value for the key. - Update(uint64, int64) - // Cost returns the cost value of a key or -1 if missing. - Cost(uint64) int64 - // Optionally, set stats object to track how policy is performing. - CollectMetrics(*Metrics) - // Clear zeroes out all counters and clears hashmaps. - Clear() - // MaxCost returns the current max cost of the cache policy. - MaxCost() int64 - // UpdateMaxCost updates the max cost of the cache policy. - UpdateMaxCost(int64) -} - -func newPolicy(numCounters, maxCost int64) policy { - return newDefaultPolicy(numCounters, maxCost) -} - -type defaultPolicy struct { - sync.Mutex - admit *tinyLFU - evict *sampledLFU - itemsCh chan []uint64 - stop chan struct{} - isClosed bool - metrics *Metrics - numCounters int64 - maxCost int64 -} - -func newDefaultPolicy(numCounters, maxCost int64) *defaultPolicy { - p := &defaultPolicy{ - admit: newTinyLFU(numCounters), - evict: newSampledLFU(maxCost), - itemsCh: make(chan []uint64, 3), - stop: make(chan struct{}), - numCounters: numCounters, - maxCost: maxCost, - } - go p.processItems() - return p -} - -func (p *defaultPolicy) CollectMetrics(metrics *Metrics) { - p.metrics = metrics - p.evict.metrics = metrics -} - -type policyPair struct { - key uint64 - cost int64 -} - -func (p *defaultPolicy) processItems() { - for { - select { - case items := <-p.itemsCh: - p.Lock() - p.admit.Push(items) - p.Unlock() - case <-p.stop: - return - } - } -} - -func (p *defaultPolicy) Push(keys []uint64) bool { - if p.isClosed { - return false - } - - if len(keys) == 0 { - return true - } - - select { - case p.itemsCh <- keys: - p.metrics.add(keepGets, keys[0], uint64(len(keys))) - return true - default: - p.metrics.add(dropGets, keys[0], uint64(len(keys))) - return false - } -} - -// Add decides whether the item with the given key and cost should be accepted by -// the policy. It returns the list of victims that have been evicted and a boolean -// indicating whether the incoming item should be accepted. -func (p *defaultPolicy) Add(key uint64, cost int64) ([]*Item, bool) { - p.Lock() - defer p.Unlock() - - // Cannot add an item bigger than entire cache. - if cost > p.evict.getMaxCost() { - return nil, false - } - - // No need to go any further if the item is already in the cache. - if has := p.evict.updateIfHas(key, cost); has { - // An update does not count as an addition, so return false. - return nil, false - } - - // If the execution reaches this point, the key doesn't exist in the cache. - // Calculate the remaining room in the cache (usually bytes). - room := p.evict.roomLeft(cost) - if room >= 0 { - // There's enough room in the cache to store the new item without - // overflowing. Do that now and stop here. - p.evict.add(key, cost) - p.metrics.add(costAdd, key, uint64(cost)) - return nil, true - } - - // incHits is the hit count for the incoming item. - incHits := p.admit.Estimate(key) - // sample is the eviction candidate pool to be filled via random sampling. - // TODO: perhaps we should use a min heap here. Right now our time - // complexity is N for finding the min. Min heap should bring it down to - // O(lg N). - sample := make([]*policyPair, 0, lfuSample) - // As items are evicted they will be appended to victims. - victims := make([]*Item, 0) - - // Delete victims until there's enough space or a minKey is found that has - // more hits than incoming item. - for ; room < 0; room = p.evict.roomLeft(cost) { - // Fill up empty slots in sample. - sample = p.evict.fillSample(sample) - - // Find minimally used item in sample. - minKey, minHits, minID, minCost := uint64(0), int64(math.MaxInt64), 0, int64(0) - for i, pair := range sample { - // Look up hit count for sample key. - if hits := p.admit.Estimate(pair.key); hits < minHits { - minKey, minHits, minID, minCost = pair.key, hits, i, pair.cost - } - } - - // If the incoming item isn't worth keeping in the policy, reject. - if incHits < minHits { - p.metrics.add(rejectSets, key, 1) - return victims, false - } - - // Delete the victim from metadata. - p.evict.del(minKey) - - // Delete the victim from sample. - sample[minID] = sample[len(sample)-1] - sample = sample[:len(sample)-1] - // Store victim in evicted victims slice. - victims = append(victims, &Item{ - Key: minKey, - Conflict: 0, - Cost: minCost, - }) - } - - p.evict.add(key, cost) - p.metrics.add(costAdd, key, uint64(cost)) - return victims, true -} - -func (p *defaultPolicy) Has(key uint64) bool { - p.Lock() - _, exists := p.evict.keyCosts[key] - p.Unlock() - return exists -} - -func (p *defaultPolicy) Del(key uint64) { - p.Lock() - p.evict.del(key) - p.Unlock() -} - -func (p *defaultPolicy) Used() int64 { - p.Lock() - used := p.evict.used - p.Unlock() - return used -} - -func (p *defaultPolicy) Update(key uint64, cost int64) { - p.Lock() - p.evict.updateIfHas(key, cost) - p.Unlock() -} - -func (p *defaultPolicy) Cost(key uint64) int64 { - p.Lock() - if cost, found := p.evict.keyCosts[key]; found { - p.Unlock() - return cost - } - p.Unlock() - return -1 -} - -func (p *defaultPolicy) Clear() { - p.Lock() - p.admit = newTinyLFU(p.numCounters) - p.evict = newSampledLFU(p.maxCost) - p.Unlock() -} - -func (p *defaultPolicy) Close() { - if p.isClosed { - return - } - - // Block until the p.processItems goroutine returns. - p.stop <- struct{}{} - close(p.stop) - close(p.itemsCh) - p.isClosed = true -} - -func (p *defaultPolicy) MaxCost() int64 { - if p == nil || p.evict == nil { - return 0 - } - return p.evict.getMaxCost() -} - -func (p *defaultPolicy) UpdateMaxCost(maxCost int64) { - if p == nil || p.evict == nil { - return - } - p.evict.updateMaxCost(maxCost) -} - -// sampledLFU is an eviction helper storing key-cost pairs. -type sampledLFU struct { - keyCosts map[uint64]int64 - maxCost int64 - used int64 - metrics *Metrics -} - -func newSampledLFU(maxCost int64) *sampledLFU { - return &sampledLFU{ - keyCosts: make(map[uint64]int64), - maxCost: maxCost, - } -} - -func (p *sampledLFU) getMaxCost() int64 { - return atomic.LoadInt64(&p.maxCost) -} - -func (p *sampledLFU) updateMaxCost(maxCost int64) { - atomic.StoreInt64(&p.maxCost, maxCost) -} - -func (p *sampledLFU) roomLeft(cost int64) int64 { - return p.getMaxCost() - (p.used + cost) -} - -func (p *sampledLFU) fillSample(in []*policyPair) []*policyPair { - if len(in) >= lfuSample { - return in - } - for key, cost := range p.keyCosts { - in = append(in, &policyPair{key, cost}) - if len(in) >= lfuSample { - return in - } - } - return in -} - -func (p *sampledLFU) del(key uint64) { - cost, ok := p.keyCosts[key] - if !ok { - return - } - p.used -= cost - delete(p.keyCosts, key) - p.metrics.add(costEvict, key, uint64(cost)) - p.metrics.add(keyEvict, key, 1) -} - -func (p *sampledLFU) add(key uint64, cost int64) { - p.keyCosts[key] = cost - p.used += cost -} - -func (p *sampledLFU) updateIfHas(key uint64, cost int64) bool { - if prev, found := p.keyCosts[key]; found { - // Update the cost of an existing key, but don't worry about evicting. - // Evictions will be handled the next time a new item is added. - p.metrics.add(keyUpdate, key, 1) - if prev > cost { - diff := prev - cost - p.metrics.add(costAdd, key, ^uint64(uint64(diff)-1)) - } else if cost > prev { - diff := cost - prev - p.metrics.add(costAdd, key, uint64(diff)) - } - p.used += cost - prev - p.keyCosts[key] = cost - return true - } - return false -} - -func (p *sampledLFU) clear() { - p.used = 0 - p.keyCosts = make(map[uint64]int64) -} - -// tinyLFU is an admission helper that keeps track of access frequency using -// tiny (4-bit) counters in the form of a count-min sketch. -// tinyLFU is NOT thread safe. -type tinyLFU struct { - freq *cmSketch - door *bloom.Bloom - incrs int64 - resetAt int64 -} - -func newTinyLFU(numCounters int64) *tinyLFU { - return &tinyLFU{ - freq: newCmSketch(numCounters), - door: bloom.NewBloomFilterWithErrorRate(uint64(numCounters), 0.01), - resetAt: numCounters, - } -} - -func (p *tinyLFU) Push(keys []uint64) { - for _, key := range keys { - p.Increment(key) - } -} - -func (p *tinyLFU) Estimate(key uint64) int64 { - hits := p.freq.Estimate(key) - if p.door.Has(key) { - hits++ - } - return hits -} - -func (p *tinyLFU) Increment(key uint64) { - // Flip doorkeeper bit if not already done. - if added := p.door.AddIfNotHas(key); !added { - // Increment count-min counter if doorkeeper bit is already set. - p.freq.Increment(key) - } - p.incrs++ - if p.incrs >= p.resetAt { - p.reset() - } -} - -func (p *tinyLFU) reset() { - // Zero out incrs. - p.incrs = 0 - // clears doorkeeper bits - p.door.Clear() - // halves count-min counters - p.freq.Reset() -} - -func (p *tinyLFU) clear() { - p.incrs = 0 - p.freq.Clear() - p.door.Clear() -} diff --git a/go/cache/ristretto/policy_test.go b/go/cache/ristretto/policy_test.go deleted file mode 100644 index c864b6c74d0..00000000000 --- a/go/cache/ristretto/policy_test.go +++ /dev/null @@ -1,276 +0,0 @@ -/* - * Copyright 2020 Dgraph Labs, Inc. and Contributors - * Copyright 2021 The Vitess Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package ristretto - -import ( - "testing" - "time" - - "github.com/stretchr/testify/require" -) - -func TestPolicy(t *testing.T) { - defer func() { - require.Nil(t, recover()) - }() - newPolicy(100, 10) -} - -func TestPolicyMetrics(t *testing.T) { - p := newDefaultPolicy(100, 10) - p.CollectMetrics(newMetrics()) - require.NotNil(t, p.metrics) - require.NotNil(t, p.evict.metrics) -} - -func TestPolicyProcessItems(t *testing.T) { - p := newDefaultPolicy(100, 10) - p.itemsCh <- []uint64{1, 2, 2} - time.Sleep(wait) - p.Lock() - require.Equal(t, int64(2), p.admit.Estimate(2)) - require.Equal(t, int64(1), p.admit.Estimate(1)) - p.Unlock() - - p.stop <- struct{}{} - p.itemsCh <- []uint64{3, 3, 3} - time.Sleep(wait) - p.Lock() - require.Equal(t, int64(0), p.admit.Estimate(3)) - p.Unlock() -} - -func TestPolicyPush(t *testing.T) { - p := newDefaultPolicy(100, 10) - require.True(t, p.Push([]uint64{})) - - keepCount := 0 - for i := 0; i < 10; i++ { - if p.Push([]uint64{1, 2, 3, 4, 5}) { - keepCount++ - } - } - require.NotEqual(t, 0, keepCount) -} - -func TestPolicyAdd(t *testing.T) { - p := newDefaultPolicy(1000, 100) - if victims, added := p.Add(1, 101); victims != nil || added { - t.Fatal("can't add an item bigger than entire cache") - } - p.Lock() - p.evict.add(1, 1) - p.admit.Increment(1) - p.admit.Increment(2) - p.admit.Increment(3) - p.Unlock() - - victims, added := p.Add(1, 1) - require.Nil(t, victims) - require.False(t, added) - - victims, added = p.Add(2, 20) - require.Nil(t, victims) - require.True(t, added) - - victims, added = p.Add(3, 90) - require.NotNil(t, victims) - require.True(t, added) - - victims, added = p.Add(4, 20) - require.NotNil(t, victims) - require.False(t, added) -} - -func TestPolicyHas(t *testing.T) { - p := newDefaultPolicy(100, 10) - p.Add(1, 1) - require.True(t, p.Has(1)) - require.False(t, p.Has(2)) -} - -func TestPolicyDel(t *testing.T) { - p := newDefaultPolicy(100, 10) - p.Add(1, 1) - p.Del(1) - p.Del(2) - require.False(t, p.Has(1)) - require.False(t, p.Has(2)) -} - -func TestPolicyCap(t *testing.T) { - p := newDefaultPolicy(100, 10) - p.Add(1, 1) - require.Equal(t, int64(9), p.MaxCost()-p.Used()) -} - -func TestPolicyUpdate(t *testing.T) { - p := newDefaultPolicy(100, 10) - p.Add(1, 1) - p.Update(1, 2) - p.Lock() - require.Equal(t, int64(2), p.evict.keyCosts[1]) - p.Unlock() -} - -func TestPolicyCost(t *testing.T) { - p := newDefaultPolicy(100, 10) - p.Add(1, 2) - require.Equal(t, int64(2), p.Cost(1)) - require.Equal(t, int64(-1), p.Cost(2)) -} - -func TestPolicyClear(t *testing.T) { - p := newDefaultPolicy(100, 10) - p.Add(1, 1) - p.Add(2, 2) - p.Add(3, 3) - p.Clear() - require.Equal(t, int64(10), p.MaxCost()-p.Used()) - require.False(t, p.Has(1)) - require.False(t, p.Has(2)) - require.False(t, p.Has(3)) -} - -func TestPolicyClose(t *testing.T) { - defer func() { - require.NotNil(t, recover()) - }() - - p := newDefaultPolicy(100, 10) - p.Add(1, 1) - p.Close() - p.itemsCh <- []uint64{1} -} - -func TestPushAfterClose(t *testing.T) { - p := newDefaultPolicy(100, 10) - p.Close() - require.False(t, p.Push([]uint64{1, 2})) -} - -func TestAddAfterClose(t *testing.T) { - p := newDefaultPolicy(100, 10) - p.Close() - p.Add(1, 1) -} - -func TestSampledLFUAdd(t *testing.T) { - e := newSampledLFU(4) - e.add(1, 1) - e.add(2, 2) - e.add(3, 1) - require.Equal(t, int64(4), e.used) - require.Equal(t, int64(2), e.keyCosts[2]) -} - -func TestSampledLFUDel(t *testing.T) { - e := newSampledLFU(4) - e.add(1, 1) - e.add(2, 2) - e.del(2) - require.Equal(t, int64(1), e.used) - _, ok := e.keyCosts[2] - require.False(t, ok) - e.del(4) -} - -func TestSampledLFUUpdate(t *testing.T) { - e := newSampledLFU(4) - e.add(1, 1) - require.True(t, e.updateIfHas(1, 2)) - require.Equal(t, int64(2), e.used) - require.False(t, e.updateIfHas(2, 2)) -} - -func TestSampledLFUClear(t *testing.T) { - e := newSampledLFU(4) - e.add(1, 1) - e.add(2, 2) - e.add(3, 1) - e.clear() - require.Equal(t, 0, len(e.keyCosts)) - require.Equal(t, int64(0), e.used) -} - -func TestSampledLFURoom(t *testing.T) { - e := newSampledLFU(16) - e.add(1, 1) - e.add(2, 2) - e.add(3, 3) - require.Equal(t, int64(6), e.roomLeft(4)) -} - -func TestSampledLFUSample(t *testing.T) { - e := newSampledLFU(16) - e.add(4, 4) - e.add(5, 5) - sample := e.fillSample([]*policyPair{ - {1, 1}, - {2, 2}, - {3, 3}, - }) - k := sample[len(sample)-1].key - require.Equal(t, 5, len(sample)) - require.NotEqual(t, 1, k) - require.NotEqual(t, 2, k) - require.NotEqual(t, 3, k) - require.Equal(t, len(sample), len(e.fillSample(sample))) - e.del(5) - sample = e.fillSample(sample[:len(sample)-2]) - require.Equal(t, 4, len(sample)) -} - -func TestTinyLFUIncrement(t *testing.T) { - a := newTinyLFU(4) - a.Increment(1) - a.Increment(1) - a.Increment(1) - require.True(t, a.door.Has(1)) - require.Equal(t, int64(2), a.freq.Estimate(1)) - - a.Increment(1) - require.False(t, a.door.Has(1)) - require.Equal(t, int64(1), a.freq.Estimate(1)) -} - -func TestTinyLFUEstimate(t *testing.T) { - a := newTinyLFU(8) - a.Increment(1) - a.Increment(1) - a.Increment(1) - require.Equal(t, int64(3), a.Estimate(1)) - require.Equal(t, int64(0), a.Estimate(2)) -} - -func TestTinyLFUPush(t *testing.T) { - a := newTinyLFU(16) - a.Push([]uint64{1, 2, 2, 3, 3, 3}) - require.Equal(t, int64(1), a.Estimate(1)) - require.Equal(t, int64(2), a.Estimate(2)) - require.Equal(t, int64(3), a.Estimate(3)) - require.Equal(t, int64(6), a.incrs) -} - -func TestTinyLFUClear(t *testing.T) { - a := newTinyLFU(16) - a.Push([]uint64{1, 3, 3, 3}) - a.clear() - require.Equal(t, int64(0), a.incrs) - require.Equal(t, int64(0), a.Estimate(3)) -} diff --git a/go/cache/ristretto/ring.go b/go/cache/ristretto/ring.go deleted file mode 100644 index 84d8689ee37..00000000000 --- a/go/cache/ristretto/ring.go +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Copyright 2019 Dgraph Labs, Inc. and Contributors - * Copyright 2021 The Vitess Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package ristretto - -import ( - "sync" -) - -// ringConsumer is the user-defined object responsible for receiving and -// processing items in batches when buffers are drained. -type ringConsumer interface { - Push([]uint64) bool -} - -// ringStripe is a singular ring buffer that is not concurrent safe. -type ringStripe struct { - cons ringConsumer - data []uint64 - capa int -} - -func newRingStripe(cons ringConsumer, capa int64) *ringStripe { - return &ringStripe{ - cons: cons, - data: make([]uint64, 0, capa), - capa: int(capa), - } -} - -// Push appends an item in the ring buffer and drains (copies items and -// sends to Consumer) if full. -func (s *ringStripe) Push(item uint64) { - s.data = append(s.data, item) - // Decide if the ring buffer should be drained. - if len(s.data) >= s.capa { - // Send elements to consumer and create a new ring stripe. - if s.cons.Push(s.data) { - s.data = make([]uint64, 0, s.capa) - } else { - s.data = s.data[:0] - } - } -} - -// ringBuffer stores multiple buffers (stripes) and distributes Pushed items -// between them to lower contention. -// -// This implements the "batching" process described in the BP-Wrapper paper -// (section III part A). -type ringBuffer struct { - pool *sync.Pool -} - -// newRingBuffer returns a striped ring buffer. The Consumer in ringConfig will -// be called when individual stripes are full and need to drain their elements. -func newRingBuffer(cons ringConsumer, capa int64) *ringBuffer { - // LOSSY buffers use a very simple sync.Pool for concurrently reusing - // stripes. We do lose some stripes due to GC (unheld items in sync.Pool - // are cleared), but the performance gains generally outweigh the small - // percentage of elements lost. The performance primarily comes from - // low-level runtime functions used in the standard library that aren't - // available to us (such as runtime_procPin()). - return &ringBuffer{ - pool: &sync.Pool{ - New: func() any { return newRingStripe(cons, capa) }, - }, - } -} - -// Push adds an element to one of the internal stripes and possibly drains if -// the stripe becomes full. -func (b *ringBuffer) Push(item uint64) { - // Reuse or create a new stripe. - stripe := b.pool.Get().(*ringStripe) - stripe.Push(item) - b.pool.Put(stripe) -} diff --git a/go/cache/ristretto/ring_test.go b/go/cache/ristretto/ring_test.go deleted file mode 100644 index 0dbe962ccc6..00000000000 --- a/go/cache/ristretto/ring_test.go +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright 2020 Dgraph Labs, Inc. and Contributors - * Copyright 2021 The Vitess Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package ristretto - -import ( - "sync" - "testing" - - "github.com/stretchr/testify/require" -) - -type testConsumer struct { - push func([]uint64) - save bool -} - -func (c *testConsumer) Push(items []uint64) bool { - if c.save { - c.push(items) - return true - } - return false -} - -func TestRingDrain(t *testing.T) { - drains := 0 - r := newRingBuffer(&testConsumer{ - push: func(items []uint64) { - drains++ - }, - save: true, - }, 1) - for i := 0; i < 100; i++ { - r.Push(uint64(i)) - } - require.Equal(t, 100, drains, "buffers shouldn't be dropped with BufferItems == 1") -} - -func TestRingReset(t *testing.T) { - drains := 0 - r := newRingBuffer(&testConsumer{ - push: func(items []uint64) { - drains++ - }, - save: false, - }, 4) - for i := 0; i < 100; i++ { - r.Push(uint64(i)) - } - require.Equal(t, 0, drains, "testConsumer shouldn't be draining") -} - -func TestRingConsumer(t *testing.T) { - mu := &sync.Mutex{} - drainItems := make(map[uint64]struct{}) - r := newRingBuffer(&testConsumer{ - push: func(items []uint64) { - mu.Lock() - defer mu.Unlock() - for i := range items { - drainItems[items[i]] = struct{}{} - } - }, - save: true, - }, 4) - for i := 0; i < 100; i++ { - r.Push(uint64(i)) - } - l := len(drainItems) - require.NotEqual(t, 0, l) - require.True(t, l <= 100) -} diff --git a/go/cache/ristretto/sketch.go b/go/cache/ristretto/sketch.go deleted file mode 100644 index c8ad31e8494..00000000000 --- a/go/cache/ristretto/sketch.go +++ /dev/null @@ -1,154 +0,0 @@ -/* - * Copyright 2019 Dgraph Labs, Inc. and Contributors - * Copyright 2021 The Vitess Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// Package ristretto includes multiple probabalistic data structures needed for -// admission/eviction metadata. Most are Counting Bloom Filter variations, but -// a caching-specific feature that is also required is a "freshness" mechanism, -// which basically serves as a "lifetime" process. This freshness mechanism -// was described in the original TinyLFU paper [1], but other mechanisms may -// be better suited for certain data distributions. -// -// [1]: https://arxiv.org/abs/1512.00727 -package ristretto - -import ( - "fmt" - "math/rand" - "time" -) - -// cmSketch is a Count-Min sketch implementation with 4-bit counters, heavily -// based on Damian Gryski's CM4 [1]. -// -// [1]: https://github.com/dgryski/go-tinylfu/blob/master/cm4.go -type cmSketch struct { - rows [cmDepth]cmRow - seed [cmDepth]uint64 - mask uint64 -} - -const ( - // cmDepth is the number of counter copies to store (think of it as rows). - cmDepth = 4 -) - -func newCmSketch(numCounters int64) *cmSketch { - if numCounters == 0 { - panic("cmSketch: bad numCounters") - } - // Get the next power of 2 for better cache performance. - numCounters = next2Power(numCounters) - sketch := &cmSketch{mask: uint64(numCounters - 1)} - // Initialize rows of counters and seeds. - source := rand.New(rand.NewSource(time.Now().UnixNano())) - for i := 0; i < cmDepth; i++ { - sketch.seed[i] = source.Uint64() - sketch.rows[i] = newCmRow(numCounters) - } - return sketch -} - -// Increment increments the count(ers) for the specified key. -func (s *cmSketch) Increment(hashed uint64) { - for i := range s.rows { - s.rows[i].increment((hashed ^ s.seed[i]) & s.mask) - } -} - -// Estimate returns the value of the specified key. -func (s *cmSketch) Estimate(hashed uint64) int64 { - min := byte(255) - for i := range s.rows { - val := s.rows[i].get((hashed ^ s.seed[i]) & s.mask) - if val < min { - min = val - } - } - return int64(min) -} - -// Reset halves all counter values. -func (s *cmSketch) Reset() { - for _, r := range s.rows { - r.reset() - } -} - -// Clear zeroes all counters. -func (s *cmSketch) Clear() { - for _, r := range s.rows { - r.clear() - } -} - -// cmRow is a row of bytes, with each byte holding two counters. -type cmRow []byte - -func newCmRow(numCounters int64) cmRow { - return make(cmRow, numCounters/2) -} - -func (r cmRow) get(n uint64) byte { - return byte(r[n/2]>>((n&1)*4)) & 0x0f -} - -func (r cmRow) increment(n uint64) { - // Index of the counter. - i := n / 2 - // Shift distance (even 0, odd 4). - s := (n & 1) * 4 - // Counter value. - v := (r[i] >> s) & 0x0f - // Only increment if not max value (overflow wrap is bad for LFU). - if v < 15 { - r[i] += 1 << s - } -} - -func (r cmRow) reset() { - // Halve each counter. - for i := range r { - r[i] = (r[i] >> 1) & 0x77 - } -} - -func (r cmRow) clear() { - // Zero each counter. - clear(r) -} - -func (r cmRow) string() string { - s := "" - for i := uint64(0); i < uint64(len(r)*2); i++ { - s += fmt.Sprintf("%02d ", (r[(i/2)]>>((i&1)*4))&0x0f) - } - s = s[:len(s)-1] - return s -} - -// next2Power rounds x up to the next power of 2, if it's not already one. -func next2Power(x int64) int64 { - x-- - x |= x >> 1 - x |= x >> 2 - x |= x >> 4 - x |= x >> 8 - x |= x >> 16 - x |= x >> 32 - x++ - return x -} diff --git a/go/cache/ristretto/sketch_test.go b/go/cache/ristretto/sketch_test.go deleted file mode 100644 index 03804a6d599..00000000000 --- a/go/cache/ristretto/sketch_test.go +++ /dev/null @@ -1,104 +0,0 @@ -/* - * Copyright 2020 Dgraph Labs, Inc. and Contributors - * Copyright 2021 The Vitess Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package ristretto - -import ( - "testing" - - "vitess.io/vitess/go/vt/log" - - "github.com/stretchr/testify/require" -) - -func TestSketch(t *testing.T) { - defer func() { - require.NotNil(t, recover()) - }() - - s := newCmSketch(5) - require.Equal(t, uint64(7), s.mask) - newCmSketch(0) -} - -func TestSketchIncrement(t *testing.T) { - s := newCmSketch(16) - s.Increment(1) - s.Increment(5) - s.Increment(9) - for i := 0; i < cmDepth; i++ { - if s.rows[i].string() != s.rows[0].string() { - break - } - require.False(t, i == cmDepth-1, "identical rows, bad seeding") - } -} - -func TestSketchEstimate(t *testing.T) { - s := newCmSketch(16) - s.Increment(1) - s.Increment(1) - require.Equal(t, int64(2), s.Estimate(1)) - require.Equal(t, int64(0), s.Estimate(0)) -} - -func TestSketchReset(t *testing.T) { - s := newCmSketch(16) - s.Increment(1) - s.Increment(1) - s.Increment(1) - s.Increment(1) - s.Reset() - require.Equal(t, int64(2), s.Estimate(1)) -} - -func TestSketchClear(t *testing.T) { - s := newCmSketch(16) - for i := 0; i < 16; i++ { - s.Increment(uint64(i)) - } - s.Clear() - for i := 0; i < 16; i++ { - require.Equal(t, int64(0), s.Estimate(uint64(i))) - } -} - -func TestNext2Power(t *testing.T) { - sz := 12 << 30 - szf := float64(sz) * 0.01 - val := int64(szf) - log.Infof("szf = %.2f val = %d\n", szf, val) - pow := next2Power(val) - log.Infof("pow = %d. mult 4 = %d\n", pow, pow*4) -} - -func BenchmarkSketchIncrement(b *testing.B) { - s := newCmSketch(16) - b.SetBytes(1) - for n := 0; n < b.N; n++ { - s.Increment(1) - } -} - -func BenchmarkSketchEstimate(b *testing.B) { - s := newCmSketch(16) - s.Increment(1) - b.SetBytes(1) - for n := 0; n < b.N; n++ { - s.Estimate(1) - } -} diff --git a/go/cache/ristretto/store.go b/go/cache/ristretto/store.go deleted file mode 100644 index 0e455e7052f..00000000000 --- a/go/cache/ristretto/store.go +++ /dev/null @@ -1,240 +0,0 @@ -/* - * Copyright 2019 Dgraph Labs, Inc. and Contributors - * Copyright 2021 The Vitess Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package ristretto - -import ( - "sync" -) - -// TODO: Do we need this to be a separate struct from Item? -type storeItem struct { - key uint64 - conflict uint64 - value any -} - -// store is the interface fulfilled by all hash map implementations in this -// file. Some hash map implementations are better suited for certain data -// distributions than others, so this allows us to abstract that out for use -// in Ristretto. -// -// Every store is safe for concurrent usage. -type store interface { - // Get returns the value associated with the key parameter. - Get(uint64, uint64) (any, bool) - // Set adds the key-value pair to the Map or updates the value if it's - // already present. The key-value pair is passed as a pointer to an - // item object. - Set(*Item) - // Del deletes the key-value pair from the Map. - Del(uint64, uint64) (uint64, any) - // Update attempts to update the key with a new value and returns true if - // successful. - Update(*Item) (any, bool) - // Clear clears all contents of the store. - Clear(onEvict itemCallback) - // ForEach yields all the values in the store - ForEach(forEach func(any) bool) - // Len returns the number of entries in the store - Len() int -} - -// newStore returns the default store implementation. -func newStore() store { - return newShardedMap() -} - -const numShards uint64 = 256 - -type shardedMap struct { - shards []*lockedMap -} - -func newShardedMap() *shardedMap { - sm := &shardedMap{ - shards: make([]*lockedMap, int(numShards)), - } - for i := range sm.shards { - sm.shards[i] = newLockedMap() - } - return sm -} - -func (sm *shardedMap) Get(key, conflict uint64) (any, bool) { - return sm.shards[key%numShards].get(key, conflict) -} - -func (sm *shardedMap) Set(i *Item) { - if i == nil { - // If item is nil make this Set a no-op. - return - } - - sm.shards[i.Key%numShards].Set(i) -} - -func (sm *shardedMap) Del(key, conflict uint64) (uint64, any) { - return sm.shards[key%numShards].Del(key, conflict) -} - -func (sm *shardedMap) Update(newItem *Item) (any, bool) { - return sm.shards[newItem.Key%numShards].Update(newItem) -} - -func (sm *shardedMap) ForEach(forEach func(any) bool) { - for _, shard := range sm.shards { - if !shard.foreach(forEach) { - break - } - } -} - -func (sm *shardedMap) Len() int { - l := 0 - for _, shard := range sm.shards { - l += shard.Len() - } - return l -} - -func (sm *shardedMap) Clear(onEvict itemCallback) { - for i := uint64(0); i < numShards; i++ { - sm.shards[i].Clear(onEvict) - } -} - -type lockedMap struct { - sync.RWMutex - data map[uint64]storeItem -} - -func newLockedMap() *lockedMap { - return &lockedMap{ - data: make(map[uint64]storeItem), - } -} - -func (m *lockedMap) get(key, conflict uint64) (any, bool) { - m.RLock() - item, ok := m.data[key] - m.RUnlock() - if !ok { - return nil, false - } - if conflict != 0 && (conflict != item.conflict) { - return nil, false - } - return item.value, true -} - -func (m *lockedMap) Set(i *Item) { - if i == nil { - // If the item is nil make this Set a no-op. - return - } - - m.Lock() - defer m.Unlock() - item, ok := m.data[i.Key] - - if ok { - // The item existed already. We need to check the conflict key and reject the - // update if they do not match. Only after that the expiration map is updated. - if i.Conflict != 0 && (i.Conflict != item.conflict) { - return - } - } - - m.data[i.Key] = storeItem{ - key: i.Key, - conflict: i.Conflict, - value: i.Value, - } -} - -func (m *lockedMap) Del(key, conflict uint64) (uint64, any) { - m.Lock() - item, ok := m.data[key] - if !ok { - m.Unlock() - return 0, nil - } - if conflict != 0 && (conflict != item.conflict) { - m.Unlock() - return 0, nil - } - - delete(m.data, key) - m.Unlock() - return item.conflict, item.value -} - -func (m *lockedMap) Update(newItem *Item) (any, bool) { - m.Lock() - item, ok := m.data[newItem.Key] - if !ok { - m.Unlock() - return nil, false - } - if newItem.Conflict != 0 && (newItem.Conflict != item.conflict) { - m.Unlock() - return nil, false - } - - m.data[newItem.Key] = storeItem{ - key: newItem.Key, - conflict: newItem.Conflict, - value: newItem.Value, - } - - m.Unlock() - return item.value, true -} - -func (m *lockedMap) Len() int { - m.RLock() - l := len(m.data) - m.RUnlock() - return l -} - -func (m *lockedMap) Clear(onEvict itemCallback) { - m.Lock() - i := &Item{} - if onEvict != nil { - for _, si := range m.data { - i.Key = si.key - i.Conflict = si.conflict - i.Value = si.value - onEvict(i) - } - } - m.data = make(map[uint64]storeItem) - m.Unlock() -} - -func (m *lockedMap) foreach(forEach func(any) bool) bool { - m.RLock() - defer m.RUnlock() - for _, si := range m.data { - if !forEach(si.value) { - return false - } - } - return true -} diff --git a/go/cache/ristretto/store_test.go b/go/cache/ristretto/store_test.go deleted file mode 100644 index 54634736a72..00000000000 --- a/go/cache/ristretto/store_test.go +++ /dev/null @@ -1,224 +0,0 @@ -/* - * Copyright 2020 Dgraph Labs, Inc. and Contributors - * Copyright 2021 The Vitess Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package ristretto - -import ( - "strconv" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestStoreSetGet(t *testing.T) { - s := newStore() - key, conflict := defaultStringHash("1") - i := Item{ - Key: key, - Conflict: conflict, - Value: 2, - } - s.Set(&i) - val, ok := s.Get(key, conflict) - require.True(t, ok) - require.Equal(t, 2, val.(int)) - - i.Value = 3 - s.Set(&i) - val, ok = s.Get(key, conflict) - require.True(t, ok) - require.Equal(t, 3, val.(int)) - - key, conflict = defaultStringHash("2") - i = Item{ - Key: key, - Conflict: conflict, - Value: 2, - } - s.Set(&i) - val, ok = s.Get(key, conflict) - require.True(t, ok) - require.Equal(t, 2, val.(int)) -} - -func TestStoreDel(t *testing.T) { - s := newStore() - key, conflict := defaultStringHash("1") - i := Item{ - Key: key, - Conflict: conflict, - Value: 1, - } - s.Set(&i) - s.Del(key, conflict) - val, ok := s.Get(key, conflict) - require.False(t, ok) - require.Nil(t, val) - - s.Del(2, 0) -} - -func TestStoreClear(t *testing.T) { - s := newStore() - for i := 0; i < 1000; i++ { - key, conflict := defaultStringHash(strconv.Itoa(i)) - it := Item{ - Key: key, - Conflict: conflict, - Value: i, - } - s.Set(&it) - } - s.Clear(nil) - for i := 0; i < 1000; i++ { - key, conflict := defaultStringHash(strconv.Itoa(i)) - val, ok := s.Get(key, conflict) - require.False(t, ok) - require.Nil(t, val) - } -} - -func TestStoreUpdate(t *testing.T) { - s := newStore() - key, conflict := defaultStringHash("1") - i := Item{ - Key: key, - Conflict: conflict, - Value: 1, - } - s.Set(&i) - i.Value = 2 - _, ok := s.Update(&i) - require.True(t, ok) - - val, ok := s.Get(key, conflict) - require.True(t, ok) - require.NotNil(t, val) - - val, ok = s.Get(key, conflict) - require.True(t, ok) - require.Equal(t, 2, val.(int)) - - i.Value = 3 - _, ok = s.Update(&i) - require.True(t, ok) - - val, ok = s.Get(key, conflict) - require.True(t, ok) - require.Equal(t, 3, val.(int)) - - key, conflict = defaultStringHash("2") - i = Item{ - Key: key, - Conflict: conflict, - Value: 2, - } - _, ok = s.Update(&i) - require.False(t, ok) - val, ok = s.Get(key, conflict) - require.False(t, ok) - require.Nil(t, val) -} - -func TestStoreCollision(t *testing.T) { - s := newShardedMap() - s.shards[1].Lock() - s.shards[1].data[1] = storeItem{ - key: 1, - conflict: 0, - value: 1, - } - s.shards[1].Unlock() - val, ok := s.Get(1, 1) - require.False(t, ok) - require.Nil(t, val) - - i := Item{ - Key: 1, - Conflict: 1, - Value: 2, - } - s.Set(&i) - val, ok = s.Get(1, 0) - require.True(t, ok) - require.NotEqual(t, 2, val.(int)) - - _, ok = s.Update(&i) - require.False(t, ok) - val, ok = s.Get(1, 0) - require.True(t, ok) - require.NotEqual(t, 2, val.(int)) - - s.Del(1, 1) - val, ok = s.Get(1, 0) - require.True(t, ok) - require.NotNil(t, val) -} - -func BenchmarkStoreGet(b *testing.B) { - s := newStore() - key, conflict := defaultStringHash("1") - i := Item{ - Key: key, - Conflict: conflict, - Value: 1, - } - s.Set(&i) - b.SetBytes(1) - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - s.Get(key, conflict) - } - }) -} - -func BenchmarkStoreSet(b *testing.B) { - s := newStore() - key, conflict := defaultStringHash("1") - b.SetBytes(1) - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - i := Item{ - Key: key, - Conflict: conflict, - Value: 1, - } - s.Set(&i) - } - }) -} - -func BenchmarkStoreUpdate(b *testing.B) { - s := newStore() - key, conflict := defaultStringHash("1") - i := Item{ - Key: key, - Conflict: conflict, - Value: 1, - } - s.Set(&i) - b.SetBytes(1) - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - s.Update(&Item{ - Key: key, - Conflict: conflict, - Value: 2, - }) - } - }) -} diff --git a/go/cache/theine/LICENSE b/go/cache/theine/LICENSE new file mode 100644 index 00000000000..0161260b7b6 --- /dev/null +++ b/go/cache/theine/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Yiling-J + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/go/cache/theine/bf/bf.go b/go/cache/theine/bf/bf.go new file mode 100644 index 00000000000..f68e34d81e3 --- /dev/null +++ b/go/cache/theine/bf/bf.go @@ -0,0 +1,116 @@ +package bf + +import ( + "math" +) + +// doorkeeper is a small bloom-filter-based cache admission policy +type Bloomfilter struct { + Filter bitvector // our filter bit vector + M uint32 // size of bit vector in bits + K uint32 // distinct hash functions needed + FalsePositiveRate float64 + Capacity int +} + +func New(falsePositiveRate float64) *Bloomfilter { + d := &Bloomfilter{FalsePositiveRate: falsePositiveRate} + d.EnsureCapacity(320) + return d +} + +// create new bloomfilter with given size in bytes +func NewWithSize(size uint32) *Bloomfilter { + d := &Bloomfilter{} + bits := size * 8 + m := nextPowerOfTwo(uint32(bits)) + d.M = m + d.Filter = newbv(m) + return d +} + +func (d *Bloomfilter) EnsureCapacity(capacity int) { + if capacity <= d.Capacity { + return + } + capacity = int(nextPowerOfTwo(uint32(capacity))) + bits := float64(capacity) * -math.Log(d.FalsePositiveRate) / (math.Log(2.0) * math.Log(2.0)) // in bits + m := nextPowerOfTwo(uint32(bits)) + + if m < 1024 { + m = 1024 + } + + k := uint32(0.7 * float64(m) / float64(capacity)) + if k < 2 { + k = 2 + } + d.Capacity = capacity + d.M = m + d.Filter = newbv(m) + d.K = k +} + +func (d *Bloomfilter) Exist(h uint64) bool { + h1, h2 := uint32(h), uint32(h>>32) + var o uint = 1 + for i := uint32(0); i < d.K; i++ { + o &= d.Filter.get((h1 + (i * h2)) & (d.M - 1)) + } + return o == 1 +} + +// insert inserts the byte array b into the bloom filter. Returns true if the value +// was already considered to be in the bloom filter. +func (d *Bloomfilter) Insert(h uint64) bool { + h1, h2 := uint32(h), uint32(h>>32) + var o uint = 1 + for i := uint32(0); i < d.K; i++ { + o &= d.Filter.getset((h1 + (i * h2)) & (d.M - 1)) + } + return o == 1 +} + +// Reset clears the bloom filter +func (d *Bloomfilter) Reset() { + for i := range d.Filter { + d.Filter[i] = 0 + } +} + +// Internal routines for the bit vector +type bitvector []uint64 + +func newbv(size uint32) bitvector { + return make([]uint64, uint(size+63)/64) +} + +func (b bitvector) get(bit uint32) uint { + shift := bit % 64 + idx := bit / 64 + bb := b[idx] + m := uint64(1) << shift + return uint((bb & m) >> shift) +} + +// set bit 'bit' in the bitvector d and return previous value +func (b bitvector) getset(bit uint32) uint { + shift := bit % 64 + idx := bit / 64 + bb := b[idx] + m := uint64(1) << shift + b[idx] |= m + return uint((bb & m) >> shift) +} + +// return the integer >= i which is a power of two +func nextPowerOfTwo(i uint32) uint32 { + n := i - 1 + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + n++ + return n +} diff --git a/go/cache/theine/bf/bf_test.go b/go/cache/theine/bf/bf_test.go new file mode 100644 index 00000000000..f0e505766e7 --- /dev/null +++ b/go/cache/theine/bf/bf_test.go @@ -0,0 +1,24 @@ +package bf + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBloom(t *testing.T) { + bf := NewWithSize(5) + bf.FalsePositiveRate = 0.1 + bf.EnsureCapacity(5) + bf.EnsureCapacity(500) + bf.EnsureCapacity(200) + + exist := bf.Insert(123) + require.False(t, exist) + + exist = bf.Exist(123) + require.True(t, exist) + + exist = bf.Exist(456) + require.False(t, exist) +} diff --git a/go/cache/theine/entry.go b/go/cache/theine/entry.go new file mode 100644 index 00000000000..48e3bd5a09a --- /dev/null +++ b/go/cache/theine/entry.go @@ -0,0 +1,93 @@ +/* +Copyright 2023 The Vitess Authors. +Copyright 2023 Yiling-J + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package theine + +import "sync/atomic" + +const ( + NEW int8 = iota + REMOVE + UPDATE +) + +type ReadBufItem[K cachekey, V any] struct { + entry *Entry[K, V] + hash uint64 +} +type WriteBufItem[K cachekey, V any] struct { + entry *Entry[K, V] + costChange int64 + code int8 +} + +type MetaData[K cachekey, V any] struct { + prev *Entry[K, V] + next *Entry[K, V] +} + +type Entry[K cachekey, V any] struct { + key K + value V + meta MetaData[K, V] + cost atomic.Int64 + frequency atomic.Int32 + epoch atomic.Uint32 + removed bool + deque bool + root bool + list uint8 // used in slru, probation or protected +} + +func NewEntry[K cachekey, V any](key K, value V, cost int64) *Entry[K, V] { + entry := &Entry[K, V]{ + key: key, + value: value, + } + entry.cost.Store(cost) + return entry +} + +func (e *Entry[K, V]) Next() *Entry[K, V] { + if p := e.meta.next; !p.root { + return e.meta.next + } + return nil +} + +func (e *Entry[K, V]) Prev() *Entry[K, V] { + if p := e.meta.prev; !p.root { + return e.meta.prev + } + return nil +} + +func (e *Entry[K, V]) prev() *Entry[K, V] { + return e.meta.prev +} + +func (e *Entry[K, V]) next() *Entry[K, V] { + return e.meta.next +} + +func (e *Entry[K, V]) setPrev(entry *Entry[K, V]) { + e.meta.prev = entry +} + +func (e *Entry[K, V]) setNext(entry *Entry[K, V]) { + e.meta.next = entry +} diff --git a/go/cache/theine/list.go b/go/cache/theine/list.go new file mode 100644 index 00000000000..19854190cba --- /dev/null +++ b/go/cache/theine/list.go @@ -0,0 +1,205 @@ +/* +Copyright 2023 The Vitess Authors. +Copyright 2023 Yiling-J + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package theine + +import ( + "fmt" + "strings" +) + +const ( + LIST_PROBATION uint8 = 1 + LIST_PROTECTED uint8 = 2 +) + +// List represents a doubly linked list. +// The zero value for List is an empty list ready to use. +type List[K cachekey, V any] struct { + root Entry[K, V] // sentinel list element, only &root, root.prev, and root.next are used + len int // current list length(sum of costs) excluding (this) sentinel element + count int // count of entries in list + capacity uint + bounded bool + listType uint8 // 1 tinylfu list, 2 timerwheel list +} + +// New returns an initialized list. +func NewList[K cachekey, V any](size uint, listType uint8) *List[K, V] { + l := &List[K, V]{listType: listType, capacity: size, root: Entry[K, V]{}} + l.root.root = true + l.root.setNext(&l.root) + l.root.setPrev(&l.root) + l.len = 0 + l.capacity = size + if size > 0 { + l.bounded = true + } + return l +} + +func (l *List[K, V]) Reset() { + l.root.setNext(&l.root) + l.root.setPrev(&l.root) + l.len = 0 +} + +// Len returns the number of elements of list l. +// The complexity is O(1). +func (l *List[K, V]) Len() int { return l.len } + +func (l *List[K, V]) display() string { + var s []string + for e := l.Front(); e != nil; e = e.Next() { + s = append(s, fmt.Sprintf("%v", e.key)) + } + return strings.Join(s, "/") +} + +func (l *List[K, V]) displayReverse() string { + var s []string + for e := l.Back(); e != nil; e = e.Prev() { + s = append(s, fmt.Sprintf("%v", e.key)) + } + return strings.Join(s, "/") +} + +// Front returns the first element of list l or nil if the list is empty. +func (l *List[K, V]) Front() *Entry[K, V] { + e := l.root.next() + if e != &l.root { + return e + } + return nil +} + +// Back returns the last element of list l or nil if the list is empty. +func (l *List[K, V]) Back() *Entry[K, V] { + e := l.root.prev() + if e != &l.root { + return e + } + return nil +} + +// insert inserts e after at, increments l.len, and evicted entry if capacity exceed +func (l *List[K, V]) insert(e, at *Entry[K, V]) *Entry[K, V] { + var evicted *Entry[K, V] + if l.bounded && l.len >= int(l.capacity) { + evicted = l.PopTail() + } + e.list = l.listType + e.setPrev(at) + e.setNext(at.next()) + e.prev().setNext(e) + e.next().setPrev(e) + if l.bounded { + l.len += int(e.cost.Load()) + l.count += 1 + } + return evicted +} + +// PushFront push entry to list head +func (l *List[K, V]) PushFront(e *Entry[K, V]) *Entry[K, V] { + return l.insert(e, &l.root) +} + +// Push push entry to the back of list +func (l *List[K, V]) PushBack(e *Entry[K, V]) *Entry[K, V] { + return l.insert(e, l.root.prev()) +} + +// remove removes e from its list, decrements l.len +func (l *List[K, V]) remove(e *Entry[K, V]) { + e.prev().setNext(e.next()) + e.next().setPrev(e.prev()) + e.setNext(nil) + e.setPrev(nil) + e.list = 0 + if l.bounded { + l.len -= int(e.cost.Load()) + l.count -= 1 + } +} + +// move moves e to next to at. +func (l *List[K, V]) move(e, at *Entry[K, V]) { + if e == at { + return + } + e.prev().setNext(e.next()) + e.next().setPrev(e.prev()) + + e.setPrev(at) + e.setNext(at.next()) + e.prev().setNext(e) + e.next().setPrev(e) +} + +// Remove removes e from l if e is an element of list l. +// It returns the element value e.Value. +// The element must not be nil. +func (l *List[K, V]) Remove(e *Entry[K, V]) { + l.remove(e) +} + +// MoveToFront moves element e to the front of list l. +// If e is not an element of l, the list is not modified. +// The element must not be nil. +func (l *List[K, V]) MoveToFront(e *Entry[K, V]) { + l.move(e, &l.root) +} + +// MoveToBack moves element e to the back of list l. +// If e is not an element of l, the list is not modified. +// The element must not be nil. +func (l *List[K, V]) MoveToBack(e *Entry[K, V]) { + l.move(e, l.root.prev()) +} + +// MoveBefore moves element e to its new position before mark. +// If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. +func (l *List[K, V]) MoveBefore(e, mark *Entry[K, V]) { + l.move(e, mark.prev()) +} + +// MoveAfter moves element e to its new position after mark. +// If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. +func (l *List[K, V]) MoveAfter(e, mark *Entry[K, V]) { + l.move(e, mark) +} + +func (l *List[K, V]) PopTail() *Entry[K, V] { + entry := l.root.prev() + if entry != nil && entry != &l.root { + l.remove(entry) + return entry + } + return nil +} + +func (l *List[K, V]) Contains(entry *Entry[K, V]) bool { + for e := l.Front(); e != nil; e = e.Next() { + if e == entry { + return true + } + } + return false +} diff --git a/go/cache/theine/list_test.go b/go/cache/theine/list_test.go new file mode 100644 index 00000000000..aad68f5c142 --- /dev/null +++ b/go/cache/theine/list_test.go @@ -0,0 +1,91 @@ +/* +Copyright 2023 The Vitess Authors. +Copyright 2023 Yiling-J + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package theine + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestList(t *testing.T) { + l := NewList[StringKey, string](5, LIST_PROBATION) + require.Equal(t, uint(5), l.capacity) + require.Equal(t, LIST_PROBATION, l.listType) + for i := 0; i < 5; i++ { + evicted := l.PushFront(NewEntry(StringKey(fmt.Sprintf("%d", i)), "", 1)) + require.Nil(t, evicted) + } + require.Equal(t, 5, l.len) + require.Equal(t, "4/3/2/1/0", l.display()) + require.Equal(t, "0/1/2/3/4", l.displayReverse()) + + evicted := l.PushFront(NewEntry(StringKey("5"), "", 1)) + require.Equal(t, StringKey("0"), evicted.key) + require.Equal(t, 5, l.len) + require.Equal(t, "5/4/3/2/1", l.display()) + require.Equal(t, "1/2/3/4/5", l.displayReverse()) + + for i := 0; i < 5; i++ { + entry := l.PopTail() + require.Equal(t, StringKey(fmt.Sprintf("%d", i+1)), entry.key) + } + entry := l.PopTail() + require.Nil(t, entry) + + var entries []*Entry[StringKey, string] + for i := 0; i < 5; i++ { + new := NewEntry(StringKey(fmt.Sprintf("%d", i)), "", 1) + evicted := l.PushFront(new) + entries = append(entries, new) + require.Nil(t, evicted) + } + require.Equal(t, "4/3/2/1/0", l.display()) + l.MoveToBack(entries[2]) + require.Equal(t, "4/3/1/0/2", l.display()) + require.Equal(t, "2/0/1/3/4", l.displayReverse()) + l.MoveBefore(entries[1], entries[3]) + require.Equal(t, "4/1/3/0/2", l.display()) + require.Equal(t, "2/0/3/1/4", l.displayReverse()) + l.MoveAfter(entries[2], entries[4]) + require.Equal(t, "4/2/1/3/0", l.display()) + require.Equal(t, "0/3/1/2/4", l.displayReverse()) + l.Remove(entries[1]) + require.Equal(t, "4/2/3/0", l.display()) + require.Equal(t, "0/3/2/4", l.displayReverse()) + +} + +func TestListCountCost(t *testing.T) { + l := NewList[StringKey, string](100, LIST_PROBATION) + require.Equal(t, uint(100), l.capacity) + require.Equal(t, LIST_PROBATION, l.listType) + for i := 0; i < 5; i++ { + evicted := l.PushFront(NewEntry(StringKey(fmt.Sprintf("%d", i)), "", 20)) + require.Nil(t, evicted) + } + require.Equal(t, 100, l.len) + require.Equal(t, 5, l.count) + for i := 0; i < 3; i++ { + entry := l.PopTail() + require.NotNil(t, entry) + } + require.Equal(t, 40, l.len) + require.Equal(t, 2, l.count) +} diff --git a/go/cache/theine/mpsc.go b/go/cache/theine/mpsc.go new file mode 100644 index 00000000000..c00e2ce5a26 --- /dev/null +++ b/go/cache/theine/mpsc.go @@ -0,0 +1,86 @@ +/* +Copyright 2023 The Vitess Authors. +Copyright 2023 Yiling-J + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package theine + +// This implementation is based on http://www.1024cores.net/home/lock-free-algorithms/queues/non-intrusive-mpsc-node-based-queue + +import ( + "sync" + "sync/atomic" +) + +type node[V any] struct { + next atomic.Pointer[node[V]] + val V +} + +type Queue[V any] struct { + head, tail atomic.Pointer[node[V]] + nodePool sync.Pool +} + +func NewQueue[V any]() *Queue[V] { + q := &Queue[V]{nodePool: sync.Pool{New: func() any { + return new(node[V]) + }}} + stub := &node[V]{} + q.head.Store(stub) + q.tail.Store(stub) + return q +} + +// Push adds x to the back of the queue. +// +// Push can be safely called from multiple goroutines +func (q *Queue[V]) Push(x V) { + n := q.nodePool.Get().(*node[V]) + n.val = x + + // current producer acquires head node + prev := q.head.Swap(n) + + // release node to consumer + prev.next.Store(n) +} + +// Pop removes the item from the front of the queue or nil if the queue is empty +// +// Pop must be called from a single, consumer goroutine +func (q *Queue[V]) Pop() (V, bool) { + tail := q.tail.Load() + next := tail.next.Load() + if next != nil { + var null V + q.tail.Store(next) + v := next.val + next.val = null + tail.next.Store(nil) + q.nodePool.Put(tail) + return v, true + } + var null V + return null, false +} + +// Empty returns true if the queue is empty +// +// Empty must be called from a single, consumer goroutine +func (q *Queue[V]) Empty() bool { + tail := q.tail.Load() + return tail.next.Load() == nil +} diff --git a/go/cache/theine/mpsc_test.go b/go/cache/theine/mpsc_test.go new file mode 100644 index 00000000000..eca50efed3e --- /dev/null +++ b/go/cache/theine/mpsc_test.go @@ -0,0 +1,46 @@ +/* +Copyright 2023 The Vitess Authors. +Copyright 2023 Yiling-J + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package theine + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestQueue_PushPop(t *testing.T) { + q := NewQueue[int]() + + q.Push(1) + q.Push(2) + v, ok := q.Pop() + assert.True(t, ok) + assert.Equal(t, 1, v) + v, ok = q.Pop() + assert.True(t, ok) + assert.Equal(t, 2, v) + _, ok = q.Pop() + assert.False(t, ok) +} + +func TestQueue_Empty(t *testing.T) { + q := NewQueue[int]() + assert.True(t, q.Empty()) + q.Push(1) + assert.False(t, q.Empty()) +} diff --git a/go/cache/theine/singleflight.go b/go/cache/theine/singleflight.go new file mode 100644 index 00000000000..fde56670514 --- /dev/null +++ b/go/cache/theine/singleflight.go @@ -0,0 +1,196 @@ +/* +Copyright 2023 The Vitess Authors. +Copyright 2023 Yiling-J +Copyright 2013 The Go Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package theine + +import ( + "bytes" + "errors" + "fmt" + "runtime" + "runtime/debug" + "sync" + "sync/atomic" +) + +// errGoexit indicates the runtime.Goexit was called in +// the user given function. +var errGoexit = errors.New("runtime.Goexit was called") + +// A panicError is an arbitrary value recovered from a panic +// with the stack trace during the execution of given function. +type panicError struct { + value interface{} + stack []byte +} + +// Error implements error interface. +func (p *panicError) Error() string { + return fmt.Sprintf("%v\n\n%s", p.value, p.stack) +} + +func newPanicError(v interface{}) error { + stack := debug.Stack() + + // The first line of the stack trace is of the form "goroutine N [status]:" + // but by the time the panic reaches Do the goroutine may no longer exist + // and its status will have changed. Trim out the misleading line. + if line := bytes.IndexByte(stack[:], '\n'); line >= 0 { + stack = stack[line+1:] + } + return &panicError{value: v, stack: stack} +} + +// call is an in-flight or completed singleflight.Do call +type call[V any] struct { + + // These fields are written once before the WaitGroup is done + // and are only read after the WaitGroup is done. + val V + err error + + wg sync.WaitGroup + + // These fields are read and written with the singleflight + // mutex held before the WaitGroup is done, and are read but + // not written after the WaitGroup is done. + dups atomic.Int32 +} + +// Group represents a class of work and forms a namespace in +// which units of work can be executed with duplicate suppression. +type Group[K comparable, V any] struct { + m map[K]*call[V] // lazily initialized + mu sync.Mutex // protects m + callPool sync.Pool +} + +func NewGroup[K comparable, V any]() *Group[K, V] { + return &Group[K, V]{ + callPool: sync.Pool{New: func() any { + return new(call[V]) + }}, + } +} + +// Result holds the results of Do, so they can be passed +// on a channel. +type Result struct { + Val interface{} + Err error + Shared bool +} + +// Do executes and returns the results of the given function, making +// sure that only one execution is in-flight for a given key at a +// time. If a duplicate comes in, the duplicate caller waits for the +// original to complete and receives the same results. +// The return value shared indicates whether v was given to multiple callers. +func (g *Group[K, V]) Do(key K, fn func() (V, error)) (v V, err error, shared bool) { + g.mu.Lock() + if g.m == nil { + g.m = make(map[K]*call[V]) + } + if c, ok := g.m[key]; ok { + _ = c.dups.Add(1) + g.mu.Unlock() + c.wg.Wait() + + if e, ok := c.err.(*panicError); ok { + panic(e) + } else if c.err == errGoexit { + runtime.Goexit() + } + // assign value/err before put back to pool to avoid race + v = c.val + err = c.err + n := c.dups.Add(-1) + if n == 0 { + g.callPool.Put(c) + } + return v, err, true + } + c := g.callPool.Get().(*call[V]) + defer func() { + n := c.dups.Add(-1) + if n == 0 { + g.callPool.Put(c) + } + }() + _ = c.dups.Add(1) + c.wg.Add(1) + g.m[key] = c + g.mu.Unlock() + + g.doCall(c, key, fn) + return c.val, c.err, true +} + +// doCall handles the single call for a key. +func (g *Group[K, V]) doCall(c *call[V], key K, fn func() (V, error)) { + normalReturn := false + recovered := false + + // use double-defer to distinguish panic from runtime.Goexit, + // more details see https://golang.org/cl/134395 + defer func() { + // the given function invoked runtime.Goexit + if !normalReturn && !recovered { + c.err = errGoexit + } + + g.mu.Lock() + defer g.mu.Unlock() + c.wg.Done() + if g.m[key] == c { + delete(g.m, key) + } + + if e, ok := c.err.(*panicError); ok { + panic(e) + } + }() + + func() { + defer func() { + if !normalReturn { + // Ideally, we would wait to take a stack trace until we've determined + // whether this is a panic or a runtime.Goexit. + // + // Unfortunately, the only way we can distinguish the two is to see + // whether the recover stopped the goroutine from terminating, and by + // the time we know that, the part of the stack trace relevant to the + // panic has been discarded. + if r := recover(); r != nil { + c.err = newPanicError(r) + } + } + }() + + c.val, c.err = fn() + normalReturn = true + }() + + if !normalReturn { + recovered = true + } +} diff --git a/go/cache/theine/singleflight_test.go b/go/cache/theine/singleflight_test.go new file mode 100644 index 00000000000..60b28e69b4e --- /dev/null +++ b/go/cache/theine/singleflight_test.go @@ -0,0 +1,211 @@ +/* +Copyright 2023 The Vitess Authors. +Copyright 2023 Yiling-J +Copyright 2013 The Go Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package theine + +import ( + "crypto/rand" + "errors" + "fmt" + "io" + "runtime" + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestDo(t *testing.T) { + g := NewGroup[string, string]() + v, err, _ := g.Do("key", func() (string, error) { + return "bar", nil + }) + if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want { + t.Errorf("Do = %v; want %v", got, want) + } + if err != nil { + t.Errorf("Do error = %v", err) + } +} + +func TestDoErr(t *testing.T) { + g := NewGroup[string, string]() + someErr := errors.New("Some error") + v, err, _ := g.Do("key", func() (string, error) { + return "", someErr + }) + if err != someErr { + t.Errorf("Do error = %v; want someErr %v", err, someErr) + } + if v != "" { + t.Errorf("unexpected non-nil value %#v", v) + } +} + +func TestDoDupSuppress(t *testing.T) { + g := NewGroup[string, string]() + var wg1, wg2 sync.WaitGroup + c := make(chan string, 1) + var calls int32 + fn := func() (string, error) { + if atomic.AddInt32(&calls, 1) == 1 { + // First invocation. + wg1.Done() + } + v := <-c + c <- v // pump; make available for any future calls + + time.Sleep(10 * time.Millisecond) // let more goroutines enter Do + + return v, nil + } + + const n = 10 + wg1.Add(1) + for i := 0; i < n; i++ { + wg1.Add(1) + wg2.Add(1) + go func() { + defer wg2.Done() + wg1.Done() + v, err, _ := g.Do("key", fn) + if err != nil { + t.Errorf("Do error: %v", err) + return + } + if s := v; s != "bar" { + t.Errorf("Do = %T %v; want %q", v, v, "bar") + } + }() + } + wg1.Wait() + // At least one goroutine is in fn now and all of them have at + // least reached the line before the Do. + c <- "bar" + wg2.Wait() + if got := atomic.LoadInt32(&calls); got <= 0 || got >= n { + t.Errorf("number of calls = %d; want over 0 and less than %d", got, n) + } +} + +// Test singleflight behaves correctly after Do panic. +// See https://github.com/golang/go/issues/41133 +func TestPanicDo(t *testing.T) { + g := NewGroup[string, string]() + fn := func() (string, error) { + panic("invalid memory address or nil pointer dereference") + } + + const n = 5 + waited := int32(n) + panicCount := int32(0) + done := make(chan struct{}) + for i := 0; i < n; i++ { + go func() { + defer func() { + if err := recover(); err != nil { + atomic.AddInt32(&panicCount, 1) + } + + if atomic.AddInt32(&waited, -1) == 0 { + close(done) + } + }() + + _, _, _ = g.Do("key", fn) + }() + } + + select { + case <-done: + if panicCount != n { + t.Errorf("Expect %d panic, but got %d", n, panicCount) + } + case <-time.After(time.Second): + t.Fatalf("Do hangs") + } +} + +func TestGoexitDo(t *testing.T) { + g := NewGroup[string, int]() + fn := func() (int, error) { + runtime.Goexit() + return 0, nil + } + + const n = 5 + waited := int32(n) + done := make(chan struct{}) + for i := 0; i < n; i++ { + go func() { + var err error + defer func() { + if err != nil { + t.Errorf("Error should be nil, but got: %v", err) + } + if atomic.AddInt32(&waited, -1) == 0 { + close(done) + } + }() + _, err, _ = g.Do("key", fn) + }() + } + + select { + case <-done: + case <-time.After(time.Second): + t.Fatalf("Do hangs") + } +} + +func BenchmarkDo(b *testing.B) { + keys := randKeys(b, 10240, 10) + benchDo(b, NewGroup[string, int](), keys) + +} + +func benchDo(b *testing.B, g *Group[string, int], keys []string) { + keyc := len(keys) + b.ReportAllocs() + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + for i := 0; pb.Next(); i++ { + _, _, _ = g.Do(keys[i%keyc], func() (int, error) { + return 0, nil + }) + } + }) +} + +func randKeys(b *testing.B, count, length uint) []string { + keys := make([]string, 0, count) + key := make([]byte, length) + + for i := uint(0); i < count; i++ { + if _, err := io.ReadFull(rand.Reader, key); err != nil { + b.Fatalf("Failed to generate random key %d of %d of length %d: %s", i+1, count, length, err) + } + keys = append(keys, string(key)) + } + return keys +} diff --git a/go/cache/theine/sketch.go b/go/cache/theine/sketch.go new file mode 100644 index 00000000000..7d241d94fc8 --- /dev/null +++ b/go/cache/theine/sketch.go @@ -0,0 +1,137 @@ +/* +Copyright 2023 The Vitess Authors. +Copyright 2023 Yiling-J + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package theine + +type CountMinSketch struct { + Table []uint64 + Additions uint + SampleSize uint + BlockMask uint +} + +func NewCountMinSketch() *CountMinSketch { + new := &CountMinSketch{} + new.EnsureCapacity(16) + return new +} + +// indexOf return table index and counter index together +func (s *CountMinSketch) indexOf(h uint64, block uint64, offset uint8) (uint, uint) { + counterHash := h + uint64(1+offset)*(h>>32) + // max block + 7(8 * 8 bytes), fit 64 bytes cache line + index := block + counterHash&1 + uint64(offset<<1) + return uint(index), uint((counterHash & 0xF) << 2) +} + +func (s *CountMinSketch) inc(index uint, offset uint) bool { + mask := uint64(0xF << offset) + if s.Table[index]&mask != mask { + s.Table[index] += 1 << offset + return true + } + return false +} + +func (s *CountMinSketch) Add(h uint64) bool { + hn := spread(h) + block := (hn & uint64(s.BlockMask)) << 3 + hc := rehash(h) + index0, offset0 := s.indexOf(hc, block, 0) + index1, offset1 := s.indexOf(hc, block, 1) + index2, offset2 := s.indexOf(hc, block, 2) + index3, offset3 := s.indexOf(hc, block, 3) + + added := s.inc(index0, offset0) + added = s.inc(index1, offset1) || added + added = s.inc(index2, offset2) || added + added = s.inc(index3, offset3) || added + + if added { + s.Additions += 1 + if s.Additions == s.SampleSize { + s.reset() + return true + } + } + return false +} + +func (s *CountMinSketch) reset() { + for i := range s.Table { + s.Table[i] = s.Table[i] >> 1 + } + s.Additions = s.Additions >> 1 +} + +func (s *CountMinSketch) count(h uint64, block uint64, offset uint8) uint { + index, off := s.indexOf(h, block, offset) + count := (s.Table[index] >> off) & 0xF + return uint(count) +} + +func (s *CountMinSketch) Estimate(h uint64) uint { + hn := spread(h) + block := (hn & uint64(s.BlockMask)) << 3 + hc := rehash(h) + m := min(s.count(hc, block, 0), 100) + m = min(s.count(hc, block, 1), m) + m = min(s.count(hc, block, 2), m) + m = min(s.count(hc, block, 3), m) + return m +} + +func next2Power(x uint) uint { + x-- + x |= x >> 1 + x |= x >> 2 + x |= x >> 4 + x |= x >> 8 + x |= x >> 16 + x |= x >> 32 + x++ + return x +} + +func (s *CountMinSketch) EnsureCapacity(size uint) { + if len(s.Table) >= int(size) { + return + } + if size < 16 { + size = 16 + } + newSize := next2Power(size) + s.Table = make([]uint64, newSize) + s.SampleSize = 10 * size + s.BlockMask = uint((len(s.Table) >> 3) - 1) + s.Additions = 0 +} + +func spread(h uint64) uint64 { + h ^= h >> 17 + h *= 0xed5ad4bb + h ^= h >> 11 + h *= 0xac4c1b51 + h ^= h >> 15 + return h +} + +func rehash(h uint64) uint64 { + h *= 0x31848bab + h ^= h >> 14 + return h +} diff --git a/go/cache/theine/sketch_test.go b/go/cache/theine/sketch_test.go new file mode 100644 index 00000000000..3437f0cac3c --- /dev/null +++ b/go/cache/theine/sketch_test.go @@ -0,0 +1,54 @@ +package theine + +import ( + "fmt" + "testing" + + "github.com/cespare/xxhash/v2" + "github.com/stretchr/testify/require" +) + +func TestEnsureCapacity(t *testing.T) { + sketch := NewCountMinSketch() + sketch.EnsureCapacity(1) + require.Equal(t, 16, len(sketch.Table)) +} + +func TestSketch(t *testing.T) { + sketch := NewCountMinSketch() + sketch.EnsureCapacity(100) + require.Equal(t, 128, len(sketch.Table)) + require.Equal(t, uint(1000), sketch.SampleSize) + // override sampleSize so test won't reset + sketch.SampleSize = 5120 + + failed := 0 + for i := 0; i < 500; i++ { + key := fmt.Sprintf("key:%d", i) + keyh := xxhash.Sum64String(key) + sketch.Add(keyh) + sketch.Add(keyh) + sketch.Add(keyh) + sketch.Add(keyh) + sketch.Add(keyh) + key = fmt.Sprintf("key:%d:b", i) + keyh2 := xxhash.Sum64String(key) + sketch.Add(keyh2) + sketch.Add(keyh2) + sketch.Add(keyh2) + + es1 := sketch.Estimate(keyh) + es2 := sketch.Estimate(keyh2) + if es2 > es1 { + failed++ + } + require.True(t, es1 >= 5) + require.True(t, es2 >= 3) + + } + require.True(t, float32(failed)/4000 < 0.1) + require.True(t, sketch.Additions > 3500) + a := sketch.Additions + sketch.reset() + require.Equal(t, a>>1, sketch.Additions) +} diff --git a/go/cache/theine/slru.go b/go/cache/theine/slru.go new file mode 100644 index 00000000000..e3bcb2532b1 --- /dev/null +++ b/go/cache/theine/slru.go @@ -0,0 +1,79 @@ +/* +Copyright 2023 The Vitess Authors. +Copyright 2023 Yiling-J + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package theine + +type Slru[K cachekey, V any] struct { + probation *List[K, V] + protected *List[K, V] + maxsize uint +} + +func NewSlru[K cachekey, V any](size uint) *Slru[K, V] { + return &Slru[K, V]{ + maxsize: size, + probation: NewList[K, V](size, LIST_PROBATION), + protected: NewList[K, V](uint(float32(size)*0.8), LIST_PROTECTED), + } +} + +func (s *Slru[K, V]) insert(entry *Entry[K, V]) *Entry[K, V] { + var evicted *Entry[K, V] + if s.probation.Len()+s.protected.Len() >= int(s.maxsize) { + evicted = s.probation.PopTail() + } + s.probation.PushFront(entry) + return evicted +} + +func (s *Slru[K, V]) victim() *Entry[K, V] { + if s.probation.Len()+s.protected.Len() < int(s.maxsize) { + return nil + } + return s.probation.Back() +} + +func (s *Slru[K, V]) access(entry *Entry[K, V]) { + switch entry.list { + case LIST_PROBATION: + s.probation.remove(entry) + evicted := s.protected.PushFront(entry) + if evicted != nil { + s.probation.PushFront(evicted) + } + case LIST_PROTECTED: + s.protected.MoveToFront(entry) + } +} + +func (s *Slru[K, V]) remove(entry *Entry[K, V]) { + switch entry.list { + case LIST_PROBATION: + s.probation.remove(entry) + case LIST_PROTECTED: + s.protected.remove(entry) + } +} + +func (s *Slru[K, V]) updateCost(entry *Entry[K, V], delta int64) { + switch entry.list { + case LIST_PROBATION: + s.probation.len += int(delta) + case LIST_PROTECTED: + s.protected.len += int(delta) + } +} diff --git a/go/cache/theine/store.go b/go/cache/theine/store.go new file mode 100644 index 00000000000..3d86e549867 --- /dev/null +++ b/go/cache/theine/store.go @@ -0,0 +1,615 @@ +/* +Copyright 2023 The Vitess Authors. +Copyright 2023 Yiling-J + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package theine + +import ( + "runtime" + "sync" + "sync/atomic" + "time" + + "github.com/gammazero/deque" + + "vitess.io/vitess/go/cache/theine/bf" + "vitess.io/vitess/go/hack" +) + +const ( + MaxReadBuffSize = 64 + MinWriteBuffSize = 4 + MaxWriteBuffSize = 1024 +) + +type RemoveReason uint8 + +const ( + REMOVED RemoveReason = iota + EVICTED + EXPIRED +) + +type Shard[K cachekey, V any] struct { + hashmap map[K]*Entry[K, V] + dookeeper *bf.Bloomfilter + deque *deque.Deque[*Entry[K, V]] + group *Group[K, V] + qsize uint + qlen int + counter uint + mu sync.RWMutex +} + +func NewShard[K cachekey, V any](size uint, qsize uint, doorkeeper bool) *Shard[K, V] { + s := &Shard[K, V]{ + hashmap: make(map[K]*Entry[K, V]), + qsize: qsize, + deque: deque.New[*Entry[K, V]](), + group: NewGroup[K, V](), + } + if doorkeeper { + s.dookeeper = bf.New(0.01) + } + return s +} + +func (s *Shard[K, V]) set(key K, entry *Entry[K, V]) { + s.hashmap[key] = entry + if s.dookeeper != nil { + ds := 20 * len(s.hashmap) + if ds > s.dookeeper.Capacity { + s.dookeeper.EnsureCapacity(ds) + } + } +} + +func (s *Shard[K, V]) get(key K) (entry *Entry[K, V], ok bool) { + entry, ok = s.hashmap[key] + return +} + +func (s *Shard[K, V]) delete(entry *Entry[K, V]) bool { + var deleted bool + exist, ok := s.hashmap[entry.key] + if ok && exist == entry { + delete(s.hashmap, exist.key) + deleted = true + } + return deleted +} + +func (s *Shard[K, V]) len() int { + return len(s.hashmap) +} + +type Metrics struct { + evicted atomic.Int64 + hits atomic.Int64 + misses atomic.Int64 +} + +func (m *Metrics) Evicted() int64 { + return m.evicted.Load() +} + +func (m *Metrics) Hits() int64 { + return m.hits.Load() +} + +func (m *Metrics) Misses() int64 { + return m.misses.Load() +} + +func (m *Metrics) Accesses() int64 { + return m.Hits() + m.Misses() +} + +type cachekey interface { + comparable + Hash() uint64 + Hash2() (uint64, uint64) +} + +type HashKey256 [32]byte + +func (h HashKey256) Hash() uint64 { + return uint64(h[0]) | uint64(h[1])<<8 | uint64(h[2])<<16 | uint64(h[3])<<24 | + uint64(h[4])<<32 | uint64(h[5])<<40 | uint64(h[6])<<48 | uint64(h[7])<<56 +} + +func (h HashKey256) Hash2() (uint64, uint64) { + h0 := h.Hash() + h1 := uint64(h[8]) | uint64(h[9])<<8 | uint64(h[10])<<16 | uint64(h[11])<<24 | + uint64(h[12])<<32 | uint64(h[13])<<40 | uint64(h[14])<<48 | uint64(h[15])<<56 + return h0, h1 +} + +type StringKey string + +func (h StringKey) Hash() uint64 { + return hack.RuntimeStrhash(string(h), 13850135847636357301) +} + +func (h StringKey) Hash2() (uint64, uint64) { + h0 := h.Hash() + h1 := ((h0 >> 16) ^ h0) * 0x45d9f3b + h1 = ((h1 >> 16) ^ h1) * 0x45d9f3b + h1 = (h1 >> 16) ^ h1 + return h0, h1 +} + +type cacheval interface { + CachedSize(alloc bool) int64 +} + +type Store[K cachekey, V cacheval] struct { + Metrics Metrics + OnRemoval func(K, V, RemoveReason) + + entryPool sync.Pool + writebuf chan WriteBufItem[K, V] + policy *TinyLfu[K, V] + readbuf *Queue[ReadBufItem[K, V]] + shards []*Shard[K, V] + cap uint + shardCount uint + writebufsize int64 + tailUpdate bool + doorkeeper bool + + mlock sync.Mutex + readCounter atomic.Uint32 + open atomic.Bool +} + +func NewStore[K cachekey, V cacheval](maxsize int64, doorkeeper bool) *Store[K, V] { + writeBufSize := maxsize / 100 + if writeBufSize < MinWriteBuffSize { + writeBufSize = MinWriteBuffSize + } + if writeBufSize > MaxWriteBuffSize { + writeBufSize = MaxWriteBuffSize + } + shardCount := 1 + for shardCount < runtime.GOMAXPROCS(0)*2 { + shardCount *= 2 + } + if shardCount < 16 { + shardCount = 16 + } + if shardCount > 128 { + shardCount = 128 + } + dequeSize := int(maxsize) / 100 / shardCount + shardSize := int(maxsize) / shardCount + if shardSize < 50 { + shardSize = 50 + } + policySize := int(maxsize) - (dequeSize * shardCount) + + s := &Store[K, V]{ + cap: uint(maxsize), + policy: NewTinyLfu[K, V](uint(policySize)), + readbuf: NewQueue[ReadBufItem[K, V]](), + writebuf: make(chan WriteBufItem[K, V], writeBufSize), + entryPool: sync.Pool{New: func() any { return &Entry[K, V]{} }}, + shardCount: uint(shardCount), + doorkeeper: doorkeeper, + writebufsize: writeBufSize, + } + s.shards = make([]*Shard[K, V], 0, s.shardCount) + for i := 0; i < int(s.shardCount); i++ { + s.shards = append(s.shards, NewShard[K, V](uint(shardSize), uint(dequeSize), doorkeeper)) + } + + go s.maintenance() + s.open.Store(true) + return s +} + +func (s *Store[K, V]) EnsureOpen() { + if s.open.Swap(true) { + return + } + s.writebuf = make(chan WriteBufItem[K, V], s.writebufsize) + go s.maintenance() +} + +func (s *Store[K, V]) getFromShard(key K, hash uint64, shard *Shard[K, V], epoch uint32) (V, bool) { + new := s.readCounter.Add(1) + shard.mu.RLock() + entry, ok := shard.get(key) + var value V + if ok { + if entry.epoch.Load() < epoch { + s.Metrics.misses.Add(1) + ok = false + } else { + s.Metrics.hits.Add(1) + s.policy.hit.Add(1) + value = entry.value + } + } else { + s.Metrics.misses.Add(1) + } + shard.mu.RUnlock() + switch { + case new < MaxReadBuffSize: + var send ReadBufItem[K, V] + send.hash = hash + if ok { + send.entry = entry + } + s.readbuf.Push(send) + case new == MaxReadBuffSize: + var send ReadBufItem[K, V] + send.hash = hash + if ok { + send.entry = entry + } + s.readbuf.Push(send) + s.drainRead() + } + return value, ok +} + +func (s *Store[K, V]) Get(key K, epoch uint32) (V, bool) { + h, index := s.index(key) + shard := s.shards[index] + return s.getFromShard(key, h, shard, epoch) +} + +func (s *Store[K, V]) GetOrLoad(key K, epoch uint32, load func() (V, error)) (V, bool, error) { + h, index := s.index(key) + shard := s.shards[index] + v, ok := s.getFromShard(key, h, shard, epoch) + if !ok { + loaded, err, _ := shard.group.Do(key, func() (V, error) { + loaded, err := load() + if err == nil { + s.Set(key, loaded, 0, epoch) + } + return loaded, err + }) + return loaded, false, err + } + return v, true, nil +} + +func (s *Store[K, V]) setEntry(shard *Shard[K, V], cost int64, epoch uint32, entry *Entry[K, V]) { + shard.set(entry.key, entry) + // cost larger than deque size, send to policy directly + if cost > int64(shard.qsize) { + shard.mu.Unlock() + s.writebuf <- WriteBufItem[K, V]{entry: entry, code: NEW} + return + } + entry.deque = true + shard.deque.PushFront(entry) + shard.qlen += int(cost) + s.processDeque(shard, epoch) +} + +func (s *Store[K, V]) setInternal(key K, value V, cost int64, epoch uint32) (*Shard[K, V], *Entry[K, V], bool) { + h, index := s.index(key) + shard := s.shards[index] + shard.mu.Lock() + exist, ok := shard.get(key) + if ok { + var costChange int64 + exist.value = value + oldCost := exist.cost.Swap(cost) + if oldCost != cost { + costChange = cost - oldCost + if exist.deque { + shard.qlen += int(costChange) + } + } + shard.mu.Unlock() + exist.epoch.Store(epoch) + if costChange != 0 { + s.writebuf <- WriteBufItem[K, V]{ + entry: exist, code: UPDATE, costChange: costChange, + } + } + return shard, exist, true + } + if s.doorkeeper { + if shard.counter > uint(shard.dookeeper.Capacity) { + shard.dookeeper.Reset() + shard.counter = 0 + } + hit := shard.dookeeper.Insert(h) + if !hit { + shard.counter += 1 + shard.mu.Unlock() + return shard, nil, false + } + } + entry := s.entryPool.Get().(*Entry[K, V]) + entry.frequency.Store(-1) + entry.key = key + entry.value = value + entry.cost.Store(cost) + entry.epoch.Store(epoch) + s.setEntry(shard, cost, epoch, entry) + return shard, entry, true + +} + +func (s *Store[K, V]) Set(key K, value V, cost int64, epoch uint32) bool { + if cost == 0 { + cost = value.CachedSize(true) + } + if cost > int64(s.cap) { + return false + } + _, _, ok := s.setInternal(key, value, cost, epoch) + return ok +} + +type dequeKV[K cachekey, V cacheval] struct { + k K + v V +} + +func (s *Store[K, V]) processDeque(shard *Shard[K, V], epoch uint32) { + if shard.qlen <= int(shard.qsize) { + shard.mu.Unlock() + return + } + var evictedkv []dequeKV[K, V] + var expiredkv []dequeKV[K, V] + + // send to slru + send := make([]*Entry[K, V], 0, 2) + for shard.qlen > int(shard.qsize) { + evicted := shard.deque.PopBack() + evicted.deque = false + shard.qlen -= int(evicted.cost.Load()) + + if evicted.epoch.Load() < epoch { + deleted := shard.delete(evicted) + if deleted { + if s.OnRemoval != nil { + evictedkv = append(evictedkv, dequeKV[K, V]{evicted.key, evicted.value}) + } + s.postDelete(evicted) + s.Metrics.evicted.Add(1) + } + } else { + count := evicted.frequency.Load() + threshold := s.policy.threshold.Load() + if count == -1 { + send = append(send, evicted) + } else { + if int32(count) >= threshold { + send = append(send, evicted) + } else { + deleted := shard.delete(evicted) + // double check because entry maybe removed already by Delete API + if deleted { + if s.OnRemoval != nil { + evictedkv = append(evictedkv, dequeKV[K, V]{evicted.key, evicted.value}) + } + s.postDelete(evicted) + s.Metrics.evicted.Add(1) + } + } + } + } + } + + shard.mu.Unlock() + for _, entry := range send { + s.writebuf <- WriteBufItem[K, V]{entry: entry, code: NEW} + } + if s.OnRemoval != nil { + for _, kv := range evictedkv { + s.OnRemoval(kv.k, kv.v, EVICTED) + } + for _, kv := range expiredkv { + s.OnRemoval(kv.k, kv.v, EXPIRED) + } + } +} + +func (s *Store[K, V]) Delete(key K) { + _, index := s.index(key) + shard := s.shards[index] + shard.mu.Lock() + entry, ok := shard.get(key) + if ok { + shard.delete(entry) + } + shard.mu.Unlock() + if ok { + s.writebuf <- WriteBufItem[K, V]{entry: entry, code: REMOVE} + } +} + +func (s *Store[K, V]) Len() int { + total := 0 + for _, s := range s.shards { + s.mu.RLock() + total += s.len() + s.mu.RUnlock() + } + return total +} + +func (s *Store[K, V]) UsedCapacity() int { + total := 0 + for _, s := range s.shards { + s.mu.RLock() + total += s.qlen + s.mu.RUnlock() + } + return total +} + +func (s *Store[K, V]) MaxCapacity() int { + return int(s.cap) +} + +// spread hash before get index +func (s *Store[K, V]) index(key K) (uint64, int) { + h0, h1 := key.Hash2() + return h0, int(h1 & uint64(s.shardCount-1)) +} + +func (s *Store[K, V]) postDelete(entry *Entry[K, V]) { + var zero V + entry.value = zero + s.entryPool.Put(entry) +} + +// remove entry from cache/policy/timingwheel and add back to pool +func (s *Store[K, V]) removeEntry(entry *Entry[K, V], reason RemoveReason) { + if prev := entry.meta.prev; prev != nil { + s.policy.Remove(entry) + } + switch reason { + case EVICTED, EXPIRED: + _, index := s.index(entry.key) + shard := s.shards[index] + shard.mu.Lock() + deleted := shard.delete(entry) + shard.mu.Unlock() + if deleted { + if s.OnRemoval != nil { + s.OnRemoval(entry.key, entry.value, reason) + } + s.postDelete(entry) + s.Metrics.evicted.Add(1) + } + case REMOVED: + // already removed from shard map + if s.OnRemoval != nil { + s.OnRemoval(entry.key, entry.value, reason) + } + } +} + +func (s *Store[K, V]) drainRead() { + s.policy.total.Add(MaxReadBuffSize) + s.mlock.Lock() + for { + v, ok := s.readbuf.Pop() + if !ok { + break + } + s.policy.Access(v) + } + s.mlock.Unlock() + s.readCounter.Store(0) +} + +func (s *Store[K, V]) maintenanceItem(item WriteBufItem[K, V]) { + s.mlock.Lock() + defer s.mlock.Unlock() + + entry := item.entry + if entry == nil { + return + } + + // lock free because store API never read/modify entry metadata + switch item.code { + case NEW: + if entry.removed { + return + } + evicted := s.policy.Set(entry) + if evicted != nil { + s.removeEntry(evicted, EVICTED) + s.tailUpdate = true + } + removed := s.policy.EvictEntries() + for _, e := range removed { + s.tailUpdate = true + s.removeEntry(e, EVICTED) + } + case REMOVE: + entry.removed = true + s.removeEntry(entry, REMOVED) + s.policy.threshold.Store(-1) + case UPDATE: + if item.costChange != 0 { + s.policy.UpdateCost(entry, item.costChange) + removed := s.policy.EvictEntries() + for _, e := range removed { + s.tailUpdate = true + s.removeEntry(e, EVICTED) + } + } + } + item.entry = nil + if s.tailUpdate { + s.policy.UpdateThreshold() + s.tailUpdate = false + } +} + +func (s *Store[K, V]) maintenance() { + tick := time.NewTicker(500 * time.Millisecond) + defer tick.Stop() + + for { + select { + case <-tick.C: + s.mlock.Lock() + s.policy.UpdateThreshold() + s.mlock.Unlock() + + case item, ok := <-s.writebuf: + if !ok { + return + } + s.maintenanceItem(item) + } + } +} + +func (s *Store[K, V]) Range(epoch uint32, f func(key K, value V) bool) { + for _, shard := range s.shards { + shard.mu.RLock() + for _, entry := range shard.hashmap { + if entry.epoch.Load() < epoch { + continue + } + if !f(entry.key, entry.value) { + shard.mu.RUnlock() + return + } + } + shard.mu.RUnlock() + } +} + +func (s *Store[K, V]) Close() { + if !s.open.Swap(false) { + panic("theine.Store: double close") + } + + for _, s := range s.shards { + s.mu.Lock() + clear(s.hashmap) + s.mu.Unlock() + } + close(s.writebuf) +} diff --git a/go/cache/theine/store_test.go b/go/cache/theine/store_test.go new file mode 100644 index 00000000000..880acf30193 --- /dev/null +++ b/go/cache/theine/store_test.go @@ -0,0 +1,82 @@ +/* +Copyright 2023 The Vitess Authors. +Copyright 2023 Yiling-J + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package theine + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +type cachedint int + +func (ci cachedint) CachedSize(bool) int64 { + return 1 +} + +type keyint int + +func (k keyint) Hash() uint64 { + return uint64(k) +} + +func (k keyint) Hash2() (uint64, uint64) { + return uint64(k), uint64(k) * 333 +} + +func TestProcessDeque(t *testing.T) { + store := NewStore[keyint, cachedint](20000, false) + + evicted := map[keyint]cachedint{} + store.OnRemoval = func(key keyint, value cachedint, reason RemoveReason) { + if reason == EVICTED { + evicted[key] = value + } + } + _, index := store.index(123) + shard := store.shards[index] + shard.qsize = 10 + + for i := keyint(0); i < 5; i++ { + entry := &Entry[keyint, cachedint]{key: i} + entry.cost.Store(1) + store.shards[index].deque.PushFront(entry) + store.shards[index].qlen += 1 + store.shards[index].hashmap[i] = entry + } + + // move 0,1,2 entries to slru + store.Set(123, 123, 8, 0) + require.Equal(t, store.shards[index].deque.Len(), 3) + var keys []keyint + for store.shards[index].deque.Len() != 0 { + e := store.shards[index].deque.PopBack() + keys = append(keys, e.key) + } + require.Equal(t, []keyint{3, 4, 123}, keys) +} + +func TestDoorKeeperDynamicSize(t *testing.T) { + store := NewStore[keyint, cachedint](200000, true) + shard := store.shards[0] + require.True(t, shard.dookeeper.Capacity == 512) + for i := keyint(0); i < 5000; i++ { + shard.set(i, &Entry[keyint, cachedint]{}) + } + require.True(t, shard.dookeeper.Capacity > 100000) +} diff --git a/go/cache/theine/tlfu.go b/go/cache/theine/tlfu.go new file mode 100644 index 00000000000..f7a4f8dec51 --- /dev/null +++ b/go/cache/theine/tlfu.go @@ -0,0 +1,197 @@ +/* +Copyright 2023 The Vitess Authors. +Copyright 2023 Yiling-J + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package theine + +import ( + "sync/atomic" +) + +type TinyLfu[K cachekey, V any] struct { + slru *Slru[K, V] + sketch *CountMinSketch + size uint + counter uint + total atomic.Uint32 + hit atomic.Uint32 + hr float32 + threshold atomic.Int32 + lruFactor uint8 + step int8 +} + +func NewTinyLfu[K cachekey, V any](size uint) *TinyLfu[K, V] { + tlfu := &TinyLfu[K, V]{ + size: size, + slru: NewSlru[K, V](size), + sketch: NewCountMinSketch(), + step: 1, + } + // default threshold to -1 so all entries are admitted until cache is full + tlfu.threshold.Store(-1) + return tlfu +} + +func (t *TinyLfu[K, V]) climb() { + total := t.total.Load() + hit := t.hit.Load() + current := float32(hit) / float32(total) + delta := current - t.hr + var diff int8 + if delta > 0.0 { + if t.step < 0 { + t.step -= 1 + } else { + t.step += 1 + } + if t.step < -13 { + t.step = -13 + } else if t.step > 13 { + t.step = 13 + } + newFactor := int8(t.lruFactor) + t.step + if newFactor < 0 { + newFactor = 0 + } else if newFactor > 16 { + newFactor = 16 + } + diff = newFactor - int8(t.lruFactor) + t.lruFactor = uint8(newFactor) + } else if delta < 0.0 { + // reset + if t.step > 0 { + t.step = -1 + } else { + t.step = 1 + } + newFactor := int8(t.lruFactor) + t.step + if newFactor < 0 { + newFactor = 0 + } else if newFactor > 16 { + newFactor = 16 + } + diff = newFactor - int8(t.lruFactor) + t.lruFactor = uint8(newFactor) + } + t.threshold.Add(-int32(diff)) + t.hr = current + t.hit.Store(0) + t.total.Store(0) +} + +func (t *TinyLfu[K, V]) Set(entry *Entry[K, V]) *Entry[K, V] { + t.counter++ + if t.counter > 10*t.size { + t.climb() + t.counter = 0 + } + if entry.meta.prev == nil { + if victim := t.slru.victim(); victim != nil { + freq := int(entry.frequency.Load()) + if freq == -1 { + freq = int(t.sketch.Estimate(entry.key.Hash())) + } + evictedCount := uint(freq) + uint(t.lruFactor) + victimCount := t.sketch.Estimate(victim.key.Hash()) + if evictedCount <= uint(victimCount) { + return entry + } + } else { + count := t.slru.probation.count + t.slru.protected.count + t.sketch.EnsureCapacity(uint(count + count/100)) + } + evicted := t.slru.insert(entry) + return evicted + } + + return nil +} + +func (t *TinyLfu[K, V]) Access(item ReadBufItem[K, V]) { + t.counter++ + if t.counter > 10*t.size { + t.climb() + t.counter = 0 + } + if entry := item.entry; entry != nil { + reset := t.sketch.Add(item.hash) + if reset { + t.threshold.Store(t.threshold.Load() / 2) + } + if entry.meta.prev != nil { + var tail bool + if entry == t.slru.victim() { + tail = true + } + t.slru.access(entry) + if tail { + t.UpdateThreshold() + } + } else { + entry.frequency.Store(int32(t.sketch.Estimate(item.hash))) + } + } else { + reset := t.sketch.Add(item.hash) + if reset { + t.threshold.Store(t.threshold.Load() / 2) + } + } +} + +func (t *TinyLfu[K, V]) Remove(entry *Entry[K, V]) { + t.slru.remove(entry) +} + +func (t *TinyLfu[K, V]) UpdateCost(entry *Entry[K, V], delta int64) { + t.slru.updateCost(entry, delta) +} + +func (t *TinyLfu[K, V]) EvictEntries() []*Entry[K, V] { + removed := []*Entry[K, V]{} + + for t.slru.probation.Len()+t.slru.protected.Len() > int(t.slru.maxsize) { + entry := t.slru.probation.PopTail() + if entry == nil { + break + } + removed = append(removed, entry) + } + for t.slru.probation.Len()+t.slru.protected.Len() > int(t.slru.maxsize) { + entry := t.slru.protected.PopTail() + if entry == nil { + break + } + removed = append(removed, entry) + } + return removed +} + +func (t *TinyLfu[K, V]) UpdateThreshold() { + if t.slru.probation.Len()+t.slru.protected.Len() < int(t.slru.maxsize) { + t.threshold.Store(-1) + } else { + tail := t.slru.victim() + if tail != nil { + t.threshold.Store( + int32(t.sketch.Estimate(tail.key.Hash()) - uint(t.lruFactor)), + ) + } else { + // cache is not full + t.threshold.Store(-1) + } + } +} diff --git a/go/cache/theine/tlfu_test.go b/go/cache/theine/tlfu_test.go new file mode 100644 index 00000000000..ac6ddaabdb6 --- /dev/null +++ b/go/cache/theine/tlfu_test.go @@ -0,0 +1,156 @@ +/* +Copyright 2023 The Vitess Authors. +Copyright 2023 Yiling-J + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package theine + +import ( + "fmt" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestTlfu(t *testing.T) { + tlfu := NewTinyLfu[StringKey, string](1000) + require.Equal(t, uint(1000), tlfu.slru.probation.capacity) + require.Equal(t, uint(800), tlfu.slru.protected.capacity) + require.Equal(t, 0, tlfu.slru.probation.len) + require.Equal(t, 0, tlfu.slru.protected.len) + + var entries []*Entry[StringKey, string] + for i := 0; i < 200; i++ { + e := NewEntry(StringKey(fmt.Sprintf("%d", i)), "", 1) + evicted := tlfu.Set(e) + entries = append(entries, e) + require.Nil(t, evicted) + } + + require.Equal(t, 200, tlfu.slru.probation.len) + require.Equal(t, 0, tlfu.slru.protected.len) + + // probation -> protected + tlfu.Access(ReadBufItem[StringKey, string]{entry: entries[11]}) + require.Equal(t, 199, tlfu.slru.probation.len) + require.Equal(t, 1, tlfu.slru.protected.len) + tlfu.Access(ReadBufItem[StringKey, string]{entry: entries[11]}) + require.Equal(t, 199, tlfu.slru.probation.len) + require.Equal(t, 1, tlfu.slru.protected.len) + + for i := 200; i < 1000; i++ { + e := NewEntry(StringKey(fmt.Sprintf("%d", i)), "", 1) + entries = append(entries, e) + evicted := tlfu.Set(e) + require.Nil(t, evicted) + } + // access protected + tlfu.Access(ReadBufItem[StringKey, string]{entry: entries[11]}) + require.Equal(t, 999, tlfu.slru.probation.len) + require.Equal(t, 1, tlfu.slru.protected.len) + + evicted := tlfu.Set(NewEntry(StringKey("0a"), "", 1)) + require.Equal(t, StringKey("0a"), evicted.key) + require.Equal(t, 999, tlfu.slru.probation.len) + require.Equal(t, 1, tlfu.slru.protected.len) + + victim := tlfu.slru.victim() + require.Equal(t, StringKey("0"), victim.key) + tlfu.Access(ReadBufItem[StringKey, string]{entry: entries[991]}) + tlfu.Access(ReadBufItem[StringKey, string]{entry: entries[991]}) + tlfu.Access(ReadBufItem[StringKey, string]{entry: entries[991]}) + tlfu.Access(ReadBufItem[StringKey, string]{entry: entries[991]}) + evicted = tlfu.Set(NewEntry(StringKey("1a"), "", 1)) + require.Equal(t, StringKey("1a"), evicted.key) + require.Equal(t, 998, tlfu.slru.probation.len) + + var entries2 []*Entry[StringKey, string] + for i := 0; i < 1000; i++ { + e := NewEntry(StringKey(fmt.Sprintf("%d*", i)), "", 1) + tlfu.Set(e) + entries2 = append(entries2, e) + } + require.Equal(t, 998, tlfu.slru.probation.len) + require.Equal(t, 2, tlfu.slru.protected.len) + + for _, i := range []int{997, 998, 999} { + tlfu.Remove(entries2[i]) + tlfu.slru.probation.display() + tlfu.slru.probation.displayReverse() + tlfu.slru.protected.display() + tlfu.slru.protected.displayReverse() + } + +} + +func TestEvictEntries(t *testing.T) { + tlfu := NewTinyLfu[StringKey, string](500) + require.Equal(t, uint(500), tlfu.slru.probation.capacity) + require.Equal(t, uint(400), tlfu.slru.protected.capacity) + require.Equal(t, 0, tlfu.slru.probation.len) + require.Equal(t, 0, tlfu.slru.protected.len) + + for i := 0; i < 500; i++ { + tlfu.Set(NewEntry(StringKey(fmt.Sprintf("%d:1", i)), "", 1)) + } + require.Equal(t, 500, tlfu.slru.probation.len) + require.Equal(t, 0, tlfu.slru.protected.len) + new := NewEntry(StringKey("l:10"), "", 10) + new.frequency.Store(10) + tlfu.Set(new) + require.Equal(t, 509, tlfu.slru.probation.len) + require.Equal(t, 0, tlfu.slru.protected.len) + // 2. probation length is 509, so remove 9 entries from probation + removed := tlfu.EvictEntries() + for _, rm := range removed { + require.True(t, strings.HasSuffix(string(rm.key), ":1")) + } + require.Equal(t, 9, len(removed)) + require.Equal(t, 500, tlfu.slru.probation.len) + require.Equal(t, 0, tlfu.slru.protected.len) + + // put l:450 to probation, this will remove 1 entry, probation len is 949 now + // remove 449 entries from probation + new = NewEntry(StringKey("l:450"), "", 450) + new.frequency.Store(10) + tlfu.Set(new) + removed = tlfu.EvictEntries() + require.Equal(t, 449, len(removed)) + require.Equal(t, 500, tlfu.slru.probation.len) + require.Equal(t, 0, tlfu.slru.protected.len) + + // put l:460 to probation, this will remove 1 entry, probation len is 959 now + // remove all entries except the new l:460 one + new = NewEntry(StringKey("l:460"), "", 460) + new.frequency.Store(10) + tlfu.Set(new) + removed = tlfu.EvictEntries() + require.Equal(t, 41, len(removed)) + require.Equal(t, 460, tlfu.slru.probation.len) + require.Equal(t, 0, tlfu.slru.protected.len) + + // access + tlfu.Access(ReadBufItem[StringKey, string]{entry: new}) + require.Equal(t, 0, tlfu.slru.probation.len) + require.Equal(t, 460, tlfu.slru.protected.len) + new.cost.Store(600) + tlfu.UpdateCost(new, 140) + removed = tlfu.EvictEntries() + require.Equal(t, 1, len(removed)) + require.Equal(t, 0, tlfu.slru.probation.len) + require.Equal(t, 0, tlfu.slru.protected.len) + +} diff --git a/go/flags/endtoend/vtgate.txt b/go/flags/endtoend/vtgate.txt index 3f2752be084..1d2f665177b 100644 --- a/go/flags/endtoend/vtgate.txt +++ b/go/flags/endtoend/vtgate.txt @@ -35,9 +35,7 @@ Usage of vtgate: --enable_set_var This will enable the use of MySQL's SET_VAR query hint for certain system variables instead of using reserved connections (default true) --enable_system_settings This will enable the system settings to be changed per session at the database connection level (default true) --foreign_key_mode string This is to provide how to handle foreign key constraint in create/alter table. Valid values are: allow, disallow (default "allow") - --gate_query_cache_lfu gate server cache algorithm. when set to true, a new cache algorithm based on a TinyLFU admission policy will be used to improve cache behavior and prevent pollution from sparse queries (default true) --gate_query_cache_memory int gate server query cache size in bytes, maximum amount of memory to be cached. vtgate analyzes every incoming query and generate a query plan, these plans are being cached in a lru cache. This config controls the capacity of the lru cache. (default 33554432) - --gate_query_cache_size int gate server query cache size, maximum number of queries to be cached. vtgate analyzes every incoming query and generate a query plan, these plans are being cached in a cache. This config controls the expected amount of unique entries in the cache. (default 5000) --gateway_initial_tablet_timeout duration At startup, the tabletGateway will wait up to this duration to get at least one tablet per keyspace/shard/tablet type (default 30s) --grpc-use-effective-groups If set, and SSL is not used, will set the immediate caller's security groups from the effective caller id's groups. --grpc-use-static-authentication-callerid If set, will set the immediate caller id to the username authenticated by the static auth plugin. diff --git a/go/flags/endtoend/vttablet.txt b/go/flags/endtoend/vttablet.txt index 9b42ebb5644..b787ffbfe57 100644 --- a/go/flags/endtoend/vttablet.txt +++ b/go/flags/endtoend/vttablet.txt @@ -225,9 +225,7 @@ Usage of vttablet: --queryserver-config-passthrough-dmls query server pass through all dml statements without rewriting --queryserver-config-pool-conn-max-lifetime duration query server connection max lifetime (in seconds), vttablet manages various mysql connection pools. This config means if a connection has lived at least this long, it connection will be removed from pool upon the next time it is returned to the pool. (default 0s) --queryserver-config-pool-size int query server read pool size, connection pool is used by regular queries (non streaming, not in a transaction) (default 16) - --queryserver-config-query-cache-lfu query server cache algorithm. when set to true, a new cache algorithm based on a TinyLFU admission policy will be used to improve cache behavior and prevent pollution from sparse queries (default true) --queryserver-config-query-cache-memory int query server query cache size in bytes, maximum amount of memory to be used for caching. vttablet analyzes every incoming query and generate a query plan, these plans are being cached in a lru cache. This config controls the capacity of the lru cache. (default 33554432) - --queryserver-config-query-cache-size int query server query cache size, maximum number of queries to be cached. vttablet analyzes every incoming query and generate a query plan, these plans are being cached in a lru cache. This config controls the capacity of the lru cache. (default 5000) --queryserver-config-query-pool-timeout duration query server query pool timeout (in seconds), it is how long vttablet waits for a connection from the query pool. If set to 0 (default) then the overall query timeout is used instead. (default 0s) --queryserver-config-query-pool-waiter-cap int query server query pool waiter limit, this is the maximum number of queries that can be queued waiting to get a connection (default 5000) --queryserver-config-query-timeout duration query server query timeout (in seconds), this is the query timeout in vttablet side. If a query takes more than this timeout, it will be killed. (default 30s) diff --git a/go/test/endtoend/cluster/cluster_process.go b/go/test/endtoend/cluster/cluster_process.go index 5db2078d666..e111e4325f3 100644 --- a/go/test/endtoend/cluster/cluster_process.go +++ b/go/test/endtoend/cluster/cluster_process.go @@ -61,6 +61,7 @@ import ( const ( DefaultCell = "zone1" DefaultStartPort = 6700 + DefaultVttestEnv = "VTTEST=endtoend" ) var ( diff --git a/go/test/endtoend/cluster/mysqlctl_process.go b/go/test/endtoend/cluster/mysqlctl_process.go index c3ab377320a..b5e7cfb5a32 100644 --- a/go/test/endtoend/cluster/mysqlctl_process.go +++ b/go/test/endtoend/cluster/mysqlctl_process.go @@ -148,6 +148,8 @@ ssl_key={{.ServerKey}} } } tmpProcess.Args = append(tmpProcess.Args, "start") + tmpProcess.Env = append(tmpProcess.Env, os.Environ()...) + tmpProcess.Env = append(tmpProcess.Env, DefaultVttestEnv) log.Infof("Starting mysqlctl with command: %v", tmpProcess.Args) return tmpProcess, tmpProcess.Start() } diff --git a/go/test/endtoend/cluster/mysqlctld_process.go b/go/test/endtoend/cluster/mysqlctld_process.go index 4f51e8ca888..9a0f36e3918 100644 --- a/go/test/endtoend/cluster/mysqlctld_process.go +++ b/go/test/endtoend/cluster/mysqlctld_process.go @@ -95,6 +95,7 @@ func (mysqlctld *MysqlctldProcess) Start() error { tempProcess.Stderr = errFile tempProcess.Env = append(tempProcess.Env, os.Environ()...) + tempProcess.Env = append(tempProcess.Env, DefaultVttestEnv) tempProcess.Stdout = os.Stdout tempProcess.Stderr = os.Stderr diff --git a/go/test/endtoend/cluster/topo_process.go b/go/test/endtoend/cluster/topo_process.go index 8fd4bd1c74c..6a9ba1ec438 100644 --- a/go/test/endtoend/cluster/topo_process.go +++ b/go/test/endtoend/cluster/topo_process.go @@ -96,6 +96,7 @@ func (topo *TopoProcess) SetupEtcd() (err error) { topo.proc.Stderr = errFile topo.proc.Env = append(topo.proc.Env, os.Environ()...) + topo.proc.Env = append(topo.proc.Env, DefaultVttestEnv) log.Infof("Starting etcd with command: %v", strings.Join(topo.proc.Args, " ")) diff --git a/go/test/endtoend/cluster/vtbackup_process.go b/go/test/endtoend/cluster/vtbackup_process.go index be75026bf0d..ba508e8d593 100644 --- a/go/test/endtoend/cluster/vtbackup_process.go +++ b/go/test/endtoend/cluster/vtbackup_process.go @@ -84,6 +84,7 @@ func (vtbackup *VtbackupProcess) Setup() (err error) { vtbackup.proc.Stdout = os.Stdout vtbackup.proc.Env = append(vtbackup.proc.Env, os.Environ()...) + vtbackup.proc.Env = append(vtbackup.proc.Env, DefaultVttestEnv) log.Infof("Running vtbackup with args: %v", strings.Join(vtbackup.proc.Args, " ")) err = vtbackup.proc.Run() diff --git a/go/test/endtoend/cluster/vtctld_process.go b/go/test/endtoend/cluster/vtctld_process.go index 5e85f172ce1..b8f6cf240fc 100644 --- a/go/test/endtoend/cluster/vtctld_process.go +++ b/go/test/endtoend/cluster/vtctld_process.go @@ -74,6 +74,7 @@ func (vtctld *VtctldProcess) Setup(cell string, extraArgs ...string) (err error) vtctld.proc.Stderr = errFile vtctld.proc.Env = append(vtctld.proc.Env, os.Environ()...) + vtctld.proc.Env = append(vtctld.proc.Env, DefaultVttestEnv) log.Infof("Starting vtctld with command: %v", strings.Join(vtctld.proc.Args, " ")) diff --git a/go/test/endtoend/cluster/vtgate_process.go b/go/test/endtoend/cluster/vtgate_process.go index 3bf29c8af9b..48aecab7c1e 100644 --- a/go/test/endtoend/cluster/vtgate_process.go +++ b/go/test/endtoend/cluster/vtgate_process.go @@ -127,6 +127,7 @@ func (vtgate *VtgateProcess) Setup() (err error) { vtgate.proc.Stderr = errFile vtgate.proc.Env = append(vtgate.proc.Env, os.Environ()...) + vtgate.proc.Env = append(vtgate.proc.Env, DefaultVttestEnv) log.Infof("Running vtgate with command: %v", strings.Join(vtgate.proc.Args, " ")) diff --git a/go/test/endtoend/cluster/vtorc_process.go b/go/test/endtoend/cluster/vtorc_process.go index 34c4f3295ab..28c6355a175 100644 --- a/go/test/endtoend/cluster/vtorc_process.go +++ b/go/test/endtoend/cluster/vtorc_process.go @@ -133,6 +133,7 @@ func (orc *VTOrcProcess) Setup() (err error) { orc.proc.Stderr = errFile orc.proc.Env = append(orc.proc.Env, os.Environ()...) + orc.proc.Env = append(orc.proc.Env, DefaultVttestEnv) log.Infof("Running vtorc with command: %v", strings.Join(orc.proc.Args, " ")) diff --git a/go/test/endtoend/cluster/vttablet_process.go b/go/test/endtoend/cluster/vttablet_process.go index eb44a96f346..9f5e513c7a0 100644 --- a/go/test/endtoend/cluster/vttablet_process.go +++ b/go/test/endtoend/cluster/vttablet_process.go @@ -132,6 +132,7 @@ func (vttablet *VttabletProcess) Setup() (err error) { vttablet.proc.Stderr = errFile vttablet.proc.Env = append(vttablet.proc.Env, os.Environ()...) + vttablet.proc.Env = append(vttablet.proc.Env, DefaultVttestEnv) log.Infof("Running vttablet with command: %v", strings.Join(vttablet.proc.Args, " ")) diff --git a/go/test/endtoend/vreplication/helper_test.go b/go/test/endtoend/vreplication/helper_test.go index ad8fd5a96c7..6d032122430 100644 --- a/go/test/endtoend/vreplication/helper_test.go +++ b/go/test/endtoend/vreplication/helper_test.go @@ -20,19 +20,18 @@ import ( "context" "crypto/rand" "encoding/hex" + "encoding/json" "fmt" "io" "net/http" "os/exec" "regexp" "sort" - "strconv" "strings" "sync/atomic" "testing" "time" - "github.com/PuerkitoBio/goquery" "github.com/buger/jsonparser" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -231,12 +230,28 @@ func waitForRowCountInTablet(t *testing.T, vttablet *cluster.VttabletProcess, da } } -func validateThatQueryExecutesOnTablet(t *testing.T, conn *mysql.Conn, tablet *cluster.VttabletProcess, ksName string, query string, matchQuery string) bool { - count := getQueryCount(tablet.QueryzURL, matchQuery) +func executeOnTablet(t *testing.T, conn *mysql.Conn, tablet *cluster.VttabletProcess, ksName string, query string, matchQuery string) (int, []byte, int, []byte) { + queryStatsURL := fmt.Sprintf("http://%s:%d/debug/query_stats", tablet.TabletHostname, tablet.Port) + + count0, body0 := getQueryCount(t, queryStatsURL, matchQuery) + qr := execVtgateQuery(t, conn, ksName, query) require.NotNil(t, qr) - newCount := getQueryCount(tablet.QueryzURL, matchQuery) - return newCount == count+1 + + count1, body1 := getQueryCount(t, queryStatsURL, matchQuery) + return count0, body0, count1, body1 +} + +func assertQueryExecutesOnTablet(t *testing.T, conn *mysql.Conn, tablet *cluster.VttabletProcess, ksName string, query string, matchQuery string) { + t.Helper() + count0, body0, count1, body1 := executeOnTablet(t, conn, tablet, ksName, query, matchQuery) + assert.Equalf(t, count0+1, count1, "query %q did not execute in target;\ntried to match %q\nbefore:\n%s\n\nafter:\n%s\n\n", query, matchQuery, body0, body1) +} + +func assertQueryDoesNotExecutesOnTablet(t *testing.T, conn *mysql.Conn, tablet *cluster.VttabletProcess, ksName string, query string, matchQuery string) { + t.Helper() + count0, body0, count1, body1 := executeOnTablet(t, conn, tablet, ksName, query, matchQuery) + assert.Equalf(t, count0, count1, "query %q executed in target;\ntried to match %q\nbefore:\n%s\n\nafter:\n%s\n\n", query, matchQuery, body0, body1) } // waitForWorkflowState waits for all of the given workflow's @@ -351,77 +366,36 @@ func confirmTablesHaveSecondaryKeys(t *testing.T, tablets []*cluster.VttabletPro } } -func getHTTPBody(url string) string { +func getHTTPBody(t *testing.T, url string) []byte { resp, err := http.Get(url) - if err != nil { - log.Infof("http Get returns %+v", err) - return "" - } - if resp.StatusCode != 200 { - log.Infof("http Get returns status %d", resp.StatusCode) - return "" - } - respByte, _ := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, 200, resp.StatusCode) + defer resp.Body.Close() - body := string(respByte) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) return body } -func getQueryCount(url string, query string) int { - var headings, row []string - var rows [][]string - body := getHTTPBody(url) - doc, err := goquery.NewDocumentFromReader(strings.NewReader(body)) - if err != nil { - log.Infof("goquery parsing returns %+v\n", err) - return 0 +func getQueryCount(t *testing.T, url string, query string) (int, []byte) { + body := getHTTPBody(t, url) + + var queryStats []struct { + Query string + QueryCount uint64 } - var queryIndex, countIndex, count int - queryIndex = -1 - countIndex = -1 + err := json.Unmarshal(body, &queryStats) + require.NoError(t, err) - doc.Find("table").Each(func(index int, tablehtml *goquery.Selection) { - tablehtml.Find("tr").Each(func(indextr int, rowhtml *goquery.Selection) { - rowhtml.Find("th").Each(func(indexth int, tableheading *goquery.Selection) { - heading := tableheading.Text() - if heading == "Query" { - queryIndex = indexth - } - if heading == "Count" { - countIndex = indexth - } - headings = append(headings, heading) - }) - rowhtml.Find("td").Each(func(indexth int, tablecell *goquery.Selection) { - row = append(row, tablecell.Text()) - }) - rows = append(rows, row) - row = nil - }) - }) - if queryIndex == -1 || countIndex == -1 { - log.Infof("Queryz response is incorrect") - return 0 - } - for _, row := range rows { - if len(row) != len(headings) { - continue - } - filterChars := []string{"_", "`"} - //Queries seem to include non-printable characters at times and hence equality fails unless these are removed - re := regexp.MustCompile("[[:^ascii:]]") - foundQuery := re.ReplaceAllLiteralString(row[queryIndex], "") - cleanQuery := re.ReplaceAllLiteralString(query, "") - for _, filterChar := range filterChars { - foundQuery = strings.ReplaceAll(foundQuery, filterChar, "") - cleanQuery = strings.ReplaceAll(cleanQuery, filterChar, "") - } - if foundQuery == cleanQuery || strings.Contains(foundQuery, cleanQuery) { - count, _ = strconv.Atoi(row[countIndex]) + for _, q := range queryStats { + if strings.Contains(q.Query, query) { + return int(q.QueryCount), body } } - return count + + return 0, body } func validateDryRunResults(t *testing.T, output string, want []string) { @@ -530,8 +504,8 @@ func getDebugVar(t *testing.T, port int, varPath []string) (string, error) { var err error url := fmt.Sprintf("http://localhost:%d/debug/vars", port) log.Infof("url: %s, varPath: %s", url, strings.Join(varPath, ":")) - body := getHTTPBody(url) - val, _, _, err = jsonparser.Get([]byte(body), varPath...) + body := getHTTPBody(t, url) + val, _, _, err = jsonparser.Get(body, varPath...) require.NoError(t, err) return string(val), nil } diff --git a/go/test/endtoend/vreplication/resharding_workflows_v2_test.go b/go/test/endtoend/vreplication/resharding_workflows_v2_test.go index ec20e1d92ca..993da344905 100644 --- a/go/test/endtoend/vreplication/resharding_workflows_v2_test.go +++ b/go/test/endtoend/vreplication/resharding_workflows_v2_test.go @@ -218,7 +218,7 @@ func validateReadsRoute(t *testing.T, tabletTypes string, tablet *cluster.Vttabl for _, tt := range []string{"replica", "rdonly"} { destination := fmt.Sprintf("%s:%s@%s", tablet.Keyspace, tablet.Shard, tt) if strings.Contains(tabletTypes, tt) { - require.True(t, validateThatQueryExecutesOnTablet(t, vtgateConn, tablet, destination, readQuery, readQuery)) + assertQueryExecutesOnTablet(t, vtgateConn, tablet, destination, readQuery, readQuery) } } } @@ -233,17 +233,17 @@ func validateReadsRouteToTarget(t *testing.T, tabletTypes string) { func validateWritesRouteToSource(t *testing.T) { insertQuery := "insert into customer(name, cid) values('tempCustomer2', 200)" - matchInsertQuery := "insert into customer(name, cid) values" - require.True(t, validateThatQueryExecutesOnTablet(t, vtgateConn, sourceTab, "customer", insertQuery, matchInsertQuery)) + matchInsertQuery := "insert into customer(`name`, cid) values" + assertQueryExecutesOnTablet(t, vtgateConn, sourceTab, "customer", insertQuery, matchInsertQuery) execVtgateQuery(t, vtgateConn, "customer", "delete from customer where cid > 100") } func validateWritesRouteToTarget(t *testing.T) { insertQuery := "insert into customer(name, cid) values('tempCustomer3', 101)" - matchInsertQuery := "insert into customer(name, cid) values" - require.True(t, validateThatQueryExecutesOnTablet(t, vtgateConn, targetTab2, "customer", insertQuery, matchInsertQuery)) + matchInsertQuery := "insert into customer(`name`, cid) values" + assertQueryExecutesOnTablet(t, vtgateConn, targetTab2, "customer", insertQuery, matchInsertQuery) insertQuery = "insert into customer(name, cid) values('tempCustomer3', 102)" - require.True(t, validateThatQueryExecutesOnTablet(t, vtgateConn, targetTab1, "customer", insertQuery, matchInsertQuery)) + assertQueryExecutesOnTablet(t, vtgateConn, targetTab1, "customer", insertQuery, matchInsertQuery) execVtgateQuery(t, vtgateConn, "customer", "delete from customer where cid > 100") } diff --git a/go/test/endtoend/vreplication/vdiff_helper_test.go b/go/test/endtoend/vreplication/vdiff_helper_test.go index 9f5ce973c40..982ea04c957 100644 --- a/go/test/endtoend/vreplication/vdiff_helper_test.go +++ b/go/test/endtoend/vreplication/vdiff_helper_test.go @@ -70,7 +70,7 @@ func doVDiff1(t *testing.T, ksWorkflow, cells string) { diffReports := make(map[string]*wrangler.DiffReport) t.Logf("vdiff1 output: %s", output) err = json.Unmarshal([]byte(output), &diffReports) - require.NoError(t, err) + require.NoErrorf(t, err, "full output: %s", output) if len(diffReports) < 1 { t.Fatal("VDiff did not return a valid json response " + output + "\n") } diff --git a/go/test/endtoend/vreplication/vreplication_test.go b/go/test/endtoend/vreplication/vreplication_test.go index 12f8ee7dad6..38c7aa8faa3 100644 --- a/go/test/endtoend/vreplication/vreplication_test.go +++ b/go/test/endtoend/vreplication/vreplication_test.go @@ -787,10 +787,10 @@ func shardCustomer(t *testing.T, testReverse bool, cells []*Cell, sourceCellOrAl }) query := "select cid from customer" - require.True(t, validateThatQueryExecutesOnTablet(t, vtgateConn, productTab, "product", query, query)) + assertQueryExecutesOnTablet(t, vtgateConn, productTab, "product", query, query) insertQuery1 := "insert into customer(cid, name) values(1001, 'tempCustomer1')" matchInsertQuery1 := "insert into customer(cid, `name`) values (:vtg1 /* INT64 */, :vtg2 /* VARCHAR */)" - require.True(t, validateThatQueryExecutesOnTablet(t, vtgateConn, productTab, "product", insertQuery1, matchInsertQuery1)) + assertQueryExecutesOnTablet(t, vtgateConn, productTab, "product", insertQuery1, matchInsertQuery1) // FIXME for some reason, these inserts fails on mac, need to investigate, some // vreplication bug because of case insensitiveness of table names on mac? @@ -809,7 +809,7 @@ func shardCustomer(t *testing.T, testReverse bool, cells []*Cell, sourceCellOrAl vdiff1(t, ksWorkflow, "") switchReadsDryRun(t, workflowType, allCellNames, ksWorkflow, dryRunResultsReadCustomerShard) switchReads(t, workflowType, allCellNames, ksWorkflow, false) - require.True(t, validateThatQueryExecutesOnTablet(t, vtgateConn, productTab, "customer", query, query)) + assertQueryExecutesOnTablet(t, vtgateConn, productTab, "customer", query, query) var commit func(t *testing.T) if withOpenTx { @@ -840,14 +840,14 @@ func shardCustomer(t *testing.T, testReverse bool, cells []*Cell, sourceCellOrAl ksShards := []string{"product/0", "customer/-80", "customer/80-"} printShardPositions(vc, ksShards) insertQuery2 := "insert into customer(name, cid) values('tempCustomer2', 100)" - matchInsertQuery2 := "insert into customer(`name`, cid) values (:vtg1 /* VARCHAR */, :_cid0)" - require.False(t, validateThatQueryExecutesOnTablet(t, vtgateConn, productTab, "customer", insertQuery2, matchInsertQuery2)) + matchInsertQuery2 := "insert into customer(`name`, cid) values (:vtg1 /* VARCHAR */, :_cid_0)" + assertQueryDoesNotExecutesOnTablet(t, vtgateConn, productTab, "customer", insertQuery2, matchInsertQuery2) insertQuery2 = "insert into customer(name, cid) values('tempCustomer3', 101)" // ID 101, hence due to reverse_bits in shard 80- - require.True(t, validateThatQueryExecutesOnTablet(t, vtgateConn, customerTab2, "customer", insertQuery2, matchInsertQuery2)) + assertQueryExecutesOnTablet(t, vtgateConn, customerTab2, "customer", insertQuery2, matchInsertQuery2) insertQuery2 = "insert into customer(name, cid) values('tempCustomer4', 102)" // ID 102, hence due to reverse_bits in shard -80 - require.True(t, validateThatQueryExecutesOnTablet(t, vtgateConn, customerTab1, "customer", insertQuery2, matchInsertQuery2)) + assertQueryExecutesOnTablet(t, vtgateConn, customerTab1, "customer", insertQuery2, matchInsertQuery2) execVtgateQuery(t, vtgateConn, "customer", "update customer set meta = convert(x'7b7d' using utf8mb4) where cid = 1") if testReverse { @@ -862,12 +862,12 @@ func shardCustomer(t *testing.T, testReverse bool, cells []*Cell, sourceCellOrAl require.Contains(t, output, "'customer.bmd5'") insertQuery1 = "insert into customer(cid, name) values(1002, 'tempCustomer5')" - require.True(t, validateThatQueryExecutesOnTablet(t, vtgateConn, productTab, "product", insertQuery1, matchInsertQuery1)) + assertQueryExecutesOnTablet(t, vtgateConn, productTab, "product", insertQuery1, matchInsertQuery1) // both inserts go into 80-, this tests the edge-case where a stream (-80) has no relevant new events after the previous switch insertQuery1 = "insert into customer(cid, name) values(1003, 'tempCustomer6')" - require.False(t, validateThatQueryExecutesOnTablet(t, vtgateConn, customerTab1, "customer", insertQuery1, matchInsertQuery1)) + assertQueryDoesNotExecutesOnTablet(t, vtgateConn, customerTab1, "customer", insertQuery1, matchInsertQuery1) insertQuery1 = "insert into customer(cid, name) values(1004, 'tempCustomer7')" - require.False(t, validateThatQueryExecutesOnTablet(t, vtgateConn, customerTab2, "customer", insertQuery1, matchInsertQuery1)) + assertQueryDoesNotExecutesOnTablet(t, vtgateConn, customerTab2, "customer", insertQuery1, matchInsertQuery1) waitForNoWorkflowLag(t, vc, targetKs, workflow) @@ -902,11 +902,11 @@ func shardCustomer(t *testing.T, testReverse bool, cells []*Cell, sourceCellOrAl require.True(t, found) insertQuery2 = "insert into customer(name, cid) values('tempCustomer8', 103)" // ID 103, hence due to reverse_bits in shard 80- - require.False(t, validateThatQueryExecutesOnTablet(t, vtgateConn, productTab, "customer", insertQuery2, matchInsertQuery2)) + assertQueryDoesNotExecutesOnTablet(t, vtgateConn, productTab, "customer", insertQuery2, matchInsertQuery2) insertQuery2 = "insert into customer(name, cid) values('tempCustomer10', 104)" // ID 105, hence due to reverse_bits in shard -80 - require.True(t, validateThatQueryExecutesOnTablet(t, vtgateConn, customerTab1, "customer", insertQuery2, matchInsertQuery2)) + assertQueryExecutesOnTablet(t, vtgateConn, customerTab1, "customer", insertQuery2, matchInsertQuery2) insertQuery2 = "insert into customer(name, cid) values('tempCustomer9', 105)" // ID 104, hence due to reverse_bits in shard 80- - require.True(t, validateThatQueryExecutesOnTablet(t, vtgateConn, customerTab2, "customer", insertQuery2, matchInsertQuery2)) + assertQueryExecutesOnTablet(t, vtgateConn, customerTab2, "customer", insertQuery2, matchInsertQuery2) execVtgateQuery(t, vtgateConn, "customer", "delete from customer where name like 'tempCustomer%'") waitForRowCountInTablet(t, customerTab1, "customer", "customer", 1) diff --git a/go/test/endtoend/vtgate/queries/normalize/normalize_test.go b/go/test/endtoend/vtgate/queries/normalize/normalize_test.go index 4fa4313e76c..52e30accf03 100644 --- a/go/test/endtoend/vtgate/queries/normalize/normalize_test.go +++ b/go/test/endtoend/vtgate/queries/normalize/normalize_test.go @@ -25,11 +25,11 @@ import ( "testing" "time" - "vitess.io/vitess/go/test/endtoend/utils" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "vitess.io/vitess/go/test/endtoend/utils" + "vitess.io/vitess/go/mysql" ) @@ -46,38 +46,24 @@ func TestNormalizeAllFields(t *testing.T) { assert.Equal(t, 1, len(qr.Rows), "wrong number of table rows, expected 1 but had %d. Results: %v", len(qr.Rows), qr.Rows) // Now need to figure out the best way to check the normalized query in the planner cache... - results, err := getPlanCache(fmt.Sprintf("%s:%d", vtParams.Host, clusterInstance.VtgateProcess.Port)) - require.Nil(t, err) - found := false - for _, record := range results { - key := record["Key"].(string) - if key == normalizedInsertQuery { - found = true - break - } - } - assert.Truef(t, found, "correctly normalized record not found in planner cache %v", results) + results := getPlanCache(t, fmt.Sprintf("%s:%d", vtParams.Host, clusterInstance.VtgateProcess.Port)) + assert.Contains(t, results, normalizedInsertQuery) } -func getPlanCache(vtgateHostPort string) ([]map[string]any, error) { - var results []map[string]any +func getPlanCache(t *testing.T, vtgateHostPort string) map[string]any { + var results map[string]any client := http.Client{ Timeout: 10 * time.Second, } resp, err := client.Get(fmt.Sprintf("http://%s/debug/query_plans", vtgateHostPort)) - if err != nil { - return results, err - } + require.NoError(t, err) defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) - if err != nil { - return results, err - } + require.NoError(t, err) err = json.Unmarshal(body, &results) - if err != nil { - return results, err - } + require.NoErrorf(t, err, "failed to unmarshal results. contents:\n%s\n\n", body) - return results, nil + return results } diff --git a/go/vt/servenv/servenv.go b/go/vt/servenv/servenv.go index 06942d4f4b9..0c1deddeb10 100644 --- a/go/vt/servenv/servenv.go +++ b/go/vt/servenv/servenv.go @@ -516,3 +516,10 @@ func RegisterFlagsForTopoBinaries(registerFlags func(fs *pflag.FlagSet)) { OnParseFor(cmd, registerFlags) } } + +// TestingEndtoend is true when this Vitess binary is being ran as part of an endtoend test suite +var TestingEndtoend = false + +func init() { + TestingEndtoend = os.Getenv("VTTEST") == "endtoend" +} diff --git a/go/vt/vtexplain/vtexplain_vtgate.go b/go/vt/vtexplain/vtexplain_vtgate.go index a98e6c3a724..bbeb99e0e36 100644 --- a/go/vt/vtexplain/vtexplain_vtgate.go +++ b/go/vt/vtexplain/vtexplain_vtgate.go @@ -25,10 +25,10 @@ import ( "sort" "strings" + "vitess.io/vitess/go/cache/theine" "vitess.io/vitess/go/vt/vtgate/logstats" "vitess.io/vitess/go/vt/vtgate/vindexes" - "vitess.io/vitess/go/cache" "vitess.io/vitess/go/vt/topo" "vitess.io/vitess/go/vt/topo/memorytopo" @@ -74,8 +74,9 @@ func (vte *VTExplain) initVtgateExecutor(ctx context.Context, vSchemaStr, ksShar streamSize := 10 var schemaTracker vtgate.SchemaInfo // no schema tracker for these tests queryLogBufferSize := 10 - plans := cache.NewDefaultCacheImpl(cache.DefaultConfig) - vte.vtgateExecutor = vtgate.NewExecutor(ctx, vte.explainTopo, vtexplainCell, resolver, opts.Normalize, false, streamSize, plans, schemaTracker, false, opts.PlannerVersion, streamlog.New[*logstats.LogStats]("VTGate", queryLogBufferSize)) + plans := theine.NewStore[vtgate.PlanCacheKey, *engine.Plan](4*1024*1024, false) + vte.vtgateExecutor = vtgate.NewExecutor(ctx, vte.explainTopo, vtexplainCell, resolver, opts.Normalize, false, streamSize, plans, schemaTracker, false, opts.PlannerVersion) + vte.vtgateExecutor.SetQueryLogger(streamlog.New[*logstats.LogStats]("VTGate", queryLogBufferSize)) return nil } @@ -207,29 +208,27 @@ func (vte *VTExplain) vtgateExecute(sql string) ([]*engine.Plan, map[string]*Tab // This will ensure that the commit/rollback order is predictable. vte.sortShardSession() - // use the plan cache to get the set of plans used for this query, then - // clear afterwards for the next run - planCache := vte.vtgateExecutor.Plans() - _, err := vte.vtgateExecutor.Execute(context.Background(), nil, "VtexplainExecute", vtgate.NewSafeSession(vte.vtgateSession), sql, nil) if err != nil { for _, tc := range vte.explainTopo.TabletConns { tc.tabletQueries = nil tc.mysqlQueries = nil } - planCache.Clear() - + vte.vtgateExecutor.ClearPlans() return nil, nil, vterrors.Wrapf(err, "vtexplain execute error in '%s'", sql) } var plans []*engine.Plan - planCache.ForEach(func(value any) bool { - plan := value.(*engine.Plan) + + // use the plan cache to get the set of plans used for this query, then + // clear afterwards for the next run + vte.vtgateExecutor.ForEachPlan(func(plan *engine.Plan) bool { plan.ExecTime = 0 plans = append(plans, plan) return true }) - planCache.Clear() + + vte.vtgateExecutor.ClearPlans() tabletActions := make(map[string]*TabletActions) for shard, tc := range vte.explainTopo.TabletConns { diff --git a/go/vt/vtgate/executor.go b/go/vt/vtgate/executor.go index 07b1ba79f42..5b94183a950 100644 --- a/go/vt/vtgate/executor.go +++ b/go/vt/vtgate/executor.go @@ -17,11 +17,8 @@ limitations under the License. package vtgate import ( - "bufio" "bytes" "context" - "crypto/sha256" - "encoding/hex" "encoding/json" "fmt" "io" @@ -33,10 +30,11 @@ import ( "github.com/spf13/pflag" + "vitess.io/vitess/go/cache/theine" "vitess.io/vitess/go/streamlog" + "vitess.io/vitess/go/vt/vthash" "vitess.io/vitess/go/acl" - "vitess.io/vitess/go/cache" "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/stats" @@ -105,9 +103,11 @@ type Executor struct { mu sync.Mutex vschema *vindexes.VSchema streamSize int - plans cache.Cache vschemaStats *VSchemaStats + plans *PlanCache + epoch atomic.Uint32 + normalize bool warnShardedOnly bool @@ -127,6 +127,15 @@ const pathQueryPlans = "/debug/query_plans" const pathScatterStats = "/debug/scatter_stats" const pathVSchema = "/debug/vschema" +type PlanCacheKey = theine.HashKey256 +type PlanCache = theine.Store[PlanCacheKey, *engine.Plan] + +func DefaultPlanCache() *PlanCache { + // when being endtoend tested, disable the doorkeeper to ensure reproducible results + doorkeeper := !servenv.TestingEndtoend + return theine.NewStore[PlanCacheKey, *engine.Plan](queryPlanCacheMemory, doorkeeper) +} + // NewExecutor creates a new Executor. func NewExecutor( ctx context.Context, @@ -135,11 +144,10 @@ func NewExecutor( resolver *Resolver, normalize, warnOnShardedOnly bool, streamSize int, - plans cache.Cache, + plans *PlanCache, schemaTracker SchemaInfo, noScatter bool, pv plancontext.PlannerVersion, - queryLogger *streamlog.StreamLogger[*logstats.LogStats], ) *Executor { e := &Executor{ serv: serv, @@ -147,14 +155,13 @@ func NewExecutor( resolver: resolver, scatterConn: resolver.scatterConn, txConn: resolver.scatterConn.txConn, - plans: plans, normalize: normalize, warnShardedOnly: warnOnShardedOnly, streamSize: streamSize, schemaTracker: schemaTracker, allowScatter: !noScatter, pv: pv, - queryLogger: queryLogger, + plans: plans, } vschemaacl.Init() @@ -172,19 +179,19 @@ func NewExecutor( return int64(e.plans.Len()) }) stats.NewGaugeFunc("QueryPlanCacheSize", "Query plan cache size", func() int64 { - return e.plans.UsedCapacity() + return int64(e.plans.UsedCapacity()) }) stats.NewGaugeFunc("QueryPlanCacheCapacity", "Query plan cache capacity", func() int64 { - return e.plans.MaxCapacity() + return int64(e.plans.MaxCapacity()) }) stats.NewCounterFunc("QueryPlanCacheEvictions", "Query plan cache evictions", func() int64 { - return e.plans.Evictions() + return e.plans.Metrics.Evicted() }) stats.NewCounterFunc("QueryPlanCacheHits", "Query plan cache hits", func() int64 { - return e.plans.Hits() + return e.plans.Metrics.Hits() }) stats.NewCounterFunc("QueryPlanCacheMisses", "Query plan cache misses", func() int64 { - return e.plans.Misses() + return e.plans.Metrics.Hits() }) servenv.HTTPHandle(pathQueryPlans, e) servenv.HTTPHandle(pathScatterStats, e) @@ -977,7 +984,7 @@ func (e *Executor) SaveVSchema(vschema *vindexes.VSchema, stats *VSchemaStats) { e.vschema = vschema } e.vschemaStats = stats - e.plans.Clear() + e.ClearPlans() if vschemaCounters != nil { vschemaCounters.Add("Reload", 1) @@ -1070,37 +1077,23 @@ func (e *Executor) getPlan( return e.cacheAndBuildStatement(ctx, vcursor, query, stmt, reservedVars, bindVarNeeds, logStats) } -func (e *Executor) hashPlan(ctx context.Context, vcursor *vcursorImpl, query string) string { - planHash := sha256.New() - - { - // use a bufio.Writer to accumulate writes instead of writing directly to the hasher - buf := bufio.NewWriter(planHash) - vcursor.keyForPlan(ctx, query, buf) - buf.Flush() - } +func (e *Executor) hashPlan(ctx context.Context, vcursor *vcursorImpl, query string) PlanCacheKey { + hasher := vthash.New256() + vcursor.keyForPlan(ctx, query, hasher) - return hex.EncodeToString(planHash.Sum(nil)) + var planKey PlanCacheKey + hasher.Sum(planKey[:0]) + return planKey } -func (e *Executor) cacheAndBuildStatement( +func (e *Executor) buildStatement( ctx context.Context, vcursor *vcursorImpl, query string, stmt sqlparser.Statement, reservedVars *sqlparser.ReservedVars, bindVarNeeds *sqlparser.BindVarNeeds, - logStats *logstats.LogStats, ) (*engine.Plan, error) { - planKey := e.hashPlan(ctx, vcursor, query) - planCachable := sqlparser.CachePlan(stmt) && vcursor.safeSession.cachePlan() - if planCachable { - if plan, ok := e.plans.Get(planKey); ok { - logStats.CachedPlan = true - return plan.(*engine.Plan), nil - } - } - plan, err := planbuilder.BuildFromStmt(ctx, query, stmt, reservedVars, vcursor, bindVarNeeds, enableOnlineDDL, enableDirectDDL) if err != nil { return nil, err @@ -1110,13 +1103,32 @@ func (e *Executor) cacheAndBuildStatement( vcursor.warnings = nil err = e.checkThatPlanIsValid(stmt, plan) - // Only cache the plan if it is valid (i.e. does not scatter) - if err == nil && planCachable { - e.plans.Set(planKey, plan) - } return plan, err } +func (e *Executor) cacheAndBuildStatement( + ctx context.Context, + vcursor *vcursorImpl, + query string, + stmt sqlparser.Statement, + reservedVars *sqlparser.ReservedVars, + bindVarNeeds *sqlparser.BindVarNeeds, + logStats *logstats.LogStats, +) (*engine.Plan, error) { + planCachable := sqlparser.CachePlan(stmt) && vcursor.safeSession.cachePlan() + if planCachable { + planKey := e.hashPlan(ctx, vcursor, query) + + var plan *engine.Plan + var err error + plan, logStats.CachedPlan, err = e.plans.GetOrLoad(planKey, e.epoch.Load(), func() (*engine.Plan, error) { + return e.buildStatement(ctx, vcursor, query, stmt, reservedVars, bindVarNeeds) + }) + return plan, err + } + return e.buildStatement(ctx, vcursor, query, stmt, reservedVars, bindVarNeeds) +} + func (e *Executor) canNormalizeStatement(stmt sqlparser.Statement, setVarComment string) bool { return sqlparser.CanNormalize(stmt) || setVarComment != "" } @@ -1149,18 +1161,10 @@ func prepareSetVarComment(vcursor *vcursorImpl, stmt sqlparser.Statement) (strin return strings.TrimSpace(res.String()), nil } -type cacheItem struct { - Key string - Value *engine.Plan -} - -func (e *Executor) debugCacheEntries() (items []cacheItem) { - e.plans.ForEach(func(value any) bool { - plan := value.(*engine.Plan) - items = append(items, cacheItem{ - Key: plan.Original, - Value: plan, - }) +func (e *Executor) debugCacheEntries() (items map[string]*engine.Plan) { + items = make(map[string]*engine.Plan) + e.ForEachPlan(func(plan *engine.Plan) bool { + items[plan.Original] = plan return true }) return @@ -1198,10 +1202,20 @@ func returnAsJSON(response http.ResponseWriter, stuff any) { } // Plans returns the LRU plan cache -func (e *Executor) Plans() cache.Cache { +func (e *Executor) Plans() *PlanCache { return e.plans } +func (e *Executor) ForEachPlan(each func(plan *engine.Plan) bool) { + e.plans.Range(e.epoch.Load(), func(_ PlanCacheKey, value *engine.Plan) bool { + return each(value) + }) +} + +func (e *Executor) ClearPlans() { + e.epoch.Add(1) +} + func (e *Executor) updateQueryCounts(planType, keyspace, tableName string, shardQueries int64) { queriesProcessed.Add(planType, 1) queriesRouted.Add(planType, shardQueries) diff --git a/go/vt/vtgate/executor_framework_test.go b/go/vt/vtgate/executor_framework_test.go index a1e1e91d53b..107215d6f4d 100644 --- a/go/vt/vtgate/executor_framework_test.go +++ b/go/vt/vtgate/executor_framework_test.go @@ -28,9 +28,10 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "vitess.io/vitess/go/cache/theine" "vitess.io/vitess/go/test/utils" + "vitess.io/vitess/go/vt/vtgate/engine" - "vitess.io/vitess/go/cache" "vitess.io/vitess/go/constants/sidecar" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/streamlog" @@ -177,8 +178,15 @@ func createExecutorEnv(t testing.TB) (executor *Executor, sbc1, sbc2, sbclookup _ = hc.AddTestTablet(cell, "2", 3, KsTestUnsharded, "0", topodatapb.TabletType_REPLICA, true, 1, nil) queryLogger := streamlog.New[*logstats.LogStats]("VTGate", queryLogBufferSize) - plans := cache.NewDefaultCacheImpl(cache.DefaultConfig) - executor = NewExecutor(ctx, serv, cell, resolver, false, false, testBufferSize, plans, nil, false, querypb.ExecuteOptions_Gen4, queryLogger) + + // All these vtgate tests expect plans to be immediately cached after first use; + // this is not the actual behavior of the system in a production context because we use a doorkeeper + // that sometimes can cause a plan to not be cached the very first time it's seen, to prevent + // one-off queries from thrashing the cache. Disable the doorkeeper in the tests to prevent flakiness. + plans := theine.NewStore[PlanCacheKey, *engine.Plan](queryPlanCacheMemory, false) + + executor = NewExecutor(ctx, serv, cell, resolver, false, false, testBufferSize, plans, nil, false, querypb.ExecuteOptions_Gen4) + executor.SetQueryLogger(queryLogger) key.AnyShardPicker = DestinationAnyShardPickerFirstShard{} @@ -209,8 +217,9 @@ func createCustomExecutor(t testing.TB, vschema string) (executor *Executor, sbc sbclookup = hc.AddTestTablet(cell, "0", 1, KsTestUnsharded, "0", topodatapb.TabletType_PRIMARY, true, 1, nil) queryLogger := streamlog.New[*logstats.LogStats]("VTGate", queryLogBufferSize) - plans := cache.NewDefaultCacheImpl(cache.DefaultConfig) - executor = NewExecutor(ctx, serv, cell, resolver, false, false, testBufferSize, plans, nil, false, querypb.ExecuteOptions_Gen4, queryLogger) + plans := DefaultPlanCache() + executor = NewExecutor(ctx, serv, cell, resolver, false, false, testBufferSize, plans, nil, false, querypb.ExecuteOptions_Gen4) + executor.SetQueryLogger(queryLogger) t.Cleanup(func() { defer utils.EnsureNoLeaks(t) @@ -246,8 +255,9 @@ func createCustomExecutorSetValues(t testing.TB, vschema string, values []*sqlty sbclookup = hc.AddTestTablet(cell, "0", 1, KsTestUnsharded, "0", topodatapb.TabletType_PRIMARY, true, 1, nil) queryLogger := streamlog.New[*logstats.LogStats]("VTGate", queryLogBufferSize) - plans := cache.NewDefaultCacheImpl(cache.DefaultConfig) - executor = NewExecutor(ctx, serv, cell, resolver, false, false, testBufferSize, plans, nil, false, querypb.ExecuteOptions_Gen4, queryLogger) + plans := DefaultPlanCache() + executor = NewExecutor(ctx, serv, cell, resolver, false, false, testBufferSize, plans, nil, false, querypb.ExecuteOptions_Gen4) + executor.SetQueryLogger(queryLogger) t.Cleanup(func() { defer utils.EnsureNoLeaks(t) diff --git a/go/vt/vtgate/executor_scatter_stats.go b/go/vt/vtgate/executor_scatter_stats.go index cfe0b7435f2..beaa60d7012 100644 --- a/go/vt/vtgate/executor_scatter_stats.go +++ b/go/vt/vtgate/executor_scatter_stats.go @@ -62,8 +62,7 @@ func (e *Executor) gatherScatterStats() (statsResults, error) { plans := make([]*engine.Plan, 0) routes := make([]*engine.Route, 0) // First we go over all plans and collect statistics and all query plans for scatter queries - e.plans.ForEach(func(value any) bool { - plan := value.(*engine.Plan) + e.ForEachPlan(func(plan *engine.Plan) bool { scatter := engine.Find(findScatter, plan.Instructions) readOnly := !engine.Exists(isUpdating, plan.Instructions) isScatter := scatter != nil diff --git a/go/vt/vtgate/executor_scatter_stats_test.go b/go/vt/vtgate/executor_scatter_stats_test.go index 9a9e54517e7..84dd2744e8b 100644 --- a/go/vt/vtgate/executor_scatter_stats_test.go +++ b/go/vt/vtgate/executor_scatter_stats_test.go @@ -19,6 +19,7 @@ package vtgate import ( "net/http/httptest" "testing" + "time" "github.com/stretchr/testify/require" @@ -67,7 +68,7 @@ func TestScatterStatsHttpWriting(t *testing.T) { _, err = executor.Execute(ctx, nil, "TestExecutorResultsExceeded", session, query4, nil) require.NoError(t, err) - executor.plans.Wait() + time.Sleep(500 * time.Millisecond) recorder := httptest.NewRecorder() executor.WriteScatterStats(recorder) diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index b8d2625ec04..4b99eef6c33 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -37,7 +37,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "vitess.io/vitess/go/cache" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/test/utils" "vitess.io/vitess/go/vt/discovery" @@ -1561,8 +1560,10 @@ func TestStreamSelectIN(t *testing.T) { func createExecutor(ctx context.Context, serv *sandboxTopo, cell string, resolver *Resolver) *Executor { queryLogger := streamlog.New[*logstats.LogStats]("VTGate", queryLogBufferSize) - plans := cache.NewDefaultCacheImpl(cache.DefaultConfig) - return NewExecutor(ctx, serv, cell, resolver, false, false, testBufferSize, plans, nil, false, querypb.ExecuteOptions_Gen4, queryLogger) + plans := DefaultPlanCache() + ex := NewExecutor(ctx, serv, cell, resolver, false, false, testBufferSize, plans, nil, false, querypb.ExecuteOptions_Gen4) + ex.SetQueryLogger(queryLogger) + return ex } func TestSelectScatter(t *testing.T) { @@ -3186,8 +3187,9 @@ func TestStreamOrderByLimitWithMultipleResults(t *testing.T) { } queryLogger := streamlog.New[*logstats.LogStats]("VTGate", queryLogBufferSize) - plans := cache.NewDefaultCacheImpl(cache.DefaultConfig) - executor := NewExecutor(ctx, serv, cell, resolver, true, false, testBufferSize, plans, nil, false, querypb.ExecuteOptions_Gen4, queryLogger) + plans := DefaultPlanCache() + executor := NewExecutor(ctx, serv, cell, resolver, true, false, testBufferSize, plans, nil, false, querypb.ExecuteOptions_Gen4) + executor.SetQueryLogger(queryLogger) defer executor.Close() // some sleep for all goroutines to start time.Sleep(100 * time.Millisecond) @@ -4094,12 +4096,12 @@ func TestSelectCFC(t *testing.T) { for { select { case <-timeout: - t.Fatal("not able to cache a plan withing 10 seconds.") + t.Fatal("not able to cache a plan within 30 seconds.") case <-time.After(5 * time.Millisecond): // should be able to find cache entry before the timeout. cacheItems := executor.debugCacheEntries() for _, item := range cacheItems { - if strings.Contains(item.Key, "c2 from tbl_cfc where c1 like") { + if strings.Contains(item.Original, "c2 from tbl_cfc where c1 like") { return } } diff --git a/go/vt/vtgate/executor_stream_test.go b/go/vt/vtgate/executor_stream_test.go index 6f6dcd9f6b4..e5c730eb157 100644 --- a/go/vt/vtgate/executor_stream_test.go +++ b/go/vt/vtgate/executor_stream_test.go @@ -27,7 +27,6 @@ import ( vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" "vitess.io/vitess/go/vt/vtgate/logstats" - "vitess.io/vitess/go/cache" "vitess.io/vitess/go/vt/discovery" topodatapb "vitess.io/vitess/go/vt/proto/topodata" @@ -67,8 +66,11 @@ func TestStreamSQLSharded(t *testing.T) { _ = hc.AddTestTablet(cell, shard, 1, "TestExecutor", shard, topodatapb.TabletType_PRIMARY, true, 1, nil) } queryLogger := streamlog.New[*logstats.LogStats]("VTGate", queryLogBufferSize) - plans := cache.NewDefaultCacheImpl(cache.DefaultConfig) - executor := NewExecutor(ctx, serv, cell, resolver, false, false, testBufferSize, plans, nil, false, querypb.ExecuteOptions_Gen4, queryLogger) + plans := DefaultPlanCache() + + executor := NewExecutor(ctx, serv, cell, resolver, false, false, testBufferSize, plans, nil, false, querypb.ExecuteOptions_Gen4) + executor.SetQueryLogger(queryLogger) + defer executor.Close() sql := "stream * from sharded_user_msgs" diff --git a/go/vt/vtgate/executor_test.go b/go/vt/vtgate/executor_test.go index 17d73997125..2ab45f1ef42 100644 --- a/go/vt/vtgate/executor_test.go +++ b/go/vt/vtgate/executor_test.go @@ -28,6 +28,7 @@ import ( "sort" "strings" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/google/safehtml/template" @@ -35,7 +36,6 @@ import ( "github.com/stretchr/testify/require" "google.golang.org/protobuf/proto" - "vitess.io/vitess/go/cache" "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/mysql/sqlerror" "vitess.io/vitess/go/sqltypes" @@ -1665,13 +1665,9 @@ func TestGetPlanUnnormalized(t *testing.T) { } } -func assertCacheSize(t *testing.T, c cache.Cache, expected int) { +func assertCacheSize(t *testing.T, c *PlanCache, expected int) { t.Helper() - var size int - c.ForEach(func(_ any) bool { - size++ - return true - }) + size := c.Len() if size != expected { t.Errorf("getPlan() expected cache to have size %d, but got: %d", expected, size) } @@ -1682,8 +1678,7 @@ func assertCacheContains(t *testing.T, e *Executor, vc *vcursorImpl, sql string) var plan *engine.Plan if vc == nil { - e.plans.ForEach(func(x any) bool { - p := x.(*engine.Plan) + e.ForEachPlan(func(p *engine.Plan) bool { if p.Original == sql { plan = p } @@ -1691,9 +1686,7 @@ func assertCacheContains(t *testing.T, e *Executor, vc *vcursorImpl, sql string) }) } else { h := e.hashPlan(context.Background(), vc, sql) - if p, ok := e.plans.Get(h); ok { - plan = p.(*engine.Plan) - } + plan, _ = e.plans.Get(h, e.epoch.Load()) } require.Truef(t, plan != nil, "plan not found for query: %s", sql) return plan @@ -1712,7 +1705,7 @@ func getPlanCached(t *testing.T, ctx context.Context, e *Executor, vcursor *vcur require.NoError(t, err) // Wait for cache to settle - e.plans.Wait() + time.Sleep(100 * time.Millisecond) return plan, logStats } @@ -2147,7 +2140,7 @@ func TestServingKeyspaces(t *testing.T) { hc.BroadcastAll() // Clear plan cache, to force re-planning of the query. - executor.plans.Clear() + executor.ClearPlans() require.ElementsMatch(t, []string{"TestUnsharded"}, gw.GetServingKeyspaces()) result, err = executor.Execute(ctx, nil, "TestServingKeyspaces", NewSafeSession(&vtgatepb.Session{}), "select keyspace_name from dual", nil) require.NoError(t, err) diff --git a/go/vt/vtgate/querylog.go b/go/vt/vtgate/querylog.go index 94c1b23263b..7425f2feba9 100644 --- a/go/vt/vtgate/querylog.go +++ b/go/vt/vtgate/querylog.go @@ -19,7 +19,6 @@ package vtgate import ( "net/http" - "vitess.io/vitess/go/cache" "vitess.io/vitess/go/streamlog" "vitess.io/vitess/go/vt/servenv" "vitess.io/vitess/go/vt/vtgate/logstats" @@ -36,7 +35,7 @@ var ( QueryzHandler = "/debug/queryz" ) -func initQueryLogger(plans cache.Cache) (*streamlog.StreamLogger[*logstats.LogStats], error) { +func (e *Executor) defaultQueryLogger() error { queryLogger := streamlog.New[*logstats.LogStats]("VTGate", queryLogBufferSize) queryLogger.ServeLogs(QueryLogHandler, streamlog.GetFormatter(queryLogger)) @@ -47,15 +46,20 @@ func initQueryLogger(plans cache.Cache) (*streamlog.StreamLogger[*logstats.LogSt }) servenv.HTTPHandleFunc(QueryzHandler, func(w http.ResponseWriter, r *http.Request) { - queryzHandler(plans, w, r) + queryzHandler(e, w, r) }) if queryLogToFile != "" { _, err := queryLogger.LogToFile(queryLogToFile, streamlog.GetFormatter(queryLogger)) if err != nil { - return nil, err + return err } } - return queryLogger, nil + e.queryLogger = queryLogger + return nil +} + +func (e *Executor) SetQueryLogger(ql *streamlog.StreamLogger[*logstats.LogStats]) { + e.queryLogger = ql } diff --git a/go/vt/vtgate/queryz.go b/go/vt/vtgate/queryz.go index 6c5d2b89dee..e546fc68c6f 100644 --- a/go/vt/vtgate/queryz.go +++ b/go/vt/vtgate/queryz.go @@ -24,8 +24,6 @@ import ( "github.com/google/safehtml/template" - "vitess.io/vitess/go/cache" - "vitess.io/vitess/go/acl" "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/vt/logz" @@ -129,7 +127,7 @@ func (s *queryzSorter) Len() int { return len(s.rows) } func (s *queryzSorter) Swap(i, j int) { s.rows[i], s.rows[j] = s.rows[j], s.rows[i] } func (s *queryzSorter) Less(i, j int) bool { return s.less(s.rows[i], s.rows[j]) } -func queryzHandler(plans cache.Cache, w http.ResponseWriter, r *http.Request) { +func queryzHandler(e *Executor, w http.ResponseWriter, r *http.Request) { if err := acl.CheckAccessHTTP(r, acl.DEBUGGING); err != nil { acl.SendError(w, err) return @@ -145,8 +143,7 @@ func queryzHandler(plans cache.Cache, w http.ResponseWriter, r *http.Request) { }, } - plans.ForEach(func(value any) bool { - plan := value.(*engine.Plan) + e.ForEachPlan(func(plan *engine.Plan) bool { Value := &queryzRow{ Query: logz.Wrappable(sqlparser.TruncateForUI(plan.Original)), } diff --git a/go/vt/vtgate/queryz_test.go b/go/vt/vtgate/queryz_test.go index 83fd064df7d..826cb8641d8 100644 --- a/go/vt/vtgate/queryz_test.go +++ b/go/vt/vtgate/queryz_test.go @@ -46,7 +46,7 @@ func TestQueryzHandler(t *testing.T) { sql := "select id from user where id = 1" _, err := executorExec(ctx, executor, session, sql, nil) require.NoError(t, err) - executor.plans.Wait() + time.Sleep(100 * time.Millisecond) plan1 := assertCacheContains(t, executor, nil, "select id from `user` where id = 1") plan1.ExecTime = uint64(1 * time.Millisecond) @@ -54,7 +54,7 @@ func TestQueryzHandler(t *testing.T) { sql = "select id from user" _, err = executorExec(ctx, executor, session, sql, nil) require.NoError(t, err) - executor.plans.Wait() + time.Sleep(100 * time.Millisecond) plan2 := assertCacheContains(t, executor, nil, "select id from `user`") plan2.ExecTime = uint64(1 * time.Second) @@ -64,7 +64,7 @@ func TestQueryzHandler(t *testing.T) { "name": sqltypes.BytesBindVariable([]byte("myname")), }) require.NoError(t, err) - executor.plans.Wait() + time.Sleep(100 * time.Millisecond) plan3 := assertCacheContains(t, executor, nil, "insert into `user`(id, `name`) values (:id, :name)") // vindex insert from above execution @@ -82,7 +82,7 @@ func TestQueryzHandler(t *testing.T) { plan3.ExecTime = uint64(100 * time.Millisecond) plan4.ExecTime = uint64(200 * time.Millisecond) - queryzHandler(executor.plans, resp, req) + queryzHandler(executor, resp, req) body, _ := io.ReadAll(resp.Body) planPattern1 := []string{ ``, diff --git a/go/vt/vtgate/vtgate.go b/go/vt/vtgate/vtgate.go index 4f8cffceb12..8175161e37f 100644 --- a/go/vt/vtgate/vtgate.go +++ b/go/vt/vtgate/vtgate.go @@ -31,7 +31,6 @@ import ( "github.com/spf13/pflag" "vitess.io/vitess/go/acl" - "vitess.io/vitess/go/cache" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/stats" "vitess.io/vitess/go/tb" @@ -65,9 +64,7 @@ var ( truncateErrorLen int // plan cache related flag - queryPlanCacheSize = cache.DefaultConfig.MaxEntries - queryPlanCacheMemory = cache.DefaultConfig.MaxMemoryUsage - queryPlanCacheLFU bool + queryPlanCacheMemory int64 = 32 * 1024 * 1024 // 32mb maxMemoryRows = 300000 warnMemoryRows = 30000 @@ -122,9 +119,7 @@ func registerFlags(fs *pflag.FlagSet) { fs.BoolVar(&terseErrors, "vtgate-config-terse-errors", terseErrors, "prevent bind vars from escaping in returned errors") fs.IntVar(&truncateErrorLen, "truncate-error-len", truncateErrorLen, "truncate errors sent to client if they are longer than this value (0 means do not truncate)") fs.IntVar(&streamBufferSize, "stream_buffer_size", streamBufferSize, "the number of bytes sent from vtgate for each stream call. It's recommended to keep this value in sync with vttablet's query-server-config-stream-buffer-size.") - fs.Int64Var(&queryPlanCacheSize, "gate_query_cache_size", queryPlanCacheSize, "gate server query cache size, maximum number of queries to be cached. vtgate analyzes every incoming query and generate a query plan, these plans are being cached in a cache. This config controls the expected amount of unique entries in the cache.") fs.Int64Var(&queryPlanCacheMemory, "gate_query_cache_memory", queryPlanCacheMemory, "gate server query cache size in bytes, maximum amount of memory to be cached. vtgate analyzes every incoming query and generate a query plan, these plans are being cached in a lru cache. This config controls the capacity of the lru cache.") - fs.BoolVar(&queryPlanCacheLFU, "gate_query_cache_lfu", cache.DefaultConfig.LFU, "gate server cache algorithm. when set to true, a new cache algorithm based on a TinyLFU admission policy will be used to improve cache behavior and prevent pollution from sparse queries") fs.IntVar(&maxMemoryRows, "max_memory_rows", maxMemoryRows, "Maximum number of rows that will be held in memory for intermediate results as well as the final result.") fs.IntVar(&warnMemoryRows, "warn_memory_rows", warnMemoryRows, "Warning threshold for in-memory results. A row count higher than this amount will cause the VtGateWarnings.ResultsExceeded counter to be incremented.") fs.StringVar(&defaultDDLStrategy, "ddl_strategy", defaultDDLStrategy, "Set default strategy for DDL statements. Override with @@ddl_strategy session variable") @@ -152,6 +147,8 @@ func registerFlags(fs *pflag.FlagSet) { _ = fs.String("schema_change_signal_user", "", "User to be used to send down query to vttablet to retrieve schema changes") _ = fs.MarkDeprecated("schema_change_signal_user", "schema tracking uses an internal api and does not require a user to be specified") + _ = fs.MarkDeprecated("gate_query_cache_lfu", "gate server cache algorithm. when set to true, a new cache algorithm based on a TinyLFU admission policy will be used to improve cache behavior and prevent pollution from sparse queries") + _ = fs.MarkDeprecated("gate_query_cache_size", "gate server query cache size, maximum number of queries to be cached. vtgate analyzes every incoming query and generate a query plan, these plans are being cached in a cache. This config controls the expected amount of unique entries in the cache.") } func init() { servenv.OnParseFor("vtgate", registerFlags) @@ -304,17 +301,7 @@ func Init( si = st } - cacheCfg := &cache.Config{ - MaxEntries: queryPlanCacheSize, - MaxMemoryUsage: queryPlanCacheMemory, - LFU: queryPlanCacheLFU, - } - - plans := cache.NewDefaultCacheImpl(cacheCfg) - queryLogger, err := initQueryLogger(plans) - if err != nil { - log.Fatalf("error initializing query logger: %v", err) - } + plans := DefaultPlanCache() executor := NewExecutor( ctx, @@ -328,9 +315,12 @@ func Init( si, noScatter, pv, - queryLogger, ) + if err := executor.defaultQueryLogger(); err != nil { + log.Fatalf("error initializing query logger: %v", err) + } + // connect the schema tracker with the vschema manager if enableSchemaChangeSignal { st.RegisterSignalReceiver(executor.vm.Rebuild) diff --git a/go/vt/vthash/hash.go b/go/vt/vthash/hash.go index 7b6a130dc08..3dbd85af6a3 100644 --- a/go/vt/vthash/hash.go +++ b/go/vt/vthash/hash.go @@ -17,6 +17,7 @@ limitations under the License. package vthash import ( + "vitess.io/vitess/go/vt/vthash/highway" "vitess.io/vitess/go/vt/vthash/metro" ) @@ -28,3 +29,12 @@ func New() Hasher { h.Reset() return h } + +type Hasher256 = highway.Digest +type Hash256 = [32]byte + +var defaultHash256Key = [32]byte{} + +func New256() *Hasher256 { + return highway.New(defaultHash256Key) +} diff --git a/go/vt/vthash/highway/LICENSE b/go/vt/vthash/highway/LICENSE new file mode 100644 index 00000000000..d6456956733 --- /dev/null +++ b/go/vt/vthash/highway/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/go/vt/vthash/highway/highwayhash.go b/go/vt/vthash/highway/highwayhash.go new file mode 100644 index 00000000000..a922b435d9d --- /dev/null +++ b/go/vt/vthash/highway/highwayhash.go @@ -0,0 +1,184 @@ +/* +Copyright 2023 The Vitess Authors. +Copyright (c) 2017 Minio Inc. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package highwayhash implements the pseudo-random-function (PRF) HighwayHash. +// HighwayHash is a fast hash function designed to defend hash-flooding attacks +// or to authenticate short-lived messages. +// +// HighwayHash is not a general purpose cryptographic hash function and does not +// provide (strong) collision resistance. +package highway + +import ( + "encoding/binary" + "errors" + "unsafe" +) + +const ( + // Size is the size of HighwayHash-256 checksum in bytes. + Size = 32 + // Size128 is the size of HighwayHash-128 checksum in bytes. + Size128 = 16 +) + +var errKeySize = errors.New("highwayhash: invalid key size") + +// New returns a hash.Hash computing the HighwayHash-256 checksum. +// It returns a non-nil error if the key is not 32 bytes long. +func New(key [Size]byte) *Digest { + h := &Digest{size: Size, key: key} + h.Reset() + return h +} + +// New128 returns a hash.Hash computing the HighwayHash-128 checksum. +// It returns a non-nil error if the key is not 32 bytes long. +func New128(key [Size]byte) *Digest { + h := &Digest{size: Size128, key: key} + h.Reset() + return h +} + +// Sum computes the HighwayHash-256 checksum of data. +// It panics if the key is not 32 bytes long. +func Sum(data, key []byte) [Size]byte { + if len(key) != Size { + panic(errKeySize) + } + var state [16]uint64 + initialize(&state, key) + if n := len(data) & (^(Size - 1)); n > 0 { + update(&state, data[:n]) + data = data[n:] + } + if len(data) > 0 { + var block [Size]byte + offset := copy(block[:], data) + hashBuffer(&state, &block, offset) + } + var hash [Size]byte + finalize(hash[:], &state) + return hash +} + +// Sum128 computes the HighwayHash-128 checksum of data. +// It panics if the key is not 32 bytes long. +func Sum128(data, key []byte) [Size128]byte { + if len(key) != Size { + panic(errKeySize) + } + var state [16]uint64 + initialize(&state, key) + if n := len(data) & (^(Size - 1)); n > 0 { + update(&state, data[:n]) + data = data[n:] + } + if len(data) > 0 { + var block [Size]byte + offset := copy(block[:], data) + hashBuffer(&state, &block, offset) + } + var hash [Size128]byte + finalize(hash[:], &state) + return hash +} + +type Digest struct { + state [16]uint64 // v0 | v1 | mul0 | mul1 + + key, buffer [Size]byte + offset int + size int +} + +func (d *Digest) Size() int { return d.size } + +func (d *Digest) BlockSize() int { return Size } + +func (d *Digest) Reset() { + initialize(&d.state, d.key[:]) + d.offset = 0 +} + +func (d *Digest) WriteString(str string) (int, error) { + return d.Write(unsafe.Slice(unsafe.StringData(str), len(str))) +} + +func (d *Digest) Write(p []byte) (n int, err error) { + n = len(p) + if d.offset > 0 { + remaining := Size - d.offset + if n < remaining { + d.offset += copy(d.buffer[d.offset:], p) + return + } + copy(d.buffer[d.offset:], p[:remaining]) + update(&d.state, d.buffer[:]) + p = p[remaining:] + d.offset = 0 + } + if nn := len(p) & (^(Size - 1)); nn > 0 { + update(&d.state, p[:nn]) + p = p[nn:] + } + if len(p) > 0 { + d.offset = copy(d.buffer[d.offset:], p) + } + return +} + +func (d *Digest) Sum(b []byte) []byte { + state := d.state + if d.offset > 0 { + hashBuffer(&state, &d.buffer, d.offset) + } + var hash [Size]byte + finalize(hash[:d.size], &state) + return append(b, hash[:d.size]...) +} + +func hashBuffer(state *[16]uint64, buffer *[32]byte, offset int) { + var block [Size]byte + mod32 := (uint64(offset) << 32) + uint64(offset) + for i := range state[:4] { + state[i] += mod32 + } + for i := range state[4:8] { + t0 := uint32(state[i+4]) + t0 = (t0 << uint(offset)) | (t0 >> uint(32-offset)) + + t1 := uint32(state[i+4] >> 32) + t1 = (t1 << uint(offset)) | (t1 >> uint(32-offset)) + + state[i+4] = (uint64(t1) << 32) | uint64(t0) + } + + mod4 := offset & 3 + remain := offset - mod4 + + copy(block[:], buffer[:remain]) + if offset >= 16 { + copy(block[28:], buffer[offset-4:]) + } else if mod4 != 0 { + last := uint32(buffer[remain]) + last += uint32(buffer[remain+mod4>>1]) << 8 + last += uint32(buffer[offset-1]) << 16 + binary.LittleEndian.PutUint32(block[16:], last) + } + update(state, block[:]) +} diff --git a/go/vt/vthash/highway/highwayhashAVX2_amd64.s b/go/vt/vthash/highway/highwayhashAVX2_amd64.s new file mode 100644 index 00000000000..761eac33dfe --- /dev/null +++ b/go/vt/vthash/highway/highwayhashAVX2_amd64.s @@ -0,0 +1,258 @@ +// Copyright (c) 2017 Minio Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64,!gccgo,!appengine,!nacl,!noasm + +#include "textflag.h" + +DATA ·consAVX2<>+0x00(SB)/8, $0xdbe6d5d5fe4cce2f +DATA ·consAVX2<>+0x08(SB)/8, $0xa4093822299f31d0 +DATA ·consAVX2<>+0x10(SB)/8, $0x13198a2e03707344 +DATA ·consAVX2<>+0x18(SB)/8, $0x243f6a8885a308d3 +DATA ·consAVX2<>+0x20(SB)/8, $0x3bd39e10cb0ef593 +DATA ·consAVX2<>+0x28(SB)/8, $0xc0acf169b5f18a8c +DATA ·consAVX2<>+0x30(SB)/8, $0xbe5466cf34e90c6c +DATA ·consAVX2<>+0x38(SB)/8, $0x452821e638d01377 +GLOBL ·consAVX2<>(SB), (NOPTR+RODATA), $64 + +DATA ·zipperMergeAVX2<>+0x00(SB)/8, $0xf010e05020c03 +DATA ·zipperMergeAVX2<>+0x08(SB)/8, $0x70806090d0a040b +DATA ·zipperMergeAVX2<>+0x10(SB)/8, $0xf010e05020c03 +DATA ·zipperMergeAVX2<>+0x18(SB)/8, $0x70806090d0a040b +GLOBL ·zipperMergeAVX2<>(SB), (NOPTR+RODATA), $32 + +#define REDUCE_MOD(x0, x1, x2, x3, tmp0, tmp1, y0, y1) \ + MOVQ $0x3FFFFFFFFFFFFFFF, tmp0 \ + ANDQ tmp0, x3 \ + MOVQ x2, y0 \ + MOVQ x3, y1 \ + \ + MOVQ x2, tmp0 \ + MOVQ x3, tmp1 \ + SHLQ $1, tmp1 \ + SHRQ $63, tmp0 \ + MOVQ tmp1, x3 \ + ORQ tmp0, x3 \ + \ + SHLQ $1, x2 \ + \ + MOVQ y0, tmp0 \ + MOVQ y1, tmp1 \ + SHLQ $2, tmp1 \ + SHRQ $62, tmp0 \ + MOVQ tmp1, y1 \ + ORQ tmp0, y1 \ + \ + SHLQ $2, y0 \ + \ + XORQ x0, y0 \ + XORQ x2, y0 \ + XORQ x1, y1 \ + XORQ x3, y1 + +#define UPDATE(msg) \ + VPADDQ msg, Y2, Y2 \ + VPADDQ Y3, Y2, Y2 \ + \ + VPSRLQ $32, Y1, Y0 \ + BYTE $0xC5; BYTE $0xFD; BYTE $0xF4; BYTE $0xC2 \ // VPMULUDQ Y2, Y0, Y0 + VPXOR Y0, Y3, Y3 \ + \ + VPADDQ Y4, Y1, Y1 \ + \ + VPSRLQ $32, Y2, Y0 \ + BYTE $0xC5; BYTE $0xFD; BYTE $0xF4; BYTE $0xC1 \ // VPMULUDQ Y1, Y0, Y0 + VPXOR Y0, Y4, Y4 \ + \ + VPSHUFB Y5, Y2, Y0 \ + VPADDQ Y0, Y1, Y1 \ + \ + VPSHUFB Y5, Y1, Y0 \ + VPADDQ Y0, Y2, Y2 + +// func initializeAVX2(state *[16]uint64, key []byte) +TEXT ·initializeAVX2(SB), 4, $0-32 + MOVQ state+0(FP), AX + MOVQ key_base+8(FP), BX + MOVQ $·consAVX2<>(SB), CX + + VMOVDQU 0(BX), Y1 + VPSHUFD $177, Y1, Y2 + + VMOVDQU 0(CX), Y3 + VMOVDQU 32(CX), Y4 + + VPXOR Y3, Y1, Y1 + VPXOR Y4, Y2, Y2 + + VMOVDQU Y1, 0(AX) + VMOVDQU Y2, 32(AX) + VMOVDQU Y3, 64(AX) + VMOVDQU Y4, 96(AX) + VZEROUPPER + RET + +// func updateAVX2(state *[16]uint64, msg []byte) +TEXT ·updateAVX2(SB), 4, $0-32 + MOVQ state+0(FP), AX + MOVQ msg_base+8(FP), BX + MOVQ msg_len+16(FP), CX + + CMPQ CX, $32 + JB DONE + + VMOVDQU 0(AX), Y1 + VMOVDQU 32(AX), Y2 + VMOVDQU 64(AX), Y3 + VMOVDQU 96(AX), Y4 + + VMOVDQU ·zipperMergeAVX2<>(SB), Y5 + +LOOP: + VMOVDQU 0(BX), Y0 + UPDATE(Y0) + + ADDQ $32, BX + SUBQ $32, CX + JA LOOP + + VMOVDQU Y1, 0(AX) + VMOVDQU Y2, 32(AX) + VMOVDQU Y3, 64(AX) + VMOVDQU Y4, 96(AX) + VZEROUPPER + +DONE: + RET + +// func finalizeAVX2(out []byte, state *[16]uint64) +TEXT ·finalizeAVX2(SB), 4, $0-32 + MOVQ state+24(FP), AX + MOVQ out_base+0(FP), BX + MOVQ out_len+8(FP), CX + + VMOVDQU 0(AX), Y1 + VMOVDQU 32(AX), Y2 + VMOVDQU 64(AX), Y3 + VMOVDQU 96(AX), Y4 + + VMOVDQU ·zipperMergeAVX2<>(SB), Y5 + + VPERM2I128 $1, Y1, Y1, Y0 + VPSHUFD $177, Y0, Y0 + UPDATE(Y0) + + VPERM2I128 $1, Y1, Y1, Y0 + VPSHUFD $177, Y0, Y0 + UPDATE(Y0) + + VPERM2I128 $1, Y1, Y1, Y0 + VPSHUFD $177, Y0, Y0 + UPDATE(Y0) + + VPERM2I128 $1, Y1, Y1, Y0 + VPSHUFD $177, Y0, Y0 + UPDATE(Y0) + + CMPQ CX, $8 + JE skipUpdate // Just 4 rounds for 64-bit checksum + + VPERM2I128 $1, Y1, Y1, Y0 + VPSHUFD $177, Y0, Y0 + UPDATE(Y0) + + VPERM2I128 $1, Y1, Y1, Y0 + VPSHUFD $177, Y0, Y0 + UPDATE(Y0) + + CMPQ CX, $16 + JE skipUpdate // 6 rounds for 128-bit checksum + + VPERM2I128 $1, Y1, Y1, Y0 + VPSHUFD $177, Y0, Y0 + UPDATE(Y0) + + VPERM2I128 $1, Y1, Y1, Y0 + VPSHUFD $177, Y0, Y0 + UPDATE(Y0) + + VPERM2I128 $1, Y1, Y1, Y0 + VPSHUFD $177, Y0, Y0 + UPDATE(Y0) + + VPERM2I128 $1, Y1, Y1, Y0 + VPSHUFD $177, Y0, Y0 + UPDATE(Y0) + +skipUpdate: + VMOVDQU Y1, 0(AX) + VMOVDQU Y2, 32(AX) + VMOVDQU Y3, 64(AX) + VMOVDQU Y4, 96(AX) + VZEROUPPER + + CMPQ CX, $8 + JE hash64 + CMPQ CX, $16 + JE hash128 + + // 256-bit checksum + MOVQ 0*8(AX), R8 + MOVQ 1*8(AX), R9 + MOVQ 4*8(AX), R10 + MOVQ 5*8(AX), R11 + ADDQ 8*8(AX), R8 + ADDQ 9*8(AX), R9 + ADDQ 12*8(AX), R10 + ADDQ 13*8(AX), R11 + + REDUCE_MOD(R8, R9, R10, R11, R12, R13, R14, R15) + MOVQ R14, 0(BX) + MOVQ R15, 8(BX) + + MOVQ 2*8(AX), R8 + MOVQ 3*8(AX), R9 + MOVQ 6*8(AX), R10 + MOVQ 7*8(AX), R11 + ADDQ 10*8(AX), R8 + ADDQ 11*8(AX), R9 + ADDQ 14*8(AX), R10 + ADDQ 15*8(AX), R11 + + REDUCE_MOD(R8, R9, R10, R11, R12, R13, R14, R15) + MOVQ R14, 16(BX) + MOVQ R15, 24(BX) + RET + +hash128: + MOVQ 0*8(AX), R8 + MOVQ 1*8(AX), R9 + ADDQ 6*8(AX), R8 + ADDQ 7*8(AX), R9 + ADDQ 8*8(AX), R8 + ADDQ 9*8(AX), R9 + ADDQ 14*8(AX), R8 + ADDQ 15*8(AX), R9 + MOVQ R8, 0(BX) + MOVQ R9, 8(BX) + RET + +hash64: + MOVQ 0*8(AX), DX + ADDQ 4*8(AX), DX + ADDQ 8*8(AX), DX + ADDQ 12*8(AX), DX + MOVQ DX, 0(BX) + RET + diff --git a/go/vt/vthash/highway/highwayhash_amd64.go b/go/vt/vthash/highway/highwayhash_amd64.go new file mode 100644 index 00000000000..f47a47fb1d3 --- /dev/null +++ b/go/vt/vthash/highway/highwayhash_amd64.go @@ -0,0 +1,80 @@ +//go:build amd64 && !gccgo && !appengine && !nacl && !noasm +// +build amd64,!gccgo,!appengine,!nacl,!noasm + +/* +Copyright (c) 2017 Minio Inc. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package highway + +import "golang.org/x/sys/cpu" + +var ( + useSSE4 = cpu.X86.HasSSE41 + useAVX2 = cpu.X86.HasAVX2 + useNEON = false + useVMX = false +) + +//go:noescape +func initializeSSE4(state *[16]uint64, key []byte) + +//go:noescape +func initializeAVX2(state *[16]uint64, key []byte) + +//go:noescape +func updateSSE4(state *[16]uint64, msg []byte) + +//go:noescape +func updateAVX2(state *[16]uint64, msg []byte) + +//go:noescape +func finalizeSSE4(out []byte, state *[16]uint64) + +//go:noescape +func finalizeAVX2(out []byte, state *[16]uint64) + +func initialize(state *[16]uint64, key []byte) { + switch { + case useAVX2: + initializeAVX2(state, key) + case useSSE4: + initializeSSE4(state, key) + default: + initializeGeneric(state, key) + } +} + +func update(state *[16]uint64, msg []byte) { + switch { + case useAVX2: + updateAVX2(state, msg) + case useSSE4: + updateSSE4(state, msg) + default: + updateGeneric(state, msg) + } +} + +func finalize(out []byte, state *[16]uint64) { + switch { + case useAVX2: + finalizeAVX2(out, state) + case useSSE4: + finalizeSSE4(out, state) + default: + finalizeGeneric(out, state) + } +} diff --git a/go/vt/vthash/highway/highwayhash_amd64.s b/go/vt/vthash/highway/highwayhash_amd64.s new file mode 100644 index 00000000000..5c0f87256f6 --- /dev/null +++ b/go/vt/vthash/highway/highwayhash_amd64.s @@ -0,0 +1,304 @@ +// Copyright (c) 2017 Minio Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 !gccgo !appengine !nacl + +#include "textflag.h" + +DATA ·asmConstants<>+0x00(SB)/8, $0xdbe6d5d5fe4cce2f +DATA ·asmConstants<>+0x08(SB)/8, $0xa4093822299f31d0 +DATA ·asmConstants<>+0x10(SB)/8, $0x13198a2e03707344 +DATA ·asmConstants<>+0x18(SB)/8, $0x243f6a8885a308d3 +DATA ·asmConstants<>+0x20(SB)/8, $0x3bd39e10cb0ef593 +DATA ·asmConstants<>+0x28(SB)/8, $0xc0acf169b5f18a8c +DATA ·asmConstants<>+0x30(SB)/8, $0xbe5466cf34e90c6c +DATA ·asmConstants<>+0x38(SB)/8, $0x452821e638d01377 +GLOBL ·asmConstants<>(SB), (NOPTR+RODATA), $64 + +DATA ·asmZipperMerge<>+0x00(SB)/8, $0xf010e05020c03 +DATA ·asmZipperMerge<>+0x08(SB)/8, $0x70806090d0a040b +GLOBL ·asmZipperMerge<>(SB), (NOPTR+RODATA), $16 + +#define v00 X0 +#define v01 X1 +#define v10 X2 +#define v11 X3 +#define m00 X4 +#define m01 X5 +#define m10 X6 +#define m11 X7 + +#define t0 X8 +#define t1 X9 +#define t2 X10 + +#define REDUCE_MOD(x0, x1, x2, x3, tmp0, tmp1, y0, y1) \ + MOVQ $0x3FFFFFFFFFFFFFFF, tmp0 \ + ANDQ tmp0, x3 \ + MOVQ x2, y0 \ + MOVQ x3, y1 \ + \ + MOVQ x2, tmp0 \ + MOVQ x3, tmp1 \ + SHLQ $1, tmp1 \ + SHRQ $63, tmp0 \ + MOVQ tmp1, x3 \ + ORQ tmp0, x3 \ + \ + SHLQ $1, x2 \ + \ + MOVQ y0, tmp0 \ + MOVQ y1, tmp1 \ + SHLQ $2, tmp1 \ + SHRQ $62, tmp0 \ + MOVQ tmp1, y1 \ + ORQ tmp0, y1 \ + \ + SHLQ $2, y0 \ + \ + XORQ x0, y0 \ + XORQ x2, y0 \ + XORQ x1, y1 \ + XORQ x3, y1 + +#define UPDATE(msg0, msg1) \ + PADDQ msg0, v10 \ + PADDQ m00, v10 \ + PADDQ msg1, v11 \ + PADDQ m01, v11 \ + \ + MOVO v00, t0 \ + MOVO v01, t1 \ + PSRLQ $32, t0 \ + PSRLQ $32, t1 \ + PMULULQ v10, t0 \ + PMULULQ v11, t1 \ + PXOR t0, m00 \ + PXOR t1, m01 \ + \ + PADDQ m10, v00 \ + PADDQ m11, v01 \ + \ + MOVO v10, t0 \ + MOVO v11, t1 \ + PSRLQ $32, t0 \ + PSRLQ $32, t1 \ + PMULULQ v00, t0 \ + PMULULQ v01, t1 \ + PXOR t0, m10 \ + PXOR t1, m11 \ + \ + MOVO v10, t0 \ + PSHUFB t2, t0 \ + MOVO v11, t1 \ + PSHUFB t2, t1 \ + PADDQ t0, v00 \ + PADDQ t1, v01 \ + \ + MOVO v00, t0 \ + PSHUFB t2, t0 \ + MOVO v01, t1 \ + PSHUFB t2, t1 \ + PADDQ t0, v10 \ + PADDQ t1, v11 + +// func initializeSSE4(state *[16]uint64, key []byte) +TEXT ·initializeSSE4(SB), NOSPLIT, $0-32 + MOVQ state+0(FP), AX + MOVQ key_base+8(FP), BX + MOVQ $·asmConstants<>(SB), CX + + MOVOU 0(BX), v00 + MOVOU 16(BX), v01 + + PSHUFD $177, v00, v10 + PSHUFD $177, v01, v11 + + MOVOU 0(CX), m00 + MOVOU 16(CX), m01 + MOVOU 32(CX), m10 + MOVOU 48(CX), m11 + + PXOR m00, v00 + PXOR m01, v01 + PXOR m10, v10 + PXOR m11, v11 + + MOVOU v00, 0(AX) + MOVOU v01, 16(AX) + MOVOU v10, 32(AX) + MOVOU v11, 48(AX) + MOVOU m00, 64(AX) + MOVOU m01, 80(AX) + MOVOU m10, 96(AX) + MOVOU m11, 112(AX) + RET + +// func updateSSE4(state *[16]uint64, msg []byte) +TEXT ·updateSSE4(SB), NOSPLIT, $0-32 + MOVQ state+0(FP), AX + MOVQ msg_base+8(FP), BX + MOVQ msg_len+16(FP), CX + + CMPQ CX, $32 + JB DONE + + MOVOU 0(AX), v00 + MOVOU 16(AX), v01 + MOVOU 32(AX), v10 + MOVOU 48(AX), v11 + MOVOU 64(AX), m00 + MOVOU 80(AX), m01 + MOVOU 96(AX), m10 + MOVOU 112(AX), m11 + + MOVOU ·asmZipperMerge<>(SB), t2 + +LOOP: + MOVOU 0(BX), t0 + MOVOU 16(BX), t1 + + UPDATE(t0, t1) + + ADDQ $32, BX + SUBQ $32, CX + JA LOOP + + MOVOU v00, 0(AX) + MOVOU v01, 16(AX) + MOVOU v10, 32(AX) + MOVOU v11, 48(AX) + MOVOU m00, 64(AX) + MOVOU m01, 80(AX) + MOVOU m10, 96(AX) + MOVOU m11, 112(AX) + +DONE: + RET + +// func finalizeSSE4(out []byte, state *[16]uint64) +TEXT ·finalizeSSE4(SB), NOSPLIT, $0-32 + MOVQ state+24(FP), AX + MOVQ out_base+0(FP), BX + MOVQ out_len+8(FP), CX + + MOVOU 0(AX), v00 + MOVOU 16(AX), v01 + MOVOU 32(AX), v10 + MOVOU 48(AX), v11 + MOVOU 64(AX), m00 + MOVOU 80(AX), m01 + MOVOU 96(AX), m10 + MOVOU 112(AX), m11 + + MOVOU ·asmZipperMerge<>(SB), t2 + + PSHUFD $177, v01, t0 + PSHUFD $177, v00, t1 + UPDATE(t0, t1) + + PSHUFD $177, v01, t0 + PSHUFD $177, v00, t1 + UPDATE(t0, t1) + + PSHUFD $177, v01, t0 + PSHUFD $177, v00, t1 + UPDATE(t0, t1) + + PSHUFD $177, v01, t0 + PSHUFD $177, v00, t1 + UPDATE(t0, t1) + + CMPQ CX, $8 + JE skipUpdate // Just 4 rounds for 64-bit checksum + + PSHUFD $177, v01, t0 + PSHUFD $177, v00, t1 + UPDATE(t0, t1) + + PSHUFD $177, v01, t0 + PSHUFD $177, v00, t1 + UPDATE(t0, t1) + + CMPQ CX, $16 + JE skipUpdate // 6 rounds for 128-bit checksum + + PSHUFD $177, v01, t0 + PSHUFD $177, v00, t1 + UPDATE(t0, t1) + + PSHUFD $177, v01, t0 + PSHUFD $177, v00, t1 + UPDATE(t0, t1) + + PSHUFD $177, v01, t0 + PSHUFD $177, v00, t1 + UPDATE(t0, t1) + + PSHUFD $177, v01, t0 + PSHUFD $177, v00, t1 + UPDATE(t0, t1) + +skipUpdate: + MOVOU v00, 0(AX) + MOVOU v01, 16(AX) + MOVOU v10, 32(AX) + MOVOU v11, 48(AX) + MOVOU m00, 64(AX) + MOVOU m01, 80(AX) + MOVOU m10, 96(AX) + MOVOU m11, 112(AX) + + CMPQ CX, $8 + JE hash64 + CMPQ CX, $16 + JE hash128 + + // 256-bit checksum + PADDQ v00, m00 + PADDQ v10, m10 + PADDQ v01, m01 + PADDQ v11, m11 + + MOVQ m00, R8 + PEXTRQ $1, m00, R9 + MOVQ m10, R10 + PEXTRQ $1, m10, R11 + REDUCE_MOD(R8, R9, R10, R11, R12, R13, R14, R15) + MOVQ R14, 0(BX) + MOVQ R15, 8(BX) + + MOVQ m01, R8 + PEXTRQ $1, m01, R9 + MOVQ m11, R10 + PEXTRQ $1, m11, R11 + REDUCE_MOD(R8, R9, R10, R11, R12, R13, R14, R15) + MOVQ R14, 16(BX) + MOVQ R15, 24(BX) + RET + +hash128: + PADDQ v00, v11 + PADDQ m00, m11 + PADDQ v11, m11 + MOVOU m11, 0(BX) + RET + +hash64: + PADDQ v00, v10 + PADDQ m00, m10 + PADDQ v10, m10 + MOVQ m10, DX + MOVQ DX, 0(BX) + RET diff --git a/go/vt/vthash/highway/highwayhash_arm64.go b/go/vt/vthash/highway/highwayhash_arm64.go new file mode 100644 index 00000000000..2b22db7ff56 --- /dev/null +++ b/go/vt/vthash/highway/highwayhash_arm64.go @@ -0,0 +1,64 @@ +//go:build !noasm && !appengine +// +build !noasm,!appengine + +/* +Copyright (c) 2017 Minio Inc. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Copyright (c) 2017 Minio Inc. All rights reserved. +// Use of this source code is governed by a license that can be +// found in the LICENSE file. + +package highway + +var ( + useSSE4 = false + useAVX2 = false + useNEON = true + useVMX = false +) + +//go:noescape +func initializeArm64(state *[16]uint64, key []byte) + +//go:noescape +func updateArm64(state *[16]uint64, msg []byte) + +//go:noescape +func finalizeArm64(out []byte, state *[16]uint64) + +func initialize(state *[16]uint64, key []byte) { + if useNEON { + initializeArm64(state, key) + } else { + initializeGeneric(state, key) + } +} + +func update(state *[16]uint64, msg []byte) { + if useNEON { + updateArm64(state, msg) + } else { + updateGeneric(state, msg) + } +} + +func finalize(out []byte, state *[16]uint64) { + if useNEON { + finalizeArm64(out, state) + } else { + finalizeGeneric(out, state) + } +} diff --git a/go/vt/vthash/highway/highwayhash_arm64.s b/go/vt/vthash/highway/highwayhash_arm64.s new file mode 100644 index 00000000000..bbf2f9822bd --- /dev/null +++ b/go/vt/vthash/highway/highwayhash_arm64.s @@ -0,0 +1,322 @@ +// Copyright (c) 2017 Minio Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//+build !noasm,!appengine + +// Use github.com/minio/asm2plan9s on this file to assemble ARM instructions to +// the opcodes of their Plan9 equivalents + +#include "textflag.h" + +#define REDUCE_MOD(x0, x1, x2, x3, tmp0, tmp1, y0, y1) \ + MOVD $0x3FFFFFFFFFFFFFFF, tmp0 \ + AND tmp0, x3 \ + MOVD x2, y0 \ + MOVD x3, y1 \ + \ + MOVD x2, tmp0 \ + MOVD x3, tmp1 \ + LSL $1, tmp1 \ + LSR $63, tmp0 \ + MOVD tmp1, x3 \ + ORR tmp0, x3 \ + \ + LSL $1, x2 \ + \ + MOVD y0, tmp0 \ + MOVD y1, tmp1 \ + LSL $2, tmp1 \ + LSR $62, tmp0 \ + MOVD tmp1, y1 \ + ORR tmp0, y1 \ + \ + LSL $2, y0 \ + \ + EOR x0, y0 \ + EOR x2, y0 \ + EOR x1, y1 \ + EOR x3, y1 + +#define UPDATE(MSG1, MSG2) \ + \ // Add message + VADD MSG1.D2, V2.D2, V2.D2 \ + VADD MSG2.D2, V3.D2, V3.D2 \ + \ + \ // v1 += mul0 + VADD V4.D2, V2.D2, V2.D2 \ + VADD V5.D2, V3.D2, V3.D2 \ + \ + \ // First pair of multiplies + VTBL V29.B16, [V0.B16, V1.B16], V10.B16 \ + VTBL V30.B16, [V2.B16, V3.B16], V11.B16 \ + \ + \ // VUMULL V10.S2, V11.S2, V12.D2 /* assembler support missing */ + \ // VUMULL2 V10.S4, V11.S4, V13.D2 /* assembler support missing */ + WORD $0x2eaac16c \ // umull v12.2d, v11.2s, v10.2s + WORD $0x6eaac16d \ // umull2 v13.2d, v11.4s, v10.4s + \ + \ // v0 += mul1 + VADD V6.D2, V0.D2, V0.D2 \ + VADD V7.D2, V1.D2, V1.D2 \ + \ + \ // Second pair of multiplies + VTBL V29.B16, [V2.B16, V3.B16], V15.B16 \ + VTBL V30.B16, [V0.B16, V1.B16], V14.B16 \ + \ + \ // EOR multiplication result in + VEOR V12.B16, V4.B16, V4.B16 \ + VEOR V13.B16, V5.B16, V5.B16 \ + \ + \ // VUMULL V14.S2, V15.S2, V16.D2 /* assembler support missing */ + \ // VUMULL2 V14.S4, V15.S4, V17.D2 /* assembler support missing */ + WORD $0x2eaec1f0 \ // umull v16.2d, v15.2s, v14.2s + WORD $0x6eaec1f1 \ // umull2 v17.2d, v15.4s, v14.4s + \ + \ // First pair of zipper-merges + VTBL V28.B16, [V2.B16], V18.B16 \ + VADD V18.D2, V0.D2, V0.D2 \ + VTBL V28.B16, [V3.B16], V19.B16 \ + VADD V19.D2, V1.D2, V1.D2 \ + \ + \ // Second pair of zipper-merges + VTBL V28.B16, [V0.B16], V20.B16 \ + VADD V20.D2, V2.D2, V2.D2 \ + VTBL V28.B16, [V1.B16], V21.B16 \ + VADD V21.D2, V3.D2, V3.D2 \ + \ + \ // EOR multiplication result in + VEOR V16.B16, V6.B16, V6.B16 \ + VEOR V17.B16, V7.B16, V7.B16 + +// func initializeArm64(state *[16]uint64, key []byte) +TEXT ·initializeArm64(SB), NOSPLIT, $0 + MOVD state+0(FP), R0 + MOVD key_base+8(FP), R1 + + VLD1 (R1), [V1.S4, V2.S4] + + VREV64 V1.S4, V3.S4 + VREV64 V2.S4, V4.S4 + + MOVD $·asmConstants(SB), R3 + VLD1 (R3), [V5.S4, V6.S4, V7.S4, V8.S4] + VEOR V5.B16, V1.B16, V1.B16 + VEOR V6.B16, V2.B16, V2.B16 + VEOR V7.B16, V3.B16, V3.B16 + VEOR V8.B16, V4.B16, V4.B16 + + VST1.P [V1.D2, V2.D2, V3.D2, V4.D2], 64(R0) + VST1 [V5.D2, V6.D2, V7.D2, V8.D2], (R0) + RET + +TEXT ·updateArm64(SB), NOSPLIT, $0 + MOVD state+0(FP), R0 + MOVD msg_base+8(FP), R1 + MOVD msg_len+16(FP), R2 // length of message + SUBS $32, R2 + BMI complete + + // Definition of registers + // v0 = v0.lo + // v1 = v0.hi + // v2 = v1.lo + // v3 = v1.hi + // v4 = mul0.lo + // v5 = mul0.hi + // v6 = mul1.lo + // v7 = mul1.hi + + // Load zipper merge constants table pointer + MOVD $·asmZipperMerge(SB), R3 + + // and load zipper merge constants into v28, v29, and v30 + VLD1 (R3), [V28.B16, V29.B16, V30.B16] + + VLD1.P 64(R0), [V0.D2, V1.D2, V2.D2, V3.D2] + VLD1 (R0), [V4.D2, V5.D2, V6.D2, V7.D2] + SUBS $64, R0 + +loop: + // Main loop + VLD1.P 32(R1), [V26.S4, V27.S4] + + UPDATE(V26, V27) + + SUBS $32, R2 + BPL loop + + // Store result + VST1.P [V0.D2, V1.D2, V2.D2, V3.D2], 64(R0) + VST1 [V4.D2, V5.D2, V6.D2, V7.D2], (R0) + +complete: + RET + +// func finalizeArm64(out []byte, state *[16]uint64) +TEXT ·finalizeArm64(SB), NOSPLIT, $0-32 + MOVD state+24(FP), R0 + MOVD out_base+0(FP), R1 + MOVD out_len+8(FP), R2 + + // Load zipper merge constants table pointer + MOVD $·asmZipperMerge(SB), R3 + + // and load zipper merge constants into v28, v29, and v30 + VLD1 (R3), [V28.B16, V29.B16, V30.B16] + + VLD1.P 64(R0), [V0.D2, V1.D2, V2.D2, V3.D2] + VLD1 (R0), [V4.D2, V5.D2, V6.D2, V7.D2] + SUB $64, R0 + + VREV64 V1.S4, V26.S4 + VREV64 V0.S4, V27.S4 + UPDATE(V26, V27) + + VREV64 V1.S4, V26.S4 + VREV64 V0.S4, V27.S4 + UPDATE(V26, V27) + + VREV64 V1.S4, V26.S4 + VREV64 V0.S4, V27.S4 + UPDATE(V26, V27) + + VREV64 V1.S4, V26.S4 + VREV64 V0.S4, V27.S4 + UPDATE(V26, V27) + + CMP $8, R2 + BEQ skipUpdate // Just 4 rounds for 64-bit checksum + + VREV64 V1.S4, V26.S4 + VREV64 V0.S4, V27.S4 + UPDATE(V26, V27) + + VREV64 V1.S4, V26.S4 + VREV64 V0.S4, V27.S4 + UPDATE(V26, V27) + + CMP $16, R2 + BEQ skipUpdate // 6 rounds for 128-bit checksum + + VREV64 V1.S4, V26.S4 + VREV64 V0.S4, V27.S4 + UPDATE(V26, V27) + + VREV64 V1.S4, V26.S4 + VREV64 V0.S4, V27.S4 + UPDATE(V26, V27) + + VREV64 V1.S4, V26.S4 + VREV64 V0.S4, V27.S4 + UPDATE(V26, V27) + + VREV64 V1.S4, V26.S4 + VREV64 V0.S4, V27.S4 + UPDATE(V26, V27) + +skipUpdate: + // Store result + VST1.P [V0.D2, V1.D2, V2.D2, V3.D2], 64(R0) + VST1 [V4.D2, V5.D2, V6.D2, V7.D2], (R0) + SUB $64, R0 + + CMP $8, R2 + BEQ hash64 + CMP $16, R2 + BEQ hash128 + + // 256-bit checksum + MOVD 0*8(R0), R8 + MOVD 1*8(R0), R9 + MOVD 4*8(R0), R10 + MOVD 5*8(R0), R11 + MOVD 8*8(R0), R4 + MOVD 9*8(R0), R5 + MOVD 12*8(R0), R6 + MOVD 13*8(R0), R7 + ADD R4, R8 + ADD R5, R9 + ADD R6, R10 + ADD R7, R11 + + REDUCE_MOD(R8, R9, R10, R11, R4, R5, R6, R7) + MOVD R6, 0(R1) + MOVD R7, 8(R1) + + MOVD 2*8(R0), R8 + MOVD 3*8(R0), R9 + MOVD 6*8(R0), R10 + MOVD 7*8(R0), R11 + MOVD 10*8(R0), R4 + MOVD 11*8(R0), R5 + MOVD 14*8(R0), R6 + MOVD 15*8(R0), R7 + ADD R4, R8 + ADD R5, R9 + ADD R6, R10 + ADD R7, R11 + + REDUCE_MOD(R8, R9, R10, R11, R4, R5, R6, R7) + MOVD R6, 16(R1) + MOVD R7, 24(R1) + RET + +hash128: + MOVD 0*8(R0), R8 + MOVD 1*8(R0), R9 + MOVD 6*8(R0), R10 + MOVD 7*8(R0), R11 + ADD R10, R8 + ADD R11, R9 + MOVD 8*8(R0), R10 + MOVD 9*8(R0), R11 + ADD R10, R8 + ADD R11, R9 + MOVD 14*8(R0), R10 + MOVD 15*8(R0), R11 + ADD R10, R8 + ADD R11, R9 + MOVD R8, 0(R1) + MOVD R9, 8(R1) + RET + +hash64: + MOVD 0*8(R0), R4 + MOVD 4*8(R0), R5 + MOVD 8*8(R0), R6 + MOVD 12*8(R0), R7 + ADD R5, R4 + ADD R7, R6 + ADD R6, R4 + MOVD R4, (R1) + RET + +DATA ·asmConstants+0x00(SB)/8, $0xdbe6d5d5fe4cce2f +DATA ·asmConstants+0x08(SB)/8, $0xa4093822299f31d0 +DATA ·asmConstants+0x10(SB)/8, $0x13198a2e03707344 +DATA ·asmConstants+0x18(SB)/8, $0x243f6a8885a308d3 +DATA ·asmConstants+0x20(SB)/8, $0x3bd39e10cb0ef593 +DATA ·asmConstants+0x28(SB)/8, $0xc0acf169b5f18a8c +DATA ·asmConstants+0x30(SB)/8, $0xbe5466cf34e90c6c +DATA ·asmConstants+0x38(SB)/8, $0x452821e638d01377 +GLOBL ·asmConstants(SB), 8, $64 + +// Constants for TBL instructions +DATA ·asmZipperMerge+0x0(SB)/8, $0x000f010e05020c03 // zipper merge constant +DATA ·asmZipperMerge+0x8(SB)/8, $0x070806090d0a040b +DATA ·asmZipperMerge+0x10(SB)/8, $0x0f0e0d0c07060504 // setup first register for multiply +DATA ·asmZipperMerge+0x18(SB)/8, $0x1f1e1d1c17161514 +DATA ·asmZipperMerge+0x20(SB)/8, $0x0b0a090803020100 // setup second register for multiply +DATA ·asmZipperMerge+0x28(SB)/8, $0x1b1a191813121110 +GLOBL ·asmZipperMerge(SB), 8, $48 diff --git a/go/vt/vthash/highway/highwayhash_generic.go b/go/vt/vthash/highway/highwayhash_generic.go new file mode 100644 index 00000000000..9ea17094843 --- /dev/null +++ b/go/vt/vthash/highway/highwayhash_generic.go @@ -0,0 +1,350 @@ +/* +Copyright (c) 2017 Minio Inc. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package highway + +import ( + "encoding/binary" +) + +const ( + v0 = 0 + v1 = 4 + mul0 = 8 + mul1 = 12 +) + +var ( + init0 = [4]uint64{0xdbe6d5d5fe4cce2f, 0xa4093822299f31d0, 0x13198a2e03707344, 0x243f6a8885a308d3} + init1 = [4]uint64{0x3bd39e10cb0ef593, 0xc0acf169b5f18a8c, 0xbe5466cf34e90c6c, 0x452821e638d01377} +) + +func initializeGeneric(state *[16]uint64, k []byte) { + var key [4]uint64 + + key[0] = binary.LittleEndian.Uint64(k[0:]) + key[1] = binary.LittleEndian.Uint64(k[8:]) + key[2] = binary.LittleEndian.Uint64(k[16:]) + key[3] = binary.LittleEndian.Uint64(k[24:]) + + copy(state[mul0:], init0[:]) + copy(state[mul1:], init1[:]) + + for i, k := range key { + state[v0+i] = init0[i] ^ k + } + + key[0] = key[0]>>32 | key[0]<<32 + key[1] = key[1]>>32 | key[1]<<32 + key[2] = key[2]>>32 | key[2]<<32 + key[3] = key[3]>>32 | key[3]<<32 + + for i, k := range key { + state[v1+i] = init1[i] ^ k + } +} + +func updateGeneric(state *[16]uint64, msg []byte) { + for len(msg) >= 32 { + m := msg[:32] + + // add message + mul0 + // Interleave operations to hide multiplication + state[v1+0] += binary.LittleEndian.Uint64(m) + state[mul0+0] + state[mul0+0] ^= uint64(uint32(state[v1+0])) * (state[v0+0] >> 32) + state[v0+0] += state[mul1+0] + state[mul1+0] ^= uint64(uint32(state[v0+0])) * (state[v1+0] >> 32) + + state[v1+1] += binary.LittleEndian.Uint64(m[8:]) + state[mul0+1] + state[mul0+1] ^= uint64(uint32(state[v1+1])) * (state[v0+1] >> 32) + state[v0+1] += state[mul1+1] + state[mul1+1] ^= uint64(uint32(state[v0+1])) * (state[v1+1] >> 32) + + state[v1+2] += binary.LittleEndian.Uint64(m[16:]) + state[mul0+2] + state[mul0+2] ^= uint64(uint32(state[v1+2])) * (state[v0+2] >> 32) + state[v0+2] += state[mul1+2] + state[mul1+2] ^= uint64(uint32(state[v0+2])) * (state[v1+2] >> 32) + + state[v1+3] += binary.LittleEndian.Uint64(m[24:]) + state[mul0+3] + state[mul0+3] ^= uint64(uint32(state[v1+3])) * (state[v0+3] >> 32) + state[v0+3] += state[mul1+3] + state[mul1+3] ^= uint64(uint32(state[v0+3])) * (state[v1+3] >> 32) + + // inlined: zipperMerge(state[v1+0], state[v1+1], &state[v0+0], &state[v0+1]) + { + val0 := state[v1+0] + val1 := state[v1+1] + res := val0 & (0xff << (2 * 8)) + res2 := (val0 & (0xff << (7 * 8))) + (val1 & (0xff << (2 * 8))) + res += (val1 & (0xff << (7 * 8))) >> 8 + res2 += (val0 & (0xff << (6 * 8))) >> 8 + res += ((val0 & (0xff << (5 * 8))) + (val1 & (0xff << (6 * 8)))) >> 16 + res2 += (val1 & (0xff << (5 * 8))) >> 16 + res += ((val0 & (0xff << (3 * 8))) + (val1 & (0xff << (4 * 8)))) >> 24 + res2 += ((val1 & (0xff << (3 * 8))) + (val0 & (0xff << (4 * 8)))) >> 24 + res += (val0 & (0xff << (1 * 8))) << 32 + res2 += (val1 & 0xff) << 48 + res += val0 << 56 + res2 += (val1 & (0xff << (1 * 8))) << 24 + + state[v0+0] += res + state[v0+1] += res2 + } + // zipperMerge(state[v1+2], state[v1+3], &state[v0+2], &state[v0+3]) + { + val0 := state[v1+2] + val1 := state[v1+3] + res := val0 & (0xff << (2 * 8)) + res2 := (val0 & (0xff << (7 * 8))) + (val1 & (0xff << (2 * 8))) + res += (val1 & (0xff << (7 * 8))) >> 8 + res2 += (val0 & (0xff << (6 * 8))) >> 8 + res += ((val0 & (0xff << (5 * 8))) + (val1 & (0xff << (6 * 8)))) >> 16 + res2 += (val1 & (0xff << (5 * 8))) >> 16 + res += ((val0 & (0xff << (3 * 8))) + (val1 & (0xff << (4 * 8)))) >> 24 + res2 += ((val1 & (0xff << (3 * 8))) + (val0 & (0xff << (4 * 8)))) >> 24 + res += (val0 & (0xff << (1 * 8))) << 32 + res2 += (val1 & 0xff) << 48 + res += val0 << 56 + res2 += (val1 & (0xff << (1 * 8))) << 24 + + state[v0+2] += res + state[v0+3] += res2 + } + + // inlined: zipperMerge(state[v0+0], state[v0+1], &state[v1+0], &state[v1+1]) + { + val0 := state[v0+0] + val1 := state[v0+1] + res := val0 & (0xff << (2 * 8)) + res2 := (val0 & (0xff << (7 * 8))) + (val1 & (0xff << (2 * 8))) + res += (val1 & (0xff << (7 * 8))) >> 8 + res2 += (val0 & (0xff << (6 * 8))) >> 8 + res += ((val0 & (0xff << (5 * 8))) + (val1 & (0xff << (6 * 8)))) >> 16 + res2 += (val1 & (0xff << (5 * 8))) >> 16 + res += ((val0 & (0xff << (3 * 8))) + (val1 & (0xff << (4 * 8)))) >> 24 + res2 += ((val1 & (0xff << (3 * 8))) + (val0 & (0xff << (4 * 8)))) >> 24 + res += (val0 & (0xff << (1 * 8))) << 32 + res2 += (val1 & 0xff) << 48 + res += val0 << 56 + res2 += (val1 & (0xff << (1 * 8))) << 24 + + state[v1+0] += res + state[v1+1] += res2 + } + + //inlined: zipperMerge(state[v0+2], state[v0+3], &state[v1+2], &state[v1+3]) + { + val0 := state[v0+2] + val1 := state[v0+3] + res := val0 & (0xff << (2 * 8)) + res2 := (val0 & (0xff << (7 * 8))) + (val1 & (0xff << (2 * 8))) + res += (val1 & (0xff << (7 * 8))) >> 8 + res2 += (val0 & (0xff << (6 * 8))) >> 8 + res += ((val0 & (0xff << (5 * 8))) + (val1 & (0xff << (6 * 8)))) >> 16 + res2 += (val1 & (0xff << (5 * 8))) >> 16 + res += ((val0 & (0xff << (3 * 8))) + (val1 & (0xff << (4 * 8)))) >> 24 + res2 += ((val1 & (0xff << (3 * 8))) + (val0 & (0xff << (4 * 8)))) >> 24 + res += (val0 & (0xff << (1 * 8))) << 32 + res2 += (val1 & 0xff) << 48 + res += val0 << 56 + res2 += (val1 & (0xff << (1 * 8))) << 24 + + state[v1+2] += res + state[v1+3] += res2 + } + msg = msg[32:] + } +} + +func finalizeGeneric(out []byte, state *[16]uint64) { + var perm [4]uint64 + var tmp [32]byte + runs := 4 + if len(out) == 16 { + runs = 6 + } else if len(out) == 32 { + runs = 10 + } + for i := 0; i < runs; i++ { + perm[0] = state[v0+2]>>32 | state[v0+2]<<32 + perm[1] = state[v0+3]>>32 | state[v0+3]<<32 + perm[2] = state[v0+0]>>32 | state[v0+0]<<32 + perm[3] = state[v0+1]>>32 | state[v0+1]<<32 + + binary.LittleEndian.PutUint64(tmp[0:], perm[0]) + binary.LittleEndian.PutUint64(tmp[8:], perm[1]) + binary.LittleEndian.PutUint64(tmp[16:], perm[2]) + binary.LittleEndian.PutUint64(tmp[24:], perm[3]) + + update(state, tmp[:]) + } + + switch len(out) { + case 8: + binary.LittleEndian.PutUint64(out, state[v0+0]+state[v1+0]+state[mul0+0]+state[mul1+0]) + case 16: + binary.LittleEndian.PutUint64(out, state[v0+0]+state[v1+2]+state[mul0+0]+state[mul1+2]) + binary.LittleEndian.PutUint64(out[8:], state[v0+1]+state[v1+3]+state[mul0+1]+state[mul1+3]) + case 32: + h0, h1 := reduceMod(state[v0+0]+state[mul0+0], state[v0+1]+state[mul0+1], state[v1+0]+state[mul1+0], state[v1+1]+state[mul1+1]) + binary.LittleEndian.PutUint64(out[0:], h0) + binary.LittleEndian.PutUint64(out[8:], h1) + + h0, h1 = reduceMod(state[v0+2]+state[mul0+2], state[v0+3]+state[mul0+3], state[v1+2]+state[mul1+2], state[v1+3]+state[mul1+3]) + binary.LittleEndian.PutUint64(out[16:], h0) + binary.LittleEndian.PutUint64(out[24:], h1) + } +} + +// Experiments on variations left for future reference... +/* +func zipperMerge(v0, v1 uint64, d0, d1 *uint64) { + if true { + // fastest. original interleaved... + res := v0 & (0xff << (2 * 8)) + res2 := (v0 & (0xff << (7 * 8))) + (v1 & (0xff << (2 * 8))) + res += (v1 & (0xff << (7 * 8))) >> 8 + res2 += (v0 & (0xff << (6 * 8))) >> 8 + res += ((v0 & (0xff << (5 * 8))) + (v1 & (0xff << (6 * 8)))) >> 16 + res2 += (v1 & (0xff << (5 * 8))) >> 16 + res += ((v0 & (0xff << (3 * 8))) + (v1 & (0xff << (4 * 8)))) >> 24 + res2 += ((v1 & (0xff << (3 * 8))) + (v0 & (0xff << (4 * 8)))) >> 24 + res += (v0 & (0xff << (1 * 8))) << 32 + res2 += (v1 & 0xff) << 48 + res += v0 << 56 + res2 += (v1 & (0xff << (1 * 8))) << 24 + + *d0 += res + *d1 += res2 + } else if false { + // Reading bytes and combining into uint64 + var v0b [8]byte + binary.LittleEndian.PutUint64(v0b[:], v0) + var v1b [8]byte + binary.LittleEndian.PutUint64(v1b[:], v1) + var res, res2 uint64 + + res = uint64(v0b[0]) << (7 * 8) + res2 = uint64(v1b[0]) << (6 * 8) + res |= uint64(v0b[1]) << (5 * 8) + res2 |= uint64(v1b[1]) << (4 * 8) + res |= uint64(v0b[2]) << (2 * 8) + res2 |= uint64(v1b[2]) << (2 * 8) + res |= uint64(v0b[3]) + res2 |= uint64(v0b[4]) << (1 * 8) + res |= uint64(v0b[5]) << (3 * 8) + res2 |= uint64(v0b[6]) << (5 * 8) + res |= uint64(v1b[4]) << (1 * 8) + res2 |= uint64(v0b[7]) << (7 * 8) + res |= uint64(v1b[6]) << (4 * 8) + res2 |= uint64(v1b[3]) + res |= uint64(v1b[7]) << (6 * 8) + res2 |= uint64(v1b[5]) << (3 * 8) + + *d0 += res + *d1 += res2 + + } else if false { + // bytes to bytes shuffle + var v0b [8]byte + binary.LittleEndian.PutUint64(v0b[:], v0) + var v1b [8]byte + binary.LittleEndian.PutUint64(v1b[:], v1) + var res [8]byte + + //res += ((v0 & (0xff << (3 * 8))) + (v1 & (0xff << (4 * 8)))) >> 24 + res[0] = v0b[3] + res[1] = v1b[4] + + // res := v0 & (0xff << (2 * 8)) + res[2] = v0b[2] + + //res += ((v0 & (0xff << (5 * 8))) + (v1 & (0xff << (6 * 8)))) >> 16 + res[3] = v0b[5] + res[4] = v1b[6] + + //res += (v0 & (0xff << (1 * 8))) << 32 + res[5] = v0b[1] + + //res += (v1 & (0xff << (7 * 8))) >> 8 + res[6] += v1b[7] + + //res += v0 << 56 + res[7] = v0b[0] + v0 = binary.LittleEndian.Uint64(res[:]) + *d0 += v0 + + //res += ((v1 & (0xff << (3 * 8))) + (v0 & (0xff << (4 * 8)))) >> 24 + res[0] = v1b[3] + res[1] = v0b[4] + + res[2] = v1b[2] + + // res += (v1 & (0xff << (5 * 8))) >> 16 + res[3] = v1b[5] + + //res += (v1 & (0xff << (1 * 8))) << 24 + res[4] = v1b[1] + + // res += (v0 & (0xff << (6 * 8))) >> 8 + res[5] = v0b[6] + + //res := (v0 & (0xff << (7 * 8))) + (v1 & (0xff << (2 * 8))) + res[7] = v0b[7] + + //res += (v1 & 0xff) << 48 + res[6] = v1b[0] + + v0 = binary.LittleEndian.Uint64(res[:]) + *d1 += v0 + } else { + // original. + res := v0 & (0xff << (2 * 8)) + res += (v1 & (0xff << (7 * 8))) >> 8 + res += ((v0 & (0xff << (5 * 8))) + (v1 & (0xff << (6 * 8)))) >> 16 + res += ((v0 & (0xff << (3 * 8))) + (v1 & (0xff << (4 * 8)))) >> 24 + res += (v0 & (0xff << (1 * 8))) << 32 + res += v0 << 56 + + *d0 += res + + res = (v0 & (0xff << (7 * 8))) + (v1 & (0xff << (2 * 8))) + res += (v0 & (0xff << (6 * 8))) >> 8 + res += (v1 & (0xff << (5 * 8))) >> 16 + res += ((v1 & (0xff << (3 * 8))) + (v0 & (0xff << (4 * 8)))) >> 24 + res += (v1 & 0xff) << 48 + res += (v1 & (0xff << (1 * 8))) << 24 + + *d1 += res + } +} +*/ + +// reduce v = [v0, v1, v2, v3] mod the irreducible polynomial x^128 + x^2 + x +func reduceMod(v0, v1, v2, v3 uint64) (r0, r1 uint64) { + v3 &= 0x3FFFFFFFFFFFFFFF + + r0, r1 = v2, v3 + + v3 = (v3 << 1) | (v2 >> (64 - 1)) + v2 <<= 1 + r1 = (r1 << 2) | (r0 >> (64 - 2)) + r0 <<= 2 + + r0 ^= v0 ^ v2 + r1 ^= v1 ^ v3 + return +} diff --git a/go/vt/vthash/highway/highwayhash_ppc64le.go b/go/vt/vthash/highway/highwayhash_ppc64le.go new file mode 100644 index 00000000000..f70e2a41473 --- /dev/null +++ b/go/vt/vthash/highway/highwayhash_ppc64le.go @@ -0,0 +1,49 @@ +//go:build !noasm && !appengine +// +build !noasm,!appengine + +/* +Copyright (c) 2017 Minio Inc. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +// Copyright (c) 2017 Minio Inc. All rights reserved. +// Use of this source code is governed by a license that can be +// found in the LICENSE file. + +package highway + +var ( + useSSE4 = false + useAVX2 = false + useNEON = false + useVMX = true +) + +//go:noescape +func updatePpc64Le(state *[16]uint64, msg []byte) + +func initialize(state *[16]uint64, key []byte) { + initializeGeneric(state, key) +} + +func update(state *[16]uint64, msg []byte) { + if useVMX { + updatePpc64Le(state, msg) + } else { + updateGeneric(state, msg) + } +} + +func finalize(out []byte, state *[16]uint64) { + finalizeGeneric(out, state) +} diff --git a/go/vt/vthash/highway/highwayhash_ppc64le.s b/go/vt/vthash/highway/highwayhash_ppc64le.s new file mode 100644 index 00000000000..957cebc4ddc --- /dev/null +++ b/go/vt/vthash/highway/highwayhash_ppc64le.s @@ -0,0 +1,180 @@ +// Copyright (c) 2017 Minio Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//+build !noasm,!appengine + +#include "textflag.h" + +// Definition of registers +#define V0_LO VS32 +#define V0_LO_ V0 +#define V0_HI VS33 +#define V0_HI_ V1 +#define V1_LO VS34 +#define V1_LO_ V2 +#define V1_HI VS35 +#define V1_HI_ V3 +#define MUL0_LO VS36 +#define MUL0_LO_ V4 +#define MUL0_HI VS37 +#define MUL0_HI_ V5 +#define MUL1_LO VS38 +#define MUL1_LO_ V6 +#define MUL1_HI VS39 +#define MUL1_HI_ V7 + +// Message +#define MSG_LO VS40 +#define MSG_LO_ V8 +#define MSG_HI VS41 + +// Constants +#define ROTATE VS42 +#define ROTATE_ V10 +#define MASK VS43 +#define MASK_ V11 + +// Temps +#define TEMP1 VS44 +#define TEMP1_ V12 +#define TEMP2 VS45 +#define TEMP2_ V13 +#define TEMP3 VS46 +#define TEMP3_ V14 +#define TEMP4_ V15 +#define TEMP5_ V16 +#define TEMP6_ V17 +#define TEMP7_ V18 + +// Regular registers +#define STATE R3 +#define MSG_BASE R4 +#define MSG_LEN R5 +#define CONSTANTS R6 +#define P1 R7 +#define P2 R8 +#define P3 R9 +#define P4 R10 +#define P5 R11 +#define P6 R12 +#define P7 R14 // avoid using R13 + +TEXT ·updatePpc64Le(SB), NOFRAME|NOSPLIT, $0-32 + MOVD state+0(FP), STATE + MOVD msg_base+8(FP), MSG_BASE + MOVD msg_len+16(FP), MSG_LEN // length of message + + // Sanity check for length + CMPU MSG_LEN, $31 + BLE complete + + // Setup offsets + MOVD $16, P1 + MOVD $32, P2 + MOVD $48, P3 + MOVD $64, P4 + MOVD $80, P5 + MOVD $96, P6 + MOVD $112, P7 + + // Load state + LXVD2X (STATE)(R0), V0_LO + LXVD2X (STATE)(P1), V0_HI + LXVD2X (STATE)(P2), V1_LO + LXVD2X (STATE)(P3), V1_HI + LXVD2X (STATE)(P4), MUL0_LO + LXVD2X (STATE)(P5), MUL0_HI + LXVD2X (STATE)(P6), MUL1_LO + LXVD2X (STATE)(P7), MUL1_HI + XXPERMDI V0_LO, V0_LO, $2, V0_LO + XXPERMDI V0_HI, V0_HI, $2, V0_HI + XXPERMDI V1_LO, V1_LO, $2, V1_LO + XXPERMDI V1_HI, V1_HI, $2, V1_HI + XXPERMDI MUL0_LO, MUL0_LO, $2, MUL0_LO + XXPERMDI MUL0_HI, MUL0_HI, $2, MUL0_HI + XXPERMDI MUL1_LO, MUL1_LO, $2, MUL1_LO + XXPERMDI MUL1_HI, MUL1_HI, $2, MUL1_HI + + // Load asmConstants table pointer + MOVD $·asmConstants(SB), CONSTANTS + LXVD2X (CONSTANTS)(R0), ROTATE + LXVD2X (CONSTANTS)(P1), MASK + XXLNAND MASK, MASK, MASK + +loop: + // Main highwayhash update loop + LXVD2X (MSG_BASE)(R0), MSG_LO + VADDUDM V0_LO_, MUL1_LO_, TEMP1_ + VRLD V0_LO_, ROTATE_, TEMP2_ + VADDUDM MUL1_HI_, V0_HI_, TEMP3_ + LXVD2X (MSG_BASE)(P1), MSG_HI + ADD $32, MSG_BASE, MSG_BASE + XXPERMDI MSG_LO, MSG_LO, $2, MSG_LO + XXPERMDI MSG_HI, MSG_HI, $2, V0_LO + VADDUDM MSG_LO_, MUL0_LO_, MSG_LO_ + VADDUDM V0_LO_, MUL0_HI_, V0_LO_ + VADDUDM MSG_LO_, V1_LO_, V1_LO_ + VSRD V0_HI_, ROTATE_, MSG_LO_ + VADDUDM V0_LO_, V1_HI_, V1_HI_ + VPERM V1_LO_, V1_LO_, MASK_, V0_LO_ + VMULOUW V1_LO_, TEMP2_, TEMP2_ + VPERM V1_HI_, V1_HI_, MASK_, TEMP7_ + VADDUDM V0_LO_, TEMP1_, V0_LO_ + VMULOUW V1_HI_, MSG_LO_, MSG_LO_ + VADDUDM TEMP7_, TEMP3_, V0_HI_ + VPERM V0_LO_, V0_LO_, MASK_, TEMP6_ + VRLD V1_LO_, ROTATE_, TEMP4_ + VSRD V1_HI_, ROTATE_, TEMP5_ + VPERM V0_HI_, V0_HI_, MASK_, TEMP7_ + XXLXOR MUL0_LO, TEMP2, MUL0_LO + VMULOUW TEMP1_, TEMP4_, TEMP1_ + VMULOUW TEMP3_, TEMP5_, TEMP3_ + XXLXOR MUL0_HI, MSG_LO, MUL0_HI + XXLXOR MUL1_LO, TEMP1, MUL1_LO + XXLXOR MUL1_HI, TEMP3, MUL1_HI + VADDUDM TEMP6_, V1_LO_, V1_LO_ + VADDUDM TEMP7_, V1_HI_, V1_HI_ + + SUB $32, MSG_LEN, MSG_LEN + CMPU MSG_LEN, $32 + BGE loop + + // Save state + XXPERMDI V0_LO, V0_LO, $2, V0_LO + XXPERMDI V0_HI, V0_HI, $2, V0_HI + XXPERMDI V1_LO, V1_LO, $2, V1_LO + XXPERMDI V1_HI, V1_HI, $2, V1_HI + XXPERMDI MUL0_LO, MUL0_LO, $2, MUL0_LO + XXPERMDI MUL0_HI, MUL0_HI, $2, MUL0_HI + XXPERMDI MUL1_LO, MUL1_LO, $2, MUL1_LO + XXPERMDI MUL1_HI, MUL1_HI, $2, MUL1_HI + STXVD2X V0_LO, (STATE)(R0) + STXVD2X V0_HI, (STATE)(P1) + STXVD2X V1_LO, (STATE)(P2) + STXVD2X V1_HI, (STATE)(P3) + STXVD2X MUL0_LO, (STATE)(P4) + STXVD2X MUL0_HI, (STATE)(P5) + STXVD2X MUL1_LO, (STATE)(P6) + STXVD2X MUL1_HI, (STATE)(P7) + +complete: + RET + +// Constants table +DATA ·asmConstants+0x0(SB)/8, $0x0000000000000020 +DATA ·asmConstants+0x8(SB)/8, $0x0000000000000020 +DATA ·asmConstants+0x10(SB)/8, $0x070806090d0a040b // zipper merge constant +DATA ·asmConstants+0x18(SB)/8, $0x000f010e05020c03 // zipper merge constant + +GLOBL ·asmConstants(SB), 8, $32 diff --git a/go/cache/perf_test.go b/go/vt/vthash/highway/highwayhash_ref.go similarity index 53% rename from go/cache/perf_test.go rename to go/vt/vthash/highway/highwayhash_ref.go index 693e55238a0..3ecb0e2f6ea 100644 --- a/go/cache/perf_test.go +++ b/go/vt/vthash/highway/highwayhash_ref.go @@ -1,5 +1,8 @@ +//go:build noasm || (!amd64 && !arm64 && !ppc64le) +// +build noasm !amd64,!arm64,!ppc64le + /* -Copyright 2019 The Vitess Authors. +Copyright (c) 2017 Minio Inc. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,23 +17,23 @@ See the License for the specific language governing permissions and limitations under the License. */ -package cache +package highway -import ( - "testing" +var ( + useSSE4 = false + useAVX2 = false + useNEON = false + useVMX = false ) -func BenchmarkGet(b *testing.B) { - cache := NewLRUCache(64*1024*1024, func(val any) int64 { - return int64(cap(val.([]byte))) - }) - value := make([]byte, 1000) - cache.Set("stuff", value) - for i := 0; i < b.N; i++ { - val, ok := cache.Get("stuff") - if !ok { - panic("error") - } - _ = val - } +func initialize(state *[16]uint64, k []byte) { + initializeGeneric(state, k) +} + +func update(state *[16]uint64, msg []byte) { + updateGeneric(state, msg) +} + +func finalize(out []byte, state *[16]uint64) { + finalizeGeneric(out, state) } diff --git a/go/vt/vthash/highway/highwayhash_test.go b/go/vt/vthash/highway/highwayhash_test.go new file mode 100644 index 00000000000..896b6d13763 --- /dev/null +++ b/go/vt/vthash/highway/highwayhash_test.go @@ -0,0 +1,228 @@ +/* +Copyright (c) 2017 Minio Inc. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +// Copyright (c) 2017 Minio Inc. All rights reserved. +// Use of this source code is governed by a license that can be +// found in the LICENSE file. + +package highway + +import ( + "bytes" + "encoding/hex" + "math/rand" + "runtime" + "sync/atomic" + "testing" +) + +func TestVectors(t *testing.T) { + defer func(sse4, avx2, neon, vmx bool) { + useSSE4, useAVX2, useNEON, useVMX = sse4, avx2, neon, vmx + }(useSSE4, useAVX2, useNEON, useVMX) + + if useAVX2 { + t.Run("AVX2 version", func(t *testing.T) { + testVectors(New128, testVectors128, t) + testVectors(New, testVectors256, t) + useAVX2 = false + }) + } + if useSSE4 { + t.Run("SSE4 version", func(t *testing.T) { + testVectors(New128, testVectors128, t) + testVectors(New, testVectors256, t) + useSSE4 = false + }) + } + if useNEON { + t.Run("NEON version", func(t *testing.T) { + testVectors(New128, testVectors128, t) + testVectors(New, testVectors256, t) + useNEON = false + }) + } + if useVMX { + t.Run("VMX version", func(t *testing.T) { + testVectors(New128, testVectors128, t) + testVectors(New, testVectors256, t) + useVMX = false + }) + } + t.Run("Generic version", func(t *testing.T) { + testVectors(New128, testVectors128, t) + testVectors(New, testVectors256, t) + }) +} + +func testVectors(NewFunc func([32]byte) *Digest, vectors []string, t *testing.T) { + key, err := hex.DecodeString("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f") + if err != nil { + t.Fatalf("Failed to decode key: %v", err) + } + input := make([]byte, len(vectors)) + + h := NewFunc([32]byte(key)) + for i, v := range vectors { + input[i] = byte(i) + + expected, err := hex.DecodeString(v) + if err != nil { + t.Fatalf("Failed to decode test vector: %v error: %v", v, err) + } + + _, _ = h.Write(input[:i]) + if sum := h.Sum(nil); !bytes.Equal(sum, expected[:]) { + t.Errorf("Test %d: hash mismatch: got: %v want: %v", i, hex.EncodeToString(sum), hex.EncodeToString(expected)) + } + h.Reset() + + switch h.Size() { + case Size: + if sum := Sum(input[:i], key); !bytes.Equal(sum[:], expected) { + t.Errorf("Test %d: Sum mismatch: got: %v want: %v", i, hex.EncodeToString(sum[:]), hex.EncodeToString(expected)) + } + case Size128: + if sum := Sum128(input[:i], key); !bytes.Equal(sum[:], expected) { + t.Errorf("Test %d: Sum mismatch: got: %v want: %v", i, hex.EncodeToString(sum[:]), hex.EncodeToString(expected)) + } + } + } +} + +var testVectors128 = []string{ + "c7fe8f9d8f26ed0f6f3e097f765e5633", "a8e7813689a8b0d6b4dc9cebf91d29dc", "04da165a26ad153d68e832dc38560878", "eb0b5f291b62070679ddced90f9ae6bf", + "9ee4ac6db49e392608923139d02a922e", "d82ed186c3bd50323ac2636c90103819", "476589cbb36a476f1910ed376f57de7c", "b4717169ca1f402a6c79029fff031fbe", + "e8520528846de9a1c20aec3bc6f15c69", "b2631ef302212a14cc00505b8cb9851a", "5bbcb6260eb7a1515955a42d3b1f9e92", "5b419a0562039988137d7bc4221fd2be", + "6695af1c5f1f1fcdd4c8f9e08cba18a8", "5761fe12415625a248b8ddb8784ce9b2", "1909ccd1eb2f49bda2415602bc1dcdce", "54afc42ba5372214d7bc266e0b6c79e0", + "ad01a4d5ff604441c8189f01d5a39e02", "62991cc5964b2ac5a05e9b16b178b8ec", "ceeafb118fca40d931d5f816d6463af9", "f5cbc0e50a9dc48a937c1df58dbffd3f", + "a8002d859b276dac46aaeba56b3acd7d", "568af093bd2116f1d5d93d1698c37331", "9ff88cf650e24c0ced981841da3c12b3", "ce519a3ded97ab150e0869914774e27c", + "b845488d191e00cd772daad88bd9d9d0", "793d49a017d6f334167e7f39f604d37d", "b6c6f4a99068b55c4f30676516290813", "c0d15b248b6fda308c74d93f7e8b826f", + "c0124c20490358e01c445fac0cdaf693", "453007a51b7348f67659b64f1197b85f", "06528a7354834f0291097eeb18499a50", "297ca5e865b4e70646d4f5073a5e4152", + "aa4a43c166df8419b9e4b3f95819fc16", "6cc3c6e0af7816119d84a2e59db558f9", "9004fb4084bc3f7736856543d2d56ec9", "41c9b60b71dce391e9aceec10b6a33ea", + "d4d97a5d81e3cf259ec58f828c4fe9f2", "f288c23cb838fbb904ec50f8c8c47974", "8c2b9825c5d5851df4db486fc1b1266e", "e7bd6060bd554e8ad03f8b0599d53421", + "368f7794f98f952a23641de61a2d05e8", "333245bee63a2389b9c0e8d7879ccf3a", "d5c8a97ee2f5584440512aca9bb48f41", "682ad17e83010309e661c83396f61710", + "9095d40447d80d33e4a64b3aadf19d33", "76c5f263a6639356f65ec9e3953d3b36", "3707b98685d0c8ace9284e7d08e8a02b", "20956dc8277ac2392e936051a420b68d", + "2d071a67eb4a6a8ee67ee4101a56d36e", "4ac7beb165d711002e84de6e656e0ed8", "4cc66a932bd615257d8a08d7948708ce", "af236ec152156291efcc23eb94004f26", + "803426970d88211e8610a3d3074865d8", "2d437f09af6ad7393947079de0e117a5", "145ac637f3a4170fd476f9695f21512f", "445e8912da5cfba0d13cf1d1c43d8c56", + "ce469cd800fcc893690e337e94dad5ba", "94561a1d50077c812bacbf2ce76e4d58", "bf53f073af68d691ede0c18376648ef9", "8bcf3c6befe18152d8836016dfc34cbc", + "b9eeaabe6d1bd6aa7b78160c009d96ff", "795847c04fd825432d1c5f90bd19b914", "d1a66baad176a179862b3aa5c520f7f1", "f03e2f021870bd74cb4b5fada894ea3a", + "f2c4d498711fbb98c88f91de7105bce0", +} + +var testVectors256 = []string{ + "f574c8c22a4844dd1f35c713730146d9ff1487b9ccbeaeb3f41d75453123da41", "54825fe4bc41b9ed0fc6ca3def440de2474a32cb9b1b657284e475b24c627320", + "54e4af24dff9df3f73e80a1b1abfc4117a592269cc6951112cb4330d59f60812", "5cd9d10dd7a00a48d0d111697c5e22895a86bb8b6b42a88e22c7e190c3fb3de2", + "dce42b2197c4cfc99b92d2aff69d5fa89e10f41d219fda1f9b4f4d377a27e407", "b385dca466f5b4b44201465eba634bbfe31ddccd688ef415c68580387d58740f", + "b4b9ad860ac74564b6ceb48427fb9ca913dbb2a0409de2da70119d9af26d52b6", "81ad8709a0b166d6376d8ceb38f8f1a430e063d4076e22e96c522c067dd65457", + "c08b76edb005b9f1453afffcf36f97e67897d0d98d51be4f330d1e37ebafa0d9", "81293c0dd7e4d880a1f12464d1bb0ff1d10c3f9dbe2d5ccff273b601f7e8bfc0", + "be62a2e5508ce4ade038fefdb192948e38b8e92f4bb78407cd6d65db74d5410e", "cf071853b977bea138971a6adea797ba1f268e9cef4c27afe8e84cc735b9393e", + "575840e30238ad15a053e839dccb119d25b2313c993eea232e21f4cae3e9d96c", "367cd7b15e6fc901a6951f53c1f967a3b8dcda7c42a3941fd3d53bbf0a00f197", + "418effee1ee915085ddf216efa280c0e745309ed628ead4ee6739d1cda01fd3f", "2e604278700519c146b1018501dbc362c10634fa17adf58547c3fed47bf884c8", + "1fcdb6a189d91af5d97b622ad675f0f7068af279f5d5017e9f4d176ac115d41a", "8e06a42ca8cff419b975923abd4a9d3bc610c0e9ddb000801356214909d58488", + "5d9fab817f6c6d12ee167709c5a3da4e493edda7731512af2dc380aa85ac0190", "fa559114f9beaa063d1ce744414f86dfda64bc60e8bcbafdb61c499247a52bde", + "db9f0735406bfcad656e488e32b787a0ea23465a93a9d14644ee3c0d445c89e3", "dfb3a3ee1dd3f9b533e1060ae224308f20e18f28c8384cf24997d69bcf1d3f70", + "e3ef9447850b3c2ba0ceda9b963f5d1c2eac63a5af6af1817530d0795a1c4423", "6237fd93c7f88a4124f9d761948e6bbc789e1a2a6af26f776eca17d4bfb7a03a", + "c1a355d22aea03cd2a1b9cb5e5fe8501e473974fd438f4d1e4763bf867dd69be", "fba0873887a851f9aee048a5d2317b2cfa6e18b638388044729f21bec78ec7a3", + "088c0dea51f18f958834f6b497897e4b6d38c55143078ec7faee206f557755d9", "0654b07f8017a9298c571f3584f81833faa7f6f66eea24ddffae975e469343e7", + "cb6c5e9380082498da979fb071d2d01f83b100274786e7561778749ff9491629", "56c554704f95d41beb6c597cff2edbff5b6bab1b9ac66a7c53c17f537076030f", + "9874599788e32588c13263afebf67c6417c928dc03d92b55abc5bf002c63d772", "4d641a6076e28068dab70fb1208b72b36ed110060612bdd0f22e4533ef14ef8a", + "fec3a139908ce3bc8912c1a32663d542a9aefc64f79555e3995a47c96b3cb0c9", "e5a634f0cb1501f6d046cebf75ea366c90597282d3c8173b357a0011eda2da7e", + "a2def9ed59e926130c729f73016877c42ff662d70f506951ab29250ad9d00d8a", "d442d403d549519344d1da0213b46bffec369dcd12b09c333022cc9e61531de6", + "96b650aa88c88b52fce18460a3ecaeb8763424c01e1558a144ec7c09ad4ac102", "27c31722a788d6be3f8760f71451e61ea602307db3265c3fb997156395e8f2dd", + "ad510b2bcf21dbe76cabb0f42463fcfa5b9c2dc2447285b09c84051e8d88adf0", "00cb4dcd93975105eb7d0663314a593c349e11cf1a0875ac94b05c809762c85a", + "9e77b5228c8d2209847e6b51b24d6419a04131f8abc8922b9193e125d75a787f", "4ba7d0465d2ec459646003ca653ca55eb4ae35b66b91a948d4e9543f14dfe6ba", + "e3d0036d6923b65e92a01db4bc783dd50db1f652dc4823fe118c2c6357248064", "8154b8c4b21bb643a1807e71258c31c67d689c6f4d7f4a8c7c1d4035e01702bd", + "374c824357ca517f3a701db15e4d4cb069f3f6cb1e1e514de2565421ea7567d6", "cc457ef8ee09b439b379fc59c4e8b852248c85d1180992444901ee5e647bf080", + "14d59abed19486cee73668522690a1bf7d2a90e4f6fda41efee196d658440c38", "a4a023f88be189d1d7a701e53b353b1f84282ee0b4774fa20c18f9746f64947e", + "48ec25d335c6f8af0b8d0314a40a2e2c6774441a617fd34e8914503be338ec39", "97f1835fadfd2b2acc74f2be6e3e3d0155617277043c56e17e0332e95d8a5af1", + "326312c81ef9d1d511ffb1f99b0b111032601c5426ab75a15215702857dcba87", "842808d82ca9b5c7fbee2e1bb62aa6dd2f73aefeec82988ffb4f1fc05cbd386b", + "f0323d7375f26ecf8b7dbfa22d82f0a36a4012f535744e302d17b3ebefe3280b", "dbe9b20107f898e628888a9a812aae66c9f2b8c92490ea14a4b53e52706141a7", + "b7ed07e3877e913ac15244e3dadeb41770cc11e762f189f60edd9c78fe6bce29", "8e5d15cbd83aff0ea244084cad9ecd47eb21fee60ee4c846510a34f05dc2f3de", + "4dd0822be686fd036d131707600dab32897a852b830e2b68b1393744f1e38c13", "02f9d7c454c7772feabfadd9a9e053100ae74a546863e658ca83dd729c828ac4", + "9fa066e419eb00f914d3c7a8019ebe3171f408cab8c6fe3afbe7ff870febc0b8", "fb8e3cbe8f7d27db7ba51ae17768ce537d7e9a0dd2949c71c93c459263b545b3", + "c9f2a4db3b9c6337c86d4636b3e795608ab8651e7949803ad57c92e5cd88c982", "e44a2314a7b11f6b7e46a65b252e562075d6f3402d892b3e68d71ee4fbe30cf4", + "2ac987b2b11ce18e6d263df6efaac28f039febe6873464667368d5e81da98a57", "67eb3a6a26f8b1f5dd1aec4dbe40b083aefb265b63c8e17f9fd7fede47a4a3f4", + "7524c16affe6d890f2c1da6e192a421a02b08e1ffe65379ebecf51c3c4d7bdc1", +} + +func benchmarkWrite(size int64, b *testing.B) { + var key [32]byte + data := make([]byte, size) + + h := New128(key) + b.SetBytes(size) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = h.Write(data) + } +} + +func BenchmarkWrite_8(b *testing.B) { benchmarkWrite(8, b) } +func BenchmarkWrite_16(b *testing.B) { benchmarkWrite(16, b) } +func BenchmarkWrite_64(b *testing.B) { benchmarkWrite(64, b) } +func BenchmarkWrite_1K(b *testing.B) { benchmarkWrite(1024, b) } +func BenchmarkWrite_8K(b *testing.B) { benchmarkWrite(8*1024, b) } + +func benchmarkSum256(size int64, b *testing.B) { + var key [32]byte + data := make([]byte, size) + + b.SetBytes(size) + b.ResetTimer() + for i := 0; i < b.N; i++ { + Sum(data, key[:]) + } +} + +func BenchmarkSum256_8(b *testing.B) { benchmarkSum256(8, b) } +func BenchmarkSum256_16(b *testing.B) { benchmarkSum256(16, b) } +func BenchmarkSum256_64(b *testing.B) { benchmarkSum256(64, b) } +func BenchmarkSum256_1K(b *testing.B) { benchmarkSum256(1024, b) } +func BenchmarkSum256_8K(b *testing.B) { benchmarkSum256(8*1024, b) } +func BenchmarkSum256_1M(b *testing.B) { benchmarkSum256(1024*1024, b) } +func BenchmarkSum256_5M(b *testing.B) { benchmarkSum256(5*1024*1024, b) } +func BenchmarkSum256_10M(b *testing.B) { benchmarkSum256(10*1024*1024, b) } +func BenchmarkSum256_25M(b *testing.B) { benchmarkSum256(25*1024*1024, b) } + +func benchmarkParallel(b *testing.B, size int) { + + c := runtime.GOMAXPROCS(0) + + var key [32]byte + + rng := rand.New(rand.NewSource(0xabadc0cac01a)) + data := make([][]byte, c) + for i := range data { + data[i] = make([]byte, size) + rng.Read(data[i]) + } + + b.SetBytes(int64(size)) + b.ResetTimer() + + counter := uint64(0) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + index := atomic.AddUint64(&counter, 1) + Sum(data[int(index)%len(data)], key[:]) + } + }) +} + +func BenchmarkParallel_1M(b *testing.B) { benchmarkParallel(b, 1024*1024) } +func BenchmarkParallel_5M(b *testing.B) { benchmarkParallel(b, 5*1024*1024) } +func BenchmarkParallel_10M(b *testing.B) { benchmarkParallel(b, 10*1024*1024) } +func BenchmarkParallel_25M(b *testing.B) { benchmarkParallel(b, 25*1024*1024) } diff --git a/go/vt/vthash/metro/metro.go b/go/vt/vthash/metro/metro.go index 76482408fef..66214713604 100644 --- a/go/vt/vthash/metro/metro.go +++ b/go/vt/vthash/metro/metro.go @@ -21,6 +21,7 @@ package metro import ( "encoding/binary" "math/bits" + "unsafe" ) const k0 = 0xC83A91E1 @@ -69,6 +70,10 @@ func (m *Metro128) Write64(u uint64) { _, _ = m.Write(scratch[:8]) } +func (m *Metro128) WriteString(str string) (int, error) { + return m.Write(unsafe.Slice(unsafe.StringData(str), len(str))) +} + func (m *Metro128) Write(buffer []byte) (int, error) { ptr := buffer diff --git a/go/vt/vttablet/endtoend/config_test.go b/go/vt/vttablet/endtoend/config_test.go index 759deb87ba2..60303cf4bf5 100644 --- a/go/vt/vttablet/endtoend/config_test.go +++ b/go/vt/vttablet/endtoend/config_test.go @@ -73,14 +73,6 @@ func TestStreamPoolSize(t *testing.T) { verifyIntValue(t, vstart, "StreamConnPoolCapacity", 1) } -func TestQueryCacheCapacity(t *testing.T) { - revert := changeVar(t, "QueryCacheCapacity", "1") - defer revert() - - vstart := framework.DebugVars() - verifyIntValue(t, vstart, "QueryCacheCapacity", 1) -} - func TestDisableConsolidator(t *testing.T) { totalConsolidationsTag := "Waits/Histograms/Consolidations/Count" initial := framework.FetchInt(framework.DebugVars(), totalConsolidationsTag) @@ -182,8 +174,6 @@ func TestQueryPlanCache(t *testing.T) { //sleep to avoid race between SchemaChanged event clearing out the plans cache which breaks this test framework.Server.WaitForSchemaReset(2 * time.Second) - defer framework.Server.SetQueryPlanCacheCap(framework.Server.QueryPlanCacheCap()) - bindVars := map[string]*querypb.BindVariable{ "ival1": sqltypes.Int64BindVariable(1), "ival2": sqltypes.Int64BindVariable(1), @@ -197,21 +187,18 @@ func TestQueryPlanCache(t *testing.T) { assert.Equal(t, 1, framework.Server.QueryPlanCacheLen()) vend := framework.DebugVars() - assert.Equal(t, 1, framework.FetchInt(vend, "QueryCacheLength")) assert.GreaterOrEqual(t, framework.FetchInt(vend, "QueryCacheSize"), cachedPlanSize) _, _ = client.Execute("select * from vitess_test where intval=:ival2", bindVars) require.Equal(t, 2, framework.Server.QueryPlanCacheLen()) vend = framework.DebugVars() - assert.Equal(t, 2, framework.FetchInt(vend, "QueryCacheLength")) assert.GreaterOrEqual(t, framework.FetchInt(vend, "QueryCacheSize"), 2*cachedPlanSize) _, _ = client.Execute("select * from vitess_test where intval=1", bindVars) require.Equal(t, 3, framework.Server.QueryPlanCacheLen()) vend = framework.DebugVars() - assert.Equal(t, 3, framework.FetchInt(vend, "QueryCacheLength")) assert.GreaterOrEqual(t, framework.FetchInt(vend, "QueryCacheSize"), 3*cachedPlanSize) } diff --git a/go/vt/vttablet/endtoend/framework/server.go b/go/vt/vttablet/endtoend/framework/server.go index 5992ba494da..4f8043fba5a 100644 --- a/go/vt/vttablet/endtoend/framework/server.go +++ b/go/vt/vttablet/endtoend/framework/server.go @@ -124,6 +124,7 @@ func StartServer(ctx context.Context, connParams, connAppDebugParams mysql.ConnP _ = config.Oltp.TxTimeoutSeconds.Set("5s") _ = config.Olap.TxTimeoutSeconds.Set("5s") config.EnableViews = true + config.QueryCacheDoorkeeper = false gotBytes, _ := yaml2.Marshal(config) log.Infof("Config:\n%s", gotBytes) return StartCustomServer(ctx, connParams, connAppDebugParams, dbName, config) diff --git a/go/vt/vttablet/endtoend/framework/testcase.go b/go/vt/vttablet/endtoend/framework/testcase.go index 37808c5aa7a..e02227b4eb6 100644 --- a/go/vt/vttablet/endtoend/framework/testcase.go +++ b/go/vt/vttablet/endtoend/framework/testcase.go @@ -21,6 +21,7 @@ import ( "fmt" "reflect" "strings" + "time" "vitess.io/vitess/go/vt/vterrors" @@ -122,7 +123,7 @@ func (tc *TestCase) Test(name string, client *QueryClient) error { } // wait for all previous test cases to have been settled in cache - client.server.QueryPlanCacheWait() + time.Sleep(100 * time.Millisecond) catcher := NewQueryCatcher() defer catcher.Close() diff --git a/go/vt/vttablet/endtoend/misc_test.go b/go/vt/vttablet/endtoend/misc_test.go index e66a3fde064..ae47999a97e 100644 --- a/go/vt/vttablet/endtoend/misc_test.go +++ b/go/vt/vttablet/endtoend/misc_test.go @@ -185,8 +185,7 @@ func TestIntegrityError(t *testing.T) { } func TestTrailingComment(t *testing.T) { - vstart := framework.DebugVars() - v1 := framework.FetchInt(vstart, "QueryCacheLength") + v1 := framework.Server.QueryPlanCacheLen() bindVars := map[string]*querypb.BindVariable{"ival": sqltypes.Int64BindVariable(1)} client := framework.NewClient() @@ -201,7 +200,7 @@ func TestTrailingComment(t *testing.T) { t.Error(err) return } - v2 := framework.FetchInt(framework.DebugVars(), "QueryCacheLength") + v2 := framework.Server.QueryPlanCacheLen() if v2 != v1+1 { t.Errorf("QueryCacheLength(%s): %d, want %d", query, v2, v1+1) } diff --git a/go/vt/vttablet/tabletserver/debugenv.go b/go/vt/vttablet/tabletserver/debugenv.go index d7127176d3f..e229c46cadd 100644 --- a/go/vt/vttablet/tabletserver/debugenv.go +++ b/go/vt/vttablet/tabletserver/debugenv.go @@ -116,8 +116,6 @@ func debugEnvHandler(tsv *TabletServer, w http.ResponseWriter, r *http.Request) setIntVal(tsv.SetStreamPoolSize) case "TxPoolSize": setIntVal(tsv.SetTxPoolSize) - case "QueryCacheCapacity": - setIntVal(tsv.SetQueryPlanCacheCap) case "MaxResultSize": setIntVal(tsv.SetMaxResultSize) case "WarnResultSize": diff --git a/go/vt/vttablet/tabletserver/query_engine.go b/go/vt/vttablet/tabletserver/query_engine.go index 5ac26bdd7e5..2381d3a07fa 100644 --- a/go/vt/vttablet/tabletserver/query_engine.go +++ b/go/vt/vttablet/tabletserver/query_engine.go @@ -20,15 +20,15 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "net/http" - "strings" "sync" "sync/atomic" "time" "vitess.io/vitess/go/acl" - "vitess.io/vitess/go/cache" + "vitess.io/vitess/go/cache/theine" "vitess.io/vitess/go/mysql/sqlerror" "vitess.io/vitess/go/pools" "vitess.io/vitess/go/stats" @@ -44,6 +44,7 @@ import ( "vitess.io/vitess/go/vt/tableacl" tacl "vitess.io/vitess/go/vt/tableacl/acl" "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vthash" "vitess.io/vitess/go/vt/vttablet/tabletserver/connpool" "vitess.io/vitess/go/vt/vttablet/tabletserver/planbuilder" "vitess.io/vitess/go/vt/vttablet/tabletserver/rules" @@ -122,6 +123,17 @@ func isValid(planType planbuilder.PlanType, hasReservedCon bool, hasSysSettings // _______________________________________________ +type PlanCacheKey = theine.StringKey +type PlanCache = theine.Store[PlanCacheKey, *TabletPlan] + +type SettingsCacheKey = theine.HashKey256 +type SettingsCache = theine.Store[SettingsCacheKey, *pools.Setting] + +type currentSchema struct { + tables map[string]*schema.Table + epoch uint32 +} + // QueryEngine implements the core functionality of tabletserver. // It assumes that no requests will be sent to it before Open is // called and succeeds. @@ -130,14 +142,17 @@ func isValid(planType planbuilder.PlanType, hasReservedCon bool, hasSysSettings // Close: There should be no more pending queries when this // function is called. type QueryEngine struct { - isOpen bool + isOpen atomic.Bool env tabletenv.Env se *schema.Engine // mu protects the following fields. - mu sync.RWMutex - tables map[string]*schema.Table - plans cache.Cache + schemaMu sync.Mutex + epoch uint32 + schema atomic.Pointer[currentSchema] + + plans *PlanCache + settings *SettingsCache queryRuleSources *rules.Map // Pools @@ -186,21 +201,29 @@ type QueryEngine struct { // You must call this only once. func NewQueryEngine(env tabletenv.Env, se *schema.Engine) *QueryEngine { config := env.Config() - cacheCfg := &cache.Config{ - MaxEntries: int64(config.QueryCacheSize), - MaxMemoryUsage: config.QueryCacheMemory, - LFU: config.QueryCacheLFU, - } qe := &QueryEngine{ env: env, se: se, - tables: make(map[string]*schema.Table), - plans: cache.NewDefaultCacheImpl(cacheCfg), queryRuleSources: rules.NewMap(), enablePerWorkloadTableMetrics: config.EnablePerWorkloadTableMetrics, } + // Cache for query plans: user configured size with a doorkeeper by default to prevent one-off queries + // from thrashing the cache. + qe.plans = theine.NewStore[PlanCacheKey, *TabletPlan](config.QueryCacheMemory, config.QueryCacheDoorkeeper) + + // cache for connection settings: default to 1/4th of the size for the query cache and do + // not use a doorkeeper because custom connection settings are rarely one-off and we always + // want to cache them + var settingsCacheMemory = config.QueryCacheMemory / 4 + qe.settings = theine.NewStore[SettingsCacheKey, *pools.Setting](settingsCacheMemory, false) + + qe.schema.Store(¤tSchema{ + tables: make(map[string]*schema.Table), + epoch: 0, + }) + qe.conns = connpool.NewPool(env, "ConnPool", config.OltpReadPool) qe.streamConns = connpool.NewPool(env, "StreamConnPool", config.OlapReadPool) qe.consolidatorMode.Store(config.Consolidator) @@ -248,9 +271,15 @@ func NewQueryEngine(env tabletenv.Env, se *schema.Engine) *QueryEngine { env.Exporter().NewGaugeFunc("QueryCacheLength", "Query engine query cache length", func() int64 { return int64(qe.plans.Len()) }) - env.Exporter().NewGaugeFunc("QueryCacheSize", "Query engine query cache size", qe.plans.UsedCapacity) - env.Exporter().NewGaugeFunc("QueryCacheCapacity", "Query engine query cache capacity", qe.plans.MaxCapacity) - env.Exporter().NewCounterFunc("QueryCacheEvictions", "Query engine query cache evictions", qe.plans.Evictions) + env.Exporter().NewGaugeFunc("QueryCacheSize", "Query engine query cache size", func() int64 { + return int64(qe.plans.UsedCapacity()) + }) + env.Exporter().NewGaugeFunc("QueryCacheCapacity", "Query engine query cache capacity", func() int64 { + return int64(qe.plans.MaxCapacity()) + }) + env.Exporter().NewCounterFunc("QueryCacheEvictions", "Query engine query cache evictions", func() int64 { + return qe.plans.Metrics.Evicted() + }) labels := []string{"Table", "Plan"} if config.EnablePerWorkloadTableMetrics { @@ -277,12 +306,14 @@ func NewQueryEngine(env tabletenv.Env, se *schema.Engine) *QueryEngine { // Open must be called before sending requests to QueryEngine. func (qe *QueryEngine) Open() error { - if qe.isOpen { + if qe.isOpen.Load() { return nil } log.Info("Query Engine: opening") - qe.conns.Open(qe.env.Config().DB.AppWithDB(), qe.env.Config().DB.DbaWithDB(), qe.env.Config().DB.AppDebugWithDB()) + config := qe.env.Config() + + qe.conns.Open(config.DB.AppWithDB(), config.DB.DbaWithDB(), config.DB.AppDebugWithDB()) conn, err := qe.conns.Get(tabletenv.LocalContext(), nil) if err != nil { @@ -299,9 +330,11 @@ func (qe *QueryEngine) Open() error { return err } - qe.streamConns.Open(qe.env.Config().DB.AppWithDB(), qe.env.Config().DB.DbaWithDB(), qe.env.Config().DB.AppDebugWithDB()) + qe.streamConns.Open(config.DB.AppWithDB(), config.DB.DbaWithDB(), config.DB.AppDebugWithDB()) qe.se.RegisterNotifier("qe", qe.schemaChanged, true) - qe.isOpen = true + qe.plans.EnsureOpen() + qe.settings.EnsureOpen() + qe.isOpen.Store(true) return nil } @@ -309,63 +342,69 @@ func (qe *QueryEngine) Open() error { // You must ensure that no more queries will be sent // before calling Close. func (qe *QueryEngine) Close() { - if !qe.isOpen { + if !qe.isOpen.Swap(false) { return } // Close in reverse order of Open. qe.se.UnregisterNotifier("qe") + qe.plans.Close() - qe.tables = make(map[string]*schema.Table) + qe.settings.Close() + qe.streamConns.Close() qe.conns.Close() - qe.isOpen = false log.Info("Query Engine: closed") } -// GetPlan returns the TabletPlan that for the query. Plans are cached in a cache.LRUCache. -func (qe *QueryEngine) GetPlan(ctx context.Context, logStats *tabletenv.LogStats, sql string, skipQueryPlanCache bool) (*TabletPlan, error) { - span, _ := trace.NewSpan(ctx, "QueryEngine.GetPlan") - defer span.Finish() - if !skipQueryPlanCache { - if plan := qe.getQuery(sql); plan != nil { - logStats.CachedPlan = true - return plan, nil - } - } - // Obtain read lock to prevent schema from changing while - // we build a plan. The read lock allows multiple identical - // queries to build the same plan. One of them will win by - // updating the query cache and prevent future races. Due to - // this, query stats reporting may not be accurate, but it's - // acceptable because those numbers are best effort. - qe.mu.RLock() - defer qe.mu.RUnlock() +var errNoCache = errors.New("plan should not be cached") + +func (qe *QueryEngine) getPlan(curSchema *currentSchema, sql string) (*TabletPlan, error) { statement, err := sqlparser.Parse(sql) if err != nil { return nil, err } - splan, err := planbuilder.Build(statement, qe.tables, qe.env.Config().DB.DBName, qe.env.Config().EnableViews) + splan, err := planbuilder.Build(statement, curSchema.tables, qe.env.Config().DB.DBName, qe.env.Config().EnableViews) if err != nil { return nil, err } plan := &TabletPlan{Plan: splan, Original: sql} plan.Rules = qe.queryRuleSources.FilterByPlan(sql, plan.PlanID, plan.TableNames()...) plan.buildAuthorized() - if plan.PlanID == planbuilder.PlanDDL || plan.PlanID == planbuilder.PlanSet { - return plan, nil - } - if !skipQueryPlanCache && !sqlparser.SkipQueryPlanCacheDirective(statement) { - qe.plans.Set(sql, plan) + if plan.PlanID == planbuilder.PlanDDL || plan.PlanID == planbuilder.PlanSet || sqlparser.SkipQueryPlanCacheDirective(statement) { + return plan, errNoCache } + return plan, nil } +// GetPlan returns the TabletPlan that for the query. Plans are cached in a cache.LRUCache. +func (qe *QueryEngine) GetPlan(ctx context.Context, logStats *tabletenv.LogStats, sql string, skipQueryPlanCache bool) (*TabletPlan, error) { + span, _ := trace.NewSpan(ctx, "QueryEngine.GetPlan") + defer span.Finish() + + var plan *TabletPlan + var err error + + curSchema := qe.schema.Load() + + if skipQueryPlanCache { + plan, err = qe.getPlan(curSchema, sql) + } else { + plan, logStats.CachedPlan, err = qe.plans.GetOrLoad(PlanCacheKey(sql), curSchema.epoch, func() (*TabletPlan, error) { + return qe.getPlan(curSchema, sql) + }) + } + + if errors.Is(err, errNoCache) { + err = nil + } + return plan, err +} + // GetStreamPlan is similar to GetPlan, but doesn't use the cache // and doesn't enforce a limit. It just returns the parsed query. func (qe *QueryEngine) GetStreamPlan(sql string) (*TabletPlan, error) { - qe.mu.RLock() - defer qe.mu.RUnlock() - splan, err := planbuilder.BuildStreaming(sql, qe.tables) + splan, err := planbuilder.BuildStreaming(sql, qe.schema.Load().tables) if err != nil { return nil, err } @@ -377,9 +416,7 @@ func (qe *QueryEngine) GetStreamPlan(sql string) (*TabletPlan, error) { // GetMessageStreamPlan builds a plan for Message streaming. func (qe *QueryEngine) GetMessageStreamPlan(name string) (*TabletPlan, error) { - qe.mu.RLock() - defer qe.mu.RUnlock() - splan, err := planbuilder.BuildMessageStreaming(name, qe.tables) + splan, err := planbuilder.BuildMessageStreaming(name, qe.schema.Load().tables) if err != nil { return nil, err } @@ -394,33 +431,44 @@ func (qe *QueryEngine) GetConnSetting(ctx context.Context, settings []string) (* span, _ := trace.NewSpan(ctx, "QueryEngine.GetConnSetting") defer span.Finish() - var keyBuilder strings.Builder + hasher := vthash.New256() for _, q := range settings { - keyBuilder.WriteString(q) + _, _ = hasher.WriteString(q) } - // try to get the connSetting from the cache - cacheKey := keyBuilder.String() - if plan := qe.getConnSetting(cacheKey); plan != nil { - return plan, nil - } + var cacheKey SettingsCacheKey + hasher.Sum(cacheKey[:0]) - // build the setting queries - query, resetQuery, err := planbuilder.BuildSettingQuery(settings) - if err != nil { - return nil, err - } - connSetting := pools.NewSetting(query, resetQuery) - - // store the connSetting in the cache - qe.plans.Set(cacheKey, connSetting) - - return connSetting, nil + connSetting, _, err := qe.settings.GetOrLoad(cacheKey, 0, func() (*pools.Setting, error) { + // build the setting queries + query, resetQuery, err := planbuilder.BuildSettingQuery(settings) + if err != nil { + return nil, err + } + return pools.NewSetting(query, resetQuery), nil + }) + return connSetting, err } // ClearQueryPlanCache should be called if query plan cache is potentially obsolete func (qe *QueryEngine) ClearQueryPlanCache() { - qe.plans.Clear() + qe.schemaMu.Lock() + defer qe.schemaMu.Unlock() + + qe.epoch++ + + current := qe.schema.Load() + qe.schema.Store(¤tSchema{ + tables: current.tables, + epoch: qe.epoch, + }) +} + +func (qe *QueryEngine) ForEachPlan(each func(plan *TabletPlan) bool) { + curSchema := qe.schema.Load() + qe.plans.Range(curSchema.epoch, func(_ PlanCacheKey, plan *TabletPlan) bool { + return each(plan) + }) } // IsMySQLReachable returns an error if it cannot connect to MySQL. @@ -438,56 +486,31 @@ func (qe *QueryEngine) IsMySQLReachable() error { } func (qe *QueryEngine) schemaChanged(tables map[string]*schema.Table, created, altered, dropped []*schema.Table) { - qe.mu.Lock() - defer qe.mu.Unlock() - qe.tables = tables - if len(altered) != 0 || len(dropped) != 0 { - qe.plans.Clear() - } -} + qe.schemaMu.Lock() + defer qe.schemaMu.Unlock() -// getQuery fetches the plan and makes it the most recent. -func (qe *QueryEngine) getQuery(sql string) *TabletPlan { - cacheResult, ok := qe.plans.Get(sql) - if !ok { - return nil - } - plan, ok := cacheResult.(*TabletPlan) - if ok { - return plan - } - return nil -} - -func (qe *QueryEngine) getConnSetting(key string) *pools.Setting { - cacheResult, ok := qe.plans.Get(key) - if !ok { - return nil - } - plan, ok := cacheResult.(*pools.Setting) - if ok { - return plan + if len(altered) != 0 || len(dropped) != 0 { + qe.epoch++ } - return nil -} -// SetQueryPlanCacheCap sets the query plan cache capacity. -func (qe *QueryEngine) SetQueryPlanCacheCap(size int) { - if size <= 0 { - size = 1 - } - qe.plans.SetCapacity(int64(size)) + qe.schema.Store(¤tSchema{ + tables: tables, + epoch: qe.epoch, + }) } // QueryPlanCacheCap returns the capacity of the query cache. func (qe *QueryEngine) QueryPlanCacheCap() int { - return int(qe.plans.MaxCapacity()) + return qe.plans.MaxCapacity() } // QueryPlanCacheLen returns the length (size in entries) of the query cache -func (qe *QueryEngine) QueryPlanCacheLen() int { - qe.plans.Wait() - return qe.plans.Len() +func (qe *QueryEngine) QueryPlanCacheLen() (count int) { + qe.ForEachPlan(func(plan *TabletPlan) bool { + count++ + return true + }) + return } // AddStats adds the given stats for the planName.tableName @@ -547,8 +570,7 @@ func (qe *QueryEngine) handleHTTPQueryPlans(response http.ResponseWriter, reques } response.Header().Set("Content-Type", "text/plain") - qe.plans.ForEach(func(value any) bool { - plan := value.(*TabletPlan) + qe.ForEachPlan(func(plan *TabletPlan) bool { response.Write([]byte(fmt.Sprintf("%#v\n", sqlparser.TruncateForUI(plan.Original)))) if b, err := json.MarshalIndent(plan.Plan, "", " "); err != nil { response.Write([]byte(err.Error())) @@ -567,9 +589,7 @@ func (qe *QueryEngine) handleHTTPQueryStats(response http.ResponseWriter, reques } response.Header().Set("Content-Type", "application/json; charset=utf-8") var qstats []perQueryStats - qe.plans.ForEach(func(value any) bool { - plan := value.(*TabletPlan) - + qe.ForEachPlan(func(plan *TabletPlan) bool { var pqstats perQueryStats pqstats.Query = unicoded(sqlparser.TruncateForUI(plan.Original)) pqstats.Table = plan.TableName().String() diff --git a/go/vt/vttablet/tabletserver/query_engine_test.go b/go/vt/vttablet/tabletserver/query_engine_test.go index e073b6f8c49..73ac1ca5e37 100644 --- a/go/vt/vttablet/tabletserver/query_engine_test.go +++ b/go/vt/vttablet/tabletserver/query_engine_test.go @@ -18,7 +18,6 @@ package tabletserver import ( "context" - "expvar" "fmt" "math/rand" "net/http" @@ -32,6 +31,7 @@ import ( "testing" "time" + "vitess.io/vitess/go/cache/theine" "vitess.io/vitess/go/vt/proto/topodata" "vitess.io/vitess/go/vt/sqlparser" @@ -41,7 +41,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "vitess.io/vitess/go/cache" "vitess.io/vitess/go/mysql/fakesqldb" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/streamlog" @@ -148,7 +147,7 @@ func TestGetMessageStreamPlan(t *testing.T) { } wantPlan := &planbuilder.Plan{ PlanID: planbuilder.PlanMessageStream, - Table: qe.tables["msg"], + Table: qe.schema.Load().tables["msg"], Permissions: []planbuilder.Permission{{ TableName: "msg", Role: tableacl.WRITER, @@ -164,12 +163,8 @@ func TestGetMessageStreamPlan(t *testing.T) { func assertPlanCacheSize(t *testing.T, qe *QueryEngine, expected int) { t.Helper() - var size int - qe.plans.Wait() - qe.plans.ForEach(func(_ any) bool { - size++ - return true - }) + time.Sleep(100 * time.Millisecond) + size := qe.plans.Len() require.Equal(t, expected, size, "expected query plan cache to contain %d entries, found %d", expected, size) } @@ -179,7 +174,6 @@ func TestQueryPlanCache(t *testing.T) { schematest.AddDefaultQueries(db) firstQuery := "select * from test_table_01" - secondQuery := "select * from test_table_02" db.AddQuery("select * from test_table_01 where 1 != 1", &sqltypes.Result{}) db.AddQuery("select * from test_table_02 where 1 != 1", &sqltypes.Result{}) @@ -190,23 +184,11 @@ func TestQueryPlanCache(t *testing.T) { ctx := context.Background() logStats := tabletenv.NewLogStats(ctx, "GetPlanStats") - if cache.DefaultConfig.LFU { - // this cache capacity is in bytes - qe.SetQueryPlanCacheCap(528) - } else { - // this cache capacity is in number of elements - qe.SetQueryPlanCacheCap(1) - } + firstPlan, err := qe.GetPlan(ctx, logStats, firstQuery, false) require.NoError(t, err) require.NotNil(t, firstPlan, "plan should not be nil") - secondPlan, err := qe.GetPlan(ctx, logStats, secondQuery, false) - fmt.Println(secondPlan.CachedSize(true)) - require.NoError(t, err) - require.NotNil(t, secondPlan, "plan should not be nil") - expvar.Do(func(kv expvar.KeyValue) { - _ = kv.Value.String() - }) + assertPlanCacheSize(t, qe, 1) qe.ClearQueryPlanCache() } @@ -227,7 +209,7 @@ func TestNoQueryPlanCache(t *testing.T) { ctx := context.Background() logStats := tabletenv.NewLogStats(ctx, "GetPlanStats") - qe.SetQueryPlanCacheCap(1024) + firstPlan, err := qe.GetPlan(ctx, logStats, firstQuery, true) if err != nil { t.Fatal(err) @@ -256,7 +238,7 @@ func TestNoQueryPlanCacheDirective(t *testing.T) { ctx := context.Background() logStats := tabletenv.NewLogStats(ctx, "GetPlanStats") - qe.SetQueryPlanCacheCap(1024) + firstPlan, err := qe.GetPlan(ctx, logStats, firstQuery, false) if err != nil { t.Fatal(err) @@ -305,6 +287,8 @@ func newTestQueryEngine(idleTimeout time.Duration, strict bool, dbcfgs *dbconfig env := tabletenv.NewEnv(config, "TabletServerTest") se := schema.NewEngine(env) qe := NewQueryEngine(env, se) + // the integration tests that check cache behavior do not expect a doorkeeper; disable it + qe.plans = theine.NewStore[PlanCacheKey, *TabletPlan](4*1024*1024, false) se.InitDBConfig(dbcfgs.DbaWithDB()) return qe } @@ -393,13 +377,12 @@ func BenchmarkPlanCacheThroughput(b *testing.B) { } } -func benchmarkPlanCache(b *testing.B, db *fakesqldb.DB, lfu bool, par int) { +func benchmarkPlanCache(b *testing.B, db *fakesqldb.DB, par int) { b.Helper() dbcfgs := newDBConfigs(db) config := tabletenv.NewDefaultConfig() config.DB = dbcfgs - config.QueryCacheLFU = lfu env := tabletenv.NewEnv(config, "TabletServerTest") se := schema.NewEngine(env) @@ -432,12 +415,8 @@ func BenchmarkPlanCacheContention(b *testing.B) { db.AddQueryPattern(".*", &sqltypes.Result{}) for par := 1; par <= 8; par *= 2 { - b.Run(fmt.Sprintf("ContentionLRU-%d", par), func(b *testing.B) { - benchmarkPlanCache(b, db, false, par) - }) - b.Run(fmt.Sprintf("ContentionLFU-%d", par), func(b *testing.B) { - benchmarkPlanCache(b, db, true, par) + benchmarkPlanCache(b, db, par) }) } } @@ -483,16 +462,9 @@ func TestPlanCachePollution(t *testing.T) { var wg sync.WaitGroup go func() { - cacheMode := "lru" - if config.QueryCacheLFU { - cacheMode = "lfu" - } + cacheMode := "lfu" - out, err := os.Create(path.Join(plotPath, - fmt.Sprintf("cache_plot_%d_%d_%s.dat", - config.QueryCacheSize, config.QueryCacheMemory, cacheMode, - )), - ) + out, err := os.Create(path.Join(plotPath, fmt.Sprintf("cache_plot_%d_%s.dat", config.QueryCacheMemory, cacheMode))) require.NoError(t, err) defer out.Close() diff --git a/go/vt/vttablet/tabletserver/query_executor_test.go b/go/vt/vttablet/tabletserver/query_executor_test.go index cf0e87a6e27..2385dbc3a31 100644 --- a/go/vt/vttablet/tabletserver/query_executor_test.go +++ b/go/vt/vttablet/tabletserver/query_executor_test.go @@ -23,6 +23,7 @@ import ( "math/rand" "strings" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -319,7 +320,7 @@ func TestQueryExecutorPlans(t *testing.T) { assert.True(t, vterrors.Equals(err, tcase.errorWant)) } // Wait for the existing query to be processed by the cache - tsv.QueryPlanCacheWait() + time.Sleep(100 * time.Millisecond) // Test inside a transaction. target := tsv.sm.Target() @@ -412,7 +413,7 @@ func TestQueryExecutorQueryAnnotation(t *testing.T) { assert.Equal(t, tcase.logWant, qre.logStats.RewrittenSQL(), tcase.input) // Wait for the existing query to be processed by the cache - tsv.QueryPlanCacheWait() + time.Sleep(100 * time.Millisecond) // Test inside a transaction. target := tsv.sm.Target() diff --git a/go/vt/vttablet/tabletserver/queryz.go b/go/vt/vttablet/tabletserver/queryz.go index 7d976d6b6b7..151f028ca09 100644 --- a/go/vt/vttablet/tabletserver/queryz.go +++ b/go/vt/vttablet/tabletserver/queryz.go @@ -152,8 +152,7 @@ func queryzHandler(qe *QueryEngine, w http.ResponseWriter, r *http.Request) { return row1.timePQ() > row2.timePQ() }, } - qe.plans.ForEach(func(value any) bool { - plan := value.(*TabletPlan) + qe.ForEachPlan(func(plan *TabletPlan) bool { if plan == nil { return true } diff --git a/go/vt/vttablet/tabletserver/queryz_test.go b/go/vt/vttablet/tabletserver/queryz_test.go index a0bea742e04..8e1b7b38cfd 100644 --- a/go/vt/vttablet/tabletserver/queryz_test.go +++ b/go/vt/vttablet/tabletserver/queryz_test.go @@ -46,7 +46,7 @@ func TestQueryzHandler(t *testing.T) { }, } plan1.AddStats(10, 2*time.Second, 1*time.Second, 0, 2, 0) - qe.plans.Set(query1, plan1) + qe.plans.Set(query1, plan1, 0, 0) const query2 = "insert into test_table values 1" plan2 := &TabletPlan{ @@ -57,7 +57,7 @@ func TestQueryzHandler(t *testing.T) { }, } plan2.AddStats(1, 2*time.Millisecond, 1*time.Millisecond, 1, 0, 0) - qe.plans.Set(query2, plan2) + qe.plans.Set(query2, plan2, 0, 0) const query3 = "show tables" plan3 := &TabletPlan{ @@ -68,8 +68,8 @@ func TestQueryzHandler(t *testing.T) { }, } plan3.AddStats(1, 75*time.Millisecond, 50*time.Millisecond, 0, 1, 0) - qe.plans.Set(query3, plan3) - qe.plans.Set("", (*TabletPlan)(nil)) + qe.plans.Set(query3, plan3, 0, 0) + qe.plans.Set("", (*TabletPlan)(nil), 0, 0) hugeInsert := "insert into test_table values 0" for i := 1; i < 1000; i++ { @@ -83,11 +83,11 @@ func TestQueryzHandler(t *testing.T) { }, } plan4.AddStats(1, 1*time.Millisecond, 1*time.Millisecond, 1, 0, 0) - qe.plans.Set(hugeInsert, plan4) - qe.plans.Set("", (*TabletPlan)(nil)) + qe.plans.Set(PlanCacheKey(hugeInsert), plan4, 0, 0) + qe.plans.Set("", (*TabletPlan)(nil), 0, 0) // Wait for cache to settle - qe.plans.Wait() + time.Sleep(100 * time.Millisecond) queryzHandler(qe, resp, req) body, _ := io.ReadAll(resp.Body) diff --git a/go/vt/vttablet/tabletserver/schema/engine_test.go b/go/vt/vttablet/tabletserver/schema/engine_test.go index 78b43fd1e0e..4000795d9d0 100644 --- a/go/vt/vttablet/tabletserver/schema/engine_test.go +++ b/go/vt/vttablet/tabletserver/schema/engine_test.go @@ -82,7 +82,7 @@ func TestOpenAndReload(t *testing.T) { )) firstReadRowsValue := 12 AddFakeInnoDBReadRowsResult(db, firstReadRowsValue) - se := newEngine(10, 10*time.Second, 10*time.Second, 0, db) + se := newEngine(10*time.Second, 10*time.Second, 0, db) se.Open() defer se.Close() @@ -273,7 +273,7 @@ func TestReloadWithSwappedTables(t *testing.T) { firstReadRowsValue := 12 AddFakeInnoDBReadRowsResult(db, firstReadRowsValue) - se := newEngine(10, 10*time.Second, 10*time.Second, 0, db) + se := newEngine(10*time.Second, 10*time.Second, 0, db) se.Open() defer se.Close() want := initialSchema() @@ -423,7 +423,7 @@ func TestOpenFailedDueToExecErr(t *testing.T) { schematest.AddDefaultQueries(db) want := "injected error" db.RejectQueryPattern(baseShowTablesPattern, want) - se := newEngine(10, 1*time.Second, 1*time.Second, 0, db) + se := newEngine(1*time.Second, 1*time.Second, 0, db) err := se.Open() if err == nil || !strings.Contains(err.Error(), want) { t.Errorf("se.Open: %v, want %s", err, want) @@ -454,7 +454,7 @@ func TestOpenFailedDueToLoadTableErr(t *testing.T) { db.AddRejectedQuery("SELECT * FROM `fakesqldb`.`test_view` WHERE 1 != 1", sqlerror.NewSQLErrorFromError(errors.New("The user specified as a definer ('root'@'%') does not exist (errno 1449) (sqlstate HY000)"))) AddFakeInnoDBReadRowsResult(db, 0) - se := newEngine(10, 1*time.Second, 1*time.Second, 0, db) + se := newEngine(1*time.Second, 1*time.Second, 0, db) err := se.Open() // failed load should return an error because of test_table assert.ErrorContains(t, err, "Row count exceeded") @@ -489,7 +489,7 @@ func TestOpenNoErrorDueToInvalidViews(t *testing.T) { db.AddRejectedQuery("SELECT `col1`, `col2` FROM `fakesqldb`.`bar_view` WHERE 1 != 1", sqlerror.NewSQLError(sqlerror.ERWrongFieldWithGroup, sqlerror.SSClientError, "random error for table bar_view")) AddFakeInnoDBReadRowsResult(db, 0) - se := newEngine(10, 1*time.Second, 1*time.Second, 0, db) + se := newEngine(1*time.Second, 1*time.Second, 0, db) err := se.Open() require.NoError(t, err) @@ -505,7 +505,7 @@ func TestExportVars(t *testing.T) { db := fakesqldb.New(t) defer db.Close() schematest.AddDefaultQueries(db) - se := newEngine(10, 1*time.Second, 1*time.Second, 0, db) + se := newEngine(1*time.Second, 1*time.Second, 0, db) se.Open() defer se.Close() expvar.Do(func(kv expvar.KeyValue) { @@ -517,7 +517,7 @@ func TestStatsURL(t *testing.T) { db := fakesqldb.New(t) defer db.Close() schematest.AddDefaultQueries(db) - se := newEngine(10, 1*time.Second, 1*time.Second, 0, db) + se := newEngine(1*time.Second, 1*time.Second, 0, db) se.Open() defer se.Close() @@ -547,7 +547,7 @@ func TestSchemaEngineCloseTickRace(t *testing.T) { }) AddFakeInnoDBReadRowsResult(db, 12) // Start the engine with a small reload tick - se := newEngine(10, 100*time.Millisecond, 1*time.Second, 0, db) + se := newEngine(100*time.Millisecond, 1*time.Second, 0, db) err := se.Open() require.NoError(t, err) @@ -574,9 +574,8 @@ func TestSchemaEngineCloseTickRace(t *testing.T) { } } -func newEngine(queryCacheSize int, reloadTime time.Duration, idleTimeout time.Duration, schemaMaxAgeSeconds int64, db *fakesqldb.DB) *Engine { +func newEngine(reloadTime time.Duration, idleTimeout time.Duration, schemaMaxAgeSeconds int64, db *fakesqldb.DB) *Engine { config := tabletenv.NewDefaultConfig() - config.QueryCacheSize = queryCacheSize _ = config.SchemaReloadIntervalSeconds.Set(reloadTime.String()) _ = config.OltpReadPool.IdleTimeoutSeconds.Set(idleTimeout.String()) _ = config.OlapReadPool.IdleTimeoutSeconds.Set(idleTimeout.String()) @@ -1114,7 +1113,7 @@ func TestEngineReload(t *testing.T) { conn, err := connpool.NewDBConnNoPool(context.Background(), db.ConnParams(), nil, nil) require.NoError(t, err) - se := newEngine(10, 10*time.Second, 10*time.Second, 0, db) + se := newEngine(10*time.Second, 10*time.Second, 0, db) se.conns.Open(se.cp, se.cp, se.cp) se.isOpen = true se.notifiers = make(map[string]notifier) diff --git a/go/vt/vttablet/tabletserver/schema/main_test.go b/go/vt/vttablet/tabletserver/schema/main_test.go index 19fc66c36d1..0948c1313fc 100644 --- a/go/vt/vttablet/tabletserver/schema/main_test.go +++ b/go/vt/vttablet/tabletserver/schema/main_test.go @@ -37,7 +37,7 @@ func getTestSchemaEngine(t *testing.T, schemaMaxAgeSeconds int64) (*Engine, *fak db.AddQueryPattern(baseShowTablesPattern, &sqltypes.Result{}) db.AddQuery(mysql.BaseShowPrimary, &sqltypes.Result{}) AddFakeInnoDBReadRowsResult(db, 1) - se := newEngine(10, 10*time.Second, 10*time.Second, schemaMaxAgeSeconds, db) + se := newEngine(10*time.Second, 10*time.Second, schemaMaxAgeSeconds, db) require.NoError(t, se.Open()) cancel := func() { defer db.Close() diff --git a/go/vt/vttablet/tabletserver/tabletenv/config.go b/go/vt/vttablet/tabletserver/tabletenv/config.go index cb91cf271ac..ca7d3c1bdaf 100644 --- a/go/vt/vttablet/tabletserver/tabletenv/config.go +++ b/go/vt/vttablet/tabletserver/tabletenv/config.go @@ -26,7 +26,6 @@ import ( "github.com/spf13/pflag" "google.golang.org/protobuf/encoding/prototext" - "vitess.io/vitess/go/cache" "vitess.io/vitess/go/flagutil" "vitess.io/vitess/go/streamlog" "vitess.io/vitess/go/vt/dbconfigs" @@ -138,9 +137,9 @@ func registerTabletEnvFlags(fs *pflag.FlagSet) { fs.BoolVar(¤tConfig.PassthroughDML, "queryserver-config-passthrough-dmls", defaultConfig.PassthroughDML, "query server pass through all dml statements without rewriting") fs.IntVar(¤tConfig.StreamBufferSize, "queryserver-config-stream-buffer-size", defaultConfig.StreamBufferSize, "query server stream buffer size, the maximum number of bytes sent from vttablet for each stream call. It's recommended to keep this value in sync with vtgate's stream_buffer_size.") - fs.IntVar(¤tConfig.QueryCacheSize, "queryserver-config-query-cache-size", defaultConfig.QueryCacheSize, "query server query cache size, maximum number of queries to be cached. vttablet analyzes every incoming query and generate a query plan, these plans are being cached in a lru cache. This config controls the capacity of the lru cache.") + _ = fs.MarkDeprecated("queryserver-config-query-cache-size", "query server query cache size, maximum number of queries to be cached. vttablet analyzes every incoming query and generate a query plan, these plans are being cached in a lru cache. This config controls the capacity of the lru cache.") fs.Int64Var(¤tConfig.QueryCacheMemory, "queryserver-config-query-cache-memory", defaultConfig.QueryCacheMemory, "query server query cache size in bytes, maximum amount of memory to be used for caching. vttablet analyzes every incoming query and generate a query plan, these plans are being cached in a lru cache. This config controls the capacity of the lru cache.") - fs.BoolVar(¤tConfig.QueryCacheLFU, "queryserver-config-query-cache-lfu", defaultConfig.QueryCacheLFU, "query server cache algorithm. when set to true, a new cache algorithm based on a TinyLFU admission policy will be used to improve cache behavior and prevent pollution from sparse queries") + _ = fs.MarkDeprecated("queryserver-config-query-cache-lfu", "query server cache algorithm. when set to true, a new cache algorithm based on a TinyLFU admission policy will be used to improve cache behavior and prevent pollution from sparse queries") currentConfig.SchemaReloadIntervalSeconds = defaultConfig.SchemaReloadIntervalSeconds.Clone() fs.Var(¤tConfig.SchemaReloadIntervalSeconds, currentConfig.SchemaReloadIntervalSeconds.Name(), "query server schema reload time, how often vttablet reloads schemas from underlying MySQL instance in seconds. vttablet keeps table schemas in its own memory and periodically refreshes it from MySQL. This config controls the reload time.") @@ -336,9 +335,8 @@ type TabletConfig struct { StreamBufferSize int `json:"streamBufferSize,omitempty"` ConsolidatorStreamTotalSize int64 `json:"consolidatorStreamTotalSize,omitempty"` ConsolidatorStreamQuerySize int64 `json:"consolidatorStreamQuerySize,omitempty"` - QueryCacheSize int `json:"queryCacheSize,omitempty"` QueryCacheMemory int64 `json:"queryCacheMemory,omitempty"` - QueryCacheLFU bool `json:"queryCacheLFU,omitempty"` + QueryCacheDoorkeeper bool `json:"queryCacheDoorkeeper,omitempty"` SchemaReloadIntervalSeconds flagutil.DeprecatedFloat64Seconds `json:"schemaReloadIntervalSeconds,omitempty"` SignalSchemaChangeReloadIntervalSeconds flagutil.DeprecatedFloat64Seconds `json:"signalSchemaChangeReloadIntervalSeconds,omitempty"` SchemaChangeReloadTimeout time.Duration `json:"schemaChangeReloadTimeout,omitempty"` @@ -815,10 +813,11 @@ var defaultConfig = TabletConfig{ // memory copies. so with the encoding overhead, this seems to work // great (the overhead makes the final packets on the wire about twice // bigger than this). - StreamBufferSize: 32 * 1024, - QueryCacheSize: int(cache.DefaultConfig.MaxEntries), - QueryCacheMemory: cache.DefaultConfig.MaxMemoryUsage, - QueryCacheLFU: cache.DefaultConfig.LFU, + StreamBufferSize: 32 * 1024, + QueryCacheMemory: 32 * 1024 * 1024, // 32 mb for our query cache + // The doorkeeper for the plan cache is disabled by default in endtoend tests to ensure + // results are consistent between runs. + QueryCacheDoorkeeper: !servenv.TestingEndtoend, SchemaReloadIntervalSeconds: flagutil.NewDeprecatedFloat64Seconds("queryserver-config-schema-reload-time", 30*time.Minute), // SchemaChangeReloadTimeout is used for the signal reload operation where we have to query mysqld. // The queries during the signal reload operation are typically expected to have low load, diff --git a/go/vt/vttablet/tabletserver/tabletenv/config_test.go b/go/vt/vttablet/tabletserver/tabletenv/config_test.go index a1c5421332c..4e52a5c8919 100644 --- a/go/vt/vttablet/tabletserver/tabletenv/config_test.go +++ b/go/vt/vttablet/tabletserver/tabletenv/config_test.go @@ -154,9 +154,8 @@ oltpReadPool: idleTimeoutSeconds: 30m0s maxWaiters: 5000 size: 16 -queryCacheLFU: true +queryCacheDoorkeeper: true queryCacheMemory: 33554432 -queryCacheSize: 5000 replicationTracker: heartbeatIntervalSeconds: 250ms mode: disable diff --git a/go/vt/vttablet/tabletserver/tabletserver.go b/go/vt/vttablet/tabletserver/tabletserver.go index 8ebc2770fa3..fa789b7144f 100644 --- a/go/vt/vttablet/tabletserver/tabletserver.go +++ b/go/vt/vttablet/tabletserver/tabletserver.go @@ -1730,7 +1730,6 @@ func (tsv *TabletServer) HandlePanic(err *error) { // Close shuts down any remaining go routines func (tsv *TabletServer) Close(ctx context.Context) error { tsv.sm.closeAll() - tsv.qe.Close() tsv.stats.Stop() return nil } @@ -1961,11 +1960,6 @@ func (tsv *TabletServer) TxPoolSize() int { return tsv.te.txPool.scp.Capacity() } -// SetQueryPlanCacheCap changes the plan cache capacity to the specified value. -func (tsv *TabletServer) SetQueryPlanCacheCap(val int) { - tsv.qe.SetQueryPlanCacheCap(val) -} - // QueryPlanCacheCap returns the plan cache capacity func (tsv *TabletServer) QueryPlanCacheCap() int { return tsv.qe.QueryPlanCacheCap() @@ -1976,11 +1970,6 @@ func (tsv *TabletServer) QueryPlanCacheLen() int { return tsv.qe.QueryPlanCacheLen() } -// QueryPlanCacheWait waits until the query plan cache has processed all recent queries -func (tsv *TabletServer) QueryPlanCacheWait() { - tsv.qe.plans.Wait() -} - // SetMaxResultSize changes the max result size to the specified value. func (tsv *TabletServer) SetMaxResultSize(val int) { tsv.qe.maxResultSize.Store(int64(val)) diff --git a/go/vt/vttablet/tabletserver/tabletserver_test.go b/go/vt/vttablet/tabletserver/tabletserver_test.go index aca34f40a57..2f4988cc2b6 100644 --- a/go/vt/vttablet/tabletserver/tabletserver_test.go +++ b/go/vt/vttablet/tabletserver/tabletserver_test.go @@ -2051,14 +2051,6 @@ func TestConfigChanges(t *testing.T) { t.Errorf("tsv.te.Pool().Timeout: %v, want %v", val, newDuration) } - tsv.SetQueryPlanCacheCap(newSize) - if val := tsv.QueryPlanCacheCap(); val != newSize { - t.Errorf("QueryPlanCacheCap: %d, want %d", val, newSize) - } - if val := int(tsv.qe.QueryPlanCacheCap()); val != newSize { - t.Errorf("tsv.qe.QueryPlanCacheCap: %d, want %d", val, newSize) - } - tsv.SetMaxResultSize(newSize) if val := tsv.MaxResultSize(); val != newSize { t.Errorf("MaxResultSize: %d, want %d", val, newSize) diff --git a/go/vt/vttest/environment.go b/go/vt/vttest/environment.go index f856d0f038a..dac9d718827 100644 --- a/go/vt/vttest/environment.go +++ b/go/vt/vttest/environment.go @@ -292,6 +292,7 @@ func NewLocalTestEnvWithDirectory(flavor string, basePort int, directory string) Env: []string{ fmt.Sprintf("VTDATAROOT=%s", directory), fmt.Sprintf("MYSQL_FLAVOR=%s", flavor), + "VTTEST=endtoend", }, }, nil }