Skip to content

Commit

Permalink
adding bearer token policy
Browse files Browse the repository at this point in the history
  • Loading branch information
catalinaperalta committed Jun 23, 2021
1 parent e429076 commit ebe82ab
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 0 deletions.
111 changes: 111 additions & 0 deletions sdk/azcore/policy_bearer_token.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package azcore

import (
"errors"
"net/http"
"sync"
"time"
)

const (
bearerTokenPrefix = "Bearer "
)

type bearerTokenPolicy struct {
// cond is used to synchronize token refresh. the locker
// must be locked when updating the following shared state.
cond *sync.Cond

// renewing indicates that the token is in the process of being refreshed
renewing bool

// header contains the authorization header value
header string

// expiresOn is when the token will expire
expiresOn time.Time

// the following fields are read-only
creds TokenCredential
options TokenRequestOptions
}

func NewBearerTokenPolicy(cred TokenCredential, opts AuthenticationPolicyOptions) Policy {
return &bearerTokenPolicy{
cond: sync.NewCond(&sync.Mutex{}),
creds: cred,
options: opts.Options,
}
}

func (b *bearerTokenPolicy) Do(req *Request) (*Response, error) {
if req.URL.Scheme != "https" {
// HTTPS must be used, otherwise the tokens are at the risk of being exposed
return nil, errors.New("token credentials require a URL using the HTTPS protocol scheme")
}
// create a "refresh window" before the token's real expiration date.
// this allows callers to continue to use the old token while the
// refresh is in progress.
const window = 2 * time.Minute
now, getToken, header := time.Now(), false, ""
// acquire exclusive lock
b.cond.L.Lock()
for {
if b.expiresOn.IsZero() || b.expiresOn.Before(now) {
// token was never obtained or has expired
if !b.renewing {
// another go routine isn't refreshing the token so this one will
b.renewing = true
getToken = true
break
}
// getting here means this go routine will wait for the token to refresh
} else if b.expiresOn.Add(-window).Before(now) {
// token is within the expiration window
if !b.renewing {
// another go routine isn't refreshing the token so this one will
b.renewing = true
getToken = true
break
}
// this go routine will use the existing token while another refreshes it
header = b.header
break
} else {
// token is not expiring yet so use it as-is
header = b.header
break
}
// wait for the token to refresh
b.cond.Wait()
}
b.cond.L.Unlock()
if getToken {
// this go routine has been elected to refresh the token
tk, err := b.creds.GetToken(req.Context(), b.options)
// update shared state
b.cond.L.Lock()
// to avoid a deadlock if GetToken() fails we MUST reset b.renewing to false before returning
b.renewing = false
if err != nil {
b.unlock()
return nil, err
}
header = bearerTokenPrefix + tk.Token
b.header = header
b.expiresOn = tk.ExpiresOn
b.unlock()
}
req.Request.Header.Set(HeaderXmsDate, time.Now().UTC().Format(http.TimeFormat))
req.Request.Header.Set(HeaderAuthorization, header)
return req.Next()
}

// signal any waiters that the token has been refreshed
func (b *bearerTokenPolicy) unlock() {
b.cond.Broadcast()
b.cond.L.Unlock()
}
68 changes: 68 additions & 0 deletions sdk/azcore/policy_bearer_token_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// +build go1.13

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package azcore

import (
"context"
"fmt"
"net/http"
"reflect"
"testing"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/internal/mock"
)

type dummyCredential struct{}

func (c *dummyCredential) GetToken(ctx context.Context, options TokenRequestOptions) (*AccessToken, error) {
return &AccessToken{
Token: "success_token",
ExpiresOn: time.Date(2021, 06, 25, 3, 20, 0, 0, time.UTC),
}, nil
}

func (c *dummyCredential) AuthenticationPolicy(options AuthenticationPolicyOptions) Policy {
return NewBearerTokenPolicy(c, options)
}

func TestBearerTokenPolicyHTTPFail(t *testing.T) {
srv, close := mock.NewServer()
defer close()
srv.SetResponse(mock.WithStatusCode(http.StatusOK))
cred := &dummyCredential{}
pl := NewPipeline(srv, NewBearerTokenPolicy(cred, AuthenticationPolicyOptions{}))
req, err := NewRequest(context.Background(), http.MethodGet, srv.URL())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
_, err = pl.Do(req)
if err == nil {
t.Fatalf("expected an error but did not receive one")
}
}

func TestBearerTokenPolicy(t *testing.T) {
srv, close := mock.NewTLSServer()
defer close()
srv.SetResponse(mock.WithStatusCode(http.StatusOK))
cred := &dummyCredential{}
pl := NewPipeline(srv, NewBearerTokenPolicy(cred, AuthenticationPolicyOptions{}))
req, err := NewRequest(context.Background(), http.MethodGet, srv.URL())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
resp, err := pl.Do(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !reflect.DeepEqual(req.Header, resp.Request.Header) {
t.Fatal("unexpected modification to request headers")
}
if resp.Request.Header.Get(HeaderAuthorization) != fmt.Sprintf("Bearer %s", "success_token") {
t.Fatalf("unexpected value in Authorization header: %v", resp.Request.Header.Get(HeaderAuthorization))
}
}

0 comments on commit ebe82ab

Please sign in to comment.