Skip to content

Commit

Permalink
agent: add an inflight cache better concurrent request handling (#10705)
Browse files Browse the repository at this point in the history
* agent: do not grap idLock writelock until caching entry

* agent: inflight cache using sync.Map

* agent: implement an inflight caching mechanism

* agent/lease: add lock for inflight cache to prevent simultaneous Set calls

* agent/lease: lock on a per-ID basis so unique requests can be processed independently

* agent/lease: add some concurrency tests

* test: use lease_id for uniqueness

* agent: remove env flags, add comments around locks

* agent: clean up test comment

* agent: clean up test comment

* agent: remove commented debug code

* agent/lease: word-smithing

* Update command/agent/cache/lease_cache.go

Co-authored-by: Nick Cabatoff <[email protected]>

* agent/lease: return the context error if the Done ch got closed

* agent/lease: fix data race in concurrency tests

* agent/lease: mockDelayProxier: return ctx.Err() if context got canceled

* agent/lease: remove unused inflightCacheLock

* agent/lease: test: bump context timeout to 3s

Co-authored-by: Nick Cabatoff <[email protected]>
  • Loading branch information
calvn and ncabatoff authored Jan 26, 2021
1 parent 316ccea commit df51db7
Show file tree
Hide file tree
Showing 3 changed files with 231 additions and 36 deletions.
110 changes: 76 additions & 34 deletions command/agent/cache/lease_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import (
"github.com/hashicorp/vault/sdk/helper/jsonutil"
"github.com/hashicorp/vault/sdk/helper/locksutil"
"github.com/hashicorp/vault/sdk/logical"
gocache "github.com/patrickmn/go-cache"
"go.uber.org/atomic"
)

const (
Expand Down Expand Up @@ -78,6 +80,9 @@ type LeaseCache struct {
// idLocks is used during cache lookup to ensure that identical requests made
// in parallel won't trigger multiple renewal goroutines.
idLocks []*locksutil.LockEntry

// inflightCache keeps track of inflight requests
inflightCache *gocache.Cache
}

// LeaseCacheConfig is the configuration for initializing a new
Expand All @@ -89,6 +94,22 @@ type LeaseCacheConfig struct {
Logger hclog.Logger
}

type inflightRequest struct {
// ch is closed by the request that ends up processing the set of
// parallel request
ch chan struct{}

// remaining is the number of remaining inflight request that needs to
// be processed before this object can be cleaned up
remaining atomic.Uint64
}

func newInflightRequest() *inflightRequest {
return &inflightRequest{
ch: make(chan struct{}),
}
}

// NewLeaseCache creates a new instance of a LeaseCache.
func NewLeaseCache(conf *LeaseCacheConfig) (*LeaseCache, error) {
if conf == nil {
Expand All @@ -112,13 +133,14 @@ func NewLeaseCache(conf *LeaseCacheConfig) (*LeaseCache, error) {
baseCtxInfo := cachememdb.NewContextInfo(conf.BaseContext)

return &LeaseCache{
client: conf.Client,
proxier: conf.Proxier,
logger: conf.Logger,
db: db,
baseCtxInfo: baseCtxInfo,
l: &sync.RWMutex{},
idLocks: locksutil.CreateLocks(),
client: conf.Client,
proxier: conf.Proxier,
logger: conf.Logger,
db: db,
baseCtxInfo: baseCtxInfo,
l: &sync.RWMutex{},
idLocks: locksutil.CreateLocks(),
inflightCache: gocache.New(gocache.NoExpiration, gocache.NoExpiration),
}, nil
}

Expand Down Expand Up @@ -170,40 +192,60 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse,
return nil, err
}

// Grab a read lock for this particular request
idLock := locksutil.LockForKey(c.idLocks, id)
// Check the inflight cache to see if there are other inflight requests
// of the same kind, based on the computed ID. If so, we increment a counter

idLock.RLock()
unlockFunc := idLock.RUnlock
defer func() { unlockFunc() }()
var inflight *inflightRequest

// Check if the response for this request is already in the cache
sendResp, err := c.checkCacheForRequest(id)
if err != nil {
return nil, err
}
if sendResp != nil {
c.logger.Debug("returning cached response", "path", req.Request.URL.Path)
return sendResp, nil
}
defer func() {
// Cleanup on the cache if there are no remaining inflight requests.
// This is the last step, so we defer the call first
if inflight != nil && inflight.remaining.Load() == 0 {
c.inflightCache.Delete(id)
}
}()

// Perform a lock upgrade
idLock.RUnlock()
idLock := locksutil.LockForKey(c.idLocks, id)

// Briefly grab an ID-based lock in here to emulate a load-or-store behavior
// and prevent concurrent cacheable requests from being proxied twice if
// they both miss the cache due to it being clean when peeking the cache
// entry.
idLock.Lock()
unlockFunc = idLock.Unlock
inflightRaw, found := c.inflightCache.Get(id)
if found {
idLock.Unlock()
inflight = inflightRaw.(*inflightRequest)
inflight.remaining.Inc()
defer inflight.remaining.Dec()

// If found it means that there's an inflight request being processed.
// We wait until that's finished before proceeding further.
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-inflight.ch:
}
} else {
inflight = newInflightRequest()
inflight.remaining.Inc()
defer inflight.remaining.Dec()

c.inflightCache.Set(id, inflight, gocache.NoExpiration)
idLock.Unlock()

// Signal that the processing request is done
defer close(inflight.ch)
}

// Check cache once more after upgrade
sendResp, err = c.checkCacheForRequest(id)
// Check if the response for this request is already in the cache
cachedResp, err := c.checkCacheForRequest(id)
if err != nil {
return nil, err
}

// If found, it means that some other parallel request already cached this response
// in between this upgrade so we can simply return that. Otherwise, this request
// will be the one performing the cache write.
if sendResp != nil {
c.logger.Debug("returning cached response", "method", req.Request.Method, "path", req.Request.URL.Path)
return sendResp, nil
if cachedResp != nil {
c.logger.Debug("returning cached response", "path", req.Request.URL.Path)
return cachedResp, nil
}

c.logger.Debug("forwarding request", "method", req.Request.Method, "path", req.Request.URL.Path)
Expand Down Expand Up @@ -441,7 +483,7 @@ func (c *LeaseCache) startRenewing(ctx context.Context, index *cachememdb.Index,
func computeIndexID(req *SendRequest) (string, error) {
var b bytes.Buffer

// Serialze the request
// Serialize the request
if err := req.Request.Write(&b); err != nil {
return "", fmt.Errorf("failed to serialize request: %v", err)
}
Expand Down
132 changes: 130 additions & 2 deletions command/agent/cache/lease_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,17 @@ import (
"net/url"
"reflect"
"strings"
"sync"
"testing"

"github.com/hashicorp/vault/command/agent/cache/cachememdb"
"time"

"github.com/go-test/deep"
hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/command/agent/cache/cachememdb"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/logging"
"go.uber.org/atomic"
)

func testNewLeaseCache(t *testing.T, responses []*SendResponse) *LeaseCache {
Expand All @@ -40,6 +42,27 @@ func testNewLeaseCache(t *testing.T, responses []*SendResponse) *LeaseCache {
return lc
}

func testNewLeaseCacheWithDelay(t *testing.T, cacheable bool, delay int) *LeaseCache {
t.Helper()

client, err := api.NewClient(api.DefaultConfig())
if err != nil {
t.Fatal(err)
}

lc, err := NewLeaseCache(&LeaseCacheConfig{
Client: client,
BaseContext: context.Background(),
Proxier: &mockDelayProxier{cacheable, delay},
Logger: logging.NewVaultLogger(hclog.Trace).Named("cache.leasecache"),
})
if err != nil {
t.Fatal(err)
}

return lc
}

func TestCache_ComputeIndexID(t *testing.T) {
type args struct {
req *http.Request
Expand Down Expand Up @@ -509,3 +532,108 @@ func TestCache_DeriveNamespaceAndRevocationPath(t *testing.T) {
})
}
}

func TestLeaseCache_Concurrent_NonCacheable(t *testing.T) {
lc := testNewLeaseCacheWithDelay(t, false, 50)

// We are going to send 100 requests, each taking 50ms to process. If these
// requests are processed serially, it will take ~5seconds to finish. we
// use a ContextWithTimeout to tell us if this is the case by giving ample
// time for it process them concurrently but time out if they get processed
// serially.
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()

wgDoneCh := make(chan struct{})

go func() {
var wg sync.WaitGroup
// 100 concurrent requests
for i := 0; i < 100; i++ {
wg.Add(1)

go func() {
defer wg.Done()

// Send a request through the lease cache which is not cacheable (there is
// no lease information or auth information in the response)
sendReq := &SendRequest{
Request: httptest.NewRequest("GET", "http://example.com", nil),
}

_, err := lc.Send(ctx, sendReq)
if err != nil {
t.Fatal(err)
}
}()
}

wg.Wait()
close(wgDoneCh)
}()

select {
case <-ctx.Done():
t.Fatalf("request timed out: %s", ctx.Err())
case <-wgDoneCh:
}

}

func TestLeaseCache_Concurrent_Cacheable(t *testing.T) {
lc := testNewLeaseCacheWithDelay(t, true, 50)

if err := lc.RegisterAutoAuthToken("autoauthtoken"); err != nil {
t.Fatal(err)
}

// We are going to send 100 requests, each taking 50ms to process. If these
// requests are processed serially, it will take ~5seconds to finish, so we
// use a ContextWithTimeout to tell us if this is the case.
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()

var cacheCount atomic.Uint32
wgDoneCh := make(chan struct{})

go func() {
var wg sync.WaitGroup
// Start 100 concurrent requests
for i := 0; i < 100; i++ {
wg.Add(1)

go func() {
defer wg.Done()

sendReq := &SendRequest{
Token: "autoauthtoken",
Request: httptest.NewRequest("GET", "http://example.com/v1/sample/api", nil),
}

resp, err := lc.Send(ctx, sendReq)
if err != nil {
t.Fatal(err)
}

if resp.CacheMeta != nil && resp.CacheMeta.Hit {
cacheCount.Inc()
}
}()
}

wg.Wait()
close(wgDoneCh)
}()

select {
case <-ctx.Done():
t.Fatalf("request timed out: %s", ctx.Err())
case <-wgDoneCh:
}

// Ensure that all but one request got proxied. The other 99 should be
// returned from the cache.
if cacheCount.Load() != 99 {
t.Fatalf("Should have returned a cached response 99 times, got %d", cacheCount.Load())
}
}
25 changes: 25 additions & 0 deletions command/agent/cache/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"io/ioutil"
"math/rand"
"net/http"
"strings"
"time"
Expand Down Expand Up @@ -80,3 +81,27 @@ func (p *mockTokenVerifierProxier) Send(ctx context.Context, req *SendRequest) (
func (p *mockTokenVerifierProxier) GetCurrentRequestToken() string {
return p.currentToken
}

type mockDelayProxier struct {
cacheableResp bool
delay int
}

func (p *mockDelayProxier) Send(ctx context.Context, req *SendRequest) (*SendResponse, error) {
if p.delay > 0 {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(time.Duration(p.delay) * time.Millisecond):
}
}

// If this is a cacheable response, we return a unique response every time
if p.cacheableResp {
rand.Seed(time.Now().Unix())
s := fmt.Sprintf(`{"lease_id": "%d", "renewable": true, "data": {"foo": "bar"}}`, rand.Int())
return newTestSendResponse(http.StatusOK, s), nil
}

return newTestSendResponse(http.StatusOK, `{"value": "output"}`), nil
}

0 comments on commit df51db7

Please sign in to comment.