diff --git a/sdk/azcore/policy_bearer_token.go b/sdk/azcore/policy_bearer_token.go new file mode 100644 index 000000000000..3f264103c89f --- /dev/null +++ b/sdk/azcore/policy_bearer_token.go @@ -0,0 +1,111 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azcore + +import ( + "errors" + "net/http" + "sync" + "time" +) + +const ( + bearerTokenPrefix = "Bearer " +) + +type bearerTokenPolicy struct { + // cond is used to synchronize token refresh. the locker + // must be locked when updating the following shared state. + cond *sync.Cond + + // renewing indicates that the token is in the process of being refreshed + renewing bool + + // header contains the authorization header value + header string + + // expiresOn is when the token will expire + expiresOn time.Time + + // the following fields are read-only + creds TokenCredential + options TokenRequestOptions +} + +func NewBearerTokenPolicy(cred TokenCredential, opts AuthenticationPolicyOptions) Policy { + return &bearerTokenPolicy{ + cond: sync.NewCond(&sync.Mutex{}), + creds: cred, + options: opts.Options, + } +} + +func (b *bearerTokenPolicy) Do(req *Request) (*Response, error) { + if req.URL.Scheme != "https" { + // HTTPS must be used, otherwise the tokens are at the risk of being exposed + return nil, errors.New("token credentials require a URL using the HTTPS protocol scheme") + } + // create a "refresh window" before the token's real expiration date. + // this allows callers to continue to use the old token while the + // refresh is in progress. + const window = 2 * time.Minute + now, getToken, header := time.Now(), false, "" + // acquire exclusive lock + b.cond.L.Lock() + for { + if b.expiresOn.IsZero() || b.expiresOn.Before(now) { + // token was never obtained or has expired + if !b.renewing { + // another go routine isn't refreshing the token so this one will + b.renewing = true + getToken = true + break + } + // getting here means this go routine will wait for the token to refresh + } else if b.expiresOn.Add(-window).Before(now) { + // token is within the expiration window + if !b.renewing { + // another go routine isn't refreshing the token so this one will + b.renewing = true + getToken = true + break + } + // this go routine will use the existing token while another refreshes it + header = b.header + break + } else { + // token is not expiring yet so use it as-is + header = b.header + break + } + // wait for the token to refresh + b.cond.Wait() + } + b.cond.L.Unlock() + if getToken { + // this go routine has been elected to refresh the token + tk, err := b.creds.GetToken(req.Context(), b.options) + // update shared state + b.cond.L.Lock() + // to avoid a deadlock if GetToken() fails we MUST reset b.renewing to false before returning + b.renewing = false + if err != nil { + b.unlock() + return nil, err + } + header = bearerTokenPrefix + tk.Token + b.header = header + b.expiresOn = tk.ExpiresOn + b.unlock() + } + req.Request.Header.Set(HeaderXmsDate, time.Now().UTC().Format(http.TimeFormat)) + req.Request.Header.Set(HeaderAuthorization, header) + return req.Next() +} + +// signal any waiters that the token has been refreshed +func (b *bearerTokenPolicy) unlock() { + b.cond.Broadcast() + b.cond.L.Unlock() +} diff --git a/sdk/azcore/policy_bearer_token_test.go b/sdk/azcore/policy_bearer_token_test.go new file mode 100644 index 000000000000..416fc713ec3d --- /dev/null +++ b/sdk/azcore/policy_bearer_token_test.go @@ -0,0 +1,68 @@ +// +build go1.13 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azcore + +import ( + "context" + "fmt" + "net/http" + "reflect" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" +) + +type dummyCredential struct{} + +func (c *dummyCredential) GetToken(ctx context.Context, options TokenRequestOptions) (*AccessToken, error) { + return &AccessToken{ + Token: "success_token", + ExpiresOn: time.Date(2021, 06, 25, 3, 20, 0, 0, time.UTC), + }, nil +} + +func (c *dummyCredential) AuthenticationPolicy(options AuthenticationPolicyOptions) Policy { + return NewBearerTokenPolicy(c, options) +} + +func TestBearerTokenPolicyHTTPFail(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.SetResponse(mock.WithStatusCode(http.StatusOK)) + cred := &dummyCredential{} + pl := NewPipeline(srv, NewBearerTokenPolicy(cred, AuthenticationPolicyOptions{})) + req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + _, err = pl.Do(req) + if err == nil { + t.Fatalf("expected an error but did not receive one") + } +} + +func TestBearerTokenPolicy(t *testing.T) { + srv, close := mock.NewTLSServer() + defer close() + srv.SetResponse(mock.WithStatusCode(http.StatusOK)) + cred := &dummyCredential{} + pl := NewPipeline(srv, NewBearerTokenPolicy(cred, AuthenticationPolicyOptions{})) + req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + resp, err := pl.Do(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !reflect.DeepEqual(req.Header, resp.Request.Header) { + t.Fatal("unexpected modification to request headers") + } + if resp.Request.Header.Get(HeaderAuthorization) != fmt.Sprintf("Bearer %s", "success_token") { + t.Fatalf("unexpected value in Authorization header: %v", resp.Request.Header.Get(HeaderAuthorization)) + } +}