Skip to content

Commit

Permalink
send traceparent header to ES (#1002)
Browse files Browse the repository at this point in the history
* send traceparent header to ES

* propagate tracestate header

* propagate headers for non-sampled tx/dropped span

Propagate the parent transaction's traceparent and
tracestate headers when there's a non-sampled
transaction in the context, or if the span is
dropped

* fix nil pointer error

* do not set traceheaders when transaction is nil
  • Loading branch information
stuartnelson3 authored Aug 11, 2021
1 parent 2045b7b commit fbf6fb2
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 4 deletions.
12 changes: 11 additions & 1 deletion module/apmelasticsearch/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,26 @@ type roundTripper struct {
func (r *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
ctx := req.Context()
tx := apm.TransactionFromContext(ctx)
if tx == nil || !tx.Sampled() {
if tx == nil {
return r.r.RoundTrip(req)
}
traceContext := tx.TraceContext()
if !tx.Sampled() {
apmhttp.SetHeaders(req, traceContext, false)
return r.r.RoundTrip(req)
}

propagateLegacyHeader := tx.ShouldPropagateLegacyHeader()
name := requestName(req)
span := tx.StartSpan(name, "db.elasticsearch", apm.SpanFromContext(ctx))

if span.Dropped() {
span.End()
apmhttp.SetHeaders(req, traceContext, propagateLegacyHeader)
return r.r.RoundTrip(req)
}

traceContext = span.TraceContext()
statement, req := captureSearchStatement(req)
username, _, _ := req.BasicAuth()
ctx = apm.ContextWithSpan(ctx, span)
Expand All @@ -89,6 +98,7 @@ func (r *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
User: username,
})

apmhttp.SetHeaders(req, traceContext, propagateLegacyHeader)
resp, err := r.r.RoundTrip(req)
if err != nil {
span.End()
Expand Down
113 changes: 113 additions & 0 deletions module/apmelasticsearch/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,12 @@ import (
"github.com/stretchr/testify/require"
"golang.org/x/net/context/ctxhttp"

"go.elastic.co/apm"
"go.elastic.co/apm/apmtest"
"go.elastic.co/apm/model"
"go.elastic.co/apm/module/apmelasticsearch"
"go.elastic.co/apm/module/apmhttp"
"go.elastic.co/apm/transport/transporttest"
)

func TestWrapRoundTripper(t *testing.T) {
Expand Down Expand Up @@ -303,6 +306,116 @@ func TestDestination(t *testing.T) {
test("http://[2001:db8::1]:80/_search", "2001:db8::1", 80)
}

func TestTraceHeaders(t *testing.T) {
headers := make(map[string]string)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
for k, vs := range req.Header {
headers[k] = strings.Join(vs, " ")
}
}))
defer server.Close()
client := &http.Client{Transport: apmelasticsearch.WrapRoundTripper(http.DefaultTransport)}

req, err := http.NewRequest("GET", server.URL, nil)
require.NoError(t, err)

_, _, _ = apmtest.WithTransaction(func(ctx context.Context) {
_, err := client.Do(req.WithContext(ctx))
assert.NoError(t, err)
})

assert.Contains(t, headers, apmhttp.ElasticTraceparentHeader)
assert.Contains(t, headers, apmhttp.W3CTraceparentHeader)
assert.Contains(t, headers, apmhttp.TracestateHeader)
}

func TestClientSpanDropped(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Write([]byte(req.Header.Get("Traceparent")))
}))
defer server.Close()

tracer, transport := transporttest.NewRecorderTracer()
defer tracer.Close()

tracer.SetMaxSpans(1)
tx := tracer.StartTransaction("name", "type")
ctx := apm.ContextWithTransaction(context.Background(), tx)

var responseBodies []string
for i := 0; i < 2; i++ {
body, err := doGET(ctx, server.URL)
require.NoError(t, err)
responseBodies = append(responseBodies, body)
}

tx.End()
tracer.Flush(nil)
payloads := transport.Payloads()
require.Len(t, payloads.Spans, 1)
transaction := payloads.Transactions[0]
span := payloads.Spans[0] // for first request

clientTraceContext, err := apmhttp.ParseTraceparentHeader(string(responseBodies[0]))
require.NoError(t, err)
assert.Equal(t, span.TraceID, model.TraceID(clientTraceContext.Trace))
assert.Equal(t, span.ID, model.SpanID(clientTraceContext.Span))

clientTraceContext, err = apmhttp.ParseTraceparentHeader(string(responseBodies[1]))
require.NoError(t, err)
assert.Equal(t, transaction.TraceID, model.TraceID(clientTraceContext.Trace))
assert.Equal(t, transaction.ID, model.SpanID(clientTraceContext.Span))
}

func TestClientTransactionUnsampled(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Write([]byte(req.Header.Get("Traceparent")))
}))
defer server.Close()

tracer, transport := transporttest.NewRecorderTracer()
defer tracer.Close()
tracer.SetSampler(apm.NewRatioSampler(0)) // sample nothing

tx := tracer.StartTransaction("name", "type")
ctx := apm.ContextWithTransaction(context.Background(), tx)
body, err := doGET(ctx, server.URL)
require.NoError(t, err)

tx.End()
tracer.Flush(nil)

payloads := transport.Payloads()
require.Len(t, payloads.Transactions, 1)
require.Len(t, payloads.Spans, 0)
transaction := payloads.Transactions[0]

clientTraceContext, err := apmhttp.ParseTraceparentHeader(string(body))
require.NoError(t, err)
assert.Equal(t, transaction.TraceID, model.TraceID(clientTraceContext.Trace))
assert.Equal(t, transaction.ID, model.SpanID(clientTraceContext.Span))
}

func doGET(ctx context.Context, url string) (string, error) {
client := &http.Client{Transport: apmelasticsearch.WrapRoundTripper(http.DefaultTransport)}
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return "", err
}
resp, err := client.Do(req.WithContext(ctx))
if err != nil {
return "", err
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return "", err
}
defer resp.Body.Close()

return string(body), nil
}

type readCloser struct {
io.Reader
closed bool
Expand Down
7 changes: 4 additions & 3 deletions module/apmhttp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func (r *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
propagateLegacyHeader := tx.ShouldPropagateLegacyHeader()
traceContext := tx.TraceContext()
if !traceContext.Options.Recorded() {
r.setHeaders(req, traceContext, propagateLegacyHeader)
SetHeaders(req, traceContext, propagateLegacyHeader)
return r.r.RoundTrip(req)
}

Expand All @@ -117,7 +117,7 @@ func (r *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
span = nil
}

r.setHeaders(req, traceContext, propagateLegacyHeader)
SetHeaders(req, traceContext, propagateLegacyHeader)
resp, err := r.r.RoundTrip(req)
if span != nil {
if err != nil {
Expand All @@ -133,7 +133,8 @@ func (r *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return resp, err
}

func (r *roundTripper) setHeaders(req *http.Request, traceContext apm.TraceContext, propagateLegacyHeader bool) {
// SetHeaders sets traceparent and tracestate headers on an http request.
func SetHeaders(req *http.Request, traceContext apm.TraceContext, propagateLegacyHeader bool) {
headerValue := FormatTraceparentHeader(traceContext)
if propagateLegacyHeader {
req.Header.Set(ElasticTraceparentHeader, headerValue)
Expand Down

0 comments on commit fbf6fb2

Please sign in to comment.