test: tests for SSH connector (#2003)

Signed-off-by: Matej Vasek <mvasek@redhat.com>
This commit is contained in:
Matej Vasek 2023-10-23 14:27:45 +02:00 committed by GitHub
parent a3ac5e7248
commit d65b812266
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 1691 additions and 0 deletions

1
go.mod
View File

@ -27,6 +27,7 @@ require (
github.com/google/go-containerregistry v0.15.2
github.com/google/go-github/v49 v49.1.0
github.com/google/uuid v1.3.1
github.com/hectane/go-acl v0.0.0-20190604041725-da78bae5fc95
github.com/heroku/color v0.0.6
github.com/hinshun/vt10x v0.0.0-20220228203356-1ab2cad5fd82
github.com/manifestival/client-go-client v0.5.0

3
go.sum
View File

@ -569,6 +569,8 @@ github.com/hashicorp/golang-lru v1.0.2 h1:dV3g9Z/unq5DpblPpw+Oqcv4dU/1omnb4Ok8iP
github.com/hashicorp/golang-lru v1.0.2/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4=
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
github.com/hectane/go-acl v0.0.0-20190604041725-da78bae5fc95 h1:S4qyfL2sEm5Budr4KVMyEniCy+PbS55651I/a+Kn/NQ=
github.com/hectane/go-acl v0.0.0-20190604041725-da78bae5fc95/go.mod h1:QiyDdbZLaJ/mZP4Zwc9g2QsfaEA4o7XvvgZegSci5/E=
github.com/heroku/color v0.0.6 h1:UTFFMrmMLFcL3OweqP1lAdp8i1y/9oHqkeHjQ/b/Ny0=
github.com/heroku/color v0.0.6/go.mod h1:ZBvOcx7cTF2QKOv4LbmoBtNl5uB17qWxGuzZrsi1wLU=
github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68=
@ -1155,6 +1157,7 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190529164535-6a60838ec259/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190616124812-15dcb6c0061f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=

597
pkg/ssh/server_test.go Normal file
View File

@ -0,0 +1,597 @@
package ssh_test
import (
"bytes"
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/md5"
"crypto/rsa"
"encoding/binary"
"errors"
"fmt"
"io"
"math/rand"
"net"
"net/http"
"os"
"strconv"
"strings"
"sync"
"testing"
"time"
"golang.org/x/crypto/ssh"
)
type SSHServer struct {
lock sync.Locker
dockerServer http.Server
dockerListener listener
dockerHost string
hostIPv4 string
hostIPv6 string
portIPv4 int
portIPv6 int
hasDialStdio bool
isWin bool
serverKeys []any
authorizedKeys []any
}
func (s *SSHServer) SetIsWindows(v bool) {
s.lock.Lock()
defer s.lock.Unlock()
s.isWin = v
}
func (s *SSHServer) IsWindows() bool {
s.lock.Lock()
defer s.lock.Unlock()
return s.isWin
}
func (s *SSHServer) SetDockerHostEnvVar(host string) {
s.lock.Lock()
defer s.lock.Unlock()
s.dockerHost = host
}
func (s *SSHServer) GetDockerHostEnvVar() string {
s.lock.Lock()
defer s.lock.Unlock()
return s.dockerHost
}
func (s *SSHServer) HasDialStdio() bool {
s.lock.Lock()
defer s.lock.Unlock()
return s.hasDialStdio
}
func (s *SSHServer) SetHasDialStdio(v bool) {
s.lock.Lock()
defer s.lock.Unlock()
s.hasDialStdio = v
}
const dockerUnixSocket = "/home/testuser/test.sock"
const dockerTCPSocket = "localhost:1234"
// We need to set up SSH server against which we will run the tests.
// This will return SSHServer structure representing the state of the testing server.
func prepareSSHServer(t *testing.T, authorizedKeys ...any) (sshServer *SSHServer, err error) {
ctx, cancel := context.WithCancel(context.Background())
defer func() {
if err != nil {
cancel()
}
}()
httpServerErrChan := make(chan error)
pollingLoopErr := make(chan error)
pollingLoopIPv6Err := make(chan error)
handlePing := http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
writer.Header().Add("Content-Type", "text/plain")
writer.WriteHeader(200)
_, _ = writer.Write([]byte("OK"))
})
sshServer = &SSHServer{
dockerServer: http.Server{
Handler: handlePing,
},
dockerListener: listener{conns: make(chan net.Conn), closed: make(chan struct{})},
lock: &sync.Mutex{},
authorizedKeys: authorizedKeys,
}
rsaKey, err := rsa.GenerateKey(rand.New(rand.NewSource(time.Now().UnixNano())), 2048)
if err != nil {
t.Fatal(err)
}
ecdsaKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.New(rand.NewSource(time.Now().UnixNano())))
if err != nil {
t.Fatal(err)
}
sshServer.serverKeys = []any{rsaKey, ecdsaKey}
sshTCPListener, err := net.Listen("tcp4", "localhost:0")
if err != nil {
return
}
hasIPv6 := true
sshTCP6Listener, err := net.Listen("tcp6", "localhost:0")
if err != nil {
hasIPv6 = false
t.Log(err)
}
host, p, err := net.SplitHostPort(sshTCPListener.Addr().String())
if err != nil {
return
}
port, err := strconv.ParseInt(p, 10, 32)
if err != nil {
return
}
sshServer.hostIPv4 = host
sshServer.portIPv4 = int(port)
if hasIPv6 {
host, p, err = net.SplitHostPort(sshTCP6Listener.Addr().String())
if err != nil {
return
}
port, err = strconv.ParseInt(p, 10, 32)
if err != nil {
return
}
sshServer.hostIPv6 = host
sshServer.portIPv6 = int(port)
}
t.Logf("Listening on %s", sshTCPListener.Addr())
if hasIPv6 {
t.Logf("Listening on %s", sshTCP6Listener.Addr())
}
go func() {
httpServerErrChan <- sshServer.dockerServer.Serve(&sshServer.dockerListener)
}()
stopSSH := func() {
var err error
cancel()
stopCtx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
err = sshServer.dockerServer.Shutdown(stopCtx)
if err != nil {
t.Error(err)
}
err = <-httpServerErrChan
if err != nil && !strings.Contains(err.Error(), "Server closed") {
t.Error(err)
}
sshTCPListener.Close()
err = <-pollingLoopErr
if err != nil && !errors.Is(err, net.ErrClosed) {
t.Error(err)
}
if hasIPv6 {
sshTCP6Listener.Close()
err = <-pollingLoopIPv6Err
if err != nil && !errors.Is(err, net.ErrClosed) {
t.Error(err)
}
}
}
t.Cleanup(stopSSH)
connChan := make(chan net.Conn)
go func() {
for {
tcpConn, err := sshTCPListener.Accept()
if err != nil {
pollingLoopErr <- err
return
}
connChan <- tcpConn
}
}()
if hasIPv6 {
go func() {
for {
tcpConn, err := sshTCP6Listener.Accept()
if err != nil {
pollingLoopIPv6Err <- err
return
}
connChan <- tcpConn
}
}()
}
go func() {
for {
conn := <-connChan
go func(conn net.Conn) {
err := sshServer.handleConnection(ctx, conn)
if err != nil {
fmt.Fprintf(os.Stderr, "err: %v\n", err)
}
}(conn)
}
}()
return sshServer, err
}
func (s *SSHServer) setupServerAuth() (conf *ssh.ServerConfig, err error) {
passwd := map[string]string{
"testuser": "idkfa",
"root": "iddqd",
}
authorizedKeys := make(map[[16]byte][]byte, len(s.authorizedKeys))
for _, key := range s.authorizedKeys {
var pk ssh.PublicKey
pk, err = ssh.NewPublicKey(key)
if err != nil {
return
}
bs := pk.Marshal()
authorizedKeys[md5.Sum(bs)] = bs
}
conf = &ssh.ServerConfig{
PasswordCallback: func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
if p, ok := passwd[conn.User()]; ok && p == string(password) {
return nil, nil
}
return nil, fmt.Errorf("incorrect password %q for user %q", string(password), conn.User())
},
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
keyBytes := key.Marshal()
if b, ok := authorizedKeys[md5.Sum(keyBytes)]; ok && bytes.Equal(b, keyBytes) {
return &ssh.Permissions{}, nil
}
return nil, fmt.Errorf("untrusted public key: %q", string(keyBytes))
},
}
for _, k := range s.serverKeys {
signer, e := ssh.NewSignerFromKey(k)
if e != nil {
return nil, e
}
conf.AddHostKey(signer)
}
return conf, nil
}
func (s *SSHServer) handleConnection(ctx context.Context, conn net.Conn) error {
config, err := s.setupServerAuth()
if err != nil {
_, _ = fmt.Fprintf(os.Stderr, "cannot load auth: %v\n", err)
}
sshConn, newChannels, reqs, err := ssh.NewServerConn(conn, config)
if err != nil {
return err
}
go func() {
<-ctx.Done()
err = sshConn.Close()
if err != nil && !errors.Is(err, net.ErrClosed) {
fmt.Fprintf(os.Stderr, "err: %v\n", err)
}
}()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
ssh.DiscardRequests(reqs)
}()
for newChannel := range newChannels {
wg.Add(1)
go func(newChannel ssh.NewChannel) {
defer wg.Done()
s.handleChannel(newChannel)
}(newChannel)
}
wg.Wait()
return nil
}
func (s *SSHServer) handleChannel(newChannel ssh.NewChannel) {
var err error
switch newChannel.ChannelType() {
case "session":
s.handleSession(newChannel)
case "direct-streamlocal@openssh.com", "direct-tcpip":
s.handleTunnel(newChannel)
default:
err = newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("type of channel %q is not supported", newChannel.ChannelType()))
if err != nil {
fmt.Fprintf(os.Stderr, "err: %v\n", err)
}
}
}
func (s *SSHServer) handleSession(newChannel ssh.NewChannel) {
ch, reqs, err := newChannel.Accept()
if err != nil {
fmt.Fprintf(os.Stderr, "err: %v\n", err)
return
}
defer ch.Close()
for req := range reqs {
if req.Type == "exec" {
s.handleExec(ch, req)
break
}
}
}
func (s *SSHServer) handleExec(ch ssh.Channel, req *ssh.Request) {
var err error
err = req.Reply(true, nil)
if err != nil {
fmt.Fprintf(os.Stderr, "err: %v\n", err)
return
}
execData := struct {
Command string
}{}
err = ssh.Unmarshal(req.Payload, &execData)
if err != nil {
fmt.Fprintf(os.Stderr, "err: %v\n", err)
return
}
sendExitCode := func(ret uint32) {
msg := []byte{0, 0, 0, 0}
binary.BigEndian.PutUint32(msg, ret)
_, err = ch.SendRequest("exit-status", false, msg)
if err != nil && !errors.Is(err, io.EOF) {
fmt.Fprintf(os.Stderr, "err: %v\n", err)
}
}
var ret uint32
switch {
case execData.Command == "set":
ret = 0
dh := s.GetDockerHostEnvVar()
if dh != "" {
_, _ = fmt.Fprintf(ch, "DOCKER_HOST=%s\n", dh)
}
case execData.Command == "systeminfo" && s.IsWindows():
_, _ = fmt.Fprintln(ch, "something Windows something")
ret = 0
case execData.Command == "docker system dial-stdio --help" && s.HasDialStdio():
_, _ = fmt.Fprintln(ch, "\nUsage: docker system dial-stdio\n\nProxy the stdio stream to the daemon connection. Should not be invoked manually.")
ret = 0
case execData.Command == "docker system dial-stdio" && s.HasDialStdio():
pr, pw, conn := newPipeConn()
select {
case s.dockerListener.conns <- conn:
case <-s.dockerListener.closed:
err = ch.Close()
if err != nil {
fmt.Fprintf(os.Stderr, "err: %v\n", err)
}
}
cpDone := make(chan struct{})
go func() {
var err error
_, err = io.Copy(pw, ch)
if err != nil {
fmt.Fprintf(os.Stderr, "err: %v\n", err)
}
err = pw.Close()
if err != nil {
fmt.Fprintf(os.Stderr, "err: %v\n", err)
}
cpDone <- struct{}{}
}()
_, err = io.Copy(ch, pr)
if err != nil {
fmt.Fprintf(os.Stderr, "err: %v\n", err)
}
err = pr.Close()
if err != nil {
fmt.Fprintf(os.Stderr, "err: %v\n", err)
}
<-cpDone
<-conn.closed
if err != nil {
fmt.Fprintf(os.Stderr, "err: %v\n", err)
}
ret = 0
default:
_, _ = fmt.Fprintf(ch.Stderr(), "unknown command: %q\n", execData.Command)
ret = 127
}
sendExitCode(ret)
}
func newPipeConn() (*io.PipeReader, *io.PipeWriter, *rwcConn) {
pr0, pw0 := io.Pipe()
pr1, pw1 := io.Pipe()
rwc := pipeReaderWriterCloser{r: pr0, w: pw1}
return pr1, pw0, newRWCConn(rwc)
}
type pipeReaderWriterCloser struct {
r *io.PipeReader
w *io.PipeWriter
}
func (d pipeReaderWriterCloser) Read(p []byte) (n int, err error) {
return d.r.Read(p)
}
func (d pipeReaderWriterCloser) Write(p []byte) (n int, err error) {
return d.w.Write(p)
}
func (d pipeReaderWriterCloser) Close() error {
err := d.r.Close()
if err != nil {
return err
}
return d.w.Close()
}
func (s *SSHServer) handleTunnel(newChannel ssh.NewChannel) {
var err error
switch newChannel.ChannelType() {
case "direct-streamlocal@openssh.com":
bs := newChannel.ExtraData()
unixExtraData := struct {
SocketPath string
Reserved0 string
Reserved1 uint32
}{}
err = ssh.Unmarshal(bs, &unixExtraData)
if err != nil {
fmt.Fprintf(os.Stderr, "err: %v\n", err)
return
}
if unixExtraData.SocketPath != dockerUnixSocket {
err = newChannel.Reject(ssh.ConnectionFailed, fmt.Sprintf("bad socket: %q", unixExtraData.SocketPath))
if err != nil {
fmt.Fprintf(os.Stderr, "err: %v\n", err)
}
return
}
case "direct-tcpip":
bs := newChannel.ExtraData()
tcpExtraData := struct { //nolint:maligned
HostLocal string
PortLocal uint32
HostRemote string
PortRemote uint32
}{}
err = ssh.Unmarshal(bs, &tcpExtraData)
if err != nil {
fmt.Fprintf(os.Stderr, "err: %v\n", err)
return
}
hostPort := fmt.Sprintf("%s:%d", tcpExtraData.HostLocal, tcpExtraData.PortLocal)
if hostPort != dockerTCPSocket {
err = newChannel.Reject(ssh.ConnectionFailed, fmt.Sprintf("bad socket: '%s:%d'", tcpExtraData.HostLocal, tcpExtraData.PortLocal))
if err != nil {
fmt.Fprintf(os.Stderr, "err: %v\n", err)
}
return
}
}
ch, _, err := newChannel.Accept()
if err != nil {
fmt.Fprintf(os.Stderr, "err: %v\n", err)
return
}
conn := newRWCConn(ch)
select {
case s.dockerListener.conns <- conn:
case <-s.dockerListener.closed:
err = ch.Close()
if err != nil {
fmt.Fprintf(os.Stderr, "err: %v\n", err)
}
return
}
<-conn.closed
}
type listener struct {
conns chan net.Conn
closed chan struct{}
o sync.Once
}
func (l *listener) Accept() (net.Conn, error) {
select {
case <-l.closed:
return nil, net.ErrClosed
case conn := <-l.conns:
return conn, nil
}
}
func (l *listener) Close() error {
l.o.Do(func() {
close(l.closed)
})
return nil
}
func (l *listener) Addr() net.Addr {
return &net.UnixAddr{Name: dockerUnixSocket, Net: "unix"}
}
func newRWCConn(rwc io.ReadWriteCloser) *rwcConn {
return &rwcConn{rwc: rwc, closed: make(chan struct{})}
}
type rwcConn struct {
rwc io.ReadWriteCloser
closed chan struct{}
o sync.Once
}
func (c *rwcConn) Read(b []byte) (n int, err error) {
return c.rwc.Read(b)
}
func (c *rwcConn) Write(b []byte) (n int, err error) {
return c.rwc.Write(b)
}
func (c *rwcConn) Close() error {
c.o.Do(func() {
close(c.closed)
})
return c.rwc.Close()
}
func (c *rwcConn) LocalAddr() net.Addr {
return &net.UnixAddr{Name: dockerUnixSocket, Net: "unix"}
}
func (c *rwcConn) RemoteAddr() net.Addr {
return &net.UnixAddr{Name: "@", Net: "unix"}
}
func (c *rwcConn) SetDeadline(t time.Time) error { return nil }
func (c *rwcConn) SetReadDeadline(t time.Time) error { return nil }
func (c *rwcConn) SetWriteDeadline(t time.Time) error { return nil }

1026
pkg/ssh/ssh_dialer_test.go Normal file

File diff suppressed because it is too large Load Diff

25
pkg/ssh/ssh_posix_test.go Normal file
View File

@ -0,0 +1,25 @@
//go:build !windows
// +build !windows
package ssh_test
import (
"errors"
"net"
"os"
)
func fixupPrivateKeyMod(path string) {
err := os.Chmod(path, 0600)
if err != nil {
panic(err)
}
}
func listen(addr string) (net.Listener, error) {
return net.Listen("unix", addr)
}
func isErrClosed(err error) bool {
return errors.Is(err, net.ErrClosed)
}

View File

@ -0,0 +1,39 @@
package ssh_test
import (
"errors"
"net"
"os/user"
"strings"
"github.com/Microsoft/go-winio"
"github.com/hectane/go-acl"
)
func fixupPrivateKeyMod(path string) {
usr, err := user.Current()
if err != nil {
panic(err)
}
mode := uint32(0600)
err = acl.Apply(path,
true,
false,
acl.GrantName(((mode&0700)<<23)|((mode&0200)<<9), usr.Username))
// See https://github.com/hectane/go-acl/issues/1
if err != nil && err.Error() != "The operation completed successfully." {
panic(err)
}
}
func listen(addr string) (net.Listener, error) {
if strings.Contains(addr, "\\pipe\\") {
return winio.ListenPipe(addr, nil)
}
return net.Listen("unix", addr)
}
func isErrClosed(err error) bool {
return errors.Is(err, net.ErrClosed) || errors.Is(err, winio.ErrPipeListenerClosed) || errors.Is(err, winio.ErrFileClosed)
}