Skip to content

Commit

Permalink
Merge pull request #861 from ajeddeloh/openstack-multipart
Browse files Browse the repository at this point in the history
Clean up url handling
  • Loading branch information
Andrew Jeddeloh authored Sep 25, 2019
2 parents e5cbb6a + 27c3a93 commit 15b5b04
Showing 1 changed file with 55 additions and 25 deletions.
80 changes: 55 additions & 25 deletions internal/resource/url.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,25 +100,50 @@ type FetchOptions struct {
// in the contents of the file and delete it. It will return the downloaded
// contents, or an error if one was encountered.
func (f *Fetcher) FetchToBuffer(u url.URL, opts FetchOptions) ([]byte, error) {
file, err := ioutil.TempFile("", "ignition")
if err != nil {
return nil, err
}
defer os.Remove(file.Name())
defer file.Close()
err = f.Fetch(u, file, opts)
if err != nil {
return nil, err
var err error
dest := new(bytes.Buffer)
switch u.Scheme {
case "http", "https":
err = f.fetchFromHTTP(u, dest, opts)
case "tftp":
err = f.fetchFromTFTP(u, dest, opts)
case "data":
err = f.fetchFromDataURL(u, dest, opts)
case "s3":
buf := &s3buf{
WriteAtBuffer: aws.NewWriteAtBuffer([]byte{}),
}
err = f.fetchFromS3(u, buf, opts)
return buf.Bytes(), err
case "":
return nil, nil
default:
return nil, ErrSchemeUnsupported
}
_, err = file.Seek(0, os.SEEK_SET)
if err != nil {
return nil, err
return dest.Bytes(), err
}

// s3buf is a wrapper around the aws.WriteAtBuffer that also allows reading and seeking.
// Read() and Seek() are only safe to call after the download call is made. This is only for
// use with fetchFromS3* functions.
type s3buf struct {
*aws.WriteAtBuffer
// only safe to call read/seek after finishing writing. Not safe for parallel use
reader io.ReadSeeker
}

func (s *s3buf) Read(p []byte) (int, error) {
if s.reader == nil {
s.reader = bytes.NewReader(s.Bytes())
}
res, err := ioutil.ReadAll(file)
if err != nil {
return nil, err
return s.reader.Read(p)
}

func (s *s3buf) Seek(offset int64, whence int) (int64, error) {
if s.reader == nil {
s.reader = bytes.NewReader(s.Bytes())
}
return res, nil
return s.reader.Seek(offset, whence)
}

// Fetch calls the appropriate FetchFrom* function based on the scheme of the
Expand All @@ -134,13 +159,13 @@ func (f *Fetcher) FetchToBuffer(u url.URL, opts FetchOptions) ([]byte, error) {
func (f *Fetcher) Fetch(u url.URL, dest *os.File, opts FetchOptions) error {
switch u.Scheme {
case "http", "https":
return f.FetchFromHTTP(u, dest, opts)
return f.fetchFromHTTP(u, dest, opts)
case "tftp":
return f.FetchFromTFTP(u, dest, opts)
return f.fetchFromTFTP(u, dest, opts)
case "data":
return f.FetchFromDataURL(u, dest, opts)
return f.fetchFromDataURL(u, dest, opts)
case "s3":
return f.FetchFromS3(u, dest, opts)
return f.fetchFromS3(u, dest, opts)
case "":
return nil
default:
Expand All @@ -150,7 +175,7 @@ func (f *Fetcher) Fetch(u url.URL, dest *os.File, opts FetchOptions) error {

// FetchFromTFTP fetches a resource from u via TFTP into dest, returning an
// error if one is encountered.
func (f *Fetcher) FetchFromTFTP(u url.URL, dest *os.File, opts FetchOptions) error {
func (f *Fetcher) fetchFromTFTP(u url.URL, dest io.Writer, opts FetchOptions) error {
if !strings.ContainsRune(u.Host, ':') {
u.Host = u.Host + ":69"
}
Expand Down Expand Up @@ -212,7 +237,7 @@ func (f *Fetcher) FetchFromTFTP(u url.URL, dest *os.File, opts FetchOptions) err

// FetchFromHTTP fetches a resource from u via HTTP(S) into dest, returning an
// error if one is encountered.
func (f *Fetcher) FetchFromHTTP(u url.URL, dest *os.File, opts FetchOptions) error {
func (f *Fetcher) fetchFromHTTP(u url.URL, dest io.Writer, opts FetchOptions) error {
// for the case when "config is not valid"
// this if necessary if not spawned through kola (e.g. Packet Dashboard)
if f.client == nil {
Expand Down Expand Up @@ -248,7 +273,7 @@ func (f *Fetcher) FetchFromHTTP(u url.URL, dest *os.File, opts FetchOptions) err

// FetchFromDataURL writes the data stored in the dataurl u into dest, returning
// an error if one is encountered.
func (f *Fetcher) FetchFromDataURL(u url.URL, dest *os.File, opts FetchOptions) error {
func (f *Fetcher) fetchFromDataURL(u url.URL, dest io.Writer, opts FetchOptions) error {
if opts.Compression != "" {
return ErrCompressionUnsupported
}
Expand All @@ -260,11 +285,16 @@ func (f *Fetcher) FetchFromDataURL(u url.URL, dest *os.File, opts FetchOptions)
return f.decompressCopyHashAndVerify(dest, bytes.NewBuffer(url.Data), opts)
}

type s3target interface {
io.WriterAt
io.ReadSeeker
}

// FetchFromS3 gets data from an S3 bucket as described by u and writes it into
// dest, returning an error if one is encountered. It will attempt to acquire
// IAM credentials from the EC2 metadata service, and if this fails will attempt
// to fetch the object with anonymous credentials.
func (f *Fetcher) FetchFromS3(u url.URL, dest *os.File, opts FetchOptions) error {
func (f *Fetcher) fetchFromS3(u url.URL, dest s3target, opts FetchOptions) error {
if opts.Compression != "" {
return ErrCompressionUnsupported
}
Expand Down Expand Up @@ -337,7 +367,7 @@ func (f *Fetcher) FetchFromS3(u url.URL, dest *os.File, opts FetchOptions) error
return nil
}

func (f *Fetcher) fetchFromS3WithCreds(ctx context.Context, dest *os.File, input *s3.GetObjectInput, sess *session.Session) error {
func (f *Fetcher) fetchFromS3WithCreds(ctx context.Context, dest s3target, input *s3.GetObjectInput, sess *session.Session) error {
httpClient, err := defaultHTTPClient()
if err != nil {
return err
Expand Down

0 comments on commit 15b5b04

Please sign in to comment.