From 8c908a8c1def2d9eeecce2cf671bc6480d26578e Mon Sep 17 00:00:00 2001 From: iamqizhao Date: Tue, 26 Jul 2016 16:44:49 -0700 Subject: [PATCH] Reject over-sized requests on server --- call.go | 3 ++- rpc_util.go | 11 +++++--- server.go | 35 +++++++++++++----------- stream.go | 8 +++--- test/end2end_test.go | 63 ++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 97 insertions(+), 23 deletions(-) diff --git a/call.go b/call.go index 27cf64115..0df314d81 100644 --- a/call.go +++ b/call.go @@ -36,6 +36,7 @@ package grpc import ( "bytes" "io" + "math" "time" "golang.org/x/net/context" @@ -57,7 +58,7 @@ func recvResponse(dopts dialOptions, t transport.ClientTransport, c *callInfo, s } p := &parser{r: stream} for { - if err = recv(p, dopts.codec, stream, dopts.dc, reply); err != nil { + if err = recv(p, dopts.codec, stream, dopts.dc, reply, math.MaxInt32); err != nil { if err == io.EOF { break } diff --git a/rpc_util.go b/rpc_util.go index d62871756..431eb3eca 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -308,7 +308,7 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) er return nil } -func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}) error { +func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxMsgSize int) error { pf, d, err := p.recvMsg() if err != nil { return err @@ -319,11 +319,16 @@ func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{ if pf == compressionMade { d, err = dc.Do(bytes.NewReader(d)) if err != nil { - return transport.StreamErrorf(codes.Internal, "grpc: failed to decompress the received message %v", err) + return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err) } } + if len(d) > maxMsgSize { + // TODO: Revisit the error code. Currently keep it consistent with java + // implementation. + return Errorf(codes.Internal, "grpc: server received a message of %d bytes exceeding %d limit", len(d), maxMsgSize) + } if err := c.Unmarshal(d, m); err != nil { - return transport.StreamErrorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err) + return Errorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err) } return nil } diff --git a/server.go b/server.go index b167bd057..d2f0936b3 100644 --- a/server.go +++ b/server.go @@ -105,14 +105,14 @@ type options struct { codec Codec cp Compressor dc Decompressor - msgLimit int + maxMsgSize int unaryInt UnaryServerInterceptor streamInt StreamServerInterceptor maxConcurrentStreams uint32 useHandlerImpl bool // use http.Handler-based server } -var defaultMsgLimit = 1024 * 1024 * 4 // use 4MB as the default message size limit +var defaultMaxMsgSize = 1024 * 1024 * 4 // use 4MB as the default message size limit // A ServerOption sets options. type ServerOption func(*options) @@ -124,23 +124,25 @@ func CustomCodec(codec Codec) ServerOption { } } -// RPCCompressor returns a ServerOption that sets a compressor for outbound message. +// RPCCompressor returns a ServerOption that sets a compressor for outbound messages. func RPCCompressor(cp Compressor) ServerOption { return func(o *options) { o.cp = cp } } -// RPCDecompressor returns a ServerOption that sets a decompressor for inbound message. +// RPCDecompressor returns a ServerOption that sets a decompressor for inbound messages. func RPCDecompressor(dc Decompressor) ServerOption { return func(o *options) { o.dc = dc } } -func MsgLimit(m int) ServerOption { +// MaxMsgSize returns a ServerOption to set the max message size in bytes for inbound mesages. +// If this is not set, gRPC uses the default 4MB. +func MaxMsgSize(m int) ServerOption { return func(o *options) { - o.msgLimit = m + o.maxMsgSize = m } } @@ -186,7 +188,7 @@ func StreamInterceptor(i StreamServerInterceptor) ServerOption { // started to accept requests yet. func NewServer(opt ...ServerOption) *Server { var opts options - opts.msgLimit = defaultMsgLimit + opts.maxMsgSize = defaultMaxMsgSize for _, o := range opt { o(&opts) } @@ -585,11 +587,11 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. return err } } - if len(req) > s.opts.msgLimit { + if len(req) > s.opts.maxMsgSize { // TODO: Revisit the error code. Currently keep it consistent with // java implementation. statusCode = codes.Internal - statusDesc = fmt.Sprintf("server received a message of %d bytes exceeding %d limit", len(req), s.opts.msgLimit) + statusDesc = fmt.Sprintf("grpc: server received a message of %d bytes exceeding %d limit", len(req), s.opts.maxMsgSize) } if err := s.opts.codec.Unmarshal(req, v); err != nil { return err @@ -650,13 +652,14 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp stream.SetSendCompress(s.opts.cp.Type()) } ss := &serverStream{ - t: t, - s: stream, - p: &parser{r: stream}, - codec: s.opts.codec, - cp: s.opts.cp, - dc: s.opts.dc, - trInfo: trInfo, + t: t, + s: stream, + p: &parser{r: stream}, + codec: s.opts.codec, + cp: s.opts.cp, + dc: s.opts.dc, + maxMsgSize: s.opts.maxMsgSize, + trInfo: trInfo, } if ss.cp != nil { ss.cbuf = new(bytes.Buffer) diff --git a/stream.go b/stream.go index fb7e50f9c..f06f137da 100644 --- a/stream.go +++ b/stream.go @@ -37,6 +37,7 @@ import ( "bytes" "errors" "io" + "math" "sync" "time" @@ -291,7 +292,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { } func (cs *clientStream) RecvMsg(m interface{}) (err error) { - err = recv(cs.p, cs.codec, cs.s, cs.dc, m) + err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32) defer func() { // err != nil indicates the termination of the stream. if err != nil { @@ -310,7 +311,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { return } // Special handling for client streaming rpc. - err = recv(cs.p, cs.codec, cs.s, cs.dc, m) + err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32) cs.closeTransportStream(err) if err == nil { return toRPCErr(errors.New("grpc: client streaming protocol violation: get , want ")) @@ -411,6 +412,7 @@ type serverStream struct { cp Compressor dc Decompressor cbuf *bytes.Buffer + maxMsgSize int statusCode codes.Code statusDesc string trInfo *traceInfo @@ -477,5 +479,5 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) { ss.mu.Unlock() } }() - return recv(ss.p, ss.codec, ss.s, ss.dc, m) + return recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxMsgSize) } diff --git a/test/end2end_test.go b/test/end2end_test.go index cdbc4c555..b98c2c6d6 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -373,6 +373,7 @@ type test struct { testServer testpb.TestServiceServer // nil means none healthServer *health.HealthServer // nil means disabled maxStream uint32 + maxMsgSize int userAgent string clientCompression bool serverCompression bool @@ -423,6 +424,9 @@ func (te *test) startServer(ts testpb.TestServiceServer) { e := te.e te.t.Logf("Running test in %s environment...", e.name) sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(te.maxStream)} + if te.maxMsgSize > 0 { + sopts = append(sopts, grpc.MaxMsgSize(te.maxMsgSize)) + } if te.serverCompression { sopts = append(sopts, grpc.RPCCompressor(grpc.NewGZIPCompressor()), @@ -956,6 +960,65 @@ func testLargeUnary(t *testing.T, e env) { } } +func TestExceedMsgLimit(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + testExceedMsgLimit(t, e) + } +} + +func testExceedMsgLimit(t *testing.T, e env) { + te := newTest(t, e) + te.maxMsgSize = 1024 + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + tc := testpb.NewTestServiceClient(te.clientConn()) + + argSize := int32(te.maxMsgSize + 1) + const respSize = 1 + + payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, argSize) + if err != nil { + t.Fatal(err) + } + + req := &testpb.SimpleRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseSize: proto.Int32(respSize), + Payload: payload, + } + if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.Internal { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: ", err, codes.Internal) + } + + stream, err := tc.FullDuplexCall(te.ctx) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + respParam := []*testpb.ResponseParameters{ + { + Size: proto.Int32(1), + }, + } + + spayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(te.maxMsgSize+1)) + if err != nil { + t.Fatal(err) + } + + sreq := &testpb.StreamingOutputCallRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseParameters: respParam, + Payload: spayload, + } + if err := stream.Send(sreq); err != nil { + t.Fatalf("%v.Send(%v) = %v, want ", stream, sreq, err) + } + if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.Internal { + t.Fatalf("%v.Recv() = _, %v, want _, error code: ", stream, err, codes.Internal) + } +} + func TestMetadataUnaryRPC(t *testing.T) { defer leakCheck(t)() for _, e := range listTestEnv() {