diff --git a/protocol_connect.go b/protocol_connect.go index 8e6294c7..699440fe 100644 --- a/protocol_connect.go +++ b/protocol_connect.go @@ -864,6 +864,13 @@ func (u *connectStreamingUnmarshaler) Unmarshal(message any) *Error { if err := json.Unmarshal(env.Data.Bytes(), &end); err != nil { return errorf(CodeInternal, "unmarshal end stream message: %w", err) } + for name, value := range end.Trailer { + canonical := http.CanonicalHeaderKey(name) + if name != canonical { + delete(end.Trailer, name) + end.Trailer[canonical] = append(end.Trailer[canonical], value...) + } + } u.trailer = end.Trailer u.endStreamErr = end.Error.asError() return errSpecialEnvelope diff --git a/protocol_connect_test.go b/protocol_connect_test.go index 2a9f587b..7951daf1 100644 --- a/protocol_connect_test.go +++ b/protocol_connect_test.go @@ -15,7 +15,9 @@ package connect import ( + "bytes" "encoding/json" + "net/http" "strings" "testing" "time" @@ -54,3 +56,40 @@ func TestConnectErrorDetailMarshalingNoDescriptor(t *testing.T) { assert.Nil(t, err) assert.Equal(t, string(encoded), raw) } + +func TestConnectEndOfResponseCanonicalTrailers(t *testing.T) { + t.Parallel() + + buffer := bytes.Buffer{} + bufferPool := newBufferPool() + + endStreamMessage := connectEndStreamMessage{Trailer: make(http.Header)} + endStreamMessage.Trailer["not-canonical-header"] = []string{"a"} + endStreamMessage.Trailer["mixed-Canonical"] = []string{"b"} + endStreamMessage.Trailer["Mixed-Canonical"] = []string{"b"} + endStreamMessage.Trailer["Canonical-Header"] = []string{"c"} + endStreamData, err := json.Marshal(endStreamMessage) + assert.Nil(t, err) + + writer := envelopeWriter{ + writer: &buffer, + bufferPool: bufferPool, + } + err = writer.Write(&envelope{ + Flags: connectFlagEnvelopeEndStream, + Data: bytes.NewBuffer(endStreamData), + }) + assert.Nil(t, err) + + unmarshaler := connectStreamingUnmarshaler{ + envelopeReader: envelopeReader{ + reader: &buffer, + bufferPool: bufferPool, + }, + } + err = unmarshaler.Unmarshal(nil) // parameter won't be used + assert.ErrorIs(t, err, errSpecialEnvelope) + assert.Equal(t, unmarshaler.Trailer().Values("Not-Canonical-Header"), []string{"a"}) + assert.Equal(t, unmarshaler.Trailer().Values("Mixed-Canonical"), []string{"b", "b"}) + assert.Equal(t, unmarshaler.Trailer().Values("Canonical-Header"), []string{"c"}) +}