Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix http.Flusher and io.ReaderFrom implementation #923

113 changes: 25 additions & 88 deletions http/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,22 @@
return i.w.Header()
}

var _ http.ResponseWriter = (*rwInterceptor)(nil)
func (i *rwInterceptor) ReadFrom(r io.Reader) (n int64, err error) {
return io.Copy(i, r)
}

func (i *rwInterceptor) Flush() {
// coraza middleware always needs to buffer the entire request, response cycle
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment made me take a closer look at http.Flusher and there is one thing that can and should be in this implementation.

WriteHeader must be called if it wasn't called before with a status code of 200.

Thank you for pointing out that the was room for improvement here.

// we can not flush early
}

type responseWriter interface {
http.ResponseWriter
io.ReaderFrom
http.Flusher
}

var _ responseWriter = (*rwInterceptor)(nil)

// wrap wraps the interceptor into a response writer that also preserves
// the http interfaces implemented by the original response writer to avoid
Expand Down Expand Up @@ -168,110 +183,32 @@
var (
hijacker, isHijacker = i.w.(http.Hijacker)
pusher, isPusher = i.w.(http.Pusher)
flusher, isFlusher = i.w.(http.Flusher)
reader, isReader = i.w.(io.ReaderFrom)
)

switch {
case !isHijacker && !isPusher && !isFlusher && !isReader:
case !isHijacker && !isPusher:
return struct {
http.ResponseWriter
responseWriter
}{i}, responseProcessor
case !isHijacker && !isPusher && !isFlusher && isReader:
return struct {
http.ResponseWriter
io.ReaderFrom
}{i, reader}, responseProcessor
case !isHijacker && !isPusher && isFlusher && !isReader:
return struct {
http.ResponseWriter
http.Flusher
}{i, flusher}, responseProcessor
case !isHijacker && !isPusher && isFlusher && isReader:
case !isHijacker && isPusher:
return struct {
http.ResponseWriter
http.Flusher
io.ReaderFrom
}{i, flusher, reader}, responseProcessor
case !isHijacker && isPusher && !isFlusher && !isReader:
return struct {
http.ResponseWriter
responseWriter
http.Pusher
}{i, pusher}, responseProcessor
case !isHijacker && isPusher && !isFlusher && isReader:
return struct {
http.ResponseWriter
http.Pusher
io.ReaderFrom
}{i, pusher, reader}, responseProcessor
case !isHijacker && isPusher && isFlusher && !isReader:
return struct {
http.ResponseWriter
http.Pusher
http.Flusher
}{i, pusher, flusher}, responseProcessor
case !isHijacker && isPusher && isFlusher && isReader:
return struct {
http.ResponseWriter
http.Pusher
http.Flusher
io.ReaderFrom
}{i, pusher, flusher, reader}, responseProcessor
case isHijacker && !isPusher && !isFlusher && !isReader:
case isHijacker && !isPusher:
return struct {
http.ResponseWriter
responseWriter
http.Hijacker
}{i, hijacker}, responseProcessor
case isHijacker && !isPusher && !isFlusher && isReader:
case isHijacker && isPusher:

Check warning on line 203 in http/interceptor.go

View check run for this annotation

Codecov / codecov/patch

http/interceptor.go#L203

Added line #L203 was not covered by tests
return struct {
http.ResponseWriter
http.Hijacker
io.ReaderFrom
}{i, hijacker, reader}, responseProcessor
case isHijacker && !isPusher && isFlusher && !isReader:
return struct {
http.ResponseWriter
http.Hijacker
http.Flusher
}{i, hijacker, flusher}, responseProcessor
case isHijacker && !isPusher && isFlusher && isReader:
return struct {
http.ResponseWriter
http.Hijacker
http.Flusher
io.ReaderFrom
}{i, hijacker, flusher, reader}, responseProcessor
case isHijacker && isPusher && !isFlusher && !isReader:
return struct {
http.ResponseWriter
responseWriter

Check warning on line 205 in http/interceptor.go

View check run for this annotation

Codecov / codecov/patch

http/interceptor.go#L205

Added line #L205 was not covered by tests
http.Hijacker
http.Pusher
}{i, hijacker, pusher}, responseProcessor
case isHijacker && isPusher && !isFlusher && isReader:
return struct {
http.ResponseWriter
http.Hijacker
http.Pusher
io.ReaderFrom
}{i, hijacker, pusher, reader}, responseProcessor
case isHijacker && isPusher && isFlusher && !isReader:
return struct {
http.ResponseWriter
http.Hijacker
http.Pusher
http.Flusher
}{i, hijacker, pusher, flusher}, responseProcessor
case isHijacker && isPusher && isFlusher && isReader:
return struct {
http.ResponseWriter
http.Hijacker
http.Pusher
http.Flusher
io.ReaderFrom
}{i, hijacker, pusher, flusher, reader}, responseProcessor
default:
return struct {
http.ResponseWriter
responseWriter

Check warning on line 211 in http/interceptor.go

View check run for this annotation

Codecov / codecov/patch

http/interceptor.go#L211

Added line #L211 was not covered by tests
}{i}, responseProcessor
}
}
205 changes: 205 additions & 0 deletions http/interceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
package http

import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"testing"
Expand Down Expand Up @@ -44,3 +46,206 @@ func TestWriteHeader(t *testing.T) {
t.Errorf("unexpected status code, want %d, have %d", want, have)
}
}

func TestWrite(t *testing.T) {
waf, err := coraza.NewWAF(coraza.NewWAFConfig())
if err != nil {
t.Fatal(err)
}

tx := waf.NewTransaction()
req, _ := http.NewRequest("GET", "", nil)
res := httptest.NewRecorder()

rw, responseProcessor := wrap(res, req, tx)
_, err = rw.Write([]byte("hello"))
if err != nil {
t.Errorf("unexpected error: %v", err)
}

_, err = rw.Write([]byte("world"))
if err != nil {
t.Errorf("unexpected error: %v", err)
}

err = responseProcessor(tx, req)
if err != nil {
t.Errorf("unexpected error: %v", err)
}

if want, have := 200, res.Code; want != have {
t.Errorf("unexpected status code, want %d, have %d", want, have)
}
}

func TestWriteWithWriteHeader(t *testing.T) {
waf, err := coraza.NewWAF(coraza.NewWAFConfig())
if err != nil {
t.Fatal(err)
}

tx := waf.NewTransaction()
req, _ := http.NewRequest("GET", "", nil)
res := httptest.NewRecorder()

rw, responseProcessor := wrap(res, req, tx)
rw.WriteHeader(204)
// although we called WriteHeader, status code should be applied until
// responseProcessor is called.
if unwanted, have := 204, res.Code; unwanted == have {
t.Errorf("unexpected status code %d", have)
}

_, err = rw.Write([]byte("hello"))
if err != nil {
t.Errorf("unexpected error: %v", err)
}

_, err = rw.Write([]byte("world"))
if err != nil {
t.Errorf("unexpected error: %v", err)
}

err = responseProcessor(tx, req)
if err != nil {
t.Errorf("unexpected error: %v", err)
}

if want, have := 204, res.Code; want != have {
t.Errorf("unexpected status code, want %d, have %d", want, have)
}
}

func TestFlush(t *testing.T) {
waf, err := coraza.NewWAF(coraza.NewWAFConfig())
if err != nil {
t.Fatal(err)
}

tx := waf.NewTransaction()
req, _ := http.NewRequest("GET", "", nil)
res := httptest.NewRecorder()
rw, responseProcessor := wrap(res, req, tx)
rw.WriteHeader(204)
rw.(http.Flusher).Flush()
// although we called WriteHeader, status code should be applied until
// responseProcessor is called.
if unwanted, have := 204, res.Code; unwanted == have {
t.Errorf("unexpected status code %d", have)
}

err = responseProcessor(tx, req)
if err != nil {
t.Errorf("unexpected error: %v", err)
}

if want, have := 204, res.Code; want != have {
t.Errorf("unexpected status code, want %d, have %d", want, have)
}
}

type testReaderFrom struct {
io.Writer
}

func (x *testReaderFrom) ReadFrom(r io.Reader) (n int64, err error) {
return io.Copy(x, r)
}

func TestReadFrom(t *testing.T) {
waf, err := coraza.NewWAF(coraza.NewWAFConfig())
if err != nil {
t.Fatal(err)
}

tx := waf.NewTransaction()
req, _ := http.NewRequest("GET", "", nil)
res := httptest.NewRecorder()

type responseWriter interface {
http.ResponseWriter
http.Flusher
}

resWithReaderFrom := struct {
responseWriter
io.ReaderFrom
}{
res,
&testReaderFrom{res},
}

rw, responseProcessor := wrap(resWithReaderFrom, req, tx)
rw.WriteHeader(204)
// although we called WriteHeader, status code should be applied until
// responseProcessor is called.
if unwanted, have := 204, res.Code; unwanted == have {
t.Errorf("unexpected status code %d", have)
}

_, err = rw.(io.ReaderFrom).ReadFrom(bytes.NewBuffer([]byte("hello world")))
if err != nil {
t.Errorf("unexpected error: %v", err)
}

err = responseProcessor(tx, req)
if err != nil {
t.Errorf("unexpected error: %v", err)
}

if want, have := 204, res.Code; want != have {
t.Errorf("unexpected status code, want %d, have %d", want, have)
}
}

type testPusher struct{}

func (x *testPusher) Push(target string, opts *http.PushOptions) error {
romainmenke marked this conversation as resolved.
Show resolved Hide resolved
return nil
}

func TestPusher(t *testing.T) {
waf, err := coraza.NewWAF(coraza.NewWAFConfig())
if err != nil {
t.Fatal(err)
}

tx := waf.NewTransaction()
req, _ := http.NewRequest("GET", "", nil)
res := httptest.NewRecorder()

type responseWriter interface {
http.ResponseWriter
http.Flusher
}

resWithPush := struct {
responseWriter
http.Pusher
}{
res,
&testPusher{},
}

rw, responseProcessor := wrap(resWithPush, req, tx)
rw.WriteHeader(204)
err = rw.(http.Pusher).Push("http://example.com", nil)
if err != nil {
t.Errorf("unexpected error: %v", err)
}

// although we called WriteHeader, status code should be applied until
// responseProcessor is called.
if unwanted, have := 204, res.Code; unwanted == have {
t.Errorf("unexpected status code %d", have)
}

err = responseProcessor(tx, req)
if err != nil {
t.Errorf("unexpected error: %v", err)
}

if want, have := 204, res.Code; want != have {
t.Errorf("unexpected status code, want %d, have %d", want, have)
}
}
Loading