Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

agent: add an inflight cache better concurrent request handling #10705

Merged
merged 20 commits into from
Jan 26, 2021
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 77 additions & 34 deletions command/agent/cache/lease_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import (
"github.com/hashicorp/vault/sdk/helper/jsonutil"
"github.com/hashicorp/vault/sdk/helper/locksutil"
"github.com/hashicorp/vault/sdk/logical"
gocache "github.com/patrickmn/go-cache"
"go.uber.org/atomic"
)

const (
Expand Down Expand Up @@ -78,6 +80,10 @@ type LeaseCache struct {
// idLocks is used during cache lookup to ensure that identical requests made
// in parallel won't trigger multiple renewal goroutines.
idLocks []*locksutil.LockEntry

// inflightCache keeps track of inflight requests
inflightCache *gocache.Cache
inflightCacheLock sync.Mutex
}

// LeaseCacheConfig is the configuration for initializing a new
Expand All @@ -89,6 +95,22 @@ type LeaseCacheConfig struct {
Logger hclog.Logger
}

type inflightRequest struct {
// ch is closed by the request that ends up processing the set of
// parallel request
ch chan struct{}

// remaining is the number of remaining inflight request that needs to
// be processed before this object can be cleaned up
remaining atomic.Uint64
ncabatoff marked this conversation as resolved.
Show resolved Hide resolved
}

func newInflightRequest() *inflightRequest {
return &inflightRequest{
ch: make(chan struct{}),
}
}

// NewLeaseCache creates a new instance of a LeaseCache.
func NewLeaseCache(conf *LeaseCacheConfig) (*LeaseCache, error) {
if conf == nil {
Expand All @@ -112,13 +134,14 @@ func NewLeaseCache(conf *LeaseCacheConfig) (*LeaseCache, error) {
baseCtxInfo := cachememdb.NewContextInfo(conf.BaseContext)

return &LeaseCache{
client: conf.Client,
proxier: conf.Proxier,
logger: conf.Logger,
db: db,
baseCtxInfo: baseCtxInfo,
l: &sync.RWMutex{},
idLocks: locksutil.CreateLocks(),
client: conf.Client,
proxier: conf.Proxier,
logger: conf.Logger,
db: db,
baseCtxInfo: baseCtxInfo,
l: &sync.RWMutex{},
idLocks: locksutil.CreateLocks(),
inflightCache: gocache.New(gocache.NoExpiration, gocache.NoExpiration),
}, nil
}

Expand Down Expand Up @@ -170,40 +193,60 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse,
return nil, err
}

// Grab a read lock for this particular request
idLock := locksutil.LockForKey(c.idLocks, id)
// Check the inflight cache to see if there are other inflight requests
// of the same kind, based on the computed ID. If so, we increment a counter

idLock.RLock()
unlockFunc := idLock.RUnlock
defer func() { unlockFunc() }()
var inflight *inflightRequest

// Check if the response for this request is already in the cache
sendResp, err := c.checkCacheForRequest(id)
if err != nil {
return nil, err
}
if sendResp != nil {
c.logger.Debug("returning cached response", "path", req.Request.URL.Path)
return sendResp, nil
}
defer func() {
// Cleanup on the cache if there are no remaining inflight requests.
// This is the last step, so we defer the call first
if inflight != nil && inflight.remaining.Load() == 0 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When would inflight be nil at this point?

Copy link
Contributor Author

@calvn calvn Jan 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't be nil because the conditional further down always assigns this values to something, but I'm nil-checking just in case since this is within a defer that's called right after the variable is declared but before the value is assigned.

c.inflightCache.Delete(id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This felt a little racy since we can be here multiple times for the same id if a request comes in right after the condition Load() == 0 is checked. But I've gone through it and don't think there is any harmful behavior. The inflight object still exists even if it's deleted from the cache so the final request can complete, and calling Delete() on an id not present is a no-op.

}
}()

// Perform a lock upgrade
idLock.RUnlock()
idLock := locksutil.LockForKey(c.idLocks, id)

// Briefly grab an ID-based lock in here to emulate a load-or-store behavior
// and prevent concurrent cacheable requests from being proxied twice if
// they both miss the cache due to it being clean when peeking the cache
// entry.
idLock.Lock()
unlockFunc = idLock.Unlock
inflightRaw, found := c.inflightCache.Get(id)
if found {
idLock.Unlock()
inflight = inflightRaw.(*inflightRequest)
inflight.remaining.Inc()
defer inflight.remaining.Dec()

// If found it means that there's an inflight request being processed.
// We wait until that's finished before proceeding further.
select {
case <-ctx.Done():
calvn marked this conversation as resolved.
Show resolved Hide resolved
return nil, ctx.Err()
case <-inflight.ch:
Copy link
Contributor

@briankassouf briankassouf Jan 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So right now if we detect here that the thread processing the request has completed (channel has been closed) then we simply continue. But once we get down to:

cachedResp, err := c.checkCacheForRequest(id)

We'd see an nil cachedResp since that is still going to only cache leased values. Then, i think, we'd simply re-send the request to the Vault server. I think this fix is missing a step where we store the resulting request in the inflightRequest object and access it here when the channel is closed. Thoughts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The winner who had closed the channel will have cached the response before this thread gets to call c.checkCacheForRequest(id) so it will result in a cache hit. In the case that the request resulted in a non-cacheable response, it would proxy to Vault as it should.

The changes in the PR don't actually prevent identical non-cacheable requests from being proxied to Vault; it simply allows one of the requests to be processed first (since we don't know if it's cacheable) before opening the floodgate to let other identical request to be processed concurrently. I don't think there's a need to store the actual request/response object in the inflightRequest .

}
} else {
inflight = newInflightRequest()
inflight.remaining.Inc()
defer inflight.remaining.Dec()

c.inflightCache.Set(id, inflight, gocache.NoExpiration)
idLock.Unlock()

// Signal that the processing request is done
defer close(inflight.ch)
}

// Check cache once more after upgrade
sendResp, err = c.checkCacheForRequest(id)
// Check if the response for this request is already in the cache
cachedResp, err := c.checkCacheForRequest(id)
if err != nil {
return nil, err
}

// If found, it means that some other parallel request already cached this response
// in between this upgrade so we can simply return that. Otherwise, this request
// will be the one performing the cache write.
if sendResp != nil {
c.logger.Debug("returning cached response", "method", req.Request.Method, "path", req.Request.URL.Path)
return sendResp, nil
if cachedResp != nil {
c.logger.Debug("returning cached response", "path", req.Request.URL.Path)
return cachedResp, nil
}

c.logger.Debug("forwarding request", "method", req.Request.Method, "path", req.Request.URL.Path)
Expand Down Expand Up @@ -441,7 +484,7 @@ func (c *LeaseCache) startRenewing(ctx context.Context, index *cachememdb.Index,
func computeIndexID(req *SendRequest) (string, error) {
var b bytes.Buffer

// Serialze the request
// Serialize the request
if err := req.Request.Write(&b); err != nil {
return "", fmt.Errorf("failed to serialize request: %v", err)
}
Expand Down
134 changes: 132 additions & 2 deletions command/agent/cache/lease_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,17 @@ import (
"net/url"
"reflect"
"strings"
"sync"
"testing"

"github.com/hashicorp/vault/command/agent/cache/cachememdb"
"time"

"github.com/go-test/deep"
hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/command/agent/cache/cachememdb"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/logging"
"go.uber.org/atomic"
)

func testNewLeaseCache(t *testing.T, responses []*SendResponse) *LeaseCache {
Expand All @@ -40,6 +42,27 @@ func testNewLeaseCache(t *testing.T, responses []*SendResponse) *LeaseCache {
return lc
}

func testNewLeaseCacheWithDelay(t *testing.T, cacheable bool, delay int) *LeaseCache {
t.Helper()

client, err := api.NewClient(api.DefaultConfig())
if err != nil {
t.Fatal(err)
}

lc, err := NewLeaseCache(&LeaseCacheConfig{
Client: client,
BaseContext: context.Background(),
Proxier: &mockDelayProxier{cacheable, delay},
Logger: logging.NewVaultLogger(hclog.Trace).Named("cache.leasecache"),
})
if err != nil {
t.Fatal(err)
}

return lc
}

func TestCache_ComputeIndexID(t *testing.T) {
type args struct {
req *http.Request
Expand Down Expand Up @@ -509,3 +532,110 @@ func TestCache_DeriveNamespaceAndRevocationPath(t *testing.T) {
})
}
}

func TestLeaseCache_Concurrent_NonCacheable(t *testing.T) {
lc := testNewLeaseCacheWithDelay(t, false, 50)

// We are going to send 100 requests, each taking 50ms to process. If these
// requests are processed serially, it will take ~5seconds to finish. we
// use a ContextWithTimeout to tell us if this is the case by giving ample
// time for it process them concurrently but time out if they get processed
// serially.
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
calvn marked this conversation as resolved.
Show resolved Hide resolved
defer cancel()

wgDoneCh := make(chan struct{})

go func() {
var wg sync.WaitGroup
// 100 concurrent requests
for i := 0; i < 100; i++ {
wg.Add(1)

go func() {
defer wg.Done()

// Send a request through the lease cache which is not cacheable (there is
// no lease information or auth information in the response)
sendReq := &SendRequest{
Request: httptest.NewRequest("GET", "http://example.com", nil),
}

_, err := lc.Send(ctx, sendReq)
if err != nil {
t.Fatal(err)
}
}()
}

wg.Wait()
close(wgDoneCh)
}()

select {
case <-ctx.Done():
t.Fatalf("request timed out: %s", ctx.Err())
case <-wgDoneCh:
}

}

func TestLeaseCache_Concurrent_Cacheable(t *testing.T) {
lc := testNewLeaseCacheWithDelay(t, true, 50)

if err := lc.RegisterAutoAuthToken("autoauthtoken"); err != nil {
t.Fatal(err)
}

// We are going to send 100 requests, each taking 50ms to process. If these
// requests are processed serially, it will take ~5seconds to finish, so we
// use a ContextWithTimeout to tell us if this is the case.
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()

var cacheCount atomic.Uint32
wgDoneCh := make(chan struct{})

go func() {
var wg sync.WaitGroup
// Start 100 concurrent requests
for i := 0; i < 100; i++ {
wg.Add(1)

go func() {
defer wg.Done()

// Send a request through the lease cache which is not cacheable (there is
// no lease information or auth information in the response)
sendReq := &SendRequest{
Token: "autoauthtoken",
Request: httptest.NewRequest("GET", "http://example.com/v1/sample/api", nil),
}

resp, err := lc.Send(ctx, sendReq)
if err != nil {
t.Fatal(err)
}

if resp.CacheMeta != nil && resp.CacheMeta.Hit {
cacheCount.Inc()
}
}()
}

wg.Wait()
close(wgDoneCh)
}()

select {
case <-ctx.Done():
t.Fatalf("request timed out: %s", ctx.Err())
case <-wgDoneCh:
}

// Ensure that all but one request got proxied. The other 99 should be
// returned from the cache.
if cacheCount.Load() != 99 {
t.Fatalf("Should have returned a cached response 99 times, got %d", cacheCount.Load())
}
}
25 changes: 25 additions & 0 deletions command/agent/cache/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"io/ioutil"
"math/rand"
"net/http"
"strings"
"time"
Expand Down Expand Up @@ -80,3 +81,27 @@ func (p *mockTokenVerifierProxier) Send(ctx context.Context, req *SendRequest) (
func (p *mockTokenVerifierProxier) GetCurrentRequestToken() string {
return p.currentToken
}

type mockDelayProxier struct {
cacheableResp bool
delay int
}

func (p *mockDelayProxier) Send(ctx context.Context, req *SendRequest) (*SendResponse, error) {
if p.delay > 0 {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(time.Duration(p.delay) * time.Millisecond):
}
}

// If this is a cacheable response, we return a unique response every time
if p.cacheableResp {
rand.Seed(time.Now().Unix())
s := fmt.Sprintf(`{"lease_id": "%d", "renewable": true, "data": {"foo": "bar"}}`, rand.Int())
return newTestSendResponse(http.StatusOK, s), nil
}

return newTestSendResponse(http.StatusOK, `{"value": "output"}`), nil
}