mirror of https://github.com/grpc/grpc-go.git
Implement new Codec that uses `mem.BufferSlice` instead of `[]byte` (#7356)
This commit is contained in:
parent
7e12068baf
commit
9ab8b62505
|
|
@ -55,6 +55,10 @@ jobs:
|
||||||
goversion: '1.22'
|
goversion: '1.22'
|
||||||
testflags: -race
|
testflags: -race
|
||||||
|
|
||||||
|
- type: tests
|
||||||
|
goversion: '1.22'
|
||||||
|
testflags: '-race -tags=buffer_pooling'
|
||||||
|
|
||||||
- type: tests
|
- type: tests
|
||||||
goversion: '1.22'
|
goversion: '1.22'
|
||||||
goarch: 386
|
goarch: 386
|
||||||
|
|
|
||||||
|
|
@ -66,11 +66,11 @@ import (
|
||||||
"google.golang.org/grpc/benchmark/stats"
|
"google.golang.org/grpc/benchmark/stats"
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
"google.golang.org/grpc/encoding/gzip"
|
"google.golang.org/grpc/encoding/gzip"
|
||||||
"google.golang.org/grpc/experimental"
|
|
||||||
"google.golang.org/grpc/grpclog"
|
"google.golang.org/grpc/grpclog"
|
||||||
"google.golang.org/grpc/internal"
|
"google.golang.org/grpc/internal"
|
||||||
"google.golang.org/grpc/internal/channelz"
|
"google.golang.org/grpc/internal/channelz"
|
||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
|
"google.golang.org/grpc/mem"
|
||||||
"google.golang.org/grpc/metadata"
|
"google.golang.org/grpc/metadata"
|
||||||
"google.golang.org/grpc/test/bufconn"
|
"google.golang.org/grpc/test/bufconn"
|
||||||
|
|
||||||
|
|
@ -153,6 +153,33 @@ const (
|
||||||
warmuptime = time.Second
|
warmuptime = time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var useNopBufferPool atomic.Bool
|
||||||
|
|
||||||
|
type swappableBufferPool struct {
|
||||||
|
mem.BufferPool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p swappableBufferPool) Get(length int) *[]byte {
|
||||||
|
var pool mem.BufferPool
|
||||||
|
if useNopBufferPool.Load() {
|
||||||
|
pool = mem.NopBufferPool{}
|
||||||
|
} else {
|
||||||
|
pool = p.BufferPool
|
||||||
|
}
|
||||||
|
return pool.Get(length)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p swappableBufferPool) Put(i *[]byte) {
|
||||||
|
if useNopBufferPool.Load() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.BufferPool.Put(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
internal.SetDefaultBufferPoolForTesting.(func(mem.BufferPool))(swappableBufferPool{mem.DefaultBufferPool()})
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
allWorkloads = []string{workloadsUnary, workloadsStreaming, workloadsUnconstrained, workloadsAll}
|
allWorkloads = []string{workloadsUnary, workloadsStreaming, workloadsUnconstrained, workloadsAll}
|
||||||
allCompModes = []string{compModeOff, compModeGzip, compModeNop, compModeAll}
|
allCompModes = []string{compModeOff, compModeGzip, compModeNop, compModeAll}
|
||||||
|
|
@ -343,10 +370,9 @@ func makeClients(bf stats.Features) ([]testgrpc.BenchmarkServiceClient, func())
|
||||||
}
|
}
|
||||||
switch bf.RecvBufferPool {
|
switch bf.RecvBufferPool {
|
||||||
case recvBufferPoolNil:
|
case recvBufferPoolNil:
|
||||||
// Do nothing.
|
useNopBufferPool.Store(true)
|
||||||
case recvBufferPoolSimple:
|
case recvBufferPoolSimple:
|
||||||
opts = append(opts, experimental.WithRecvBufferPool(grpc.NewSharedBufferPool()))
|
// Do nothing as buffering is enabled by default.
|
||||||
sopts = append(sopts, experimental.RecvBufferPool(grpc.NewSharedBufferPool()))
|
|
||||||
default:
|
default:
|
||||||
logger.Fatalf("Unknown shared recv buffer pool type: %v", bf.RecvBufferPool)
|
logger.Fatalf("Unknown shared recv buffer pool type: %v", bf.RecvBufferPool)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
75
codec.go
75
codec.go
|
|
@ -21,18 +21,79 @@ package grpc
|
||||||
import (
|
import (
|
||||||
"google.golang.org/grpc/encoding"
|
"google.golang.org/grpc/encoding"
|
||||||
_ "google.golang.org/grpc/encoding/proto" // to register the Codec for "proto"
|
_ "google.golang.org/grpc/encoding/proto" // to register the Codec for "proto"
|
||||||
|
"google.golang.org/grpc/mem"
|
||||||
)
|
)
|
||||||
|
|
||||||
// baseCodec contains the functionality of both Codec and encoding.Codec, but
|
// baseCodec captures the new encoding.CodecV2 interface without the Name
|
||||||
// omits the name/string, which vary between the two and are not needed for
|
// function, allowing it to be implemented by older Codec and encoding.Codec
|
||||||
// anything besides the registry in the encoding package.
|
// implementations. The omitted Name function is only needed for the register in
|
||||||
|
// the encoding package and is not part of the core functionality.
|
||||||
type baseCodec interface {
|
type baseCodec interface {
|
||||||
Marshal(v any) ([]byte, error)
|
Marshal(v any) (mem.BufferSlice, error)
|
||||||
Unmarshal(data []byte, v any) error
|
Unmarshal(data mem.BufferSlice, v any) error
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ baseCodec = Codec(nil)
|
// getCodec returns an encoding.CodecV2 for the codec of the given name (if
|
||||||
var _ baseCodec = encoding.Codec(nil)
|
// registered). Initially checks the V2 registry with encoding.GetCodecV2 and
|
||||||
|
// returns the V2 codec if it is registered. Otherwise, it checks the V1 registry
|
||||||
|
// with encoding.GetCodec and if it is registered wraps it with newCodecV1Bridge
|
||||||
|
// to turn it into an encoding.CodecV2. Returns nil otherwise.
|
||||||
|
func getCodec(name string) encoding.CodecV2 {
|
||||||
|
codecV2 := encoding.GetCodecV2(name)
|
||||||
|
if codecV2 != nil {
|
||||||
|
return codecV2
|
||||||
|
}
|
||||||
|
|
||||||
|
codecV1 := encoding.GetCodec(name)
|
||||||
|
if codecV1 != nil {
|
||||||
|
return newCodecV1Bridge(codecV1)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newCodecV0Bridge(c Codec) baseCodec {
|
||||||
|
return codecV0Bridge{codec: c}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newCodecV1Bridge(c encoding.Codec) encoding.CodecV2 {
|
||||||
|
return codecV1Bridge{
|
||||||
|
codecV0Bridge: codecV0Bridge{codec: c},
|
||||||
|
name: c.Name(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ baseCodec = codecV0Bridge{}
|
||||||
|
|
||||||
|
type codecV0Bridge struct {
|
||||||
|
codec interface {
|
||||||
|
Marshal(v any) ([]byte, error)
|
||||||
|
Unmarshal(data []byte, v any) error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c codecV0Bridge) Marshal(v any) (mem.BufferSlice, error) {
|
||||||
|
data, err := c.codec.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return mem.BufferSlice{mem.NewBuffer(&data, nil)}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c codecV0Bridge) Unmarshal(data mem.BufferSlice, v any) (err error) {
|
||||||
|
return c.codec.Unmarshal(data.Materialize(), v)
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ encoding.CodecV2 = codecV1Bridge{}
|
||||||
|
|
||||||
|
type codecV1Bridge struct {
|
||||||
|
codecV0Bridge
|
||||||
|
name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c codecV1Bridge) Name() string {
|
||||||
|
return c.name
|
||||||
|
}
|
||||||
|
|
||||||
// Codec defines the interface gRPC uses to encode and decode messages.
|
// Codec defines the interface gRPC uses to encode and decode messages.
|
||||||
// Note that implementations of this interface must be thread safe;
|
// Note that implementations of this interface must be thread safe;
|
||||||
|
|
|
||||||
|
|
@ -33,6 +33,7 @@ import (
|
||||||
"google.golang.org/grpc/internal/binarylog"
|
"google.golang.org/grpc/internal/binarylog"
|
||||||
"google.golang.org/grpc/internal/transport"
|
"google.golang.org/grpc/internal/transport"
|
||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
|
"google.golang.org/grpc/mem"
|
||||||
"google.golang.org/grpc/resolver"
|
"google.golang.org/grpc/resolver"
|
||||||
"google.golang.org/grpc/stats"
|
"google.golang.org/grpc/stats"
|
||||||
)
|
)
|
||||||
|
|
@ -60,7 +61,7 @@ func init() {
|
||||||
internal.WithBinaryLogger = withBinaryLogger
|
internal.WithBinaryLogger = withBinaryLogger
|
||||||
internal.JoinDialOptions = newJoinDialOption
|
internal.JoinDialOptions = newJoinDialOption
|
||||||
internal.DisableGlobalDialOptions = newDisableGlobalDialOptions
|
internal.DisableGlobalDialOptions = newDisableGlobalDialOptions
|
||||||
internal.WithRecvBufferPool = withRecvBufferPool
|
internal.WithBufferPool = withBufferPool
|
||||||
}
|
}
|
||||||
|
|
||||||
// dialOptions configure a Dial call. dialOptions are set by the DialOption
|
// dialOptions configure a Dial call. dialOptions are set by the DialOption
|
||||||
|
|
@ -92,7 +93,6 @@ type dialOptions struct {
|
||||||
defaultServiceConfigRawJSON *string
|
defaultServiceConfigRawJSON *string
|
||||||
resolvers []resolver.Builder
|
resolvers []resolver.Builder
|
||||||
idleTimeout time.Duration
|
idleTimeout time.Duration
|
||||||
recvBufferPool SharedBufferPool
|
|
||||||
defaultScheme string
|
defaultScheme string
|
||||||
maxCallAttempts int
|
maxCallAttempts int
|
||||||
}
|
}
|
||||||
|
|
@ -679,11 +679,11 @@ func defaultDialOptions() dialOptions {
|
||||||
WriteBufferSize: defaultWriteBufSize,
|
WriteBufferSize: defaultWriteBufSize,
|
||||||
UseProxy: true,
|
UseProxy: true,
|
||||||
UserAgent: grpcUA,
|
UserAgent: grpcUA,
|
||||||
|
BufferPool: mem.DefaultBufferPool(),
|
||||||
},
|
},
|
||||||
bs: internalbackoff.DefaultExponential,
|
bs: internalbackoff.DefaultExponential,
|
||||||
healthCheckFunc: internal.HealthCheckFunc,
|
healthCheckFunc: internal.HealthCheckFunc,
|
||||||
idleTimeout: 30 * time.Minute,
|
idleTimeout: 30 * time.Minute,
|
||||||
recvBufferPool: nopBufferPool{},
|
|
||||||
defaultScheme: "dns",
|
defaultScheme: "dns",
|
||||||
maxCallAttempts: defaultMaxCallAttempts,
|
maxCallAttempts: defaultMaxCallAttempts,
|
||||||
}
|
}
|
||||||
|
|
@ -760,25 +760,8 @@ func WithMaxCallAttempts(n int) DialOption {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithRecvBufferPool returns a DialOption that configures the ClientConn
|
func withBufferPool(bufferPool mem.BufferPool) DialOption {
|
||||||
// to use the provided shared buffer pool for parsing incoming messages. Depending
|
|
||||||
// on the application's workload, this could result in reduced memory allocation.
|
|
||||||
//
|
|
||||||
// If you are unsure about how to implement a memory pool but want to utilize one,
|
|
||||||
// begin with grpc.NewSharedBufferPool.
|
|
||||||
//
|
|
||||||
// Note: The shared buffer pool feature will not be active if any of the following
|
|
||||||
// options are used: WithStatsHandler, EnableTracing, or binary logging. In such
|
|
||||||
// cases, the shared buffer pool will be ignored.
|
|
||||||
//
|
|
||||||
// Deprecated: use experimental.WithRecvBufferPool instead. Will be deleted in
|
|
||||||
// v1.60.0 or later.
|
|
||||||
func WithRecvBufferPool(bufferPool SharedBufferPool) DialOption {
|
|
||||||
return withRecvBufferPool(bufferPool)
|
|
||||||
}
|
|
||||||
|
|
||||||
func withRecvBufferPool(bufferPool SharedBufferPool) DialOption {
|
|
||||||
return newFuncDialOption(func(o *dialOptions) {
|
return newFuncDialOption(func(o *dialOptions) {
|
||||||
o.recvBufferPool = bufferPool
|
o.copts.BufferPool = bufferPool
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,82 @@
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* Copyright 2024 gRPC authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package encoding
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"google.golang.org/grpc/mem"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CodecV2 defines the interface gRPC uses to encode and decode messages. Note
|
||||||
|
// that implementations of this interface must be thread safe; a CodecV2's
|
||||||
|
// methods can be called from concurrent goroutines.
|
||||||
|
type CodecV2 interface {
|
||||||
|
// Marshal returns the wire format of v. The buffers in the returned
|
||||||
|
// [mem.BufferSlice] must have at least one reference each, which will be freed
|
||||||
|
// by gRPC when they are no longer needed.
|
||||||
|
Marshal(v any) (out mem.BufferSlice, err error)
|
||||||
|
// Unmarshal parses the wire format into v. Note that data will be freed as soon
|
||||||
|
// as this function returns. If the codec wishes to guarantee access to the data
|
||||||
|
// after this function, it must take its own reference that it frees when it is
|
||||||
|
// no longer needed.
|
||||||
|
Unmarshal(data mem.BufferSlice, v any) error
|
||||||
|
// Name returns the name of the Codec implementation. The returned string
|
||||||
|
// will be used as part of content type in transmission. The result must be
|
||||||
|
// static; the result cannot change between calls.
|
||||||
|
Name() string
|
||||||
|
}
|
||||||
|
|
||||||
|
var registeredV2Codecs = make(map[string]CodecV2)
|
||||||
|
|
||||||
|
// RegisterCodecV2 registers the provided CodecV2 for use with all gRPC clients and
|
||||||
|
// servers.
|
||||||
|
//
|
||||||
|
// The CodecV2 will be stored and looked up by result of its Name() method, which
|
||||||
|
// should match the content-subtype of the encoding handled by the CodecV2. This
|
||||||
|
// is case-insensitive, and is stored and looked up as lowercase. If the
|
||||||
|
// result of calling Name() is an empty string, RegisterCodecV2 will panic. See
|
||||||
|
// Content-Type on
|
||||||
|
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
|
||||||
|
// more details.
|
||||||
|
//
|
||||||
|
// If both a Codec and CodecV2 are registered with the same name, the CodecV2
|
||||||
|
// will be used.
|
||||||
|
//
|
||||||
|
// NOTE: this function must only be called during initialization time (i.e. in
|
||||||
|
// an init() function), and is not thread-safe. If multiple Codecs are
|
||||||
|
// registered with the same name, the one registered last will take effect.
|
||||||
|
func RegisterCodecV2(codec CodecV2) {
|
||||||
|
if codec == nil {
|
||||||
|
panic("cannot register a nil CodecV2")
|
||||||
|
}
|
||||||
|
if codec.Name() == "" {
|
||||||
|
panic("cannot register CodecV2 with empty string result for Name()")
|
||||||
|
}
|
||||||
|
contentSubtype := strings.ToLower(codec.Name())
|
||||||
|
registeredV2Codecs[contentSubtype] = codec
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCodecV2 gets a registered CodecV2 by content-subtype, or nil if no CodecV2 is
|
||||||
|
// registered for the content-subtype.
|
||||||
|
//
|
||||||
|
// The content-subtype is expected to be lowercase.
|
||||||
|
func GetCodecV2(contentSubtype string) CodecV2 {
|
||||||
|
return registeredV2Codecs[contentSubtype]
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,81 @@
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* Copyright 2024 gRPC authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package proto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"google.golang.org/grpc/encoding"
|
||||||
|
"google.golang.org/grpc/mem"
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
encoding.RegisterCodecV2(&codecV2{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// codec is a CodecV2 implementation with protobuf. It is the default codec for
|
||||||
|
// gRPC.
|
||||||
|
type codecV2 struct{}
|
||||||
|
|
||||||
|
var _ encoding.CodecV2 = (*codecV2)(nil)
|
||||||
|
|
||||||
|
func (c *codecV2) Marshal(v any) (data mem.BufferSlice, err error) {
|
||||||
|
vv := messageV2Of(v)
|
||||||
|
if vv == nil {
|
||||||
|
return nil, fmt.Errorf("proto: failed to marshal, message is %T, want proto.Message", v)
|
||||||
|
}
|
||||||
|
|
||||||
|
size := proto.Size(vv)
|
||||||
|
if mem.IsBelowBufferPoolingThreshold(size) {
|
||||||
|
buf, err := proto.Marshal(vv)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
data = append(data, mem.SliceBuffer(buf))
|
||||||
|
} else {
|
||||||
|
pool := mem.DefaultBufferPool()
|
||||||
|
buf := pool.Get(size)
|
||||||
|
if _, err := (proto.MarshalOptions{}).MarshalAppend((*buf)[:0], vv); err != nil {
|
||||||
|
pool.Put(buf)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
data = append(data, mem.NewBuffer(buf, pool))
|
||||||
|
}
|
||||||
|
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *codecV2) Unmarshal(data mem.BufferSlice, v any) (err error) {
|
||||||
|
vv := messageV2Of(v)
|
||||||
|
if vv == nil {
|
||||||
|
return fmt.Errorf("failed to unmarshal, message is %T, want proto.Message", v)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := data.MaterializeToBuffer(mem.DefaultBufferPool())
|
||||||
|
defer buf.Free()
|
||||||
|
// TODO: Upgrade proto.Unmarshal to support mem.BufferSlice. Right now, it's not
|
||||||
|
// really possible without a major overhaul of the proto package, but the
|
||||||
|
// vtprotobuf library may be able to support this.
|
||||||
|
return proto.Unmarshal(buf.ReadOnlyData(), vv)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *codecV2) Name() string {
|
||||||
|
return Name
|
||||||
|
}
|
||||||
|
|
@ -28,38 +28,37 @@ package experimental
|
||||||
import (
|
import (
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/internal"
|
"google.golang.org/grpc/internal"
|
||||||
|
"google.golang.org/grpc/mem"
|
||||||
)
|
)
|
||||||
|
|
||||||
// WithRecvBufferPool returns a grpc.DialOption that configures the use of
|
// WithBufferPool returns a grpc.DialOption that configures the use of bufferPool
|
||||||
// bufferPool for parsing incoming messages on a grpc.ClientConn. Depending on
|
// for parsing incoming messages on a grpc.ClientConn, and for temporary buffers
|
||||||
// the application's workload, this could result in reduced memory allocation.
|
// when marshaling outgoing messages. By default, mem.DefaultBufferPool is used,
|
||||||
|
// and this option only exists to provide alternative buffer pool implementations
|
||||||
|
// to the client, such as more optimized size allocations etc. However, the
|
||||||
|
// default buffer pool is already tuned to account for many different use-cases.
|
||||||
//
|
//
|
||||||
// If you are unsure about how to implement a memory pool but want to utilize
|
// Note: The following options will interfere with the buffer pool because they
|
||||||
// one, begin with grpc.NewSharedBufferPool.
|
// require a fully materialized buffer instead of a sequence of buffers:
|
||||||
//
|
// EnableTracing, and binary logging. In such cases, materializing the buffer
|
||||||
// Note: The shared buffer pool feature will not be active if any of the
|
// will generate a lot of garbage, reducing the overall benefit from using a
|
||||||
// following options are used: WithStatsHandler, EnableTracing, or binary
|
// pool.
|
||||||
// logging. In such cases, the shared buffer pool will be ignored.
|
func WithBufferPool(bufferPool mem.BufferPool) grpc.DialOption {
|
||||||
//
|
return internal.WithBufferPool.(func(mem.BufferPool) grpc.DialOption)(bufferPool)
|
||||||
// Note: It is not recommended to use the shared buffer pool when compression is
|
|
||||||
// enabled.
|
|
||||||
func WithRecvBufferPool(bufferPool grpc.SharedBufferPool) grpc.DialOption {
|
|
||||||
return internal.WithRecvBufferPool.(func(grpc.SharedBufferPool) grpc.DialOption)(bufferPool)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RecvBufferPool returns a grpc.ServerOption that configures the server to use
|
// BufferPool returns a grpc.ServerOption that configures the server to use the
|
||||||
// the provided shared buffer pool for parsing incoming messages. Depending on
|
// provided buffer pool for parsing incoming messages and for temporary buffers
|
||||||
// the application's workload, this could result in reduced memory allocation.
|
// when marshaling outgoing messages. By default, mem.DefaultBufferPool is used,
|
||||||
|
// and this option only exists to provide alternative buffer pool implementations
|
||||||
|
// to the server, such as more optimized size allocations etc. However, the
|
||||||
|
// default buffer pool is already tuned to account for many different use-cases.
|
||||||
//
|
//
|
||||||
// If you are unsure about how to implement a memory pool but want to utilize
|
// Note: The following options will interfere with the buffer pool because they
|
||||||
// one, begin with grpc.NewSharedBufferPool.
|
// require a fully materialized buffer instead of a sequence of buffers:
|
||||||
//
|
// EnableTracing, and binary logging. In such cases, materializing the buffer
|
||||||
// Note: The shared buffer pool feature will not be active if any of the
|
// will generate a lot of garbage, reducing the overall benefit from using a
|
||||||
// following options are used: StatsHandler, EnableTracing, or binary logging.
|
// pool.
|
||||||
// In such cases, the shared buffer pool will be ignored.
|
func BufferPool(bufferPool mem.BufferPool) grpc.ServerOption {
|
||||||
//
|
return internal.BufferPool.(func(mem.BufferPool) grpc.ServerOption)(bufferPool)
|
||||||
// Note: It is not recommended to use the shared buffer pool when compression is
|
|
||||||
// enabled.
|
|
||||||
func RecvBufferPool(bufferPool grpc.SharedBufferPool) grpc.ServerOption {
|
|
||||||
return internal.RecvBufferPool.(func(grpc.SharedBufferPool) grpc.ServerOption)(bufferPool)
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -46,6 +46,9 @@ func Test(t *testing.T) {
|
||||||
const defaultTestTimeout = 10 * time.Second
|
const defaultTestTimeout = 10 * time.Second
|
||||||
|
|
||||||
func (s) TestRecvBufferPoolStream(t *testing.T) {
|
func (s) TestRecvBufferPoolStream(t *testing.T) {
|
||||||
|
// TODO: How much of this test can be preserved now that buffer reuse happens at
|
||||||
|
// the codec and HTTP/2 level?
|
||||||
|
t.SkipNow()
|
||||||
tcs := []struct {
|
tcs := []struct {
|
||||||
name string
|
name string
|
||||||
callOpts []grpc.CallOption
|
callOpts []grpc.CallOption
|
||||||
|
|
@ -83,8 +86,8 @@ func (s) TestRecvBufferPoolStream(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
pool := &checkBufferPool{}
|
pool := &checkBufferPool{}
|
||||||
sopts := []grpc.ServerOption{experimental.RecvBufferPool(pool)}
|
sopts := []grpc.ServerOption{experimental.BufferPool(pool)}
|
||||||
dopts := []grpc.DialOption{experimental.WithRecvBufferPool(pool)}
|
dopts := []grpc.DialOption{experimental.WithBufferPool(pool)}
|
||||||
if err := ss.Start(sopts, dopts...); err != nil {
|
if err := ss.Start(sopts, dopts...); err != nil {
|
||||||
t.Fatalf("Error starting endpoint server: %v", err)
|
t.Fatalf("Error starting endpoint server: %v", err)
|
||||||
}
|
}
|
||||||
|
|
@ -129,6 +132,8 @@ func (s) TestRecvBufferPoolStream(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s) TestRecvBufferPoolUnary(t *testing.T) {
|
func (s) TestRecvBufferPoolUnary(t *testing.T) {
|
||||||
|
// TODO: See above
|
||||||
|
t.SkipNow()
|
||||||
tcs := []struct {
|
tcs := []struct {
|
||||||
name string
|
name string
|
||||||
callOpts []grpc.CallOption
|
callOpts []grpc.CallOption
|
||||||
|
|
@ -159,8 +164,8 @@ func (s) TestRecvBufferPoolUnary(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
pool := &checkBufferPool{}
|
pool := &checkBufferPool{}
|
||||||
sopts := []grpc.ServerOption{experimental.RecvBufferPool(pool)}
|
sopts := []grpc.ServerOption{experimental.BufferPool(pool)}
|
||||||
dopts := []grpc.DialOption{experimental.WithRecvBufferPool(pool)}
|
dopts := []grpc.DialOption{experimental.WithBufferPool(pool)}
|
||||||
if err := ss.Start(sopts, dopts...); err != nil {
|
if err := ss.Start(sopts, dopts...); err != nil {
|
||||||
t.Fatalf("Error starting endpoint server: %v", err)
|
t.Fatalf("Error starting endpoint server: %v", err)
|
||||||
}
|
}
|
||||||
|
|
@ -196,8 +201,9 @@ type checkBufferPool struct {
|
||||||
puts [][]byte
|
puts [][]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *checkBufferPool) Get(size int) []byte {
|
func (p *checkBufferPool) Get(size int) *[]byte {
|
||||||
return make([]byte, size)
|
b := make([]byte, size)
|
||||||
|
return &b
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *checkBufferPool) Put(bs *[]byte) {
|
func (p *checkBufferPool) Put(bs *[]byte) {
|
||||||
|
|
|
||||||
|
|
@ -204,7 +204,7 @@ func (s) TestClientRPCEventsLogAll(t *testing.T) {
|
||||||
SequenceID: 2,
|
SequenceID: 2,
|
||||||
Authority: ss.Address,
|
Authority: ss.Address,
|
||||||
Payload: payload{
|
Payload: payload{
|
||||||
Message: []uint8{},
|
Message: nil,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
@ -285,7 +285,7 @@ func (s) TestClientRPCEventsLogAll(t *testing.T) {
|
||||||
SequenceID: 2,
|
SequenceID: 2,
|
||||||
Authority: ss.Address,
|
Authority: ss.Address,
|
||||||
Payload: payload{
|
Payload: payload{
|
||||||
Message: []uint8{},
|
Message: nil,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
@ -512,7 +512,7 @@ func (s) TestServerRPCEventsLogAll(t *testing.T) {
|
||||||
SequenceID: 4,
|
SequenceID: 4,
|
||||||
Authority: ss.Address,
|
Authority: ss.Address,
|
||||||
Payload: payload{
|
Payload: payload{
|
||||||
Message: []uint8{},
|
Message: nil,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
@ -870,7 +870,7 @@ func (s) TestPrecedenceOrderingInConfiguration(t *testing.T) {
|
||||||
SequenceID: 2,
|
SequenceID: 2,
|
||||||
Authority: ss.Address,
|
Authority: ss.Address,
|
||||||
Payload: payload{
|
Payload: payload{
|
||||||
Message: []uint8{},
|
Message: nil,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -18,11 +18,11 @@
|
||||||
package internal
|
package internal
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// WithRecvBufferPool is implemented by the grpc package and returns a dial
|
// WithBufferPool is implemented by the grpc package and returns a dial
|
||||||
// option to configure a shared buffer pool for a grpc.ClientConn.
|
// option to configure a shared buffer pool for a grpc.ClientConn.
|
||||||
WithRecvBufferPool any // func (grpc.SharedBufferPool) grpc.DialOption
|
WithBufferPool any // func (grpc.SharedBufferPool) grpc.DialOption
|
||||||
|
|
||||||
// RecvBufferPool is implemented by the grpc package and returns a server
|
// BufferPool is implemented by the grpc package and returns a server
|
||||||
// option to configure a shared buffer pool for a grpc.Server.
|
// option to configure a shared buffer pool for a grpc.Server.
|
||||||
RecvBufferPool any // func (grpc.SharedBufferPool) grpc.ServerOption
|
BufferPool any // func (grpc.SharedBufferPool) grpc.ServerOption
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -24,17 +24,22 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"google.golang.org/grpc/internal/leakcheck"
|
"google.golang.org/grpc/internal/leakcheck"
|
||||||
)
|
)
|
||||||
|
|
||||||
var lcFailed uint32
|
var lcFailed uint32
|
||||||
|
|
||||||
type errorer struct {
|
type logger struct {
|
||||||
t *testing.T
|
t *testing.T
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e errorer) Errorf(format string, args ...any) {
|
func (e logger) Logf(format string, args ...any) {
|
||||||
|
e.t.Logf(format, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e logger) Errorf(format string, args ...any) {
|
||||||
atomic.StoreUint32(&lcFailed, 1)
|
atomic.StoreUint32(&lcFailed, 1)
|
||||||
e.t.Errorf(format, args...)
|
e.t.Errorf(format, args...)
|
||||||
}
|
}
|
||||||
|
|
@ -48,16 +53,22 @@ type Tester struct{}
|
||||||
// Setup updates the tlogger.
|
// Setup updates the tlogger.
|
||||||
func (Tester) Setup(t *testing.T) {
|
func (Tester) Setup(t *testing.T) {
|
||||||
TLogger.Update(t)
|
TLogger.Update(t)
|
||||||
|
// TODO: There is one final leak around closing connections without completely
|
||||||
|
// draining the recvBuffer that has yet to be resolved. All other leaks have been
|
||||||
|
// completely addressed, and this can be turned back on as soon as this issue is
|
||||||
|
// fixed.
|
||||||
|
leakcheck.SetTrackingBufferPool(logger{t: t})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Teardown performs a leak check.
|
// Teardown performs a leak check.
|
||||||
func (Tester) Teardown(t *testing.T) {
|
func (Tester) Teardown(t *testing.T) {
|
||||||
|
leakcheck.CheckTrackingBufferPool()
|
||||||
if atomic.LoadUint32(&lcFailed) == 1 {
|
if atomic.LoadUint32(&lcFailed) == 1 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
leakcheck.Check(errorer{t: t})
|
leakcheck.CheckGoroutines(logger{t: t}, 10*time.Second)
|
||||||
if atomic.LoadUint32(&lcFailed) == 1 {
|
if atomic.LoadUint32(&lcFailed) == 1 {
|
||||||
t.Log("Leak check disabled for future tests")
|
t.Log("Goroutine leak check disabled for future tests")
|
||||||
}
|
}
|
||||||
TLogger.EndTest(t)
|
TLogger.EndTest(t)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -225,6 +225,10 @@ var (
|
||||||
// SetDefaultBufferPoolForTesting updates the default buffer pool, for
|
// SetDefaultBufferPoolForTesting updates the default buffer pool, for
|
||||||
// testing purposes.
|
// testing purposes.
|
||||||
SetDefaultBufferPoolForTesting any // func(mem.BufferPool)
|
SetDefaultBufferPoolForTesting any // func(mem.BufferPool)
|
||||||
|
|
||||||
|
// SetBufferPoolingThresholdForTesting updates the buffer pooling threshold, for
|
||||||
|
// testing purposes.
|
||||||
|
SetBufferPoolingThresholdForTesting any // func(int)
|
||||||
)
|
)
|
||||||
|
|
||||||
// HealthChecker defines the signature of the client-side LB channel health
|
// HealthChecker defines the signature of the client-side LB channel health
|
||||||
|
|
|
||||||
|
|
@ -16,18 +16,171 @@
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// Package leakcheck contains functions to check leaked goroutines.
|
// Package leakcheck contains functions to check leaked goroutines and buffers.
|
||||||
//
|
//
|
||||||
// Call "defer leakcheck.Check(t)" at the beginning of tests.
|
// Call the following at the beginning of test:
|
||||||
|
//
|
||||||
|
// defer leakcheck.NewLeakChecker(t).Check()
|
||||||
package leakcheck
|
package leakcheck
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"runtime/debug"
|
||||||
|
"slices"
|
||||||
"sort"
|
"sort"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"google.golang.org/grpc/internal"
|
||||||
|
"google.golang.org/grpc/mem"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// failTestsOnLeakedBuffers is a special flag that will cause tests to fail if
|
||||||
|
// leaked buffers are detected, instead of simply logging them as an
|
||||||
|
// informational failure. This can be enabled with the "checkbuffers" compile
|
||||||
|
// flag, e.g.:
|
||||||
|
//
|
||||||
|
// go test -tags=checkbuffers
|
||||||
|
var failTestsOnLeakedBuffers = false
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
defaultPool := mem.DefaultBufferPool()
|
||||||
|
globalPool.Store(&defaultPool)
|
||||||
|
(internal.SetDefaultBufferPoolForTesting.(func(mem.BufferPool)))(&globalPool)
|
||||||
|
}
|
||||||
|
|
||||||
|
var globalPool swappableBufferPool
|
||||||
|
|
||||||
|
type swappableBufferPool struct {
|
||||||
|
atomic.Pointer[mem.BufferPool]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *swappableBufferPool) Get(length int) *[]byte {
|
||||||
|
return (*b.Load()).Get(length)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *swappableBufferPool) Put(buf *[]byte) {
|
||||||
|
(*b.Load()).Put(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTrackingBufferPool replaces the default buffer pool in the mem package to
|
||||||
|
// one that tracks where buffers are allocated. CheckTrackingBufferPool should
|
||||||
|
// then be invoked at the end of the test to validate that all buffers pulled
|
||||||
|
// from the pool were returned.
|
||||||
|
func SetTrackingBufferPool(logger Logger) {
|
||||||
|
newPool := mem.BufferPool(&trackingBufferPool{
|
||||||
|
pool: *globalPool.Load(),
|
||||||
|
logger: logger,
|
||||||
|
allocatedBuffers: make(map[*[]byte][]uintptr),
|
||||||
|
})
|
||||||
|
globalPool.Store(&newPool)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckTrackingBufferPool undoes the effects of SetTrackingBufferPool, and fails
|
||||||
|
// unit tests if not all buffers were returned. It is invalid to invoke this
|
||||||
|
// method without previously having invoked SetTrackingBufferPool.
|
||||||
|
func CheckTrackingBufferPool() {
|
||||||
|
p := (*globalPool.Load()).(*trackingBufferPool)
|
||||||
|
p.lock.Lock()
|
||||||
|
defer p.lock.Unlock()
|
||||||
|
|
||||||
|
globalPool.Store(&p.pool)
|
||||||
|
|
||||||
|
type uniqueTrace struct {
|
||||||
|
stack []uintptr
|
||||||
|
count int
|
||||||
|
}
|
||||||
|
|
||||||
|
var totalLeakedBuffers int
|
||||||
|
var uniqueTraces []uniqueTrace
|
||||||
|
for _, stack := range p.allocatedBuffers {
|
||||||
|
idx, ok := slices.BinarySearchFunc(uniqueTraces, stack, func(trace uniqueTrace, stack []uintptr) int {
|
||||||
|
return slices.Compare(trace.stack, stack)
|
||||||
|
})
|
||||||
|
if !ok {
|
||||||
|
uniqueTraces = slices.Insert(uniqueTraces, idx, uniqueTrace{stack: stack})
|
||||||
|
}
|
||||||
|
uniqueTraces[idx].count++
|
||||||
|
totalLeakedBuffers++
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ut := range uniqueTraces {
|
||||||
|
frames := runtime.CallersFrames(ut.stack)
|
||||||
|
var trace strings.Builder
|
||||||
|
for {
|
||||||
|
f, ok := frames.Next()
|
||||||
|
if !ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
trace.WriteString(f.Function)
|
||||||
|
trace.WriteString("\n\t")
|
||||||
|
trace.WriteString(f.File)
|
||||||
|
trace.WriteString(":")
|
||||||
|
trace.WriteString(strconv.Itoa(f.Line))
|
||||||
|
trace.WriteString("\n")
|
||||||
|
}
|
||||||
|
format := "%d allocated buffers never freed:\n%s"
|
||||||
|
args := []any{ut.count, trace.String()}
|
||||||
|
if failTestsOnLeakedBuffers {
|
||||||
|
p.logger.Errorf(format, args...)
|
||||||
|
} else {
|
||||||
|
p.logger.Logf("WARNING "+format, args...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if totalLeakedBuffers > 0 {
|
||||||
|
p.logger.Logf("%g%% of buffers never freed", float64(totalLeakedBuffers)/float64(p.bufferCount))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type trackingBufferPool struct {
|
||||||
|
pool mem.BufferPool
|
||||||
|
logger Logger
|
||||||
|
|
||||||
|
lock sync.Mutex
|
||||||
|
bufferCount int
|
||||||
|
allocatedBuffers map[*[]byte][]uintptr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *trackingBufferPool) Get(length int) *[]byte {
|
||||||
|
p.lock.Lock()
|
||||||
|
defer p.lock.Unlock()
|
||||||
|
|
||||||
|
p.bufferCount++
|
||||||
|
|
||||||
|
buf := p.pool.Get(length)
|
||||||
|
|
||||||
|
var stackBuf [16]uintptr
|
||||||
|
var stack []uintptr
|
||||||
|
skip := 2
|
||||||
|
for {
|
||||||
|
n := runtime.Callers(skip, stackBuf[:])
|
||||||
|
stack = append(stack, stackBuf[:n]...)
|
||||||
|
if n < len(stackBuf) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
skip += len(stackBuf)
|
||||||
|
}
|
||||||
|
p.allocatedBuffers[buf] = stack
|
||||||
|
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *trackingBufferPool) Put(buf *[]byte) {
|
||||||
|
p.lock.Lock()
|
||||||
|
defer p.lock.Unlock()
|
||||||
|
|
||||||
|
if _, ok := p.allocatedBuffers[buf]; !ok {
|
||||||
|
p.logger.Errorf("Unknown buffer freed:\n%s", string(debug.Stack()))
|
||||||
|
} else {
|
||||||
|
delete(p.allocatedBuffers, buf)
|
||||||
|
}
|
||||||
|
p.pool.Put(buf)
|
||||||
|
}
|
||||||
|
|
||||||
var goroutinesToIgnore = []string{
|
var goroutinesToIgnore = []string{
|
||||||
"testing.Main(",
|
"testing.Main(",
|
||||||
"testing.tRunner(",
|
"testing.tRunner(",
|
||||||
|
|
@ -94,13 +247,17 @@ func interestingGoroutines() (gs []string) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Errorfer is the interface that wraps the Errorf method. It's a subset of
|
// Logger is the interface that wraps the Logf and Errorf method. It's a subset
|
||||||
// testing.TB to make it easy to use Check.
|
// of testing.TB to make it easy to use this package.
|
||||||
type Errorfer interface {
|
type Logger interface {
|
||||||
|
Logf(format string, args ...any)
|
||||||
Errorf(format string, args ...any)
|
Errorf(format string, args ...any)
|
||||||
}
|
}
|
||||||
|
|
||||||
func check(efer Errorfer, timeout time.Duration) {
|
// CheckGoroutines looks at the currently-running goroutines and checks if there
|
||||||
|
// are any interesting (created by gRPC) goroutines leaked. It waits up to 10
|
||||||
|
// seconds in the error cases.
|
||||||
|
func CheckGoroutines(logger Logger, timeout time.Duration) {
|
||||||
// Loop, waiting for goroutines to shut down.
|
// Loop, waiting for goroutines to shut down.
|
||||||
// Wait up to timeout, but finish as quickly as possible.
|
// Wait up to timeout, but finish as quickly as possible.
|
||||||
deadline := time.Now().Add(timeout)
|
deadline := time.Now().Add(timeout)
|
||||||
|
|
@ -112,13 +269,32 @@ func check(efer Errorfer, timeout time.Duration) {
|
||||||
time.Sleep(50 * time.Millisecond)
|
time.Sleep(50 * time.Millisecond)
|
||||||
}
|
}
|
||||||
for _, g := range leaked {
|
for _, g := range leaked {
|
||||||
efer.Errorf("Leaked goroutine: %v", g)
|
logger.Errorf("Leaked goroutine: %v", g)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check looks at the currently-running goroutines and checks if there are any
|
// LeakChecker captures an Logger and is returned by NewLeakChecker as a
|
||||||
// interesting (created by gRPC) goroutines leaked. It waits up to 10 seconds
|
// convenient method to set up leak check tests in a unit test.
|
||||||
// in the error cases.
|
type LeakChecker struct {
|
||||||
func Check(efer Errorfer) {
|
logger Logger
|
||||||
check(efer, 10*time.Second)
|
}
|
||||||
|
|
||||||
|
// Check executes the leak check tests, failing the unit test if any buffer or
|
||||||
|
// goroutine leaks are detected.
|
||||||
|
func (lc *LeakChecker) Check() {
|
||||||
|
CheckTrackingBufferPool()
|
||||||
|
CheckGoroutines(lc.logger, 10*time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewLeakChecker offers a convenient way to set up the leak checks for a
|
||||||
|
// specific unit test. It can be used as follows, at the beginning of tests:
|
||||||
|
//
|
||||||
|
// defer leakcheck.NewLeakChecker(t).Check()
|
||||||
|
//
|
||||||
|
// It initially invokes SetTrackingBufferPool to set up buffer tracking, then the
|
||||||
|
// deferred LeakChecker.Check call will invoke CheckTrackingBufferPool and
|
||||||
|
// CheckGoroutines with a default timeout of 10 seconds.
|
||||||
|
func NewLeakChecker(logger Logger) *LeakChecker {
|
||||||
|
SetTrackingBufferPool(logger)
|
||||||
|
return &LeakChecker{logger: logger}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
|
//go:build checkbuffers
|
||||||
|
|
||||||
/*
|
/*
|
||||||
*
|
*
|
||||||
* Copyright 2023 gRPC authors.
|
* Copyright 2017 gRPC authors.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
@ -16,33 +18,8 @@
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package grpc
|
package leakcheck
|
||||||
|
|
||||||
import "testing"
|
func init() {
|
||||||
|
failTestsOnLeakedBuffers = true
|
||||||
func (s) TestSharedBufferPool(t *testing.T) {
|
|
||||||
pools := []SharedBufferPool{
|
|
||||||
nopBufferPool{},
|
|
||||||
NewSharedBufferPool(),
|
|
||||||
}
|
|
||||||
|
|
||||||
lengths := []int{
|
|
||||||
level4PoolMaxSize + 1,
|
|
||||||
level4PoolMaxSize,
|
|
||||||
level3PoolMaxSize,
|
|
||||||
level2PoolMaxSize,
|
|
||||||
level1PoolMaxSize,
|
|
||||||
level0PoolMaxSize,
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, p := range pools {
|
|
||||||
for _, l := range lengths {
|
|
||||||
bs := p.Get(l)
|
|
||||||
if len(bs) != l {
|
|
||||||
t.Fatalf("Expected buffer of length %d, got %d", l, len(bs))
|
|
||||||
}
|
|
||||||
|
|
||||||
p.Put(&bs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
@ -25,12 +25,15 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type testErrorfer struct {
|
type testLogger struct {
|
||||||
errorCount int
|
errorCount int
|
||||||
errors []string
|
errors []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *testErrorfer) Errorf(format string, args ...any) {
|
func (e *testLogger) Logf(format string, args ...any) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *testLogger) Errorf(format string, args ...any) {
|
||||||
e.errors = append(e.errors, fmt.Sprintf(format, args...))
|
e.errors = append(e.errors, fmt.Sprintf(format, args...))
|
||||||
e.errorCount++
|
e.errorCount++
|
||||||
}
|
}
|
||||||
|
|
@ -43,13 +46,13 @@ func TestCheck(t *testing.T) {
|
||||||
if ig := interestingGoroutines(); len(ig) == 0 {
|
if ig := interestingGoroutines(); len(ig) == 0 {
|
||||||
t.Error("blah")
|
t.Error("blah")
|
||||||
}
|
}
|
||||||
e := &testErrorfer{}
|
e := &testLogger{}
|
||||||
check(e, time.Second)
|
CheckGoroutines(e, time.Second)
|
||||||
if e.errorCount != leakCount {
|
if e.errorCount != leakCount {
|
||||||
t.Errorf("check found %v leaks, want %v leaks", e.errorCount, leakCount)
|
t.Errorf("CheckGoroutines found %v leaks, want %v leaks", e.errorCount, leakCount)
|
||||||
t.Logf("leaked goroutines:\n%v", strings.Join(e.errors, "\n"))
|
t.Logf("leaked goroutines:\n%v", strings.Join(e.errors, "\n"))
|
||||||
}
|
}
|
||||||
check(t, 3*time.Second)
|
CheckGoroutines(t, 3*time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ignoredTestingLeak(d time.Duration) {
|
func ignoredTestingLeak(d time.Duration) {
|
||||||
|
|
@ -66,11 +69,11 @@ func TestCheckRegisterIgnore(t *testing.T) {
|
||||||
if ig := interestingGoroutines(); len(ig) == 0 {
|
if ig := interestingGoroutines(); len(ig) == 0 {
|
||||||
t.Error("blah")
|
t.Error("blah")
|
||||||
}
|
}
|
||||||
e := &testErrorfer{}
|
e := &testLogger{}
|
||||||
check(e, time.Second)
|
CheckGoroutines(e, time.Second)
|
||||||
if e.errorCount != leakCount {
|
if e.errorCount != leakCount {
|
||||||
t.Errorf("check found %v leaks, want %v leaks", e.errorCount, leakCount)
|
t.Errorf("CheckGoroutines found %v leaks, want %v leaks", e.errorCount, leakCount)
|
||||||
t.Logf("leaked goroutines:\n%v", strings.Join(e.errors, "\n"))
|
t.Logf("leaked goroutines:\n%v", strings.Join(e.errors, "\n"))
|
||||||
}
|
}
|
||||||
check(t, 3*time.Second)
|
CheckGoroutines(t, 3*time.Second)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,7 @@ import (
|
||||||
"golang.org/x/net/http2/hpack"
|
"golang.org/x/net/http2/hpack"
|
||||||
"google.golang.org/grpc/internal/grpclog"
|
"google.golang.org/grpc/internal/grpclog"
|
||||||
"google.golang.org/grpc/internal/grpcutil"
|
"google.golang.org/grpc/internal/grpcutil"
|
||||||
|
"google.golang.org/grpc/mem"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -148,9 +149,9 @@ type dataFrame struct {
|
||||||
streamID uint32
|
streamID uint32
|
||||||
endStream bool
|
endStream bool
|
||||||
h []byte
|
h []byte
|
||||||
d []byte
|
reader mem.Reader
|
||||||
// onEachWrite is called every time
|
// onEachWrite is called every time
|
||||||
// a part of d is written out.
|
// a part of data is written out.
|
||||||
onEachWrite func()
|
onEachWrite func()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -454,12 +455,13 @@ func (c *controlBuffer) finish() {
|
||||||
// These streams need to be cleaned out since the transport
|
// These streams need to be cleaned out since the transport
|
||||||
// is still not aware of these yet.
|
// is still not aware of these yet.
|
||||||
for head := c.list.dequeueAll(); head != nil; head = head.next {
|
for head := c.list.dequeueAll(); head != nil; head = head.next {
|
||||||
hdr, ok := head.it.(*headerFrame)
|
switch v := head.it.(type) {
|
||||||
if !ok {
|
case *headerFrame:
|
||||||
continue
|
if v.onOrphaned != nil { // It will be nil on the server-side.
|
||||||
}
|
v.onOrphaned(ErrConnClosing)
|
||||||
if hdr.onOrphaned != nil { // It will be nil on the server-side.
|
}
|
||||||
hdr.onOrphaned(ErrConnClosing)
|
case *dataFrame:
|
||||||
|
_ = v.reader.Close()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -509,12 +511,13 @@ type loopyWriter struct {
|
||||||
draining bool
|
draining bool
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
logger *grpclog.PrefixLogger
|
logger *grpclog.PrefixLogger
|
||||||
|
bufferPool mem.BufferPool
|
||||||
|
|
||||||
// Side-specific handlers
|
// Side-specific handlers
|
||||||
ssGoAwayHandler func(*goAway) (bool, error)
|
ssGoAwayHandler func(*goAway) (bool, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newLoopyWriter(s side, fr *framer, cbuf *controlBuffer, bdpEst *bdpEstimator, conn net.Conn, logger *grpclog.PrefixLogger, goAwayHandler func(*goAway) (bool, error)) *loopyWriter {
|
func newLoopyWriter(s side, fr *framer, cbuf *controlBuffer, bdpEst *bdpEstimator, conn net.Conn, logger *grpclog.PrefixLogger, goAwayHandler func(*goAway) (bool, error), bufferPool mem.BufferPool) *loopyWriter {
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
l := &loopyWriter{
|
l := &loopyWriter{
|
||||||
side: s,
|
side: s,
|
||||||
|
|
@ -530,6 +533,7 @@ func newLoopyWriter(s side, fr *framer, cbuf *controlBuffer, bdpEst *bdpEstimato
|
||||||
conn: conn,
|
conn: conn,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
ssGoAwayHandler: goAwayHandler,
|
ssGoAwayHandler: goAwayHandler,
|
||||||
|
bufferPool: bufferPool,
|
||||||
}
|
}
|
||||||
return l
|
return l
|
||||||
}
|
}
|
||||||
|
|
@ -787,6 +791,11 @@ func (l *loopyWriter) cleanupStreamHandler(c *cleanupStream) error {
|
||||||
// not be established yet.
|
// not be established yet.
|
||||||
delete(l.estdStreams, c.streamID)
|
delete(l.estdStreams, c.streamID)
|
||||||
str.deleteSelf()
|
str.deleteSelf()
|
||||||
|
for head := str.itl.dequeueAll(); head != nil; head = head.next {
|
||||||
|
if df, ok := head.it.(*dataFrame); ok {
|
||||||
|
_ = df.reader.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if c.rst { // If RST_STREAM needs to be sent.
|
if c.rst { // If RST_STREAM needs to be sent.
|
||||||
if err := l.framer.fr.WriteRSTStream(c.streamID, c.rstCode); err != nil {
|
if err := l.framer.fr.WriteRSTStream(c.streamID, c.rstCode); err != nil {
|
||||||
|
|
@ -922,16 +931,18 @@ func (l *loopyWriter) processData() (bool, error) {
|
||||||
dataItem := str.itl.peek().(*dataFrame) // Peek at the first data item this stream.
|
dataItem := str.itl.peek().(*dataFrame) // Peek at the first data item this stream.
|
||||||
// A data item is represented by a dataFrame, since it later translates into
|
// A data item is represented by a dataFrame, since it later translates into
|
||||||
// multiple HTTP2 data frames.
|
// multiple HTTP2 data frames.
|
||||||
// Every dataFrame has two buffers; h that keeps grpc-message header and d that is actual data.
|
// Every dataFrame has two buffers; h that keeps grpc-message header and data
|
||||||
// As an optimization to keep wire traffic low, data from d is copied to h to make as big as the
|
// that is the actual message. As an optimization to keep wire traffic low, data
|
||||||
// maximum possible HTTP2 frame size.
|
// from data is copied to h to make as big as the maximum possible HTTP2 frame
|
||||||
|
// size.
|
||||||
|
|
||||||
if len(dataItem.h) == 0 && len(dataItem.d) == 0 { // Empty data frame
|
if len(dataItem.h) == 0 && dataItem.reader.Remaining() == 0 { // Empty data frame
|
||||||
// Client sends out empty data frame with endStream = true
|
// Client sends out empty data frame with endStream = true
|
||||||
if err := l.framer.fr.WriteData(dataItem.streamID, dataItem.endStream, nil); err != nil {
|
if err := l.framer.fr.WriteData(dataItem.streamID, dataItem.endStream, nil); err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
str.itl.dequeue() // remove the empty data item from stream
|
str.itl.dequeue() // remove the empty data item from stream
|
||||||
|
_ = dataItem.reader.Close()
|
||||||
if str.itl.isEmpty() {
|
if str.itl.isEmpty() {
|
||||||
str.state = empty
|
str.state = empty
|
||||||
} else if trailer, ok := str.itl.peek().(*headerFrame); ok { // the next item is trailers.
|
} else if trailer, ok := str.itl.peek().(*headerFrame); ok { // the next item is trailers.
|
||||||
|
|
@ -946,9 +957,7 @@ func (l *loopyWriter) processData() (bool, error) {
|
||||||
}
|
}
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
var (
|
|
||||||
buf []byte
|
|
||||||
)
|
|
||||||
// Figure out the maximum size we can send
|
// Figure out the maximum size we can send
|
||||||
maxSize := http2MaxFrameLen
|
maxSize := http2MaxFrameLen
|
||||||
if strQuota := int(l.oiws) - str.bytesOutStanding; strQuota <= 0 { // stream-level flow control.
|
if strQuota := int(l.oiws) - str.bytesOutStanding; strQuota <= 0 { // stream-level flow control.
|
||||||
|
|
@ -962,43 +971,50 @@ func (l *loopyWriter) processData() (bool, error) {
|
||||||
}
|
}
|
||||||
// Compute how much of the header and data we can send within quota and max frame length
|
// Compute how much of the header and data we can send within quota and max frame length
|
||||||
hSize := min(maxSize, len(dataItem.h))
|
hSize := min(maxSize, len(dataItem.h))
|
||||||
dSize := min(maxSize-hSize, len(dataItem.d))
|
dSize := min(maxSize-hSize, dataItem.reader.Remaining())
|
||||||
if hSize != 0 {
|
remainingBytes := len(dataItem.h) + dataItem.reader.Remaining() - hSize - dSize
|
||||||
if dSize == 0 {
|
|
||||||
buf = dataItem.h
|
|
||||||
} else {
|
|
||||||
// We can add some data to grpc message header to distribute bytes more equally across frames.
|
|
||||||
// Copy on the stack to avoid generating garbage
|
|
||||||
var localBuf [http2MaxFrameLen]byte
|
|
||||||
copy(localBuf[:hSize], dataItem.h)
|
|
||||||
copy(localBuf[hSize:], dataItem.d[:dSize])
|
|
||||||
buf = localBuf[:hSize+dSize]
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
buf = dataItem.d
|
|
||||||
}
|
|
||||||
|
|
||||||
size := hSize + dSize
|
size := hSize + dSize
|
||||||
|
|
||||||
|
var buf *[]byte
|
||||||
|
|
||||||
|
if hSize != 0 && dSize == 0 {
|
||||||
|
buf = &dataItem.h
|
||||||
|
} else {
|
||||||
|
// Note: this is only necessary because the http2.Framer does not support
|
||||||
|
// partially writing a frame, so the sequence must be materialized into a buffer.
|
||||||
|
// TODO: Revisit once https://github.com/golang/go/issues/66655 is addressed.
|
||||||
|
pool := l.bufferPool
|
||||||
|
if pool == nil {
|
||||||
|
// Note that this is only supposed to be nil in tests. Otherwise, stream is
|
||||||
|
// always initialized with a BufferPool.
|
||||||
|
pool = mem.DefaultBufferPool()
|
||||||
|
}
|
||||||
|
buf = pool.Get(size)
|
||||||
|
defer pool.Put(buf)
|
||||||
|
|
||||||
|
copy((*buf)[:hSize], dataItem.h)
|
||||||
|
_, _ = dataItem.reader.Read((*buf)[hSize:])
|
||||||
|
}
|
||||||
|
|
||||||
// Now that outgoing flow controls are checked we can replenish str's write quota
|
// Now that outgoing flow controls are checked we can replenish str's write quota
|
||||||
str.wq.replenish(size)
|
str.wq.replenish(size)
|
||||||
var endStream bool
|
var endStream bool
|
||||||
// If this is the last data message on this stream and all of it can be written in this iteration.
|
// If this is the last data message on this stream and all of it can be written in this iteration.
|
||||||
if dataItem.endStream && len(dataItem.h)+len(dataItem.d) <= size {
|
if dataItem.endStream && remainingBytes == 0 {
|
||||||
endStream = true
|
endStream = true
|
||||||
}
|
}
|
||||||
if dataItem.onEachWrite != nil {
|
if dataItem.onEachWrite != nil {
|
||||||
dataItem.onEachWrite()
|
dataItem.onEachWrite()
|
||||||
}
|
}
|
||||||
if err := l.framer.fr.WriteData(dataItem.streamID, endStream, buf[:size]); err != nil {
|
if err := l.framer.fr.WriteData(dataItem.streamID, endStream, (*buf)[:size]); err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
str.bytesOutStanding += size
|
str.bytesOutStanding += size
|
||||||
l.sendQuota -= uint32(size)
|
l.sendQuota -= uint32(size)
|
||||||
dataItem.h = dataItem.h[hSize:]
|
dataItem.h = dataItem.h[hSize:]
|
||||||
dataItem.d = dataItem.d[dSize:]
|
|
||||||
|
|
||||||
if len(dataItem.h) == 0 && len(dataItem.d) == 0 { // All the data from that message was written out.
|
if remainingBytes == 0 { // All the data from that message was written out.
|
||||||
|
_ = dataItem.reader.Close()
|
||||||
str.itl.dequeue()
|
str.itl.dequeue()
|
||||||
}
|
}
|
||||||
if str.itl.isEmpty() {
|
if str.itl.isEmpty() {
|
||||||
|
|
|
||||||
|
|
@ -85,10 +85,10 @@ func (fr *FramerBridge) ReadFrame() (Frame, error) {
|
||||||
switch f := f.(type) {
|
switch f := f.(type) {
|
||||||
case *http2.DataFrame:
|
case *http2.DataFrame:
|
||||||
buf := fr.pool.Get(int(hdr.Size))
|
buf := fr.pool.Get(int(hdr.Size))
|
||||||
copy(buf, f.Data())
|
copy(*buf, f.Data())
|
||||||
return &DataFrame{
|
return &DataFrame{
|
||||||
hdr: hdr,
|
hdr: hdr,
|
||||||
Data: buf,
|
Data: *buf,
|
||||||
free: func() { fr.pool.Put(buf) },
|
free: func() { fr.pool.Put(buf) },
|
||||||
}, nil
|
}, nil
|
||||||
case *http2.RSTStreamFrame:
|
case *http2.RSTStreamFrame:
|
||||||
|
|
@ -111,21 +111,21 @@ func (fr *FramerBridge) ReadFrame() (Frame, error) {
|
||||||
}, nil
|
}, nil
|
||||||
case *http2.PingFrame:
|
case *http2.PingFrame:
|
||||||
buf := fr.pool.Get(int(hdr.Size))
|
buf := fr.pool.Get(int(hdr.Size))
|
||||||
copy(buf, f.Data[:])
|
copy(*buf, f.Data[:])
|
||||||
return &PingFrame{
|
return &PingFrame{
|
||||||
hdr: hdr,
|
hdr: hdr,
|
||||||
Data: buf,
|
Data: *buf,
|
||||||
free: func() { fr.pool.Put(buf) },
|
free: func() { fr.pool.Put(buf) },
|
||||||
}, nil
|
}, nil
|
||||||
case *http2.GoAwayFrame:
|
case *http2.GoAwayFrame:
|
||||||
// Size of the frame minus the code and lastStreamID
|
// Size of the frame minus the code and lastStreamID
|
||||||
buf := fr.pool.Get(int(hdr.Size) - 8)
|
buf := fr.pool.Get(int(hdr.Size) - 8)
|
||||||
copy(buf, f.DebugData())
|
copy(*buf, f.DebugData())
|
||||||
return &GoAwayFrame{
|
return &GoAwayFrame{
|
||||||
hdr: hdr,
|
hdr: hdr,
|
||||||
LastStreamID: f.LastStreamID,
|
LastStreamID: f.LastStreamID,
|
||||||
Code: ErrCode(f.ErrCode),
|
Code: ErrCode(f.ErrCode),
|
||||||
DebugData: buf,
|
DebugData: *buf,
|
||||||
free: func() { fr.pool.Put(buf) },
|
free: func() { fr.pool.Put(buf) },
|
||||||
}, nil
|
}, nil
|
||||||
case *http2.WindowUpdateFrame:
|
case *http2.WindowUpdateFrame:
|
||||||
|
|
@ -141,10 +141,10 @@ func (fr *FramerBridge) ReadFrame() (Frame, error) {
|
||||||
default:
|
default:
|
||||||
buf := fr.pool.Get(int(hdr.Size))
|
buf := fr.pool.Get(int(hdr.Size))
|
||||||
uf := f.(*http2.UnknownFrame)
|
uf := f.(*http2.UnknownFrame)
|
||||||
copy(buf, uf.Payload())
|
copy(*buf, uf.Payload())
|
||||||
return &UnknownFrame{
|
return &UnknownFrame{
|
||||||
hdr: hdr,
|
hdr: hdr,
|
||||||
Payload: buf,
|
Payload: *buf,
|
||||||
free: func() { fr.pool.Put(buf) },
|
free: func() { fr.pool.Put(buf) },
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
@ -156,19 +156,19 @@ func (fr *FramerBridge) WriteData(streamID uint32, endStream bool, data ...[]byt
|
||||||
return fr.framer.WriteData(streamID, endStream, data[0])
|
return fr.framer.WriteData(streamID, endStream, data[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
var buf []byte
|
|
||||||
tl := 0
|
tl := 0
|
||||||
for _, s := range data {
|
for _, s := range data {
|
||||||
tl += len(s)
|
tl += len(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
buf = fr.pool.Get(tl)[:0]
|
buf := fr.pool.Get(tl)
|
||||||
|
*buf = (*buf)[:0]
|
||||||
defer fr.pool.Put(buf)
|
defer fr.pool.Put(buf)
|
||||||
for _, s := range data {
|
for _, s := range data {
|
||||||
buf = append(buf, s...)
|
*buf = append(*buf, s...)
|
||||||
}
|
}
|
||||||
|
|
||||||
return fr.framer.WriteData(streamID, endStream, buf)
|
return fr.framer.WriteData(streamID, endStream, *buf)
|
||||||
}
|
}
|
||||||
|
|
||||||
// WriteHeaders writes a Headers Frame into the underlying writer.
|
// WriteHeaders writes a Headers Frame into the underlying writer.
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,6 @@
|
||||||
package transport
|
package transport
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
@ -40,6 +39,7 @@ import (
|
||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
"google.golang.org/grpc/internal/grpclog"
|
"google.golang.org/grpc/internal/grpclog"
|
||||||
"google.golang.org/grpc/internal/grpcutil"
|
"google.golang.org/grpc/internal/grpcutil"
|
||||||
|
"google.golang.org/grpc/mem"
|
||||||
"google.golang.org/grpc/metadata"
|
"google.golang.org/grpc/metadata"
|
||||||
"google.golang.org/grpc/peer"
|
"google.golang.org/grpc/peer"
|
||||||
"google.golang.org/grpc/stats"
|
"google.golang.org/grpc/stats"
|
||||||
|
|
@ -50,7 +50,7 @@ import (
|
||||||
// NewServerHandlerTransport returns a ServerTransport handling gRPC from
|
// NewServerHandlerTransport returns a ServerTransport handling gRPC from
|
||||||
// inside an http.Handler, or writes an HTTP error to w and returns an error.
|
// inside an http.Handler, or writes an HTTP error to w and returns an error.
|
||||||
// It requires that the http Server supports HTTP/2.
|
// It requires that the http Server supports HTTP/2.
|
||||||
func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []stats.Handler) (ServerTransport, error) {
|
func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []stats.Handler, bufferPool mem.BufferPool) (ServerTransport, error) {
|
||||||
if r.Method != http.MethodPost {
|
if r.Method != http.MethodPost {
|
||||||
w.Header().Set("Allow", http.MethodPost)
|
w.Header().Set("Allow", http.MethodPost)
|
||||||
msg := fmt.Sprintf("invalid gRPC request method %q", r.Method)
|
msg := fmt.Sprintf("invalid gRPC request method %q", r.Method)
|
||||||
|
|
@ -98,6 +98,7 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []s
|
||||||
contentType: contentType,
|
contentType: contentType,
|
||||||
contentSubtype: contentSubtype,
|
contentSubtype: contentSubtype,
|
||||||
stats: stats,
|
stats: stats,
|
||||||
|
bufferPool: bufferPool,
|
||||||
}
|
}
|
||||||
st.logger = prefixLoggerForServerHandlerTransport(st)
|
st.logger = prefixLoggerForServerHandlerTransport(st)
|
||||||
|
|
||||||
|
|
@ -171,6 +172,8 @@ type serverHandlerTransport struct {
|
||||||
|
|
||||||
stats []stats.Handler
|
stats []stats.Handler
|
||||||
logger *grpclog.PrefixLogger
|
logger *grpclog.PrefixLogger
|
||||||
|
|
||||||
|
bufferPool mem.BufferPool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ht *serverHandlerTransport) Close(err error) {
|
func (ht *serverHandlerTransport) Close(err error) {
|
||||||
|
|
@ -330,16 +333,28 @@ func (ht *serverHandlerTransport) writeCustomHeaders(s *Stream) {
|
||||||
s.hdrMu.Unlock()
|
s.hdrMu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ht *serverHandlerTransport) Write(s *Stream, hdr []byte, data []byte, opts *Options) error {
|
func (ht *serverHandlerTransport) Write(s *Stream, hdr []byte, data mem.BufferSlice, opts *Options) error {
|
||||||
|
// Always take a reference because otherwise there is no guarantee the data will
|
||||||
|
// be available after this function returns. This is what callers to Write
|
||||||
|
// expect.
|
||||||
|
data.Ref()
|
||||||
headersWritten := s.updateHeaderSent()
|
headersWritten := s.updateHeaderSent()
|
||||||
return ht.do(func() {
|
err := ht.do(func() {
|
||||||
|
defer data.Free()
|
||||||
if !headersWritten {
|
if !headersWritten {
|
||||||
ht.writePendingHeaders(s)
|
ht.writePendingHeaders(s)
|
||||||
}
|
}
|
||||||
ht.rw.Write(hdr)
|
ht.rw.Write(hdr)
|
||||||
ht.rw.Write(data)
|
for _, b := range data {
|
||||||
|
_, _ = ht.rw.Write(b.ReadOnlyData())
|
||||||
|
}
|
||||||
ht.rw.(http.Flusher).Flush()
|
ht.rw.(http.Flusher).Flush()
|
||||||
})
|
})
|
||||||
|
if err != nil {
|
||||||
|
data.Free()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
|
func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
|
||||||
|
|
@ -406,7 +421,7 @@ func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream
|
||||||
headerWireLength: 0, // won't have access to header wire length until golang/go#18997.
|
headerWireLength: 0, // won't have access to header wire length until golang/go#18997.
|
||||||
}
|
}
|
||||||
s.trReader = &transportReader{
|
s.trReader = &transportReader{
|
||||||
reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf, freeBuffer: func(*bytes.Buffer) {}},
|
reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf},
|
||||||
windowHandler: func(int) {},
|
windowHandler: func(int) {},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -415,21 +430,19 @@ func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream
|
||||||
go func() {
|
go func() {
|
||||||
defer close(readerDone)
|
defer close(readerDone)
|
||||||
|
|
||||||
// TODO: minimize garbage, optimize recvBuffer code/ownership
|
for {
|
||||||
const readSize = 8196
|
buf := ht.bufferPool.Get(http2MaxFrameLen)
|
||||||
for buf := make([]byte, readSize); ; {
|
n, err := req.Body.Read(*buf)
|
||||||
n, err := req.Body.Read(buf)
|
|
||||||
if n > 0 {
|
if n > 0 {
|
||||||
s.buf.put(recvMsg{buffer: bytes.NewBuffer(buf[:n:n])})
|
*buf = (*buf)[:n]
|
||||||
buf = buf[n:]
|
s.buf.put(recvMsg{buffer: mem.NewBuffer(buf, ht.bufferPool)})
|
||||||
|
} else {
|
||||||
|
ht.bufferPool.Put(buf)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.buf.put(recvMsg{err: mapRecvMsgError(err)})
|
s.buf.put(recvMsg{err: mapRecvMsgError(err)})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(buf) == 0 {
|
|
||||||
buf = make([]byte, readSize)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -33,6 +33,7 @@ import (
|
||||||
|
|
||||||
epb "google.golang.org/genproto/googleapis/rpc/errdetails"
|
epb "google.golang.org/genproto/googleapis/rpc/errdetails"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/mem"
|
||||||
"google.golang.org/grpc/metadata"
|
"google.golang.org/grpc/metadata"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
|
|
@ -203,7 +204,7 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) {
|
||||||
if tt.modrw != nil {
|
if tt.modrw != nil {
|
||||||
rw = tt.modrw(rw)
|
rw = tt.modrw(rw)
|
||||||
}
|
}
|
||||||
got, gotErr := NewServerHandlerTransport(rw, tt.req, nil)
|
got, gotErr := NewServerHandlerTransport(rw, tt.req, nil, mem.DefaultBufferPool())
|
||||||
if (gotErr != nil) != (tt.wantErr != "") || (gotErr != nil && gotErr.Error() != tt.wantErr) {
|
if (gotErr != nil) != (tt.wantErr != "") || (gotErr != nil && gotErr.Error() != tt.wantErr) {
|
||||||
t.Errorf("%s: error = %q; want %q", tt.name, gotErr.Error(), tt.wantErr)
|
t.Errorf("%s: error = %q; want %q", tt.name, gotErr.Error(), tt.wantErr)
|
||||||
continue
|
continue
|
||||||
|
|
@ -259,7 +260,7 @@ func newHandleStreamTest(t *testing.T) *handleStreamTest {
|
||||||
Body: bodyr,
|
Body: bodyr,
|
||||||
}
|
}
|
||||||
rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
|
rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
|
||||||
ht, err := NewServerHandlerTransport(rw, req, nil)
|
ht, err := NewServerHandlerTransport(rw, req, nil, mem.DefaultBufferPool())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
@ -374,7 +375,7 @@ func (s) TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
|
||||||
Body: bodyr,
|
Body: bodyr,
|
||||||
}
|
}
|
||||||
rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
|
rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
|
||||||
ht, err := NewServerHandlerTransport(rw, req, nil)
|
ht, err := NewServerHandlerTransport(rw, req, nil, mem.DefaultBufferPool())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
@ -439,7 +440,7 @@ func (s) TestHandlerTransport_HandleStreams_WriteStatusWrite(t *testing.T) {
|
||||||
st.bodyw.Close() // no body
|
st.bodyw.Close() // no body
|
||||||
|
|
||||||
st.ht.WriteStatus(s, status.New(codes.OK, ""))
|
st.ht.WriteStatus(s, status.New(codes.OK, ""))
|
||||||
st.ht.Write(s, []byte("hdr"), []byte("data"), &Options{})
|
st.ht.Write(s, []byte("hdr"), newBufferSlice([]byte("data")), &Options{})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -47,6 +47,7 @@ import (
|
||||||
isyscall "google.golang.org/grpc/internal/syscall"
|
isyscall "google.golang.org/grpc/internal/syscall"
|
||||||
"google.golang.org/grpc/internal/transport/networktype"
|
"google.golang.org/grpc/internal/transport/networktype"
|
||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
|
"google.golang.org/grpc/mem"
|
||||||
"google.golang.org/grpc/metadata"
|
"google.golang.org/grpc/metadata"
|
||||||
"google.golang.org/grpc/peer"
|
"google.golang.org/grpc/peer"
|
||||||
"google.golang.org/grpc/resolver"
|
"google.golang.org/grpc/resolver"
|
||||||
|
|
@ -146,7 +147,7 @@ type http2Client struct {
|
||||||
|
|
||||||
onClose func(GoAwayReason)
|
onClose func(GoAwayReason)
|
||||||
|
|
||||||
bufferPool *bufferPool
|
bufferPool mem.BufferPool
|
||||||
|
|
||||||
connectionID uint64
|
connectionID uint64
|
||||||
logger *grpclog.PrefixLogger
|
logger *grpclog.PrefixLogger
|
||||||
|
|
@ -348,7 +349,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
|
||||||
streamQuota: defaultMaxStreamsClient,
|
streamQuota: defaultMaxStreamsClient,
|
||||||
streamsQuotaAvailable: make(chan struct{}, 1),
|
streamsQuotaAvailable: make(chan struct{}, 1),
|
||||||
keepaliveEnabled: keepaliveEnabled,
|
keepaliveEnabled: keepaliveEnabled,
|
||||||
bufferPool: newBufferPool(),
|
bufferPool: opts.BufferPool,
|
||||||
onClose: onClose,
|
onClose: onClose,
|
||||||
}
|
}
|
||||||
var czSecurity credentials.ChannelzSecurityValue
|
var czSecurity credentials.ChannelzSecurityValue
|
||||||
|
|
@ -465,7 +466,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
go func() {
|
go func() {
|
||||||
t.loopy = newLoopyWriter(clientSide, t.framer, t.controlBuf, t.bdpEst, t.conn, t.logger, t.outgoingGoAwayHandler)
|
t.loopy = newLoopyWriter(clientSide, t.framer, t.controlBuf, t.bdpEst, t.conn, t.logger, t.outgoingGoAwayHandler, t.bufferPool)
|
||||||
if err := t.loopy.run(); !isIOError(err) {
|
if err := t.loopy.run(); !isIOError(err) {
|
||||||
// Immediately close the connection, as the loopy writer returns
|
// Immediately close the connection, as the loopy writer returns
|
||||||
// when there are no more active streams and we were draining (the
|
// when there are no more active streams and we were draining (the
|
||||||
|
|
@ -506,7 +507,6 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
|
||||||
closeStream: func(err error) {
|
closeStream: func(err error) {
|
||||||
t.CloseStream(s, err)
|
t.CloseStream(s, err)
|
||||||
},
|
},
|
||||||
freeBuffer: t.bufferPool.put,
|
|
||||||
},
|
},
|
||||||
windowHandler: func(n int) {
|
windowHandler: func(n int) {
|
||||||
t.updateWindow(s, uint32(n))
|
t.updateWindow(s, uint32(n))
|
||||||
|
|
@ -1078,27 +1078,36 @@ func (t *http2Client) GracefulClose() {
|
||||||
|
|
||||||
// Write formats the data into HTTP2 data frame(s) and sends it out. The caller
|
// Write formats the data into HTTP2 data frame(s) and sends it out. The caller
|
||||||
// should proceed only if Write returns nil.
|
// should proceed only if Write returns nil.
|
||||||
func (t *http2Client) Write(s *Stream, hdr []byte, data []byte, opts *Options) error {
|
func (t *http2Client) Write(s *Stream, hdr []byte, data mem.BufferSlice, opts *Options) error {
|
||||||
|
reader := data.Reader()
|
||||||
|
|
||||||
if opts.Last {
|
if opts.Last {
|
||||||
// If it's the last message, update stream state.
|
// If it's the last message, update stream state.
|
||||||
if !s.compareAndSwapState(streamActive, streamWriteDone) {
|
if !s.compareAndSwapState(streamActive, streamWriteDone) {
|
||||||
|
_ = reader.Close()
|
||||||
return errStreamDone
|
return errStreamDone
|
||||||
}
|
}
|
||||||
} else if s.getState() != streamActive {
|
} else if s.getState() != streamActive {
|
||||||
|
_ = reader.Close()
|
||||||
return errStreamDone
|
return errStreamDone
|
||||||
}
|
}
|
||||||
df := &dataFrame{
|
df := &dataFrame{
|
||||||
streamID: s.id,
|
streamID: s.id,
|
||||||
endStream: opts.Last,
|
endStream: opts.Last,
|
||||||
h: hdr,
|
h: hdr,
|
||||||
d: data,
|
reader: reader,
|
||||||
}
|
}
|
||||||
if hdr != nil || data != nil { // If it's not an empty data frame, check quota.
|
if hdr != nil || df.reader.Remaining() != 0 { // If it's not an empty data frame, check quota.
|
||||||
if err := s.wq.get(int32(len(hdr) + len(data))); err != nil {
|
if err := s.wq.get(int32(len(hdr) + df.reader.Remaining())); err != nil {
|
||||||
|
_ = reader.Close()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return t.controlBuf.put(df)
|
if err := t.controlBuf.put(df); err != nil {
|
||||||
|
_ = reader.Close()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *http2Client) getStream(f http2.Frame) *Stream {
|
func (t *http2Client) getStream(f http2.Frame) *Stream {
|
||||||
|
|
@ -1203,10 +1212,13 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
|
||||||
// guarantee f.Data() is consumed before the arrival of next frame.
|
// guarantee f.Data() is consumed before the arrival of next frame.
|
||||||
// Can this copy be eliminated?
|
// Can this copy be eliminated?
|
||||||
if len(f.Data()) > 0 {
|
if len(f.Data()) > 0 {
|
||||||
buffer := t.bufferPool.get()
|
pool := t.bufferPool
|
||||||
buffer.Reset()
|
if pool == nil {
|
||||||
buffer.Write(f.Data())
|
// Note that this is only supposed to be nil in tests. Otherwise, stream is
|
||||||
s.write(recvMsg{buffer: buffer})
|
// always initialized with a BufferPool.
|
||||||
|
pool = mem.DefaultBufferPool()
|
||||||
|
}
|
||||||
|
s.write(recvMsg{buffer: mem.Copy(f.Data(), pool)})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// The server has closed the stream without sending trailers. Record that
|
// The server has closed the stream without sending trailers. Record that
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,7 @@ import (
|
||||||
"google.golang.org/grpc/internal/grpcutil"
|
"google.golang.org/grpc/internal/grpcutil"
|
||||||
"google.golang.org/grpc/internal/pretty"
|
"google.golang.org/grpc/internal/pretty"
|
||||||
"google.golang.org/grpc/internal/syscall"
|
"google.golang.org/grpc/internal/syscall"
|
||||||
|
"google.golang.org/grpc/mem"
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
|
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
|
|
@ -119,7 +120,7 @@ type http2Server struct {
|
||||||
|
|
||||||
// Fields below are for channelz metric collection.
|
// Fields below are for channelz metric collection.
|
||||||
channelz *channelz.Socket
|
channelz *channelz.Socket
|
||||||
bufferPool *bufferPool
|
bufferPool mem.BufferPool
|
||||||
|
|
||||||
connectionID uint64
|
connectionID uint64
|
||||||
|
|
||||||
|
|
@ -261,7 +262,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
|
||||||
idle: time.Now(),
|
idle: time.Now(),
|
||||||
kep: kep,
|
kep: kep,
|
||||||
initialWindowSize: iwz,
|
initialWindowSize: iwz,
|
||||||
bufferPool: newBufferPool(),
|
bufferPool: config.BufferPool,
|
||||||
}
|
}
|
||||||
var czSecurity credentials.ChannelzSecurityValue
|
var czSecurity credentials.ChannelzSecurityValue
|
||||||
if au, ok := authInfo.(credentials.ChannelzSecurityInfo); ok {
|
if au, ok := authInfo.(credentials.ChannelzSecurityInfo); ok {
|
||||||
|
|
@ -330,7 +331,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
|
||||||
t.handleSettings(sf)
|
t.handleSettings(sf)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
t.loopy = newLoopyWriter(serverSide, t.framer, t.controlBuf, t.bdpEst, t.conn, t.logger, t.outgoingGoAwayHandler)
|
t.loopy = newLoopyWriter(serverSide, t.framer, t.controlBuf, t.bdpEst, t.conn, t.logger, t.outgoingGoAwayHandler, t.bufferPool)
|
||||||
err := t.loopy.run()
|
err := t.loopy.run()
|
||||||
close(t.loopyWriterDone)
|
close(t.loopyWriterDone)
|
||||||
if !isIOError(err) {
|
if !isIOError(err) {
|
||||||
|
|
@ -613,10 +614,9 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade
|
||||||
s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone)
|
s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone)
|
||||||
s.trReader = &transportReader{
|
s.trReader = &transportReader{
|
||||||
reader: &recvBufferReader{
|
reader: &recvBufferReader{
|
||||||
ctx: s.ctx,
|
ctx: s.ctx,
|
||||||
ctxDone: s.ctxDone,
|
ctxDone: s.ctxDone,
|
||||||
recv: s.buf,
|
recv: s.buf,
|
||||||
freeBuffer: t.bufferPool.put,
|
|
||||||
},
|
},
|
||||||
windowHandler: func(n int) {
|
windowHandler: func(n int) {
|
||||||
t.updateWindow(s, uint32(n))
|
t.updateWindow(s, uint32(n))
|
||||||
|
|
@ -813,10 +813,13 @@ func (t *http2Server) handleData(f *http2.DataFrame) {
|
||||||
// guarantee f.Data() is consumed before the arrival of next frame.
|
// guarantee f.Data() is consumed before the arrival of next frame.
|
||||||
// Can this copy be eliminated?
|
// Can this copy be eliminated?
|
||||||
if len(f.Data()) > 0 {
|
if len(f.Data()) > 0 {
|
||||||
buffer := t.bufferPool.get()
|
pool := t.bufferPool
|
||||||
buffer.Reset()
|
if pool == nil {
|
||||||
buffer.Write(f.Data())
|
// Note that this is only supposed to be nil in tests. Otherwise, stream is
|
||||||
s.write(recvMsg{buffer: buffer})
|
// always initialized with a BufferPool.
|
||||||
|
pool = mem.DefaultBufferPool()
|
||||||
|
}
|
||||||
|
s.write(recvMsg{buffer: mem.Copy(f.Data(), pool)})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if f.StreamEnded() {
|
if f.StreamEnded() {
|
||||||
|
|
@ -1114,27 +1117,37 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
|
||||||
|
|
||||||
// Write converts the data into HTTP2 data frame and sends it out. Non-nil error
|
// Write converts the data into HTTP2 data frame and sends it out. Non-nil error
|
||||||
// is returns if it fails (e.g., framing error, transport error).
|
// is returns if it fails (e.g., framing error, transport error).
|
||||||
func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) error {
|
func (t *http2Server) Write(s *Stream, hdr []byte, data mem.BufferSlice, opts *Options) error {
|
||||||
|
reader := data.Reader()
|
||||||
|
|
||||||
if !s.isHeaderSent() { // Headers haven't been written yet.
|
if !s.isHeaderSent() { // Headers haven't been written yet.
|
||||||
if err := t.WriteHeader(s, nil); err != nil {
|
if err := t.WriteHeader(s, nil); err != nil {
|
||||||
|
_ = reader.Close()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Writing headers checks for this condition.
|
// Writing headers checks for this condition.
|
||||||
if s.getState() == streamDone {
|
if s.getState() == streamDone {
|
||||||
|
_ = reader.Close()
|
||||||
return t.streamContextErr(s)
|
return t.streamContextErr(s)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
df := &dataFrame{
|
df := &dataFrame{
|
||||||
streamID: s.id,
|
streamID: s.id,
|
||||||
h: hdr,
|
h: hdr,
|
||||||
d: data,
|
reader: reader,
|
||||||
onEachWrite: t.setResetPingStrikes,
|
onEachWrite: t.setResetPingStrikes,
|
||||||
}
|
}
|
||||||
if err := s.wq.get(int32(len(hdr) + len(data))); err != nil {
|
if err := s.wq.get(int32(len(hdr) + df.reader.Remaining())); err != nil {
|
||||||
|
_ = reader.Close()
|
||||||
return t.streamContextErr(s)
|
return t.streamContextErr(s)
|
||||||
}
|
}
|
||||||
return t.controlBuf.put(df)
|
if err := t.controlBuf.put(df); err != nil {
|
||||||
|
_ = reader.Close()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// keepalive running in a separate goroutine does the following:
|
// keepalive running in a separate goroutine does the following:
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,6 @@
|
||||||
package transport
|
package transport
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
@ -37,6 +36,7 @@ import (
|
||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
"google.golang.org/grpc/internal/channelz"
|
"google.golang.org/grpc/internal/channelz"
|
||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
|
"google.golang.org/grpc/mem"
|
||||||
"google.golang.org/grpc/metadata"
|
"google.golang.org/grpc/metadata"
|
||||||
"google.golang.org/grpc/peer"
|
"google.golang.org/grpc/peer"
|
||||||
"google.golang.org/grpc/resolver"
|
"google.golang.org/grpc/resolver"
|
||||||
|
|
@ -47,32 +47,10 @@ import (
|
||||||
|
|
||||||
const logLevel = 2
|
const logLevel = 2
|
||||||
|
|
||||||
type bufferPool struct {
|
|
||||||
pool sync.Pool
|
|
||||||
}
|
|
||||||
|
|
||||||
func newBufferPool() *bufferPool {
|
|
||||||
return &bufferPool{
|
|
||||||
pool: sync.Pool{
|
|
||||||
New: func() any {
|
|
||||||
return new(bytes.Buffer)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *bufferPool) get() *bytes.Buffer {
|
|
||||||
return p.pool.Get().(*bytes.Buffer)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *bufferPool) put(b *bytes.Buffer) {
|
|
||||||
p.pool.Put(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
// recvMsg represents the received msg from the transport. All transport
|
// recvMsg represents the received msg from the transport. All transport
|
||||||
// protocol specific info has been removed.
|
// protocol specific info has been removed.
|
||||||
type recvMsg struct {
|
type recvMsg struct {
|
||||||
buffer *bytes.Buffer
|
buffer mem.Buffer
|
||||||
// nil: received some data
|
// nil: received some data
|
||||||
// io.EOF: stream is completed. data is nil.
|
// io.EOF: stream is completed. data is nil.
|
||||||
// other non-nil error: transport failure. data is nil.
|
// other non-nil error: transport failure. data is nil.
|
||||||
|
|
@ -102,6 +80,9 @@ func newRecvBuffer() *recvBuffer {
|
||||||
func (b *recvBuffer) put(r recvMsg) {
|
func (b *recvBuffer) put(r recvMsg) {
|
||||||
b.mu.Lock()
|
b.mu.Lock()
|
||||||
if b.err != nil {
|
if b.err != nil {
|
||||||
|
// drop the buffer on the floor. Since b.err is not nil, any subsequent reads
|
||||||
|
// will always return an error, making this buffer inaccessible.
|
||||||
|
r.buffer.Free()
|
||||||
b.mu.Unlock()
|
b.mu.Unlock()
|
||||||
// An error had occurred earlier, don't accept more
|
// An error had occurred earlier, don't accept more
|
||||||
// data or errors.
|
// data or errors.
|
||||||
|
|
@ -148,45 +129,70 @@ type recvBufferReader struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
ctxDone <-chan struct{} // cache of ctx.Done() (for performance).
|
ctxDone <-chan struct{} // cache of ctx.Done() (for performance).
|
||||||
recv *recvBuffer
|
recv *recvBuffer
|
||||||
last *bytes.Buffer // Stores the remaining data in the previous calls.
|
last mem.Buffer // Stores the remaining data in the previous calls.
|
||||||
err error
|
err error
|
||||||
freeBuffer func(*bytes.Buffer)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read reads the next len(p) bytes from last. If last is drained, it tries to
|
func (r *recvBufferReader) ReadHeader(header []byte) (n int, err error) {
|
||||||
// read additional data from recv. It blocks if there no additional data available
|
|
||||||
// in recv. If Read returns any non-nil error, it will continue to return that error.
|
|
||||||
func (r *recvBufferReader) Read(p []byte) (n int, err error) {
|
|
||||||
if r.err != nil {
|
if r.err != nil {
|
||||||
return 0, r.err
|
return 0, r.err
|
||||||
}
|
}
|
||||||
if r.last != nil {
|
if r.last != nil {
|
||||||
// Read remaining data left in last call.
|
n, r.last = mem.ReadUnsafe(header, r.last)
|
||||||
copied, _ := r.last.Read(p)
|
return n, nil
|
||||||
if r.last.Len() == 0 {
|
|
||||||
r.freeBuffer(r.last)
|
|
||||||
r.last = nil
|
|
||||||
}
|
|
||||||
return copied, nil
|
|
||||||
}
|
}
|
||||||
if r.closeStream != nil {
|
if r.closeStream != nil {
|
||||||
n, r.err = r.readClient(p)
|
n, r.err = r.readHeaderClient(header)
|
||||||
} else {
|
} else {
|
||||||
n, r.err = r.read(p)
|
n, r.err = r.readHeader(header)
|
||||||
}
|
}
|
||||||
return n, r.err
|
return n, r.err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *recvBufferReader) read(p []byte) (n int, err error) {
|
// Read reads the next n bytes from last. If last is drained, it tries to read
|
||||||
|
// additional data from recv. It blocks if there no additional data available in
|
||||||
|
// recv. If Read returns any non-nil error, it will continue to return that
|
||||||
|
// error.
|
||||||
|
func (r *recvBufferReader) Read(n int) (buf mem.Buffer, err error) {
|
||||||
|
if r.err != nil {
|
||||||
|
return nil, r.err
|
||||||
|
}
|
||||||
|
if r.last != nil {
|
||||||
|
buf = r.last
|
||||||
|
if r.last.Len() > n {
|
||||||
|
buf, r.last = mem.SplitUnsafe(buf, n)
|
||||||
|
} else {
|
||||||
|
r.last = nil
|
||||||
|
}
|
||||||
|
return buf, nil
|
||||||
|
}
|
||||||
|
if r.closeStream != nil {
|
||||||
|
buf, r.err = r.readClient(n)
|
||||||
|
} else {
|
||||||
|
buf, r.err = r.read(n)
|
||||||
|
}
|
||||||
|
return buf, r.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *recvBufferReader) readHeader(header []byte) (n int, err error) {
|
||||||
select {
|
select {
|
||||||
case <-r.ctxDone:
|
case <-r.ctxDone:
|
||||||
return 0, ContextErr(r.ctx.Err())
|
return 0, ContextErr(r.ctx.Err())
|
||||||
case m := <-r.recv.get():
|
case m := <-r.recv.get():
|
||||||
return r.readAdditional(m, p)
|
return r.readHeaderAdditional(m, header)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *recvBufferReader) readClient(p []byte) (n int, err error) {
|
func (r *recvBufferReader) read(n int) (buf mem.Buffer, err error) {
|
||||||
|
select {
|
||||||
|
case <-r.ctxDone:
|
||||||
|
return nil, ContextErr(r.ctx.Err())
|
||||||
|
case m := <-r.recv.get():
|
||||||
|
return r.readAdditional(m, n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *recvBufferReader) readHeaderClient(header []byte) (n int, err error) {
|
||||||
// If the context is canceled, then closes the stream with nil metadata.
|
// If the context is canceled, then closes the stream with nil metadata.
|
||||||
// closeStream writes its error parameter to r.recv as a recvMsg.
|
// closeStream writes its error parameter to r.recv as a recvMsg.
|
||||||
// r.readAdditional acts on that message and returns the necessary error.
|
// r.readAdditional acts on that message and returns the necessary error.
|
||||||
|
|
@ -207,25 +213,67 @@ func (r *recvBufferReader) readClient(p []byte) (n int, err error) {
|
||||||
// faster.
|
// faster.
|
||||||
r.closeStream(ContextErr(r.ctx.Err()))
|
r.closeStream(ContextErr(r.ctx.Err()))
|
||||||
m := <-r.recv.get()
|
m := <-r.recv.get()
|
||||||
return r.readAdditional(m, p)
|
return r.readHeaderAdditional(m, header)
|
||||||
case m := <-r.recv.get():
|
case m := <-r.recv.get():
|
||||||
return r.readAdditional(m, p)
|
return r.readHeaderAdditional(m, header)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *recvBufferReader) readAdditional(m recvMsg, p []byte) (n int, err error) {
|
func (r *recvBufferReader) readClient(n int) (buf mem.Buffer, err error) {
|
||||||
|
// If the context is canceled, then closes the stream with nil metadata.
|
||||||
|
// closeStream writes its error parameter to r.recv as a recvMsg.
|
||||||
|
// r.readAdditional acts on that message and returns the necessary error.
|
||||||
|
select {
|
||||||
|
case <-r.ctxDone:
|
||||||
|
// Note that this adds the ctx error to the end of recv buffer, and
|
||||||
|
// reads from the head. This will delay the error until recv buffer is
|
||||||
|
// empty, thus will delay ctx cancellation in Recv().
|
||||||
|
//
|
||||||
|
// It's done this way to fix a race between ctx cancel and trailer. The
|
||||||
|
// race was, stream.Recv() may return ctx error if ctxDone wins the
|
||||||
|
// race, but stream.Trailer() may return a non-nil md because the stream
|
||||||
|
// was not marked as done when trailer is received. This closeStream
|
||||||
|
// call will mark stream as done, thus fix the race.
|
||||||
|
//
|
||||||
|
// TODO: delaying ctx error seems like a unnecessary side effect. What
|
||||||
|
// we really want is to mark the stream as done, and return ctx error
|
||||||
|
// faster.
|
||||||
|
r.closeStream(ContextErr(r.ctx.Err()))
|
||||||
|
m := <-r.recv.get()
|
||||||
|
return r.readAdditional(m, n)
|
||||||
|
case m := <-r.recv.get():
|
||||||
|
return r.readAdditional(m, n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *recvBufferReader) readHeaderAdditional(m recvMsg, header []byte) (n int, err error) {
|
||||||
r.recv.load()
|
r.recv.load()
|
||||||
if m.err != nil {
|
if m.err != nil {
|
||||||
|
if m.buffer != nil {
|
||||||
|
m.buffer.Free()
|
||||||
|
}
|
||||||
return 0, m.err
|
return 0, m.err
|
||||||
}
|
}
|
||||||
copied, _ := m.buffer.Read(p)
|
|
||||||
if m.buffer.Len() == 0 {
|
n, r.last = mem.ReadUnsafe(header, m.buffer)
|
||||||
r.freeBuffer(m.buffer)
|
|
||||||
r.last = nil
|
return n, nil
|
||||||
} else {
|
}
|
||||||
r.last = m.buffer
|
|
||||||
|
func (r *recvBufferReader) readAdditional(m recvMsg, n int) (b mem.Buffer, err error) {
|
||||||
|
r.recv.load()
|
||||||
|
if m.err != nil {
|
||||||
|
if m.buffer != nil {
|
||||||
|
m.buffer.Free()
|
||||||
|
}
|
||||||
|
return nil, m.err
|
||||||
}
|
}
|
||||||
return copied, nil
|
|
||||||
|
if m.buffer.Len() > n {
|
||||||
|
m.buffer, r.last = mem.SplitUnsafe(m.buffer, n)
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.buffer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type streamState uint32
|
type streamState uint32
|
||||||
|
|
@ -251,7 +299,7 @@ type Stream struct {
|
||||||
recvCompress string
|
recvCompress string
|
||||||
sendCompress string
|
sendCompress string
|
||||||
buf *recvBuffer
|
buf *recvBuffer
|
||||||
trReader io.Reader
|
trReader *transportReader
|
||||||
fc *inFlow
|
fc *inFlow
|
||||||
wq *writeQuota
|
wq *writeQuota
|
||||||
|
|
||||||
|
|
@ -499,14 +547,55 @@ func (s *Stream) write(m recvMsg) {
|
||||||
s.buf.put(m)
|
s.buf.put(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read reads all p bytes from the wire for this stream.
|
func (s *Stream) ReadHeader(header []byte) (err error) {
|
||||||
func (s *Stream) Read(p []byte) (n int, err error) {
|
|
||||||
// Don't request a read if there was an error earlier
|
// Don't request a read if there was an error earlier
|
||||||
if er := s.trReader.(*transportReader).er; er != nil {
|
if er := s.trReader.er; er != nil {
|
||||||
return 0, er
|
return er
|
||||||
}
|
}
|
||||||
s.requestRead(len(p))
|
s.requestRead(len(header))
|
||||||
return io.ReadFull(s.trReader, p)
|
for len(header) != 0 {
|
||||||
|
n, err := s.trReader.ReadHeader(header)
|
||||||
|
header = header[n:]
|
||||||
|
if len(header) == 0 {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
if n > 0 && err == io.EOF {
|
||||||
|
err = io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read reads n bytes from the wire for this stream.
|
||||||
|
func (s *Stream) Read(n int) (data mem.BufferSlice, err error) {
|
||||||
|
// Don't request a read if there was an error earlier
|
||||||
|
if er := s.trReader.er; er != nil {
|
||||||
|
return nil, er
|
||||||
|
}
|
||||||
|
s.requestRead(n)
|
||||||
|
for n != 0 {
|
||||||
|
buf, err := s.trReader.Read(n)
|
||||||
|
var bufLen int
|
||||||
|
if buf != nil {
|
||||||
|
bufLen = buf.Len()
|
||||||
|
}
|
||||||
|
n -= bufLen
|
||||||
|
if n == 0 {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
if bufLen > 0 && err == io.EOF {
|
||||||
|
err = io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
data.Free()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
data = append(data, buf)
|
||||||
|
}
|
||||||
|
return data, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// transportReader reads all the data available for this Stream from the transport and
|
// transportReader reads all the data available for this Stream from the transport and
|
||||||
|
|
@ -514,21 +603,31 @@ func (s *Stream) Read(p []byte) (n int, err error) {
|
||||||
// The error is io.EOF when the stream is done or another non-nil error if
|
// The error is io.EOF when the stream is done or another non-nil error if
|
||||||
// the stream broke.
|
// the stream broke.
|
||||||
type transportReader struct {
|
type transportReader struct {
|
||||||
reader io.Reader
|
reader *recvBufferReader
|
||||||
// The handler to control the window update procedure for both this
|
// The handler to control the window update procedure for both this
|
||||||
// particular stream and the associated transport.
|
// particular stream and the associated transport.
|
||||||
windowHandler func(int)
|
windowHandler func(int)
|
||||||
er error
|
er error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *transportReader) Read(p []byte) (n int, err error) {
|
func (t *transportReader) ReadHeader(header []byte) (int, error) {
|
||||||
n, err = t.reader.Read(p)
|
n, err := t.reader.ReadHeader(header)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.er = err
|
t.er = err
|
||||||
return
|
return 0, err
|
||||||
}
|
}
|
||||||
t.windowHandler(n)
|
t.windowHandler(len(header))
|
||||||
return
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *transportReader) Read(n int) (mem.Buffer, error) {
|
||||||
|
buf, err := t.reader.Read(n)
|
||||||
|
if err != nil {
|
||||||
|
t.er = err
|
||||||
|
return buf, err
|
||||||
|
}
|
||||||
|
t.windowHandler(buf.Len())
|
||||||
|
return buf, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// BytesReceived indicates whether any bytes have been received on this stream.
|
// BytesReceived indicates whether any bytes have been received on this stream.
|
||||||
|
|
@ -574,6 +673,7 @@ type ServerConfig struct {
|
||||||
ChannelzParent *channelz.Server
|
ChannelzParent *channelz.Server
|
||||||
MaxHeaderListSize *uint32
|
MaxHeaderListSize *uint32
|
||||||
HeaderTableSize *uint32
|
HeaderTableSize *uint32
|
||||||
|
BufferPool mem.BufferPool
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConnectOptions covers all relevant options for communicating with the server.
|
// ConnectOptions covers all relevant options for communicating with the server.
|
||||||
|
|
@ -612,6 +712,8 @@ type ConnectOptions struct {
|
||||||
MaxHeaderListSize *uint32
|
MaxHeaderListSize *uint32
|
||||||
// UseProxy specifies if a proxy should be used.
|
// UseProxy specifies if a proxy should be used.
|
||||||
UseProxy bool
|
UseProxy bool
|
||||||
|
// The mem.BufferPool to use when reading/writing to the wire.
|
||||||
|
BufferPool mem.BufferPool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClientTransport establishes the transport with the required ConnectOptions
|
// NewClientTransport establishes the transport with the required ConnectOptions
|
||||||
|
|
@ -673,7 +775,7 @@ type ClientTransport interface {
|
||||||
|
|
||||||
// Write sends the data for the given stream. A nil stream indicates
|
// Write sends the data for the given stream. A nil stream indicates
|
||||||
// the write is to be performed on the transport as a whole.
|
// the write is to be performed on the transport as a whole.
|
||||||
Write(s *Stream, hdr []byte, data []byte, opts *Options) error
|
Write(s *Stream, hdr []byte, data mem.BufferSlice, opts *Options) error
|
||||||
|
|
||||||
// NewStream creates a Stream for an RPC.
|
// NewStream creates a Stream for an RPC.
|
||||||
NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, error)
|
NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, error)
|
||||||
|
|
@ -725,7 +827,7 @@ type ServerTransport interface {
|
||||||
|
|
||||||
// Write sends the data for the given stream.
|
// Write sends the data for the given stream.
|
||||||
// Write may not be called on all streams.
|
// Write may not be called on all streams.
|
||||||
Write(s *Stream, hdr []byte, data []byte, opts *Options) error
|
Write(s *Stream, hdr []byte, data mem.BufferSlice, opts *Options) error
|
||||||
|
|
||||||
// WriteStatus sends the status of a stream to the client. WriteStatus is
|
// WriteStatus sends the status of a stream to the client. WriteStatus is
|
||||||
// the final call made on a stream and always occurs.
|
// the final call made on a stream and always occurs.
|
||||||
|
|
|
||||||
|
|
@ -46,6 +46,7 @@ import (
|
||||||
"google.golang.org/grpc/internal/grpctest"
|
"google.golang.org/grpc/internal/grpctest"
|
||||||
"google.golang.org/grpc/internal/leakcheck"
|
"google.golang.org/grpc/internal/leakcheck"
|
||||||
"google.golang.org/grpc/internal/testutils"
|
"google.golang.org/grpc/internal/testutils"
|
||||||
|
"google.golang.org/grpc/mem"
|
||||||
"google.golang.org/grpc/metadata"
|
"google.golang.org/grpc/metadata"
|
||||||
"google.golang.org/grpc/resolver"
|
"google.golang.org/grpc/resolver"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
|
|
@ -74,6 +75,29 @@ func init() {
|
||||||
expectedResponseLarge[len(expectedResponseLarge)-1] = 'c'
|
expectedResponseLarge[len(expectedResponseLarge)-1] = 'c'
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newBufferSlice(b []byte) mem.BufferSlice {
|
||||||
|
return mem.BufferSlice{mem.NewBuffer(&b, nil)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Stream) readTo(p []byte) (int, error) {
|
||||||
|
data, err := s.Read(len(p))
|
||||||
|
defer data.Free()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if data.Len() != len(p) {
|
||||||
|
if err == nil {
|
||||||
|
err = io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
data.CopyTo(p)
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
type testStreamHandler struct {
|
type testStreamHandler struct {
|
||||||
t *http2Server
|
t *http2Server
|
||||||
notify chan struct{}
|
notify chan struct{}
|
||||||
|
|
@ -114,7 +138,7 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) {
|
||||||
resp = expectedResponseLarge
|
resp = expectedResponseLarge
|
||||||
}
|
}
|
||||||
p := make([]byte, len(req))
|
p := make([]byte, len(req))
|
||||||
_, err := s.Read(p)
|
_, err := s.readTo(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -124,7 +148,7 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// send a response back to the client.
|
// send a response back to the client.
|
||||||
h.t.Write(s, nil, resp, &Options{})
|
h.t.Write(s, nil, newBufferSlice(resp), &Options{})
|
||||||
// send the trailer to end the stream.
|
// send the trailer to end the stream.
|
||||||
h.t.WriteStatus(s, status.New(codes.OK, ""))
|
h.t.WriteStatus(s, status.New(codes.OK, ""))
|
||||||
}
|
}
|
||||||
|
|
@ -132,7 +156,7 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) {
|
||||||
func (h *testStreamHandler) handleStreamPingPong(t *testing.T, s *Stream) {
|
func (h *testStreamHandler) handleStreamPingPong(t *testing.T, s *Stream) {
|
||||||
header := make([]byte, 5)
|
header := make([]byte, 5)
|
||||||
for {
|
for {
|
||||||
if _, err := s.Read(header); err != nil {
|
if _, err := s.readTo(header); err != nil {
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
h.t.WriteStatus(s, status.New(codes.OK, ""))
|
h.t.WriteStatus(s, status.New(codes.OK, ""))
|
||||||
return
|
return
|
||||||
|
|
@ -143,7 +167,7 @@ func (h *testStreamHandler) handleStreamPingPong(t *testing.T, s *Stream) {
|
||||||
}
|
}
|
||||||
sz := binary.BigEndian.Uint32(header[1:])
|
sz := binary.BigEndian.Uint32(header[1:])
|
||||||
msg := make([]byte, int(sz))
|
msg := make([]byte, int(sz))
|
||||||
if _, err := s.Read(msg); err != nil {
|
if _, err := s.readTo(msg); err != nil {
|
||||||
t.Errorf("Error on server while reading message: %v", err)
|
t.Errorf("Error on server while reading message: %v", err)
|
||||||
h.t.WriteStatus(s, status.New(codes.Internal, "panic"))
|
h.t.WriteStatus(s, status.New(codes.Internal, "panic"))
|
||||||
return
|
return
|
||||||
|
|
@ -152,7 +176,7 @@ func (h *testStreamHandler) handleStreamPingPong(t *testing.T, s *Stream) {
|
||||||
buf[0] = byte(0)
|
buf[0] = byte(0)
|
||||||
binary.BigEndian.PutUint32(buf[1:], uint32(sz))
|
binary.BigEndian.PutUint32(buf[1:], uint32(sz))
|
||||||
copy(buf[5:], msg)
|
copy(buf[5:], msg)
|
||||||
h.t.Write(s, nil, buf, &Options{})
|
h.t.Write(s, nil, newBufferSlice(buf), &Options{})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -178,10 +202,11 @@ func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *Stream) {
|
||||||
p = make([]byte, n+1)
|
p = make([]byte, n+1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
data := newBufferSlice(p)
|
||||||
conn.controlBuf.put(&dataFrame{
|
conn.controlBuf.put(&dataFrame{
|
||||||
streamID: s.id,
|
streamID: s.id,
|
||||||
h: nil,
|
h: nil,
|
||||||
d: p,
|
reader: data.Reader(),
|
||||||
onEachWrite: func() {},
|
onEachWrite: func() {},
|
||||||
})
|
})
|
||||||
sent += len(p)
|
sent += len(p)
|
||||||
|
|
@ -191,6 +216,8 @@ func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *Stream) {
|
||||||
func (h *testStreamHandler) handleStreamEncodingRequiredStatus(s *Stream) {
|
func (h *testStreamHandler) handleStreamEncodingRequiredStatus(s *Stream) {
|
||||||
// raw newline is not accepted by http2 framer so it must be encoded.
|
// raw newline is not accepted by http2 framer so it must be encoded.
|
||||||
h.t.WriteStatus(s, encodingTestStatus)
|
h.t.WriteStatus(s, encodingTestStatus)
|
||||||
|
// Drain any remaining buffers from the stream since it was closed early.
|
||||||
|
s.Read(math.MaxInt)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *testStreamHandler) handleStreamInvalidHeaderField(s *Stream) {
|
func (h *testStreamHandler) handleStreamInvalidHeaderField(s *Stream) {
|
||||||
|
|
@ -260,7 +287,7 @@ func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *Stream) {
|
||||||
t.Errorf("Server timed-out.")
|
t.Errorf("Server timed-out.")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
_, err := s.Read(p)
|
_, err := s.readTo(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("s.Read(_) = _, %v, want _, <nil>", err)
|
t.Errorf("s.Read(_) = _, %v, want _, <nil>", err)
|
||||||
return
|
return
|
||||||
|
|
@ -273,14 +300,14 @@ func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *Stream) {
|
||||||
// This write will cause server to run out of stream level,
|
// This write will cause server to run out of stream level,
|
||||||
// flow control and the other side won't send a window update
|
// flow control and the other side won't send a window update
|
||||||
// until that happens.
|
// until that happens.
|
||||||
if err := h.t.Write(s, nil, resp, &Options{}); err != nil {
|
if err := h.t.Write(s, nil, newBufferSlice(resp), &Options{}); err != nil {
|
||||||
t.Errorf("server Write got %v, want <nil>", err)
|
t.Errorf("server Write got %v, want <nil>", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Read one more time to ensure that everything remains fine and
|
// Read one more time to ensure that everything remains fine and
|
||||||
// that the goroutine, that we launched earlier to signal client
|
// that the goroutine, that we launched earlier to signal client
|
||||||
// to read, gets enough time to process.
|
// to read, gets enough time to process.
|
||||||
_, err = s.Read(p)
|
_, err = s.readTo(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("s.Read(_) = _, %v, want _, nil", err)
|
t.Errorf("s.Read(_) = _, %v, want _, nil", err)
|
||||||
return
|
return
|
||||||
|
|
@ -502,7 +529,7 @@ func (s) TestInflightStreamClosing(t *testing.T) {
|
||||||
serr := status.Error(codes.Internal, "client connection is closing")
|
serr := status.Error(codes.Internal, "client connection is closing")
|
||||||
go func() {
|
go func() {
|
||||||
defer close(donec)
|
defer close(donec)
|
||||||
if _, err := stream.Read(make([]byte, defaultWindowSize)); err != serr {
|
if _, err := stream.readTo(make([]byte, defaultWindowSize)); err != serr {
|
||||||
t.Errorf("unexpected Stream error %v, expected %v", err, serr)
|
t.Errorf("unexpected Stream error %v, expected %v", err, serr)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
@ -592,15 +619,15 @@ func (s) TestClientSendAndReceive(t *testing.T) {
|
||||||
t.Fatalf("wrong stream id: %d", s2.id)
|
t.Fatalf("wrong stream id: %d", s2.id)
|
||||||
}
|
}
|
||||||
opts := Options{Last: true}
|
opts := Options{Last: true}
|
||||||
if err := ct.Write(s1, nil, expectedRequest, &opts); err != nil && err != io.EOF {
|
if err := ct.Write(s1, nil, newBufferSlice(expectedRequest), &opts); err != nil && err != io.EOF {
|
||||||
t.Fatalf("failed to send data: %v", err)
|
t.Fatalf("failed to send data: %v", err)
|
||||||
}
|
}
|
||||||
p := make([]byte, len(expectedResponse))
|
p := make([]byte, len(expectedResponse))
|
||||||
_, recvErr := s1.Read(p)
|
_, recvErr := s1.readTo(p)
|
||||||
if recvErr != nil || !bytes.Equal(p, expectedResponse) {
|
if recvErr != nil || !bytes.Equal(p, expectedResponse) {
|
||||||
t.Fatalf("Error: %v, want <nil>; Result: %v, want %v", recvErr, p, expectedResponse)
|
t.Fatalf("Error: %v, want <nil>; Result: %v, want %v", recvErr, p, expectedResponse)
|
||||||
}
|
}
|
||||||
_, recvErr = s1.Read(p)
|
_, recvErr = s1.readTo(p)
|
||||||
if recvErr != io.EOF {
|
if recvErr != io.EOF {
|
||||||
t.Fatalf("Error: %v; want <EOF>", recvErr)
|
t.Fatalf("Error: %v; want <EOF>", recvErr)
|
||||||
}
|
}
|
||||||
|
|
@ -629,16 +656,16 @@ func performOneRPC(ct ClientTransport) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
opts := Options{Last: true}
|
opts := Options{Last: true}
|
||||||
if err := ct.Write(s, []byte{}, expectedRequest, &opts); err == nil || err == io.EOF {
|
if err := ct.Write(s, []byte{}, newBufferSlice(expectedRequest), &opts); err == nil || err == io.EOF {
|
||||||
time.Sleep(5 * time.Millisecond)
|
time.Sleep(5 * time.Millisecond)
|
||||||
// The following s.Recv()'s could error out because the
|
// The following s.Recv()'s could error out because the
|
||||||
// underlying transport is gone.
|
// underlying transport is gone.
|
||||||
//
|
//
|
||||||
// Read response
|
// Read response
|
||||||
p := make([]byte, len(expectedResponse))
|
p := make([]byte, len(expectedResponse))
|
||||||
s.Read(p)
|
s.readTo(p)
|
||||||
// Read io.EOF
|
// Read io.EOF
|
||||||
s.Read(p)
|
s.readTo(p)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -674,14 +701,14 @@ func (s) TestLargeMessage(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err)
|
t.Errorf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err)
|
||||||
}
|
}
|
||||||
if err := ct.Write(s, []byte{}, expectedRequestLarge, &Options{Last: true}); err != nil && err != io.EOF {
|
if err := ct.Write(s, []byte{}, newBufferSlice(expectedRequestLarge), &Options{Last: true}); err != nil && err != io.EOF {
|
||||||
t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
|
t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
|
||||||
}
|
}
|
||||||
p := make([]byte, len(expectedResponseLarge))
|
p := make([]byte, len(expectedResponseLarge))
|
||||||
if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponseLarge) {
|
if _, err := s.readTo(p); err != nil || !bytes.Equal(p, expectedResponseLarge) {
|
||||||
t.Errorf("s.Read(%v) = _, %v, want %v, <nil>", err, p, expectedResponse)
|
t.Errorf("s.Read(%v) = _, %v, want %v, <nil>", err, p, expectedResponse)
|
||||||
}
|
}
|
||||||
if _, err = s.Read(p); err != io.EOF {
|
if _, err = s.readTo(p); err != io.EOF {
|
||||||
t.Errorf("Failed to complete the stream %v; want <EOF>", err)
|
t.Errorf("Failed to complete the stream %v; want <EOF>", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
@ -765,7 +792,7 @@ func (s) TestLargeMessageWithDelayRead(t *testing.T) {
|
||||||
// This write will cause client to run out of stream level,
|
// This write will cause client to run out of stream level,
|
||||||
// flow control and the other side won't send a window update
|
// flow control and the other side won't send a window update
|
||||||
// until that happens.
|
// until that happens.
|
||||||
if err := ct.Write(s, []byte{}, expectedRequestLarge, &Options{}); err != nil {
|
if err := ct.Write(s, []byte{}, newBufferSlice(expectedRequestLarge), &Options{}); err != nil {
|
||||||
t.Fatalf("write(_, _, _) = %v, want <nil>", err)
|
t.Fatalf("write(_, _, _) = %v, want <nil>", err)
|
||||||
}
|
}
|
||||||
p := make([]byte, len(expectedResponseLarge))
|
p := make([]byte, len(expectedResponseLarge))
|
||||||
|
|
@ -777,13 +804,13 @@ func (s) TestLargeMessageWithDelayRead(t *testing.T) {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatalf("Client timed out")
|
t.Fatalf("Client timed out")
|
||||||
}
|
}
|
||||||
if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponseLarge) {
|
if _, err := s.readTo(p); err != nil || !bytes.Equal(p, expectedResponseLarge) {
|
||||||
t.Fatalf("s.Read(_) = _, %v, want _, <nil>", err)
|
t.Fatalf("s.Read(_) = _, %v, want _, <nil>", err)
|
||||||
}
|
}
|
||||||
if err := ct.Write(s, []byte{}, expectedRequestLarge, &Options{Last: true}); err != nil {
|
if err := ct.Write(s, []byte{}, newBufferSlice(expectedRequestLarge), &Options{Last: true}); err != nil {
|
||||||
t.Fatalf("Write(_, _, _) = %v, want <nil>", err)
|
t.Fatalf("Write(_, _, _) = %v, want <nil>", err)
|
||||||
}
|
}
|
||||||
if _, err = s.Read(p); err != io.EOF {
|
if _, err = s.readTo(p); err != io.EOF {
|
||||||
t.Fatalf("Failed to complete the stream %v; want <EOF>", err)
|
t.Fatalf("Failed to complete the stream %v; want <EOF>", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -792,6 +819,7 @@ func (s) TestLargeMessageWithDelayRead(t *testing.T) {
|
||||||
// proceed until they complete naturally, while not allowing creation of new
|
// proceed until they complete naturally, while not allowing creation of new
|
||||||
// streams during this window.
|
// streams during this window.
|
||||||
func (s) TestGracefulClose(t *testing.T) {
|
func (s) TestGracefulClose(t *testing.T) {
|
||||||
|
leakcheck.SetTrackingBufferPool(t)
|
||||||
server, ct, cancel := setUp(t, 0, pingpong)
|
server, ct, cancel := setUp(t, 0, pingpong)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
defer func() {
|
defer func() {
|
||||||
|
|
@ -800,7 +828,8 @@ func (s) TestGracefulClose(t *testing.T) {
|
||||||
server.lis.Close()
|
server.lis.Close()
|
||||||
// Check for goroutine leaks (i.e. GracefulClose with an active stream
|
// Check for goroutine leaks (i.e. GracefulClose with an active stream
|
||||||
// doesn't eventually close the connection when that stream completes).
|
// doesn't eventually close the connection when that stream completes).
|
||||||
leakcheck.Check(t)
|
leakcheck.CheckGoroutines(t, 10*time.Second)
|
||||||
|
leakcheck.CheckTrackingBufferPool()
|
||||||
// Correctly clean up the server
|
// Correctly clean up the server
|
||||||
server.stop()
|
server.stop()
|
||||||
}()
|
}()
|
||||||
|
|
@ -818,15 +847,15 @@ func (s) TestGracefulClose(t *testing.T) {
|
||||||
outgoingHeader[0] = byte(0)
|
outgoingHeader[0] = byte(0)
|
||||||
binary.BigEndian.PutUint32(outgoingHeader[1:], uint32(len(msg)))
|
binary.BigEndian.PutUint32(outgoingHeader[1:], uint32(len(msg)))
|
||||||
incomingHeader := make([]byte, 5)
|
incomingHeader := make([]byte, 5)
|
||||||
if err := ct.Write(s, outgoingHeader, msg, &Options{}); err != nil {
|
if err := ct.Write(s, outgoingHeader, newBufferSlice(msg), &Options{}); err != nil {
|
||||||
t.Fatalf("Error while writing: %v", err)
|
t.Fatalf("Error while writing: %v", err)
|
||||||
}
|
}
|
||||||
if _, err := s.Read(incomingHeader); err != nil {
|
if _, err := s.readTo(incomingHeader); err != nil {
|
||||||
t.Fatalf("Error while reading: %v", err)
|
t.Fatalf("Error while reading: %v", err)
|
||||||
}
|
}
|
||||||
sz := binary.BigEndian.Uint32(incomingHeader[1:])
|
sz := binary.BigEndian.Uint32(incomingHeader[1:])
|
||||||
recvMsg := make([]byte, int(sz))
|
recvMsg := make([]byte, int(sz))
|
||||||
if _, err := s.Read(recvMsg); err != nil {
|
if _, err := s.readTo(recvMsg); err != nil {
|
||||||
t.Fatalf("Error while reading: %v", err)
|
t.Fatalf("Error while reading: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -851,7 +880,7 @@ func (s) TestGracefulClose(t *testing.T) {
|
||||||
|
|
||||||
// Confirm the existing stream still functions as expected.
|
// Confirm the existing stream still functions as expected.
|
||||||
ct.Write(s, nil, nil, &Options{Last: true})
|
ct.Write(s, nil, nil, &Options{Last: true})
|
||||||
if _, err := s.Read(incomingHeader); err != io.EOF {
|
if _, err := s.readTo(incomingHeader); err != io.EOF {
|
||||||
t.Fatalf("Client expected EOF from the server. Got: %v", err)
|
t.Fatalf("Client expected EOF from the server. Got: %v", err)
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
@ -879,13 +908,13 @@ func (s) TestLargeMessageSuspension(t *testing.T) {
|
||||||
}()
|
}()
|
||||||
// Write should not be done successfully due to flow control.
|
// Write should not be done successfully due to flow control.
|
||||||
msg := make([]byte, initialWindowSize*8)
|
msg := make([]byte, initialWindowSize*8)
|
||||||
ct.Write(s, nil, msg, &Options{})
|
ct.Write(s, nil, newBufferSlice(msg), &Options{})
|
||||||
err = ct.Write(s, nil, msg, &Options{Last: true})
|
err = ct.Write(s, nil, newBufferSlice(msg), &Options{Last: true})
|
||||||
if err != errStreamDone {
|
if err != errStreamDone {
|
||||||
t.Fatalf("Write got %v, want io.EOF", err)
|
t.Fatalf("Write got %v, want io.EOF", err)
|
||||||
}
|
}
|
||||||
expectedErr := status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error())
|
expectedErr := status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error())
|
||||||
if _, err := s.Read(make([]byte, 8)); err.Error() != expectedErr.Error() {
|
if _, err := s.readTo(make([]byte, 8)); err.Error() != expectedErr.Error() {
|
||||||
t.Fatalf("Read got %v of type %T, want %v", err, err, expectedErr)
|
t.Fatalf("Read got %v of type %T, want %v", err, err, expectedErr)
|
||||||
}
|
}
|
||||||
ct.Close(fmt.Errorf("closed manually by test"))
|
ct.Close(fmt.Errorf("closed manually by test"))
|
||||||
|
|
@ -997,11 +1026,12 @@ func (s) TestServerContextCanceledOnClosedConnection(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to open stream: %v", err)
|
t.Fatalf("Failed to open stream: %v", err)
|
||||||
}
|
}
|
||||||
|
d := newBufferSlice(make([]byte, http2MaxFrameLen))
|
||||||
ct.controlBuf.put(&dataFrame{
|
ct.controlBuf.put(&dataFrame{
|
||||||
streamID: s.id,
|
streamID: s.id,
|
||||||
endStream: false,
|
endStream: false,
|
||||||
h: nil,
|
h: nil,
|
||||||
d: make([]byte, http2MaxFrameLen),
|
reader: d.Reader(),
|
||||||
onEachWrite: func() {},
|
onEachWrite: func() {},
|
||||||
})
|
})
|
||||||
// Loop until the server side stream is created.
|
// Loop until the server side stream is created.
|
||||||
|
|
@ -1078,7 +1108,7 @@ func (s) TestClientConnDecoupledFromApplicationRead(t *testing.T) {
|
||||||
t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream1.id)
|
t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream1.id)
|
||||||
}
|
}
|
||||||
// Exhaust client's connection window.
|
// Exhaust client's connection window.
|
||||||
if err := st.Write(sstream1, []byte{}, make([]byte, defaultWindowSize), &Options{}); err != nil {
|
if err := st.Write(sstream1, []byte{}, newBufferSlice(make([]byte, defaultWindowSize)), &Options{}); err != nil {
|
||||||
t.Fatalf("Server failed to write data. Err: %v", err)
|
t.Fatalf("Server failed to write data. Err: %v", err)
|
||||||
}
|
}
|
||||||
notifyChan = make(chan struct{})
|
notifyChan = make(chan struct{})
|
||||||
|
|
@ -1103,17 +1133,17 @@ func (s) TestClientConnDecoupledFromApplicationRead(t *testing.T) {
|
||||||
t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream2.id)
|
t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream2.id)
|
||||||
}
|
}
|
||||||
// Server should be able to send data on the new stream, even though the client hasn't read anything on the first stream.
|
// Server should be able to send data on the new stream, even though the client hasn't read anything on the first stream.
|
||||||
if err := st.Write(sstream2, []byte{}, make([]byte, defaultWindowSize), &Options{}); err != nil {
|
if err := st.Write(sstream2, []byte{}, newBufferSlice(make([]byte, defaultWindowSize)), &Options{}); err != nil {
|
||||||
t.Fatalf("Server failed to write data. Err: %v", err)
|
t.Fatalf("Server failed to write data. Err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Client should be able to read data on second stream.
|
// Client should be able to read data on second stream.
|
||||||
if _, err := cstream2.Read(make([]byte, defaultWindowSize)); err != nil {
|
if _, err := cstream2.readTo(make([]byte, defaultWindowSize)); err != nil {
|
||||||
t.Fatalf("_.Read(_) = _, %v, want _, <nil>", err)
|
t.Fatalf("_.Read(_) = _, %v, want _, <nil>", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Client should be able to read data on first stream.
|
// Client should be able to read data on first stream.
|
||||||
if _, err := cstream1.Read(make([]byte, defaultWindowSize)); err != nil {
|
if _, err := cstream1.readTo(make([]byte, defaultWindowSize)); err != nil {
|
||||||
t.Fatalf("_.Read(_) = _, %v, want _, <nil>", err)
|
t.Fatalf("_.Read(_) = _, %v, want _, <nil>", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -1149,7 +1179,7 @@ func (s) TestServerConnDecoupledFromApplicationRead(t *testing.T) {
|
||||||
t.Fatalf("Failed to create 1st stream. Err: %v", err)
|
t.Fatalf("Failed to create 1st stream. Err: %v", err)
|
||||||
}
|
}
|
||||||
// Exhaust server's connection window.
|
// Exhaust server's connection window.
|
||||||
if err := client.Write(cstream1, nil, make([]byte, defaultWindowSize), &Options{Last: true}); err != nil {
|
if err := client.Write(cstream1, nil, newBufferSlice(make([]byte, defaultWindowSize)), &Options{Last: true}); err != nil {
|
||||||
t.Fatalf("Client failed to write data. Err: %v", err)
|
t.Fatalf("Client failed to write data. Err: %v", err)
|
||||||
}
|
}
|
||||||
//Client should be able to create another stream and send data on it.
|
//Client should be able to create another stream and send data on it.
|
||||||
|
|
@ -1157,7 +1187,7 @@ func (s) TestServerConnDecoupledFromApplicationRead(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create 2nd stream. Err: %v", err)
|
t.Fatalf("Failed to create 2nd stream. Err: %v", err)
|
||||||
}
|
}
|
||||||
if err := client.Write(cstream2, nil, make([]byte, defaultWindowSize), &Options{}); err != nil {
|
if err := client.Write(cstream2, nil, newBufferSlice(make([]byte, defaultWindowSize)), &Options{}); err != nil {
|
||||||
t.Fatalf("Client failed to write data. Err: %v", err)
|
t.Fatalf("Client failed to write data. Err: %v", err)
|
||||||
}
|
}
|
||||||
// Get the streams on server.
|
// Get the streams on server.
|
||||||
|
|
@ -1179,11 +1209,11 @@ func (s) TestServerConnDecoupledFromApplicationRead(t *testing.T) {
|
||||||
}
|
}
|
||||||
st.mu.Unlock()
|
st.mu.Unlock()
|
||||||
// Reading from the stream on server should succeed.
|
// Reading from the stream on server should succeed.
|
||||||
if _, err := sstream1.Read(make([]byte, defaultWindowSize)); err != nil {
|
if _, err := sstream1.readTo(make([]byte, defaultWindowSize)); err != nil {
|
||||||
t.Fatalf("_.Read(_) = %v, want <nil>", err)
|
t.Fatalf("_.Read(_) = %v, want <nil>", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := sstream1.Read(make([]byte, 1)); err != io.EOF {
|
if _, err := sstream1.readTo(make([]byte, 1)); err != io.EOF {
|
||||||
t.Fatalf("_.Read(_) = %v, want io.EOF", err)
|
t.Fatalf("_.Read(_) = %v, want io.EOF", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1435,6 +1465,9 @@ func (s) TestClientWithMisbehavedServer(t *testing.T) {
|
||||||
t.Fatalf("Test timed-out.")
|
t.Fatalf("Test timed-out.")
|
||||||
case <-success:
|
case <-success:
|
||||||
}
|
}
|
||||||
|
// Drain the remaining buffers in the stream by reading until an error is
|
||||||
|
// encountered.
|
||||||
|
str.Read(math.MaxInt)
|
||||||
}
|
}
|
||||||
|
|
||||||
var encodingTestStatus = status.New(codes.Internal, "\n")
|
var encodingTestStatus = status.New(codes.Internal, "\n")
|
||||||
|
|
@ -1453,11 +1486,11 @@ func (s) TestEncodingRequiredStatus(t *testing.T) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
opts := Options{Last: true}
|
opts := Options{Last: true}
|
||||||
if err := ct.Write(s, nil, expectedRequest, &opts); err != nil && err != errStreamDone {
|
if err := ct.Write(s, nil, newBufferSlice(expectedRequest), &opts); err != nil && err != errStreamDone {
|
||||||
t.Fatalf("Failed to write the request: %v", err)
|
t.Fatalf("Failed to write the request: %v", err)
|
||||||
}
|
}
|
||||||
p := make([]byte, http2MaxFrameLen)
|
p := make([]byte, http2MaxFrameLen)
|
||||||
if _, err := s.trReader.(*transportReader).Read(p); err != io.EOF {
|
if _, err := s.readTo(p); err != io.EOF {
|
||||||
t.Fatalf("Read got error %v, want %v", err, io.EOF)
|
t.Fatalf("Read got error %v, want %v", err, io.EOF)
|
||||||
}
|
}
|
||||||
if !testutils.StatusErrEqual(s.Status().Err(), encodingTestStatus.Err()) {
|
if !testutils.StatusErrEqual(s.Status().Err(), encodingTestStatus.Err()) {
|
||||||
|
|
@ -1465,6 +1498,8 @@ func (s) TestEncodingRequiredStatus(t *testing.T) {
|
||||||
}
|
}
|
||||||
ct.Close(fmt.Errorf("closed manually by test"))
|
ct.Close(fmt.Errorf("closed manually by test"))
|
||||||
server.stop()
|
server.stop()
|
||||||
|
// Drain any remaining buffers from the stream since it was closed early.
|
||||||
|
s.Read(math.MaxInt)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s) TestInvalidHeaderField(t *testing.T) {
|
func (s) TestInvalidHeaderField(t *testing.T) {
|
||||||
|
|
@ -1481,7 +1516,7 @@ func (s) TestInvalidHeaderField(t *testing.T) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
p := make([]byte, http2MaxFrameLen)
|
p := make([]byte, http2MaxFrameLen)
|
||||||
_, err = s.trReader.(*transportReader).Read(p)
|
_, err = s.readTo(p)
|
||||||
if se, ok := status.FromError(err); !ok || se.Code() != codes.Internal || !strings.Contains(err.Error(), expectedInvalidHeaderField) {
|
if se, ok := status.FromError(err); !ok || se.Code() != codes.Internal || !strings.Contains(err.Error(), expectedInvalidHeaderField) {
|
||||||
t.Fatalf("Read got error %v, want error with code %s and contains %q", err, codes.Internal, expectedInvalidHeaderField)
|
t.Fatalf("Read got error %v, want error with code %s and contains %q", err, codes.Internal, expectedInvalidHeaderField)
|
||||||
}
|
}
|
||||||
|
|
@ -1639,17 +1674,17 @@ func testFlowControlAccountCheck(t *testing.T, msgSize int, wc windowSizeConfig)
|
||||||
opts := Options{}
|
opts := Options{}
|
||||||
header := make([]byte, 5)
|
header := make([]byte, 5)
|
||||||
for i := 1; i <= 5; i++ {
|
for i := 1; i <= 5; i++ {
|
||||||
if err := client.Write(stream, nil, buf, &opts); err != nil {
|
if err := client.Write(stream, nil, newBufferSlice(buf), &opts); err != nil {
|
||||||
t.Errorf("Error on client while writing message %v on stream %v: %v", i, stream.id, err)
|
t.Errorf("Error on client while writing message %v on stream %v: %v", i, stream.id, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if _, err := stream.Read(header); err != nil {
|
if _, err := stream.readTo(header); err != nil {
|
||||||
t.Errorf("Error on client while reading data frame header %v on stream %v: %v", i, stream.id, err)
|
t.Errorf("Error on client while reading data frame header %v on stream %v: %v", i, stream.id, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
sz := binary.BigEndian.Uint32(header[1:])
|
sz := binary.BigEndian.Uint32(header[1:])
|
||||||
recvMsg := make([]byte, int(sz))
|
recvMsg := make([]byte, int(sz))
|
||||||
if _, err := stream.Read(recvMsg); err != nil {
|
if _, err := stream.readTo(recvMsg); err != nil {
|
||||||
t.Errorf("Error on client while reading data %v on stream %v: %v", i, stream.id, err)
|
t.Errorf("Error on client while reading data %v on stream %v: %v", i, stream.id, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -1680,7 +1715,7 @@ func testFlowControlAccountCheck(t *testing.T, msgSize int, wc windowSizeConfig)
|
||||||
// Close all streams
|
// Close all streams
|
||||||
for _, stream := range clientStreams {
|
for _, stream := range clientStreams {
|
||||||
client.Write(stream, nil, nil, &Options{Last: true})
|
client.Write(stream, nil, nil, &Options{Last: true})
|
||||||
if _, err := stream.Read(make([]byte, 5)); err != io.EOF {
|
if _, err := stream.readTo(make([]byte, 5)); err != io.EOF {
|
||||||
t.Fatalf("Client expected an EOF from the server. Got: %v", err)
|
t.Fatalf("Client expected an EOF from the server. Got: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -1752,21 +1787,19 @@ func (s) TestReadGivesSameErrorAfterAnyErrorOccurs(t *testing.T) {
|
||||||
}
|
}
|
||||||
s.trReader = &transportReader{
|
s.trReader = &transportReader{
|
||||||
reader: &recvBufferReader{
|
reader: &recvBufferReader{
|
||||||
ctx: s.ctx,
|
ctx: s.ctx,
|
||||||
ctxDone: s.ctx.Done(),
|
ctxDone: s.ctx.Done(),
|
||||||
recv: s.buf,
|
recv: s.buf,
|
||||||
freeBuffer: func(*bytes.Buffer) {},
|
|
||||||
},
|
},
|
||||||
windowHandler: func(int) {},
|
windowHandler: func(int) {},
|
||||||
}
|
}
|
||||||
testData := make([]byte, 1)
|
testData := make([]byte, 1)
|
||||||
testData[0] = 5
|
testData[0] = 5
|
||||||
testBuffer := bytes.NewBuffer(testData)
|
|
||||||
testErr := errors.New("test error")
|
testErr := errors.New("test error")
|
||||||
s.write(recvMsg{buffer: testBuffer, err: testErr})
|
s.write(recvMsg{buffer: mem.NewBuffer(&testData, nil), err: testErr})
|
||||||
|
|
||||||
inBuf := make([]byte, 1)
|
inBuf := make([]byte, 1)
|
||||||
actualCount, actualErr := s.Read(inBuf)
|
actualCount, actualErr := s.readTo(inBuf)
|
||||||
if actualCount != 0 {
|
if actualCount != 0 {
|
||||||
t.Errorf("actualCount, _ := s.Read(_) differs; want 0; got %v", actualCount)
|
t.Errorf("actualCount, _ := s.Read(_) differs; want 0; got %v", actualCount)
|
||||||
}
|
}
|
||||||
|
|
@ -1774,12 +1807,12 @@ func (s) TestReadGivesSameErrorAfterAnyErrorOccurs(t *testing.T) {
|
||||||
t.Errorf("_ , actualErr := s.Read(_) differs; want actualErr.Error() to be %v; got %v", testErr.Error(), actualErr.Error())
|
t.Errorf("_ , actualErr := s.Read(_) differs; want actualErr.Error() to be %v; got %v", testErr.Error(), actualErr.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
s.write(recvMsg{buffer: testBuffer, err: nil})
|
s.write(recvMsg{buffer: mem.NewBuffer(&testData, nil), err: nil})
|
||||||
s.write(recvMsg{buffer: testBuffer, err: errors.New("different error from first")})
|
s.write(recvMsg{buffer: mem.NewBuffer(&testData, nil), err: errors.New("different error from first")})
|
||||||
|
|
||||||
for i := 0; i < 2; i++ {
|
for i := 0; i < 2; i++ {
|
||||||
inBuf := make([]byte, 1)
|
inBuf := make([]byte, 1)
|
||||||
actualCount, actualErr := s.Read(inBuf)
|
actualCount, actualErr := s.readTo(inBuf)
|
||||||
if actualCount != 0 {
|
if actualCount != 0 {
|
||||||
t.Errorf("actualCount, _ := s.Read(_) differs; want %v; got %v", 0, actualCount)
|
t.Errorf("actualCount, _ := s.Read(_) differs; want %v; got %v", 0, actualCount)
|
||||||
}
|
}
|
||||||
|
|
@ -2208,11 +2241,11 @@ func (s) TestPingPong1B(t *testing.T) {
|
||||||
runPingPongTest(t, 1)
|
runPingPongTest(t, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s) TestPingPong1KB(t *testing.T) {
|
func TestPingPong1KB(t *testing.T) {
|
||||||
runPingPongTest(t, 1024)
|
runPingPongTest(t, 1024)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s) TestPingPong64KB(t *testing.T) {
|
func TestPingPong64KB(t *testing.T) {
|
||||||
runPingPongTest(t, 65536)
|
runPingPongTest(t, 65536)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -2247,24 +2280,24 @@ func runPingPongTest(t *testing.T, msgSize int) {
|
||||||
opts := &Options{}
|
opts := &Options{}
|
||||||
incomingHeader := make([]byte, 5)
|
incomingHeader := make([]byte, 5)
|
||||||
|
|
||||||
ctx, cancel = context.WithTimeout(ctx, time.Second)
|
ctx, cancel = context.WithTimeout(ctx, 10*time.Millisecond)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
for ctx.Err() == nil {
|
for ctx.Err() == nil {
|
||||||
if err := client.Write(stream, outgoingHeader, msg, opts); err != nil {
|
if err := client.Write(stream, outgoingHeader, newBufferSlice(msg), opts); err != nil {
|
||||||
t.Fatalf("Error on client while writing message. Err: %v", err)
|
t.Fatalf("Error on client while writing message. Err: %v", err)
|
||||||
}
|
}
|
||||||
if _, err := stream.Read(incomingHeader); err != nil {
|
if _, err := stream.readTo(incomingHeader); err != nil {
|
||||||
t.Fatalf("Error on client while reading data header. Err: %v", err)
|
t.Fatalf("Error on client while reading data header. Err: %v", err)
|
||||||
}
|
}
|
||||||
sz := binary.BigEndian.Uint32(incomingHeader[1:])
|
sz := binary.BigEndian.Uint32(incomingHeader[1:])
|
||||||
recvMsg := make([]byte, int(sz))
|
recvMsg := make([]byte, int(sz))
|
||||||
if _, err := stream.Read(recvMsg); err != nil {
|
if _, err := stream.readTo(recvMsg); err != nil {
|
||||||
t.Fatalf("Error on client while reading data. Err: %v", err)
|
t.Fatalf("Error on client while reading data. Err: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
client.Write(stream, nil, nil, &Options{Last: true})
|
client.Write(stream, nil, nil, &Options{Last: true})
|
||||||
if _, err := stream.Read(incomingHeader); err != io.EOF {
|
if _, err := stream.readTo(incomingHeader); err != io.EOF {
|
||||||
t.Fatalf("Client expected EOF from the server. Got: %v", err)
|
t.Fatalf("Client expected EOF from the server. Got: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -29,10 +29,10 @@ import (
|
||||||
// decreased memory allocation.
|
// decreased memory allocation.
|
||||||
type BufferPool interface {
|
type BufferPool interface {
|
||||||
// Get returns a buffer with specified length from the pool.
|
// Get returns a buffer with specified length from the pool.
|
||||||
Get(length int) []byte
|
Get(length int) *[]byte
|
||||||
|
|
||||||
// Put returns a buffer to the pool.
|
// Put returns a buffer to the pool.
|
||||||
Put([]byte)
|
Put(*[]byte)
|
||||||
}
|
}
|
||||||
|
|
||||||
var defaultBufferPoolSizes = []int{
|
var defaultBufferPoolSizes = []int{
|
||||||
|
|
@ -48,7 +48,13 @@ var defaultBufferPool BufferPool
|
||||||
func init() {
|
func init() {
|
||||||
defaultBufferPool = NewTieredBufferPool(defaultBufferPoolSizes...)
|
defaultBufferPool = NewTieredBufferPool(defaultBufferPoolSizes...)
|
||||||
|
|
||||||
internal.SetDefaultBufferPoolForTesting = func(pool BufferPool) { defaultBufferPool = pool }
|
internal.SetDefaultBufferPoolForTesting = func(pool BufferPool) {
|
||||||
|
defaultBufferPool = pool
|
||||||
|
}
|
||||||
|
|
||||||
|
internal.SetBufferPoolingThresholdForTesting = func(threshold int) {
|
||||||
|
bufferPoolingThreshold = threshold
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultBufferPool returns the current default buffer pool. It is a BufferPool
|
// DefaultBufferPool returns the current default buffer pool. It is a BufferPool
|
||||||
|
|
@ -78,12 +84,12 @@ type tieredBufferPool struct {
|
||||||
fallbackPool simpleBufferPool
|
fallbackPool simpleBufferPool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *tieredBufferPool) Get(size int) []byte {
|
func (p *tieredBufferPool) Get(size int) *[]byte {
|
||||||
return p.getPool(size).Get(size)
|
return p.getPool(size).Get(size)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *tieredBufferPool) Put(buf []byte) {
|
func (p *tieredBufferPool) Put(buf *[]byte) {
|
||||||
p.getPool(cap(buf)).Put(buf)
|
p.getPool(cap(*buf)).Put(buf)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *tieredBufferPool) getPool(size int) BufferPool {
|
func (p *tieredBufferPool) getPool(size int) BufferPool {
|
||||||
|
|
@ -111,21 +117,22 @@ type sizedBufferPool struct {
|
||||||
defaultSize int
|
defaultSize int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *sizedBufferPool) Get(size int) []byte {
|
func (p *sizedBufferPool) Get(size int) *[]byte {
|
||||||
bs := *p.pool.Get().(*[]byte)
|
buf := p.pool.Get().(*[]byte)
|
||||||
return bs[:size]
|
b := *buf
|
||||||
|
clear(b[:cap(b)])
|
||||||
|
*buf = b[:size]
|
||||||
|
return buf
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *sizedBufferPool) Put(buf []byte) {
|
func (p *sizedBufferPool) Put(buf *[]byte) {
|
||||||
if cap(buf) < p.defaultSize {
|
if cap(*buf) < p.defaultSize {
|
||||||
// Ignore buffers that are too small to fit in the pool. Otherwise, when
|
// Ignore buffers that are too small to fit in the pool. Otherwise, when
|
||||||
// Get is called it will panic as it tries to index outside the bounds
|
// Get is called it will panic as it tries to index outside the bounds
|
||||||
// of the buffer.
|
// of the buffer.
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
buf = buf[:cap(buf)]
|
p.pool.Put(buf)
|
||||||
clear(buf)
|
|
||||||
p.pool.Put(&buf)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSizedBufferPool(size int) *sizedBufferPool {
|
func newSizedBufferPool(size int) *sizedBufferPool {
|
||||||
|
|
@ -150,10 +157,11 @@ type simpleBufferPool struct {
|
||||||
pool sync.Pool
|
pool sync.Pool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *simpleBufferPool) Get(size int) []byte {
|
func (p *simpleBufferPool) Get(size int) *[]byte {
|
||||||
bs, ok := p.pool.Get().(*[]byte)
|
bs, ok := p.pool.Get().(*[]byte)
|
||||||
if ok && cap(*bs) >= size {
|
if ok && cap(*bs) >= size {
|
||||||
return (*bs)[:size]
|
*bs = (*bs)[:size]
|
||||||
|
return bs
|
||||||
}
|
}
|
||||||
|
|
||||||
// A buffer was pulled from the pool, but it is too small. Put it back in
|
// A buffer was pulled from the pool, but it is too small. Put it back in
|
||||||
|
|
@ -162,13 +170,12 @@ func (p *simpleBufferPool) Get(size int) []byte {
|
||||||
p.pool.Put(bs)
|
p.pool.Put(bs)
|
||||||
}
|
}
|
||||||
|
|
||||||
return make([]byte, size)
|
b := make([]byte, size)
|
||||||
|
return &b
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *simpleBufferPool) Put(buf []byte) {
|
func (p *simpleBufferPool) Put(buf *[]byte) {
|
||||||
buf = buf[:cap(buf)]
|
p.pool.Put(buf)
|
||||||
clear(buf)
|
|
||||||
p.pool.Put(&buf)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ BufferPool = NopBufferPool{}
|
var _ BufferPool = NopBufferPool{}
|
||||||
|
|
@ -177,10 +184,11 @@ var _ BufferPool = NopBufferPool{}
|
||||||
type NopBufferPool struct{}
|
type NopBufferPool struct{}
|
||||||
|
|
||||||
// Get returns a buffer with specified length from the pool.
|
// Get returns a buffer with specified length from the pool.
|
||||||
func (NopBufferPool) Get(length int) []byte {
|
func (NopBufferPool) Get(length int) *[]byte {
|
||||||
return make([]byte, length)
|
b := make([]byte, length)
|
||||||
|
return &b
|
||||||
}
|
}
|
||||||
|
|
||||||
// Put returns a buffer to the pool.
|
// Put returns a buffer to the pool.
|
||||||
func (NopBufferPool) Put([]byte) {
|
func (NopBufferPool) Put(*[]byte) {
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -19,9 +19,10 @@
|
||||||
package mem_test
|
package mem_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"testing"
|
"testing"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
|
||||||
"google.golang.org/grpc/mem"
|
"google.golang.org/grpc/mem"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -38,8 +39,8 @@ func (s) TestBufferPool(t *testing.T) {
|
||||||
for _, p := range pools {
|
for _, p := range pools {
|
||||||
for _, l := range testSizes {
|
for _, l := range testSizes {
|
||||||
bs := p.Get(l)
|
bs := p.Get(l)
|
||||||
if len(bs) != l {
|
if len(*bs) != l {
|
||||||
t.Fatalf("Get(%d) returned buffer of length %d, want %d", l, len(bs), l)
|
t.Fatalf("Get(%d) returned buffer of length %d, want %d", l, len(*bs), l)
|
||||||
}
|
}
|
||||||
|
|
||||||
p.Put(bs)
|
p.Put(bs)
|
||||||
|
|
@ -50,24 +51,37 @@ func (s) TestBufferPool(t *testing.T) {
|
||||||
func (s) TestBufferPoolClears(t *testing.T) {
|
func (s) TestBufferPoolClears(t *testing.T) {
|
||||||
pool := mem.NewTieredBufferPool(4)
|
pool := mem.NewTieredBufferPool(4)
|
||||||
|
|
||||||
buf := pool.Get(4)
|
for {
|
||||||
copy(buf, "1234")
|
buf1 := pool.Get(4)
|
||||||
pool.Put(buf)
|
copy(*buf1, "1234")
|
||||||
|
pool.Put(buf1)
|
||||||
|
|
||||||
if !cmp.Equal(buf, make([]byte, 4)) {
|
buf2 := pool.Get(4)
|
||||||
t.Fatalf("buffer not cleared")
|
if unsafe.SliceData(*buf1) != unsafe.SliceData(*buf2) {
|
||||||
|
pool.Put(buf2)
|
||||||
|
// This test is only relevant if a buffer is reused, otherwise try again. This
|
||||||
|
// can happen if a GC pause happens between putting the buffer back in the pool
|
||||||
|
// and getting a new one.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(*buf1, make([]byte, 4)) {
|
||||||
|
t.Fatalf("buffer not cleared")
|
||||||
|
}
|
||||||
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s) TestBufferPoolIgnoresShortBuffers(t *testing.T) {
|
func (s) TestBufferPoolIgnoresShortBuffers(t *testing.T) {
|
||||||
pool := mem.NewTieredBufferPool(10, 20)
|
pool := mem.NewTieredBufferPool(10, 20)
|
||||||
buf := pool.Get(1)
|
buf := pool.Get(1)
|
||||||
if cap(buf) != 10 {
|
if cap(*buf) != 10 {
|
||||||
t.Fatalf("Get(1) returned buffer with capacity: %d, want 10", cap(buf))
|
t.Fatalf("Get(1) returned buffer with capacity: %d, want 10", cap(*buf))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Insert a short buffer into the pool, which is currently empty.
|
// Insert a short buffer into the pool, which is currently empty.
|
||||||
pool.Put(make([]byte, 1))
|
short := make([]byte, 1)
|
||||||
|
pool.Put(&short)
|
||||||
// Then immediately request a buffer that would be pulled from the pool where the
|
// Then immediately request a buffer that would be pulled from the pool where the
|
||||||
// short buffer would have been returned. If the short buffer is pulled from the
|
// short buffer would have been returned. If the short buffer is pulled from the
|
||||||
// pool, it could cause a panic.
|
// pool, it could cause a panic.
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@
|
||||||
package mem
|
package mem
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"compress/flate"
|
||||||
"io"
|
"io"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -36,7 +37,7 @@ import (
|
||||||
// By convention, any APIs that return (mem.BufferSlice, error) should reduce
|
// By convention, any APIs that return (mem.BufferSlice, error) should reduce
|
||||||
// the burden on the caller by never returning a mem.BufferSlice that needs to
|
// the burden on the caller by never returning a mem.BufferSlice that needs to
|
||||||
// be freed if the error is non-nil, unless explicitly stated.
|
// be freed if the error is non-nil, unless explicitly stated.
|
||||||
type BufferSlice []*Buffer
|
type BufferSlice []Buffer
|
||||||
|
|
||||||
// Len returns the sum of the length of all the Buffers in this slice.
|
// Len returns the sum of the length of all the Buffers in this slice.
|
||||||
//
|
//
|
||||||
|
|
@ -52,14 +53,11 @@ func (s BufferSlice) Len() int {
|
||||||
return length
|
return length
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ref returns a new BufferSlice containing a new reference of each Buffer in the
|
// Ref invokes Ref on each buffer in the slice.
|
||||||
// input slice.
|
func (s BufferSlice) Ref() {
|
||||||
func (s BufferSlice) Ref() BufferSlice {
|
for _, b := range s {
|
||||||
out := make(BufferSlice, len(s))
|
b.Ref()
|
||||||
for i, b := range s {
|
|
||||||
out[i] = b.Ref()
|
|
||||||
}
|
}
|
||||||
return out
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Free invokes Buffer.Free() on each Buffer in the slice.
|
// Free invokes Buffer.Free() on each Buffer in the slice.
|
||||||
|
|
@ -97,54 +95,73 @@ func (s BufferSlice) Materialize() []byte {
|
||||||
// to a single Buffer pulled from the given BufferPool. As a special case, if the
|
// to a single Buffer pulled from the given BufferPool. As a special case, if the
|
||||||
// input BufferSlice only actually has one Buffer, this function has nothing to
|
// input BufferSlice only actually has one Buffer, this function has nothing to
|
||||||
// do and simply returns said Buffer.
|
// do and simply returns said Buffer.
|
||||||
func (s BufferSlice) MaterializeToBuffer(pool BufferPool) *Buffer {
|
func (s BufferSlice) MaterializeToBuffer(pool BufferPool) Buffer {
|
||||||
if len(s) == 1 {
|
if len(s) == 1 {
|
||||||
return s[0].Ref()
|
s[0].Ref()
|
||||||
|
return s[0]
|
||||||
}
|
}
|
||||||
buf := pool.Get(s.Len())
|
sLen := s.Len()
|
||||||
s.CopyTo(buf)
|
if sLen == 0 {
|
||||||
return NewBuffer(buf, pool.Put)
|
return emptyBuffer{}
|
||||||
|
}
|
||||||
|
buf := pool.Get(sLen)
|
||||||
|
s.CopyTo(*buf)
|
||||||
|
return NewBuffer(buf, pool)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reader returns a new Reader for the input slice after taking references to
|
// Reader returns a new Reader for the input slice after taking references to
|
||||||
// each underlying buffer.
|
// each underlying buffer.
|
||||||
func (s BufferSlice) Reader() *Reader {
|
func (s BufferSlice) Reader() Reader {
|
||||||
return &Reader{
|
s.Ref()
|
||||||
data: s.Ref(),
|
return &sliceReader{
|
||||||
|
data: s,
|
||||||
len: s.Len(),
|
len: s.Len(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ io.ReadCloser = (*Reader)(nil)
|
|
||||||
|
|
||||||
// Reader exposes a BufferSlice's data as an io.Reader, allowing it to interface
|
// Reader exposes a BufferSlice's data as an io.Reader, allowing it to interface
|
||||||
// with other parts systems. It also provides an additional convenience method
|
// with other parts systems. It also provides an additional convenience method
|
||||||
// Remaining(), which returns the number of unread bytes remaining in the slice.
|
// Remaining(), which returns the number of unread bytes remaining in the slice.
|
||||||
//
|
// Buffers will be freed as they are read.
|
||||||
// Note that reading data from the reader does not free the underlying buffers!
|
type Reader interface {
|
||||||
// Only calling Close once all data is read will free the buffers.
|
flate.Reader
|
||||||
type Reader struct {
|
// Close frees the underlying BufferSlice and never returns an error. Subsequent
|
||||||
|
// calls to Read will return (0, io.EOF).
|
||||||
|
Close() error
|
||||||
|
// Remaining returns the number of unread bytes remaining in the slice.
|
||||||
|
Remaining() int
|
||||||
|
}
|
||||||
|
|
||||||
|
type sliceReader struct {
|
||||||
data BufferSlice
|
data BufferSlice
|
||||||
len int
|
len int
|
||||||
// The index into data[0].ReadOnlyData().
|
// The index into data[0].ReadOnlyData().
|
||||||
bufferIdx int
|
bufferIdx int
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remaining returns the number of unread bytes remaining in the slice.
|
func (r *sliceReader) Remaining() int {
|
||||||
func (r *Reader) Remaining() int {
|
|
||||||
return r.len
|
return r.len
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close frees the underlying BufferSlice and never returns an error. Subsequent
|
func (r *sliceReader) Close() error {
|
||||||
// calls to Read will return (0, io.EOF).
|
|
||||||
func (r *Reader) Close() error {
|
|
||||||
r.data.Free()
|
r.data.Free()
|
||||||
r.data = nil
|
r.data = nil
|
||||||
r.len = 0
|
r.len = 0
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Reader) Read(buf []byte) (n int, _ error) {
|
func (r *sliceReader) freeFirstBufferIfEmpty() bool {
|
||||||
|
if len(r.data) == 0 || r.bufferIdx != len(r.data[0].ReadOnlyData()) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
r.data[0].Free()
|
||||||
|
r.data = r.data[1:]
|
||||||
|
r.bufferIdx = 0
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *sliceReader) Read(buf []byte) (n int, _ error) {
|
||||||
if r.len == 0 {
|
if r.len == 0 {
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
}
|
}
|
||||||
|
|
@ -159,19 +176,32 @@ func (r *Reader) Read(buf []byte) (n int, _ error) {
|
||||||
n += copied // Increment the total number of bytes read.
|
n += copied // Increment the total number of bytes read.
|
||||||
buf = buf[copied:] // Shrink the given byte slice.
|
buf = buf[copied:] // Shrink the given byte slice.
|
||||||
|
|
||||||
// If we have copied all of the data from the first Buffer, free it and
|
// If we have copied all the data from the first Buffer, free it and advance to
|
||||||
// advance to the next in the slice.
|
// the next in the slice.
|
||||||
if r.bufferIdx == len(data) {
|
r.freeFirstBufferIfEmpty()
|
||||||
oldBuffer := r.data[0]
|
|
||||||
oldBuffer.Free()
|
|
||||||
r.data = r.data[1:]
|
|
||||||
r.bufferIdx = 0
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *sliceReader) ReadByte() (byte, error) {
|
||||||
|
if r.len == 0 {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
// There may be any number of empty buffers in the slice, clear them all until a
|
||||||
|
// non-empty buffer is reached. This is guaranteed to exit since r.len is not 0.
|
||||||
|
for r.freeFirstBufferIfEmpty() {
|
||||||
|
}
|
||||||
|
|
||||||
|
b := r.data[0].ReadOnlyData()[r.bufferIdx]
|
||||||
|
r.len--
|
||||||
|
r.bufferIdx++
|
||||||
|
// Free the first buffer in the slice if the last byte was read
|
||||||
|
r.freeFirstBufferIfEmpty()
|
||||||
|
return b, nil
|
||||||
|
}
|
||||||
|
|
||||||
var _ io.Writer = (*writer)(nil)
|
var _ io.Writer = (*writer)(nil)
|
||||||
|
|
||||||
type writer struct {
|
type writer struct {
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,10 @@ import (
|
||||||
"google.golang.org/grpc/mem"
|
"google.golang.org/grpc/mem"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func newBuffer(data []byte, pool mem.BufferPool) mem.Buffer {
|
||||||
|
return mem.NewBuffer(&data, pool)
|
||||||
|
}
|
||||||
|
|
||||||
func (s) TestBufferSlice_Len(t *testing.T) {
|
func (s) TestBufferSlice_Len(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|
@ -40,15 +44,15 @@ func (s) TestBufferSlice_Len(t *testing.T) {
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "single",
|
name: "single",
|
||||||
in: mem.BufferSlice{mem.NewBuffer([]byte("abcd"), nil)},
|
in: mem.BufferSlice{newBuffer([]byte("abcd"), nil)},
|
||||||
want: 4,
|
want: 4,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "multiple",
|
name: "multiple",
|
||||||
in: mem.BufferSlice{
|
in: mem.BufferSlice{
|
||||||
mem.NewBuffer([]byte("abcd"), nil),
|
newBuffer([]byte("abcd"), nil),
|
||||||
mem.NewBuffer([]byte("abcd"), nil),
|
newBuffer([]byte("abcd"), nil),
|
||||||
mem.NewBuffer([]byte("abcd"), nil),
|
newBuffer([]byte("abcd"), nil),
|
||||||
},
|
},
|
||||||
want: 12,
|
want: 12,
|
||||||
},
|
},
|
||||||
|
|
@ -65,15 +69,15 @@ func (s) TestBufferSlice_Len(t *testing.T) {
|
||||||
func (s) TestBufferSlice_Ref(t *testing.T) {
|
func (s) TestBufferSlice_Ref(t *testing.T) {
|
||||||
// Create a new buffer slice and a reference to it.
|
// Create a new buffer slice and a reference to it.
|
||||||
bs := mem.BufferSlice{
|
bs := mem.BufferSlice{
|
||||||
mem.NewBuffer([]byte("abcd"), nil),
|
newBuffer([]byte("abcd"), nil),
|
||||||
mem.NewBuffer([]byte("abcd"), nil),
|
newBuffer([]byte("abcd"), nil),
|
||||||
}
|
}
|
||||||
bsRef := bs.Ref()
|
bs.Ref()
|
||||||
|
|
||||||
// Free the original buffer slice and verify that the reference can still
|
// Free the original buffer slice and verify that the reference can still
|
||||||
// read data from it.
|
// read data from it.
|
||||||
bs.Free()
|
bs.Free()
|
||||||
got := bsRef.Materialize()
|
got := bs.Materialize()
|
||||||
want := []byte("abcdabcd")
|
want := []byte("abcdabcd")
|
||||||
if !bytes.Equal(got, want) {
|
if !bytes.Equal(got, want) {
|
||||||
t.Errorf("BufferSlice.Materialize() = %s, want %s", string(got), string(want))
|
t.Errorf("BufferSlice.Materialize() = %s, want %s", string(got), string(want))
|
||||||
|
|
@ -89,16 +93,16 @@ func (s) TestBufferSlice_MaterializeToBuffer(t *testing.T) {
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "single",
|
name: "single",
|
||||||
in: mem.BufferSlice{mem.NewBuffer([]byte("abcd"), nil)},
|
in: mem.BufferSlice{newBuffer([]byte("abcd"), nil)},
|
||||||
pool: nil, // MaterializeToBuffer should not use the pool in this case.
|
pool: nil, // MaterializeToBuffer should not use the pool in this case.
|
||||||
wantData: []byte("abcd"),
|
wantData: []byte("abcd"),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "multiple",
|
name: "multiple",
|
||||||
in: mem.BufferSlice{
|
in: mem.BufferSlice{
|
||||||
mem.NewBuffer([]byte("abcd"), nil),
|
newBuffer([]byte("abcd"), nil),
|
||||||
mem.NewBuffer([]byte("abcd"), nil),
|
newBuffer([]byte("abcd"), nil),
|
||||||
mem.NewBuffer([]byte("abcd"), nil),
|
newBuffer([]byte("abcd"), nil),
|
||||||
},
|
},
|
||||||
pool: mem.DefaultBufferPool(),
|
pool: mem.DefaultBufferPool(),
|
||||||
wantData: []byte("abcdabcdabcd"),
|
wantData: []byte("abcdabcdabcd"),
|
||||||
|
|
@ -106,6 +110,7 @@ func (s) TestBufferSlice_MaterializeToBuffer(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
defer tt.in.Free()
|
||||||
got := tt.in.MaterializeToBuffer(tt.pool)
|
got := tt.in.MaterializeToBuffer(tt.pool)
|
||||||
defer got.Free()
|
defer got.Free()
|
||||||
if !bytes.Equal(got.ReadOnlyData(), tt.wantData) {
|
if !bytes.Equal(got.ReadOnlyData(), tt.wantData) {
|
||||||
|
|
@ -117,9 +122,9 @@ func (s) TestBufferSlice_MaterializeToBuffer(t *testing.T) {
|
||||||
|
|
||||||
func (s) TestBufferSlice_Reader(t *testing.T) {
|
func (s) TestBufferSlice_Reader(t *testing.T) {
|
||||||
bs := mem.BufferSlice{
|
bs := mem.BufferSlice{
|
||||||
mem.NewBuffer([]byte("abcd"), nil),
|
newBuffer([]byte("abcd"), nil),
|
||||||
mem.NewBuffer([]byte("abcd"), nil),
|
newBuffer([]byte("abcd"), nil),
|
||||||
mem.NewBuffer([]byte("abcd"), nil),
|
newBuffer([]byte("abcd"), nil),
|
||||||
}
|
}
|
||||||
wantData := []byte("abcdabcdabcd")
|
wantData := []byte("abcdabcdabcd")
|
||||||
|
|
||||||
|
|
|
||||||
217
mem/buffers.go
217
mem/buffers.go
|
|
@ -27,13 +27,14 @@ package mem
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
)
|
)
|
||||||
|
|
||||||
// A Buffer represents a reference counted piece of data (in bytes) that can be
|
// A Buffer represents a reference counted piece of data (in bytes) that can be
|
||||||
// acquired by a call to NewBuffer() or Copy(). A reference to a Buffer may be
|
// acquired by a call to NewBuffer() or Copy(). A reference to a Buffer may be
|
||||||
// released by calling Free(), which invokes the given free function only after
|
// released by calling Free(), which invokes the free function given at creation
|
||||||
// all references are released.
|
// only after all references are released.
|
||||||
//
|
//
|
||||||
// Note that a Buffer is not safe for concurrent access and instead each
|
// Note that a Buffer is not safe for concurrent access and instead each
|
||||||
// goroutine should use its own reference to the data, which can be acquired via
|
// goroutine should use its own reference to the data, which can be acquired via
|
||||||
|
|
@ -41,23 +42,61 @@ import (
|
||||||
//
|
//
|
||||||
// Attempts to access the underlying data after releasing the reference to the
|
// Attempts to access the underlying data after releasing the reference to the
|
||||||
// Buffer will panic.
|
// Buffer will panic.
|
||||||
type Buffer struct {
|
type Buffer interface {
|
||||||
data []byte
|
// ReadOnlyData returns the underlying byte slice. Note that it is undefined
|
||||||
refs *atomic.Int32
|
// behavior to modify the contents of this slice in any way.
|
||||||
free func()
|
ReadOnlyData() []byte
|
||||||
freed bool
|
// Ref increases the reference counter for this Buffer.
|
||||||
|
Ref()
|
||||||
|
// Free decrements this Buffer's reference counter and frees the underlying
|
||||||
|
// byte slice if the counter reaches 0 as a result of this call.
|
||||||
|
Free()
|
||||||
|
// Len returns the Buffer's size.
|
||||||
|
Len() int
|
||||||
|
|
||||||
|
split(n int) (left, right Buffer)
|
||||||
|
read(buf []byte) (int, Buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewBuffer creates a new Buffer from the given data, initializing the
|
var (
|
||||||
// reference counter to 1. The given free function is called when all references
|
bufferPoolingThreshold = 1 << 10
|
||||||
// to the returned Buffer are released.
|
|
||||||
|
bufferObjectPool = sync.Pool{New: func() any { return new(buffer) }}
|
||||||
|
refObjectPool = sync.Pool{New: func() any { return new(atomic.Int32) }}
|
||||||
|
)
|
||||||
|
|
||||||
|
func IsBelowBufferPoolingThreshold(size int) bool {
|
||||||
|
return size <= bufferPoolingThreshold
|
||||||
|
}
|
||||||
|
|
||||||
|
type buffer struct {
|
||||||
|
origData *[]byte
|
||||||
|
data []byte
|
||||||
|
refs *atomic.Int32
|
||||||
|
pool BufferPool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBuffer() *buffer {
|
||||||
|
return bufferObjectPool.Get().(*buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBuffer creates a new Buffer from the given data, initializing the reference
|
||||||
|
// counter to 1. The data will then be returned to the given pool when all
|
||||||
|
// references to the returned Buffer are released. As a special case to avoid
|
||||||
|
// additional allocations, if the given buffer pool is nil, the returned buffer
|
||||||
|
// will be a "no-op" Buffer where invoking Buffer.Free() does nothing and the
|
||||||
|
// underlying data is never freed.
|
||||||
//
|
//
|
||||||
// Note that the backing array of the given data is not copied.
|
// Note that the backing array of the given data is not copied.
|
||||||
func NewBuffer(data []byte, onFree func([]byte)) *Buffer {
|
func NewBuffer(data *[]byte, pool BufferPool) Buffer {
|
||||||
b := &Buffer{data: data, refs: new(atomic.Int32)}
|
if pool == nil || IsBelowBufferPoolingThreshold(len(*data)) {
|
||||||
if onFree != nil {
|
return (SliceBuffer)(*data)
|
||||||
b.free = func() { onFree(data) }
|
|
||||||
}
|
}
|
||||||
|
b := newBuffer()
|
||||||
|
b.origData = data
|
||||||
|
b.data = *data
|
||||||
|
b.pool = pool
|
||||||
|
b.refs = refObjectPool.Get().(*atomic.Int32)
|
||||||
b.refs.Add(1)
|
b.refs.Add(1)
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
@ -68,82 +107,146 @@ func NewBuffer(data []byte, onFree func([]byte)) *Buffer {
|
||||||
// It acquires a []byte from the given pool and copies over the backing array
|
// It acquires a []byte from the given pool and copies over the backing array
|
||||||
// of the given data. The []byte acquired from the pool is returned to the
|
// of the given data. The []byte acquired from the pool is returned to the
|
||||||
// pool when all references to the returned Buffer are released.
|
// pool when all references to the returned Buffer are released.
|
||||||
func Copy(data []byte, pool BufferPool) *Buffer {
|
func Copy(data []byte, pool BufferPool) Buffer {
|
||||||
|
if IsBelowBufferPoolingThreshold(len(data)) {
|
||||||
|
buf := make(SliceBuffer, len(data))
|
||||||
|
copy(buf, data)
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
buf := pool.Get(len(data))
|
buf := pool.Get(len(data))
|
||||||
copy(buf, data)
|
copy(*buf, data)
|
||||||
return NewBuffer(buf, pool.Put)
|
return NewBuffer(buf, pool)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReadOnlyData returns the underlying byte slice. Note that it is undefined
|
func (b *buffer) ReadOnlyData() []byte {
|
||||||
// behavior to modify the contents of this slice in any way.
|
if b.refs == nil {
|
||||||
func (b *Buffer) ReadOnlyData() []byte {
|
|
||||||
if b.freed {
|
|
||||||
panic("Cannot read freed buffer")
|
panic("Cannot read freed buffer")
|
||||||
}
|
}
|
||||||
return b.data
|
return b.data
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ref returns a new reference to this Buffer's underlying byte slice.
|
func (b *buffer) Ref() {
|
||||||
func (b *Buffer) Ref() *Buffer {
|
if b.refs == nil {
|
||||||
if b.freed {
|
|
||||||
panic("Cannot ref freed buffer")
|
panic("Cannot ref freed buffer")
|
||||||
}
|
}
|
||||||
|
|
||||||
b.refs.Add(1)
|
b.refs.Add(1)
|
||||||
return &Buffer{
|
|
||||||
data: b.data,
|
|
||||||
refs: b.refs,
|
|
||||||
free: b.free,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Free decrements this Buffer's reference counter and frees the underlying
|
func (b *buffer) Free() {
|
||||||
// byte slice if the counter reaches 0 as a result of this call.
|
if b.refs == nil {
|
||||||
func (b *Buffer) Free() {
|
panic("Cannot free freed buffer")
|
||||||
if b.freed {
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
b.freed = true
|
|
||||||
refs := b.refs.Add(-1)
|
refs := b.refs.Add(-1)
|
||||||
if refs == 0 && b.free != nil {
|
switch {
|
||||||
b.free()
|
case refs > 0:
|
||||||
|
return
|
||||||
|
case refs == 0:
|
||||||
|
if b.pool != nil {
|
||||||
|
b.pool.Put(b.origData)
|
||||||
|
}
|
||||||
|
|
||||||
|
refObjectPool.Put(b.refs)
|
||||||
|
b.origData = nil
|
||||||
|
b.data = nil
|
||||||
|
b.refs = nil
|
||||||
|
b.pool = nil
|
||||||
|
bufferObjectPool.Put(b)
|
||||||
|
default:
|
||||||
|
panic("Cannot free freed buffer")
|
||||||
}
|
}
|
||||||
b.data = nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Len returns the Buffer's size.
|
func (b *buffer) Len() int {
|
||||||
func (b *Buffer) Len() int {
|
|
||||||
// Convenience: io.Reader returns (n int, err error), and n is often checked
|
|
||||||
// before err is checked. To mimic this, Len() should work on nil Buffers.
|
|
||||||
if b == nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
return len(b.ReadOnlyData())
|
return len(b.ReadOnlyData())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Split modifies the receiver to point to the first n bytes while it returns a
|
func (b *buffer) split(n int) (Buffer, Buffer) {
|
||||||
// new reference to the remaining bytes. The returned Buffer functions just like
|
if b.refs == nil {
|
||||||
// a normal reference acquired using Ref().
|
|
||||||
func (b *Buffer) Split(n int) *Buffer {
|
|
||||||
if b.freed {
|
|
||||||
panic("Cannot split freed buffer")
|
panic("Cannot split freed buffer")
|
||||||
}
|
}
|
||||||
|
|
||||||
b.refs.Add(1)
|
b.refs.Add(1)
|
||||||
|
split := newBuffer()
|
||||||
|
split.origData = b.origData
|
||||||
|
split.data = b.data[n:]
|
||||||
|
split.refs = b.refs
|
||||||
|
split.pool = b.pool
|
||||||
|
|
||||||
split := &Buffer{
|
b.data = b.data[:n]
|
||||||
refs: b.refs,
|
|
||||||
free: b.free,
|
return b, split
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *buffer) read(buf []byte) (int, Buffer) {
|
||||||
|
if b.refs == nil {
|
||||||
|
panic("Cannot read freed buffer")
|
||||||
}
|
}
|
||||||
|
|
||||||
b.data, split.data = b.data[:n], b.data[n:]
|
n := copy(buf, b.data)
|
||||||
|
if n == len(b.data) {
|
||||||
|
b.Free()
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
return split
|
b.data = b.data[n:]
|
||||||
|
return n, b
|
||||||
}
|
}
|
||||||
|
|
||||||
// String returns a string representation of the buffer. May be used for
|
// String returns a string representation of the buffer. May be used for
|
||||||
// debugging purposes.
|
// debugging purposes.
|
||||||
func (b *Buffer) String() string {
|
func (b *buffer) String() string {
|
||||||
return fmt.Sprintf("mem.Buffer(%p, data: %p, length: %d)", b, b.ReadOnlyData(), len(b.ReadOnlyData()))
|
return fmt.Sprintf("mem.Buffer(%p, data: %p, length: %d)", b, b.ReadOnlyData(), len(b.ReadOnlyData()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ReadUnsafe(dst []byte, buf Buffer) (int, Buffer) {
|
||||||
|
return buf.read(dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SplitUnsafe modifies the receiver to point to the first n bytes while it
|
||||||
|
// returns a new reference to the remaining bytes. The returned Buffer functions
|
||||||
|
// just like a normal reference acquired using Ref().
|
||||||
|
func SplitUnsafe(buf Buffer, n int) (left, right Buffer) {
|
||||||
|
return buf.split(n)
|
||||||
|
}
|
||||||
|
|
||||||
|
type emptyBuffer struct{}
|
||||||
|
|
||||||
|
func (e emptyBuffer) ReadOnlyData() []byte {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e emptyBuffer) Ref() {}
|
||||||
|
func (e emptyBuffer) Free() {}
|
||||||
|
|
||||||
|
func (e emptyBuffer) Len() int {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e emptyBuffer) split(n int) (left, right Buffer) {
|
||||||
|
return e, e
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e emptyBuffer) read(buf []byte) (int, Buffer) {
|
||||||
|
return 0, e
|
||||||
|
}
|
||||||
|
|
||||||
|
type SliceBuffer []byte
|
||||||
|
|
||||||
|
func (s SliceBuffer) ReadOnlyData() []byte { return s }
|
||||||
|
func (s SliceBuffer) Ref() {}
|
||||||
|
func (s SliceBuffer) Free() {}
|
||||||
|
func (s SliceBuffer) Len() int { return len(s) }
|
||||||
|
|
||||||
|
func (s SliceBuffer) split(n int) (left, right Buffer) {
|
||||||
|
return s[:n], s[n:]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s SliceBuffer) read(buf []byte) (int, Buffer) {
|
||||||
|
n := copy(buf, s)
|
||||||
|
if n == len(s) {
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
return n, s[n:]
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -20,24 +20,20 @@ package mem_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
|
"google.golang.org/grpc/internal"
|
||||||
"google.golang.org/grpc/internal/grpctest"
|
"google.golang.org/grpc/internal/grpctest"
|
||||||
"google.golang.org/grpc/mem"
|
"google.golang.org/grpc/mem"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
defaultTestTimeout = 5 * time.Second
|
|
||||||
defaultTestShortTimeout = 100 * time.Millisecond
|
|
||||||
)
|
|
||||||
|
|
||||||
type s struct {
|
type s struct {
|
||||||
grpctest.Tester
|
grpctest.Tester
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test(t *testing.T) {
|
func Test(t *testing.T) {
|
||||||
|
internal.SetBufferPoolingThresholdForTesting.(func(int))(0)
|
||||||
|
|
||||||
grpctest.RunSubTests(t, s{})
|
grpctest.RunSubTests(t, s{})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -45,29 +41,23 @@ func Test(t *testing.T) {
|
||||||
// the free function with the correct data.
|
// the free function with the correct data.
|
||||||
func (s) TestBuffer_NewBufferAndFree(t *testing.T) {
|
func (s) TestBuffer_NewBufferAndFree(t *testing.T) {
|
||||||
data := "abcd"
|
data := "abcd"
|
||||||
errCh := make(chan error, 1)
|
freed := false
|
||||||
freeF := func(got []byte) {
|
freeF := poolFunc(func(got *[]byte) {
|
||||||
if !bytes.Equal(got, []byte(data)) {
|
if !bytes.Equal(*got, []byte(data)) {
|
||||||
errCh <- fmt.Errorf("Free function called with bytes %s, want %s", string(got), string(data))
|
t.Fatalf("Free function called with bytes %s, want %s", string(*got), data)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
errCh <- nil
|
freed = true
|
||||||
}
|
})
|
||||||
|
|
||||||
buf := mem.NewBuffer([]byte(data), freeF)
|
buf := newBuffer([]byte(data), freeF)
|
||||||
if got := buf.ReadOnlyData(); !bytes.Equal(got, []byte(data)) {
|
if got := buf.ReadOnlyData(); !bytes.Equal(got, []byte(data)) {
|
||||||
t.Fatalf("Buffer contains data %s, want %s", string(got), string(data))
|
t.Fatalf("Buffer contains data %s, want %s", string(got), string(data))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify that the free function is invoked when all references are freed.
|
// Verify that the free function is invoked when all references are freed.
|
||||||
buf.Free()
|
buf.Free()
|
||||||
select {
|
if !freed {
|
||||||
case err := <-errCh:
|
t.Fatalf("Buffer not freed")
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
case <-time.After(defaultTestTimeout):
|
|
||||||
t.Fatalf("Timeout waiting for Buffer to be freed")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -76,84 +66,87 @@ func (s) TestBuffer_NewBufferAndFree(t *testing.T) {
|
||||||
// correct data, but only after all references are released.
|
// correct data, but only after all references are released.
|
||||||
func (s) TestBuffer_NewBufferRefAndFree(t *testing.T) {
|
func (s) TestBuffer_NewBufferRefAndFree(t *testing.T) {
|
||||||
data := "abcd"
|
data := "abcd"
|
||||||
errCh := make(chan error, 1)
|
freed := false
|
||||||
freeF := func(got []byte) {
|
freeF := poolFunc(func(got *[]byte) {
|
||||||
if !bytes.Equal(got, []byte(data)) {
|
if !bytes.Equal(*got, []byte(data)) {
|
||||||
errCh <- fmt.Errorf("Free function called with bytes %s, want %s", string(got), string(data))
|
t.Fatalf("Free function called with bytes %s, want %s", string(*got), string(data))
|
||||||
return
|
|
||||||
}
|
}
|
||||||
errCh <- nil
|
freed = true
|
||||||
}
|
})
|
||||||
|
|
||||||
buf := mem.NewBuffer([]byte(data), freeF)
|
buf := newBuffer([]byte(data), freeF)
|
||||||
if got := buf.ReadOnlyData(); !bytes.Equal(got, []byte(data)) {
|
if got := buf.ReadOnlyData(); !bytes.Equal(got, []byte(data)) {
|
||||||
t.Fatalf("Buffer contains data %s, want %s", string(got), string(data))
|
t.Fatalf("Buffer contains data %s, want %s", string(got), string(data))
|
||||||
}
|
}
|
||||||
|
|
||||||
bufRef := buf.Ref()
|
buf.Ref()
|
||||||
if got := bufRef.ReadOnlyData(); !bytes.Equal(got, []byte(data)) {
|
if got := buf.ReadOnlyData(); !bytes.Equal(got, []byte(data)) {
|
||||||
t.Fatalf("New reference to the Buffer contains data %s, want %s", string(got), string(data))
|
t.Fatalf("New reference to the Buffer contains data %s, want %s", string(got), string(data))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify that the free function is not invoked when all references are yet
|
// Verify that the free function is not invoked when all references are yet
|
||||||
// to be freed.
|
// to be freed.
|
||||||
buf.Free()
|
buf.Free()
|
||||||
select {
|
if freed {
|
||||||
case <-errCh:
|
|
||||||
t.Fatalf("Free function called before all references freed")
|
t.Fatalf("Free function called before all references freed")
|
||||||
case <-time.After(defaultTestShortTimeout):
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify that the free function is invoked when all references are freed.
|
// Verify that the free function is invoked when all references are freed.
|
||||||
bufRef.Free()
|
buf.Free()
|
||||||
select {
|
if !freed {
|
||||||
case err := <-errCh:
|
t.Fatalf("Buffer not freed")
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
case <-time.After(defaultTestTimeout):
|
|
||||||
t.Fatalf("Timeout waiting for Buffer to be freed")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// testBufferPool is a buffer pool that makes new buffer without pooling, and
|
func (s) TestBuffer_FreeAfterFree(t *testing.T) {
|
||||||
// notifies on a channel that a buffer was returned to the pool.
|
buf := newBuffer([]byte("abcd"), mem.NopBufferPool{})
|
||||||
type testBufferPool struct {
|
if buf.Len() != 4 {
|
||||||
putCh chan []byte
|
t.Fatalf("Buffer length is %d, want 4", buf.Len())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure that a double free does panic.
|
||||||
|
buf.Free()
|
||||||
|
defer checkForPanic(t, "Cannot free freed buffer")
|
||||||
|
buf.Free()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *testBufferPool) Get(length int) []byte {
|
type singleBufferPool struct {
|
||||||
return make([]byte, length)
|
t *testing.T
|
||||||
|
data *[]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *testBufferPool) Put(data []byte) {
|
func (s *singleBufferPool) Get(length int) *[]byte {
|
||||||
t.putCh <- data
|
if len(*s.data) != length {
|
||||||
|
s.t.Fatalf("Invalid requested length, got %d want %d", length, len(*s.data))
|
||||||
|
}
|
||||||
|
return s.data
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTestBufferPool() *testBufferPool {
|
func (s *singleBufferPool) Put(b *[]byte) {
|
||||||
return &testBufferPool{putCh: make(chan []byte, 1)}
|
if s.data != b {
|
||||||
|
s.t.Fatalf("Wrong buffer returned to pool, got %p want %p", b, s.data)
|
||||||
|
}
|
||||||
|
s.data = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tests that a buffer created with Copy, which when later freed, returns the underlying
|
// Tests that a buffer created with Copy, which when later freed, returns the underlying
|
||||||
// byte slice to the buffer pool.
|
// byte slice to the buffer pool.
|
||||||
func (s) TestBuffer_CopyAndFree(t *testing.T) {
|
func (s) TestBuffer_CopyAndFree(t *testing.T) {
|
||||||
data := "abcd"
|
data := []byte("abcd")
|
||||||
testPool := newTestBufferPool()
|
testPool := &singleBufferPool{
|
||||||
|
t: t,
|
||||||
|
data: &data,
|
||||||
|
}
|
||||||
|
|
||||||
buf := mem.Copy([]byte(data), testPool)
|
buf := mem.Copy(data, testPool)
|
||||||
if got := buf.ReadOnlyData(); !bytes.Equal(got, []byte(data)) {
|
if got := buf.ReadOnlyData(); !bytes.Equal(got, data) {
|
||||||
t.Fatalf("Buffer contains data %s, want %s", string(got), string(data))
|
t.Fatalf("Buffer contains data %s, want %s", string(got), string(data))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify that the free function is invoked when all references are freed.
|
// Verify that the free function is invoked when all references are freed.
|
||||||
buf.Free()
|
buf.Free()
|
||||||
select {
|
if testPool.data != nil {
|
||||||
case got := <-testPool.putCh:
|
t.Fatalf("Buffer not freed")
|
||||||
if !bytes.Equal(got, []byte(data)) {
|
|
||||||
t.Fatalf("Free function called with bytes %s, want %s", string(got), string(data))
|
|
||||||
}
|
|
||||||
case <-time.After(defaultTestTimeout):
|
|
||||||
t.Fatalf("Timeout waiting for Buffer to be freed")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -161,68 +154,103 @@ func (s) TestBuffer_CopyAndFree(t *testing.T) {
|
||||||
// acquired, which when later freed, returns the underlying byte slice to the
|
// acquired, which when later freed, returns the underlying byte slice to the
|
||||||
// buffer pool.
|
// buffer pool.
|
||||||
func (s) TestBuffer_CopyRefAndFree(t *testing.T) {
|
func (s) TestBuffer_CopyRefAndFree(t *testing.T) {
|
||||||
data := "abcd"
|
data := []byte("abcd")
|
||||||
testPool := newTestBufferPool()
|
testPool := &singleBufferPool{
|
||||||
|
t: t,
|
||||||
|
data: &data,
|
||||||
|
}
|
||||||
|
|
||||||
buf := mem.Copy([]byte(data), testPool)
|
buf := mem.Copy(data, testPool)
|
||||||
if got := buf.ReadOnlyData(); !bytes.Equal(got, []byte(data)) {
|
if got := buf.ReadOnlyData(); !bytes.Equal(got, data) {
|
||||||
t.Fatalf("Buffer contains data %s, want %s", string(got), string(data))
|
t.Fatalf("Buffer contains data %s, want %s", string(got), string(data))
|
||||||
}
|
}
|
||||||
|
|
||||||
bufRef := buf.Ref()
|
buf.Ref()
|
||||||
if got := bufRef.ReadOnlyData(); !bytes.Equal(got, []byte(data)) {
|
if got := buf.ReadOnlyData(); !bytes.Equal(got, []byte(data)) {
|
||||||
t.Fatalf("New reference to the Buffer contains data %s, want %s", string(got), string(data))
|
t.Fatalf("New reference to the Buffer contains data %s, want %s", string(got), string(data))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify that the free function is not invoked when all references are yet
|
// Verify that the free function is not invoked when all references are yet
|
||||||
// to be freed.
|
// to be freed.
|
||||||
buf.Free()
|
buf.Free()
|
||||||
select {
|
if testPool.data == nil {
|
||||||
case <-testPool.putCh:
|
|
||||||
t.Fatalf("Free function called before all references freed")
|
t.Fatalf("Free function called before all references freed")
|
||||||
case <-time.After(defaultTestShortTimeout):
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify that the free function is invoked when all references are freed.
|
// Verify that the free function is invoked when all references are freed.
|
||||||
bufRef.Free()
|
buf.Free()
|
||||||
select {
|
if testPool.data != nil {
|
||||||
case got := <-testPool.putCh:
|
t.Fatalf("Free never called")
|
||||||
if !bytes.Equal(got, []byte(data)) {
|
|
||||||
t.Fatalf("Free function called with bytes %s, want %s", string(got), string(data))
|
|
||||||
}
|
|
||||||
case <-time.After(defaultTestTimeout):
|
|
||||||
t.Fatalf("Timeout waiting for Buffer to be freed")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s) TestBuffer_ReadOnlyDataAfterFree(t *testing.T) {
|
||||||
|
// Verify that reading before freeing does not panic.
|
||||||
|
buf := newBuffer([]byte("abcd"), mem.NopBufferPool{})
|
||||||
|
buf.ReadOnlyData()
|
||||||
|
|
||||||
|
buf.Free()
|
||||||
|
defer checkForPanic(t, "Cannot read freed buffer")
|
||||||
|
buf.ReadOnlyData()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s) TestBuffer_RefAfterFree(t *testing.T) {
|
||||||
|
// Verify that acquiring a ref before freeing does not panic.
|
||||||
|
buf := newBuffer([]byte("abcd"), mem.NopBufferPool{})
|
||||||
|
buf.Ref()
|
||||||
|
|
||||||
|
// This first call should not panc and bring the ref counter down to 1
|
||||||
|
buf.Free()
|
||||||
|
// This second call actually frees the buffer
|
||||||
|
buf.Free()
|
||||||
|
defer checkForPanic(t, "Cannot ref freed buffer")
|
||||||
|
buf.Ref()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s) TestBuffer_SplitAfterFree(t *testing.T) {
|
||||||
|
// Verify that splitting before freeing does not panic.
|
||||||
|
buf := newBuffer([]byte("abcd"), mem.NopBufferPool{})
|
||||||
|
buf, bufSplit := mem.SplitUnsafe(buf, 2)
|
||||||
|
|
||||||
|
bufSplit.Free()
|
||||||
|
buf.Free()
|
||||||
|
defer checkForPanic(t, "Cannot split freed buffer")
|
||||||
|
mem.SplitUnsafe(buf, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
type poolFunc func(*[]byte)
|
||||||
|
|
||||||
|
func (p poolFunc) Get(length int) *[]byte {
|
||||||
|
panic("Get should never be called")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p poolFunc) Put(i *[]byte) {
|
||||||
|
p(i)
|
||||||
|
}
|
||||||
|
|
||||||
func (s) TestBuffer_Split(t *testing.T) {
|
func (s) TestBuffer_Split(t *testing.T) {
|
||||||
ready := false
|
ready := false
|
||||||
freed := false
|
freed := false
|
||||||
data := []byte{1, 2, 3, 4}
|
data := []byte{1, 2, 3, 4}
|
||||||
buf := mem.NewBuffer(data, func(bytes []byte) {
|
buf := mem.NewBuffer(&data, poolFunc(func(bytes *[]byte) {
|
||||||
if !ready {
|
if !ready {
|
||||||
t.Fatalf("Freed too early")
|
t.Fatalf("Freed too early")
|
||||||
}
|
}
|
||||||
freed = true
|
freed = true
|
||||||
})
|
}))
|
||||||
checkBufData := func(b *mem.Buffer, expected []byte) {
|
checkBufData := func(b mem.Buffer, expected []byte) {
|
||||||
|
t.Helper()
|
||||||
if !bytes.Equal(b.ReadOnlyData(), expected) {
|
if !bytes.Equal(b.ReadOnlyData(), expected) {
|
||||||
t.Fatalf("Buffer did not contain expected data %v, got %v", expected, b.ReadOnlyData())
|
t.Fatalf("Buffer did not contain expected data %v, got %v", expected, b.ReadOnlyData())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Take a ref of the original buffer
|
buf, split1 := mem.SplitUnsafe(buf, 2)
|
||||||
ref1 := buf.Ref()
|
|
||||||
|
|
||||||
split1 := buf.Split(2)
|
|
||||||
checkBufData(buf, data[:2])
|
checkBufData(buf, data[:2])
|
||||||
checkBufData(split1, data[2:])
|
checkBufData(split1, data[2:])
|
||||||
// Check that even though buf was split, the reference wasn't modified
|
|
||||||
checkBufData(ref1, data)
|
|
||||||
ref1.Free()
|
|
||||||
|
|
||||||
// Check that splitting the buffer more than once works as intended.
|
// Check that splitting the buffer more than once works as intended.
|
||||||
split2 := split1.Split(1)
|
split1, split2 := mem.SplitUnsafe(split1, 1)
|
||||||
checkBufData(split1, data[2:3])
|
checkBufData(split1, data[2:3])
|
||||||
checkBufData(split2, data[3:])
|
checkBufData(split2, data[3:])
|
||||||
|
|
||||||
|
|
@ -242,52 +270,9 @@ func checkForPanic(t *testing.T, wantErr string) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
r := recover()
|
r := recover()
|
||||||
if r == nil {
|
if r == nil {
|
||||||
t.Fatalf("Use after free dit not panic")
|
t.Fatalf("Use after free did not panic")
|
||||||
}
|
}
|
||||||
if r.(string) != wantErr {
|
if msg, ok := r.(string); !ok || msg != wantErr {
|
||||||
t.Fatalf("panic called with %v, want %s", r, wantErr)
|
t.Fatalf("panic called with %v, want %s", r, wantErr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s) TestBuffer_ReadOnlyDataAfterFree(t *testing.T) {
|
|
||||||
// Verify that reading before freeing does not panic.
|
|
||||||
buf := mem.NewBuffer([]byte("abcd"), nil)
|
|
||||||
buf.ReadOnlyData()
|
|
||||||
|
|
||||||
buf.Free()
|
|
||||||
defer checkForPanic(t, "Cannot read freed buffer")
|
|
||||||
buf.ReadOnlyData()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s) TestBuffer_RefAfterFree(t *testing.T) {
|
|
||||||
// Verify that acquiring a ref before freeing does not panic.
|
|
||||||
buf := mem.NewBuffer([]byte("abcd"), nil)
|
|
||||||
bufRef := buf.Ref()
|
|
||||||
defer bufRef.Free()
|
|
||||||
|
|
||||||
buf.Free()
|
|
||||||
defer checkForPanic(t, "Cannot ref freed buffer")
|
|
||||||
buf.Ref()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s) TestBuffer_SplitAfterFree(t *testing.T) {
|
|
||||||
// Verify that splitting before freeing does not panic.
|
|
||||||
buf := mem.NewBuffer([]byte("abcd"), nil)
|
|
||||||
bufSplit := buf.Split(2)
|
|
||||||
defer bufSplit.Free()
|
|
||||||
|
|
||||||
buf.Free()
|
|
||||||
defer checkForPanic(t, "Cannot split freed buffer")
|
|
||||||
buf.Split(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s) TestBuffer_FreeAfterFree(t *testing.T) {
|
|
||||||
buf := mem.NewBuffer([]byte("abcd"), nil)
|
|
||||||
if buf.Len() != 4 {
|
|
||||||
t.Fatalf("Buffer length is %d, want 4", buf.Len())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure that a double free does not panic.
|
|
||||||
buf.Free()
|
|
||||||
buf.Free()
|
|
||||||
}
|
|
||||||
|
|
|
||||||
28
preloader.go
28
preloader.go
|
|
@ -20,6 +20,7 @@ package grpc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/mem"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -31,9 +32,10 @@ import (
|
||||||
// later release.
|
// later release.
|
||||||
type PreparedMsg struct {
|
type PreparedMsg struct {
|
||||||
// Struct for preparing msg before sending them
|
// Struct for preparing msg before sending them
|
||||||
encodedData []byte
|
encodedData mem.BufferSlice
|
||||||
hdr []byte
|
hdr []byte
|
||||||
payload []byte
|
payload mem.BufferSlice
|
||||||
|
pf payloadFormat
|
||||||
}
|
}
|
||||||
|
|
||||||
// Encode marshalls and compresses the message using the codec and compressor for the stream.
|
// Encode marshalls and compresses the message using the codec and compressor for the stream.
|
||||||
|
|
@ -57,11 +59,27 @@ func (p *PreparedMsg) Encode(s Stream, msg any) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
p.encodedData = data
|
|
||||||
compData, err := compress(data, rpcInfo.preloaderInfo.cp, rpcInfo.preloaderInfo.comp)
|
materializedData := data.Materialize()
|
||||||
|
data.Free()
|
||||||
|
p.encodedData = mem.BufferSlice{mem.NewBuffer(&materializedData, nil)}
|
||||||
|
|
||||||
|
// TODO: it should be possible to grab the bufferPool from the underlying
|
||||||
|
// stream implementation with a type cast to its actual type (such as
|
||||||
|
// addrConnStream) and accessing the buffer pool directly.
|
||||||
|
var compData mem.BufferSlice
|
||||||
|
compData, p.pf, err = compress(p.encodedData, rpcInfo.preloaderInfo.cp, rpcInfo.preloaderInfo.comp, mem.DefaultBufferPool())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
p.hdr, p.payload = msgHeader(data, compData)
|
|
||||||
|
if p.pf.isCompressed() {
|
||||||
|
materializedCompData := compData.Materialize()
|
||||||
|
compData.Free()
|
||||||
|
compData = mem.BufferSlice{mem.NewBuffer(&materializedCompData, nil)}
|
||||||
|
}
|
||||||
|
|
||||||
|
p.hdr, p.payload = msgHeader(p.encodedData, compData, p.pf)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
267
rpc_util.go
267
rpc_util.go
|
|
@ -19,7 +19,6 @@
|
||||||
package grpc
|
package grpc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"compress/gzip"
|
"compress/gzip"
|
||||||
"context"
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
|
@ -35,6 +34,7 @@ import (
|
||||||
"google.golang.org/grpc/encoding"
|
"google.golang.org/grpc/encoding"
|
||||||
"google.golang.org/grpc/encoding/proto"
|
"google.golang.org/grpc/encoding/proto"
|
||||||
"google.golang.org/grpc/internal/transport"
|
"google.golang.org/grpc/internal/transport"
|
||||||
|
"google.golang.org/grpc/mem"
|
||||||
"google.golang.org/grpc/metadata"
|
"google.golang.org/grpc/metadata"
|
||||||
"google.golang.org/grpc/peer"
|
"google.golang.org/grpc/peer"
|
||||||
"google.golang.org/grpc/stats"
|
"google.golang.org/grpc/stats"
|
||||||
|
|
@ -511,11 +511,51 @@ type ForceCodecCallOption struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o ForceCodecCallOption) before(c *callInfo) error {
|
func (o ForceCodecCallOption) before(c *callInfo) error {
|
||||||
c.codec = o.Codec
|
c.codec = newCodecV1Bridge(o.Codec)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
func (o ForceCodecCallOption) after(c *callInfo, attempt *csAttempt) {}
|
func (o ForceCodecCallOption) after(c *callInfo, attempt *csAttempt) {}
|
||||||
|
|
||||||
|
// ForceCodecV2 returns a CallOption that will set codec to be used for all
|
||||||
|
// request and response messages for a call. The result of calling Name() will
|
||||||
|
// be used as the content-subtype after converting to lowercase, unless
|
||||||
|
// CallContentSubtype is also used.
|
||||||
|
//
|
||||||
|
// See Content-Type on
|
||||||
|
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
|
||||||
|
// more details. Also see the documentation on RegisterCodec and
|
||||||
|
// CallContentSubtype for more details on the interaction between Codec and
|
||||||
|
// content-subtype.
|
||||||
|
//
|
||||||
|
// This function is provided for advanced users; prefer to use only
|
||||||
|
// CallContentSubtype to select a registered codec instead.
|
||||||
|
//
|
||||||
|
// # Experimental
|
||||||
|
//
|
||||||
|
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
|
||||||
|
// later release.
|
||||||
|
func ForceCodecV2(codec encoding.CodecV2) CallOption {
|
||||||
|
return ForceCodecV2CallOption{CodecV2: codec}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ForceCodecV2CallOption is a CallOption that indicates the codec used for
|
||||||
|
// marshaling messages.
|
||||||
|
//
|
||||||
|
// # Experimental
|
||||||
|
//
|
||||||
|
// Notice: This type is EXPERIMENTAL and may be changed or removed in a
|
||||||
|
// later release.
|
||||||
|
type ForceCodecV2CallOption struct {
|
||||||
|
CodecV2 encoding.CodecV2
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o ForceCodecV2CallOption) before(c *callInfo) error {
|
||||||
|
c.codec = o.CodecV2
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o ForceCodecV2CallOption) after(c *callInfo, attempt *csAttempt) {}
|
||||||
|
|
||||||
// CallCustomCodec behaves like ForceCodec, but accepts a grpc.Codec instead of
|
// CallCustomCodec behaves like ForceCodec, but accepts a grpc.Codec instead of
|
||||||
// an encoding.Codec.
|
// an encoding.Codec.
|
||||||
//
|
//
|
||||||
|
|
@ -536,7 +576,7 @@ type CustomCodecCallOption struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o CustomCodecCallOption) before(c *callInfo) error {
|
func (o CustomCodecCallOption) before(c *callInfo) error {
|
||||||
c.codec = o.Codec
|
c.codec = newCodecV0Bridge(o.Codec)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
func (o CustomCodecCallOption) after(c *callInfo, attempt *csAttempt) {}
|
func (o CustomCodecCallOption) after(c *callInfo, attempt *csAttempt) {}
|
||||||
|
|
@ -577,19 +617,28 @@ const (
|
||||||
compressionMade payloadFormat = 1 // compressed
|
compressionMade payloadFormat = 1 // compressed
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func (pf payloadFormat) isCompressed() bool {
|
||||||
|
return pf == compressionMade
|
||||||
|
}
|
||||||
|
|
||||||
|
type streamReader interface {
|
||||||
|
ReadHeader(header []byte) error
|
||||||
|
Read(n int) (mem.BufferSlice, error)
|
||||||
|
}
|
||||||
|
|
||||||
// parser reads complete gRPC messages from the underlying reader.
|
// parser reads complete gRPC messages from the underlying reader.
|
||||||
type parser struct {
|
type parser struct {
|
||||||
// r is the underlying reader.
|
// r is the underlying reader.
|
||||||
// See the comment on recvMsg for the permissible
|
// See the comment on recvMsg for the permissible
|
||||||
// error types.
|
// error types.
|
||||||
r io.Reader
|
r streamReader
|
||||||
|
|
||||||
// The header of a gRPC message. Find more detail at
|
// The header of a gRPC message. Find more detail at
|
||||||
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md
|
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md
|
||||||
header [5]byte
|
header [5]byte
|
||||||
|
|
||||||
// recvBufferPool is the pool of shared receive buffers.
|
// bufferPool is the pool of shared receive buffers.
|
||||||
recvBufferPool SharedBufferPool
|
bufferPool mem.BufferPool
|
||||||
}
|
}
|
||||||
|
|
||||||
// recvMsg reads a complete gRPC message from the stream.
|
// recvMsg reads a complete gRPC message from the stream.
|
||||||
|
|
@ -604,14 +653,15 @@ type parser struct {
|
||||||
// - an error from the status package
|
// - an error from the status package
|
||||||
//
|
//
|
||||||
// No other error values or types must be returned, which also means
|
// No other error values or types must be returned, which also means
|
||||||
// that the underlying io.Reader must not return an incompatible
|
// that the underlying streamReader must not return an incompatible
|
||||||
// error.
|
// error.
|
||||||
func (p *parser) recvMsg(maxReceiveMessageSize int) (pf payloadFormat, msg []byte, err error) {
|
func (p *parser) recvMsg(maxReceiveMessageSize int) (payloadFormat, mem.BufferSlice, error) {
|
||||||
if _, err := p.r.Read(p.header[:]); err != nil {
|
err := p.r.ReadHeader(p.header[:])
|
||||||
|
if err != nil {
|
||||||
return 0, nil, err
|
return 0, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
pf = payloadFormat(p.header[0])
|
pf := payloadFormat(p.header[0])
|
||||||
length := binary.BigEndian.Uint32(p.header[1:])
|
length := binary.BigEndian.Uint32(p.header[1:])
|
||||||
|
|
||||||
if length == 0 {
|
if length == 0 {
|
||||||
|
|
@ -623,20 +673,21 @@ func (p *parser) recvMsg(maxReceiveMessageSize int) (pf payloadFormat, msg []byt
|
||||||
if int(length) > maxReceiveMessageSize {
|
if int(length) > maxReceiveMessageSize {
|
||||||
return 0, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", length, maxReceiveMessageSize)
|
return 0, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", length, maxReceiveMessageSize)
|
||||||
}
|
}
|
||||||
msg = p.recvBufferPool.Get(int(length))
|
|
||||||
if _, err := p.r.Read(msg); err != nil {
|
data, err := p.r.Read(int(length))
|
||||||
|
if err != nil {
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
err = io.ErrUnexpectedEOF
|
err = io.ErrUnexpectedEOF
|
||||||
}
|
}
|
||||||
return 0, nil, err
|
return 0, nil, err
|
||||||
}
|
}
|
||||||
return pf, msg, nil
|
return pf, data, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// encode serializes msg and returns a buffer containing the message, or an
|
// encode serializes msg and returns a buffer containing the message, or an
|
||||||
// error if it is too large to be transmitted by grpc. If msg is nil, it
|
// error if it is too large to be transmitted by grpc. If msg is nil, it
|
||||||
// generates an empty message.
|
// generates an empty message.
|
||||||
func encode(c baseCodec, msg any) ([]byte, error) {
|
func encode(c baseCodec, msg any) (mem.BufferSlice, error) {
|
||||||
if msg == nil { // NOTE: typed nils will not be caught by this check
|
if msg == nil { // NOTE: typed nils will not be caught by this check
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
@ -644,7 +695,8 @@ func encode(c baseCodec, msg any) ([]byte, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error())
|
return nil, status.Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error())
|
||||||
}
|
}
|
||||||
if uint(len(b)) > math.MaxUint32 {
|
if uint(b.Len()) > math.MaxUint32 {
|
||||||
|
b.Free()
|
||||||
return nil, status.Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b))
|
return nil, status.Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b))
|
||||||
}
|
}
|
||||||
return b, nil
|
return b, nil
|
||||||
|
|
@ -655,34 +707,41 @@ func encode(c baseCodec, msg any) ([]byte, error) {
|
||||||
// indicating no compression was done.
|
// indicating no compression was done.
|
||||||
//
|
//
|
||||||
// TODO(dfawley): eliminate cp parameter by wrapping Compressor in an encoding.Compressor.
|
// TODO(dfawley): eliminate cp parameter by wrapping Compressor in an encoding.Compressor.
|
||||||
func compress(in []byte, cp Compressor, compressor encoding.Compressor) ([]byte, error) {
|
func compress(in mem.BufferSlice, cp Compressor, compressor encoding.Compressor, pool mem.BufferPool) (mem.BufferSlice, payloadFormat, error) {
|
||||||
if compressor == nil && cp == nil {
|
if (compressor == nil && cp == nil) || in.Len() == 0 {
|
||||||
return nil, nil
|
return nil, compressionNone, nil
|
||||||
}
|
|
||||||
if len(in) == 0 {
|
|
||||||
return nil, nil
|
|
||||||
}
|
}
|
||||||
|
var out mem.BufferSlice
|
||||||
|
w := mem.NewWriter(&out, pool)
|
||||||
wrapErr := func(err error) error {
|
wrapErr := func(err error) error {
|
||||||
|
out.Free()
|
||||||
return status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
|
return status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
|
||||||
}
|
}
|
||||||
cbuf := &bytes.Buffer{}
|
|
||||||
if compressor != nil {
|
if compressor != nil {
|
||||||
z, err := compressor.Compress(cbuf)
|
z, err := compressor.Compress(w)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, wrapErr(err)
|
return nil, 0, wrapErr(err)
|
||||||
}
|
}
|
||||||
if _, err := z.Write(in); err != nil {
|
for _, b := range in {
|
||||||
return nil, wrapErr(err)
|
if _, err := z.Write(b.ReadOnlyData()); err != nil {
|
||||||
|
return nil, 0, wrapErr(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if err := z.Close(); err != nil {
|
if err := z.Close(); err != nil {
|
||||||
return nil, wrapErr(err)
|
return nil, 0, wrapErr(err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if err := cp.Do(cbuf, in); err != nil {
|
// This is obviously really inefficient since it fully materializes the data, but
|
||||||
return nil, wrapErr(err)
|
// there is no way around this with the old Compressor API. At least it attempts
|
||||||
|
// to return the buffer to the provider, in the hopes it can be reused (maybe
|
||||||
|
// even by a subsequent call to this very function).
|
||||||
|
buf := in.MaterializeToBuffer(pool)
|
||||||
|
defer buf.Free()
|
||||||
|
if err := cp.Do(w, buf.ReadOnlyData()); err != nil {
|
||||||
|
return nil, 0, wrapErr(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return cbuf.Bytes(), nil
|
return out, compressionMade, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|
@ -693,28 +752,31 @@ const (
|
||||||
|
|
||||||
// msgHeader returns a 5-byte header for the message being transmitted and the
|
// msgHeader returns a 5-byte header for the message being transmitted and the
|
||||||
// payload, which is compData if non-nil or data otherwise.
|
// payload, which is compData if non-nil or data otherwise.
|
||||||
func msgHeader(data, compData []byte) (hdr []byte, payload []byte) {
|
func msgHeader(data, compData mem.BufferSlice, pf payloadFormat) (hdr []byte, payload mem.BufferSlice) {
|
||||||
hdr = make([]byte, headerLen)
|
hdr = make([]byte, headerLen)
|
||||||
if compData != nil {
|
hdr[0] = byte(pf)
|
||||||
hdr[0] = byte(compressionMade)
|
|
||||||
data = compData
|
var length uint32
|
||||||
|
if pf.isCompressed() {
|
||||||
|
length = uint32(compData.Len())
|
||||||
|
payload = compData
|
||||||
} else {
|
} else {
|
||||||
hdr[0] = byte(compressionNone)
|
length = uint32(data.Len())
|
||||||
|
payload = data
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write length of payload into buf
|
// Write length of payload into buf
|
||||||
binary.BigEndian.PutUint32(hdr[payloadLen:], uint32(len(data)))
|
binary.BigEndian.PutUint32(hdr[payloadLen:], length)
|
||||||
return hdr, data
|
return hdr, payload
|
||||||
}
|
}
|
||||||
|
|
||||||
func outPayload(client bool, msg any, data, payload []byte, t time.Time) *stats.OutPayload {
|
func outPayload(client bool, msg any, dataLength, payloadLength int, t time.Time) *stats.OutPayload {
|
||||||
return &stats.OutPayload{
|
return &stats.OutPayload{
|
||||||
Client: client,
|
Client: client,
|
||||||
Payload: msg,
|
Payload: msg,
|
||||||
Data: data,
|
Length: dataLength,
|
||||||
Length: len(data),
|
WireLength: payloadLength + headerLen,
|
||||||
WireLength: len(payload) + headerLen,
|
CompressedLength: payloadLength,
|
||||||
CompressedLength: len(payload),
|
|
||||||
SentTime: t,
|
SentTime: t,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -741,7 +803,13 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool
|
||||||
|
|
||||||
type payloadInfo struct {
|
type payloadInfo struct {
|
||||||
compressedLength int // The compressed length got from wire.
|
compressedLength int // The compressed length got from wire.
|
||||||
uncompressedBytes []byte
|
uncompressedBytes mem.BufferSlice
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *payloadInfo) free() {
|
||||||
|
if p != nil && p.uncompressedBytes != nil {
|
||||||
|
p.uncompressedBytes.Free()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// recvAndDecompress reads a message from the stream, decompressing it if necessary.
|
// recvAndDecompress reads a message from the stream, decompressing it if necessary.
|
||||||
|
|
@ -751,96 +819,113 @@ type payloadInfo struct {
|
||||||
// TODO: Refactor this function to reduce the number of arguments.
|
// TODO: Refactor this function to reduce the number of arguments.
|
||||||
// See: https://google.github.io/styleguide/go/best-practices.html#function-argument-lists
|
// See: https://google.github.io/styleguide/go/best-practices.html#function-argument-lists
|
||||||
func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool,
|
func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool,
|
||||||
) (uncompressedBuf []byte, cancel func(), err error) {
|
) (out mem.BufferSlice, err error) {
|
||||||
pf, compressedBuf, err := p.recvMsg(maxReceiveMessageSize)
|
pf, compressed, err := p.recvMsg(maxReceiveMessageSize)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
compressedLength := compressed.Len()
|
||||||
|
|
||||||
if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil, isServer); st != nil {
|
if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil, isServer); st != nil {
|
||||||
return nil, nil, st.Err()
|
compressed.Free()
|
||||||
|
return nil, st.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
var size int
|
var size int
|
||||||
if pf == compressionMade {
|
if pf.isCompressed() {
|
||||||
|
defer compressed.Free()
|
||||||
|
|
||||||
// To match legacy behavior, if the decompressor is set by WithDecompressor or RPCDecompressor,
|
// To match legacy behavior, if the decompressor is set by WithDecompressor or RPCDecompressor,
|
||||||
// use this decompressor as the default.
|
// use this decompressor as the default.
|
||||||
if dc != nil {
|
if dc != nil {
|
||||||
uncompressedBuf, err = dc.Do(bytes.NewReader(compressedBuf))
|
var uncompressedBuf []byte
|
||||||
|
uncompressedBuf, err = dc.Do(compressed.Reader())
|
||||||
|
if err == nil {
|
||||||
|
out = mem.BufferSlice{mem.NewBuffer(&uncompressedBuf, nil)}
|
||||||
|
}
|
||||||
size = len(uncompressedBuf)
|
size = len(uncompressedBuf)
|
||||||
} else {
|
} else {
|
||||||
uncompressedBuf, size, err = decompress(compressor, compressedBuf, maxReceiveMessageSize)
|
out, size, err = decompress(compressor, compressed, maxReceiveMessageSize, p.bufferPool)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err)
|
return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err)
|
||||||
}
|
}
|
||||||
if size > maxReceiveMessageSize {
|
if size > maxReceiveMessageSize {
|
||||||
|
out.Free()
|
||||||
// TODO: Revisit the error code. Currently keep it consistent with java
|
// TODO: Revisit the error code. Currently keep it consistent with java
|
||||||
// implementation.
|
// implementation.
|
||||||
return nil, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max (%d vs. %d)", size, maxReceiveMessageSize)
|
return nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max (%d vs. %d)", size, maxReceiveMessageSize)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
uncompressedBuf = compressedBuf
|
out = compressed
|
||||||
}
|
}
|
||||||
|
|
||||||
if payInfo != nil {
|
if payInfo != nil {
|
||||||
payInfo.compressedLength = len(compressedBuf)
|
payInfo.compressedLength = compressedLength
|
||||||
payInfo.uncompressedBytes = uncompressedBuf
|
out.Ref()
|
||||||
|
payInfo.uncompressedBytes = out
|
||||||
cancel = func() {}
|
|
||||||
} else {
|
|
||||||
cancel = func() {
|
|
||||||
p.recvBufferPool.Put(&compressedBuf)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return uncompressedBuf, cancel, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Using compressor, decompress d, returning data and size.
|
// Using compressor, decompress d, returning data and size.
|
||||||
// Optionally, if data will be over maxReceiveMessageSize, just return the size.
|
// Optionally, if data will be over maxReceiveMessageSize, just return the size.
|
||||||
func decompress(compressor encoding.Compressor, d []byte, maxReceiveMessageSize int) ([]byte, int, error) {
|
func decompress(compressor encoding.Compressor, d mem.BufferSlice, maxReceiveMessageSize int, pool mem.BufferPool) (mem.BufferSlice, int, error) {
|
||||||
dcReader, err := compressor.Decompress(bytes.NewReader(d))
|
dcReader, err := compressor.Decompress(d.Reader())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
if sizer, ok := compressor.(interface {
|
|
||||||
DecompressedSize(compressedBytes []byte) int
|
// TODO: Can/should this still be preserved with the new BufferSlice API? Are
|
||||||
}); ok {
|
// there any actual benefits to allocating a single large buffer instead of
|
||||||
if size := sizer.DecompressedSize(d); size >= 0 {
|
// multiple smaller ones?
|
||||||
if size > maxReceiveMessageSize {
|
//if sizer, ok := compressor.(interface {
|
||||||
return nil, size, nil
|
// DecompressedSize(compressedBytes []byte) int
|
||||||
}
|
//}); ok {
|
||||||
// size is used as an estimate to size the buffer, but we
|
// if size := sizer.DecompressedSize(d); size >= 0 {
|
||||||
// will read more data if available.
|
// if size > maxReceiveMessageSize {
|
||||||
// +MinRead so ReadFrom will not reallocate if size is correct.
|
// return nil, size, nil
|
||||||
//
|
// }
|
||||||
// TODO: If we ensure that the buffer size is the same as the DecompressedSize,
|
// // size is used as an estimate to size the buffer, but we
|
||||||
// we can also utilize the recv buffer pool here.
|
// // will read more data if available.
|
||||||
buf := bytes.NewBuffer(make([]byte, 0, size+bytes.MinRead))
|
// // +MinRead so ReadFrom will not reallocate if size is correct.
|
||||||
bytesRead, err := buf.ReadFrom(io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1))
|
// //
|
||||||
return buf.Bytes(), int(bytesRead), err
|
// // TODO: If we ensure that the buffer size is the same as the DecompressedSize,
|
||||||
}
|
// // we can also utilize the recv buffer pool here.
|
||||||
|
// buf := bytes.NewBuffer(make([]byte, 0, size+bytes.MinRead))
|
||||||
|
// bytesRead, err := buf.ReadFrom(io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1))
|
||||||
|
// return buf.Bytes(), int(bytesRead), err
|
||||||
|
// }
|
||||||
|
//}
|
||||||
|
|
||||||
|
var out mem.BufferSlice
|
||||||
|
_, err = io.Copy(mem.NewWriter(&out, pool), io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1))
|
||||||
|
if err != nil {
|
||||||
|
out.Free()
|
||||||
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
// Read from LimitReader with limit max+1. So if the underlying
|
return out, out.Len(), nil
|
||||||
// reader is over limit, the result will be bigger than max.
|
|
||||||
d, err = io.ReadAll(io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1))
|
|
||||||
return d, len(d), err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// For the two compressor parameters, both should not be set, but if they are,
|
// For the two compressor parameters, both should not be set, but if they are,
|
||||||
// dc takes precedence over compressor.
|
// dc takes precedence over compressor.
|
||||||
// TODO(dfawley): wrap the old compressor/decompressor using the new API?
|
// TODO(dfawley): wrap the old compressor/decompressor using the new API?
|
||||||
func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool) error {
|
func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool) error {
|
||||||
buf, cancel, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor, isServer)
|
data, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor, isServer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
if err := c.Unmarshal(buf, m); err != nil {
|
// If the codec wants its own reference to the data, it can get it. Otherwise, always
|
||||||
|
// free the buffers.
|
||||||
|
defer data.Free()
|
||||||
|
|
||||||
|
if err := c.Unmarshal(data, m); err != nil {
|
||||||
return status.Errorf(codes.Internal, "grpc: failed to unmarshal the received message: %v", err)
|
return status.Errorf(codes.Internal, "grpc: failed to unmarshal the received message: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -943,7 +1028,7 @@ func setCallInfoCodec(c *callInfo) error {
|
||||||
// encoding.Codec (Name vs. String method name). We only support
|
// encoding.Codec (Name vs. String method name). We only support
|
||||||
// setting content subtype from encoding.Codec to avoid a behavior
|
// setting content subtype from encoding.Codec to avoid a behavior
|
||||||
// change with the deprecated version.
|
// change with the deprecated version.
|
||||||
if ec, ok := c.codec.(encoding.Codec); ok {
|
if ec, ok := c.codec.(encoding.CodecV2); ok {
|
||||||
c.contentSubtype = strings.ToLower(ec.Name())
|
c.contentSubtype = strings.ToLower(ec.Name())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -952,12 +1037,12 @@ func setCallInfoCodec(c *callInfo) error {
|
||||||
|
|
||||||
if c.contentSubtype == "" {
|
if c.contentSubtype == "" {
|
||||||
// No codec specified in CallOptions; use proto by default.
|
// No codec specified in CallOptions; use proto by default.
|
||||||
c.codec = encoding.GetCodec(proto.Name)
|
c.codec = getCodec(proto.Name)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// c.contentSubtype is already lowercased in CallContentSubtype
|
// c.contentSubtype is already lowercased in CallContentSubtype
|
||||||
c.codec = encoding.GetCodec(c.contentSubtype)
|
c.codec = getCodec(c.contentSubtype)
|
||||||
if c.codec == nil {
|
if c.codec == nil {
|
||||||
return status.Errorf(codes.Internal, "no codec registered for content-subtype %s", c.contentSubtype)
|
return status.Errorf(codes.Internal, "no codec registered for content-subtype %s", c.contentSubtype)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -27,21 +27,45 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
"google.golang.org/grpc/encoding"
|
|
||||||
protoenc "google.golang.org/grpc/encoding/proto"
|
protoenc "google.golang.org/grpc/encoding/proto"
|
||||||
"google.golang.org/grpc/internal/testutils"
|
"google.golang.org/grpc/internal/testutils"
|
||||||
"google.golang.org/grpc/internal/transport"
|
"google.golang.org/grpc/internal/transport"
|
||||||
|
"google.golang.org/grpc/mem"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
perfpb "google.golang.org/grpc/test/codec_perf"
|
perfpb "google.golang.org/grpc/test/codec_perf"
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
type fullReader struct {
|
type fullReader struct {
|
||||||
reader io.Reader
|
data []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f fullReader) Read(p []byte) (int, error) {
|
func (f *fullReader) ReadHeader(header []byte) error {
|
||||||
return io.ReadFull(f.reader, p)
|
buf, err := f.Read(len(header))
|
||||||
|
defer buf.Free()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
buf.CopyTo(header)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fullReader) Read(n int) (mem.BufferSlice, error) {
|
||||||
|
if len(f.data) == 0 {
|
||||||
|
return nil, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(f.data) < n {
|
||||||
|
data := f.data
|
||||||
|
f.data = nil
|
||||||
|
return mem.BufferSlice{mem.NewBuffer(&data, nil)}, io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := f.data[:n]
|
||||||
|
f.data = f.data[n:]
|
||||||
|
|
||||||
|
return mem.BufferSlice{mem.NewBuffer(&buf, nil)}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ CallOption = EmptyCallOption{} // ensure EmptyCallOption implements the interface
|
var _ CallOption = EmptyCallOption{} // ensure EmptyCallOption implements the interface
|
||||||
|
|
@ -64,10 +88,10 @@ func (s) TestSimpleParsing(t *testing.T) {
|
||||||
// Check that messages with length >= 2^24 are parsed.
|
// Check that messages with length >= 2^24 are parsed.
|
||||||
{append([]byte{0, 1, 0, 0, 0}, bigMsg...), nil, bigMsg, compressionNone},
|
{append([]byte{0, 1, 0, 0, 0}, bigMsg...), nil, bigMsg, compressionNone},
|
||||||
} {
|
} {
|
||||||
buf := fullReader{bytes.NewReader(test.p)}
|
buf := &fullReader{test.p}
|
||||||
parser := &parser{r: buf, recvBufferPool: nopBufferPool{}}
|
parser := &parser{r: buf, bufferPool: mem.DefaultBufferPool()}
|
||||||
pt, b, err := parser.recvMsg(math.MaxInt32)
|
pt, b, err := parser.recvMsg(math.MaxInt32)
|
||||||
if err != test.err || !bytes.Equal(b, test.b) || pt != test.pt {
|
if err != test.err || !bytes.Equal(b.Materialize(), test.b) || pt != test.pt {
|
||||||
t.Fatalf("parser{%v}.recvMsg(_) = %v, %v, %v\nwant %v, %v, %v", test.p, pt, b, err, test.pt, test.b, test.err)
|
t.Fatalf("parser{%v}.recvMsg(_) = %v, %v, %v\nwant %v, %v, %v", test.p, pt, b, err, test.pt, test.b, test.err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -76,8 +100,8 @@ func (s) TestSimpleParsing(t *testing.T) {
|
||||||
func (s) TestMultipleParsing(t *testing.T) {
|
func (s) TestMultipleParsing(t *testing.T) {
|
||||||
// Set a byte stream consists of 3 messages with their headers.
|
// Set a byte stream consists of 3 messages with their headers.
|
||||||
p := []byte{0, 0, 0, 0, 1, 'a', 0, 0, 0, 0, 2, 'b', 'c', 0, 0, 0, 0, 1, 'd'}
|
p := []byte{0, 0, 0, 0, 1, 'a', 0, 0, 0, 0, 2, 'b', 'c', 0, 0, 0, 0, 1, 'd'}
|
||||||
b := fullReader{bytes.NewReader(p)}
|
b := &fullReader{p}
|
||||||
parser := &parser{r: b, recvBufferPool: nopBufferPool{}}
|
parser := &parser{r: b, bufferPool: mem.DefaultBufferPool()}
|
||||||
|
|
||||||
wantRecvs := []struct {
|
wantRecvs := []struct {
|
||||||
pt payloadFormat
|
pt payloadFormat
|
||||||
|
|
@ -89,7 +113,7 @@ func (s) TestMultipleParsing(t *testing.T) {
|
||||||
}
|
}
|
||||||
for i, want := range wantRecvs {
|
for i, want := range wantRecvs {
|
||||||
pt, data, err := parser.recvMsg(math.MaxInt32)
|
pt, data, err := parser.recvMsg(math.MaxInt32)
|
||||||
if err != nil || pt != want.pt || !reflect.DeepEqual(data, want.data) {
|
if err != nil || pt != want.pt || !reflect.DeepEqual(data.Materialize(), want.data) {
|
||||||
t.Fatalf("after %d calls, parser{%v}.recvMsg(_) = %v, %v, %v\nwant %v, %v, <nil>",
|
t.Fatalf("after %d calls, parser{%v}.recvMsg(_) = %v, %v, %v\nwant %v, %v, <nil>",
|
||||||
i, p, pt, data, err, want.pt, want.data)
|
i, p, pt, data, err, want.pt, want.data)
|
||||||
}
|
}
|
||||||
|
|
@ -113,12 +137,12 @@ func (s) TestEncode(t *testing.T) {
|
||||||
}{
|
}{
|
||||||
{nil, []byte{0, 0, 0, 0, 0}, []byte{}, nil},
|
{nil, []byte{0, 0, 0, 0, 0}, []byte{}, nil},
|
||||||
} {
|
} {
|
||||||
data, err := encode(encoding.GetCodec(protoenc.Name), test.msg)
|
data, err := encode(getCodec(protoenc.Name), test.msg)
|
||||||
if err != test.err || !bytes.Equal(data, test.data) {
|
if err != test.err || !bytes.Equal(data.Materialize(), test.data) {
|
||||||
t.Errorf("encode(_, %v) = %v, %v; want %v, %v", test.msg, data, err, test.data, test.err)
|
t.Errorf("encode(_, %v) = %v, %v; want %v, %v", test.msg, data, err, test.data, test.err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if hdr, _ := msgHeader(data, nil); !bytes.Equal(hdr, test.hdr) {
|
if hdr, _ := msgHeader(data, nil, compressionNone); !bytes.Equal(hdr, test.hdr) {
|
||||||
t.Errorf("msgHeader(%v, false) = %v; want %v", data, hdr, test.hdr)
|
t.Errorf("msgHeader(%v, false) = %v; want %v", data, hdr, test.hdr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -194,7 +218,7 @@ func (s) TestToRPCErr(t *testing.T) {
|
||||||
// bmEncode benchmarks encoding a Protocol Buffer message containing mSize
|
// bmEncode benchmarks encoding a Protocol Buffer message containing mSize
|
||||||
// bytes.
|
// bytes.
|
||||||
func bmEncode(b *testing.B, mSize int) {
|
func bmEncode(b *testing.B, mSize int) {
|
||||||
cdc := encoding.GetCodec(protoenc.Name)
|
cdc := getCodec(protoenc.Name)
|
||||||
msg := &perfpb.Buffer{Body: make([]byte, mSize)}
|
msg := &perfpb.Buffer{Body: make([]byte, mSize)}
|
||||||
encodeData, _ := encode(cdc, msg)
|
encodeData, _ := encode(cdc, msg)
|
||||||
encodedSz := int64(len(encodeData))
|
encodedSz := int64(len(encodeData))
|
||||||
|
|
|
||||||
96
server.go
96
server.go
|
|
@ -45,6 +45,7 @@ import (
|
||||||
"google.golang.org/grpc/internal/grpcutil"
|
"google.golang.org/grpc/internal/grpcutil"
|
||||||
"google.golang.org/grpc/internal/transport"
|
"google.golang.org/grpc/internal/transport"
|
||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
|
"google.golang.org/grpc/mem"
|
||||||
"google.golang.org/grpc/metadata"
|
"google.golang.org/grpc/metadata"
|
||||||
"google.golang.org/grpc/peer"
|
"google.golang.org/grpc/peer"
|
||||||
"google.golang.org/grpc/stats"
|
"google.golang.org/grpc/stats"
|
||||||
|
|
@ -80,7 +81,7 @@ func init() {
|
||||||
}
|
}
|
||||||
internal.BinaryLogger = binaryLogger
|
internal.BinaryLogger = binaryLogger
|
||||||
internal.JoinServerOptions = newJoinServerOption
|
internal.JoinServerOptions = newJoinServerOption
|
||||||
internal.RecvBufferPool = recvBufferPool
|
internal.BufferPool = bufferPool
|
||||||
}
|
}
|
||||||
|
|
||||||
var statusOK = status.New(codes.OK, "")
|
var statusOK = status.New(codes.OK, "")
|
||||||
|
|
@ -170,7 +171,7 @@ type serverOptions struct {
|
||||||
maxHeaderListSize *uint32
|
maxHeaderListSize *uint32
|
||||||
headerTableSize *uint32
|
headerTableSize *uint32
|
||||||
numServerWorkers uint32
|
numServerWorkers uint32
|
||||||
recvBufferPool SharedBufferPool
|
bufferPool mem.BufferPool
|
||||||
waitForHandlers bool
|
waitForHandlers bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -181,7 +182,7 @@ var defaultServerOptions = serverOptions{
|
||||||
connectionTimeout: 120 * time.Second,
|
connectionTimeout: 120 * time.Second,
|
||||||
writeBufferSize: defaultWriteBufSize,
|
writeBufferSize: defaultWriteBufSize,
|
||||||
readBufferSize: defaultReadBufSize,
|
readBufferSize: defaultReadBufSize,
|
||||||
recvBufferPool: nopBufferPool{},
|
bufferPool: mem.DefaultBufferPool(),
|
||||||
}
|
}
|
||||||
var globalServerOptions []ServerOption
|
var globalServerOptions []ServerOption
|
||||||
|
|
||||||
|
|
@ -313,7 +314,7 @@ func KeepaliveEnforcementPolicy(kep keepalive.EnforcementPolicy) ServerOption {
|
||||||
// Will be supported throughout 1.x.
|
// Will be supported throughout 1.x.
|
||||||
func CustomCodec(codec Codec) ServerOption {
|
func CustomCodec(codec Codec) ServerOption {
|
||||||
return newFuncServerOption(func(o *serverOptions) {
|
return newFuncServerOption(func(o *serverOptions) {
|
||||||
o.codec = codec
|
o.codec = newCodecV0Bridge(codec)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -342,7 +343,22 @@ func CustomCodec(codec Codec) ServerOption {
|
||||||
// later release.
|
// later release.
|
||||||
func ForceServerCodec(codec encoding.Codec) ServerOption {
|
func ForceServerCodec(codec encoding.Codec) ServerOption {
|
||||||
return newFuncServerOption(func(o *serverOptions) {
|
return newFuncServerOption(func(o *serverOptions) {
|
||||||
o.codec = codec
|
o.codec = newCodecV1Bridge(codec)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ForceServerCodecV2 is the equivalent of ForceServerCodec, but for the new
|
||||||
|
// CodecV2 interface.
|
||||||
|
//
|
||||||
|
// Will be supported throughout 1.x.
|
||||||
|
//
|
||||||
|
// # Experimental
|
||||||
|
//
|
||||||
|
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
|
||||||
|
// later release.
|
||||||
|
func ForceServerCodecV2(codecV2 encoding.CodecV2) ServerOption {
|
||||||
|
return newFuncServerOption(func(o *serverOptions) {
|
||||||
|
o.codec = codecV2
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -592,26 +608,9 @@ func WaitForHandlers(w bool) ServerOption {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// RecvBufferPool returns a ServerOption that configures the server
|
func bufferPool(bufferPool mem.BufferPool) ServerOption {
|
||||||
// to use the provided shared buffer pool for parsing incoming messages. Depending
|
|
||||||
// on the application's workload, this could result in reduced memory allocation.
|
|
||||||
//
|
|
||||||
// If you are unsure about how to implement a memory pool but want to utilize one,
|
|
||||||
// begin with grpc.NewSharedBufferPool.
|
|
||||||
//
|
|
||||||
// Note: The shared buffer pool feature will not be active if any of the following
|
|
||||||
// options are used: StatsHandler, EnableTracing, or binary logging. In such
|
|
||||||
// cases, the shared buffer pool will be ignored.
|
|
||||||
//
|
|
||||||
// Deprecated: use experimental.WithRecvBufferPool instead. Will be deleted in
|
|
||||||
// v1.60.0 or later.
|
|
||||||
func RecvBufferPool(bufferPool SharedBufferPool) ServerOption {
|
|
||||||
return recvBufferPool(bufferPool)
|
|
||||||
}
|
|
||||||
|
|
||||||
func recvBufferPool(bufferPool SharedBufferPool) ServerOption {
|
|
||||||
return newFuncServerOption(func(o *serverOptions) {
|
return newFuncServerOption(func(o *serverOptions) {
|
||||||
o.recvBufferPool = bufferPool
|
o.bufferPool = bufferPool
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -980,6 +979,7 @@ func (s *Server) newHTTP2Transport(c net.Conn) transport.ServerTransport {
|
||||||
ChannelzParent: s.channelz,
|
ChannelzParent: s.channelz,
|
||||||
MaxHeaderListSize: s.opts.maxHeaderListSize,
|
MaxHeaderListSize: s.opts.maxHeaderListSize,
|
||||||
HeaderTableSize: s.opts.headerTableSize,
|
HeaderTableSize: s.opts.headerTableSize,
|
||||||
|
BufferPool: s.opts.bufferPool,
|
||||||
}
|
}
|
||||||
st, err := transport.NewServerTransport(c, config)
|
st, err := transport.NewServerTransport(c, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -1072,7 +1072,7 @@ var _ http.Handler = (*Server)(nil)
|
||||||
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
|
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
|
||||||
// later release.
|
// later release.
|
||||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
st, err := transport.NewServerHandlerTransport(w, r, s.opts.statsHandlers)
|
st, err := transport.NewServerHandlerTransport(w, r, s.opts.statsHandlers, s.opts.bufferPool)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Errors returned from transport.NewServerHandlerTransport have
|
// Errors returned from transport.NewServerHandlerTransport have
|
||||||
// already been written to w.
|
// already been written to w.
|
||||||
|
|
@ -1142,20 +1142,35 @@ func (s *Server) sendResponse(ctx context.Context, t transport.ServerTransport,
|
||||||
channelz.Error(logger, s.channelz, "grpc: server failed to encode response: ", err)
|
channelz.Error(logger, s.channelz, "grpc: server failed to encode response: ", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
compData, err := compress(data, cp, comp)
|
|
||||||
|
compData, pf, err := compress(data, cp, comp, s.opts.bufferPool)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
data.Free()
|
||||||
channelz.Error(logger, s.channelz, "grpc: server failed to compress response: ", err)
|
channelz.Error(logger, s.channelz, "grpc: server failed to compress response: ", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
hdr, payload := msgHeader(data, compData)
|
|
||||||
|
hdr, payload := msgHeader(data, compData, pf)
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
compData.Free()
|
||||||
|
data.Free()
|
||||||
|
// payload does not need to be freed here, it is either data or compData, both of
|
||||||
|
// which are already freed.
|
||||||
|
}()
|
||||||
|
|
||||||
|
dataLen := data.Len()
|
||||||
|
payloadLen := payload.Len()
|
||||||
// TODO(dfawley): should we be checking len(data) instead?
|
// TODO(dfawley): should we be checking len(data) instead?
|
||||||
if len(payload) > s.opts.maxSendMessageSize {
|
if payloadLen > s.opts.maxSendMessageSize {
|
||||||
return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(payload), s.opts.maxSendMessageSize)
|
return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", payloadLen, s.opts.maxSendMessageSize)
|
||||||
}
|
}
|
||||||
err = t.Write(stream, hdr, payload, opts)
|
err = t.Write(stream, hdr, payload, opts)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
for _, sh := range s.opts.statsHandlers {
|
if len(s.opts.statsHandlers) != 0 {
|
||||||
sh.HandleRPC(ctx, outPayload(false, msg, data, payload, time.Now()))
|
for _, sh := range s.opts.statsHandlers {
|
||||||
|
sh.HandleRPC(ctx, outPayload(false, msg, dataLen, payloadLen, time.Now()))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
|
|
@ -1334,9 +1349,10 @@ func (s *Server) processUnaryRPC(ctx context.Context, t transport.ServerTranspor
|
||||||
var payInfo *payloadInfo
|
var payInfo *payloadInfo
|
||||||
if len(shs) != 0 || len(binlogs) != 0 {
|
if len(shs) != 0 || len(binlogs) != 0 {
|
||||||
payInfo = &payloadInfo{}
|
payInfo = &payloadInfo{}
|
||||||
|
defer payInfo.free()
|
||||||
}
|
}
|
||||||
|
|
||||||
d, cancel, err := recvAndDecompress(&parser{r: stream, recvBufferPool: s.opts.recvBufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp, true)
|
d, err := recvAndDecompress(&parser{r: stream, bufferPool: s.opts.bufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if e := t.WriteStatus(stream, status.Convert(err)); e != nil {
|
if e := t.WriteStatus(stream, status.Convert(err)); e != nil {
|
||||||
channelz.Warningf(logger, s.channelz, "grpc: Server.processUnaryRPC failed to write status: %v", e)
|
channelz.Warningf(logger, s.channelz, "grpc: Server.processUnaryRPC failed to write status: %v", e)
|
||||||
|
|
@ -1347,24 +1363,22 @@ func (s *Server) processUnaryRPC(ctx context.Context, t transport.ServerTranspor
|
||||||
t.IncrMsgRecv()
|
t.IncrMsgRecv()
|
||||||
}
|
}
|
||||||
df := func(v any) error {
|
df := func(v any) error {
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
if err := s.getCodec(stream.ContentSubtype()).Unmarshal(d, v); err != nil {
|
if err := s.getCodec(stream.ContentSubtype()).Unmarshal(d, v); err != nil {
|
||||||
return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err)
|
return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, sh := range shs {
|
for _, sh := range shs {
|
||||||
sh.HandleRPC(ctx, &stats.InPayload{
|
sh.HandleRPC(ctx, &stats.InPayload{
|
||||||
RecvTime: time.Now(),
|
RecvTime: time.Now(),
|
||||||
Payload: v,
|
Payload: v,
|
||||||
Length: len(d),
|
Length: d.Len(),
|
||||||
WireLength: payInfo.compressedLength + headerLen,
|
WireLength: payInfo.compressedLength + headerLen,
|
||||||
CompressedLength: payInfo.compressedLength,
|
CompressedLength: payInfo.compressedLength,
|
||||||
Data: d,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
if len(binlogs) != 0 {
|
if len(binlogs) != 0 {
|
||||||
cm := &binarylog.ClientMessage{
|
cm := &binarylog.ClientMessage{
|
||||||
Message: d,
|
Message: d.Materialize(),
|
||||||
}
|
}
|
||||||
for _, binlog := range binlogs {
|
for _, binlog := range binlogs {
|
||||||
binlog.Log(ctx, cm)
|
binlog.Log(ctx, cm)
|
||||||
|
|
@ -1548,7 +1562,7 @@ func (s *Server) processStreamingRPC(ctx context.Context, t transport.ServerTran
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
t: t,
|
t: t,
|
||||||
s: stream,
|
s: stream,
|
||||||
p: &parser{r: stream, recvBufferPool: s.opts.recvBufferPool},
|
p: &parser{r: stream, bufferPool: s.opts.bufferPool},
|
||||||
codec: s.getCodec(stream.ContentSubtype()),
|
codec: s.getCodec(stream.ContentSubtype()),
|
||||||
maxReceiveMessageSize: s.opts.maxReceiveMessageSize,
|
maxReceiveMessageSize: s.opts.maxReceiveMessageSize,
|
||||||
maxSendMessageSize: s.opts.maxSendMessageSize,
|
maxSendMessageSize: s.opts.maxSendMessageSize,
|
||||||
|
|
@ -1963,12 +1977,12 @@ func (s *Server) getCodec(contentSubtype string) baseCodec {
|
||||||
return s.opts.codec
|
return s.opts.codec
|
||||||
}
|
}
|
||||||
if contentSubtype == "" {
|
if contentSubtype == "" {
|
||||||
return encoding.GetCodec(proto.Name)
|
return getCodec(proto.Name)
|
||||||
}
|
}
|
||||||
codec := encoding.GetCodec(contentSubtype)
|
codec := getCodec(contentSubtype)
|
||||||
if codec == nil {
|
if codec == nil {
|
||||||
logger.Warningf("Unsupported codec %q. Defaulting to %q for now. This will start to fail in future releases.", contentSubtype, proto.Name)
|
logger.Warningf("Unsupported codec %q. Defaulting to %q for now. This will start to fail in future releases.", contentSubtype, proto.Name)
|
||||||
return encoding.GetCodec(proto.Name)
|
return getCodec(proto.Name)
|
||||||
}
|
}
|
||||||
return codec
|
return codec
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,154 +0,0 @@
|
||||||
/*
|
|
||||||
*
|
|
||||||
* Copyright 2023 gRPC authors.
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
|
|
||||||
package grpc
|
|
||||||
|
|
||||||
import "sync"
|
|
||||||
|
|
||||||
// SharedBufferPool is a pool of buffers that can be shared, resulting in
|
|
||||||
// decreased memory allocation. Currently, in gRPC-go, it is only utilized
|
|
||||||
// for parsing incoming messages.
|
|
||||||
//
|
|
||||||
// # Experimental
|
|
||||||
//
|
|
||||||
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
|
|
||||||
// later release.
|
|
||||||
type SharedBufferPool interface {
|
|
||||||
// Get returns a buffer with specified length from the pool.
|
|
||||||
//
|
|
||||||
// The returned byte slice may be not zero initialized.
|
|
||||||
Get(length int) []byte
|
|
||||||
|
|
||||||
// Put returns a buffer to the pool.
|
|
||||||
Put(*[]byte)
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewSharedBufferPool creates a simple SharedBufferPool with buckets
|
|
||||||
// of different sizes to optimize memory usage. This prevents the pool from
|
|
||||||
// wasting large amounts of memory, even when handling messages of varying sizes.
|
|
||||||
//
|
|
||||||
// # Experimental
|
|
||||||
//
|
|
||||||
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
|
|
||||||
// later release.
|
|
||||||
func NewSharedBufferPool() SharedBufferPool {
|
|
||||||
return &simpleSharedBufferPool{
|
|
||||||
pools: [poolArraySize]simpleSharedBufferChildPool{
|
|
||||||
newBytesPool(level0PoolMaxSize),
|
|
||||||
newBytesPool(level1PoolMaxSize),
|
|
||||||
newBytesPool(level2PoolMaxSize),
|
|
||||||
newBytesPool(level3PoolMaxSize),
|
|
||||||
newBytesPool(level4PoolMaxSize),
|
|
||||||
newBytesPool(0),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// simpleSharedBufferPool is a simple implementation of SharedBufferPool.
|
|
||||||
type simpleSharedBufferPool struct {
|
|
||||||
pools [poolArraySize]simpleSharedBufferChildPool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *simpleSharedBufferPool) Get(size int) []byte {
|
|
||||||
return p.pools[p.poolIdx(size)].Get(size)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *simpleSharedBufferPool) Put(bs *[]byte) {
|
|
||||||
p.pools[p.poolIdx(cap(*bs))].Put(bs)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *simpleSharedBufferPool) poolIdx(size int) int {
|
|
||||||
switch {
|
|
||||||
case size <= level0PoolMaxSize:
|
|
||||||
return level0PoolIdx
|
|
||||||
case size <= level1PoolMaxSize:
|
|
||||||
return level1PoolIdx
|
|
||||||
case size <= level2PoolMaxSize:
|
|
||||||
return level2PoolIdx
|
|
||||||
case size <= level3PoolMaxSize:
|
|
||||||
return level3PoolIdx
|
|
||||||
case size <= level4PoolMaxSize:
|
|
||||||
return level4PoolIdx
|
|
||||||
default:
|
|
||||||
return levelMaxPoolIdx
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
level0PoolMaxSize = 16 // 16 B
|
|
||||||
level1PoolMaxSize = level0PoolMaxSize * 16 // 256 B
|
|
||||||
level2PoolMaxSize = level1PoolMaxSize * 16 // 4 KB
|
|
||||||
level3PoolMaxSize = level2PoolMaxSize * 16 // 64 KB
|
|
||||||
level4PoolMaxSize = level3PoolMaxSize * 16 // 1 MB
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
level0PoolIdx = iota
|
|
||||||
level1PoolIdx
|
|
||||||
level2PoolIdx
|
|
||||||
level3PoolIdx
|
|
||||||
level4PoolIdx
|
|
||||||
levelMaxPoolIdx
|
|
||||||
poolArraySize
|
|
||||||
)
|
|
||||||
|
|
||||||
type simpleSharedBufferChildPool interface {
|
|
||||||
Get(size int) []byte
|
|
||||||
Put(any)
|
|
||||||
}
|
|
||||||
|
|
||||||
type bufferPool struct {
|
|
||||||
sync.Pool
|
|
||||||
|
|
||||||
defaultSize int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *bufferPool) Get(size int) []byte {
|
|
||||||
bs := p.Pool.Get().(*[]byte)
|
|
||||||
|
|
||||||
if cap(*bs) < size {
|
|
||||||
p.Pool.Put(bs)
|
|
||||||
|
|
||||||
return make([]byte, size)
|
|
||||||
}
|
|
||||||
|
|
||||||
return (*bs)[:size]
|
|
||||||
}
|
|
||||||
|
|
||||||
func newBytesPool(size int) simpleSharedBufferChildPool {
|
|
||||||
return &bufferPool{
|
|
||||||
Pool: sync.Pool{
|
|
||||||
New: func() any {
|
|
||||||
bs := make([]byte, size)
|
|
||||||
return &bs
|
|
||||||
},
|
|
||||||
},
|
|
||||||
defaultSize: size,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// nopBufferPool is a buffer pool just makes new buffer without pooling.
|
|
||||||
type nopBufferPool struct {
|
|
||||||
}
|
|
||||||
|
|
||||||
func (nopBufferPool) Get(length int) []byte {
|
|
||||||
return make([]byte, length)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (nopBufferPool) Put(*[]byte) {
|
|
||||||
}
|
|
||||||
|
|
@ -77,9 +77,6 @@ type InPayload struct {
|
||||||
// the call to HandleRPC which provides the InPayload returns and must be
|
// the call to HandleRPC which provides the InPayload returns and must be
|
||||||
// copied if needed later.
|
// copied if needed later.
|
||||||
Payload any
|
Payload any
|
||||||
// Data is the serialized message payload.
|
|
||||||
// Deprecated: Data will be removed in the next release.
|
|
||||||
Data []byte
|
|
||||||
|
|
||||||
// Length is the size of the uncompressed payload data. Does not include any
|
// Length is the size of the uncompressed payload data. Does not include any
|
||||||
// framing (gRPC or HTTP/2).
|
// framing (gRPC or HTTP/2).
|
||||||
|
|
@ -150,9 +147,6 @@ type OutPayload struct {
|
||||||
// the call to HandleRPC which provides the OutPayload returns and must be
|
// the call to HandleRPC which provides the OutPayload returns and must be
|
||||||
// copied if needed later.
|
// copied if needed later.
|
||||||
Payload any
|
Payload any
|
||||||
// Data is the serialized message payload.
|
|
||||||
// Deprecated: Data will be removed in the next release.
|
|
||||||
Data []byte
|
|
||||||
// Length is the size of the uncompressed payload data. Does not include any
|
// Length is the size of the uncompressed payload data. Does not include any
|
||||||
// framing (gRPC or HTTP/2).
|
// framing (gRPC or HTTP/2).
|
||||||
Length int
|
Length int
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
"google.golang.org/grpc/internal"
|
"google.golang.org/grpc/internal"
|
||||||
|
|
@ -38,6 +39,7 @@ import (
|
||||||
"google.golang.org/grpc/stats"
|
"google.golang.org/grpc/stats"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
|
"google.golang.org/protobuf/testing/protocmp"
|
||||||
|
|
||||||
testgrpc "google.golang.org/grpc/interop/grpc_testing"
|
testgrpc "google.golang.org/grpc/interop/grpc_testing"
|
||||||
testpb "google.golang.org/grpc/interop/grpc_testing"
|
testpb "google.golang.org/grpc/interop/grpc_testing"
|
||||||
|
|
@ -538,40 +540,29 @@ func checkInPayload(t *testing.T, d *gotData, e *expectedData) {
|
||||||
if d.ctx == nil {
|
if d.ctx == nil {
|
||||||
t.Fatalf("d.ctx = nil, want <non-nil>")
|
t.Fatalf("d.ctx = nil, want <non-nil>")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var idx *int
|
||||||
|
var payloads []proto.Message
|
||||||
if d.client {
|
if d.client {
|
||||||
b, err := proto.Marshal(e.responses[e.respIdx])
|
idx = &e.respIdx
|
||||||
if err != nil {
|
payloads = e.responses
|
||||||
t.Fatalf("failed to marshal message: %v", err)
|
|
||||||
}
|
|
||||||
if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.responses[e.respIdx]) {
|
|
||||||
t.Fatalf("st.Payload = %T, want %T", st.Payload, e.responses[e.respIdx])
|
|
||||||
}
|
|
||||||
e.respIdx++
|
|
||||||
if string(st.Data) != string(b) {
|
|
||||||
t.Fatalf("st.Data = %v, want %v", st.Data, b)
|
|
||||||
}
|
|
||||||
if st.Length != len(b) {
|
|
||||||
t.Fatalf("st.Length = %v, want %v", st.Length, len(b))
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
b, err := proto.Marshal(e.requests[e.reqIdx])
|
idx = &e.reqIdx
|
||||||
if err != nil {
|
payloads = e.requests
|
||||||
t.Fatalf("failed to marshal message: %v", err)
|
|
||||||
}
|
|
||||||
if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.requests[e.reqIdx]) {
|
|
||||||
t.Fatalf("st.Payload = %T, want %T", st.Payload, e.requests[e.reqIdx])
|
|
||||||
}
|
|
||||||
e.reqIdx++
|
|
||||||
if string(st.Data) != string(b) {
|
|
||||||
t.Fatalf("st.Data = %v, want %v", st.Data, b)
|
|
||||||
}
|
|
||||||
if st.Length != len(b) {
|
|
||||||
t.Fatalf("st.Length = %v, want %v", st.Length, len(b))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
wantPayload := payloads[*idx]
|
||||||
|
if diff := cmp.Diff(wantPayload, st.Payload.(proto.Message), protocmp.Transform()); diff != "" {
|
||||||
|
t.Fatalf("unexpected difference in st.Payload (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
*idx++
|
||||||
|
if st.Length != proto.Size(wantPayload) {
|
||||||
|
t.Fatalf("st.Length = %v, want %v", st.Length, proto.Size(wantPayload))
|
||||||
|
}
|
||||||
|
|
||||||
// Below are sanity checks that WireLength and RecvTime are populated.
|
// Below are sanity checks that WireLength and RecvTime are populated.
|
||||||
// TODO: check values of WireLength and RecvTime.
|
// TODO: check values of WireLength and RecvTime.
|
||||||
if len(st.Data) > 0 && st.CompressedLength == 0 {
|
if st.Length > 0 && st.CompressedLength == 0 {
|
||||||
t.Fatalf("st.WireLength = %v with non-empty data, want <non-zero>",
|
t.Fatalf("st.WireLength = %v with non-empty data, want <non-zero>",
|
||||||
st.CompressedLength)
|
st.CompressedLength)
|
||||||
}
|
}
|
||||||
|
|
@ -657,40 +648,29 @@ func checkOutPayload(t *testing.T, d *gotData, e *expectedData) {
|
||||||
if d.ctx == nil {
|
if d.ctx == nil {
|
||||||
t.Fatalf("d.ctx = nil, want <non-nil>")
|
t.Fatalf("d.ctx = nil, want <non-nil>")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var idx *int
|
||||||
|
var payloads []proto.Message
|
||||||
if d.client {
|
if d.client {
|
||||||
b, err := proto.Marshal(e.requests[e.reqIdx])
|
idx = &e.reqIdx
|
||||||
if err != nil {
|
payloads = e.requests
|
||||||
t.Fatalf("failed to marshal message: %v", err)
|
|
||||||
}
|
|
||||||
if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.requests[e.reqIdx]) {
|
|
||||||
t.Fatalf("st.Payload = %T, want %T", st.Payload, e.requests[e.reqIdx])
|
|
||||||
}
|
|
||||||
e.reqIdx++
|
|
||||||
if string(st.Data) != string(b) {
|
|
||||||
t.Fatalf("st.Data = %v, want %v", st.Data, b)
|
|
||||||
}
|
|
||||||
if st.Length != len(b) {
|
|
||||||
t.Fatalf("st.Length = %v, want %v", st.Length, len(b))
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
b, err := proto.Marshal(e.responses[e.respIdx])
|
idx = &e.respIdx
|
||||||
if err != nil {
|
payloads = e.responses
|
||||||
t.Fatalf("failed to marshal message: %v", err)
|
|
||||||
}
|
|
||||||
if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.responses[e.respIdx]) {
|
|
||||||
t.Fatalf("st.Payload = %T, want %T", st.Payload, e.responses[e.respIdx])
|
|
||||||
}
|
|
||||||
e.respIdx++
|
|
||||||
if string(st.Data) != string(b) {
|
|
||||||
t.Fatalf("st.Data = %v, want %v", st.Data, b)
|
|
||||||
}
|
|
||||||
if st.Length != len(b) {
|
|
||||||
t.Fatalf("st.Length = %v, want %v", st.Length, len(b))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
// Below are sanity checks that WireLength and SentTime are populated.
|
|
||||||
|
expectedPayload := payloads[*idx]
|
||||||
|
if !proto.Equal(st.Payload.(proto.Message), expectedPayload) {
|
||||||
|
t.Fatalf("st.Payload = %v, want %v", st.Payload, expectedPayload)
|
||||||
|
}
|
||||||
|
*idx++
|
||||||
|
if st.Length != proto.Size(expectedPayload) {
|
||||||
|
t.Fatalf("st.Length = %v, want %v", st.Length, proto.Size(expectedPayload))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Below are sanity checks that Length, CompressedLength and SentTime are populated.
|
||||||
// TODO: check values of WireLength and SentTime.
|
// TODO: check values of WireLength and SentTime.
|
||||||
if len(st.Data) > 0 && st.WireLength == 0 {
|
if st.Length > 0 && st.WireLength == 0 {
|
||||||
t.Fatalf("st.WireLength = %v with non-empty data, want <non-zero>",
|
t.Fatalf("st.WireLength = %v with non-empty data, want <non-zero>",
|
||||||
st.WireLength)
|
st.WireLength)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
181
stream.go
181
stream.go
|
|
@ -41,6 +41,7 @@ import (
|
||||||
"google.golang.org/grpc/internal/serviceconfig"
|
"google.golang.org/grpc/internal/serviceconfig"
|
||||||
istatus "google.golang.org/grpc/internal/status"
|
istatus "google.golang.org/grpc/internal/status"
|
||||||
"google.golang.org/grpc/internal/transport"
|
"google.golang.org/grpc/internal/transport"
|
||||||
|
"google.golang.org/grpc/mem"
|
||||||
"google.golang.org/grpc/metadata"
|
"google.golang.org/grpc/metadata"
|
||||||
"google.golang.org/grpc/peer"
|
"google.golang.org/grpc/peer"
|
||||||
"google.golang.org/grpc/stats"
|
"google.golang.org/grpc/stats"
|
||||||
|
|
@ -359,7 +360,7 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client
|
||||||
cs.attempt = a
|
cs.attempt = a
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if err := cs.withRetry(op, func() { cs.bufferForRetryLocked(0, op) }); err != nil {
|
if err := cs.withRetry(op, func() { cs.bufferForRetryLocked(0, op, nil) }); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -517,7 +518,7 @@ func (a *csAttempt) newStream() error {
|
||||||
}
|
}
|
||||||
a.s = s
|
a.s = s
|
||||||
a.ctx = s.Context()
|
a.ctx = s.Context()
|
||||||
a.p = &parser{r: s, recvBufferPool: a.cs.cc.dopts.recvBufferPool}
|
a.p = &parser{r: s, bufferPool: a.cs.cc.dopts.copts.BufferPool}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -566,10 +567,15 @@ type clientStream struct {
|
||||||
// place where we need to check if the attempt is nil.
|
// place where we need to check if the attempt is nil.
|
||||||
attempt *csAttempt
|
attempt *csAttempt
|
||||||
// TODO(hedging): hedging will have multiple attempts simultaneously.
|
// TODO(hedging): hedging will have multiple attempts simultaneously.
|
||||||
committed bool // active attempt committed for retry?
|
committed bool // active attempt committed for retry?
|
||||||
onCommit func()
|
onCommit func()
|
||||||
buffer []func(a *csAttempt) error // operations to replay on retry
|
replayBuffer []replayOp // operations to replay on retry
|
||||||
bufferSize int // current size of buffer
|
replayBufferSize int // current size of replayBuffer
|
||||||
|
}
|
||||||
|
|
||||||
|
type replayOp struct {
|
||||||
|
op func(a *csAttempt) error
|
||||||
|
cleanup func()
|
||||||
}
|
}
|
||||||
|
|
||||||
// csAttempt implements a single transport stream attempt within a
|
// csAttempt implements a single transport stream attempt within a
|
||||||
|
|
@ -607,7 +613,12 @@ func (cs *clientStream) commitAttemptLocked() {
|
||||||
cs.onCommit()
|
cs.onCommit()
|
||||||
}
|
}
|
||||||
cs.committed = true
|
cs.committed = true
|
||||||
cs.buffer = nil
|
for _, op := range cs.replayBuffer {
|
||||||
|
if op.cleanup != nil {
|
||||||
|
op.cleanup()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cs.replayBuffer = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs *clientStream) commitAttempt() {
|
func (cs *clientStream) commitAttempt() {
|
||||||
|
|
@ -732,7 +743,7 @@ func (cs *clientStream) retryLocked(attempt *csAttempt, lastErr error) error {
|
||||||
// the stream is canceled.
|
// the stream is canceled.
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// Note that the first op in the replay buffer always sets cs.attempt
|
// Note that the first op in replayBuffer always sets cs.attempt
|
||||||
// if it is able to pick a transport and create a stream.
|
// if it is able to pick a transport and create a stream.
|
||||||
if lastErr = cs.replayBufferLocked(attempt); lastErr == nil {
|
if lastErr = cs.replayBufferLocked(attempt); lastErr == nil {
|
||||||
return nil
|
return nil
|
||||||
|
|
@ -761,7 +772,7 @@ func (cs *clientStream) withRetry(op func(a *csAttempt) error, onSuccess func())
|
||||||
// already be status errors.
|
// already be status errors.
|
||||||
return toRPCErr(op(cs.attempt))
|
return toRPCErr(op(cs.attempt))
|
||||||
}
|
}
|
||||||
if len(cs.buffer) == 0 {
|
if len(cs.replayBuffer) == 0 {
|
||||||
// For the first op, which controls creation of the stream and
|
// For the first op, which controls creation of the stream and
|
||||||
// assigns cs.attempt, we need to create a new attempt inline
|
// assigns cs.attempt, we need to create a new attempt inline
|
||||||
// before executing the first op. On subsequent ops, the attempt
|
// before executing the first op. On subsequent ops, the attempt
|
||||||
|
|
@ -851,25 +862,26 @@ func (cs *clientStream) Trailer() metadata.MD {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs *clientStream) replayBufferLocked(attempt *csAttempt) error {
|
func (cs *clientStream) replayBufferLocked(attempt *csAttempt) error {
|
||||||
for _, f := range cs.buffer {
|
for _, f := range cs.replayBuffer {
|
||||||
if err := f(attempt); err != nil {
|
if err := f.op(attempt); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs *clientStream) bufferForRetryLocked(sz int, op func(a *csAttempt) error) {
|
func (cs *clientStream) bufferForRetryLocked(sz int, op func(a *csAttempt) error, cleanup func()) {
|
||||||
// Note: we still will buffer if retry is disabled (for transparent retries).
|
// Note: we still will buffer if retry is disabled (for transparent retries).
|
||||||
if cs.committed {
|
if cs.committed {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
cs.bufferSize += sz
|
cs.replayBufferSize += sz
|
||||||
if cs.bufferSize > cs.callInfo.maxRetryRPCBufferSize {
|
if cs.replayBufferSize > cs.callInfo.maxRetryRPCBufferSize {
|
||||||
cs.commitAttemptLocked()
|
cs.commitAttemptLocked()
|
||||||
|
cleanup()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
cs.buffer = append(cs.buffer, op)
|
cs.replayBuffer = append(cs.replayBuffer, replayOp{op: op, cleanup: cleanup})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs *clientStream) SendMsg(m any) (err error) {
|
func (cs *clientStream) SendMsg(m any) (err error) {
|
||||||
|
|
@ -891,23 +903,50 @@ func (cs *clientStream) SendMsg(m any) (err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// load hdr, payload, data
|
// load hdr, payload, data
|
||||||
hdr, payload, data, err := prepareMsg(m, cs.codec, cs.cp, cs.comp)
|
hdr, data, payload, pf, err := prepareMsg(m, cs.codec, cs.cp, cs.comp, cs.cc.dopts.copts.BufferPool)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
data.Free()
|
||||||
|
// only free payload if compression was made, and therefore it is a different set
|
||||||
|
// of buffers from data.
|
||||||
|
if pf.isCompressed() {
|
||||||
|
payload.Free()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
dataLen := data.Len()
|
||||||
|
payloadLen := payload.Len()
|
||||||
// TODO(dfawley): should we be checking len(data) instead?
|
// TODO(dfawley): should we be checking len(data) instead?
|
||||||
if len(payload) > *cs.callInfo.maxSendMessageSize {
|
if payloadLen > *cs.callInfo.maxSendMessageSize {
|
||||||
return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(payload), *cs.callInfo.maxSendMessageSize)
|
return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", payloadLen, *cs.callInfo.maxSendMessageSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// always take an extra ref in case data == payload (i.e. when the data isn't
|
||||||
|
// compressed). The original ref will always be freed by the deferred free above.
|
||||||
|
payload.Ref()
|
||||||
op := func(a *csAttempt) error {
|
op := func(a *csAttempt) error {
|
||||||
return a.sendMsg(m, hdr, payload, data)
|
return a.sendMsg(m, hdr, payload, dataLen, payloadLen)
|
||||||
|
}
|
||||||
|
|
||||||
|
// onSuccess is invoked when the op is captured for a subsequent retry. If the
|
||||||
|
// stream was established by a previous message and therefore retries are
|
||||||
|
// disabled, onSuccess will not be invoked, and payloadRef can be freed
|
||||||
|
// immediately.
|
||||||
|
onSuccessCalled := false
|
||||||
|
err = cs.withRetry(op, func() {
|
||||||
|
cs.bufferForRetryLocked(len(hdr)+payloadLen, op, payload.Free)
|
||||||
|
onSuccessCalled = true
|
||||||
|
})
|
||||||
|
if !onSuccessCalled {
|
||||||
|
payload.Free()
|
||||||
}
|
}
|
||||||
err = cs.withRetry(op, func() { cs.bufferForRetryLocked(len(hdr)+len(payload), op) })
|
|
||||||
if len(cs.binlogs) != 0 && err == nil {
|
if len(cs.binlogs) != 0 && err == nil {
|
||||||
cm := &binarylog.ClientMessage{
|
cm := &binarylog.ClientMessage{
|
||||||
OnClientSide: true,
|
OnClientSide: true,
|
||||||
Message: data,
|
Message: data.Materialize(),
|
||||||
}
|
}
|
||||||
for _, binlog := range cs.binlogs {
|
for _, binlog := range cs.binlogs {
|
||||||
binlog.Log(cs.ctx, cm)
|
binlog.Log(cs.ctx, cm)
|
||||||
|
|
@ -924,6 +963,7 @@ func (cs *clientStream) RecvMsg(m any) error {
|
||||||
var recvInfo *payloadInfo
|
var recvInfo *payloadInfo
|
||||||
if len(cs.binlogs) != 0 {
|
if len(cs.binlogs) != 0 {
|
||||||
recvInfo = &payloadInfo{}
|
recvInfo = &payloadInfo{}
|
||||||
|
defer recvInfo.free()
|
||||||
}
|
}
|
||||||
err := cs.withRetry(func(a *csAttempt) error {
|
err := cs.withRetry(func(a *csAttempt) error {
|
||||||
return a.recvMsg(m, recvInfo)
|
return a.recvMsg(m, recvInfo)
|
||||||
|
|
@ -931,7 +971,7 @@ func (cs *clientStream) RecvMsg(m any) error {
|
||||||
if len(cs.binlogs) != 0 && err == nil {
|
if len(cs.binlogs) != 0 && err == nil {
|
||||||
sm := &binarylog.ServerMessage{
|
sm := &binarylog.ServerMessage{
|
||||||
OnClientSide: true,
|
OnClientSide: true,
|
||||||
Message: recvInfo.uncompressedBytes,
|
Message: recvInfo.uncompressedBytes.Materialize(),
|
||||||
}
|
}
|
||||||
for _, binlog := range cs.binlogs {
|
for _, binlog := range cs.binlogs {
|
||||||
binlog.Log(cs.ctx, sm)
|
binlog.Log(cs.ctx, sm)
|
||||||
|
|
@ -958,7 +998,7 @@ func (cs *clientStream) CloseSend() error {
|
||||||
// RecvMsg. This also matches historical behavior.
|
// RecvMsg. This also matches historical behavior.
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
cs.withRetry(op, func() { cs.bufferForRetryLocked(0, op) })
|
cs.withRetry(op, func() { cs.bufferForRetryLocked(0, op, nil) })
|
||||||
if len(cs.binlogs) != 0 {
|
if len(cs.binlogs) != 0 {
|
||||||
chc := &binarylog.ClientHalfClose{
|
chc := &binarylog.ClientHalfClose{
|
||||||
OnClientSide: true,
|
OnClientSide: true,
|
||||||
|
|
@ -1034,7 +1074,7 @@ func (cs *clientStream) finish(err error) {
|
||||||
cs.cancel()
|
cs.cancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *csAttempt) sendMsg(m any, hdr, payld, data []byte) error {
|
func (a *csAttempt) sendMsg(m any, hdr []byte, payld mem.BufferSlice, dataLength, payloadLength int) error {
|
||||||
cs := a.cs
|
cs := a.cs
|
||||||
if a.trInfo != nil {
|
if a.trInfo != nil {
|
||||||
a.mu.Lock()
|
a.mu.Lock()
|
||||||
|
|
@ -1052,8 +1092,10 @@ func (a *csAttempt) sendMsg(m any, hdr, payld, data []byte) error {
|
||||||
}
|
}
|
||||||
return io.EOF
|
return io.EOF
|
||||||
}
|
}
|
||||||
for _, sh := range a.statsHandlers {
|
if len(a.statsHandlers) != 0 {
|
||||||
sh.HandleRPC(a.ctx, outPayload(true, m, data, payld, time.Now()))
|
for _, sh := range a.statsHandlers {
|
||||||
|
sh.HandleRPC(a.ctx, outPayload(true, m, dataLength, payloadLength, time.Now()))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if channelz.IsOn() {
|
if channelz.IsOn() {
|
||||||
a.t.IncrMsgSent()
|
a.t.IncrMsgSent()
|
||||||
|
|
@ -1065,6 +1107,7 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) {
|
||||||
cs := a.cs
|
cs := a.cs
|
||||||
if len(a.statsHandlers) != 0 && payInfo == nil {
|
if len(a.statsHandlers) != 0 && payInfo == nil {
|
||||||
payInfo = &payloadInfo{}
|
payInfo = &payloadInfo{}
|
||||||
|
defer payInfo.free()
|
||||||
}
|
}
|
||||||
|
|
||||||
if !a.decompSet {
|
if !a.decompSet {
|
||||||
|
|
@ -1102,14 +1145,12 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) {
|
||||||
}
|
}
|
||||||
for _, sh := range a.statsHandlers {
|
for _, sh := range a.statsHandlers {
|
||||||
sh.HandleRPC(a.ctx, &stats.InPayload{
|
sh.HandleRPC(a.ctx, &stats.InPayload{
|
||||||
Client: true,
|
Client: true,
|
||||||
RecvTime: time.Now(),
|
RecvTime: time.Now(),
|
||||||
Payload: m,
|
Payload: m,
|
||||||
// TODO truncate large payload.
|
|
||||||
Data: payInfo.uncompressedBytes,
|
|
||||||
WireLength: payInfo.compressedLength + headerLen,
|
WireLength: payInfo.compressedLength + headerLen,
|
||||||
CompressedLength: payInfo.compressedLength,
|
CompressedLength: payInfo.compressedLength,
|
||||||
Length: len(payInfo.uncompressedBytes),
|
Length: payInfo.uncompressedBytes.Len(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
if channelz.IsOn() {
|
if channelz.IsOn() {
|
||||||
|
|
@ -1273,7 +1314,7 @@ func newNonRetryClientStream(ctx context.Context, desc *StreamDesc, method strin
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
as.s = s
|
as.s = s
|
||||||
as.p = &parser{r: s, recvBufferPool: ac.dopts.recvBufferPool}
|
as.p = &parser{r: s, bufferPool: ac.dopts.copts.BufferPool}
|
||||||
ac.incrCallsStarted()
|
ac.incrCallsStarted()
|
||||||
if desc != unaryStreamDesc {
|
if desc != unaryStreamDesc {
|
||||||
// Listen on stream context to cleanup when the stream context is
|
// Listen on stream context to cleanup when the stream context is
|
||||||
|
|
@ -1370,17 +1411,26 @@ func (as *addrConnStream) SendMsg(m any) (err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// load hdr, payload, data
|
// load hdr, payload, data
|
||||||
hdr, payld, _, err := prepareMsg(m, as.codec, as.cp, as.comp)
|
hdr, data, payload, pf, err := prepareMsg(m, as.codec, as.cp, as.comp, as.ac.dopts.copts.BufferPool)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
data.Free()
|
||||||
|
// only free payload if compression was made, and therefore it is a different set
|
||||||
|
// of buffers from data.
|
||||||
|
if pf.isCompressed() {
|
||||||
|
payload.Free()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// TODO(dfawley): should we be checking len(data) instead?
|
// TODO(dfawley): should we be checking len(data) instead?
|
||||||
if len(payld) > *as.callInfo.maxSendMessageSize {
|
if payload.Len() > *as.callInfo.maxSendMessageSize {
|
||||||
return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(payld), *as.callInfo.maxSendMessageSize)
|
return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", payload.Len(), *as.callInfo.maxSendMessageSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := as.t.Write(as.s, hdr, payld, &transport.Options{Last: !as.desc.ClientStreams}); err != nil {
|
if err := as.t.Write(as.s, hdr, payload, &transport.Options{Last: !as.desc.ClientStreams}); err != nil {
|
||||||
if !as.desc.ClientStreams {
|
if !as.desc.ClientStreams {
|
||||||
// For non-client-streaming RPCs, we return nil instead of EOF on error
|
// For non-client-streaming RPCs, we return nil instead of EOF on error
|
||||||
// because the generated code requires it. finish is not called; RecvMsg()
|
// because the generated code requires it. finish is not called; RecvMsg()
|
||||||
|
|
@ -1639,18 +1689,31 @@ func (ss *serverStream) SendMsg(m any) (err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// load hdr, payload, data
|
// load hdr, payload, data
|
||||||
hdr, payload, data, err := prepareMsg(m, ss.codec, ss.cp, ss.comp)
|
hdr, data, payload, pf, err := prepareMsg(m, ss.codec, ss.cp, ss.comp, ss.p.bufferPool)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
data.Free()
|
||||||
|
// only free payload if compression was made, and therefore it is a different set
|
||||||
|
// of buffers from data.
|
||||||
|
if pf.isCompressed() {
|
||||||
|
payload.Free()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
dataLen := data.Len()
|
||||||
|
payloadLen := payload.Len()
|
||||||
|
|
||||||
// TODO(dfawley): should we be checking len(data) instead?
|
// TODO(dfawley): should we be checking len(data) instead?
|
||||||
if len(payload) > ss.maxSendMessageSize {
|
if payloadLen > ss.maxSendMessageSize {
|
||||||
return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(payload), ss.maxSendMessageSize)
|
return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", payloadLen, ss.maxSendMessageSize)
|
||||||
}
|
}
|
||||||
if err := ss.t.Write(ss.s, hdr, payload, &transport.Options{Last: false}); err != nil {
|
if err := ss.t.Write(ss.s, hdr, payload, &transport.Options{Last: false}); err != nil {
|
||||||
return toRPCErr(err)
|
return toRPCErr(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(ss.binlogs) != 0 {
|
if len(ss.binlogs) != 0 {
|
||||||
if !ss.serverHeaderBinlogged {
|
if !ss.serverHeaderBinlogged {
|
||||||
h, _ := ss.s.Header()
|
h, _ := ss.s.Header()
|
||||||
|
|
@ -1663,7 +1726,7 @@ func (ss *serverStream) SendMsg(m any) (err error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
sm := &binarylog.ServerMessage{
|
sm := &binarylog.ServerMessage{
|
||||||
Message: data,
|
Message: data.Materialize(),
|
||||||
}
|
}
|
||||||
for _, binlog := range ss.binlogs {
|
for _, binlog := range ss.binlogs {
|
||||||
binlog.Log(ss.ctx, sm)
|
binlog.Log(ss.ctx, sm)
|
||||||
|
|
@ -1671,7 +1734,7 @@ func (ss *serverStream) SendMsg(m any) (err error) {
|
||||||
}
|
}
|
||||||
if len(ss.statsHandler) != 0 {
|
if len(ss.statsHandler) != 0 {
|
||||||
for _, sh := range ss.statsHandler {
|
for _, sh := range ss.statsHandler {
|
||||||
sh.HandleRPC(ss.s.Context(), outPayload(false, m, data, payload, time.Now()))
|
sh.HandleRPC(ss.s.Context(), outPayload(false, m, dataLen, payloadLen, time.Now()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
@ -1708,6 +1771,7 @@ func (ss *serverStream) RecvMsg(m any) (err error) {
|
||||||
var payInfo *payloadInfo
|
var payInfo *payloadInfo
|
||||||
if len(ss.statsHandler) != 0 || len(ss.binlogs) != 0 {
|
if len(ss.statsHandler) != 0 || len(ss.binlogs) != 0 {
|
||||||
payInfo = &payloadInfo{}
|
payInfo = &payloadInfo{}
|
||||||
|
defer payInfo.free()
|
||||||
}
|
}
|
||||||
if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, payInfo, ss.decomp, true); err != nil {
|
if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, payInfo, ss.decomp, true); err != nil {
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
|
|
@ -1727,11 +1791,9 @@ func (ss *serverStream) RecvMsg(m any) (err error) {
|
||||||
if len(ss.statsHandler) != 0 {
|
if len(ss.statsHandler) != 0 {
|
||||||
for _, sh := range ss.statsHandler {
|
for _, sh := range ss.statsHandler {
|
||||||
sh.HandleRPC(ss.s.Context(), &stats.InPayload{
|
sh.HandleRPC(ss.s.Context(), &stats.InPayload{
|
||||||
RecvTime: time.Now(),
|
RecvTime: time.Now(),
|
||||||
Payload: m,
|
Payload: m,
|
||||||
// TODO truncate large payload.
|
Length: payInfo.uncompressedBytes.Len(),
|
||||||
Data: payInfo.uncompressedBytes,
|
|
||||||
Length: len(payInfo.uncompressedBytes),
|
|
||||||
WireLength: payInfo.compressedLength + headerLen,
|
WireLength: payInfo.compressedLength + headerLen,
|
||||||
CompressedLength: payInfo.compressedLength,
|
CompressedLength: payInfo.compressedLength,
|
||||||
})
|
})
|
||||||
|
|
@ -1739,7 +1801,7 @@ func (ss *serverStream) RecvMsg(m any) (err error) {
|
||||||
}
|
}
|
||||||
if len(ss.binlogs) != 0 {
|
if len(ss.binlogs) != 0 {
|
||||||
cm := &binarylog.ClientMessage{
|
cm := &binarylog.ClientMessage{
|
||||||
Message: payInfo.uncompressedBytes,
|
Message: payInfo.uncompressedBytes.Materialize(),
|
||||||
}
|
}
|
||||||
for _, binlog := range ss.binlogs {
|
for _, binlog := range ss.binlogs {
|
||||||
binlog.Log(ss.ctx, cm)
|
binlog.Log(ss.ctx, cm)
|
||||||
|
|
@ -1754,23 +1816,26 @@ func MethodFromServerStream(stream ServerStream) (string, bool) {
|
||||||
return Method(stream.Context())
|
return Method(stream.Context())
|
||||||
}
|
}
|
||||||
|
|
||||||
// prepareMsg returns the hdr, payload and data
|
// prepareMsg returns the hdr, payload and data using the compressors passed or
|
||||||
// using the compressors passed or using the
|
// using the passed preparedmsg. The returned boolean indicates whether
|
||||||
// passed preparedmsg
|
// compression was made and therefore whether the payload needs to be freed in
|
||||||
func prepareMsg(m any, codec baseCodec, cp Compressor, comp encoding.Compressor) (hdr, payload, data []byte, err error) {
|
// addition to the returned data. Freeing the payload if the returned boolean is
|
||||||
|
// false can lead to undefined behavior.
|
||||||
|
func prepareMsg(m any, codec baseCodec, cp Compressor, comp encoding.Compressor, pool mem.BufferPool) (hdr []byte, data, payload mem.BufferSlice, pf payloadFormat, err error) {
|
||||||
if preparedMsg, ok := m.(*PreparedMsg); ok {
|
if preparedMsg, ok := m.(*PreparedMsg); ok {
|
||||||
return preparedMsg.hdr, preparedMsg.payload, preparedMsg.encodedData, nil
|
return preparedMsg.hdr, preparedMsg.encodedData, preparedMsg.payload, preparedMsg.pf, nil
|
||||||
}
|
}
|
||||||
// The input interface is not a prepared msg.
|
// The input interface is not a prepared msg.
|
||||||
// Marshal and Compress the data at this point
|
// Marshal and Compress the data at this point
|
||||||
data, err = encode(codec, m)
|
data, err = encode(codec, m)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, 0, err
|
||||||
}
|
}
|
||||||
compData, err := compress(data, cp, comp)
|
compData, pf, err := compress(data, cp, comp, pool)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
data.Free()
|
||||||
|
return nil, nil, nil, 0, err
|
||||||
}
|
}
|
||||||
hdr, payload = msgHeader(data, compData)
|
hdr, payload = msgHeader(data, compData, pf)
|
||||||
return hdr, payload, data, nil
|
return hdr, data, payload, pf, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -158,4 +158,7 @@ func (s) TestCancelWhileRecvingWithCompression(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if err := ss.CC.Close(); err != nil {
|
||||||
|
t.Fatalf("Close failed with %v, want nil", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -455,7 +455,7 @@ func (s) TestRetryStreaming(t *testing.T) {
|
||||||
time.Sleep(time.Millisecond)
|
time.Sleep(time.Millisecond)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for i, tc := range testCases {
|
||||||
func() {
|
func() {
|
||||||
serverOpIter = 0
|
serverOpIter = 0
|
||||||
serverOps = tc.serverOps
|
serverOps = tc.serverOps
|
||||||
|
|
@ -464,9 +464,9 @@ func (s) TestRetryStreaming(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("%v: Error while creating stream: %v", tc.desc, err)
|
t.Fatalf("%v: Error while creating stream: %v", tc.desc, err)
|
||||||
}
|
}
|
||||||
for _, op := range tc.clientOps {
|
for j, op := range tc.clientOps {
|
||||||
if err := op(stream); err != nil {
|
if err := op(stream); err != nil {
|
||||||
t.Errorf("%v: %v", tc.desc, err)
|
t.Errorf("%d %d %v: %v", i, j, tc.desc, err)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue