diff --git a/copy_test.go b/copy_test.go index 0b6e6c20..c94aad2f 100644 --- a/copy_test.go +++ b/copy_test.go @@ -1471,11 +1471,146 @@ func TestCopyGraph_WithOptions(t *testing.T) { t.Errorf("count(Push()) = %d, want %d", got, expected) } }) + + t.Run("WithMount_Mounted", func(t *testing.T) { + root = descs[6] + dst := &countingStorage{storage: cas.NewMemory()} + var numOnMounted atomic.Int64 + m := mounter(func(ctx context.Context, + desc ocispec.Descriptor, + fromRepo string, + getContent func() (io.ReadCloser, error), + ) error { + if expected := "source"; fromRepo != expected { + t.Fatalf("fromRepo = %v, want %v", fromRepo, expected) + } + + // _, err := getContent() + // if !errors.Is(err, errdef.ErrUnsupported) { + // t.Fatalf("Expected error %v", errdef.ErrUnsupported) + // } + rc, err := src.Fetch(ctx, desc) + if err != nil { + t.Fatalf("Failed to fetch content: %v", err) + } + defer rc.Close() + err = dst.storage.Push(ctx, desc, rc) // bypass the counters + if err != nil { + t.Fatalf("Failed to push content: %v", err) + } + return nil + }) + opts = oras.CopyGraphOptions{} + var numPreCopy, numPostCopy atomic.Int64 + opts.PreCopy = func(ctx context.Context, desc ocispec.Descriptor) error { + numPreCopy.Add(1) + return nil + } + opts.PostCopy = func(ctx context.Context, desc ocispec.Descriptor) error { + numPostCopy.Add(1) + return nil + } + opts.WithMount("source", m, func(ctx context.Context, d ocispec.Descriptor) { + numOnMounted.Add(1) + }) + if err := oras.CopyGraph(ctx, src, dst, root, opts); err != nil { + t.Fatalf("CopyGraph() error = %v, wantErr %v", err, errdef.ErrSizeExceedsLimit) + } + + if got, expected := dst.numExists.Load(), int64(7); got != expected { + t.Errorf("count(Exists()) = %d, want %d", got, expected) + } + if got, expected := dst.numFetch.Load(), int64(0); got != expected { + t.Errorf("count(Fetch()) = %d, want %d", got, expected) + } + // 7 (exists) - 1 (skipped) = 6 pushes expected + if got, expected := dst.numPush.Load(), int64(3); got != expected { + // If we get >=7 then ErrSkipDesc did not short circuit the push like it is supposed to do. + t.Errorf("count(Push()) = %d, want %d", got, expected) + } + if got, expected := numOnMounted.Load(), int64(4); got != expected { + t.Errorf("count(onMounted()) = %d, want %d", got, expected) + } + if got, expected := numPreCopy.Load(), int64(3); got != expected { + t.Errorf("count(PreCopy()) = %d, want %d", got, expected) + } + if got, expected := numPostCopy.Load(), int64(3); got != expected { + t.Errorf("count(PostCopy()) = %d, want %d", got, expected) + } + }) + + t.Run("WithMount_Copied", func(t *testing.T) { + root = descs[6] + dst := &countingStorage{storage: cas.NewMemory()} + var numOnMounted atomic.Int64 + m := mounter(func(ctx context.Context, + desc ocispec.Descriptor, + fromRepo string, + getContent func() (io.ReadCloser, error), + ) error { + if expected := "source"; fromRepo != expected { + t.Fatalf("fromRepo = %v, want %v", fromRepo, expected) + } + + _, err := getContent() + if !errors.Is(err, errdef.ErrUnsupported) { + t.Fatalf("Expected error %v", errdef.ErrUnsupported) + } + rc, err := src.Fetch(ctx, desc) + if err != nil { + t.Fatalf("Failed to fetch content: %v", err) + } + defer rc.Close() + err = dst.storage.Push(ctx, desc, rc) // bypass the counters + if err != nil { + t.Fatalf("Failed to push content: %v", err) + } + return nil + }) + opts = oras.CopyGraphOptions{} + var numPreCopy, numPostCopy atomic.Int64 + opts.PreCopy = func(ctx context.Context, desc ocispec.Descriptor) error { + numPreCopy.Add(1) + return nil + } + opts.PostCopy = func(ctx context.Context, desc ocispec.Descriptor) error { + numPostCopy.Add(1) + return nil + } + opts.WithMount("source", m, func(ctx context.Context, d ocispec.Descriptor) { + numOnMounted.Add(1) + }) + if err := oras.CopyGraph(ctx, src, dst, root, opts); err != nil { + t.Fatalf("CopyGraph() error = %v, wantErr %v", err, errdef.ErrSizeExceedsLimit) + } + + if got, expected := dst.numExists.Load(), int64(7); got != expected { + t.Errorf("count(Exists()) = %d, want %d", got, expected) + } + if got, expected := dst.numFetch.Load(), int64(0); got != expected { + t.Errorf("count(Fetch()) = %d, want %d", got, expected) + } + // 7 (exists) - 1 (skipped) = 6 pushes expected + if got, expected := dst.numPush.Load(), int64(3); got != expected { + // If we get >=7 then ErrSkipDesc did not short circuit the push like it is supposed to do. + t.Errorf("count(Push()) = %d, want %d", got, expected) + } + if got, expected := numOnMounted.Load(), int64(0); got != expected { + t.Errorf("count(onMounted()) = %d, want %d", got, expected) + } + if got, expected := numPreCopy.Load(), int64(7); got != expected { + t.Errorf("count(PreCopy()) = %d, want %d", got, expected) + } + if got, expected := numPostCopy.Load(), int64(7); got != expected { + t.Errorf("count(PostCopy()) = %d, want %d", got, expected) + } + }) } // countingStorage counts the calls to its content.Storage methods type countingStorage struct { - storage content.Storage + storage content.Storage + numExists, numFetch, numPush atomic.Int64 } @@ -1494,6 +1629,16 @@ func (cs *countingStorage) Push(ctx context.Context, target ocispec.Descriptor, return cs.storage.Push(ctx, target, r) } +type mounter func(context.Context, ocispec.Descriptor, string, func() (io.ReadCloser, error)) error + +func (m mounter) Mount(ctx context.Context, + desc ocispec.Descriptor, + fromRepo string, + getContent func() (io.ReadCloser, error), +) error { + return m(ctx, desc, fromRepo, getContent) +} + func TestCopyGraph_WithConcurrencyLimit(t *testing.T) { src := cas.NewMemory() // generate test content diff --git a/example_copy_test.go b/example_copy_test.go index b7e524a9..8ef6b0b5 100644 --- a/example_copy_test.go +++ b/example_copy_test.go @@ -35,6 +35,7 @@ import ( "oras.land/oras-go/v2/content/memory" "oras.land/oras-go/v2/content/oci" "oras.land/oras-go/v2/internal/spec" + "oras.land/oras-go/v2/registry" "oras.land/oras-go/v2/registry/remote" ) @@ -214,6 +215,37 @@ func ExampleCopy_remoteToRemote() { // sha256:7cbb44b44e8ede5a89cf193db3f5f2fd019d89697e6b87e8ed2589e60649b0d1 } +func ExampleCopy_remoteToRemote_WithMount() { + reg, err := remote.NewRegistry(remoteHost) + if err != nil { + panic(err) // Handle error + } + ctx := context.Background() + src, err := reg.Repository(ctx, "source") + if err != nil { + panic(err) // Handle error + } + dst, err := reg.Repository(ctx, "target") + if err != nil { + panic(err) // Handle error + } + + tagName := "latest" + + opts := oras.CopyOptions{} + // Enable cross-repository blob mounting + opts.WithMount("source", dst.(registry.Mounter), nil) + + desc, err := oras.Copy(ctx, src, tagName, dst, tagName, opts) + if err != nil { + panic(err) // Handle error + } + fmt.Println("Final", desc.Digest) + + // Output: + // Final sha256:7cbb44b44e8ede5a89cf193db3f5f2fd019d89697e6b87e8ed2589e60649b0d1 +} + func ExampleCopy_remoteToLocal() { reg, err := remote.NewRegistry(remoteHost) if err != nil { diff --git a/registry/remote/repository_test.go b/registry/remote/repository_test.go index b6772cbd..4e03f2cb 100644 --- a/registry/remote/repository_test.go +++ b/registry/remote/repository_test.go @@ -421,16 +421,37 @@ func TestRepository_Mount_Fallback(t *testing.T) { repo.PlainHTTP = true ctx := context.Background() - err = repo.Mount(ctx, blobDesc, "test", nil) - if err != nil { - t.Fatalf("Repository.Push() error = %v", err) - } - if !bytes.Equal(gotBlob, blob) { - t.Errorf("Repository.Mount() = %v, want %v", gotBlob, blob) - } - if got, want := sequence, "post get put "; got != want { - t.Errorf("unexpected request sequence; got %q want %q", got, want) - } + t.Run("nil", func(t *testing.T) { + sequence = "" + + err = repo.Mount(ctx, blobDesc, "test", nil) + if err != nil { + t.Fatalf("Repository.Push() error = %v", err) + } + if !bytes.Equal(gotBlob, blob) { + t.Errorf("Repository.Mount() = %v, want %v", gotBlob, blob) + } + if got, want := sequence, "post get put "; got != want { + t.Errorf("unexpected request sequence; got %q want %q", got, want) + } + }) + + t.Run("ErrUnsupported", func(t *testing.T) { + sequence = "" + + err = repo.Mount(ctx, blobDesc, "test", func() (io.ReadCloser, error) { + return nil, errdef.ErrUnsupported + }) + if err != nil { + t.Fatalf("Repository.Push() error = %v", err) + } + if !bytes.Equal(gotBlob, blob) { + t.Errorf("Repository.Mount() = %v, want %v", gotBlob, blob) + } + if got, want := sequence, "post get put "; got != want { + t.Errorf("unexpected request sequence; got %q want %q", got, want) + } + }) } func TestRepository_Mount_Error(t *testing.T) {