Skip to content

Commit

Permalink
feat: support to skip tls when download file (#348)
Browse files Browse the repository at this point in the history
Co-authored-by: rick <[email protected]>
  • Loading branch information
LinuxSuRen and LinuxSuRen authored Feb 7, 2023
1 parent b80f4f5 commit ba52e6c
Show file tree
Hide file tree
Showing 6 changed files with 246 additions and 155 deletions.
26 changes: 24 additions & 2 deletions cmd/get.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ func newGetCmd(ctx context.Context) (cmd *cobra.Command) {
"Same with option --accept-preRelease")
flags.BoolVarP(&opt.Force, "force", "f", false, "Overwrite the exist file if this is true")
flags.IntVarP(&opt.Mod, "mod", "", -1, "The file permission, -1 means using the system default")
flags.BoolVarP(&opt.SkipTLS, "skip-tls", "k", false, "Skip the TLS")

flags.IntVarP(&opt.Timeout, "time", "", 10,
`The default timeout in seconds with the HTTP request`)
Expand Down Expand Up @@ -106,6 +107,7 @@ type downloadOption struct {
Magnet bool
Force bool
Mod int
SkipTLS bool

ContinueAt int64

Expand Down Expand Up @@ -297,24 +299,44 @@ func (o *downloadOption) runE(cmd *cobra.Command, args []string) (err error) {
targetURL = strings.Replace(targetURL, "raw.githubusercontent.com", fmt.Sprintf("%s/https://raw.githubusercontent.com", o.ProxyGitHub), 1)
}
logger.Printf("start to download from %s\n", targetURL)
var suggestedFilenameAware net.SuggestedFilenameAware
if o.Thread <= 1 {
downloader := &net.ContinueDownloader{}
suggestedFilenameAware = downloader
downloader.WithoutProxy(o.NoProxy).
WithRoundTripper(o.RoundTripper)
WithRoundTripper(o.RoundTripper).
WithInsecureSkipVerify(o.SkipTLS)
err = downloader.DownloadWithContinue(targetURL, o.Output, o.ContinueAt, -1, 0, o.ShowProgress)
} else {
downloader := &net.MultiThreadDownloader{}
suggestedFilenameAware = downloader
downloader.WithKeepParts(o.KeepPart).
WithShowProgress(o.ShowProgress).
WithoutProxy(o.NoProxy).
WithRoundTripper(o.RoundTripper)
WithRoundTripper(o.RoundTripper).
WithInsecureSkipVerify(o.SkipTLS)
err = downloader.Download(targetURL, o.Output, o.Thread)
}

// set file permission
if o.Mod != -1 {
err = sysos.Chmod(o.Output, fs.FileMode(o.Mod))
}

if err == nil {
logger.Printf("downloaded: %s\n", o.Output)
}

if suggested := suggestedFilenameAware.GetSuggestedFilename(); suggested != "" {
confirm := &survey.Confirm{
Message: fmt.Sprintf("Do you want to rename filename from '%s' to '%s'?", o.Output, suggested),
}
var yes bool
if confirmErr := survey.AskOne(confirm, &yes); confirmErr == nil && yes {
fmt.Println("rename")
err = sysos.Rename(o.Output, suggested)
}
}
return
}

Expand Down
187 changes: 47 additions & 140 deletions pkg/net/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"os"
"path"
"strconv"
"sync"
"strings"
"time"

"github.com/linuxsuren/http-downloader/pkg/common"
Expand Down Expand Up @@ -50,6 +50,7 @@ type HTTPDownloader struct {
Debug bool
RoundTripper http.RoundTripper
progressIndicator *ProgressIndicator
suggestedFilename string
}

// SetProxy set the proxy for a http
Expand Down Expand Up @@ -150,6 +151,14 @@ func (h *HTTPDownloader) DownloadFile() error {
}
}

if disposition, ok := resp.Header["Content-Disposition"]; ok && len(disposition) >= 1 {
h.suggestedFilename = strings.TrimPrefix(disposition[0], `filename="`)
h.suggestedFilename = strings.TrimSuffix(h.suggestedFilename, `"`)
if h.suggestedFilename == filepath {
h.suggestedFilename = ""
}
}

// pre-hook before get started to download file
if h.PreStart != nil && !h.PreStart(resp) {
return nil
Expand Down Expand Up @@ -192,127 +201,15 @@ func (h *HTTPDownloader) DownloadFile() error {
return err
}

// DownloadFileWithMultipleThread downloads the files with multiple threads
func DownloadFileWithMultipleThread(targetURL, targetFilePath string, thread int, showProgress bool) (err error) {
return DownloadFileWithMultipleThreadKeepParts(targetURL, targetFilePath, thread, false, showProgress)
// GetSuggestedFilename returns the suggested filename which comes from the HTTP response header.
// Returns empty string if the filename is same with the given name.
func (h *HTTPDownloader) GetSuggestedFilename() string {
return h.suggestedFilename
}

// MultiThreadDownloader is a download with multi-thread
type MultiThreadDownloader struct {
noProxy bool
keepParts, showProgress bool

roundTripper http.RoundTripper
}

// WithoutProxy indicates not use HTTP proxy
func (d *MultiThreadDownloader) WithoutProxy(noProxy bool) *MultiThreadDownloader {
d.noProxy = noProxy
return d
}

// WithShowProgress indicate if show the download progress
func (d *MultiThreadDownloader) WithShowProgress(showProgress bool) *MultiThreadDownloader {
d.showProgress = showProgress
return d
}

// WithKeepParts indicates if keeping the part files
func (d *MultiThreadDownloader) WithKeepParts(keepParts bool) *MultiThreadDownloader {
d.keepParts = keepParts
return d
}

// WithRoundTripper sets RoundTripper
func (d *MultiThreadDownloader) WithRoundTripper(roundTripper http.RoundTripper) *MultiThreadDownloader {
d.roundTripper = roundTripper
return d
}

// Download starts to download the target URL
func (d *MultiThreadDownloader) Download(targetURL, targetFilePath string, thread int) (err error) {
// get the total size of the target file
var total int64
var rangeSupport bool
if total, rangeSupport, err = DetectSizeWithRoundTripper(targetURL, targetFilePath, true, d.noProxy, d.roundTripper); err != nil {
return
}

if rangeSupport {
unit := total / int64(thread)
offset := total - unit*int64(thread)
var wg sync.WaitGroup
var partItems []string
var m sync.Mutex

defer func() {
// remove all partial files
for _, part := range partItems {
_ = os.RemoveAll(part)
}
}()

fmt.Printf("start to download with %d threads, size: %d, unit: %d\n", thread, total, unit)
for i := 0; i < thread; i++ {
wg.Add(1)
go func(index int, wg *sync.WaitGroup) {
defer wg.Done()
output := fmt.Sprintf("%s-%d", targetFilePath, index)

m.Lock()
partItems = append(partItems, output)
m.Unlock()

end := unit*int64(index+1) - 1
if index == thread-1 {
// this is the last part
end += offset
}
start := unit * int64(index)

downloader := &ContinueDownloader{}
downloader.WithoutProxy(d.noProxy).
WithRoundTripper(d.roundTripper)
if downloadErr := downloader.DownloadWithContinue(targetURL, output,
int64(index), start, end, d.showProgress); downloadErr != nil {
fmt.Println(downloadErr)
}
}(i, &wg)
}

wg.Wait()
ProgressIndicator{}.Close()

// concat all these partial files
var f *os.File
if f, err = os.OpenFile(targetFilePath, os.O_CREATE|os.O_WRONLY, 0600); err == nil {
defer func() {
_ = f.Close()
}()

for i := 0; i < thread; i++ {
partFile := fmt.Sprintf("%s-%d", targetFilePath, i)
if data, ferr := os.ReadFile(partFile); ferr == nil {
if _, err = f.Write(data); err != nil {
err = fmt.Errorf("failed to write file: '%s'", partFile)
break
} else if !d.keepParts {
_ = os.RemoveAll(partFile)
}
} else {
err = fmt.Errorf("failed to read file: '%s'", partFile)
break
}
}
}
} else {
fmt.Println("cannot download it using multiple threads, failed to one")
downloader := &ContinueDownloader{}
downloader.WithoutProxy(d.noProxy)
downloader.WithRoundTripper(d.roundTripper)
err = downloader.DownloadWithContinue(targetURL, targetFilePath, -1, 0, 0, true)
}
return
// SuggestedFilenameAware is the interface for getting suggested filename
type SuggestedFilenameAware interface {
GetSuggestedFilename() string
}

// DownloadFileWithMultipleThreadKeepParts downloads the files with multiple threads
Expand All @@ -326,8 +223,14 @@ func DownloadFileWithMultipleThreadKeepParts(targetURL, targetFilePath string, t
type ContinueDownloader struct {
downloader *HTTPDownloader

roundTripper http.RoundTripper
noProxy bool
roundTripper http.RoundTripper
noProxy bool
insecureSkipVerify bool
}

// GetSuggestedFilename returns the suggested filename
func (c *ContinueDownloader) GetSuggestedFilename() string {
return c.downloader.GetSuggestedFilename()
}

// WithRoundTripper set WithRoundTripper
Expand All @@ -342,14 +245,21 @@ func (c *ContinueDownloader) WithoutProxy(noProxy bool) *ContinueDownloader {
return c
}

// WithInsecureSkipVerify set if skip the insecure verify
func (c *ContinueDownloader) WithInsecureSkipVerify(insecureSkipVerify bool) *ContinueDownloader {
c.insecureSkipVerify = insecureSkipVerify
return c
}

// DownloadWithContinue downloads the files continuously
func (c *ContinueDownloader) DownloadWithContinue(targetURL, output string, index, continueAt, end int64, showProgress bool) (err error) {
c.downloader = &HTTPDownloader{
TargetFilePath: output,
URL: targetURL,
ShowProgress: showProgress,
NoProxy: c.noProxy,
RoundTripper: c.roundTripper,
TargetFilePath: output,
URL: targetURL,
ShowProgress: showProgress,
NoProxy: c.noProxy,
RoundTripper: c.roundTripper,
InsecureSkipVerify: c.insecureSkipVerify,
}
if index >= 0 {
c.downloader.Title = fmt.Sprintf("Downloading part %d", index)
Expand All @@ -371,21 +281,16 @@ func (c *ContinueDownloader) DownloadWithContinue(targetURL, output string, inde
return
}

// DetectSize returns the size of target resource
//
// Deprecated, use DetectSizeWithRoundTripper instead
func DetectSize(targetURL, output string, showProgress bool) (int64, bool, error) {
return DetectSizeWithRoundTripper(targetURL, output, showProgress, false, nil)
}

// DetectSizeWithRoundTripper returns the size of target resource
func DetectSizeWithRoundTripper(targetURL, output string, showProgress bool, noProxy bool, roundTripper http.RoundTripper) (total int64, rangeSupport bool, err error) {
func DetectSizeWithRoundTripper(targetURL, output string, showProgress, noProxy, insecureSkipVerify bool,
roundTripper http.RoundTripper) (total int64, rangeSupport bool, err error) {
downloader := HTTPDownloader{
TargetFilePath: output,
URL: targetURL,
ShowProgress: showProgress,
RoundTripper: roundTripper,
NoProxy: false, // below HTTP request does not need proxy
TargetFilePath: output,
URL: targetURL,
ShowProgress: showProgress,
RoundTripper: roundTripper,
NoProxy: false, // below HTTP request does not need proxy
InsecureSkipVerify: insecureSkipVerify,
}

var detectOffset int64
Expand All @@ -400,6 +305,8 @@ func DetectSizeWithRoundTripper(targetURL, output string, showProgress bool, noP
contentLen := resp.Header.Get("Content-Length")
if total, lenErr = strconv.ParseInt(contentLen, 10, 0); lenErr == nil {
total += detectOffset
} else {
rangeSupport = false
}
// always return false because we just want to get the header from response
return false
Expand Down
Loading

0 comments on commit ba52e6c

Please sign in to comment.