diff --git a/component/http/handler.go b/component/http/handler.go index a3c1ff031..d8fe783cb 100644 --- a/component/http/handler.go +++ b/component/http/handler.go @@ -30,7 +30,7 @@ func handler(hnd ProcessorFunc) http.HandlerFunc { // TODO : for cached responses this becomes inconsistent, to be fixed in #160 // the corID will be passed to all consecutive responses // if it was missing from the initial request - corID := correlation.GetOrSetHeaderID(r.Header) + corID := getOrSetCorrelationID(r.Header) ctx := correlation.ContextWithID(r.Context(), corID) logger := log.Sub(map[string]interface{}{correlation.ID: corID}) ctx = log.WithContext(ctx, logger) @@ -52,6 +52,26 @@ func handler(hnd ProcessorFunc) http.HandlerFunc { } } +func getOrSetCorrelationID(h http.Header) string { + cor, ok := h[correlation.HeaderID] + if !ok { + corID := correlation.New() + h.Set(correlation.HeaderID, corID) + return corID + } + if len(cor) == 0 { + corID := correlation.New() + h.Set(correlation.HeaderID, corID) + return corID + } + if cor[0] == "" { + corID := correlation.New() + h.Set(correlation.HeaderID, corID) + return corID + } + return cor[0] +} + func determineEncoding(h http.Header) (string, encoding.DecodeFunc, encoding.EncodeFunc, error) { cth, cok := h[encoding.ContentTypeHeader] ach, aok := h[encoding.AcceptHeader] diff --git a/component/http/handler_test.go b/component/http/handler_test.go index 1c2155f23..00edbaeb4 100644 --- a/component/http/handler_test.go +++ b/component/http/handler_test.go @@ -9,6 +9,7 @@ import ( "reflect" "testing" + "github.com/beatlabs/patron/correlation" "github.com/beatlabs/patron/encoding" "github.com/beatlabs/patron/encoding/json" "github.com/beatlabs/patron/encoding/protobuf" @@ -204,6 +205,33 @@ func Test_handleError(t *testing.T) { } } +func Test_getOrSetCorrelationID(t *testing.T) { + t.Parallel() + withID := http.Header{correlation.HeaderID: []string{"123"}} + withoutID := http.Header{correlation.HeaderID: []string{}} + withEmptyID := http.Header{correlation.HeaderID: []string{""}} + missingHeader := http.Header{} + type args struct { + hdr http.Header + } + tests := map[string]struct { + args args + }{ + "with id": {args: args{hdr: withID}}, + "without id": {args: args{hdr: withoutID}}, + "with empty id": {args: args{hdr: withEmptyID}}, + "missing Header": {args: args{hdr: missingHeader}}, + } + for name, tt := range tests { + tt := tt + t.Run(name, func(t *testing.T) { + t.Parallel() + assert.NotEmpty(t, getOrSetCorrelationID(tt.args.hdr)) + assert.NotEmpty(t, tt.args.hdr[correlation.HeaderID][0]) + }) + } +} + type testHandler struct { err bool resp interface{} diff --git a/component/http/middleware/middleware.go b/component/http/middleware/middleware.go index 8de10c45a..b168dbc04 100644 --- a/component/http/middleware/middleware.go +++ b/component/http/middleware/middleware.go @@ -163,7 +163,7 @@ func NewAuth(auth auth.Authenticator) Func { func NewLoggingTracing(path string, statusCodeLogger StatusCodeLoggerHandler) Func { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - corID := correlation.GetOrSetHeaderID(r.Header) + corID := getOrSetCorrelationID(r.Header) sp, r := span(path, corID, r) lw := newResponseWriter(w, true) next.ServeHTTP(lw, r) @@ -180,7 +180,7 @@ func NewLoggingTracing(path string, statusCodeLogger StatusCodeLoggerHandler) Fu func NewInjectObservability() Func { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - corID := correlation.GetOrSetHeaderID(r.Header) + corID := getOrSetCorrelationID(r.Header) ctx := correlation.ContextWithID(r.Context(), corID) logger := log.Sub(map[string]interface{}{correlation.ID: corID}) ctx = log.WithContext(ctx, logger) @@ -189,6 +189,26 @@ func NewInjectObservability() Func { } } +func getOrSetCorrelationID(h http.Header) string { + cor, ok := h[correlation.HeaderID] + if !ok { + corID := correlation.New() + h.Set(correlation.HeaderID, corID) + return corID + } + if len(cor) == 0 { + corID := correlation.New() + h.Set(correlation.HeaderID, corID) + return corID + } + if cor[0] == "" { + corID := correlation.New() + h.Set(correlation.HeaderID, corID) + return corID + } + return cor[0] +} + func initHTTPServerMetrics() { httpStatusTracingHandledMetric = prometheus.NewCounterVec( prometheus.CounterOpts{ diff --git a/component/http/middleware/middleware_test.go b/component/http/middleware/middleware_test.go index 7a3537ca3..fe425f3b4 100644 --- a/component/http/middleware/middleware_test.go +++ b/component/http/middleware/middleware_test.go @@ -8,6 +8,7 @@ import ( "testing" httpcache "github.com/beatlabs/patron/component/http/cache" + "github.com/beatlabs/patron/correlation" "github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go/mocktracer" "github.com/stretchr/testify/assert" @@ -829,3 +830,30 @@ func TestNewAppVersion(t *testing.T) { }) } } + +func Test_getOrSetCorrelationID(t *testing.T) { + t.Parallel() + withID := http.Header{correlation.HeaderID: []string{"123"}} + withoutID := http.Header{correlation.HeaderID: []string{}} + withEmptyID := http.Header{correlation.HeaderID: []string{""}} + missingHeader := http.Header{} + type args struct { + hdr http.Header + } + tests := map[string]struct { + args args + }{ + "with id": {args: args{hdr: withID}}, + "without id": {args: args{hdr: withoutID}}, + "with empty id": {args: args{hdr: withEmptyID}}, + "missing Header": {args: args{hdr: missingHeader}}, + } + for name, tt := range tests { + tt := tt + t.Run(name, func(t *testing.T) { + t.Parallel() + assert.NotEmpty(t, getOrSetCorrelationID(tt.args.hdr)) + assert.NotEmpty(t, tt.args.hdr[correlation.HeaderID][0]) + }) + } +} diff --git a/correlation/correlation.go b/correlation/correlation.go index 65e614ece..20619294b 100644 --- a/correlation/correlation.go +++ b/correlation/correlation.go @@ -3,7 +3,6 @@ package correlation import ( "context" - "net/http" "github.com/google/uuid" ) @@ -19,36 +18,21 @@ type idContextKey struct{} var idKey = idContextKey{} +// New correlation ID. +func New() string { + return uuid.New().String() +} + // IDFromContext returns the correlation ID from the context. // If no ID is set a new one is generated. func IDFromContext(ctx context.Context) string { if id, ok := ctx.Value(idKey).(string); ok { return id } - return uuid.New().String() + return New() } // ContextWithID sets a correlation ID to a context. func ContextWithID(ctx context.Context, correlationID string) context.Context { return context.WithValue(ctx, idKey, correlationID) } - -func GetOrSetHeaderID(h http.Header) string { - cor, ok := h[HeaderID] - if !ok { - corID := uuid.New().String() - h.Set(HeaderID, corID) - return corID - } - if len(cor) == 0 { - corID := uuid.New().String() - h.Set(HeaderID, corID) - return corID - } - if cor[0] == "" { - corID := uuid.New().String() - h.Set(HeaderID, corID) - return corID - } - return cor[0] -} diff --git a/correlation/correlation_test.go b/correlation/correlation_test.go index 1f1f8261a..1c7a729f2 100644 --- a/correlation/correlation_test.go +++ b/correlation/correlation_test.go @@ -2,7 +2,6 @@ package correlation import ( "context" - "net/http" "testing" "github.com/stretchr/testify/assert" @@ -36,30 +35,3 @@ func TestContextWithID(t *testing.T) { assert.True(t, ok) assert.Equal(t, "123", val) } - -func TestGetOrSetHeaderID(t *testing.T) { - t.Parallel() - withID := http.Header{HeaderID: []string{"123"}} - withoutID := http.Header{HeaderID: []string{}} - withEmptyID := http.Header{HeaderID: []string{""}} - missingHeader := http.Header{} - type args struct { - hdr http.Header - } - tests := map[string]struct { - args args - }{ - "with id": {args: args{hdr: withID}}, - "without id": {args: args{hdr: withoutID}}, - "with empty id": {args: args{hdr: withEmptyID}}, - "missing Header": {args: args{hdr: missingHeader}}, - } - for name, tt := range tests { - tt := tt - t.Run(name, func(t *testing.T) { - t.Parallel() - assert.NotEmpty(t, GetOrSetHeaderID(tt.args.hdr)) - assert.NotEmpty(t, tt.args.hdr[HeaderID][0]) - }) - } -}