From 2bb318258959db281674bc6fd67b5167b7ff0d65 Mon Sep 17 00:00:00 2001 From: dfawley Date: Thu, 20 Jul 2017 14:09:45 -0700 Subject: [PATCH] Fix bufconn.Close to not be blocking. (#1377) --- test/bufconn/bufconn.go | 40 +++++++++++++++++++----------------- test/bufconn/bufconn_test.go | 32 +++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 19 deletions(-) diff --git a/test/bufconn/bufconn.go b/test/bufconn/bufconn.go index 7d244d5c9..bc0ab839f 100644 --- a/test/bufconn/bufconn.go +++ b/test/bufconn/bufconn.go @@ -31,10 +31,10 @@ import ( // Listener implements a net.Listener that creates local, buffered net.Conns // via its Accept and Dial method. type Listener struct { - mu sync.Mutex - sz int - ch chan net.Conn - closed bool + mu sync.Mutex + sz int + ch chan net.Conn + done chan struct{} } var errClosed = fmt.Errorf("Closed") @@ -42,28 +42,31 @@ var errClosed = fmt.Errorf("Closed") // Listen returns a Listener that can only be contacted by its own Dialers and // creates buffered connections between the two. func Listen(sz int) *Listener { - return &Listener{sz: sz, ch: make(chan net.Conn)} + return &Listener{sz: sz, ch: make(chan net.Conn), done: make(chan struct{})} } // Accept blocks until Dial is called, then returns a net.Conn for the server // half of the connection. func (l *Listener) Accept() (net.Conn, error) { - c := <-l.ch - if c == nil { + select { + case <-l.done: return nil, errClosed + case c := <-l.ch: + return c, nil } - return c, nil } // Close stops the listener. func (l *Listener) Close() error { l.mu.Lock() defer l.mu.Unlock() - if l.closed { - return nil + select { + case <-l.done: + // Already closed. + break + default: + close(l.done) } - l.closed = true - close(l.ch) return nil } @@ -74,14 +77,13 @@ func (l *Listener) Addr() net.Addr { return addr{} } // providing it the server half of the connection, and returns the client half // of the connection. func (l *Listener) Dial() (net.Conn, error) { - l.mu.Lock() - defer l.mu.Unlock() - if l.closed { - return nil, errClosed - } p1, p2 := newPipe(l.sz), newPipe(l.sz) - l.ch <- &conn{p1, p2} - return &conn{p2, p1}, nil + select { + case <-l.done: + return nil, errClosed + case l.ch <- &conn{p1, p2}: + return &conn{p2, p1}, nil + } } type pipe struct { diff --git a/test/bufconn/bufconn_test.go b/test/bufconn/bufconn_test.go index 7d7b9207a..0f7bc2227 100644 --- a/test/bufconn/bufconn_test.go +++ b/test/bufconn/bufconn_test.go @@ -115,3 +115,35 @@ func TestListener(t *testing.T) { t.Fatalf(err.Error()) } } + +func TestCloseWhileDialing(t *testing.T) { + l := Listen(7) + var c net.Conn + var err error + done := make(chan struct{}) + go func() { + c, err = l.Dial() + close(done) + }() + l.Close() + <-done + if c != nil || err != errClosed { + t.Fatalf("c, err = %v, %v; want nil, %v", c, err, errClosed) + } +} + +func TestCloseWhileAccepting(t *testing.T) { + l := Listen(7) + var c net.Conn + var err error + done := make(chan struct{}) + go func() { + c, err = l.Accept() + close(done) + }() + l.Close() + <-done + if c != nil || err != errClosed { + t.Fatalf("c, err = %v, %v; want nil, %v", c, err, errClosed) + } +}