Skip to content

Commit

Permalink
feat(apmgrpc): wrap the server-stream with transaction-ctx (#1151)
Browse files Browse the repository at this point in the history
* feat(apmgrpc): wrap the server-stream with transaction-ctx

This allows the handler to retrieve the transaction from it's context, to e.g. create spans.

* define wrappedServerStream instead of importing from go-grpc-middleware
  • Loading branch information
bendiktv2 authored Nov 16, 2021
1 parent ee25156 commit 98763b8
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 4 deletions.
2 changes: 0 additions & 2 deletions module/apmgrpc/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ github.com/santhosh-tekuri/jsonschema v1.2.4/go.mod h1:TEAUOeZSmIxTTuHatJzrvARHi
github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/yuin/goldmark v1.1.27 h1:nqDD4MMMQA0lmWq03Z2/myGPYLQoXtmi0rGVs95ntbo=
Expand Down
22 changes: 21 additions & 1 deletion module/apmgrpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ func NewStreamServerInterceptor(o ...ServerOption) grpc.StreamServerInterceptor
tx, ctx := startTransaction(ctx, opts.tracer, info.FullMethod)
defer tx.End()

wrapped := wrapServerStream(stream)
wrapped.wrappedContext = ctx

// TODO(axw) define span context schema for RPC,
// including at least the peer address.

Expand All @@ -153,7 +156,7 @@ func NewStreamServerInterceptor(o ...ServerOption) grpc.StreamServerInterceptor
}
setTransactionResult(tx, err)
}()
return handler(srv, stream)
return handler(srv, wrapped)
}
}

Expand Down Expand Up @@ -311,3 +314,20 @@ func WithServerStreamIgnorer(s StreamIgnorerFunc) ServerOption {
o.streamIgnorer = s
}
}

// wrappedServerStream is a thin wrapper around grpc.ServerStream that allows modifying context.
type wrappedServerStream struct {
grpc.ServerStream
// wrappedContext is the wrapper's own Context. You can assign it.
wrappedContext context.Context
}

// Context returns the wrapper's WrappedContext, overwriting the nested grpc.ServerStream.Context()
func (w *wrappedServerStream) Context() context.Context {
return w.wrappedContext
}

// wrapServerStream returns a ServerStream that has the ability to overwrite context.
func wrapServerStream(stream grpc.ServerStream) *wrappedServerStream {
return &wrappedServerStream{ServerStream: stream, wrappedContext: stream.Context()}
}
12 changes: 11 additions & 1 deletion module/apmgrpc/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ func TestServerStream(t *testing.T) {
tracer, transport := transporttest.NewRecorderTracer()
defer tracer.Close()

s, _, addr := newAccumulatorServer(t, tracer, apmgrpc.WithRecovery())
s, accumulatorServer, addr := newAccumulatorServer(t, tracer, apmgrpc.WithRecovery())
defer s.GracefulStop()

conn, client := newAccumulatorClient(t, addr)
Expand Down Expand Up @@ -290,6 +290,13 @@ func TestServerStream(t *testing.T) {
tracer.Flush(nil)
transactions := transport.Payloads().Transactions
require.Len(t, transactions, 1)

// The transaction should have propagated into the accumulatorServer
require.NotNil(t, accumulatorServer.transactionFromContext)
expectedTraceID := fmt.Sprintf("%x", transactions[0].TraceID)
actualTraceID := accumulatorServer.transactionFromContext.TraceContext().Trace.String()
require.NotEmpty(t, expectedTraceID)
require.Equal(t, expectedTraceID, actualTraceID)
}

func TestServerTLS(t *testing.T) {
Expand Down Expand Up @@ -457,9 +464,12 @@ func (s *helloworldServer) SayHello(ctx context.Context, req *pb.HelloRequest) (
type accumulator struct {
panic bool
err error

transactionFromContext *apm.Transaction
}

func (a *accumulator) Accumulate(srv testservice.Accumulator_AccumulateServer) error {
a.transactionFromContext = apm.TransactionFromContext(srv.Context())
if a.panic {
panic(a.err)
}
Expand Down

0 comments on commit 98763b8

Please sign in to comment.