From a1ed1ec1fab8a4e54f68a263ecbafc492bf17e38 Mon Sep 17 00:00:00 2001 From: Solomon Hykes Date: Fri, 9 May 2014 20:55:39 -0700 Subject: [PATCH] inmem: switch to a Sender/Receiver/Pipe interface similar to io.Pipe Signed-off-by: Solomon Hykes --- inmem/inmem.go | 243 +++++++++++++++++++++++++++++++------------- inmem/inmem_test.go | 27 +++-- 2 files changed, 190 insertions(+), 80 deletions(-) diff --git a/inmem/inmem.go b/inmem/inmem.go index 67302e70d6..12e2791199 100644 --- a/inmem/inmem.go +++ b/inmem/inmem.go @@ -1,88 +1,193 @@ package inmem import ( - "fmt" "io" "sync" ) -type Handle interface { - Send(msg *Message, mode int) (Handle, error) - Receive(mode int) (*Message, Handle, error) - CloseWrite() error +type Sender interface { + Send(msg *Message, mode int) (Receiver, Sender, error) + Close() error +} + +type Receiver interface { + Receive(mode int) (*Message, Receiver, Sender, error) + Close() error +} + +type Message struct { + Name string + Args []string + Data string } const ( - R = 1 << (32 - 1 - iota) + R = 1 << (32 - 1 - iota) W ) -type Message struct { - Name string - Args []string - Data string +func Pipe() (*PipeReceiver, *PipeSender) { + p := new(pipe) + p.rwait.L = &p.l + p.wwait.L = &p.l + r := &PipeReceiver{p} + w := &PipeSender{p} + return r, w } - -func Pipe() (Handle, Handle) { - red := make(chan *pipeMessage) - black := make(chan *pipeMessage) - return &PipeHandle{r: red, w: black}, &PipeHandle{r: black, w: red} -} - -type PipeHandle struct { - sync.RWMutex - r chan *pipeMessage - w chan *pipeMessage -} - -func (h *PipeHandle) Send(msg *Message, mode int) (Handle, error) { - h.RLock() - defer h.RUnlock() - if h.w == nil { - return nil, fmt.Errorf("closed pipe") - } - var ( - rh Handle - lh Handle - ) - if mode&(R|W) != 0 { - rh, lh = Pipe() - if mode&W == 0 { - lh.CloseWrite() - } - } - h.w <-&pipeMessage{msg, rh} - return lh, nil -} - -func (h *PipeHandle) Receive(mode int) (*Message, Handle, error) { - pmsg, ok := <-h.r - if !ok { - return nil, nil, io.EOF - } - var handle Handle - if pmsg.handle != nil && mode&W == 0 { - pmsg.handle.CloseWrite() - } - if mode&(R|W) != 0 { - handle = pmsg.handle - } - return pmsg.payload, handle, nil -} - -func (h *PipeHandle) CloseWrite() error { - h.Lock() - defer h.Unlock() - if h.w == nil { - return fmt.Errorf("already closed") - } - close(h.w) - h.w = nil - return nil +type pipe struct { + ch chan *pipeMessage + rwait sync.Cond + wwait sync.Cond + l sync.Mutex + rl sync.Mutex + wl sync.Mutex + rerr error // if reader closed, error to give writes + werr error // if writer closed, error to give reads + pmsg *pipeMessage } type pipeMessage struct { - payload *Message - handle Handle + msg *Message + out *PipeSender + in *PipeReceiver +} + +func (p *pipe) send(msg *Message, mode int) (in Receiver, out Sender, err error) { + // Prepare the message + pmsg := &pipeMessage{msg: msg} + if mode&R != 0 { + in, pmsg.out = Pipe() + defer func() { + if err != nil { + in.Close() + in = nil + pmsg.out.Close() + } + }() + } + if mode&W != 0 { + pmsg.in, out = Pipe() + defer func() { + if err != nil { + out.Close() + out = nil + pmsg.in.Close() + } + }() + } + // One writer at a time. + p.wl.Lock() + defer p.wl.Unlock() + + p.l.Lock() + defer p.l.Unlock() + p.pmsg = pmsg + p.rwait.Signal() + for { + if p.pmsg == nil { + break + } + if p.rerr != nil { + err = p.rerr + break + } + if p.werr != nil { + err = io.ErrClosedPipe + } + p.wwait.Wait() + } + p.pmsg = nil // in case of rerr or werr + return +} + +func (p *pipe) receive(mode int) (msg *Message, in Receiver, out Sender, err error) { + p.rl.Lock() + defer p.rl.Unlock() + + p.l.Lock() + defer p.l.Unlock() + for { + if p.rerr != nil { + return nil, nil, nil, io.ErrClosedPipe + } + if p.pmsg != nil { + break + } + if p.werr != nil { + return nil, nil, nil, p.werr + } + p.rwait.Wait() + } + pmsg := p.pmsg + if pmsg.out != nil && mode&W == 0 { + pmsg.out.Close() + } + if pmsg.in != nil && mode&R == 0 { + pmsg.in.Close() + } + p.pmsg = nil + msg = pmsg.msg + p.wwait.Signal() + return +} + +func (p *pipe) rclose(err error) { + if err == nil { + err = io.ErrClosedPipe + } + p.l.Lock() + defer p.l.Unlock() + p.rerr = err + p.rwait.Signal() + p.wwait.Signal() +} + +func (p *pipe) wclose(err error) { + if err == nil { + err = io.EOF + } + p.l.Lock() + defer p.l.Unlock() + p.werr = err + p.rwait.Signal() + p.wwait.Signal() +} + +// PipeReceiver + +type PipeReceiver struct { + p *pipe +} + +func (r *PipeReceiver) Receive(mode int) (*Message, Receiver, Sender, error) { + return r.p.receive(mode) +} + +func (r *PipeReceiver) Close() error { + return r.CloseWithError(nil) +} + +func (r *PipeReceiver) CloseWithError(err error) error { + r.p.rclose(err) + return nil +} + +// PipeSender + +type PipeSender struct { + p *pipe +} + +func (w *PipeSender) Send(msg *Message, mode int) (Receiver, Sender, error) { + return w.p.send(msg, mode) +} + +func (w *PipeSender) Close() error { + return w.CloseWithError(nil) +} + +func (w *PipeSender) CloseWithError(err error) error { + w.p.wclose(err) + return nil } diff --git a/inmem/inmem_test.go b/inmem/inmem_test.go index 4bac748dd4..03267eb4f1 100644 --- a/inmem/inmem_test.go +++ b/inmem/inmem_test.go @@ -6,13 +6,11 @@ import ( ) func TestSimpleSend(t *testing.T) { - a, b := Pipe() - defer a.CloseWrite() - defer b.CloseWrite() + r, w := Pipe() onTimeout := time.After(100 * time.Millisecond) onRcv := make(chan bool) go func() { - msg, h, err := b.Receive(0) + msg, in, out, err := r.Receive(0) if err != nil { t.Fatal(err) } @@ -25,20 +23,27 @@ func TestSimpleSend(t *testing.T) { if len(msg.Args) != 0 { t.Fatalf("%#v", *msg) } - if h != nil { - t.Fatalf("%#v", h) + if in != nil { + t.Fatalf("%#v", in) + } + if out != nil { + t.Fatalf("%#v", out) } close(onRcv) }() - h, err := a.Send(&Message{Name:"print", Data: "hello world"}, 0) + in, out, err := w.Send(&Message{Name: "print", Data: "hello world"}, 0) if err != nil { t.Fatal(err) } - if h != nil { - t.Fatalf("%#v", h) + if in != nil { + t.Fatalf("%#v", in) + } + if out != nil { + t.Fatalf("%#v", out) } select { - case <-onTimeout: t.Fatalf("timeout") - case <-onRcv: + case <-onTimeout: + t.Fatalf("timeout") + case <-onRcv: } }