Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Concurrent Download Support for artifacts #11531

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 71 additions & 20 deletions client/allocrunner/taskrunner/artifact_hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package taskrunner
import (
"context"
"fmt"
"sync"

log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/nomad/client/allocrunner/interfaces"
Expand All @@ -25,32 +26,19 @@ func newArtifactHook(e ti.EventEmitter, logger log.Logger) *artifactHook {
return h
}

func (*artifactHook) Name() string {
// Copied in client/state when upgrading from <0.9 schemas, so if you
// change it here you also must change it there.
return "artifacts"
}

func (h *artifactHook) Prestart(ctx context.Context, req *interfaces.TaskPrestartRequest, resp *interfaces.TaskPrestartResponse) error {
if len(req.Task.Artifacts) == 0 {
resp.Done = true
return nil
}

// Initialize hook state to store download progress
resp.State = make(map[string]string, len(req.Task.Artifacts))

h.eventEmitter.EmitEvent(structs.NewTaskEvent(structs.TaskDownloadingArtifacts))

for _, artifact := range req.Task.Artifacts {
func (h *artifactHook) doWork(req *interfaces.TaskPrestartRequest, resp *interfaces.TaskPrestartResponse, jobs chan *structs.TaskArtifact, errorChannel chan error, wg *sync.WaitGroup, responseStateMutex *sync.Mutex) {
defer wg.Done()
for artifact := range jobs {
gowthamgts marked this conversation as resolved.
Show resolved Hide resolved
aid := artifact.Hash()
if req.PreviousState[aid] != "" {
h.logger.Trace("skipping already downloaded artifact", "artifact", artifact.GetterSource)
responseStateMutex.Lock()
resp.State[aid] = req.PreviousState[aid]
responseStateMutex.Unlock()
continue
}

h.logger.Debug("downloading artifact", "artifact", artifact.GetterSource)
h.logger.Debug("downloading artifact", "artifact", artifact.GetterSource, "aid", aid)
//XXX add ctx to GetArtifact to allow cancelling long downloads
if err := getter.GetArtifact(req.TaskEnv, artifact); err != nil {

Expand All @@ -60,13 +48,76 @@ func (h *artifactHook) Prestart(ctx context.Context, req *interfaces.TaskPrestar
)
herr := NewHookError(wrapped, structs.NewTaskEvent(structs.TaskArtifactDownloadFailed).SetDownloadError(wrapped))

return herr
errorChannel <- herr
continue
}

// Mark artifact as downloaded to avoid re-downloading due to
// retries caused by subsequent artifacts failing. Any
// non-empty value works.
responseStateMutex.Lock()
resp.State[aid] = "1"
responseStateMutex.Unlock()
}
}

func (*artifactHook) Name() string {
// Copied in client/state when upgrading from <0.9 schemas, so if you
// change it here you also must change it there.
return "artifacts"
}

func (h *artifactHook) Prestart(ctx context.Context, req *interfaces.TaskPrestartRequest, resp *interfaces.TaskPrestartResponse) error {
if len(req.Task.Artifacts) == 0 {
resp.Done = true
return nil
}

// Initialize hook state to store download progress
resp.State = make(map[string]string, len(req.Task.Artifacts))

// responseStateMutex is a lock used to guard against concurrent writes to the above resp.State map
responseStateMutex := &sync.Mutex{}

h.eventEmitter.EmitEvent(structs.NewTaskEvent(structs.TaskDownloadingArtifacts))

// maxConcurrency denotes the number of workers that will download artifacts in parallel
maxConcurrency := 3

// jobsChannel is a buffered channel which will have all the artifacts that needs to be processed
jobsChannel := make(chan *structs.TaskArtifact, maxConcurrency)

// errorChannel is also a buffered channel that will be used to signal errors
errorChannel := make(chan error, maxConcurrency)

// create workers and process artifacts
go func() {
defer close(errorChannel)
var wg sync.WaitGroup
for i := 0; i < maxConcurrency; i++ {
wg.Add(1)
go h.doWork(req, resp, jobsChannel, errorChannel, &wg, responseStateMutex)
}
wg.Wait()
}()

// Push all artifact requests to job channel
go func() {
defer close(jobsChannel)
for _, artifact := range req.Task.Artifacts {
jobsChannel <- artifact
gowthamgts marked this conversation as resolved.
Show resolved Hide resolved
}
}()

// Iterate over the errorChannel and if there is an error, store it to a variable for future return
var err error
for e := range errorChannel {
err = e
}

// once error channel is closed, we can check and return the error
if err != nil {
return err
}

resp.Done = true
Expand Down
228 changes: 228 additions & 0 deletions client/allocrunner/taskrunner/artifact_hook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package taskrunner

import (
"context"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -158,3 +159,230 @@ func TestTaskRunner_ArtifactHook_PartialDone(t *testing.T) {
require.True(t, resp.Done)
require.Len(t, resp.State, 2)
}

// TestTaskRunner_ArtifactHook_ConcurrentDownload asserts that the artifact hook
// download multiple files concurrently. this is a successful test without any errors.
func TestTaskRunner_ArtifactHook_ConcurrentDownloadSuccess(t *testing.T) {
t.Parallel()

me := &mockEmitter{}
artifactHook := newArtifactHook(me, testlog.HCLogger(t))

// Create a source directory all 7 artifacts
srcdir, err := ioutil.TempDir("", "nomadtest-src")
require.NoError(t, err)
defer func() {
require.NoError(t, os.RemoveAll(srcdir))
}()
gowthamgts marked this conversation as resolved.
Show resolved Hide resolved

numOfFiles := 7
for i := 0; i < numOfFiles; i++ {
file := filepath.Join(srcdir, fmt.Sprintf("file%d.txt", i))
require.NoError(t, ioutil.WriteFile(file, []byte{byte(i)}, 0644))
}

// Test server to serve the artifacts
ts := httptest.NewServer(http.FileServer(http.Dir(srcdir)))
defer ts.Close()

// Create the target directory.
destdir, err := ioutil.TempDir("", "nomadtest-dest")
require.NoError(t, err)
defer func() {
require.NoError(t, os.RemoveAll(destdir))
}()

req := &interfaces.TaskPrestartRequest{
TaskEnv: taskenv.NewTaskEnv(nil, nil, nil, nil, destdir, ""),
TaskDir: &allocdir.TaskDir{Dir: destdir},
Task: &structs.Task{
Artifacts: []*structs.TaskArtifact{
{
GetterSource: ts.URL + "/file0.txt",
GetterMode: structs.GetterModeAny,
},
{
GetterSource: ts.URL + "/file1.txt",
GetterMode: structs.GetterModeAny,
},
{
GetterSource: ts.URL + "/file2.txt",
GetterMode: structs.GetterModeAny,
},
{
GetterSource: ts.URL + "/file3.txt",
GetterMode: structs.GetterModeAny,
},
{
GetterSource: ts.URL + "/file4.txt",
GetterMode: structs.GetterModeAny,
},
{
GetterSource: ts.URL + "/file5.txt",
GetterMode: structs.GetterModeAny,
},
{
GetterSource: ts.URL + "/file6.txt",
GetterMode: structs.GetterModeAny,
},
},
},
}

resp := interfaces.TaskPrestartResponse{}

// start the hook
err = artifactHook.Prestart(context.Background(), req, &resp)

require.NoError(t, err)
require.True(t, resp.Done)
require.Len(t, resp.State, 7)
require.Len(t, me.events, 1)
require.Equal(t, structs.TaskDownloadingArtifacts, me.events[0].Type)

// Assert all files downloaded properly
files, err := filepath.Glob(filepath.Join(destdir, "*.txt"))
require.NoError(t, err)
require.Len(t, files, 7)
sort.Strings(files)
require.Contains(t, files[0], "file0.txt")
require.Contains(t, files[1], "file1.txt")
require.Contains(t, files[2], "file2.txt")
require.Contains(t, files[3], "file3.txt")
require.Contains(t, files[4], "file4.txt")
require.Contains(t, files[5], "file5.txt")
require.Contains(t, files[6], "file6.txt")

// Stop the test server entirely and assert that re-running works
ts.Close()
gowthamgts marked this conversation as resolved.
Show resolved Hide resolved
}

// TestTaskRunner_ArtifactHook_ConcurrentDownload asserts that the artifact hook
gowthamgts marked this conversation as resolved.
Show resolved Hide resolved
// download multiple files concurrently. first iteration will result in failure and
// second iteration should succeed without downloading already downloaded files.
func TestTaskRunner_ArtifactHook_ConcurrentDownloadFailure(t *testing.T) {
t.Parallel()

me := &mockEmitter{}
artifactHook := newArtifactHook(me, testlog.HCLogger(t))

// Create a source directory with 3 of the 4 artifacts
srcdir, err := ioutil.TempDir("", "nomadtest-src")
require.NoError(t, err)
defer func() {
require.NoError(t, os.RemoveAll(srcdir))
}()

file1 := filepath.Join(srcdir, "file1.txt")
require.NoError(t, ioutil.WriteFile(file1, []byte{'1'}, 0644))

file2 := filepath.Join(srcdir, "file2.txt")
require.NoError(t, ioutil.WriteFile(file2, []byte{'2'}, 0644))

file3 := filepath.Join(srcdir, "file3.txt")
require.NoError(t, ioutil.WriteFile(file3, []byte{'3'}, 0644))

// Test server to serve the artifacts
ts := httptest.NewServer(http.FileServer(http.Dir(srcdir)))
defer ts.Close()

// Create the target directory.
destdir, err := ioutil.TempDir("", "nomadtest-dest")
require.NoError(t, err)
defer func() {
require.NoError(t, os.RemoveAll(destdir))
}()

req := &interfaces.TaskPrestartRequest{
TaskEnv: taskenv.NewTaskEnv(nil, nil, nil, nil, destdir, ""),
TaskDir: &allocdir.TaskDir{Dir: destdir},
Task: &structs.Task{
Artifacts: []*structs.TaskArtifact{
{
GetterSource: ts.URL + "/file0.txt", // this request will fail
GetterMode: structs.GetterModeAny,
},
{
GetterSource: ts.URL + "/file1.txt",
GetterMode: structs.GetterModeAny,
},
{
GetterSource: ts.URL + "/file2.txt",
GetterMode: structs.GetterModeAny,
},
{
GetterSource: ts.URL + "/file3.txt",
GetterMode: structs.GetterModeAny,
},
},
},
}

resp := interfaces.TaskPrestartResponse{}

// On first run file1 (foo) should download but file2 (bar) should
// fail.
gowthamgts marked this conversation as resolved.
Show resolved Hide resolved
err = artifactHook.Prestart(context.Background(), req, &resp)

require.Error(t, err)
require.True(t, structs.IsRecoverable(err))
require.Len(t, resp.State, 3)
require.False(t, resp.Done)
require.Len(t, me.events, 1)
require.Equal(t, structs.TaskDownloadingArtifacts, me.events[0].Type)

// delete the downloaded files so that it'll error if it's downloaded again
require.NoError(t, os.Remove(file1))
require.NoError(t, os.Remove(file2))
require.NoError(t, os.Remove(file3))

// create the missing file
file0 := filepath.Join(srcdir, "file0.txt")
require.NoError(t, ioutil.WriteFile(file0, []byte{'0'}, 0644))

// Mock TaskRunner by copying state from resp to req and reset resp.
req.PreviousState = helper.CopyMapStringString(resp.State)

resp = interfaces.TaskPrestartResponse{}

// Retry the download and assert it succeeds
err = artifactHook.Prestart(context.Background(), req, &resp)
require.NoError(t, err)
require.True(t, resp.Done)
require.Len(t, resp.State, 4)

// Assert all files downloaded properly
files, err := filepath.Glob(filepath.Join(destdir, "*.txt"))
require.NoError(t, err)
sort.Strings(files)
require.Contains(t, files[0], "file0.txt")
require.Contains(t, files[1], "file1.txt")
require.Contains(t, files[2], "file2.txt")
require.Contains(t, files[3], "file3.txt")

// verify the file contents too, since files will also be created for failed downloads
data0, err := ioutil.ReadFile(files[0])
require.NoError(t, err)
require.Equal(t, data0, []byte{'0'})

data1, err := ioutil.ReadFile(files[1])
require.NoError(t, err)
require.Equal(t, data1, []byte{'1'})

data2, err := ioutil.ReadFile(files[2])
require.NoError(t, err)
require.Equal(t, data2, []byte{'2'})

data3, err := ioutil.ReadFile(files[3])
require.NoError(t, err)
require.Equal(t, data3, []byte{'3'})

// Stop the test server entirely and assert that re-running works
ts.Close()
req.PreviousState = helper.CopyMapStringString(resp.State)
resp = interfaces.TaskPrestartResponse{}
err = artifactHook.Prestart(context.Background(), req, &resp)
require.NoError(t, err)
require.True(t, resp.Done)
require.Len(t, resp.State, 4)
}