This commit is contained in:
VelorumS 2025-01-19 13:36:21 +01:00 committed by GitHub
commit cf1ffe0d6f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 18 additions and 24 deletions

View File

@ -14,7 +14,18 @@ import urllib3.connection
from .. import constants
from .basehttpadapter import BaseHTTPAdapter
RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer
class DontCloseStreamWrapper:
"""Make close() noop for the wrapped object."""
def __init__(self, obj):
self.obj = obj
def __getattr__(self, name):
def wrapper(*args, **kwargs):
if name != "close":
return getattr(self.obj, name)(*args, **kwargs)
return wrapper
class SSHSocket(socket.socket):
@ -85,7 +96,7 @@ class SSHSocket(socket.socket):
self.connect()
self.proc.stdout.channel = self
return self.proc.stdout
return DontCloseStreamWrapper(self.proc.stdout)
def close(self):
if not self.proc or self.proc.stdin.closed:
@ -159,7 +170,7 @@ class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
class SSHHTTPAdapter(BaseHTTPAdapter):
__attrs__ = requests.adapters.HTTPAdapter.__attrs__ + [
'pools', 'timeout', 'ssh_client', 'ssh_params', 'max_pool_size'
'pool', 'timeout', 'ssh_client', 'ssh_params', 'max_pool_size'
]
def __init__(self, base_url, timeout=60,
@ -176,10 +187,8 @@ class SSHHTTPAdapter(BaseHTTPAdapter):
self.ssh_host = base_url[len('ssh://'):]
self.timeout = timeout
self.pool = None
self.max_pool_size = max_pool_size
self.pools = RecentlyUsedContainer(
pool_connections, dispose_func=lambda p: p.close()
)
super().__init__()
def _create_paramiko_client(self, base_url):
@ -218,31 +227,16 @@ class SSHHTTPAdapter(BaseHTTPAdapter):
self.ssh_client.connect(**self.ssh_params)
def get_connection(self, url, proxies=None):
if not self.ssh_client:
return SSHConnectionPool(
ssh_client=self.ssh_client,
timeout=self.timeout,
maxsize=self.max_pool_size,
host=self.ssh_host
)
with self.pools.lock:
pool = self.pools.get(url)
if pool:
return pool
# Connection is closed try a reconnect
if self.pool is None:
if self.ssh_client and not self.ssh_client.get_transport():
self._connect()
pool = SSHConnectionPool(
self.pool = SSHConnectionPool(
ssh_client=self.ssh_client,
timeout=self.timeout,
maxsize=self.max_pool_size,
host=self.ssh_host
)
self.pools[url] = pool
return pool
return self.pool
def close(self):
super().close()