ttrpc: handle concurrent requests and responses

With this changeset, ttrpc can now handle mutliple outstanding requests
and responses on the same connection without blocking. On the
server-side, we dispatch a goroutine per outstanding reequest. On the
client side, a management goroutine dispatches responses to blocked
waiters.

The protocol has been changed to support this behavior by including a
"stream id" that can used to identify which request a response belongs
to on the client-side of the connection. With these changes, we should
also be able to support streams in the future.

Signed-off-by: Stephen J Day <stephen.day@docker.com>
This commit is contained in:
Stephen J Day 2017-11-21 21:38:38 -08:00
parent 2a81659f49
commit 7f752bf263
No known key found for this signature in database
GPG Key ID: 67B3DED84EDC823F
5 changed files with 333 additions and 179 deletions

View File

@ -5,101 +5,95 @@ import (
"context" "context"
"encoding/binary" "encoding/binary"
"io" "io"
"net"
"github.com/containerd/containerd/log"
"github.com/gogo/protobuf/proto"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
const maxMessageSize = 8 << 10 // TODO(stevvooe): Cut these down, since they are pre-alloced. const (
messageHeaderLength = 10
messageLengthMax = 8 << 10
)
type messageType uint8
const (
messageTypeRequest messageType = 0x1
messageTypeResponse messageType = 0x2
)
// messageHeader represents the fixed-length message header of 10 bytes sent
// with every request.
type messageHeader struct {
Length uint32 // length excluding this header. b[:4]
StreamID uint32 // identifies which request stream message is a part of. b[4:8]
Type messageType // message type b[8]
Flags uint8 // reserved b[9]
}
func readMessageHeader(p []byte, r io.Reader) (messageHeader, error) {
_, err := io.ReadFull(r, p[:messageHeaderLength])
if err != nil {
return messageHeader{}, err
}
return messageHeader{
Length: binary.BigEndian.Uint32(p[:4]),
StreamID: binary.BigEndian.Uint32(p[4:8]),
Type: messageType(p[8]),
Flags: p[9],
}, nil
}
func writeMessageHeader(w io.Writer, p []byte, mh messageHeader) error {
binary.BigEndian.PutUint32(p[:4], mh.Length)
binary.BigEndian.PutUint32(p[4:8], mh.StreamID)
p[8] = byte(mh.Type)
p[9] = mh.Flags
_, err := w.Write(p[:])
return err
}
type channel struct { type channel struct {
conn net.Conn
bw *bufio.Writer bw *bufio.Writer
br *bufio.Reader br *bufio.Reader
hrbuf [messageHeaderLength]byte // avoid alloc when reading header
hwbuf [messageHeaderLength]byte
} }
func newChannel(conn net.Conn) *channel { func newChannel(w io.Writer, r io.Reader) *channel {
return &channel{ return &channel{
conn: conn, bw: bufio.NewWriter(w),
bw: bufio.NewWriterSize(conn, maxMessageSize), br: bufio.NewReader(r),
br: bufio.NewReaderSize(conn, maxMessageSize),
} }
} }
func (ch *channel) recv(ctx context.Context, msg interface{}) error { func (ch *channel) recv(ctx context.Context, p []byte) (messageHeader, error) {
defer log.G(ctx).WithField("msg", msg).Info("recv") mh, err := readMessageHeader(ch.hrbuf[:], ch.br)
if err != nil {
return messageHeader{}, err
}
// TODO(stevvooe): Use `bufio.Reader.Peek` here to remove this allocation. if mh.Length > uint32(len(p)) {
var p [maxMessageSize]byte return messageHeader{}, errors.Wrapf(io.ErrShortBuffer, "message length %v over buffer size %v", mh.Length, len(p))
n, err := readmsg(ch.br, p[:]) }
if _, err := io.ReadFull(ch.br, p[:mh.Length]); err != nil {
return messageHeader{}, errors.Wrapf(err, "failed reading message")
}
return mh, nil
}
func (ch *channel) send(ctx context.Context, streamID uint32, t messageType, p []byte) error {
if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t}); err != nil {
return err
}
_, err := ch.bw.Write(p)
if err != nil { if err != nil {
return err return err
} }
switch msg := msg.(type) { return ch.bw.Flush()
case proto.Message:
return proto.Unmarshal(p[:n], msg)
default:
return errors.Errorf("unnsupported type in channel: %#v", msg)
}
}
func (ch *channel) send(ctx context.Context, msg interface{}) error {
log.G(ctx).WithField("msg", msg).Info("send")
var p []byte
switch msg := msg.(type) {
case proto.Message:
var err error
// TODO(stevvooe): trickiest allocation of the bunch. This will be hard
// to get rid of without using `MarshalTo` directly.
p, err = proto.Marshal(msg)
if err != nil {
return err
}
default:
return errors.Errorf("unsupported type recv from channel: %#v", msg)
}
return writemsg(ch.bw, p)
}
func readmsg(r *bufio.Reader, p []byte) (int, error) {
mlen, err := binary.ReadVarint(r)
if err != nil {
return 0, errors.Wrapf(err, "failed reading message size")
}
if mlen > int64(len(p)) {
return 0, errors.Wrapf(io.ErrShortBuffer, "message length %v over buffer size %v", mlen, len(p))
}
nn, err := io.ReadFull(r, p[:mlen])
if err != nil {
return 0, errors.Wrapf(err, "failed reading message size")
}
if int64(nn) != mlen {
return 0, errors.Errorf("mismatched read against message length %v != %v", nn, mlen)
}
return int(mlen), nil
}
func writemsg(w *bufio.Writer, p []byte) error {
var (
mlenp [binary.MaxVarintLen64]byte
n = binary.PutVarint(mlenp[:], int64(len(p)))
)
if _, err := w.Write(mlenp[:n]); err != nil {
return errors.Wrapf(err, "failed writing message header")
}
if _, err := w.Write(p); err != nil {
return errors.Wrapf(err, "failed writing message")
}
return w.Flush()
} }

View File

@ -3,6 +3,7 @@ package ttrpc
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"context"
"io" "io"
"reflect" "reflect"
"testing" "testing"
@ -12,8 +13,10 @@ import (
func TestReadWriteMessage(t *testing.T) { func TestReadWriteMessage(t *testing.T) {
var ( var (
channel bytes.Buffer ctx = context.Background()
w = bufio.NewWriter(&channel) buffer bytes.Buffer
w = bufio.NewWriter(&buffer)
ch = newChannel(w, nil)
messages = [][]byte{ messages = [][]byte{
[]byte("hello"), []byte("hello"),
[]byte("this is a test"), []byte("this is a test"),
@ -21,20 +24,21 @@ func TestReadWriteMessage(t *testing.T) {
} }
) )
for _, msg := range messages { for i, msg := range messages {
if err := writemsg(w, msg); err != nil { if err := ch.send(ctx, uint32(i), 1, msg); err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }
var ( var (
received [][]byte received [][]byte
r = bufio.NewReader(bytes.NewReader(channel.Bytes())) r = bufio.NewReader(bytes.NewReader(buffer.Bytes()))
rch = newChannel(nil, r)
) )
for { for {
var p [4096]byte var p [4096]byte
n, err := readmsg(r, p[:]) mh, err := rch.recv(ctx, p[:])
if err != nil { if err != nil {
if errors.Cause(err) != io.EOF { if errors.Cause(err) != io.EOF {
t.Fatal(err) t.Fatal(err)
@ -42,7 +46,7 @@ func TestReadWriteMessage(t *testing.T) {
break break
} }
received = append(received, p[:n]) received = append(received, p[:mh.Length])
} }
if !reflect.DeepEqual(received, messages) { if !reflect.DeepEqual(received, messages) {
@ -52,21 +56,25 @@ func TestReadWriteMessage(t *testing.T) {
func TestSmallBuffer(t *testing.T) { func TestSmallBuffer(t *testing.T) {
var ( var (
channel bytes.Buffer ctx = context.Background()
w = bufio.NewWriter(&channel) buffer bytes.Buffer
w = bufio.NewWriter(&buffer)
ch = newChannel(w, nil)
msg = []byte("a message of massive length") msg = []byte("a message of massive length")
) )
if err := writemsg(w, msg); err != nil { if err := ch.send(ctx, 1, 1, msg); err != nil {
t.Fatal(err) t.Fatal(err)
} }
// now, read it off the channel with a small buffer // now, read it off the channel with a small buffer
var ( var (
p = make([]byte, len(msg)-1) p = make([]byte, len(msg)-1)
r = bufio.NewReader(bytes.NewReader(channel.Bytes())) r = bufio.NewReader(bytes.NewReader(buffer.Bytes()))
rch = newChannel(nil, r)
) )
_, err := readmsg(r, p[:])
_, err := rch.recv(ctx, p[:])
if err == nil { if err == nil {
t.Fatalf("error expected reading with small buffer") t.Fatalf("error expected reading with small buffer")
} }
@ -75,41 +83,3 @@ func TestSmallBuffer(t *testing.T) {
t.Fatalf("errors.Cause(err) should equal io.ErrShortBuffer: %v != %v", err, io.ErrShortBuffer) t.Fatalf("errors.Cause(err) should equal io.ErrShortBuffer: %v != %v", err, io.ErrShortBuffer)
} }
} }
func BenchmarkReadWrite(b *testing.B) {
b.StopTimer()
var (
messages = [][]byte{
[]byte("hello"),
[]byte("this is a test"),
[]byte("of message framing"),
}
total int64
channel bytes.Buffer
w = bufio.NewWriter(&channel)
p [4096]byte
)
b.ReportAllocs()
b.StartTimer()
for i := 0; i < b.N; i++ {
msg := messages[i%len(messages)]
if err := writemsg(w, msg); err != nil {
b.Fatal(err)
}
total += int64(len(msg))
}
b.SetBytes(total)
r := bufio.NewReader(bytes.NewReader(channel.Bytes()))
for i := 0; i < b.N; i++ {
_, err := readmsg(r, p[:])
if err != nil {
if errors.Cause(err) != io.EOF {
b.Fatal(err)
}
break
}
}
}

177
client.go
View File

@ -3,56 +3,191 @@ package ttrpc
import ( import (
"context" "context"
"net" "net"
"sync"
"sync/atomic"
"github.com/gogo/protobuf/proto" "github.com/gogo/protobuf/proto"
"github.com/pkg/errors"
) )
type Client struct { type Client struct {
codec codec
channel *channel channel *channel
requestID uint32
sendRequests chan sendRequest
recvRequests chan recvRequest
closed chan struct{}
closeOnce sync.Once
done chan struct{}
err error
} }
func NewClient(conn net.Conn) *Client { func NewClient(conn net.Conn) *Client {
return &Client{ c := &Client{
channel: newChannel(conn), codec: codec{},
channel: newChannel(conn, conn),
sendRequests: make(chan sendRequest),
recvRequests: make(chan recvRequest),
closed: make(chan struct{}),
done: make(chan struct{}),
} }
go c.run()
return c
} }
func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error { func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error {
var payload []byte payload, err := c.codec.Marshal(req)
switch v := req.(type) {
case proto.Message:
var err error
payload, err = proto.Marshal(v)
if err != nil { if err != nil {
return err return err
} }
default:
return errors.Errorf("ttrpc: unknown request type: %T", req)
}
requestID := atomic.AddUint32(&c.requestID, 1)
request := Request{ request := Request{
Service: service, Service: service,
Method: method, Method: method,
Payload: payload, Payload: payload,
} }
if err := c.channel.send(ctx, &request); err != nil { if err := c.send(ctx, requestID, &request); err != nil {
return err return err
} }
var response Response var response Response
if err := c.channel.recv(ctx, &response); err != nil { if err := c.recv(ctx, requestID, &response); err != nil {
return err return err
} }
switch v := resp.(type) {
case proto.Message: return c.codec.Unmarshal(response.Payload, resp)
if err := proto.Unmarshal(response.Payload, v); err != nil { }
return err
} func (c *Client) Close() error {
default: c.closeOnce.Do(func() {
return errors.Errorf("ttrpc: unknown response type: %T", resp) close(c.closed)
} })
return nil return nil
} }
type sendRequest struct {
ctx context.Context
id uint32
msg interface{}
err chan error
}
func (c *Client) send(ctx context.Context, id uint32, msg interface{}) error {
errs := make(chan error, 1)
select {
case c.sendRequests <- sendRequest{
ctx: ctx,
id: id,
msg: msg,
err: errs,
}:
case <-ctx.Done():
return ctx.Err()
case <-c.done:
return c.err
}
select {
case err := <-errs:
return err
case <-ctx.Done():
return ctx.Err()
case <-c.done:
return c.err
}
}
type recvRequest struct {
id uint32
msg interface{}
err chan error
}
func (c *Client) recv(ctx context.Context, id uint32, msg interface{}) error {
errs := make(chan error, 1)
select {
case c.recvRequests <- recvRequest{
id: id,
msg: msg,
err: errs,
}:
case <-c.done:
return c.err
case <-ctx.Done():
return ctx.Err()
}
select {
case err := <-errs:
return err
case <-c.done:
return c.err
case <-ctx.Done():
return ctx.Err()
}
}
type received struct {
mh messageHeader
p []byte
err error
}
func (c *Client) run() {
defer close(c.done)
var (
waiters = map[uint32]recvRequest{}
queued = map[uint32]received{} // messages unmatched by waiter
incoming = make(chan received)
)
go func() {
// start one more goroutine to recv messages without blocking.
for {
var p [messageLengthMax]byte
mh, err := c.channel.recv(context.TODO(), p[:])
select {
case incoming <- received{
mh: mh,
p: p[:mh.Length],
err: err,
}:
case <-c.done:
return
}
}
}()
for {
select {
case req := <-c.sendRequests:
if p, err := proto.Marshal(req.msg.(proto.Message)); err != nil {
req.err <- err
} else {
req.err <- c.channel.send(req.ctx, req.id, messageTypeRequest, p)
}
case req := <-c.recvRequests:
if r, ok := queued[req.id]; ok {
req.err <- proto.Unmarshal(r.p, req.msg.(proto.Message))
}
waiters[req.id] = req
case r := <-incoming:
if r.err != nil {
c.err = r.err
return
}
if waiter, ok := waiters[r.mh.StreamID]; ok {
waiter.err <- proto.Unmarshal(r.p, waiter.msg.(proto.Message))
} else {
queued[r.mh.StreamID] = r
}
case <-c.closed:
return
}
}
}

View File

@ -9,6 +9,7 @@ import (
type Server struct { type Server struct {
services *serviceSet services *serviceSet
codec codec
} }
func NewServer() *Server { func NewServer() *Server {
@ -43,36 +44,92 @@ func (s *Server) Serve(l net.Listener) error {
func (s *Server) handleConn(conn net.Conn) { func (s *Server) handleConn(conn net.Conn) {
defer conn.Close() defer conn.Close()
type (
request struct {
id uint32
req *Request
}
response struct {
id uint32
resp *Response
}
)
var ( var (
ch = newChannel(conn) ch = newChannel(conn, conn)
req Request
ctx, cancel = context.WithCancel(context.Background()) ctx, cancel = context.WithCancel(context.Background())
responses = make(chan response)
requests = make(chan request)
recvErr = make(chan error, 1)
done = make(chan struct{})
) )
defer cancel() defer cancel()
defer close(done)
// TODO(stevvooe): Recover here or in dispatch to handle panics in service go func() {
// methods. defer close(recvErr)
var p [messageLengthMax]byte
// every connection is just a simple in/out request loop. No complexity for
// multiplexing streams or dealing with head of line blocking, as this
// isn't necessary for shim control.
for { for {
if err := ch.recv(ctx, &req); err != nil { mh, err := ch.recv(ctx, p[:])
log.L.WithError(err).Error("failed receiving message on channel") if err != nil {
recvErr <- err
return return
} }
p, status := s.services.call(ctx, req.Service, req.Method, req.Payload) if mh.Type != messageTypeRequest {
// we must ignore this for future compat.
continue
}
var req Request
if err := s.codec.Unmarshal(p[:mh.Length], &req); err != nil {
recvErr <- err
return
}
select {
case requests <- request{
id: mh.StreamID,
req: &req,
}:
case <-done:
}
}
}()
for {
select {
case request := <-requests:
go func(id uint32) {
p, status := s.services.call(ctx, request.req.Service, request.req.Method, request.req.Payload)
resp := &Response{ resp := &Response{
Status: status.Proto(), Status: status.Proto(),
Payload: p, Payload: p,
} }
if err := ch.send(ctx, resp); err != nil { select {
case responses <- response{
id: id,
resp: resp,
}:
case <-done:
}
}(request.id)
case response := <-responses:
p, err := s.codec.Marshal(response.resp)
if err != nil {
log.L.WithError(err).Error("failed marshaling response")
return
}
if err := ch.send(ctx, response.id, messageTypeResponse, p); err != nil {
log.L.WithError(err).Error("failed sending message on channel") log.L.WithError(err).Error("failed sending message on channel")
return return
} }
case err := <-recvErr:
log.L.WithError(err).Error("error receiving message")
return
}
} }
} }

View File

@ -6,7 +6,6 @@ import (
"os" "os"
"path" "path"
"github.com/containerd/containerd/log"
"github.com/gogo/protobuf/proto" "github.com/gogo/protobuf/proto"
"github.com/pkg/errors" "github.com/pkg/errors"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
@ -52,7 +51,6 @@ func (s *serviceSet) call(ctx context.Context, serviceName, methodName string, p
} }
func (s *serviceSet) dispatch(ctx context.Context, serviceName, methodName string, p []byte) ([]byte, error) { func (s *serviceSet) dispatch(ctx context.Context, serviceName, methodName string, p []byte) ([]byte, error) {
ctx = log.WithLogger(ctx, log.G(ctx).WithField("method", fullPath(serviceName, methodName)))
method, err := s.resolve(serviceName, methodName) method, err := s.resolve(serviceName, methodName)
if err != nil { if err != nil {
return nil, err return nil, err