diff --git a/pkg/common/accessio/downloader/downloader.go b/pkg/common/accessio/downloader/downloader.go new file mode 100644 index 000000000..46f67f54b --- /dev/null +++ b/pkg/common/accessio/downloader/downloader.go @@ -0,0 +1,9 @@ +package downloader + +import "io" + +// Downloader defines a downloader for various objects using a WriterAt to +// transfer data to. +type Downloader interface { + Download(w io.WriterAt) error +} diff --git a/pkg/common/accessio/downloader/http/downloader.go b/pkg/common/accessio/downloader/http/downloader.go new file mode 100644 index 000000000..46d65c0e5 --- /dev/null +++ b/pkg/common/accessio/downloader/http/downloader.go @@ -0,0 +1,36 @@ +package http + +import ( + "bytes" + "fmt" + "io" + "net/http" +) + +// Downloader simply uses the default HTTP client to download the contents of a URL. +type Downloader struct { + link string +} + +func NewDownloader(link string) *Downloader { + return &Downloader{ + link: link, + } +} + +func (h *Downloader) Download(w io.WriterAt) error { + resp, err := http.Get(h.link) + if err != nil { + return fmt.Errorf("failed to get link: %w", err) + } + defer resp.Body.Close() + var blob []byte + buf := bytes.NewBuffer(blob) + if _, err := io.Copy(buf, resp.Body); err != nil { + return fmt.Errorf("failed to copy response body: %w", err) + } + if _, err := w.WriteAt(buf.Bytes(), 0); err != nil { + return fmt.Errorf("failed to WriteAt to the writer: %w", err) + } + return nil +} diff --git a/pkg/contexts/ocm/accessmethods/s3/downloader.go b/pkg/common/accessio/downloader/s3/downloader.go similarity index 85% rename from pkg/contexts/ocm/accessmethods/s3/downloader.go rename to pkg/common/accessio/downloader/s3/downloader.go index b0fee5f58..8577072bf 100644 --- a/pkg/contexts/ocm/accessmethods/s3/downloader.go +++ b/pkg/common/accessio/downloader/s3/downloader.go @@ -14,19 +14,14 @@ import ( const defaultRegion = "us-west-1" -// Downloader defines a downloader for AWS S3 objects. -type Downloader interface { - Download(w io.WriterAt) error -} - -// S3Downloader is a downloader capable of downloading S3 Objects. -type S3Downloader struct { +// Downloader is a downloader capable of downloading S3 Objects. +type Downloader struct { region, bucket, key, version string creds *AWSCreds } -func NewS3Downloader(region, bucket, key, version string, creds *AWSCreds) *S3Downloader { - return &S3Downloader{ +func NewDownloader(region, bucket, key, version string, creds *AWSCreds) *Downloader { + return &Downloader{ region: region, bucket: bucket, key: key, @@ -42,7 +37,7 @@ type AWSCreds struct { SessionToken string } -func (s *S3Downloader) Download(w io.WriterAt) error { +func (s *Downloader) Download(w io.WriterAt) error { ctx := context.Background() opts := []func(*config.LoadOptions) error{ config.WithRegion(s.region), @@ -79,6 +74,7 @@ func (s *S3Downloader) Download(w io.WriterAt) error { } client := s3.NewFromConfig(cfg, func(o *s3.Options) { + // Pass in creds because of https://github.com/aws/aws-sdk-go-v2/issues/1797 o.Credentials = awsCred o.Region = s.region }) diff --git a/pkg/contexts/ocm/accessmethods/github/downloader.go b/pkg/contexts/ocm/accessmethods/github/downloader.go deleted file mode 100644 index 8cb7a0406..000000000 --- a/pkg/contexts/ocm/accessmethods/github/downloader.go +++ /dev/null @@ -1,35 +0,0 @@ -package github - -import ( - "bytes" - "fmt" - "io" - "net/http" -) - -// Downloader defines an abstraction for downloading an archive from GitHub. -type Downloader interface { - Download(link string) ([]byte, error) -} - -// HTTPDownloader simply uses the default HTTP client to download the contents of a URL. -type HTTPDownloader struct{} - -func (h *HTTPDownloader) Download(link string) ([]byte, error) { - httpResp, err := http.Get(link) - if err != nil { - return nil, err - } - defer func() { - if err := httpResp.Body.Close(); err != nil { - fmt.Println("failed to close body: ", err) - } - }() - - var blob []byte - buf := bytes.NewBuffer(blob) - if _, err := io.Copy(buf, httpResp.Body); err != nil { - return nil, err - } - return buf.Bytes(), nil -} diff --git a/pkg/contexts/ocm/accessmethods/github/method.go b/pkg/contexts/ocm/accessmethods/github/method.go index d90505c4b..3a89dc346 100644 --- a/pkg/contexts/ocm/accessmethods/github/method.go +++ b/pkg/contexts/ocm/accessmethods/github/method.go @@ -17,21 +17,21 @@ package github import ( "context" "fmt" - "io" "net/http" "net/url" "strings" - "sync" "unicode" "github.com/google/go-github/v45/github" "golang.org/x/oauth2" "github.com/open-component-model/ocm/pkg/common/accessio" + "github.com/open-component-model/ocm/pkg/common/accessio/downloader" + hd "github.com/open-component-model/ocm/pkg/common/accessio/downloader/http" + "github.com/open-component-model/ocm/pkg/common/accessobj" "github.com/open-component-model/ocm/pkg/contexts/credentials" "github.com/open-component-model/ocm/pkg/contexts/credentials/identity/hostpath" "github.com/open-component-model/ocm/pkg/contexts/oci/identity" - "github.com/open-component-model/ocm/pkg/contexts/oci/repositories/artefactset" "github.com/open-component-model/ocm/pkg/contexts/ocm/cpi" "github.com/open-component-model/ocm/pkg/errors" "github.com/open-component-model/ocm/pkg/mime" @@ -67,17 +67,15 @@ type AccessSpec struct { // RepoUrl is the repository URL, with host, owner and repository RepoURL string `json:"repoUrl"` - // APIHostname is an optional different hostname for accessing the github REST API + // APIHostname is an optional different hostname for accessing the GitHub REST API // for enterprise installations APIHostname string `json:"apiHostname,omitempty"` - // Ref - Ref string `json:"ref,omitempty"` - // Commit defines the hash of the commit. + // Commit defines the hash of the commit Commit string `json:"commit"` client *http.Client - downloader Downloader + downloader downloader.Downloader } var _ cpi.AccessSpec = (*AccessSpec)(nil) @@ -85,13 +83,6 @@ var _ cpi.AccessSpec = (*AccessSpec)(nil) // AccessSpecOptions defines a set of options which can be applied to the access spec. type AccessSpecOptions func(s *AccessSpec) -// WithRef creates an access spec with a specified reference field -func WithRef(ref string) AccessSpecOptions { - return func(s *AccessSpec) { - s.Ref = ref - } -} - // WithClient creates an access spec with a custom http client. func WithClient(client *http.Client) AccessSpecOptions { return func(s *AccessSpec) { @@ -100,25 +91,18 @@ func WithClient(client *http.Client) AccessSpecOptions { } // WithDownloader defines a client with a custom downloader. -func WithDownloader(downloader Downloader) AccessSpecOptions { +func WithDownloader(downloader downloader.Downloader) AccessSpecOptions { return func(s *AccessSpec) { s.downloader = downloader } } -// New creates a new GitHub registry access spec version v1 -func New(hostname string, port int, repo, owner, commit string, opts ...AccessSpecOptions) *AccessSpec { - if hostname == "" { - hostname = "github.com" - } - p := "" - if port != 0 { - p = fmt.Sprintf(":%d", port) - } - url := fmt.Sprintf("%s%s/%s/%s", hostname, p, owner, repo) +// New creates a new GitHub registry access spec version v1. +func New(repoURL, apiHostname, commit string, opts ...AccessSpecOptions) *AccessSpec { s := &AccessSpec{ ObjectVersionedType: runtime.NewVersionedObjectType(Type), - RepoURL: url, + RepoURL: repoURL, + APIHostname: apiHostname, Commit: commit, } for _, o := range opts { @@ -139,7 +123,20 @@ func (a *AccessSpec) AccessMethod(c cpi.ComponentVersionAccess) (cpi.AccessMetho return newMethod(c, a) } -//////////////////////////////////////////////////////////////////////////////// +func (a *AccessSpec) createHTTPClient(token string) *http.Client { + if token != "" { + ts := oauth2.StaticTokenSource( + &oauth2.Token{AccessToken: token}, + ) + ctx := context.Background() + // set up the test client if we have one + if a.client != nil { + ctx = context.WithValue(ctx, oauth2.HTTPClient, a.client) + } + return oauth2.NewClient(ctx, ts) + } + return a.client +} // RepositoryService defines capabilities of a GitHub repository. type RepositoryService interface { @@ -147,26 +144,20 @@ type RepositoryService interface { } type accessMethod struct { - lock sync.Mutex - blob artefactset.ArtefactBlob + accessio.BlobAccess + compvers cpi.ComponentVersionAccess spec *AccessSpec repositoryService RepositoryService owner string repo string - downloader Downloader } var _ cpi.AccessMethod = (*accessMethod)(nil) func newMethod(c cpi.ComponentVersionAccess, a *AccessSpec) (cpi.AccessMethod, error) { - if len(a.Commit) != ShaLength { - return nil, fmt.Errorf("commit is not a SHA") - } - for _, c := range a.Commit { - if !unicode.IsOneOf([]*unicode.RangeTable{unicode.Letter, unicode.Digit}, c) { - return nil, fmt.Errorf("commit contains invalid characters for a SHA") - } + if err := validateCommit(a.Commit); err != nil { + return nil, fmt.Errorf("failed to validate commit: %w", err) } unparsed := a.RepoURL @@ -190,15 +181,8 @@ func newMethod(c cpi.ComponentVersionAccess, a *AccessSpec) (cpi.AccessMethod, e } var client *github.Client + httpclient := a.createHTTPClient(token) - httpclient := a.client - - if token != "" && httpclient == nil { - ts := oauth2.StaticTokenSource( - &oauth2.Token{AccessToken: token}, - ) - httpclient = oauth2.NewClient(context.Background(), ts) - } if u.Hostname() == "github.com" { client = github.NewClient(httpclient) } else { @@ -214,18 +198,40 @@ func newMethod(c cpi.ComponentVersionAccess, a *AccessSpec) (cpi.AccessMethod, e } } - var downloader Downloader = &HTTPDownloader{} - if a.downloader != nil { - downloader = a.downloader - } - return &accessMethod{ + method := &accessMethod{ spec: a, compvers: c, owner: pathcomps[0], repo: pathcomps[1], repositoryService: client.Repositories, - downloader: downloader, - }, nil + } + + link, err := method.getDownloadLink() + if err != nil { + return nil, fmt.Errorf("failed to get download link: %w", err) + } + + var d downloader.Downloader = hd.NewDownloader(link) + if a.downloader != nil { + d = a.downloader + } + + w := accessio.NewWriteAtWriter(d.Download) + cacheBlobAccess := accessobj.CachedBlobAccessForWriter(c.GetContext(), method.MimeType(), w) + method.BlobAccess = cacheBlobAccess + return method, nil +} + +func validateCommit(commit string) error { + if len(commit) != ShaLength { + return fmt.Errorf("commit is not a SHA") + } + for _, c := range commit { + if !unicode.IsOneOf([]*unicode.RangeTable{unicode.Letter, unicode.Digit}, c) { + return fmt.Errorf("commit contains invalid characters for a SHA") + } + } + return nil } func getCreds(hostname, port, path string, cctx credentials.Context) (string, error) { @@ -258,79 +264,18 @@ func (m *accessMethod) GetKind() string { return Type } -// Close should clean up all cached data if present. -// Exp.: Cache the blob data. -func (m *accessMethod) Close() error { - m.lock.Lock() - defer m.lock.Unlock() - if m.blob != nil { - tmp := m.blob - m.blob = nil - return tmp.Close() - } - return nil -} - -func (m *accessMethod) Get() ([]byte, error) { - blob, err := m.getBlob() - if err != nil { - return nil, err - } - return blob.Get() -} - -func (m *accessMethod) Reader() (io.ReadCloser, error) { - b, err := m.getBlob() - if err != nil { - return nil, err - } - r, err := b.Reader() - if err != nil { - return nil, err - } - return r, nil -} - func (m *accessMethod) MimeType() string { return mime.MIME_TGZ } -// TODO: Implement caching based on the SHA of the blob. If it is detected that that SHA already exists -// return it. ( Use the virtual filesystem implementation so it can be in memory or via file system ). -func (m *accessMethod) getBlob() (accessio.BlobAccess, error) { - m.lock.Lock() - defer m.lock.Unlock() - if m.blob != nil { - return m.blob, nil - } - blob, err := m.downloadArchive() - if err != nil { - return nil, err - } - - return accessio.BlobAccessForData(mime.MIME_TGZ, blob), nil -} - -func (m *accessMethod) downloadArchive() ([]byte, error) { - if len(m.spec.Commit) != ShaLength { - return nil, fmt.Errorf("commit is not a SHA") - } - for _, c := range m.spec.Commit { - if !unicode.IsOneOf([]*unicode.RangeTable{unicode.Letter, unicode.Digit}, c) { - return nil, fmt.Errorf("commit contains invalid characters for a SHA") - } - } - +func (m *accessMethod) getDownloadLink() (string, error) { link, resp, err := m.repositoryService.GetArchiveLink(context.Background(), m.owner, m.repo, github.Tarball, &github.RepositoryContentGetOptions{ Ref: m.spec.Commit, }, true) if err != nil { - return nil, err + return "", err } - defer func() { - if err := resp.Body.Close(); err != nil { - fmt.Println("failed to close body: ", err) - } - }() - return m.downloader.Download(link.String()) + defer resp.Body.Close() + + return link.String(), nil } diff --git a/pkg/contexts/ocm/accessmethods/github/method_test.go b/pkg/contexts/ocm/accessmethods/github/method_test.go index 38e5b72cb..21b2bb3d4 100644 --- a/pkg/contexts/ocm/accessmethods/github/method_test.go +++ b/pkg/contexts/ocm/accessmethods/github/method_test.go @@ -22,26 +22,35 @@ import ( "os" "path/filepath" - _ "github.com/open-component-model/ocm/pkg/contexts/datacontext/config" - "k8s.io/apimachinery/pkg/util/sets" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "github.com/mandelsoft/vfs/pkg/osfs" + "github.com/mandelsoft/vfs/pkg/vfs" "github.com/open-component-model/ocm/pkg/common" "github.com/open-component-model/ocm/pkg/contexts/credentials" "github.com/open-component-model/ocm/pkg/contexts/credentials/core" + "github.com/open-component-model/ocm/pkg/contexts/datacontext" + "github.com/open-component-model/ocm/pkg/contexts/datacontext/attrs/tmpcache" + "github.com/open-component-model/ocm/pkg/contexts/datacontext/attrs/vfsattr" + _ "github.com/open-component-model/ocm/pkg/contexts/datacontext/config" "github.com/open-component-model/ocm/pkg/contexts/ocm" me "github.com/open-component-model/ocm/pkg/contexts/ocm/accessmethods/github" "github.com/open-component-model/ocm/pkg/contexts/ocm/cpi" ) -const doPrivate = false - type mockDownloader struct { - expected []byte - shouldMatchLink string + expected []byte + err error +} + +func (m *mockDownloader) Download(w io.WriterAt) error { + if _, err := w.WriteAt(m.expected, 0); err != nil { + return fmt.Errorf("failed to write to mock writer: %w", err) + } + return m.err } // RoundTripFunc . @@ -60,32 +69,17 @@ func NewTestClient(fn RoundTripFunc) *http.Client { } -func (m *mockDownloader) Download(link string) ([]byte, error) { - if link != m.shouldMatchLink { - return nil, fmt.Errorf("link mismatch; got: %s want: %s", link, m.shouldMatchLink) - } - - return m.expected, nil -} - -func Configure(ctx ocm.Context) { - data, err := os.ReadFile(filepath.Join(os.Getenv("HOME"), ".ocmconfig")) - if err != nil { - return - } - _, err = ctx.ConfigContext().ApplyData(data, nil, ".ocmconfig") - ExpectWithOffset(1, err).ToNot(HaveOccurred()) - -} - var _ = Describe("Method", func() { var ( ctx ocm.Context expectedBlobContent []byte err error - testClient *http.Client defaultLink string accessSpec *me.AccessSpec + dctx datacontext.Context + fs vfs.FileSystem + expectedURL string + clientFn func(url string) *http.Client ) BeforeEach(func() { @@ -93,70 +87,44 @@ var _ = Describe("Method", func() { expectedBlobContent, err = os.ReadFile(filepath.Join("testdata", "repo.tar.gz")) Expect(err).ToNot(HaveOccurred()) defaultLink = "https://github.com/test/test/sha?token=token" + expectedURL = "https://api.github.com/repos/test/test/tarball/7b1445755ee2527f0bf80ef9eeb59a5d2e6e3e1f" + + clientFn = func(url string) *http.Client { + return NewTestClient(func(req *http.Request) *http.Response { + if req.URL.String() != url { + Fail(fmt.Sprintf("failed to match url to expected url. want: %s; got: %s", expectedURL, req.URL.String())) + } + return &http.Response{ + StatusCode: 302, + Status: http.StatusText(http.StatusFound), + Body: io.NopCloser(bytes.NewBufferString(`{}`)), + Header: http.Header{ + "Location": []string{defaultLink}, + }, + } + }) + } - testClient = NewTestClient(func(req *http.Request) *http.Response { - return &http.Response{ - StatusCode: 302, - Status: http.StatusText(http.StatusFound), - Body: io.NopCloser(bytes.NewBufferString(`{}`)), - // Must be set to non-nil value or it panics - Header: http.Header{ - "Location": []string{defaultLink}, - }, - } - }) accessSpec = me.New( - "hostname", - 1234, - "repo", - "owner", + "https://github.com/test/test", + "", "7b1445755ee2527f0bf80ef9eeb59a5d2e6e3e1f", - me.WithClient(testClient), + me.WithClient(clientFn(expectedURL)), me.WithDownloader(&mockDownloader{ - expected: expectedBlobContent, - shouldMatchLink: defaultLink, + expected: expectedBlobContent, }), ) + fs, err = osfs.NewTempFileSystem() + Expect(err).To(Succeed()) + dctx = datacontext.New(nil) + vfsattr.Set(ctx, fs) + tmpcache.Set(ctx, &tmpcache.Attribute{Path: "/tmp"}) }) - It("downloads public spiff commit", func() { - spec := me.New("github.com", 0, "spiff", "mandelsoft", "25d9a3f0031c0b42e9ef7ab0117c35378040ef82") - - m, err := spec.AccessMethod(&cpi.DummyComponentVersionAccess{Context: ctx}) - Expect(err).ToNot(HaveOccurred()) - content, err := m.Get() - Expect(err).ToNot(HaveOccurred()) - Expect(len(content)).To(Equal(281655)) + AfterEach(func() { + vfs.Cleanup(fs) }) - if doPrivate { - Context("private access", func() { - It("downloads private commit", func() { - Configure(ctx) - - spec := me.New("github.com", 0, "cnudie-pause", "mandelsoft", "76eaae596ba24e401240654c4ad19ae66ba1e1a2") - - m, err := spec.AccessMethod(&cpi.DummyComponentVersionAccess{Context: ctx}) - Expect(err).ToNot(HaveOccurred()) - content, err := m.Get() - Expect(err).ToNot(HaveOccurred()) - Expect(len(content)).To(Equal(3764)) - }) - - It("downloads enterprise commit", func() { - Configure(ctx) - - spec := me.New("github.tools.sap", 0, "dummy", "D021770", "d17e2c594f0ab71f2c0f050b9d7fb485af4d6850") - - m, err := spec.AccessMethod(&cpi.DummyComponentVersionAccess{Context: ctx}) - Expect(err).ToNot(HaveOccurred()) - content, err := m.Get() - Expect(err).ToNot(HaveOccurred()) - Expect(len(content)).To(Equal(284)) - }) - }) - } - It("downloads artifacts", func() { m, err := accessSpec.AccessMethod(&cpi.DummyComponentVersionAccess{Context: ctx}) Expect(err).ToNot(HaveOccurred()) @@ -169,14 +137,11 @@ var _ = Describe("Method", func() { It("errors", func() { accessSpec := me.New( "hostname", - 1234, - "repo", - "owner", + "", "not-a-sha", - me.WithClient(testClient), + me.WithClient(clientFn(expectedURL)), me.WithDownloader(&mockDownloader{ - expected: expectedBlobContent, - shouldMatchLink: defaultLink, + expected: expectedBlobContent, }), ) m, err := accessSpec.AccessMethod(&cpi.DummyComponentVersionAccess{Context: ctx}) @@ -191,14 +156,11 @@ var _ = Describe("Method", func() { It("errors", func() { accessSpec := me.New( "hostname", - 1234, - "repo", - "owner", + "1234", "refs/heads/veryinteresting_branch_namess", - me.WithClient(testClient), + me.WithClient(clientFn(expectedURL)), me.WithDownloader(&mockDownloader{ - expected: expectedBlobContent, - shouldMatchLink: defaultLink, + expected: expectedBlobContent, }), ) m, err := accessSpec.AccessMethod(&cpi.DummyComponentVersionAccess{Context: ctx}) @@ -210,9 +172,42 @@ var _ = Describe("Method", func() { }) When("credentials are provided", func() { + BeforeEach(func() { + clientFn = func(url string) *http.Client { + return NewTestClient(func(req *http.Request) *http.Response { + if v, ok := req.Header["Authorization"]; ok { + Expect(v).To(ContainElement("Bearer test")) + } else { + Fail("Authorization header not found in request") + } + if req.URL.String() != url { + Fail(fmt.Sprintf("failed to match url to expected url. want: %s; got: %s", expectedURL, req.URL.String())) + } + return &http.Response{ + StatusCode: 302, + Status: http.StatusText(http.StatusFound), + // Must be set to non-nil value or it panics + Body: io.NopCloser(bytes.NewBufferString(`{}`)), + Header: http.Header{ + "Location": []string{defaultLink}, + }, + } + }) + } + accessSpec = me.New( + "https://github.com/test/test", + "", + "7b1445755ee2527f0bf80ef9eeb59a5d2e6e3e1f", + me.WithClient(clientFn(expectedURL)), + me.WithDownloader(&mockDownloader{ + expected: expectedBlobContent, + }), + ) + }) It("can use those to access private repos", func() { called := false - mcc := &mockCredContext{ + mcc := &mockContext{ + dataContext: dctx, creds: &mockCredSource{ cred: &mockCredentials{ value: func() string { @@ -235,7 +230,7 @@ var _ = Describe("Method", func() { When("GetCredentialsForConsumer returns an error", func() { It("errors", func() { called := false - mcc := &mockCredContext{ + mcc := &mockContext{ creds: &mockCredSource{ cred: &mockCredentials{ value: func() string { @@ -253,6 +248,33 @@ var _ = Describe("Method", func() { Expect(called).To(BeFalse()) }) }) + + When("an enterprise repo URL is provided", func() { + It("uses that domain and includes api/v3 in the request URL", func() { + expectedURL = "https://github.tools.sap/api/v3/repos/test/test/tarball/25d9a3f0031c0b42e9ef7ab0117c35378040ef82" + spec := me.New("https://github.tools.sap/test/test", "", "25d9a3f0031c0b42e9ef7ab0117c35378040ef82", me.WithClient(clientFn(expectedURL))) + _, err := spec.AccessMethod(&cpi.DummyComponentVersionAccess{Context: ctx}) + Expect(err).ToNot(HaveOccurred()) + }) + }) + + When("hostname is different from github.com", func() { + It("will use an enterprise client", func() { + expectedURL = "https://custom/api/v3/repos/test/test/tarball/25d9a3f0031c0b42e9ef7ab0117c35378040ef82" + spec := me.New("https://github.tools.sap/test/test", "custom", "25d9a3f0031c0b42e9ef7ab0117c35378040ef82", me.WithClient(clientFn(expectedURL))) + _, err := spec.AccessMethod(&cpi.DummyComponentVersionAccess{Context: ctx}) + Expect(err).ToNot(HaveOccurred()) + }) + }) + + When("repoURL doesn't have an https prefix", func() { + It("will add one", func() { + expectedURL = "https://api.github.com/repos/test/test/tarball/25d9a3f0031c0b42e9ef7ab0117c35378040ef82" + spec := me.New("github.com/test/test", "", "25d9a3f0031c0b42e9ef7ab0117c35378040ef82", me.WithClient(clientFn(expectedURL))) + _, err := spec.AccessMethod(&cpi.DummyComponentVersionAccess{Context: ctx}) + Expect(err).ToNot(HaveOccurred()) + }) + }) }) type mockComponentVersionAccess struct { @@ -264,15 +286,20 @@ func (m *mockComponentVersionAccess) GetContext() ocm.Context { return m.credContext } -type mockCredContext struct { +type mockContext struct { ocm.Context - creds credentials.Context + creds credentials.Context + dataContext datacontext.Context } -func (m *mockCredContext) CredentialsContext() credentials.Context { +func (m *mockContext) CredentialsContext() credentials.Context { return m.creds } +func (m *mockContext) GetAttributes() datacontext.Attributes { + return m.dataContext.GetAttributes() +} + type mockCredSource struct { credentials.Context cred credentials.Credentials diff --git a/pkg/contexts/ocm/accessmethods/s3/method.go b/pkg/contexts/ocm/accessmethods/s3/method.go index 76504cc17..f8e3f4d9c 100644 --- a/pkg/contexts/ocm/accessmethods/s3/method.go +++ b/pkg/contexts/ocm/accessmethods/s3/method.go @@ -16,11 +16,11 @@ package s3 import ( "fmt" - "io" "path" - "sync" "github.com/open-component-model/ocm/pkg/common/accessio" + "github.com/open-component-model/ocm/pkg/common/accessio/downloader" + "github.com/open-component-model/ocm/pkg/common/accessio/downloader/s3" "github.com/open-component-model/ocm/pkg/common/accessobj" "github.com/open-component-model/ocm/pkg/contexts/credentials" "github.com/open-component-model/ocm/pkg/contexts/credentials/identity/hostpath" @@ -59,13 +59,13 @@ type AccessSpec struct { // MediaType defines the mime type of the object to download. // +optional MediaType string `json:"mediaType,omitempty"` - downloader Downloader + downloader downloader.Downloader } var _ cpi.AccessSpec = (*AccessSpec)(nil) // New creates a new GitHub registry access spec version v1 -func New(region, bucket, key, version, mediaType string, downloader Downloader) *AccessSpec { +func New(region, bucket, key, version, mediaType string, downloader downloader.Downloader) *AccessSpec { return &AccessSpec{ ObjectVersionedType: runtime.NewVersionedObjectType(Type), Region: region, @@ -92,10 +92,10 @@ func (a *AccessSpec) AccessMethod(c cpi.ComponentVersionAccess) (cpi.AccessMetho //////////////////////////////////////////////////////////////////////////////// type accessMethod struct { - lock sync.Mutex - comp cpi.ComponentVersionAccess - spec *AccessSpec - cacheBlobAccess accessio.BlobAccess + accessio.BlobAccess + + comp cpi.ComponentVersionAccess + spec *AccessSpec } var _ cpi.AccessMethod = (*accessMethod)(nil) @@ -114,14 +114,14 @@ func newMethod(c cpi.ComponentVersionAccess, a *AccessSpec) (*accessMethod, erro accessKeyID = creds.GetProperty(credentials.ATTR_AWS_ACCESS_KEY_ID) accessSecret = creds.GetProperty(credentials.ATTR_AWS_SECRET_ACCESS_KEY) } - var awsCreds *AWSCreds + var awsCreds *s3.AWSCreds if accessKeyID != "" { - awsCreds = &AWSCreds{ + awsCreds = &s3.AWSCreds{ AccessKeyID: accessKeyID, AccessSecret: accessSecret, } } - var d Downloader = NewS3Downloader(a.Region, a.Bucket, a.Key, a.Version, awsCreds) + var d downloader.Downloader = s3.NewDownloader(a.Region, a.Bucket, a.Key, a.Version, awsCreds) if a.downloader != nil { d = a.downloader } @@ -133,9 +133,9 @@ func newMethod(c cpi.ComponentVersionAccess, a *AccessSpec) (*accessMethod, erro } cacheBlobAccess := accessobj.CachedBlobAccessForWriter(c.GetContext(), mediaType, w) return &accessMethod{ - spec: a, - comp: c, - cacheBlobAccess: cacheBlobAccess, + spec: a, + comp: c, + BlobAccess: cacheBlobAccess, }, nil } @@ -168,22 +168,3 @@ func getCreds(a *AccessSpec, cctx credentials.Context) (credentials.Credentials, func (m *accessMethod) GetKind() string { return Type } - -func (m *accessMethod) Close() error { - m.lock.Lock() - defer m.lock.Unlock() - - return m.cacheBlobAccess.Close() -} - -func (m *accessMethod) Get() ([]byte, error) { - return m.cacheBlobAccess.Get() -} - -func (m *accessMethod) Reader() (io.ReadCloser, error) { - return m.cacheBlobAccess.Reader() -} - -func (m *accessMethod) MimeType() string { - return m.cacheBlobAccess.MimeType() -} diff --git a/pkg/contexts/ocm/accessmethods/s3/method_test.go b/pkg/contexts/ocm/accessmethods/s3/method_test.go index 698781c88..98fbcac38 100644 --- a/pkg/contexts/ocm/accessmethods/s3/method_test.go +++ b/pkg/contexts/ocm/accessmethods/s3/method_test.go @@ -24,6 +24,7 @@ import ( "github.com/mandelsoft/vfs/pkg/vfs" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "github.com/open-component-model/ocm/pkg/common/accessio/downloader" "github.com/open-component-model/ocm/pkg/contexts/datacontext" "github.com/open-component-model/ocm/pkg/contexts/datacontext/attrs/tmpcache" "github.com/open-component-model/ocm/pkg/contexts/datacontext/attrs/vfsattr" @@ -54,7 +55,7 @@ var _ = Describe("Method", func() { var ( env *Builder accessSpec *s3.AccessSpec - downloader s3.Downloader + downloader downloader.Downloader expectedContent []byte err error mcc ocm.Context