diff --git a/channel.go b/channel.go index 2339767..a71260b 100644 --- a/channel.go +++ b/channel.go @@ -5,101 +5,95 @@ import ( "context" "encoding/binary" "io" - "net" - "github.com/containerd/containerd/log" - "github.com/gogo/protobuf/proto" "github.com/pkg/errors" ) -const maxMessageSize = 8 << 10 // TODO(stevvooe): Cut these down, since they are pre-alloced. +const ( + messageHeaderLength = 10 + messageLengthMax = 8 << 10 +) -type channel struct { - conn net.Conn - bw *bufio.Writer - br *bufio.Reader +type messageType uint8 + +const ( + messageTypeRequest messageType = 0x1 + messageTypeResponse messageType = 0x2 +) + +// messageHeader represents the fixed-length message header of 10 bytes sent +// with every request. +type messageHeader struct { + Length uint32 // length excluding this header. b[:4] + StreamID uint32 // identifies which request stream message is a part of. b[4:8] + Type messageType // message type b[8] + Flags uint8 // reserved b[9] } -func newChannel(conn net.Conn) *channel { +func readMessageHeader(p []byte, r io.Reader) (messageHeader, error) { + _, err := io.ReadFull(r, p[:messageHeaderLength]) + if err != nil { + return messageHeader{}, err + } + + return messageHeader{ + Length: binary.BigEndian.Uint32(p[:4]), + StreamID: binary.BigEndian.Uint32(p[4:8]), + Type: messageType(p[8]), + Flags: p[9], + }, nil +} + +func writeMessageHeader(w io.Writer, p []byte, mh messageHeader) error { + binary.BigEndian.PutUint32(p[:4], mh.Length) + binary.BigEndian.PutUint32(p[4:8], mh.StreamID) + p[8] = byte(mh.Type) + p[9] = mh.Flags + + _, err := w.Write(p[:]) + return err +} + +type channel struct { + bw *bufio.Writer + br *bufio.Reader + hrbuf [messageHeaderLength]byte // avoid alloc when reading header + hwbuf [messageHeaderLength]byte +} + +func newChannel(w io.Writer, r io.Reader) *channel { return &channel{ - conn: conn, - bw: bufio.NewWriterSize(conn, maxMessageSize), - br: bufio.NewReaderSize(conn, maxMessageSize), + bw: bufio.NewWriter(w), + br: bufio.NewReader(r), } } -func (ch *channel) recv(ctx context.Context, msg interface{}) error { - defer log.G(ctx).WithField("msg", msg).Info("recv") +func (ch *channel) recv(ctx context.Context, p []byte) (messageHeader, error) { + mh, err := readMessageHeader(ch.hrbuf[:], ch.br) + if err != nil { + return messageHeader{}, err + } - // TODO(stevvooe): Use `bufio.Reader.Peek` here to remove this allocation. - var p [maxMessageSize]byte - n, err := readmsg(ch.br, p[:]) + if mh.Length > uint32(len(p)) { + return messageHeader{}, errors.Wrapf(io.ErrShortBuffer, "message length %v over buffer size %v", mh.Length, len(p)) + } + + if _, err := io.ReadFull(ch.br, p[:mh.Length]); err != nil { + return messageHeader{}, errors.Wrapf(err, "failed reading message") + } + + return mh, nil +} + +func (ch *channel) send(ctx context.Context, streamID uint32, t messageType, p []byte) error { + if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t}); err != nil { + return err + } + + _, err := ch.bw.Write(p) if err != nil { return err } - switch msg := msg.(type) { - case proto.Message: - return proto.Unmarshal(p[:n], msg) - default: - return errors.Errorf("unnsupported type in channel: %#v", msg) - } -} - -func (ch *channel) send(ctx context.Context, msg interface{}) error { - log.G(ctx).WithField("msg", msg).Info("send") - var p []byte - switch msg := msg.(type) { - case proto.Message: - var err error - // TODO(stevvooe): trickiest allocation of the bunch. This will be hard - // to get rid of without using `MarshalTo` directly. - p, err = proto.Marshal(msg) - if err != nil { - return err - } - default: - return errors.Errorf("unsupported type recv from channel: %#v", msg) - } - - return writemsg(ch.bw, p) -} - -func readmsg(r *bufio.Reader, p []byte) (int, error) { - mlen, err := binary.ReadVarint(r) - if err != nil { - return 0, errors.Wrapf(err, "failed reading message size") - } - - if mlen > int64(len(p)) { - return 0, errors.Wrapf(io.ErrShortBuffer, "message length %v over buffer size %v", mlen, len(p)) - } - - nn, err := io.ReadFull(r, p[:mlen]) - if err != nil { - return 0, errors.Wrapf(err, "failed reading message size") - } - - if int64(nn) != mlen { - return 0, errors.Errorf("mismatched read against message length %v != %v", nn, mlen) - } - - return int(mlen), nil -} - -func writemsg(w *bufio.Writer, p []byte) error { - var ( - mlenp [binary.MaxVarintLen64]byte - n = binary.PutVarint(mlenp[:], int64(len(p))) - ) - - if _, err := w.Write(mlenp[:n]); err != nil { - return errors.Wrapf(err, "failed writing message header") - } - - if _, err := w.Write(p); err != nil { - return errors.Wrapf(err, "failed writing message") - } - - return w.Flush() + return ch.bw.Flush() } diff --git a/channel_test.go b/channel_test.go index 8bff326..58ac82b 100644 --- a/channel_test.go +++ b/channel_test.go @@ -3,6 +3,7 @@ package ttrpc import ( "bufio" "bytes" + "context" "io" "reflect" "testing" @@ -12,8 +13,10 @@ import ( func TestReadWriteMessage(t *testing.T) { var ( - channel bytes.Buffer - w = bufio.NewWriter(&channel) + ctx = context.Background() + buffer bytes.Buffer + w = bufio.NewWriter(&buffer) + ch = newChannel(w, nil) messages = [][]byte{ []byte("hello"), []byte("this is a test"), @@ -21,20 +24,21 @@ func TestReadWriteMessage(t *testing.T) { } ) - for _, msg := range messages { - if err := writemsg(w, msg); err != nil { + for i, msg := range messages { + if err := ch.send(ctx, uint32(i), 1, msg); err != nil { t.Fatal(err) } } var ( received [][]byte - r = bufio.NewReader(bytes.NewReader(channel.Bytes())) + r = bufio.NewReader(bytes.NewReader(buffer.Bytes())) + rch = newChannel(nil, r) ) for { var p [4096]byte - n, err := readmsg(r, p[:]) + mh, err := rch.recv(ctx, p[:]) if err != nil { if errors.Cause(err) != io.EOF { t.Fatal(err) @@ -42,7 +46,7 @@ func TestReadWriteMessage(t *testing.T) { break } - received = append(received, p[:n]) + received = append(received, p[:mh.Length]) } if !reflect.DeepEqual(received, messages) { @@ -52,21 +56,25 @@ func TestReadWriteMessage(t *testing.T) { func TestSmallBuffer(t *testing.T) { var ( - channel bytes.Buffer - w = bufio.NewWriter(&channel) - msg = []byte("a message of massive length") + ctx = context.Background() + buffer bytes.Buffer + w = bufio.NewWriter(&buffer) + ch = newChannel(w, nil) + msg = []byte("a message of massive length") ) - if err := writemsg(w, msg); err != nil { + if err := ch.send(ctx, 1, 1, msg); err != nil { t.Fatal(err) } // now, read it off the channel with a small buffer var ( - p = make([]byte, len(msg)-1) - r = bufio.NewReader(bytes.NewReader(channel.Bytes())) + p = make([]byte, len(msg)-1) + r = bufio.NewReader(bytes.NewReader(buffer.Bytes())) + rch = newChannel(nil, r) ) - _, err := readmsg(r, p[:]) + + _, err := rch.recv(ctx, p[:]) if err == nil { t.Fatalf("error expected reading with small buffer") } @@ -75,41 +83,3 @@ func TestSmallBuffer(t *testing.T) { t.Fatalf("errors.Cause(err) should equal io.ErrShortBuffer: %v != %v", err, io.ErrShortBuffer) } } - -func BenchmarkReadWrite(b *testing.B) { - b.StopTimer() - var ( - messages = [][]byte{ - []byte("hello"), - []byte("this is a test"), - []byte("of message framing"), - } - total int64 - channel bytes.Buffer - w = bufio.NewWriter(&channel) - p [4096]byte - ) - - b.ReportAllocs() - b.StartTimer() - for i := 0; i < b.N; i++ { - msg := messages[i%len(messages)] - if err := writemsg(w, msg); err != nil { - b.Fatal(err) - } - total += int64(len(msg)) - } - b.SetBytes(total) - - r := bufio.NewReader(bytes.NewReader(channel.Bytes())) - for i := 0; i < b.N; i++ { - _, err := readmsg(r, p[:]) - if err != nil { - if errors.Cause(err) != io.EOF { - b.Fatal(err) - } - - break - } - } -} diff --git a/client.go b/client.go index 437c9f3..265ce95 100644 --- a/client.go +++ b/client.go @@ -3,56 +3,191 @@ package ttrpc import ( "context" "net" + "sync" + "sync/atomic" "github.com/gogo/protobuf/proto" - "github.com/pkg/errors" ) type Client struct { - channel *channel + codec codec + channel *channel + requestID uint32 + sendRequests chan sendRequest + recvRequests chan recvRequest + + closed chan struct{} + closeOnce sync.Once + done chan struct{} + err error } func NewClient(conn net.Conn) *Client { - return &Client{ - channel: newChannel(conn), + c := &Client{ + codec: codec{}, + channel: newChannel(conn, conn), + sendRequests: make(chan sendRequest), + recvRequests: make(chan recvRequest), + closed: make(chan struct{}), + done: make(chan struct{}), } + + go c.run() + return c } func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error { - var payload []byte - switch v := req.(type) { - case proto.Message: - var err error - payload, err = proto.Marshal(v) - if err != nil { - return err - } - default: - return errors.Errorf("ttrpc: unknown request type: %T", req) + payload, err := c.codec.Marshal(req) + if err != nil { + return err } + requestID := atomic.AddUint32(&c.requestID, 1) request := Request{ Service: service, Method: method, Payload: payload, } - if err := c.channel.send(ctx, &request); err != nil { + if err := c.send(ctx, requestID, &request); err != nil { return err } var response Response - if err := c.channel.recv(ctx, &response); err != nil { + if err := c.recv(ctx, requestID, &response); err != nil { return err } - switch v := resp.(type) { - case proto.Message: - if err := proto.Unmarshal(response.Payload, v); err != nil { - return err - } - default: - return errors.Errorf("ttrpc: unknown response type: %T", resp) - } + + return c.codec.Unmarshal(response.Payload, resp) +} + +func (c *Client) Close() error { + c.closeOnce.Do(func() { + close(c.closed) + }) return nil } + +type sendRequest struct { + ctx context.Context + id uint32 + msg interface{} + err chan error +} + +func (c *Client) send(ctx context.Context, id uint32, msg interface{}) error { + errs := make(chan error, 1) + select { + case c.sendRequests <- sendRequest{ + ctx: ctx, + id: id, + msg: msg, + err: errs, + }: + case <-ctx.Done(): + return ctx.Err() + case <-c.done: + return c.err + } + + select { + case err := <-errs: + return err + case <-ctx.Done(): + return ctx.Err() + case <-c.done: + return c.err + } +} + +type recvRequest struct { + id uint32 + msg interface{} + err chan error +} + +func (c *Client) recv(ctx context.Context, id uint32, msg interface{}) error { + errs := make(chan error, 1) + select { + case c.recvRequests <- recvRequest{ + id: id, + msg: msg, + err: errs, + }: + case <-c.done: + return c.err + case <-ctx.Done(): + return ctx.Err() + } + + select { + case err := <-errs: + return err + case <-c.done: + return c.err + case <-ctx.Done(): + return ctx.Err() + } +} + +type received struct { + mh messageHeader + p []byte + err error +} + +func (c *Client) run() { + defer close(c.done) + var ( + waiters = map[uint32]recvRequest{} + queued = map[uint32]received{} // messages unmatched by waiter + incoming = make(chan received) + ) + + go func() { + // start one more goroutine to recv messages without blocking. + for { + var p [messageLengthMax]byte + mh, err := c.channel.recv(context.TODO(), p[:]) + select { + case incoming <- received{ + mh: mh, + p: p[:mh.Length], + err: err, + }: + case <-c.done: + return + } + } + }() + + for { + select { + case req := <-c.sendRequests: + if p, err := proto.Marshal(req.msg.(proto.Message)); err != nil { + req.err <- err + } else { + req.err <- c.channel.send(req.ctx, req.id, messageTypeRequest, p) + } + case req := <-c.recvRequests: + if r, ok := queued[req.id]; ok { + req.err <- proto.Unmarshal(r.p, req.msg.(proto.Message)) + } + waiters[req.id] = req + case r := <-incoming: + if r.err != nil { + c.err = r.err + return + } + + if waiter, ok := waiters[r.mh.StreamID]; ok { + waiter.err <- proto.Unmarshal(r.p, waiter.msg.(proto.Message)) + } else { + queued[r.mh.StreamID] = r + } + case <-c.closed: + return + } + } +} diff --git a/server.go b/server.go index 812f911..7db9e9a 100644 --- a/server.go +++ b/server.go @@ -9,6 +9,7 @@ import ( type Server struct { services *serviceSet + codec codec } func NewServer() *Server { @@ -43,35 +44,91 @@ func (s *Server) Serve(l net.Listener) error { func (s *Server) handleConn(conn net.Conn) { defer conn.Close() + type ( + request struct { + id uint32 + req *Request + } + + response struct { + id uint32 + resp *Response + } + ) + var ( - ch = newChannel(conn) - req Request + ch = newChannel(conn, conn) ctx, cancel = context.WithCancel(context.Background()) + responses = make(chan response) + requests = make(chan request) + recvErr = make(chan error, 1) + done = make(chan struct{}) ) defer cancel() + defer close(done) - // TODO(stevvooe): Recover here or in dispatch to handle panics in service - // methods. + go func() { + defer close(recvErr) + var p [messageLengthMax]byte + for { + mh, err := ch.recv(ctx, p[:]) + if err != nil { + recvErr <- err + return + } + + if mh.Type != messageTypeRequest { + // we must ignore this for future compat. + continue + } + + var req Request + if err := s.codec.Unmarshal(p[:mh.Length], &req); err != nil { + recvErr <- err + return + } + + select { + case requests <- request{ + id: mh.StreamID, + req: &req, + }: + case <-done: + } + } + }() - // every connection is just a simple in/out request loop. No complexity for - // multiplexing streams or dealing with head of line blocking, as this - // isn't necessary for shim control. for { - if err := ch.recv(ctx, &req); err != nil { - log.L.WithError(err).Error("failed receiving message on channel") - return - } + select { + case request := <-requests: + go func(id uint32) { + p, status := s.services.call(ctx, request.req.Service, request.req.Method, request.req.Payload) + resp := &Response{ + Status: status.Proto(), + Payload: p, + } - p, status := s.services.call(ctx, req.Service, req.Method, req.Payload) - - resp := &Response{ - Status: status.Proto(), - Payload: p, - } - - if err := ch.send(ctx, resp); err != nil { - log.L.WithError(err).Error("failed sending message on channel") + select { + case responses <- response{ + id: id, + resp: resp, + }: + case <-done: + } + }(request.id) + case response := <-responses: + p, err := s.codec.Marshal(response.resp) + if err != nil { + log.L.WithError(err).Error("failed marshaling response") + return + } + if err := ch.send(ctx, response.id, messageTypeResponse, p); err != nil { + log.L.WithError(err).Error("failed sending message on channel") + return + } + case err := <-recvErr: + log.L.WithError(err).Error("error receiving message") return } } diff --git a/services.go b/services.go index 943404c..b9a749e 100644 --- a/services.go +++ b/services.go @@ -6,7 +6,6 @@ import ( "os" "path" - "github.com/containerd/containerd/log" "github.com/gogo/protobuf/proto" "github.com/pkg/errors" "google.golang.org/grpc/codes" @@ -52,7 +51,6 @@ func (s *serviceSet) call(ctx context.Context, serviceName, methodName string, p } func (s *serviceSet) dispatch(ctx context.Context, serviceName, methodName string, p []byte) ([]byte, error) { - ctx = log.WithLogger(ctx, log.G(ctx).WithField("method", fullPath(serviceName, methodName))) method, err := s.resolve(serviceName, methodName) if err != nil { return nil, err