mirror of https://github.com/grpc/grpc-go.git
internal/transport: Wait for server goroutines to exit during shutdown in test (#8306)
This commit is contained in:
parent
aaabd60df2
commit
6995ef2ab6
|
@ -320,21 +320,23 @@ func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *ServerStream)
|
||||||
}
|
}
|
||||||
|
|
||||||
type server struct {
|
type server struct {
|
||||||
lis net.Listener
|
lis net.Listener
|
||||||
port string
|
port string
|
||||||
startedErr chan error // error (or nil) with server start value
|
startedErr chan error // error (or nil) with server start value
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
conns map[ServerTransport]net.Conn
|
conns map[ServerTransport]net.Conn
|
||||||
h *testStreamHandler
|
h *testStreamHandler
|
||||||
ready chan struct{}
|
ready chan struct{}
|
||||||
channelz *channelz.Server
|
channelz *channelz.Server
|
||||||
|
servingTasksDone chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTestServer() *server {
|
func newTestServer() *server {
|
||||||
return &server{
|
return &server{
|
||||||
startedErr: make(chan error, 1),
|
startedErr: make(chan error, 1),
|
||||||
ready: make(chan struct{}),
|
ready: make(chan struct{}),
|
||||||
channelz: channelz.RegisterServer("test server"),
|
servingTasksDone: make(chan struct{}),
|
||||||
|
channelz: channelz.RegisterServer("test server"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -358,6 +360,12 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
|
||||||
s.port = p
|
s.port = p
|
||||||
s.conns = make(map[ServerTransport]net.Conn)
|
s.conns = make(map[ServerTransport]net.Conn)
|
||||||
s.startedErr <- nil
|
s.startedErr <- nil
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
defer func() {
|
||||||
|
wg.Wait()
|
||||||
|
close(s.servingTasksDone)
|
||||||
|
}()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
conn, err := s.lis.Accept()
|
conn, err := s.lis.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -383,40 +391,89 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
wg.Add(1)
|
||||||
switch ht {
|
switch ht {
|
||||||
case notifyCall:
|
case notifyCall:
|
||||||
go transport.HandleStreams(ctx, h.handleStreamAndNotify)
|
go func() {
|
||||||
|
transport.HandleStreams(ctx, h.handleStreamAndNotify)
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
case suspended:
|
case suspended:
|
||||||
go transport.HandleStreams(ctx, func(*ServerStream) {})
|
go func() {
|
||||||
|
transport.HandleStreams(ctx, func(*ServerStream) {})
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
case misbehaved:
|
case misbehaved:
|
||||||
go transport.HandleStreams(ctx, func(s *ServerStream) {
|
go func() {
|
||||||
go h.handleStreamMisbehave(t, s)
|
transport.HandleStreams(ctx, func(s *ServerStream) {
|
||||||
})
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
h.handleStreamMisbehave(t, s)
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
})
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
case encodingRequiredStatus:
|
case encodingRequiredStatus:
|
||||||
go transport.HandleStreams(ctx, func(s *ServerStream) {
|
go func() {
|
||||||
go h.handleStreamEncodingRequiredStatus(s)
|
transport.HandleStreams(ctx, func(s *ServerStream) {
|
||||||
})
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
h.handleStreamEncodingRequiredStatus(s)
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
})
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
case invalidHeaderField:
|
case invalidHeaderField:
|
||||||
go transport.HandleStreams(ctx, func(s *ServerStream) {
|
go func() {
|
||||||
go h.handleStreamInvalidHeaderField(s)
|
transport.HandleStreams(ctx, func(s *ServerStream) {
|
||||||
})
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
h.handleStreamInvalidHeaderField(s)
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
})
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
case delayRead:
|
case delayRead:
|
||||||
h.notify = make(chan struct{})
|
h.notify = make(chan struct{})
|
||||||
h.getNotified = make(chan struct{})
|
h.getNotified = make(chan struct{})
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
close(s.ready)
|
close(s.ready)
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
go transport.HandleStreams(ctx, func(s *ServerStream) {
|
go func() {
|
||||||
go h.handleStreamDelayRead(t, s)
|
transport.HandleStreams(ctx, func(s *ServerStream) {
|
||||||
})
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
h.handleStreamDelayRead(t, s)
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
})
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
case pingpong:
|
case pingpong:
|
||||||
go transport.HandleStreams(ctx, func(s *ServerStream) {
|
go func() {
|
||||||
go h.handleStreamPingPong(t, s)
|
transport.HandleStreams(ctx, func(s *ServerStream) {
|
||||||
})
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
h.handleStreamPingPong(t, s)
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
})
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
default:
|
default:
|
||||||
go transport.HandleStreams(ctx, func(s *ServerStream) {
|
go func() {
|
||||||
go h.handleStream(t, s)
|
transport.HandleStreams(ctx, func(s *ServerStream) {
|
||||||
})
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
h.handleStream(t, s)
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
})
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -440,6 +497,7 @@ func (s *server) stop() {
|
||||||
}
|
}
|
||||||
s.conns = nil
|
s.conns = nil
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
<-s.servingTasksDone
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *server) addr() string {
|
func (s *server) addr() string {
|
||||||
|
@ -2254,11 +2312,11 @@ func (s) TestPingPong1B(t *testing.T) {
|
||||||
runPingPongTest(t, 1)
|
runPingPongTest(t, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPingPong1KB(t *testing.T) {
|
func (s) TestPingPong1KB(t *testing.T) {
|
||||||
runPingPongTest(t, 1024)
|
runPingPongTest(t, 1024)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPingPong64KB(t *testing.T) {
|
func (s) TestPingPong64KB(t *testing.T) {
|
||||||
runPingPongTest(t, 65536)
|
runPingPongTest(t, 65536)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue