mirror of https://github.com/knative/func.git
test: tests for SSH connector (#2003)
Signed-off-by: Matej Vasek <mvasek@redhat.com>
This commit is contained in:
parent
a3ac5e7248
commit
d65b812266
1
go.mod
1
go.mod
|
@ -27,6 +27,7 @@ require (
|
|||
github.com/google/go-containerregistry v0.15.2
|
||||
github.com/google/go-github/v49 v49.1.0
|
||||
github.com/google/uuid v1.3.1
|
||||
github.com/hectane/go-acl v0.0.0-20190604041725-da78bae5fc95
|
||||
github.com/heroku/color v0.0.6
|
||||
github.com/hinshun/vt10x v0.0.0-20220228203356-1ab2cad5fd82
|
||||
github.com/manifestival/client-go-client v0.5.0
|
||||
|
|
3
go.sum
3
go.sum
|
@ -569,6 +569,8 @@ github.com/hashicorp/golang-lru v1.0.2 h1:dV3g9Z/unq5DpblPpw+Oqcv4dU/1omnb4Ok8iP
|
|||
github.com/hashicorp/golang-lru v1.0.2/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4=
|
||||
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
|
||||
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
|
||||
github.com/hectane/go-acl v0.0.0-20190604041725-da78bae5fc95 h1:S4qyfL2sEm5Budr4KVMyEniCy+PbS55651I/a+Kn/NQ=
|
||||
github.com/hectane/go-acl v0.0.0-20190604041725-da78bae5fc95/go.mod h1:QiyDdbZLaJ/mZP4Zwc9g2QsfaEA4o7XvvgZegSci5/E=
|
||||
github.com/heroku/color v0.0.6 h1:UTFFMrmMLFcL3OweqP1lAdp8i1y/9oHqkeHjQ/b/Ny0=
|
||||
github.com/heroku/color v0.0.6/go.mod h1:ZBvOcx7cTF2QKOv4LbmoBtNl5uB17qWxGuzZrsi1wLU=
|
||||
github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68=
|
||||
|
@ -1155,6 +1157,7 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w
|
|||
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190529164535-6a60838ec259/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190616124812-15dcb6c0061f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
|
|
|
@ -0,0 +1,597 @@
|
|||
package ssh_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/md5"
|
||||
"crypto/rsa"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type SSHServer struct {
|
||||
lock sync.Locker
|
||||
dockerServer http.Server
|
||||
dockerListener listener
|
||||
dockerHost string
|
||||
hostIPv4 string
|
||||
hostIPv6 string
|
||||
portIPv4 int
|
||||
portIPv6 int
|
||||
hasDialStdio bool
|
||||
isWin bool
|
||||
serverKeys []any
|
||||
authorizedKeys []any
|
||||
}
|
||||
|
||||
func (s *SSHServer) SetIsWindows(v bool) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
s.isWin = v
|
||||
}
|
||||
|
||||
func (s *SSHServer) IsWindows() bool {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
return s.isWin
|
||||
}
|
||||
|
||||
func (s *SSHServer) SetDockerHostEnvVar(host string) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
s.dockerHost = host
|
||||
}
|
||||
|
||||
func (s *SSHServer) GetDockerHostEnvVar() string {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
return s.dockerHost
|
||||
}
|
||||
|
||||
func (s *SSHServer) HasDialStdio() bool {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
return s.hasDialStdio
|
||||
}
|
||||
|
||||
func (s *SSHServer) SetHasDialStdio(v bool) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
s.hasDialStdio = v
|
||||
}
|
||||
|
||||
const dockerUnixSocket = "/home/testuser/test.sock"
|
||||
const dockerTCPSocket = "localhost:1234"
|
||||
|
||||
// We need to set up SSH server against which we will run the tests.
|
||||
// This will return SSHServer structure representing the state of the testing server.
|
||||
func prepareSSHServer(t *testing.T, authorizedKeys ...any) (sshServer *SSHServer, err error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer func() {
|
||||
if err != nil {
|
||||
cancel()
|
||||
}
|
||||
}()
|
||||
|
||||
httpServerErrChan := make(chan error)
|
||||
pollingLoopErr := make(chan error)
|
||||
pollingLoopIPv6Err := make(chan error)
|
||||
|
||||
handlePing := http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
||||
writer.Header().Add("Content-Type", "text/plain")
|
||||
writer.WriteHeader(200)
|
||||
_, _ = writer.Write([]byte("OK"))
|
||||
})
|
||||
|
||||
sshServer = &SSHServer{
|
||||
dockerServer: http.Server{
|
||||
Handler: handlePing,
|
||||
},
|
||||
dockerListener: listener{conns: make(chan net.Conn), closed: make(chan struct{})},
|
||||
lock: &sync.Mutex{},
|
||||
authorizedKeys: authorizedKeys,
|
||||
}
|
||||
|
||||
rsaKey, err := rsa.GenerateKey(rand.New(rand.NewSource(time.Now().UnixNano())), 2048)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ecdsaKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.New(rand.NewSource(time.Now().UnixNano())))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sshServer.serverKeys = []any{rsaKey, ecdsaKey}
|
||||
|
||||
sshTCPListener, err := net.Listen("tcp4", "localhost:0")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
hasIPv6 := true
|
||||
sshTCP6Listener, err := net.Listen("tcp6", "localhost:0")
|
||||
if err != nil {
|
||||
hasIPv6 = false
|
||||
t.Log(err)
|
||||
}
|
||||
|
||||
host, p, err := net.SplitHostPort(sshTCPListener.Addr().String())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
port, err := strconv.ParseInt(p, 10, 32)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
sshServer.hostIPv4 = host
|
||||
sshServer.portIPv4 = int(port)
|
||||
|
||||
if hasIPv6 {
|
||||
host, p, err = net.SplitHostPort(sshTCP6Listener.Addr().String())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
port, err = strconv.ParseInt(p, 10, 32)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
sshServer.hostIPv6 = host
|
||||
sshServer.portIPv6 = int(port)
|
||||
}
|
||||
|
||||
t.Logf("Listening on %s", sshTCPListener.Addr())
|
||||
if hasIPv6 {
|
||||
t.Logf("Listening on %s", sshTCP6Listener.Addr())
|
||||
}
|
||||
|
||||
go func() {
|
||||
httpServerErrChan <- sshServer.dockerServer.Serve(&sshServer.dockerListener)
|
||||
}()
|
||||
|
||||
stopSSH := func() {
|
||||
var err error
|
||||
cancel()
|
||||
|
||||
stopCtx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
err = sshServer.dockerServer.Shutdown(stopCtx)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
err = <-httpServerErrChan
|
||||
if err != nil && !strings.Contains(err.Error(), "Server closed") {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
sshTCPListener.Close()
|
||||
err = <-pollingLoopErr
|
||||
if err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if hasIPv6 {
|
||||
sshTCP6Listener.Close()
|
||||
err = <-pollingLoopIPv6Err
|
||||
if err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
t.Cleanup(stopSSH)
|
||||
|
||||
connChan := make(chan net.Conn)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
tcpConn, err := sshTCPListener.Accept()
|
||||
if err != nil {
|
||||
pollingLoopErr <- err
|
||||
return
|
||||
}
|
||||
connChan <- tcpConn
|
||||
}
|
||||
}()
|
||||
|
||||
if hasIPv6 {
|
||||
go func() {
|
||||
for {
|
||||
tcpConn, err := sshTCP6Listener.Accept()
|
||||
if err != nil {
|
||||
pollingLoopIPv6Err <- err
|
||||
return
|
||||
}
|
||||
connChan <- tcpConn
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn := <-connChan
|
||||
go func(conn net.Conn) {
|
||||
err := sshServer.handleConnection(ctx, conn)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "err: %v\n", err)
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
return sshServer, err
|
||||
}
|
||||
|
||||
func (s *SSHServer) setupServerAuth() (conf *ssh.ServerConfig, err error) {
|
||||
passwd := map[string]string{
|
||||
"testuser": "idkfa",
|
||||
"root": "iddqd",
|
||||
}
|
||||
|
||||
authorizedKeys := make(map[[16]byte][]byte, len(s.authorizedKeys))
|
||||
for _, key := range s.authorizedKeys {
|
||||
var pk ssh.PublicKey
|
||||
pk, err = ssh.NewPublicKey(key)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
bs := pk.Marshal()
|
||||
authorizedKeys[md5.Sum(bs)] = bs
|
||||
}
|
||||
|
||||
conf = &ssh.ServerConfig{
|
||||
PasswordCallback: func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
||||
if p, ok := passwd[conn.User()]; ok && p == string(password) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("incorrect password %q for user %q", string(password), conn.User())
|
||||
},
|
||||
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
||||
keyBytes := key.Marshal()
|
||||
if b, ok := authorizedKeys[md5.Sum(keyBytes)]; ok && bytes.Equal(b, keyBytes) {
|
||||
return &ssh.Permissions{}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("untrusted public key: %q", string(keyBytes))
|
||||
},
|
||||
}
|
||||
|
||||
for _, k := range s.serverKeys {
|
||||
signer, e := ssh.NewSignerFromKey(k)
|
||||
if e != nil {
|
||||
return nil, e
|
||||
}
|
||||
conf.AddHostKey(signer)
|
||||
}
|
||||
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
func (s *SSHServer) handleConnection(ctx context.Context, conn net.Conn) error {
|
||||
config, err := s.setupServerAuth()
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintf(os.Stderr, "cannot load auth: %v\n", err)
|
||||
}
|
||||
sshConn, newChannels, reqs, err := ssh.NewServerConn(conn, config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
err = sshConn.Close()
|
||||
if err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
fmt.Fprintf(os.Stderr, "err: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ssh.DiscardRequests(reqs)
|
||||
}()
|
||||
|
||||
for newChannel := range newChannels {
|
||||
wg.Add(1)
|
||||
go func(newChannel ssh.NewChannel) {
|
||||
defer wg.Done()
|
||||
s.handleChannel(newChannel)
|
||||
}(newChannel)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SSHServer) handleChannel(newChannel ssh.NewChannel) {
|
||||
var err error
|
||||
switch newChannel.ChannelType() {
|
||||
case "session":
|
||||
s.handleSession(newChannel)
|
||||
case "direct-streamlocal@openssh.com", "direct-tcpip":
|
||||
s.handleTunnel(newChannel)
|
||||
default:
|
||||
err = newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("type of channel %q is not supported", newChannel.ChannelType()))
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "err: %v\n", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SSHServer) handleSession(newChannel ssh.NewChannel) {
|
||||
ch, reqs, err := newChannel.Accept()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "err: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
defer ch.Close()
|
||||
for req := range reqs {
|
||||
if req.Type == "exec" {
|
||||
s.handleExec(ch, req)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SSHServer) handleExec(ch ssh.Channel, req *ssh.Request) {
|
||||
var err error
|
||||
err = req.Reply(true, nil)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "err: %v\n", err)
|
||||
return
|
||||
}
|
||||
execData := struct {
|
||||
Command string
|
||||
}{}
|
||||
err = ssh.Unmarshal(req.Payload, &execData)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "err: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
sendExitCode := func(ret uint32) {
|
||||
msg := []byte{0, 0, 0, 0}
|
||||
binary.BigEndian.PutUint32(msg, ret)
|
||||
_, err = ch.SendRequest("exit-status", false, msg)
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
fmt.Fprintf(os.Stderr, "err: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
var ret uint32
|
||||
switch {
|
||||
case execData.Command == "set":
|
||||
ret = 0
|
||||
dh := s.GetDockerHostEnvVar()
|
||||
if dh != "" {
|
||||
_, _ = fmt.Fprintf(ch, "DOCKER_HOST=%s\n", dh)
|
||||
}
|
||||
case execData.Command == "systeminfo" && s.IsWindows():
|
||||
_, _ = fmt.Fprintln(ch, "something Windows something")
|
||||
ret = 0
|
||||
case execData.Command == "docker system dial-stdio --help" && s.HasDialStdio():
|
||||
_, _ = fmt.Fprintln(ch, "\nUsage: docker system dial-stdio\n\nProxy the stdio stream to the daemon connection. Should not be invoked manually.")
|
||||
ret = 0
|
||||
case execData.Command == "docker system dial-stdio" && s.HasDialStdio():
|
||||
pr, pw, conn := newPipeConn()
|
||||
|
||||
select {
|
||||
case s.dockerListener.conns <- conn:
|
||||
case <-s.dockerListener.closed:
|
||||
err = ch.Close()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "err: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
cpDone := make(chan struct{})
|
||||
go func() {
|
||||
var err error
|
||||
_, err = io.Copy(pw, ch)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "err: %v\n", err)
|
||||
}
|
||||
err = pw.Close()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "err: %v\n", err)
|
||||
}
|
||||
cpDone <- struct{}{}
|
||||
}()
|
||||
|
||||
_, err = io.Copy(ch, pr)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "err: %v\n", err)
|
||||
}
|
||||
err = pr.Close()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "err: %v\n", err)
|
||||
}
|
||||
|
||||
<-cpDone
|
||||
|
||||
<-conn.closed
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "err: %v\n", err)
|
||||
}
|
||||
|
||||
ret = 0
|
||||
default:
|
||||
_, _ = fmt.Fprintf(ch.Stderr(), "unknown command: %q\n", execData.Command)
|
||||
ret = 127
|
||||
}
|
||||
sendExitCode(ret)
|
||||
}
|
||||
|
||||
func newPipeConn() (*io.PipeReader, *io.PipeWriter, *rwcConn) {
|
||||
pr0, pw0 := io.Pipe()
|
||||
pr1, pw1 := io.Pipe()
|
||||
rwc := pipeReaderWriterCloser{r: pr0, w: pw1}
|
||||
return pr1, pw0, newRWCConn(rwc)
|
||||
}
|
||||
|
||||
type pipeReaderWriterCloser struct {
|
||||
r *io.PipeReader
|
||||
w *io.PipeWriter
|
||||
}
|
||||
|
||||
func (d pipeReaderWriterCloser) Read(p []byte) (n int, err error) {
|
||||
return d.r.Read(p)
|
||||
}
|
||||
|
||||
func (d pipeReaderWriterCloser) Write(p []byte) (n int, err error) {
|
||||
return d.w.Write(p)
|
||||
}
|
||||
|
||||
func (d pipeReaderWriterCloser) Close() error {
|
||||
err := d.r.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return d.w.Close()
|
||||
}
|
||||
|
||||
func (s *SSHServer) handleTunnel(newChannel ssh.NewChannel) {
|
||||
var err error
|
||||
|
||||
switch newChannel.ChannelType() {
|
||||
case "direct-streamlocal@openssh.com":
|
||||
bs := newChannel.ExtraData()
|
||||
unixExtraData := struct {
|
||||
SocketPath string
|
||||
Reserved0 string
|
||||
Reserved1 uint32
|
||||
}{}
|
||||
err = ssh.Unmarshal(bs, &unixExtraData)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "err: %v\n", err)
|
||||
return
|
||||
}
|
||||
if unixExtraData.SocketPath != dockerUnixSocket {
|
||||
err = newChannel.Reject(ssh.ConnectionFailed, fmt.Sprintf("bad socket: %q", unixExtraData.SocketPath))
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "err: %v\n", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
case "direct-tcpip":
|
||||
bs := newChannel.ExtraData()
|
||||
tcpExtraData := struct { //nolint:maligned
|
||||
HostLocal string
|
||||
PortLocal uint32
|
||||
HostRemote string
|
||||
PortRemote uint32
|
||||
}{}
|
||||
err = ssh.Unmarshal(bs, &tcpExtraData)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "err: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
hostPort := fmt.Sprintf("%s:%d", tcpExtraData.HostLocal, tcpExtraData.PortLocal)
|
||||
if hostPort != dockerTCPSocket {
|
||||
err = newChannel.Reject(ssh.ConnectionFailed, fmt.Sprintf("bad socket: '%s:%d'", tcpExtraData.HostLocal, tcpExtraData.PortLocal))
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "err: %v\n", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
ch, _, err := newChannel.Accept()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "err: %v\n", err)
|
||||
return
|
||||
}
|
||||
conn := newRWCConn(ch)
|
||||
select {
|
||||
case s.dockerListener.conns <- conn:
|
||||
case <-s.dockerListener.closed:
|
||||
err = ch.Close()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "err: %v\n", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
<-conn.closed
|
||||
}
|
||||
|
||||
type listener struct {
|
||||
conns chan net.Conn
|
||||
closed chan struct{}
|
||||
o sync.Once
|
||||
}
|
||||
|
||||
func (l *listener) Accept() (net.Conn, error) {
|
||||
select {
|
||||
case <-l.closed:
|
||||
return nil, net.ErrClosed
|
||||
case conn := <-l.conns:
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (l *listener) Close() error {
|
||||
l.o.Do(func() {
|
||||
close(l.closed)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *listener) Addr() net.Addr {
|
||||
return &net.UnixAddr{Name: dockerUnixSocket, Net: "unix"}
|
||||
}
|
||||
|
||||
func newRWCConn(rwc io.ReadWriteCloser) *rwcConn {
|
||||
return &rwcConn{rwc: rwc, closed: make(chan struct{})}
|
||||
}
|
||||
|
||||
type rwcConn struct {
|
||||
rwc io.ReadWriteCloser
|
||||
closed chan struct{}
|
||||
o sync.Once
|
||||
}
|
||||
|
||||
func (c *rwcConn) Read(b []byte) (n int, err error) {
|
||||
return c.rwc.Read(b)
|
||||
}
|
||||
|
||||
func (c *rwcConn) Write(b []byte) (n int, err error) {
|
||||
return c.rwc.Write(b)
|
||||
}
|
||||
|
||||
func (c *rwcConn) Close() error {
|
||||
c.o.Do(func() {
|
||||
close(c.closed)
|
||||
})
|
||||
return c.rwc.Close()
|
||||
}
|
||||
|
||||
func (c *rwcConn) LocalAddr() net.Addr {
|
||||
return &net.UnixAddr{Name: dockerUnixSocket, Net: "unix"}
|
||||
}
|
||||
|
||||
func (c *rwcConn) RemoteAddr() net.Addr {
|
||||
return &net.UnixAddr{Name: "@", Net: "unix"}
|
||||
}
|
||||
|
||||
func (c *rwcConn) SetDeadline(t time.Time) error { return nil }
|
||||
|
||||
func (c *rwcConn) SetReadDeadline(t time.Time) error { return nil }
|
||||
|
||||
func (c *rwcConn) SetWriteDeadline(t time.Time) error { return nil }
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,25 @@
|
|||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package ssh_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"os"
|
||||
)
|
||||
|
||||
func fixupPrivateKeyMod(path string) {
|
||||
err := os.Chmod(path, 0600)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func listen(addr string) (net.Listener, error) {
|
||||
return net.Listen("unix", addr)
|
||||
}
|
||||
|
||||
func isErrClosed(err error) bool {
|
||||
return errors.Is(err, net.ErrClosed)
|
||||
}
|
|
@ -0,0 +1,39 @@
|
|||
package ssh_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"os/user"
|
||||
"strings"
|
||||
|
||||
"github.com/Microsoft/go-winio"
|
||||
"github.com/hectane/go-acl"
|
||||
)
|
||||
|
||||
func fixupPrivateKeyMod(path string) {
|
||||
usr, err := user.Current()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
mode := uint32(0600)
|
||||
err = acl.Apply(path,
|
||||
true,
|
||||
false,
|
||||
acl.GrantName(((mode&0700)<<23)|((mode&0200)<<9), usr.Username))
|
||||
|
||||
// See https://github.com/hectane/go-acl/issues/1
|
||||
if err != nil && err.Error() != "The operation completed successfully." {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func listen(addr string) (net.Listener, error) {
|
||||
if strings.Contains(addr, "\\pipe\\") {
|
||||
return winio.ListenPipe(addr, nil)
|
||||
}
|
||||
return net.Listen("unix", addr)
|
||||
}
|
||||
|
||||
func isErrClosed(err error) bool {
|
||||
return errors.Is(err, net.ErrClosed) || errors.Is(err, winio.ErrPipeListenerClosed) || errors.Is(err, winio.ErrFileClosed)
|
||||
}
|
Loading…
Reference in New Issue