From a4043c62cc2329bacda331d33fc908ab11ef0ec3 Mon Sep 17 00:00:00 2001 From: Mo Date: Tue, 23 May 2017 23:41:27 +0800 Subject: [PATCH] use http.StatusOK as initial value for responseLogger.status (#103) --- handlers.go | 8 ++------ handlers_test.go | 24 ++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/handlers.go b/handlers.go index 551f8ff..75db7f8 100644 --- a/handlers.go +++ b/handlers.go @@ -79,9 +79,9 @@ func (h combinedLoggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque } func makeLogger(w http.ResponseWriter) loggingResponseWriter { - var logger loggingResponseWriter = &responseLogger{w: w} + var logger loggingResponseWriter = &responseLogger{w: w, status: http.StatusOK} if _, ok := w.(http.Hijacker); ok { - logger = &hijackLogger{responseLogger{w: w}} + logger = &hijackLogger{responseLogger{w: w, status: http.StatusOK}} } h, ok1 := logger.(http.Hijacker) c, ok2 := w.(http.CloseNotifier) @@ -114,10 +114,6 @@ func (l *responseLogger) Header() http.Header { } func (l *responseLogger) Write(b []byte) (int, error) { - if l.status == 0 { - // The status will be StatusOK if WriteHeader has not been called yet - l.status = http.StatusOK - } size, err := l.w.Write(b) l.size += size return size, err diff --git a/handlers_test.go b/handlers_test.go index 6ea7c7f..04ee244 100644 --- a/handlers_test.go +++ b/handlers_test.go @@ -73,6 +73,30 @@ func TestMethodHandler(t *testing.T) { } } +func TestMakeLogger(t *testing.T) { + rec := httptest.NewRecorder() + logger := makeLogger(rec) + // initial status + if logger.Status() != http.StatusOK { + t.Fatalf("wrong status, got %d want %d", logger.Status(), http.StatusOK) + } + // WriteHeader + logger.WriteHeader(http.StatusInternalServerError) + if logger.Status() != http.StatusInternalServerError { + t.Fatalf("wrong status, got %d want %d", logger.Status(), http.StatusInternalServerError) + } + // Write + logger.Write([]byte(ok)) + if logger.Size() != len(ok) { + t.Fatalf("wrong size, got %d want %d", logger.Size(), len(ok)) + } + // Header + logger.Header().Set("key", "value") + if val := logger.Header().Get("key"); val != "value" { + t.Fatalf("wrong header, got %s want %s", val, "value") + } +} + func TestWriteLog(t *testing.T) { loc, err := time.LoadLocation("Europe/Warsaw") if err != nil {