podman/vendor/github.com/dtylman/scp/scp.go

154 lines
3.1 KiB
Go

package scp
import (
"bytes"
"errors"
"fmt"
"io"
"os"
"path/filepath"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
)
const (
fileMode = "0644"
buffSize = 1024 * 256
)
//CopyTo copy from local to remote
func CopyTo(sshClient *ssh.Client, local string, remote string) (int64, error) {
session, err := sshClient.NewSession()
if err != nil {
return 0, err
}
defer session.Close()
stderr := &bytes.Buffer{}
session.Stderr = stderr
stdout := &bytes.Buffer{}
session.Stdout = stdout
writer, err := session.StdinPipe()
if err != nil {
return 0, err
}
defer writer.Close()
err = session.Start("scp -t " + filepath.Dir(remote))
if err != nil {
return 0, err
}
localFile, err := os.Open(local)
if err != nil {
return 0, err
}
fileInfo, err := localFile.Stat()
if err != nil {
return 0, err
}
_, err = fmt.Fprintf(writer, "C%s %d %s\n", fileMode, fileInfo.Size(), filepath.Base(remote))
if err != nil {
return 0, err
}
n, err := copyN(writer, localFile, fileInfo.Size())
if err != nil {
return 0, err
}
err = ack(writer)
if err != nil {
return 0, err
}
err = session.Wait()
log.Debugf("Copied %v bytes out of %v. err: %v stdout:%v. stderr:%v", n, fileInfo.Size(), err, stdout, stderr)
//NOTE: Process exited with status 1 is not an error, it just how scp work. (waiting for the next control message and we send EOF)
return n, nil
}
//CopyFrom copy from remote to local
func CopyFrom(sshClient *ssh.Client, remote string, local string) (int64, error) {
session, err := sshClient.NewSession()
if err != nil {
return 0, err
}
defer session.Close()
stderr := &bytes.Buffer{}
session.Stderr = stderr
writer, err := session.StdinPipe()
if err != nil {
return 0, err
}
defer writer.Close()
reader, err := session.StdoutPipe()
if err != nil {
return 0, err
}
err = session.Start("scp -f " + remote)
if err != nil {
return 0, err
}
err = ack(writer)
if err != nil {
return 0, err
}
msg, err := NewMessageFromReader(reader)
if err != nil {
return 0, err
}
if msg.Type == ErrorMessage || msg.Type == WarnMessage {
return 0, msg.Error
}
log.Debugf("Receiving %v", msg)
err = ack(writer)
if err != nil {
return 0, err
}
outFile, err := os.Create(local)
if err != nil {
return 0, err
}
defer outFile.Close()
n, err := copyN(outFile, reader, msg.Size)
if err != nil {
return 0, err
}
err = outFile.Sync()
if err != nil {
return 0, err
}
err = outFile.Close()
if err != nil {
return 0, err
}
err = session.Wait()
log.Debugf("Copied %v bytes out of %v. err: %v stderr:%v", n, msg.Size, err, stderr)
return n, nil
}
func ack(writer io.Writer) error {
var msg = []byte{0, 0, 10, 13}
n, err := writer.Write(msg)
if err != nil {
return err
}
if n < len(msg) {
return errors.New("Failed to write ack buffer")
}
return nil
}
func copyN(writer io.Writer, src io.Reader, size int64) (int64, error) {
reader := io.LimitReader(src, size)
var total int64
for total < size {
n, err := io.CopyBuffer(writer, reader, make([]byte, buffSize))
log.Debugf("Copied chunk %v total: %v out of %v err: %v ", n, total, size, err)
if err != nil {
return 0, err
}
total += n
}
return total, nil
}