package ssh_test import ( "bytes" "context" "crypto/ecdsa" "crypto/elliptic" "crypto/md5" "crypto/rsa" "encoding/binary" "errors" "fmt" "io" "math/rand" "net" "net/http" "os" "strconv" "strings" "sync" "testing" "time" "golang.org/x/crypto/ssh" ) type SSHServer struct { lock sync.Locker dockerServer http.Server dockerListener listener dockerHost string hostIPv4 string hostIPv6 string portIPv4 int portIPv6 int hasDialStdio bool isWin bool serverKeys []any authorizedKeys []any } func (s *SSHServer) SetIsWindows(v bool) { s.lock.Lock() defer s.lock.Unlock() s.isWin = v } func (s *SSHServer) IsWindows() bool { s.lock.Lock() defer s.lock.Unlock() return s.isWin } func (s *SSHServer) SetDockerHostEnvVar(host string) { s.lock.Lock() defer s.lock.Unlock() s.dockerHost = host } func (s *SSHServer) GetDockerHostEnvVar() string { s.lock.Lock() defer s.lock.Unlock() return s.dockerHost } func (s *SSHServer) HasDialStdio() bool { s.lock.Lock() defer s.lock.Unlock() return s.hasDialStdio } func (s *SSHServer) SetHasDialStdio(v bool) { s.lock.Lock() defer s.lock.Unlock() s.hasDialStdio = v } const dockerUnixSocket = "/home/testuser/test.sock" const dockerTCPSocket = "localhost:1234" // We need to set up SSH server against which we will run the tests. // This will return SSHServer structure representing the state of the testing server. func prepareSSHServer(t *testing.T, authorizedKeys ...any) (sshServer *SSHServer, err error) { ctx, cancel := context.WithCancel(context.Background()) defer func() { if err != nil { cancel() } }() httpServerErrChan := make(chan error) pollingLoopErr := make(chan error) pollingLoopIPv6Err := make(chan error) handlePing := http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { writer.Header().Add("Content-Type", "text/plain") writer.WriteHeader(200) _, _ = writer.Write([]byte("OK")) }) sshServer = &SSHServer{ dockerServer: http.Server{ Handler: handlePing, }, dockerListener: listener{conns: make(chan net.Conn), closed: make(chan struct{})}, lock: &sync.Mutex{}, authorizedKeys: authorizedKeys, } rsaKey, err := rsa.GenerateKey(rand.New(rand.NewSource(time.Now().UnixNano())), 2048) if err != nil { t.Fatal(err) } ecdsaKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.New(rand.NewSource(time.Now().UnixNano()))) if err != nil { t.Fatal(err) } sshServer.serverKeys = []any{rsaKey, ecdsaKey} sshTCPListener, err := net.Listen("tcp4", "localhost:0") if err != nil { return } hasIPv6 := true sshTCP6Listener, err := net.Listen("tcp6", "localhost:0") if err != nil { hasIPv6 = false t.Log(err) } host, p, err := net.SplitHostPort(sshTCPListener.Addr().String()) if err != nil { return } port, err := strconv.ParseInt(p, 10, 32) if err != nil { return } sshServer.hostIPv4 = host sshServer.portIPv4 = int(port) if hasIPv6 { host, p, err = net.SplitHostPort(sshTCP6Listener.Addr().String()) if err != nil { return } port, err = strconv.ParseInt(p, 10, 32) if err != nil { return } sshServer.hostIPv6 = host sshServer.portIPv6 = int(port) } t.Logf("Listening on %s", sshTCPListener.Addr()) if hasIPv6 { t.Logf("Listening on %s", sshTCP6Listener.Addr()) } go func() { httpServerErrChan <- sshServer.dockerServer.Serve(&sshServer.dockerListener) }() stopSSH := func() { var err error cancel() stopCtx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() err = sshServer.dockerServer.Shutdown(stopCtx) if err != nil { t.Error(err) } err = <-httpServerErrChan if err != nil && !strings.Contains(err.Error(), "Server closed") { t.Error(err) } sshTCPListener.Close() err = <-pollingLoopErr if err != nil && !errors.Is(err, net.ErrClosed) { t.Error(err) } if hasIPv6 { sshTCP6Listener.Close() err = <-pollingLoopIPv6Err if err != nil && !errors.Is(err, net.ErrClosed) { t.Error(err) } } } t.Cleanup(stopSSH) connChan := make(chan net.Conn) go func() { for { tcpConn, err := sshTCPListener.Accept() if err != nil { pollingLoopErr <- err return } connChan <- tcpConn } }() if hasIPv6 { go func() { for { tcpConn, err := sshTCP6Listener.Accept() if err != nil { pollingLoopIPv6Err <- err return } connChan <- tcpConn } }() } go func() { for { conn := <-connChan go func(conn net.Conn) { err := sshServer.handleConnection(ctx, conn) if err != nil { fmt.Fprintf(os.Stderr, "err: %v\n", err) } }(conn) } }() return sshServer, err } func (s *SSHServer) setupServerAuth() (conf *ssh.ServerConfig, err error) { passwd := map[string]string{ "testuser": "idkfa", "root": "iddqd", } authorizedKeys := make(map[[16]byte][]byte, len(s.authorizedKeys)) for _, key := range s.authorizedKeys { var pk ssh.PublicKey pk, err = ssh.NewPublicKey(key) if err != nil { return } bs := pk.Marshal() authorizedKeys[md5.Sum(bs)] = bs } conf = &ssh.ServerConfig{ PasswordCallback: func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { if p, ok := passwd[conn.User()]; ok && p == string(password) { return nil, nil } return nil, fmt.Errorf("incorrect password %q for user %q", string(password), conn.User()) }, PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { keyBytes := key.Marshal() if b, ok := authorizedKeys[md5.Sum(keyBytes)]; ok && bytes.Equal(b, keyBytes) { return &ssh.Permissions{}, nil } return nil, fmt.Errorf("untrusted public key: %q", string(keyBytes)) }, } for _, k := range s.serverKeys { signer, e := ssh.NewSignerFromKey(k) if e != nil { return nil, e } conf.AddHostKey(signer) } return conf, nil } func (s *SSHServer) handleConnection(ctx context.Context, conn net.Conn) error { config, err := s.setupServerAuth() if err != nil { _, _ = fmt.Fprintf(os.Stderr, "cannot load auth: %v\n", err) } sshConn, newChannels, reqs, err := ssh.NewServerConn(conn, config) if err != nil { return err } go func() { <-ctx.Done() err = sshConn.Close() if err != nil && !errors.Is(err, net.ErrClosed) { fmt.Fprintf(os.Stderr, "err: %v\n", err) } }() var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() ssh.DiscardRequests(reqs) }() for newChannel := range newChannels { wg.Add(1) go func(newChannel ssh.NewChannel) { defer wg.Done() s.handleChannel(newChannel) }(newChannel) } wg.Wait() return nil } func (s *SSHServer) handleChannel(newChannel ssh.NewChannel) { var err error switch newChannel.ChannelType() { case "session": s.handleSession(newChannel) case "direct-streamlocal@openssh.com", "direct-tcpip": s.handleTunnel(newChannel) default: err = newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("type of channel %q is not supported", newChannel.ChannelType())) if err != nil { fmt.Fprintf(os.Stderr, "err: %v\n", err) } } } func (s *SSHServer) handleSession(newChannel ssh.NewChannel) { ch, reqs, err := newChannel.Accept() if err != nil { fmt.Fprintf(os.Stderr, "err: %v\n", err) return } defer ch.Close() for req := range reqs { if req.Type == "exec" { s.handleExec(ch, req) break } } } func (s *SSHServer) handleExec(ch ssh.Channel, req *ssh.Request) { var err error err = req.Reply(true, nil) if err != nil { fmt.Fprintf(os.Stderr, "err: %v\n", err) return } execData := struct { Command string }{} err = ssh.Unmarshal(req.Payload, &execData) if err != nil { fmt.Fprintf(os.Stderr, "err: %v\n", err) return } sendExitCode := func(ret uint32) { msg := []byte{0, 0, 0, 0} binary.BigEndian.PutUint32(msg, ret) _, err = ch.SendRequest("exit-status", false, msg) if err != nil && !errors.Is(err, io.EOF) { fmt.Fprintf(os.Stderr, "err: %v\n", err) } } var ret uint32 switch { case execData.Command == "set": ret = 0 dh := s.GetDockerHostEnvVar() if dh != "" { _, _ = fmt.Fprintf(ch, "DOCKER_HOST=%s\n", dh) } case execData.Command == "systeminfo" && s.IsWindows(): _, _ = fmt.Fprintln(ch, "something Windows something") ret = 0 case execData.Command == "docker system dial-stdio --help" && s.HasDialStdio(): _, _ = fmt.Fprintln(ch, "\nUsage: docker system dial-stdio\n\nProxy the stdio stream to the daemon connection. Should not be invoked manually.") ret = 0 case execData.Command == "docker system dial-stdio" && s.HasDialStdio(): pr, pw, conn := newPipeConn() select { case s.dockerListener.conns <- conn: case <-s.dockerListener.closed: err = ch.Close() if err != nil { fmt.Fprintf(os.Stderr, "err: %v\n", err) } } cpDone := make(chan struct{}) go func() { var err error _, err = io.Copy(pw, ch) if err != nil { fmt.Fprintf(os.Stderr, "err: %v\n", err) } err = pw.Close() if err != nil { fmt.Fprintf(os.Stderr, "err: %v\n", err) } cpDone <- struct{}{} }() _, err = io.Copy(ch, pr) if err != nil { fmt.Fprintf(os.Stderr, "err: %v\n", err) } err = pr.Close() if err != nil { fmt.Fprintf(os.Stderr, "err: %v\n", err) } <-cpDone <-conn.closed if err != nil { fmt.Fprintf(os.Stderr, "err: %v\n", err) } ret = 0 default: _, _ = fmt.Fprintf(ch.Stderr(), "unknown command: %q\n", execData.Command) ret = 127 } sendExitCode(ret) } func newPipeConn() (*io.PipeReader, *io.PipeWriter, *rwcConn) { pr0, pw0 := io.Pipe() pr1, pw1 := io.Pipe() rwc := pipeReaderWriterCloser{r: pr0, w: pw1} return pr1, pw0, newRWCConn(rwc) } type pipeReaderWriterCloser struct { r *io.PipeReader w *io.PipeWriter } func (d pipeReaderWriterCloser) Read(p []byte) (n int, err error) { return d.r.Read(p) } func (d pipeReaderWriterCloser) Write(p []byte) (n int, err error) { return d.w.Write(p) } func (d pipeReaderWriterCloser) Close() error { err := d.r.Close() if err != nil { return err } return d.w.Close() } func (s *SSHServer) handleTunnel(newChannel ssh.NewChannel) { var err error switch newChannel.ChannelType() { case "direct-streamlocal@openssh.com": bs := newChannel.ExtraData() unixExtraData := struct { SocketPath string Reserved0 string Reserved1 uint32 }{} err = ssh.Unmarshal(bs, &unixExtraData) if err != nil { fmt.Fprintf(os.Stderr, "err: %v\n", err) return } if unixExtraData.SocketPath != dockerUnixSocket { err = newChannel.Reject(ssh.ConnectionFailed, fmt.Sprintf("bad socket: %q", unixExtraData.SocketPath)) if err != nil { fmt.Fprintf(os.Stderr, "err: %v\n", err) } return } case "direct-tcpip": bs := newChannel.ExtraData() tcpExtraData := struct { //nolint:maligned HostLocal string PortLocal uint32 HostRemote string PortRemote uint32 }{} err = ssh.Unmarshal(bs, &tcpExtraData) if err != nil { fmt.Fprintf(os.Stderr, "err: %v\n", err) return } hostPort := fmt.Sprintf("%s:%d", tcpExtraData.HostLocal, tcpExtraData.PortLocal) if hostPort != dockerTCPSocket { err = newChannel.Reject(ssh.ConnectionFailed, fmt.Sprintf("bad socket: '%s:%d'", tcpExtraData.HostLocal, tcpExtraData.PortLocal)) if err != nil { fmt.Fprintf(os.Stderr, "err: %v\n", err) } return } } ch, _, err := newChannel.Accept() if err != nil { fmt.Fprintf(os.Stderr, "err: %v\n", err) return } conn := newRWCConn(ch) select { case s.dockerListener.conns <- conn: case <-s.dockerListener.closed: err = ch.Close() if err != nil { fmt.Fprintf(os.Stderr, "err: %v\n", err) } return } <-conn.closed } type listener struct { conns chan net.Conn closed chan struct{} o sync.Once } func (l *listener) Accept() (net.Conn, error) { select { case <-l.closed: return nil, net.ErrClosed case conn := <-l.conns: return conn, nil } } func (l *listener) Close() error { l.o.Do(func() { close(l.closed) }) return nil } func (l *listener) Addr() net.Addr { return &net.UnixAddr{Name: dockerUnixSocket, Net: "unix"} } func newRWCConn(rwc io.ReadWriteCloser) *rwcConn { return &rwcConn{rwc: rwc, closed: make(chan struct{})} } type rwcConn struct { rwc io.ReadWriteCloser closed chan struct{} o sync.Once } func (c *rwcConn) Read(b []byte) (n int, err error) { return c.rwc.Read(b) } func (c *rwcConn) Write(b []byte) (n int, err error) { return c.rwc.Write(b) } func (c *rwcConn) Close() error { c.o.Do(func() { close(c.closed) }) return c.rwc.Close() } func (c *rwcConn) LocalAddr() net.Addr { return &net.UnixAddr{Name: dockerUnixSocket, Net: "unix"} } func (c *rwcConn) RemoteAddr() net.Addr { return &net.UnixAddr{Name: "@", Net: "unix"} } func (c *rwcConn) SetDeadline(t time.Time) error { return nil } func (c *rwcConn) SetReadDeadline(t time.Time) error { return nil } func (c *rwcConn) SetWriteDeadline(t time.Time) error { return nil }