mirror of https://github.com/docker/docker-py.git
Reuse ssh connection
Signed-off-by: coodyz <guo92820@gmail.com>
This commit is contained in:
parent
c38656dc78
commit
3e5d872ceb
|
|
@ -11,10 +11,25 @@ import subprocess
|
||||||
from docker.transport.basehttpadapter import BaseHTTPAdapter
|
from docker.transport.basehttpadapter import BaseHTTPAdapter
|
||||||
from .. import constants
|
from .. import constants
|
||||||
|
|
||||||
import urllib3
|
import http.client as httplib
|
||||||
import urllib3.connection
|
|
||||||
|
|
||||||
RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer
|
try:
|
||||||
|
import requests.packages.urllib3 as urllib3
|
||||||
|
except ImportError:
|
||||||
|
import urllib3
|
||||||
|
|
||||||
|
|
||||||
|
class DontCloseStreamWrapper:
|
||||||
|
"""Make close() noop for the wrapped object."""
|
||||||
|
|
||||||
|
def __init__(self, obj):
|
||||||
|
self.obj = obj
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
if name != "close":
|
||||||
|
return getattr(self.obj, name)(*args, **kwargs)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
class SSHSocket(socket.socket):
|
class SSHSocket(socket.socket):
|
||||||
|
|
@ -54,11 +69,12 @@ class SSHSocket(socket.socket):
|
||||||
env.pop('SSL_CERT_FILE', None)
|
env.pop('SSL_CERT_FILE', None)
|
||||||
|
|
||||||
self.proc = subprocess.Popen(
|
self.proc = subprocess.Popen(
|
||||||
args,
|
' '.join(args),
|
||||||
env=env,
|
env=env,
|
||||||
|
shell=True,
|
||||||
stdout=subprocess.PIPE,
|
stdout=subprocess.PIPE,
|
||||||
stdin=subprocess.PIPE,
|
stdin=subprocess.PIPE,
|
||||||
preexec_fn=preexec_func)
|
preexec_fn=None if constants.IS_WINDOWS_PLATFORM else preexec_func)
|
||||||
|
|
||||||
def _write(self, data):
|
def _write(self, data):
|
||||||
if not self.proc or self.proc.stdin.closed:
|
if not self.proc or self.proc.stdin.closed:
|
||||||
|
|
@ -85,7 +101,7 @@ class SSHSocket(socket.socket):
|
||||||
self.connect()
|
self.connect()
|
||||||
self.proc.stdout.channel = self
|
self.proc.stdout.channel = self
|
||||||
|
|
||||||
return self.proc.stdout
|
return DontCloseStreamWrapper(self.proc.stdout)
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
if not self.proc or self.proc.stdin.closed:
|
if not self.proc or self.proc.stdin.closed:
|
||||||
|
|
@ -95,7 +111,7 @@ class SSHSocket(socket.socket):
|
||||||
self.proc.terminate()
|
self.proc.terminate()
|
||||||
|
|
||||||
|
|
||||||
class SSHConnection(urllib3.connection.HTTPConnection):
|
class SSHConnection(httplib.HTTPConnection):
|
||||||
def __init__(self, ssh_transport=None, timeout=60, host=None):
|
def __init__(self, ssh_transport=None, timeout=60, host=None):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
'localhost', timeout=timeout
|
'localhost', timeout=timeout
|
||||||
|
|
@ -141,8 +157,8 @@ class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
|
||||||
try:
|
try:
|
||||||
conn = self.pool.get(block=self.block, timeout=timeout)
|
conn = self.pool.get(block=self.block, timeout=timeout)
|
||||||
|
|
||||||
except AttributeError as ae: # self.pool is None
|
except AttributeError: # self.pool is None
|
||||||
raise urllib3.exceptions.ClosedPoolError(self, "Pool is closed.") from ae
|
raise urllib3.exceptions.ClosedPoolError(self, "Pool is closed.")
|
||||||
|
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
if self.block:
|
if self.block:
|
||||||
|
|
@ -150,8 +166,8 @@ class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
|
||||||
self,
|
self,
|
||||||
"Pool reached maximum size and no more "
|
"Pool reached maximum size and no more "
|
||||||
"connections are allowed."
|
"connections are allowed."
|
||||||
) from None
|
)
|
||||||
# Oh well, we'll create a new connection then
|
pass # Oh well, we'll create a new connection then
|
||||||
|
|
||||||
return conn or self._new_conn()
|
return conn or self._new_conn()
|
||||||
|
|
||||||
|
|
@ -159,7 +175,7 @@ class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
|
||||||
class SSHHTTPAdapter(BaseHTTPAdapter):
|
class SSHHTTPAdapter(BaseHTTPAdapter):
|
||||||
|
|
||||||
__attrs__ = requests.adapters.HTTPAdapter.__attrs__ + [
|
__attrs__ = requests.adapters.HTTPAdapter.__attrs__ + [
|
||||||
'pools', 'timeout', 'ssh_client', 'ssh_params', 'max_pool_size'
|
'pool', 'timeout', 'ssh_client', 'ssh_params', 'max_pool_size'
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, base_url, timeout=60,
|
def __init__(self, base_url, timeout=60,
|
||||||
|
|
@ -176,10 +192,8 @@ class SSHHTTPAdapter(BaseHTTPAdapter):
|
||||||
self.ssh_host = base_url[len('ssh://'):]
|
self.ssh_host = base_url[len('ssh://'):]
|
||||||
|
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
|
self.pool = None
|
||||||
self.max_pool_size = max_pool_size
|
self.max_pool_size = max_pool_size
|
||||||
self.pools = RecentlyUsedContainer(
|
|
||||||
pool_connections, dispose_func=lambda p: p.close()
|
|
||||||
)
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def _create_paramiko_client(self, base_url):
|
def _create_paramiko_client(self, base_url):
|
||||||
|
|
@ -199,7 +213,7 @@ class SSHHTTPAdapter(BaseHTTPAdapter):
|
||||||
host_config = conf.lookup(base_url.hostname)
|
host_config = conf.lookup(base_url.hostname)
|
||||||
if 'proxycommand' in host_config:
|
if 'proxycommand' in host_config:
|
||||||
self.ssh_params["sock"] = paramiko.ProxyCommand(
|
self.ssh_params["sock"] = paramiko.ProxyCommand(
|
||||||
host_config['proxycommand']
|
self.ssh_conf['proxycommand']
|
||||||
)
|
)
|
||||||
if 'hostname' in host_config:
|
if 'hostname' in host_config:
|
||||||
self.ssh_params['hostname'] = host_config['hostname']
|
self.ssh_params['hostname'] = host_config['hostname']
|
||||||
|
|
@ -211,38 +225,23 @@ class SSHHTTPAdapter(BaseHTTPAdapter):
|
||||||
self.ssh_params['key_filename'] = host_config['identityfile']
|
self.ssh_params['key_filename'] = host_config['identityfile']
|
||||||
|
|
||||||
self.ssh_client.load_system_host_keys()
|
self.ssh_client.load_system_host_keys()
|
||||||
self.ssh_client.set_missing_host_key_policy(paramiko.RejectPolicy())
|
self.ssh_client.set_missing_host_key_policy(paramiko.WarningPolicy())
|
||||||
|
|
||||||
def _connect(self):
|
def _connect(self):
|
||||||
if self.ssh_client:
|
if self.ssh_client:
|
||||||
self.ssh_client.connect(**self.ssh_params)
|
self.ssh_client.connect(**self.ssh_params)
|
||||||
|
|
||||||
def get_connection(self, url, proxies=None):
|
def get_connection(self, url, proxies=None):
|
||||||
if not self.ssh_client:
|
if self.pool is None:
|
||||||
return SSHConnectionPool(
|
|
||||||
ssh_client=self.ssh_client,
|
|
||||||
timeout=self.timeout,
|
|
||||||
maxsize=self.max_pool_size,
|
|
||||||
host=self.ssh_host
|
|
||||||
)
|
|
||||||
with self.pools.lock:
|
|
||||||
pool = self.pools.get(url)
|
|
||||||
if pool:
|
|
||||||
return pool
|
|
||||||
|
|
||||||
# Connection is closed try a reconnect
|
|
||||||
if self.ssh_client and not self.ssh_client.get_transport():
|
if self.ssh_client and not self.ssh_client.get_transport():
|
||||||
self._connect()
|
self._connect()
|
||||||
|
self.pool = SSHConnectionPool(
|
||||||
pool = SSHConnectionPool(
|
|
||||||
ssh_client=self.ssh_client,
|
ssh_client=self.ssh_client,
|
||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
maxsize=self.max_pool_size,
|
maxsize=self.max_pool_size,
|
||||||
host=self.ssh_host
|
host=self.ssh_host
|
||||||
)
|
)
|
||||||
self.pools[url] = pool
|
return self.pool
|
||||||
|
|
||||||
return pool
|
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
super().close()
|
super().close()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue