From 8f9f505b78b619b5ab09c4d108997f6033cc542a Mon Sep 17 00:00:00 2001 From: Shiwei Zhang Date: Thu, 21 Mar 2024 11:08:27 +0800 Subject: [PATCH] fix(DynamicStore): retry `setCredsStore` on next `PUT` (#728) Fix #727 Signed-off-by: Shiwei Zhang --- internal/syncutil/once.go | 36 ++++++++++- internal/syncutil/once_test.go | 95 ++++++++++++++++++++++++++++ registry/remote/credentials/store.go | 12 ++-- 3 files changed, 135 insertions(+), 8 deletions(-) diff --git a/internal/syncutil/once.go b/internal/syncutil/once.go index 1d5571980..e44970530 100644 --- a/internal/syncutil/once.go +++ b/internal/syncutil/once.go @@ -15,10 +15,14 @@ limitations under the License. package syncutil -import "context" +import ( + "context" + "sync" + "sync/atomic" +) // Once is an object that will perform exactly one action. -// Unlike sync.Once, this Once allowes the action to have return values. +// Unlike sync.Once, this Once allows the action to have return values. type Once struct { result interface{} err error @@ -68,3 +72,31 @@ func (o *Once) Do(ctx context.Context, f func() (interface{}, error)) (bool, int } } } + +// OnceOrRetry is an object that will perform exactly one success action. +type OnceOrRetry struct { + done atomic.Bool + lock sync.Mutex +} + +// OnceOrRetry calls the function f if and only if Do is being called for the +// first time for this instance of Once or all previous calls to Do are failed. +func (o *OnceOrRetry) Do(f func() error) error { + // fast path + if o.done.Load() { + return nil + } + + // slow path + o.lock.Lock() + defer o.lock.Unlock() + + if o.done.Load() { + return nil + } + if err := f(); err != nil { + return err + } + o.done.Store(true) + return nil +} diff --git a/internal/syncutil/once_test.go b/internal/syncutil/once_test.go index 2bcba2f12..461c1e689 100644 --- a/internal/syncutil/once_test.go +++ b/internal/syncutil/once_test.go @@ -22,6 +22,7 @@ import ( "reflect" "strconv" "sync" + "sync/atomic" "testing" "time" ) @@ -191,3 +192,97 @@ func TestOnce_Do_Cancel_Panic(t *testing.T) { t.Fatalf("Once.Do() result = %v, want %v", result, wantResult) } } + +func TestOnceOrRetry_Do(t *testing.T) { + var once OnceOrRetry + var count atomic.Int32 + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + err := once.Do(func() error { + count.Add(1) + return nil + }) + if err != nil { + t.Errorf("OnceOrRetry.Do() error = %v, wantErr %v", err, nil) + } + }() + } + wg.Wait() + + if got := count.Load(); got != 1 { + t.Fatal("OnceOrRetry.Do() called more than once") + } +} + +func TestOnceOrRetry_Do_Fail(t *testing.T) { + var once OnceOrRetry + var wg sync.WaitGroup + + // test failure + for i := 0; i < 100; i++ { + wg.Add(1) + go func(wantErr error) { + defer wg.Done() + err := once.Do(func() error { + return wantErr + }) + if err != wantErr { + t.Errorf("OnceOrRetry.Do() error = %v, wantErr %v", err, wantErr) + } + }(errors.New(strconv.Itoa(i))) + } + wg.Wait() + + // retry after failure + err := once.Do(func() error { + return nil + }) + if err != nil { + t.Fatalf("OnceOrRetry.Do() error = %v, wantErr %v", err, nil) + } + + // no retry after success + err = once.Do(func() error { + t.Fatal("OnceOrRetry.Do() called twice") + return nil + }) + if err != nil { + t.Fatalf("OnceOrRetry.Do() error = %v, wantErr %v", err, nil) + } +} + +func TestOnceOrRetry_Do_Panic(t *testing.T) { + var once OnceOrRetry + + // test panic + func() { + defer func() { + if r := recover(); r == nil { + t.Fatal("OnceOrRetry.Do() did not panic") + } + }() + _ = once.Do(func() error { + panic("failed") + }) + }() + + // retry after panic + err := once.Do(func() error { + return nil + }) + if err != nil { + t.Fatalf("OnceOrRetry.Do() error = %v, wantErr %v", err, nil) + } + + // no retry after success + err = once.Do(func() error { + t.Fatal("OnceOrRetry.Do() called twice") + return nil + }) + if err != nil { + t.Fatalf("OnceOrRetry.Do() error = %v, wantErr %v", err, nil) + } +} diff --git a/registry/remote/credentials/store.go b/registry/remote/credentials/store.go index ae8ce5be3..e26a98ae7 100644 --- a/registry/remote/credentials/store.go +++ b/registry/remote/credentials/store.go @@ -25,8 +25,8 @@ import ( "fmt" "os" "path/filepath" - "sync" + "oras.land/oras-go/v2/internal/syncutil" "oras.land/oras-go/v2/registry/remote/auth" "oras.land/oras-go/v2/registry/remote/credentials/internal/config" ) @@ -53,7 +53,7 @@ type DynamicStore struct { config *config.Config options StoreOptions detectedCredsStore string - setCredsStoreOnce sync.Once + setCredsStoreOnce syncutil.OnceOrRetry } // StoreOptions provides options for NewStore. @@ -136,19 +136,19 @@ func (ds *DynamicStore) Get(ctx context.Context, serverAddress string) (auth.Cre // Put saves credentials into the store for the given server address. // Put returns ErrPlaintextPutDisabled if native store is not available and // [StoreOptions].AllowPlaintextPut is set to false. -func (ds *DynamicStore) Put(ctx context.Context, serverAddress string, cred auth.Credential) (returnErr error) { +func (ds *DynamicStore) Put(ctx context.Context, serverAddress string, cred auth.Credential) error { if err := ds.getStore(serverAddress).Put(ctx, serverAddress, cred); err != nil { return err } // save the detected creds store back to the config file on first put - ds.setCredsStoreOnce.Do(func() { + return ds.setCredsStoreOnce.Do(func() error { if ds.detectedCredsStore != "" { if err := ds.config.SetCredentialsStore(ds.detectedCredsStore); err != nil { - returnErr = fmt.Errorf("failed to set credsStore: %w", err) + return fmt.Errorf("failed to set credsStore: %w", err) } } + return nil }) - return returnErr } // Delete removes credentials from the store for the given server address.