From 6f0c1f0b90805054aeefd163f568aa5914ddbad3 Mon Sep 17 00:00:00 2001 From: "Kyle M. Tarplee" Date: Mon, 27 Nov 2023 07:55:54 -0500 Subject: [PATCH] feat: added blob mounting support for oras Copy functions Implements a WithMount method on CopyGraphOptions Also allows for getContent to return ErrUnsupported to fall back to default behavior. Signed-off-by: Kyle M. Tarplee --- copy.go | 60 ++++++++++++ copy_test.go | 144 ++++++++++++++++++++++++++++- example_copy_test.go | 32 +++++++ registry/remote/repository.go | 4 + registry/remote/repository_test.go | 41 ++++++-- 5 files changed, 270 insertions(+), 11 deletions(-) diff --git a/copy.go b/copy.go index 9caed980..7d80e55e 100644 --- a/copy.go +++ b/copy.go @@ -114,6 +114,66 @@ type CopyGraphOptions struct { FindSuccessors func(ctx context.Context, fetcher content.Fetcher, desc ocispec.Descriptor) ([]ocispec.Descriptor, error) } +// WithMount enabled cross repository blob mounting. +// sourceReference is the repository to use for mounting (the mount point). +// mounter is the destination for the mount (a well-known implementation of this is *registry.Repository representing the target). +// onMounted is called (if provided) when the blob is mounted. +// The original PreCopy hook is called only on copy, and therefore not when the blob is mounted. +func (opts *CopyGraphOptions) WithMount(sourceRepository string, mounter registry.Mounter, onMounted func(context.Context, ocispec.Descriptor) error) { + preCopy := opts.PreCopy + opts.PreCopy = func(ctx context.Context, desc ocispec.Descriptor) error { + // Only care to mount blobs + if descriptor.IsManifest(desc) { + // still want to call PreCopy if it is a manifest + if preCopy != nil { + return preCopy(ctx, desc) + } + return nil + } + + var mountFailed bool + getContent := func() (io.ReadCloser, error) { + // call the original PreCopy function if it exists + if preCopy != nil { + if err := preCopy(ctx, desc); err != nil { + return nil, err + } + } + // the invocation of getContent indicates that mounting has failed + mountFailed = true + + // To avoid needing a content.Fetcher as an input argument we simply fall back to the default behavior + // as if getContent was nil + return nil, errdef.ErrUnsupported + } + + // Mount or copy + if err := mounter.Mount(ctx, desc, sourceRepository, getContent); err != nil { + return err + } + + if !mountFailed { + // mounted + if onMounted != nil { + if err := onMounted(ctx, desc); err != nil { + return err + } + } + // signal that the descriptor now exists + return SkipNode + } + + // we copied it + if opts.PostCopy != nil { + if err := opts.PostCopy(ctx, desc); err != nil { + return err + } + } + // signal that the descriptor now exists + return SkipNode + } +} + // Copy copies a rooted directed acyclic graph (DAG) with the tagged root node // in the source Target to the destination Target. // The destination reference will be the same as the source reference if the diff --git a/copy_test.go b/copy_test.go index 89ac7ed1..4efffccc 100644 --- a/copy_test.go +++ b/copy_test.go @@ -1471,11 +1471,143 @@ func TestCopyGraph_WithOptions(t *testing.T) { t.Errorf("count(Push()) = %d, want %d", got, expected) } }) + + t.Run("WithMount_Mounted", func(t *testing.T) { + root = descs[6] + dst := &countingStorage{storage: cas.NewMemory()} + var numOnMounted atomic.Int64 + m := mounter(func(ctx context.Context, + desc ocispec.Descriptor, + fromRepo string, + getContent func() (io.ReadCloser, error), + ) error { + if expected := "source"; fromRepo != expected { + t.Fatalf("fromRepo = %v, want %v", fromRepo, expected) + } + rc, err := src.Fetch(ctx, desc) + if err != nil { + t.Fatalf("Failed to fetch content: %v", err) + } + defer rc.Close() + err = dst.storage.Push(ctx, desc, rc) // bypass the counters + if err != nil { + t.Fatalf("Failed to push content: %v", err) + } + return nil + }) + opts = oras.CopyGraphOptions{} + var numPreCopy, numPostCopy atomic.Int64 + opts.PreCopy = func(ctx context.Context, desc ocispec.Descriptor) error { + numPreCopy.Add(1) + return nil + } + opts.PostCopy = func(ctx context.Context, desc ocispec.Descriptor) error { + numPostCopy.Add(1) + return nil + } + opts.WithMount("source", m, func(ctx context.Context, d ocispec.Descriptor) error { + numOnMounted.Add(1) + return nil + }) + if err := oras.CopyGraph(ctx, src, dst, root, opts); err != nil { + t.Fatalf("CopyGraph() error = %v, wantErr %v", err, errdef.ErrSizeExceedsLimit) + } + + if got, expected := dst.numExists.Load(), int64(7); got != expected { + t.Errorf("count(Exists()) = %d, want %d", got, expected) + } + if got, expected := dst.numFetch.Load(), int64(0); got != expected { + t.Errorf("count(Fetch()) = %d, want %d", got, expected) + } + // 7 (exists) - 1 (skipped) = 6 pushes expected + if got, expected := dst.numPush.Load(), int64(3); got != expected { + // If we get >=7 then ErrSkipDesc did not short circuit the push like it is supposed to do. + t.Errorf("count(Push()) = %d, want %d", got, expected) + } + if got, expected := numOnMounted.Load(), int64(4); got != expected { + t.Errorf("count(onMounted()) = %d, want %d", got, expected) + } + if got, expected := numPreCopy.Load(), int64(3); got != expected { + t.Errorf("count(PreCopy()) = %d, want %d", got, expected) + } + if got, expected := numPostCopy.Load(), int64(3); got != expected { + t.Errorf("count(PostCopy()) = %d, want %d", got, expected) + } + }) + + t.Run("WithMount_Copied", func(t *testing.T) { + root = descs[6] + dst := &countingStorage{storage: cas.NewMemory()} + var numOnMounted atomic.Int64 + m := mounter(func(ctx context.Context, + desc ocispec.Descriptor, + fromRepo string, + getContent func() (io.ReadCloser, error), + ) error { + if expected := "source"; fromRepo != expected { + t.Fatalf("fromRepo = %v, want %v", fromRepo, expected) + } + + _, err := getContent() + if !errors.Is(err, errdef.ErrUnsupported) { + t.Fatalf("Expected error %v", errdef.ErrUnsupported) + } + rc, err := src.Fetch(ctx, desc) + if err != nil { + t.Fatalf("Failed to fetch content: %v", err) + } + defer rc.Close() + err = dst.storage.Push(ctx, desc, rc) // bypass the counters + if err != nil { + t.Fatalf("Failed to push content: %v", err) + } + return nil + }) + opts = oras.CopyGraphOptions{} + var numPreCopy, numPostCopy atomic.Int64 + opts.PreCopy = func(ctx context.Context, desc ocispec.Descriptor) error { + numPreCopy.Add(1) + return nil + } + opts.PostCopy = func(ctx context.Context, desc ocispec.Descriptor) error { + numPostCopy.Add(1) + return nil + } + opts.WithMount("source", m, func(ctx context.Context, d ocispec.Descriptor) error { + numOnMounted.Add(1) + return nil + }) + if err := oras.CopyGraph(ctx, src, dst, root, opts); err != nil { + t.Fatalf("CopyGraph() error = %v, wantErr %v", err, errdef.ErrSizeExceedsLimit) + } + + if got, expected := dst.numExists.Load(), int64(7); got != expected { + t.Errorf("count(Exists()) = %d, want %d", got, expected) + } + if got, expected := dst.numFetch.Load(), int64(0); got != expected { + t.Errorf("count(Fetch()) = %d, want %d", got, expected) + } + // 7 (exists) - 1 (skipped) = 6 pushes expected + if got, expected := dst.numPush.Load(), int64(3); got != expected { + // If we get >=7 then ErrSkipDesc did not short circuit the push like it is supposed to do. + t.Errorf("count(Push()) = %d, want %d", got, expected) + } + if got, expected := numOnMounted.Load(), int64(0); got != expected { + t.Errorf("count(onMounted()) = %d, want %d", got, expected) + } + if got, expected := numPreCopy.Load(), int64(7); got != expected { + t.Errorf("count(PreCopy()) = %d, want %d", got, expected) + } + if got, expected := numPostCopy.Load(), int64(7); got != expected { + t.Errorf("count(PostCopy()) = %d, want %d", got, expected) + } + }) } // countingStorage counts the calls to its content.Storage methods type countingStorage struct { - storage content.Storage + storage content.Storage + numExists, numFetch, numPush atomic.Int64 } @@ -1494,6 +1626,16 @@ func (cs *countingStorage) Push(ctx context.Context, target ocispec.Descriptor, return cs.storage.Push(ctx, target, r) } +type mounter func(context.Context, ocispec.Descriptor, string, func() (io.ReadCloser, error)) error + +func (m mounter) Mount(ctx context.Context, + desc ocispec.Descriptor, + fromRepo string, + getContent func() (io.ReadCloser, error), +) error { + return m(ctx, desc, fromRepo, getContent) +} + func TestCopyGraph_WithConcurrencyLimit(t *testing.T) { src := cas.NewMemory() // generate test content diff --git a/example_copy_test.go b/example_copy_test.go index 58ee9f56..ad5d4125 100644 --- a/example_copy_test.go +++ b/example_copy_test.go @@ -35,6 +35,7 @@ import ( "oras.land/oras-go/v2/content/memory" "oras.land/oras-go/v2/content/oci" "oras.land/oras-go/v2/internal/spec" + "oras.land/oras-go/v2/registry" "oras.land/oras-go/v2/registry/remote" ) @@ -215,6 +216,37 @@ func ExampleCopy_remoteToRemote() { // sha256:7cbb44b44e8ede5a89cf193db3f5f2fd019d89697e6b87e8ed2589e60649b0d1 } +func ExampleCopy_remoteToRemoteWithMount() { + reg, err := remote.NewRegistry(remoteHost) + if err != nil { + panic(err) // Handle error + } + ctx := context.Background() + src, err := reg.Repository(ctx, "source") + if err != nil { + panic(err) // Handle error + } + dst, err := reg.Repository(ctx, "target") + if err != nil { + panic(err) // Handle error + } + + tagName := "latest" + + opts := oras.CopyOptions{} + // Enable cross-repository blob mounting + opts.WithMount("source", dst.(registry.Mounter), nil) + + desc, err := oras.Copy(ctx, src, tagName, dst, tagName, opts) + if err != nil { + panic(err) // Handle error + } + fmt.Println("Final", desc.Digest) + + // Output: + // Final sha256:7cbb44b44e8ede5a89cf193db3f5f2fd019d89697e6b87e8ed2589e60649b0d1 +} + func ExampleCopy_remoteToLocal() { reg, err := remote.NewRegistry(remoteHost) if err != nil { diff --git a/registry/remote/repository.go b/registry/remote/repository.go index d67240f2..80b57eb1 100644 --- a/registry/remote/repository.go +++ b/registry/remote/repository.go @@ -806,6 +806,10 @@ func (s *blobStore) Mount(ctx context.Context, desc ocispec.Descriptor, fromRepo var r io.ReadCloser if getContent != nil { r, err = getContent() + if errors.Is(err, errdef.ErrUnsupported) { + // getContent can return a ErrUnsupported to fallback to the default copy operation + r, err = s.sibling(fromRepo).Fetch(ctx, desc) + } } else { r, err = s.sibling(fromRepo).Fetch(ctx, desc) } diff --git a/registry/remote/repository_test.go b/registry/remote/repository_test.go index b66aec46..4578a840 100644 --- a/registry/remote/repository_test.go +++ b/registry/remote/repository_test.go @@ -421,16 +421,37 @@ func TestRepository_Mount_Fallback(t *testing.T) { repo.PlainHTTP = true ctx := context.Background() - err = repo.Mount(ctx, blobDesc, "test", nil) - if err != nil { - t.Fatalf("Repository.Push() error = %v", err) - } - if !bytes.Equal(gotBlob, blob) { - t.Errorf("Repository.Mount() = %v, want %v", gotBlob, blob) - } - if got, want := sequence, "post get put "; got != want { - t.Errorf("unexpected request sequence; got %q want %q", got, want) - } + t.Run("getContent is nil", func(t *testing.T) { + sequence = "" + + err = repo.Mount(ctx, blobDesc, "test", nil) + if err != nil { + t.Fatalf("Repository.Push() error = %v", err) + } + if !bytes.Equal(gotBlob, blob) { + t.Errorf("Repository.Mount() = %v, want %v", gotBlob, blob) + } + if got, want := sequence, "post get put "; got != want { + t.Errorf("unexpected request sequence; got %q want %q", got, want) + } + }) + + t.Run("getContent is ErrUnsupported", func(t *testing.T) { + sequence = "" + + err = repo.Mount(ctx, blobDesc, "test", func() (io.ReadCloser, error) { + return nil, errdef.ErrUnsupported + }) + if err != nil { + t.Fatalf("Repository.Push() error = %v", err) + } + if !bytes.Equal(gotBlob, blob) { + t.Errorf("Repository.Mount() = %v, want %v", gotBlob, blob) + } + if got, want := sequence, "post get put "; got != want { + t.Errorf("unexpected request sequence; got %q want %q", got, want) + } + }) } func TestRepository_Mount_Error(t *testing.T) {