From b0bdbd9b5e2ac9e3007828ad8b86c60361ce288a Mon Sep 17 00:00:00 2001 From: Christopher Brown Date: Fri, 22 Jun 2018 16:34:46 -0500 Subject: [PATCH] requested refactor --- imageproxy.go | 35 +++++++++++++++-------------------- imageproxy_test.go | 36 ++++++++++++++---------------------- 2 files changed, 29 insertions(+), 42 deletions(-) diff --git a/imageproxy.go b/imageproxy.go index 0cb04329b..0d27c2454 100644 --- a/imageproxy.go +++ b/imageproxy.go @@ -168,13 +168,13 @@ func (p *Proxy) serveImage(w http.ResponseWriter, r *http.Request) { return } - if contentType := p.allowedContentType(resp.Header.Get("Content-Type")); contentType != "" { - w.Header().Set("Content-Type", contentType) - w.Header().Set("X-Content-Type-Options", "nosniff") - } else { + contentType, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Type")) + if !validContentType(p.ContentTypes, contentType) { http.Error(w, "forbidden content-type", http.StatusForbidden) return } + w.Header().Set("Content-Type", contentType) + w.Header().Set("X-Content-Type-Options", "nosniff") copyHeader(w.Header(), resp.Header, "Content-Length") @@ -225,16 +225,11 @@ func (p *Proxy) allowed(r *Request) error { return fmt.Errorf("request does not contain an allowed host or valid signature: %v", r) } -// allowedContentType returns an allowed content type string to use in responses or "" if the -// content type cannot be used. -func (p *Proxy) allowedContentType(contentType string) string { - mediaType, _, _ := mime.ParseMediaType(contentType) - if mediaType == "" { - return "" - } - - if len(p.ContentTypes) == 0 { - switch mediaType { +// validContentType returns whether contentType matches one of the patterns. If no patterns are +// given, validContentType returns true if contentType matches one of the whitelisted image types. +func validContentType(patterns []string, contentType string) bool { + if len(patterns) == 0 { + switch contentType { case "image/bmp", "image/cgm", "image/g3fax", "image/gif", "image/ief", "image/jp2", "image/jpeg", "image/jpg", "image/pict", "image/png", "image/prs.btif", "image/svg+xml", "image/tiff", "image/vnd.adobe.photoshop", "image/vnd.djvu", "image/vnd.dwg", @@ -246,18 +241,18 @@ func (p *Proxy) allowedContentType(contentType string) string { "image/x-portable-bitmap", "image/x-portable-graymap", "image/x-portable-pixmap", "image/x-quicktime", "image/x-rgb", "image/x-xbitmap", "image/x-xpixmap", "image/x-xwindowdump": - return mediaType + return true } - return "" + return false } - for _, pattern := range p.ContentTypes { - if ok, err := filepath.Match(pattern, mediaType); ok && err == nil { - return mediaType + for _, pattern := range patterns { + if ok, err := filepath.Match(pattern, contentType); ok && err == nil { + return true } } - return "" + return false } // validHost returns whether the host in u matches one of hosts. diff --git a/imageproxy_test.go b/imageproxy_test.go index 84a3c516b..e084aebc1 100644 --- a/imageproxy_test.go +++ b/imageproxy_test.go @@ -412,36 +412,28 @@ func TestTransformingTransport(t *testing.T) { } } -func TestAllowedContentType(t *testing.T) { - p := &Proxy{} - - for contentType, expected := range map[string]string{ - "": "", - "image/png": "image/png", - "image/PNG": "image/png", - "image/PNG; foo=bar": "image/png", - "text/html": "", +func TestValidContentType(t *testing.T) { + for contentType, expected := range map[string]bool{ + "": false, + "image/png": true, + "text/html": false, } { - actual := p.allowedContentType(contentType) + actual := validContentType(nil, contentType) if actual != expected { t.Errorf("got %v, expected %v for content type: %v", actual, expected, contentType) } } } -func TestAllowedContentType_Whitelist(t *testing.T) { - p := &Proxy{ - ContentTypes: []string{"foo/*", "bar/baz"}, - } - - for contentType, expected := range map[string]string{ - "": "", - "image/png": "", - "foo/asdf": "foo/asdf", - "bar/baz": "bar/baz", - "bar/bazz": "", +func TestValidContentType_Patterns(t *testing.T) { + for contentType, expected := range map[string]bool{ + "": false, + "image/png": false, + "foo/asdf": true, + "bar/baz": true, + "bar/bazz": false, } { - actual := p.allowedContentType(contentType) + actual := validContentType([]string{"foo/*", "bar/baz"}, contentType) if actual != expected { t.Errorf("got %v, expected %v for content type: %v", actual, expected, contentType) }