internal/transport: Wait for server goroutines to exit during shutdown in test (#8306)

This commit is contained in:
Arjan Singh Bal 2025-05-21 09:24:41 +05:30 committed by GitHub
parent aaabd60df2
commit 6995ef2ab6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 91 additions and 33 deletions

View File

@ -320,21 +320,23 @@ func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *ServerStream)
}
type server struct {
lis net.Listener
port string
startedErr chan error // error (or nil) with server start value
mu sync.Mutex
conns map[ServerTransport]net.Conn
h *testStreamHandler
ready chan struct{}
channelz *channelz.Server
lis net.Listener
port string
startedErr chan error // error (or nil) with server start value
mu sync.Mutex
conns map[ServerTransport]net.Conn
h *testStreamHandler
ready chan struct{}
channelz *channelz.Server
servingTasksDone chan struct{}
}
func newTestServer() *server {
return &server{
startedErr: make(chan error, 1),
ready: make(chan struct{}),
channelz: channelz.RegisterServer("test server"),
startedErr: make(chan error, 1),
ready: make(chan struct{}),
servingTasksDone: make(chan struct{}),
channelz: channelz.RegisterServer("test server"),
}
}
@ -358,6 +360,12 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
s.port = p
s.conns = make(map[ServerTransport]net.Conn)
s.startedErr <- nil
wg := sync.WaitGroup{}
defer func() {
wg.Wait()
close(s.servingTasksDone)
}()
for {
conn, err := s.lis.Accept()
if err != nil {
@ -383,40 +391,89 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
s.mu.Unlock()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
wg.Add(1)
switch ht {
case notifyCall:
go transport.HandleStreams(ctx, h.handleStreamAndNotify)
go func() {
transport.HandleStreams(ctx, h.handleStreamAndNotify)
wg.Done()
}()
case suspended:
go transport.HandleStreams(ctx, func(*ServerStream) {})
go func() {
transport.HandleStreams(ctx, func(*ServerStream) {})
wg.Done()
}()
case misbehaved:
go transport.HandleStreams(ctx, func(s *ServerStream) {
go h.handleStreamMisbehave(t, s)
})
go func() {
transport.HandleStreams(ctx, func(s *ServerStream) {
wg.Add(1)
go func() {
h.handleStreamMisbehave(t, s)
wg.Done()
}()
})
wg.Done()
}()
case encodingRequiredStatus:
go transport.HandleStreams(ctx, func(s *ServerStream) {
go h.handleStreamEncodingRequiredStatus(s)
})
go func() {
transport.HandleStreams(ctx, func(s *ServerStream) {
wg.Add(1)
go func() {
h.handleStreamEncodingRequiredStatus(s)
wg.Done()
}()
})
wg.Done()
}()
case invalidHeaderField:
go transport.HandleStreams(ctx, func(s *ServerStream) {
go h.handleStreamInvalidHeaderField(s)
})
go func() {
transport.HandleStreams(ctx, func(s *ServerStream) {
wg.Add(1)
go func() {
h.handleStreamInvalidHeaderField(s)
wg.Done()
}()
})
wg.Done()
}()
case delayRead:
h.notify = make(chan struct{})
h.getNotified = make(chan struct{})
s.mu.Lock()
close(s.ready)
s.mu.Unlock()
go transport.HandleStreams(ctx, func(s *ServerStream) {
go h.handleStreamDelayRead(t, s)
})
go func() {
transport.HandleStreams(ctx, func(s *ServerStream) {
wg.Add(1)
go func() {
h.handleStreamDelayRead(t, s)
wg.Done()
}()
})
wg.Done()
}()
case pingpong:
go transport.HandleStreams(ctx, func(s *ServerStream) {
go h.handleStreamPingPong(t, s)
})
go func() {
transport.HandleStreams(ctx, func(s *ServerStream) {
wg.Add(1)
go func() {
h.handleStreamPingPong(t, s)
wg.Done()
}()
})
wg.Done()
}()
default:
go transport.HandleStreams(ctx, func(s *ServerStream) {
go h.handleStream(t, s)
})
go func() {
transport.HandleStreams(ctx, func(s *ServerStream) {
wg.Add(1)
go func() {
h.handleStream(t, s)
wg.Done()
}()
})
wg.Done()
}()
}
}
}
@ -440,6 +497,7 @@ func (s *server) stop() {
}
s.conns = nil
s.mu.Unlock()
<-s.servingTasksDone
}
func (s *server) addr() string {
@ -2254,11 +2312,11 @@ func (s) TestPingPong1B(t *testing.T) {
runPingPongTest(t, 1)
}
func TestPingPong1KB(t *testing.T) {
func (s) TestPingPong1KB(t *testing.T) {
runPingPongTest(t, 1024)
}
func TestPingPong64KB(t *testing.T) {
func (s) TestPingPong64KB(t *testing.T) {
runPingPongTest(t, 65536)
}