diff --git a/README.md b/README.md index c7935dac..893b1568 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,7 @@ Acquiring tokens with MSAL Go follows this general three step pattern. There mig * Initializing a public client: ```go - publicClientApp, err := public.New("client_id", public.WithAuthority("https://login.microsoftonline.com/Enter_The_Tenant_Name_Here")) + publicClientApp, err := public.New("client_id", public.WithAuthority("https://login.microsoft.com/Enter_The_Tenant_Name_Here")) ``` * Initializing a confidential client: @@ -54,7 +54,7 @@ Acquiring tokens with MSAL Go follows this general three step pattern. There mig if err != nil { return nil, fmt.Errorf("could not create a cred from a secret: %w", err) } - confidentialClientApp, err := confidential.New("client_id", cred, confidential.WithAuthority("https://login.microsoftonline.com/Enter_The_Tenant_Name_Here")) + confidentialClientApp, err := confidential.New("client_id", cred, confidential.WithAuthority("https://login.microsoft.com/Enter_The_Tenant_Name_Here")) ``` 1. MSAL comes packaged with an in-memory cache. Utilizing the cache is optional, but we would highly recommend it. diff --git a/apps/internal/base/base.go b/apps/internal/base/base.go index 00617abf..5f68384f 100644 --- a/apps/internal/base/base.go +++ b/apps/internal/base/base.go @@ -10,6 +10,7 @@ import ( "net/url" "reflect" "strings" + "sync" "time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" @@ -27,31 +28,21 @@ const ( ) // manager provides an internal cache. It is defined to allow faking the cache in tests. -// In all production use it is a *storage.Manager. +// In production it's a *storage.Manager or *storage.PartitionedManager. type manager interface { - Read(ctx context.Context, authParameters authority.AuthParams, account shared.Account) (storage.TokenResponse, error) - Write(authParameters authority.AuthParams, tokenResponse accesstokens.TokenResponse) (shared.Account, error) + cache.Serializer + Read(context.Context, authority.AuthParams) (storage.TokenResponse, error) + Write(authority.AuthParams, accesstokens.TokenResponse) (shared.Account, error) +} + +// accountManager is a manager that also caches accounts. In production it's a *storage.Manager. +type accountManager interface { + manager AllAccounts() []shared.Account Account(homeAccountID string) shared.Account RemoveAccount(account shared.Account, clientID string) } -// partitionedManager provides an internal cache. It is defined to allow faking the cache in tests. -// In all production use it is a *storage.PartitionedManager. -type partitionedManager interface { - Read(ctx context.Context, authParameters authority.AuthParams) (storage.TokenResponse, error) - Write(authParameters authority.AuthParams, tokenResponse accesstokens.TokenResponse) (shared.Account, error) -} - -type noopCacheAccessor struct{} - -func (n noopCacheAccessor) Replace(ctx context.Context, u cache.Unmarshaler, h cache.ReplaceHints) error { - return nil -} -func (n noopCacheAccessor) Export(ctx context.Context, m cache.Marshaler, h cache.ExportHints) error { - return nil -} - // AcquireTokenSilentParameters contains the parameters to acquire a token silently (from cache). type AcquireTokenSilentParameters struct { Scopes []string @@ -137,12 +128,14 @@ func NewAuthResult(tokenResponse accesstokens.TokenResponse, account shared.Acco // Client is a base client that provides access to common methods and primatives that // can be used by multiple clients. type Client struct { - Token *oauth.Client - manager manager // *storage.Manager or fakeManager in tests - pmanager partitionedManager // *storage.PartitionedManager or fakeManager in tests - - AuthParams authority.AuthParams // DO NOT EVER MAKE THIS A POINTER! See "Note" in New(). - cacheAccessor cache.ExportReplace + Token *oauth.Client + manager accountManager // *storage.Manager or fakeManager in tests + // pmanager is a partitioned cache for OBO authentication. *storage.PartitionedManager or fakeManager in tests + pmanager manager + + AuthParams authority.AuthParams // DO NOT EVER MAKE THIS A POINTER! See "Note" in New(). + cacheAccessor cache.ExportReplace + cacheAccessorMu *sync.RWMutex } // Option is an optional argument to the New constructor. @@ -214,11 +207,11 @@ func New(clientID string, authorityURI string, token *oauth.Client, options ...O } authParams := authority.NewAuthParams(clientID, authInfo) client := Client{ // Note: Hey, don't even THINK about making Base into *Base. See "design notes" in public.go and confidential.go - Token: token, - AuthParams: authParams, - cacheAccessor: noopCacheAccessor{}, - manager: storage.New(token), - pmanager: storage.NewPartitionedManager(token), + Token: token, + AuthParams: authParams, + cacheAccessorMu: &sync.RWMutex{}, + manager: storage.New(token), + pmanager: storage.NewPartitionedManager(token), } for _, o := range options { if err = o(&client); err != nil { @@ -283,8 +276,9 @@ func (b Client) AuthCodeURL(ctx context.Context, clientID, redirectURI string, s return baseURL.String(), nil } -func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilentParameters) (ar AuthResult, err error) { - // when tenant == "", the caller didn't specify a tenant and WithTenant will use the client's configured tenant +func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilentParameters) (AuthResult, error) { + ar := AuthResult{} + // when tenant == "", the caller didn't specify a tenant and WithTenant will choose the client's configured tenant tenant := silent.TenantID authParams, err := b.AuthParams.WithTenant(tenant) if err != nil { @@ -296,38 +290,23 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen authParams.Claims = silent.Claims authParams.UserAssertion = silent.UserAssertion - var storageTokenResponse storage.TokenResponse - if authParams.AuthorizationType == authority.ATOnBehalfOf { - if s, ok := b.pmanager.(cache.Serializer); ok { - suggestedCacheKey := authParams.CacheKey(silent.IsAppCache) - err = b.cacheAccessor.Replace(ctx, s, cache.ReplaceHints{PartitionKey: suggestedCacheKey}) - if err != nil { - return ar, err - } - defer func() { - err = b.export(ctx, s, suggestedCacheKey, err) - }() - } - storageTokenResponse, err = b.pmanager.Read(ctx, authParams) - if err != nil { - return ar, err - } - } else { - if s, ok := b.manager.(cache.Serializer); ok { - suggestedCacheKey := authParams.CacheKey(silent.IsAppCache) - err = b.cacheAccessor.Replace(ctx, s, cache.ReplaceHints{PartitionKey: suggestedCacheKey}) - if err != nil { - return ar, err - } - defer func() { - err = b.export(ctx, s, suggestedCacheKey, err) - }() - } + m := b.pmanager + if authParams.AuthorizationType != authority.ATOnBehalfOf { authParams.AuthorizationType = authority.ATRefreshToken - storageTokenResponse, err = b.manager.Read(ctx, authParams, silent.Account) - if err != nil { - return ar, err - } + m = b.manager + } + if b.cacheAccessor != nil { + key := authParams.CacheKey(silent.IsAppCache) + b.cacheAccessorMu.RLock() + err = b.cacheAccessor.Replace(ctx, m, cache.ReplaceHints{PartitionKey: key}) + b.cacheAccessorMu.RUnlock() + } + if err != nil { + return ar, err + } + storageTokenResponse, err := m.Read(ctx, authParams) + if err != nil { + return ar, err } // ignore cached access tokens when given claims @@ -340,21 +319,17 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen // redeem a cached refresh token, if available if reflect.ValueOf(storageTokenResponse.RefreshToken).IsZero() { - err = errors.New("no token found") - return ar, err + return ar, errors.New("no token found") } var cc *accesstokens.Credential if silent.RequestType == accesstokens.ATConfidential { cc = silent.Credential } - token, err := b.Token.Refresh(ctx, silent.RequestType, authParams, cc, storageTokenResponse.RefreshToken) if err != nil { return ar, err } - - ar, err = b.AuthResultFromToken(ctx, authParams, token, true) - return ar, err + return b.AuthResultFromToken(ctx, authParams, token, true) } func (b Client) AcquireTokenByAuthCode(ctx context.Context, authCodeParams AcquireTokenAuthCodeParameters) (AuthResult, error) { @@ -417,103 +392,76 @@ func (b Client) AcquireTokenOnBehalfOf(ctx context.Context, onBehalfOfParams Acq return ar, err } -func (b Client) AuthResultFromToken(ctx context.Context, authParams authority.AuthParams, token accesstokens.TokenResponse, cacheWrite bool) (ar AuthResult, err error) { +func (b Client) AuthResultFromToken(ctx context.Context, authParams authority.AuthParams, token accesstokens.TokenResponse, cacheWrite bool) (AuthResult, error) { if !cacheWrite { return NewAuthResult(token, shared.Account{}) } - - var account shared.Account + var m manager = b.manager if authParams.AuthorizationType == authority.ATOnBehalfOf { - if s, ok := b.pmanager.(cache.Serializer); ok { - suggestedCacheKey := token.CacheKey(authParams) - err = b.cacheAccessor.Replace(ctx, s, cache.ReplaceHints{PartitionKey: suggestedCacheKey}) - if err != nil { - return ar, err - } - defer func() { - err = b.export(ctx, s, suggestedCacheKey, err) - }() - } - account, err = b.pmanager.Write(authParams, token) + m = b.pmanager + } + key := token.CacheKey(authParams) + if b.cacheAccessor != nil { + b.cacheAccessorMu.Lock() + defer b.cacheAccessorMu.Unlock() + err := b.cacheAccessor.Replace(ctx, m, cache.ReplaceHints{PartitionKey: key}) if err != nil { - return ar, err - } - } else { - if s, ok := b.manager.(cache.Serializer); ok { - suggestedCacheKey := token.CacheKey(authParams) - err = b.cacheAccessor.Replace(ctx, s, cache.ReplaceHints{PartitionKey: suggestedCacheKey}) - if err != nil { - return ar, err - } - defer func() { - err = b.export(ctx, s, suggestedCacheKey, err) - }() - } - account, err = b.manager.Write(authParams, token) - if err != nil { - return ar, err + return AuthResult{}, err } } - ar, err = NewAuthResult(token, account) + account, err := m.Write(authParams, token) + if err != nil { + return AuthResult{}, err + } + ar, err := NewAuthResult(token, account) + if err == nil && b.cacheAccessor != nil { + err = b.cacheAccessor.Export(ctx, b.manager, cache.ExportHints{PartitionKey: key}) + } return ar, err } -func (b Client) AllAccounts(ctx context.Context) (accts []shared.Account, err error) { - if s, ok := b.manager.(cache.Serializer); ok { - suggestedCacheKey := b.AuthParams.CacheKey(false) - err = b.cacheAccessor.Replace(ctx, s, cache.ReplaceHints{PartitionKey: suggestedCacheKey}) +func (b Client) AllAccounts(ctx context.Context) ([]shared.Account, error) { + if b.cacheAccessor != nil { + b.cacheAccessorMu.RLock() + defer b.cacheAccessorMu.RUnlock() + key := b.AuthParams.CacheKey(false) + err := b.cacheAccessor.Replace(ctx, b.manager, cache.ReplaceHints{PartitionKey: key}) if err != nil { - return accts, err + return nil, err } - defer func() { - err = b.export(ctx, s, suggestedCacheKey, err) - }() } - - accts = b.manager.AllAccounts() - return accts, err + return b.manager.AllAccounts(), nil } -func (b Client) Account(ctx context.Context, homeAccountID string) (acct shared.Account, err error) { - authParams := b.AuthParams // This is a copy, as we dont' have a pointer receiver and .AuthParams is not a pointer. - authParams.AuthorizationType = authority.AccountByID - authParams.HomeAccountID = homeAccountID - if s, ok := b.manager.(cache.Serializer); ok { - suggestedCacheKey := b.AuthParams.CacheKey(false) - err = b.cacheAccessor.Replace(ctx, s, cache.ReplaceHints{PartitionKey: suggestedCacheKey}) +func (b Client) Account(ctx context.Context, homeAccountID string) (shared.Account, error) { + if b.cacheAccessor != nil { + b.cacheAccessorMu.RLock() + defer b.cacheAccessorMu.RUnlock() + authParams := b.AuthParams // This is a copy, as we don't have a pointer receiver and .AuthParams is not a pointer. + authParams.AuthorizationType = authority.AccountByID + authParams.HomeAccountID = homeAccountID + key := b.AuthParams.CacheKey(false) + err := b.cacheAccessor.Replace(ctx, b.manager, cache.ReplaceHints{PartitionKey: key}) if err != nil { - return acct, err + return shared.Account{}, err } - defer func() { - err = b.export(ctx, s, suggestedCacheKey, err) - }() } - acct = b.manager.Account(homeAccountID) - return acct, err + return b.manager.Account(homeAccountID), nil } // RemoveAccount removes all the ATs, RTs and IDTs from the cache associated with this account. -func (b Client) RemoveAccount(ctx context.Context, account shared.Account) (err error) { - if s, ok := b.manager.(cache.Serializer); ok { - suggestedCacheKey := b.AuthParams.CacheKey(false) - err = b.cacheAccessor.Replace(ctx, s, cache.ReplaceHints{PartitionKey: suggestedCacheKey}) - if err != nil { - return err - } - defer func() { - err = b.export(ctx, s, suggestedCacheKey, err) - }() +func (b Client) RemoveAccount(ctx context.Context, account shared.Account) error { + if b.cacheAccessor == nil { + b.manager.RemoveAccount(account, b.AuthParams.ClientID) + return nil } - b.manager.RemoveAccount(account, b.AuthParams.ClientID) - return err -} - -// export helps other methods defer exporting the cache after possibly updating its in-memory content. -// err is the error the calling method will return. If err isn't nil, export returns it without -// exporting the cache. -func (b Client) export(ctx context.Context, marshal cache.Marshaler, key string, err error) error { + b.cacheAccessorMu.Lock() + defer b.cacheAccessorMu.Unlock() + key := b.AuthParams.CacheKey(false) + err := b.cacheAccessor.Replace(ctx, b.manager, cache.ReplaceHints{PartitionKey: key}) if err != nil { return err } - return b.cacheAccessor.Export(ctx, marshal, cache.ExportHints{PartitionKey: key}) + b.manager.RemoveAccount(account, b.AuthParams.ClientID) + return b.cacheAccessor.Export(ctx, b.manager, cache.ExportHints{PartitionKey: key}) } diff --git a/apps/internal/base/base_test.go b/apps/internal/base/base_test.go index 7aac8102..f2278981 100644 --- a/apps/internal/base/base_test.go +++ b/apps/internal/base/base_test.go @@ -133,6 +133,7 @@ func TestAcquireTokenSilentScopes(t *testing.T) { }, accesstokens.TokenResponse{ AccessToken: fakeAccessToken, + ClientInfo: accesstokens.ClientInfo{UID: "uid", UTID: "utid"}, ExpiresOn: internalTime.DurationTime{T: time.Now().Add(-time.Hour)}, GrantedScopes: accesstokens.Scopes{Slice: test.cachedTokenScopes}, IDToken: fakeIDToken, @@ -223,23 +224,27 @@ func TestCacheIOErrors(t *testing.T) { } t.Run(name, func(t *testing.T) { client := fakeClient(t, WithCacheAccessor(&cache)) - _, actual := client.Account(ctx, "...") - if !errors.Is(actual, expected) { - t.Fatalf(`expected "%v", got "%v"`, expected, actual) - } - _, actual = client.AcquireTokenByAuthCode(ctx, AcquireTokenAuthCodeParameters{AppType: accesstokens.ATConfidential}) - if !errors.Is(actual, expected) { - t.Fatalf(`expected "%v", got "%v"`, expected, actual) - } - _, actual = client.AcquireTokenOnBehalfOf(ctx, AcquireTokenOnBehalfOfParameters{Credential: &accesstokens.Credential{Secret: "..."}}) - if !errors.Is(actual, expected) { - t.Fatalf(`expected "%v", got "%v"`, expected, actual) + if !export { + // Account and AllAccounts don't export the cache, and AcquireTokenSilent does so + // only after redeeming a refresh token, so we test them only for replace errors + _, actual := client.Account(ctx, "...") + if !errors.Is(actual, expected) { + t.Fatalf(`expected "%v", got "%v"`, expected, actual) + } + _, actual = client.AllAccounts(ctx) + if !errors.Is(actual, expected) { + t.Fatalf(`expected "%v", got "%v"`, expected, actual) + } + _, actual = client.AcquireTokenSilent(ctx, AcquireTokenSilentParameters{Scopes: testScopes}) + if cache.replaceErr != nil && !errors.Is(actual, expected) { + t.Fatalf(`expected "%v", got "%v"`, expected, actual) + } } - _, actual = client.AcquireTokenSilent(ctx, AcquireTokenSilentParameters{}) + _, actual := client.AcquireTokenByAuthCode(ctx, AcquireTokenAuthCodeParameters{AppType: accesstokens.ATConfidential, Scopes: testScopes}) if !errors.Is(actual, expected) { t.Fatalf(`expected "%v", got "%v"`, expected, actual) } - _, actual = client.AllAccounts(ctx) + _, actual = client.AcquireTokenOnBehalfOf(ctx, AcquireTokenOnBehalfOfParameters{Credential: &accesstokens.Credential{Secret: "..."}, Scopes: testScopes}) if !errors.Is(actual, expected) { t.Fatalf(`expected "%v", got "%v"`, expected, actual) } @@ -254,6 +259,45 @@ func TestCacheIOErrors(t *testing.T) { }) } + // testing that AcquireTokenSilent propagates errors from Export requires more elaborate + // setup because that method exports the cache only after acquiring a new access token + t.Run("silent auth export error", func(t *testing.T) { + cache := failCache{} + hid := "uid.utid" + client := fakeClient(t, WithCacheAccessor(&cache)) + // cache fake tokens and app metadata + _, err := client.AuthResultFromToken(ctx, + authority.AuthParams{ + AuthorityInfo: authority.Info{Host: fakeAuthority}, + ClientID: fakeClientID, + HomeAccountID: hid, + Scopes: testScopes, + }, + accesstokens.TokenResponse{ + AccessToken: "at", + ClientInfo: accesstokens.ClientInfo{UID: "uid", UTID: "utid"}, + GrantedScopes: accesstokens.Scopes{Slice: testScopes}, + IDToken: fakeIDToken, + RefreshToken: "rt", + }, + true, + ) + if err != nil { + t.Fatal(err) + } + // AcquireTokenSilent should return this error after redeeming a refresh token + cache.exportErr = expected + _, actual := client.AcquireTokenSilent(ctx, + AcquireTokenSilentParameters{ + Account: shared.NewAccount(hid, fakeAuthority, "realm", "id", authority.AAD, "upn"), + Scopes: []string{"not-" + testScopes[0]}, + }, + ) + if !errors.Is(actual, expected) { + t.Fatalf(`expected "%v", got "%v"`, expected, actual) + } + }) + // when the client fails to acquire a token, it should return an error instead of exporting the cache t.Run("auth error", func(t *testing.T) { cache := failCache{} diff --git a/apps/internal/base/internal/storage/storage.go b/apps/internal/base/internal/storage/storage.go index 1c0471bb..add75192 100644 --- a/apps/internal/base/internal/storage/storage.go +++ b/apps/internal/base/internal/storage/storage.go @@ -83,7 +83,7 @@ func isMatchingScopes(scopesOne []string, scopesTwo string) bool { } // Read reads a storage token from the cache if it exists. -func (m *Manager) Read(ctx context.Context, authParameters authority.AuthParams, account shared.Account) (TokenResponse, error) { +func (m *Manager) Read(ctx context.Context, authParameters authority.AuthParams) (TokenResponse, error) { tr := TokenResponse{} homeAccountID := authParameters.HomeAccountID realm := authParameters.AuthorityInfo.Tenant @@ -103,7 +103,8 @@ func (m *Manager) Read(ctx context.Context, authParameters authority.AuthParams, accessToken := m.readAccessToken(homeAccountID, aliases, realm, clientID, scopes) tr.AccessToken = accessToken - if account.IsZero() { + if homeAccountID == "" { + // caller didn't specify a user, so there's no reason to search for an ID or refresh token return tr, nil } // errors returned by read* methods indicate a cache miss and are therefore non-fatal. We continue populating @@ -122,7 +123,7 @@ func (m *Manager) Read(ctx context.Context, authParameters authority.AuthParams, } } - account, err = m.readAccount(homeAccountID, aliases, realm) + account, err := m.readAccount(homeAccountID, aliases, realm) if err == nil { tr.Account = account } @@ -493,6 +494,8 @@ func (m *Manager) update(cache *Contract) { // Marshal implements cache.Marshaler. func (m *Manager) Marshal() ([]byte, error) { + m.contractMu.RLock() + defer m.contractMu.RUnlock() return json.Marshal(m.contract) } diff --git a/apps/internal/base/internal/storage/storage_test.go b/apps/internal/base/internal/storage/storage_test.go index 0f1f1dd5..cc39c3bc 100644 --- a/apps/internal/base/internal/storage/storage_test.go +++ b/apps/internal/base/internal/storage/storage_test.go @@ -792,7 +792,7 @@ func TestRead(t *testing.T) { manager := newForTest(responder) manager.update(contract) - got, err := manager.Read(context.Background(), authParameters, testAccount) + got, err := manager.Read(context.Background(), authParameters) switch { case err == nil && test.err: t.Errorf("TestRead(%s): got err == nil, want err != nil", test.desc) diff --git a/apps/internal/oauth/oauth.go b/apps/internal/oauth/oauth.go index 5f136933..ebd86e2b 100644 --- a/apps/internal/oauth/oauth.go +++ b/apps/internal/oauth/oauth.go @@ -76,12 +76,17 @@ func (t *Client) ResolveEndpoints(ctx context.Context, authorityInfo authority.I return t.Resolver.ResolveEndpoints(ctx, authorityInfo, userPrincipalName) } +// AADInstanceDiscovery attempts to discover a tenant endpoint (used in OIDC auth with an authorization endpoint). +// This is done by AAD which allows for aliasing of tenants (windows.sts.net is the same as login.windows.com). func (t *Client) AADInstanceDiscovery(ctx context.Context, authorityInfo authority.Info) (authority.InstanceDiscoveryResponse, error) { return t.Authority.AADInstanceDiscovery(ctx, authorityInfo) } // AuthCode returns a token based on an authorization code. func (t *Client) AuthCode(ctx context.Context, req accesstokens.AuthCodeRequest) (accesstokens.TokenResponse, error) { + if err := scopeError(req.AuthParams); err != nil { + return accesstokens.TokenResponse{}, err + } if err := t.resolveEndpoint(ctx, &req.AuthParams, ""); err != nil { return accesstokens.TokenResponse{}, err } @@ -107,6 +112,10 @@ func (t *Client) Credential(ctx context.Context, authParams authority.AuthParams } tr, err := cred.TokenProvider(ctx, params) if err != nil { + if len(scopes) == 0 { + err = fmt.Errorf("token request had an empty authority.AuthParams.Scopes, which may cause the following error: %w", err) + return accesstokens.TokenResponse{}, err + } return accesstokens.TokenResponse{}, err } return accesstokens.TokenResponse{ @@ -134,6 +143,9 @@ func (t *Client) Credential(ctx context.Context, authParams authority.AuthParams // Credential acquires a token from the authority using a client credentials grant. func (t *Client) OnBehalfOf(ctx context.Context, authParams authority.AuthParams, cred *accesstokens.Credential) (accesstokens.TokenResponse, error) { + if err := scopeError(authParams); err != nil { + return accesstokens.TokenResponse{}, err + } if err := t.resolveEndpoint(ctx, &authParams, ""); err != nil { return accesstokens.TokenResponse{}, err } @@ -145,20 +157,35 @@ func (t *Client) OnBehalfOf(ctx context.Context, authParams authority.AuthParams if err != nil { return accesstokens.TokenResponse{}, err } - return t.AccessTokens.FromUserAssertionClientCertificate(ctx, authParams, authParams.UserAssertion, jwt) + tr, err := t.AccessTokens.FromUserAssertionClientCertificate(ctx, authParams, authParams.UserAssertion, jwt) + if err != nil { + return accesstokens.TokenResponse{}, err + } + return tr, nil } func (t *Client) Refresh(ctx context.Context, reqType accesstokens.AppType, authParams authority.AuthParams, cc *accesstokens.Credential, refreshToken accesstokens.RefreshToken) (accesstokens.TokenResponse, error) { + if err := scopeError(authParams); err != nil { + return accesstokens.TokenResponse{}, err + } if err := t.resolveEndpoint(ctx, &authParams, ""); err != nil { return accesstokens.TokenResponse{}, err } - return t.AccessTokens.FromRefreshToken(ctx, reqType, authParams, cc, refreshToken.Secret) + tr, err := t.AccessTokens.FromRefreshToken(ctx, reqType, authParams, cc, refreshToken.Secret) + if err != nil { + return accesstokens.TokenResponse{}, err + } + return tr, nil } // UsernamePassword retrieves a token where a username and password is used. However, if this is // a user realm of "Federated", this uses SAML tokens. If "Managed", uses normal username/password. func (t *Client) UsernamePassword(ctx context.Context, authParams authority.AuthParams) (accesstokens.TokenResponse, error) { + if err := scopeError(authParams); err != nil { + return accesstokens.TokenResponse{}, err + } + if authParams.AuthorityInfo.AuthorityType == authority.ADFS { if err := t.resolveEndpoint(ctx, &authParams, authParams.Username); err != nil { return accesstokens.TokenResponse{}, err @@ -178,15 +205,25 @@ func (t *Client) UsernamePassword(ctx context.Context, authParams authority.Auth case authority.Federated: mexDoc, err := t.WSTrust.Mex(ctx, userRealm.FederationMetadataURL) if err != nil { - return accesstokens.TokenResponse{}, fmt.Errorf("problem getting mex doc from federated url(%s): %w", userRealm.FederationMetadataURL, err) + err = fmt.Errorf("problem getting mex doc from federated url(%s): %w", userRealm.FederationMetadataURL, err) + return accesstokens.TokenResponse{}, err } saml, err := t.WSTrust.SAMLTokenInfo(ctx, authParams, userRealm.CloudAudienceURN, mexDoc.UsernamePasswordEndpoint) if err != nil { - return accesstokens.TokenResponse{}, fmt.Errorf("problem getting SAML token info: %w", err) + err = fmt.Errorf("problem getting SAML token info: %w", err) + return accesstokens.TokenResponse{}, err + } + tr, err := t.AccessTokens.FromSamlGrant(ctx, authParams, saml) + if err != nil { + return accesstokens.TokenResponse{}, err } - return t.AccessTokens.FromSamlGrant(ctx, authParams, saml) + return tr, nil case authority.Managed: + if len(authParams.Scopes) == 0 { + err = fmt.Errorf("token request had an empty authority.AuthParams.Scopes, which may cause the following error: %w", err) + return accesstokens.TokenResponse{}, err + } return t.AccessTokens.FromUsernamePassword(ctx, authParams) } return accesstokens.TokenResponse{}, errors.New("unknown account type") @@ -274,6 +311,10 @@ func isWaitDeviceCodeErr(err error) bool { // DeviceCode returns a DeviceCode object that can be used to get the code that must be entered on the second // device and optionally the token once the code has been entered on the second device. func (t *Client) DeviceCode(ctx context.Context, authParams authority.AuthParams) (DeviceCode, error) { + if err := scopeError(authParams); err != nil { + return DeviceCode{}, err + } + if err := t.resolveEndpoint(ctx, &authParams, ""); err != nil { return DeviceCode{}, err } @@ -294,3 +335,19 @@ func (t *Client) resolveEndpoint(ctx context.Context, authParams *authority.Auth authParams.Endpoints = endpoints return nil } + +// scopeError takes an authority.AuthParams and returns an error +// if len(AuthParams.Scope) == 0. +func scopeError(a authority.AuthParams) error { + // TODO(someone): we could look deeper at the message to determine if + // it's a scope error, but this is a good start. + /* + {error":"invalid_scope","error_description":"AADSTS1002012: The provided value for scope + openid offline_access profile is not valid. Client credential flows must have a scope value + with /.default suffixed to the resource identifier (application ID URI)...} + */ + if len(a.Scopes) == 0 { + return fmt.Errorf("token request had an empty authority.AuthParams.Scopes, which is invalid") + } + return nil +} diff --git a/apps/internal/oauth/oauth_test.go b/apps/internal/oauth/oauth_test.go index b3ebacab..c0d6e074 100644 --- a/apps/internal/oauth/oauth_test.go +++ b/apps/internal/oauth/oauth_test.go @@ -25,6 +25,8 @@ import ( "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority" ) +var testScopes = []string{"scope"} + func TestAuthCode(t *testing.T) { tests := []struct { desc string @@ -56,7 +58,7 @@ func TestAuthCode(t *testing.T) { token.AccessTokens = test.at token.Resolver = test.re - _, err := token.AuthCode(context.Background(), accesstokens.AuthCodeRequest{}) + _, err := token.AuthCode(context.Background(), accesstokens.AuthCodeRequest{AuthParams: authority.AuthParams{Scopes: testScopes}}) switch { case err == nil && test.err: t.Errorf("TestAuthCode(%s): got err == nil, want err != nil", test.desc) @@ -183,7 +185,7 @@ func TestRefresh(t *testing.T) { _, err := token.Refresh( context.Background(), accesstokens.ATPublic, - authority.AuthParams{}, + authority.AuthParams{Scopes: testScopes}, &accesstokens.Credential{}, accesstokens.RefreshToken{}, ) @@ -263,7 +265,7 @@ func TestUsernamePassword(t *testing.T) { token.Resolver = test.re token.WSTrust = test.ws - _, err := token.UsernamePassword(context.Background(), authority.AuthParams{}) + _, err := token.UsernamePassword(context.Background(), authority.AuthParams{Scopes: testScopes}) switch { case err == nil && test.err: t.Errorf("TestUsernamePassword(%s): got err == nil, want err != nil", test.desc) @@ -382,7 +384,7 @@ func TestDeviceCodeToken(t *testing.T) { token.AccessTokens = test.at token.Resolver = test.re - dc, err := token.DeviceCode(context.Background(), authority.AuthParams{}) + dc, err := token.DeviceCode(context.Background(), authority.AuthParams{Scopes: testScopes}) switch { case err == nil && test.err: t.Errorf("TestDeviceCodeToken(%s): got err == nil, want err != nil", test.desc) diff --git a/apps/internal/oauth/ops/authority/authority.go b/apps/internal/oauth/ops/authority/authority.go index 5bebdb8e..7b2ccb4f 100644 --- a/apps/internal/oauth/ops/authority/authority.go +++ b/apps/internal/oauth/ops/authority/authority.go @@ -28,10 +28,19 @@ const ( regionName = "REGION_NAME" defaultAPIVersion = "2021-10-01" imdsEndpoint = "http://169.254.169.254/metadata/instance/compute/location?format=text&api-version=" + defaultAPIVersion - defaultHost = "login.microsoftonline.com" autoDetectRegion = "TryAutoDetect" ) +// These are various hosts that host AAD Instance discovery endpoints. +const ( + defaultHost = "login.microsoftonline.com" + loginMicrosoft = "login.microsoft.com" + loginWindows = "login.windows.net" + loginSTSWindows = "sts.windows.net" + loginMicrosoftOnline = defaultHost +) + +// jsonCaller is an interface that allows us to mock the JSONCall method. type jsonCaller interface { JSONCall(ctx context.Context, endpoint string, headers http.Header, qv url.Values, body, resp interface{}) error } @@ -54,6 +63,8 @@ func TrustedHost(host string) bool { return false } +// OAuthResponseBase is the base JSON return message for an OAuth call. +// This is embedded in other calls to get the base fields from every response. type OAuthResponseBase struct { Error string `json:"error"` SubError string `json:"suberror"` @@ -442,6 +453,8 @@ func (c Client) GetTenantDiscoveryResponse(ctx context.Context, openIDConfigurat return resp, err } +// AADInstanceDiscovery attempts to discover a tenant endpoint (used in OIDC auth with an authorization endpoint). +// This is done by AAD which allows for aliasing of tenants (windows.sts.net is the same as login.windows.com). func (c Client) AADInstanceDiscovery(ctx context.Context, authorityInfo Info) (InstanceDiscoveryResponse, error) { region := "" var err error @@ -454,9 +467,10 @@ func (c Client) AADInstanceDiscovery(ctx context.Context, authorityInfo Info) (I if region != "" { environment := authorityInfo.Host switch environment { - case "login.microsoft.com", "login.windows.net", "sts.windows.net", defaultHost: - environment = "r." + defaultHost + case loginMicrosoft, loginWindows, loginSTSWindows, defaultHost: + environment = loginMicrosoft } + resp.TenantDiscoveryEndpoint = fmt.Sprintf(tenantDiscoveryEndpointWithRegion, region, environment, authorityInfo.Tenant) metadata := InstanceDiscoveryMetadata{ PreferredNetwork: fmt.Sprintf("%v.%v", region, authorityInfo.Host), diff --git a/apps/internal/oauth/ops/authority/authority_test.go b/apps/internal/oauth/ops/authority/authority_test.go index d33b3677..0ce103fc 100644 --- a/apps/internal/oauth/ops/authority/authority_test.go +++ b/apps/internal/oauth/ops/authority/authority_test.go @@ -267,7 +267,7 @@ func TestAADInstanceDiscoveryWithRegion(t *testing.T) { client := Client{&fakeJSONCaller{}} region := "region" discoveryPath := "tenant/v2.0/.well-known/openid-configuration" - publicCloudEndpoint := fmt.Sprintf("https://%s.r.login.microsoftonline.com/%s", region, discoveryPath) + publicCloudEndpoint := fmt.Sprintf("https://%s.login.microsoft.com/%s", region, discoveryPath) for _, test := range []struct{ host, expectedEndpoint string }{ {"login.chinacloudapi.cn", fmt.Sprintf("https://%s.login.chinacloudapi.cn/%s", region, discoveryPath)}, {"login.microsoft.com", publicCloudEndpoint}, diff --git a/apps/internal/version/version.go b/apps/internal/version/version.go index 4a20fef3..b76c0c56 100644 --- a/apps/internal/version/version.go +++ b/apps/internal/version/version.go @@ -5,4 +5,4 @@ package version // Version is the version of this client package that is communicated to the server. -const Version = "0.9.0" +const Version = "1.0.0"