Skip to content

Commit

Permalink
feat: added blob mounting support for oras Copy functions
Browse files Browse the repository at this point in the history
Adds MountFrom and OnMounted to CopyGraphOptions.
Allows for trying to mount from multiple repositories.

Signed-off-by: Kyle M. Tarplee <[email protected]>
  • Loading branch information
ktarplee committed Jan 3, 2024
1 parent 48f0943 commit 75b9a9f
Show file tree
Hide file tree
Showing 4 changed files with 410 additions and 12 deletions.
89 changes: 88 additions & 1 deletion copy.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ type CopyGraphOptions struct {
// OnCopySkipped will be called when the sub-DAG rooted by the current node
// is skipped.
OnCopySkipped func(ctx context.Context, desc ocispec.Descriptor) error
// MountFrom returns the candidate repositories that desc may be mounted from.
// The OCI references will be tried in turn. If mounting fails on all of them, then it falls back to a copy.
MountFrom func(ctx context.Context, desc ocispec.Descriptor) ([]string, error)
// OnMounted will be invoked when desc is mounted.
OnMounted func(ctx context.Context, desc ocispec.Descriptor) error
// FindSuccessors finds the successors of the current node.
// fetcher provides cached access to the source storage, and is suitable
// for fetching non-leaf nodes like manifests. Since anything fetched from
Expand Down Expand Up @@ -259,12 +264,94 @@ func copyGraph(ctx context.Context, src content.ReadOnlyStorage, dst content.Sto
if exists {
return copyNode(ctx, proxy.Cache, dst, desc, opts)
}
return copyNode(ctx, src, dst, desc, opts)
return mountOrCopyNode(ctx, src, dst, desc, opts)
}

return syncutil.Go(ctx, limiter, fn, root)
}

// mountOrCopyNode enabled cross repository blob mounting.
// sourceReference is the repository to use for mounting (the mount point).
// mounter is the destination for the mount (a well-known implementation of this is *registry.Repository representing the target).
// OnMounted is called (if provided) when the blob is mounted.
// The original PreCopy hook is called only on copy, and therefore not when the blob is mounted.
func mountOrCopyNode(ctx context.Context, src content.ReadOnlyStorage, dst content.Storage, desc ocispec.Descriptor, opts CopyGraphOptions) error {
mounter, ok := dst.(registry.Mounter)
if !ok {
// mounting is not supported by the destination
return copyNode(ctx, src, dst, desc, opts)
}

// Only care to mount blobs
if descriptor.IsManifest(desc) {
return copyNode(ctx, src, dst, desc, opts)
}

Check warning on line 288 in copy.go

View check run for this annotation

Codecov / codecov/patch

copy.go#L287-L288

Added lines #L287 - L288 were not covered by tests

if opts.MountFrom == nil {
return copyNode(ctx, src, dst, desc, opts)
}

sourceRepositories, err := opts.MountFrom(ctx, desc)
if err != nil {
// Technically this error is not fatal, we can still attempt to copy the node
// But for consistency with the other callbacks we bail out.
return err
}

Check warning on line 299 in copy.go

View check run for this annotation

Codecov / codecov/patch

copy.go#L296-L299

Added lines #L296 - L299 were not covered by tests

if len(sourceRepositories) == 0 {
return copyNode(ctx, src, dst, desc, opts)
}

Check warning on line 303 in copy.go

View check run for this annotation

Codecov / codecov/patch

copy.go#L302-L303

Added lines #L302 - L303 were not covered by tests

skipContent := errors.New("skip content")
for i, sourceRepository := range sourceRepositories {
// try mounting this source repository
var mountFailed bool
getContent := func() (io.ReadCloser, error) {
// the invocation of getContent indicates that mounting has failed
mountFailed = true

if len(sourceRepositories)-1 == i {
// this is the last iteration so we need to actually get the content and do the copy

// call the original PreCopy function if it exists
if opts.PreCopy != nil {
if err := opts.PreCopy(ctx, desc); err != nil {
return nil, err
}

Check warning on line 320 in copy.go

View check run for this annotation

Codecov / codecov/patch

copy.go#L319-L320

Added lines #L319 - L320 were not covered by tests
}
return src.Fetch(ctx, desc)
}

// We want to return an error that we will test for from mounter.Mount()
return nil, skipContent
}

// Mount or copy
if err := mounter.Mount(ctx, desc, sourceRepository, getContent); err != nil && !errors.Is(err, skipContent) {
return err
}

Check warning on line 332 in copy.go

View check run for this annotation

Codecov / codecov/patch

copy.go#L331-L332

Added lines #L331 - L332 were not covered by tests

if !mountFailed {
// mounted, success
if opts.OnMounted != nil {
if err := opts.OnMounted(ctx, desc); err != nil {
return err
}

Check warning on line 339 in copy.go

View check run for this annotation

Codecov / codecov/patch

copy.go#L338-L339

Added lines #L338 - L339 were not covered by tests
}
return nil
}
}

// we copied it
if opts.PostCopy != nil {
if err := opts.PostCopy(ctx, desc); err != nil {
return err
}

Check warning on line 349 in copy.go

View check run for this annotation

Codecov / codecov/patch

copy.go#L348-L349

Added lines #L348 - L349 were not covered by tests
}

return nil
}

// doCopyNode copies a single content from the source CAS to the destination CAS.
func doCopyNode(ctx context.Context, src content.ReadOnlyStorage, dst content.Storage, desc ocispec.Descriptor) error {
rc, err := src.Fetch(ctx, desc)
Expand Down
252 changes: 251 additions & 1 deletion copy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1471,11 +1471,251 @@ func TestCopyGraph_WithOptions(t *testing.T) {
t.Errorf("count(Push()) = %d, want %d", got, expected)
}
})

t.Run("MountFrom_Mounted", func(t *testing.T) {
root = descs[6]
dst := &countingStorage{storage: cas.NewMemory()}
var numMount atomic.Int64
dst.mount = func(ctx context.Context,
desc ocispec.Descriptor,
fromRepo string,
getContent func() (io.ReadCloser, error),
) error {
numMount.Add(1)
if expected := "source"; fromRepo != expected {
t.Fatalf("fromRepo = %v, want %v", fromRepo, expected)
}
rc, err := src.Fetch(ctx, desc)
if err != nil {
t.Fatalf("Failed to fetch content: %v", err)
}
defer rc.Close()
err = dst.storage.Push(ctx, desc, rc) // bypass the counters
if err != nil {
t.Fatalf("Failed to push content: %v", err)
}
return nil
}
opts = oras.CopyGraphOptions{}
var numPreCopy, numPostCopy, numOnMounted, numMountFrom atomic.Int64
opts.PreCopy = func(ctx context.Context, desc ocispec.Descriptor) error {
numPreCopy.Add(1)
return nil
}
opts.PostCopy = func(ctx context.Context, desc ocispec.Descriptor) error {
numPostCopy.Add(1)
return nil
}
opts.OnMounted = func(ctx context.Context, d ocispec.Descriptor) error {
numOnMounted.Add(1)
return nil
}
opts.MountFrom = func(ctx context.Context, desc ocispec.Descriptor) ([]string, error) {
numMountFrom.Add(1)
return []string{"source"}, nil
}
if err := oras.CopyGraph(ctx, src, dst, root, opts); err != nil {
t.Fatalf("CopyGraph() error = %v, wantErr %v", err, errdef.ErrSizeExceedsLimit)
}

if got, expected := dst.numExists.Load(), int64(7); got != expected {
t.Errorf("count(Exists()) = %d, want %d", got, expected)
}
if got, expected := dst.numFetch.Load(), int64(0); got != expected {
t.Errorf("count(Fetch()) = %d, want %d", got, expected)
}
// 7 (exists) - 1 (skipped) = 6 pushes expected
if got, expected := dst.numPush.Load(), int64(3); got != expected {
// If we get >=7 then ErrSkipDesc did not short circuit the push like it is supposed to do.
t.Errorf("count(Push()) = %d, want %d", got, expected)
}
if got, expected := numMount.Load(), int64(4); got != expected {
t.Errorf("count(Mount()) = %d, want %d", got, expected)
}
if got, expected := numOnMounted.Load(), int64(4); got != expected {
t.Errorf("count(OnMounted()) = %d, want %d", got, expected)
}
if got, expected := numMountFrom.Load(), int64(4); got != expected {
t.Errorf("count(MountFrom()) = %d, want %d", got, expected)
}
if got, expected := numPreCopy.Load(), int64(3); got != expected {
t.Errorf("count(PreCopy()) = %d, want %d", got, expected)
}
if got, expected := numPostCopy.Load(), int64(3); got != expected {
t.Errorf("count(PostCopy()) = %d, want %d", got, expected)
}
})

t.Run("MountFrom_Copied", func(t *testing.T) {
root = descs[6]
dst := &countingStorage{storage: cas.NewMemory()}
var numMount atomic.Int64
dst.mount = func(ctx context.Context,
desc ocispec.Descriptor,
fromRepo string,
getContent func() (io.ReadCloser, error),
) error {
numMount.Add(1)
if expected := "source"; fromRepo != expected {
t.Fatalf("fromRepo = %v, want %v", fromRepo, expected)
}

rc, err := getContent()
if err != nil {
t.Fatalf("Failed to fetch content: %v", err)
}
defer rc.Close()
err = dst.storage.Push(ctx, desc, rc) // bypass the counters
if err != nil {
t.Fatalf("Failed to push content: %v", err)
}
return nil
}
opts = oras.CopyGraphOptions{}
var numPreCopy, numPostCopy, numOnMounted, numMountFrom atomic.Int64
opts.PreCopy = func(ctx context.Context, desc ocispec.Descriptor) error {
numPreCopy.Add(1)
return nil
}
opts.PostCopy = func(ctx context.Context, desc ocispec.Descriptor) error {
numPostCopy.Add(1)
return nil
}
opts.OnMounted = func(ctx context.Context, d ocispec.Descriptor) error {
numOnMounted.Add(1)
return nil
}
opts.MountFrom = func(ctx context.Context, desc ocispec.Descriptor) ([]string, error) {
numMountFrom.Add(1)
return []string{"source"}, nil
}
if err := oras.CopyGraph(ctx, src, dst, root, opts); err != nil {
t.Fatalf("CopyGraph() error = %v, wantErr %v", err, errdef.ErrSizeExceedsLimit)
}

if got, expected := dst.numExists.Load(), int64(7); got != expected {
t.Errorf("count(Exists()) = %d, want %d", got, expected)
}
if got, expected := dst.numFetch.Load(), int64(0); got != expected {
t.Errorf("count(Fetch()) = %d, want %d", got, expected)
}
// 7 (exists) - 1 (skipped) = 6 pushes expected
if got, expected := dst.numPush.Load(), int64(3); got != expected {
// If we get >=7 then ErrSkipDesc did not short circuit the push like it is supposed to do.
t.Errorf("count(Push()) = %d, want %d", got, expected)
}
if got, expected := numMount.Load(), int64(4); got != expected {
t.Errorf("count(Mount()) = %d, want %d", got, expected)
}
if got, expected := numOnMounted.Load(), int64(0); got != expected {
t.Errorf("count(OnMounted()) = %d, want %d", got, expected)
}
if got, expected := numMountFrom.Load(), int64(4); got != expected {
t.Errorf("count(MountFrom()) = %d, want %d", got, expected)
}
if got, expected := numPreCopy.Load(), int64(7); got != expected {
t.Errorf("count(PreCopy()) = %d, want %d", got, expected)
}
if got, expected := numPostCopy.Load(), int64(7); got != expected {
t.Errorf("count(PostCopy()) = %d, want %d", got, expected)
}
})

t.Run("MountFrom_Mounted_Second_Try", func(t *testing.T) {
root = descs[6]
dst := &countingStorage{storage: cas.NewMemory()}
var numMount atomic.Int64
dst.mount = func(ctx context.Context,
desc ocispec.Descriptor,
fromRepo string,
getContent func() (io.ReadCloser, error),
) error {
numMount.Add(1)
switch fromRepo {
case "source":
rc, err := src.Fetch(ctx, desc)
if err != nil {
t.Fatalf("Failed to fetch content: %v", err)
}
defer rc.Close()
err = dst.storage.Push(ctx, desc, rc) // bypass the counters
if err != nil {
t.Fatalf("Failed to push content: %v", err)
}
return nil
case "missing/the/data":
// simulate a registry mount will fail, so it will request the content to start the copy.
rc, err := getContent()
if err != nil {
return fmt.Errorf("getContent failed: %w", err)
}
defer rc.Close()
err = dst.storage.Push(ctx, desc, rc) // bypass the counters
if err != nil {
t.Fatalf("Failed to push content: %v", err)
}
return nil
default:
t.Fatalf("fromRepo = %v, want either %v or %v", fromRepo, "missing/the/data", "source")
return nil
}
}
opts = oras.CopyGraphOptions{}
var numPreCopy, numPostCopy, numOnMounted, numMountFrom atomic.Int64
opts.PreCopy = func(ctx context.Context, desc ocispec.Descriptor) error {
numPreCopy.Add(1)
return nil
}
opts.PostCopy = func(ctx context.Context, desc ocispec.Descriptor) error {
numPostCopy.Add(1)
return nil
}
opts.OnMounted = func(ctx context.Context, d ocispec.Descriptor) error {
numOnMounted.Add(1)
return nil
}
opts.MountFrom = func(ctx context.Context, desc ocispec.Descriptor) ([]string, error) {
numMountFrom.Add(1)
return []string{"missing/the/data", "source"}, nil
}
if err := oras.CopyGraph(ctx, src, dst, root, opts); err != nil {
t.Fatalf("CopyGraph() error = %v, wantErr %v", err, errdef.ErrSizeExceedsLimit)
}

if got, expected := dst.numExists.Load(), int64(7); got != expected {
t.Errorf("count(Exists()) = %d, want %d", got, expected)
}
if got, expected := dst.numFetch.Load(), int64(0); got != expected {
t.Errorf("count(Fetch()) = %d, want %d", got, expected)
}
// 7 (exists) - 1 (skipped) = 6 pushes expected
if got, expected := dst.numPush.Load(), int64(3); got != expected {
// If we get >=7 then ErrSkipDesc did not short circuit the push like it is supposed to do.
t.Errorf("count(Push()) = %d, want %d", got, expected)
}
if got, expected := numMount.Load(), int64(4*2); got != expected {
t.Errorf("count(Mount()) = %d, want %d", got, expected)
}
if got, expected := numOnMounted.Load(), int64(4); got != expected {
t.Errorf("count(OnMounted()) = %d, want %d", got, expected)
}
if got, expected := numMountFrom.Load(), int64(4); got != expected {
t.Errorf("count(MountFrom()) = %d, want %d", got, expected)
}
if got, expected := numPreCopy.Load(), int64(3); got != expected {
t.Errorf("count(PreCopy()) = %d, want %d", got, expected)
}
if got, expected := numPostCopy.Load(), int64(3); got != expected {
t.Errorf("count(PostCopy()) = %d, want %d", got, expected)
}
})
}

// countingStorage counts the calls to its content.Storage methods
type countingStorage struct {
storage content.Storage
storage content.Storage
mount mountFunc

numExists, numFetch, numPush atomic.Int64
}

Expand All @@ -1494,6 +1734,16 @@ func (cs *countingStorage) Push(ctx context.Context, target ocispec.Descriptor,
return cs.storage.Push(ctx, target, r)
}

type mountFunc func(context.Context, ocispec.Descriptor, string, func() (io.ReadCloser, error)) error

func (cs *countingStorage) Mount(ctx context.Context,
desc ocispec.Descriptor,
fromRepo string,
getContent func() (io.ReadCloser, error),
) error {
return cs.mount(ctx, desc, fromRepo, getContent)
}

func TestCopyGraph_WithConcurrencyLimit(t *testing.T) {
src := cas.NewMemory()
// generate test content
Expand Down
Loading

0 comments on commit 75b9a9f

Please sign in to comment.