diff --git a/physical/cache.go b/physical/cache.go index 6a109a32bd03..654941139bb5 100644 --- a/physical/cache.go +++ b/physical/cache.go @@ -3,6 +3,7 @@ package physical import ( "strings" + iradix "github.com/hashicorp/go-immutable-radix" "github.com/hashicorp/golang-lru" "github.com/hashicorp/vault/helper/locksutil" log "github.com/mgutz/logxi/v1" @@ -18,10 +19,11 @@ const ( // Vault are for policy objects so there is a large read reduction // by using a simple write-through cache. type Cache struct { - backend Backend - lru *lru.TwoQueueCache - locks []*locksutil.LockEntry - logger log.Logger + backend Backend + lru *lru.TwoQueueCache + locks []*locksutil.LockEntry + exceptions *iradix.Tree + logger log.Logger } // TransactionalCache is a Cache that wraps the physical that is transactional @@ -32,27 +34,37 @@ type TransactionalCache struct { // NewCache returns a physical cache of the given size. // If no size is provided, the default size is used. -func NewCache(b Backend, size int, logger log.Logger) *Cache { +func NewCache(b Backend, size int, coreExceptions []string, logger log.Logger) *Cache { + if logger.IsTrace() { + logger.Trace("physical/cache: creating LRU cache", "size", size) + } if size <= 0 { size = DefaultCacheSize } - if logger.IsTrace() { - logger.Trace("physical/cache: creating LRU cache", "size", size) + cacheExceptions := iradix.New() + for _, key := range coreExceptions { + cacheValue := true + if strings.HasPrefix(key, "!") { + key = strings.TrimPrefix(key, "!") + cacheValue = false + } + cacheExceptions, _, _ = cacheExceptions.Insert([]byte(key), cacheValue) } + cache, _ := lru.New2Q(size) c := &Cache{ - backend: b, - lru: cache, - locks: locksutil.CreateLocks(), - logger: logger, + backend: b, + lru: cache, + locks: locksutil.CreateLocks(), + exceptions: cacheExceptions, + logger: logger, } - return c } -func NewTransactionalCache(b Backend, size int, logger log.Logger) *TransactionalCache { +func NewTransactionalCache(b Backend, size int, coreExceptions []string, logger log.Logger) *TransactionalCache { c := &TransactionalCache{ - Cache: NewCache(b, size, logger), + Cache: NewCache(b, size, coreExceptions, logger), Transactional: b.(Transactional), } return c @@ -75,7 +87,7 @@ func (c *Cache) Put(entry *Entry) error { defer lock.Unlock() err := c.backend.Put(entry) - if err == nil && !strings.HasPrefix(entry.Key, "core/") { + if err == nil && c.shouldCache(entry.Key) { c.lru.Add(entry.Key, entry) } return err @@ -90,7 +102,7 @@ func (c *Cache) Get(key string) (*Entry, error) { // otherwise we risk certain race conditions upstream. The primary issue is // with the HA mode, we could potentially negatively cache the leader entry // and cause leader discovery to fail. - if strings.HasPrefix(key, "core/") { + if !c.shouldCache(key) { return c.backend.Get(key) } @@ -98,9 +110,8 @@ func (c *Cache) Get(key string) (*Entry, error) { if raw, ok := c.lru.Get(key); ok { if raw == nil { return nil, nil - } else { - return raw.(*Entry), nil } + return raw.(*Entry), nil } // Read from the underlying backend @@ -123,7 +134,7 @@ func (c *Cache) Delete(key string) error { defer lock.Unlock() err := c.backend.Delete(key) - if err == nil && !strings.HasPrefix(key, "core/") { + if err == nil && c.shouldCache(key) { c.lru.Remove(key) } return err @@ -137,10 +148,15 @@ func (c *Cache) List(prefix string) ([]string, error) { } func (c *TransactionalCache) Transaction(txns []*TxnEntry) error { - // Lock the world - for _, lock := range c.locks { - lock.Lock() - defer lock.Unlock() + // Collect keys that need to be locked + var keys []string + for _, curr := range txns { + keys = append(keys, curr.Entry.Key) + } + // Lock the keys + for _, l := range locksutil.LocksForKeys(c.locks, keys) { + l.Lock() + defer l.Unlock() } if err := c.Transactional.Transaction(txns); err != nil { @@ -148,13 +164,32 @@ func (c *TransactionalCache) Transaction(txns []*TxnEntry) error { } for _, txn := range txns { - switch txn.Operation { - case PutOperation: - c.lru.Add(txn.Entry.Key, txn.Entry) - case DeleteOperation: - c.lru.Remove(txn.Entry.Key) + if c.shouldCache(txn.Entry.Key) { + switch txn.Operation { + case PutOperation: + c.lru.Add(txn.Entry.Key, txn.Entry) + case DeleteOperation: + c.lru.Remove(txn.Entry.Key) + } } } return nil } + +// shouldCache checks for any cache exceptions +func (c *Cache) shouldCache(key string) bool { + // prefix match if nested under core/ + if strings.HasPrefix(key, "core/") { + if prefix, val, found := c.exceptions.Root().LongestPrefix([]byte(key)); found { + strPrefix := string(prefix) + if strings.HasSuffix(strPrefix, "/") || strPrefix == key { + return val.(bool) + } + } + // default for core/ values is false + return false + } + // default is true + return true +} diff --git a/physical/inmem/cache_test.go b/physical/inmem/cache_test.go index c771f03920ce..cd9de9332d84 100644 --- a/physical/inmem/cache_test.go +++ b/physical/inmem/cache_test.go @@ -15,7 +15,7 @@ func TestCache(t *testing.T) { if err != nil { t.Fatal(err) } - cache := physical.NewCache(inm, 0, logger) + cache := physical.NewCache(inm, 0, nil, logger) physical.ExerciseBackend(t, cache) physical.ExerciseBackend_ListPrefix(t, cache) } @@ -27,7 +27,7 @@ func TestCache_Purge(t *testing.T) { if err != nil { t.Fatal(err) } - cache := physical.NewCache(inm, 0, logger) + cache := physical.NewCache(inm, 0, nil, logger) ent := &physical.Entry{ Key: "foo", @@ -63,7 +63,7 @@ func TestCache_Purge(t *testing.T) { } } -func TestCache_IgnoreCore(t *testing.T) { +func TestCache_ExcludeCore(t *testing.T) { logger := logformat.NewVaultLogger(log.LevelTrace) inm, err := NewInmem(nil, logger) @@ -71,7 +71,7 @@ func TestCache_IgnoreCore(t *testing.T) { t.Fatal(err) } - cache := physical.NewCache(inm, 0, logger) + cache := physical.NewCache(inm, 0, nil, logger) var ent *physical.Entry @@ -151,3 +151,180 @@ func TestCache_IgnoreCore(t *testing.T) { t.Fatal("expected non-cached value") } } + +func TestCache_ExcludeCoreTransactional(t *testing.T) { + logger := logformat.NewVaultLogger(log.LevelTrace) + + inm, err := NewTransactionalInmem(nil, logger) + if err != nil { + t.Fatal(err) + } + cache := physical.NewTransactionalCache(inm, 0, nil, logger) + + var ent *physical.TxnEntry + var entry *physical.Entry + + // First try normal handling + ent = &physical.TxnEntry{ + Operation: physical.PutOperation, + Entry: &physical.Entry{ + Key: "foo", + Value: []byte("bar"), + }, + } + if err := cache.Transaction([]*physical.TxnEntry{ent}); err != nil { + t.Fatal(err) + } + ent = &physical.TxnEntry{ + Operation: physical.PutOperation, + Entry: &physical.Entry{ + Key: "foo", + Value: []byte("foobar"), + }, + } + if err := inm.(physical.Transactional).Transaction([]*physical.TxnEntry{ent}); err != nil { + t.Fatal(err) + } + entry, err = cache.Get("foo") + if err != nil { + t.Fatal(err) + } + if string(entry.Value) != "bar" { + t.Fatal("expected cached value") + } + + // Now try core path + ent = &physical.TxnEntry{ + Operation: physical.PutOperation, + Entry: &physical.Entry{ + Key: "core/foo", + Value: []byte("bar"), + }, + } + if err := cache.Transaction([]*physical.TxnEntry{ent}); err != nil { + t.Fatal(err) + } + ent = &physical.TxnEntry{ + Operation: physical.PutOperation, + Entry: &physical.Entry{ + Key: "core/foo", + Value: []byte("foobar"), + }, + } + if err := inm.(physical.Transactional).Transaction([]*physical.TxnEntry{ent}); err != nil { + t.Fatal(err) + } + entry, err = cache.Get("core/foo") + if err != nil { + t.Fatal(err) + } + if string(entry.Value) != "foobar" { + t.Fatal("expected cached value") + } +} + +func TestCache_CoreExceptions(t *testing.T) { + logger := logformat.NewVaultLogger(log.LevelTrace) + + inm, err := NewInmem(nil, logger) + if err != nil { + t.Fatal(err) + } + + cache := physical.NewCache(inm, 0, []string{"core/bar", "!core/baz/", "core/baz/zzz"}, logger) + + var ent *physical.Entry + + // Now try core path + ent = &physical.Entry{ + Key: "core/foo", + Value: []byte("bar"), + } + if err := cache.Put(ent); err != nil { + t.Fatal(err) + } + ent = &physical.Entry{ + Key: "core/foo", + Value: []byte("foobar"), + } + if err := inm.Put(ent); err != nil { + t.Fatal(err) + } + ent, err = cache.Get("core/foo") + if err != nil { + t.Fatal(err) + } + if string(ent.Value) != "foobar" { + t.Fatal("expected cached value") + } + + // Now try an exception + ent = &physical.Entry{ + Key: "core/bar", + Value: []byte("bar"), + } + if err := cache.Put(ent); err != nil { + t.Fatal(err) + } + ent = &physical.Entry{ + Key: "core/bar", + Value: []byte("foobar"), + } + if err := inm.Put(ent); err != nil { + t.Fatal(err) + } + ent, err = cache.Get("core/bar") + if err != nil { + t.Fatal(err) + } + if string(ent.Value) != "bar" { + t.Fatal("expected cached value") + } + + // another one + ent = &physical.Entry{ + Key: "core/baz/aaa", + Value: []byte("bar"), + } + if err := cache.Put(ent); err != nil { + t.Fatal(err) + } + ent = &physical.Entry{ + Key: "core/baz/aaa", + Value: []byte("foobar"), + } + if err := inm.Put(ent); err != nil { + t.Fatal(err) + } + ent, err = cache.Get("core/baz/aaa") + if err != nil { + t.Fatal(err) + } + if string(ent.Value) != "foobar" { + t.Fatal("expected cached value") + } + + // another one + ent = &physical.Entry{ + Key: "core/baz/zzz", + Value: []byte("bar"), + } + if err := cache.Put(ent); err != nil { + t.Fatal(err) + } + ent = &physical.Entry{ + Key: "core/baz/zzz", + Value: []byte("foobar"), + } + if err := inm.Put(ent); err != nil { + t.Fatal(err) + } + ent, err = cache.Get("core/baz/zzz") + if err != nil { + t.Fatal(err) + } + if string(ent.Value) != "bar" { + t.Fatal("expected cached value") + } + +} diff --git a/vault/core.go b/vault/core.go index 4df0cbdc76f9..ba3a6898486b 100644 --- a/vault/core.go +++ b/vault/core.go @@ -517,9 +517,9 @@ func NewCore(conf *CoreConfig) (*Core, error) { // Wrap the physical backend in a cache layer if enabled if !conf.DisableCache { if txnOK { - c.physical = physical.NewTransactionalCache(phys, conf.CacheSize, conf.Logger) + c.physical = physical.NewTransactionalCache(phys, conf.CacheSize, nil, conf.Logger) } else { - c.physical = physical.NewCache(phys, conf.CacheSize, conf.Logger) + c.physical = physical.NewCache(phys, conf.CacheSize, nil, conf.Logger) } }