server: prohibit more than MaxConcurrentStreams handlers from running at once (#6703)

This commit is contained in:
Doug Fawley 2023-10-10 10:51:45 -07:00 committed by GitHub
parent 313861efe5
commit f2180b4d54
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 210 additions and 45 deletions

View File

@ -425,3 +425,42 @@ func BenchmarkRLockUnlock(b *testing.B) {
}
})
}
type ifNop interface {
nop()
}
type alwaysNop struct{}
func (alwaysNop) nop() {}
type concreteNop struct {
isNop atomic.Bool
i int
}
func (c *concreteNop) nop() {
if c.isNop.Load() {
return
}
c.i++
}
func BenchmarkInterfaceNop(b *testing.B) {
n := ifNop(alwaysNop{})
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
n.nop()
}
})
}
func BenchmarkConcreteNop(b *testing.B) {
n := &concreteNop{}
n.isNop.Store(true)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
n.nop()
}
})
}

View File

@ -171,15 +171,10 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
ID: http2.SettingMaxFrameSize,
Val: http2MaxFrameLen,
}}
// TODO(zhaoq): Have a better way to signal "no limit" because 0 is
// permitted in the HTTP2 spec.
maxStreams := config.MaxStreams
if maxStreams == 0 {
maxStreams = math.MaxUint32
} else {
if config.MaxStreams != math.MaxUint32 {
isettings = append(isettings, http2.Setting{
ID: http2.SettingMaxConcurrentStreams,
Val: maxStreams,
Val: config.MaxStreams,
})
}
dynamicWindow := true
@ -258,7 +253,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
framer: framer,
readerDone: make(chan struct{}),
writerDone: make(chan struct{}),
maxStreams: maxStreams,
maxStreams: config.MaxStreams,
inTapHandle: config.InTapHandle,
fc: &trInFlow{limit: uint32(icwz)},
state: reachable,

View File

@ -337,6 +337,9 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
return
}
rawConn := conn
if serverConfig.MaxStreams == 0 {
serverConfig.MaxStreams = math.MaxUint32
}
transport, err := NewServerTransport(conn, serverConfig)
if err != nil {
return
@ -425,8 +428,8 @@ func setUpServerOnly(t *testing.T, port int, sc *ServerConfig, ht hType) *server
return server
}
func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, *http2Client, func()) {
return setUpWithOptions(t, port, &ServerConfig{MaxStreams: maxStreams}, ht, ConnectOptions{})
func setUp(t *testing.T, port int, ht hType) (*server, *http2Client, func()) {
return setUpWithOptions(t, port, &ServerConfig{}, ht, ConnectOptions{})
}
func setUpWithOptions(t *testing.T, port int, sc *ServerConfig, ht hType, copts ConnectOptions) (*server, *http2Client, func()) {
@ -521,7 +524,7 @@ func (s) TestInflightStreamClosing(t *testing.T) {
// Tests that when streamID > MaxStreamId, the current client transport drains.
func (s) TestClientTransportDrainsAfterStreamIDExhausted(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
server, ct, cancel := setUp(t, 0, normal)
defer cancel()
defer server.stop()
callHdr := &CallHdr{
@ -566,7 +569,7 @@ func (s) TestClientTransportDrainsAfterStreamIDExhausted(t *testing.T) {
}
func (s) TestClientSendAndReceive(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
server, ct, cancel := setUp(t, 0, normal)
defer cancel()
callHdr := &CallHdr{
Host: "localhost",
@ -606,7 +609,7 @@ func (s) TestClientSendAndReceive(t *testing.T) {
}
func (s) TestClientErrorNotify(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
server, ct, cancel := setUp(t, 0, normal)
defer cancel()
go server.stop()
// ct.reader should detect the error and activate ct.Error().
@ -640,7 +643,7 @@ func performOneRPC(ct ClientTransport) {
}
func (s) TestClientMix(t *testing.T) {
s, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
s, ct, cancel := setUp(t, 0, normal)
defer cancel()
time.AfterFunc(time.Second, s.stop)
go func(ct ClientTransport) {
@ -654,7 +657,7 @@ func (s) TestClientMix(t *testing.T) {
}
func (s) TestLargeMessage(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
server, ct, cancel := setUp(t, 0, normal)
defer cancel()
callHdr := &CallHdr{
Host: "localhost",
@ -789,7 +792,7 @@ func (s) TestLargeMessageWithDelayRead(t *testing.T) {
// proceed until they complete naturally, while not allowing creation of new
// streams during this window.
func (s) TestGracefulClose(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, pingpong)
server, ct, cancel := setUp(t, 0, pingpong)
defer cancel()
defer func() {
// Stop the server's listener to make the server's goroutines terminate
@ -855,7 +858,7 @@ func (s) TestGracefulClose(t *testing.T) {
}
func (s) TestLargeMessageSuspension(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, suspended)
server, ct, cancel := setUp(t, 0, suspended)
defer cancel()
callHdr := &CallHdr{
Host: "localhost",
@ -963,7 +966,7 @@ func (s) TestMaxStreams(t *testing.T) {
}
func (s) TestServerContextCanceledOnClosedConnection(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, suspended)
server, ct, cancel := setUp(t, 0, suspended)
defer cancel()
callHdr := &CallHdr{
Host: "localhost",
@ -1435,7 +1438,7 @@ func (s) TestClientWithMisbehavedServer(t *testing.T) {
var encodingTestStatus = status.New(codes.Internal, "\n")
func (s) TestEncodingRequiredStatus(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, encodingRequiredStatus)
server, ct, cancel := setUp(t, 0, encodingRequiredStatus)
defer cancel()
callHdr := &CallHdr{
Host: "localhost",
@ -1463,7 +1466,7 @@ func (s) TestEncodingRequiredStatus(t *testing.T) {
}
func (s) TestInvalidHeaderField(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, invalidHeaderField)
server, ct, cancel := setUp(t, 0, invalidHeaderField)
defer cancel()
callHdr := &CallHdr{
Host: "localhost",
@ -1485,7 +1488,7 @@ func (s) TestInvalidHeaderField(t *testing.T) {
}
func (s) TestHeaderChanClosedAfterReceivingAnInvalidHeader(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, invalidHeaderField)
server, ct, cancel := setUp(t, 0, invalidHeaderField)
defer cancel()
defer server.stop()
defer ct.Close(fmt.Errorf("closed manually by test"))
@ -2153,7 +2156,7 @@ func (s) TestPingPong1MB(t *testing.T) {
// This is a stress-test of flow control logic.
func runPingPongTest(t *testing.T, msgSize int) {
server, client, cancel := setUp(t, 0, 0, pingpong)
server, client, cancel := setUp(t, 0, pingpong)
defer cancel()
defer server.stop()
defer client.Close(fmt.Errorf("closed manually by test"))
@ -2235,7 +2238,7 @@ func (s) TestHeaderTblSize(t *testing.T) {
}
}()
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
server, ct, cancel := setUp(t, 0, normal)
defer cancel()
defer ct.Close(fmt.Errorf("closed manually by test"))
defer server.stop()
@ -2594,7 +2597,7 @@ func TestConnectionError_Unwrap(t *testing.T) {
func (s) TestPeerSetInServerContext(t *testing.T) {
// create client and server transports.
server, client, cancel := setUp(t, 0, math.MaxUint32, normal)
server, client, cancel := setUp(t, 0, normal)
defer cancel()
defer server.stop()
defer client.Close(fmt.Errorf("closed manually by test"))

View File

@ -115,12 +115,6 @@ type serviceInfo struct {
mdata any
}
type serverWorkerData struct {
st transport.ServerTransport
wg *sync.WaitGroup
stream *transport.Stream
}
// Server is a gRPC server to serve RPC requests.
type Server struct {
opts serverOptions
@ -145,7 +139,7 @@ type Server struct {
channelzID *channelz.Identifier
czData *channelzData
serverWorkerChannel chan *serverWorkerData
serverWorkerChannel chan func()
}
type serverOptions struct {
@ -179,6 +173,7 @@ type serverOptions struct {
}
var defaultServerOptions = serverOptions{
maxConcurrentStreams: math.MaxUint32,
maxReceiveMessageSize: defaultServerMaxReceiveMessageSize,
maxSendMessageSize: defaultServerMaxSendMessageSize,
connectionTimeout: 120 * time.Second,
@ -404,6 +399,9 @@ func MaxSendMsgSize(m int) ServerOption {
// MaxConcurrentStreams returns a ServerOption that will apply a limit on the number
// of concurrent streams to each ServerTransport.
func MaxConcurrentStreams(n uint32) ServerOption {
if n == 0 {
n = math.MaxUint32
}
return newFuncServerOption(func(o *serverOptions) {
o.maxConcurrentStreams = n
})
@ -605,24 +603,19 @@ const serverWorkerResetThreshold = 1 << 16
// [1] https://github.com/golang/go/issues/18138
func (s *Server) serverWorker() {
for completed := 0; completed < serverWorkerResetThreshold; completed++ {
data, ok := <-s.serverWorkerChannel
f, ok := <-s.serverWorkerChannel
if !ok {
return
}
s.handleSingleStream(data)
f()
}
go s.serverWorker()
}
func (s *Server) handleSingleStream(data *serverWorkerData) {
defer data.wg.Done()
s.handleStream(data.st, data.stream)
}
// initServerWorkers creates worker goroutines and a channel to process incoming
// connections to reduce the time spent overall on runtime.morestack.
func (s *Server) initServerWorkers() {
s.serverWorkerChannel = make(chan *serverWorkerData)
s.serverWorkerChannel = make(chan func())
for i := uint32(0); i < s.opts.numServerWorkers; i++ {
go s.serverWorker()
}
@ -982,21 +975,26 @@ func (s *Server) serveStreams(st transport.ServerTransport) {
defer st.Close(errors.New("finished serving streams for the server transport"))
var wg sync.WaitGroup
streamQuota := newHandlerQuota(s.opts.maxConcurrentStreams)
st.HandleStreams(func(stream *transport.Stream) {
wg.Add(1)
streamQuota.acquire()
f := func() {
defer streamQuota.release()
defer wg.Done()
s.handleStream(st, stream)
}
if s.opts.numServerWorkers > 0 {
data := &serverWorkerData{st: st, wg: &wg, stream: stream}
select {
case s.serverWorkerChannel <- data:
case s.serverWorkerChannel <- f:
return
default:
// If all stream workers are busy, fallback to the default code path.
}
}
go func() {
defer wg.Done()
s.handleStream(st, stream)
}()
go f()
})
wg.Wait()
}
@ -2077,3 +2075,34 @@ func validateSendCompressor(name, clientCompressors string) error {
}
return fmt.Errorf("client does not support compressor %q", name)
}
// atomicSemaphore implements a blocking, counting semaphore. acquire should be
// called synchronously; release may be called asynchronously.
type atomicSemaphore struct {
n atomic.Int64
wait chan struct{}
}
func (q *atomicSemaphore) acquire() {
if q.n.Add(-1) < 0 {
// We ran out of quota. Block until a release happens.
<-q.wait
}
}
func (q *atomicSemaphore) release() {
// N.B. the "<= 0" check below should allow for this to work with multiple
// concurrent calls to acquire, but also note that with synchronous calls to
// acquire, as our system does, n will never be less than -1. There are
// fairness issues (queuing) to consider if this was to be generalized.
if q.n.Add(1) <= 0 {
// An acquire was waiting on us. Unblock it.
q.wait <- struct{}{}
}
}
func newHandlerQuota(n uint32) *atomicSemaphore {
a := &atomicSemaphore{wait: make(chan struct{}, 1)}
a.n.Store(int64(n))
return a
}

99
server_ext_test.go Normal file
View File

@ -0,0 +1,99 @@
/*
*
* Copyright 2023 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpc_test
import (
"context"
"io"
"testing"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/stubserver"
testgrpc "google.golang.org/grpc/interop/grpc_testing"
)
// TestServer_MaxHandlers ensures that no more than MaxConcurrentStreams server
// handlers are active at one time.
func (s) TestServer_MaxHandlers(t *testing.T) {
started := make(chan struct{})
blockCalls := grpcsync.NewEvent()
// This stub server does not properly respect the stream context, so it will
// not exit when the context is canceled.
ss := stubserver.StubServer{
FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
started <- struct{}{}
<-blockCalls.Done()
return nil
},
}
if err := ss.Start([]grpc.ServerOption{grpc.MaxConcurrentStreams(1)}); err != nil {
t.Fatal("Error starting server:", err)
}
defer ss.Stop()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
// Start one RPC to the server.
ctx1, cancel1 := context.WithCancel(ctx)
_, err := ss.Client.FullDuplexCall(ctx1)
if err != nil {
t.Fatal("Error staring call:", err)
}
// Wait for the handler to be invoked.
select {
case <-started:
case <-ctx.Done():
t.Fatalf("Timed out waiting for RPC to start on server.")
}
// Cancel it on the client. The server handler will still be running.
cancel1()
ctx2, cancel2 := context.WithCancel(ctx)
defer cancel2()
s, err := ss.Client.FullDuplexCall(ctx2)
if err != nil {
t.Fatal("Error staring call:", err)
}
// After 100ms, allow the first call to unblock. That should allow the
// second RPC to run and finish.
select {
case <-started:
blockCalls.Fire()
t.Fatalf("RPC started unexpectedly.")
case <-time.After(100 * time.Millisecond):
blockCalls.Fire()
}
select {
case <-started:
case <-ctx.Done():
t.Fatalf("Timed out waiting for second RPC to start on server.")
}
if _, err := s.Recv(); err != io.EOF {
t.Fatal("Received unexpected RPC error:", err)
}
}