diff --git a/docker/transport/sshconn.py b/docker/transport/sshconn.py index 6e1d0ee7..070b6937 100644 --- a/docker/transport/sshconn.py +++ b/docker/transport/sshconn.py @@ -11,10 +11,25 @@ import subprocess from docker.transport.basehttpadapter import BaseHTTPAdapter from .. import constants -import urllib3 -import urllib3.connection +import http.client as httplib -RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer +try: + import requests.packages.urllib3 as urllib3 +except ImportError: + import urllib3 + + +class DontCloseStreamWrapper: + """Make close() noop for the wrapped object.""" + + def __init__(self, obj): + self.obj = obj + + def __getattr__(self, name): + def wrapper(*args, **kwargs): + if name != "close": + return getattr(self.obj, name)(*args, **kwargs) + return wrapper class SSHSocket(socket.socket): @@ -54,11 +69,12 @@ class SSHSocket(socket.socket): env.pop('SSL_CERT_FILE', None) self.proc = subprocess.Popen( - args, + ' '.join(args), env=env, + shell=True, stdout=subprocess.PIPE, stdin=subprocess.PIPE, - preexec_fn=preexec_func) + preexec_fn=None if constants.IS_WINDOWS_PLATFORM else preexec_func) def _write(self, data): if not self.proc or self.proc.stdin.closed: @@ -85,7 +101,7 @@ class SSHSocket(socket.socket): self.connect() self.proc.stdout.channel = self - return self.proc.stdout + return DontCloseStreamWrapper(self.proc.stdout) def close(self): if not self.proc or self.proc.stdin.closed: @@ -95,7 +111,7 @@ class SSHSocket(socket.socket): self.proc.terminate() -class SSHConnection(urllib3.connection.HTTPConnection): +class SSHConnection(httplib.HTTPConnection): def __init__(self, ssh_transport=None, timeout=60, host=None): super().__init__( 'localhost', timeout=timeout @@ -141,8 +157,8 @@ class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool): try: conn = self.pool.get(block=self.block, timeout=timeout) - except AttributeError as ae: # self.pool is None - raise urllib3.exceptions.ClosedPoolError(self, "Pool is closed.") from ae + except AttributeError: # self.pool is None + raise urllib3.exceptions.ClosedPoolError(self, "Pool is closed.") except queue.Empty: if self.block: @@ -150,8 +166,8 @@ class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool): self, "Pool reached maximum size and no more " "connections are allowed." - ) from None - # Oh well, we'll create a new connection then + ) + pass # Oh well, we'll create a new connection then return conn or self._new_conn() @@ -159,7 +175,7 @@ class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool): class SSHHTTPAdapter(BaseHTTPAdapter): __attrs__ = requests.adapters.HTTPAdapter.__attrs__ + [ - 'pools', 'timeout', 'ssh_client', 'ssh_params', 'max_pool_size' + 'pool', 'timeout', 'ssh_client', 'ssh_params', 'max_pool_size' ] def __init__(self, base_url, timeout=60, @@ -176,10 +192,8 @@ class SSHHTTPAdapter(BaseHTTPAdapter): self.ssh_host = base_url[len('ssh://'):] self.timeout = timeout + self.pool = None self.max_pool_size = max_pool_size - self.pools = RecentlyUsedContainer( - pool_connections, dispose_func=lambda p: p.close() - ) super().__init__() def _create_paramiko_client(self, base_url): @@ -199,7 +213,7 @@ class SSHHTTPAdapter(BaseHTTPAdapter): host_config = conf.lookup(base_url.hostname) if 'proxycommand' in host_config: self.ssh_params["sock"] = paramiko.ProxyCommand( - host_config['proxycommand'] + self.ssh_conf['proxycommand'] ) if 'hostname' in host_config: self.ssh_params['hostname'] = host_config['hostname'] @@ -211,38 +225,23 @@ class SSHHTTPAdapter(BaseHTTPAdapter): self.ssh_params['key_filename'] = host_config['identityfile'] self.ssh_client.load_system_host_keys() - self.ssh_client.set_missing_host_key_policy(paramiko.RejectPolicy()) + self.ssh_client.set_missing_host_key_policy(paramiko.WarningPolicy()) def _connect(self): if self.ssh_client: self.ssh_client.connect(**self.ssh_params) def get_connection(self, url, proxies=None): - if not self.ssh_client: - return SSHConnectionPool( - ssh_client=self.ssh_client, - timeout=self.timeout, - maxsize=self.max_pool_size, - host=self.ssh_host - ) - with self.pools.lock: - pool = self.pools.get(url) - if pool: - return pool - - # Connection is closed try a reconnect + if self.pool is None: if self.ssh_client and not self.ssh_client.get_transport(): self._connect() - - pool = SSHConnectionPool( + self.pool = SSHConnectionPool( ssh_client=self.ssh_client, timeout=self.timeout, maxsize=self.max_pool_size, host=self.ssh_host ) - self.pools[url] = pool - - return pool + return self.pool def close(self): super().close()