mirror of https://github.com/containers/podman.git
154 lines
3.1 KiB
Go
154 lines
3.1 KiB
Go
package scp
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"path/filepath"
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
const (
|
|
fileMode = "0644"
|
|
buffSize = 1024 * 256
|
|
)
|
|
|
|
//CopyTo copy from local to remote
|
|
func CopyTo(sshClient *ssh.Client, local string, remote string) (int64, error) {
|
|
session, err := sshClient.NewSession()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
defer session.Close()
|
|
stderr := &bytes.Buffer{}
|
|
session.Stderr = stderr
|
|
stdout := &bytes.Buffer{}
|
|
session.Stdout = stdout
|
|
writer, err := session.StdinPipe()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
defer writer.Close()
|
|
err = session.Start("scp -t " + filepath.Dir(remote))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
localFile, err := os.Open(local)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
fileInfo, err := localFile.Stat()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
_, err = fmt.Fprintf(writer, "C%s %d %s\n", fileMode, fileInfo.Size(), filepath.Base(remote))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
n, err := copyN(writer, localFile, fileInfo.Size())
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
err = ack(writer)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
err = session.Wait()
|
|
log.Debugf("Copied %v bytes out of %v. err: %v stdout:%v. stderr:%v", n, fileInfo.Size(), err, stdout, stderr)
|
|
//NOTE: Process exited with status 1 is not an error, it just how scp work. (waiting for the next control message and we send EOF)
|
|
return n, nil
|
|
}
|
|
|
|
//CopyFrom copy from remote to local
|
|
func CopyFrom(sshClient *ssh.Client, remote string, local string) (int64, error) {
|
|
session, err := sshClient.NewSession()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
defer session.Close()
|
|
stderr := &bytes.Buffer{}
|
|
session.Stderr = stderr
|
|
writer, err := session.StdinPipe()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
defer writer.Close()
|
|
reader, err := session.StdoutPipe()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
err = session.Start("scp -f " + remote)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
err = ack(writer)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
msg, err := NewMessageFromReader(reader)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
if msg.Type == ErrorMessage || msg.Type == WarnMessage {
|
|
return 0, msg.Error
|
|
}
|
|
log.Debugf("Receiving %v", msg)
|
|
|
|
err = ack(writer)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
outFile, err := os.Create(local)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
defer outFile.Close()
|
|
n, err := copyN(outFile, reader, msg.Size)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
err = outFile.Sync()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
err = outFile.Close()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
err = session.Wait()
|
|
log.Debugf("Copied %v bytes out of %v. err: %v stderr:%v", n, msg.Size, err, stderr)
|
|
return n, nil
|
|
}
|
|
|
|
func ack(writer io.Writer) error {
|
|
var msg = []byte{0, 0, 10, 13}
|
|
n, err := writer.Write(msg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if n < len(msg) {
|
|
return errors.New("Failed to write ack buffer")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func copyN(writer io.Writer, src io.Reader, size int64) (int64, error) {
|
|
reader := io.LimitReader(src, size)
|
|
var total int64
|
|
for total < size {
|
|
n, err := io.CopyBuffer(writer, reader, make([]byte, buffSize))
|
|
log.Debugf("Copied chunk %v total: %v out of %v err: %v ", n, total, size, err)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
total += n
|
|
}
|
|
return total, nil
|
|
}
|