Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bearer token policy to azcore #14889

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions sdk/azcore/policy_bearer_token.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package azcore

import (
"errors"
"net/http"
"sync"
"time"
)

const (
bearerTokenPrefix = "Bearer "
)

type bearerTokenPolicy struct {
// cond is used to synchronize token refresh. the locker
// must be locked when updating the following shared state.
cond *sync.Cond

// renewing indicates that the token is in the process of being refreshed
renewing bool

// header contains the authorization header value
header string

// expiresOn is when the token will expire
expiresOn time.Time

// the following fields are read-only
creds TokenCredential
options TokenRequestOptions
}

func NewBearerTokenPolicy(cred TokenCredential, opts AuthenticationPolicyOptions) Policy {
return &bearerTokenPolicy{
cond: sync.NewCond(&sync.Mutex{}),
creds: cred,
options: opts.Options,
}
}

func (b *bearerTokenPolicy) Do(req *Request) (*Response, error) {
if req.URL.Scheme != "https" {
// HTTPS must be used, otherwise the tokens are at the risk of being exposed
return nil, errors.New("token credentials require a URL using the HTTPS protocol scheme")
}
// create a "refresh window" before the token's real expiration date.
// this allows callers to continue to use the old token while the
// refresh is in progress.
const window = 2 * time.Minute
now, getToken, header := time.Now(), false, ""
// acquire exclusive lock
b.cond.L.Lock()
for {
if b.expiresOn.IsZero() || b.expiresOn.Before(now) {
// token was never obtained or has expired
if !b.renewing {
// another go routine isn't refreshing the token so this one will
b.renewing = true
getToken = true
break
}
// getting here means this go routine will wait for the token to refresh
} else if b.expiresOn.Add(-window).Before(now) {
// token is within the expiration window
if !b.renewing {
// another go routine isn't refreshing the token so this one will
b.renewing = true
getToken = true
break
}
// this go routine will use the existing token while another refreshes it
header = b.header
break
} else {
// token is not expiring yet so use it as-is
header = b.header
break
}
// wait for the token to refresh
b.cond.Wait()
}
b.cond.L.Unlock()
if getToken {
// this go routine has been elected to refresh the token
tk, err := b.creds.GetToken(req.Context(), b.options)
// update shared state
b.cond.L.Lock()
// to avoid a deadlock if GetToken() fails we MUST reset b.renewing to false before returning
b.renewing = false
if err != nil {
b.unlock()
return nil, err
}
header = bearerTokenPrefix + tk.Token
b.header = header
b.expiresOn = tk.ExpiresOn
b.unlock()
}
req.Request.Header.Set(HeaderXmsDate, time.Now().UTC().Format(http.TimeFormat))
req.Request.Header.Set(HeaderAuthorization, header)
return req.Next()
}

// signal any waiters that the token has been refreshed
func (b *bearerTokenPolicy) unlock() {
b.cond.Broadcast()
b.cond.L.Unlock()
}
68 changes: 68 additions & 0 deletions sdk/azcore/policy_bearer_token_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// +build go1.13

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package azcore

import (
"context"
"fmt"
"net/http"
"reflect"
"testing"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/internal/mock"
)

type dummyCredential struct{}

func (c *dummyCredential) GetToken(ctx context.Context, options TokenRequestOptions) (*AccessToken, error) {
return &AccessToken{
Token: "success_token",
ExpiresOn: time.Date(2021, 06, 25, 3, 20, 0, 0, time.UTC),
}, nil
}

func (c *dummyCredential) AuthenticationPolicy(options AuthenticationPolicyOptions) Policy {
return NewBearerTokenPolicy(c, options)
}

func TestBearerTokenPolicyHTTPFail(t *testing.T) {
srv, close := mock.NewServer()
defer close()
srv.SetResponse(mock.WithStatusCode(http.StatusOK))
cred := &dummyCredential{}
pl := NewPipeline(srv, NewBearerTokenPolicy(cred, AuthenticationPolicyOptions{}))
req, err := NewRequest(context.Background(), http.MethodGet, srv.URL())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
_, err = pl.Do(req)
if err == nil {
t.Fatalf("expected an error but did not receive one")
}
}

func TestBearerTokenPolicy(t *testing.T) {
srv, close := mock.NewTLSServer()
defer close()
srv.SetResponse(mock.WithStatusCode(http.StatusOK))
cred := &dummyCredential{}
pl := NewPipeline(srv, NewBearerTokenPolicy(cred, AuthenticationPolicyOptions{}))
req, err := NewRequest(context.Background(), http.MethodGet, srv.URL())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
resp, err := pl.Do(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !reflect.DeepEqual(req.Header, resp.Request.Header) {
t.Fatal("unexpected modification to request headers")
}
if resp.Request.Header.Get(HeaderAuthorization) != fmt.Sprintf("Bearer %s", "success_token") {
t.Fatalf("unexpected value in Authorization header: %v", resp.Request.Header.Get(HeaderAuthorization))
}
}