Support compression

This commit is contained in:
iamqizhao 2016-01-22 18:21:41 -08:00
parent 5da22b92e9
commit da3bb0c9f7
12 changed files with 481 additions and 120 deletions

21
call.go
View File

@ -34,6 +34,7 @@
package grpc
import (
"bytes"
"io"
"time"
@ -47,7 +48,7 @@ import (
// On error, it returns the error and indicates whether the call should be retried.
//
// TODO(zhaoq): Check whether the received message sequence is valid.
func recvResponse(codec Codec, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) error {
func recvResponse(dopts dialOptions, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) error {
// Try to acquire header metadata from the server if there is any.
var err error
c.headerMD, err = stream.Header()
@ -56,7 +57,7 @@ func recvResponse(codec Codec, t transport.ClientTransport, c *callInfo, stream
}
p := &parser{s: stream}
for {
if err = recv(p, codec, reply); err != nil {
if err = recv(p, dopts.codec, stream, dopts.dg, reply); err != nil {
if err == io.EOF {
break
}
@ -68,7 +69,7 @@ func recvResponse(codec Codec, t transport.ClientTransport, c *callInfo, stream
}
// sendRequest writes out various information of an RPC such as Context and Message.
func sendRequest(ctx context.Context, codec Codec, callHdr *transport.CallHdr, t transport.ClientTransport, args interface{}, opts *transport.Options) (_ *transport.Stream, err error) {
func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHdr *transport.CallHdr, t transport.ClientTransport, args interface{}, opts *transport.Options) (_ *transport.Stream, err error) {
stream, err := t.NewStream(ctx, callHdr)
if err != nil {
return nil, err
@ -80,8 +81,7 @@ func sendRequest(ctx context.Context, codec Codec, callHdr *transport.CallHdr, t
}
}
}()
// TODO(zhaoq): Support compression.
outBuf, err := encode(codec, args, compressionNone)
outBuf, err := encode(codec, args, compressor, new(bytes.Buffer))
if err != nil {
return nil, transport.StreamErrorf(codes.Internal, "grpc: %v", err)
}
@ -129,7 +129,11 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
}
var (
lastErr error // record the error that happened
cp Compressor
)
if cc.dopts.cg != nil {
cp = cc.dopts.cg()
}
for {
var (
err error
@ -144,6 +148,9 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
Host: cc.authority,
Method: method,
}
if cp != nil {
callHdr.SendCompress = cp.Type()
}
t, err = cc.dopts.picker.Pick(ctx)
if err != nil {
if lastErr != nil {
@ -155,7 +162,7 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
if c.traceInfo.tr != nil {
c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true)
}
stream, err = sendRequest(ctx, cc.dopts.codec, callHdr, t, args, topts)
stream, err = sendRequest(ctx, cc.dopts.codec, cp, callHdr, t, args, topts)
if err != nil {
if _, ok := err.(transport.ConnectionError); ok {
lastErr = err
@ -167,7 +174,7 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
return toRPCErr(err)
}
// Receive the response
lastErr = recvResponse(cc.dopts.codec, t, &c, stream, reply)
lastErr = recvResponse(cc.dopts, t, &c, stream, reply)
if _, ok := lastErr.(transport.ConnectionError); ok {
continue
}

View File

@ -98,7 +98,7 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) {
}
}
// send a response back to end the stream.
reply, err := encode(testCodec{}, &expectedResponse, compressionNone)
reply, err := encode(testCodec{}, &expectedResponse, nil, nil)
if err != nil {
t.Fatalf("Failed to encode the response: %v", err)
}

View File

@ -73,6 +73,8 @@ var (
// values passed to Dial.
type dialOptions struct {
codec Codec
cg CompressorGenerator
dg DecompressorGenerator
picker Picker
block bool
insecure bool
@ -89,6 +91,18 @@ func WithCodec(c Codec) DialOption {
}
}
func WithCompressor(f CompressorGenerator) DialOption {
return func(o *dialOptions) {
o.cg = f
}
}
func WithDecompressor(f DecompressorGenerator) DialOption {
return func(o *dialOptions) {
o.dg = f
}
}
// WithPicker returns a DialOption which sets a picker for connection selection.
func WithPicker(p Picker) DialOption {
return func(o *dialOptions) {

View File

@ -34,9 +34,12 @@
package grpc
import (
"bytes"
"compress/gzip"
"encoding/binary"
"fmt"
"io"
"io/ioutil"
"math"
"math/rand"
"os"
@ -75,6 +78,59 @@ func (protoCodec) String() string {
return "proto"
}
type Compressor interface {
Do(w io.Writer, p []byte) error
Type() string
}
func NewGZIPCompressor() Compressor {
return &gzipCompressor{}
}
type gzipCompressor struct {
}
func (c *gzipCompressor) Do(w io.Writer, p []byte) error {
z := gzip.NewWriter(w)
if _, err := z.Write(p); err != nil {
return err
}
return z.Close()
}
func (c *gzipCompressor) Type() string {
return "gzip"
}
type Decompressor interface {
Do(r io.Reader) ([]byte, error)
Type() string
}
type gzipDecompressor struct {
}
func NewGZIPDecompressor() Decompressor {
return &gzipDecompressor{}
}
func (d *gzipDecompressor) Do(r io.Reader) ([]byte, error) {
z, err := gzip.NewReader(r)
if err != nil {
return nil, err
}
defer z.Close()
return ioutil.ReadAll(z)
}
func (d *gzipDecompressor) Type() string {
return "gzip"
}
type CompressorGenerator func() Compressor
type DecompressorGenerator func() Decompressor
// callInfo contains all related configuration and information about an RPC.
type callInfo struct {
failFast bool
@ -126,8 +182,7 @@ type payloadFormat uint8
const (
compressionNone payloadFormat = iota // no compression
compressionFlate
// More formats
compressionMade
)
// parser reads complelete gRPC messages from the underlying reader.
@ -166,7 +221,7 @@ func (p *parser) recvMsg() (pf payloadFormat, msg []byte, err error) {
// encode serializes msg and prepends the message header. If msg is nil, it
// generates the message header of 0 message length.
func encode(c Codec, msg interface{}, pf payloadFormat) ([]byte, error) {
func encode(c Codec, msg interface{}, cp Compressor, cbuf *bytes.Buffer) ([]byte, error) {
var b []byte
var length uint
if msg != nil {
@ -176,6 +231,12 @@ func encode(c Codec, msg interface{}, pf payloadFormat) ([]byte, error) {
if err != nil {
return nil, err
}
if cp != nil {
if err := cp.Do(cbuf, b); err != nil {
return nil, err
}
b = cbuf.Bytes()
}
length = uint(len(b))
}
if length > math.MaxUint32 {
@ -190,7 +251,11 @@ func encode(c Codec, msg interface{}, pf payloadFormat) ([]byte, error) {
var buf = make([]byte, payloadLen+sizeLen+len(b))
// Write payload format
buf[0] = byte(pf)
if cp == nil {
buf[0] = byte(compressionNone)
} else {
buf[0] = byte(compressionMade)
}
// Write length of b into buf
binary.BigEndian.PutUint32(buf[1:], uint32(length))
// Copy encoded msg to buf
@ -199,22 +264,42 @@ func encode(c Codec, msg interface{}, pf payloadFormat) ([]byte, error) {
return buf, nil
}
func recv(p *parser, c Codec, m interface{}) error {
func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) error {
switch pf {
case compressionNone:
case compressionMade:
if recvCompress == "" {
return transport.StreamErrorf(codes.InvalidArgument, "grpc: received unexpected payload format %d", pf)
}
if dc == nil || recvCompress != dc.Type() {
return transport.StreamErrorf(codes.InvalidArgument, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
}
default:
return transport.StreamErrorf(codes.InvalidArgument, "grpc: received unexpected payload format %d", pf)
}
return nil
}
func recv(p *parser, c Codec, s *transport.Stream, dg DecompressorGenerator, m interface{}) error {
pf, d, err := p.recvMsg()
if err != nil {
return err
}
switch pf {
case compressionNone:
var dc Decompressor
if pf == compressionMade && dg != nil {
dc = dg()
}
if err := checkRecvPayload(pf, s.RecvCompress(), dc); err != nil {
return err
}
if pf == compressionMade {
d, err = dc.Do(bytes.NewReader(d))
if err != nil {
return transport.StreamErrorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
}
}
if err := c.Unmarshal(d, m); err != nil {
if rErr, ok := err.(rpcError); ok {
return rErr
} else {
return Errorf(codes.Internal, "grpc: %v", err)
}
}
default:
return Errorf(codes.Internal, "gprc: compression is not supported yet.")
return transport.StreamErrorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err)
}
return nil
}

View File

@ -106,16 +106,40 @@ func TestEncode(t *testing.T) {
for _, test := range []struct {
// input
msg proto.Message
pt payloadFormat
cp Compressor
// outputs
b []byte
err error
}{
{nil, compressionNone, []byte{0, 0, 0, 0, 0}, nil},
{nil, nil, []byte{0, 0, 0, 0, 0}, nil},
} {
b, err := encode(protoCodec{}, test.msg, test.pt)
b, err := encode(protoCodec{}, test.msg, nil, nil)
if err != test.err || !bytes.Equal(b, test.b) {
t.Fatalf("encode(_, _, %d) = %v, %v\nwant %v, %v", test.pt, b, err, test.b, test.err)
t.Fatalf("encode(_, _, %v, _) = %v, %v\nwant %v, %v", test.cp, b, err, test.b, test.err)
}
}
}
func TestCompress(t *testing.T) {
for _, test := range []struct {
// input
data []byte
cp Compressor
dc Decompressor
// outputs
err error
}{
{make([]byte, 1024), &gzipCompressor{}, &gzipDecompressor{}, nil},
} {
b := new(bytes.Buffer)
if err := test.cp.Do(b, test.data); err != test.err {
t.Fatalf("Compressor.Do(_, %v) = %v, want %v", test.data, err, test.err)
}
if b.Len() >= len(test.data) {
t.Fatalf("The compressor fails to compress data.")
}
if p, err := test.dc.Do(b); err != nil || !bytes.Equal(test.data, p) {
t.Fatalf("Decompressor.Do(%v) = %v, %v, want %v, <nil>", b, p, err, test.data)
}
}
}
@ -158,12 +182,12 @@ func TestContextErr(t *testing.T) {
// bytes.
func bmEncode(b *testing.B, mSize int) {
msg := &perfpb.Buffer{Body: make([]byte, mSize)}
encoded, _ := encode(protoCodec{}, msg, compressionNone)
encoded, _ := encode(protoCodec{}, msg, nil, nil)
encodedSz := int64(len(encoded))
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
encode(protoCodec{}, msg, compressionNone)
encode(protoCodec{}, msg, nil, nil)
}
b.SetBytes(encodedSz)
}

View File

@ -34,6 +34,7 @@
package grpc
import (
"bytes"
"errors"
"fmt"
"io"
@ -92,6 +93,8 @@ type Server struct {
type options struct {
creds credentials.Credentials
codec Codec
cg CompressorGenerator
dg DecompressorGenerator
maxConcurrentStreams uint32
}
@ -105,6 +108,18 @@ func CustomCodec(codec Codec) ServerOption {
}
}
func CompressON(f CompressorGenerator) ServerOption {
return func(o *options) {
o.cg = f
}
}
func DecompressON(f DecompressorGenerator) ServerOption {
return func(o *options) {
o.dg = f
}
}
// MaxConcurrentStreams returns a ServerOption that will apply a limit on the number
// of concurrent streams to each ServerTransport.
func MaxConcurrentStreams(n uint32) ServerOption {
@ -287,8 +302,8 @@ func (s *Server) Serve(lis net.Listener) error {
}
}
func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, pf payloadFormat, opts *transport.Options) error {
p, err := encode(s.opts.codec, msg, pf)
func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, cp Compressor, opts *transport.Options) error {
p, err := encode(s.opts.codec, msg, cp, new(bytes.Buffer))
if err != nil {
// This typically indicates a fatal issue (e.g., memory
// corruption or hardware faults) the application program
@ -327,18 +342,47 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
// Nothing to do here.
case transport.StreamError:
if err := t.WriteStatus(stream, err.Code, err.Desc); err != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status: %v", err)
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err)
}
default:
panic(fmt.Sprintf("grpc: Unexpected error (%T) from recvMsg: %v", err, err))
}
return err
}
switch pf {
case compressionNone:
var dc Decompressor
if pf == compressionMade && s.opts.dg != nil {
dc = s.opts.dg()
}
if err := checkRecvPayload(pf, stream.RecvCompress(), dc); err != nil {
switch err := err.(type) {
case transport.StreamError:
if err := t.WriteStatus(stream, err.Code, err.Desc); err != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err)
}
default:
if err := t.WriteStatus(stream, codes.Internal, err.Error()); err != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err)
}
}
return err
}
statusCode := codes.OK
statusDesc := ""
df := func(v interface{}) error {
if pf == compressionMade {
var err error
req, err = dc.Do(bytes.NewReader(req))
//req, err = ioutil.ReadAll(dc)
//defer dc.Close()
if err != nil {
if err := t.WriteStatus(stream, codes.Internal, err.Error()); err != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err)
}
return err
}
}
if err := s.opts.codec.Unmarshal(req, v); err != nil {
return err
}
@ -360,7 +404,6 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
trInfo.tr.LazyLog(stringer(statusDesc), true)
trInfo.tr.SetError()
}
if err := t.WriteStatus(stream, statusCode, statusDesc); err != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status: %v", err)
return err
@ -374,7 +417,12 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
Last: true,
Delay: false,
}
if err := s.sendResponse(t, stream, reply, compressionNone, opts); err != nil {
var cp Compressor
if s.opts.cg != nil {
cp = s.opts.cg()
stream.SetSendCompress(cp.Type())
}
if err := s.sendResponse(t, stream, reply, cp, opts); err != nil {
switch err := err.(type) {
case transport.ConnectionError:
// Nothing to do here.
@ -391,18 +439,22 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
trInfo.tr.LazyLog(&payload{sent: true, msg: reply}, true)
}
return t.WriteStatus(stream, statusCode, statusDesc)
default:
panic(fmt.Sprintf("payload format to be supported: %d", pf))
}
}
}
func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, sd *StreamDesc, trInfo *traceInfo) (err error) {
var cp Compressor
if s.opts.cg != nil {
cp = s.opts.cg()
stream.SetSendCompress(cp.Type())
}
ss := &serverStream{
t: t,
s: stream,
p: &parser{s: stream},
codec: s.opts.codec,
cp: cp,
dg: s.opts.dg,
trInfo: trInfo,
}
if trInfo != nil {
@ -422,6 +474,9 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
if err, ok := appErr.(rpcError); ok {
ss.statusCode = err.code
ss.statusDesc = err.desc
} else if err, ok := appErr.(transport.StreamError); ok {
ss.statusCode = err.Code
ss.statusDesc = err.Desc
} else {
ss.statusCode = convertCode(appErr)
ss.statusDesc = appErr.Error()

View File

@ -34,6 +34,7 @@
package grpc
import (
"bytes"
"errors"
"io"
"sync"
@ -104,14 +105,23 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
if err != nil {
return nil, toRPCErr(err)
}
var cp Compressor
if cc.dopts.cg != nil {
cp = cc.dopts.cg()
}
// TODO(zhaoq): CallOption is omitted. Add support when it is needed.
callHdr := &transport.CallHdr{
Host: cc.authority,
Method: method,
}
if cp != nil {
callHdr.SendCompress = cp.Type()
}
cs := &clientStream{
desc: desc,
codec: cc.dopts.codec,
cp: cp,
dg: cc.dopts.dg,
tracing: EnableTracing,
}
if cs.tracing {
@ -153,6 +163,9 @@ type clientStream struct {
p *parser
desc *StreamDesc
codec Codec
cp Compressor
cbuf bytes.Buffer
dg DecompressorGenerator
tracing bool // set to EnableTracing when the clientStream is created.
@ -198,7 +211,8 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
}
err = toRPCErr(err)
}()
out, err := encode(cs.codec, m, compressionNone)
out, err := encode(cs.codec, m, cs.cp, &cs.cbuf)
defer cs.cbuf.Reset()
if err != nil {
return transport.StreamErrorf(codes.Internal, "grpc: %v", err)
}
@ -206,7 +220,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
}
func (cs *clientStream) RecvMsg(m interface{}) (err error) {
err = recv(cs.p, cs.codec, m)
err = recv(cs.p, cs.codec, cs.s, cs.dg, m)
defer func() {
// err != nil indicates the termination of the stream.
if err != nil {
@ -225,7 +239,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
return
}
// Special handling for client streaming rpc.
err = recv(cs.p, cs.codec, m)
err = recv(cs.p, cs.codec, cs.s, cs.dg, m)
cs.closeTransportStream(err)
if err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
@ -310,6 +324,9 @@ type serverStream struct {
s *transport.Stream
p *parser
codec Codec
cp Compressor
dg DecompressorGenerator
cbuf bytes.Buffer
statusCode codes.Code
statusDesc string
trInfo *traceInfo
@ -348,7 +365,8 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) {
ss.mu.Unlock()
}
}()
out, err := encode(ss.codec, m, compressionNone)
out, err := encode(ss.codec, m, ss.cp, &ss.cbuf)
defer ss.cbuf.Reset()
if err != nil {
err = transport.StreamErrorf(codes.Internal, "grpc: %v", err)
return err
@ -371,5 +389,5 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) {
ss.mu.Unlock()
}
}()
return recv(ss.p, ss.codec, m)
return recv(ss.p, ss.codec, ss.s, ss.dg, m)
}

View File

@ -143,7 +143,6 @@ func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*
if err != nil {
return nil, err
}
return &testpb.SimpleResponse{
Payload: payload,
}, nil
@ -328,8 +327,8 @@ func listTestEnv() []env {
return []env{{"tcp", nil, ""}, {"tcp", nil, "tls"}, {"unix", unixDialer, ""}, {"unix", unixDialer, "tls"}}
}
func setUp(t *testing.T, hs *health.HealthServer, maxStream uint32, ua string, e env) (s *grpc.Server, cc *grpc.ClientConn) {
sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(maxStream)}
func serverSetUp(t *testing.T, hs *health.HealthServer, maxStream uint32, cg grpc.CompressorGenerator, dg grpc.DecompressorGenerator, e env) (s *grpc.Server, addr string) {
sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(maxStream), grpc.CompressON(cg), grpc.DecompressON(dg)}
la := ":0"
switch e.network {
case "unix":
@ -353,7 +352,7 @@ func setUp(t *testing.T, hs *health.HealthServer, maxStream uint32, ua string, e
}
testpb.RegisterTestServiceServer(s, &testServer{security: e.security})
go s.Serve(lis)
addr := la
addr = la
switch e.network {
case "unix":
default:
@ -363,17 +362,22 @@ func setUp(t *testing.T, hs *health.HealthServer, maxStream uint32, ua string, e
}
addr = "localhost:" + port
}
return
}
func clientSetUp(t *testing.T, addr string, cg grpc.CompressorGenerator, dg grpc.DecompressorGenerator, ua string, e env) (cc *grpc.ClientConn) {
var derr error
if e.security == "tls" {
creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com")
if err != nil {
t.Fatalf("Failed to create credentials %v", err)
}
cc, err = grpc.Dial(addr, grpc.WithTransportCredentials(creds), grpc.WithDialer(e.dialer), grpc.WithUserAgent(ua))
cc, derr = grpc.Dial(addr, grpc.WithTransportCredentials(creds), grpc.WithDialer(e.dialer), grpc.WithUserAgent(ua), grpc.WithCompressor(cg), grpc.WithDecompressor(dg))
} else {
cc, err = grpc.Dial(addr, grpc.WithDialer(e.dialer), grpc.WithInsecure(), grpc.WithUserAgent(ua))
cc, derr = grpc.Dial(addr, grpc.WithDialer(e.dialer), grpc.WithInsecure(), grpc.WithUserAgent(ua), grpc.WithCompressor(cg), grpc.WithDecompressor(dg))
}
if err != nil {
t.Fatalf("Dial(%q) = %v", addr, err)
if derr != nil {
t.Fatalf("Dial(%q) = %v", addr, derr)
}
return
}
@ -390,7 +394,8 @@ func TestTimeoutOnDeadServer(t *testing.T) {
}
func testTimeoutOnDeadServer(t *testing.T, e env) {
s, cc := setUp(t, nil, math.MaxUint32, "", e)
s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc)
ctx, _ := context.WithTimeout(context.Background(), time.Second)
if _, err := cc.WaitForStateChange(ctx, grpc.Idle); err != nil {
@ -443,7 +448,8 @@ func TestHealthCheckOnSuccess(t *testing.T) {
func testHealthCheckOnSuccess(t *testing.T, e env) {
hs := health.NewHealthServer()
hs.SetServingStatus("grpc.health.v1alpha.Health", 1)
s, cc := setUp(t, hs, math.MaxUint32, "", e)
s, addr := serverSetUp(t, hs, math.MaxUint32, nil, nil, e)
cc := clientSetUp(t, addr, nil, nil, "", e)
defer tearDown(s, cc)
if _, err := healthCheck(1*time.Second, cc, "grpc.health.v1alpha.Health"); err != nil {
t.Fatalf("Health/Check(_, _) = _, %v, want _, <nil>", err)
@ -459,7 +465,8 @@ func TestHealthCheckOnFailure(t *testing.T) {
func testHealthCheckOnFailure(t *testing.T, e env) {
hs := health.NewHealthServer()
hs.SetServingStatus("grpc.health.v1alpha.HealthCheck", 1)
s, cc := setUp(t, hs, math.MaxUint32, "", e)
s, addr := serverSetUp(t, hs, math.MaxUint32, nil, nil, e)
cc := clientSetUp(t, addr, nil, nil, "", e)
defer tearDown(s, cc)
if _, err := healthCheck(0*time.Second, cc, "grpc.health.v1alpha.Health"); err != grpc.Errorf(codes.DeadlineExceeded, "context deadline exceeded") {
t.Fatalf("Health/Check(_, _) = _, %v, want _, error code %d", err, codes.DeadlineExceeded)
@ -473,7 +480,8 @@ func TestHealthCheckOff(t *testing.T) {
}
func testHealthCheckOff(t *testing.T, e env) {
s, cc := setUp(t, nil, math.MaxUint32, "", e)
s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
cc := clientSetUp(t, addr, nil, nil, "", e)
defer tearDown(s, cc)
if _, err := healthCheck(1*time.Second, cc, ""); err != grpc.Errorf(codes.Unimplemented, "unknown service grpc.health.v1alpha.Health") {
t.Fatalf("Health/Check(_, _) = _, %v, want _, error code %d", err, codes.Unimplemented)
@ -488,7 +496,8 @@ func TestHealthCheckServingStatus(t *testing.T) {
func testHealthCheckServingStatus(t *testing.T, e env) {
hs := health.NewHealthServer()
s, cc := setUp(t, hs, math.MaxUint32, "", e)
s, addr := serverSetUp(t, hs, math.MaxUint32, nil, nil, e)
cc := clientSetUp(t, addr, nil, nil, "", e)
defer tearDown(s, cc)
out, err := healthCheck(1*time.Second, cc, "")
if err != nil {
@ -526,7 +535,8 @@ func TestEmptyUnaryWithUserAgent(t *testing.T) {
}
func testEmptyUnaryWithUserAgent(t *testing.T, e env) {
s, cc := setUp(t, nil, math.MaxUint32, testAppUA, e)
s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
cc := clientSetUp(t, addr, nil, nil, testAppUA, e)
// Wait until cc is connected.
ctx, _ := context.WithTimeout(context.Background(), time.Second)
if _, err := cc.WaitForStateChange(ctx, grpc.Idle); err != nil {
@ -569,7 +579,8 @@ func TestFailedEmptyUnary(t *testing.T) {
}
func testFailedEmptyUnary(t *testing.T, e env) {
s, cc := setUp(t, nil, math.MaxUint32, "", e)
s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
ctx := metadata.NewContext(context.Background(), testMetadata)
@ -585,7 +596,8 @@ func TestLargeUnary(t *testing.T) {
}
func testLargeUnary(t *testing.T, e env) {
s, cc := setUp(t, nil, math.MaxUint32, "", e)
s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
argSize := 271828
@ -619,7 +631,8 @@ func TestMetadataUnaryRPC(t *testing.T) {
}
func testMetadataUnaryRPC(t *testing.T, e env) {
s, cc := setUp(t, nil, math.MaxUint32, "", e)
s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
argSize := 2718
@ -684,7 +697,8 @@ func TestRetry(t *testing.T) {
// TODO(zhaoq): Refactor to make this clearer and add more cases to test racy
// and error-prone paths.
func testRetry(t *testing.T, e env) {
s, cc := setUp(t, nil, math.MaxUint32, "", e)
s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
var wg sync.WaitGroup
@ -714,7 +728,8 @@ func TestRPCTimeout(t *testing.T) {
// TODO(zhaoq): Have a better test coverage of timeout and cancellation mechanism.
func testRPCTimeout(t *testing.T, e env) {
s, cc := setUp(t, nil, math.MaxUint32, "", e)
s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
argSize := 2718
@ -746,7 +761,8 @@ func TestCancel(t *testing.T) {
}
func testCancel(t *testing.T, e env) {
s, cc := setUp(t, nil, math.MaxUint32, "", e)
s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
argSize := 2718
@ -778,7 +794,8 @@ func TestCancelNoIO(t *testing.T) {
func testCancelNoIO(t *testing.T, e env) {
// Only allows 1 live stream per server transport.
s, cc := setUp(t, nil, 1, "", e)
s, addr := serverSetUp(t, nil, 1, nil, nil, e)
cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
ctx, cancel := context.WithCancel(context.Background())
@ -829,7 +846,8 @@ func TestPingPong(t *testing.T) {
}
func testPingPong(t *testing.T, e env) {
s, cc := setUp(t, nil, math.MaxUint32, "", e)
s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
stream, err := tc.FullDuplexCall(context.Background())
@ -886,7 +904,8 @@ func TestMetadataStreamingRPC(t *testing.T) {
}
func testMetadataStreamingRPC(t *testing.T, e env) {
s, cc := setUp(t, nil, math.MaxUint32, "", e)
s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
ctx := metadata.NewContext(context.Background(), testMetadata)
@ -952,7 +971,8 @@ func TestServerStreaming(t *testing.T) {
}
func testServerStreaming(t *testing.T, e env) {
s, cc := setUp(t, nil, math.MaxUint32, "", e)
s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
respParam := make([]*testpb.ResponseParameters, len(respSizes))
@ -1004,7 +1024,8 @@ func TestFailedServerStreaming(t *testing.T) {
}
func testFailedServerStreaming(t *testing.T, e env) {
s, cc := setUp(t, nil, math.MaxUint32, "", e)
s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
respParam := make([]*testpb.ResponseParameters, len(respSizes))
@ -1034,7 +1055,8 @@ func TestClientStreaming(t *testing.T) {
}
func testClientStreaming(t *testing.T, e env) {
s, cc := setUp(t, nil, math.MaxUint32, "", e)
s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
stream, err := tc.StreamingInputCall(context.Background())
@ -1074,7 +1096,8 @@ func TestExceedMaxStreamsLimit(t *testing.T) {
func testExceedMaxStreamsLimit(t *testing.T, e env) {
// Only allows 1 live stream per server transport.
s, cc := setUp(t, nil, 1, "", e)
s, addr := serverSetUp(t, nil, 1, nil, nil, e)
cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
_, err := tc.StreamingInputCall(context.Background())
@ -1095,3 +1118,109 @@ func testExceedMaxStreamsLimit(t *testing.T, e env) {
t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, %d", tc, err, codes.DeadlineExceeded)
}
}
func TestCompressServerHasNoSupport(t *testing.T) {
for _, e := range listTestEnv() {
testCompressServerHasNoSupport(t, e)
}
}
func testCompressServerHasNoSupport(t *testing.T, e env) {
s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
cc := clientSetUp(t, addr, grpc.NewGZIPCompressor, nil, "", e)
// Unary call
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
argSize := 271828
respSize := 314159
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(argSize))
if err != nil {
t.Fatal(err)
}
req := &testpb.SimpleRequest{
ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
ResponseSize: proto.Int32(int32(respSize)),
Payload: payload,
}
if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.InvalidArgument {
t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code %d", err, codes.InvalidArgument)
}
// Streaming RPC
stream, err := tc.FullDuplexCall(context.Background())
if err != nil {
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
}
respParam := []*testpb.ResponseParameters{
{
Size: proto.Int32(31415),
},
}
payload, err = newPayload(testpb.PayloadType_COMPRESSABLE, int32(31415))
if err != nil {
t.Fatal(err)
}
sreq := &testpb.StreamingOutputCallRequest{
ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
ResponseParameters: respParam,
Payload: payload,
}
if err := stream.Send(sreq); err != nil {
t.Fatalf("%v.Send(%v) = %v, want <nil>", stream, sreq, err)
}
if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.InvalidArgument {
t.Fatalf("%v.Recv() = %v, want error code %d", stream, err, codes.InvalidArgument)
}
}
func TestCompressOK(t *testing.T) {
for _, e := range listTestEnv() {
testCompressOK(t, e)
}
}
func testCompressOK(t *testing.T, e env) {
s, addr := serverSetUp(t, nil, math.MaxUint32, grpc.NewGZIPCompressor, grpc.NewGZIPDecompressor, e)
cc := clientSetUp(t, addr, grpc.NewGZIPCompressor, grpc.NewGZIPDecompressor, "", e)
// Unary call
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
argSize := 271828
respSize := 314159
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(argSize))
if err != nil {
t.Fatal(err)
}
req := &testpb.SimpleRequest{
ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
ResponseSize: proto.Int32(int32(respSize)),
Payload: payload,
}
if _, err := tc.UnaryCall(context.Background(), req); err != nil {
t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, <nil>", err)
}
// Streaming RPC
stream, err := tc.FullDuplexCall(context.Background())
if err != nil {
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
}
respParam := []*testpb.ResponseParameters{
{
Size: proto.Int32(31415),
},
}
payload, err = newPayload(testpb.PayloadType_COMPRESSABLE, int32(31415))
if err != nil {
t.Fatal(err)
}
sreq := &testpb.StreamingOutputCallRequest{
ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
ResponseParameters: respParam,
Payload: payload,
}
if err := stream.Send(sreq); err != nil {
t.Fatalf("%v.Send(%v) = %v, want <nil>", stream, sreq, err)
}
if _, err := stream.Recv(); err != nil {
t.Fatalf("%v.Recv() = %v, want <nil>", stream, err)
}
}

View File

@ -210,6 +210,7 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
s := &Stream{
id: t.nextID,
method: callHdr.Method,
sendCompress: callHdr.SendCompress,
buf: newRecvBuffer(),
fc: fc,
sendQuotaPool: newQuotaPool(int(t.streamSendQuota)),
@ -322,6 +323,9 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
t.hEnc.WriteField(hpack.HeaderField{Name: "user-agent", Value: t.userAgent})
t.hEnc.WriteField(hpack.HeaderField{Name: "te", Value: "trailers"})
if callHdr.SendCompress != "" {
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress})
}
if timeout > 0 {
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: timeoutEncode(timeout)})
}
@ -694,8 +698,10 @@ func (t *http2Client) operateHeaders(hDec *hpackDecoder, s *Stream, frame header
if !endHeaders {
return s
}
s.mu.Lock()
if !endStream {
s.recvCompress = hDec.state.encoding
}
if !s.headerDone {
if !endStream && len(hDec.state.mdata) > 0 {
s.header = hDec.state.mdata

View File

@ -164,6 +164,7 @@ func (t *http2Server) operateHeaders(hDec *hpackDecoder, s *Stream, frame header
if !endHeaders {
return s
}
s.recvCompress = hDec.state.encoding
if hDec.state.timeoutSet {
s.ctx, s.cancel = context.WithTimeout(context.TODO(), hDec.state.timeout)
} else {
@ -190,6 +191,7 @@ func (t *http2Server) operateHeaders(hDec *hpackDecoder, s *Stream, frame header
ctx: s.ctx,
recv: s.buf,
}
s.recvCompress = hDec.state.encoding
s.method = hDec.state.method
t.mu.Lock()
if t.state != reachable {
@ -446,6 +448,9 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
t.hBuf.Reset()
t.hEnc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
t.hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"})
if s.sendCompress != "" {
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress})
}
for k, v := range md {
for _, entry := range v {
t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry})
@ -520,6 +525,9 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
t.hBuf.Reset()
t.hEnc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
t.hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"})
if s.sendCompress != "" {
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress})
}
p := http2.HeadersFrameParam{
StreamID: s.id,
BlockFragment: t.hBuf.Bytes(),

View File

@ -89,6 +89,7 @@ var (
// Records the states during HPACK decoding. Must be reset once the
// decoding of the entire headers are finished.
type decodeState struct {
encoding string
// statusCode caches the stream status received from the trailer
// the server sent. Client side only.
statusCode codes.Code
@ -145,6 +146,8 @@ func newHPACKDecoder() *hpackDecoder {
d.err = StreamErrorf(codes.FailedPrecondition, "transport: received the unexpected header")
return
}
case "grpc-encoding":
d.state.encoding = f.Value
case "grpc-status":
code, err := strconv.Atoi(f.Value)
if err != nil {

View File

@ -171,6 +171,8 @@ type Stream struct {
cancel context.CancelFunc
// method records the associated RPC method of the stream.
method string
recvCompress string
sendCompress string
buf *recvBuffer
dec io.Reader
fc *inFlow
@ -201,6 +203,14 @@ type Stream struct {
statusDesc string
}
func (s *Stream) RecvCompress() string {
return s.recvCompress
}
func (s *Stream) SetSendCompress(str string) {
s.sendCompress = str
}
// Header acquires the key-value pairs of header metadata once it
// is available. It blocks until i) the metadata is ready or ii) there is no
// header metadata or iii) the stream is cancelled/expired.
@ -350,6 +360,8 @@ type Options struct {
type CallHdr struct {
Host string // peer host
Method string // the operation to perform on the specified host
RecvCompress string
SendCompress string
}
// ClientTransport is the common interface for all gRPC client side transport