Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent persistent cache data races #402

Merged
merged 10 commits into from
Apr 13, 2023
234 changes: 91 additions & 143 deletions apps/internal/base/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"net/url"
"reflect"
"strings"
"sync"
"time"

"github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache"
Expand All @@ -27,31 +28,21 @@ const (
)

// manager provides an internal cache. It is defined to allow faking the cache in tests.
// In all production use it is a *storage.Manager.
// In production it's a *storage.Manager or *storage.PartitionedManager.
type manager interface {
Read(ctx context.Context, authParameters authority.AuthParams, account shared.Account) (storage.TokenResponse, error)
Write(authParameters authority.AuthParams, tokenResponse accesstokens.TokenResponse) (shared.Account, error)
cache.Serializer
Read(context.Context, authority.AuthParams) (storage.TokenResponse, error)
Write(authority.AuthParams, accesstokens.TokenResponse) (shared.Account, error)
}

// accountManager is a manager that also caches accounts. In production it's a *storage.Manager.
type accountManager interface {
manager
AllAccounts() []shared.Account
Account(homeAccountID string) shared.Account
RemoveAccount(account shared.Account, clientID string)
}

// partitionedManager provides an internal cache. It is defined to allow faking the cache in tests.
// In all production use it is a *storage.PartitionedManager.
type partitionedManager interface {
Read(ctx context.Context, authParameters authority.AuthParams) (storage.TokenResponse, error)
Write(authParameters authority.AuthParams, tokenResponse accesstokens.TokenResponse) (shared.Account, error)
}

type noopCacheAccessor struct{}

func (n noopCacheAccessor) Replace(ctx context.Context, u cache.Unmarshaler, h cache.ReplaceHints) error {
return nil
}
func (n noopCacheAccessor) Export(ctx context.Context, m cache.Marshaler, h cache.ExportHints) error {
return nil
}

// AcquireTokenSilentParameters contains the parameters to acquire a token silently (from cache).
type AcquireTokenSilentParameters struct {
Scopes []string
Expand Down Expand Up @@ -137,12 +128,14 @@ func NewAuthResult(tokenResponse accesstokens.TokenResponse, account shared.Acco
// Client is a base client that provides access to common methods and primatives that
// can be used by multiple clients.
type Client struct {
Token *oauth.Client
manager manager // *storage.Manager or fakeManager in tests
pmanager partitionedManager // *storage.PartitionedManager or fakeManager in tests

AuthParams authority.AuthParams // DO NOT EVER MAKE THIS A POINTER! See "Note" in New().
cacheAccessor cache.ExportReplace
Token *oauth.Client
manager accountManager // *storage.Manager or fakeManager in tests
// pmanager is a partitioned cache for OBO authentication. *storage.PartitionedManager or fakeManager in tests
pmanager manager

AuthParams authority.AuthParams // DO NOT EVER MAKE THIS A POINTER! See "Note" in New().
cacheAccessor cache.ExportReplace
cacheAccessorMu *sync.RWMutex
}

// Option is an optional argument to the New constructor.
Expand Down Expand Up @@ -214,11 +207,11 @@ func New(clientID string, authorityURI string, token *oauth.Client, options ...O
}
authParams := authority.NewAuthParams(clientID, authInfo)
client := Client{ // Note: Hey, don't even THINK about making Base into *Base. See "design notes" in public.go and confidential.go
Token: token,
AuthParams: authParams,
cacheAccessor: noopCacheAccessor{},
manager: storage.New(token),
pmanager: storage.NewPartitionedManager(token),
Token: token,
AuthParams: authParams,
cacheAccessorMu: &sync.RWMutex{},
manager: storage.New(token),
pmanager: storage.NewPartitionedManager(token),
}
for _, o := range options {
if err = o(&client); err != nil {
Expand Down Expand Up @@ -283,8 +276,9 @@ func (b Client) AuthCodeURL(ctx context.Context, clientID, redirectURI string, s
return baseURL.String(), nil
}

func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilentParameters) (ar AuthResult, err error) {
// when tenant == "", the caller didn't specify a tenant and WithTenant will use the client's configured tenant
func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilentParameters) (AuthResult, error) {
ar := AuthResult{}
// when tenant == "", the caller didn't specify a tenant and WithTenant will choose the client's configured tenant
tenant := silent.TenantID
authParams, err := b.AuthParams.WithTenant(tenant)
if err != nil {
Expand All @@ -296,38 +290,23 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen
authParams.Claims = silent.Claims
authParams.UserAssertion = silent.UserAssertion

var storageTokenResponse storage.TokenResponse
if authParams.AuthorizationType == authority.ATOnBehalfOf {
if s, ok := b.pmanager.(cache.Serializer); ok {
suggestedCacheKey := authParams.CacheKey(silent.IsAppCache)
err = b.cacheAccessor.Replace(ctx, s, cache.ReplaceHints{PartitionKey: suggestedCacheKey})
if err != nil {
return ar, err
}
defer func() {
err = b.export(ctx, s, suggestedCacheKey, err)
}()
}
storageTokenResponse, err = b.pmanager.Read(ctx, authParams)
if err != nil {
return ar, err
}
} else {
if s, ok := b.manager.(cache.Serializer); ok {
suggestedCacheKey := authParams.CacheKey(silent.IsAppCache)
err = b.cacheAccessor.Replace(ctx, s, cache.ReplaceHints{PartitionKey: suggestedCacheKey})
if err != nil {
return ar, err
}
defer func() {
err = b.export(ctx, s, suggestedCacheKey, err)
}()
}
m := b.pmanager
if authParams.AuthorizationType != authority.ATOnBehalfOf {
authParams.AuthorizationType = authority.ATRefreshToken
storageTokenResponse, err = b.manager.Read(ctx, authParams, silent.Account)
if err != nil {
return ar, err
}
m = b.manager
}
if b.cacheAccessor != nil {
key := authParams.CacheKey(silent.IsAppCache)
b.cacheAccessorMu.RLock()
err = b.cacheAccessor.Replace(ctx, m, cache.ReplaceHints{PartitionKey: key})
b.cacheAccessorMu.RUnlock()
}
if err != nil {
return ar, err
}
storageTokenResponse, err := m.Read(ctx, authParams)
if err != nil {
return ar, err
}

// ignore cached access tokens when given claims
Expand All @@ -340,21 +319,17 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen

// redeem a cached refresh token, if available
if reflect.ValueOf(storageTokenResponse.RefreshToken).IsZero() {
err = errors.New("no token found")
return ar, err
return ar, errors.New("no token found")
}
var cc *accesstokens.Credential
if silent.RequestType == accesstokens.ATConfidential {
cc = silent.Credential
}

token, err := b.Token.Refresh(ctx, silent.RequestType, authParams, cc, storageTokenResponse.RefreshToken)
if err != nil {
return ar, err
}

ar, err = b.AuthResultFromToken(ctx, authParams, token, true)
return ar, err
return b.AuthResultFromToken(ctx, authParams, token, true)
}

func (b Client) AcquireTokenByAuthCode(ctx context.Context, authCodeParams AcquireTokenAuthCodeParameters) (AuthResult, error) {
Expand Down Expand Up @@ -417,103 +392,76 @@ func (b Client) AcquireTokenOnBehalfOf(ctx context.Context, onBehalfOfParams Acq
return ar, err
}

func (b Client) AuthResultFromToken(ctx context.Context, authParams authority.AuthParams, token accesstokens.TokenResponse, cacheWrite bool) (ar AuthResult, err error) {
func (b Client) AuthResultFromToken(ctx context.Context, authParams authority.AuthParams, token accesstokens.TokenResponse, cacheWrite bool) (AuthResult, error) {
if !cacheWrite {
return NewAuthResult(token, shared.Account{})
}

var account shared.Account
var m manager = b.manager
if authParams.AuthorizationType == authority.ATOnBehalfOf {
if s, ok := b.pmanager.(cache.Serializer); ok {
suggestedCacheKey := token.CacheKey(authParams)
err = b.cacheAccessor.Replace(ctx, s, cache.ReplaceHints{PartitionKey: suggestedCacheKey})
if err != nil {
return ar, err
}
defer func() {
err = b.export(ctx, s, suggestedCacheKey, err)
}()
}
account, err = b.pmanager.Write(authParams, token)
m = b.pmanager
}
key := token.CacheKey(authParams)
if b.cacheAccessor != nil {
b.cacheAccessorMu.Lock()
defer b.cacheAccessorMu.Unlock()
err := b.cacheAccessor.Replace(ctx, m, cache.ReplaceHints{PartitionKey: key})
if err != nil {
return ar, err
}
} else {
if s, ok := b.manager.(cache.Serializer); ok {
suggestedCacheKey := token.CacheKey(authParams)
err = b.cacheAccessor.Replace(ctx, s, cache.ReplaceHints{PartitionKey: suggestedCacheKey})
if err != nil {
return ar, err
}
defer func() {
err = b.export(ctx, s, suggestedCacheKey, err)
}()
}
account, err = b.manager.Write(authParams, token)
if err != nil {
return ar, err
return AuthResult{}, err
}
}
ar, err = NewAuthResult(token, account)
account, err := m.Write(authParams, token)
if err != nil {
return AuthResult{}, err
}
ar, err := NewAuthResult(token, account)
if err == nil && b.cacheAccessor != nil {
err = b.cacheAccessor.Export(ctx, b.manager, cache.ExportHints{PartitionKey: key})
}
return ar, err
}

func (b Client) AllAccounts(ctx context.Context) (accts []shared.Account, err error) {
if s, ok := b.manager.(cache.Serializer); ok {
suggestedCacheKey := b.AuthParams.CacheKey(false)
err = b.cacheAccessor.Replace(ctx, s, cache.ReplaceHints{PartitionKey: suggestedCacheKey})
func (b Client) AllAccounts(ctx context.Context) ([]shared.Account, error) {
if b.cacheAccessor != nil {
b.cacheAccessorMu.RLock()
defer b.cacheAccessorMu.RUnlock()
key := b.AuthParams.CacheKey(false)
err := b.cacheAccessor.Replace(ctx, b.manager, cache.ReplaceHints{PartitionKey: key})
if err != nil {
return accts, err
return nil, err
}
defer func() {
err = b.export(ctx, s, suggestedCacheKey, err)
jhendrixMSFT marked this conversation as resolved.
Show resolved Hide resolved
}()
}

accts = b.manager.AllAccounts()
return accts, err
return b.manager.AllAccounts(), nil
}

func (b Client) Account(ctx context.Context, homeAccountID string) (acct shared.Account, err error) {
authParams := b.AuthParams // This is a copy, as we dont' have a pointer receiver and .AuthParams is not a pointer.
authParams.AuthorizationType = authority.AccountByID
authParams.HomeAccountID = homeAccountID
if s, ok := b.manager.(cache.Serializer); ok {
suggestedCacheKey := b.AuthParams.CacheKey(false)
err = b.cacheAccessor.Replace(ctx, s, cache.ReplaceHints{PartitionKey: suggestedCacheKey})
func (b Client) Account(ctx context.Context, homeAccountID string) (shared.Account, error) {
if b.cacheAccessor != nil {
b.cacheAccessorMu.RLock()
defer b.cacheAccessorMu.RUnlock()
authParams := b.AuthParams // This is a copy, as we don't have a pointer receiver and .AuthParams is not a pointer.
authParams.AuthorizationType = authority.AccountByID
authParams.HomeAccountID = homeAccountID
key := b.AuthParams.CacheKey(false)
err := b.cacheAccessor.Replace(ctx, b.manager, cache.ReplaceHints{PartitionKey: key})
if err != nil {
return acct, err
return shared.Account{}, err
}
defer func() {
err = b.export(ctx, s, suggestedCacheKey, err)
jhendrixMSFT marked this conversation as resolved.
Show resolved Hide resolved
}()
}
acct = b.manager.Account(homeAccountID)
return acct, err
return b.manager.Account(homeAccountID), nil
}

// RemoveAccount removes all the ATs, RTs and IDTs from the cache associated with this account.
func (b Client) RemoveAccount(ctx context.Context, account shared.Account) (err error) {
if s, ok := b.manager.(cache.Serializer); ok {
suggestedCacheKey := b.AuthParams.CacheKey(false)
err = b.cacheAccessor.Replace(ctx, s, cache.ReplaceHints{PartitionKey: suggestedCacheKey})
if err != nil {
return err
}
defer func() {
err = b.export(ctx, s, suggestedCacheKey, err)
}()
func (b Client) RemoveAccount(ctx context.Context, account shared.Account) error {
if b.cacheAccessor == nil {
b.manager.RemoveAccount(account, b.AuthParams.ClientID)
return nil
}
b.manager.RemoveAccount(account, b.AuthParams.ClientID)
return err
}

// export helps other methods defer exporting the cache after possibly updating its in-memory content.
// err is the error the calling method will return. If err isn't nil, export returns it without
// exporting the cache.
func (b Client) export(ctx context.Context, marshal cache.Marshaler, key string, err error) error {
b.cacheAccessorMu.Lock()
defer b.cacheAccessorMu.Unlock()
key := b.AuthParams.CacheKey(false)
err := b.cacheAccessor.Replace(ctx, b.manager, cache.ReplaceHints{PartitionKey: key})
if err != nil {
return err
}
return b.cacheAccessor.Export(ctx, marshal, cache.ExportHints{PartitionKey: key})
b.manager.RemoveAccount(account, b.AuthParams.ClientID)
return b.cacheAccessor.Export(ctx, b.manager, cache.ExportHints{PartitionKey: key})
}
Loading