From 6d0aaaec1d0f56508a80789692303dcb736e5c81 Mon Sep 17 00:00:00 2001 From: Gayathri625 Date: Tue, 6 Aug 2024 23:27:21 +0530 Subject: [PATCH] grpc: make client report `Internal` status when server response contains unsupported encoding (#7461) --- rpc_util.go | 18 ++++++---- server.go | 2 +- stream.go | 14 +++----- test/compressor_test.go | 77 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 95 insertions(+), 16 deletions(-) diff --git a/rpc_util.go b/rpc_util.go index 2d562d572..a206008bf 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -719,7 +719,7 @@ func outPayload(client bool, msg any, data, payload []byte, t time.Time) *stats. } } -func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool) *status.Status { +func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool, isServer bool) *status.Status { switch pf { case compressionNone: case compressionMade: @@ -727,7 +727,11 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool return status.New(codes.Internal, "grpc: compressed flag set with identity or empty encoding") } if !haveCompressor { - return status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress) + if isServer { + return status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress) + } else { + return status.Newf(codes.Internal, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress) + } } default: return status.Newf(codes.Internal, "grpc: received unexpected payload format %d", pf) @@ -744,14 +748,16 @@ type payloadInfo struct { // // Cancelling the returned cancel function releases the buffer back to the pool. So the caller should cancel as soon as // the buffer is no longer needed. -func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, +// TODO: Refactor this function to reduce the number of arguments. +// See: https://google.github.io/styleguide/go/best-practices.html#function-argument-lists +func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool, ) (uncompressedBuf []byte, cancel func(), err error) { pf, compressedBuf, err := p.recvMsg(maxReceiveMessageSize) if err != nil { return nil, nil, err } - if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil); st != nil { + if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil, isServer); st != nil { return nil, nil, st.Err() } @@ -825,8 +831,8 @@ func decompress(compressor encoding.Compressor, d []byte, maxReceiveMessageSize // For the two compressor parameters, both should not be set, but if they are, // dc takes precedence over compressor. // TODO(dfawley): wrap the old compressor/decompressor using the new API? -func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor) error { - buf, cancel, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor) +func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool) error { + buf, cancel, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor, isServer) if err != nil { return err } diff --git a/server.go b/server.go index 89f8e4792..41cf41ac2 100644 --- a/server.go +++ b/server.go @@ -1336,7 +1336,7 @@ func (s *Server) processUnaryRPC(ctx context.Context, t transport.ServerTranspor payInfo = &payloadInfo{} } - d, cancel, err := recvAndDecompress(&parser{r: stream, recvBufferPool: s.opts.recvBufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp) + d, cancel, err := recvAndDecompress(&parser{r: stream, recvBufferPool: s.opts.recvBufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp, true) if err != nil { if e := t.WriteStatus(stream, status.Convert(err)); e != nil { channelz.Warningf(logger, s.channelz, "grpc: Server.processUnaryRPC failed to write status: %v", e) diff --git a/stream.go b/stream.go index 8051ef5b5..24fea2024 100644 --- a/stream.go +++ b/stream.go @@ -1083,8 +1083,7 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) { // Only initialize this state once per stream. a.decompSet = true } - err = recv(a.p, cs.codec, a.s, a.dc, m, *cs.callInfo.maxReceiveMessageSize, payInfo, a.decomp) - if err != nil { + if err := recv(a.p, cs.codec, a.s, a.dc, m, *cs.callInfo.maxReceiveMessageSize, payInfo, a.decomp, false); err != nil { if err == io.EOF { if statusErr := a.s.Status().Err(); statusErr != nil { return statusErr @@ -1122,8 +1121,7 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) { } // Special handling for non-server-stream rpcs. // This recv expects EOF or errors, so we don't collect inPayload. - err = recv(a.p, cs.codec, a.s, a.dc, m, *cs.callInfo.maxReceiveMessageSize, nil, a.decomp) - if err == nil { + if err := recv(a.p, cs.codec, a.s, a.dc, m, *cs.callInfo.maxReceiveMessageSize, nil, a.decomp, false); err == nil { return toRPCErr(errors.New("grpc: client streaming protocol violation: get , want ")) } if err == io.EOF { @@ -1423,8 +1421,7 @@ func (as *addrConnStream) RecvMsg(m any) (err error) { // Only initialize this state once per stream. as.decompSet = true } - err = recv(as.p, as.codec, as.s, as.dc, m, *as.callInfo.maxReceiveMessageSize, nil, as.decomp) - if err != nil { + if err := recv(as.p, as.codec, as.s, as.dc, m, *as.callInfo.maxReceiveMessageSize, nil, as.decomp, false); err != nil { if err == io.EOF { if statusErr := as.s.Status().Err(); statusErr != nil { return statusErr @@ -1444,8 +1441,7 @@ func (as *addrConnStream) RecvMsg(m any) (err error) { // Special handling for non-server-stream rpcs. // This recv expects EOF or errors, so we don't collect inPayload. - err = recv(as.p, as.codec, as.s, as.dc, m, *as.callInfo.maxReceiveMessageSize, nil, as.decomp) - if err == nil { + if err := recv(as.p, as.codec, as.s, as.dc, m, *as.callInfo.maxReceiveMessageSize, nil, as.decomp, false); err == nil { return toRPCErr(errors.New("grpc: client streaming protocol violation: get , want ")) } if err == io.EOF { @@ -1715,7 +1711,7 @@ func (ss *serverStream) RecvMsg(m any) (err error) { if len(ss.statsHandler) != 0 || len(ss.binlogs) != 0 { payInfo = &payloadInfo{} } - if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, payInfo, ss.decomp); err != nil { + if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, payInfo, ss.decomp, true); err != nil { if err == io.EOF { if len(ss.binlogs) != 0 { chc := &binarylog.ClientHalfClose{} diff --git a/test/compressor_test.go b/test/compressor_test.go index 7f3abb908..4340c772b 100644 --- a/test/compressor_test.go +++ b/test/compressor_test.go @@ -30,6 +30,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/encoding" "google.golang.org/grpc/internal/stubserver" "google.golang.org/grpc/metadata" @@ -39,6 +40,82 @@ import ( testpb "google.golang.org/grpc/interop/grpc_testing" ) +// TestUnsupportedEncodingResponse validates gRPC status codes +// for different client-server compression setups +// ensuring the correct behavior when compression is enabled or disabled on either side. +func (s) TestUnsupportedEncodingResponse(t *testing.T) { + tests := []struct { + name string + clientCompress bool + serverCompress bool + wantStatus codes.Code + }{ + { + name: "client_server_compression", + clientCompress: true, + serverCompress: true, + wantStatus: codes.OK, + }, + { + name: "client_compression", + clientCompress: true, + serverCompress: false, + wantStatus: codes.Unimplemented, + }, + { + name: "server_compression", + clientCompress: false, + serverCompress: true, + wantStatus: codes.Internal, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ss := &stubserver.StubServer{ + UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + return &testpb.SimpleResponse{Payload: in.Payload}, nil + }, + } + sopts := []grpc.ServerOption{} + if test.serverCompress { + // Using deprecated methods to selectively apply compression + // only on the server side. With encoding.registerCompressor(), + // the compressor is applied globally, affecting client and server + sopts = append(sopts, grpc.RPCCompressor(newNopCompressor()), grpc.RPCDecompressor(newNopDecompressor())) + } + if err := ss.StartServer(sopts...); err != nil { + t.Fatalf("Error starting server: %v", err) + } + defer ss.Stop() + + dopts := []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())} + if test.clientCompress { + // UseCompressor() requires the compressor to be registered + // using encoding.RegisterCompressor() which applies compressor globally, + // Hence, using deprecated WithCompressor() and WithDecompressor() + // to apply compression only on client. + dopts = append(dopts, grpc.WithCompressor(newNopCompressor()), grpc.WithDecompressor(newNopDecompressor())) + } + if err := ss.StartClient(dopts...); err != nil { + t.Fatalf("Error starting client: %v", err) + } + + payload := &testpb.SimpleRequest{ + Payload: &testpb.Payload{ + Body: []byte("test message"), + }, + } + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + _, err := ss.Client.UnaryCall(ctx, payload) + if got, want := status.Code(err), test.wantStatus; got != want { + t.Errorf("Client.UnaryCall() = %v, want %v", got, want) + } + }) + } +} + func (s) TestCompressServerHasNoSupport(t *testing.T) { for _, e := range listTestEnv() { testCompressServerHasNoSupport(t, e)