Skip to content

Commit

Permalink
Merge pull request #2431 from hashicorp/b-git-ssh
Browse files Browse the repository at this point in the history
Handle git ssh artifacts
  • Loading branch information
dadgar authored Mar 13, 2017
2 parents 7ce7abb + 637aff7 commit ef22127
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 4 deletions.
29 changes: 26 additions & 3 deletions client/getter/getter.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"net/url"
"path/filepath"
"strings"
"sync"

gg "github.com/hashicorp/go-getter"
Expand All @@ -21,6 +22,11 @@ var (
supported = []string{"http", "https", "s3", "hg", "git"}
)

const (
// gitSSHPrefix is the prefix for dowwnloading via git using ssh
gitSSHPrefix = "[email protected]:"
)

// getClient returns a client that is suitable for Nomad downloading artifacts.
func getClient(src, dst string) *gg.Client {
lock.Lock()
Expand All @@ -47,7 +53,17 @@ func getClient(src, dst string) *gg.Client {
// getGetterUrl returns the go-getter URL to download the artifact.
func getGetterUrl(taskEnv *env.TaskEnvironment, artifact *structs.TaskArtifact) (string, error) {
taskEnv.Build()
u, err := url.Parse(taskEnv.ReplaceEnv(artifact.GetterSource))
source := taskEnv.ReplaceEnv(artifact.GetterSource)

// Handle an invalid URL when given a go-getter url such as
// [email protected]:hashicorp/nomad.git
gitSSH := false
if strings.HasPrefix(source, gitSSHPrefix) {
gitSSH = true
source = source[len(gitSSHPrefix):]
}

u, err := url.Parse(source)
if err != nil {
return "", fmt.Errorf("failed to parse source URL %q: %v", artifact.GetterSource, err)
}
Expand All @@ -58,7 +74,14 @@ func getGetterUrl(taskEnv *env.TaskEnvironment, artifact *structs.TaskArtifact)
q.Add(k, taskEnv.ReplaceEnv(v))
}
u.RawQuery = q.Encode()
return u.String(), nil

// Add the prefix back
url := u.String()
if gitSSH {
url = fmt.Sprintf("%s%s", gitSSHPrefix, url)
}

return url, nil
}

// GetArtifact downloads an artifact into the specified task directory.
Expand All @@ -71,7 +94,7 @@ func GetArtifact(taskEnv *env.TaskEnvironment, artifact *structs.TaskArtifact, t
// Download the artifact
dest := filepath.Join(taskDir, artifact.RelativeDest)
if err := getClient(url, dest).Get(); err != nil {
return fmt.Errorf("GET error: %v", err)
return structs.NewRecoverableError(fmt.Errorf("GET error: %v", err), true)
}

return nil
Expand Down
99 changes: 99 additions & 0 deletions client/getter/getter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,102 @@ func TestGetArtifact_Archive(t *testing.T) {
}
checkContents(taskDir, expected, t)
}

func TestGetGetterUrl_Queries(t *testing.T) {
taskEnv := env.NewTaskEnvironment(mock.Node())
cases := []struct {
name string
artifact *structs.TaskArtifact
output string
}{
{
name: "adds query parameters",
artifact: &structs.TaskArtifact{
GetterSource: "https://foo.com?test=1",
GetterOptions: map[string]string{
"foo": "bar",
"bam": "boom",
},
},
output: "https://foo.com?bam=boom&foo=bar&test=1",
},
{
name: "git without http",
artifact: &structs.TaskArtifact{
GetterSource: "github.com/hashicorp/nomad",
GetterOptions: map[string]string{
"ref": "abcd1234",
},
},
output: "github.com/hashicorp/nomad?ref=abcd1234",
},
{
name: "git using ssh",
artifact: &structs.TaskArtifact{
GetterSource: "[email protected]:hashicorp/nomad?sshkey=1",
GetterOptions: map[string]string{
"ref": "abcd1234",
},
},
output: "[email protected]:hashicorp/nomad?ref=abcd1234&sshkey=1",
},
{
name: "s3 scheme 1",
artifact: &structs.TaskArtifact{
GetterSource: "s3::https://s3.amazonaws.com/bucket/foo",
GetterOptions: map[string]string{
"aws_access_key_id": "abcd1234",
},
},
output: "s3::https://s3.amazonaws.com/bucket/foo?aws_access_key_id=abcd1234",
},
{
name: "s3 scheme 2",
artifact: &structs.TaskArtifact{
GetterSource: "s3::https://s3-eu-west-1.amazonaws.com/bucket/foo",
GetterOptions: map[string]string{
"aws_access_key_id": "abcd1234",
},
},
output: "s3::https://s3-eu-west-1.amazonaws.com/bucket/foo?aws_access_key_id=abcd1234",
},
{
name: "s3 scheme 3",
artifact: &structs.TaskArtifact{
GetterSource: "bucket.s3.amazonaws.com/foo",
GetterOptions: map[string]string{
"aws_access_key_id": "abcd1234",
},
},
output: "bucket.s3.amazonaws.com/foo?aws_access_key_id=abcd1234",
},
{
name: "s3 scheme 4",
artifact: &structs.TaskArtifact{
GetterSource: "bucket.s3-eu-west-1.amazonaws.com/foo/bar",
GetterOptions: map[string]string{
"aws_access_key_id": "abcd1234",
},
},
output: "bucket.s3-eu-west-1.amazonaws.com/foo/bar?aws_access_key_id=abcd1234",
},
{
name: "local file",
artifact: &structs.TaskArtifact{
GetterSource: "/foo/bar",
},
output: "/foo/bar",
},
}

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
act, err := getGetterUrl(taskEnv, c.artifact)
if err != nil {
t.Fatalf("want %q; got err %v", c.output, err)
} else if act != c.output {
t.Fatalf("want %q; got %q", c.output, act)
}
})
}
}
2 changes: 1 addition & 1 deletion client/task_runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,7 @@ func (r *TaskRunner) prestart(resultCh chan bool) {
r.logger.Printf("[DEBUG] client: %v", wrapped)
r.setState(structs.TaskStatePending,
structs.NewTaskEvent(structs.TaskArtifactDownloadFailed).SetDownloadError(wrapped))
r.restartTracker.SetStartError(structs.NewRecoverableError(wrapped, true))
r.restartTracker.SetStartError(structs.NewRecoverableError(wrapped, structs.IsRecoverable(err)))
goto RESTART
}
}
Expand Down

0 comments on commit ef22127

Please sign in to comment.