diff --git a/sdk/azidentity/azidentity.go b/sdk/azidentity/azidentity.go index 33091e2d8d53..7cb45df01369 100644 --- a/sdk/azidentity/azidentity.go +++ b/sdk/azidentity/azidentity.go @@ -53,8 +53,13 @@ var ( errInvalidTenantID = errors.New("invalid tenantID. You can locate your tenantID by following the instructions listed here: https://learn.microsoft.com/partner-center/find-ids-and-domain-names") ) -// TokenCachePersistenceOptions contains options for persistent token caching -type TokenCachePersistenceOptions = internal.TokenCachePersistenceOptions +// Cache represents a persistent cache that makes authentication data available across processes. +// Construct one with [github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache.New]. This package's +// [persistent user authentication example] shows how to use a persistent cache to reuse logins +// across application runs. +// +// [persistent user authentication example]: https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/azidentity@v1.8.0-beta.1#example-package-PersistentUserAuthentication +type Cache = internal.Cache // setAuthorityHost initializes the authority host for credentials. Precedence is: // 1. cloud.Configuration.ActiveDirectoryAuthorityHost value set by user diff --git a/sdk/azidentity/azidentity_test.go b/sdk/azidentity/azidentity_test.go index bdfe5f934868..27847576b3cd 100644 --- a/sdk/azidentity/azidentity_test.go +++ b/sdk/azidentity/azidentity_test.go @@ -18,7 +18,6 @@ import ( "os" "path/filepath" "reflect" - "runtime" "strings" "testing" "time" @@ -213,6 +212,17 @@ func TestTenantID(t *testing.T) { } } +type testCache []byte + +func (c *testCache) Export(_ context.Context, m cache.Marshaler, _ cache.ExportHints) (err error) { + *c, err = m.Marshal() + return +} + +func (c *testCache) Replace(_ context.Context, u cache.Unmarshaler, _ cache.ReplaceHints) error { + return u.Unmarshal(*c) +} + func TestUserAuthentication(t *testing.T) { type authenticater interface { azcore.TokenCredential @@ -221,30 +231,30 @@ func TestUserAuthentication(t *testing.T) { for _, credential := range []struct { name string interactive, recordable bool - new func(*TokenCachePersistenceOptions, azcore.ClientOptions, AuthenticationRecord, bool) (authenticater, error) + new func(Cache, azcore.ClientOptions, AuthenticationRecord, bool) (authenticater, error) }{ { name: credNameBrowser, - new: func(tcpo *TokenCachePersistenceOptions, co azcore.ClientOptions, ar AuthenticationRecord, disableAutoAuth bool) (authenticater, error) { + new: func(c Cache, co azcore.ClientOptions, ar AuthenticationRecord, disableAutoAuth bool) (authenticater, error) { return NewInteractiveBrowserCredential(&InteractiveBrowserCredentialOptions{ AdditionallyAllowedTenants: []string{"*"}, AuthenticationRecord: ar, + Cache: c, ClientOptions: co, DisableAutomaticAuthentication: disableAutoAuth, - TokenCachePersistenceOptions: tcpo, }) }, interactive: true, }, { name: credNameDeviceCode, - new: func(tcpo *TokenCachePersistenceOptions, co azcore.ClientOptions, ar AuthenticationRecord, disableAutoAuth bool) (authenticater, error) { + new: func(c Cache, co azcore.ClientOptions, ar AuthenticationRecord, disableAutoAuth bool) (authenticater, error) { o := DeviceCodeCredentialOptions{ AdditionallyAllowedTenants: []string{"*"}, AuthenticationRecord: ar, + Cache: c, ClientOptions: co, DisableAutomaticAuthentication: disableAutoAuth, - TokenCachePersistenceOptions: tcpo, } if recording.GetRecordMode() == recording.PlaybackMode { o.UserPrompt = func(context.Context, DeviceCodeMessage) error { return nil } @@ -256,12 +266,12 @@ func TestUserAuthentication(t *testing.T) { }, { name: credNameUserPassword, - new: func(tcpo *TokenCachePersistenceOptions, co azcore.ClientOptions, ar AuthenticationRecord, disableAutoAuth bool) (authenticater, error) { + new: func(c Cache, co azcore.ClientOptions, ar AuthenticationRecord, disableAutoAuth bool) (authenticater, error) { opts := UsernamePasswordCredentialOptions{ - AdditionallyAllowedTenants: []string{"*"}, - AuthenticationRecord: ar, - ClientOptions: co, - TokenCachePersistenceOptions: tcpo, + AdditionallyAllowedTenants: []string{"*"}, + AuthenticationRecord: ar, + Cache: c, + ClientOptions: co, } return NewUsernamePasswordCredential(liveUser.tenantID, developerSignOnClientID, liveUser.username, liveUser.password, &opts) }, @@ -286,13 +296,13 @@ func TestUserAuthentication(t *testing.T) { }} co := azcore.ClientOptions{Cloud: cc, Transport: &sts} - cred, err := credential.new(nil, co, AuthenticationRecord{}, false) + cred, err := credential.new(Cache{}, co, AuthenticationRecord{}, false) require.NoError(t, err) _, err = cred.Authenticate(context.Background(), nil) require.NoError(t, err) t.Setenv(azureAuthorityHost, cc.ActiveDirectoryAuthorityHost) - cred, err = credential.new(nil, azcore.ClientOptions{Transport: &sts}, AuthenticationRecord{}, false) + cred, err = credential.new(Cache{}, azcore.ClientOptions{Transport: &sts}, AuthenticationRecord{}, false) require.NoError(t, err) _, err = cred.Authenticate(context.Background(), nil) if cc.ActiveDirectoryAuthorityHost == customCloud.ActiveDirectoryAuthorityHost { @@ -320,14 +330,14 @@ func TestUserAuthentication(t *testing.T) { counter := tokenRequestCountingPolicy{} co.PerCallPolicies = append(co.PerCallPolicies, &counter) - cred, err := credential.new(nil, co, AuthenticationRecord{}, false) + cred, err := credential.new(Cache{}, co, AuthenticationRecord{}, false) require.NoError(t, err) ar, err := cred.Authenticate(context.Background(), &testTRO) require.NoError(t, err) // some fields of the returned AuthenticationRecord should have specific values - require.Equal(t, ar.ClientID, developerSignOnClientID) - require.Equal(t, ar.Version, supportedAuthRecordVersions[0]) + require.Equal(t, developerSignOnClientID, ar.ClientID) + require.Equal(t, supportedAuthRecordVersions[0], ar.Version) // all others should have nonempty values v := reflect.Indirect(reflect.ValueOf(&ar)) for _, f := range reflect.VisibleFields(reflect.TypeOf(ar)) { @@ -337,48 +347,47 @@ func TestUserAuthentication(t *testing.T) { require.Equal(t, 1, counter.count) }) - t.Run("PersistentCache_Live/"+credential.name, func(t *testing.T) { - switch recording.GetRecordMode() { - case recording.LiveMode: - if credential.interactive && !runManualTests { - t.Skipf("set %s to run this test", azidentityRunManualTests) - } - case recording.PlaybackMode, recording.RecordingMode: - if !credential.recordable { - t.Skip("this test can't be recorded") - } + t.Run("PersistentCache/"+credential.name, func(t *testing.T) { + if credential.name == credNameBrowser && !runManualTests { + t.Skipf("set %s to run this test", azidentityRunManualTests) } - if runtime.GOOS != "windows" { - t.Skip("this test runs only on Windows") - } - p, err := internal.CacheFilePath(t.Name()) - require.NoError(t, err) - os.Remove(p) - co, stop := initRecording(t) - defer stop() - counter := tokenRequestCountingPolicy{} - co.PerCallPolicies = append(co.PerCallPolicies, &counter) - tcpo := TokenCachePersistenceOptions{Name: t.Name()} + tokenReqs := 0 + c := internal.NewCache(func(bool) (cache.ExportReplace, error) { + return &testCache{}, nil + }) + co := azcore.ClientOptions{Transport: &mockSTS{ + tokenRequestCallback: func(*http.Request) *http.Response { + tokenReqs++ + return nil + }, + }} - cred, err := credential.new(&tcpo, co, AuthenticationRecord{}, true) + cred, err := credential.new(c, co, AuthenticationRecord{}, false) require.NoError(t, err) - record, err := cred.Authenticate(context.Background(), &testTRO) + record, err := cred.Authenticate(ctx, &testTRO) require.NoError(t, err) - defer os.Remove(p) - tk, err := cred.GetToken(context.Background(), testTRO) + _, err = cred.GetToken(ctx, testTRO) require.NoError(t, err) - require.Equal(t, 1, counter.count) + require.Equal(t, 1, tokenReqs) - cred2, err := credential.new(&tcpo, co, record, true) + // cred2 should return the token cached by cred + cred2, err := credential.new(c, co, record, true) require.NoError(t, err) - tk2, err := cred2.GetToken(context.Background(), testTRO) + _, err = cred2.GetToken(ctx, testTRO) require.NoError(t, err) - require.Equal(t, tk.Token, tk2.Token) + require.Equal(t, 1, tokenReqs) + + // cred should request a new token because the cached one isn't a CAE token + caeTRO := testTRO + caeTRO.EnableCAE = true + _, err = cred.GetToken(ctx, caeTRO) + require.NoError(t, err) + require.Equal(t, 2, tokenReqs) }) if credential.interactive { t.Run("DisableAutomaticAuthentication/"+credential.name, func(t *testing.T) { - cred, err := credential.new(nil, policy.ClientOptions{Transport: &mockSTS{}}, AuthenticationRecord{}, true) + cred, err := credential.new(Cache{}, policy.ClientOptions{Transport: &mockSTS{}}, AuthenticationRecord{}, true) require.NoError(t, err) expected := policy.TokenRequestOptions{ Claims: "claims", @@ -402,7 +411,7 @@ func TestUserAuthentication(t *testing.T) { } }) t.Run("DisableAutomaticAuthentication/ChainedTokenCredential/"+credential.name, func(t *testing.T) { - cred, err := credential.new(nil, policy.ClientOptions{}, AuthenticationRecord{}, true) + cred, err := credential.new(Cache{}, policy.ClientOptions{}, AuthenticationRecord{}, true) require.NoError(t, err) expected := azcore.AccessToken{ExpiresOn: time.Now().UTC(), Token: tokenValue} fake := NewFakeCredential() @@ -1103,107 +1112,90 @@ func TestResolveTenant(t *testing.T) { } } -func TestTokenCachePersistenceOptions(t *testing.T) { - af := filepath.Join(t.TempDir(), t.Name()+credNameWorkloadIdentity) - if err := os.WriteFile(af, []byte("assertion"), os.ModePerm); err != nil { - t.Fatal(err) - } - before := internal.NewCache - t.Cleanup(func() { internal.NewCache = before }) - for _, test := range []struct { - desc string - options *TokenCachePersistenceOptions - err error +func TestConfidentialClientPersistentCache(t *testing.T) { + // for WorkloadIdentityCredential + tfp := filepath.Join(t.TempDir(), "tokenfile") + require.NoError(t, os.WriteFile(tfp, []byte("token"), 0600)) + for _, credential := range []struct { + name string + new func(azcore.ClientOptions, Cache) (azcore.TokenCredential, error) }{ { - desc: "nil options", + name: credNameAssertion, + new: func(co azcore.ClientOptions, c Cache) (azcore.TokenCredential, error) { + o := ClientAssertionCredentialOptions{Cache: c, ClientOptions: co} + return NewClientAssertionCredential(fakeTenantID, fakeClientID, func(context.Context) (string, error) { return "...", nil }, &o) + }, }, + // TODO: set SYSTEM_OIDC_REQUEST_URI, fake response + // { + // name: credNameAzurePipelines, + // new: func(co azcore.ClientOptions, c Cache) (azcore.TokenCredential, error) { + // o := AzurePipelinesCredentialOptions{Cache: c, ClientOptions: co} + // return NewAzurePipelinesCredential(fakeTenantID, fakeClientID, "service-connection", tokenValue, &o) + // }, + // }, { - desc: "default options", - options: &TokenCachePersistenceOptions{}, + name: credNameCert, + new: func(co azcore.ClientOptions, c Cache) (azcore.TokenCredential, error) { + o := ClientCertificateCredentialOptions{Cache: c, ClientOptions: co} + return NewClientCertificateCredential(fakeTenantID, fakeClientID, allCertTests[0].certs, allCertTests[0].key, &o) + }, }, { - desc: "all options set", - options: &TokenCachePersistenceOptions{AllowUnencryptedStorage: true, Name: "name"}, + name: credNameSecret, + new: func(co azcore.ClientOptions, c Cache) (azcore.TokenCredential, error) { + o := ClientSecretCredentialOptions{Cache: c, ClientOptions: co} + return NewClientSecretCredential(fakeTenantID, fakeClientID, fakeSecret, &o) + }, }, - } { - internal.NewCache = func(o *internal.TokenCachePersistenceOptions, _ bool) (cache.ExportReplace, error) { - if (test.options == nil) != (o == nil) { - t.Fatalf("expected %v, got %v", test.options, o) - } - if test.options != nil { - if test.options.AllowUnencryptedStorage != o.AllowUnencryptedStorage { - t.Fatalf("expected AllowUnencryptedStorage %v, got %v", test.options.AllowUnencryptedStorage, o.AllowUnencryptedStorage) - } - if test.options.Name != o.Name { - t.Fatalf("expected Name %q, got %q", test.options.Name, o.Name) + { + name: credNameWorkloadIdentity, + new: func(co azcore.ClientOptions, c Cache) (azcore.TokenCredential, error) { + o := WorkloadIdentityCredentialOptions{ + Cache: c, + ClientID: fakeClientID, + ClientOptions: co, + TenantID: fakeTenantID, + TokenFilePath: tfp, } - } - return nil, nil - } - for _, subtest := range []struct { - ctor func(azcore.ClientOptions, *TokenCachePersistenceOptions) (azcore.TokenCredential, error) - env map[string]string - name string - }{ - { - name: credNameAssertion, - ctor: func(co azcore.ClientOptions, tco *TokenCachePersistenceOptions) (azcore.TokenCredential, error) { - o := ClientAssertionCredentialOptions{ClientOptions: co, TokenCachePersistenceOptions: tco} - return NewClientAssertionCredential(fakeTenantID, fakeClientID, func(context.Context) (string, error) { return "...", nil }, &o) - }, - }, - { - name: credNameCert, - ctor: func(co azcore.ClientOptions, tco *TokenCachePersistenceOptions) (azcore.TokenCredential, error) { - o := ClientCertificateCredentialOptions{ClientOptions: co, TokenCachePersistenceOptions: tco} - return NewClientCertificateCredential(fakeTenantID, fakeClientID, allCertTests[0].certs, allCertTests[0].key, &o) - }, - }, - { - name: credNameDeviceCode, - ctor: func(co azcore.ClientOptions, tco *TokenCachePersistenceOptions) (azcore.TokenCredential, error) { - o := DeviceCodeCredentialOptions{ - ClientOptions: co, - TokenCachePersistenceOptions: tco, - UserPrompt: func(context.Context, DeviceCodeMessage) error { return nil }, - } - return NewDeviceCodeCredential(&o) - }, - }, - { - name: credNameSecret, - ctor: func(co azcore.ClientOptions, tco *TokenCachePersistenceOptions) (azcore.TokenCredential, error) { - o := ClientSecretCredentialOptions{ClientOptions: co, TokenCachePersistenceOptions: tco} - return NewClientSecretCredential(fakeTenantID, fakeClientID, fakeSecret, &o) - }, - }, - { - name: credNameUserPassword, - ctor: func(co azcore.ClientOptions, tco *TokenCachePersistenceOptions) (azcore.TokenCredential, error) { - o := UsernamePasswordCredentialOptions{ClientOptions: co, TokenCachePersistenceOptions: tco} - return NewUsernamePasswordCredential(fakeTenantID, fakeClientID, fakeUsername, "password", &o) - }, + return NewWorkloadIdentityCredential(&o) }, - } { - t.Run(fmt.Sprintf("%s/%s", subtest.name, test.desc), func(t *testing.T) { - for k, v := range subtest.env { - t.Setenv(k, v) - } - c, err := subtest.ctor(policy.ClientOptions{Transport: &mockSTS{}}, test.options) - if err != nil { - t.Fatal(err) - } - _, err = c.GetToken(context.Background(), testTRO) - if err != nil { - if !errors.Is(err, test.err) { - t.Fatalf("expected %v, got %v", test.err, err) - } - } else if test.err != nil { - t.Fatal("expected an error") - } + }, + } { + t.Run(credential.name, func(t *testing.T) { + tokenReqs := 0 + c := internal.NewCache(func(bool) (cache.ExportReplace, error) { + return &testCache{}, nil }) - } + sts := mockSTS{ + tokenRequestCallback: func(*http.Request) *http.Response { + tokenReqs++ + return nil + }, + } + cred, err := credential.new(policy.ClientOptions{Transport: &sts}, c) + require.NoError(t, err) + _, err = cred.GetToken(context.Background(), testTRO) + require.NoError(t, err) + _, err = cred.GetToken(ctx, testTRO) + require.NoError(t, err) + require.Equal(t, 1, tokenReqs) + + // cred2 should return the token cached by cred + cred2, err := credential.new(policy.ClientOptions{Transport: &sts}, c) + require.NoError(t, err) + _, err = cred2.GetToken(ctx, testTRO) + require.NoError(t, err) + require.Equal(t, 1, tokenReqs) + + // cred should request a new token because the cached one isn't a CAE token + caeTRO := testTRO + caeTRO.EnableCAE = true + _, err = cred.GetToken(ctx, caeTRO) + require.NoError(t, err) + require.Equal(t, 2, tokenReqs) + }) } } diff --git a/sdk/azidentity/azure_pipelines_credential.go b/sdk/azidentity/azure_pipelines_credential.go index 80c1806bb187..320551ffb769 100644 --- a/sdk/azidentity/azure_pipelines_credential.go +++ b/sdk/azidentity/azure_pipelines_credential.go @@ -40,6 +40,11 @@ type AzurePipelinesCredentialOptions struct { // application is registered. AdditionallyAllowedTenants []string + // Cache is a persistent cache the credential will use to store the tokens it acquires, making + // them available to other processes and credential instances. The default, zero value means the + // credential will store tokens in memory and not share them. + Cache Cache + // DisableInstanceDiscovery should be set true only by applications authenticating in disconnected clouds, or // private clouds such as Azure Stack. It determines whether the credential requests Microsoft Entra instance metadata // from https://login.microsoft.com before authenticating. Setting this to true will skip this request, making @@ -83,6 +88,7 @@ func NewAzurePipelinesCredential(tenantID, clientID, serviceConnectionID, system } caco := ClientAssertionCredentialOptions{ AdditionallyAllowedTenants: options.AdditionallyAllowedTenants, + Cache: options.Cache, ClientOptions: options.ClientOptions, DisableInstanceDiscovery: options.DisableInstanceDiscovery, } diff --git a/sdk/azidentity/cache/cache.go b/sdk/azidentity/cache/cache.go index 56243fa9aeb1..dfd93809db8d 100644 --- a/sdk/azidentity/cache/cache.go +++ b/sdk/azidentity/cache/cache.go @@ -8,45 +8,101 @@ package cache import ( + "bytes" + "context" "fmt" "path/filepath" + "sync" + "time" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/Azure/azure-sdk-for-go/sdk/azidentity/internal" extcache "github.com/AzureAD/microsoft-authentication-extensions-for-go/cache" - "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" + msal "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" ) -const defaultName = "msal.cache" - -func init() { - internal.NewCache = func(o *internal.TokenCachePersistenceOptions, enableCAE bool) (cache.ExportReplace, error) { - if o == nil { - return nil, nil +var ( + // once ensures New tests the storage implementation only once + once = &sync.Once{} + // storageError is the error from the storage test + storageError error + // tryStorage tests the storage implementation by round-tripping data + tryStorage = func() { + const errFmt = "persistent storage isn't available due to error %q" + s, err := storage("azidentity-test-cache") + if err != nil { + storageError = fmt.Errorf(errFmt, err) + return } - cp := *o - if cp.Name == "" { - cp.Name = defaultName + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + in := []byte("test") + err = s.Write(ctx, in) + if err != nil { + storageError = fmt.Errorf(errFmt, err) + return + } + out, err := s.Read(ctx) + if err != nil { + storageError = fmt.Errorf(errFmt, err) + return } - suffix := ".nocae" - if enableCAE { - suffix = ".cae" + if !bytes.Equal(in, out) { + storageError = fmt.Errorf(errFmt, "read doesn't match write") } - cp.Name += suffix - a, err := storage(cp) + err = s.Delete(ctx) if err != nil { - return nil, err + storageError = fmt.Errorf(errFmt, err) + } + } +) + +// Options for persistent token caches. +type Options struct { + // Name distinguishes caches. Set this to isolate data from other applications. + Name string +} + +// New constructs persistent token caches. See the [token caching guide] for details +// about the storage implementation. +// +// [token caching guide]: https://aka.ms/azsdk/go/identity/caching#Persistent-token-caching +func New(opts *Options) (azidentity.Cache, error) { + once.Do(tryStorage) + if storageError != nil { + return azidentity.Cache{}, storageError + } + o := Options{} + if opts != nil { + o = *opts + } + if o.Name == "" { + o.Name = "msal.cache" + } + factory := func(cae bool) (msal.ExportReplace, error) { + name := o.Name + if cae { + name += ".cae" } - p, err := internal.CacheFilePath(cp.Name) + p, err := cacheFilePath(name) if err != nil { return nil, err } - return extcache.New(a, p) - } - internal.CacheFilePath = func(name string) (string, error) { - dir, err := cacheDir() + s, err := storage(name) if err != nil { - return "", fmt.Errorf("couldn't create a cache file due to error %q", err) + return nil, err } - return filepath.Join(dir, ".IdentityService", name), nil + return extcache.New(s, p) + } + return internal.NewCache(factory), nil +} + +// cacheFilePath maps a cache name to a file path. This path is the base for a lockfile. +// Storage implementations may also use it directly to store cache data. +func cacheFilePath(name string) (string, error) { + dir, err := cacheDir() + if err != nil { + return "", fmt.Errorf("couldn't create a cache file due to error %q", err) } + return filepath.Join(dir, ".IdentityService", name), nil } diff --git a/sdk/azidentity/cache/windows_test.go b/sdk/azidentity/cache/cache_test.go similarity index 53% rename from sdk/azidentity/cache/windows_test.go rename to sdk/azidentity/cache/cache_test.go index 7ea953cbfcb0..5abcca9ae88d 100644 --- a/sdk/azidentity/cache/windows_test.go +++ b/sdk/azidentity/cache/cache_test.go @@ -1,5 +1,6 @@ -//go:build go1.18 && windows -// +build go1.18,windows +//go:build go1.18 && (linux || windows) +// +build go1.18 +// +build linux windows // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. @@ -20,26 +21,29 @@ import ( var ctx = context.Background() -func TestCaching(t *testing.T) { +func TestCache(t *testing.T) { + before := cacheDir + t.Cleanup(func() { cacheDir = before }) + cacheDir = func() (string, error) { return t.TempDir(), nil } for _, test := range []struct { - ctor func(azidentity.TokenCachePersistenceOptions) (azcore.TokenCredential, error) - name string + credential func(azidentity.Cache) (azcore.TokenCredential, error) + name string }{ { - func(tcpo azidentity.TokenCachePersistenceOptions) (azcore.TokenCredential, error) { + func(c azidentity.Cache) (azcore.TokenCredential, error) { opts := azidentity.ClientSecretCredentialOptions{ - ClientOptions: policy.ClientOptions{Transport: &mockSTS{}}, - TokenCachePersistenceOptions: &tcpo, + Cache: c, + ClientOptions: policy.ClientOptions{Transport: &mockSTS{}}, } return azidentity.NewClientSecretCredential("tenantID", "clientID", "secret", &opts) }, "confidential", }, { - func(tcpo azidentity.TokenCachePersistenceOptions) (azcore.TokenCredential, error) { + func(c azidentity.Cache) (azcore.TokenCredential, error) { opts := azidentity.DeviceCodeCredentialOptions{ - ClientOptions: policy.ClientOptions{Transport: &mockSTS{}}, - TokenCachePersistenceOptions: &tcpo, + Cache: c, + ClientOptions: policy.ClientOptions{Transport: &mockSTS{}}, } return azidentity.NewDeviceCodeCredential(&opts) }, @@ -47,19 +51,16 @@ func TestCaching(t *testing.T) { }, } { t.Run(test.name, func(t *testing.T) { - tcpo := azidentity.TokenCachePersistenceOptions{ - Name: strings.ReplaceAll(t.Name(), string(filepath.Separator), "_"), - } - if a, e := storage(azidentity.TokenCachePersistenceOptions{Name: tcpo.Name + ".nocae"}); e == nil { - defer func() { a.Delete(ctx) }() - } - cred, err := test.ctor(tcpo) + name := strings.ReplaceAll(t.Name(), string(filepath.Separator), "_") + cache, err := New(&Options{Name: name}) + require.NoError(t, err) + cred, err := test.credential(cache) require.NoError(t, err) tro := policy.TokenRequestOptions{Scopes: []string{"scope"}} tk, err := cred.GetToken(ctx, tro) require.NoError(t, err) - cred2, err := test.ctor(tcpo) + cred2, err := test.credential(cache) require.NoError(t, err) tk2, err := cred2.GetToken(ctx, tro) require.NoError(t, err) diff --git a/sdk/azidentity/cache/ci.azidentity.yml b/sdk/azidentity/cache/ci.azidentity.yml index 32df6a86e1f8..6088a12a73d0 100644 --- a/sdk/azidentity/cache/ci.azidentity.yml +++ b/sdk/azidentity/cache/ci.azidentity.yml @@ -24,4 +24,5 @@ pr: extends: template: /eng/pipelines/templates/jobs/archetype-sdk-client.yml parameters: + EnableRaceDetector: true ServiceDirectory: 'azidentity/cache' diff --git a/sdk/azidentity/cache/darwin.go b/sdk/azidentity/cache/darwin.go index c5ed0fd45454..d61c29149068 100644 --- a/sdk/azidentity/cache/darwin.go +++ b/sdk/azidentity/cache/darwin.go @@ -7,42 +7,14 @@ package cache import ( - "context" - "errors" "os" - "time" - "github.com/Azure/azure-sdk-for-go/sdk/azidentity/internal" "github.com/AzureAD/microsoft-authentication-extensions-for-go/cache/accessor" ) -var cacheDir = os.UserHomeDir - -func storage(o internal.TokenCachePersistenceOptions) (accessor.Accessor, error) { - name := o.Name - if name == "" { - name = defaultName - } - if err := tryAccessor(); err != nil { - return nil, errors.New("cache encryption is impossible because the keychain isn't usable: " + err.Error()) - } - return accessor.New(name, accessor.WithAccount("MSALCache")) -} - -func tryAccessor() error { - a, err := accessor.New("azidentity-test-cache") - if err != nil { - return err +var ( + cacheDir = os.UserHomeDir + storage = func(name string) (accessor.Accessor, error) { + return accessor.New(name, accessor.WithAccount("MSALCache")) } - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() - err = a.Write(ctx, []byte("test")) - if err != nil { - return err - } - _, err = a.Read(ctx) - if err != nil { - return err - } - return a.Delete(ctx) -} +) diff --git a/sdk/azidentity/cache/linux.go b/sdk/azidentity/cache/linux.go index 264d38db667f..3b7850ac51af 100644 --- a/sdk/azidentity/cache/linux.go +++ b/sdk/azidentity/cache/linux.go @@ -16,7 +16,6 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache/internal/aescbc" "github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache/internal/jwe" - "github.com/Azure/azure-sdk-for-go/sdk/azidentity/internal" "github.com/AzureAD/microsoft-authentication-extensions-for-go/cache/accessor" "golang.org/x/sys/unix" ) @@ -27,38 +26,12 @@ const ( ) var ( - cacheDir = os.UserHomeDir - tryKeyring = func() error { - k, err := newKeyring("azidentity-test-cache") - if err != nil { - return err - } - // the Accessor interface requires contexts for these methods but this implementation - // doesn't use them, which is okay because these methods don't block on user interaction - ctx := context.Background() - err = k.Write(ctx, []byte("test")) - if err != nil { - return err - } - _, err = k.Read(ctx) - if err != nil { - return err - } - return k.Delete(ctx) + cacheDir = os.UserHomeDir + storage = func(name string) (accessor.Accessor, error) { + return newKeyring(name) } ) -func storage(o internal.TokenCachePersistenceOptions) (accessor.Accessor, error) { - name := o.Name - if name == "" { - name = defaultName - } - if err := tryKeyring(); err != nil { - return nil, errors.New("cache encryption is impossible because the kernel key retention facility isn't usable: " + err.Error()) - } - return newKeyring(name) -} - // keyring encrypts cache data with a key stored on the user keyring and writes the encrypted // data to a file. The encryption key, and thus the data, is lost when the system shuts down. type keyring struct { @@ -68,7 +41,7 @@ type keyring struct { } func newKeyring(name string) (*keyring, error) { - p, err := internal.CacheFilePath(name) + p, err := cacheFilePath(name) if err != nil { return nil, err } diff --git a/sdk/azidentity/cache/linux_test.go b/sdk/azidentity/cache/linux_test.go index ed72499aa73a..9927c6680db4 100644 --- a/sdk/azidentity/cache/linux_test.go +++ b/sdk/azidentity/cache/linux_test.go @@ -7,28 +7,23 @@ package cache import ( - "context" - "errors" "os" "path/filepath" "testing" - "github.com/Azure/azure-sdk-for-go/sdk/azidentity/internal" "github.com/google/uuid" "github.com/stretchr/testify/require" ) -var ctx = context.Background() - func TestKeyExistsButNotFile(t *testing.T) { expected := []byte(t.Name()) - a, err := storage(internal.TokenCachePersistenceOptions{Name: t.Name()}) + a, err := storage(t.Name()) require.NoError(t, err) err = a.Write(ctx, append([]byte("not"), expected...)) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, a.Delete(ctx)) }) - p, err := internal.CacheFilePath(t.Name()) + p, err := cacheFilePath(t.Name()) require.NoError(t, err) require.NoError(t, os.Remove(p)) @@ -139,14 +134,3 @@ func TestTwoInstances(t *testing.T) { }) } } - -func TestKeyringUnusable(t *testing.T) { - before := tryKeyring - t.Cleanup(func() { tryKeyring = before }) - expected := errors.New("it didn't work") - tryKeyring = func() error { return expected } - - _, err := storage(internal.TokenCachePersistenceOptions{}) - require.Error(t, err) - require.Contains(t, err.Error(), expected.Error()) -} diff --git a/sdk/azidentity/cache/mock_test.go b/sdk/azidentity/cache/mock_test.go index d623cbaa6237..151b95f3ebe9 100644 --- a/sdk/azidentity/cache/mock_test.go +++ b/sdk/azidentity/cache/mock_test.go @@ -1,5 +1,6 @@ -//go:build go1.18 && windows -// +build go1.18,windows +//go:build go1.18 && (linux || windows) +// +build go1.18 +// +build linux windows // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. diff --git a/sdk/azidentity/cache/new_test.go b/sdk/azidentity/cache/new_test.go new file mode 100644 index 000000000000..7314ddafb955 --- /dev/null +++ b/sdk/azidentity/cache/new_test.go @@ -0,0 +1,75 @@ +//go:build go1.18 && (darwin || linux || windows) +// +build go1.18 +// +build darwin linux windows + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cache + +import ( + "errors" + "path/filepath" + "sync" + "testing" + + "github.com/AzureAD/microsoft-authentication-extensions-for-go/cache/accessor" + "github.com/AzureAD/microsoft-authentication-extensions-for-go/cache/accessor/file" + "github.com/stretchr/testify/require" +) + +func TestNew(t *testing.T) { + errBefore := storageError + onceBefore := once + storageBefore := storage + tryStorageBefore := tryStorage + t.Cleanup(func() { + once = onceBefore + storage = storageBefore + storageError = errBefore + tryStorage = tryStorageBefore + }) + for _, expectedErr := range []error{nil, errors.New("it didn't work")} { + name := "storage error" + if expectedErr == nil { + name = "no storage error" + } + t.Run(name, func(t *testing.T) { + once = &sync.Once{} + storage = func(string) (accessor.Accessor, error) { + p := filepath.Join(t.TempDir(), t.Name()) + return file.New(p) + } + storageError = nil + tries := 0 + tryStorage = func() { + tries++ + storageError = expectedErr + } + wg := &sync.WaitGroup{} + ch := make(chan error, 1) + for i := 0; i < 50; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if _, err := New(&Options{Name: t.Name()}); err != nil { + select { + case ch <- err: + default: + } + } + }() + } + wg.Wait() + select { + case err := <-ch: + if expectedErr == nil { + t.Fatal(err) + } + require.EqualError(t, err, expectedErr.Error()) + default: + } + require.Equal(t, 1, tries, "tryStorage was called more than once") + }) + } +} diff --git a/sdk/azidentity/cache/windows.go b/sdk/azidentity/cache/windows.go index 7411c7aa015c..4dd041db978e 100644 --- a/sdk/azidentity/cache/windows.go +++ b/sdk/azidentity/cache/windows.go @@ -7,19 +7,19 @@ package cache import ( - "github.com/Azure/azure-sdk-for-go/sdk/azidentity/internal" "github.com/AzureAD/microsoft-authentication-extensions-for-go/cache/accessor" "golang.org/x/sys/windows" ) -func cacheDir() (string, error) { - return windows.KnownFolderPath(windows.FOLDERID_LocalAppData, 0) -} - -func storage(o internal.TokenCachePersistenceOptions) (accessor.Accessor, error) { - p, err := internal.CacheFilePath(o.Name) - if err != nil { - return nil, err +var ( + cacheDir = func() (string, error) { + return windows.KnownFolderPath(windows.FOLDERID_LocalAppData, 0) + } + storage = func(name string) (accessor.Accessor, error) { + p, err := cacheFilePath(name) + if err != nil { + return nil, err + } + return accessor.New(p) } - return accessor.New(p) -} +) diff --git a/sdk/azidentity/ci.yml b/sdk/azidentity/ci.yml index 80d1abb3d1e5..7a3870299580 100644 --- a/sdk/azidentity/ci.yml +++ b/sdk/azidentity/ci.yml @@ -30,6 +30,7 @@ extends: SubscriptionConfigurations: - $(sub-config-azure-cloud-test-resources) - $(sub-config-identity-test-resources) + EnableRaceDetector: true RunLiveTests: true ServiceDirectory: azidentity UseFederatedAuth: true diff --git a/sdk/azidentity/client_assertion_credential.go b/sdk/azidentity/client_assertion_credential.go index ea8a75c18d4e..e0f2a7a3793e 100644 --- a/sdk/azidentity/client_assertion_credential.go +++ b/sdk/azidentity/client_assertion_credential.go @@ -37,14 +37,16 @@ type ClientAssertionCredentialOptions struct { // application is registered. AdditionallyAllowedTenants []string + // Cache is a persistent cache the credential will use to store the tokens it acquires, making + // them available to other processes and credential instances. The default, zero value means the + // credential will store tokens in memory and not share them. + Cache Cache + // DisableInstanceDiscovery should be set true only by applications authenticating in disconnected clouds, or // private clouds such as Azure Stack. It determines whether the credential requests Microsoft Entra instance metadata // from https://login.microsoft.com before authenticating. Setting this to true will skip this request, making // the application responsible for ensuring the configured authority is valid and trustworthy. DisableInstanceDiscovery bool - - // TokenCachePersistenceOptions enables persistent token caching when not nil. - TokenCachePersistenceOptions *TokenCachePersistenceOptions } // NewClientAssertionCredential constructs a ClientAssertionCredential. The getAssertion function must be thread safe. Pass nil for options to accept defaults. @@ -61,10 +63,10 @@ func NewClientAssertionCredential(tenantID, clientID string, getAssertion func(c }, ) msalOpts := confidentialClientOptions{ - AdditionallyAllowedTenants: options.AdditionallyAllowedTenants, - ClientOptions: options.ClientOptions, - DisableInstanceDiscovery: options.DisableInstanceDiscovery, - TokenCachePersistenceOptions: options.TokenCachePersistenceOptions, + AdditionallyAllowedTenants: options.AdditionallyAllowedTenants, + Cache: options.Cache, + ClientOptions: options.ClientOptions, + DisableInstanceDiscovery: options.DisableInstanceDiscovery, } c, err := newConfidentialClient(tenantID, clientID, credNameAssertion, cred, msalOpts) if err != nil { diff --git a/sdk/azidentity/client_certificate_credential.go b/sdk/azidentity/client_certificate_credential.go index b7a2f62c48dc..aef0d4c13f19 100644 --- a/sdk/azidentity/client_certificate_credential.go +++ b/sdk/azidentity/client_certificate_credential.go @@ -31,6 +31,11 @@ type ClientCertificateCredentialOptions struct { // application is registered. AdditionallyAllowedTenants []string + // Cache is a persistent cache the credential will use to store the tokens it acquires, making + // them available to other processes and credential instances. The default, zero value means the + // credential will store tokens in memory and not share them. + Cache Cache + // DisableInstanceDiscovery should be set true only by applications authenticating in disconnected clouds, or // private clouds such as Azure Stack. It determines whether the credential requests Microsoft Entra instance metadata // from https://login.microsoft.com before authenticating. Setting this to true will skip this request, making @@ -41,9 +46,6 @@ type ClientCertificateCredentialOptions struct { // header of each token request's JWT. This is required for Subject Name/Issuer (SNI) authentication. // Defaults to False. SendCertificateChain bool - - // TokenCachePersistenceOptions enables persistent token caching when not nil. - TokenCachePersistenceOptions *TokenCachePersistenceOptions } // ClientCertificateCredential authenticates a service principal with a certificate. @@ -65,11 +67,11 @@ func NewClientCertificateCredential(tenantID string, clientID string, certs []*x return nil, err } msalOpts := confidentialClientOptions{ - AdditionallyAllowedTenants: options.AdditionallyAllowedTenants, - ClientOptions: options.ClientOptions, - DisableInstanceDiscovery: options.DisableInstanceDiscovery, - SendX5C: options.SendCertificateChain, - TokenCachePersistenceOptions: options.TokenCachePersistenceOptions, + AdditionallyAllowedTenants: options.AdditionallyAllowedTenants, + Cache: options.Cache, + ClientOptions: options.ClientOptions, + DisableInstanceDiscovery: options.DisableInstanceDiscovery, + SendX5C: options.SendCertificateChain, } c, err := newConfidentialClient(tenantID, clientID, credNameCert, cred, msalOpts) if err != nil { diff --git a/sdk/azidentity/client_secret_credential.go b/sdk/azidentity/client_secret_credential.go index d4bff927831c..4459e3a6080c 100644 --- a/sdk/azidentity/client_secret_credential.go +++ b/sdk/azidentity/client_secret_credential.go @@ -32,8 +32,10 @@ type ClientSecretCredentialOptions struct { // the application responsible for ensuring the configured authority is valid and trustworthy. DisableInstanceDiscovery bool - // TokenCachePersistenceOptions enables persistent token caching when not nil. - TokenCachePersistenceOptions *TokenCachePersistenceOptions + // Cache is a persistent cache the credential will use to store the tokens it acquires, making + // them available to other processes and credential instances. The default, zero value means the + // credential will store tokens in memory and not share them. + Cache Cache } // ClientSecretCredential authenticates an application with a client secret. @@ -51,10 +53,10 @@ func NewClientSecretCredential(tenantID string, clientID string, clientSecret st return nil, err } msalOpts := confidentialClientOptions{ - AdditionallyAllowedTenants: options.AdditionallyAllowedTenants, - ClientOptions: options.ClientOptions, - DisableInstanceDiscovery: options.DisableInstanceDiscovery, - TokenCachePersistenceOptions: options.TokenCachePersistenceOptions, + AdditionallyAllowedTenants: options.AdditionallyAllowedTenants, + Cache: options.Cache, + ClientOptions: options.ClientOptions, + DisableInstanceDiscovery: options.DisableInstanceDiscovery, } c, err := newConfidentialClient(tenantID, clientID, credNameSecret, cred, msalOpts) if err != nil { diff --git a/sdk/azidentity/confidential_client.go b/sdk/azidentity/confidential_client.go index 76020112c7a7..4e3e5da4bc72 100644 --- a/sdk/azidentity/confidential_client.go +++ b/sdk/azidentity/confidential_client.go @@ -29,8 +29,8 @@ type confidentialClientOptions struct { AdditionallyAllowedTenants []string // Assertion for on-behalf-of authentication Assertion string + Cache Cache DisableInstanceDiscovery, SendX5C bool - TokenCachePersistenceOptions *TokenCachePersistenceOptions } // confidentialClient wraps the MSAL confidential client @@ -145,7 +145,7 @@ func (c *confidentialClient) client(tro policy.TokenRequestOptions) (msalConfide } func (c *confidentialClient) newMSALClient(enableCAE bool) (msalConfidentialClient, error) { - cache, err := internal.NewCache(c.opts.TokenCachePersistenceOptions, enableCAE) + cache, err := internal.ExportReplace(c.opts.Cache, enableCAE) if err != nil { return nil, err } diff --git a/sdk/azidentity/device_code_credential.go b/sdk/azidentity/device_code_credential.go index 29a73e96e842..dc18ef2b05ce 100644 --- a/sdk/azidentity/device_code_credential.go +++ b/sdk/azidentity/device_code_credential.go @@ -29,6 +29,11 @@ type DeviceCodeCredentialOptions struct { // to enable the credential to use data from a previous authentication. AuthenticationRecord AuthenticationRecord + // Cache is a persistent cache the credential will use to store the tokens it acquires, making + // them available to other processes and credential instances. The default, zero value means the + // credential will store tokens in memory and not share them. + Cache Cache + // ClientID is the ID of the application users will authenticate to. // Defaults to the ID of an Azure development application. ClientID string @@ -49,9 +54,6 @@ type DeviceCodeCredentialOptions struct { // applications. TenantID string - // TokenCachePersistenceOptions enables persistent token caching when not nil. - TokenCachePersistenceOptions *TokenCachePersistenceOptions - // UserPrompt controls how the credential presents authentication instructions. The credential calls // this function with authentication details when it receives a device code. By default, the credential // prints these details to stdout. @@ -101,12 +103,12 @@ func NewDeviceCodeCredential(options *DeviceCodeCredentialOptions) (*DeviceCodeC cp.init() msalOpts := publicClientOptions{ AdditionallyAllowedTenants: cp.AdditionallyAllowedTenants, + Cache: cp.Cache, ClientOptions: cp.ClientOptions, DeviceCodePrompt: cp.UserPrompt, DisableAutomaticAuthentication: cp.DisableAutomaticAuthentication, DisableInstanceDiscovery: cp.DisableInstanceDiscovery, Record: cp.AuthenticationRecord, - TokenCachePersistenceOptions: cp.TokenCachePersistenceOptions, } c, err := newPublicClient(cp.TenantID, cp.ClientID, credNameDeviceCode, msalOpts) if err != nil { diff --git a/sdk/azidentity/example_user_auth_test.go b/sdk/azidentity/example_user_auth_test.go index 4f29e3c62583..bd1bb19e0746 100644 --- a/sdk/azidentity/example_user_auth_test.go +++ b/sdk/azidentity/example_user_auth_test.go @@ -12,9 +12,7 @@ import ( "os" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" - - // importing the cache module registers the cache implementation for the current platform - _ "github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache" ) // this example shows file storage but any form of byte storage would work @@ -39,20 +37,27 @@ func storeRecord(record azidentity.AuthenticationRecord) error { // interactively every time the application runs. The example uses [InteractiveBrowserCredential], however // [DeviceCodeCredential] has the same API. The key steps are: // -// 1. Enable persistent caching by importing "github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache" and -// setting [TokenCachePersistenceOptions] -// 2. Call Authenticate to acquire an [AuthenticationRecord] and store that for future use. An [AuthenticationRecord] +// 1. Call [github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache.New] to construct a persistent cache +// 2. Set the Cache field in the credential's options +// 3. Call Authenticate to acquire an [AuthenticationRecord] and store that for future use. An [AuthenticationRecord] // enables credentials to access data in the persistent cache. The record contains no authentication secrets. -// 3. Add the [AuthenticationRecord] to the credential's options +// 4. Add the [AuthenticationRecord] to the credential's options func Example_persistentUserAuthentication() { record, err := retrieveRecord() if err != nil { // TODO: handle error } + c, err := cache.New(nil) + if err != nil { + // TODO: handle error. An error here means persistent + // caching is impossible in the runtime environment. + } cred, err := azidentity.NewInteractiveBrowserCredential(&azidentity.InteractiveBrowserCredentialOptions{ + // If record is zero, the credential will start with no user logged in AuthenticationRecord: record, - // Credentials cache in memory by default. Set TokenCachePersistenceOptions to enable persistent caching. - TokenCachePersistenceOptions: &azidentity.TokenCachePersistenceOptions{}, + // Credentials cache in memory by default. Setting Cache with a + // nonzero value from cache.New() enables persistent caching. + Cache: c, }) if err != nil { // TODO: handle error diff --git a/sdk/azidentity/interactive_browser_credential.go b/sdk/azidentity/interactive_browser_credential.go index ad6bdaf69189..2f349af016b0 100644 --- a/sdk/azidentity/interactive_browser_credential.go +++ b/sdk/azidentity/interactive_browser_credential.go @@ -28,6 +28,11 @@ type InteractiveBrowserCredentialOptions struct { // to enable the credential to use data from a previous authentication. AuthenticationRecord AuthenticationRecord + // Cache is a persistent cache the credential will use to store the tokens it acquires, making + // them available to other processes and credential instances. The default, zero value means the + // credential will store tokens in memory and not share them. + Cache Cache + // ClientID is the ID of the application users will authenticate to. // Defaults to the ID of an Azure development application. ClientID string @@ -54,9 +59,6 @@ type InteractiveBrowserCredentialOptions struct { // TenantID is the Microsoft Entra tenant the credential authenticates in. Defaults to the // "organizations" tenant, which can authenticate work and school accounts. TenantID string - - // TokenCachePersistenceOptions enables persistent token caching when not nil. - TokenCachePersistenceOptions *TokenCachePersistenceOptions } func (o *InteractiveBrowserCredentialOptions) init() { @@ -82,13 +84,13 @@ func NewInteractiveBrowserCredential(options *InteractiveBrowserCredentialOption cp.init() msalOpts := publicClientOptions{ AdditionallyAllowedTenants: cp.AdditionallyAllowedTenants, + Cache: cp.Cache, ClientOptions: cp.ClientOptions, DisableAutomaticAuthentication: cp.DisableAutomaticAuthentication, DisableInstanceDiscovery: cp.DisableInstanceDiscovery, LoginHint: cp.LoginHint, Record: cp.AuthenticationRecord, RedirectURL: cp.RedirectURL, - TokenCachePersistenceOptions: cp.TokenCachePersistenceOptions, } c, err := newPublicClient(cp.TenantID, cp.ClientID, credNameBrowser, msalOpts) if err != nil { diff --git a/sdk/azidentity/internal/cache.go b/sdk/azidentity/internal/cache.go new file mode 100644 index 000000000000..001f750c8692 --- /dev/null +++ b/sdk/azidentity/internal/cache.go @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package internal + +import ( + "sync" + + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" +) + +// Cache represents a persistent cache that makes authentication data available across processes. +// Construct one with [github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache.New]. This module's +// [persistent user authentication example] shows how to use a persistent cache to reuse logins +// across application runs. +// +// [persistent user authentication example]: https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/azidentity@v1.8.0-beta.1#example-package-PersistentUserAuthentication +type Cache struct { + // impl is a pointer so a Cache can carry persistent state across copies + impl *impl +} + +// impl is a Cache's private implementation +type impl struct { + // factory constructs storage implementations + factory func(bool) (cache.ExportReplace, error) + // cae and noCAE are previously constructed storage implementations. CAE + // and non-CAE tokens must be stored separately because MSAL's cache doesn't + // observe token claims. If a single storage implementation held both kinds + // of tokens, it could create a reauthentication or error loop by returning + // a non-CAE token lacking a required claim. + cae, noCAE cache.ExportReplace + // mu synchronizes around cae and noCAE + mu *sync.RWMutex +} + +func (i *impl) exportReplace(cae bool) (cache.ExportReplace, error) { + if i == nil { + // zero-value Cache: return a nil ExportReplace and MSAL will cache in memory + return nil, nil + } + var ( + err error + xr cache.ExportReplace + ) + i.mu.RLock() + xr = i.cae + if !cae { + xr = i.noCAE + } + i.mu.RUnlock() + if xr != nil { + return xr, nil + } + i.mu.Lock() + defer i.mu.Unlock() + if cae { + if i.cae == nil { + if xr, err = i.factory(cae); err == nil { + i.cae = xr + } + } + return i.cae, err + } + if i.noCAE == nil { + if xr, err = i.factory(cae); err == nil { + i.noCAE = xr + } + } + return i.noCAE, err +} + +// NewCache is the constructor for Cache. It takes a factory instead of an instance +// because it doesn't know whether the Cache will store both CAE and non-CAE tokens. +func NewCache(factory func(cae bool) (cache.ExportReplace, error)) Cache { + return Cache{&impl{factory: factory, mu: &sync.RWMutex{}}} +} + +// ExportReplace returns an implementation satisfying MSAL's ExportReplace interface. +// It's a function instead of a method on Cache so packages in azidentity and +// azidentity/cache can call it while applications can't. "cae" declares whether the +// caller intends this implementation to store CAE tokens. +func ExportReplace(c Cache, cae bool) (cache.ExportReplace, error) { + return c.impl.exportReplace(cae) +} diff --git a/sdk/azidentity/internal/cache_test.go b/sdk/azidentity/internal/cache_test.go new file mode 100644 index 000000000000..774cdc9ef23a --- /dev/null +++ b/sdk/azidentity/internal/cache_test.go @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package internal + +import ( + "context" + "fmt" + "sync" + "testing" + + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" + "github.com/stretchr/testify/require" +) + +type testCache []byte + +func (testCache) Export(context.Context, cache.Marshaler, cache.ExportHints) error { + return nil +} + +func (testCache) Replace(context.Context, cache.Unmarshaler, cache.ReplaceHints) error { + return nil +} + +func TestExportReplace(t *testing.T) { + countCAE, countNoCAE := 0, 0 + c := NewCache(func(cae bool) (cache.ExportReplace, error) { + if cae { + countCAE++ + } else { + countNoCAE++ + } + return (testCache)([]byte(fmt.Sprint(cae))), nil + }) + wg := &sync.WaitGroup{} + ch := make(chan error, 1) + for i := 0; i < 50; i++ { + wg.Add(1) + go func(cae bool) { + defer wg.Done() + if _, err := ExportReplace(c, cae); err != nil { + select { + case ch <- err: + // set error + default: + // already set + } + } + }(i%2 == 0) + } + wg.Wait() + select { + case err := <-ch: + t.Fatal(err) + default: + } + require.Equal(t, 1, countCAE) + require.Equal(t, 1, countNoCAE) + for _, b := range []bool{false, true} { + xr, err := ExportReplace(c, b) + require.NoError(t, err) + require.EqualValues(t, []byte(fmt.Sprint(b)), xr.(testCache)) + } +} diff --git a/sdk/azidentity/internal/exported.go b/sdk/azidentity/internal/exported.go deleted file mode 100644 index b1b4d5c8bd35..000000000000 --- a/sdk/azidentity/internal/exported.go +++ /dev/null @@ -1,18 +0,0 @@ -//go:build go1.18 -// +build go1.18 - -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package internal - -// TokenCachePersistenceOptions contains options for persistent token caching -type TokenCachePersistenceOptions struct { - // AllowUnencryptedStorage controls whether the cache should fall back to storing its data in plain text - // when encryption isn't possible. Setting this true doesn't disable encryption. The cache always attempts - // encryption before falling back to plaintext storage. - AllowUnencryptedStorage bool - - // Name identifies the cache. Set this to isolate data from other applications. - Name string -} diff --git a/sdk/azidentity/internal/internal.go b/sdk/azidentity/internal/internal.go deleted file mode 100644 index c1498b464471..000000000000 --- a/sdk/azidentity/internal/internal.go +++ /dev/null @@ -1,31 +0,0 @@ -//go:build go1.18 -// +build go1.18 - -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package internal - -import ( - "errors" - - "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" -) - -var errMissingImport = errors.New("import github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache to enable persistent caching") - -// NewCache constructs a persistent token cache when "o" isn't nil. Applications that intend to -// use a persistent cache must first import the cache module, which will replace this function -// with a platform-specific implementation. -var NewCache = func(o *TokenCachePersistenceOptions, enableCAE bool) (cache.ExportReplace, error) { - if o == nil { - return nil, nil - } - return nil, errMissingImport -} - -// CacheFilePath returns the path to the cache file for the given name. -// Defining it in this package makes it available to azidentity tests. -var CacheFilePath = func(name string) (string, error) { - return "", errMissingImport -} diff --git a/sdk/azidentity/mock_test.go b/sdk/azidentity/mock_test.go index 2509f57778f1..c5aaf5f0813d 100644 --- a/sdk/azidentity/mock_test.go +++ b/sdk/azidentity/mock_test.go @@ -53,7 +53,7 @@ func (m *mockSTS) Do(req *http.Request) (*http.Response, error) { if grant := req.FormValue("grant_type"); grant == "device_code" || grant == "password" { // include account info because we're authenticating a user res.Body = io.NopCloser(bytes.NewReader( - []byte(fmt.Sprintf(`{"access_token":"at","expires_in": 3600,"refresh_token":"rt","client_info":%q,"id_token":%q,"token_type":"Bearer"}`, mockClientInfo, mockIDT)), + []byte(fmt.Sprintf(`{"access_token":%q,"expires_in": 3600,"refresh_token":"rt","client_info":%q,"id_token":%q,"token_type":"Bearer"}`, tokenValue, mockClientInfo, mockIDT)), )) } else { res.Body = io.NopCloser(bytes.NewReader(accessTokenRespSuccess)) diff --git a/sdk/azidentity/persistent_cache_live_test.go b/sdk/azidentity/persistent_cache_live_test.go new file mode 100644 index 000000000000..c90c0588e015 --- /dev/null +++ b/sdk/azidentity/persistent_cache_live_test.go @@ -0,0 +1,115 @@ +//go:build go1.18 && (darwin || linux || windows) + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// the test in this file must be defined in azidentity_test because it imports azidentity/cache + +package azidentity_test + +import ( + "context" + "os" + "runtime" + "strings" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache" + "github.com/Azure/azure-sdk-for-go/sdk/internal/recording" + "github.com/stretchr/testify/require" +) + +var ctx = context.Background() + +func TestPersistentCacheLive(t *testing.T) { + if recording.GetRecordMode() != recording.LiveMode { + t.Skip("this test runs only in live mode") + } + if runtime.GOOS == "darwin" && os.Getenv("AZIDENTITY_RUN_MANUAL_TESTS") == "" { + t.Skip("set AZIDENTITY_RUN_MANUAL_TESTS to run this test on macOS") + } + armURL := os.Getenv("RESOURCE_MANAGER_URL") + if armURL == "" { + t.Skip("set RESOURCE_MANAGER_URL to run this test") + } + tro := policy.TokenRequestOptions{Scopes: []string{armURL + "/.default"}} + for _, test := range []struct { + credential func(*testing.T, azidentity.AuthenticationRecord, azidentity.Cache) (azcore.TokenCredential, error) + name string + }{ + { + credential: func(t *testing.T, _ azidentity.AuthenticationRecord, c azidentity.Cache) (azcore.TokenCredential, error) { + t.Skip("https://github.com/Azure/azure-sdk-for-go/issues/22879") + clientID := os.Getenv("IDENTITY_SP_CLIENT_ID") + secret := os.Getenv("IDENTITY_SP_CLIENT_SECRET") + tenantID := os.Getenv("IDENTITY_SP_TENANT_ID") + if clientID == "" || secret == "" || tenantID == "" { + t.Skip("set IDENTITY_SP_* with service principal configuration to run this test") + } + return azidentity.NewClientSecretCredential(tenantID, clientID, secret, + &azidentity.ClientSecretCredentialOptions{Cache: c}, + ) + }, + name: "confidential", + }, + { + credential: func(t *testing.T, r azidentity.AuthenticationRecord, c azidentity.Cache) (azcore.TokenCredential, error) { + clientID := "04b07795-8ddb-461a-bbee-02f9e1bf7b46" + password := os.Getenv("AZURE_IDENTITY_TEST_PASSWORD") + tenantID := os.Getenv("AZURE_IDENTITY_TEST_TENANTID") + username := os.Getenv("AZURE_IDENTITY_TEST_USERNAME") + if password == "" || tenantID == "" || username == "" { + t.Skip("set AZURE_IDENTITY_TEST_* with user configuration to run this test") + } + return azidentity.NewUsernamePasswordCredential(tenantID, clientID, username, password, + &azidentity.UsernamePasswordCredentialOptions{ + AuthenticationRecord: r, + Cache: c, + }, + ) + }, + name: "public", + }, + } { + t.Run(test.name, func(t *testing.T) { + c, err := cache.New(&cache.Options{Name: strings.ReplaceAll(t.Name(), "/", "_")}) + require.NoError(t, err) + + rec := azidentity.AuthenticationRecord{} + cred, err := test.credential(t, rec, c) + require.NoError(t, err) + if test.name == "public" { + type authenticater interface { + Authenticate(context.Context, *policy.TokenRequestOptions) (azidentity.AuthenticationRecord, error) + } + a, ok := cred.(authenticater) + require.True(t, ok, "test bug: public credential must implement Authenticate") + rec, err = a.Authenticate(ctx, &tro) + require.NoError(t, err) + } + tk, err := cred.GetToken(ctx, tro) + require.NoError(t, err) + + cred2, err := test.credential(t, rec, c) + require.NoError(t, err) + tk2, err := cred2.GetToken(ctx, tro) + require.NoError(t, err) + // require.Equal is more to the point but prints a value i.e. logs a token when expected != actual + require.True(t, tk.Token == tk2.Token, "expected a cached token") + + caeTRO := tro + caeTRO.EnableCAE = true + tk3, err := cred.GetToken(ctx, caeTRO) + require.NoError(t, err) + require.False(t, tk.Token == tk3.Token, "expected a new token because the cached one isn't a CAE token") + + tk4, err := cred2.GetToken(ctx, caeTRO) + require.NoError(t, err) + require.True(t, tk3.Token == tk4.Token, "expected a cached token") + require.False(t, tk.Token == tk4.Token, "expected a CAE token") + }) + } +} diff --git a/sdk/azidentity/public_client.go b/sdk/azidentity/public_client.go index e76cb3bab4e6..5669ee9b1e11 100644 --- a/sdk/azidentity/public_client.go +++ b/sdk/azidentity/public_client.go @@ -30,12 +30,12 @@ type publicClientOptions struct { azcore.ClientOptions AdditionallyAllowedTenants []string + Cache Cache DeviceCodePrompt func(context.Context, DeviceCodeMessage) error DisableAutomaticAuthentication bool DisableInstanceDiscovery bool LoginHint, RedirectURL string Record AuthenticationRecord - TokenCachePersistenceOptions *TokenCachePersistenceOptions Username, Password string } @@ -222,13 +222,13 @@ func (p *publicClient) client(tro policy.TokenRequestOptions) (msalPublicClient, } func (p *publicClient) newMSALClient(enableCAE bool) (msalPublicClient, error) { - cache, err := internal.NewCache(p.opts.TokenCachePersistenceOptions, enableCAE) + c, err := internal.ExportReplace(p.opts.Cache, enableCAE) if err != nil { return nil, err } o := []public.Option{ public.WithAuthority(runtime.JoinPaths(p.host, p.tenantID)), - public.WithCache(cache), + public.WithCache(c), public.WithHTTPClient(p), } if enableCAE { diff --git a/sdk/azidentity/username_password_credential.go b/sdk/azidentity/username_password_credential.go index 33401f2c3fc3..475f379980b6 100644 --- a/sdk/azidentity/username_password_credential.go +++ b/sdk/azidentity/username_password_credential.go @@ -29,14 +29,16 @@ type UsernamePasswordCredentialOptions struct { // to enable the credential to use data from a previous authentication. AuthenticationRecord AuthenticationRecord + // Cache is a persistent cache the credential will use to store the tokens it acquires, making + // them available to other processes and credential instances. The default, zero value means the + // credential will store tokens in memory and not share them. + Cache Cache + // DisableInstanceDiscovery should be set true only by applications authenticating in disconnected clouds, or // private clouds such as Azure Stack. It determines whether the credential requests Microsoft Entra instance metadata // from https://login.microsoft.com before authenticating. Setting this to true will skip this request, making // the application responsible for ensuring the configured authority is valid and trustworthy. DisableInstanceDiscovery bool - - // TokenCachePersistenceOptions enables persistent token caching when not nil. - TokenCachePersistenceOptions *TokenCachePersistenceOptions } // UsernamePasswordCredential authenticates a user with a password. Microsoft doesn't recommend this kind of authentication, @@ -54,13 +56,13 @@ func NewUsernamePasswordCredential(tenantID string, clientID string, username st options = &UsernamePasswordCredentialOptions{} } opts := publicClientOptions{ - AdditionallyAllowedTenants: options.AdditionallyAllowedTenants, - ClientOptions: options.ClientOptions, - DisableInstanceDiscovery: options.DisableInstanceDiscovery, - Password: password, - Record: options.AuthenticationRecord, - TokenCachePersistenceOptions: options.TokenCachePersistenceOptions, - Username: username, + AdditionallyAllowedTenants: options.AdditionallyAllowedTenants, + Cache: options.Cache, + ClientOptions: options.ClientOptions, + DisableInstanceDiscovery: options.DisableInstanceDiscovery, + Password: password, + Record: options.AuthenticationRecord, + Username: username, } c, err := newPublicClient(tenantID, clientID, credNameUserPassword, opts) if err != nil { diff --git a/sdk/azidentity/workload_identity.go b/sdk/azidentity/workload_identity.go index 3e43e788e931..c4713c8e900f 100644 --- a/sdk/azidentity/workload_identity.go +++ b/sdk/azidentity/workload_identity.go @@ -39,15 +39,24 @@ type WorkloadIdentityCredentialOptions struct { // Add the wildcard value "*" to allow the credential to acquire tokens for any tenant in which the // application is registered. AdditionallyAllowedTenants []string + + // Cache is a persistent cache the credential will use to store the tokens it acquires, making + // them available to other processes and credential instances. The default, zero value means the + // credential will store tokens in memory and not share them. + Cache Cache + // ClientID of the service principal. Defaults to the value of the environment variable AZURE_CLIENT_ID. ClientID string + // DisableInstanceDiscovery should be set true only by applications authenticating in disconnected clouds, or // private clouds such as Azure Stack. It determines whether the credential requests Microsoft Entra instance metadata // from https://login.microsoft.com before authenticating. Setting this to true will skip this request, making // the application responsible for ensuring the configured authority is valid and trustworthy. DisableInstanceDiscovery bool + // TenantID of the service principal. Defaults to the value of the environment variable AZURE_TENANT_ID. TenantID string + // TokenFilePath is the path of a file containing a Kubernetes service account token. Defaults to the value of the // environment variable AZURE_FEDERATED_TOKEN_FILE. TokenFilePath string @@ -81,6 +90,7 @@ func NewWorkloadIdentityCredential(options *WorkloadIdentityCredentialOptions) ( w := WorkloadIdentityCredential{file: file, mtx: &sync.RWMutex{}} caco := ClientAssertionCredentialOptions{ AdditionallyAllowedTenants: options.AdditionallyAllowedTenants, + Cache: options.Cache, ClientOptions: options.ClientOptions, DisableInstanceDiscovery: options.DisableInstanceDiscovery, }