mirror of https://github.com/grpc/grpc-go.git
rpc_util: Fix RecvBufferPool deactivation issues (#6766)
This commit is contained in:
parent
9d981b0eb0
commit
d076e14b48
|
@ -26,12 +26,12 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/encoding/gzip"
|
||||||
"google.golang.org/grpc/experimental"
|
"google.golang.org/grpc/experimental"
|
||||||
"google.golang.org/grpc/internal/grpctest"
|
"google.golang.org/grpc/internal/grpctest"
|
||||||
"google.golang.org/grpc/internal/stubserver"
|
"google.golang.org/grpc/internal/stubserver"
|
||||||
|
|
||||||
testgrpc "google.golang.org/grpc/interop/grpc_testing"
|
testgrpc "google.golang.org/grpc/interop/grpc_testing"
|
||||||
testpb "google.golang.org/grpc/interop/grpc_testing"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type s struct {
|
type s struct {
|
||||||
|
@ -44,59 +44,161 @@ func Test(t *testing.T) {
|
||||||
|
|
||||||
const defaultTestTimeout = 10 * time.Second
|
const defaultTestTimeout = 10 * time.Second
|
||||||
|
|
||||||
func (s) TestRecvBufferPool(t *testing.T) {
|
func (s) TestRecvBufferPoolStream(t *testing.T) {
|
||||||
ss := &stubserver.StubServer{
|
tcs := []struct {
|
||||||
FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
|
name string
|
||||||
for i := 0; i < 10; i++ {
|
callOpts []grpc.CallOption
|
||||||
preparedMsg := &grpc.PreparedMsg{}
|
}{
|
||||||
err := preparedMsg.Encode(stream, &testpb.StreamingOutputCallResponse{
|
{
|
||||||
Payload: &testpb.Payload{
|
name: "default",
|
||||||
Body: []byte{'0' + uint8(i)},
|
},
|
||||||
},
|
{
|
||||||
})
|
name: "useCompressor",
|
||||||
if err != nil {
|
callOpts: []grpc.CallOption{
|
||||||
return err
|
grpc.UseCompressor(gzip.Name),
|
||||||
}
|
},
|
||||||
stream.SendMsg(preparedMsg)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
sopts := []grpc.ServerOption{experimental.RecvBufferPool(grpc.NewSharedBufferPool())}
|
|
||||||
dopts := []grpc.DialOption{experimental.WithRecvBufferPool(grpc.NewSharedBufferPool())}
|
|
||||||
if err := ss.Start(sopts, dopts...); err != nil {
|
|
||||||
t.Fatalf("Error starting endpoint server: %v", err)
|
|
||||||
}
|
|
||||||
defer ss.Stop()
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
for _, tc := range tcs {
|
||||||
defer cancel()
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
const reqCount = 10
|
||||||
|
|
||||||
stream, err := ss.Client.FullDuplexCall(ctx)
|
ss := &stubserver.StubServer{
|
||||||
if err != nil {
|
FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
|
||||||
t.Fatalf("ss.Client.FullDuplexCall failed: %f", err)
|
for i := 0; i < reqCount; i++ {
|
||||||
}
|
preparedMsg := &grpc.PreparedMsg{}
|
||||||
|
if err := preparedMsg.Encode(stream, &testgrpc.StreamingOutputCallResponse{
|
||||||
|
Payload: &testgrpc.Payload{
|
||||||
|
Body: []byte{'0' + uint8(i)},
|
||||||
|
},
|
||||||
|
}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
stream.SendMsg(preparedMsg)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
var ngot int
|
pool := &checkBufferPool{}
|
||||||
var buf bytes.Buffer
|
sopts := []grpc.ServerOption{experimental.RecvBufferPool(pool)}
|
||||||
for {
|
dopts := []grpc.DialOption{experimental.WithRecvBufferPool(pool)}
|
||||||
reply, err := stream.Recv()
|
if err := ss.Start(sopts, dopts...); err != nil {
|
||||||
if err == io.EOF {
|
t.Fatalf("Error starting endpoint server: %v", err)
|
||||||
break
|
}
|
||||||
}
|
defer ss.Stop()
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||||
}
|
defer cancel()
|
||||||
ngot++
|
|
||||||
if buf.Len() > 0 {
|
stream, err := ss.Client.FullDuplexCall(ctx, tc.callOpts...)
|
||||||
buf.WriteByte(',')
|
if err != nil {
|
||||||
}
|
t.Fatalf("ss.Client.FullDuplexCall failed: %v", err)
|
||||||
buf.Write(reply.GetPayload().GetBody())
|
}
|
||||||
}
|
|
||||||
if want := 10; ngot != want {
|
var ngot int
|
||||||
t.Errorf("Got %d replies, want %d", ngot, want)
|
var buf bytes.Buffer
|
||||||
}
|
for {
|
||||||
if got, want := buf.String(), "0,1,2,3,4,5,6,7,8,9"; got != want {
|
reply, err := stream.Recv()
|
||||||
t.Errorf("Got replies %q; want %q", got, want)
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
ngot++
|
||||||
|
if buf.Len() > 0 {
|
||||||
|
buf.WriteByte(',')
|
||||||
|
}
|
||||||
|
buf.Write(reply.GetPayload().GetBody())
|
||||||
|
}
|
||||||
|
if want := 10; ngot != want {
|
||||||
|
t.Fatalf("Got %d replies, want %d", ngot, want)
|
||||||
|
}
|
||||||
|
if got, want := buf.String(), "0,1,2,3,4,5,6,7,8,9"; got != want {
|
||||||
|
t.Fatalf("Got replies %q; want %q", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(pool.puts) != reqCount {
|
||||||
|
t.Fatalf("Expected 10 buffers to be returned to the pool, got %d", len(pool.puts))
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s) TestRecvBufferPoolUnary(t *testing.T) {
|
||||||
|
tcs := []struct {
|
||||||
|
name string
|
||||||
|
callOpts []grpc.CallOption
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "default",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "useCompressor",
|
||||||
|
callOpts: []grpc.CallOption{
|
||||||
|
grpc.UseCompressor(gzip.Name),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tcs {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
const largeSize = 1024
|
||||||
|
|
||||||
|
ss := &stubserver.StubServer{
|
||||||
|
UnaryCallF: func(ctx context.Context, in *testgrpc.SimpleRequest) (*testgrpc.SimpleResponse, error) {
|
||||||
|
return &testgrpc.SimpleResponse{
|
||||||
|
Payload: &testgrpc.Payload{
|
||||||
|
Body: make([]byte, largeSize),
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pool := &checkBufferPool{}
|
||||||
|
sopts := []grpc.ServerOption{experimental.RecvBufferPool(pool)}
|
||||||
|
dopts := []grpc.DialOption{experimental.WithRecvBufferPool(pool)}
|
||||||
|
if err := ss.Start(sopts, dopts...); err != nil {
|
||||||
|
t.Fatalf("Error starting endpoint server: %v", err)
|
||||||
|
}
|
||||||
|
defer ss.Stop()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
const reqCount = 10
|
||||||
|
for i := 0; i < reqCount; i++ {
|
||||||
|
if _, err := ss.Client.UnaryCall(
|
||||||
|
ctx,
|
||||||
|
&testgrpc.SimpleRequest{
|
||||||
|
Payload: &testgrpc.Payload{
|
||||||
|
Body: make([]byte, largeSize),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
tc.callOpts...,
|
||||||
|
); err != nil {
|
||||||
|
t.Fatalf("ss.Client.UnaryCall failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const bufferCount = reqCount * 2 // req + resp
|
||||||
|
if len(pool.puts) != bufferCount {
|
||||||
|
t.Fatalf("Expected %d buffers to be returned to the pool, got %d", bufferCount, len(pool.puts))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type checkBufferPool struct {
|
||||||
|
puts [][]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *checkBufferPool) Get(size int) []byte {
|
||||||
|
return make([]byte, size)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *checkBufferPool) Put(bs *[]byte) {
|
||||||
|
p.puts = append(p.puts, *bs)
|
||||||
|
}
|
||||||
|
|
54
rpc_util.go
54
rpc_util.go
|
@ -744,17 +744,19 @@ type payloadInfo struct {
|
||||||
uncompressedBytes []byte
|
uncompressedBytes []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor) ([]byte, error) {
|
// recvAndDecompress reads a message from the stream, decompressing it if necessary.
|
||||||
pf, buf, err := p.recvMsg(maxReceiveMessageSize)
|
//
|
||||||
|
// Cancelling the returned cancel function releases the buffer back to the pool. So the caller should cancel as soon as
|
||||||
|
// the buffer is no longer needed.
|
||||||
|
func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor,
|
||||||
|
) (uncompressedBuf []byte, cancel func(), err error) {
|
||||||
|
pf, compressedBuf, err := p.recvMsg(maxReceiveMessageSize)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
|
||||||
if payInfo != nil {
|
|
||||||
payInfo.compressedLength = len(buf)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil); st != nil {
|
if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil); st != nil {
|
||||||
return nil, st.Err()
|
return nil, nil, st.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
var size int
|
var size int
|
||||||
|
@ -762,21 +764,35 @@ func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxRecei
|
||||||
// To match legacy behavior, if the decompressor is set by WithDecompressor or RPCDecompressor,
|
// To match legacy behavior, if the decompressor is set by WithDecompressor or RPCDecompressor,
|
||||||
// use this decompressor as the default.
|
// use this decompressor as the default.
|
||||||
if dc != nil {
|
if dc != nil {
|
||||||
buf, err = dc.Do(bytes.NewReader(buf))
|
uncompressedBuf, err = dc.Do(bytes.NewReader(compressedBuf))
|
||||||
size = len(buf)
|
size = len(uncompressedBuf)
|
||||||
} else {
|
} else {
|
||||||
buf, size, err = decompress(compressor, buf, maxReceiveMessageSize)
|
uncompressedBuf, size, err = decompress(compressor, compressedBuf, maxReceiveMessageSize)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err)
|
return nil, nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err)
|
||||||
}
|
}
|
||||||
if size > maxReceiveMessageSize {
|
if size > maxReceiveMessageSize {
|
||||||
// TODO: Revisit the error code. Currently keep it consistent with java
|
// TODO: Revisit the error code. Currently keep it consistent with java
|
||||||
// implementation.
|
// implementation.
|
||||||
return nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max (%d vs. %d)", size, maxReceiveMessageSize)
|
return nil, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max (%d vs. %d)", size, maxReceiveMessageSize)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
uncompressedBuf = compressedBuf
|
||||||
|
}
|
||||||
|
|
||||||
|
if payInfo != nil {
|
||||||
|
payInfo.compressedLength = len(compressedBuf)
|
||||||
|
payInfo.uncompressedBytes = uncompressedBuf
|
||||||
|
|
||||||
|
cancel = func() {}
|
||||||
|
} else {
|
||||||
|
cancel = func() {
|
||||||
|
p.recvBufferPool.Put(&compressedBuf)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return buf, nil
|
|
||||||
|
return uncompressedBuf, cancel, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Using compressor, decompress d, returning data and size.
|
// Using compressor, decompress d, returning data and size.
|
||||||
|
@ -796,6 +812,9 @@ func decompress(compressor encoding.Compressor, d []byte, maxReceiveMessageSize
|
||||||
// size is used as an estimate to size the buffer, but we
|
// size is used as an estimate to size the buffer, but we
|
||||||
// will read more data if available.
|
// will read more data if available.
|
||||||
// +MinRead so ReadFrom will not reallocate if size is correct.
|
// +MinRead so ReadFrom will not reallocate if size is correct.
|
||||||
|
//
|
||||||
|
// TODO: If we ensure that the buffer size is the same as the DecompressedSize,
|
||||||
|
// we can also utilize the recv buffer pool here.
|
||||||
buf := bytes.NewBuffer(make([]byte, 0, size+bytes.MinRead))
|
buf := bytes.NewBuffer(make([]byte, 0, size+bytes.MinRead))
|
||||||
bytesRead, err := buf.ReadFrom(io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1))
|
bytesRead, err := buf.ReadFrom(io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1))
|
||||||
return buf.Bytes(), int(bytesRead), err
|
return buf.Bytes(), int(bytesRead), err
|
||||||
|
@ -811,18 +830,15 @@ func decompress(compressor encoding.Compressor, d []byte, maxReceiveMessageSize
|
||||||
// dc takes precedence over compressor.
|
// dc takes precedence over compressor.
|
||||||
// TODO(dfawley): wrap the old compressor/decompressor using the new API?
|
// TODO(dfawley): wrap the old compressor/decompressor using the new API?
|
||||||
func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor) error {
|
func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor) error {
|
||||||
buf, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor)
|
buf, cancel, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
if err := c.Unmarshal(buf, m); err != nil {
|
if err := c.Unmarshal(buf, m); err != nil {
|
||||||
return status.Errorf(codes.Internal, "grpc: failed to unmarshal the received message: %v", err)
|
return status.Errorf(codes.Internal, "grpc: failed to unmarshal the received message: %v", err)
|
||||||
}
|
}
|
||||||
if payInfo != nil {
|
|
||||||
payInfo.uncompressedBytes = buf
|
|
||||||
} else {
|
|
||||||
p.recvBufferPool.Put(&buf)
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1342,7 +1342,8 @@ func (s *Server) processUnaryRPC(ctx context.Context, t transport.ServerTranspor
|
||||||
if len(shs) != 0 || len(binlogs) != 0 {
|
if len(shs) != 0 || len(binlogs) != 0 {
|
||||||
payInfo = &payloadInfo{}
|
payInfo = &payloadInfo{}
|
||||||
}
|
}
|
||||||
d, err := recvAndDecompress(&parser{r: stream, recvBufferPool: s.opts.recvBufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp)
|
|
||||||
|
d, cancel, err := recvAndDecompress(&parser{r: stream, recvBufferPool: s.opts.recvBufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if e := t.WriteStatus(stream, status.Convert(err)); e != nil {
|
if e := t.WriteStatus(stream, status.Convert(err)); e != nil {
|
||||||
channelz.Warningf(logger, s.channelzID, "grpc: Server.processUnaryRPC failed to write status: %v", e)
|
channelz.Warningf(logger, s.channelzID, "grpc: Server.processUnaryRPC failed to write status: %v", e)
|
||||||
|
@ -1353,6 +1354,8 @@ func (s *Server) processUnaryRPC(ctx context.Context, t transport.ServerTranspor
|
||||||
t.IncrMsgRecv()
|
t.IncrMsgRecv()
|
||||||
}
|
}
|
||||||
df := func(v any) error {
|
df := func(v any) error {
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
if err := s.getCodec(stream.ContentSubtype()).Unmarshal(d, v); err != nil {
|
if err := s.getCodec(stream.ContentSubtype()).Unmarshal(d, v); err != nil {
|
||||||
return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err)
|
return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue