From 537f38a110f0437a68eaa021c88e47c68f2ed211 Mon Sep 17 00:00:00 2001 From: Damien Mathieu <42@dmathieu.com> Date: Tue, 3 Jan 2023 18:27:08 +0100 Subject: [PATCH] Refactor handler_test tests into subtests to reduce repetition (#3003) * refactor handler_test tests into subtests to reduce repetition * provide testing.T to the handler Co-authored-by: Chester Cheung Co-authored-by: Tyler Yahn --- .../net/http/otelhttp/handler_test.go | 117 ++++++++++-------- 1 file changed, 67 insertions(+), 50 deletions(-) diff --git a/instrumentation/net/http/otelhttp/handler_test.go b/instrumentation/net/http/otelhttp/handler_test.go index 1a2683d348f..78c23afd79e 100644 --- a/instrumentation/net/http/otelhttp/handler_test.go +++ b/instrumentation/net/http/otelhttp/handler_test.go @@ -18,6 +18,7 @@ import ( "io" "net/http" "net/http/httptest" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -26,56 +27,72 @@ import ( "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" ) -func TestResponseWriterImplementsFlusher(t *testing.T) { - h := otelhttp.NewHandler( - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Implements(t, (*http.Flusher)(nil), w) - }), "test_handler", - ) +func TestHandler(t *testing.T) { + testCases := []struct { + name string + handler func(*testing.T) http.Handler + requestBody io.Reader + expectedStatusCode int + }{ + { + name: "implements flusher", + handler: func(t *testing.T) http.Handler { + return otelhttp.NewHandler( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Implements(t, (*http.Flusher)(nil), w) + }), "test_handler", + ) + }, + expectedStatusCode: 200, + }, + { + name: "succeeds", + handler: func(t *testing.T) http.Handler { + return otelhttp.NewHandler( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.NotNil(t, r.Body) - r, err := http.NewRequest(http.MethodGet, "http://localhost/", nil) - require.NoError(t, err) + b, err := io.ReadAll(r.Body) + assert.NoError(t, err) + assert.Equal(t, "hello world", string(b)) + }), "test_handler", + ) + }, + requestBody: strings.NewReader("hello world"), + expectedStatusCode: 200, + }, + { + name: "succeeds with a nil body", + handler: func(t *testing.T) http.Handler { + return otelhttp.NewHandler( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Nil(t, r.Body) + }), "test_handler", + ) + }, + expectedStatusCode: 200, + }, + { + name: "succeeds with an http.NoBody", + handler: func(t *testing.T) http.Handler { + return otelhttp.NewHandler( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.NoBody, r.Body) + }), "test_handler", + ) + }, + requestBody: http.NoBody, + expectedStatusCode: 200, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + r, err := http.NewRequest(http.MethodGet, "http://localhost/", tc.requestBody) + require.NoError(t, err) - h.ServeHTTP(httptest.NewRecorder(), r) -} - -// This use case is important as we make sure the body isn't mutated -// when it is nil. This is a common use case for tests where the request -// is directly passed to the handler. -func TestHandlerReadingNilBodySuccess(t *testing.T) { - h := otelhttp.NewHandler( - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Body != nil { - _, err := io.ReadAll(r.Body) - assert.NoError(t, err) - } - }), "test_handler", - ) - - r, err := http.NewRequest(http.MethodGet, "http://localhost/", nil) - require.NoError(t, err) - - rr := httptest.NewRecorder() - h.ServeHTTP(rr, r) - assert.Equal(t, 200, rr.Result().StatusCode) -} - -// This use case is important as we make sure the body isn't mutated -// when it is NoBody. -func TestHandlerReadingNoBodySuccess(t *testing.T) { - h := otelhttp.NewHandler( - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Body != http.NoBody { - _, err := io.ReadAll(r.Body) - assert.NoError(t, err) - } - }), "test_handler", - ) - - r, err := http.NewRequest(http.MethodGet, "http://localhost/", http.NoBody) - require.NoError(t, err) - - rr := httptest.NewRecorder() - h.ServeHTTP(rr, r) - assert.Equal(t, 200, rr.Result().StatusCode) + rr := httptest.NewRecorder() + tc.handler(t).ServeHTTP(rr, r) + assert.Equal(t, tc.expectedStatusCode, rr.Result().StatusCode) + }) + } }