Skip to content

Commit

Permalink
feat: added blob mounting support for oras Copy functions
Browse files Browse the repository at this point in the history
Implements a WithMount method on CopyGraphOptions

Also allows for getContent to return ErrUnsupported to fall back to default behavior.

Signed-off-by: Kyle M. Tarplee <[email protected]>
  • Loading branch information
ktarplee committed Nov 22, 2023
1 parent 79a08b4 commit 45a1cb8
Show file tree
Hide file tree
Showing 5 changed files with 271 additions and 11 deletions.
58 changes: 58 additions & 0 deletions copy.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,64 @@ type CopyGraphOptions struct {
FindSuccessors func(ctx context.Context, fetcher content.Fetcher, desc ocispec.Descriptor) ([]ocispec.Descriptor, error)
}

// WithMount enabled cross repository blob mounting.
// sourceReference is the repository to use for mounting (the mount point).
// mounter is the destination for the mount (a well-known implementation of this is *registry.Repository representing the target).
// onMounted is called (if provided) when the blob is mounted.
// The original PreCopy hook is called only on copy, and therefore not when the blob is mounted.
func (opts *CopyGraphOptions) WithMount(sourceRepository string, mounter registry.Mounter, onMounted func(context.Context, ocispec.Descriptor)) {
preCopy := opts.PreCopy
opts.PreCopy = func(ctx context.Context, desc ocispec.Descriptor) error {
// Only care to mount blobs
if descriptor.IsManifest(desc) {
// still want to call PreCopy if it is a manifest
if preCopy != nil {
return preCopy(ctx, desc)
}
return nil
}

var mountFailed bool
getContent := func() (io.ReadCloser, error) {
// call the original PreCopy function if it exists
if preCopy != nil {
if err := preCopy(ctx, desc); err != nil {
return nil, err
}
}
// the invocation of getContent indicates that mounting has failed
mountFailed = true

// To avoid needing a content.Fetcher as an input argument we simply fall back to the default behavior
// as if getContent was nil
return nil, errdef.ErrUnsupported
}

// Mount or copy
if err := mounter.Mount(ctx, desc, sourceRepository, getContent); err != nil {
return err
}

if !mountFailed {
// mounted
if onMounted != nil {
onMounted(ctx, desc)
}
// signal that the descriptor now exists
return SkipNode
}

// we copied it
if opts.PostCopy != nil {
if err := opts.PostCopy(ctx, desc); err != nil {
return err
}
}
// signal that the descriptor now exists
return SkipNode
}
}

// Copy copies a rooted directed acyclic graph (DAG) with the tagged root node
// in the source Target to the destination Target.
// The destination reference will be the same as the source reference if the
Expand Down
147 changes: 146 additions & 1 deletion copy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1471,11 +1471,146 @@ func TestCopyGraph_WithOptions(t *testing.T) {
t.Errorf("count(Push()) = %d, want %d", got, expected)
}
})

t.Run("WithMount_Mounted", func(t *testing.T) {
root = descs[6]
dst := &countingStorage{storage: cas.NewMemory()}
var numOnMounted atomic.Int64
m := mounter(func(ctx context.Context,
desc ocispec.Descriptor,
fromRepo string,
getContent func() (io.ReadCloser, error),
) error {
if expected := "source"; fromRepo != expected {
t.Fatalf("fromRepo = %v, want %v", fromRepo, expected)
}

// _, err := getContent()
// if !errors.Is(err, errdef.ErrUnsupported) {
// t.Fatalf("Expected error %v", errdef.ErrUnsupported)
// }
rc, err := src.Fetch(ctx, desc)
if err != nil {
t.Fatalf("Failed to fetch content: %v", err)
}
defer rc.Close()
err = dst.storage.Push(ctx, desc, rc) // bypass the counters
if err != nil {
t.Fatalf("Failed to push content: %v", err)
}
return nil
})
opts = oras.CopyGraphOptions{}
var numPreCopy, numPostCopy atomic.Int64
opts.PreCopy = func(ctx context.Context, desc ocispec.Descriptor) error {
numPreCopy.Add(1)
return nil
}
opts.PostCopy = func(ctx context.Context, desc ocispec.Descriptor) error {
numPostCopy.Add(1)
return nil
}
opts.WithMount("source", m, func(ctx context.Context, d ocispec.Descriptor) {
numOnMounted.Add(1)
})
if err := oras.CopyGraph(ctx, src, dst, root, opts); err != nil {
t.Fatalf("CopyGraph() error = %v, wantErr %v", err, errdef.ErrSizeExceedsLimit)
}

if got, expected := dst.numExists.Load(), int64(7); got != expected {
t.Errorf("count(Exists()) = %d, want %d", got, expected)
}
if got, expected := dst.numFetch.Load(), int64(0); got != expected {
t.Errorf("count(Fetch()) = %d, want %d", got, expected)
}
// 7 (exists) - 1 (skipped) = 6 pushes expected
if got, expected := dst.numPush.Load(), int64(3); got != expected {
// If we get >=7 then ErrSkipDesc did not short circuit the push like it is supposed to do.
t.Errorf("count(Push()) = %d, want %d", got, expected)
}
if got, expected := numOnMounted.Load(), int64(4); got != expected {
t.Errorf("count(onMounted()) = %d, want %d", got, expected)
}
if got, expected := numPreCopy.Load(), int64(3); got != expected {
t.Errorf("count(PreCopy()) = %d, want %d", got, expected)
}
if got, expected := numPostCopy.Load(), int64(3); got != expected {
t.Errorf("count(PostCopy()) = %d, want %d", got, expected)
}
})

t.Run("WithMount_Copied", func(t *testing.T) {
root = descs[6]
dst := &countingStorage{storage: cas.NewMemory()}
var numOnMounted atomic.Int64
m := mounter(func(ctx context.Context,
desc ocispec.Descriptor,
fromRepo string,
getContent func() (io.ReadCloser, error),
) error {
if expected := "source"; fromRepo != expected {
t.Fatalf("fromRepo = %v, want %v", fromRepo, expected)
}

_, err := getContent()
if !errors.Is(err, errdef.ErrUnsupported) {
t.Fatalf("Expected error %v", errdef.ErrUnsupported)
}
rc, err := src.Fetch(ctx, desc)
if err != nil {
t.Fatalf("Failed to fetch content: %v", err)
}
defer rc.Close()
err = dst.storage.Push(ctx, desc, rc) // bypass the counters
if err != nil {
t.Fatalf("Failed to push content: %v", err)
}
return nil
})
opts = oras.CopyGraphOptions{}
var numPreCopy, numPostCopy atomic.Int64
opts.PreCopy = func(ctx context.Context, desc ocispec.Descriptor) error {
numPreCopy.Add(1)
return nil
}
opts.PostCopy = func(ctx context.Context, desc ocispec.Descriptor) error {
numPostCopy.Add(1)
return nil
}
opts.WithMount("source", m, func(ctx context.Context, d ocispec.Descriptor) {
numOnMounted.Add(1)
})
if err := oras.CopyGraph(ctx, src, dst, root, opts); err != nil {
t.Fatalf("CopyGraph() error = %v, wantErr %v", err, errdef.ErrSizeExceedsLimit)
}

if got, expected := dst.numExists.Load(), int64(7); got != expected {
t.Errorf("count(Exists()) = %d, want %d", got, expected)
}
if got, expected := dst.numFetch.Load(), int64(0); got != expected {
t.Errorf("count(Fetch()) = %d, want %d", got, expected)
}
// 7 (exists) - 1 (skipped) = 6 pushes expected
if got, expected := dst.numPush.Load(), int64(3); got != expected {
// If we get >=7 then ErrSkipDesc did not short circuit the push like it is supposed to do.
t.Errorf("count(Push()) = %d, want %d", got, expected)
}
if got, expected := numOnMounted.Load(), int64(0); got != expected {
t.Errorf("count(onMounted()) = %d, want %d", got, expected)
}
if got, expected := numPreCopy.Load(), int64(7); got != expected {
t.Errorf("count(PreCopy()) = %d, want %d", got, expected)
}
if got, expected := numPostCopy.Load(), int64(7); got != expected {
t.Errorf("count(PostCopy()) = %d, want %d", got, expected)
}
})
}

// countingStorage counts the calls to its content.Storage methods
type countingStorage struct {
storage content.Storage
storage content.Storage

numExists, numFetch, numPush atomic.Int64
}

Expand All @@ -1494,6 +1629,16 @@ func (cs *countingStorage) Push(ctx context.Context, target ocispec.Descriptor,
return cs.storage.Push(ctx, target, r)
}

type mounter func(context.Context, ocispec.Descriptor, string, func() (io.ReadCloser, error)) error

func (m mounter) Mount(ctx context.Context,
desc ocispec.Descriptor,
fromRepo string,
getContent func() (io.ReadCloser, error),
) error {
return m(ctx, desc, fromRepo, getContent)
}

func TestCopyGraph_WithConcurrencyLimit(t *testing.T) {
src := cas.NewMemory()
// generate test content
Expand Down
32 changes: 32 additions & 0 deletions example_copy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
"oras.land/oras-go/v2/content/memory"
"oras.land/oras-go/v2/content/oci"
"oras.land/oras-go/v2/internal/spec"
"oras.land/oras-go/v2/registry"
"oras.land/oras-go/v2/registry/remote"
)

Expand Down Expand Up @@ -215,6 +216,37 @@ func ExampleCopy_remoteToRemote() {
// sha256:7cbb44b44e8ede5a89cf193db3f5f2fd019d89697e6b87e8ed2589e60649b0d1
}

func ExampleCopy_remoteToRemoteWithMount() {
reg, err := remote.NewRegistry(remoteHost)
if err != nil {
panic(err) // Handle error
}
ctx := context.Background()
src, err := reg.Repository(ctx, "source")
if err != nil {
panic(err) // Handle error
}
dst, err := reg.Repository(ctx, "target")
if err != nil {
panic(err) // Handle error
}

tagName := "latest"

opts := oras.CopyOptions{}
// Enable cross-repository blob mounting
opts.WithMount("source", dst.(registry.Mounter), nil)

desc, err := oras.Copy(ctx, src, tagName, dst, tagName, opts)
if err != nil {
panic(err) // Handle error
}
fmt.Println("Final", desc.Digest)

// Output:
// Final sha256:7cbb44b44e8ede5a89cf193db3f5f2fd019d89697e6b87e8ed2589e60649b0d1
}

func ExampleCopy_remoteToLocal() {
reg, err := remote.NewRegistry(remoteHost)
if err != nil {
Expand Down
4 changes: 4 additions & 0 deletions registry/remote/repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,10 @@ func (s *blobStore) Mount(ctx context.Context, desc ocispec.Descriptor, fromRepo
var r io.ReadCloser
if getContent != nil {
r, err = getContent()
if errors.Is(err, errdef.ErrUnsupported) {
// getContent can return a ErrUnsupported to fallback to the default copy operation
r, err = s.sibling(fromRepo).Fetch(ctx, desc)
}
} else {
r, err = s.sibling(fromRepo).Fetch(ctx, desc)
}
Expand Down
41 changes: 31 additions & 10 deletions registry/remote/repository_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -421,16 +421,37 @@ func TestRepository_Mount_Fallback(t *testing.T) {
repo.PlainHTTP = true
ctx := context.Background()

err = repo.Mount(ctx, blobDesc, "test", nil)
if err != nil {
t.Fatalf("Repository.Push() error = %v", err)
}
if !bytes.Equal(gotBlob, blob) {
t.Errorf("Repository.Mount() = %v, want %v", gotBlob, blob)
}
if got, want := sequence, "post get put "; got != want {
t.Errorf("unexpected request sequence; got %q want %q", got, want)
}
t.Run("nil", func(t *testing.T) {
sequence = ""

err = repo.Mount(ctx, blobDesc, "test", nil)
if err != nil {
t.Fatalf("Repository.Push() error = %v", err)
}
if !bytes.Equal(gotBlob, blob) {
t.Errorf("Repository.Mount() = %v, want %v", gotBlob, blob)
}
if got, want := sequence, "post get put "; got != want {
t.Errorf("unexpected request sequence; got %q want %q", got, want)
}
})

t.Run("ErrUnsupported", func(t *testing.T) {
sequence = ""

err = repo.Mount(ctx, blobDesc, "test", func() (io.ReadCloser, error) {
return nil, errdef.ErrUnsupported
})
if err != nil {
t.Fatalf("Repository.Push() error = %v", err)
}
if !bytes.Equal(gotBlob, blob) {
t.Errorf("Repository.Mount() = %v, want %v", gotBlob, blob)
}
if got, want := sequence, "post get put "; got != want {
t.Errorf("unexpected request sequence; got %q want %q", got, want)
}
})
}

func TestRepository_Mount_Error(t *testing.T) {
Expand Down

0 comments on commit 45a1cb8

Please sign in to comment.