Cache SSH connections

The underlying SSH connections are kept open and are reused
across several SSH sessions. This is due to upstream issues in
which concurrent/parallel SSH connections may lead to instability.

https://github.com/golang/go/issues/51926
https://github.com/golang/go/issues/27140
Signed-off-by: Paulo Gomes <paulo.gomes@weave.works>
This commit is contained in:
Paulo Gomes 2022-03-28 11:58:10 +01:00
parent 017707a71c
commit 92ad1f813b
No known key found for this signature in database
GPG Key ID: 9995233870E99BEE
4 changed files with 319 additions and 92 deletions

View File

@ -288,11 +288,13 @@ func (self *httpSmartSubtransportStream) Free() {
if self.resp != nil {
traceLog.Info("[http]: httpSmartSubtransportStream.Free()")
// ensure body is fully processed and closed
// for increased likelihood of transport reuse in HTTP/1.x.
// it should not be a problem to do this more than once.
io.Copy(io.Discard, self.resp.Body)
self.resp.Body.Close()
if self.resp.Body != nil {
// ensure body is fully processed and closed
// for increased likelihood of transport reuse in HTTP/1.x.
// it should not be a problem to do this more than once.
_, _ = io.Copy(io.Discard, self.resp.Body) // errors can be safely ignored
_ = self.resp.Body.Close() // errors can be safely ignored
}
}
}
@ -366,8 +368,11 @@ func (self *httpSmartSubtransportStream) sendRequest() error {
if req.Method == "POST" && resp.StatusCode >= 301 && resp.StatusCode <= 308 {
// ensure body is fully processed and closed
// for increased likelihood of transport reuse in HTTP/1.x.
io.Copy(io.Discard, resp.Body)
resp.Body.Close()
_, _ = io.Copy(io.Discard, resp.Body) // errors can be safely ignored
if err := resp.Body.Close(); err != nil {
return err
}
// The next try will go against the new destination
self.req.URL, err = resp.Location()
@ -386,8 +391,10 @@ func (self *httpSmartSubtransportStream) sendRequest() error {
// ensure body is fully processed and closed
// for increased likelihood of transport reuse in HTTP/1.x.
io.Copy(io.Discard, resp.Body)
resp.Body.Close()
_, _ = io.Copy(io.Discard, resp.Body) // errors can be safely ignored
if err := resp.Body.Close(); err != nil {
return err
}
return fmt.Errorf("Unhandled HTTP error %s", resp.Status)
}

View File

@ -36,7 +36,7 @@ import (
func TestHttpAction_CreateClientRequest(t *testing.T) {
tests := []struct {
description string
name string
url string
expectedUrl string
expectedMethod string
@ -46,7 +46,7 @@ func TestHttpAction_CreateClientRequest(t *testing.T) {
wantedErr error
}{
{
description: "Uploadpack: no changes when no options found",
name: "Uploadpack: no changes when no options found",
url: "https://sometarget/abc",
expectedUrl: "https://sometarget/abc/git-upload-pack",
expectedMethod: "POST",
@ -56,7 +56,7 @@ func TestHttpAction_CreateClientRequest(t *testing.T) {
wantedErr: nil,
},
{
description: "UploadpackLs: no changes when no options found",
name: "UploadpackLs: no changes when no options found",
url: "https://sometarget/abc",
expectedUrl: "https://sometarget/abc/info/refs?service=git-upload-pack",
expectedMethod: "GET",
@ -66,7 +66,7 @@ func TestHttpAction_CreateClientRequest(t *testing.T) {
wantedErr: nil,
},
{
description: "Receivepack: no changes when no options found",
name: "Receivepack: no changes when no options found",
url: "https://sometarget/abc",
expectedUrl: "https://sometarget/abc/git-receive-pack",
expectedMethod: "POST",
@ -76,7 +76,7 @@ func TestHttpAction_CreateClientRequest(t *testing.T) {
wantedErr: nil,
},
{
description: "ReceivepackLs: no changes when no options found",
name: "ReceivepackLs: no changes when no options found",
url: "https://sometarget/abc",
expectedUrl: "https://sometarget/abc/info/refs?service=git-receive-pack",
expectedMethod: "GET",
@ -86,7 +86,7 @@ func TestHttpAction_CreateClientRequest(t *testing.T) {
wantedErr: nil,
},
{
description: "override URL via options",
name: "override URL via options",
url: "https://initial-target/abc",
expectedUrl: "https://final-target/git-upload-pack",
expectedMethod: "POST",
@ -98,7 +98,7 @@ func TestHttpAction_CreateClientRequest(t *testing.T) {
wantedErr: nil,
},
{
description: "error when no http.transport provided",
name: "error when no http.transport provided",
url: "https://initial-target/abc",
expectedUrl: "",
expectedMethod: "",
@ -110,29 +110,31 @@ func TestHttpAction_CreateClientRequest(t *testing.T) {
}
for _, tt := range tests {
if tt.opts != nil {
AddTransportOptions(tt.url, *tt.opts)
}
_, req, err := createClientRequest(tt.url, tt.action, tt.transport)
if tt.wantedErr != nil {
if tt.wantedErr.Error() != err.Error() {
t.Errorf("%s: wanted: %v got: %v", tt.description, tt.wantedErr, err)
t.Run(tt.name, func(t *testing.T) {
if tt.opts != nil {
AddTransportOptions(tt.url, *tt.opts)
}
} else {
assert.Equal(t, req.URL.String(), tt.expectedUrl)
assert.Equal(t, req.Method, tt.expectedMethod)
}
if tt.opts != nil {
RemoveTransportOptions(tt.url)
}
_, req, err := createClientRequest(tt.url, tt.action, tt.transport)
if tt.wantedErr != nil {
if tt.wantedErr.Error() != err.Error() {
t.Errorf("wanted: %v got: %v", tt.wantedErr, err)
}
} else {
assert.Equal(t, req.URL.String(), tt.expectedUrl)
assert.Equal(t, req.Method, tt.expectedMethod)
}
if tt.opts != nil {
RemoveTransportOptions(tt.url)
}
})
}
}
func TestOptions(t *testing.T) {
tests := []struct {
description string
name string
registerOpts bool
url string
opts TransportOptions
@ -140,7 +142,7 @@ func TestOptions(t *testing.T) {
expectedOpts *TransportOptions
}{
{
description: "return registered option",
name: "return registered option",
registerOpts: true,
url: "https://target/?123",
opts: TransportOptions{},
@ -148,7 +150,7 @@ func TestOptions(t *testing.T) {
expectedOpts: &TransportOptions{},
},
{
description: "match registered options",
name: "match registered options",
registerOpts: true,
url: "https://target/?876",
opts: TransportOptions{
@ -162,7 +164,7 @@ func TestOptions(t *testing.T) {
},
},
{
description: "ignore when options not registered",
name: "ignore when options not registered",
registerOpts: false,
url: "",
opts: TransportOptions{},
@ -172,28 +174,30 @@ func TestOptions(t *testing.T) {
}
for _, tt := range tests {
if tt.registerOpts {
AddTransportOptions(tt.url, tt.opts)
}
opts, found := transportOptions(tt.url)
if tt.expectOpts != found {
t.Errorf("%s: wanted %v got %v", tt.description, tt.expectOpts, found)
}
if tt.expectOpts {
if reflect.DeepEqual(opts, *tt.expectedOpts) {
t.Errorf("%s: wanted %v got %v", tt.description, *tt.expectedOpts, opts)
t.Run(tt.name, func(t *testing.T) {
if tt.registerOpts {
AddTransportOptions(tt.url, tt.opts)
}
}
if tt.registerOpts {
RemoveTransportOptions(tt.url)
}
opts, found := transportOptions(tt.url)
if tt.expectOpts != found {
t.Errorf("%s: wanted %v got %v", tt.name, tt.expectOpts, found)
}
if _, found = transportOptions(tt.url); found {
t.Errorf("%s: option for %s was not removed", tt.description, tt.url)
}
if tt.expectOpts {
if reflect.DeepEqual(opts, *tt.expectedOpts) {
t.Errorf("%s: wanted %v got %v", tt.name, *tt.expectedOpts, opts)
}
}
if tt.registerOpts {
RemoveTransportOptions(tt.url)
}
if _, found = transportOptions(tt.url); found {
t.Errorf("%s: option for %s was not removed", tt.name, tt.url)
}
})
}
}

View File

@ -53,6 +53,8 @@ import (
"net/url"
"runtime"
"strings"
"sync"
"time"
"golang.org/x/crypto/ssh"
@ -62,6 +64,17 @@ import (
// registerManagedSSH registers a Go-native implementation of
// SSH transport that doesn't rely on any lower-level libraries
// such as libssh2.
//
// The underlying SSH connections are kept open and are reused
// across several SSH sessions. This is due to upstream issues in
// which concurrent/parallel SSH connections may lead to instability.
//
// Connections are created on first attempt to use a given remote. The
// connection is removed from the cache on the first failed session related
// operation.
//
// https://github.com/golang/go/issues/51926
// https://github.com/golang/go/issues/27140
func registerManagedSSH() error {
for _, protocol := range []string{"ssh", "ssh+git", "git+ssh"} {
_, err := git2go.NewRegisteredSmartTransport(protocol, false, sshSmartSubtransportFactory)
@ -89,6 +102,18 @@ type sshSmartSubtransport struct {
currentStream *sshSmartSubtransportStream
}
// aMux is the read-write mutex to control access to sshClients.
var aMux sync.RWMutex
// sshClients stores active ssh clients/connections to be reused.
//
// Once opened, connections will be kept cached until an error occurs
// during SSH commands, by which point it will be discarded, leading to
// a follow-up cache miss.
//
// The key must be based on cacheKey, refer to that function's comments.
var sshClients map[string]*ssh.Client = make(map[string]*ssh.Client)
func (t *sshSmartSubtransport) Action(urlString string, action git2go.SmartServiceAction) (git2go.SmartSubtransportStream, error) {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
@ -135,7 +160,14 @@ func (t *sshSmartSubtransport) Action(urlString string, action git2go.SmartServi
}
defer cred.Free()
sshConfig, err := getSSHConfigFromCredential(cred)
var addr string
if u.Port() != "" {
addr = fmt.Sprintf("%s:%s", u.Hostname(), u.Port())
} else {
addr = fmt.Sprintf("%s:22", u.Hostname())
}
ckey, sshConfig, err := cacheKeyAndConfig(addr, cred)
if err != nil {
return nil, err
}
@ -156,52 +188,66 @@ func (t *sshSmartSubtransport) Action(urlString string, action git2go.SmartServi
return t.transport.SmartCertificateCheck(cert, true, hostname)
}
var addr string
if u.Port() != "" {
addr = fmt.Sprintf("%s:%s", u.Hostname(), u.Port())
} else {
addr = fmt.Sprintf("%s:22", u.Hostname())
aMux.RLock()
if c, ok := sshClients[ckey]; ok {
traceLog.Info("[ssh]: cache hit", "remoteAddress", addr)
t.client = c
}
aMux.RUnlock()
// In some scenarios the ssh handshake can hang indefinitely at
// golang.org/x/crypto/ssh.(*handshakeTransport).kexLoop.
//
// xref: https://github.com/golang/go/issues/51926
done := make(chan error, 1)
go func() {
t.client, err = ssh.Dial("tcp", addr, sshConfig)
done <- err
}()
if t.client == nil {
traceLog.Info("[ssh]: cache miss", "remoteAddress", addr)
select {
case doneErr := <-done:
if doneErr != nil {
err = fmt.Errorf("ssh.Dial: %w", doneErr)
aMux.Lock()
defer aMux.Unlock()
// In some scenarios the ssh handshake can hang indefinitely at
// golang.org/x/crypto/ssh.(*handshakeTransport).kexLoop.
//
// xref: https://github.com/golang/go/issues/51926
done := make(chan error, 1)
go func() {
t.client, err = ssh.Dial("tcp", addr, sshConfig)
done <- err
}()
dialTimeout := sshConfig.Timeout + (30 * time.Second)
select {
case doneErr := <-done:
if doneErr != nil {
err = fmt.Errorf("ssh.Dial: %w", doneErr)
}
case <-time.After(dialTimeout):
err = fmt.Errorf("timed out waiting for ssh.Dial after %s", dialTimeout)
}
case <-time.After(sshConfig.Timeout + (5 * time.Second)):
err = fmt.Errorf("timed out waiting for ssh.Dial")
}
if err != nil {
return nil, err
if err != nil {
return nil, err
}
sshClients[ckey] = t.client
}
traceLog.Info("[ssh]: creating new ssh session")
if t.session, err = t.client.NewSession(); err != nil {
discardCachedSshClient(ckey)
return nil, err
}
t.stdin, err = t.session.StdinPipe()
if err != nil {
if t.stdin, err = t.session.StdinPipe(); err != nil {
discardCachedSshClient(ckey)
return nil, err
}
t.stdout, err = t.session.StdoutPipe()
if err != nil {
if t.stdout, err = t.session.StdoutPipe(); err != nil {
discardCachedSshClient(ckey)
return nil, err
}
traceLog.Info("[ssh]: run on remote", "cmd", cmd)
if err := t.session.Start(cmd); err != nil {
discardCachedSshClient(ckey)
return nil, err
}
@ -214,15 +260,25 @@ func (t *sshSmartSubtransport) Action(urlString string, action git2go.SmartServi
}
func (t *sshSmartSubtransport) Close() error {
var returnErr error
traceLog.Info("[ssh]: sshSmartSubtransport.Close()")
t.currentStream = nil
if t.client != nil {
t.stdin.Close()
t.session.Wait()
t.session.Close()
if err := t.stdin.Close(); err != nil {
returnErr = fmt.Errorf("cannot close stdin: %w", err)
}
t.client = nil
}
return nil
if t.session != nil {
traceLog.Info("[ssh]: skipping session.wait")
traceLog.Info("[ssh]: session.Close()")
if err := t.session.Close(); err != nil {
returnErr = fmt.Errorf("cannot close session: %w", err)
}
}
return returnErr
}
func (t *sshSmartSubtransport) Free() {
@ -245,19 +301,23 @@ func (stream *sshSmartSubtransportStream) Free() {
traceLog.Info("[ssh]: sshSmartSubtransportStream.Free()")
}
func getSSHConfigFromCredential(cred *git2go.Credential) (*ssh.ClientConfig, error) {
func cacheKeyAndConfig(remoteAddress string, cred *git2go.Credential) (string, *ssh.ClientConfig, error) {
username, _, privatekey, passphrase, err := cred.GetSSHKey()
if err != nil {
return nil, err
return "", nil, err
}
var pemBytes []byte
if cred.Type() == git2go.CredentialTypeSSHMemory {
pemBytes = []byte(privatekey)
} else {
return nil, fmt.Errorf("file based SSH credential is not supported")
return "", nil, fmt.Errorf("file based SSH credential is not supported")
}
// must include the passphrase, otherwise a caller that knows the private key, but
// not its passphrase would be able to bypass auth.
ck := cacheKey(remoteAddress, username, passphrase, pemBytes)
var key ssh.Signer
if passphrase != "" {
key, err = ssh.ParsePrivateKeyWithPassphrase(pemBytes, []byte(passphrase))
@ -266,12 +326,44 @@ func getSSHConfigFromCredential(cred *git2go.Credential) (*ssh.ClientConfig, err
}
if err != nil {
return nil, err
return "", nil, err
}
return &ssh.ClientConfig{
cfg := &ssh.ClientConfig{
User: username,
Auth: []ssh.AuthMethod{ssh.PublicKeys(key)},
Timeout: sshConnectionTimeOut,
}, nil
}
return ck, cfg, nil
}
// cacheKey generates a cache key that is multi-tenancy safe.
//
// Stablishing multiple and concurrent ssh connections leads to stability
// issues documented above. However, the caching/sharing of already stablished
// connections could represent a vector for users to bypass the ssh authentication
// mechanism.
//
// cacheKey tries to ensure that connections are only shared by users that
// have the exact same remoteAddress and credentials.
func cacheKey(remoteAddress, userName, passphrase string, pubKey []byte) string {
h := sha256.New()
v := fmt.Sprintf("%s-%s-%s-%v", remoteAddress, userName, passphrase, pubKey)
h.Write([]byte(v))
return fmt.Sprintf("%x", h.Sum(nil))
}
// discardCachedSshClient discards the cached ssh client, forcing the next git operation
// to create a new one via ssh.Dial.
func discardCachedSshClient(key string) {
aMux.Lock()
defer aMux.Unlock()
if _, found := sshClients[key]; found {
traceLog.Info("[ssh]: discard cached ssh client")
delete(sshClients, key)
}
}

View File

@ -0,0 +1,124 @@
/*
Copyright 2022 The Flux authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package managed
import (
"testing"
)
func TestCacheKey(t *testing.T) {
tests := []struct {
name string
remoteAddress1 string
user1 string
passphrase1 string
pubKey1 []byte
remoteAddress2 string
user2 string
passphrase2 string
pubKey2 []byte
expectMatch bool
}{
{
name: "same remote addresses with no config",
remoteAddress1: "1.1.1.1",
remoteAddress2: "1.1.1.1",
expectMatch: true,
},
{
name: "same remote addresses with different config",
remoteAddress1: "1.1.1.1",
user1: "joe",
remoteAddress2: "1.1.1.1",
user2: "another-joe",
expectMatch: false,
},
{
name: "different remote addresses with no config",
remoteAddress1: "8.8.8.8",
remoteAddress2: "1.1.1.1",
expectMatch: false,
},
{
name: "different remote addresses with same config",
remoteAddress1: "8.8.8.8",
user1: "legit",
remoteAddress2: "1.1.1.1",
user2: "legit",
expectMatch: false,
},
{
name: "same remote addresses with same pubkey signers",
remoteAddress1: "1.1.1.1",
user1: "same-jane",
pubKey1: []byte{255, 123, 0},
remoteAddress2: "1.1.1.1",
user2: "same-jane",
pubKey2: []byte{255, 123, 0},
expectMatch: true,
},
{
name: "same remote addresses with different pubkey signers",
remoteAddress1: "1.1.1.1",
user1: "same-jane",
pubKey1: []byte{255, 123, 0},
remoteAddress2: "1.1.1.1",
user2: "same-jane",
pubKey2: []byte{0, 123, 0},
expectMatch: false,
},
{
name: "same remote addresses with pubkey signers and passphrases",
remoteAddress1: "1.1.1.1",
user1: "same-jane",
passphrase1: "same-pass",
pubKey1: []byte{255, 123, 0},
remoteAddress2: "1.1.1.1",
user2: "same-jane",
passphrase2: "same-pass",
pubKey2: []byte{255, 123, 0},
expectMatch: true,
},
{
name: "same remote addresses with pubkey signers and different passphrases",
remoteAddress1: "1.1.1.1",
user1: "same-jane",
passphrase1: "same-pass",
pubKey1: []byte{255, 123, 0},
remoteAddress2: "1.1.1.1",
user2: "same-jane",
passphrase2: "different-pass",
pubKey2: []byte{255, 123, 0},
expectMatch: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cacheKey1 := cacheKey(tt.remoteAddress1, tt.user1, tt.passphrase1, tt.pubKey1)
cacheKey2 := cacheKey(tt.remoteAddress2, tt.user2, tt.passphrase2, tt.pubKey2)
if tt.expectMatch && cacheKey1 != cacheKey2 {
t.Errorf("cache keys '%s' and '%s' should match", cacheKey1, cacheKey2)
}
if !tt.expectMatch && cacheKey1 == cacheKey2 {
t.Errorf("cache keys '%s' and '%s' should not match", cacheKey1, cacheKey2)
}
})
}
}