mirror of https://github.com/grpc/grpc-go.git
server: prohibit more than MaxConcurrentStreams handlers from running at once (#6703)
This commit is contained in:
parent
313861efe5
commit
f2180b4d54
|
|
@ -425,3 +425,42 @@ func BenchmarkRLockUnlock(b *testing.B) {
|
|||
}
|
||||
})
|
||||
}
|
||||
|
||||
type ifNop interface {
|
||||
nop()
|
||||
}
|
||||
|
||||
type alwaysNop struct{}
|
||||
|
||||
func (alwaysNop) nop() {}
|
||||
|
||||
type concreteNop struct {
|
||||
isNop atomic.Bool
|
||||
i int
|
||||
}
|
||||
|
||||
func (c *concreteNop) nop() {
|
||||
if c.isNop.Load() {
|
||||
return
|
||||
}
|
||||
c.i++
|
||||
}
|
||||
|
||||
func BenchmarkInterfaceNop(b *testing.B) {
|
||||
n := ifNop(alwaysNop{})
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
n.nop()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkConcreteNop(b *testing.B) {
|
||||
n := &concreteNop{}
|
||||
n.isNop.Store(true)
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
n.nop()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -171,15 +171,10 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
|
|||
ID: http2.SettingMaxFrameSize,
|
||||
Val: http2MaxFrameLen,
|
||||
}}
|
||||
// TODO(zhaoq): Have a better way to signal "no limit" because 0 is
|
||||
// permitted in the HTTP2 spec.
|
||||
maxStreams := config.MaxStreams
|
||||
if maxStreams == 0 {
|
||||
maxStreams = math.MaxUint32
|
||||
} else {
|
||||
if config.MaxStreams != math.MaxUint32 {
|
||||
isettings = append(isettings, http2.Setting{
|
||||
ID: http2.SettingMaxConcurrentStreams,
|
||||
Val: maxStreams,
|
||||
Val: config.MaxStreams,
|
||||
})
|
||||
}
|
||||
dynamicWindow := true
|
||||
|
|
@ -258,7 +253,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
|
|||
framer: framer,
|
||||
readerDone: make(chan struct{}),
|
||||
writerDone: make(chan struct{}),
|
||||
maxStreams: maxStreams,
|
||||
maxStreams: config.MaxStreams,
|
||||
inTapHandle: config.InTapHandle,
|
||||
fc: &trInFlow{limit: uint32(icwz)},
|
||||
state: reachable,
|
||||
|
|
|
|||
|
|
@ -337,6 +337,9 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
|
|||
return
|
||||
}
|
||||
rawConn := conn
|
||||
if serverConfig.MaxStreams == 0 {
|
||||
serverConfig.MaxStreams = math.MaxUint32
|
||||
}
|
||||
transport, err := NewServerTransport(conn, serverConfig)
|
||||
if err != nil {
|
||||
return
|
||||
|
|
@ -425,8 +428,8 @@ func setUpServerOnly(t *testing.T, port int, sc *ServerConfig, ht hType) *server
|
|||
return server
|
||||
}
|
||||
|
||||
func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, *http2Client, func()) {
|
||||
return setUpWithOptions(t, port, &ServerConfig{MaxStreams: maxStreams}, ht, ConnectOptions{})
|
||||
func setUp(t *testing.T, port int, ht hType) (*server, *http2Client, func()) {
|
||||
return setUpWithOptions(t, port, &ServerConfig{}, ht, ConnectOptions{})
|
||||
}
|
||||
|
||||
func setUpWithOptions(t *testing.T, port int, sc *ServerConfig, ht hType, copts ConnectOptions) (*server, *http2Client, func()) {
|
||||
|
|
@ -521,7 +524,7 @@ func (s) TestInflightStreamClosing(t *testing.T) {
|
|||
|
||||
// Tests that when streamID > MaxStreamId, the current client transport drains.
|
||||
func (s) TestClientTransportDrainsAfterStreamIDExhausted(t *testing.T) {
|
||||
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
|
||||
server, ct, cancel := setUp(t, 0, normal)
|
||||
defer cancel()
|
||||
defer server.stop()
|
||||
callHdr := &CallHdr{
|
||||
|
|
@ -566,7 +569,7 @@ func (s) TestClientTransportDrainsAfterStreamIDExhausted(t *testing.T) {
|
|||
}
|
||||
|
||||
func (s) TestClientSendAndReceive(t *testing.T) {
|
||||
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
|
||||
server, ct, cancel := setUp(t, 0, normal)
|
||||
defer cancel()
|
||||
callHdr := &CallHdr{
|
||||
Host: "localhost",
|
||||
|
|
@ -606,7 +609,7 @@ func (s) TestClientSendAndReceive(t *testing.T) {
|
|||
}
|
||||
|
||||
func (s) TestClientErrorNotify(t *testing.T) {
|
||||
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
|
||||
server, ct, cancel := setUp(t, 0, normal)
|
||||
defer cancel()
|
||||
go server.stop()
|
||||
// ct.reader should detect the error and activate ct.Error().
|
||||
|
|
@ -640,7 +643,7 @@ func performOneRPC(ct ClientTransport) {
|
|||
}
|
||||
|
||||
func (s) TestClientMix(t *testing.T) {
|
||||
s, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
|
||||
s, ct, cancel := setUp(t, 0, normal)
|
||||
defer cancel()
|
||||
time.AfterFunc(time.Second, s.stop)
|
||||
go func(ct ClientTransport) {
|
||||
|
|
@ -654,7 +657,7 @@ func (s) TestClientMix(t *testing.T) {
|
|||
}
|
||||
|
||||
func (s) TestLargeMessage(t *testing.T) {
|
||||
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
|
||||
server, ct, cancel := setUp(t, 0, normal)
|
||||
defer cancel()
|
||||
callHdr := &CallHdr{
|
||||
Host: "localhost",
|
||||
|
|
@ -789,7 +792,7 @@ func (s) TestLargeMessageWithDelayRead(t *testing.T) {
|
|||
// proceed until they complete naturally, while not allowing creation of new
|
||||
// streams during this window.
|
||||
func (s) TestGracefulClose(t *testing.T) {
|
||||
server, ct, cancel := setUp(t, 0, math.MaxUint32, pingpong)
|
||||
server, ct, cancel := setUp(t, 0, pingpong)
|
||||
defer cancel()
|
||||
defer func() {
|
||||
// Stop the server's listener to make the server's goroutines terminate
|
||||
|
|
@ -855,7 +858,7 @@ func (s) TestGracefulClose(t *testing.T) {
|
|||
}
|
||||
|
||||
func (s) TestLargeMessageSuspension(t *testing.T) {
|
||||
server, ct, cancel := setUp(t, 0, math.MaxUint32, suspended)
|
||||
server, ct, cancel := setUp(t, 0, suspended)
|
||||
defer cancel()
|
||||
callHdr := &CallHdr{
|
||||
Host: "localhost",
|
||||
|
|
@ -963,7 +966,7 @@ func (s) TestMaxStreams(t *testing.T) {
|
|||
}
|
||||
|
||||
func (s) TestServerContextCanceledOnClosedConnection(t *testing.T) {
|
||||
server, ct, cancel := setUp(t, 0, math.MaxUint32, suspended)
|
||||
server, ct, cancel := setUp(t, 0, suspended)
|
||||
defer cancel()
|
||||
callHdr := &CallHdr{
|
||||
Host: "localhost",
|
||||
|
|
@ -1435,7 +1438,7 @@ func (s) TestClientWithMisbehavedServer(t *testing.T) {
|
|||
var encodingTestStatus = status.New(codes.Internal, "\n")
|
||||
|
||||
func (s) TestEncodingRequiredStatus(t *testing.T) {
|
||||
server, ct, cancel := setUp(t, 0, math.MaxUint32, encodingRequiredStatus)
|
||||
server, ct, cancel := setUp(t, 0, encodingRequiredStatus)
|
||||
defer cancel()
|
||||
callHdr := &CallHdr{
|
||||
Host: "localhost",
|
||||
|
|
@ -1463,7 +1466,7 @@ func (s) TestEncodingRequiredStatus(t *testing.T) {
|
|||
}
|
||||
|
||||
func (s) TestInvalidHeaderField(t *testing.T) {
|
||||
server, ct, cancel := setUp(t, 0, math.MaxUint32, invalidHeaderField)
|
||||
server, ct, cancel := setUp(t, 0, invalidHeaderField)
|
||||
defer cancel()
|
||||
callHdr := &CallHdr{
|
||||
Host: "localhost",
|
||||
|
|
@ -1485,7 +1488,7 @@ func (s) TestInvalidHeaderField(t *testing.T) {
|
|||
}
|
||||
|
||||
func (s) TestHeaderChanClosedAfterReceivingAnInvalidHeader(t *testing.T) {
|
||||
server, ct, cancel := setUp(t, 0, math.MaxUint32, invalidHeaderField)
|
||||
server, ct, cancel := setUp(t, 0, invalidHeaderField)
|
||||
defer cancel()
|
||||
defer server.stop()
|
||||
defer ct.Close(fmt.Errorf("closed manually by test"))
|
||||
|
|
@ -2153,7 +2156,7 @@ func (s) TestPingPong1MB(t *testing.T) {
|
|||
|
||||
// This is a stress-test of flow control logic.
|
||||
func runPingPongTest(t *testing.T, msgSize int) {
|
||||
server, client, cancel := setUp(t, 0, 0, pingpong)
|
||||
server, client, cancel := setUp(t, 0, pingpong)
|
||||
defer cancel()
|
||||
defer server.stop()
|
||||
defer client.Close(fmt.Errorf("closed manually by test"))
|
||||
|
|
@ -2235,7 +2238,7 @@ func (s) TestHeaderTblSize(t *testing.T) {
|
|||
}
|
||||
}()
|
||||
|
||||
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
|
||||
server, ct, cancel := setUp(t, 0, normal)
|
||||
defer cancel()
|
||||
defer ct.Close(fmt.Errorf("closed manually by test"))
|
||||
defer server.stop()
|
||||
|
|
@ -2594,7 +2597,7 @@ func TestConnectionError_Unwrap(t *testing.T) {
|
|||
|
||||
func (s) TestPeerSetInServerContext(t *testing.T) {
|
||||
// create client and server transports.
|
||||
server, client, cancel := setUp(t, 0, math.MaxUint32, normal)
|
||||
server, client, cancel := setUp(t, 0, normal)
|
||||
defer cancel()
|
||||
defer server.stop()
|
||||
defer client.Close(fmt.Errorf("closed manually by test"))
|
||||
|
|
|
|||
71
server.go
71
server.go
|
|
@ -115,12 +115,6 @@ type serviceInfo struct {
|
|||
mdata any
|
||||
}
|
||||
|
||||
type serverWorkerData struct {
|
||||
st transport.ServerTransport
|
||||
wg *sync.WaitGroup
|
||||
stream *transport.Stream
|
||||
}
|
||||
|
||||
// Server is a gRPC server to serve RPC requests.
|
||||
type Server struct {
|
||||
opts serverOptions
|
||||
|
|
@ -145,7 +139,7 @@ type Server struct {
|
|||
channelzID *channelz.Identifier
|
||||
czData *channelzData
|
||||
|
||||
serverWorkerChannel chan *serverWorkerData
|
||||
serverWorkerChannel chan func()
|
||||
}
|
||||
|
||||
type serverOptions struct {
|
||||
|
|
@ -179,6 +173,7 @@ type serverOptions struct {
|
|||
}
|
||||
|
||||
var defaultServerOptions = serverOptions{
|
||||
maxConcurrentStreams: math.MaxUint32,
|
||||
maxReceiveMessageSize: defaultServerMaxReceiveMessageSize,
|
||||
maxSendMessageSize: defaultServerMaxSendMessageSize,
|
||||
connectionTimeout: 120 * time.Second,
|
||||
|
|
@ -404,6 +399,9 @@ func MaxSendMsgSize(m int) ServerOption {
|
|||
// MaxConcurrentStreams returns a ServerOption that will apply a limit on the number
|
||||
// of concurrent streams to each ServerTransport.
|
||||
func MaxConcurrentStreams(n uint32) ServerOption {
|
||||
if n == 0 {
|
||||
n = math.MaxUint32
|
||||
}
|
||||
return newFuncServerOption(func(o *serverOptions) {
|
||||
o.maxConcurrentStreams = n
|
||||
})
|
||||
|
|
@ -605,24 +603,19 @@ const serverWorkerResetThreshold = 1 << 16
|
|||
// [1] https://github.com/golang/go/issues/18138
|
||||
func (s *Server) serverWorker() {
|
||||
for completed := 0; completed < serverWorkerResetThreshold; completed++ {
|
||||
data, ok := <-s.serverWorkerChannel
|
||||
f, ok := <-s.serverWorkerChannel
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
s.handleSingleStream(data)
|
||||
f()
|
||||
}
|
||||
go s.serverWorker()
|
||||
}
|
||||
|
||||
func (s *Server) handleSingleStream(data *serverWorkerData) {
|
||||
defer data.wg.Done()
|
||||
s.handleStream(data.st, data.stream)
|
||||
}
|
||||
|
||||
// initServerWorkers creates worker goroutines and a channel to process incoming
|
||||
// connections to reduce the time spent overall on runtime.morestack.
|
||||
func (s *Server) initServerWorkers() {
|
||||
s.serverWorkerChannel = make(chan *serverWorkerData)
|
||||
s.serverWorkerChannel = make(chan func())
|
||||
for i := uint32(0); i < s.opts.numServerWorkers; i++ {
|
||||
go s.serverWorker()
|
||||
}
|
||||
|
|
@ -982,21 +975,26 @@ func (s *Server) serveStreams(st transport.ServerTransport) {
|
|||
defer st.Close(errors.New("finished serving streams for the server transport"))
|
||||
var wg sync.WaitGroup
|
||||
|
||||
streamQuota := newHandlerQuota(s.opts.maxConcurrentStreams)
|
||||
st.HandleStreams(func(stream *transport.Stream) {
|
||||
wg.Add(1)
|
||||
|
||||
streamQuota.acquire()
|
||||
f := func() {
|
||||
defer streamQuota.release()
|
||||
defer wg.Done()
|
||||
s.handleStream(st, stream)
|
||||
}
|
||||
|
||||
if s.opts.numServerWorkers > 0 {
|
||||
data := &serverWorkerData{st: st, wg: &wg, stream: stream}
|
||||
select {
|
||||
case s.serverWorkerChannel <- data:
|
||||
case s.serverWorkerChannel <- f:
|
||||
return
|
||||
default:
|
||||
// If all stream workers are busy, fallback to the default code path.
|
||||
}
|
||||
}
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
s.handleStream(st, stream)
|
||||
}()
|
||||
go f()
|
||||
})
|
||||
wg.Wait()
|
||||
}
|
||||
|
|
@ -2077,3 +2075,34 @@ func validateSendCompressor(name, clientCompressors string) error {
|
|||
}
|
||||
return fmt.Errorf("client does not support compressor %q", name)
|
||||
}
|
||||
|
||||
// atomicSemaphore implements a blocking, counting semaphore. acquire should be
|
||||
// called synchronously; release may be called asynchronously.
|
||||
type atomicSemaphore struct {
|
||||
n atomic.Int64
|
||||
wait chan struct{}
|
||||
}
|
||||
|
||||
func (q *atomicSemaphore) acquire() {
|
||||
if q.n.Add(-1) < 0 {
|
||||
// We ran out of quota. Block until a release happens.
|
||||
<-q.wait
|
||||
}
|
||||
}
|
||||
|
||||
func (q *atomicSemaphore) release() {
|
||||
// N.B. the "<= 0" check below should allow for this to work with multiple
|
||||
// concurrent calls to acquire, but also note that with synchronous calls to
|
||||
// acquire, as our system does, n will never be less than -1. There are
|
||||
// fairness issues (queuing) to consider if this was to be generalized.
|
||||
if q.n.Add(1) <= 0 {
|
||||
// An acquire was waiting on us. Unblock it.
|
||||
q.wait <- struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
func newHandlerQuota(n uint32) *atomicSemaphore {
|
||||
a := &atomicSemaphore{wait: make(chan struct{}, 1)}
|
||||
a.n.Store(int64(n))
|
||||
return a
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,99 @@
|
|||
/*
|
||||
*
|
||||
* Copyright 2023 gRPC authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package grpc_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/internal/grpcsync"
|
||||
"google.golang.org/grpc/internal/stubserver"
|
||||
|
||||
testgrpc "google.golang.org/grpc/interop/grpc_testing"
|
||||
)
|
||||
|
||||
// TestServer_MaxHandlers ensures that no more than MaxConcurrentStreams server
|
||||
// handlers are active at one time.
|
||||
func (s) TestServer_MaxHandlers(t *testing.T) {
|
||||
started := make(chan struct{})
|
||||
blockCalls := grpcsync.NewEvent()
|
||||
|
||||
// This stub server does not properly respect the stream context, so it will
|
||||
// not exit when the context is canceled.
|
||||
ss := stubserver.StubServer{
|
||||
FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
|
||||
started <- struct{}{}
|
||||
<-blockCalls.Done()
|
||||
return nil
|
||||
},
|
||||
}
|
||||
if err := ss.Start([]grpc.ServerOption{grpc.MaxConcurrentStreams(1)}); err != nil {
|
||||
t.Fatal("Error starting server:", err)
|
||||
}
|
||||
defer ss.Stop()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Start one RPC to the server.
|
||||
ctx1, cancel1 := context.WithCancel(ctx)
|
||||
_, err := ss.Client.FullDuplexCall(ctx1)
|
||||
if err != nil {
|
||||
t.Fatal("Error staring call:", err)
|
||||
}
|
||||
|
||||
// Wait for the handler to be invoked.
|
||||
select {
|
||||
case <-started:
|
||||
case <-ctx.Done():
|
||||
t.Fatalf("Timed out waiting for RPC to start on server.")
|
||||
}
|
||||
|
||||
// Cancel it on the client. The server handler will still be running.
|
||||
cancel1()
|
||||
|
||||
ctx2, cancel2 := context.WithCancel(ctx)
|
||||
defer cancel2()
|
||||
s, err := ss.Client.FullDuplexCall(ctx2)
|
||||
if err != nil {
|
||||
t.Fatal("Error staring call:", err)
|
||||
}
|
||||
|
||||
// After 100ms, allow the first call to unblock. That should allow the
|
||||
// second RPC to run and finish.
|
||||
select {
|
||||
case <-started:
|
||||
blockCalls.Fire()
|
||||
t.Fatalf("RPC started unexpectedly.")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
blockCalls.Fire()
|
||||
}
|
||||
|
||||
select {
|
||||
case <-started:
|
||||
case <-ctx.Done():
|
||||
t.Fatalf("Timed out waiting for second RPC to start on server.")
|
||||
}
|
||||
if _, err := s.Recv(); err != io.EOF {
|
||||
t.Fatal("Received unexpected RPC error:", err)
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue