Skip to content

Commit

Permalink
Client should verify response content-type (#679)
Browse files Browse the repository at this point in the history
  • Loading branch information
jhump authored Feb 8, 2024
1 parent dc78d86 commit 2290ed2
Show file tree
Hide file tree
Showing 7 changed files with 469 additions and 8 deletions.
4 changes: 4 additions & 0 deletions connect_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,7 @@ func TestUnavailableIfHostInvalid(t *testing.T) {
func TestBidiRequiresHTTP2(t *testing.T) {
t.Parallel()
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/connect+proto")
_, err := io.WriteString(w, "hello world")
assert.Nil(t, err)
})
Expand Down Expand Up @@ -841,6 +842,7 @@ func TestCompressMinBytesClient(t *testing.T) {
tb.Helper()
mux := http.NewServeMux()
mux.Handle("/", http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
writer.Header().Set("Content-Type", "application/proto")
assert.Equal(tb, request.Header.Get("Content-Encoding"), expect)
}))
server := memhttptest.NewServer(t, mux)
Expand Down Expand Up @@ -2231,6 +2233,7 @@ func TestStreamUnexpectedEOF(t *testing.T) {
name: "connect_excess_eof",
options: []connect.ClientOption{connect.WithProtoJSON()},
handler: func(responseWriter http.ResponseWriter, _ *http.Request) {
responseWriter.Header().Set("Content-Type", "application/connect+json")
_, err := responseWriter.Write(head[:])
assert.Nil(t, err)
_, err = responseWriter.Write(payload)
Expand All @@ -2252,6 +2255,7 @@ func TestStreamUnexpectedEOF(t *testing.T) {
name: "grpc-web_excess_eof",
options: []connect.ClientOption{connect.WithProtoJSON(), connect.WithGRPCWeb()},
handler: func(responseWriter http.ResponseWriter, _ *http.Request) {
responseWriter.Header().Set("Content-Type", "application/grpc-web+json")
_, err := responseWriter.Write(head[:])
assert.Nil(t, err)
_, err = responseWriter.Write(payload)
Expand Down
5 changes: 5 additions & 0 deletions duplex_http_call.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,11 @@ func (d *duplexHTTPCall) URL() *url.URL {
return d.request.URL
}

// Method returns the HTTP method for the request (GET or POST).
func (d *duplexHTTPCall) Method() string {
return d.request.Method
}

// SetMethod changes the method of the request before it is sent.
func (d *duplexHTTPCall) SetMethod(method string) {
d.request.Method = method
Expand Down
2 changes: 1 addition & 1 deletion error.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func NewError(c Code, underlying error) *Error {
// This is useful for clients trying to propagate partial failures from
// streaming RPCs. Often, these RPCs include error information in their
// response messages (for example, [gRPC server reflection] and
// OpenTelemtetry's [OTLP]). Clients propagating these errors up the stack
// OpenTelemetry's [OTLP]). Clients propagating these errors up the stack
// should use NewWireError to clarify that the error code, message, and details
// (if any) were explicitly sent by the server rather than inferred from a
// lower-level networking error or timeout.
Expand Down
91 changes: 85 additions & 6 deletions protocol_connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,21 @@ func (cc *connectUnaryClientConn) validateResponse(response *http.Response) *Err
}
cc.responseTrailer[strings.TrimPrefix(k, connectUnaryTrailerPrefix)] = v
}
err := connectValidateUnaryResponseContentType(
cc.marshaler.codec.Name(),
cc.duplexCall.Method(),
response.StatusCode,
response.Status,
getHeaderCanonical(response.Header, headerContentType),
)
if err != nil {
if IsNotModifiedError(err) {
// Allow access to response headers for this kind of error.
// RFC 9110 doesn't allow trailers on 304s, so we only need to include headers.
err.meta = cc.responseHeader.Clone()
}
return err
}
compression := getHeaderCanonical(response.Header, connectUnaryHeaderCompression)
if compression != "" &&
compression != compressionIdentity &&
Expand All @@ -522,12 +537,7 @@ func (cc *connectUnaryClientConn) validateResponse(response *http.Response) *Err
cc.compressionPools.CommaSeparatedNames(),
)
}
if response.StatusCode == http.StatusNotModified && cc.Spec().IdempotencyLevel == IdempotencyNoSideEffects {
serverErr := NewWireError(CodeUnknown, errNotModifiedClient)
// RFC 9110 doesn't allow trailers on 304s, so we only need to include headers.
serverErr.meta = cc.responseHeader.Clone()
return serverErr
} else if response.StatusCode != http.StatusOK {
if response.StatusCode != http.StatusOK {
unmarshaler := connectUnaryUnmarshaler{
reader: response.Body,
compressionPool: cc.compressionPools.Get(compression),
Expand Down Expand Up @@ -643,6 +653,14 @@ func (cc *connectStreamingClientConn) validateResponse(response *http.Response)
if response.StatusCode != http.StatusOK {
return errorf(connectHTTPToCode(response.StatusCode), "HTTP status %v", response.Status)
}
err := connectValidateStreamResponseContentType(
cc.codec.Name(),
cc.spec.StreamType,
getHeaderCanonical(response.Header, headerContentType),
)
if err != nil {
return err
}
compression := getHeaderCanonical(response.Header, connectStreamingHeaderCompression)
if compression != "" &&
compression != compressionIdentity &&
Expand Down Expand Up @@ -1324,3 +1342,64 @@ func queryValueReader(data string, base64Encoded bool) io.Reader {
}
return strings.NewReader(data)
}

func connectValidateUnaryResponseContentType(
requestCodecName string,
httpMethod string,
statusCode int,
statusMsg string,
responseContentType string,
) *Error {
if statusCode != http.StatusOK {
if statusCode == http.StatusNotModified && httpMethod == http.MethodGet {
return NewWireError(CodeUnknown, errNotModifiedClient)
}
// Error responses must be JSON-encoded.
if responseContentType == connectUnaryContentTypePrefix+codecNameJSON ||
responseContentType == connectUnaryContentTypePrefix+codecNameJSONCharsetUTF8 {
return nil
}
return NewError(
connectHTTPToCode(statusCode),
errors.New(statusMsg),
)
}
// Normal responses must have valid content-type that indicates same codec as the request.
responseCodecName := connectCodecFromContentType(
StreamTypeUnary,
responseContentType,
)
if responseCodecName == requestCodecName {
return nil
}
// HACK: We likely want a better way to handle the optional "charset" parameter
// for application/json, instead of hard-coding. But this suffices for now.
if (responseCodecName == codecNameJSON && requestCodecName == codecNameJSONCharsetUTF8) ||
(responseCodecName == codecNameJSONCharsetUTF8 && requestCodecName == codecNameJSON) {
// Both are JSON
return nil
}
return errorf(
CodeInternal,
"invalid content-type: %q; expecting %q",
responseContentType,
connectUnaryContentTypePrefix+requestCodecName,
)
}

func connectValidateStreamResponseContentType(requestCodecName string, streamType StreamType, responseContentType string) *Error {
// Responses must have valid content-type that indicates same codec as the request.
responseCodecName := connectCodecFromContentType(
streamType,
responseContentType,
)
if responseCodecName != requestCodecName {
return errorf(
CodeInternal,
"invalid content-type: %q; expecting %q",
responseContentType,
connectStreamingContentTypePrefix+requestCodecName,
)
}
return nil
}
Loading

0 comments on commit 2290ed2

Please sign in to comment.