mirror of https://github.com/grpc/grpc-go.git
Introduce new Compressor/Decompressor API (#1428)
This commit is contained in:
parent
246b2f7081
commit
5db344a40a
17
call.go
17
call.go
|
@ -19,7 +19,6 @@
|
|||
package grpc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
|
@ -27,6 +26,7 @@ import (
|
|||
"golang.org/x/net/trace"
|
||||
"google.golang.org/grpc/balancer"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/encoding"
|
||||
"google.golang.org/grpc/peer"
|
||||
"google.golang.org/grpc/stats"
|
||||
"google.golang.org/grpc/status"
|
||||
|
@ -62,7 +62,7 @@ func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTran
|
|||
if c.maxReceiveMessageSize == nil {
|
||||
return Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)")
|
||||
}
|
||||
if err = recv(p, dopts.codec, stream, dopts.dc, reply, *c.maxReceiveMessageSize, inPayload); err != nil {
|
||||
if err = recv(p, dopts.codec, stream, dopts.dc, reply, *c.maxReceiveMessageSize, inPayload, encoding.GetCompressor(c.compressorType)); err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
|
@ -89,18 +89,17 @@ func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor,
|
|||
}
|
||||
}()
|
||||
var (
|
||||
cbuf *bytes.Buffer
|
||||
outPayload *stats.OutPayload
|
||||
)
|
||||
if compressor != nil {
|
||||
cbuf = new(bytes.Buffer)
|
||||
}
|
||||
if dopts.copts.StatsHandler != nil {
|
||||
outPayload = &stats.OutPayload{
|
||||
Client: true,
|
||||
}
|
||||
}
|
||||
hdr, data, err := encode(dopts.codec, args, compressor, cbuf, outPayload)
|
||||
if c.compressorType != "" && encoding.GetCompressor(c.compressorType) == nil {
|
||||
return Errorf(codes.Internal, "grpc: Compressor is not installed for grpc-encoding %q", c.compressorType)
|
||||
}
|
||||
hdr, data, err := encode(dopts.codec, args, compressor, outPayload, encoding.GetCompressor(c.compressorType))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -223,7 +222,9 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
|
|||
Host: cc.authority,
|
||||
Method: method,
|
||||
}
|
||||
if cc.dopts.cp != nil {
|
||||
if c.compressorType != "" {
|
||||
callHdr.SendCompress = c.compressorType
|
||||
} else if cc.dopts.cp != nil {
|
||||
callHdr.SendCompress = cc.dopts.cp.Type()
|
||||
}
|
||||
if c.creds != nil {
|
||||
|
|
|
@ -104,6 +104,16 @@ const (
|
|||
// DialOption configures how we set up the connection.
|
||||
type DialOption func(*dialOptions)
|
||||
|
||||
// UseCompressor returns a CallOption which sets the compressor used when sending the request.
|
||||
// If WithCompressor is set, UseCompressor has higher priority.
|
||||
// This API is EXPERIMENTAL.
|
||||
func UseCompressor(name string) CallOption {
|
||||
return beforeCall(func(c *callInfo) error {
|
||||
c.compressorType = name
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// WithWriteBufferSize lets you set the size of write buffer, this determines how much data can be batched
|
||||
// before doing a write on the wire.
|
||||
func WithWriteBufferSize(s int) DialOption {
|
||||
|
@ -156,7 +166,8 @@ func WithCodec(c Codec) DialOption {
|
|||
}
|
||||
|
||||
// WithCompressor returns a DialOption which sets a CompressorGenerator for generating message
|
||||
// compressor.
|
||||
// compressor. It has lower priority than the compressor set by RegisterCompressor.
|
||||
// This function is deprecated.
|
||||
func WithCompressor(cp Compressor) DialOption {
|
||||
return func(o *dialOptions) {
|
||||
o.cp = cp
|
||||
|
@ -164,7 +175,8 @@ func WithCompressor(cp Compressor) DialOption {
|
|||
}
|
||||
|
||||
// WithDecompressor returns a DialOption which sets a DecompressorGenerator for generating
|
||||
// message decompressor.
|
||||
// message decompressor. It has higher priority than the decompressor set by RegisterCompressor.
|
||||
// This function is deprecated.
|
||||
func WithDecompressor(dc Decompressor) DialOption {
|
||||
return func(o *dialOptions) {
|
||||
o.dc = dc
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
/*
|
||||
*
|
||||
* Copyright 2017 gRPC authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
// Package encoding defines the interface for the compressor and the functions
|
||||
// to register and get the compossor.
|
||||
// This package is EXPERIMENTAL.
|
||||
package encoding
|
||||
|
||||
import (
|
||||
"io"
|
||||
)
|
||||
|
||||
var registerCompressor = make(map[string]Compressor)
|
||||
|
||||
// Compressor is used for compressing and decompressing when sending or receiving messages.
|
||||
type Compressor interface {
|
||||
// Compress writes the data written to wc to w after compressing it. If an error
|
||||
// occurs while initializing the compressor, that error is returned instead.
|
||||
Compress(w io.Writer) (io.WriteCloser, error)
|
||||
// Decompress reads data from r, decompresses it, and provides the uncompressed data
|
||||
// via the returned io.Reader. If an error occurs while initializing the decompressor, that error
|
||||
// is returned instead.
|
||||
Decompress(r io.Reader) (io.Reader, error)
|
||||
// Name is the name of the compression codec and is used to set the content coding header.
|
||||
Name() string
|
||||
}
|
||||
|
||||
// RegisterCompressor registers the compressor with gRPC by its name. It can be activated when
|
||||
// sending an RPC via grpc.UseCompressor(). It will be automatically accessed when receiving a
|
||||
// message based on the content coding header. Servers also use it to send a response with the
|
||||
// same encoding as the request.
|
||||
//
|
||||
// NOTE: this function must only be called during initialization time (i.e. in an init() function). If
|
||||
// multiple Compressors are registered with the same name, the one registered last will take effect.
|
||||
func RegisterCompressor(c Compressor) {
|
||||
registerCompressor[c.Name()] = c
|
||||
}
|
||||
|
||||
// GetCompressor returns Compressor for the given compressor name.
|
||||
func GetCompressor(name string) Compressor {
|
||||
return registerCompressor[name]
|
||||
}
|
|
@ -0,0 +1,93 @@
|
|||
/*
|
||||
*
|
||||
* Copyright 2017 gRPC authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
// Package gzip implements and registers the gzip compressor
|
||||
// during the initialization.
|
||||
// This package is EXPERIMENTAL.
|
||||
package gzip
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"sync"
|
||||
|
||||
"google.golang.org/grpc/encoding"
|
||||
)
|
||||
|
||||
func init() {
|
||||
c := &compressor{}
|
||||
c.poolCompressor.New = func() interface{} {
|
||||
return &writer{Writer: gzip.NewWriter(ioutil.Discard), pool: &c.poolCompressor}
|
||||
}
|
||||
encoding.RegisterCompressor(c)
|
||||
}
|
||||
|
||||
type writer struct {
|
||||
*gzip.Writer
|
||||
pool *sync.Pool
|
||||
}
|
||||
|
||||
func (c *compressor) Compress(w io.Writer) (io.WriteCloser, error) {
|
||||
z := c.poolCompressor.Get().(*writer)
|
||||
z.Writer.Reset(w)
|
||||
return z, nil
|
||||
}
|
||||
|
||||
func (z *writer) Close() error {
|
||||
defer z.pool.Put(z)
|
||||
return z.Writer.Close()
|
||||
}
|
||||
|
||||
type reader struct {
|
||||
*gzip.Reader
|
||||
pool *sync.Pool
|
||||
}
|
||||
|
||||
func (c *compressor) Decompress(r io.Reader) (io.Reader, error) {
|
||||
z, inPool := c.poolDecompressor.Get().(*reader)
|
||||
if !inPool {
|
||||
newZ, err := gzip.NewReader(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &reader{Reader: newZ, pool: &c.poolDecompressor}, nil
|
||||
}
|
||||
if err := z.Reset(r); err != nil {
|
||||
c.poolDecompressor.Put(z)
|
||||
return nil, err
|
||||
}
|
||||
return z, nil
|
||||
}
|
||||
|
||||
func (z *reader) Read(p []byte) (n int, err error) {
|
||||
n, err = z.Reader.Read(p)
|
||||
if err == io.EOF {
|
||||
z.pool.Put(z)
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *compressor) Name() string {
|
||||
return "gzip"
|
||||
}
|
||||
|
||||
type compressor struct {
|
||||
poolCompressor sync.Pool
|
||||
poolDecompressor sync.Pool
|
||||
}
|
61
rpc_util.go
61
rpc_util.go
|
@ -31,6 +31,7 @@ import (
|
|||
"golang.org/x/net/context"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/encoding"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/peer"
|
||||
"google.golang.org/grpc/stats"
|
||||
|
@ -122,6 +123,7 @@ func (d *gzipDecompressor) Type() string {
|
|||
|
||||
// callInfo contains all related configuration and information about an RPC.
|
||||
type callInfo struct {
|
||||
compressorType string
|
||||
failFast bool
|
||||
headerMD metadata.MD
|
||||
trailerMD metadata.MD
|
||||
|
@ -294,13 +296,16 @@ func (p *parser) recvMsg(maxReceiveMessageSize int) (pf payloadFormat, msg []byt
|
|||
|
||||
// encode serializes msg and returns a buffer of message header and a buffer of msg.
|
||||
// If msg is nil, it generates the message header and an empty msg buffer.
|
||||
func encode(c Codec, msg interface{}, cp Compressor, cbuf *bytes.Buffer, outPayload *stats.OutPayload) ([]byte, []byte, error) {
|
||||
var b []byte
|
||||
// TODO(ddyihai): eliminate extra Compressor parameter.
|
||||
func encode(c Codec, msg interface{}, cp Compressor, outPayload *stats.OutPayload, compressor encoding.Compressor) ([]byte, []byte, error) {
|
||||
var (
|
||||
b []byte
|
||||
cbuf *bytes.Buffer
|
||||
)
|
||||
const (
|
||||
payloadLen = 1
|
||||
sizeLen = 4
|
||||
)
|
||||
|
||||
if msg != nil {
|
||||
var err error
|
||||
b, err = c.Marshal(msg)
|
||||
|
@ -313,24 +318,35 @@ func encode(c Codec, msg interface{}, cp Compressor, cbuf *bytes.Buffer, outPayl
|
|||
outPayload.Data = b
|
||||
outPayload.Length = len(b)
|
||||
}
|
||||
if cp != nil {
|
||||
if err := cp.Do(cbuf, b); err != nil {
|
||||
return nil, nil, Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
|
||||
if compressor != nil || cp != nil {
|
||||
cbuf = new(bytes.Buffer)
|
||||
// Has compressor, check Compressor is set by UseCompressor first.
|
||||
if compressor != nil {
|
||||
z, _ := compressor.Compress(cbuf)
|
||||
if _, err := z.Write(b); err != nil {
|
||||
return nil, nil, Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
|
||||
}
|
||||
z.Close()
|
||||
} else {
|
||||
// If Compressor is not set by UseCompressor, use default Compressor
|
||||
if err := cp.Do(cbuf, b); err != nil {
|
||||
return nil, nil, Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
|
||||
}
|
||||
}
|
||||
b = cbuf.Bytes()
|
||||
}
|
||||
}
|
||||
|
||||
if uint(len(b)) > math.MaxUint32 {
|
||||
return nil, nil, Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b))
|
||||
}
|
||||
|
||||
bufHeader := make([]byte, payloadLen+sizeLen)
|
||||
if cp == nil {
|
||||
bufHeader[0] = byte(compressionNone)
|
||||
} else {
|
||||
if compressor != nil || cp != nil {
|
||||
bufHeader[0] = byte(compressionMade)
|
||||
} else {
|
||||
bufHeader[0] = byte(compressionNone)
|
||||
}
|
||||
|
||||
// Write length of b into buf
|
||||
binary.BigEndian.PutUint32(bufHeader[payloadLen:], uint32(len(b)))
|
||||
if outPayload != nil {
|
||||
|
@ -343,7 +359,7 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) er
|
|||
switch pf {
|
||||
case compressionNone:
|
||||
case compressionMade:
|
||||
if dc == nil || recvCompress != dc.Type() {
|
||||
if (dc == nil || recvCompress != dc.Type()) && encoding.GetCompressor(recvCompress) == nil {
|
||||
return Errorf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
|
||||
}
|
||||
default:
|
||||
|
@ -352,7 +368,9 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) er
|
|||
return nil
|
||||
}
|
||||
|
||||
func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxReceiveMessageSize int, inPayload *stats.InPayload) error {
|
||||
// TODO(ddyihai): eliminate extra Compressor parameter.
|
||||
func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxReceiveMessageSize int,
|
||||
inPayload *stats.InPayload, compressor encoding.Compressor) error {
|
||||
pf, d, err := p.recvMsg(maxReceiveMessageSize)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -364,9 +382,22 @@ func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{
|
|||
return err
|
||||
}
|
||||
if pf == compressionMade {
|
||||
d, err = dc.Do(bytes.NewReader(d))
|
||||
if err != nil {
|
||||
return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
|
||||
// To match legacy behavior, if the decompressor is set by WithDecompressor or RPCDecompressor,
|
||||
// use this decompressor as the default.
|
||||
if dc != nil {
|
||||
d, err = dc.Do(bytes.NewReader(d))
|
||||
if err != nil {
|
||||
return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
|
||||
}
|
||||
} else {
|
||||
dcReader, err := compressor.Decompress(bytes.NewReader(d))
|
||||
if err != nil {
|
||||
return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
|
||||
}
|
||||
d, err = ioutil.ReadAll(dcReader)
|
||||
if err != nil {
|
||||
return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(d) > maxReceiveMessageSize {
|
||||
|
|
56
server.go
56
server.go
|
@ -32,11 +32,14 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"io/ioutil"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/trace"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/encoding"
|
||||
"google.golang.org/grpc/grpclog"
|
||||
"google.golang.org/grpc/internal"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
@ -187,6 +190,8 @@ func CustomCodec(codec Codec) ServerOption {
|
|||
}
|
||||
|
||||
// RPCCompressor returns a ServerOption that sets a compressor for outbound messages.
|
||||
// It has lower priority than the compressor set by RegisterCompressor.
|
||||
// This function is deprecated.
|
||||
func RPCCompressor(cp Compressor) ServerOption {
|
||||
return func(o *options) {
|
||||
o.cp = cp
|
||||
|
@ -194,6 +199,8 @@ func RPCCompressor(cp Compressor) ServerOption {
|
|||
}
|
||||
|
||||
// RPCDecompressor returns a ServerOption that sets a decompressor for inbound messages.
|
||||
// It has higher priority than the decompressor set by RegisterCompressor.
|
||||
// This function is deprecated.
|
||||
func RPCDecompressor(dc Decompressor) ServerOption {
|
||||
return func(o *options) {
|
||||
o.dc = dc
|
||||
|
@ -701,16 +708,18 @@ func (s *Server) removeConn(c io.Closer) {
|
|||
|
||||
func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, cp Compressor, opts *transport.Options) error {
|
||||
var (
|
||||
cbuf *bytes.Buffer
|
||||
outPayload *stats.OutPayload
|
||||
)
|
||||
if cp != nil {
|
||||
cbuf = new(bytes.Buffer)
|
||||
}
|
||||
if s.opts.statsHandler != nil {
|
||||
outPayload = &stats.OutPayload{}
|
||||
}
|
||||
hdr, data, err := encode(s.opts.codec, msg, cp, cbuf, outPayload)
|
||||
if stream.RecvCompress() != "" {
|
||||
// Server receives compressor, check compressor set by register and default.
|
||||
if encoding.GetCompressor(stream.RecvCompress()) == nil && (cp == nil || cp != nil && cp.Type() != stream.RecvCompress()) {
|
||||
return Errorf(codes.Internal, "grpc: Compressor is not installed for grpc-encoding %q", stream.RecvCompress())
|
||||
}
|
||||
}
|
||||
hdr, data, err := encode(s.opts.codec, msg, cp, outPayload, encoding.GetCompressor(stream.RecvCompress()))
|
||||
if err != nil {
|
||||
grpclog.Errorln("grpc: server failed to encode response: ", err)
|
||||
return err
|
||||
|
@ -754,7 +763,9 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
|
|||
}
|
||||
}()
|
||||
}
|
||||
if s.opts.cp != nil {
|
||||
if stream.RecvCompress() != "" {
|
||||
stream.SetSendCompress(stream.RecvCompress())
|
||||
} else if s.opts.cp != nil {
|
||||
// NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686.
|
||||
stream.SetSendCompress(s.opts.cp.Type())
|
||||
}
|
||||
|
@ -786,7 +797,6 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
|
|||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if err := checkRecvPayload(pf, stream.RecvCompress(), s.opts.dc); err != nil {
|
||||
if st, ok := status.FromError(err); ok {
|
||||
if e := t.WriteStatus(stream, st); e != nil {
|
||||
|
@ -812,9 +822,18 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
|
|||
}
|
||||
if pf == compressionMade {
|
||||
var err error
|
||||
req, err = s.opts.dc.Do(bytes.NewReader(req))
|
||||
if err != nil {
|
||||
return Errorf(codes.Internal, err.Error())
|
||||
if s.opts.dc != nil {
|
||||
req, err = s.opts.dc.Do(bytes.NewReader(req))
|
||||
if err != nil {
|
||||
return Errorf(codes.Internal, err.Error())
|
||||
}
|
||||
} else {
|
||||
dcReader := encoding.GetCompressor(stream.RecvCompress())
|
||||
tmp, _ := dcReader.Decompress(bytes.NewReader(req))
|
||||
req, err = ioutil.ReadAll(tmp)
|
||||
if err != nil {
|
||||
return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(req) > s.opts.maxReceiveMessageSize {
|
||||
|
@ -909,16 +928,19 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
|
|||
sh.HandleRPC(stream.Context(), end)
|
||||
}()
|
||||
}
|
||||
if s.opts.cp != nil {
|
||||
if stream.RecvCompress() != "" {
|
||||
stream.SetSendCompress(stream.RecvCompress())
|
||||
} else if s.opts.cp != nil {
|
||||
stream.SetSendCompress(s.opts.cp.Type())
|
||||
}
|
||||
ss := &serverStream{
|
||||
t: t,
|
||||
s: stream,
|
||||
p: &parser{r: stream},
|
||||
codec: s.opts.codec,
|
||||
cp: s.opts.cp,
|
||||
dc: s.opts.dc,
|
||||
t: t,
|
||||
s: stream,
|
||||
p: &parser{r: stream},
|
||||
codec: s.opts.codec,
|
||||
cpType: stream.RecvCompress(),
|
||||
cp: s.opts.cp,
|
||||
dc: s.opts.dc,
|
||||
maxReceiveMessageSize: s.opts.maxReceiveMessageSize,
|
||||
maxSendMessageSize: s.opts.maxSendMessageSize,
|
||||
trInfo: trInfo,
|
||||
|
|
27
stream.go
27
stream.go
|
@ -19,7 +19,6 @@
|
|||
package grpc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"sync"
|
||||
|
@ -29,6 +28,7 @@ import (
|
|||
"golang.org/x/net/trace"
|
||||
"google.golang.org/grpc/balancer"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/encoding"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/peer"
|
||||
"google.golang.org/grpc/stats"
|
||||
|
@ -151,7 +151,9 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
|
|||
// time soon, so we ask the transport to flush the header.
|
||||
Flush: desc.ClientStreams,
|
||||
}
|
||||
if cc.dopts.cp != nil {
|
||||
if c.compressorType != "" {
|
||||
callHdr.SendCompress = c.compressorType
|
||||
} else if cc.dopts.cp != nil {
|
||||
callHdr.SendCompress = cc.dopts.cp.Type()
|
||||
}
|
||||
if c.creds != nil {
|
||||
|
@ -242,6 +244,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
|
|||
c: c,
|
||||
desc: desc,
|
||||
codec: cc.dopts.codec,
|
||||
cpType: c.compressorType,
|
||||
cp: cc.dopts.cp,
|
||||
dc: cc.dopts.dc,
|
||||
cancel: cancel,
|
||||
|
@ -292,6 +295,7 @@ type clientStream struct {
|
|||
p *parser
|
||||
desc *StreamDesc
|
||||
codec Codec
|
||||
cpType string
|
||||
cp Compressor
|
||||
dc Decompressor
|
||||
cancel context.CancelFunc
|
||||
|
@ -369,7 +373,10 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
|
|||
Client: true,
|
||||
}
|
||||
}
|
||||
hdr, data, err := encode(cs.codec, m, cs.cp, bytes.NewBuffer([]byte{}), outPayload)
|
||||
if cs.cpType != "" && encoding.GetCompressor(cs.cpType) == nil {
|
||||
return Errorf(codes.Internal, "grpc: Compressor is not installed for grpc-encoding %q", cs.cpType)
|
||||
}
|
||||
hdr, data, err := encode(cs.codec, m, cs.cp, outPayload, encoding.GetCompressor(cs.cpType))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -397,7 +404,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
|
|||
if cs.c.maxReceiveMessageSize == nil {
|
||||
return Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)")
|
||||
}
|
||||
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, inPayload)
|
||||
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, inPayload, encoding.GetCompressor(cs.cpType))
|
||||
defer func() {
|
||||
// err != nil indicates the termination of the stream.
|
||||
if err != nil {
|
||||
|
@ -423,7 +430,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
|
|||
if cs.c.maxReceiveMessageSize == nil {
|
||||
return Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)")
|
||||
}
|
||||
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, nil)
|
||||
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, nil, encoding.GetCompressor(cs.cpType))
|
||||
cs.closeTransportStream(err)
|
||||
if err == nil {
|
||||
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
|
||||
|
@ -552,6 +559,7 @@ type serverStream struct {
|
|||
s *transport.Stream
|
||||
p *parser
|
||||
codec Codec
|
||||
cpType string
|
||||
cp Compressor
|
||||
dc Decompressor
|
||||
maxReceiveMessageSize int
|
||||
|
@ -609,7 +617,12 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) {
|
|||
if ss.statsHandler != nil {
|
||||
outPayload = &stats.OutPayload{}
|
||||
}
|
||||
hdr, data, err := encode(ss.codec, m, ss.cp, bytes.NewBuffer([]byte{}), outPayload)
|
||||
if ss.cpType != "" {
|
||||
if encoding.GetCompressor(ss.cpType) == nil && (ss.cp == nil || ss.cp != nil && ss.cp.Type() != ss.cpType) {
|
||||
return Errorf(codes.Internal, "grpc: Compressor is not installed for grpc-encoding %q", ss.cpType)
|
||||
}
|
||||
}
|
||||
hdr, data, err := encode(ss.codec, m, ss.cp, outPayload, encoding.GetCompressor(ss.cpType))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -649,7 +662,7 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) {
|
|||
if ss.statsHandler != nil {
|
||||
inPayload = &stats.InPayload{}
|
||||
}
|
||||
if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, inPayload); err != nil {
|
||||
if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, inPayload, encoding.GetCompressor(ss.cpType)); err != nil {
|
||||
if err == io.EOF {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -51,6 +51,7 @@ import (
|
|||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/connectivity"
|
||||
"google.golang.org/grpc/credentials"
|
||||
_ "google.golang.org/grpc/encoding/gzip"
|
||||
_ "google.golang.org/grpc/grpclog/glogger"
|
||||
"google.golang.org/grpc/health"
|
||||
healthpb "google.golang.org/grpc/health/grpc_health_v1"
|
||||
|
@ -437,18 +438,24 @@ type test struct {
|
|||
cancel context.CancelFunc
|
||||
|
||||
// Configurable knobs, after newTest returns:
|
||||
testServer testpb.TestServiceServer // nil means none
|
||||
healthServer *health.Server // nil means disabled
|
||||
maxStream uint32
|
||||
tapHandle tap.ServerInHandle
|
||||
maxMsgSize *int
|
||||
maxClientReceiveMsgSize *int
|
||||
maxClientSendMsgSize *int
|
||||
maxServerReceiveMsgSize *int
|
||||
maxServerSendMsgSize *int
|
||||
userAgent string
|
||||
clientCompression bool
|
||||
serverCompression bool
|
||||
testServer testpb.TestServiceServer // nil means none
|
||||
healthServer *health.Server // nil means disabled
|
||||
maxStream uint32
|
||||
tapHandle tap.ServerInHandle
|
||||
maxMsgSize *int
|
||||
maxClientReceiveMsgSize *int
|
||||
maxClientSendMsgSize *int
|
||||
maxServerReceiveMsgSize *int
|
||||
maxServerSendMsgSize *int
|
||||
userAgent string
|
||||
// clientCompression and serverCompression are set to test the deprecated API
|
||||
// WithCompressor and WithDecompressor.
|
||||
clientCompression bool
|
||||
serverCompression bool
|
||||
// clientUseCompression is set to test the new compressor registration API UseCompressor.
|
||||
clientUseCompression bool
|
||||
// clientNopCompression is set to create a compressor whose type is not supported.
|
||||
clientNopCompression bool
|
||||
unaryClientInt grpc.UnaryClientInterceptor
|
||||
streamClientInt grpc.StreamClientInterceptor
|
||||
unaryServerInt grpc.UnaryServerInterceptor
|
||||
|
@ -594,6 +601,32 @@ func (te *test) startServer(ts testpb.TestServiceServer) {
|
|||
te.srvAddr = addr
|
||||
}
|
||||
|
||||
type nopCompressor struct {
|
||||
grpc.Compressor
|
||||
}
|
||||
|
||||
// NewNopCompressor creates a compressor to test the case that type is not supported.
|
||||
func NewNopCompressor() grpc.Compressor {
|
||||
return &nopCompressor{grpc.NewGZIPCompressor()}
|
||||
}
|
||||
|
||||
func (c *nopCompressor) Type() string {
|
||||
return "nop"
|
||||
}
|
||||
|
||||
type nopDecompressor struct {
|
||||
grpc.Decompressor
|
||||
}
|
||||
|
||||
// NewNopDecompressor creates a decompressor to test the case that type is not supported.
|
||||
func NewNopDecompressor() grpc.Decompressor {
|
||||
return &nopDecompressor{grpc.NewGZIPDecompressor()}
|
||||
}
|
||||
|
||||
func (d *nopDecompressor) Type() string {
|
||||
return "nop"
|
||||
}
|
||||
|
||||
func (te *test) clientConn() *grpc.ClientConn {
|
||||
if te.cc != nil {
|
||||
return te.cc
|
||||
|
@ -613,6 +646,15 @@ func (te *test) clientConn() *grpc.ClientConn {
|
|||
grpc.WithDecompressor(grpc.NewGZIPDecompressor()),
|
||||
)
|
||||
}
|
||||
if te.clientUseCompression {
|
||||
opts = append(opts, grpc.WithDefaultCallOptions(grpc.UseCompressor("gzip")))
|
||||
}
|
||||
if te.clientNopCompression {
|
||||
opts = append(opts,
|
||||
grpc.WithCompressor(NewNopCompressor()),
|
||||
grpc.WithDecompressor(NewNopDecompressor()),
|
||||
)
|
||||
}
|
||||
if te.unaryClientInt != nil {
|
||||
opts = append(opts, grpc.WithUnaryInterceptor(te.unaryClientInt))
|
||||
}
|
||||
|
@ -3749,7 +3791,8 @@ func TestCompressServerHasNoSupport(t *testing.T) {
|
|||
func testCompressServerHasNoSupport(t *testing.T, e env) {
|
||||
te := newTest(t, e)
|
||||
te.serverCompression = false
|
||||
te.clientCompression = true
|
||||
te.clientCompression = false
|
||||
te.clientNopCompression = true
|
||||
te.startServer(&testServer{security: e.security})
|
||||
defer te.tearDown()
|
||||
tc := testpb.NewTestServiceClient(te.clientConn())
|
||||
|
@ -5572,3 +5615,65 @@ func TestMethodFromServerStream(t *testing.T) {
|
|||
t.Fatalf("Invoke with method %q, got %q, %v, want %q, true", testMethod, method, ok, testMethod)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompressorRegister(t *testing.T) {
|
||||
defer leakcheck.Check(t)
|
||||
for _, e := range listTestEnv() {
|
||||
testCompressorRegister(t, e)
|
||||
}
|
||||
}
|
||||
|
||||
func testCompressorRegister(t *testing.T, e env) {
|
||||
te := newTest(t, e)
|
||||
te.clientCompression = false
|
||||
te.serverCompression = false
|
||||
te.clientUseCompression = true
|
||||
|
||||
te.startServer(&testServer{security: e.security})
|
||||
defer te.tearDown()
|
||||
tc := testpb.NewTestServiceClient(te.clientConn())
|
||||
|
||||
// Unary call
|
||||
const argSize = 271828
|
||||
const respSize = 314159
|
||||
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, argSize)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req := &testpb.SimpleRequest{
|
||||
ResponseType: testpb.PayloadType_COMPRESSABLE,
|
||||
ResponseSize: respSize,
|
||||
Payload: payload,
|
||||
}
|
||||
ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("something", "something"))
|
||||
if _, err := tc.UnaryCall(ctx, req); err != nil {
|
||||
t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, <nil>", err)
|
||||
}
|
||||
// Streaming RPC
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
stream, err := tc.FullDuplexCall(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
|
||||
}
|
||||
respParam := []*testpb.ResponseParameters{
|
||||
{
|
||||
Size: 31415,
|
||||
},
|
||||
}
|
||||
payload, err = newPayload(testpb.PayloadType_COMPRESSABLE, int32(31415))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sreq := &testpb.StreamingOutputCallRequest{
|
||||
ResponseType: testpb.PayloadType_COMPRESSABLE,
|
||||
ResponseParameters: respParam,
|
||||
Payload: payload,
|
||||
}
|
||||
if err := stream.Send(sreq); err != nil {
|
||||
t.Fatalf("%v.Send(%v) = %v, want <nil>", stream, sreq, err)
|
||||
}
|
||||
if _, err := stream.Recv(); err != nil {
|
||||
t.Fatalf("%v.Recv() = %v, want <nil>", stream, err)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue