diff --git a/cmd/cli/main_test.go b/cmd/cli/main_test.go index a67cfce1..43d13143 100644 --- a/cmd/cli/main_test.go +++ b/cmd/cli/main_test.go @@ -14,6 +14,7 @@ package main import ( "encoding/json" "io/ioutil" + "net/http" "os" "os/exec" "testing" @@ -202,7 +203,7 @@ func testCreateProjectFromTemplate(t *testing.T) { out, err := cmd.CombinedOutput() assert.NotNil(t, err) assert.Contains(t, string(out), "invalid_git_credentials") - assert.Contains(t, string(out), "401 Unauthorized") + assert.Contains(t, string(out), http.StatusText(http.StatusUnauthorized)) }) } @@ -418,7 +419,7 @@ func testSuccessfulAddAndRemoveTemplateRepos(t *testing.T) { createOut, createErr := createCmd.CombinedOutput() assert.NotNil(t, createErr) assert.Contains(t, string(createOut), "invalid_git_credentials") - assert.Contains(t, string(createOut), "401 Unauthorized") + assert.Contains(t, string(createOut), http.StatusText(http.StatusUnauthorized)) removeCmd := exec.Command(cwctl, "templates", "repos", "remove", "--url="+test.GHEDevfileURL, @@ -454,7 +455,7 @@ func testSuccessfulAddAndRemoveTemplateRepos(t *testing.T) { createOut, createErr := createCmd.CombinedOutput() assert.NotNil(t, createErr) assert.Contains(t, string(createOut), "invalid_git_credentials") - assert.Contains(t, string(createOut), "401 Unauthorized") + assert.Contains(t, string(createOut), http.StatusText(http.StatusUnauthorized)) removeCmd := exec.Command(cwctl, "templates", "repos", "remove", "--url="+test.GHEDevfileURL, diff --git a/pkg/project/create.go b/pkg/project/create.go index f17bef11..679e071c 100644 --- a/pkg/project/create.go +++ b/pkg/project/create.go @@ -74,7 +74,7 @@ func DownloadTemplate(destination, url string, gitCredentials *utils.GitCredenti if err != nil { errOp := errOpCreateProject // if 401 error, use invalid credentials error code - if strings.Contains(err.Error(), "401 Unauthorized") { + if err.Error() == http.StatusText(http.StatusUnauthorized) { errOp = errOpInvalidCredentials } return nil, &ProjectError{errOp, err, err.Error()} diff --git a/pkg/project/create_test.go b/pkg/project/create_test.go index 737f6e3d..0999bb17 100644 --- a/pkg/project/create_test.go +++ b/pkg/project/create_test.go @@ -99,7 +99,7 @@ func TestDownloadTemplate(t *testing.T) { assert.Nil(t, out) assert.Equal(t, errOpInvalidCredentials, err.Op) - assert.Equal(t, "unexpected status code: 401 Unauthorized", err.Desc) + assert.Equal(t, http.StatusText(http.StatusUnauthorized), err.Desc) }) t.Run("fail case: download GHE template using bad personalAccessToken)", func(t *testing.T) { os.RemoveAll(testDir) @@ -116,7 +116,7 @@ func TestDownloadTemplate(t *testing.T) { assert.Nil(t, out) assert.Equal(t, errOpInvalidCredentials, err.Op) - assert.Equal(t, "unexpected status code: 401 Unauthorized", err.Desc) + assert.Equal(t, http.StatusText(http.StatusUnauthorized), err.Desc) }) } diff --git a/pkg/utils/download.go b/pkg/utils/download.go index 1ac12204..b9dd879a 100644 --- a/pkg/utils/download.go +++ b/pkg/utils/download.go @@ -13,6 +13,7 @@ package utils import ( "context" + "errors" "fmt" "io" "net/http" @@ -79,7 +80,6 @@ func DownloadFromTarGzURL(URL *url.URL, destination string, gitCredentials *GitC func getURLToDownloadReleaseAsset(URL *url.URL, gitCredentials *GitCredentials) (*url.URL, error) { URLPathSlice := strings.Split(URL.Path, "/") - if !strings.Contains(URL.Host, "github") || len(URLPathSlice) < 6 { return nil, fmt.Errorf("URL must point to a GitHub repository release asset: %v", URL) } @@ -152,7 +152,14 @@ func DownloadFromRepoURL(URL *url.URL, destination string, gitCredentials *GitCr owner := URLPathSlice[1] repo := URLPathSlice[2] - zipURL, err := GetZipURL(owner, repo, "master", client) + + // Get the default branch rather than assuming a name. + branch, err := getDefaultBranch(owner, repo, client) + if err != nil { + return err + } + + zipURL, err := GetZipURL(owner, repo, branch, client) if err != nil { return err } @@ -197,6 +204,18 @@ func GetZipURL(owner, repo, branch string, client *github.Client) (*url.URL, err return URL, nil } +func getDefaultBranch(owner, repo string, client *github.Client) (string, error) { + ctx := context.Background() + repository, response, err := client.Repositories.Get(ctx, owner, repo) + if err != nil { + if response != nil && response.StatusCode == http.StatusUnauthorized { + return "", errors.New(http.StatusText(http.StatusUnauthorized)) + } + return "", err + } + return *repository.DefaultBranch, nil +} + // DownloadAndExtractZip downloads a zip file from a URL // and extracts it to a destination func DownloadAndExtractZip(zipURL *url.URL, destination string) error { diff --git a/pkg/utils/download_test.go b/pkg/utils/download_test.go index 2dad77d2..0faec16d 100644 --- a/pkg/utils/download_test.go +++ b/pkg/utils/download_test.go @@ -15,6 +15,7 @@ import ( "errors" "fmt" "io/ioutil" + "net/http" "net/url" "os" "path/filepath" @@ -136,7 +137,7 @@ func TestDownloadFromURLThenExtract(t *testing.T) { inDestination: filepath.Join(testDir, "failCase"), inGitCredentials: &GitCredentials{Username: test.GHEUsername, Password: "bad password"}, wantedType: errors.New(""), - wantedErrMsg: "401 Unauthorized", + wantedErrMsg: http.StatusText(http.StatusUnauthorized), wantedNumFiles: 0, }, "fail case: input good GHE tar.gz URL and credentials but no matching repo found": {