diff --git a/docker/transport/sshconn.py b/docker/transport/sshconn.py index 18706680..c75504ba 100644 --- a/docker/transport/sshconn.py +++ b/docker/transport/sshconn.py @@ -14,7 +14,18 @@ import urllib3.connection from .. import constants from .basehttpadapter import BaseHTTPAdapter -RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer + +class DontCloseStreamWrapper: + """Make close() noop for the wrapped object.""" + + def __init__(self, obj): + self.obj = obj + + def __getattr__(self, name): + def wrapper(*args, **kwargs): + if name != "close": + return getattr(self.obj, name)(*args, **kwargs) + return wrapper class SSHSocket(socket.socket): @@ -85,7 +96,7 @@ class SSHSocket(socket.socket): self.connect() self.proc.stdout.channel = self - return self.proc.stdout + return DontCloseStreamWrapper(self.proc.stdout) def close(self): if not self.proc or self.proc.stdin.closed: @@ -159,7 +170,7 @@ class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool): class SSHHTTPAdapter(BaseHTTPAdapter): __attrs__ = requests.adapters.HTTPAdapter.__attrs__ + [ - 'pools', 'timeout', 'ssh_client', 'ssh_params', 'max_pool_size' + 'pool', 'timeout', 'ssh_client', 'ssh_params', 'max_pool_size' ] def __init__(self, base_url, timeout=60, @@ -176,10 +187,8 @@ class SSHHTTPAdapter(BaseHTTPAdapter): self.ssh_host = base_url[len('ssh://'):] self.timeout = timeout + self.pool = None self.max_pool_size = max_pool_size - self.pools = RecentlyUsedContainer( - pool_connections, dispose_func=lambda p: p.close() - ) super().__init__() def _create_paramiko_client(self, base_url): @@ -218,31 +227,16 @@ class SSHHTTPAdapter(BaseHTTPAdapter): self.ssh_client.connect(**self.ssh_params) def get_connection(self, url, proxies=None): - if not self.ssh_client: - return SSHConnectionPool( - ssh_client=self.ssh_client, - timeout=self.timeout, - maxsize=self.max_pool_size, - host=self.ssh_host - ) - with self.pools.lock: - pool = self.pools.get(url) - if pool: - return pool - - # Connection is closed try a reconnect + if self.pool is None: if self.ssh_client and not self.ssh_client.get_transport(): self._connect() - - pool = SSHConnectionPool( + self.pool = SSHConnectionPool( ssh_client=self.ssh_client, timeout=self.timeout, maxsize=self.max_pool_size, host=self.ssh_host ) - self.pools[url] = pool - - return pool + return self.pool def close(self): super().close()