diff --git a/compress.go b/compress.go index 9c5b5d5..98c8832 100644 --- a/compress.go +++ b/compress.go @@ -12,6 +12,8 @@ import ( "strings" ) +const acceptEncoding string = "Accept-Encoding" + type compressResponseWriter struct { io.Writer http.ResponseWriter @@ -82,7 +84,7 @@ func CompressHandlerLevel(h http.Handler, level int) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // detect what encoding to use var encoding string - for _, curEnc := range strings.Split(r.Header.Get("Accept-Encoding"), ",") { + for _, curEnc := range strings.Split(r.Header.Get(acceptEncoding), ",") { curEnc = strings.TrimSpace(curEnc) if curEnc == gzipEncoding || curEnc == flateEncoding { encoding = curEnc @@ -90,6 +92,9 @@ func CompressHandlerLevel(h http.Handler, level int) http.Handler { } } + // always add Accept-Encoding to Vary to prevent intermediate caches corruption + w.Header().Add("Vary", acceptEncoding) + // if we weren't able to identify an encoding we're familiar with, pass on the // request to the handler and return if encoding == "" { @@ -107,8 +112,7 @@ func CompressHandlerLevel(h http.Handler, level int) http.Handler { defer encWriter.Close() w.Header().Set("Content-Encoding", encoding) - r.Header.Del("Accept-Encoding") - w.Header().Add("Vary", "Accept-Encoding") + r.Header.Del(acceptEncoding) hijacker, ok := w.(http.Hijacker) if !ok { /* w is not Hijacker... oh well... */ diff --git a/compress_test.go b/compress_test.go index e42ff54..dbce929 100644 --- a/compress_test.go +++ b/compress_test.go @@ -26,7 +26,7 @@ func compressedRequest(w *httptest.ResponseRecorder, compression string) { })).ServeHTTP(w, &http.Request{ Method: "GET", Header: http.Header{ - "Accept-Encoding": []string{compression}, + acceptEncoding: []string{compression}, }, }) @@ -47,6 +47,9 @@ func TestCompressHandlerNoCompression(t *testing.T) { if l := w.HeaderMap.Get("Content-Length"); l != "9216" { t.Errorf("wrong content-length. got %q expected %d", l, 1024*9) } + if v := w.HeaderMap.Get("Vary"); v != acceptEncoding { + t.Errorf("wrong vary. got %s expected %s", v, acceptEncoding) + } } func TestAcceptEncodingIsDropped(t *testing.T) { @@ -90,7 +93,7 @@ func TestAcceptEncodingIsDropped(t *testing.T) { for _, tCase := range tCases { ch := CompressHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - acceptEnc := r.Header.Get("Accept-Encoding") + acceptEnc := r.Header.Get(acceptEncoding) if acceptEnc == "" && tCase.isPresent { t.Fatalf("%s: expected 'Accept-Encoding' header to be present but was not", tCase.name) } @@ -108,7 +111,7 @@ func TestAcceptEncodingIsDropped(t *testing.T) { ch.ServeHTTP(w, &http.Request{ Method: "GET", Header: http.Header{ - "Accept-Encoding": []string{tCase.compression}, + acceptEncoding: []string{tCase.compression}, }, }) } @@ -189,7 +192,7 @@ func TestCompressHandlerPreserveInterfaces(t *testing.T) { _ http.Hijacker = fullyFeaturedResponseWriter{} ) var h http.Handler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - comp := r.Header.Get("Accept-Encoding") + comp := r.Header.Get(acceptEncoding) if _, ok := rw.(*compressResponseWriter); !ok { t.Fatalf("ResponseWriter wasn't wrapped by compressResponseWriter, got %T type", rw) } @@ -211,9 +214,9 @@ func TestCompressHandlerPreserveInterfaces(t *testing.T) { if err != nil { t.Fatalf("Failed to create test request: %v", err) } - r.Header.Set("Accept-Encoding", "gzip") + r.Header.Set(acceptEncoding, "gzip") h.ServeHTTP(rw, r) - r.Header.Set("Accept-Encoding", "deflate") + r.Header.Set(acceptEncoding, "deflate") h.ServeHTTP(rw, r) }