podman/vendor/github.com/digitalocean/go-libvirt/socket/socket.go

377 lines
8.8 KiB
Go

package socket
import (
"bufio"
"encoding/binary"
"errors"
"io"
"net"
"sync"
"syscall"
"time"
"unsafe"
"github.com/digitalocean/go-libvirt/internal/constants"
)
const disconnectTimeout = 5 * time.Second
// request and response statuses
const (
// StatusOK is always set for method calls or events.
// For replies it indicates successful completion of the method.
// For streams it indicates confirmation of the end of file on the stream.
StatusOK = iota
// StatusError for replies indicates that the method call failed
// and error information is being returned. For streams this indicates
// that not all data was sent and the stream has aborted.
StatusError
// StatusContinue is only used for streams.
// This indicates that further data packets will be following.
StatusContinue
)
// request and response types
const (
// Call is used when making calls to the remote server.
Call = iota
// Reply indicates a server reply.
Reply
// Message is an asynchronous notification.
Message
// Stream represents a stream data packet.
Stream
// CallWithFDs is used by a client to indicate the request has
// arguments with file descriptors.
CallWithFDs
// ReplyWithFDs is used by a server to indicate the request has
// arguments with file descriptors.
ReplyWithFDs
)
// Dialer is an interface for connecting to libvirt's underlying socket.
type Dialer interface {
Dial() (net.Conn, error)
}
// Router is an interface used to route packets to the appropriate clients.
type Router interface {
Route(*Header, []byte)
}
// Socket represents a libvirt Socket and its connection state
type Socket struct {
dialer Dialer
router Router
conn net.Conn
reader *bufio.Reader
writer *bufio.Writer
// used to serialize any Socket writes and any updates to conn, r, or w
mu *sync.Mutex
// disconnected is closed when the listen goroutine associated with a
// Socket connection has returned.
disconnected chan struct{}
}
// packet represents a RPC request or response.
type packet struct {
// Size of packet, in bytes, including length.
// Len + Header + Payload
Len uint32
Header Header
}
// Global packet instance, for use with unsafe.Sizeof()
var _p packet
// Header is a libvirt rpc packet header
type Header struct {
// Program identifier
Program uint32
// Program version
Version uint32
// Remote procedure identifier
Procedure uint32
// Call type, e.g., Reply
Type uint32
// Call serial number
Serial int32
// Request status, e.g., StatusOK
Status uint32
}
// New initializes a new type for managing the Socket.
func New(dialer Dialer, router Router) *Socket {
s := &Socket{
dialer: dialer,
router: router,
disconnected: make(chan struct{}),
mu: &sync.Mutex{},
}
// we start with a closed channel since that indicates no connection
close(s.disconnected)
return s
}
// Connect uses the dialer provided on creation to establish
// underlying physical connection to the desired libvirt.
func (s *Socket) Connect() error {
s.mu.Lock()
defer s.mu.Unlock()
if !s.isDisconnected() {
return errors.New("already connected to socket")
}
conn, err := s.dialer.Dial()
if err != nil {
return err
}
s.conn = conn
s.reader = bufio.NewReader(conn)
s.writer = bufio.NewWriter(conn)
s.disconnected = make(chan struct{})
go s.listenAndRoute()
return nil
}
// Disconnect closes the Socket connection to libvirt and waits for the reader
// gorouting to shut down.
func (s *Socket) Disconnect() error {
// just return if we're already disconnected
if s.isDisconnected() {
return nil
}
err := s.conn.Close()
if err != nil {
return err
}
// now we wait for the reader to return so as not to avoid it nil
// referencing
// Put this in a select,
// and have it only nil out the conn value if it doesn't fail
select {
case <-s.disconnected:
case <-time.After(disconnectTimeout):
return errors.New("timed out waiting for Disconnect cleanup")
}
return nil
}
// Disconnected returns a channel that will be closed once the current
// connection is closed. This can happen due to an explicit call to Disconnect
// from the client, or due to non-temporary Read or Write errors encountered.
func (s *Socket) Disconnected() <-chan struct{} {
return s.disconnected
}
// isDisconnected is a non-blocking function to query whether a connection
// is disconnected or not.
func (s *Socket) isDisconnected() bool {
select {
case <-s.disconnected:
return true
default:
return false
}
}
// listenAndRoute reads packets from the Socket and calls the provided
// Router function to route them
func (s *Socket) listenAndRoute() {
// only returns once it detects a non-temporary error related to the
// underlying connection
listen(s.reader, s.router)
// signal any clients listening that the connection has been disconnected
close(s.disconnected)
}
// listen processes incoming data and routes
// responses to their respective callback handler.
func listen(s io.Reader, router Router) {
for {
// response packet length
length, err := pktlen(s)
if err != nil {
if isTemporary(err) {
continue
}
// connection is no longer valid, so shutdown
return
}
// response header
h, err := extractHeader(s)
if err != nil {
// invalid packet
continue
}
// payload: packet length minus what was previously read
size := int(length) - int(unsafe.Sizeof(_p))
buf := make([]byte, size)
_, err = io.ReadFull(s, buf)
if err != nil {
// invalid packet
continue
}
// route response to caller
router.Route(h, buf)
}
}
// isTemporary returns true if the error returned from a read is transient.
// If the error type is an OpError, check whether the net connection
// error condition is temporary (which means we can keep using the
// connection).
// Errors not of the net.OpError type tend to be things like io.EOF,
// syscall.EINVAL, or io.ErrClosedPipe (i.e. all things that
// indicate the connection in use is no longer valid.)
func isTemporary(err error) bool {
opErr, ok := err.(*net.OpError)
if ok {
return opErr.Temporary()
}
return false
}
// pktlen returns the length of an incoming RPC packet. Read errors will
// result in a returned response length of 0 and a non-nil error.
func pktlen(r io.Reader) (uint32, error) {
buf := make([]byte, unsafe.Sizeof(_p.Len))
// extract the packet's length from the header
_, err := io.ReadFull(r, buf)
if err != nil {
return 0, err
}
return binary.BigEndian.Uint32(buf), nil
}
// extractHeader returns the decoded header from an incoming response.
func extractHeader(r io.Reader) (*Header, error) {
buf := make([]byte, unsafe.Sizeof(_p.Header))
// extract the packet's header from r
_, err := io.ReadFull(r, buf)
if err != nil {
return nil, err
}
return &Header{
Program: binary.BigEndian.Uint32(buf[0:4]),
Version: binary.BigEndian.Uint32(buf[4:8]),
Procedure: binary.BigEndian.Uint32(buf[8:12]),
Type: binary.BigEndian.Uint32(buf[12:16]),
Serial: int32(binary.BigEndian.Uint32(buf[16:20])),
Status: binary.BigEndian.Uint32(buf[20:24]),
}, nil
}
// SendPacket sends a packet to libvirt on the socket connection.
func (s *Socket) SendPacket(
serial int32,
proc uint32,
program uint32,
payload []byte,
typ uint32,
status uint32,
) error {
p := packet{
Header: Header{
Program: program,
Version: constants.ProtocolVersion,
Procedure: proc,
Type: typ,
Serial: serial,
Status: status,
},
}
size := int(unsafe.Sizeof(p.Len)) + int(unsafe.Sizeof(p.Header))
if payload != nil {
size += len(payload)
}
p.Len = uint32(size)
if s.isDisconnected() {
// this mirrors what a lot of net code return on use of a no
// longer valid connection
return syscall.EINVAL
}
s.mu.Lock()
defer s.mu.Unlock()
err := binary.Write(s.writer, binary.BigEndian, p)
if err != nil {
return err
}
// write payload
if payload != nil {
err = binary.Write(s.writer, binary.BigEndian, payload)
if err != nil {
return err
}
}
return s.writer.Flush()
}
// SendStream sends a stream of packets to libvirt on the socket connection.
func (s *Socket) SendStream(serial int32, proc uint32, program uint32,
stream io.Reader, abort chan bool) error {
// Keep total packet length under 4 MiB to follow possible limitation in libvirt server code
buf := make([]byte, 4*MiB-unsafe.Sizeof(_p))
for {
select {
case <-abort:
return s.SendPacket(serial, proc, program, nil, Stream, StatusError)
default:
}
n, err := stream.Read(buf)
if n > 0 {
err2 := s.SendPacket(serial, proc, program, buf[:n], Stream, StatusContinue)
if err2 != nil {
return err2
}
}
if err != nil {
if err == io.EOF {
return s.SendPacket(serial, proc, program, nil, Stream, StatusOK)
}
// keep original error
err2 := s.SendPacket(serial, proc, program, nil, Stream, StatusError)
if err2 != nil {
return err2
}
return err
}
}
}