mirror of https://github.com/containerd/ttrpc.git
ttrpc: handle concurrent requests and responses
With this changeset, ttrpc can now handle mutliple outstanding requests and responses on the same connection without blocking. On the server-side, we dispatch a goroutine per outstanding reequest. On the client side, a management goroutine dispatches responses to blocked waiters. The protocol has been changed to support this behavior by including a "stream id" that can used to identify which request a response belongs to on the client-side of the connection. With these changes, we should also be able to support streams in the future. Signed-off-by: Stephen J Day <stephen.day@docker.com>
This commit is contained in:
parent
2a81659f49
commit
7f752bf263
168
channel.go
168
channel.go
|
|
@ -5,101 +5,95 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
|
||||||
|
|
||||||
"github.com/containerd/containerd/log"
|
|
||||||
"github.com/gogo/protobuf/proto"
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
const maxMessageSize = 8 << 10 // TODO(stevvooe): Cut these down, since they are pre-alloced.
|
const (
|
||||||
|
messageHeaderLength = 10
|
||||||
type channel struct {
|
messageLengthMax = 8 << 10
|
||||||
conn net.Conn
|
|
||||||
bw *bufio.Writer
|
|
||||||
br *bufio.Reader
|
|
||||||
}
|
|
||||||
|
|
||||||
func newChannel(conn net.Conn) *channel {
|
|
||||||
return &channel{
|
|
||||||
conn: conn,
|
|
||||||
bw: bufio.NewWriterSize(conn, maxMessageSize),
|
|
||||||
br: bufio.NewReaderSize(conn, maxMessageSize),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ch *channel) recv(ctx context.Context, msg interface{}) error {
|
|
||||||
defer log.G(ctx).WithField("msg", msg).Info("recv")
|
|
||||||
|
|
||||||
// TODO(stevvooe): Use `bufio.Reader.Peek` here to remove this allocation.
|
|
||||||
var p [maxMessageSize]byte
|
|
||||||
n, err := readmsg(ch.br, p[:])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
switch msg := msg.(type) {
|
|
||||||
case proto.Message:
|
|
||||||
return proto.Unmarshal(p[:n], msg)
|
|
||||||
default:
|
|
||||||
return errors.Errorf("unnsupported type in channel: %#v", msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ch *channel) send(ctx context.Context, msg interface{}) error {
|
|
||||||
log.G(ctx).WithField("msg", msg).Info("send")
|
|
||||||
var p []byte
|
|
||||||
switch msg := msg.(type) {
|
|
||||||
case proto.Message:
|
|
||||||
var err error
|
|
||||||
// TODO(stevvooe): trickiest allocation of the bunch. This will be hard
|
|
||||||
// to get rid of without using `MarshalTo` directly.
|
|
||||||
p, err = proto.Marshal(msg)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return errors.Errorf("unsupported type recv from channel: %#v", msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
return writemsg(ch.bw, p)
|
|
||||||
}
|
|
||||||
|
|
||||||
func readmsg(r *bufio.Reader, p []byte) (int, error) {
|
|
||||||
mlen, err := binary.ReadVarint(r)
|
|
||||||
if err != nil {
|
|
||||||
return 0, errors.Wrapf(err, "failed reading message size")
|
|
||||||
}
|
|
||||||
|
|
||||||
if mlen > int64(len(p)) {
|
|
||||||
return 0, errors.Wrapf(io.ErrShortBuffer, "message length %v over buffer size %v", mlen, len(p))
|
|
||||||
}
|
|
||||||
|
|
||||||
nn, err := io.ReadFull(r, p[:mlen])
|
|
||||||
if err != nil {
|
|
||||||
return 0, errors.Wrapf(err, "failed reading message size")
|
|
||||||
}
|
|
||||||
|
|
||||||
if int64(nn) != mlen {
|
|
||||||
return 0, errors.Errorf("mismatched read against message length %v != %v", nn, mlen)
|
|
||||||
}
|
|
||||||
|
|
||||||
return int(mlen), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func writemsg(w *bufio.Writer, p []byte) error {
|
|
||||||
var (
|
|
||||||
mlenp [binary.MaxVarintLen64]byte
|
|
||||||
n = binary.PutVarint(mlenp[:], int64(len(p)))
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if _, err := w.Write(mlenp[:n]); err != nil {
|
type messageType uint8
|
||||||
return errors.Wrapf(err, "failed writing message header")
|
|
||||||
|
const (
|
||||||
|
messageTypeRequest messageType = 0x1
|
||||||
|
messageTypeResponse messageType = 0x2
|
||||||
|
)
|
||||||
|
|
||||||
|
// messageHeader represents the fixed-length message header of 10 bytes sent
|
||||||
|
// with every request.
|
||||||
|
type messageHeader struct {
|
||||||
|
Length uint32 // length excluding this header. b[:4]
|
||||||
|
StreamID uint32 // identifies which request stream message is a part of. b[4:8]
|
||||||
|
Type messageType // message type b[8]
|
||||||
|
Flags uint8 // reserved b[9]
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := w.Write(p); err != nil {
|
func readMessageHeader(p []byte, r io.Reader) (messageHeader, error) {
|
||||||
return errors.Wrapf(err, "failed writing message")
|
_, err := io.ReadFull(r, p[:messageHeaderLength])
|
||||||
|
if err != nil {
|
||||||
|
return messageHeader{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return w.Flush()
|
return messageHeader{
|
||||||
|
Length: binary.BigEndian.Uint32(p[:4]),
|
||||||
|
StreamID: binary.BigEndian.Uint32(p[4:8]),
|
||||||
|
Type: messageType(p[8]),
|
||||||
|
Flags: p[9],
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeMessageHeader(w io.Writer, p []byte, mh messageHeader) error {
|
||||||
|
binary.BigEndian.PutUint32(p[:4], mh.Length)
|
||||||
|
binary.BigEndian.PutUint32(p[4:8], mh.StreamID)
|
||||||
|
p[8] = byte(mh.Type)
|
||||||
|
p[9] = mh.Flags
|
||||||
|
|
||||||
|
_, err := w.Write(p[:])
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
type channel struct {
|
||||||
|
bw *bufio.Writer
|
||||||
|
br *bufio.Reader
|
||||||
|
hrbuf [messageHeaderLength]byte // avoid alloc when reading header
|
||||||
|
hwbuf [messageHeaderLength]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func newChannel(w io.Writer, r io.Reader) *channel {
|
||||||
|
return &channel{
|
||||||
|
bw: bufio.NewWriter(w),
|
||||||
|
br: bufio.NewReader(r),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ch *channel) recv(ctx context.Context, p []byte) (messageHeader, error) {
|
||||||
|
mh, err := readMessageHeader(ch.hrbuf[:], ch.br)
|
||||||
|
if err != nil {
|
||||||
|
return messageHeader{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if mh.Length > uint32(len(p)) {
|
||||||
|
return messageHeader{}, errors.Wrapf(io.ErrShortBuffer, "message length %v over buffer size %v", mh.Length, len(p))
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := io.ReadFull(ch.br, p[:mh.Length]); err != nil {
|
||||||
|
return messageHeader{}, errors.Wrapf(err, "failed reading message")
|
||||||
|
}
|
||||||
|
|
||||||
|
return mh, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ch *channel) send(ctx context.Context, streamID uint32, t messageType, p []byte) error {
|
||||||
|
if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := ch.bw.Write(p)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return ch.bw.Flush()
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ package ttrpc
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"io"
|
"io"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
@ -12,8 +13,10 @@ import (
|
||||||
|
|
||||||
func TestReadWriteMessage(t *testing.T) {
|
func TestReadWriteMessage(t *testing.T) {
|
||||||
var (
|
var (
|
||||||
channel bytes.Buffer
|
ctx = context.Background()
|
||||||
w = bufio.NewWriter(&channel)
|
buffer bytes.Buffer
|
||||||
|
w = bufio.NewWriter(&buffer)
|
||||||
|
ch = newChannel(w, nil)
|
||||||
messages = [][]byte{
|
messages = [][]byte{
|
||||||
[]byte("hello"),
|
[]byte("hello"),
|
||||||
[]byte("this is a test"),
|
[]byte("this is a test"),
|
||||||
|
|
@ -21,20 +24,21 @@ func TestReadWriteMessage(t *testing.T) {
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
for _, msg := range messages {
|
for i, msg := range messages {
|
||||||
if err := writemsg(w, msg); err != nil {
|
if err := ch.send(ctx, uint32(i), 1, msg); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
received [][]byte
|
received [][]byte
|
||||||
r = bufio.NewReader(bytes.NewReader(channel.Bytes()))
|
r = bufio.NewReader(bytes.NewReader(buffer.Bytes()))
|
||||||
|
rch = newChannel(nil, r)
|
||||||
)
|
)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
var p [4096]byte
|
var p [4096]byte
|
||||||
n, err := readmsg(r, p[:])
|
mh, err := rch.recv(ctx, p[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Cause(err) != io.EOF {
|
if errors.Cause(err) != io.EOF {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
|
@ -42,7 +46,7 @@ func TestReadWriteMessage(t *testing.T) {
|
||||||
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
received = append(received, p[:n])
|
received = append(received, p[:mh.Length])
|
||||||
}
|
}
|
||||||
|
|
||||||
if !reflect.DeepEqual(received, messages) {
|
if !reflect.DeepEqual(received, messages) {
|
||||||
|
|
@ -52,21 +56,25 @@ func TestReadWriteMessage(t *testing.T) {
|
||||||
|
|
||||||
func TestSmallBuffer(t *testing.T) {
|
func TestSmallBuffer(t *testing.T) {
|
||||||
var (
|
var (
|
||||||
channel bytes.Buffer
|
ctx = context.Background()
|
||||||
w = bufio.NewWriter(&channel)
|
buffer bytes.Buffer
|
||||||
|
w = bufio.NewWriter(&buffer)
|
||||||
|
ch = newChannel(w, nil)
|
||||||
msg = []byte("a message of massive length")
|
msg = []byte("a message of massive length")
|
||||||
)
|
)
|
||||||
|
|
||||||
if err := writemsg(w, msg); err != nil {
|
if err := ch.send(ctx, 1, 1, msg); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// now, read it off the channel with a small buffer
|
// now, read it off the channel with a small buffer
|
||||||
var (
|
var (
|
||||||
p = make([]byte, len(msg)-1)
|
p = make([]byte, len(msg)-1)
|
||||||
r = bufio.NewReader(bytes.NewReader(channel.Bytes()))
|
r = bufio.NewReader(bytes.NewReader(buffer.Bytes()))
|
||||||
|
rch = newChannel(nil, r)
|
||||||
)
|
)
|
||||||
_, err := readmsg(r, p[:])
|
|
||||||
|
_, err := rch.recv(ctx, p[:])
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatalf("error expected reading with small buffer")
|
t.Fatalf("error expected reading with small buffer")
|
||||||
}
|
}
|
||||||
|
|
@ -75,41 +83,3 @@ func TestSmallBuffer(t *testing.T) {
|
||||||
t.Fatalf("errors.Cause(err) should equal io.ErrShortBuffer: %v != %v", err, io.ErrShortBuffer)
|
t.Fatalf("errors.Cause(err) should equal io.ErrShortBuffer: %v != %v", err, io.ErrShortBuffer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkReadWrite(b *testing.B) {
|
|
||||||
b.StopTimer()
|
|
||||||
var (
|
|
||||||
messages = [][]byte{
|
|
||||||
[]byte("hello"),
|
|
||||||
[]byte("this is a test"),
|
|
||||||
[]byte("of message framing"),
|
|
||||||
}
|
|
||||||
total int64
|
|
||||||
channel bytes.Buffer
|
|
||||||
w = bufio.NewWriter(&channel)
|
|
||||||
p [4096]byte
|
|
||||||
)
|
|
||||||
|
|
||||||
b.ReportAllocs()
|
|
||||||
b.StartTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
msg := messages[i%len(messages)]
|
|
||||||
if err := writemsg(w, msg); err != nil {
|
|
||||||
b.Fatal(err)
|
|
||||||
}
|
|
||||||
total += int64(len(msg))
|
|
||||||
}
|
|
||||||
b.SetBytes(total)
|
|
||||||
|
|
||||||
r := bufio.NewReader(bytes.NewReader(channel.Bytes()))
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
_, err := readmsg(r, p[:])
|
|
||||||
if err != nil {
|
|
||||||
if errors.Cause(err) != io.EOF {
|
|
||||||
b.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
||||||
175
client.go
175
client.go
|
|
@ -3,56 +3,191 @@ package ttrpc
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/gogo/protobuf/proto"
|
"github.com/gogo/protobuf/proto"
|
||||||
"github.com/pkg/errors"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Client struct {
|
type Client struct {
|
||||||
|
codec codec
|
||||||
channel *channel
|
channel *channel
|
||||||
|
requestID uint32
|
||||||
|
sendRequests chan sendRequest
|
||||||
|
recvRequests chan recvRequest
|
||||||
|
|
||||||
|
closed chan struct{}
|
||||||
|
closeOnce sync.Once
|
||||||
|
done chan struct{}
|
||||||
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClient(conn net.Conn) *Client {
|
func NewClient(conn net.Conn) *Client {
|
||||||
return &Client{
|
c := &Client{
|
||||||
channel: newChannel(conn),
|
codec: codec{},
|
||||||
|
channel: newChannel(conn, conn),
|
||||||
|
sendRequests: make(chan sendRequest),
|
||||||
|
recvRequests: make(chan recvRequest),
|
||||||
|
closed: make(chan struct{}),
|
||||||
|
done: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
go c.run()
|
||||||
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error {
|
func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error {
|
||||||
var payload []byte
|
payload, err := c.codec.Marshal(req)
|
||||||
switch v := req.(type) {
|
|
||||||
case proto.Message:
|
|
||||||
var err error
|
|
||||||
payload, err = proto.Marshal(v)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
default:
|
|
||||||
return errors.Errorf("ttrpc: unknown request type: %T", req)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
requestID := atomic.AddUint32(&c.requestID, 1)
|
||||||
request := Request{
|
request := Request{
|
||||||
Service: service,
|
Service: service,
|
||||||
Method: method,
|
Method: method,
|
||||||
Payload: payload,
|
Payload: payload,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.channel.send(ctx, &request); err != nil {
|
if err := c.send(ctx, requestID, &request); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var response Response
|
var response Response
|
||||||
if err := c.channel.recv(ctx, &response); err != nil {
|
if err := c.recv(ctx, requestID, &response); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
switch v := resp.(type) {
|
|
||||||
case proto.Message:
|
return c.codec.Unmarshal(response.Payload, resp)
|
||||||
if err := proto.Unmarshal(response.Payload, v); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return errors.Errorf("ttrpc: unknown response type: %T", resp)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) Close() error {
|
||||||
|
c.closeOnce.Do(func() {
|
||||||
|
close(c.closed)
|
||||||
|
})
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type sendRequest struct {
|
||||||
|
ctx context.Context
|
||||||
|
id uint32
|
||||||
|
msg interface{}
|
||||||
|
err chan error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) send(ctx context.Context, id uint32, msg interface{}) error {
|
||||||
|
errs := make(chan error, 1)
|
||||||
|
select {
|
||||||
|
case c.sendRequests <- sendRequest{
|
||||||
|
ctx: ctx,
|
||||||
|
id: id,
|
||||||
|
msg: msg,
|
||||||
|
err: errs,
|
||||||
|
}:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-c.done:
|
||||||
|
return c.err
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-errs:
|
||||||
|
return err
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-c.done:
|
||||||
|
return c.err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type recvRequest struct {
|
||||||
|
id uint32
|
||||||
|
msg interface{}
|
||||||
|
err chan error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) recv(ctx context.Context, id uint32, msg interface{}) error {
|
||||||
|
errs := make(chan error, 1)
|
||||||
|
select {
|
||||||
|
case c.recvRequests <- recvRequest{
|
||||||
|
id: id,
|
||||||
|
msg: msg,
|
||||||
|
err: errs,
|
||||||
|
}:
|
||||||
|
case <-c.done:
|
||||||
|
return c.err
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-errs:
|
||||||
|
return err
|
||||||
|
case <-c.done:
|
||||||
|
return c.err
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type received struct {
|
||||||
|
mh messageHeader
|
||||||
|
p []byte
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) run() {
|
||||||
|
defer close(c.done)
|
||||||
|
var (
|
||||||
|
waiters = map[uint32]recvRequest{}
|
||||||
|
queued = map[uint32]received{} // messages unmatched by waiter
|
||||||
|
incoming = make(chan received)
|
||||||
|
)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
// start one more goroutine to recv messages without blocking.
|
||||||
|
for {
|
||||||
|
var p [messageLengthMax]byte
|
||||||
|
mh, err := c.channel.recv(context.TODO(), p[:])
|
||||||
|
select {
|
||||||
|
case incoming <- received{
|
||||||
|
mh: mh,
|
||||||
|
p: p[:mh.Length],
|
||||||
|
err: err,
|
||||||
|
}:
|
||||||
|
case <-c.done:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case req := <-c.sendRequests:
|
||||||
|
if p, err := proto.Marshal(req.msg.(proto.Message)); err != nil {
|
||||||
|
req.err <- err
|
||||||
|
} else {
|
||||||
|
req.err <- c.channel.send(req.ctx, req.id, messageTypeRequest, p)
|
||||||
|
}
|
||||||
|
case req := <-c.recvRequests:
|
||||||
|
if r, ok := queued[req.id]; ok {
|
||||||
|
req.err <- proto.Unmarshal(r.p, req.msg.(proto.Message))
|
||||||
|
}
|
||||||
|
waiters[req.id] = req
|
||||||
|
case r := <-incoming:
|
||||||
|
if r.err != nil {
|
||||||
|
c.err = r.err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if waiter, ok := waiters[r.mh.StreamID]; ok {
|
||||||
|
waiter.err <- proto.Unmarshal(r.p, waiter.msg.(proto.Message))
|
||||||
|
} else {
|
||||||
|
queued[r.mh.StreamID] = r
|
||||||
|
}
|
||||||
|
case <-c.closed:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
81
server.go
81
server.go
|
|
@ -9,6 +9,7 @@ import (
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
services *serviceSet
|
services *serviceSet
|
||||||
|
codec codec
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServer() *Server {
|
func NewServer() *Server {
|
||||||
|
|
@ -43,36 +44,92 @@ func (s *Server) Serve(l net.Listener) error {
|
||||||
func (s *Server) handleConn(conn net.Conn) {
|
func (s *Server) handleConn(conn net.Conn) {
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
|
type (
|
||||||
|
request struct {
|
||||||
|
id uint32
|
||||||
|
req *Request
|
||||||
|
}
|
||||||
|
|
||||||
|
response struct {
|
||||||
|
id uint32
|
||||||
|
resp *Response
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ch = newChannel(conn)
|
ch = newChannel(conn, conn)
|
||||||
req Request
|
|
||||||
ctx, cancel = context.WithCancel(context.Background())
|
ctx, cancel = context.WithCancel(context.Background())
|
||||||
|
responses = make(chan response)
|
||||||
|
requests = make(chan request)
|
||||||
|
recvErr = make(chan error, 1)
|
||||||
|
done = make(chan struct{})
|
||||||
)
|
)
|
||||||
|
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
defer close(done)
|
||||||
|
|
||||||
// TODO(stevvooe): Recover here or in dispatch to handle panics in service
|
go func() {
|
||||||
// methods.
|
defer close(recvErr)
|
||||||
|
var p [messageLengthMax]byte
|
||||||
// every connection is just a simple in/out request loop. No complexity for
|
|
||||||
// multiplexing streams or dealing with head of line blocking, as this
|
|
||||||
// isn't necessary for shim control.
|
|
||||||
for {
|
for {
|
||||||
if err := ch.recv(ctx, &req); err != nil {
|
mh, err := ch.recv(ctx, p[:])
|
||||||
log.L.WithError(err).Error("failed receiving message on channel")
|
if err != nil {
|
||||||
|
recvErr <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p, status := s.services.call(ctx, req.Service, req.Method, req.Payload)
|
if mh.Type != messageTypeRequest {
|
||||||
|
// we must ignore this for future compat.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var req Request
|
||||||
|
if err := s.codec.Unmarshal(p[:mh.Length], &req); err != nil {
|
||||||
|
recvErr <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case requests <- request{
|
||||||
|
id: mh.StreamID,
|
||||||
|
req: &req,
|
||||||
|
}:
|
||||||
|
case <-done:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case request := <-requests:
|
||||||
|
go func(id uint32) {
|
||||||
|
p, status := s.services.call(ctx, request.req.Service, request.req.Method, request.req.Payload)
|
||||||
resp := &Response{
|
resp := &Response{
|
||||||
Status: status.Proto(),
|
Status: status.Proto(),
|
||||||
Payload: p,
|
Payload: p,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := ch.send(ctx, resp); err != nil {
|
select {
|
||||||
|
case responses <- response{
|
||||||
|
id: id,
|
||||||
|
resp: resp,
|
||||||
|
}:
|
||||||
|
case <-done:
|
||||||
|
}
|
||||||
|
}(request.id)
|
||||||
|
case response := <-responses:
|
||||||
|
p, err := s.codec.Marshal(response.resp)
|
||||||
|
if err != nil {
|
||||||
|
log.L.WithError(err).Error("failed marshaling response")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := ch.send(ctx, response.id, messageTypeResponse, p); err != nil {
|
||||||
log.L.WithError(err).Error("failed sending message on channel")
|
log.L.WithError(err).Error("failed sending message on channel")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
case err := <-recvErr:
|
||||||
|
log.L.WithError(err).Error("error receiving message")
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,6 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
|
|
||||||
"github.com/containerd/containerd/log"
|
|
||||||
"github.com/gogo/protobuf/proto"
|
"github.com/gogo/protobuf/proto"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
|
|
@ -52,7 +51,6 @@ func (s *serviceSet) call(ctx context.Context, serviceName, methodName string, p
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *serviceSet) dispatch(ctx context.Context, serviceName, methodName string, p []byte) ([]byte, error) {
|
func (s *serviceSet) dispatch(ctx context.Context, serviceName, methodName string, p []byte) ([]byte, error) {
|
||||||
ctx = log.WithLogger(ctx, log.G(ctx).WithField("method", fullPath(serviceName, methodName)))
|
|
||||||
method, err := s.resolve(serviceName, methodName)
|
method, err := s.resolve(serviceName, methodName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue