From 3940a38aba7e2685f5c7c637f51ec3554b026c15 Mon Sep 17 00:00:00 2001 From: Julien Robert Date: Wed, 15 Nov 2023 09:12:53 +0100 Subject: [PATCH] refactor: configure go-getter --- x/upgrade/plan/downloader.go | 39 ++++++++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/x/upgrade/plan/downloader.go b/x/upgrade/plan/downloader.go index 1af03d7a2db2..9ae9e604e68b 100644 --- a/x/upgrade/plan/downloader.go +++ b/x/upgrade/plan/downloader.go @@ -1,6 +1,7 @@ package plan import ( + "context" "errors" "fmt" neturl "net/url" @@ -8,6 +9,7 @@ import ( "path/filepath" "strings" + "github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/go-getter" ) @@ -24,8 +26,9 @@ import ( // NOTE: This functions does not check the provided url for validity. func DownloadUpgrade(dstRoot, url, daemonName string) error { target := filepath.Join(dstRoot, "bin", daemonName) + // First try to download it as a single file. If there's no error, it's okay and we're done. - if err := getter.GetFile(target, url); err != nil { + if err := getFile(url, target); err != nil { // If it was a checksum error, no need to try as directory. if _, ok := err.(*getter.ChecksumError); ok { return err @@ -109,7 +112,8 @@ func DownloadURL(url string) (string, error) { } defer os.RemoveAll(tempDir) tempFile := filepath.Join(tempDir, "content") - if err = getter.GetFile(tempFile, url); err != nil { + + if err := getFile(url, tempFile); err != nil { return "", fmt.Errorf("could not download url \"%s\": %w", url, err) } tempFileBz, rerr := os.ReadFile(tempFile) @@ -136,3 +140,34 @@ func ValidateURL(urlStr string, mustChecksum bool) error { return nil } + +// getFile downloads the given url into the provided directory. +func getFile(url, dst string) error { + httpGetter := &getter.HttpGetter{ + Client: cleanhttp.DefaultClient(), + XTerraformGetDisabled: true, + } + + goGetterGetters := map[string]getter.Getter{ + "file": new(getter.FileGetter), + "gcs": new(getter.GCSGetter), + "git": new(getter.GitGetter), + "hg": new(getter.HgGetter), + "s3": new(getter.S3Getter), + "http": httpGetter, + "https": httpGetter, + } + + // https://github.com/hashicorp/go-getter#security-options + getterClient := &getter.Client{ + Ctx: context.Background(), + DisableSymlinks: true, + Src: url, + Dst: dst, + Pwd: dst, + Mode: getter.ClientModeAny, + Getters: goGetterGetters, + } + + return getterClient.Get() +}