mirror of https://github.com/knative/func.git
392 lines
8.4 KiB
Go
392 lines
8.4 KiB
Go
package docker_test
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/ecdsa"
|
|
"crypto/elliptic"
|
|
"crypto/rand"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"runtime"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/docker/docker/client"
|
|
"golang.org/x/crypto/ssh"
|
|
|
|
"knative.dev/func/pkg/docker"
|
|
)
|
|
|
|
func TestNewDockerClientWithSSH(t *testing.T) {
|
|
withCleanHome(t)
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Minute*1)
|
|
defer cancel()
|
|
|
|
sshConf := startSSH(t)
|
|
|
|
withKnowHosts(t, sshConf.address, sshConf.pubHostKey)
|
|
|
|
t.Setenv("DOCKER_HOST", fmt.Sprintf("ssh://user:pwd@%s", sshConf.address))
|
|
|
|
dockerClient, dockerHostInRemote, err := docker.NewClient(client.DefaultDockerHost)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer dockerClient.Close()
|
|
|
|
if dockerHostInRemote != `unix://`+sshDockerSocket {
|
|
t.Errorf("bad remote DOCKER_HOST: expected %q but got %q", `unix://`+sshDockerSocket, dockerHostInRemote)
|
|
}
|
|
|
|
_, err = dockerClient.Ping(ctx)
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
}
|
|
|
|
const sshDockerSocket = "/some/path/docker.sock"
|
|
|
|
type sshConfig struct {
|
|
address string
|
|
pubHostKey ssh.PublicKey
|
|
}
|
|
|
|
// emulates remote machine with docker unix socket at "/some/path/docker.sock"
|
|
func startSSH(t *testing.T, authorizedKeys ...ssh.PublicKey) (settings sshConfig) {
|
|
var err error
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
httpServerErrChan := make(chan error, 1)
|
|
pollingLoopErr := make(chan error, 1)
|
|
|
|
config := &ssh.ServerConfig{
|
|
PasswordCallback: func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
|
if string(password) != "pwd" {
|
|
return nil, errors.New("bad pwd")
|
|
}
|
|
return &ssh.Permissions{}, nil
|
|
},
|
|
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
|
for _, authKey := range authorizedKeys {
|
|
if bytes.Equal(authKey.Marshal(), key.Marshal()) {
|
|
return &ssh.Permissions{}, nil
|
|
}
|
|
}
|
|
return nil, fmt.Errorf("unknown public key")
|
|
},
|
|
}
|
|
|
|
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
hostKey, err := ssh.NewSignerFromKey(key)
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
config.AddHostKey(hostKey)
|
|
settings.pubHostKey = hostKey.PublicKey()
|
|
|
|
sshTCPListener, err := net.Listen("tcp", "localhost:0")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
dockerDaemonServer := http.Server{}
|
|
t.Cleanup(func() {
|
|
var err error
|
|
cancel()
|
|
|
|
err = sshTCPListener.Close()
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
err = <-pollingLoopErr
|
|
if err != nil && !errors.Is(err, net.ErrClosed) {
|
|
t.Error(err)
|
|
}
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
|
defer cancel()
|
|
err = dockerDaemonServer.Shutdown(ctx)
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
err = <-httpServerErrChan
|
|
if err != nil && !strings.Contains(err.Error(), "Server closed") {
|
|
t.Error(err)
|
|
}
|
|
|
|
})
|
|
|
|
settings.address = sshTCPListener.Addr().String()
|
|
|
|
t.Logf("Listening on %s", sshTCPListener.Addr())
|
|
|
|
// mimics /_ping endpoint
|
|
dockerDaemonServer.Handler = http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
|
writer.Header().Add("Content-Type", "text/plain")
|
|
writer.WriteHeader(200)
|
|
_, _ = writer.Write([]byte("OK"))
|
|
})
|
|
|
|
// listener that emulates unix socket in remote accessed via SSH
|
|
dockerDaemonListener := listener{make(chan io.ReadWriteCloser, 128)}
|
|
|
|
go func() {
|
|
httpServerErrChan <- dockerDaemonServer.Serve(dockerDaemonListener)
|
|
}()
|
|
|
|
handleChannel := func(newChannel ssh.NewChannel) {
|
|
switch newChannel.ChannelType() {
|
|
case "session":
|
|
handleSession(t, newChannel)
|
|
case "direct-streamlocal@openssh.com":
|
|
handleTunnel(t, newChannel, dockerDaemonListener)
|
|
default:
|
|
err = newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("type of channel %q is not supported", newChannel.ChannelType()))
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
handleChannels := func(newChannels <-chan ssh.NewChannel) {
|
|
for newChannel := range newChannels {
|
|
go handleChannel(newChannel)
|
|
}
|
|
}
|
|
|
|
go func() {
|
|
for {
|
|
tcpConn, err := sshTCPListener.Accept()
|
|
if err != nil {
|
|
pollingLoopErr <- err
|
|
return
|
|
}
|
|
|
|
sshConn, newChannels, reqs, err := ssh.NewServerConn(tcpConn, config)
|
|
if err != nil {
|
|
pollingLoopErr <- err
|
|
return
|
|
}
|
|
go func() {
|
|
<-ctx.Done()
|
|
err = sshConn.Close()
|
|
if err != nil && !errors.Is(err, net.ErrClosed) {
|
|
t.Error(err)
|
|
}
|
|
}()
|
|
|
|
go ssh.DiscardRequests(reqs)
|
|
|
|
go handleChannels(newChannels)
|
|
}
|
|
}()
|
|
|
|
return
|
|
}
|
|
|
|
func handleSession(t *testing.T, newChannel ssh.NewChannel) {
|
|
ch, reqs, err := newChannel.Accept()
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
go func() {
|
|
defer func() {
|
|
_ = ch.Close()
|
|
}()
|
|
for req := range reqs {
|
|
if req.Type == "exec" {
|
|
err = req.Reply(true, nil)
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
data := struct {
|
|
Command string
|
|
}{}
|
|
err = ssh.Unmarshal(req.Payload, &data)
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
var ret uint32
|
|
switch data.Command {
|
|
case "set":
|
|
ret = 0
|
|
_, _ = fmt.Fprintf(ch, "DOCKER_HOST=unix://%s\n", sshDockerSocket)
|
|
default:
|
|
_, _ = fmt.Fprintf(ch.Stderr(), "unknown command: %q\n", data.Command)
|
|
ret = 127
|
|
}
|
|
msg := []byte{0, 0, 0, 0}
|
|
binary.BigEndian.PutUint32(msg, ret)
|
|
_, err = ch.SendRequest("exit-status", false, msg)
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
func handleTunnel(t *testing.T, newChannel ssh.NewChannel, dockerDaemonListener listener) {
|
|
var err error
|
|
extraData := newChannel.ExtraData()
|
|
data := struct {
|
|
SocketPath string
|
|
Reserved0 string
|
|
Reserved1 uint32
|
|
}{}
|
|
|
|
err = ssh.Unmarshal(extraData, &data)
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
|
|
if data.SocketPath != sshDockerSocket {
|
|
err = newChannel.Reject(ssh.ConnectionFailed, fmt.Sprintf("bad socket: %q", data.SocketPath))
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
return
|
|
}
|
|
|
|
ch, reqs, err := newChannel.Accept()
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
select {
|
|
case dockerDaemonListener.connections <- ch:
|
|
default:
|
|
err = ch.Close()
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
return
|
|
}
|
|
|
|
ssh.DiscardRequests(reqs)
|
|
}
|
|
|
|
type listener struct {
|
|
connections chan io.ReadWriteCloser
|
|
}
|
|
|
|
type channelConnection struct {
|
|
ch io.ReadWriteCloser
|
|
}
|
|
|
|
func (c channelConnection) Read(b []byte) (n int, err error) {
|
|
return c.ch.Read(b)
|
|
}
|
|
|
|
func (c channelConnection) Write(b []byte) (n int, err error) {
|
|
return c.ch.Write(b)
|
|
}
|
|
|
|
func (c channelConnection) Close() error {
|
|
return c.ch.Close()
|
|
}
|
|
|
|
func (c channelConnection) LocalAddr() net.Addr {
|
|
return &net.UnixAddr{Name: sshDockerSocket, Net: "unix"}
|
|
}
|
|
|
|
func (c channelConnection) RemoteAddr() net.Addr {
|
|
return &net.UnixAddr{Name: "@", Net: "unix"}
|
|
}
|
|
|
|
func (c channelConnection) SetDeadline(t time.Time) error { return nil }
|
|
|
|
func (c channelConnection) SetReadDeadline(t time.Time) error { return nil }
|
|
|
|
func (c channelConnection) SetWriteDeadline(t time.Time) error { return nil }
|
|
|
|
func (l listener) Accept() (net.Conn, error) {
|
|
rwc, ok := <-l.connections
|
|
if !ok {
|
|
return nil, errors.New("listener closed")
|
|
}
|
|
return channelConnection{rwc}, nil
|
|
}
|
|
|
|
func (l listener) Close() error {
|
|
close(l.connections)
|
|
return nil
|
|
}
|
|
|
|
func (l listener) Addr() net.Addr {
|
|
return &net.UnixAddr{Name: sshDockerSocket, Net: "unix"}
|
|
}
|
|
|
|
// sets clean temporary $HOME for test
|
|
// this prevents interaction with actual user home which may contain .ssh/
|
|
func withCleanHome(t *testing.T) {
|
|
t.Helper()
|
|
homeName := "HOME"
|
|
if runtime.GOOS == "windows" {
|
|
homeName = "USERPROFILE"
|
|
}
|
|
tmpDir, err := os.MkdirTemp("", "tmpHome")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
oldHome, hadHome := os.LookupEnv(homeName)
|
|
os.Setenv(homeName, tmpDir)
|
|
|
|
t.Cleanup(func() {
|
|
if hadHome {
|
|
os.Setenv(homeName, oldHome)
|
|
} else {
|
|
os.Unsetenv(homeName)
|
|
}
|
|
os.RemoveAll(tmpDir)
|
|
})
|
|
}
|
|
|
|
// withKnowHosts creates $HOME/.ssh/known_hosts that trust the host
|
|
func withKnowHosts(t *testing.T, host string, pubKey ssh.PublicKey) {
|
|
t.Helper()
|
|
|
|
var err error
|
|
var home string
|
|
|
|
home, err = os.UserHomeDir()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
knownHosts := filepath.Join(home, ".ssh", "known_hosts")
|
|
|
|
_, err = os.Stat(knownHosts)
|
|
if err == nil || !errors.Is(err, os.ErrNotExist) {
|
|
t.Fatal("known_hosts already exists")
|
|
}
|
|
|
|
err = os.MkdirAll(filepath.Join(home, ".ssh"), 0700)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
knownHostFile, err := os.OpenFile(knownHosts, os.O_CREATE|os.O_WRONLY, 0600)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer knownHostFile.Close()
|
|
|
|
fmt.Fprintf(knownHostFile, "%s %s\n", host, string(ssh.MarshalAuthorizedKey(pubKey)))
|
|
|
|
t.Cleanup(func() {
|
|
os.Remove(knownHosts)
|
|
})
|
|
}
|