mirror of https://github.com/docker/buildx.git
				
				
				
			
		
			
				
	
	
		
			513 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			513 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Go
		
	
	
	
| /*
 | |
|    Copyright The containerd Authors.
 | |
| 
 | |
|    Licensed under the Apache License, Version 2.0 (the "License");
 | |
|    you may not use this file except in compliance with the License.
 | |
|    You may obtain a copy of the License at
 | |
| 
 | |
|        http://www.apache.org/licenses/LICENSE-2.0
 | |
| 
 | |
|    Unless required by applicable law or agreed to in writing, software
 | |
|    distributed under the License is distributed on an "AS IS" BASIS,
 | |
|    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
|    See the License for the specific language governing permissions and
 | |
|    limitations under the License.
 | |
| */
 | |
| 
 | |
| package ttrpc
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"io"
 | |
| 	"net"
 | |
| 	"strings"
 | |
| 	"sync"
 | |
| 	"syscall"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/sirupsen/logrus"
 | |
| 	"google.golang.org/grpc/codes"
 | |
| 	"google.golang.org/grpc/status"
 | |
| 	"google.golang.org/protobuf/proto"
 | |
| )
 | |
| 
 | |
| // Client for a ttrpc server
 | |
| type Client struct {
 | |
| 	codec   codec
 | |
| 	conn    net.Conn
 | |
| 	channel *channel
 | |
| 
 | |
| 	streamLock   sync.RWMutex
 | |
| 	streams      map[streamID]*stream
 | |
| 	nextStreamID streamID
 | |
| 	sendLock     sync.Mutex
 | |
| 
 | |
| 	ctx    context.Context
 | |
| 	closed func()
 | |
| 
 | |
| 	closeOnce       sync.Once
 | |
| 	userCloseFunc   func()
 | |
| 	userCloseWaitCh chan struct{}
 | |
| 
 | |
| 	interceptor UnaryClientInterceptor
 | |
| }
 | |
| 
 | |
| // ClientOpts configures a client
 | |
| type ClientOpts func(c *Client)
 | |
| 
 | |
| // WithOnClose sets the close func whenever the client's Close() method is called
 | |
| func WithOnClose(onClose func()) ClientOpts {
 | |
| 	return func(c *Client) {
 | |
| 		c.userCloseFunc = onClose
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // WithUnaryClientInterceptor sets the provided client interceptor
 | |
| func WithUnaryClientInterceptor(i UnaryClientInterceptor) ClientOpts {
 | |
| 	return func(c *Client) {
 | |
| 		c.interceptor = i
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // NewClient creates a new ttrpc client using the given connection
 | |
| func NewClient(conn net.Conn, opts ...ClientOpts) *Client {
 | |
| 	ctx, cancel := context.WithCancel(context.Background())
 | |
| 	channel := newChannel(conn)
 | |
| 	c := &Client{
 | |
| 		codec:           codec{},
 | |
| 		conn:            conn,
 | |
| 		channel:         channel,
 | |
| 		streams:         make(map[streamID]*stream),
 | |
| 		nextStreamID:    1,
 | |
| 		closed:          cancel,
 | |
| 		ctx:             ctx,
 | |
| 		userCloseFunc:   func() {},
 | |
| 		userCloseWaitCh: make(chan struct{}),
 | |
| 		interceptor:     defaultClientInterceptor,
 | |
| 	}
 | |
| 
 | |
| 	for _, o := range opts {
 | |
| 		o(c)
 | |
| 	}
 | |
| 
 | |
| 	go c.run()
 | |
| 	return c
 | |
| }
 | |
| 
 | |
| func (c *Client) send(sid uint32, mt messageType, flags uint8, b []byte) error {
 | |
| 	c.sendLock.Lock()
 | |
| 	defer c.sendLock.Unlock()
 | |
| 	return c.channel.send(sid, mt, flags, b)
 | |
| }
 | |
| 
 | |
| // Call makes a unary request and returns with response
 | |
| func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error {
 | |
| 	payload, err := c.codec.Marshal(req)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	var (
 | |
| 		creq = &Request{
 | |
| 			Service: service,
 | |
| 			Method:  method,
 | |
| 			Payload: payload,
 | |
| 			// TODO: metadata from context
 | |
| 		}
 | |
| 
 | |
| 		cresp = &Response{}
 | |
| 	)
 | |
| 
 | |
| 	if metadata, ok := GetMetadata(ctx); ok {
 | |
| 		metadata.setRequest(creq)
 | |
| 	}
 | |
| 
 | |
| 	if dl, ok := ctx.Deadline(); ok {
 | |
| 		creq.TimeoutNano = time.Until(dl).Nanoseconds()
 | |
| 	}
 | |
| 
 | |
| 	info := &UnaryClientInfo{
 | |
| 		FullMethod: fullPath(service, method),
 | |
| 	}
 | |
| 	if err := c.interceptor(ctx, creq, cresp, info, c.dispatch); err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	if err := c.codec.Unmarshal(cresp.Payload, resp); err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	if cresp.Status != nil && cresp.Status.Code != int32(codes.OK) {
 | |
| 		return status.ErrorProto(cresp.Status)
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // StreamDesc describes the stream properties, whether the stream has
 | |
| // a streaming client, a streaming server, or both
 | |
| type StreamDesc struct {
 | |
| 	StreamingClient bool
 | |
| 	StreamingServer bool
 | |
| }
 | |
| 
 | |
| // ClientStream is used to send or recv messages on the underlying stream
 | |
| type ClientStream interface {
 | |
| 	CloseSend() error
 | |
| 	SendMsg(m interface{}) error
 | |
| 	RecvMsg(m interface{}) error
 | |
| }
 | |
| 
 | |
| type clientStream struct {
 | |
| 	ctx          context.Context
 | |
| 	s            *stream
 | |
| 	c            *Client
 | |
| 	desc         *StreamDesc
 | |
| 	localClosed  bool
 | |
| 	remoteClosed bool
 | |
| }
 | |
| 
 | |
| func (cs *clientStream) CloseSend() error {
 | |
| 	if !cs.desc.StreamingClient {
 | |
| 		return fmt.Errorf("%w: cannot close non-streaming client", ErrProtocol)
 | |
| 	}
 | |
| 	if cs.localClosed {
 | |
| 		return ErrStreamClosed
 | |
| 	}
 | |
| 	err := cs.s.send(messageTypeData, flagRemoteClosed|flagNoData, nil)
 | |
| 	if err != nil {
 | |
| 		return filterCloseErr(err)
 | |
| 	}
 | |
| 	cs.localClosed = true
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (cs *clientStream) SendMsg(m interface{}) error {
 | |
| 	if !cs.desc.StreamingClient {
 | |
| 		return fmt.Errorf("%w: cannot send data from non-streaming client", ErrProtocol)
 | |
| 	}
 | |
| 	if cs.localClosed {
 | |
| 		return ErrStreamClosed
 | |
| 	}
 | |
| 
 | |
| 	var (
 | |
| 		payload []byte
 | |
| 		err     error
 | |
| 	)
 | |
| 	if m != nil {
 | |
| 		payload, err = cs.c.codec.Marshal(m)
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	err = cs.s.send(messageTypeData, 0, payload)
 | |
| 	if err != nil {
 | |
| 		return filterCloseErr(err)
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (cs *clientStream) RecvMsg(m interface{}) error {
 | |
| 	if cs.remoteClosed {
 | |
| 		return io.EOF
 | |
| 	}
 | |
| 
 | |
| 	var msg *streamMessage
 | |
| 	select {
 | |
| 	case <-cs.ctx.Done():
 | |
| 		return cs.ctx.Err()
 | |
| 	case <-cs.s.recvClose:
 | |
| 		// If recv has a pending message, process that first
 | |
| 		select {
 | |
| 		case msg = <-cs.s.recv:
 | |
| 		default:
 | |
| 			return cs.s.recvErr
 | |
| 		}
 | |
| 	case msg = <-cs.s.recv:
 | |
| 	}
 | |
| 
 | |
| 	if msg.header.Type == messageTypeResponse {
 | |
| 		resp := &Response{}
 | |
| 		err := proto.Unmarshal(msg.payload[:msg.header.Length], resp)
 | |
| 		// return the payload buffer for reuse
 | |
| 		cs.c.channel.putmbuf(msg.payload)
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 
 | |
| 		if err := cs.c.codec.Unmarshal(resp.Payload, m); err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 
 | |
| 		if resp.Status != nil && resp.Status.Code != int32(codes.OK) {
 | |
| 			return status.ErrorProto(resp.Status)
 | |
| 		}
 | |
| 
 | |
| 		cs.c.deleteStream(cs.s)
 | |
| 		cs.remoteClosed = true
 | |
| 
 | |
| 		return nil
 | |
| 	} else if msg.header.Type == messageTypeData {
 | |
| 		if !cs.desc.StreamingServer {
 | |
| 			cs.c.deleteStream(cs.s)
 | |
| 			cs.remoteClosed = true
 | |
| 			return fmt.Errorf("received data from non-streaming server: %w", ErrProtocol)
 | |
| 		}
 | |
| 		if msg.header.Flags&flagRemoteClosed == flagRemoteClosed {
 | |
| 			cs.c.deleteStream(cs.s)
 | |
| 			cs.remoteClosed = true
 | |
| 
 | |
| 			if msg.header.Flags&flagNoData == flagNoData {
 | |
| 				return io.EOF
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		err := cs.c.codec.Unmarshal(msg.payload[:msg.header.Length], m)
 | |
| 		cs.c.channel.putmbuf(msg.payload)
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 		return nil
 | |
| 	}
 | |
| 
 | |
| 	return fmt.Errorf("unexpected %q message received: %w", msg.header.Type, ErrProtocol)
 | |
| }
 | |
| 
 | |
| // Close closes the ttrpc connection and underlying connection
 | |
| func (c *Client) Close() error {
 | |
| 	c.closeOnce.Do(func() {
 | |
| 		c.closed()
 | |
| 
 | |
| 		c.conn.Close()
 | |
| 	})
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // UserOnCloseWait is used to blocks untils the user's on-close callback
 | |
| // finishes.
 | |
| func (c *Client) UserOnCloseWait(ctx context.Context) error {
 | |
| 	select {
 | |
| 	case <-c.userCloseWaitCh:
 | |
| 		return nil
 | |
| 	case <-ctx.Done():
 | |
| 		return ctx.Err()
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (c *Client) run() {
 | |
| 	err := c.receiveLoop()
 | |
| 	c.Close()
 | |
| 	c.cleanupStreams(err)
 | |
| 
 | |
| 	c.userCloseFunc()
 | |
| 	close(c.userCloseWaitCh)
 | |
| }
 | |
| 
 | |
| func (c *Client) receiveLoop() error {
 | |
| 	for {
 | |
| 		select {
 | |
| 		case <-c.ctx.Done():
 | |
| 			return ErrClosed
 | |
| 		default:
 | |
| 			var (
 | |
| 				msg = &streamMessage{}
 | |
| 				err error
 | |
| 			)
 | |
| 
 | |
| 			msg.header, msg.payload, err = c.channel.recv()
 | |
| 			if err != nil {
 | |
| 				_, ok := status.FromError(err)
 | |
| 				if !ok {
 | |
| 					// treat all errors that are not an rpc status as terminal.
 | |
| 					// all others poison the connection.
 | |
| 					return filterCloseErr(err)
 | |
| 				}
 | |
| 			}
 | |
| 			sid := streamID(msg.header.StreamID)
 | |
| 			s := c.getStream(sid)
 | |
| 			if s == nil {
 | |
| 				logrus.WithField("stream", sid).Errorf("ttrpc: received message on inactive stream")
 | |
| 				continue
 | |
| 			}
 | |
| 
 | |
| 			if err != nil {
 | |
| 				s.closeWithError(err)
 | |
| 			} else {
 | |
| 				if err := s.receive(c.ctx, msg); err != nil {
 | |
| 					logrus.WithError(err).WithField("stream", sid).Errorf("ttrpc: failed to handle message")
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // createStream creates a new stream and registers it with the client
 | |
| // Introduce stream types for multiple or single response
 | |
| func (c *Client) createStream(flags uint8, b []byte) (*stream, error) {
 | |
| 	c.streamLock.Lock()
 | |
| 
 | |
| 	// Check if closed since lock acquired to prevent adding
 | |
| 	// anything after cleanup completes
 | |
| 	select {
 | |
| 	case <-c.ctx.Done():
 | |
| 		c.streamLock.Unlock()
 | |
| 		return nil, ErrClosed
 | |
| 	default:
 | |
| 	}
 | |
| 
 | |
| 	// Stream ID should be allocated at same time
 | |
| 	s := newStream(c.nextStreamID, c)
 | |
| 	c.streams[s.id] = s
 | |
| 	c.nextStreamID = c.nextStreamID + 2
 | |
| 
 | |
| 	c.sendLock.Lock()
 | |
| 	defer c.sendLock.Unlock()
 | |
| 	c.streamLock.Unlock()
 | |
| 
 | |
| 	if err := c.channel.send(uint32(s.id), messageTypeRequest, flags, b); err != nil {
 | |
| 		return s, filterCloseErr(err)
 | |
| 	}
 | |
| 
 | |
| 	return s, nil
 | |
| }
 | |
| 
 | |
| func (c *Client) deleteStream(s *stream) {
 | |
| 	c.streamLock.Lock()
 | |
| 	delete(c.streams, s.id)
 | |
| 	c.streamLock.Unlock()
 | |
| 	s.closeWithError(nil)
 | |
| }
 | |
| 
 | |
| func (c *Client) getStream(sid streamID) *stream {
 | |
| 	c.streamLock.RLock()
 | |
| 	s := c.streams[sid]
 | |
| 	c.streamLock.RUnlock()
 | |
| 	return s
 | |
| }
 | |
| 
 | |
| func (c *Client) cleanupStreams(err error) {
 | |
| 	c.streamLock.Lock()
 | |
| 	defer c.streamLock.Unlock()
 | |
| 
 | |
| 	for sid, s := range c.streams {
 | |
| 		s.closeWithError(err)
 | |
| 		delete(c.streams, sid)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // filterCloseErr rewrites EOF and EPIPE errors to ErrClosed. Use when
 | |
| // returning from call or handling errors from main read loop.
 | |
| //
 | |
| // This purposely ignores errors with a wrapped cause.
 | |
| func filterCloseErr(err error) error {
 | |
| 	switch {
 | |
| 	case err == nil:
 | |
| 		return nil
 | |
| 	case err == io.EOF:
 | |
| 		return ErrClosed
 | |
| 	case errors.Is(err, io.ErrClosedPipe):
 | |
| 		return ErrClosed
 | |
| 	case errors.Is(err, io.EOF):
 | |
| 		return ErrClosed
 | |
| 	case strings.Contains(err.Error(), "use of closed network connection"):
 | |
| 		return ErrClosed
 | |
| 	default:
 | |
| 		// if we have an epipe on a write or econnreset on a read , we cast to errclosed
 | |
| 		var oerr *net.OpError
 | |
| 		if errors.As(err, &oerr) {
 | |
| 			if (oerr.Op == "write" && errors.Is(err, syscall.EPIPE)) ||
 | |
| 				(oerr.Op == "read" && errors.Is(err, syscall.ECONNRESET)) {
 | |
| 				return ErrClosed
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return err
 | |
| }
 | |
| 
 | |
| // NewStream creates a new stream with the given stream descriptor to the
 | |
| // specified service and method. If not a streaming client, the request object
 | |
| // may be provided.
 | |
| func (c *Client) NewStream(ctx context.Context, desc *StreamDesc, service, method string, req interface{}) (ClientStream, error) {
 | |
| 	var payload []byte
 | |
| 	if req != nil {
 | |
| 		var err error
 | |
| 		payload, err = c.codec.Marshal(req)
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	request := &Request{
 | |
| 		Service: service,
 | |
| 		Method:  method,
 | |
| 		Payload: payload,
 | |
| 		// TODO: metadata from context
 | |
| 	}
 | |
| 	p, err := c.codec.Marshal(request)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	var flags uint8
 | |
| 	if desc.StreamingClient {
 | |
| 		flags = flagRemoteOpen
 | |
| 	} else {
 | |
| 		flags = flagRemoteClosed
 | |
| 	}
 | |
| 	s, err := c.createStream(flags, p)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return &clientStream{
 | |
| 		ctx:  ctx,
 | |
| 		s:    s,
 | |
| 		c:    c,
 | |
| 		desc: desc,
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) error {
 | |
| 	p, err := c.codec.Marshal(req)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	s, err := c.createStream(0, p)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	defer c.deleteStream(s)
 | |
| 
 | |
| 	var msg *streamMessage
 | |
| 	select {
 | |
| 	case <-ctx.Done():
 | |
| 		return ctx.Err()
 | |
| 	case <-c.ctx.Done():
 | |
| 		return ErrClosed
 | |
| 	case <-s.recvClose:
 | |
| 		// If recv has a pending message, process that first
 | |
| 		select {
 | |
| 		case msg = <-s.recv:
 | |
| 		default:
 | |
| 			return s.recvErr
 | |
| 		}
 | |
| 	case msg = <-s.recv:
 | |
| 	}
 | |
| 
 | |
| 	if msg.header.Type == messageTypeResponse {
 | |
| 		err = proto.Unmarshal(msg.payload[:msg.header.Length], resp)
 | |
| 	} else {
 | |
| 		err = fmt.Errorf("unexpected %q message received: %w", msg.header.Type, ErrProtocol)
 | |
| 	}
 | |
| 
 | |
| 	// return the payload buffer for reuse
 | |
| 	c.channel.putmbuf(msg.payload)
 | |
| 
 | |
| 	return err
 | |
| }
 |