diff --git a/module/apmelasticsearch/client.go b/module/apmelasticsearch/client.go index e5f49876c..3c9a24f59 100644 --- a/module/apmelasticsearch/client.go +++ b/module/apmelasticsearch/client.go @@ -63,17 +63,26 @@ type roundTripper struct { func (r *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { ctx := req.Context() tx := apm.TransactionFromContext(ctx) - if tx == nil || !tx.Sampled() { + if tx == nil { + return r.r.RoundTrip(req) + } + traceContext := tx.TraceContext() + if !tx.Sampled() { + apmhttp.SetHeaders(req, traceContext, false) return r.r.RoundTrip(req) } + propagateLegacyHeader := tx.ShouldPropagateLegacyHeader() name := requestName(req) span := tx.StartSpan(name, "db.elasticsearch", apm.SpanFromContext(ctx)) + if span.Dropped() { span.End() + apmhttp.SetHeaders(req, traceContext, propagateLegacyHeader) return r.r.RoundTrip(req) } + traceContext = span.TraceContext() statement, req := captureSearchStatement(req) username, _, _ := req.BasicAuth() ctx = apm.ContextWithSpan(ctx, span) @@ -89,6 +98,7 @@ func (r *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { User: username, }) + apmhttp.SetHeaders(req, traceContext, propagateLegacyHeader) resp, err := r.r.RoundTrip(req) if err != nil { span.End() diff --git a/module/apmelasticsearch/client_test.go b/module/apmelasticsearch/client_test.go index 593e09d36..aa15dcf4d 100644 --- a/module/apmelasticsearch/client_test.go +++ b/module/apmelasticsearch/client_test.go @@ -34,9 +34,12 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/net/context/ctxhttp" + "go.elastic.co/apm" "go.elastic.co/apm/apmtest" "go.elastic.co/apm/model" "go.elastic.co/apm/module/apmelasticsearch" + "go.elastic.co/apm/module/apmhttp" + "go.elastic.co/apm/transport/transporttest" ) func TestWrapRoundTripper(t *testing.T) { @@ -303,6 +306,116 @@ func TestDestination(t *testing.T) { test("http://[2001:db8::1]:80/_search", "2001:db8::1", 80) } +func TestTraceHeaders(t *testing.T) { + headers := make(map[string]string) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + for k, vs := range req.Header { + headers[k] = strings.Join(vs, " ") + } + })) + defer server.Close() + client := &http.Client{Transport: apmelasticsearch.WrapRoundTripper(http.DefaultTransport)} + + req, err := http.NewRequest("GET", server.URL, nil) + require.NoError(t, err) + + _, _, _ = apmtest.WithTransaction(func(ctx context.Context) { + _, err := client.Do(req.WithContext(ctx)) + assert.NoError(t, err) + }) + + assert.Contains(t, headers, apmhttp.ElasticTraceparentHeader) + assert.Contains(t, headers, apmhttp.W3CTraceparentHeader) + assert.Contains(t, headers, apmhttp.TracestateHeader) +} + +func TestClientSpanDropped(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.Write([]byte(req.Header.Get("Traceparent"))) + })) + defer server.Close() + + tracer, transport := transporttest.NewRecorderTracer() + defer tracer.Close() + + tracer.SetMaxSpans(1) + tx := tracer.StartTransaction("name", "type") + ctx := apm.ContextWithTransaction(context.Background(), tx) + + var responseBodies []string + for i := 0; i < 2; i++ { + body, err := doGET(ctx, server.URL) + require.NoError(t, err) + responseBodies = append(responseBodies, body) + } + + tx.End() + tracer.Flush(nil) + payloads := transport.Payloads() + require.Len(t, payloads.Spans, 1) + transaction := payloads.Transactions[0] + span := payloads.Spans[0] // for first request + + clientTraceContext, err := apmhttp.ParseTraceparentHeader(string(responseBodies[0])) + require.NoError(t, err) + assert.Equal(t, span.TraceID, model.TraceID(clientTraceContext.Trace)) + assert.Equal(t, span.ID, model.SpanID(clientTraceContext.Span)) + + clientTraceContext, err = apmhttp.ParseTraceparentHeader(string(responseBodies[1])) + require.NoError(t, err) + assert.Equal(t, transaction.TraceID, model.TraceID(clientTraceContext.Trace)) + assert.Equal(t, transaction.ID, model.SpanID(clientTraceContext.Span)) +} + +func TestClientTransactionUnsampled(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.Write([]byte(req.Header.Get("Traceparent"))) + })) + defer server.Close() + + tracer, transport := transporttest.NewRecorderTracer() + defer tracer.Close() + tracer.SetSampler(apm.NewRatioSampler(0)) // sample nothing + + tx := tracer.StartTransaction("name", "type") + ctx := apm.ContextWithTransaction(context.Background(), tx) + body, err := doGET(ctx, server.URL) + require.NoError(t, err) + + tx.End() + tracer.Flush(nil) + + payloads := transport.Payloads() + require.Len(t, payloads.Transactions, 1) + require.Len(t, payloads.Spans, 0) + transaction := payloads.Transactions[0] + + clientTraceContext, err := apmhttp.ParseTraceparentHeader(string(body)) + require.NoError(t, err) + assert.Equal(t, transaction.TraceID, model.TraceID(clientTraceContext.Trace)) + assert.Equal(t, transaction.ID, model.SpanID(clientTraceContext.Span)) +} + +func doGET(ctx context.Context, url string) (string, error) { + client := &http.Client{Transport: apmelasticsearch.WrapRoundTripper(http.DefaultTransport)} + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return "", err + } + resp, err := client.Do(req.WithContext(ctx)) + if err != nil { + return "", err + } + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return "", err + } + defer resp.Body.Close() + + return string(body), nil +} + type readCloser struct { io.Reader closed bool diff --git a/module/apmhttp/client.go b/module/apmhttp/client.go index e29846502..f48ca4877 100644 --- a/module/apmhttp/client.go +++ b/module/apmhttp/client.go @@ -97,7 +97,7 @@ func (r *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { propagateLegacyHeader := tx.ShouldPropagateLegacyHeader() traceContext := tx.TraceContext() if !traceContext.Options.Recorded() { - r.setHeaders(req, traceContext, propagateLegacyHeader) + SetHeaders(req, traceContext, propagateLegacyHeader) return r.r.RoundTrip(req) } @@ -117,7 +117,7 @@ func (r *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { span = nil } - r.setHeaders(req, traceContext, propagateLegacyHeader) + SetHeaders(req, traceContext, propagateLegacyHeader) resp, err := r.r.RoundTrip(req) if span != nil { if err != nil { @@ -133,7 +133,8 @@ func (r *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { return resp, err } -func (r *roundTripper) setHeaders(req *http.Request, traceContext apm.TraceContext, propagateLegacyHeader bool) { +// SetHeaders sets traceparent and tracestate headers on an http request. +func SetHeaders(req *http.Request, traceContext apm.TraceContext, propagateLegacyHeader bool) { headerValue := FormatTraceparentHeader(traceContext) if propagateLegacyHeader { req.Header.Set(ElasticTraceparentHeader, headerValue)