diff --git a/copy.go b/copy.go index 79feda53..0e0d5d85 100644 --- a/copy.go +++ b/copy.go @@ -284,6 +284,8 @@ func mountOrCopyNode(ctx context.Context, src content.ReadOnlyStorage, dst conte // Only care to mount blobs if descriptor.IsManifest(desc) { + // mountOrCopyNode might never be called with a Manifest based on how copyGraph() is currently implemented + // but it is safer to handle all cases so we keep this here return copyNode(ctx, src, dst, desc, opts) } diff --git a/copy_test.go b/copy_test.go index 2343d18f..9a067e52 100644 --- a/copy_test.go +++ b/copy_test.go @@ -1456,7 +1456,7 @@ func TestCopyGraph_WithOptions(t *testing.T) { }, } if err := oras.CopyGraph(ctx, src, dst, root, opts); err != nil { - t.Fatalf("CopyGraph() error = %v, wantErr %v", err, errdef.ErrSizeExceedsLimit) + t.Fatalf("CopyGraph() error = %v", err) } if got, expected := dst.numExists.Load(), int64(7); got != expected { @@ -1515,7 +1515,7 @@ func TestCopyGraph_WithOptions(t *testing.T) { return []string{"source"}, nil } if err := oras.CopyGraph(ctx, src, dst, root, opts); err != nil { - t.Fatalf("CopyGraph() error = %v, wantErr %v", err, errdef.ErrSizeExceedsLimit) + t.Fatalf("CopyGraph() error = %v", err) } if got, expected := dst.numExists.Load(), int64(7); got != expected { @@ -1590,7 +1590,7 @@ func TestCopyGraph_WithOptions(t *testing.T) { return []string{"source"}, nil } if err := oras.CopyGraph(ctx, src, dst, root, opts); err != nil { - t.Fatalf("CopyGraph() error = %v, wantErr %v", err, errdef.ErrSizeExceedsLimit) + t.Fatalf("CopyGraph() error = %v", err) } if got, expected := dst.numExists.Load(), int64(7); got != expected { @@ -1679,7 +1679,7 @@ func TestCopyGraph_WithOptions(t *testing.T) { return []string{"missing/the/data", "source"}, nil } if err := oras.CopyGraph(ctx, src, dst, root, opts); err != nil { - t.Fatalf("CopyGraph() error = %v, wantErr %v", err, errdef.ErrSizeExceedsLimit) + t.Fatalf("CopyGraph() error = %v", err) } if got, expected := dst.numExists.Load(), int64(7); got != expected { @@ -1709,6 +1709,134 @@ func TestCopyGraph_WithOptions(t *testing.T) { t.Errorf("count(PostCopy()) = %d, want %d", got, expected) } }) + + t.Run("MountFrom empty sourceRepositories", func(t *testing.T) { + root = descs[6] + dst := &countingStorage{storage: cas.NewMemory()} + opts = oras.CopyGraphOptions{} + var numMountFrom atomic.Int64 + opts.MountFrom = func(ctx context.Context, desc ocispec.Descriptor) ([]string, error) { + numMountFrom.Add(1) + return nil, nil + } + if err := oras.CopyGraph(ctx, src, dst, root, opts); err != nil { + t.Fatalf("CopyGraph() error = %v", err) + } + + if got, expected := dst.numExists.Load(), int64(7); got != expected { + t.Errorf("count(Exists()) = %d, want %d", got, expected) + } + if got, expected := dst.numFetch.Load(), int64(0); got != expected { + t.Errorf("count(Fetch()) = %d, want %d", got, expected) + } + if got, expected := dst.numPush.Load(), int64(7); got != expected { + t.Errorf("count(Push()) = %d, want %d", got, expected) + } + if got, expected := numMountFrom.Load(), int64(4); got != expected { + t.Errorf("count(MountFrom()) = %d, want %d", got, expected) + } + }) + + t.Run("MountFrom error", func(t *testing.T) { + root = descs[6] + dst := &countingStorage{storage: cas.NewMemory()} + opts = oras.CopyGraphOptions{} + var numMountFrom atomic.Int64 + e := errors.New("mountFrom error") + opts.MountFrom = func(ctx context.Context, desc ocispec.Descriptor) ([]string, error) { + numMountFrom.Add(1) + return nil, e + } + if err := oras.CopyGraph(ctx, src, dst, root, opts); !errors.Is(err, e) { + t.Fatalf("CopyGraph() error = %v, wantErr %v", err, e) + } + + if got, expected := dst.numExists.Load(), int64(7); got != expected { + t.Errorf("count(Exists()) = %d, want %d", got, expected) + } + if got, expected := dst.numFetch.Load(), int64(0); got != expected { + t.Errorf("count(Fetch()) = %d, want %d", got, expected) + } + if got, expected := dst.numPush.Load(), int64(0); got != expected { + t.Errorf("count(Push()) = %d, want %d", got, expected) + } + if got, expected := numMountFrom.Load(), int64(4); got != expected { + t.Errorf("count(MountFrom()) = %d, want %d", got, expected) + } + }) + + t.Run("MountFrom OnMounted error", func(t *testing.T) { + root = descs[6] + dst := &countingStorage{storage: cas.NewMemory()} + var numMount atomic.Int64 + dst.mount = func(ctx context.Context, + desc ocispec.Descriptor, + fromRepo string, + getContent func() (io.ReadCloser, error), + ) error { + numMount.Add(1) + if expected := "source"; fromRepo != expected { + t.Fatalf("fromRepo = %v, want %v", fromRepo, expected) + } + rc, err := src.Fetch(ctx, desc) + if err != nil { + t.Fatalf("Failed to fetch content: %v", err) + } + defer rc.Close() + err = dst.storage.Push(ctx, desc, rc) // bypass the counters + if err != nil { + t.Fatalf("Failed to push content: %v", err) + } + return nil + } + opts = oras.CopyGraphOptions{} + var numPreCopy, numPostCopy, numOnMounted, numMountFrom atomic.Int64 + opts.PreCopy = func(ctx context.Context, desc ocispec.Descriptor) error { + numPreCopy.Add(1) + return nil + } + opts.PostCopy = func(ctx context.Context, desc ocispec.Descriptor) error { + numPostCopy.Add(1) + return nil + } + e := errors.New("onMounted error") + opts.OnMounted = func(ctx context.Context, d ocispec.Descriptor) error { + numOnMounted.Add(1) + return e + } + opts.MountFrom = func(ctx context.Context, desc ocispec.Descriptor) ([]string, error) { + numMountFrom.Add(1) + return []string{"source"}, nil + } + if err := oras.CopyGraph(ctx, src, dst, root, opts); !errors.Is(err, e) { + t.Fatalf("CopyGraph() error = %v, wantErr %v", err, e) + } + + if got, expected := dst.numExists.Load(), int64(7); got != expected { + t.Errorf("count(Exists()) = %d, want %d", got, expected) + } + if got, expected := dst.numFetch.Load(), int64(0); got != expected { + t.Errorf("count(Fetch()) = %d, want %d", got, expected) + } + if got, expected := dst.numPush.Load(), int64(0); got != expected { + t.Errorf("count(Push()) = %d, want %d", got, expected) + } + if got, expected := numMount.Load(), int64(4); got != expected { + t.Errorf("count(Mount()) = %d, want %d", got, expected) + } + if got, expected := numOnMounted.Load(), int64(4); got != expected { + t.Errorf("count(OnMounted()) = %d, want %d", got, expected) + } + if got, expected := numMountFrom.Load(), int64(4); got != expected { + t.Errorf("count(MountFrom()) = %d, want %d", got, expected) + } + if got, expected := numPreCopy.Load(), int64(0); got != expected { + t.Errorf("count(PreCopy()) = %d, want %d", got, expected) + } + if got, expected := numPostCopy.Load(), int64(0); got != expected { + t.Errorf("count(PostCopy()) = %d, want %d", got, expected) + } + }) } // countingStorage counts the calls to its content.Storage methods