Skip to content

Commit

Permalink
gzhttp: Allow overriding decompression on transport (#892)
Browse files Browse the repository at this point in the history
This allows getting compressed data even if `Content-Encoding` is set.

Also allows decompression even if "Accept-Encoding" was not set by this client.
  • Loading branch information
klauspost authored Nov 29, 2023
1 parent c63a492 commit 98ff542
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 4 deletions.
25 changes: 21 additions & 4 deletions gzhttp/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
"github.com/klauspost/compress/zstd"
)

// Transport will wrap a transport with a custom handler
// Transport will wrap an HTTP transport with a custom handler
// that will request gzip and automatically decompress it.
// Using this is significantly faster than using the default transport.
func Transport(parent http.RoundTripper, opts ...transportOption) http.RoundTripper {
Expand Down Expand Up @@ -51,10 +51,21 @@ func TransportEnableGzip(b bool) transportOption {
}
}

// TransportCustomEval will send the header of a response to a custom function.
// If the function returns false, the response will be returned as-is,
// Otherwise it will be decompressed based on Content-Encoding field, regardless
// of whether the transport added the encoding.
func TransportCustomEval(fn func(header http.Header) bool) transportOption {
return func(c *gzRoundtripper) {
c.customEval = fn
}
}

type gzRoundtripper struct {
parent http.RoundTripper
acceptEncoding string
withZstd, withGzip bool
customEval func(header http.Header) bool
}

func (g *gzRoundtripper) RoundTrip(req *http.Request) (*http.Response, error) {
Expand Down Expand Up @@ -82,16 +93,22 @@ func (g *gzRoundtripper) RoundTrip(req *http.Request) (*http.Response, error) {
if err != nil || !requestedComp {
return resp, err
}

decompress := false
if g.customEval != nil {
if !g.customEval(resp.Header) {
return resp, nil
}
decompress = true
}
// Decompress
if g.withGzip && asciiEqualFold(resp.Header.Get("Content-Encoding"), "gzip") {
if (decompress || g.withGzip) && asciiEqualFold(resp.Header.Get("Content-Encoding"), "gzip") {
resp.Body = &gzipReader{body: resp.Body}
resp.Header.Del("Content-Encoding")
resp.Header.Del("Content-Length")
resp.ContentLength = -1
resp.Uncompressed = true
}
if g.withZstd && asciiEqualFold(resp.Header.Get("Content-Encoding"), "zstd") {
if (decompress || g.withZstd) && asciiEqualFold(resp.Header.Get("Content-Encoding"), "zstd") {
resp.Body = &zstdReader{body: resp.Body}
resp.Header.Del("Content-Encoding")
resp.Header.Del("Content-Length")
Expand Down
45 changes: 45 additions & 0 deletions gzhttp/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,51 @@ func TestDefaultTransport(t *testing.T) {
}
}

func TestTransportCustomEval(t *testing.T) {
bin, err := os.ReadFile("testdata/benchmark.json")
if err != nil {
t.Fatal(err)
}

server := httptest.NewServer(newTestHandler(bin))
calledWith := ""
c := http.Client{Transport: Transport(http.DefaultTransport, TransportEnableZstd(false), TransportCustomEval(func(h http.Header) bool {
calledWith = h.Get("Content-Encoding")
return true
}))}
resp, err := c.Get(server.URL)
if err != nil {
t.Fatal(err)
}
got, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(got, bin) {
t.Errorf("data mismatch")
}
if calledWith != "gzip" {
t.Fatalf("Expected encoding %q, got %q", "gzip", calledWith)
}
// Test returning false
c = http.Client{Transport: Transport(http.DefaultTransport, TransportCustomEval(func(h http.Header) bool {
calledWith = h.Get("Content-Encoding")
return false
}))}
resp, err = c.Get(server.URL)
if err != nil {
t.Fatal(err)
}
// Check we got the compressed data
gotCE := resp.Header.Get("Content-Encoding")
if gotCE != "gzip" {
t.Fatalf("Expected encoding %q, got %q", "gzip", gotCE)
}
if calledWith != "gzip" {
t.Fatalf("Expected encoding %q, got %q", "gzip", calledWith)
}
}

func BenchmarkTransport(b *testing.B) {
raw, err := os.ReadFile("testdata/benchmark.json")
if err != nil {
Expand Down

0 comments on commit 98ff542

Please sign in to comment.