From 1e662018291004a81463983f22bab23276148625 Mon Sep 17 00:00:00 2001 From: Hidde Beydals Date: Wed, 13 Dec 2023 09:31:11 +0100 Subject: [PATCH] loader: allow overwrite of URL hostname again This adds back the support for overwriting the host name a chart is downloaded from (again) using the `SOURCE_CONTROLLER_LOCALHOST` environment variable. Signed-off-by: Hidde Beydals --- internal/loader/artifact_url.go | 29 +++++++++++++++ internal/loader/artifact_url_test.go | 55 ++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+) diff --git a/internal/loader/artifact_url.go b/internal/loader/artifact_url.go index 6b76416a4..b100ef703 100644 --- a/internal/loader/artifact_url.go +++ b/internal/loader/artifact_url.go @@ -24,6 +24,8 @@ import ( "fmt" "io" "net/http" + "net/url" + "os" "github.com/hashicorp/go-retryablehttp" digestlib "github.com/opencontainers/go-digest" @@ -32,6 +34,13 @@ import ( "helm.sh/helm/v3/pkg/chart/loader" ) +const ( + // envSourceControllerLocalhost is the name of the environment variable + // used to override the hostname of the source-controller from which + // the chart is usually downloaded. + envSourceControllerLocalhost = "SOURCE_CONTROLLER_LOCALHOST" +) + var ( // ErrFileNotFound is an error type used to signal 404 HTTP status code responses. ErrFileNotFound = errors.New("file not found") @@ -45,6 +54,11 @@ var ( // digest before loading the chart. It returns the loaded chart.Chart, or an // error. The error may be of type ErrIntegrity if the integrity check fails. func SecureLoadChartFromURL(client *retryablehttp.Client, URL, digest string) (*chart.Chart, error) { + URL, err := overwriteHostname(URL, os.Getenv(envSourceControllerLocalhost)) + if err != nil { + return nil, err + } + req, err := retryablehttp.NewRequest(http.MethodGet, URL, nil) if err != nil { return nil, err @@ -94,3 +108,18 @@ func copyAndVerify(digest string, reader io.Reader, writer io.Writer) error { } return nil } + +// overwriteHostname overwrites the hostname of the given URL with the given +// hostname. If the hostname is empty, the URL is returned unmodified. +func overwriteHostname(URL, hostname string) (string, error) { + if hostname == "" { + return URL, nil + } + + u, err := url.Parse(URL) + if err != nil { + return "", fmt.Errorf("failed to parse URL to overwrite hostname: %w", err) + } + u.Host = hostname + return u.String(), nil +} diff --git a/internal/loader/artifact_url_test.go b/internal/loader/artifact_url_test.go index 20f60ac22..9f4a3b291 100644 --- a/internal/loader/artifact_url_test.go +++ b/internal/loader/artifact_url_test.go @@ -23,6 +23,7 @@ import ( "net/http" "net/http/httptest" "os" + "strings" "testing" "github.com/hashicorp/go-retryablehttp" @@ -72,6 +73,19 @@ func TestSecureLoadChartFromURL(t *testing.T) { g.Expect(got.Metadata.Version).To(Equal("0.1.0")) }) + t.Run("overwrites hostname", func(t *testing.T) { + g := NewWithT(t) + + t.Setenv(envSourceControllerLocalhost, strings.TrimPrefix(server.URL, "http://")) + wrongHostnameURL := "http://invalid.com" + chartPath + + got, err := SecureLoadChartFromURL(client, wrongHostnameURL, digest.String()) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(got).ToNot(BeNil()) + g.Expect(got.Name()).To(Equal("chart")) + g.Expect(got.Metadata.Version).To(Equal("0.1.0")) + }) + t.Run("error on chart data digest mismatch", func(t *testing.T) { g := NewWithT(t) @@ -162,3 +176,44 @@ func Test_copyAndVerify(t *testing.T) { }) } } + +func Test_overwriteHostname(t *testing.T) { + tests := []struct { + name string + URL string + hostname string + want string + wantErr bool + }{ + { + name: "overwrite hostname", + URL: "http://example.com", + hostname: "localhost", + want: "http://localhost", + }, + { + name: "overwrite hostname with port", + URL: "http://example.com", + hostname: "localhost:9090", + want: "http://localhost:9090", + }, + { + name: "no hostname", + URL: "http://example.com", + hostname: "", + want: "http://example.com", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := overwriteHostname(tt.URL, tt.hostname) + if (err != nil) != tt.wantErr { + t.Errorf("overwriteHostname() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("overwriteHostname() got = %v, want %v", got, tt.want) + } + }) + } +}