Skip to content

Commit

Permalink
feature: 增加分片并发下载和断续下载 (#78)
Browse files Browse the repository at this point in the history
* feature: 增加分片并发下载和断续下载

- 增加多线程分片下载文件
- 增加下载文件但是进程退出后,重新下载恢复之前的进度

* fix: 修复上传的文件是URL链接,的判断错误
  • Loading branch information
arrebole authored Apr 6, 2023
1 parent 0f03e46 commit dc6b159
Show file tree
Hide file tree
Showing 9 changed files with 341 additions and 22 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,4 @@ jobs:
- name: Test
run: |
go build .
go test -v .
go test -v ./...
20 changes: 16 additions & 4 deletions commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,15 +282,19 @@ func NewGetCommand() cli.Command {
PrintErrorAndExit("get %s: parse mtime: %v", upPath, err)
}
}
if c.Int("w") > 10 || c.Int("w") < 1 {
PrintErrorAndExit("max concurrent threads must between (1 - 10)")
}
if mc.Start != "" || mc.End != "" {
session.GetStartBetweenEndFiles(upPath, localPath, mc, c.Int("w"))
} else {
session.Get(upPath, localPath, mc, c.Int("w"))
session.Get(upPath, localPath, mc, c.Int("w"), c.Bool("c"))
}
return nil
},
Flags: []cli.Flag{
cli.IntFlag{Name: "w", Usage: "max concurrent threads", Value: 5},
cli.IntFlag{Name: "w", Usage: "max concurrent threads (1-10)", Value: 5},
cli.BoolFlag{Name: "c", Usage: "continue download, Resume Broken Download"},
cli.StringFlag{Name: "mtime", Usage: "file's data was last modified n*24 hours ago, same as linux find command."},
cli.StringFlag{Name: "start", Usage: "file download range starting location"},
cli.StringFlag{Name: "end", Usage: "file download range ending location"},
Expand All @@ -315,7 +319,9 @@ func NewPutCommand() cli.Command {
if c.NArg() > 1 {
upPath = c.Args().Get(1)
}

if c.Int("w") > 10 || c.Int("w") < 1 {
PrintErrorAndExit("max concurrent threads must between (1 - 10)")
}
session.Put(
localPath,
upPath,
Expand All @@ -332,9 +338,12 @@ func NewPutCommand() cli.Command {
func NewUploadCommand() cli.Command {
return cli.Command{
Name: "upload",
Usage: "upload multiple directory or file",
Usage: "upload multiple directory or file or http url",
Action: func(c *cli.Context) error {
InitAndCheck(LOGIN, CHECK, c)
if c.Int("w") > 10 || c.Int("w") < 1 {
PrintErrorAndExit("max concurrent threads must between (1 - 10)")
}
session.Upload(
c.Args(),
c.String("remote"),
Expand Down Expand Up @@ -422,6 +431,9 @@ func NewSyncCommand() cli.Command {
if c.NArg() > 1 {
upPath = c.Args().Get(1)
}
if c.Int("w") > 10 || c.Int("w") < 1 {
PrintErrorAndExit("max concurrent threads must between (1 - 10)")
}
session.Sync(localPath, upPath, c.Int("w"), c.Bool("delete"), c.Bool("strong"))
return nil
},
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ require (
github.com/jehiah/go-strftime v0.0.0-20171201141054-1d33003b3869
github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0
github.com/syndtr/goleveldb v1.0.0
github.com/upyun/go-sdk/v3 v3.0.3
github.com/upyun/go-sdk/v3 v3.0.4
github.com/urfave/cli v1.22.4
)

Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFd
github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ=
github.com/upyun/go-sdk/v3 v3.0.3 h1:2wUkNk2fyJReMYHMvJyav050D83rYwSjN7mEPR0Pp8Q=
github.com/upyun/go-sdk/v3 v3.0.3/go.mod h1:P/SnuuwhrIgAVRd/ZpzDWqCsBAf/oHg7UggbAxyZa0E=
github.com/upyun/go-sdk/v3 v3.0.4 h1:2DCJa/Yi7/3ZybT9UCPATSzvU3wpPPxhXinNlb1Hi8Q=
github.com/upyun/go-sdk/v3 v3.0.4/go.mod h1:P/SnuuwhrIgAVRd/ZpzDWqCsBAf/oHg7UggbAxyZa0E=
github.com/urfave/cli v1.22.4 h1:u7tSpNPPswAFymm8IehJhy4uJMlUuU/GmqSkvJ1InXA=
github.com/urfave/cli v1.22.4/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
Expand Down
18 changes: 15 additions & 3 deletions io.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,27 @@ func (w *WrappedWriter) Close() error {
return w.w.Close()
}

func NewFileWrappedWriter(localPath string, bar *uiprogress.Bar) (*WrappedWriter, error) {
fd, err := os.Create(localPath)
func NewFileWrappedWriter(localPath string, bar *uiprogress.Bar, resume bool) (*WrappedWriter, error) {
var fd *os.File
var err error

if resume {
fd, err = os.OpenFile(localPath, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0755)
} else {
fd, err = os.Create(localPath)
}
if err != nil {
return nil, err
}

fileinfo, err := fd.Stat()
if err != nil {
return nil, err
}

return &WrappedWriter{
w: fd,
Copyed: 0,
Copyed: int(fileinfo.Size()),
bar: bar,
}, nil
}
Expand Down
97 changes: 97 additions & 0 deletions partial/chunk.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package partial

import (
"sync/atomic"
)

type Chunk struct {
// 切片的顺序
index int64

// 切片内容的在源文件的开始地址
start int64

// 切片内容在源文件的结束地址
end int64

// 切片任务的下载错误
err error

// 下载完的切片的具体内容
buffer []byte
}

func NewChunk(index, start, end int64) *Chunk {
chunk := &Chunk{
start: start,
end: end,
index: index,
}
return chunk
}

func (p *Chunk) SetData(bytes []byte) {
p.buffer = bytes
}

func (p *Chunk) SetError(err error) {
p.err = err
}

func (p *Chunk) Error() error {
return p.err
}

func (p *Chunk) Data() []byte {
return p.buffer
}

// 切片乱序写入后,将切片顺序读取
type ChunksSorter struct {
// 已经读取的切片数量
readCount int64

// 切片的所有总数
chunkCount int64

// 线程数,用于阻塞写入
works int64

// 存储切片的缓存区
chunks []chan *Chunk
}

func NewChunksSorter(chunkCount int64, works int) *ChunksSorter {
chunks := make([]chan *Chunk, works)
for i := 0; i < len(chunks); i++ {
chunks[i] = make(chan *Chunk)
}

return &ChunksSorter{
chunkCount: chunkCount,
works: int64(works),
chunks: chunks,
}
}

// 将数据写入到缓存区,如果该缓存已满,则会被阻塞
func (p *ChunksSorter) Write(chunk *Chunk) {
p.chunks[chunk.index%p.works] <- chunk
}

// 关闭 workId 下的通道
func (p *ChunksSorter) Close(workId int) {
if (len(p.chunks) - 1) >= workId {
close(p.chunks[workId])
}
}

// 顺序读取切片,如果下一个切片没有下载完,则会被阻塞
func (p *ChunksSorter) Read() *Chunk {
if p.chunkCount == 0 {
return nil
}
i := atomic.AddInt64(&p.readCount, 1)
chunk := <-p.chunks[(i-1)%p.works]
return chunk
}
141 changes: 141 additions & 0 deletions partial/downloader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package partial

import (
"context"
"errors"
"io"
"os"
"sync"
)

const DefaultChunkSize = 1024 * 1024 * 10

type ChunkDownFunc func(start, end int64) ([]byte, error)

type MultiPartialDownloader struct {

// 文件路径
filePath string

// 最终文件大小
finalSize int64

// 本地文件大小
localSize int64

//分片大小
chunkSize int64

writer io.Writer
works int
downFunc ChunkDownFunc
}

func NewMultiPartialDownloader(filePath string, finalSize, chunkSize int64, writer io.Writer, works int, fn ChunkDownFunc) *MultiPartialDownloader {
return &MultiPartialDownloader{
filePath: filePath,
finalSize: finalSize,
works: works,
writer: writer,
chunkSize: chunkSize,
downFunc: fn,
}
}

func (p *MultiPartialDownloader) Download() error {
fileinfo, err := os.Stat(p.filePath)

// 如果异常
// - 文件不存在异常: localSize 默认值 0
// - 不是文件不存在异常: 报错
if err != nil && !os.IsNotExist(err) {
return err
}
if err == nil {
p.localSize = fileinfo.Size()
}

// 计算需要下载的块数
needDownSize := p.finalSize - p.localSize
chunkCount := needDownSize / p.chunkSize
if needDownSize%p.chunkSize != 0 {
chunkCount++
}

chunksSorter := NewChunksSorter(
chunkCount,
p.works,
)

// 下载切片任务
var wg sync.WaitGroup
ctx, cancel := context.WithCancel(context.Background())
defer func() {
// 取消切片下载任务,并等待
cancel()
wg.Wait()
}()

for i := 0; i < p.works; i++ {
wg.Add(1)
go func(ctx context.Context, workId int) {
defer func() {
// 关闭 workId 下的接收通道
chunksSorter.Close(workId)
wg.Done()
}()

// 每个 work 取自己倍数的 chunk
for j := workId; j < int(chunkCount); j += p.works {
select {
case <-ctx.Done():
return
default:
var (
err error
buffer []byte
)
start := p.localSize + int64(j)*p.chunkSize
end := p.localSize + int64(j+1)*p.chunkSize
if end > p.finalSize {
end = p.finalSize
}
chunk := NewChunk(int64(j), start, end)

// 重试三次
for t := 0; t < 3; t++ {
// ? 由于长度是从1开始,而数据是从0地址开始
// ? 计算字节时容量会多出开头的一位,所以末尾需要减少一位
buffer, err = p.downFunc(chunk.start, chunk.end-1)
if err == nil {
break
}
}
chunk.SetData(buffer)
chunk.SetError(err)
chunksSorter.Write(chunk)

if err != nil {
return
}
}
}
}(ctx, i)
}

// 将分片顺序写入到文件
for {
chunk := chunksSorter.Read()
if chunk == nil {
break
}
if chunk.Error() != nil {
return chunk.Error()
}
if len(chunk.Data()) == 0 {
return errors.New("chunk buffer download but size is 0")
}
p.writer.Write(chunk.Data())
}
return nil
}
32 changes: 32 additions & 0 deletions partial/downloader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package partial

import (
"bytes"
"crypto/md5"
"strings"
"testing"
)

func TestDownload(t *testing.T) {
var buffer bytes.Buffer

filedata := []byte(strings.Repeat("hello world", 1024*100))
download := NewMultiPartialDownloader(
"myTestfile",
int64(len(filedata)),
1024,
&buffer,
3,
func(start, end int64) ([]byte, error) {
return filedata[start : end+1], nil
},
)

err := download.Download()
if err != nil {
t.Fatal(err.Error())
}
if md5.Sum(buffer.Bytes()) != md5.Sum(filedata) {
t.Fatal("download file has diff MD5")
}
}
Loading

0 comments on commit dc6b159

Please sign in to comment.