diff --git a/test/end2end_test.go b/test/end2end_test.go index ba1d672ed910..6574a64fc5c1 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -6143,6 +6143,73 @@ func TestServeExitsWhenListenerClosed(t *testing.T) { } } +// Service handler returns status with invalid utf8 message. +func TestStatusInvalidUTF8Message(t *testing.T) { + defer leakcheck.Check(t) + + var ( + origMsg = string([]byte{0xff, 0xfe, 0xfd}) + wantMsg = "���" + ) + + ss := &stubServer{ + emptyCall: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + return nil, status.Errorf(codes.Internal, origMsg) + }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + if _, err := ss.client.EmptyCall(ctx, &testpb.Empty{}); status.Convert(err).Message() != wantMsg { + t.Fatalf("ss.client.EmptyCall(_, _) = _, %v (msg %q); want _, err with msg %q", err, status.Convert(err).Message(), wantMsg) + } +} + +// Service handler returns status with details and invalid utf8 message. Proto +// will fail to marshal the status because of the invalid utf8 message. Details +// will be dropped when sending. +func TestStatusInvalidUTF8Details(t *testing.T) { + defer leakcheck.Check(t) + + var ( + origMsg = string([]byte{0xff, 0xfe, 0xfd}) + wantMsg = "���" + ) + + ss := &stubServer{ + emptyCall: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + st := status.New(codes.Internal, origMsg) + st, err := st.WithDetails(&testpb.Empty{}) + if err != nil { + return nil, err + } + return nil, st.Err() + }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + _, err := ss.client.EmptyCall(ctx, &testpb.Empty{}) + st := status.Convert(err) + if st.Message() != wantMsg { + t.Fatalf("ss.client.EmptyCall(_, _) = _, %v (msg %q); want _, err with msg %q", err, st.Message(), wantMsg) + } + if len(st.Details()) != 0 { + // Details should be dropped on the server side. + t.Fatalf("RPC status contain details: %v, want no details", st.Details()) + } +} + func TestClientDoesntDeadlockWhileWritingErrornousLargeMessages(t *testing.T) { defer leakcheck.Check(t) for _, e := range listTestEnv() { diff --git a/transport/http2_server.go b/transport/http2_server.go index 3643e823d97f..3303a9b15f74 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -38,6 +38,7 @@ import ( "google.golang.org/grpc/channelz" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/grpclog" "google.golang.org/grpc/internal/grpcrand" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/metadata" @@ -769,10 +770,10 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { stBytes, err := proto.Marshal(p) if err != nil { // TODO: return error instead, when callers are able to handle it. - panic(err) + grpclog.Errorf("transport: failed to marshal rpc status: %v, error: %v", p, err) + } else { + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status-details-bin", Value: encodeBinHeader(stBytes)}) } - - headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status-details-bin", Value: encodeBinHeader(stBytes)}) } // Attach the trailer metadata. diff --git a/transport/http_util.go b/transport/http_util.go index 835c81269467..fe5554737ec5 100644 --- a/transport/http_util.go +++ b/transport/http_util.go @@ -28,6 +28,7 @@ import ( "strconv" "strings" "time" + "unicode/utf8" "github.com/golang/protobuf/proto" "golang.org/x/net/http2" @@ -442,11 +443,12 @@ const ( ) // encodeGrpcMessage is used to encode status code in header field -// "grpc-message". -// It checks to see if each individual byte in msg is an -// allowable byte, and then either percent encoding or passing it through. -// When percent encoding, the byte is converted into hexadecimal notation -// with a '%' prepended. +// "grpc-message". It does percent encoding and also replaces invalid utf-8 +// characters with Unicode replacement character. +// +// It checks to see if each individual byte in msg is an allowable byte, and +// then either percent encoding or passing it through. When percent encoding, +// the byte is converted into hexadecimal notation with a '%' prepended. func encodeGrpcMessage(msg string) string { if msg == "" { return "" @@ -463,14 +465,26 @@ func encodeGrpcMessage(msg string) string { func encodeGrpcMessageUnchecked(msg string) string { var buf bytes.Buffer - lenMsg := len(msg) - for i := 0; i < lenMsg; i++ { - c := msg[i] - if c >= spaceByte && c < tildaByte && c != percentByte { - buf.WriteByte(c) - } else { - buf.WriteString(fmt.Sprintf("%%%02X", c)) + for len(msg) > 0 { + r, size := utf8.DecodeRuneInString(msg) + for _, b := range []byte(string(r)) { + if size > 1 { + // If size > 1, r is not ascii. Always do percent encoding. + buf.WriteString(fmt.Sprintf("%%%02X", b)) + continue + } + + // The for loop is necessary even if size == 1. r could be + // utf8.RuneError. + // + // fmt.Sprintf("%%%02X", utf8.RuneError) gives "%FFFD". + if b >= spaceByte && b < tildaByte && b != percentByte { + buf.WriteByte(b) + } else { + buf.WriteString(fmt.Sprintf("%%%02X", b)) + } } + msg = msg[size:] } return buf.String() } diff --git a/transport/http_util_test.go b/transport/http_util_test.go index c3754781df9f..1295a2f60430 100644 --- a/transport/http_util_test.go +++ b/transport/http_util_test.go @@ -102,12 +102,14 @@ func TestEncodeGrpcMessage(t *testing.T) { }{ {"", ""}, {"Hello", "Hello"}, - {"my favorite character is \u0000", "my favorite character is %00"}, - {"my favorite character is %", "my favorite character is %25"}, + {"\u0000", "%00"}, + {"%", "%25"}, + {"系统", "%E7%B3%BB%E7%BB%9F"}, + {string([]byte{0xff, 0xfe, 0xfd}), "%EF%BF%BD%EF%BF%BD%EF%BF%BD"}, } { actual := encodeGrpcMessage(tt.input) if tt.expected != actual { - t.Errorf("encodeGrpcMessage(%v) = %v, want %v", tt.input, actual, tt.expected) + t.Errorf("encodeGrpcMessage(%q) = %q, want %q", tt.input, actual, tt.expected) } } } @@ -123,10 +125,36 @@ func TestDecodeGrpcMessage(t *testing.T) { {"H%6", "H%6"}, {"%G0", "%G0"}, {"%E7%B3%BB%E7%BB%9F", "系统"}, + {"%EF%BF%BD", "�"}, } { actual := decodeGrpcMessage(tt.input) if tt.expected != actual { - t.Errorf("dncodeGrpcMessage(%v) = %v, want %v", tt.input, actual, tt.expected) + t.Errorf("dncodeGrpcMessage(%q) = %q, want %q", tt.input, actual, tt.expected) + } + } +} + +// Decode an encoded string should get the same thing back, except for invalid +// utf8 chars. +func TestDecodeEncodeGrpcMessage(t *testing.T) { + testCases := []struct { + orig string + want string + }{ + {"", ""}, + {"hello", "hello"}, + {"h%6", "h%6"}, + {"%G0", "%G0"}, + {"系统", "系统"}, + {"Hello, 世界", "Hello, 世界"}, + + {string([]byte{0xff, 0xfe, 0xfd}), "���"}, + {string([]byte{0xff}) + "Hello" + string([]byte{0xfe}) + "世界" + string([]byte{0xfd}), "�Hello�世界�"}, + } + for _, tC := range testCases { + got := decodeGrpcMessage(encodeGrpcMessage(tC.orig)) + if got != tC.want { + t.Errorf("decodeGrpcMessage(encodeGrpcMessage(%q)) = %q, want %q", tC.orig, got, tC.want) } } }