mirror of https://github.com/grpc/grpc-go.git
add the mem alloc guard on server
This commit is contained in:
parent
a4c08780d5
commit
b13920a0cf
|
@ -81,7 +81,7 @@ type testStreamHandler struct {
|
||||||
func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) {
|
func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) {
|
||||||
p := &parser{r: s}
|
p := &parser{r: s}
|
||||||
for {
|
for {
|
||||||
pf, req, err := p.recvMsg()
|
pf, req, err := p.recvMsg(math.MaxInt32)
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
|
@ -227,7 +227,7 @@ type parser struct {
|
||||||
// No other error values or types must be returned, which also means
|
// No other error values or types must be returned, which also means
|
||||||
// that the underlying io.Reader must not return an incompatible
|
// that the underlying io.Reader must not return an incompatible
|
||||||
// error.
|
// error.
|
||||||
func (p *parser) recvMsg() (pf payloadFormat, msg []byte, err error) {
|
func (p *parser) recvMsg(maxMsgSize int) (pf payloadFormat, msg []byte, err error) {
|
||||||
if _, err := io.ReadFull(p.r, p.header[:]); err != nil {
|
if _, err := io.ReadFull(p.r, p.header[:]); err != nil {
|
||||||
return 0, nil, err
|
return 0, nil, err
|
||||||
}
|
}
|
||||||
|
@ -238,6 +238,9 @@ func (p *parser) recvMsg() (pf payloadFormat, msg []byte, err error) {
|
||||||
if length == 0 {
|
if length == 0 {
|
||||||
return pf, nil, nil
|
return pf, nil, nil
|
||||||
}
|
}
|
||||||
|
if length > uint32(maxMsgSize) {
|
||||||
|
return 0, nil, Errorf(codes.Internal, "grpc: received message length %d exceeding the max size %d", length, maxMsgSize)
|
||||||
|
}
|
||||||
// TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead
|
// TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead
|
||||||
// of making it for each message:
|
// of making it for each message:
|
||||||
msg = make([]byte, int(length))
|
msg = make([]byte, int(length))
|
||||||
|
@ -309,7 +312,7 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) er
|
||||||
}
|
}
|
||||||
|
|
||||||
func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxMsgSize int) error {
|
func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxMsgSize int) error {
|
||||||
pf, d, err := p.recvMsg()
|
pf, d, err := p.recvMsg(maxMsgSize)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -36,6 +36,7 @@ package grpc
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"io"
|
"io"
|
||||||
|
"math"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
@ -66,9 +67,9 @@ func TestSimpleParsing(t *testing.T) {
|
||||||
} {
|
} {
|
||||||
buf := bytes.NewReader(test.p)
|
buf := bytes.NewReader(test.p)
|
||||||
parser := &parser{r: buf}
|
parser := &parser{r: buf}
|
||||||
pt, b, err := parser.recvMsg()
|
pt, b, err := parser.recvMsg(math.MaxInt32)
|
||||||
if err != test.err || !bytes.Equal(b, test.b) || pt != test.pt {
|
if err != test.err || !bytes.Equal(b, test.b) || pt != test.pt {
|
||||||
t.Fatalf("parser{%v}.recvMsg() = %v, %v, %v\nwant %v, %v, %v", test.p, pt, b, err, test.pt, test.b, test.err)
|
t.Fatalf("parser{%v}.recvMsg(_) = %v, %v, %v\nwant %v, %v, %v", test.p, pt, b, err, test.pt, test.b, test.err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -88,16 +89,16 @@ func TestMultipleParsing(t *testing.T) {
|
||||||
{compressionNone, []byte("d")},
|
{compressionNone, []byte("d")},
|
||||||
}
|
}
|
||||||
for i, want := range wantRecvs {
|
for i, want := range wantRecvs {
|
||||||
pt, data, err := parser.recvMsg()
|
pt, data, err := parser.recvMsg(math.MaxInt32)
|
||||||
if err != nil || pt != want.pt || !reflect.DeepEqual(data, want.data) {
|
if err != nil || pt != want.pt || !reflect.DeepEqual(data, want.data) {
|
||||||
t.Fatalf("after %d calls, parser{%v}.recvMsg() = %v, %v, %v\nwant %v, %v, <nil>",
|
t.Fatalf("after %d calls, parser{%v}.recvMsg(_) = %v, %v, %v\nwant %v, %v, <nil>",
|
||||||
i, p, pt, data, err, want.pt, want.data)
|
i, p, pt, data, err, want.pt, want.data)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pt, data, err := parser.recvMsg()
|
pt, data, err := parser.recvMsg(math.MaxInt32)
|
||||||
if err != io.EOF {
|
if err != io.EOF {
|
||||||
t.Fatalf("after %d recvMsgs calls, parser{%v}.recvMsg() = %v, %v, %v\nwant _, _, %v",
|
t.Fatalf("after %d recvMsgs calls, parser{%v}.recvMsg(_) = %v, %v, %v\nwant _, _, %v",
|
||||||
len(wantRecvs), p, pt, data, err, io.EOF)
|
len(wantRecvs), p, pt, data, err, io.EOF)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -538,7 +538,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
|
||||||
}
|
}
|
||||||
p := &parser{r: stream}
|
p := &parser{r: stream}
|
||||||
for {
|
for {
|
||||||
pf, req, err := p.recvMsg()
|
pf, req, err := p.recvMsg(s.opts.maxMsgSize)
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
// The entire stream is done (for unary RPC only).
|
// The entire stream is done (for unary RPC only).
|
||||||
return err
|
return err
|
||||||
|
@ -548,6 +548,10 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch err := err.(type) {
|
switch err := err.(type) {
|
||||||
|
case *rpcError:
|
||||||
|
if err := t.WriteStatus(stream, err.code, err.desc); err != nil {
|
||||||
|
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err)
|
||||||
|
}
|
||||||
case transport.ConnectionError:
|
case transport.ConnectionError:
|
||||||
// Nothing to do here.
|
// Nothing to do here.
|
||||||
case transport.StreamError:
|
case transport.StreamError:
|
||||||
|
|
Loading…
Reference in New Issue