mirror of https://github.com/grpc/grpc-go.git
				
				
				
			rpc_util: Fix RecvBufferPool deactivation issues (#6766)
This commit is contained in:
		
							parent
							
								
									76a23bf37a
								
							
						
					
					
						commit
						5ccf176a08
					
				|  | @ -26,12 +26,12 @@ import ( | |||
| 	"time" | ||||
| 
 | ||||
| 	"google.golang.org/grpc" | ||||
| 	"google.golang.org/grpc/encoding/gzip" | ||||
| 	"google.golang.org/grpc/experimental" | ||||
| 	"google.golang.org/grpc/internal/grpctest" | ||||
| 	"google.golang.org/grpc/internal/stubserver" | ||||
| 
 | ||||
| 	testgrpc "google.golang.org/grpc/interop/grpc_testing" | ||||
| 	testpb "google.golang.org/grpc/interop/grpc_testing" | ||||
| ) | ||||
| 
 | ||||
| type s struct { | ||||
|  | @ -44,17 +44,35 @@ func Test(t *testing.T) { | |||
| 
 | ||||
| const defaultTestTimeout = 10 * time.Second | ||||
| 
 | ||||
| func (s) TestRecvBufferPool(t *testing.T) { | ||||
| func (s) TestRecvBufferPoolStream(t *testing.T) { | ||||
| 	tcs := []struct { | ||||
| 		name     string | ||||
| 		callOpts []grpc.CallOption | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name: "default", | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "useCompressor", | ||||
| 			callOpts: []grpc.CallOption{ | ||||
| 				grpc.UseCompressor(gzip.Name), | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	for _, tc := range tcs { | ||||
| 		t.Run(tc.name, func(t *testing.T) { | ||||
| 			const reqCount = 10 | ||||
| 
 | ||||
| 			ss := &stubserver.StubServer{ | ||||
| 				FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error { | ||||
| 			for i := 0; i < 10; i++ { | ||||
| 					for i := 0; i < reqCount; i++ { | ||||
| 						preparedMsg := &grpc.PreparedMsg{} | ||||
| 				err := preparedMsg.Encode(stream, &testpb.StreamingOutputCallResponse{ | ||||
| 					Payload: &testpb.Payload{ | ||||
| 						if err := preparedMsg.Encode(stream, &testgrpc.StreamingOutputCallResponse{ | ||||
| 							Payload: &testgrpc.Payload{ | ||||
| 								Body: []byte{'0' + uint8(i)}, | ||||
| 							}, | ||||
| 				}) | ||||
| 				if err != nil { | ||||
| 						}); err != nil { | ||||
| 							return err | ||||
| 						} | ||||
| 						stream.SendMsg(preparedMsg) | ||||
|  | @ -62,8 +80,10 @@ func (s) TestRecvBufferPool(t *testing.T) { | |||
| 					return nil | ||||
| 				}, | ||||
| 			} | ||||
| 	sopts := []grpc.ServerOption{experimental.RecvBufferPool(grpc.NewSharedBufferPool())} | ||||
| 	dopts := []grpc.DialOption{experimental.WithRecvBufferPool(grpc.NewSharedBufferPool())} | ||||
| 
 | ||||
| 			pool := &checkBufferPool{} | ||||
| 			sopts := []grpc.ServerOption{experimental.RecvBufferPool(pool)} | ||||
| 			dopts := []grpc.DialOption{experimental.WithRecvBufferPool(pool)} | ||||
| 			if err := ss.Start(sopts, dopts...); err != nil { | ||||
| 				t.Fatalf("Error starting endpoint server: %v", err) | ||||
| 			} | ||||
|  | @ -72,9 +92,9 @@ func (s) TestRecvBufferPool(t *testing.T) { | |||
| 			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) | ||||
| 			defer cancel() | ||||
| 
 | ||||
| 	stream, err := ss.Client.FullDuplexCall(ctx) | ||||
| 			stream, err := ss.Client.FullDuplexCall(ctx, tc.callOpts...) | ||||
| 			if err != nil { | ||||
| 		t.Fatalf("ss.Client.FullDuplexCall failed: %f", err) | ||||
| 				t.Fatalf("ss.Client.FullDuplexCall failed: %v", err) | ||||
| 			} | ||||
| 
 | ||||
| 			var ngot int | ||||
|  | @ -94,9 +114,91 @@ func (s) TestRecvBufferPool(t *testing.T) { | |||
| 				buf.Write(reply.GetPayload().GetBody()) | ||||
| 			} | ||||
| 			if want := 10; ngot != want { | ||||
| 		t.Errorf("Got %d replies, want %d", ngot, want) | ||||
| 				t.Fatalf("Got %d replies, want %d", ngot, want) | ||||
| 			} | ||||
| 			if got, want := buf.String(), "0,1,2,3,4,5,6,7,8,9"; got != want { | ||||
| 		t.Errorf("Got replies %q; want %q", got, want) | ||||
| 				t.Fatalf("Got replies %q; want %q", got, want) | ||||
| 			} | ||||
| 
 | ||||
| 			if len(pool.puts) != reqCount { | ||||
| 				t.Fatalf("Expected 10 buffers to be returned to the pool, got %d", len(pool.puts)) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (s) TestRecvBufferPoolUnary(t *testing.T) { | ||||
| 	tcs := []struct { | ||||
| 		name     string | ||||
| 		callOpts []grpc.CallOption | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name: "default", | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "useCompressor", | ||||
| 			callOpts: []grpc.CallOption{ | ||||
| 				grpc.UseCompressor(gzip.Name), | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	for _, tc := range tcs { | ||||
| 		t.Run(tc.name, func(t *testing.T) { | ||||
| 			const largeSize = 1024 | ||||
| 
 | ||||
| 			ss := &stubserver.StubServer{ | ||||
| 				UnaryCallF: func(ctx context.Context, in *testgrpc.SimpleRequest) (*testgrpc.SimpleResponse, error) { | ||||
| 					return &testgrpc.SimpleResponse{ | ||||
| 						Payload: &testgrpc.Payload{ | ||||
| 							Body: make([]byte, largeSize), | ||||
| 						}, | ||||
| 					}, nil | ||||
| 				}, | ||||
| 			} | ||||
| 
 | ||||
| 			pool := &checkBufferPool{} | ||||
| 			sopts := []grpc.ServerOption{experimental.RecvBufferPool(pool)} | ||||
| 			dopts := []grpc.DialOption{experimental.WithRecvBufferPool(pool)} | ||||
| 			if err := ss.Start(sopts, dopts...); err != nil { | ||||
| 				t.Fatalf("Error starting endpoint server: %v", err) | ||||
| 			} | ||||
| 			defer ss.Stop() | ||||
| 
 | ||||
| 			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) | ||||
| 			defer cancel() | ||||
| 
 | ||||
| 			const reqCount = 10 | ||||
| 			for i := 0; i < reqCount; i++ { | ||||
| 				if _, err := ss.Client.UnaryCall( | ||||
| 					ctx, | ||||
| 					&testgrpc.SimpleRequest{ | ||||
| 						Payload: &testgrpc.Payload{ | ||||
| 							Body: make([]byte, largeSize), | ||||
| 						}, | ||||
| 					}, | ||||
| 					tc.callOpts..., | ||||
| 				); err != nil { | ||||
| 					t.Fatalf("ss.Client.UnaryCall failed: %v", err) | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			const bufferCount = reqCount * 2 // req + resp
 | ||||
| 			if len(pool.puts) != bufferCount { | ||||
| 				t.Fatalf("Expected %d buffers to be returned to the pool, got %d", bufferCount, len(pool.puts)) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type checkBufferPool struct { | ||||
| 	puts [][]byte | ||||
| } | ||||
| 
 | ||||
| func (p *checkBufferPool) Get(size int) []byte { | ||||
| 	return make([]byte, size) | ||||
| } | ||||
| 
 | ||||
| func (p *checkBufferPool) Put(bs *[]byte) { | ||||
| 	p.puts = append(p.puts, *bs) | ||||
| } | ||||
|  |  | |||
							
								
								
									
										54
									
								
								rpc_util.go
								
								
								
								
							
							
						
						
									
										54
									
								
								rpc_util.go
								
								
								
								
							|  | @ -744,17 +744,19 @@ type payloadInfo struct { | |||
| 	uncompressedBytes []byte | ||||
| } | ||||
| 
 | ||||
| func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor) ([]byte, error) { | ||||
| 	pf, buf, err := p.recvMsg(maxReceiveMessageSize) | ||||
| // recvAndDecompress reads a message from the stream, decompressing it if necessary.
 | ||||
| //
 | ||||
| // Cancelling the returned cancel function releases the buffer back to the pool. So the caller should cancel as soon as
 | ||||
| // the buffer is no longer needed.
 | ||||
| func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, | ||||
| ) (uncompressedBuf []byte, cancel func(), err error) { | ||||
| 	pf, compressedBuf, err := p.recvMsg(maxReceiveMessageSize) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if payInfo != nil { | ||||
| 		payInfo.compressedLength = len(buf) | ||||
| 		return nil, nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil); st != nil { | ||||
| 		return nil, st.Err() | ||||
| 		return nil, nil, st.Err() | ||||
| 	} | ||||
| 
 | ||||
| 	var size int | ||||
|  | @ -762,21 +764,35 @@ func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxRecei | |||
| 		// To match legacy behavior, if the decompressor is set by WithDecompressor or RPCDecompressor,
 | ||||
| 		// use this decompressor as the default.
 | ||||
| 		if dc != nil { | ||||
| 			buf, err = dc.Do(bytes.NewReader(buf)) | ||||
| 			size = len(buf) | ||||
| 			uncompressedBuf, err = dc.Do(bytes.NewReader(compressedBuf)) | ||||
| 			size = len(uncompressedBuf) | ||||
| 		} else { | ||||
| 			buf, size, err = decompress(compressor, buf, maxReceiveMessageSize) | ||||
| 			uncompressedBuf, size, err = decompress(compressor, compressedBuf, maxReceiveMessageSize) | ||||
| 		} | ||||
| 		if err != nil { | ||||
| 			return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err) | ||||
| 			return nil, nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err) | ||||
| 		} | ||||
| 		if size > maxReceiveMessageSize { | ||||
| 			// TODO: Revisit the error code. Currently keep it consistent with java
 | ||||
| 			// implementation.
 | ||||
| 			return nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max (%d vs. %d)", size, maxReceiveMessageSize) | ||||
| 			return nil, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max (%d vs. %d)", size, maxReceiveMessageSize) | ||||
| 		} | ||||
| 	} else { | ||||
| 		uncompressedBuf = compressedBuf | ||||
| 	} | ||||
| 
 | ||||
| 	if payInfo != nil { | ||||
| 		payInfo.compressedLength = len(compressedBuf) | ||||
| 		payInfo.uncompressedBytes = uncompressedBuf | ||||
| 
 | ||||
| 		cancel = func() {} | ||||
| 	} else { | ||||
| 		cancel = func() { | ||||
| 			p.recvBufferPool.Put(&compressedBuf) | ||||
| 		} | ||||
| 	} | ||||
| 	return buf, nil | ||||
| 
 | ||||
| 	return uncompressedBuf, cancel, nil | ||||
| } | ||||
| 
 | ||||
| // Using compressor, decompress d, returning data and size.
 | ||||
|  | @ -796,6 +812,9 @@ func decompress(compressor encoding.Compressor, d []byte, maxReceiveMessageSize | |||
| 			// size is used as an estimate to size the buffer, but we
 | ||||
| 			// will read more data if available.
 | ||||
| 			// +MinRead so ReadFrom will not reallocate if size is correct.
 | ||||
| 			//
 | ||||
| 			// TODO: If we ensure that the buffer size is the same as the DecompressedSize,
 | ||||
| 			// we can also utilize the recv buffer pool here.
 | ||||
| 			buf := bytes.NewBuffer(make([]byte, 0, size+bytes.MinRead)) | ||||
| 			bytesRead, err := buf.ReadFrom(io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1)) | ||||
| 			return buf.Bytes(), int(bytesRead), err | ||||
|  | @ -811,18 +830,15 @@ func decompress(compressor encoding.Compressor, d []byte, maxReceiveMessageSize | |||
| // dc takes precedence over compressor.
 | ||||
| // TODO(dfawley): wrap the old compressor/decompressor using the new API?
 | ||||
| func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor) error { | ||||
| 	buf, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor) | ||||
| 	buf, cancel, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	defer cancel() | ||||
| 
 | ||||
| 	if err := c.Unmarshal(buf, m); err != nil { | ||||
| 		return status.Errorf(codes.Internal, "grpc: failed to unmarshal the received message: %v", err) | ||||
| 	} | ||||
| 	if payInfo != nil { | ||||
| 		payInfo.uncompressedBytes = buf | ||||
| 	} else { | ||||
| 		p.recvBufferPool.Put(&buf) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -1340,7 +1340,8 @@ func (s *Server) processUnaryRPC(ctx context.Context, t transport.ServerTranspor | |||
| 	if len(shs) != 0 || len(binlogs) != 0 { | ||||
| 		payInfo = &payloadInfo{} | ||||
| 	} | ||||
| 	d, err := recvAndDecompress(&parser{r: stream, recvBufferPool: s.opts.recvBufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp) | ||||
| 
 | ||||
| 	d, cancel, err := recvAndDecompress(&parser{r: stream, recvBufferPool: s.opts.recvBufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp) | ||||
| 	if err != nil { | ||||
| 		if e := t.WriteStatus(stream, status.Convert(err)); e != nil { | ||||
| 			channelz.Warningf(logger, s.channelzID, "grpc: Server.processUnaryRPC failed to write status: %v", e) | ||||
|  | @ -1351,6 +1352,8 @@ func (s *Server) processUnaryRPC(ctx context.Context, t transport.ServerTranspor | |||
| 		t.IncrMsgRecv() | ||||
| 	} | ||||
| 	df := func(v any) error { | ||||
| 		defer cancel() | ||||
| 
 | ||||
| 		if err := s.getCodec(stream.ContentSubtype()).Unmarshal(d, v); err != nil { | ||||
| 			return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err) | ||||
| 		} | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue