From 56862b5638da936914e84e97459e047fbef65a1d Mon Sep 17 00:00:00 2001 From: Fred Carle Date: Sun, 8 May 2022 23:39:31 -0400 Subject: [PATCH] fix final suggestions --- api/http/errors.go | 2 +- api/http/handler.go | 6 +++--- api/http/handler_test.go | 2 +- api/http/{api.go => http.go} | 0 api/http/logger.go | 3 +++ api/http/server.go | 9 ++++++--- api/http/server_test.go | 20 +++++++++++++++++--- 7 files changed, 31 insertions(+), 11 deletions(-) rename api/http/{api.go => http.go} (100%) diff --git a/api/http/errors.go b/api/http/errors.go index 1f24141868..2d440890a8 100644 --- a/api/http/errors.go +++ b/api/http/errors.go @@ -28,7 +28,7 @@ type errorResponse struct { func handleErr(ctx context.Context, rw http.ResponseWriter, err error, status int) { if status == http.StatusInternalServerError { - log.ErrorE(context.Background(), http.StatusText(status), err) + log.ErrorE(ctx, http.StatusText(status), err) } sendJSON( diff --git a/api/http/handler.go b/api/http/handler.go index 147f0b9924..23cdda6d52 100644 --- a/api/http/handler.go +++ b/api/http/handler.go @@ -27,7 +27,7 @@ type handler struct { *chi.Mux } -type ctxKey string +type ctxDB struct{} // newHandler returns a handler with the router instantiated. func newHandler(db client.DB) *handler { @@ -36,7 +36,7 @@ func newHandler(db client.DB) *handler { func (h *handler) handle(f http.HandlerFunc) http.HandlerFunc { return func(rw http.ResponseWriter, req *http.Request) { - ctx := context.WithValue(req.Context(), ctxKey("DB"), h.db) + ctx := context.WithValue(req.Context(), ctxDB{}, h.db) f(rw, req.WithContext(ctx)) } } @@ -70,7 +70,7 @@ func sendJSON(ctx context.Context, rw http.ResponseWriter, v interface{}, code i } func dbFromContext(ctx context.Context) (client.DB, error) { - db, ok := ctx.Value(ctxKey("DB")).(client.DB) + db, ok := ctx.Value(ctxDB{}).(client.DB) if !ok { return nil, errors.New("no database available") } diff --git a/api/http/handler_test.go b/api/http/handler_test.go index 663a267492..04d3a681e2 100644 --- a/api/http/handler_test.go +++ b/api/http/handler_test.go @@ -183,7 +183,7 @@ func TestDbFromContext(t *testing.T) { t.Fatal(err) } - reqCtx := context.WithValue(ctx, ctxKey("DB"), defra) + reqCtx := context.WithValue(ctx, ctxDB{}, defra) _, err = dbFromContext(reqCtx) assert.NoError(t, err) diff --git a/api/http/api.go b/api/http/http.go similarity index 100% rename from api/http/api.go rename to api/http/http.go diff --git a/api/http/logger.go b/api/http/logger.go index 56785a0231..29fdb6ddd3 100644 --- a/api/http/logger.go +++ b/api/http/logger.go @@ -39,9 +39,12 @@ func (lrw *loggingResponseWriter) WriteHeader(code int) { } func (lrw *loggingResponseWriter) Write(b []byte) (int, error) { + // used for chucked payloads. Content-Length should not be set + // for each chunk. if lrw.ResponseWriter.Header().Get("Content-Length") != "" { return lrw.ResponseWriter.Write(b) } + lrw.contentLength = len(b) lrw.ResponseWriter.Header().Set("Content-Length", strconv.Itoa(lrw.contentLength)) return lrw.ResponseWriter.Write(b) diff --git a/api/http/server.go b/api/http/server.go index bf838ef62e..f75fd28114 100644 --- a/api/http/server.go +++ b/api/http/server.go @@ -18,17 +18,20 @@ import ( // The Server struct holds the Handler for the HTTP API type Server struct { - Handler http.Handler + http.Server } // NewServer instantiated a new server with the given http.Handler. func NewServer(db client.DB) *Server { return &Server{ - Handler: newHandler(db), + http.Server{ + Handler: newHandler(db), + }, } } // Listen calls ListenAndServe with our router. func (s *Server) Listen(addr string) error { - return http.ListenAndServe(addr, s.Handler) + s.Addr = addr + return s.ListenAndServe() } diff --git a/api/http/server_test.go b/api/http/server_test.go index d4a4c48b7e..d8c1c5e4fc 100644 --- a/api/http/server_test.go +++ b/api/http/server_test.go @@ -11,17 +11,31 @@ package http import ( + "context" + "net/http" "testing" "github.com/stretchr/testify/assert" ) func TestNewServerAndListen(t *testing.T) { - // @TODO: maybe it would be worth doing something a bit more thorough - - // test with no config s := NewServer(nil) if ok := assert.NotNil(t, s); ok { assert.Error(t, s.Listen(":303000")) } + + serverRunning := make(chan struct{}) + serverDone := make(chan struct{}) + go func() { + close(serverRunning) + err := s.Listen(":3131") + assert.ErrorIs(t, http.ErrServerClosed, err) + defer close(serverDone) + }() + + <-serverRunning + + s.Shutdown(context.Background()) + + <-serverDone }