diff --git a/detect_gcs.go b/detect_gcs.go index 32399cd9..044df932 100644 --- a/detect_gcs.go +++ b/detect_gcs.go @@ -18,7 +18,7 @@ func (d *GCSDetector) Detect(src, _ string) (string, bool, error) { return "", false, nil } - if strings.Contains(src, "googleapis.com/") { + if strings.Contains(src, ".googleapis.com/") { return d.detectHTTP(src) } diff --git a/detect_gcs_test.go b/detect_gcs_test.go index 1713b1d3..bf01e5a1 100644 --- a/detect_gcs_test.go +++ b/detect_gcs_test.go @@ -24,6 +24,10 @@ func TestGCSDetector(t *testing.T) { "www.googleapis.com/storage/v1/foo/bar.baz", "gcs::https://www.googleapis.com/storage/v1/foo/bar.baz", }, + { + "www.googleapis.com/storage/v2/foo/bar/toor.baz", + "gcs::https://www.googleapis.com/storage/v2/foo/bar/toor.baz", + }, } pwd := "/pwd" @@ -42,3 +46,58 @@ func TestGCSDetector(t *testing.T) { } } } + +func TestGCSDetector_MalformedDetectHTTP(t *testing.T) { + cases := []struct { + Name string + Input string + Expected string + Output string + }{ + { + "valid url", + "www.googleapis.com/storage/v1/my-bucket/foo/bar", + "", + "gcs::https://www.googleapis.com/storage/v1/my-bucket/foo/bar", + }, + { + "empty url", + "", + "", + "", + }, + { + "not valid url", + "storage/v1/my-bucket/foo/bar", + "error parsing GCS URL", + "", + }, + { + "not valid url domain", + "www.googleapis.com.invalid/storage/v1/", + "URL is not a valid GCS URL", + "", + }, + { + "not valid url length", + "http://www.googleapis.com/storage", + "URL is not a valid GCS URL", + "", + }, + } + + pwd := "/pwd" + f := new(GCSDetector) + for _, tc := range cases { + output, _, err := f.Detect(tc.Input, pwd) + if err != nil { + if err.Error() != tc.Expected { + t.Fatalf("expected error %s, got %s for %s", tc.Expected, err.Error(), tc.Name) + } + } + + if output != tc.Output { + t.Fatalf("expected %s, got %s for %s", tc.Output, output, tc.Name) + } + } +} diff --git a/detect_s3_test.go b/detect_s3_test.go index 883c7c8b..237f7332 100644 --- a/detect_s3_test.go +++ b/detect_s3_test.go @@ -90,3 +90,58 @@ func TestS3Detector(t *testing.T) { } } } + +func TestS3Detector_MalformedDetectHTTP(t *testing.T) { + cases := []struct { + Name string + Input string + Expected string + Output string + }{ + { + "valid url", + "s3.amazonaws.com/bucket/foo/bar", + "", + "s3::https://s3.amazonaws.com/bucket/foo/bar", + }, + { + "empty url", + "", + "", + "", + }, + { + "not valid url", + "bucket/foo/bar", + "error parsing S3 URL", + "", + }, + { + "not valid url domain", + "s3.amazonaws.com.invalid/bucket/foo/bar", + "error parsing S3 URL", + "", + }, + { + "not valid url lenght", + "http://s3.amazonaws.com", + "URL is not a valid S3 URL", + "", + }, + } + + pwd := "/pwd" + f := new(S3Detector) + for _, tc := range cases { + output, _, err := f.Detect(tc.Input, pwd) + if err != nil { + if err.Error() != tc.Expected { + t.Fatalf("expected error %s, got %s for %s", tc.Expected, err.Error(), tc.Name) + } + } + + if output != tc.Output { + t.Fatalf("expected %s, got %s for %s", tc.Output, output, tc.Name) + } + } +} diff --git a/get_gcs.go b/get_gcs.go index 1cc7ec3b..fa1ffc25 100644 --- a/get_gcs.go +++ b/get_gcs.go @@ -193,7 +193,7 @@ func (g *GCSGetter) getObject(ctx context.Context, client *storage.Client, dst, } func (g *GCSGetter) parseURL(u *url.URL) (bucket, path, fragment string, err error) { - if strings.Contains(u.Host, "googleapis.com") { + if strings.HasSuffix(u.Host, ".googleapis.com") { hostParts := strings.Split(u.Host, ".") if len(hostParts) != 3 { err = fmt.Errorf("URL is not a valid GCS URL") @@ -208,6 +208,8 @@ func (g *GCSGetter) parseURL(u *url.URL) (bucket, path, fragment string, err err bucket = pathParts[3] path = pathParts[4] fragment = u.Fragment + } else { + err = fmt.Errorf("URL is not a valid GCS URL") } return } diff --git a/get_gcs_test.go b/get_gcs_test.go index 8f1e3c65..340c5a60 100644 --- a/get_gcs_test.go +++ b/get_gcs_test.go @@ -233,3 +233,59 @@ func TestGCSGetter_GetFile_OAuthAccessToken(t *testing.T) { } assertContents(t, dst, "# Main\n") } + +func Test_GCSGetter_ParseUrl(t *testing.T) { + tests := []struct { + name string + url string + }{ + { + name: "valid host", + url: "https://www.googleapis.com/storage/v1/hc-go-getter-test/go-getter/foobar", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := new(GCSGetter) + u, err := url.Parse(tt.url) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + _, _, _, err = g.parseURL(u) + if err != nil { + t.Fatalf("wasn't expecting error, got %s", err) + } + }) + } +} +func Test_GCSGetter_ParseUrl_Malformed(t *testing.T) { + tests := []struct { + name string + url string + }{ + { + name: "invalid host suffix", + url: "https://www.googleapis.com.invalid", + }, + { + name: "host suffix with a typo", + url: "https://www.googleapi.com.", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := new(GCSGetter) + u, err := url.Parse(tt.url) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + _, _, _, err = g.parseURL(u) + if err == nil { + t.Fatalf("expected error, got none") + } + if err.Error() != "URL is not a valid GCS URL" { + t.Fatalf("expected error 'URL is not a valid GCS URL', got %s", err.Error()) + } + }) + } +} diff --git a/get_s3.go b/get_s3.go index 1da3e2a0..a4148af0 100644 --- a/get_s3.go +++ b/get_s3.go @@ -252,7 +252,7 @@ func (g *S3Getter) parseUrl(u *url.URL) (region, bucket, path, version string, c // This just check whether we are dealing with S3 or // any other S3 compliant service. S3 has a predictable // url as others do not - if strings.Contains(u.Host, "amazonaws.com") { + if strings.HasSuffix(u.Host, ".amazonaws.com") { // Amazon S3 supports both virtual-hosted–style and path-style URLs to access a bucket, although path-style is deprecated // In both cases few older regions supports dash-style region indication (s3-Region) even if AWS discourages their use. // The same bucket could be reached with: @@ -304,7 +304,7 @@ func (g *S3Getter) parseUrl(u *url.URL) (region, bucket, path, version string, c path = pathParts[1] } - if len(hostParts) < 3 && len(hostParts) > 5 { + if len(hostParts) < 3 || len(hostParts) > 5 { err = fmt.Errorf("URL is not a valid S3 URL") return } diff --git a/get_s3_test.go b/get_s3_test.go index c342d250..aa17fed8 100644 --- a/get_s3_test.go +++ b/get_s3_test.go @@ -278,26 +278,40 @@ func TestS3Getter_Url(t *testing.T) { func Test_S3Getter_ParseUrl_Malformed(t *testing.T) { tests := []struct { - name string - url string + name string + input string + expected string }{ { - name: "path style", - url: "https://s3.amazonaws.com/bucket", + name: "path style", + input: "https://s3.amazonaws.com/bucket", + expected: "URL is not a valid S3 URL", }, { - name: "vhost-style, dash region indication", - url: "https://bucket.s3-us-east-1.amazonaws.com", + name: "vhost-style, dash region indication", + input: "https://bucket.s3-us-east-1.amazonaws.com", + expected: "URL is not a valid S3 URL", }, { - name: "vhost-style, dot region indication", - url: "https://bucket.s3.us-east-1.amazonaws.com", + name: "vhost-style, dot region indication", + input: "https://bucket.s3.us-east-1.amazonaws.com", + expected: "URL is not a valid S3 URL", + }, + { + name: "invalid host parts", + input: "https://invalid.host.parts.lenght.s3.us-east-1.amazonaws.com", + expected: "URL is not a valid S3 URL", + }, + { + name: "invalid host suffix", + input: "https://bucket.s3.amazonaws.com.invalid", + expected: "URL is not a valid S3 compliant URL", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { g := new(S3Getter) - u, err := url.Parse(tt.url) + u, err := url.Parse(tt.input) if err != nil { t.Fatalf("unexpected error: %s", err) } @@ -305,8 +319,8 @@ func Test_S3Getter_ParseUrl_Malformed(t *testing.T) { if err == nil { t.Fatalf("expected error, got none") } - if err.Error() != "URL is not a valid S3 URL" { - t.Fatalf("expected error 'URL is not a valid S3 URL', got %s", err.Error()) + if err.Error() != tt.expected { + t.Fatalf("expected error '%s', got %s for %s", tt.expected, err.Error(), tt.name) } }) }