diff --git a/go.mod b/go.mod index 57bd1bbc..cc7b7b01 100644 --- a/go.mod +++ b/go.mod @@ -39,7 +39,7 @@ require ( github.com/onsi/gomega v1.18.1 github.com/otiai10/copy v1.7.0 github.com/spf13/pflag v1.0.5 - golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd + golang.org/x/crypto v0.0.0-20220321153916-2c7772ba3064 golang.org/x/sync v0.0.0-20210220032951-036812b2e83c google.golang.org/api v0.73.0 gotest.tools v2.2.0+incompatible diff --git a/go.sum b/go.sum index 635aa434..b08cc4fd 100644 --- a/go.sum +++ b/go.sum @@ -1153,6 +1153,8 @@ golang.org/x/crypto v0.0.0-20211117183948-ae814b36b871/go.mod h1:IxCIyHEi3zRg3s0 golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd h1:XcWmESyNjXJMLahc3mqVQJcgSTDxFxhETVlfk9uGc38= golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.0.0-20220321153916-2c7772ba3064 h1:S25/rfnfsMVgORT4/J61MJ7rdyseOZOyvLIrZEZ7s6s= +golang.org/x/crypto v0.0.0-20220321153916-2c7772ba3064/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= diff --git a/internal/helm/repository/chart_repository.go b/internal/helm/repository/chart_repository.go index 3c183ad6..e8154dca 100644 --- a/internal/helm/repository/chart_repository.go +++ b/internal/helm/repository/chart_repository.go @@ -39,7 +39,7 @@ import ( "github.com/fluxcd/pkg/version" "github.com/fluxcd/source-controller/internal/helm" - transport "github.com/fluxcd/source-controller/internal/helm/getter" + "github.com/fluxcd/source-controller/internal/transport" ) var ErrNoChartIndex = errors.New("no chart index") diff --git a/internal/helm/getter/transport.go b/internal/transport/transport.go similarity index 99% rename from internal/helm/getter/transport.go rename to internal/transport/transport.go index 34e0eaf8..89286df7 100644 --- a/internal/helm/getter/transport.go +++ b/internal/transport/transport.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package getter +package transport import ( "crypto/tls" diff --git a/internal/helm/getter/transport_test.go b/internal/transport/transport_test.go similarity index 98% rename from internal/helm/getter/transport_test.go rename to internal/transport/transport_test.go index aea7ffc1..c07a88d5 100644 --- a/internal/helm/getter/transport_test.go +++ b/internal/transport/transport_test.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package getter +package transport import ( "crypto/tls" diff --git a/main.go b/main.go index 120c83d5..0577de4e 100644 --- a/main.go +++ b/main.go @@ -228,7 +228,7 @@ func main() { }() if managed.Enabled() { - managed.InitManagedTransport() + managed.InitManagedTransport(ctrl.Log.WithName("managed-transport")) } setupLog.Info("starting manager") diff --git a/pkg/git/libgit2/managed/http.go b/pkg/git/libgit2/managed/http.go index 24adfd66..80465756 100644 --- a/pkg/git/libgit2/managed/http.go +++ b/pkg/git/libgit2/managed/http.go @@ -50,12 +50,11 @@ import ( "errors" "fmt" "io" - "net" "net/http" "net/url" "sync" - "time" + pool "github.com/fluxcd/source-controller/internal/transport" git2go "github.com/libgit2/git2go/v33" ) @@ -73,15 +72,18 @@ func registerManagedHTTP() error { } func httpSmartSubtransportFactory(remote *git2go.Remote, transport *git2go.Transport) (git2go.SmartSubtransport, error) { + traceLog.Info("[http]: httpSmartSubtransportFactory") sst := &httpSmartSubtransport{ - transport: transport, + transport: transport, + httpTransport: pool.NewOrIdle(nil), } return sst, nil } type httpSmartSubtransport struct { - transport *git2go.Transport + transport *git2go.Transport + httpTransport *http.Transport } func (t *httpSmartSubtransport) Action(targetUrl string, action git2go.SmartServiceAction) (git2go.SmartSubtransportStream, error) { @@ -104,25 +106,10 @@ func (t *httpSmartSubtransport) Action(targetUrl string, action git2go.SmartServ proxyFn = http.ProxyURL(parsedUrl) } - httpTransport := &http.Transport{ - // Add the proxy to the http transport. - Proxy: proxyFn, + t.httpTransport.Proxy = proxyFn + t.httpTransport.DisableCompression = false - // Set reasonable timeouts to ensure connections are not - // left open in an idle state, nor they hang indefinitely. - // - // These are based on the official go http.DefaultTransport: - DialContext: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }).DialContext, - MaxIdleConns: 100, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - } - - client, req, err := createClientRequest(targetUrl, action, httpTransport) + client, req, err := createClientRequest(targetUrl, action, t.httpTransport) if err != nil { return nil, err } @@ -223,10 +210,18 @@ func createClientRequest(targetUrl string, action git2go.SmartServiceAction, t * } func (t *httpSmartSubtransport) Close() error { + traceLog.Info("[http]: httpSmartSubtransport.Close()") return nil } func (t *httpSmartSubtransport) Free() { + traceLog.Info("[http]: httpSmartSubtransport.Free()") + + if t.httpTransport != nil { + traceLog.Info("[http]: release http transport back to pool") + pool.Release(t.httpTransport) + t.httpTransport = nil + } } type httpSmartSubtransportStream struct { @@ -291,7 +286,15 @@ func (self *httpSmartSubtransportStream) Write(buf []byte) (int, error) { func (self *httpSmartSubtransportStream) Free() { if self.resp != nil { - self.resp.Body.Close() + traceLog.Info("[http]: httpSmartSubtransportStream.Free()") + + if self.resp.Body != nil { + // ensure body is fully processed and closed + // for increased likelihood of transport reuse in HTTP/1.x. + // it should not be a problem to do this more than once. + _, _ = io.Copy(io.Discard, self.resp.Body) // errors can be safely ignored + _ = self.resp.Body.Close() // errors can be safely ignored + } } } @@ -354,6 +357,7 @@ func (self *httpSmartSubtransportStream) sendRequest() error { } req.SetBasicAuth(userName, password) + traceLog.Info("[http]: new request", "method", req.Method, "URL", req.URL) resp, err = self.client.Do(req) if err != nil { return err @@ -362,21 +366,36 @@ func (self *httpSmartSubtransportStream) sendRequest() error { // GET requests will be automatically redirected. // POST require the new destination, and also the body content. if req.Method == "POST" && resp.StatusCode >= 301 && resp.StatusCode <= 308 { + // ensure body is fully processed and closed + // for increased likelihood of transport reuse in HTTP/1.x. + _, _ = io.Copy(io.Discard, resp.Body) // errors can be safely ignored + + if err := resp.Body.Close(); err != nil { + return err + } + // The next try will go against the new destination self.req.URL, err = resp.Location() if err != nil { return err } + traceLog.Info("[http]: POST redirect", "URL", self.req.URL) continue } + // for HTTP 200, the response will be cleared up by Free() if resp.StatusCode == http.StatusOK { break } - io.Copy(io.Discard, resp.Body) - defer resp.Body.Close() + // ensure body is fully processed and closed + // for increased likelihood of transport reuse in HTTP/1.x. + _, _ = io.Copy(io.Discard, resp.Body) // errors can be safely ignored + if err := resp.Body.Close(); err != nil { + return err + } + return fmt.Errorf("Unhandled HTTP error %s", resp.Status) } diff --git a/pkg/git/libgit2/managed/init.go b/pkg/git/libgit2/managed/init.go index 8df4a9ae..d0cac956 100644 --- a/pkg/git/libgit2/managed/init.go +++ b/pkg/git/libgit2/managed/init.go @@ -19,6 +19,9 @@ package managed import ( "sync" "time" + + "github.com/fluxcd/pkg/runtime/logger" + "github.com/go-logr/logr" ) var ( @@ -34,6 +37,9 @@ var ( // regardless of the current operation (i.e. connection, // handshake, put/get). fullHttpClientTimeOut time.Duration = 10 * time.Minute + + debugLog logr.Logger + traceLog logr.Logger ) // InitManagedTransport initialises HTTP(S) and SSH managed transport @@ -47,9 +53,14 @@ var ( // // This function will only register managed transports once, subsequent calls // leads to no-op. -func InitManagedTransport() error { +func InitManagedTransport(log logr.Logger) error { var err error + once.Do(func() { + log.Info("Enabling experimental managed transport") + debugLog = log.V(logger.DebugLevel) + traceLog = log.V(logger.TraceLevel) + if err = registerManagedHTTP(); err != nil { return } diff --git a/pkg/git/libgit2/managed/managed_test.go b/pkg/git/libgit2/managed/managed_test.go index 1d858277..14c47385 100644 --- a/pkg/git/libgit2/managed/managed_test.go +++ b/pkg/git/libgit2/managed/managed_test.go @@ -27,6 +27,7 @@ import ( "github.com/fluxcd/pkg/gittestserver" "github.com/fluxcd/pkg/ssh" "github.com/fluxcd/source-controller/pkg/git" + "github.com/go-logr/logr" git2go "github.com/libgit2/git2go/v33" . "github.com/onsi/gomega" @@ -35,7 +36,7 @@ import ( func TestHttpAction_CreateClientRequest(t *testing.T) { tests := []struct { - description string + name string url string expectedUrl string expectedMethod string @@ -45,7 +46,7 @@ func TestHttpAction_CreateClientRequest(t *testing.T) { wantedErr error }{ { - description: "Uploadpack: no changes when no options found", + name: "Uploadpack: no changes when no options found", url: "https://sometarget/abc", expectedUrl: "https://sometarget/abc/git-upload-pack", expectedMethod: "POST", @@ -55,7 +56,7 @@ func TestHttpAction_CreateClientRequest(t *testing.T) { wantedErr: nil, }, { - description: "UploadpackLs: no changes when no options found", + name: "UploadpackLs: no changes when no options found", url: "https://sometarget/abc", expectedUrl: "https://sometarget/abc/info/refs?service=git-upload-pack", expectedMethod: "GET", @@ -65,7 +66,7 @@ func TestHttpAction_CreateClientRequest(t *testing.T) { wantedErr: nil, }, { - description: "Receivepack: no changes when no options found", + name: "Receivepack: no changes when no options found", url: "https://sometarget/abc", expectedUrl: "https://sometarget/abc/git-receive-pack", expectedMethod: "POST", @@ -75,7 +76,7 @@ func TestHttpAction_CreateClientRequest(t *testing.T) { wantedErr: nil, }, { - description: "ReceivepackLs: no changes when no options found", + name: "ReceivepackLs: no changes when no options found", url: "https://sometarget/abc", expectedUrl: "https://sometarget/abc/info/refs?service=git-receive-pack", expectedMethod: "GET", @@ -85,7 +86,7 @@ func TestHttpAction_CreateClientRequest(t *testing.T) { wantedErr: nil, }, { - description: "override URL via options", + name: "override URL via options", url: "https://initial-target/abc", expectedUrl: "https://final-target/git-upload-pack", expectedMethod: "POST", @@ -97,7 +98,7 @@ func TestHttpAction_CreateClientRequest(t *testing.T) { wantedErr: nil, }, { - description: "error when no http.transport provided", + name: "error when no http.transport provided", url: "https://initial-target/abc", expectedUrl: "", expectedMethod: "", @@ -109,29 +110,31 @@ func TestHttpAction_CreateClientRequest(t *testing.T) { } for _, tt := range tests { - if tt.opts != nil { - AddTransportOptions(tt.url, *tt.opts) - } - - _, req, err := createClientRequest(tt.url, tt.action, tt.transport) - if tt.wantedErr != nil { - if tt.wantedErr.Error() != err.Error() { - t.Errorf("%s: wanted: %v got: %v", tt.description, tt.wantedErr, err) + t.Run(tt.name, func(t *testing.T) { + if tt.opts != nil { + AddTransportOptions(tt.url, *tt.opts) } - } else { - assert.Equal(t, req.URL.String(), tt.expectedUrl) - assert.Equal(t, req.Method, tt.expectedMethod) - } - if tt.opts != nil { - RemoveTransportOptions(tt.url) - } + _, req, err := createClientRequest(tt.url, tt.action, tt.transport) + if tt.wantedErr != nil { + if tt.wantedErr.Error() != err.Error() { + t.Errorf("wanted: %v got: %v", tt.wantedErr, err) + } + } else { + assert.Equal(t, req.URL.String(), tt.expectedUrl) + assert.Equal(t, req.Method, tt.expectedMethod) + } + + if tt.opts != nil { + RemoveTransportOptions(tt.url) + } + }) } } func TestOptions(t *testing.T) { tests := []struct { - description string + name string registerOpts bool url string opts TransportOptions @@ -139,7 +142,7 @@ func TestOptions(t *testing.T) { expectedOpts *TransportOptions }{ { - description: "return registered option", + name: "return registered option", registerOpts: true, url: "https://target/?123", opts: TransportOptions{}, @@ -147,7 +150,7 @@ func TestOptions(t *testing.T) { expectedOpts: &TransportOptions{}, }, { - description: "match registered options", + name: "match registered options", registerOpts: true, url: "https://target/?876", opts: TransportOptions{ @@ -161,7 +164,7 @@ func TestOptions(t *testing.T) { }, }, { - description: "ignore when options not registered", + name: "ignore when options not registered", registerOpts: false, url: "", opts: TransportOptions{}, @@ -171,28 +174,30 @@ func TestOptions(t *testing.T) { } for _, tt := range tests { - if tt.registerOpts { - AddTransportOptions(tt.url, tt.opts) - } - - opts, found := transportOptions(tt.url) - if tt.expectOpts != found { - t.Errorf("%s: wanted %v got %v", tt.description, tt.expectOpts, found) - } - - if tt.expectOpts { - if reflect.DeepEqual(opts, *tt.expectedOpts) { - t.Errorf("%s: wanted %v got %v", tt.description, *tt.expectedOpts, opts) + t.Run(tt.name, func(t *testing.T) { + if tt.registerOpts { + AddTransportOptions(tt.url, tt.opts) } - } - if tt.registerOpts { - RemoveTransportOptions(tt.url) - } + opts, found := transportOptions(tt.url) + if tt.expectOpts != found { + t.Errorf("%s: wanted %v got %v", tt.name, tt.expectOpts, found) + } - if _, found = transportOptions(tt.url); found { - t.Errorf("%s: option for %s was not removed", tt.description, tt.url) - } + if tt.expectOpts { + if reflect.DeepEqual(opts, *tt.expectedOpts) { + t.Errorf("%s: wanted %v got %v", tt.name, *tt.expectedOpts, opts) + } + } + + if tt.registerOpts { + RemoveTransportOptions(tt.url) + } + + if _, found = transportOptions(tt.url); found { + t.Errorf("%s: option for %s was not removed", tt.name, tt.url) + } + }) } } @@ -247,7 +252,7 @@ func TestManagedTransport_E2E(t *testing.T) { defer server.StopSSH() // Force managed transport to be enabled - InitManagedTransport() + InitManagedTransport(logr.Discard()) repoPath := "test.git" err = server.InitRepo("../testdata/git/repo", git.DefaultBranch, repoPath) @@ -312,7 +317,7 @@ func TestManagedTransport_HandleRedirect(t *testing.T) { defer os.RemoveAll(tmpDir) // Force managed transport to be enabled - InitManagedTransport() + InitManagedTransport(logr.Discard()) // GitHub will cause a 301 and redirect to https repo, err := git2go.Clone("http://github.com/stefanprodan/podinfo", tmpDir, &git2go.CloneOptions{ diff --git a/pkg/git/libgit2/managed/ssh.go b/pkg/git/libgit2/managed/ssh.go index 76833ac6..a6d41705 100644 --- a/pkg/git/libgit2/managed/ssh.go +++ b/pkg/git/libgit2/managed/ssh.go @@ -53,6 +53,8 @@ import ( "net/url" "runtime" "strings" + "sync" + "time" "golang.org/x/crypto/ssh" @@ -62,6 +64,17 @@ import ( // registerManagedSSH registers a Go-native implementation of // SSH transport that doesn't rely on any lower-level libraries // such as libssh2. +// +// The underlying SSH connections are kept open and are reused +// across several SSH sessions. This is due to upstream issues in +// which concurrent/parallel SSH connections may lead to instability. +// +// Connections are created on first attempt to use a given remote. The +// connection is removed from the cache on the first failed session related +// operation. +// +// https://github.com/golang/go/issues/51926 +// https://github.com/golang/go/issues/27140 func registerManagedSSH() error { for _, protocol := range []string{"ssh", "ssh+git", "git+ssh"} { _, err := git2go.NewRegisteredSmartTransport(protocol, false, sshSmartSubtransportFactory) @@ -89,6 +102,18 @@ type sshSmartSubtransport struct { currentStream *sshSmartSubtransportStream } +// aMux is the read-write mutex to control access to sshClients. +var aMux sync.RWMutex + +// sshClients stores active ssh clients/connections to be reused. +// +// Once opened, connections will be kept cached until an error occurs +// during SSH commands, by which point it will be discarded, leading to +// a follow-up cache miss. +// +// The key must be based on cacheKey, refer to that function's comments. +var sshClients map[string]*ssh.Client = make(map[string]*ssh.Client) + func (t *sshSmartSubtransport) Action(urlString string, action git2go.SmartServiceAction) (git2go.SmartSubtransportStream, error) { runtime.LockOSThread() defer runtime.UnlockOSThread() @@ -135,7 +160,14 @@ func (t *sshSmartSubtransport) Action(urlString string, action git2go.SmartServi } defer cred.Free() - sshConfig, err := getSSHConfigFromCredential(cred) + var addr string + if u.Port() != "" { + addr = fmt.Sprintf("%s:%s", u.Hostname(), u.Port()) + } else { + addr = fmt.Sprintf("%s:22", u.Hostname()) + } + + ckey, sshConfig, err := cacheKeyAndConfig(addr, cred) if err != nil { return nil, err } @@ -156,34 +188,66 @@ func (t *sshSmartSubtransport) Action(urlString string, action git2go.SmartServi return t.transport.SmartCertificateCheck(cert, true, hostname) } - var addr string - if u.Port() != "" { - addr = fmt.Sprintf("%s:%s", u.Hostname(), u.Port()) - } else { - addr = fmt.Sprintf("%s:22", u.Hostname()) + aMux.RLock() + if c, ok := sshClients[ckey]; ok { + traceLog.Info("[ssh]: cache hit", "remoteAddress", addr) + t.client = c + } + aMux.RUnlock() + + if t.client == nil { + traceLog.Info("[ssh]: cache miss", "remoteAddress", addr) + + aMux.Lock() + defer aMux.Unlock() + + // In some scenarios the ssh handshake can hang indefinitely at + // golang.org/x/crypto/ssh.(*handshakeTransport).kexLoop. + // + // xref: https://github.com/golang/go/issues/51926 + done := make(chan error, 1) + go func() { + t.client, err = ssh.Dial("tcp", addr, sshConfig) + done <- err + }() + + dialTimeout := sshConfig.Timeout + (30 * time.Second) + + select { + case doneErr := <-done: + if doneErr != nil { + err = fmt.Errorf("ssh.Dial: %w", doneErr) + } + case <-time.After(dialTimeout): + err = fmt.Errorf("timed out waiting for ssh.Dial after %s", dialTimeout) + } + + if err != nil { + return nil, err + } + + sshClients[ckey] = t.client } - t.client, err = ssh.Dial("tcp", addr, sshConfig) - if err != nil { + traceLog.Info("[ssh]: creating new ssh session") + if t.session, err = t.client.NewSession(); err != nil { + discardCachedSshClient(ckey) return nil, err } - t.session, err = t.client.NewSession() - if err != nil { + if t.stdin, err = t.session.StdinPipe(); err != nil { + discardCachedSshClient(ckey) return nil, err } - t.stdin, err = t.session.StdinPipe() - if err != nil { - return nil, err - } - - t.stdout, err = t.session.StdoutPipe() - if err != nil { + if t.stdout, err = t.session.StdoutPipe(); err != nil { + discardCachedSshClient(ckey) return nil, err } + traceLog.Info("[ssh]: run on remote", "cmd", cmd) if err := t.session.Start(cmd); err != nil { + discardCachedSshClient(ckey) return nil, err } @@ -196,17 +260,29 @@ func (t *sshSmartSubtransport) Action(urlString string, action git2go.SmartServi } func (t *sshSmartSubtransport) Close() error { + var returnErr error + + traceLog.Info("[ssh]: sshSmartSubtransport.Close()") t.currentStream = nil if t.client != nil { - t.stdin.Close() - t.session.Wait() - t.session.Close() + if err := t.stdin.Close(); err != nil { + returnErr = fmt.Errorf("cannot close stdin: %w", err) + } t.client = nil } - return nil + if t.session != nil { + traceLog.Info("[ssh]: skipping session.wait") + traceLog.Info("[ssh]: session.Close()") + if err := t.session.Close(); err != nil { + returnErr = fmt.Errorf("cannot close session: %w", err) + } + } + + return returnErr } func (t *sshSmartSubtransport) Free() { + traceLog.Info("[ssh]: sshSmartSubtransport.Free()") } type sshSmartSubtransportStream struct { @@ -222,21 +298,26 @@ func (stream *sshSmartSubtransportStream) Write(buf []byte) (int, error) { } func (stream *sshSmartSubtransportStream) Free() { + traceLog.Info("[ssh]: sshSmartSubtransportStream.Free()") } -func getSSHConfigFromCredential(cred *git2go.Credential) (*ssh.ClientConfig, error) { +func cacheKeyAndConfig(remoteAddress string, cred *git2go.Credential) (string, *ssh.ClientConfig, error) { username, _, privatekey, passphrase, err := cred.GetSSHKey() if err != nil { - return nil, err + return "", nil, err } var pemBytes []byte if cred.Type() == git2go.CredentialTypeSSHMemory { pemBytes = []byte(privatekey) } else { - return nil, fmt.Errorf("file based SSH credential is not supported") + return "", nil, fmt.Errorf("file based SSH credential is not supported") } + // must include the passphrase, otherwise a caller that knows the private key, but + // not its passphrase would be able to bypass auth. + ck := cacheKey(remoteAddress, username, passphrase, pemBytes) + var key ssh.Signer if passphrase != "" { key, err = ssh.ParsePrivateKeyWithPassphrase(pemBytes, []byte(passphrase)) @@ -245,12 +326,44 @@ func getSSHConfigFromCredential(cred *git2go.Credential) (*ssh.ClientConfig, err } if err != nil { - return nil, err + return "", nil, err } - return &ssh.ClientConfig{ + cfg := &ssh.ClientConfig{ User: username, Auth: []ssh.AuthMethod{ssh.PublicKeys(key)}, Timeout: sshConnectionTimeOut, - }, nil + } + + return ck, cfg, nil +} + +// cacheKey generates a cache key that is multi-tenancy safe. +// +// Stablishing multiple and concurrent ssh connections leads to stability +// issues documented above. However, the caching/sharing of already stablished +// connections could represent a vector for users to bypass the ssh authentication +// mechanism. +// +// cacheKey tries to ensure that connections are only shared by users that +// have the exact same remoteAddress and credentials. +func cacheKey(remoteAddress, userName, passphrase string, pubKey []byte) string { + h := sha256.New() + + v := fmt.Sprintf("%s-%s-%s-%v", remoteAddress, userName, passphrase, pubKey) + + h.Write([]byte(v)) + return fmt.Sprintf("%x", h.Sum(nil)) +} + +// discardCachedSshClient discards the cached ssh client, forcing the next git operation +// to create a new one via ssh.Dial. +func discardCachedSshClient(key string) { + aMux.Lock() + defer aMux.Unlock() + + if _, found := sshClients[key]; found { + traceLog.Info("[ssh]: discard cached ssh client") + delete(sshClients, key) + } } diff --git a/pkg/git/libgit2/managed/ssh_test.go b/pkg/git/libgit2/managed/ssh_test.go new file mode 100644 index 00000000..0b28d519 --- /dev/null +++ b/pkg/git/libgit2/managed/ssh_test.go @@ -0,0 +1,124 @@ +/* +Copyright 2022 The Flux authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package managed + +import ( + "testing" +) + +func TestCacheKey(t *testing.T) { + tests := []struct { + name string + remoteAddress1 string + user1 string + passphrase1 string + pubKey1 []byte + remoteAddress2 string + user2 string + passphrase2 string + pubKey2 []byte + expectMatch bool + }{ + { + name: "same remote addresses with no config", + remoteAddress1: "1.1.1.1", + remoteAddress2: "1.1.1.1", + expectMatch: true, + }, + { + name: "same remote addresses with different config", + remoteAddress1: "1.1.1.1", + user1: "joe", + remoteAddress2: "1.1.1.1", + user2: "another-joe", + expectMatch: false, + }, + { + name: "different remote addresses with no config", + remoteAddress1: "8.8.8.8", + remoteAddress2: "1.1.1.1", + expectMatch: false, + }, + { + name: "different remote addresses with same config", + remoteAddress1: "8.8.8.8", + user1: "legit", + remoteAddress2: "1.1.1.1", + user2: "legit", + expectMatch: false, + }, + { + name: "same remote addresses with same pubkey signers", + remoteAddress1: "1.1.1.1", + user1: "same-jane", + pubKey1: []byte{255, 123, 0}, + remoteAddress2: "1.1.1.1", + user2: "same-jane", + pubKey2: []byte{255, 123, 0}, + expectMatch: true, + }, + { + name: "same remote addresses with different pubkey signers", + remoteAddress1: "1.1.1.1", + user1: "same-jane", + pubKey1: []byte{255, 123, 0}, + remoteAddress2: "1.1.1.1", + user2: "same-jane", + pubKey2: []byte{0, 123, 0}, + expectMatch: false, + }, + { + name: "same remote addresses with pubkey signers and passphrases", + remoteAddress1: "1.1.1.1", + user1: "same-jane", + passphrase1: "same-pass", + pubKey1: []byte{255, 123, 0}, + remoteAddress2: "1.1.1.1", + user2: "same-jane", + passphrase2: "same-pass", + pubKey2: []byte{255, 123, 0}, + expectMatch: true, + }, + { + name: "same remote addresses with pubkey signers and different passphrases", + remoteAddress1: "1.1.1.1", + user1: "same-jane", + passphrase1: "same-pass", + pubKey1: []byte{255, 123, 0}, + remoteAddress2: "1.1.1.1", + user2: "same-jane", + passphrase2: "different-pass", + pubKey2: []byte{255, 123, 0}, + expectMatch: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cacheKey1 := cacheKey(tt.remoteAddress1, tt.user1, tt.passphrase1, tt.pubKey1) + cacheKey2 := cacheKey(tt.remoteAddress2, tt.user2, tt.passphrase2, tt.pubKey2) + + if tt.expectMatch && cacheKey1 != cacheKey2 { + t.Errorf("cache keys '%s' and '%s' should match", cacheKey1, cacheKey2) + } + + if !tt.expectMatch && cacheKey1 == cacheKey2 { + t.Errorf("cache keys '%s' and '%s' should not match", cacheKey1, cacheKey2) + } + }) + } +}