Skip to content

Commit

Permalink
added unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: Kyle M. Tarplee <[email protected]>
  • Loading branch information
ktarplee committed Nov 17, 2023
1 parent 0afea81 commit d26ffea
Show file tree
Hide file tree
Showing 3 changed files with 209 additions and 11 deletions.
147 changes: 146 additions & 1 deletion copy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1471,11 +1471,146 @@ func TestCopyGraph_WithOptions(t *testing.T) {
t.Errorf("count(Push()) = %d, want %d", got, expected)
}
})

t.Run("WithMount_Mounted", func(t *testing.T) {
root = descs[6]
dst := &countingStorage{storage: cas.NewMemory()}
var numOnMounted atomic.Int64
m := mounter(func(ctx context.Context,
desc ocispec.Descriptor,
fromRepo string,
getContent func() (io.ReadCloser, error),
) error {
if expected := "source"; fromRepo != expected {
t.Fatalf("fromRepo = %v, want %v", fromRepo, expected)
}

// _, err := getContent()
// if !errors.Is(err, errdef.ErrUnsupported) {
// t.Fatalf("Expected error %v", errdef.ErrUnsupported)
// }
rc, err := src.Fetch(ctx, desc)
if err != nil {
t.Fatalf("Failed to fetch content: %v", err)
}
defer rc.Close()
err = dst.storage.Push(ctx, desc, rc) // bypass the counters
if err != nil {
t.Fatalf("Failed to push content: %v", err)
}
return nil
})
opts = oras.CopyGraphOptions{}
var numPreCopy, numPostCopy atomic.Int64
opts.PreCopy = func(ctx context.Context, desc ocispec.Descriptor) error {
numPreCopy.Add(1)
return nil
}
opts.PostCopy = func(ctx context.Context, desc ocispec.Descriptor) error {
numPostCopy.Add(1)
return nil
}
opts.WithMount("source", m, func(ctx context.Context, d ocispec.Descriptor) {
numOnMounted.Add(1)
})
if err := oras.CopyGraph(ctx, src, dst, root, opts); err != nil {
t.Fatalf("CopyGraph() error = %v, wantErr %v", err, errdef.ErrSizeExceedsLimit)
}

if got, expected := dst.numExists.Load(), int64(7); got != expected {
t.Errorf("count(Exists()) = %d, want %d", got, expected)
}
if got, expected := dst.numFetch.Load(), int64(0); got != expected {
t.Errorf("count(Fetch()) = %d, want %d", got, expected)
}
// 7 (exists) - 1 (skipped) = 6 pushes expected
if got, expected := dst.numPush.Load(), int64(3); got != expected {
// If we get >=7 then ErrSkipDesc did not short circuit the push like it is supposed to do.
t.Errorf("count(Push()) = %d, want %d", got, expected)
}
if got, expected := numOnMounted.Load(), int64(4); got != expected {
t.Errorf("count(onMounted()) = %d, want %d", got, expected)
}
if got, expected := numPreCopy.Load(), int64(3); got != expected {
t.Errorf("count(PreCopy()) = %d, want %d", got, expected)
}
if got, expected := numPostCopy.Load(), int64(3); got != expected {
t.Errorf("count(PostCopy()) = %d, want %d", got, expected)
}
})

t.Run("WithMount_Copied", func(t *testing.T) {
root = descs[6]
dst := &countingStorage{storage: cas.NewMemory()}
var numOnMounted atomic.Int64
m := mounter(func(ctx context.Context,
desc ocispec.Descriptor,
fromRepo string,
getContent func() (io.ReadCloser, error),
) error {
if expected := "source"; fromRepo != expected {
t.Fatalf("fromRepo = %v, want %v", fromRepo, expected)
}

_, err := getContent()
if !errors.Is(err, errdef.ErrUnsupported) {
t.Fatalf("Expected error %v", errdef.ErrUnsupported)
}
rc, err := src.Fetch(ctx, desc)
if err != nil {
t.Fatalf("Failed to fetch content: %v", err)
}
defer rc.Close()
err = dst.storage.Push(ctx, desc, rc) // bypass the counters
if err != nil {
t.Fatalf("Failed to push content: %v", err)
}
return nil
})
opts = oras.CopyGraphOptions{}
var numPreCopy, numPostCopy atomic.Int64
opts.PreCopy = func(ctx context.Context, desc ocispec.Descriptor) error {
numPreCopy.Add(1)
return nil
}
opts.PostCopy = func(ctx context.Context, desc ocispec.Descriptor) error {
numPostCopy.Add(1)
return nil
}
opts.WithMount("source", m, func(ctx context.Context, d ocispec.Descriptor) {
numOnMounted.Add(1)
})
if err := oras.CopyGraph(ctx, src, dst, root, opts); err != nil {
t.Fatalf("CopyGraph() error = %v, wantErr %v", err, errdef.ErrSizeExceedsLimit)
}

if got, expected := dst.numExists.Load(), int64(7); got != expected {
t.Errorf("count(Exists()) = %d, want %d", got, expected)
}
if got, expected := dst.numFetch.Load(), int64(0); got != expected {
t.Errorf("count(Fetch()) = %d, want %d", got, expected)
}
// 7 (exists) - 1 (skipped) = 6 pushes expected
if got, expected := dst.numPush.Load(), int64(3); got != expected {
// If we get >=7 then ErrSkipDesc did not short circuit the push like it is supposed to do.
t.Errorf("count(Push()) = %d, want %d", got, expected)
}
if got, expected := numOnMounted.Load(), int64(0); got != expected {
t.Errorf("count(onMounted()) = %d, want %d", got, expected)
}
if got, expected := numPreCopy.Load(), int64(7); got != expected {
t.Errorf("count(PreCopy()) = %d, want %d", got, expected)
}
if got, expected := numPostCopy.Load(), int64(7); got != expected {
t.Errorf("count(PostCopy()) = %d, want %d", got, expected)
}
})
}

// countingStorage counts the calls to its content.Storage methods
type countingStorage struct {
storage content.Storage
storage content.Storage

numExists, numFetch, numPush atomic.Int64
}

Expand All @@ -1494,6 +1629,16 @@ func (cs *countingStorage) Push(ctx context.Context, target ocispec.Descriptor,
return cs.storage.Push(ctx, target, r)
}

type mounter func(context.Context, ocispec.Descriptor, string, func() (io.ReadCloser, error)) error

func (m mounter) Mount(ctx context.Context,
desc ocispec.Descriptor,
fromRepo string,
getContent func() (io.ReadCloser, error),
) error {
return m(ctx, desc, fromRepo, getContent)
}

func TestCopyGraph_WithConcurrencyLimit(t *testing.T) {
src := cas.NewMemory()
// generate test content
Expand Down
32 changes: 32 additions & 0 deletions example_copy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
"oras.land/oras-go/v2/content/memory"
"oras.land/oras-go/v2/content/oci"
"oras.land/oras-go/v2/internal/spec"
"oras.land/oras-go/v2/registry"
"oras.land/oras-go/v2/registry/remote"
)

Expand Down Expand Up @@ -214,6 +215,37 @@ func ExampleCopy_remoteToRemote() {
// sha256:7cbb44b44e8ede5a89cf193db3f5f2fd019d89697e6b87e8ed2589e60649b0d1
}

func ExampleCopy_remoteToRemote_WithMount() {
reg, err := remote.NewRegistry(remoteHost)
if err != nil {
panic(err) // Handle error
}
ctx := context.Background()
src, err := reg.Repository(ctx, "source")
if err != nil {
panic(err) // Handle error
}
dst, err := reg.Repository(ctx, "target")
if err != nil {
panic(err) // Handle error
}

tagName := "latest"

opts := oras.CopyOptions{}
// Enable cross-repository blob mounting
opts.WithMount("source", dst.(registry.Mounter), nil)

desc, err := oras.Copy(ctx, src, tagName, dst, tagName, opts)
if err != nil {
panic(err) // Handle error
}
fmt.Println("Final", desc.Digest)

// Output:
// Final sha256:7cbb44b44e8ede5a89cf193db3f5f2fd019d89697e6b87e8ed2589e60649b0d1
}

func ExampleCopy_remoteToLocal() {
reg, err := remote.NewRegistry(remoteHost)
if err != nil {
Expand Down
41 changes: 31 additions & 10 deletions registry/remote/repository_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -421,16 +421,37 @@ func TestRepository_Mount_Fallback(t *testing.T) {
repo.PlainHTTP = true
ctx := context.Background()

err = repo.Mount(ctx, blobDesc, "test", nil)
if err != nil {
t.Fatalf("Repository.Push() error = %v", err)
}
if !bytes.Equal(gotBlob, blob) {
t.Errorf("Repository.Mount() = %v, want %v", gotBlob, blob)
}
if got, want := sequence, "post get put "; got != want {
t.Errorf("unexpected request sequence; got %q want %q", got, want)
}
t.Run("nil", func(t *testing.T) {
sequence = ""

err = repo.Mount(ctx, blobDesc, "test", nil)
if err != nil {
t.Fatalf("Repository.Push() error = %v", err)
}
if !bytes.Equal(gotBlob, blob) {
t.Errorf("Repository.Mount() = %v, want %v", gotBlob, blob)
}
if got, want := sequence, "post get put "; got != want {
t.Errorf("unexpected request sequence; got %q want %q", got, want)
}
})

t.Run("ErrUnsupported", func(t *testing.T) {
sequence = ""

err = repo.Mount(ctx, blobDesc, "test", func() (io.ReadCloser, error) {
return nil, errdef.ErrUnsupported
})
if err != nil {
t.Fatalf("Repository.Push() error = %v", err)
}
if !bytes.Equal(gotBlob, blob) {
t.Errorf("Repository.Mount() = %v, want %v", gotBlob, blob)
}
if got, want := sequence, "post get put "; got != want {
t.Errorf("unexpected request sequence; got %q want %q", got, want)
}
})
}

func TestRepository_Mount_Error(t *testing.T) {
Expand Down

0 comments on commit d26ffea

Please sign in to comment.