diff --git a/plugin/grpctrace/interceptor.go b/plugin/grpctrace/interceptor.go index a84aec9e12d..0981954c4af 100644 --- a/plugin/grpctrace/interceptor.go +++ b/plugin/grpctrace/interceptor.go @@ -43,9 +43,29 @@ var ( messageUncompressedSizeKey = kv.Key("message.uncompressed_size") ) +type messageType string + +// Event adds an event of the messageType to the span associated with the +// passed context with id and size (if message is a proto message). +func (m messageType) Event(ctx context.Context, id int, message interface{}) { + span := trace.SpanFromContext(ctx) + if p, ok := message.(proto.Message); ok { + span.AddEvent(ctx, "message", + messageTypeKey.String(string(m)), + messageIDKey.Int(id), + messageUncompressedSizeKey.Int(proto.Size(p)), + ) + } else { + span.AddEvent(ctx, "message", + messageTypeKey.String(string(m)), + messageIDKey.Int(id), + ) + } +} + const ( - messageTypeSent = "SENT" - messageTypeReceived = "RECEIVED" + messageSent messageType = "SENT" + messageReceived messageType = "RECEIVED" ) // UnaryClientInterceptor returns a grpc.UnaryClientInterceptor suitable @@ -80,11 +100,11 @@ func UnaryClientInterceptor(tracer trace.Tracer) grpc.UnaryClientInterceptor { Inject(ctx, &metadataCopy) ctx = metadata.NewOutgoingContext(ctx, metadataCopy) - addEventForMessageSent(ctx, 1, req) + messageSent.Event(ctx, 1, req) err := invoker(ctx, method, req, reply, cc, opts...) - addEventForMessageReceived(ctx, 1, reply) + messageReceived.Event(ctx, 1, reply) if err != nil { s, _ := status.FromError(err) @@ -134,7 +154,7 @@ func (w *clientStream) RecvMsg(m interface{}) error { w.events <- streamEvent{errorEvent, err} } else { w.receivedMessageID++ - addEventForMessageReceived(w.Context(), w.receivedMessageID, m) + messageReceived.Event(w.Context(), w.receivedMessageID, m) } return err @@ -144,7 +164,7 @@ func (w *clientStream) SendMsg(m interface{}) error { err := w.ClientStream.SendMsg(m) w.sentMessageID++ - addEventForMessageSent(w.Context(), w.sentMessageID, m) + messageSent.Event(w.Context(), w.sentMessageID, m) if err != nil { w.events <- streamEvent{errorEvent, err} @@ -297,15 +317,15 @@ func UnaryServerInterceptor(tracer trace.Tracer) grpc.UnaryServerInterceptor { ) defer span.End() - addEventForMessageReceived(ctx, 1, req) + messageReceived.Event(ctx, 1, req) resp, err := handler(ctx, req) - - addEventForMessageSent(ctx, 1, resp) - if err != nil { s, _ := status.FromError(err) span.SetStatus(s.Code(), s.Message()) + messageSent.Event(ctx, 1, s.Proto()) + } else { + messageSent.Event(ctx, 1, resp) } return resp, err @@ -331,7 +351,7 @@ func (w *serverStream) RecvMsg(m interface{}) error { if err == nil { w.receivedMessageID++ - addEventForMessageReceived(w.Context(), w.receivedMessageID, m) + messageReceived.Event(w.Context(), w.receivedMessageID, m) } return err @@ -341,7 +361,7 @@ func (w *serverStream) SendMsg(m interface{}) error { err := w.ServerStream.SendMsg(m) w.sentMessageID++ - addEventForMessageSent(w.Context(), w.sentMessageID, m) + messageSent.Event(w.Context(), w.sentMessageID, m) return err } @@ -435,25 +455,3 @@ func serviceFromFullMethod(method string) string { return match[1] } - -func addEventForMessageReceived(ctx context.Context, id int, m interface{}) { - size := proto.Size(m.(proto.Message)) - - span := trace.SpanFromContext(ctx) - span.AddEvent(ctx, "message", - messageTypeKey.String(messageTypeReceived), - messageIDKey.Int(id), - messageUncompressedSizeKey.Int(size), - ) -} - -func addEventForMessageSent(ctx context.Context, id int, m interface{}) { - size := proto.Size(m.(proto.Message)) - - span := trace.SpanFromContext(ctx) - span.AddEvent(ctx, "message", - messageTypeKey.String(messageTypeSent), - messageIDKey.Int(id), - messageUncompressedSizeKey.Int(size), - ) -} diff --git a/plugin/grpctrace/interceptor_test.go b/plugin/grpctrace/interceptor_test.go index a92c7177f67..211a9c36efc 100644 --- a/plugin/grpctrace/interceptor_test.go +++ b/plugin/grpctrace/interceptor_test.go @@ -20,8 +20,12 @@ import ( "time" "github.com/golang/protobuf/proto" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "google.golang.org/grpc" + "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" "go.opentelemetry.io/otel/api/kv" "go.opentelemetry.io/otel/api/kv/value" @@ -373,3 +377,37 @@ func TestStreamClientInterceptor(t *testing.T) { validate("RECEIVED", events[i+1].Attributes) } } + +func TestServerInterceptorError(t *testing.T) { + exp := &testExporter{spanMap: make(map[string]*export.SpanData)} + tp, err := sdktrace.NewProvider( + sdktrace.WithSyncer(exp), + sdktrace.WithConfig(sdktrace.Config{ + DefaultSampler: sdktrace.AlwaysSample(), + }), + ) + require.NoError(t, err) + + tracer := tp.Tracer("grpctrace/Server") + usi := UnaryServerInterceptor(tracer) + deniedErr := status.Error(codes.PermissionDenied, "PERMISSION_DENIED_TEXT") + handler := func(_ context.Context, _ interface{}) (interface{}, error) { + return nil, deniedErr + } + _, err = usi(context.Background(), &mockProtoMessage{}, &grpc.UnaryServerInfo{}, handler) + require.Error(t, err) + assert.Equal(t, err, deniedErr) + + span, ok := exp.spanMap[""] + if !ok { + t.Fatalf("failed to export error span") + } + assert.Equal(t, span.StatusCode, codes.PermissionDenied) + assert.Contains(t, deniedErr.Error(), span.StatusMessage) + assert.Len(t, span.MessageEvents, 2) + assert.Equal(t, []kv.KeyValue{ + kv.String("message.type", "SENT"), + kv.Int("message.id", 1), + kv.Int("message.uncompressed_size", 26), + }, span.MessageEvents[1].Attributes) +}