diff --git a/call.go b/call.go index 89d2782ae..3cc8f1fd0 100644 --- a/call.go +++ b/call.go @@ -34,6 +34,7 @@ package grpc import ( + "bytes" "io" "time" @@ -47,7 +48,7 @@ import ( // On error, it returns the error and indicates whether the call should be retried. // // TODO(zhaoq): Check whether the received message sequence is valid. -func recvResponse(codec Codec, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) error { +func recvResponse(dopts dialOptions, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) error { // Try to acquire header metadata from the server if there is any. var err error c.headerMD, err = stream.Header() @@ -56,7 +57,7 @@ func recvResponse(codec Codec, t transport.ClientTransport, c *callInfo, stream } p := &parser{s: stream} for { - if err = recv(p, codec, reply); err != nil { + if err = recv(p, dopts.codec, stream, dopts.dg, reply); err != nil { if err == io.EOF { break } @@ -68,7 +69,7 @@ func recvResponse(codec Codec, t transport.ClientTransport, c *callInfo, stream } // sendRequest writes out various information of an RPC such as Context and Message. -func sendRequest(ctx context.Context, codec Codec, callHdr *transport.CallHdr, t transport.ClientTransport, args interface{}, opts *transport.Options) (_ *transport.Stream, err error) { +func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHdr *transport.CallHdr, t transport.ClientTransport, args interface{}, opts *transport.Options) (_ *transport.Stream, err error) { stream, err := t.NewStream(ctx, callHdr) if err != nil { return nil, err @@ -80,8 +81,7 @@ func sendRequest(ctx context.Context, codec Codec, callHdr *transport.CallHdr, t } } }() - // TODO(zhaoq): Support compression. - outBuf, err := encode(codec, args, compressionNone) + outBuf, err := encode(codec, args, compressor, new(bytes.Buffer)) if err != nil { return nil, transport.StreamErrorf(codes.Internal, "grpc: %v", err) } @@ -129,7 +129,11 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli } var ( lastErr error // record the error that happened + cp Compressor ) + if cc.dopts.cg != nil { + cp = cc.dopts.cg() + } for { var ( err error @@ -144,6 +148,9 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli Host: cc.authority, Method: method, } + if cp != nil { + callHdr.SendCompress = cp.Type() + } t, err = cc.dopts.picker.Pick(ctx) if err != nil { if lastErr != nil { @@ -155,7 +162,7 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli if c.traceInfo.tr != nil { c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true) } - stream, err = sendRequest(ctx, cc.dopts.codec, callHdr, t, args, topts) + stream, err = sendRequest(ctx, cc.dopts.codec, cp, callHdr, t, args, topts) if err != nil { if _, ok := err.(transport.ConnectionError); ok { lastErr = err @@ -167,7 +174,7 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli return toRPCErr(err) } // Receive the response - lastErr = recvResponse(cc.dopts.codec, t, &c, stream, reply) + lastErr = recvResponse(cc.dopts, t, &c, stream, reply) if _, ok := lastErr.(transport.ConnectionError); ok { continue } diff --git a/call_test.go b/call_test.go index 48d25e50f..22e42c274 100644 --- a/call_test.go +++ b/call_test.go @@ -98,7 +98,7 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) { } } // send a response back to end the stream. - reply, err := encode(testCodec{}, &expectedResponse, compressionNone) + reply, err := encode(testCodec{}, &expectedResponse, nil, nil) if err != nil { t.Fatalf("Failed to encode the response: %v", err) } diff --git a/clientconn.go b/clientconn.go index 9c2e983b9..e81a48868 100644 --- a/clientconn.go +++ b/clientconn.go @@ -73,6 +73,8 @@ var ( // values passed to Dial. type dialOptions struct { codec Codec + cg CompressorGenerator + dg DecompressorGenerator picker Picker block bool insecure bool @@ -89,6 +91,18 @@ func WithCodec(c Codec) DialOption { } } +func WithCompressor(f CompressorGenerator) DialOption { + return func(o *dialOptions) { + o.cg = f + } +} + +func WithDecompressor(f DecompressorGenerator) DialOption { + return func(o *dialOptions) { + o.dg = f + } +} + // WithPicker returns a DialOption which sets a picker for connection selection. func WithPicker(p Picker) DialOption { return func(o *dialOptions) { diff --git a/rpc_util.go b/rpc_util.go index e6b223681..f48ad32c7 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -34,9 +34,12 @@ package grpc import ( + "bytes" + "compress/gzip" "encoding/binary" "fmt" "io" + "io/ioutil" "math" "math/rand" "os" @@ -75,6 +78,59 @@ func (protoCodec) String() string { return "proto" } +type Compressor interface { + Do(w io.Writer, p []byte) error + Type() string +} + +func NewGZIPCompressor() Compressor { + return &gzipCompressor{} +} + +type gzipCompressor struct { +} + +func (c *gzipCompressor) Do(w io.Writer, p []byte) error { + z := gzip.NewWriter(w) + if _, err := z.Write(p); err != nil { + return err + } + return z.Close() +} + +func (c *gzipCompressor) Type() string { + return "gzip" +} + +type Decompressor interface { + Do(r io.Reader) ([]byte, error) + Type() string +} + +type gzipDecompressor struct { +} + +func NewGZIPDecompressor() Decompressor { + return &gzipDecompressor{} +} + +func (d *gzipDecompressor) Do(r io.Reader) ([]byte, error) { + z, err := gzip.NewReader(r) + if err != nil { + return nil, err + } + defer z.Close() + return ioutil.ReadAll(z) +} + +func (d *gzipDecompressor) Type() string { + return "gzip" +} + +type CompressorGenerator func() Compressor + +type DecompressorGenerator func() Decompressor + // callInfo contains all related configuration and information about an RPC. type callInfo struct { failFast bool @@ -126,8 +182,7 @@ type payloadFormat uint8 const ( compressionNone payloadFormat = iota // no compression - compressionFlate - // More formats + compressionMade ) // parser reads complelete gRPC messages from the underlying reader. @@ -166,7 +221,7 @@ func (p *parser) recvMsg() (pf payloadFormat, msg []byte, err error) { // encode serializes msg and prepends the message header. If msg is nil, it // generates the message header of 0 message length. -func encode(c Codec, msg interface{}, pf payloadFormat) ([]byte, error) { +func encode(c Codec, msg interface{}, cp Compressor, cbuf *bytes.Buffer) ([]byte, error) { var b []byte var length uint if msg != nil { @@ -176,6 +231,12 @@ func encode(c Codec, msg interface{}, pf payloadFormat) ([]byte, error) { if err != nil { return nil, err } + if cp != nil { + if err := cp.Do(cbuf, b); err != nil { + return nil, err + } + b = cbuf.Bytes() + } length = uint(len(b)) } if length > math.MaxUint32 { @@ -190,7 +251,11 @@ func encode(c Codec, msg interface{}, pf payloadFormat) ([]byte, error) { var buf = make([]byte, payloadLen+sizeLen+len(b)) // Write payload format - buf[0] = byte(pf) + if cp == nil { + buf[0] = byte(compressionNone) + } else { + buf[0] = byte(compressionMade) + } // Write length of b into buf binary.BigEndian.PutUint32(buf[1:], uint32(length)) // Copy encoded msg to buf @@ -199,22 +264,42 @@ func encode(c Codec, msg interface{}, pf payloadFormat) ([]byte, error) { return buf, nil } -func recv(p *parser, c Codec, m interface{}) error { +func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) error { + switch pf { + case compressionNone: + case compressionMade: + if recvCompress == "" { + return transport.StreamErrorf(codes.InvalidArgument, "grpc: received unexpected payload format %d", pf) + } + if dc == nil || recvCompress != dc.Type() { + return transport.StreamErrorf(codes.InvalidArgument, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress) + } + default: + return transport.StreamErrorf(codes.InvalidArgument, "grpc: received unexpected payload format %d", pf) + } + return nil +} + +func recv(p *parser, c Codec, s *transport.Stream, dg DecompressorGenerator, m interface{}) error { pf, d, err := p.recvMsg() if err != nil { return err } - switch pf { - case compressionNone: - if err := c.Unmarshal(d, m); err != nil { - if rErr, ok := err.(rpcError); ok { - return rErr - } else { - return Errorf(codes.Internal, "grpc: %v", err) - } + var dc Decompressor + if pf == compressionMade && dg != nil { + dc = dg() + } + if err := checkRecvPayload(pf, s.RecvCompress(), dc); err != nil { + return err + } + if pf == compressionMade { + d, err = dc.Do(bytes.NewReader(d)) + if err != nil { + return transport.StreamErrorf(codes.Internal, "grpc: failed to decompress the received message %v", err) } - default: - return Errorf(codes.Internal, "gprc: compression is not supported yet.") + } + if err := c.Unmarshal(d, m); err != nil { + return transport.StreamErrorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err) } return nil } diff --git a/rpc_util_test.go b/rpc_util_test.go index 2673cd054..3f3749ae3 100644 --- a/rpc_util_test.go +++ b/rpc_util_test.go @@ -106,16 +106,40 @@ func TestEncode(t *testing.T) { for _, test := range []struct { // input msg proto.Message - pt payloadFormat + cp Compressor // outputs b []byte err error }{ - {nil, compressionNone, []byte{0, 0, 0, 0, 0}, nil}, + {nil, nil, []byte{0, 0, 0, 0, 0}, nil}, } { - b, err := encode(protoCodec{}, test.msg, test.pt) + b, err := encode(protoCodec{}, test.msg, nil, nil) if err != test.err || !bytes.Equal(b, test.b) { - t.Fatalf("encode(_, _, %d) = %v, %v\nwant %v, %v", test.pt, b, err, test.b, test.err) + t.Fatalf("encode(_, _, %v, _) = %v, %v\nwant %v, %v", test.cp, b, err, test.b, test.err) + } + } +} + +func TestCompress(t *testing.T) { + for _, test := range []struct { + // input + data []byte + cp Compressor + dc Decompressor + // outputs + err error + }{ + {make([]byte, 1024), &gzipCompressor{}, &gzipDecompressor{}, nil}, + } { + b := new(bytes.Buffer) + if err := test.cp.Do(b, test.data); err != test.err { + t.Fatalf("Compressor.Do(_, %v) = %v, want %v", test.data, err, test.err) + } + if b.Len() >= len(test.data) { + t.Fatalf("The compressor fails to compress data.") + } + if p, err := test.dc.Do(b); err != nil || !bytes.Equal(test.data, p) { + t.Fatalf("Decompressor.Do(%v) = %v, %v, want %v, ", b, p, err, test.data) } } } @@ -158,12 +182,12 @@ func TestContextErr(t *testing.T) { // bytes. func bmEncode(b *testing.B, mSize int) { msg := &perfpb.Buffer{Body: make([]byte, mSize)} - encoded, _ := encode(protoCodec{}, msg, compressionNone) + encoded, _ := encode(protoCodec{}, msg, nil, nil) encodedSz := int64(len(encoded)) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { - encode(protoCodec{}, msg, compressionNone) + encode(protoCodec{}, msg, nil, nil) } b.SetBytes(encodedSz) } diff --git a/server.go b/server.go index 655e7d865..e846f3ba9 100644 --- a/server.go +++ b/server.go @@ -34,6 +34,7 @@ package grpc import ( + "bytes" "errors" "fmt" "io" @@ -92,6 +93,8 @@ type Server struct { type options struct { creds credentials.Credentials codec Codec + cg CompressorGenerator + dg DecompressorGenerator maxConcurrentStreams uint32 } @@ -105,6 +108,18 @@ func CustomCodec(codec Codec) ServerOption { } } +func CompressON(f CompressorGenerator) ServerOption { + return func(o *options) { + o.cg = f + } +} + +func DecompressON(f DecompressorGenerator) ServerOption { + return func(o *options) { + o.dg = f + } +} + // MaxConcurrentStreams returns a ServerOption that will apply a limit on the number // of concurrent streams to each ServerTransport. func MaxConcurrentStreams(n uint32) ServerOption { @@ -287,8 +302,8 @@ func (s *Server) Serve(lis net.Listener) error { } } -func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, pf payloadFormat, opts *transport.Options) error { - p, err := encode(s.opts.codec, msg, pf) +func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, cp Compressor, opts *transport.Options) error { + p, err := encode(s.opts.codec, msg, cp, new(bytes.Buffer)) if err != nil { // This typically indicates a fatal issue (e.g., memory // corruption or hardware faults) the application program @@ -327,82 +342,119 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. // Nothing to do here. case transport.StreamError: if err := t.WriteStatus(stream, err.Code, err.Desc); err != nil { - grpclog.Printf("grpc: Server.processUnaryRPC failed to write status: %v", err) + grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err) } default: panic(fmt.Sprintf("grpc: Unexpected error (%T) from recvMsg: %v", err, err)) } return err } - switch pf { - case compressionNone: - statusCode := codes.OK - statusDesc := "" - df := func(v interface{}) error { - if err := s.opts.codec.Unmarshal(req, v); err != nil { - return err + + var dc Decompressor + if pf == compressionMade && s.opts.dg != nil { + dc = s.opts.dg() + } + if err := checkRecvPayload(pf, stream.RecvCompress(), dc); err != nil { + switch err := err.(type) { + case transport.StreamError: + if err := t.WriteStatus(stream, err.Code, err.Desc); err != nil { + grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err) } - if trInfo != nil { - trInfo.tr.LazyLog(&payload{sent: false, msg: v}, true) - } - return nil - } - reply, appErr := md.Handler(srv.server, stream.Context(), df) - if appErr != nil { - if err, ok := appErr.(rpcError); ok { - statusCode = err.code - statusDesc = err.desc - } else { - statusCode = convertCode(appErr) - statusDesc = appErr.Error() - } - if trInfo != nil && statusCode != codes.OK { - trInfo.tr.LazyLog(stringer(statusDesc), true) - trInfo.tr.SetError() + default: + if err := t.WriteStatus(stream, codes.Internal, err.Error()); err != nil { + grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err) } - if err := t.WriteStatus(stream, statusCode, statusDesc); err != nil { - grpclog.Printf("grpc: Server.processUnaryRPC failed to write status: %v", err) + } + return err + } + statusCode := codes.OK + statusDesc := "" + df := func(v interface{}) error { + if pf == compressionMade { + var err error + req, err = dc.Do(bytes.NewReader(req)) + //req, err = ioutil.ReadAll(dc) + //defer dc.Close() + if err != nil { + if err := t.WriteStatus(stream, codes.Internal, err.Error()); err != nil { + grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err) + } return err } - return nil } - if trInfo != nil { - trInfo.tr.LazyLog(stringer("OK"), false) - } - opts := &transport.Options{ - Last: true, - Delay: false, - } - if err := s.sendResponse(t, stream, reply, compressionNone, opts); err != nil { - switch err := err.(type) { - case transport.ConnectionError: - // Nothing to do here. - case transport.StreamError: - statusCode = err.Code - statusDesc = err.Desc - default: - statusCode = codes.Unknown - statusDesc = err.Error() - } + if err := s.opts.codec.Unmarshal(req, v); err != nil { return err } if trInfo != nil { - trInfo.tr.LazyLog(&payload{sent: true, msg: reply}, true) + trInfo.tr.LazyLog(&payload{sent: false, msg: v}, true) } - return t.WriteStatus(stream, statusCode, statusDesc) - default: - panic(fmt.Sprintf("payload format to be supported: %d", pf)) + return nil } + reply, appErr := md.Handler(srv.server, stream.Context(), df) + if appErr != nil { + if err, ok := appErr.(rpcError); ok { + statusCode = err.code + statusDesc = err.desc + } else { + statusCode = convertCode(appErr) + statusDesc = appErr.Error() + } + if trInfo != nil && statusCode != codes.OK { + trInfo.tr.LazyLog(stringer(statusDesc), true) + trInfo.tr.SetError() + } + if err := t.WriteStatus(stream, statusCode, statusDesc); err != nil { + grpclog.Printf("grpc: Server.processUnaryRPC failed to write status: %v", err) + return err + } + return nil + } + if trInfo != nil { + trInfo.tr.LazyLog(stringer("OK"), false) + } + opts := &transport.Options{ + Last: true, + Delay: false, + } + var cp Compressor + if s.opts.cg != nil { + cp = s.opts.cg() + stream.SetSendCompress(cp.Type()) + } + if err := s.sendResponse(t, stream, reply, cp, opts); err != nil { + switch err := err.(type) { + case transport.ConnectionError: + // Nothing to do here. + case transport.StreamError: + statusCode = err.Code + statusDesc = err.Desc + default: + statusCode = codes.Unknown + statusDesc = err.Error() + } + return err + } + if trInfo != nil { + trInfo.tr.LazyLog(&payload{sent: true, msg: reply}, true) + } + return t.WriteStatus(stream, statusCode, statusDesc) } } func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, sd *StreamDesc, trInfo *traceInfo) (err error) { + var cp Compressor + if s.opts.cg != nil { + cp = s.opts.cg() + stream.SetSendCompress(cp.Type()) + } ss := &serverStream{ t: t, s: stream, p: &parser{s: stream}, codec: s.opts.codec, + cp: cp, + dg: s.opts.dg, trInfo: trInfo, } if trInfo != nil { @@ -422,6 +474,9 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp if err, ok := appErr.(rpcError); ok { ss.statusCode = err.code ss.statusDesc = err.desc + } else if err, ok := appErr.(transport.StreamError); ok { + ss.statusCode = err.Code + ss.statusDesc = err.Desc } else { ss.statusCode = convertCode(appErr) ss.statusDesc = appErr.Error() diff --git a/stream.go b/stream.go index d8bdc16b5..1a8d0e459 100644 --- a/stream.go +++ b/stream.go @@ -34,6 +34,7 @@ package grpc import ( + "bytes" "errors" "io" "sync" @@ -104,14 +105,23 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth if err != nil { return nil, toRPCErr(err) } + var cp Compressor + if cc.dopts.cg != nil { + cp = cc.dopts.cg() + } // TODO(zhaoq): CallOption is omitted. Add support when it is needed. callHdr := &transport.CallHdr{ Host: cc.authority, Method: method, } + if cp != nil { + callHdr.SendCompress = cp.Type() + } cs := &clientStream{ desc: desc, codec: cc.dopts.codec, + cp: cp, + dg: cc.dopts.dg, tracing: EnableTracing, } if cs.tracing { @@ -153,6 +163,9 @@ type clientStream struct { p *parser desc *StreamDesc codec Codec + cp Compressor + cbuf bytes.Buffer + dg DecompressorGenerator tracing bool // set to EnableTracing when the clientStream is created. @@ -198,7 +211,8 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { } err = toRPCErr(err) }() - out, err := encode(cs.codec, m, compressionNone) + out, err := encode(cs.codec, m, cs.cp, &cs.cbuf) + defer cs.cbuf.Reset() if err != nil { return transport.StreamErrorf(codes.Internal, "grpc: %v", err) } @@ -206,7 +220,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { } func (cs *clientStream) RecvMsg(m interface{}) (err error) { - err = recv(cs.p, cs.codec, m) + err = recv(cs.p, cs.codec, cs.s, cs.dg, m) defer func() { // err != nil indicates the termination of the stream. if err != nil { @@ -225,7 +239,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { return } // Special handling for client streaming rpc. - err = recv(cs.p, cs.codec, m) + err = recv(cs.p, cs.codec, cs.s, cs.dg, m) cs.closeTransportStream(err) if err == nil { return toRPCErr(errors.New("grpc: client streaming protocol violation: get , want ")) @@ -310,6 +324,9 @@ type serverStream struct { s *transport.Stream p *parser codec Codec + cp Compressor + dg DecompressorGenerator + cbuf bytes.Buffer statusCode codes.Code statusDesc string trInfo *traceInfo @@ -348,7 +365,8 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) { ss.mu.Unlock() } }() - out, err := encode(ss.codec, m, compressionNone) + out, err := encode(ss.codec, m, ss.cp, &ss.cbuf) + defer ss.cbuf.Reset() if err != nil { err = transport.StreamErrorf(codes.Internal, "grpc: %v", err) return err @@ -371,5 +389,5 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) { ss.mu.Unlock() } }() - return recv(ss.p, ss.codec, m) + return recv(ss.p, ss.codec, ss.s, ss.dg, m) } diff --git a/test/end2end_test.go b/test/end2end_test.go index 93944f8d9..7823ba5ee 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -143,7 +143,6 @@ func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (* if err != nil { return nil, err } - return &testpb.SimpleResponse{ Payload: payload, }, nil @@ -328,8 +327,8 @@ func listTestEnv() []env { return []env{{"tcp", nil, ""}, {"tcp", nil, "tls"}, {"unix", unixDialer, ""}, {"unix", unixDialer, "tls"}} } -func setUp(t *testing.T, hs *health.HealthServer, maxStream uint32, ua string, e env) (s *grpc.Server, cc *grpc.ClientConn) { - sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(maxStream)} +func serverSetUp(t *testing.T, hs *health.HealthServer, maxStream uint32, cg grpc.CompressorGenerator, dg grpc.DecompressorGenerator, e env) (s *grpc.Server, addr string) { + sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(maxStream), grpc.CompressON(cg), grpc.DecompressON(dg)} la := ":0" switch e.network { case "unix": @@ -353,7 +352,7 @@ func setUp(t *testing.T, hs *health.HealthServer, maxStream uint32, ua string, e } testpb.RegisterTestServiceServer(s, &testServer{security: e.security}) go s.Serve(lis) - addr := la + addr = la switch e.network { case "unix": default: @@ -363,17 +362,22 @@ func setUp(t *testing.T, hs *health.HealthServer, maxStream uint32, ua string, e } addr = "localhost:" + port } + return +} + +func clientSetUp(t *testing.T, addr string, cg grpc.CompressorGenerator, dg grpc.DecompressorGenerator, ua string, e env) (cc *grpc.ClientConn) { + var derr error if e.security == "tls" { creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com") if err != nil { t.Fatalf("Failed to create credentials %v", err) } - cc, err = grpc.Dial(addr, grpc.WithTransportCredentials(creds), grpc.WithDialer(e.dialer), grpc.WithUserAgent(ua)) + cc, derr = grpc.Dial(addr, grpc.WithTransportCredentials(creds), grpc.WithDialer(e.dialer), grpc.WithUserAgent(ua), grpc.WithCompressor(cg), grpc.WithDecompressor(dg)) } else { - cc, err = grpc.Dial(addr, grpc.WithDialer(e.dialer), grpc.WithInsecure(), grpc.WithUserAgent(ua)) + cc, derr = grpc.Dial(addr, grpc.WithDialer(e.dialer), grpc.WithInsecure(), grpc.WithUserAgent(ua), grpc.WithCompressor(cg), grpc.WithDecompressor(dg)) } - if err != nil { - t.Fatalf("Dial(%q) = %v", addr, err) + if derr != nil { + t.Fatalf("Dial(%q) = %v", addr, derr) } return } @@ -390,7 +394,8 @@ func TestTimeoutOnDeadServer(t *testing.T) { } func testTimeoutOnDeadServer(t *testing.T, e env) { - s, cc := setUp(t, nil, math.MaxUint32, "", e) + s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e) + cc := clientSetUp(t, addr, nil, nil, "", e) tc := testpb.NewTestServiceClient(cc) ctx, _ := context.WithTimeout(context.Background(), time.Second) if _, err := cc.WaitForStateChange(ctx, grpc.Idle); err != nil { @@ -443,7 +448,8 @@ func TestHealthCheckOnSuccess(t *testing.T) { func testHealthCheckOnSuccess(t *testing.T, e env) { hs := health.NewHealthServer() hs.SetServingStatus("grpc.health.v1alpha.Health", 1) - s, cc := setUp(t, hs, math.MaxUint32, "", e) + s, addr := serverSetUp(t, hs, math.MaxUint32, nil, nil, e) + cc := clientSetUp(t, addr, nil, nil, "", e) defer tearDown(s, cc) if _, err := healthCheck(1*time.Second, cc, "grpc.health.v1alpha.Health"); err != nil { t.Fatalf("Health/Check(_, _) = _, %v, want _, ", err) @@ -459,7 +465,8 @@ func TestHealthCheckOnFailure(t *testing.T) { func testHealthCheckOnFailure(t *testing.T, e env) { hs := health.NewHealthServer() hs.SetServingStatus("grpc.health.v1alpha.HealthCheck", 1) - s, cc := setUp(t, hs, math.MaxUint32, "", e) + s, addr := serverSetUp(t, hs, math.MaxUint32, nil, nil, e) + cc := clientSetUp(t, addr, nil, nil, "", e) defer tearDown(s, cc) if _, err := healthCheck(0*time.Second, cc, "grpc.health.v1alpha.Health"); err != grpc.Errorf(codes.DeadlineExceeded, "context deadline exceeded") { t.Fatalf("Health/Check(_, _) = _, %v, want _, error code %d", err, codes.DeadlineExceeded) @@ -473,7 +480,8 @@ func TestHealthCheckOff(t *testing.T) { } func testHealthCheckOff(t *testing.T, e env) { - s, cc := setUp(t, nil, math.MaxUint32, "", e) + s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e) + cc := clientSetUp(t, addr, nil, nil, "", e) defer tearDown(s, cc) if _, err := healthCheck(1*time.Second, cc, ""); err != grpc.Errorf(codes.Unimplemented, "unknown service grpc.health.v1alpha.Health") { t.Fatalf("Health/Check(_, _) = _, %v, want _, error code %d", err, codes.Unimplemented) @@ -488,7 +496,8 @@ func TestHealthCheckServingStatus(t *testing.T) { func testHealthCheckServingStatus(t *testing.T, e env) { hs := health.NewHealthServer() - s, cc := setUp(t, hs, math.MaxUint32, "", e) + s, addr := serverSetUp(t, hs, math.MaxUint32, nil, nil, e) + cc := clientSetUp(t, addr, nil, nil, "", e) defer tearDown(s, cc) out, err := healthCheck(1*time.Second, cc, "") if err != nil { @@ -526,7 +535,8 @@ func TestEmptyUnaryWithUserAgent(t *testing.T) { } func testEmptyUnaryWithUserAgent(t *testing.T, e env) { - s, cc := setUp(t, nil, math.MaxUint32, testAppUA, e) + s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e) + cc := clientSetUp(t, addr, nil, nil, testAppUA, e) // Wait until cc is connected. ctx, _ := context.WithTimeout(context.Background(), time.Second) if _, err := cc.WaitForStateChange(ctx, grpc.Idle); err != nil { @@ -569,7 +579,8 @@ func TestFailedEmptyUnary(t *testing.T) { } func testFailedEmptyUnary(t *testing.T, e env) { - s, cc := setUp(t, nil, math.MaxUint32, "", e) + s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e) + cc := clientSetUp(t, addr, nil, nil, "", e) tc := testpb.NewTestServiceClient(cc) defer tearDown(s, cc) ctx := metadata.NewContext(context.Background(), testMetadata) @@ -585,7 +596,8 @@ func TestLargeUnary(t *testing.T) { } func testLargeUnary(t *testing.T, e env) { - s, cc := setUp(t, nil, math.MaxUint32, "", e) + s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e) + cc := clientSetUp(t, addr, nil, nil, "", e) tc := testpb.NewTestServiceClient(cc) defer tearDown(s, cc) argSize := 271828 @@ -619,7 +631,8 @@ func TestMetadataUnaryRPC(t *testing.T) { } func testMetadataUnaryRPC(t *testing.T, e env) { - s, cc := setUp(t, nil, math.MaxUint32, "", e) + s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e) + cc := clientSetUp(t, addr, nil, nil, "", e) tc := testpb.NewTestServiceClient(cc) defer tearDown(s, cc) argSize := 2718 @@ -684,7 +697,8 @@ func TestRetry(t *testing.T) { // TODO(zhaoq): Refactor to make this clearer and add more cases to test racy // and error-prone paths. func testRetry(t *testing.T, e env) { - s, cc := setUp(t, nil, math.MaxUint32, "", e) + s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e) + cc := clientSetUp(t, addr, nil, nil, "", e) tc := testpb.NewTestServiceClient(cc) defer tearDown(s, cc) var wg sync.WaitGroup @@ -714,7 +728,8 @@ func TestRPCTimeout(t *testing.T) { // TODO(zhaoq): Have a better test coverage of timeout and cancellation mechanism. func testRPCTimeout(t *testing.T, e env) { - s, cc := setUp(t, nil, math.MaxUint32, "", e) + s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e) + cc := clientSetUp(t, addr, nil, nil, "", e) tc := testpb.NewTestServiceClient(cc) defer tearDown(s, cc) argSize := 2718 @@ -746,7 +761,8 @@ func TestCancel(t *testing.T) { } func testCancel(t *testing.T, e env) { - s, cc := setUp(t, nil, math.MaxUint32, "", e) + s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e) + cc := clientSetUp(t, addr, nil, nil, "", e) tc := testpb.NewTestServiceClient(cc) defer tearDown(s, cc) argSize := 2718 @@ -778,7 +794,8 @@ func TestCancelNoIO(t *testing.T) { func testCancelNoIO(t *testing.T, e env) { // Only allows 1 live stream per server transport. - s, cc := setUp(t, nil, 1, "", e) + s, addr := serverSetUp(t, nil, 1, nil, nil, e) + cc := clientSetUp(t, addr, nil, nil, "", e) tc := testpb.NewTestServiceClient(cc) defer tearDown(s, cc) ctx, cancel := context.WithCancel(context.Background()) @@ -829,7 +846,8 @@ func TestPingPong(t *testing.T) { } func testPingPong(t *testing.T, e env) { - s, cc := setUp(t, nil, math.MaxUint32, "", e) + s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e) + cc := clientSetUp(t, addr, nil, nil, "", e) tc := testpb.NewTestServiceClient(cc) defer tearDown(s, cc) stream, err := tc.FullDuplexCall(context.Background()) @@ -886,7 +904,8 @@ func TestMetadataStreamingRPC(t *testing.T) { } func testMetadataStreamingRPC(t *testing.T, e env) { - s, cc := setUp(t, nil, math.MaxUint32, "", e) + s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e) + cc := clientSetUp(t, addr, nil, nil, "", e) tc := testpb.NewTestServiceClient(cc) defer tearDown(s, cc) ctx := metadata.NewContext(context.Background(), testMetadata) @@ -952,7 +971,8 @@ func TestServerStreaming(t *testing.T) { } func testServerStreaming(t *testing.T, e env) { - s, cc := setUp(t, nil, math.MaxUint32, "", e) + s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e) + cc := clientSetUp(t, addr, nil, nil, "", e) tc := testpb.NewTestServiceClient(cc) defer tearDown(s, cc) respParam := make([]*testpb.ResponseParameters, len(respSizes)) @@ -1004,7 +1024,8 @@ func TestFailedServerStreaming(t *testing.T) { } func testFailedServerStreaming(t *testing.T, e env) { - s, cc := setUp(t, nil, math.MaxUint32, "", e) + s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e) + cc := clientSetUp(t, addr, nil, nil, "", e) tc := testpb.NewTestServiceClient(cc) defer tearDown(s, cc) respParam := make([]*testpb.ResponseParameters, len(respSizes)) @@ -1034,7 +1055,8 @@ func TestClientStreaming(t *testing.T) { } func testClientStreaming(t *testing.T, e env) { - s, cc := setUp(t, nil, math.MaxUint32, "", e) + s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e) + cc := clientSetUp(t, addr, nil, nil, "", e) tc := testpb.NewTestServiceClient(cc) defer tearDown(s, cc) stream, err := tc.StreamingInputCall(context.Background()) @@ -1074,7 +1096,8 @@ func TestExceedMaxStreamsLimit(t *testing.T) { func testExceedMaxStreamsLimit(t *testing.T, e env) { // Only allows 1 live stream per server transport. - s, cc := setUp(t, nil, 1, "", e) + s, addr := serverSetUp(t, nil, 1, nil, nil, e) + cc := clientSetUp(t, addr, nil, nil, "", e) tc := testpb.NewTestServiceClient(cc) defer tearDown(s, cc) _, err := tc.StreamingInputCall(context.Background()) @@ -1095,3 +1118,109 @@ func testExceedMaxStreamsLimit(t *testing.T, e env) { t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, %d", tc, err, codes.DeadlineExceeded) } } + +func TestCompressServerHasNoSupport(t *testing.T) { + for _, e := range listTestEnv() { + testCompressServerHasNoSupport(t, e) + } +} + +func testCompressServerHasNoSupport(t *testing.T, e env) { + s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e) + cc := clientSetUp(t, addr, grpc.NewGZIPCompressor, nil, "", e) + // Unary call + tc := testpb.NewTestServiceClient(cc) + defer tearDown(s, cc) + argSize := 271828 + respSize := 314159 + payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(argSize)) + if err != nil { + t.Fatal(err) + } + req := &testpb.SimpleRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseSize: proto.Int32(int32(respSize)), + Payload: payload, + } + if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.InvalidArgument { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code %d", err, codes.InvalidArgument) + } + // Streaming RPC + stream, err := tc.FullDuplexCall(context.Background()) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + respParam := []*testpb.ResponseParameters{ + { + Size: proto.Int32(31415), + }, + } + payload, err = newPayload(testpb.PayloadType_COMPRESSABLE, int32(31415)) + if err != nil { + t.Fatal(err) + } + sreq := &testpb.StreamingOutputCallRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseParameters: respParam, + Payload: payload, + } + if err := stream.Send(sreq); err != nil { + t.Fatalf("%v.Send(%v) = %v, want ", stream, sreq, err) + } + if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.InvalidArgument { + t.Fatalf("%v.Recv() = %v, want error code %d", stream, err, codes.InvalidArgument) + } +} + +func TestCompressOK(t *testing.T) { + for _, e := range listTestEnv() { + testCompressOK(t, e) + } +} + +func testCompressOK(t *testing.T, e env) { + s, addr := serverSetUp(t, nil, math.MaxUint32, grpc.NewGZIPCompressor, grpc.NewGZIPDecompressor, e) + cc := clientSetUp(t, addr, grpc.NewGZIPCompressor, grpc.NewGZIPDecompressor, "", e) + // Unary call + tc := testpb.NewTestServiceClient(cc) + defer tearDown(s, cc) + argSize := 271828 + respSize := 314159 + payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(argSize)) + if err != nil { + t.Fatal(err) + } + req := &testpb.SimpleRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseSize: proto.Int32(int32(respSize)), + Payload: payload, + } + if _, err := tc.UnaryCall(context.Background(), req); err != nil { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, ", err) + } + // Streaming RPC + stream, err := tc.FullDuplexCall(context.Background()) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + respParam := []*testpb.ResponseParameters{ + { + Size: proto.Int32(31415), + }, + } + payload, err = newPayload(testpb.PayloadType_COMPRESSABLE, int32(31415)) + if err != nil { + t.Fatal(err) + } + sreq := &testpb.StreamingOutputCallRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseParameters: respParam, + Payload: payload, + } + if err := stream.Send(sreq); err != nil { + t.Fatalf("%v.Send(%v) = %v, want ", stream, sreq, err) + } + if _, err := stream.Recv(); err != nil { + t.Fatalf("%v.Recv() = %v, want ", stream, err) + } +} diff --git a/transport/http2_client.go b/transport/http2_client.go index 9eae37df7..a23f6a723 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -208,12 +208,13 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream { } // TODO(zhaoq): Handle uint32 overflow of Stream.id. s := &Stream{ - id: t.nextID, - method: callHdr.Method, - buf: newRecvBuffer(), - fc: fc, - sendQuotaPool: newQuotaPool(int(t.streamSendQuota)), - headerChan: make(chan struct{}), + id: t.nextID, + method: callHdr.Method, + sendCompress: callHdr.SendCompress, + buf: newRecvBuffer(), + fc: fc, + sendQuotaPool: newQuotaPool(int(t.streamSendQuota)), + headerChan: make(chan struct{}), } t.nextID += 2 s.windowHandler = func(n int) { @@ -322,6 +323,9 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea t.hEnc.WriteField(hpack.HeaderField{Name: "user-agent", Value: t.userAgent}) t.hEnc.WriteField(hpack.HeaderField{Name: "te", Value: "trailers"}) + if callHdr.SendCompress != "" { + t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress}) + } if timeout > 0 { t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: timeoutEncode(timeout)}) } @@ -694,8 +698,10 @@ func (t *http2Client) operateHeaders(hDec *hpackDecoder, s *Stream, frame header if !endHeaders { return s } - s.mu.Lock() + if !endStream { + s.recvCompress = hDec.state.encoding + } if !s.headerDone { if !endStream && len(hDec.state.mdata) > 0 { s.header = hDec.state.mdata diff --git a/transport/http2_server.go b/transport/http2_server.go index 98088d9a7..cce2e12d9 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -164,6 +164,7 @@ func (t *http2Server) operateHeaders(hDec *hpackDecoder, s *Stream, frame header if !endHeaders { return s } + s.recvCompress = hDec.state.encoding if hDec.state.timeoutSet { s.ctx, s.cancel = context.WithTimeout(context.TODO(), hDec.state.timeout) } else { @@ -190,6 +191,7 @@ func (t *http2Server) operateHeaders(hDec *hpackDecoder, s *Stream, frame header ctx: s.ctx, recv: s.buf, } + s.recvCompress = hDec.state.encoding s.method = hDec.state.method t.mu.Lock() if t.state != reachable { @@ -446,6 +448,9 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { t.hBuf.Reset() t.hEnc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) t.hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"}) + if s.sendCompress != "" { + t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress}) + } for k, v := range md { for _, entry := range v { t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry}) @@ -520,6 +525,9 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error { t.hBuf.Reset() t.hEnc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) t.hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"}) + if s.sendCompress != "" { + t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress}) + } p := http2.HeadersFrameParam{ StreamID: s.id, BlockFragment: t.hBuf.Bytes(), diff --git a/transport/http_util.go b/transport/http_util.go index fec4e4755..f9d9fdf0a 100644 --- a/transport/http_util.go +++ b/transport/http_util.go @@ -89,6 +89,7 @@ var ( // Records the states during HPACK decoding. Must be reset once the // decoding of the entire headers are finished. type decodeState struct { + encoding string // statusCode caches the stream status received from the trailer // the server sent. Client side only. statusCode codes.Code @@ -145,6 +146,8 @@ func newHPACKDecoder() *hpackDecoder { d.err = StreamErrorf(codes.FailedPrecondition, "transport: received the unexpected header") return } + case "grpc-encoding": + d.state.encoding = f.Value case "grpc-status": code, err := strconv.Atoi(f.Value) if err != nil { diff --git a/transport/transport.go b/transport/transport.go index e1e7f5761..9d6f4647a 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -171,6 +171,8 @@ type Stream struct { cancel context.CancelFunc // method records the associated RPC method of the stream. method string + recvCompress string + sendCompress string buf *recvBuffer dec io.Reader fc *inFlow @@ -201,6 +203,14 @@ type Stream struct { statusDesc string } +func (s *Stream) RecvCompress() string { + return s.recvCompress +} + +func (s *Stream) SetSendCompress(str string) { + s.sendCompress = str +} + // Header acquires the key-value pairs of header metadata once it // is available. It blocks until i) the metadata is ready or ii) there is no // header metadata or iii) the stream is cancelled/expired. @@ -350,6 +360,8 @@ type Options struct { type CallHdr struct { Host string // peer host Method string // the operation to perform on the specified host + RecvCompress string + SendCompress string } // ClientTransport is the common interface for all gRPC client side transport