From b13920a0cf5309612439bdd807feaf399acbd41c Mon Sep 17 00:00:00 2001 From: iamqizhao Date: Fri, 29 Jul 2016 16:19:20 -0700 Subject: [PATCH] add the mem alloc guard on server --- call_test.go | 2 +- rpc_util.go | 7 +++++-- rpc_util_test.go | 13 +++++++------ server.go | 6 +++++- 4 files changed, 18 insertions(+), 10 deletions(-) diff --git a/call_test.go b/call_test.go index 493498586..97eb9c002 100644 --- a/call_test.go +++ b/call_test.go @@ -81,7 +81,7 @@ type testStreamHandler struct { func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) { p := &parser{r: s} for { - pf, req, err := p.recvMsg() + pf, req, err := p.recvMsg(math.MaxInt32) if err == io.EOF { break } diff --git a/rpc_util.go b/rpc_util.go index 431eb3eca..a6a11369c 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -227,7 +227,7 @@ type parser struct { // No other error values or types must be returned, which also means // that the underlying io.Reader must not return an incompatible // error. -func (p *parser) recvMsg() (pf payloadFormat, msg []byte, err error) { +func (p *parser) recvMsg(maxMsgSize int) (pf payloadFormat, msg []byte, err error) { if _, err := io.ReadFull(p.r, p.header[:]); err != nil { return 0, nil, err } @@ -238,6 +238,9 @@ func (p *parser) recvMsg() (pf payloadFormat, msg []byte, err error) { if length == 0 { return pf, nil, nil } + if length > uint32(maxMsgSize) { + return 0, nil, Errorf(codes.Internal, "grpc: received message length %d exceeding the max size %d", length, maxMsgSize) + } // TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead // of making it for each message: msg = make([]byte, int(length)) @@ -309,7 +312,7 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) er } func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxMsgSize int) error { - pf, d, err := p.recvMsg() + pf, d, err := p.recvMsg(maxMsgSize) if err != nil { return err } diff --git a/rpc_util_test.go b/rpc_util_test.go index 830546c62..8a813c626 100644 --- a/rpc_util_test.go +++ b/rpc_util_test.go @@ -36,6 +36,7 @@ package grpc import ( "bytes" "io" + "math" "reflect" "testing" @@ -66,9 +67,9 @@ func TestSimpleParsing(t *testing.T) { } { buf := bytes.NewReader(test.p) parser := &parser{r: buf} - pt, b, err := parser.recvMsg() + pt, b, err := parser.recvMsg(math.MaxInt32) if err != test.err || !bytes.Equal(b, test.b) || pt != test.pt { - t.Fatalf("parser{%v}.recvMsg() = %v, %v, %v\nwant %v, %v, %v", test.p, pt, b, err, test.pt, test.b, test.err) + t.Fatalf("parser{%v}.recvMsg(_) = %v, %v, %v\nwant %v, %v, %v", test.p, pt, b, err, test.pt, test.b, test.err) } } } @@ -88,16 +89,16 @@ func TestMultipleParsing(t *testing.T) { {compressionNone, []byte("d")}, } for i, want := range wantRecvs { - pt, data, err := parser.recvMsg() + pt, data, err := parser.recvMsg(math.MaxInt32) if err != nil || pt != want.pt || !reflect.DeepEqual(data, want.data) { - t.Fatalf("after %d calls, parser{%v}.recvMsg() = %v, %v, %v\nwant %v, %v, ", + t.Fatalf("after %d calls, parser{%v}.recvMsg(_) = %v, %v, %v\nwant %v, %v, ", i, p, pt, data, err, want.pt, want.data) } } - pt, data, err := parser.recvMsg() + pt, data, err := parser.recvMsg(math.MaxInt32) if err != io.EOF { - t.Fatalf("after %d recvMsgs calls, parser{%v}.recvMsg() = %v, %v, %v\nwant _, _, %v", + t.Fatalf("after %d recvMsgs calls, parser{%v}.recvMsg(_) = %v, %v, %v\nwant _, _, %v", len(wantRecvs), p, pt, data, err, io.EOF) } } diff --git a/server.go b/server.go index 90c265fd0..cfee9db64 100644 --- a/server.go +++ b/server.go @@ -538,7 +538,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } p := &parser{r: stream} for { - pf, req, err := p.recvMsg() + pf, req, err := p.recvMsg(s.opts.maxMsgSize) if err == io.EOF { // The entire stream is done (for unary RPC only). return err @@ -548,6 +548,10 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } if err != nil { switch err := err.(type) { + case *rpcError: + if err := t.WriteStatus(stream, err.code, err.desc); err != nil { + grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err) + } case transport.ConnectionError: // Nothing to do here. case transport.StreamError: