mirror of https://github.com/containers/podman.git
377 lines
8.8 KiB
Go
377 lines
8.8 KiB
Go
package socket
|
|
|
|
import (
|
|
"bufio"
|
|
"encoding/binary"
|
|
"errors"
|
|
"io"
|
|
"net"
|
|
"sync"
|
|
"syscall"
|
|
"time"
|
|
"unsafe"
|
|
|
|
"github.com/digitalocean/go-libvirt/internal/constants"
|
|
)
|
|
|
|
const disconnectTimeout = 5 * time.Second
|
|
|
|
// request and response statuses
|
|
const (
|
|
// StatusOK is always set for method calls or events.
|
|
// For replies it indicates successful completion of the method.
|
|
// For streams it indicates confirmation of the end of file on the stream.
|
|
StatusOK = iota
|
|
|
|
// StatusError for replies indicates that the method call failed
|
|
// and error information is being returned. For streams this indicates
|
|
// that not all data was sent and the stream has aborted.
|
|
StatusError
|
|
|
|
// StatusContinue is only used for streams.
|
|
// This indicates that further data packets will be following.
|
|
StatusContinue
|
|
)
|
|
|
|
// request and response types
|
|
const (
|
|
// Call is used when making calls to the remote server.
|
|
Call = iota
|
|
|
|
// Reply indicates a server reply.
|
|
Reply
|
|
|
|
// Message is an asynchronous notification.
|
|
Message
|
|
|
|
// Stream represents a stream data packet.
|
|
Stream
|
|
|
|
// CallWithFDs is used by a client to indicate the request has
|
|
// arguments with file descriptors.
|
|
CallWithFDs
|
|
|
|
// ReplyWithFDs is used by a server to indicate the request has
|
|
// arguments with file descriptors.
|
|
ReplyWithFDs
|
|
)
|
|
|
|
// Dialer is an interface for connecting to libvirt's underlying socket.
|
|
type Dialer interface {
|
|
Dial() (net.Conn, error)
|
|
}
|
|
|
|
// Router is an interface used to route packets to the appropriate clients.
|
|
type Router interface {
|
|
Route(*Header, []byte)
|
|
}
|
|
|
|
// Socket represents a libvirt Socket and its connection state
|
|
type Socket struct {
|
|
dialer Dialer
|
|
router Router
|
|
|
|
conn net.Conn
|
|
reader *bufio.Reader
|
|
writer *bufio.Writer
|
|
// used to serialize any Socket writes and any updates to conn, r, or w
|
|
mu *sync.Mutex
|
|
|
|
// disconnected is closed when the listen goroutine associated with a
|
|
// Socket connection has returned.
|
|
disconnected chan struct{}
|
|
}
|
|
|
|
// packet represents a RPC request or response.
|
|
type packet struct {
|
|
// Size of packet, in bytes, including length.
|
|
// Len + Header + Payload
|
|
Len uint32
|
|
Header Header
|
|
}
|
|
|
|
// Global packet instance, for use with unsafe.Sizeof()
|
|
var _p packet
|
|
|
|
// Header is a libvirt rpc packet header
|
|
type Header struct {
|
|
// Program identifier
|
|
Program uint32
|
|
|
|
// Program version
|
|
Version uint32
|
|
|
|
// Remote procedure identifier
|
|
Procedure uint32
|
|
|
|
// Call type, e.g., Reply
|
|
Type uint32
|
|
|
|
// Call serial number
|
|
Serial int32
|
|
|
|
// Request status, e.g., StatusOK
|
|
Status uint32
|
|
}
|
|
|
|
// New initializes a new type for managing the Socket.
|
|
func New(dialer Dialer, router Router) *Socket {
|
|
s := &Socket{
|
|
dialer: dialer,
|
|
router: router,
|
|
disconnected: make(chan struct{}),
|
|
mu: &sync.Mutex{},
|
|
}
|
|
|
|
// we start with a closed channel since that indicates no connection
|
|
close(s.disconnected)
|
|
|
|
return s
|
|
}
|
|
|
|
// Connect uses the dialer provided on creation to establish
|
|
// underlying physical connection to the desired libvirt.
|
|
func (s *Socket) Connect() error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
if !s.isDisconnected() {
|
|
return errors.New("already connected to socket")
|
|
}
|
|
conn, err := s.dialer.Dial()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
s.conn = conn
|
|
s.reader = bufio.NewReader(conn)
|
|
s.writer = bufio.NewWriter(conn)
|
|
s.disconnected = make(chan struct{})
|
|
|
|
go s.listenAndRoute()
|
|
|
|
return nil
|
|
}
|
|
|
|
// Disconnect closes the Socket connection to libvirt and waits for the reader
|
|
// gorouting to shut down.
|
|
func (s *Socket) Disconnect() error {
|
|
// just return if we're already disconnected
|
|
if s.isDisconnected() {
|
|
return nil
|
|
}
|
|
|
|
err := s.conn.Close()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// now we wait for the reader to return so as not to avoid it nil
|
|
// referencing
|
|
// Put this in a select,
|
|
// and have it only nil out the conn value if it doesn't fail
|
|
select {
|
|
case <-s.disconnected:
|
|
case <-time.After(disconnectTimeout):
|
|
return errors.New("timed out waiting for Disconnect cleanup")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Disconnected returns a channel that will be closed once the current
|
|
// connection is closed. This can happen due to an explicit call to Disconnect
|
|
// from the client, or due to non-temporary Read or Write errors encountered.
|
|
func (s *Socket) Disconnected() <-chan struct{} {
|
|
return s.disconnected
|
|
}
|
|
|
|
// isDisconnected is a non-blocking function to query whether a connection
|
|
// is disconnected or not.
|
|
func (s *Socket) isDisconnected() bool {
|
|
select {
|
|
case <-s.disconnected:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
// listenAndRoute reads packets from the Socket and calls the provided
|
|
// Router function to route them
|
|
func (s *Socket) listenAndRoute() {
|
|
// only returns once it detects a non-temporary error related to the
|
|
// underlying connection
|
|
listen(s.reader, s.router)
|
|
|
|
// signal any clients listening that the connection has been disconnected
|
|
close(s.disconnected)
|
|
}
|
|
|
|
// listen processes incoming data and routes
|
|
// responses to their respective callback handler.
|
|
func listen(s io.Reader, router Router) {
|
|
for {
|
|
// response packet length
|
|
length, err := pktlen(s)
|
|
if err != nil {
|
|
if isTemporary(err) {
|
|
continue
|
|
}
|
|
// connection is no longer valid, so shutdown
|
|
return
|
|
}
|
|
|
|
// response header
|
|
h, err := extractHeader(s)
|
|
if err != nil {
|
|
// invalid packet
|
|
continue
|
|
}
|
|
|
|
// payload: packet length minus what was previously read
|
|
size := int(length) - int(unsafe.Sizeof(_p))
|
|
buf := make([]byte, size)
|
|
_, err = io.ReadFull(s, buf)
|
|
if err != nil {
|
|
// invalid packet
|
|
continue
|
|
}
|
|
|
|
// route response to caller
|
|
router.Route(h, buf)
|
|
}
|
|
}
|
|
|
|
// isTemporary returns true if the error returned from a read is transient.
|
|
// If the error type is an OpError, check whether the net connection
|
|
// error condition is temporary (which means we can keep using the
|
|
// connection).
|
|
// Errors not of the net.OpError type tend to be things like io.EOF,
|
|
// syscall.EINVAL, or io.ErrClosedPipe (i.e. all things that
|
|
// indicate the connection in use is no longer valid.)
|
|
func isTemporary(err error) bool {
|
|
opErr, ok := err.(*net.OpError)
|
|
if ok {
|
|
return opErr.Temporary()
|
|
}
|
|
return false
|
|
}
|
|
|
|
// pktlen returns the length of an incoming RPC packet. Read errors will
|
|
// result in a returned response length of 0 and a non-nil error.
|
|
func pktlen(r io.Reader) (uint32, error) {
|
|
buf := make([]byte, unsafe.Sizeof(_p.Len))
|
|
|
|
// extract the packet's length from the header
|
|
_, err := io.ReadFull(r, buf)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return binary.BigEndian.Uint32(buf), nil
|
|
}
|
|
|
|
// extractHeader returns the decoded header from an incoming response.
|
|
func extractHeader(r io.Reader) (*Header, error) {
|
|
buf := make([]byte, unsafe.Sizeof(_p.Header))
|
|
|
|
// extract the packet's header from r
|
|
_, err := io.ReadFull(r, buf)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &Header{
|
|
Program: binary.BigEndian.Uint32(buf[0:4]),
|
|
Version: binary.BigEndian.Uint32(buf[4:8]),
|
|
Procedure: binary.BigEndian.Uint32(buf[8:12]),
|
|
Type: binary.BigEndian.Uint32(buf[12:16]),
|
|
Serial: int32(binary.BigEndian.Uint32(buf[16:20])),
|
|
Status: binary.BigEndian.Uint32(buf[20:24]),
|
|
}, nil
|
|
}
|
|
|
|
// SendPacket sends a packet to libvirt on the socket connection.
|
|
func (s *Socket) SendPacket(
|
|
serial int32,
|
|
proc uint32,
|
|
program uint32,
|
|
payload []byte,
|
|
typ uint32,
|
|
status uint32,
|
|
) error {
|
|
p := packet{
|
|
Header: Header{
|
|
Program: program,
|
|
Version: constants.ProtocolVersion,
|
|
Procedure: proc,
|
|
Type: typ,
|
|
Serial: serial,
|
|
Status: status,
|
|
},
|
|
}
|
|
|
|
size := int(unsafe.Sizeof(p.Len)) + int(unsafe.Sizeof(p.Header))
|
|
if payload != nil {
|
|
size += len(payload)
|
|
}
|
|
p.Len = uint32(size)
|
|
|
|
if s.isDisconnected() {
|
|
// this mirrors what a lot of net code return on use of a no
|
|
// longer valid connection
|
|
return syscall.EINVAL
|
|
}
|
|
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
err := binary.Write(s.writer, binary.BigEndian, p)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// write payload
|
|
if payload != nil {
|
|
err = binary.Write(s.writer, binary.BigEndian, payload)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return s.writer.Flush()
|
|
}
|
|
|
|
// SendStream sends a stream of packets to libvirt on the socket connection.
|
|
func (s *Socket) SendStream(serial int32, proc uint32, program uint32,
|
|
stream io.Reader, abort chan bool) error {
|
|
// Keep total packet length under 4 MiB to follow possible limitation in libvirt server code
|
|
buf := make([]byte, 4*MiB-unsafe.Sizeof(_p))
|
|
for {
|
|
select {
|
|
case <-abort:
|
|
return s.SendPacket(serial, proc, program, nil, Stream, StatusError)
|
|
default:
|
|
}
|
|
n, err := stream.Read(buf)
|
|
if n > 0 {
|
|
err2 := s.SendPacket(serial, proc, program, buf[:n], Stream, StatusContinue)
|
|
if err2 != nil {
|
|
return err2
|
|
}
|
|
}
|
|
if err != nil {
|
|
if err == io.EOF {
|
|
return s.SendPacket(serial, proc, program, nil, Stream, StatusOK)
|
|
}
|
|
// keep original error
|
|
err2 := s.SendPacket(serial, proc, program, nil, Stream, StatusError)
|
|
if err2 != nil {
|
|
return err2
|
|
}
|
|
return err
|
|
}
|
|
}
|
|
}
|