internal/transport: Wait for server goroutines to exit during shutdown in test (#8306)

This commit is contained in:
Arjan Singh Bal 2025-05-21 09:24:41 +05:30 committed by GitHub
parent aaabd60df2
commit 6995ef2ab6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 91 additions and 33 deletions

View File

@ -320,21 +320,23 @@ func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *ServerStream)
} }
type server struct { type server struct {
lis net.Listener lis net.Listener
port string port string
startedErr chan error // error (or nil) with server start value startedErr chan error // error (or nil) with server start value
mu sync.Mutex mu sync.Mutex
conns map[ServerTransport]net.Conn conns map[ServerTransport]net.Conn
h *testStreamHandler h *testStreamHandler
ready chan struct{} ready chan struct{}
channelz *channelz.Server channelz *channelz.Server
servingTasksDone chan struct{}
} }
func newTestServer() *server { func newTestServer() *server {
return &server{ return &server{
startedErr: make(chan error, 1), startedErr: make(chan error, 1),
ready: make(chan struct{}), ready: make(chan struct{}),
channelz: channelz.RegisterServer("test server"), servingTasksDone: make(chan struct{}),
channelz: channelz.RegisterServer("test server"),
} }
} }
@ -358,6 +360,12 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
s.port = p s.port = p
s.conns = make(map[ServerTransport]net.Conn) s.conns = make(map[ServerTransport]net.Conn)
s.startedErr <- nil s.startedErr <- nil
wg := sync.WaitGroup{}
defer func() {
wg.Wait()
close(s.servingTasksDone)
}()
for { for {
conn, err := s.lis.Accept() conn, err := s.lis.Accept()
if err != nil { if err != nil {
@ -383,40 +391,89 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
s.mu.Unlock() s.mu.Unlock()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel() defer cancel()
wg.Add(1)
switch ht { switch ht {
case notifyCall: case notifyCall:
go transport.HandleStreams(ctx, h.handleStreamAndNotify) go func() {
transport.HandleStreams(ctx, h.handleStreamAndNotify)
wg.Done()
}()
case suspended: case suspended:
go transport.HandleStreams(ctx, func(*ServerStream) {}) go func() {
transport.HandleStreams(ctx, func(*ServerStream) {})
wg.Done()
}()
case misbehaved: case misbehaved:
go transport.HandleStreams(ctx, func(s *ServerStream) { go func() {
go h.handleStreamMisbehave(t, s) transport.HandleStreams(ctx, func(s *ServerStream) {
}) wg.Add(1)
go func() {
h.handleStreamMisbehave(t, s)
wg.Done()
}()
})
wg.Done()
}()
case encodingRequiredStatus: case encodingRequiredStatus:
go transport.HandleStreams(ctx, func(s *ServerStream) { go func() {
go h.handleStreamEncodingRequiredStatus(s) transport.HandleStreams(ctx, func(s *ServerStream) {
}) wg.Add(1)
go func() {
h.handleStreamEncodingRequiredStatus(s)
wg.Done()
}()
})
wg.Done()
}()
case invalidHeaderField: case invalidHeaderField:
go transport.HandleStreams(ctx, func(s *ServerStream) { go func() {
go h.handleStreamInvalidHeaderField(s) transport.HandleStreams(ctx, func(s *ServerStream) {
}) wg.Add(1)
go func() {
h.handleStreamInvalidHeaderField(s)
wg.Done()
}()
})
wg.Done()
}()
case delayRead: case delayRead:
h.notify = make(chan struct{}) h.notify = make(chan struct{})
h.getNotified = make(chan struct{}) h.getNotified = make(chan struct{})
s.mu.Lock() s.mu.Lock()
close(s.ready) close(s.ready)
s.mu.Unlock() s.mu.Unlock()
go transport.HandleStreams(ctx, func(s *ServerStream) { go func() {
go h.handleStreamDelayRead(t, s) transport.HandleStreams(ctx, func(s *ServerStream) {
}) wg.Add(1)
go func() {
h.handleStreamDelayRead(t, s)
wg.Done()
}()
})
wg.Done()
}()
case pingpong: case pingpong:
go transport.HandleStreams(ctx, func(s *ServerStream) { go func() {
go h.handleStreamPingPong(t, s) transport.HandleStreams(ctx, func(s *ServerStream) {
}) wg.Add(1)
go func() {
h.handleStreamPingPong(t, s)
wg.Done()
}()
})
wg.Done()
}()
default: default:
go transport.HandleStreams(ctx, func(s *ServerStream) { go func() {
go h.handleStream(t, s) transport.HandleStreams(ctx, func(s *ServerStream) {
}) wg.Add(1)
go func() {
h.handleStream(t, s)
wg.Done()
}()
})
wg.Done()
}()
} }
} }
} }
@ -440,6 +497,7 @@ func (s *server) stop() {
} }
s.conns = nil s.conns = nil
s.mu.Unlock() s.mu.Unlock()
<-s.servingTasksDone
} }
func (s *server) addr() string { func (s *server) addr() string {
@ -2254,11 +2312,11 @@ func (s) TestPingPong1B(t *testing.T) {
runPingPongTest(t, 1) runPingPongTest(t, 1)
} }
func TestPingPong1KB(t *testing.T) { func (s) TestPingPong1KB(t *testing.T) {
runPingPongTest(t, 1024) runPingPongTest(t, 1024)
} }
func TestPingPong64KB(t *testing.T) { func (s) TestPingPong64KB(t *testing.T) {
runPingPongTest(t, 65536) runPingPongTest(t, 65536)
} }