From 3d611822493362ae42c7bcf36d1c2ff8675a1120 Mon Sep 17 00:00:00 2001
From: Mariana Dima <mariana@elastic.co>
Date: Mon, 17 Jan 2022 20:20:05 +0100
Subject: [PATCH] Update `fetchBeatsBinary` to be reused in elastic-agent-poc
 (#1984)

* update func

* fix path

* work on download

* small fix

* remove test

* add sha to google

* fix typo

* add comment

(cherry picked from commit d3365c99b9c4863a15a5f3063d08f01470465110)

# Conflicts:
#	pkg/downloads/versions.go
---
 internal/utils/utils.go        | 54 ++++++++++++++++-----------
 internal/utils/utils_test.go   |  9 ++++-
 pkg/downloads/versions.go      | 68 +++++++++++++++++++++++++---------
 pkg/downloads/versions_test.go | 31 ++++++++--------
 4 files changed, 105 insertions(+), 57 deletions(-)

diff --git a/internal/utils/utils.go b/internal/utils/utils.go
index f0adc405b2..13c28262c8 100644
--- a/internal/utils/utils.go
+++ b/internal/utils/utils.go
@@ -26,6 +26,13 @@ const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
 //nolint:unused
 var seededRand = rand.New(rand.NewSource(time.Now().UnixNano()))
 
+// DownloadRequest struct contains download details ad path and URL
+type DownloadRequest struct {
+	URL                 string
+	DownloadPath        string
+	UnsanitizedFilePath string
+}
+
 // GetArchitecture retrieves if the underlying system platform is arm64 or amd64
 func GetArchitecture() string {
 	arch, present := os.LookupEnv("GOARCH")
@@ -40,36 +47,41 @@ func GetArchitecture() string {
 // DownloadFile will download a url and store it in a temporary path.
 // It writes to the destination file as it downloads it, without
 // loading the entire file into memory.
-func DownloadFile(url string) (string, error) {
-	tempParentDir := filepath.Join(os.TempDir(), uuid.NewString())
-	internalio.MkdirAll(tempParentDir)
+func DownloadFile(downloadRequest *DownloadRequest) error {
+	var filePath string
+	if downloadRequest.DownloadPath == "" {
+		tempParentDir := filepath.Join(os.TempDir(), uuid.NewString())
+		internalio.MkdirAll(tempParentDir)
+		filePath = filepath.Join(tempParentDir, uuid.NewString())
+		downloadRequest.DownloadPath = filePath
+	} else {
+		filePath = filepath.Join(downloadRequest.DownloadPath, uuid.NewString())
+	}
 
-	tempFile, err := os.Create(filepath.Join(tempParentDir, uuid.NewString()))
+	tempFile, err := os.Create(filePath)
 	if err != nil {
 		log.WithFields(log.Fields{
 			"error": err,
-			"url":   url,
+			"url":   downloadRequest.URL,
 		}).Error("Error creating file")
-		return "", err
+		return err
 	}
 	defer tempFile.Close()
 
-	filepathFull := tempFile.Name()
-
+	downloadRequest.UnsanitizedFilePath = tempFile.Name()
 	exp := GetExponentialBackOff(3)
 
 	retryCount := 1
 	var fileReader io.ReadCloser
-
 	download := func() error {
-		resp, err := http.Get(url)
+		resp, err := http.Get(downloadRequest.URL)
 		if err != nil {
 			log.WithFields(log.Fields{
 				"elapsedTime": exp.GetElapsedTime(),
 				"error":       err,
-				"path":        filepathFull,
+				"path":        downloadRequest.UnsanitizedFilePath,
 				"retry":       retryCount,
-				"url":         url,
+				"url":         downloadRequest.URL,
 			}).Warn("Could not download the file")
 
 			retryCount++
@@ -80,8 +92,8 @@ func DownloadFile(url string) (string, error) {
 		log.WithFields(log.Fields{
 			"elapsedTime": exp.GetElapsedTime(),
 			"retries":     retryCount,
-			"path":        filepathFull,
-			"url":         url,
+			"path":        downloadRequest.UnsanitizedFilePath,
+			"url":         downloadRequest.URL,
 		}).Trace("File downloaded")
 
 		fileReader = resp.Body
@@ -90,13 +102,13 @@ func DownloadFile(url string) (string, error) {
 	}
 
 	log.WithFields(log.Fields{
-		"url":  url,
-		"path": filepathFull,
+		"url":  downloadRequest.URL,
+		"path": downloadRequest.UnsanitizedFilePath,
 	}).Trace("Downloading file")
 
 	err = backoff.Retry(download, exp)
 	if err != nil {
-		return "", err
+		return err
 	}
 	defer fileReader.Close()
 
@@ -104,16 +116,16 @@ func DownloadFile(url string) (string, error) {
 	if err != nil {
 		log.WithFields(log.Fields{
 			"error": err,
-			"url":   url,
-			"path":  filepathFull,
+			"url":   downloadRequest.URL,
+			"path":  downloadRequest.UnsanitizedFilePath,
 		}).Error("Could not write file")
 
-		return filepathFull, err
+		return err
 	}
 
 	_ = os.Chmod(tempFile.Name(), 0666)
 
-	return filepathFull, nil
+	return nil
 }
 
 // IsCommit returns true if the string matches commit format
diff --git a/internal/utils/utils_test.go b/internal/utils/utils_test.go
index 82098c5459..44fbc18043 100644
--- a/internal/utils/utils_test.go
+++ b/internal/utils/utils_test.go
@@ -9,9 +9,14 @@ import (
 )
 
 func TestDownloadFile(t *testing.T) {
-	f, err := DownloadFile("https://www.elastic.co/robots.txt")
+	var dRequest = DownloadRequest{
+		URL:          "https://www.elastic.co/robots.txt",
+		DownloadPath: "",
+	}
+	err := DownloadFile(&dRequest)
 	assert.Nil(t, err)
-	defer os.Remove(filepath.Dir(f))
+	assert.NotEmpty(t, dRequest.UnsanitizedFilePath)
+	defer os.Remove(filepath.Dir(dRequest.UnsanitizedFilePath))
 }
 
 func TestGetArchitecture(t *testing.T) {
diff --git a/pkg/downloads/versions.go b/pkg/downloads/versions.go
index 08a5a14478..163396ca6a 100644
--- a/pkg/downloads/versions.go
+++ b/pkg/downloads/versions.go
@@ -83,7 +83,7 @@ func CheckPRVersion(version string, fallbackVersion string) string {
 // FetchElasticArtifact fetches an artifact from the right repository, returning binary name, path and error
 func FetchElasticArtifact(ctx context.Context, artifact string, version string, os string, arch string, extension string, isDocker bool, xpack bool) (string, string, error) {
 	binaryName := buildArtifactName(artifact, version, os, arch, extension, isDocker)
-	binaryPath, err := fetchBeatsBinary(ctx, binaryName, artifact, version, utils.TimeoutFactor, xpack)
+	binaryPath, err := FetchBeatsBinary(ctx, binaryName, artifact, version, utils.TimeoutFactor, xpack, "", false)
 	if err != nil {
 		log.WithFields(log.Fields{
 			"artifact":  artifact,
@@ -106,10 +106,11 @@ func GetCommitVersion(version string) string {
 
 // GetElasticArtifactURL returns the URL of a released artifact, which its full name is defined in the first argument,
 // from Elastic's artifact repository, building the JSON path query based on the full name
+// It also returns the URL of the sha512 file of the released artifact.
 // i.e. GetElasticArtifactURL("elastic-agent-$VERSION-$ARCH.deb", "elastic-agent", "$VERSION")
 // i.e. GetElasticArtifactURL("elastic-agent-$VERSION-x86_64.rpm", "elastic-agent","$VERSION")
 // i.e. GetElasticArtifactURL("elastic-agent-$VERSION-linux-$ARCH.tar.gz", "elastic-agent","$VERSION")
-func GetElasticArtifactURL(artifactName string, artifact string, version string) (string, error) {
+func GetElasticArtifactURL(artifactName string, artifact string, version string) (string, string, error) {
 	exp := utils.GetExponentialBackOff(time.Minute)
 
 	retryCount := 1
@@ -161,7 +162,7 @@ func GetElasticArtifactURL(artifactName string, artifact string, version string)
 
 	err := backoff.Retry(apiStatus, exp)
 	if err != nil {
-		return "", err
+		return "", "", err
 	}
 
 	jsonParsed, err := gabs.ParseJSON([]byte(body))
@@ -171,7 +172,7 @@ func GetElasticArtifactURL(artifactName string, artifact string, version string)
 			"artifactName": artifactName,
 			"version":      tmpVersion,
 		}).Error("Could not parse the response body for the artifact")
-		return "", err
+		return "", "", err
 	}
 
 	log.WithFields(log.Fields{
@@ -191,8 +192,9 @@ func GetElasticArtifactURL(artifactName string, artifact string, version string)
 	// we need to get keys with dots using Search instead of Path
 	downloadObject := packagesObject.Search(artifactName)
 	downloadURL := downloadObject.Path("url").Data().(string)
+	downloadshaURL := downloadObject.Path("sha_url").Data().(string)
 
-	return downloadURL, nil
+	return downloadURL, downloadshaURL, nil
 }
 
 // GetElasticArtifactVersion returns the current version:
@@ -348,14 +350,19 @@ func buildArtifactName(artifact string, artifactVersion string, OS string, arch
 
 }
 
-// fetchBeatsBinary it downloads the binary and returns the location of the downloaded file
+// FetchBeatsBinary it downloads the binary and returns the location of the downloaded file
 // If the environment variable BEATS_LOCAL_PATH is set, then the artifact
 // to be used will be defined by the local snapshot produced by the local build.
 // Else, if the environment variable GITHUB_CHECK_SHA1 is set, then the artifact
 // to be downloaded will be defined by the snapshot produced by the Beats CI for that commit.
+<<<<<<< HEAD
 func fetchBeatsBinary(ctx context.Context, artifactName string, artifact string, version string, timeoutFactor int, xpack bool) (string, error) {
 	beatsLocalPath := shell.GetEnv("BEATS_LOCAL_PATH", "")
 	if beatsLocalPath != "" {
+=======
+func FetchBeatsBinary(ctx context.Context, artifactName string, artifact string, version string, timeoutFactor int, xpack bool, downloadPath string, downloadSHAFile bool) (string, error) {
+	if BeatsLocalPath != "" {
+>>>>>>> d3365c99 (Update `fetchBeatsBinary` to be reused in elastic-agent-poc (#1984))
 		span, _ := apm.StartSpanOptions(ctx, "Fetching Beats binary", "beats.local.fetch-binary", apm.SpanOptions{
 			Parent: apm.SpanFromContext(ctx).TraceContext(),
 		})
@@ -378,6 +385,11 @@ func fetchBeatsBinary(ctx context.Context, artifactName string, artifact string,
 	}
 
 	handleDownload := func(URL string) (string, error) {
+		name := artifactName
+		downloadRequest := utils.DownloadRequest{
+			DownloadPath: downloadPath,
+			URL:          URL,
+		}
 		span, _ := apm.StartSpanOptions(ctx, "Fetching Beats binary", "beats.url.fetch-binary", apm.SpanOptions{
 			Parent: apm.SpanFromContext(ctx).TraceContext(),
 		})
@@ -391,20 +403,23 @@ func fetchBeatsBinary(ctx context.Context, artifactName string, artifact string,
 			return val, nil
 		}
 
-		filePathFull, err := utils.DownloadFile(URL)
+		err := utils.DownloadFile(&downloadRequest)
 		if err != nil {
-			return filePathFull, err
+			return downloadRequest.UnsanitizedFilePath, err
 		}
 
+		if strings.HasSuffix(URL, ".sha512") {
+			name = fmt.Sprintf("%s.sha512", name)
+		}
 		// use artifact name as file name to avoid having URL params in the name
-		sanitizedFilePath := filepath.Join(path.Dir(filePathFull), artifactName)
-		err = os.Rename(filePathFull, sanitizedFilePath)
+		sanitizedFilePath := filepath.Join(path.Dir(downloadRequest.UnsanitizedFilePath), name)
+		err = os.Rename(downloadRequest.UnsanitizedFilePath, sanitizedFilePath)
 		if err != nil {
 			log.WithFields(log.Fields{
-				"fileName":          filePathFull,
+				"fileName":          downloadRequest.UnsanitizedFilePath,
 				"sanitizedFileName": sanitizedFilePath,
 			}).Warn("Could not sanitize downloaded file name. Keeping old name")
-			sanitizedFilePath = filePathFull
+			sanitizedFilePath = downloadRequest.UnsanitizedFilePath
 		}
 
 		binariesCache[URL] = sanitizedFilePath
@@ -412,7 +427,7 @@ func fetchBeatsBinary(ctx context.Context, artifactName string, artifact string,
 		return sanitizedFilePath, nil
 	}
 
-	var downloadURL string
+	var downloadURL, downloadShaURL string
 	var err error
 
 	useCISnapshots := GithubCommitSha1 != ""
@@ -424,24 +439,41 @@ func fetchBeatsBinary(ctx context.Context, artifactName string, artifact string,
 
 		log.Debugf("Using CI snapshots for %s", artifact)
 
-		bucket, prefix, object := getGCPBucketCoordinates(artifactName, artifact)
-
 		maxTimeout := time.Duration(timeoutFactor) * time.Minute
 
+		bucket, prefix, object := getGCPBucketCoordinates(artifactName, artifact)
+
 		downloadURL, err = getObjectURLFromBucket(bucket, prefix, object, maxTimeout)
 		if err != nil {
 			return "", err
 		}
+		downloadLocation, err := handleDownload(downloadURL)
 
+		// check if sha file should be downloaded, else return
+		if downloadSHAFile == false {
+			return downloadLocation, err
+		}
+
+		bucket, prefix, object = getGCPBucketCoordinates(fmt.Sprintf("%s.sha512", artifactName), artifact)
+		downloadURL, err = getObjectURLFromBucket(bucket, prefix, object, maxTimeout)
+		if err != nil {
+			return "", err
+		}
 		return handleDownload(downloadURL)
 	}
 
-	downloadURL, err = GetElasticArtifactURL(artifactName, artifact, version)
+	downloadURL, downloadShaURL, err = GetElasticArtifactURL(artifactName, artifact, version)
 	if err != nil {
 		return "", err
 	}
-
-	return handleDownload(downloadURL)
+	downloadLocation, err := handleDownload(downloadURL)
+	if err != nil {
+		return "", err
+	}
+	if downloadSHAFile == true && downloadShaURL != "" {
+		downloadLocation, err = handleDownload(downloadShaURL)
+	}
+	return downloadLocation, err
 }
 
 func getBucketSearchNextPageParam(jsonParsed *gabs.Container) string {
diff --git a/pkg/downloads/versions_test.go b/pkg/downloads/versions_test.go
index be1065bca8..c545c3641d 100644
--- a/pkg/downloads/versions_test.go
+++ b/pkg/downloads/versions_test.go
@@ -6,15 +6,14 @@ package downloads
 
 import (
 	"context"
+	"github.com/Jeffail/gabs/v2"
+	"github.com/elastic/e2e-testing/internal/utils"
+	"github.com/stretchr/testify/assert"
 	"io/ioutil"
 	"os"
 	"path"
 	"path/filepath"
 	"testing"
-
-	"github.com/Jeffail/gabs/v2"
-	"github.com/elastic/e2e-testing/internal/utils"
-	"github.com/stretchr/testify/assert"
 )
 
 var artifact = "elastic-agent"
@@ -387,7 +386,7 @@ func TestFetchBeatsBinaryFromLocalPath(t *testing.T) {
 		defer os.Unsetenv("BEATS_LOCAL_PATH")
 		os.Setenv("BEATS_LOCAL_PATH", beatsDir)
 
-		_, err := fetchBeatsBinary(ctx, "foo_fileName", artifact, version, utils.TimeoutFactor, true)
+		_, err := FetchBeatsBinary(ctx, "foo_fileName", artifact, version, utils.TimeoutFactor, true, "", false)
 		assert.NotNil(t, err)
 	})
 
@@ -398,7 +397,7 @@ func TestFetchBeatsBinaryFromLocalPath(t *testing.T) {
 		artifactName := versionPrefix + "-x86_64.rpm"
 		expectedFilePath := path.Join(distributionsDir, artifactName)
 
-		downloadedFilePath, err := fetchBeatsBinary(ctx, artifactName, artifact, version, utils.TimeoutFactor, true)
+		downloadedFilePath, err := FetchBeatsBinary(ctx, artifactName, artifact, version, utils.TimeoutFactor, true, "", false)
 		assert.Nil(t, err)
 		assert.Equal(t, downloadedFilePath, expectedFilePath)
 	})
@@ -409,7 +408,7 @@ func TestFetchBeatsBinaryFromLocalPath(t *testing.T) {
 		artifactName := versionPrefix + "-aarch64.rpm"
 		expectedFilePath := path.Join(distributionsDir, artifactName)
 
-		downloadedFilePath, err := fetchBeatsBinary(ctx, artifactName, artifact, version, utils.TimeoutFactor, true)
+		downloadedFilePath, err := FetchBeatsBinary(ctx, artifactName, artifact, version, utils.TimeoutFactor, true, "", false)
 		assert.Nil(t, err)
 		assert.Equal(t, downloadedFilePath, expectedFilePath)
 	})
@@ -421,7 +420,7 @@ func TestFetchBeatsBinaryFromLocalPath(t *testing.T) {
 		artifactName := versionPrefix + "-amd64.deb"
 		expectedFilePath := path.Join(distributionsDir, artifactName)
 
-		downloadedFilePath, err := fetchBeatsBinary(ctx, artifactName, artifact, version, utils.TimeoutFactor, true)
+		downloadedFilePath, err := FetchBeatsBinary(ctx, artifactName, artifact, version, utils.TimeoutFactor, true, "", false)
 		assert.Nil(t, err)
 		assert.Equal(t, downloadedFilePath, expectedFilePath)
 	})
@@ -432,7 +431,7 @@ func TestFetchBeatsBinaryFromLocalPath(t *testing.T) {
 		artifactName := versionPrefix + "-arm64.deb"
 		expectedFilePath := path.Join(distributionsDir, artifactName)
 
-		downloadedFilePath, err := fetchBeatsBinary(ctx, artifactName, artifact, version, utils.TimeoutFactor, true)
+		downloadedFilePath, err := FetchBeatsBinary(ctx, artifactName, artifact, version, utils.TimeoutFactor, true, "", false)
 		assert.Nil(t, err)
 		assert.Equal(t, downloadedFilePath, expectedFilePath)
 	})
@@ -444,7 +443,7 @@ func TestFetchBeatsBinaryFromLocalPath(t *testing.T) {
 		artifactName := versionPrefix + "-linux-amd64.tar.gz"
 		expectedFilePath := path.Join(distributionsDir, artifactName)
 
-		downloadedFilePath, err := fetchBeatsBinary(ctx, artifactName, artifact, version, utils.TimeoutFactor, true)
+		downloadedFilePath, err := FetchBeatsBinary(ctx, artifactName, artifact, version, utils.TimeoutFactor, true, "", false)
 		assert.Nil(t, err)
 		assert.Equal(t, downloadedFilePath, expectedFilePath)
 	})
@@ -455,7 +454,7 @@ func TestFetchBeatsBinaryFromLocalPath(t *testing.T) {
 		artifactName := versionPrefix + "-linux-x86_64.tar.gz"
 		expectedFilePath := path.Join(distributionsDir, artifactName)
 
-		downloadedFilePath, err := fetchBeatsBinary(ctx, artifactName, artifact, version, utils.TimeoutFactor, true)
+		downloadedFilePath, err := FetchBeatsBinary(ctx, artifactName, artifact, version, utils.TimeoutFactor, true, "", false)
 		assert.Nil(t, err)
 		assert.Equal(t, downloadedFilePath, expectedFilePath)
 	})
@@ -466,7 +465,7 @@ func TestFetchBeatsBinaryFromLocalPath(t *testing.T) {
 		artifactName := versionPrefix + "-linux-arm64.tar.gz"
 		expectedFilePath := path.Join(distributionsDir, artifactName)
 
-		downloadedFilePath, err := fetchBeatsBinary(ctx, artifactName, artifact, version, utils.TimeoutFactor, true)
+		downloadedFilePath, err := FetchBeatsBinary(ctx, artifactName, artifact, version, utils.TimeoutFactor, true, "", false)
 		assert.Nil(t, err)
 		assert.Equal(t, downloadedFilePath, expectedFilePath)
 	})
@@ -478,7 +477,7 @@ func TestFetchBeatsBinaryFromLocalPath(t *testing.T) {
 		artifactName := versionPrefix + "-linux-amd64.docker.tar.gz"
 		expectedFilePath := path.Join(distributionsDir, artifactName)
 
-		downloadedFilePath, err := fetchBeatsBinary(ctx, artifactName, artifact, version, utils.TimeoutFactor, true)
+		downloadedFilePath, err := FetchBeatsBinary(ctx, artifactName, artifact, version, utils.TimeoutFactor, true, "", false)
 		assert.Nil(t, err)
 		assert.Equal(t, downloadedFilePath, expectedFilePath)
 	})
@@ -489,7 +488,7 @@ func TestFetchBeatsBinaryFromLocalPath(t *testing.T) {
 		artifactName := versionPrefix + "-linux-arm64.docker.tar.gz"
 		expectedFilePath := path.Join(distributionsDir, artifactName)
 
-		downloadedFilePath, err := fetchBeatsBinary(ctx, artifactName, artifact, version, utils.TimeoutFactor, true)
+		downloadedFilePath, err := FetchBeatsBinary(ctx, artifactName, artifact, version, utils.TimeoutFactor, true, "", false)
 		assert.Nil(t, err)
 		assert.Equal(t, downloadedFilePath, expectedFilePath)
 	})
@@ -501,7 +500,7 @@ func TestFetchBeatsBinaryFromLocalPath(t *testing.T) {
 		artifactName := ubi8VersionPrefix + "-linux-amd64.docker.tar.gz"
 		expectedFilePath := path.Join(distributionsDir, artifactName)
 
-		downloadedFilePath, err := fetchBeatsBinary(ctx, artifactName, artifact, version, utils.TimeoutFactor, true)
+		downloadedFilePath, err := FetchBeatsBinary(ctx, artifactName, artifact, version, utils.TimeoutFactor, true, "", false)
 		assert.Nil(t, err)
 		assert.Equal(t, downloadedFilePath, expectedFilePath)
 	})
@@ -512,7 +511,7 @@ func TestFetchBeatsBinaryFromLocalPath(t *testing.T) {
 		artifactName := ubi8VersionPrefix + "-linux-arm64.docker.tar.gz"
 		expectedFilePath := path.Join(distributionsDir, artifactName)
 
-		downloadedFilePath, err := fetchBeatsBinary(ctx, artifactName, artifact, version, utils.TimeoutFactor, true)
+		downloadedFilePath, err := FetchBeatsBinary(ctx, artifactName, artifact, version, utils.TimeoutFactor, true, "", false)
 		assert.Nil(t, err)
 		assert.Equal(t, downloadedFilePath, expectedFilePath)
 	})