mirror of https://github.com/knative/func.git
598 lines
13 KiB
Go
598 lines
13 KiB
Go
package ssh_test
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/ecdsa"
|
|
"crypto/elliptic"
|
|
"crypto/md5"
|
|
"crypto/rsa"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"math/rand"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
type SSHServer struct {
|
|
lock sync.Locker
|
|
dockerServer http.Server
|
|
dockerListener listener
|
|
dockerHost string
|
|
hostIPv4 string
|
|
hostIPv6 string
|
|
portIPv4 int
|
|
portIPv6 int
|
|
hasDialStdio bool
|
|
isWin bool
|
|
serverKeys []any
|
|
authorizedKeys []any
|
|
}
|
|
|
|
func (s *SSHServer) SetIsWindows(v bool) {
|
|
s.lock.Lock()
|
|
defer s.lock.Unlock()
|
|
s.isWin = v
|
|
}
|
|
|
|
func (s *SSHServer) IsWindows() bool {
|
|
s.lock.Lock()
|
|
defer s.lock.Unlock()
|
|
return s.isWin
|
|
}
|
|
|
|
func (s *SSHServer) SetDockerHostEnvVar(host string) {
|
|
s.lock.Lock()
|
|
defer s.lock.Unlock()
|
|
s.dockerHost = host
|
|
}
|
|
|
|
func (s *SSHServer) GetDockerHostEnvVar() string {
|
|
s.lock.Lock()
|
|
defer s.lock.Unlock()
|
|
return s.dockerHost
|
|
}
|
|
|
|
func (s *SSHServer) HasDialStdio() bool {
|
|
s.lock.Lock()
|
|
defer s.lock.Unlock()
|
|
return s.hasDialStdio
|
|
}
|
|
|
|
func (s *SSHServer) SetHasDialStdio(v bool) {
|
|
s.lock.Lock()
|
|
defer s.lock.Unlock()
|
|
s.hasDialStdio = v
|
|
}
|
|
|
|
const dockerUnixSocket = "/home/testuser/test.sock"
|
|
const dockerTCPSocket = "localhost:1234"
|
|
|
|
// We need to set up SSH server against which we will run the tests.
|
|
// This will return SSHServer structure representing the state of the testing server.
|
|
func prepareSSHServer(t *testing.T, authorizedKeys ...any) (sshServer *SSHServer, err error) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer func() {
|
|
if err != nil {
|
|
cancel()
|
|
}
|
|
}()
|
|
|
|
httpServerErrChan := make(chan error)
|
|
pollingLoopErr := make(chan error)
|
|
pollingLoopIPv6Err := make(chan error)
|
|
|
|
handlePing := http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
|
writer.Header().Add("Content-Type", "text/plain")
|
|
writer.WriteHeader(200)
|
|
_, _ = writer.Write([]byte("OK"))
|
|
})
|
|
|
|
sshServer = &SSHServer{
|
|
dockerServer: http.Server{
|
|
Handler: handlePing,
|
|
},
|
|
dockerListener: listener{conns: make(chan net.Conn), closed: make(chan struct{})},
|
|
lock: &sync.Mutex{},
|
|
authorizedKeys: authorizedKeys,
|
|
}
|
|
|
|
rsaKey, err := rsa.GenerateKey(rand.New(rand.NewSource(time.Now().UnixNano())), 2048)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
ecdsaKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.New(rand.NewSource(time.Now().UnixNano())))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
sshServer.serverKeys = []any{rsaKey, ecdsaKey}
|
|
|
|
sshTCPListener, err := net.Listen("tcp4", "localhost:0")
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
hasIPv6 := true
|
|
sshTCP6Listener, err := net.Listen("tcp6", "localhost:0")
|
|
if err != nil {
|
|
hasIPv6 = false
|
|
t.Log(err)
|
|
}
|
|
|
|
host, p, err := net.SplitHostPort(sshTCPListener.Addr().String())
|
|
if err != nil {
|
|
return
|
|
}
|
|
port, err := strconv.ParseInt(p, 10, 32)
|
|
if err != nil {
|
|
return
|
|
}
|
|
sshServer.hostIPv4 = host
|
|
sshServer.portIPv4 = int(port)
|
|
|
|
if hasIPv6 {
|
|
host, p, err = net.SplitHostPort(sshTCP6Listener.Addr().String())
|
|
if err != nil {
|
|
return
|
|
}
|
|
port, err = strconv.ParseInt(p, 10, 32)
|
|
if err != nil {
|
|
return
|
|
}
|
|
sshServer.hostIPv6 = host
|
|
sshServer.portIPv6 = int(port)
|
|
}
|
|
|
|
t.Logf("Listening on %s", sshTCPListener.Addr())
|
|
if hasIPv6 {
|
|
t.Logf("Listening on %s", sshTCP6Listener.Addr())
|
|
}
|
|
|
|
go func() {
|
|
httpServerErrChan <- sshServer.dockerServer.Serve(&sshServer.dockerListener)
|
|
}()
|
|
|
|
stopSSH := func() {
|
|
var err error
|
|
cancel()
|
|
|
|
stopCtx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
|
defer cancel()
|
|
err = sshServer.dockerServer.Shutdown(stopCtx)
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
|
|
err = <-httpServerErrChan
|
|
if err != nil && !strings.Contains(err.Error(), "Server closed") {
|
|
t.Error(err)
|
|
}
|
|
|
|
sshTCPListener.Close()
|
|
err = <-pollingLoopErr
|
|
if err != nil && !errors.Is(err, net.ErrClosed) {
|
|
t.Error(err)
|
|
}
|
|
|
|
if hasIPv6 {
|
|
sshTCP6Listener.Close()
|
|
err = <-pollingLoopIPv6Err
|
|
if err != nil && !errors.Is(err, net.ErrClosed) {
|
|
t.Error(err)
|
|
}
|
|
}
|
|
}
|
|
t.Cleanup(stopSSH)
|
|
|
|
connChan := make(chan net.Conn)
|
|
|
|
go func() {
|
|
for {
|
|
tcpConn, err := sshTCPListener.Accept()
|
|
if err != nil {
|
|
pollingLoopErr <- err
|
|
return
|
|
}
|
|
connChan <- tcpConn
|
|
}
|
|
}()
|
|
|
|
if hasIPv6 {
|
|
go func() {
|
|
for {
|
|
tcpConn, err := sshTCP6Listener.Accept()
|
|
if err != nil {
|
|
pollingLoopIPv6Err <- err
|
|
return
|
|
}
|
|
connChan <- tcpConn
|
|
}
|
|
}()
|
|
}
|
|
|
|
go func() {
|
|
for {
|
|
conn := <-connChan
|
|
go func(conn net.Conn) {
|
|
err := sshServer.handleConnection(ctx, conn)
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "err: %v\n", err)
|
|
}
|
|
}(conn)
|
|
}
|
|
}()
|
|
|
|
return sshServer, err
|
|
}
|
|
|
|
func (s *SSHServer) setupServerAuth() (conf *ssh.ServerConfig, err error) {
|
|
passwd := map[string]string{
|
|
"testuser": "idkfa",
|
|
"root": "iddqd",
|
|
}
|
|
|
|
authorizedKeys := make(map[[16]byte][]byte, len(s.authorizedKeys))
|
|
for _, key := range s.authorizedKeys {
|
|
var pk ssh.PublicKey
|
|
pk, err = ssh.NewPublicKey(key)
|
|
if err != nil {
|
|
return
|
|
}
|
|
bs := pk.Marshal()
|
|
authorizedKeys[md5.Sum(bs)] = bs
|
|
}
|
|
|
|
conf = &ssh.ServerConfig{
|
|
PasswordCallback: func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
|
if p, ok := passwd[conn.User()]; ok && p == string(password) {
|
|
return nil, nil
|
|
}
|
|
return nil, fmt.Errorf("incorrect password %q for user %q", string(password), conn.User())
|
|
},
|
|
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
|
keyBytes := key.Marshal()
|
|
if b, ok := authorizedKeys[md5.Sum(keyBytes)]; ok && bytes.Equal(b, keyBytes) {
|
|
return &ssh.Permissions{}, nil
|
|
}
|
|
return nil, fmt.Errorf("untrusted public key: %q", string(keyBytes))
|
|
},
|
|
}
|
|
|
|
for _, k := range s.serverKeys {
|
|
signer, e := ssh.NewSignerFromKey(k)
|
|
if e != nil {
|
|
return nil, e
|
|
}
|
|
conf.AddHostKey(signer)
|
|
}
|
|
|
|
return conf, nil
|
|
}
|
|
|
|
func (s *SSHServer) handleConnection(ctx context.Context, conn net.Conn) error {
|
|
config, err := s.setupServerAuth()
|
|
if err != nil {
|
|
_, _ = fmt.Fprintf(os.Stderr, "cannot load auth: %v\n", err)
|
|
}
|
|
sshConn, newChannels, reqs, err := ssh.NewServerConn(conn, config)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
go func() {
|
|
<-ctx.Done()
|
|
err = sshConn.Close()
|
|
if err != nil && !errors.Is(err, net.ErrClosed) {
|
|
fmt.Fprintf(os.Stderr, "err: %v\n", err)
|
|
}
|
|
}()
|
|
|
|
var wg sync.WaitGroup
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
ssh.DiscardRequests(reqs)
|
|
}()
|
|
|
|
for newChannel := range newChannels {
|
|
wg.Add(1)
|
|
go func(newChannel ssh.NewChannel) {
|
|
defer wg.Done()
|
|
s.handleChannel(newChannel)
|
|
}(newChannel)
|
|
}
|
|
|
|
wg.Wait()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *SSHServer) handleChannel(newChannel ssh.NewChannel) {
|
|
var err error
|
|
switch newChannel.ChannelType() {
|
|
case "session":
|
|
s.handleSession(newChannel)
|
|
case "direct-streamlocal@openssh.com", "direct-tcpip":
|
|
s.handleTunnel(newChannel)
|
|
default:
|
|
err = newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("type of channel %q is not supported", newChannel.ChannelType()))
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "err: %v\n", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *SSHServer) handleSession(newChannel ssh.NewChannel) {
|
|
ch, reqs, err := newChannel.Accept()
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "err: %v\n", err)
|
|
return
|
|
}
|
|
|
|
defer ch.Close()
|
|
for req := range reqs {
|
|
if req.Type == "exec" {
|
|
s.handleExec(ch, req)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *SSHServer) handleExec(ch ssh.Channel, req *ssh.Request) {
|
|
var err error
|
|
err = req.Reply(true, nil)
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "err: %v\n", err)
|
|
return
|
|
}
|
|
execData := struct {
|
|
Command string
|
|
}{}
|
|
err = ssh.Unmarshal(req.Payload, &execData)
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "err: %v\n", err)
|
|
return
|
|
}
|
|
|
|
sendExitCode := func(ret uint32) {
|
|
msg := []byte{0, 0, 0, 0}
|
|
binary.BigEndian.PutUint32(msg, ret)
|
|
_, err = ch.SendRequest("exit-status", false, msg)
|
|
if err != nil && !errors.Is(err, io.EOF) {
|
|
fmt.Fprintf(os.Stderr, "err: %v\n", err)
|
|
}
|
|
}
|
|
|
|
var ret uint32
|
|
switch {
|
|
case execData.Command == "set":
|
|
ret = 0
|
|
dh := s.GetDockerHostEnvVar()
|
|
if dh != "" {
|
|
_, _ = fmt.Fprintf(ch, "DOCKER_HOST=%s\n", dh)
|
|
}
|
|
case execData.Command == "systeminfo" && s.IsWindows():
|
|
_, _ = fmt.Fprintln(ch, "something Windows something")
|
|
ret = 0
|
|
case execData.Command == "docker system dial-stdio --help" && s.HasDialStdio():
|
|
_, _ = fmt.Fprintln(ch, "\nUsage: docker system dial-stdio\n\nProxy the stdio stream to the daemon connection. Should not be invoked manually.")
|
|
ret = 0
|
|
case execData.Command == "docker system dial-stdio" && s.HasDialStdio():
|
|
pr, pw, conn := newPipeConn()
|
|
|
|
select {
|
|
case s.dockerListener.conns <- conn:
|
|
case <-s.dockerListener.closed:
|
|
err = ch.Close()
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "err: %v\n", err)
|
|
}
|
|
}
|
|
|
|
cpDone := make(chan struct{})
|
|
go func() {
|
|
var err error
|
|
_, err = io.Copy(pw, ch)
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "err: %v\n", err)
|
|
}
|
|
err = pw.Close()
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "err: %v\n", err)
|
|
}
|
|
cpDone <- struct{}{}
|
|
}()
|
|
|
|
_, err = io.Copy(ch, pr)
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "err: %v\n", err)
|
|
}
|
|
err = pr.Close()
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "err: %v\n", err)
|
|
}
|
|
|
|
<-cpDone
|
|
|
|
<-conn.closed
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "err: %v\n", err)
|
|
}
|
|
|
|
ret = 0
|
|
default:
|
|
_, _ = fmt.Fprintf(ch.Stderr(), "unknown command: %q\n", execData.Command)
|
|
ret = 127
|
|
}
|
|
sendExitCode(ret)
|
|
}
|
|
|
|
func newPipeConn() (*io.PipeReader, *io.PipeWriter, *rwcConn) {
|
|
pr0, pw0 := io.Pipe()
|
|
pr1, pw1 := io.Pipe()
|
|
rwc := pipeReaderWriterCloser{r: pr0, w: pw1}
|
|
return pr1, pw0, newRWCConn(rwc)
|
|
}
|
|
|
|
type pipeReaderWriterCloser struct {
|
|
r *io.PipeReader
|
|
w *io.PipeWriter
|
|
}
|
|
|
|
func (d pipeReaderWriterCloser) Read(p []byte) (n int, err error) {
|
|
return d.r.Read(p)
|
|
}
|
|
|
|
func (d pipeReaderWriterCloser) Write(p []byte) (n int, err error) {
|
|
return d.w.Write(p)
|
|
}
|
|
|
|
func (d pipeReaderWriterCloser) Close() error {
|
|
err := d.r.Close()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return d.w.Close()
|
|
}
|
|
|
|
func (s *SSHServer) handleTunnel(newChannel ssh.NewChannel) {
|
|
var err error
|
|
|
|
switch newChannel.ChannelType() {
|
|
case "direct-streamlocal@openssh.com":
|
|
bs := newChannel.ExtraData()
|
|
unixExtraData := struct {
|
|
SocketPath string
|
|
Reserved0 string
|
|
Reserved1 uint32
|
|
}{}
|
|
err = ssh.Unmarshal(bs, &unixExtraData)
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "err: %v\n", err)
|
|
return
|
|
}
|
|
if unixExtraData.SocketPath != dockerUnixSocket {
|
|
err = newChannel.Reject(ssh.ConnectionFailed, fmt.Sprintf("bad socket: %q", unixExtraData.SocketPath))
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "err: %v\n", err)
|
|
}
|
|
return
|
|
}
|
|
case "direct-tcpip":
|
|
bs := newChannel.ExtraData()
|
|
tcpExtraData := struct { //nolint:maligned
|
|
HostLocal string
|
|
PortLocal uint32
|
|
HostRemote string
|
|
PortRemote uint32
|
|
}{}
|
|
err = ssh.Unmarshal(bs, &tcpExtraData)
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "err: %v\n", err)
|
|
return
|
|
}
|
|
|
|
hostPort := fmt.Sprintf("%s:%d", tcpExtraData.HostLocal, tcpExtraData.PortLocal)
|
|
if hostPort != dockerTCPSocket {
|
|
err = newChannel.Reject(ssh.ConnectionFailed, fmt.Sprintf("bad socket: '%s:%d'", tcpExtraData.HostLocal, tcpExtraData.PortLocal))
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "err: %v\n", err)
|
|
}
|
|
return
|
|
}
|
|
}
|
|
|
|
ch, _, err := newChannel.Accept()
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "err: %v\n", err)
|
|
return
|
|
}
|
|
conn := newRWCConn(ch)
|
|
select {
|
|
case s.dockerListener.conns <- conn:
|
|
case <-s.dockerListener.closed:
|
|
err = ch.Close()
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "err: %v\n", err)
|
|
}
|
|
return
|
|
}
|
|
<-conn.closed
|
|
}
|
|
|
|
type listener struct {
|
|
conns chan net.Conn
|
|
closed chan struct{}
|
|
o sync.Once
|
|
}
|
|
|
|
func (l *listener) Accept() (net.Conn, error) {
|
|
select {
|
|
case <-l.closed:
|
|
return nil, net.ErrClosed
|
|
case conn := <-l.conns:
|
|
return conn, nil
|
|
}
|
|
}
|
|
|
|
func (l *listener) Close() error {
|
|
l.o.Do(func() {
|
|
close(l.closed)
|
|
})
|
|
return nil
|
|
}
|
|
|
|
func (l *listener) Addr() net.Addr {
|
|
return &net.UnixAddr{Name: dockerUnixSocket, Net: "unix"}
|
|
}
|
|
|
|
func newRWCConn(rwc io.ReadWriteCloser) *rwcConn {
|
|
return &rwcConn{rwc: rwc, closed: make(chan struct{})}
|
|
}
|
|
|
|
type rwcConn struct {
|
|
rwc io.ReadWriteCloser
|
|
closed chan struct{}
|
|
o sync.Once
|
|
}
|
|
|
|
func (c *rwcConn) Read(b []byte) (n int, err error) {
|
|
return c.rwc.Read(b)
|
|
}
|
|
|
|
func (c *rwcConn) Write(b []byte) (n int, err error) {
|
|
return c.rwc.Write(b)
|
|
}
|
|
|
|
func (c *rwcConn) Close() error {
|
|
c.o.Do(func() {
|
|
close(c.closed)
|
|
})
|
|
return c.rwc.Close()
|
|
}
|
|
|
|
func (c *rwcConn) LocalAddr() net.Addr {
|
|
return &net.UnixAddr{Name: dockerUnixSocket, Net: "unix"}
|
|
}
|
|
|
|
func (c *rwcConn) RemoteAddr() net.Addr {
|
|
return &net.UnixAddr{Name: "@", Net: "unix"}
|
|
}
|
|
|
|
func (c *rwcConn) SetDeadline(t time.Time) error { return nil }
|
|
|
|
func (c *rwcConn) SetReadDeadline(t time.Time) error { return nil }
|
|
|
|
func (c *rwcConn) SetWriteDeadline(t time.Time) error { return nil }
|