mirror of https://github.com/knative/func.git
454 lines
12 KiB
Go
454 lines
12 KiB
Go
// NOTE: this code is based on "github.com/containers/podman/v3/pkg/bindings"
|
|
|
|
package ssh
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
urlPkg "net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/docker/cli/cli/connhelper"
|
|
"github.com/docker/docker/pkg/homedir"
|
|
"golang.org/x/crypto/ssh"
|
|
"golang.org/x/crypto/ssh/agent"
|
|
"golang.org/x/crypto/ssh/knownhosts"
|
|
)
|
|
|
|
type PasswordCallback func() (string, error)
|
|
type PassPhraseCallback func() (string, error)
|
|
type HostKeyCallback func(hostPort string, pubKey ssh.PublicKey) error
|
|
|
|
type Config struct {
|
|
Identity string
|
|
PassPhrase string
|
|
PasswordCallback PasswordCallback
|
|
PassPhraseCallback PassPhraseCallback
|
|
HostKeyCallback HostKeyCallback
|
|
}
|
|
|
|
type DialContextFn = func(ctx context.Context, network, addr string) (net.Conn, error)
|
|
|
|
// NewDialContext allows access to docker daemon in a remote machine using SSH.
|
|
//
|
|
// It creates a new ContextDialer which dials docker daemon in the remote
|
|
// and also returns Docker Host URI as seen by the remote.
|
|
//
|
|
// Knowing the Docker Host is useful when mounting docker socket into a container.
|
|
//
|
|
// Dialing the remote docker daemon can be done in two ways:
|
|
//
|
|
// - Use SSH to tunnel Unix/TCP socket.
|
|
//
|
|
// - Use SSH to execute the "docker system dial-stdio" command in the remote and forward its stdio.
|
|
//
|
|
// The tunnel method is used whenever possible.
|
|
// The "stdio" method is used as a fallback when tunneling is not possible:
|
|
// e.g. when remote uses Windows' named pipe.
|
|
//
|
|
// When tunneling is used all connection dialed
|
|
// by the returned ContextDialer are tunneled via single SSH connection.
|
|
// The connection should be disposed when dialer is no longer needed.
|
|
//
|
|
// For this reason returned ContextDialer may also implement io.Closer.
|
|
// Caller of this function should check if the returned ContextDialer
|
|
// is also an instance of io.Closer and call Close() on it if it is.
|
|
func NewDialContext(url *urlPkg.URL, config Config) (ContextDialer, string, error) {
|
|
sshConfig, err := NewSSHClientConfig(url, config)
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
port := url.Port()
|
|
if port == "" {
|
|
port = "22"
|
|
}
|
|
host := url.Hostname()
|
|
|
|
sshClient, err := ssh.Dial("tcp", net.JoinHostPort(host, port), sshConfig)
|
|
if err != nil {
|
|
return nil, "", fmt.Errorf("failed to dial ssh: %w", err)
|
|
}
|
|
defer func() {
|
|
if sshClient != nil {
|
|
sshClient.Close()
|
|
}
|
|
}()
|
|
|
|
var remoteDockerHost string
|
|
if url.Path != "" {
|
|
remoteDockerHost = fmt.Sprintf(`unix://%s`, url.Path)
|
|
} else {
|
|
remoteDockerHost, err = getRemoteDockerHost(sshClient)
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
}
|
|
|
|
network, addr, err := getNetworkAndAddress(remoteDockerHost)
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
if network == "npipe" {
|
|
// ssh tunneling doesn't support tunneling of Windows' named pipes
|
|
dialContext, err := stdioDialContext(url, sshClient, config.Identity)
|
|
return contextDialerFn(dialContext), remoteDockerHost, err
|
|
}
|
|
|
|
d := dialer{sshClient: sshClient, addr: addr, network: network}
|
|
// moving ownership of sshClient from this function to the returned structure
|
|
sshClient = nil
|
|
|
|
return &d, remoteDockerHost, nil
|
|
}
|
|
|
|
type dialer struct {
|
|
sshClient *ssh.Client
|
|
network string
|
|
addr string
|
|
}
|
|
|
|
type ContextDialer interface {
|
|
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
|
}
|
|
|
|
type contextDialerFn DialContextFn
|
|
|
|
func (n contextDialerFn) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
|
return n(ctx, network, address)
|
|
}
|
|
|
|
func (d *dialer) DialContext(ctx context.Context, n, a string) (net.Conn, error) {
|
|
conn, err := d.Dial(d.network, d.addr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
go func() {
|
|
if ctx != nil {
|
|
<-ctx.Done()
|
|
conn.Close()
|
|
}
|
|
}()
|
|
return conn, nil
|
|
}
|
|
|
|
func (d *dialer) Dial(n, a string) (net.Conn, error) {
|
|
return d.sshClient.Dial(d.network, d.addr)
|
|
}
|
|
|
|
func (d *dialer) Close() error {
|
|
return d.sshClient.Close()
|
|
}
|
|
|
|
func isWindowsMachine(sshClient *ssh.Client) (bool, error) {
|
|
session, err := sshClient.NewSession()
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
defer session.Close()
|
|
|
|
out, err := session.CombinedOutput("systeminfo")
|
|
if err == nil && strings.Contains(string(out), "Windows") {
|
|
return true, nil
|
|
}
|
|
return false, nil
|
|
}
|
|
|
|
func getRemoteDockerHost(sshClient *ssh.Client) (remoteDockerHost string, err error) {
|
|
session, err := sshClient.NewSession()
|
|
if err != nil {
|
|
return
|
|
}
|
|
defer session.Close()
|
|
|
|
out, err := session.CombinedOutput("set")
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
remoteDockerHost = "unix:///var/run/docker.sock"
|
|
isWin, err := isWindowsMachine(sshClient)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
if isWin {
|
|
remoteDockerHost = "npipe:////./pipe/docker_engine"
|
|
}
|
|
|
|
scanner := bufio.NewScanner(bytes.NewBuffer(out))
|
|
for scanner.Scan() {
|
|
if strings.HasPrefix(scanner.Text(), "DOCKER_HOST=") {
|
|
parts := strings.SplitN(scanner.Text(), "=", 2)
|
|
remoteDockerHost = strings.Trim(parts[1], `"'`)
|
|
break
|
|
}
|
|
}
|
|
|
|
return remoteDockerHost, err
|
|
}
|
|
|
|
func getNetworkAndAddress(remoteDockerHost string) (network string, addr string, err error) {
|
|
remoteDockerHostURL, err := urlPkg.Parse(remoteDockerHost)
|
|
if err != nil {
|
|
return
|
|
}
|
|
switch remoteDockerHostURL.Scheme {
|
|
case "unix", "npipe":
|
|
addr = remoteDockerHostURL.Path
|
|
case "fd":
|
|
remoteDockerHostURL.Scheme = "tcp" // don't know why it works that way
|
|
fallthrough
|
|
case "tcp":
|
|
addr = remoteDockerHostURL.Host
|
|
default:
|
|
return "", "", errors.New("scheme is not supported")
|
|
}
|
|
network = remoteDockerHostURL.Scheme
|
|
|
|
return network, addr, err
|
|
}
|
|
|
|
func stdioDialContext(url *urlPkg.URL, sshClient *ssh.Client, identity string) (DialContextFn, error) {
|
|
session, err := sshClient.NewSession()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer session.Close()
|
|
|
|
out, err := session.CombinedOutput("docker system dial-stdio --help")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("cannot use dial-stdio: %w (%q)", err, out)
|
|
}
|
|
|
|
var opts []string
|
|
if identity != "" {
|
|
opts = append(opts, "-i", identity)
|
|
}
|
|
|
|
connHelper, err := connhelper.GetConnectionHelperWithSSHOpts(url.String(), opts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return connHelper.Dialer, nil
|
|
}
|
|
|
|
// Default key names.
|
|
var knownKeyNames = []string{"id_rsa", "id_dsa", "id_ecdsa", "id_ecdsa_sk", "id_ed25519", "id_ed25519_sk"}
|
|
|
|
func NewSSHClientConfig(url *urlPkg.URL, credentialsConfig Config) (*ssh.ClientConfig, error) {
|
|
var (
|
|
authMethods []ssh.AuthMethod
|
|
signers []ssh.Signer
|
|
err error
|
|
)
|
|
|
|
if pw, found := url.User.Password(); found {
|
|
authMethods = append(authMethods, ssh.Password(pw))
|
|
}
|
|
|
|
// add signer from explicit identity parameter
|
|
if credentialsConfig.Identity != "" {
|
|
s, err := publicKey(credentialsConfig.Identity, []byte(credentialsConfig.Identity), credentialsConfig.PassPhraseCallback)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to parse identity file: %w", err)
|
|
}
|
|
signers = append(signers, s)
|
|
}
|
|
|
|
// add signers from ssh-agent
|
|
if sock, found := os.LookupEnv("SSH_AUTH_SOCK"); found && sock != "" {
|
|
var agentSigners []ssh.Signer
|
|
var agentConn net.Conn
|
|
agentConn, err = dialSSHAgentConnection(sock)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to connect to ssh-agent's socket: %w", err)
|
|
}
|
|
agentSigners, err = agent.NewClient(agentConn).Signers()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get signers from ssh-agent: %w", err)
|
|
}
|
|
signers = append(signers, agentSigners...)
|
|
}
|
|
|
|
// if there is no explicit identity file nor keys from ssh-agent then
|
|
// add keys with standard name from ~/.ssh/
|
|
if len(signers) == 0 {
|
|
var defaultKeyPaths []string
|
|
if home, err := os.UserHomeDir(); err == nil {
|
|
for _, keyName := range knownKeyNames {
|
|
p := filepath.Join(home, ".ssh", keyName)
|
|
|
|
fi, err := os.Stat(p)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
if fi.Mode().IsRegular() {
|
|
defaultKeyPaths = append(defaultKeyPaths, p)
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(defaultKeyPaths) == 1 {
|
|
s, err := publicKey(defaultKeyPaths[0], []byte(credentialsConfig.PassPhrase), credentialsConfig.PassPhraseCallback)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
signers = append(signers, s)
|
|
}
|
|
}
|
|
|
|
if len(signers) > 0 {
|
|
var dedup = make(map[string]ssh.Signer)
|
|
// Dedup signers based on fingerprint, ssh-agent keys override explicit identity
|
|
for _, s := range signers {
|
|
fp := ssh.FingerprintSHA256(s.PublicKey())
|
|
//if _, found := dedup[fp]; found {
|
|
// key updated
|
|
//}
|
|
dedup[fp] = s
|
|
}
|
|
|
|
var uniq []ssh.Signer
|
|
for _, s := range dedup {
|
|
uniq = append(uniq, s)
|
|
}
|
|
authMethods = append(authMethods, ssh.PublicKeysCallback(func() ([]ssh.Signer, error) {
|
|
return uniq, nil
|
|
}))
|
|
}
|
|
|
|
if len(authMethods) == 0 && credentialsConfig.PasswordCallback != nil {
|
|
authMethods = append(authMethods, ssh.PasswordCallback(credentialsConfig.PasswordCallback))
|
|
}
|
|
|
|
const sshTimeout = 5
|
|
clientConfig := &ssh.ClientConfig{
|
|
User: url.User.Username(),
|
|
Auth: authMethods,
|
|
HostKeyCallback: createHostKeyCallback(credentialsConfig.HostKeyCallback),
|
|
HostKeyAlgorithms: []string{
|
|
ssh.KeyAlgoECDSA256,
|
|
ssh.KeyAlgoECDSA384,
|
|
ssh.KeyAlgoECDSA521,
|
|
ssh.KeyAlgoED25519,
|
|
ssh.KeyAlgoRSASHA256,
|
|
ssh.KeyAlgoRSASHA512,
|
|
ssh.KeyAlgoRSA,
|
|
ssh.KeyAlgoDSA,
|
|
},
|
|
Timeout: sshTimeout * time.Second,
|
|
}
|
|
|
|
return clientConfig, nil
|
|
}
|
|
|
|
func publicKey(path string, passphrase []byte, passPhraseCallback PassPhraseCallback) (ssh.Signer, error) {
|
|
key, err := os.ReadFile(path)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read key file: %w", err)
|
|
}
|
|
|
|
signer, err := ssh.ParsePrivateKey(key)
|
|
if err != nil {
|
|
var missingPhraseError *ssh.PassphraseMissingError
|
|
if ok := errors.As(err, &missingPhraseError); !ok {
|
|
return nil, fmt.Errorf("failed to parse private key: %w", err)
|
|
}
|
|
|
|
if len(passphrase) == 0 && passPhraseCallback != nil {
|
|
b, err := passPhraseCallback()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
passphrase = []byte(b)
|
|
}
|
|
|
|
return ssh.ParsePrivateKeyWithPassphrase(key, passphrase)
|
|
}
|
|
|
|
return signer, nil
|
|
}
|
|
|
|
func createHostKeyCallback(hostKeyCallback HostKeyCallback) func(hostPort string, remote net.Addr, key ssh.PublicKey) error {
|
|
return func(hostPort string, remote net.Addr, pubKey ssh.PublicKey) error {
|
|
host, port := hostPort, "22"
|
|
if _h, _p, err := net.SplitHostPort(host); err == nil {
|
|
host, port = _h, _p
|
|
}
|
|
|
|
knownHosts := filepath.Join(homedir.Get(), ".ssh", "known_hosts")
|
|
|
|
_, err := os.Stat(knownHosts)
|
|
if err != nil && errors.Is(err, os.ErrNotExist) {
|
|
if hostKeyCallback != nil && hostKeyCallback(hostPort, pubKey) == nil {
|
|
return nil
|
|
}
|
|
return errUnknownServerKey
|
|
}
|
|
|
|
f, err := os.Open(knownHosts)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to open known_hosts: %w", err)
|
|
}
|
|
defer f.Close()
|
|
|
|
hashhost := knownhosts.HashHostname(host)
|
|
|
|
var errs []error
|
|
scanner := bufio.NewScanner(f)
|
|
for scanner.Scan() {
|
|
_, hostPorts, _key, _, _, err := ssh.ParseKnownHosts(scanner.Bytes())
|
|
if err != nil {
|
|
errs = append(errs, err)
|
|
continue
|
|
}
|
|
|
|
for _, hp := range hostPorts {
|
|
h, p := hp, "22"
|
|
if _h, _p, err := net.SplitHostPort(hp); err == nil {
|
|
h, p = _h, _p
|
|
}
|
|
|
|
if (h == host || h == hashhost) && port == p {
|
|
if pubKey.Type() != _key.Type() {
|
|
errs = append(errs, fmt.Errorf("missmatch in type of a key"))
|
|
continue
|
|
}
|
|
if bytes.Equal(_key.Marshal(), pubKey.Marshal()) {
|
|
return nil
|
|
}
|
|
|
|
return errBadServerKey
|
|
}
|
|
}
|
|
}
|
|
|
|
if hostKeyCallback != nil && hostKeyCallback(hostPort, pubKey) == nil {
|
|
return nil
|
|
}
|
|
|
|
if len(errs) > 0 {
|
|
return fmt.Errorf("server is not trusted (%v)", errs)
|
|
}
|
|
|
|
return errUnknownServerKey
|
|
}
|
|
}
|
|
|
|
var ErrBadServerKeyMsg = "server key for given host differs from key in known_host"
|
|
var ErrUnknownServerKeyMsg = "server key not found in known_hosts"
|
|
|
|
// I would expose those but since ssh pkg doesn't do correct error wrapping it would be entirely futile
|
|
var errBadServerKey = errors.New(ErrBadServerKeyMsg)
|
|
var errUnknownServerKey = errors.New(ErrUnknownServerKeyMsg)
|