diff --git a/protocol_grpc.go b/protocol_grpc.go index 4b5e4492..a102dcdd 100644 --- a/protocol_grpc.go +++ b/protocol_grpc.go @@ -582,6 +582,17 @@ type grpcMarshaler struct { func (m *grpcMarshaler) MarshalWebTrailers(trailer http.Header) *Error { raw := m.envelopeWriter.bufferPool.Get() defer m.envelopeWriter.bufferPool.Put(raw) + for key, values := range trailer { + // Per the Go specification, keys inserted during iteration may be produced + // later in the iteration or may be skipped. For safety, avoid mutating the + // map if the key is already lower-cased. + lower := strings.ToLower(key) + if key == lower { + continue + } + delete(trailer, key) + trailer[lower] = values + } if err := trailer.Write(raw); err != nil { return errorf(CodeInternal, "format trailers: %w", err) } diff --git a/protocol_grpc_test.go b/protocol_grpc_test.go index 726e9e24..8391a1f0 100644 --- a/protocol_grpc_test.go +++ b/protocol_grpc_test.go @@ -174,3 +174,23 @@ func TestGRPCPercentEncoding(t *testing.T) { roundtrip(`foo%bar`) roundtrip("fiancée") } + +func TestGRPCWebTrailerMarshalling(t *testing.T) { + t.Parallel() + responseWriter := httptest.NewRecorder() + marshaler := grpcMarshaler{ + envelopeWriter: envelopeWriter{ + writer: responseWriter, + bufferPool: newBufferPool(), + }, + } + trailer := http.Header{} + trailer.Add("grpc-status", "0") + trailer.Add("Grpc-Message", "Foo") + trailer.Add("User-Provided", "bar") + err := marshaler.MarshalWebTrailers(trailer) + assert.Nil(t, err) + responseWriter.Body.Next(5) // skip flags and message length + marshalled := responseWriter.Body.String() + assert.Equal(t, marshalled, "grpc-message: Foo\r\ngrpc-status: 0\r\nuser-provided: bar\r\n") +}