mirror of https://github.com/grpc/grpc-go.git
grpc: make client report `Internal` status when server response contains unsupported encoding (#7461)
This commit is contained in:
parent
338595ca57
commit
6d0aaaec1d
18
rpc_util.go
18
rpc_util.go
|
@ -719,7 +719,7 @@ func outPayload(client bool, msg any, data, payload []byte, t time.Time) *stats.
|
|||
}
|
||||
}
|
||||
|
||||
func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool) *status.Status {
|
||||
func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool, isServer bool) *status.Status {
|
||||
switch pf {
|
||||
case compressionNone:
|
||||
case compressionMade:
|
||||
|
@ -727,7 +727,11 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool
|
|||
return status.New(codes.Internal, "grpc: compressed flag set with identity or empty encoding")
|
||||
}
|
||||
if !haveCompressor {
|
||||
return status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
|
||||
if isServer {
|
||||
return status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
|
||||
} else {
|
||||
return status.Newf(codes.Internal, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
|
||||
}
|
||||
}
|
||||
default:
|
||||
return status.Newf(codes.Internal, "grpc: received unexpected payload format %d", pf)
|
||||
|
@ -744,14 +748,16 @@ type payloadInfo struct {
|
|||
//
|
||||
// Cancelling the returned cancel function releases the buffer back to the pool. So the caller should cancel as soon as
|
||||
// the buffer is no longer needed.
|
||||
func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor,
|
||||
// TODO: Refactor this function to reduce the number of arguments.
|
||||
// See: https://google.github.io/styleguide/go/best-practices.html#function-argument-lists
|
||||
func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool,
|
||||
) (uncompressedBuf []byte, cancel func(), err error) {
|
||||
pf, compressedBuf, err := p.recvMsg(maxReceiveMessageSize)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil); st != nil {
|
||||
if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil, isServer); st != nil {
|
||||
return nil, nil, st.Err()
|
||||
}
|
||||
|
||||
|
@ -825,8 +831,8 @@ func decompress(compressor encoding.Compressor, d []byte, maxReceiveMessageSize
|
|||
// For the two compressor parameters, both should not be set, but if they are,
|
||||
// dc takes precedence over compressor.
|
||||
// TODO(dfawley): wrap the old compressor/decompressor using the new API?
|
||||
func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor) error {
|
||||
buf, cancel, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor)
|
||||
func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool) error {
|
||||
buf, cancel, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor, isServer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -1336,7 +1336,7 @@ func (s *Server) processUnaryRPC(ctx context.Context, t transport.ServerTranspor
|
|||
payInfo = &payloadInfo{}
|
||||
}
|
||||
|
||||
d, cancel, err := recvAndDecompress(&parser{r: stream, recvBufferPool: s.opts.recvBufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp)
|
||||
d, cancel, err := recvAndDecompress(&parser{r: stream, recvBufferPool: s.opts.recvBufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp, true)
|
||||
if err != nil {
|
||||
if e := t.WriteStatus(stream, status.Convert(err)); e != nil {
|
||||
channelz.Warningf(logger, s.channelz, "grpc: Server.processUnaryRPC failed to write status: %v", e)
|
||||
|
|
14
stream.go
14
stream.go
|
@ -1083,8 +1083,7 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) {
|
|||
// Only initialize this state once per stream.
|
||||
a.decompSet = true
|
||||
}
|
||||
err = recv(a.p, cs.codec, a.s, a.dc, m, *cs.callInfo.maxReceiveMessageSize, payInfo, a.decomp)
|
||||
if err != nil {
|
||||
if err := recv(a.p, cs.codec, a.s, a.dc, m, *cs.callInfo.maxReceiveMessageSize, payInfo, a.decomp, false); err != nil {
|
||||
if err == io.EOF {
|
||||
if statusErr := a.s.Status().Err(); statusErr != nil {
|
||||
return statusErr
|
||||
|
@ -1122,8 +1121,7 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) {
|
|||
}
|
||||
// Special handling for non-server-stream rpcs.
|
||||
// This recv expects EOF or errors, so we don't collect inPayload.
|
||||
err = recv(a.p, cs.codec, a.s, a.dc, m, *cs.callInfo.maxReceiveMessageSize, nil, a.decomp)
|
||||
if err == nil {
|
||||
if err := recv(a.p, cs.codec, a.s, a.dc, m, *cs.callInfo.maxReceiveMessageSize, nil, a.decomp, false); err == nil {
|
||||
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
|
||||
}
|
||||
if err == io.EOF {
|
||||
|
@ -1423,8 +1421,7 @@ func (as *addrConnStream) RecvMsg(m any) (err error) {
|
|||
// Only initialize this state once per stream.
|
||||
as.decompSet = true
|
||||
}
|
||||
err = recv(as.p, as.codec, as.s, as.dc, m, *as.callInfo.maxReceiveMessageSize, nil, as.decomp)
|
||||
if err != nil {
|
||||
if err := recv(as.p, as.codec, as.s, as.dc, m, *as.callInfo.maxReceiveMessageSize, nil, as.decomp, false); err != nil {
|
||||
if err == io.EOF {
|
||||
if statusErr := as.s.Status().Err(); statusErr != nil {
|
||||
return statusErr
|
||||
|
@ -1444,8 +1441,7 @@ func (as *addrConnStream) RecvMsg(m any) (err error) {
|
|||
|
||||
// Special handling for non-server-stream rpcs.
|
||||
// This recv expects EOF or errors, so we don't collect inPayload.
|
||||
err = recv(as.p, as.codec, as.s, as.dc, m, *as.callInfo.maxReceiveMessageSize, nil, as.decomp)
|
||||
if err == nil {
|
||||
if err := recv(as.p, as.codec, as.s, as.dc, m, *as.callInfo.maxReceiveMessageSize, nil, as.decomp, false); err == nil {
|
||||
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
|
||||
}
|
||||
if err == io.EOF {
|
||||
|
@ -1715,7 +1711,7 @@ func (ss *serverStream) RecvMsg(m any) (err error) {
|
|||
if len(ss.statsHandler) != 0 || len(ss.binlogs) != 0 {
|
||||
payInfo = &payloadInfo{}
|
||||
}
|
||||
if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, payInfo, ss.decomp); err != nil {
|
||||
if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, payInfo, ss.decomp, true); err != nil {
|
||||
if err == io.EOF {
|
||||
if len(ss.binlogs) != 0 {
|
||||
chc := &binarylog.ClientHalfClose{}
|
||||
|
|
|
@ -30,6 +30,7 @@ import (
|
|||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/encoding"
|
||||
"google.golang.org/grpc/internal/stubserver"
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
@ -39,6 +40,82 @@ import (
|
|||
testpb "google.golang.org/grpc/interop/grpc_testing"
|
||||
)
|
||||
|
||||
// TestUnsupportedEncodingResponse validates gRPC status codes
|
||||
// for different client-server compression setups
|
||||
// ensuring the correct behavior when compression is enabled or disabled on either side.
|
||||
func (s) TestUnsupportedEncodingResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
clientCompress bool
|
||||
serverCompress bool
|
||||
wantStatus codes.Code
|
||||
}{
|
||||
{
|
||||
name: "client_server_compression",
|
||||
clientCompress: true,
|
||||
serverCompress: true,
|
||||
wantStatus: codes.OK,
|
||||
},
|
||||
{
|
||||
name: "client_compression",
|
||||
clientCompress: true,
|
||||
serverCompress: false,
|
||||
wantStatus: codes.Unimplemented,
|
||||
},
|
||||
{
|
||||
name: "server_compression",
|
||||
clientCompress: false,
|
||||
serverCompress: true,
|
||||
wantStatus: codes.Internal,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
ss := &stubserver.StubServer{
|
||||
UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
|
||||
return &testpb.SimpleResponse{Payload: in.Payload}, nil
|
||||
},
|
||||
}
|
||||
sopts := []grpc.ServerOption{}
|
||||
if test.serverCompress {
|
||||
// Using deprecated methods to selectively apply compression
|
||||
// only on the server side. With encoding.registerCompressor(),
|
||||
// the compressor is applied globally, affecting client and server
|
||||
sopts = append(sopts, grpc.RPCCompressor(newNopCompressor()), grpc.RPCDecompressor(newNopDecompressor()))
|
||||
}
|
||||
if err := ss.StartServer(sopts...); err != nil {
|
||||
t.Fatalf("Error starting server: %v", err)
|
||||
}
|
||||
defer ss.Stop()
|
||||
|
||||
dopts := []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())}
|
||||
if test.clientCompress {
|
||||
// UseCompressor() requires the compressor to be registered
|
||||
// using encoding.RegisterCompressor() which applies compressor globally,
|
||||
// Hence, using deprecated WithCompressor() and WithDecompressor()
|
||||
// to apply compression only on client.
|
||||
dopts = append(dopts, grpc.WithCompressor(newNopCompressor()), grpc.WithDecompressor(newNopDecompressor()))
|
||||
}
|
||||
if err := ss.StartClient(dopts...); err != nil {
|
||||
t.Fatalf("Error starting client: %v", err)
|
||||
}
|
||||
|
||||
payload := &testpb.SimpleRequest{
|
||||
Payload: &testpb.Payload{
|
||||
Body: []byte("test message"),
|
||||
},
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||
defer cancel()
|
||||
_, err := ss.Client.UnaryCall(ctx, payload)
|
||||
if got, want := status.Code(err), test.wantStatus; got != want {
|
||||
t.Errorf("Client.UnaryCall() = %v, want %v", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestCompressServerHasNoSupport(t *testing.T) {
|
||||
for _, e := range listTestEnv() {
|
||||
testCompressServerHasNoSupport(t, e)
|
||||
|
|
Loading…
Reference in New Issue