mirror of https://github.com/grpc/grpc-go.git
Reject over-sized requests on server
This commit is contained in:
parent
f78100723d
commit
8c908a8c1d
3
call.go
3
call.go
|
@ -36,6 +36,7 @@ package grpc
|
|||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
|
@ -57,7 +58,7 @@ func recvResponse(dopts dialOptions, t transport.ClientTransport, c *callInfo, s
|
|||
}
|
||||
p := &parser{r: stream}
|
||||
for {
|
||||
if err = recv(p, dopts.codec, stream, dopts.dc, reply); err != nil {
|
||||
if err = recv(p, dopts.codec, stream, dopts.dc, reply, math.MaxInt32); err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
|
|
11
rpc_util.go
11
rpc_util.go
|
@ -308,7 +308,7 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) er
|
|||
return nil
|
||||
}
|
||||
|
||||
func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}) error {
|
||||
func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxMsgSize int) error {
|
||||
pf, d, err := p.recvMsg()
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -319,11 +319,16 @@ func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{
|
|||
if pf == compressionMade {
|
||||
d, err = dc.Do(bytes.NewReader(d))
|
||||
if err != nil {
|
||||
return transport.StreamErrorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
|
||||
return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
|
||||
}
|
||||
}
|
||||
if len(d) > maxMsgSize {
|
||||
// TODO: Revisit the error code. Currently keep it consistent with java
|
||||
// implementation.
|
||||
return Errorf(codes.Internal, "grpc: server received a message of %d bytes exceeding %d limit", len(d), maxMsgSize)
|
||||
}
|
||||
if err := c.Unmarshal(d, m); err != nil {
|
||||
return transport.StreamErrorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err)
|
||||
return Errorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
35
server.go
35
server.go
|
@ -105,14 +105,14 @@ type options struct {
|
|||
codec Codec
|
||||
cp Compressor
|
||||
dc Decompressor
|
||||
msgLimit int
|
||||
maxMsgSize int
|
||||
unaryInt UnaryServerInterceptor
|
||||
streamInt StreamServerInterceptor
|
||||
maxConcurrentStreams uint32
|
||||
useHandlerImpl bool // use http.Handler-based server
|
||||
}
|
||||
|
||||
var defaultMsgLimit = 1024 * 1024 * 4 // use 4MB as the default message size limit
|
||||
var defaultMaxMsgSize = 1024 * 1024 * 4 // use 4MB as the default message size limit
|
||||
|
||||
// A ServerOption sets options.
|
||||
type ServerOption func(*options)
|
||||
|
@ -124,23 +124,25 @@ func CustomCodec(codec Codec) ServerOption {
|
|||
}
|
||||
}
|
||||
|
||||
// RPCCompressor returns a ServerOption that sets a compressor for outbound message.
|
||||
// RPCCompressor returns a ServerOption that sets a compressor for outbound messages.
|
||||
func RPCCompressor(cp Compressor) ServerOption {
|
||||
return func(o *options) {
|
||||
o.cp = cp
|
||||
}
|
||||
}
|
||||
|
||||
// RPCDecompressor returns a ServerOption that sets a decompressor for inbound message.
|
||||
// RPCDecompressor returns a ServerOption that sets a decompressor for inbound messages.
|
||||
func RPCDecompressor(dc Decompressor) ServerOption {
|
||||
return func(o *options) {
|
||||
o.dc = dc
|
||||
}
|
||||
}
|
||||
|
||||
func MsgLimit(m int) ServerOption {
|
||||
// MaxMsgSize returns a ServerOption to set the max message size in bytes for inbound mesages.
|
||||
// If this is not set, gRPC uses the default 4MB.
|
||||
func MaxMsgSize(m int) ServerOption {
|
||||
return func(o *options) {
|
||||
o.msgLimit = m
|
||||
o.maxMsgSize = m
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -186,7 +188,7 @@ func StreamInterceptor(i StreamServerInterceptor) ServerOption {
|
|||
// started to accept requests yet.
|
||||
func NewServer(opt ...ServerOption) *Server {
|
||||
var opts options
|
||||
opts.msgLimit = defaultMsgLimit
|
||||
opts.maxMsgSize = defaultMaxMsgSize
|
||||
for _, o := range opt {
|
||||
o(&opts)
|
||||
}
|
||||
|
@ -585,11 +587,11 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
|
|||
return err
|
||||
}
|
||||
}
|
||||
if len(req) > s.opts.msgLimit {
|
||||
if len(req) > s.opts.maxMsgSize {
|
||||
// TODO: Revisit the error code. Currently keep it consistent with
|
||||
// java implementation.
|
||||
statusCode = codes.Internal
|
||||
statusDesc = fmt.Sprintf("server received a message of %d bytes exceeding %d limit", len(req), s.opts.msgLimit)
|
||||
statusDesc = fmt.Sprintf("grpc: server received a message of %d bytes exceeding %d limit", len(req), s.opts.maxMsgSize)
|
||||
}
|
||||
if err := s.opts.codec.Unmarshal(req, v); err != nil {
|
||||
return err
|
||||
|
@ -650,13 +652,14 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
|
|||
stream.SetSendCompress(s.opts.cp.Type())
|
||||
}
|
||||
ss := &serverStream{
|
||||
t: t,
|
||||
s: stream,
|
||||
p: &parser{r: stream},
|
||||
codec: s.opts.codec,
|
||||
cp: s.opts.cp,
|
||||
dc: s.opts.dc,
|
||||
trInfo: trInfo,
|
||||
t: t,
|
||||
s: stream,
|
||||
p: &parser{r: stream},
|
||||
codec: s.opts.codec,
|
||||
cp: s.opts.cp,
|
||||
dc: s.opts.dc,
|
||||
maxMsgSize: s.opts.maxMsgSize,
|
||||
trInfo: trInfo,
|
||||
}
|
||||
if ss.cp != nil {
|
||||
ss.cbuf = new(bytes.Buffer)
|
||||
|
|
|
@ -37,6 +37,7 @@ import (
|
|||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
@ -291,7 +292,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
|
|||
}
|
||||
|
||||
func (cs *clientStream) RecvMsg(m interface{}) (err error) {
|
||||
err = recv(cs.p, cs.codec, cs.s, cs.dc, m)
|
||||
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32)
|
||||
defer func() {
|
||||
// err != nil indicates the termination of the stream.
|
||||
if err != nil {
|
||||
|
@ -310,7 +311,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
|
|||
return
|
||||
}
|
||||
// Special handling for client streaming rpc.
|
||||
err = recv(cs.p, cs.codec, cs.s, cs.dc, m)
|
||||
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32)
|
||||
cs.closeTransportStream(err)
|
||||
if err == nil {
|
||||
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
|
||||
|
@ -411,6 +412,7 @@ type serverStream struct {
|
|||
cp Compressor
|
||||
dc Decompressor
|
||||
cbuf *bytes.Buffer
|
||||
maxMsgSize int
|
||||
statusCode codes.Code
|
||||
statusDesc string
|
||||
trInfo *traceInfo
|
||||
|
@ -477,5 +479,5 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) {
|
|||
ss.mu.Unlock()
|
||||
}
|
||||
}()
|
||||
return recv(ss.p, ss.codec, ss.s, ss.dc, m)
|
||||
return recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxMsgSize)
|
||||
}
|
||||
|
|
|
@ -373,6 +373,7 @@ type test struct {
|
|||
testServer testpb.TestServiceServer // nil means none
|
||||
healthServer *health.HealthServer // nil means disabled
|
||||
maxStream uint32
|
||||
maxMsgSize int
|
||||
userAgent string
|
||||
clientCompression bool
|
||||
serverCompression bool
|
||||
|
@ -423,6 +424,9 @@ func (te *test) startServer(ts testpb.TestServiceServer) {
|
|||
e := te.e
|
||||
te.t.Logf("Running test in %s environment...", e.name)
|
||||
sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(te.maxStream)}
|
||||
if te.maxMsgSize > 0 {
|
||||
sopts = append(sopts, grpc.MaxMsgSize(te.maxMsgSize))
|
||||
}
|
||||
if te.serverCompression {
|
||||
sopts = append(sopts,
|
||||
grpc.RPCCompressor(grpc.NewGZIPCompressor()),
|
||||
|
@ -956,6 +960,65 @@ func testLargeUnary(t *testing.T, e env) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestExceedMsgLimit(t *testing.T) {
|
||||
defer leakCheck(t)()
|
||||
for _, e := range listTestEnv() {
|
||||
testExceedMsgLimit(t, e)
|
||||
}
|
||||
}
|
||||
|
||||
func testExceedMsgLimit(t *testing.T, e env) {
|
||||
te := newTest(t, e)
|
||||
te.maxMsgSize = 1024
|
||||
te.startServer(&testServer{security: e.security})
|
||||
defer te.tearDown()
|
||||
tc := testpb.NewTestServiceClient(te.clientConn())
|
||||
|
||||
argSize := int32(te.maxMsgSize + 1)
|
||||
const respSize = 1
|
||||
|
||||
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, argSize)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := &testpb.SimpleRequest{
|
||||
ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
|
||||
ResponseSize: proto.Int32(respSize),
|
||||
Payload: payload,
|
||||
}
|
||||
if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.Internal {
|
||||
t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: ", err, codes.Internal)
|
||||
}
|
||||
|
||||
stream, err := tc.FullDuplexCall(te.ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
|
||||
}
|
||||
respParam := []*testpb.ResponseParameters{
|
||||
{
|
||||
Size: proto.Int32(1),
|
||||
},
|
||||
}
|
||||
|
||||
spayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(te.maxMsgSize+1))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
sreq := &testpb.StreamingOutputCallRequest{
|
||||
ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
|
||||
ResponseParameters: respParam,
|
||||
Payload: spayload,
|
||||
}
|
||||
if err := stream.Send(sreq); err != nil {
|
||||
t.Fatalf("%v.Send(%v) = %v, want <nil>", stream, sreq, err)
|
||||
}
|
||||
if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.Internal {
|
||||
t.Fatalf("%v.Recv() = _, %v, want _, error code: ", stream, err, codes.Internal)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetadataUnaryRPC(t *testing.T) {
|
||||
defer leakCheck(t)()
|
||||
for _, e := range listTestEnv() {
|
||||
|
|
Loading…
Reference in New Issue