-
Notifications
You must be signed in to change notification settings - Fork 849
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6c64de6
commit bde34d4
Showing
2 changed files
with
179 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
package azcore | ||
|
||
import ( | ||
"errors" | ||
"net/http" | ||
"sync" | ||
"time" | ||
) | ||
|
||
const ( | ||
bearerTokenPrefix = "Bearer " | ||
) | ||
|
||
type bearerTokenPolicy struct { | ||
// cond is used to synchronize token refresh. the locker | ||
// must be locked when updating the following shared state. | ||
cond *sync.Cond | ||
|
||
// renewing indicates that the token is in the process of being refreshed | ||
renewing bool | ||
|
||
// header contains the authorization header value | ||
header string | ||
|
||
// expiresOn is when the token will expire | ||
expiresOn time.Time | ||
|
||
// the following fields are read-only | ||
creds TokenCredential | ||
options TokenRequestOptions | ||
} | ||
|
||
func NewBearerTokenPolicy(cred TokenCredential, opts AuthenticationPolicyOptions) Policy { | ||
return &bearerTokenPolicy{ | ||
cond: sync.NewCond(&sync.Mutex{}), | ||
creds: cred, | ||
options: opts.Options, | ||
} | ||
} | ||
|
||
func (b *bearerTokenPolicy) Do(req *Request) (*Response, error) { | ||
if req.URL.Scheme != "https" { | ||
// HTTPS must be used, otherwise the tokens are at the risk of being exposed | ||
return nil, errors.New("token credentials require a URL using the HTTPS protocol scheme") | ||
} | ||
// create a "refresh window" before the token's real expiration date. | ||
// this allows callers to continue to use the old token while the | ||
// refresh is in progress. | ||
const window = 2 * time.Minute | ||
now, getToken, header := time.Now(), false, "" | ||
// acquire exclusive lock | ||
b.cond.L.Lock() | ||
for { | ||
if b.expiresOn.IsZero() || b.expiresOn.Before(now) { | ||
// token was never obtained or has expired | ||
if !b.renewing { | ||
// another go routine isn't refreshing the token so this one will | ||
b.renewing = true | ||
getToken = true | ||
break | ||
} | ||
// getting here means this go routine will wait for the token to refresh | ||
} else if b.expiresOn.Add(-window).Before(now) { | ||
// token is within the expiration window | ||
if !b.renewing { | ||
// another go routine isn't refreshing the token so this one will | ||
b.renewing = true | ||
getToken = true | ||
break | ||
} | ||
// this go routine will use the existing token while another refreshes it | ||
header = b.header | ||
break | ||
} else { | ||
// token is not expiring yet so use it as-is | ||
header = b.header | ||
break | ||
} | ||
// wait for the token to refresh | ||
b.cond.Wait() | ||
} | ||
b.cond.L.Unlock() | ||
if getToken { | ||
// this go routine has been elected to refresh the token | ||
tk, err := b.creds.GetToken(req.Context(), b.options) | ||
// update shared state | ||
b.cond.L.Lock() | ||
// to avoid a deadlock if GetToken() fails we MUST reset b.renewing to false before returning | ||
b.renewing = false | ||
if err != nil { | ||
b.unlock() | ||
return nil, err | ||
} | ||
header = bearerTokenPrefix + tk.Token | ||
b.header = header | ||
b.expiresOn = tk.ExpiresOn | ||
b.unlock() | ||
} | ||
req.Request.Header.Set(HeaderXmsDate, time.Now().UTC().Format(http.TimeFormat)) | ||
req.Request.Header.Set(HeaderAuthorization, header) | ||
return req.Next() | ||
} | ||
|
||
// signal any waiters that the token has been refreshed | ||
func (b *bearerTokenPolicy) unlock() { | ||
b.cond.Broadcast() | ||
b.cond.L.Unlock() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
// +build go1.13 | ||
|
||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
package azcore | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"net/http" | ||
"reflect" | ||
"testing" | ||
"time" | ||
|
||
"github.com/Azure/azure-sdk-for-go/sdk/internal/mock" | ||
) | ||
|
||
type dummyCredential struct{} | ||
|
||
func (c *dummyCredential) GetToken(ctx context.Context, options TokenRequestOptions) (*AccessToken, error) { | ||
return &AccessToken{ | ||
Token: "success_token", | ||
ExpiresOn: time.Date(2021, 06, 25, 3, 20, 0, 0, time.UTC), | ||
}, nil | ||
} | ||
|
||
func (c *dummyCredential) AuthenticationPolicy(options AuthenticationPolicyOptions) Policy { | ||
return NewBearerTokenPolicy(c, options) | ||
} | ||
|
||
func TestBearerTokenPolicyHTTPFail(t *testing.T) { | ||
srv, close := mock.NewServer() | ||
defer close() | ||
srv.SetResponse(mock.WithStatusCode(http.StatusOK)) | ||
cred := &dummyCredential{} | ||
pl := NewPipeline(srv, NewBearerTokenPolicy(cred, AuthenticationPolicyOptions{})) | ||
req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) | ||
if err != nil { | ||
t.Fatalf("unexpected error: %v", err) | ||
} | ||
_, err = pl.Do(req) | ||
if err == nil { | ||
t.Fatalf("expected an error but did not receive one") | ||
} | ||
} | ||
|
||
func TestBearerTokenPolicy(t *testing.T) { | ||
srv, close := mock.NewTLSServer() | ||
defer close() | ||
srv.SetResponse(mock.WithStatusCode(http.StatusOK)) | ||
cred := &dummyCredential{} | ||
pl := NewPipeline(srv, NewBearerTokenPolicy(cred, AuthenticationPolicyOptions{})) | ||
req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) | ||
if err != nil { | ||
t.Fatalf("unexpected error: %v", err) | ||
} | ||
resp, err := pl.Do(req) | ||
if err != nil { | ||
t.Fatalf("unexpected error: %v", err) | ||
} | ||
if !reflect.DeepEqual(req.Header, resp.Request.Header) { | ||
t.Fatal("unexpected modification to request headers") | ||
} | ||
if resp.Request.Header.Get(HeaderAuthorization) != fmt.Sprintf("Bearer %s", "success_token") { | ||
t.Fatalf("unexpected value in Authorization header: %v", resp.Request.Header.Get(HeaderAuthorization)) | ||
} | ||
} |