Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

grpc: make client report Internal status when server response contains unsupported encoding #7461

Merged
merged 15 commits into from
Aug 6, 2024
Merged
16 changes: 10 additions & 6 deletions rpc_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -719,15 +719,19 @@
}
}

func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool) *status.Status {
func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool, isServer bool) *status.Status {
switch pf {
case compressionNone:
case compressionMade:
if recvCompress == "" || recvCompress == encoding.Identity {
return status.New(codes.Internal, "grpc: compressed flag set with identity or empty encoding")
}
if !haveCompressor {
return status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
if isServer {
return status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)

Check warning on line 731 in rpc_util.go

View check run for this annotation

Codecov / codecov/patch

rpc_util.go#L731

Added line #L731 was not covered by tests
} else {
return status.Newf(codes.Internal, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
easwars marked this conversation as resolved.
Show resolved Hide resolved
}
}
default:
return status.Newf(codes.Internal, "grpc: received unexpected payload format %d", pf)
Expand All @@ -744,14 +748,14 @@
//
// Cancelling the returned cancel function releases the buffer back to the pool. So the caller should cancel as soon as
// the buffer is no longer needed.
func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor,
func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool,
easwars marked this conversation as resolved.
Show resolved Hide resolved
) (uncompressedBuf []byte, cancel func(), err error) {
pf, compressedBuf, err := p.recvMsg(maxReceiveMessageSize)
if err != nil {
return nil, nil, err
}

if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil); st != nil {
if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil, isServer); st != nil {
return nil, nil, st.Err()
}

Expand Down Expand Up @@ -825,8 +829,8 @@
// For the two compressor parameters, both should not be set, but if they are,
// dc takes precedence over compressor.
// TODO(dfawley): wrap the old compressor/decompressor using the new API?
func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor) error {
buf, cancel, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor)
func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool) error {
buf, cancel, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor, isServer)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1336,7 +1336,7 @@ func (s *Server) processUnaryRPC(ctx context.Context, t transport.ServerTranspor
payInfo = &payloadInfo{}
}

d, cancel, err := recvAndDecompress(&parser{r: stream, recvBufferPool: s.opts.recvBufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp)
d, cancel, err := recvAndDecompress(&parser{r: stream, recvBufferPool: s.opts.recvBufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp, true)
if err != nil {
if e := t.WriteStatus(stream, status.Convert(err)); e != nil {
channelz.Warningf(logger, s.channelz, "grpc: Server.processUnaryRPC failed to write status: %v", e)
Expand Down
10 changes: 5 additions & 5 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -1083,7 +1083,7 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) {
// Only initialize this state once per stream.
a.decompSet = true
}
err = recv(a.p, cs.codec, a.s, a.dc, m, *cs.callInfo.maxReceiveMessageSize, payInfo, a.decomp)
err = recv(a.p, cs.codec, a.s, a.dc, m, *cs.callInfo.maxReceiveMessageSize, payInfo, a.decomp, false)
easwars marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
if err == io.EOF {
if statusErr := a.s.Status().Err(); statusErr != nil {
Expand Down Expand Up @@ -1122,7 +1122,7 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) {
}
// Special handling for non-server-stream rpcs.
// This recv expects EOF or errors, so we don't collect inPayload.
err = recv(a.p, cs.codec, a.s, a.dc, m, *cs.callInfo.maxReceiveMessageSize, nil, a.decomp)
err = recv(a.p, cs.codec, a.s, a.dc, m, *cs.callInfo.maxReceiveMessageSize, nil, a.decomp, false)
if err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
}
Expand Down Expand Up @@ -1423,7 +1423,7 @@ func (as *addrConnStream) RecvMsg(m any) (err error) {
// Only initialize this state once per stream.
as.decompSet = true
}
err = recv(as.p, as.codec, as.s, as.dc, m, *as.callInfo.maxReceiveMessageSize, nil, as.decomp)
err = recv(as.p, as.codec, as.s, as.dc, m, *as.callInfo.maxReceiveMessageSize, nil, as.decomp, false)
if err != nil {
if err == io.EOF {
if statusErr := as.s.Status().Err(); statusErr != nil {
Expand All @@ -1444,7 +1444,7 @@ func (as *addrConnStream) RecvMsg(m any) (err error) {

// Special handling for non-server-stream rpcs.
// This recv expects EOF or errors, so we don't collect inPayload.
err = recv(as.p, as.codec, as.s, as.dc, m, *as.callInfo.maxReceiveMessageSize, nil, as.decomp)
err = recv(as.p, as.codec, as.s, as.dc, m, *as.callInfo.maxReceiveMessageSize, nil, as.decomp, false)
if err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
}
Expand Down Expand Up @@ -1715,7 +1715,7 @@ func (ss *serverStream) RecvMsg(m any) (err error) {
if len(ss.statsHandler) != 0 || len(ss.binlogs) != 0 {
payInfo = &payloadInfo{}
}
if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, payInfo, ss.decomp); err != nil {
if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, payInfo, ss.decomp, true); err != nil {
if err == io.EOF {
if len(ss.binlogs) != 0 {
chc := &binarylog.ClientHalfClose{}
Expand Down
78 changes: 78 additions & 0 deletions test/compressor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (

"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/encoding"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/metadata"
Expand All @@ -39,6 +40,83 @@ import (
testpb "google.golang.org/grpc/interop/grpc_testing"
)

// TestUnsupportedEncodingResponse validates gRPC status codes for different client-server compression setups
// ensuring the correct behavior when compression is enabled or disabled on either side.
func (s) TestUnsupportedEncodingResponse(t *testing.T) {
tests := []struct {
name string
clientUseNop bool
serverUseNop bool
easwars marked this conversation as resolved.
Show resolved Hide resolved
expectedStatus codes.Code
easwars marked this conversation as resolved.
Show resolved Hide resolved
}{
{
name: "client_server_nop_compression",
easwars marked this conversation as resolved.
Show resolved Hide resolved
clientUseNop: true,
serverUseNop: true,
expectedStatus: codes.OK,
},
{
name: "client_nop_compression",
clientUseNop: true,
serverUseNop: false,
expectedStatus: codes.Unimplemented,
},
{
name: "server_nop_compression",
clientUseNop: false,
serverUseNop: true,
expectedStatus: codes.Internal,
},
}
Gayathri625 marked this conversation as resolved.
Show resolved Hide resolved

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ss := &stubserver.StubServer{
UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
return &testpb.SimpleResponse{
Payload: in.Payload,
}, nil
easwars marked this conversation as resolved.
Show resolved Hide resolved
},
}
sopts := []grpc.ServerOption{}
if test.serverUseNop {
// Using deprecated methods to selectively apply compression only on the server side.
// with encoding.registerCompressor(), the compressor is applied globally, affecting both the client and server.
easwars marked this conversation as resolved.
Show resolved Hide resolved
sopts = append(sopts, grpc.RPCCompressor(newNopCompressor()), grpc.RPCDecompressor(newNopDecompressor()))
Gayathri625 marked this conversation as resolved.
Show resolved Hide resolved
}
if err := ss.Start(sopts); err != nil {
t.Fatalf("Error starting server: %v", err)
}

easwars marked this conversation as resolved.
Show resolved Hide resolved
defer ss.Stop()
dOpts := []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())}
easwars marked this conversation as resolved.
Show resolved Hide resolved
if test.clientUseNop {
// UseCompressor() requires the compressor to be registered using encoding.RegisterCompressor() which applies compressor globally,
// Hence, using deprecated WithCompressor() and WithDecompressor() to apply compression only on client.
dOpts = append(dOpts, grpc.WithCompressor(newNopCompressor()), grpc.WithDecompressor(newNopDecompressor()))
}
cc, err := grpc.NewClient(ss.Address, dOpts...)
if err != nil {
t.Fatalf("grpc.NewClient() returned unexpected error: %v", err)
}
defer cc.Close()
easwars marked this conversation as resolved.
Show resolved Hide resolved
ss.Client = testpb.NewTestServiceClient(cc)
easwars marked this conversation as resolved.
Show resolved Hide resolved

payload := &testpb.SimpleRequest{
Payload: &testpb.Payload{
Body: []byte("test message"),
},
}
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
_, err = ss.Client.UnaryCall(ctx, payload)
if got, want := status.Code(err), test.expectedStatus; got != want {
t.Errorf("Client.UnaryCall() = %v, want %v", got, want)
}
})
}
}

func (s) TestCompressServerHasNoSupport(t *testing.T) {
for _, e := range listTestEnv() {
testCompressServerHasNoSupport(t, e)
Expand Down