diff --git a/util/downloader/downloader.go b/util/downloader/downloader.go new file mode 100644 index 000000000..bae58fe09 --- /dev/null +++ b/util/downloader/downloader.go @@ -0,0 +1,284 @@ +package downloader + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "hash" + "io" + "net/http" + "os" + "path/filepath" +) + +var ( + ErrHeaderRequest = errors.New("request header error") + ErrSHA256Mismatch = errors.New("sha256 mismatch") + ErrCreateDir = errors.New("create dir error") + ErrInvalidFilePath = errors.New("file path is a directory, not a file") + ErrGetFileInfo = errors.New("get file info error") + ErrCopyExistsFileData = errors.New("error copying existing file data") + ErrDoRequest = errors.New("error doing request") + ErrFileWriting = errors.New("error writing file") + ErrNewRequest = errors.New("error creating request") + ErrOpenFileExists = errors.New("error opening existing file") +) + +type Downloader struct { + client *http.Client + url string + filePath string + sha256Sum string + fileType string + fileName string + statsCh chan Stats + errCh chan error +} + +type Stats struct { + Downloaded int64 + TotalSize int64 + Percent float64 + Completed bool +} + +func New(url, filePath, sha256Sum string, opts ...Option) *Downloader { + opt := defaultOptions() + + for _, o := range opts { + o(opt) + } + + return &Downloader{ + client: opt.client, + url: url, + filePath: filePath, + sha256Sum: sha256Sum, + statsCh: make(chan Stats), + errCh: make(chan error, 1), + } +} + +func (d *Downloader) Start(ctx context.Context) { + go d.download(ctx) +} + +func (d *Downloader) Stats() <-chan Stats { + return d.statsCh +} + +func (d *Downloader) FileType() string { + return d.fileType +} + +func (d *Downloader) FileName() string { + return d.fileName +} + +func (d *Downloader) Errors() <-chan error { + return d.errCh +} + +func (d *Downloader) download(ctx context.Context) { + stats, err := d.getHeader(ctx) + if err != nil { + d.handleError(err) + + return + } + + d.fileName = filepath.Base(d.filePath) + if err := d.createDir(); err != nil { + d.handleError(err) + + return + } + + out, err := d.openFile() + if err != nil { + d.handleError(err) + + return + } + defer func() { + _ = out.Close() + }() + + if err := d.validateExistingFile(out, &stats); err != nil { + d.handleError(err) + + return + } + + if err := d.downloadFile(ctx, out, &stats); err != nil { + d.handleError(err) + } +} + +func (d *Downloader) getHeader(ctx context.Context) (Stats, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodHead, d.url, http.NoBody) + if err != nil { + return Stats{}, ErrHeaderRequest + } + + resp, err := d.client.Do(req) + if err != nil { + return Stats{}, ErrHeaderRequest + } + + defer func() { + _ = resp.Body.Close() + }() + + d.fileType = resp.Header.Get("Content-Type") + + return Stats{ + TotalSize: resp.ContentLength, + }, nil +} + +func (d *Downloader) createDir() error { + dir := filepath.Dir(d.filePath) + if err := os.MkdirAll(dir, 0o750); err != nil { + return ErrCreateDir + } + + return nil +} + +func (d *Downloader) openFile() (*os.File, error) { + fileInfo, err := os.Stat(d.filePath) + if err == nil && fileInfo.IsDir() { + return nil, ErrInvalidFilePath + } + + return os.OpenFile(d.filePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o600) +} + +func (*Downloader) validateExistingFile(out *os.File, stats *Stats) error { + fileInfo, err := out.Stat() + if err != nil { + return ErrGetFileInfo + } + stats.Downloaded = fileInfo.Size() + + return nil +} + +func (d *Downloader) downloadFile(ctx context.Context, out *os.File, stats *Stats) error { + req, err := d.createRequest(ctx, stats.Downloaded) + if err != nil { + return err + } + + resp, err := d.client.Do(req) + if err != nil { + return ErrDoRequest + } + + defer func() { + _ = resp.Body.Close() + }() + + buffer := make([]byte, 32*1024) + hasher := sha256.New() + + if err := d.updateHasherWithExistingData(stats.Downloaded, hasher); err != nil { + return err + } + + return d.writeToFile(ctx, resp, out, buffer, hasher, stats) +} + +func (d *Downloader) createRequest(ctx context.Context, downloaded int64) (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, d.url, http.NoBody) + if err != nil { + return nil, ErrNewRequest + } + if downloaded > 0 { + req.Header.Set("Range", fmt.Sprintf("bytes=%d-", downloaded)) + } + + return req, nil +} + +func (d *Downloader) updateHasherWithExistingData(downloaded int64, hasher io.Writer) error { + if downloaded > 0 { + existingFile, err := os.Open(d.filePath) + if err != nil { + return ErrOpenFileExists + } + defer func() { + _ = existingFile.Close() + }() + + if _, err := io.CopyN(hasher, existingFile, downloaded); err != nil { + return ErrCopyExistsFileData + } + } + + return nil +} + +func (d *Downloader) writeToFile(ctx context.Context, resp *http.Response, out *os.File, buffer []byte, + hasher hash.Hash, stats *Stats, +) error { + for { + select { + case <-ctx.Done(): + d.stop() + + return ctx.Err() + default: + n, err := resp.Body.Read(buffer) + if n > 0 { + if _, err := out.Write(buffer[:n]); err != nil { + return ErrFileWriting + } + + if _, err := hasher.Write(buffer[:n]); err != nil { + return ErrFileWriting + } + + stats.Downloaded += int64(n) + stats.Percent = float64(stats.Downloaded) / float64(stats.TotalSize) * 100 + d.statsCh <- *stats + } + if err != nil { + if err == io.EOF { + return d.finalizeDownload(hasher, stats) + } + + return fmt.Errorf("error reading response body: %w", err) + } + } + } +} + +func (d *Downloader) finalizeDownload(hasher hash.Hash, stats *Stats) error { + stats.Completed = true + sum := hex.EncodeToString(hasher.Sum(nil)) + if sum != d.sha256Sum { + return ErrSHA256Mismatch + } + d.statsCh <- *stats + + d.stop() + + return nil +} + +func (d *Downloader) stop() { + close(d.statsCh) + close(d.errCh) +} + +func (d *Downloader) handleError(err error) { + select { + case d.errCh <- err: + default: + d.stop() + } +} diff --git a/util/downloader/downloader_test.go b/util/downloader/downloader_test.go new file mode 100644 index 000000000..881584885 --- /dev/null +++ b/util/downloader/downloader_test.go @@ -0,0 +1,90 @@ +package downloader + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "log" + "net/http" + "net/http/httptest" + "os" + "testing" + "time" + + "github.com/pactus-project/pactus/util" + "github.com/stretchr/testify/assert" +) + +func TestDownloader(t *testing.T) { + fileContent := []byte("This is a test file content") + fileURL := "/testfile" + expectedSHA256 := sha256.Sum256(fileContent) + expectedSHA256Hex := hex.EncodeToString(expectedSHA256[:]) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == fileURL { + _, err := w.Write(fileContent) + assert.NoError(t, err) + } else { + http.NotFound(w, r) + } + })) + defer server.Close() + + filePath := util.TempFilePath() + + defer func() { + assert.NoError(t, os.RemoveAll("./testdata")) + }() + + dl := New(server.URL+fileURL, filePath, expectedSHA256Hex, WithCustomClient(server.Client())) + + assrt := assert.New(t) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + go func() { + dl.Start(ctx) + }() + + done := make(chan bool) + + go func() { + for stat := range dl.Stats() { + log.Printf("Downloaded: %d / %d (%.2f%%)\n", stat.Downloaded, stat.TotalSize, stat.Percent) + assrt.True(stat.Downloaded <= stat.TotalSize, "Downloaded size should not exceed total size") + assrt.True(stat.Percent <= 100, "Download percentage should not exceed 100") + + if stat.Completed { + log.Println("Download completed successfully") + assrt.Equal(float64(100), stat.Percent, "Download should be 100% complete") + done <- true + + return + } + } + }() + + go func() { + for err := range dl.Errors() { + assrt.Fail("Download encountered an error", err) + done <- true + + return + } + }() + + select { + case <-done: + case <-time.After(2 * time.Minute): + cancel() + assrt.Fail("Download test timed out") + } + + t.Log(dl.FileName()) + t.Log(dl.FileType()) + + downloadedContent, err := os.ReadFile(filePath) + assrt.NoError(err, "Failed to read the downloaded file") + assrt.Equal(fileContent, downloadedContent, "Downloaded file content does not match expected content") +} diff --git a/util/downloader/options.go b/util/downloader/options.go new file mode 100644 index 000000000..27c0042f4 --- /dev/null +++ b/util/downloader/options.go @@ -0,0 +1,21 @@ +package downloader + +import "net/http" + +type options struct { + client *http.Client +} + +type Option func(*options) + +func defaultOptions() *options { + return &options{ + client: http.DefaultClient, + } +} + +func WithCustomClient(client *http.Client) Option { + return func(o *options) { + o.client = client + } +}