Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Client should verify response content-type #679

Merged
merged 2 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions connect_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,7 @@ func TestUnavailableIfHostInvalid(t *testing.T) {
func TestBidiRequiresHTTP2(t *testing.T) {
t.Parallel()
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/connect+proto")
_, err := io.WriteString(w, "hello world")
assert.Nil(t, err)
})
Expand Down Expand Up @@ -841,6 +842,7 @@ func TestCompressMinBytesClient(t *testing.T) {
tb.Helper()
mux := http.NewServeMux()
mux.Handle("/", http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
writer.Header().Set("Content-Type", "application/proto")
assert.Equal(tb, request.Header.Get("Content-Encoding"), expect)
}))
server := memhttptest.NewServer(t, mux)
Expand Down Expand Up @@ -2231,6 +2233,7 @@ func TestStreamUnexpectedEOF(t *testing.T) {
name: "connect_excess_eof",
options: []connect.ClientOption{connect.WithProtoJSON()},
handler: func(responseWriter http.ResponseWriter, _ *http.Request) {
responseWriter.Header().Set("Content-Type", "application/connect+json")
_, err := responseWriter.Write(head[:])
assert.Nil(t, err)
_, err = responseWriter.Write(payload)
Expand All @@ -2252,6 +2255,7 @@ func TestStreamUnexpectedEOF(t *testing.T) {
name: "grpc-web_excess_eof",
options: []connect.ClientOption{connect.WithProtoJSON(), connect.WithGRPCWeb()},
handler: func(responseWriter http.ResponseWriter, _ *http.Request) {
responseWriter.Header().Set("Content-Type", "application/grpc-web+json")
_, err := responseWriter.Write(head[:])
assert.Nil(t, err)
_, err = responseWriter.Write(payload)
Expand Down
5 changes: 5 additions & 0 deletions duplex_http_call.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,11 @@ func (d *duplexHTTPCall) URL() *url.URL {
return d.request.URL
}

// Method returns the HTTP method for the request (GET or POST).
func (d *duplexHTTPCall) Method() string {
return d.request.Method
}

// SetMethod changes the method of the request before it is sent.
func (d *duplexHTTPCall) SetMethod(method string) {
d.request.Method = method
Expand Down
2 changes: 1 addition & 1 deletion error.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func NewError(c Code, underlying error) *Error {
// This is useful for clients trying to propagate partial failures from
// streaming RPCs. Often, these RPCs include error information in their
// response messages (for example, [gRPC server reflection] and
// OpenTelemtetry's [OTLP]). Clients propagating these errors up the stack
// OpenTelemetry's [OTLP]). Clients propagating these errors up the stack
// should use NewWireError to clarify that the error code, message, and details
// (if any) were explicitly sent by the server rather than inferred from a
// lower-level networking error or timeout.
Expand Down
91 changes: 85 additions & 6 deletions protocol_connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,21 @@ func (cc *connectUnaryClientConn) validateResponse(response *http.Response) *Err
}
cc.responseTrailer[strings.TrimPrefix(k, connectUnaryTrailerPrefix)] = v
}
err := connectValidateUnaryResponseContentType(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be moved into the validate response function?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is. Do you mean inlined? I had made it a separate function to make it easier to test.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry misread, thought the method on the duplexcall was due to it not being part of this function, but thats for the request params.

cc.marshaler.codec.Name(),
cc.duplexCall.Method(),
response.StatusCode,
response.Status,
getHeaderCanonical(response.Header, headerContentType),
)
if err != nil {
Copy link
Member

@akshayjshah akshayjshah Feb 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we scope err to the if block? It's visually a little ugly, but the functions in this portion of the code are long enough that I'd love to minimize scope where possible. Same below.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had not done so just because I find it a bit hard to read when the if keyword is far from the actual condition. But if that's accepted style, and even preferred in this case, I'll change it.

if IsNotModifiedError(err) {
// Allow access to response headers for this kind of error.
// RFC 9110 doesn't allow trailers on 304s, so we only need to include headers.
err.meta = cc.responseHeader.Clone()
}
return err
}
compression := getHeaderCanonical(response.Header, connectUnaryHeaderCompression)
if compression != "" &&
compression != compressionIdentity &&
Expand All @@ -522,12 +537,7 @@ func (cc *connectUnaryClientConn) validateResponse(response *http.Response) *Err
cc.compressionPools.CommaSeparatedNames(),
)
}
if response.StatusCode == http.StatusNotModified && cc.Spec().IdempotencyLevel == IdempotencyNoSideEffects {
serverErr := NewWireError(CodeUnknown, errNotModifiedClient)
// RFC 9110 doesn't allow trailers on 304s, so we only need to include headers.
serverErr.meta = cc.responseHeader.Clone()
return serverErr
} else if response.StatusCode != http.StatusOK {
if response.StatusCode != http.StatusOK {
unmarshaler := connectUnaryUnmarshaler{
reader: response.Body,
compressionPool: cc.compressionPools.Get(compression),
Expand Down Expand Up @@ -643,6 +653,14 @@ func (cc *connectStreamingClientConn) validateResponse(response *http.Response)
if response.StatusCode != http.StatusOK {
return errorf(connectHTTPToCode(response.StatusCode), "HTTP status %v", response.Status)
}
err := connectValidateStreamResponseContentType(
cc.codec.Name(),
cc.spec.StreamType,
getHeaderCanonical(response.Header, headerContentType),
)
if err != nil {
return err
}
compression := getHeaderCanonical(response.Header, connectStreamingHeaderCompression)
if compression != "" &&
compression != compressionIdentity &&
Expand Down Expand Up @@ -1324,3 +1342,64 @@ func queryValueReader(data string, base64Encoded bool) io.Reader {
}
return strings.NewReader(data)
}

func connectValidateUnaryResponseContentType(
requestCodecName string,
httpMethod string,
statusCode int,
statusMsg string,
responseContentType string,
) *Error {
if statusCode != http.StatusOK {
if statusCode == http.StatusNotModified && httpMethod == http.MethodGet {
return NewWireError(CodeUnknown, errNotModifiedClient)
}
// Error responses must be JSON-encoded.
if responseContentType == connectUnaryContentTypePrefix+codecNameJSON ||
responseContentType == connectUnaryContentTypePrefix+codecNameJSONCharsetUTF8 {
return nil
}
return NewError(
connectHTTPToCode(statusCode),
errors.New(statusMsg),
)
}
// Normal responses must have valid content-type that indicates same codec as the request.
responseCodecName := connectCodecFromContentType(
StreamTypeUnary,
responseContentType,
)
if responseCodecName == requestCodecName {
return nil
}
// HACK: We likely want a better way to handle the optional "charset" parameter
// for application/json, instead of hard-coding. But this suffices for now.
if (responseCodecName == codecNameJSON && requestCodecName == codecNameJSONCharsetUTF8) ||
(responseCodecName == codecNameJSONCharsetUTF8 && requestCodecName == codecNameJSON) {
// Both are JSON
return nil
}
return errorf(
CodeInternal,
"invalid content-type: %q; expecting %q",
responseContentType,
connectUnaryContentTypePrefix+requestCodecName,
)
}

func connectValidateStreamResponseContentType(requestCodecName string, streamType StreamType, responseContentType string) *Error {
// Responses must have valid content-type that indicates same codec as the request.
responseCodecName := connectCodecFromContentType(
streamType,
responseContentType,
)
if responseCodecName != requestCodecName {
return errorf(
CodeInternal,
"invalid content-type: %q; expecting %q",
responseContentType,
connectStreamingContentTypePrefix+requestCodecName,
)
}
return nil
}
220 changes: 220 additions & 0 deletions protocol_connect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"strings"
"testing"
Expand Down Expand Up @@ -93,3 +94,222 @@
assert.Equal(t, unmarshaler.Trailer().Values("Mixed-Canonical"), []string{"b", "b"})
assert.Equal(t, unmarshaler.Trailer().Values("Canonical-Header"), []string{"c"})
}

func TestConnectValidateUnaryResponseContentType(t *testing.T) {
t.Parallel()
testCases := []struct {
codecName string
get bool
statusCode int
responseContentType string
expectCode Code
expectBadContentType bool
expectNotModified bool
}{
// Allowed content-types for OK responses.
{
codecName: codecNameProto,
statusCode: http.StatusOK,
responseContentType: "application/proto",
},
{
codecName: codecNameJSON,
statusCode: http.StatusOK,
responseContentType: "application/json",
},
{
codecName: codecNameJSON,
statusCode: http.StatusOK,
responseContentType: "application/json; charset=utf-8",
},
{
codecName: codecNameJSONCharsetUTF8,
statusCode: http.StatusOK,
responseContentType: "application/json",
},
{
codecName: codecNameJSONCharsetUTF8,
statusCode: http.StatusOK,
responseContentType: "application/json; charset=utf-8",
},
// Allowed content-types for error responses.
{
codecName: codecNameProto,
statusCode: http.StatusNotFound,
responseContentType: "application/json",
},
{
codecName: codecNameProto,
statusCode: http.StatusBadRequest,
responseContentType: "application/json; charset=utf-8",
},
{
codecName: codecNameJSON,
statusCode: http.StatusInternalServerError,
responseContentType: "application/json",
},
{
codecName: codecNameJSON,
statusCode: http.StatusPreconditionFailed,
responseContentType: "application/json; charset=utf-8",
},
// 304 Not Modified for GET request gets a special error, regardless of content-type
{
codecName: codecNameProto,
get: true,
statusCode: http.StatusNotModified,
responseContentType: "application/json",
expectCode: CodeUnknown,
expectNotModified: true,
},
{
codecName: codecNameJSON,
get: true,
statusCode: http.StatusNotModified,
responseContentType: "application/json",
expectCode: CodeUnknown,
expectNotModified: true,
},
// OK status, invalid content-type
{
codecName: codecNameProto,
statusCode: http.StatusOK,
responseContentType: "application/proto; charset=utf-8",
expectCode: CodeInternal,
expectBadContentType: true,
},
{
codecName: codecNameProto,
statusCode: http.StatusOK,
responseContentType: "application/json",
expectCode: CodeInternal,
expectBadContentType: true,
},
{
codecName: codecNameJSON,
statusCode: http.StatusOK,
responseContentType: "application/proto",
expectCode: CodeInternal,
expectBadContentType: true,
},
{
codecName: codecNameJSON,
statusCode: http.StatusOK,
responseContentType: "some/garbage",
expectCode: CodeInternal,
expectBadContentType: true,
},
// Error status, invalid content-type, returns code based on HTTP status code
{
codecName: codecNameProto,
statusCode: http.StatusNotFound,
responseContentType: "application/proto",
expectCode: connectHTTPToCode(http.StatusNotFound),
},
{
codecName: codecNameJSON,
statusCode: http.StatusBadRequest,
responseContentType: "some/garbage",
expectCode: connectHTTPToCode(http.StatusBadRequest),
},
{
codecName: codecNameJSON,
statusCode: http.StatusTooManyRequests,
responseContentType: "some/garbage",
expectCode: connectHTTPToCode(http.StatusTooManyRequests),
},
}
for _, testCase := range testCases {
testCase := testCase
httpMethod := http.MethodPost
if testCase.get {
httpMethod = http.MethodGet
}
testCaseName := fmt.Sprintf("%s_%s->%d_%s", httpMethod, testCase.codecName, testCase.statusCode, testCase.responseContentType)
t.Run(testCaseName, func(t *testing.T) {
t.Parallel()
err := connectValidateUnaryResponseContentType(
testCase.codecName,
httpMethod,
testCase.statusCode,
http.StatusText(testCase.statusCode),
testCase.responseContentType,
)
if testCase.expectCode == 0 {
assert.Nil(t, err)
} else if assert.NotNil(t, err) {
assert.Equal(t, CodeOf(err), testCase.expectCode)
if testCase.expectNotModified {

Check failure on line 242 in protocol_connect_test.go

View workflow job for this annotation

GitHub Actions / ci (1.21.x)

ifElseChain: rewrite if-else to switch statement (gocritic)
assert.ErrorIs(t, err, errNotModified)
} else if testCase.expectBadContentType {
assert.True(t, strings.Contains(err.Message(), fmt.Sprintf("invalid content-type: %q; expecting", testCase.responseContentType)))
} else {
assert.Equal(t, err.Message(), http.StatusText(testCase.statusCode))
}
}
})
}
}

func TestConnectValidateStreamResponseContentType(t *testing.T) {
t.Parallel()
testCases := []struct {
codecName string
responseContentType string
expectErr bool
}{
// Allowed content-types
{
codecName: codecNameProto,
responseContentType: "application/connect+proto",
},
{
codecName: codecNameJSON,
responseContentType: "application/connect+json",
},
// Disallowed content-types
{
codecName: codecNameProto,
responseContentType: "application/proto",
expectErr: true,
},
{
codecName: codecNameJSON,
responseContentType: "application/json",
expectErr: true,
},
{
codecName: codecNameJSON,
responseContentType: "application/json; charset=utf-8",
expectErr: true,
},
{
codecName: codecNameJSON,
responseContentType: "application/connect+json; charset=utf-8",
expectErr: true,
},
{
codecName: codecNameProto,
responseContentType: "some/garbage",
expectErr: true,
},
}
for _, testCase := range testCases {
testCase := testCase
testCaseName := fmt.Sprintf("%s->%s", testCase.codecName, testCase.responseContentType)
t.Run(testCaseName, func(t *testing.T) {
t.Parallel()
err := connectValidateStreamResponseContentType(
testCase.codecName,
StreamTypeServer,
testCase.responseContentType,
)
if !testCase.expectErr {
assert.Nil(t, err)
} else if assert.NotNil(t, err) {
assert.Equal(t, CodeOf(err), CodeInternal)
assert.True(t, strings.Contains(err.Message(), fmt.Sprintf("invalid content-type: %q; expecting", testCase.responseContentType)))
}
})
}
}
Loading
Loading