mirror of https://github.com/docker/docker-py.git
Fix ssh connection - don't override the host and port of the http pool
Signed-off-by: aiordache <anca.iordache@docker.com>
This commit is contained in:
parent
6da140e26c
commit
f5531a94e1
|
@ -1,9 +1,9 @@
|
|||
import io
|
||||
import paramiko
|
||||
import requests.adapters
|
||||
import six
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import socket
|
||||
import subprocess
|
||||
|
||||
|
@ -23,40 +23,6 @@ except ImportError:
|
|||
RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer
|
||||
|
||||
|
||||
def create_paramiko_client(base_url):
|
||||
logging.getLogger("paramiko").setLevel(logging.WARNING)
|
||||
ssh_client = paramiko.SSHClient()
|
||||
base_url = six.moves.urllib_parse.urlparse(base_url)
|
||||
ssh_params = {
|
||||
"hostname": base_url.hostname,
|
||||
"port": base_url.port,
|
||||
"username": base_url.username
|
||||
}
|
||||
ssh_config_file = os.path.expanduser("~/.ssh/config")
|
||||
if os.path.exists(ssh_config_file):
|
||||
conf = paramiko.SSHConfig()
|
||||
with open(ssh_config_file) as f:
|
||||
conf.parse(f)
|
||||
host_config = conf.lookup(base_url.hostname)
|
||||
ssh_conf = host_config
|
||||
if 'proxycommand' in host_config:
|
||||
ssh_params["sock"] = paramiko.ProxyCommand(
|
||||
ssh_conf['proxycommand']
|
||||
)
|
||||
if 'hostname' in host_config:
|
||||
ssh_params['hostname'] = host_config['hostname']
|
||||
if 'identityfile' in host_config:
|
||||
ssh_params['key_filename'] = host_config['identityfile']
|
||||
if base_url.port is None and 'port' in host_config:
|
||||
ssh_params['port'] = ssh_conf['port']
|
||||
if base_url.username is None and 'user' in host_config:
|
||||
ssh_params['username'] = ssh_conf['user']
|
||||
|
||||
ssh_client.load_system_host_keys()
|
||||
ssh_client.set_missing_host_key_policy(paramiko.WarningPolicy())
|
||||
return ssh_client, ssh_params
|
||||
|
||||
|
||||
class SSHSocket(socket.socket):
|
||||
def __init__(self, host):
|
||||
super(SSHSocket, self).__init__(
|
||||
|
@ -80,7 +46,8 @@ class SSHSocket(socket.socket):
|
|||
' '.join(args),
|
||||
shell=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stdin=subprocess.PIPE)
|
||||
stdin=subprocess.PIPE,
|
||||
preexec_fn=lambda: signal.signal(signal.SIGINT, signal.SIG_IGN))
|
||||
|
||||
def _write(self, data):
|
||||
if not self.proc or self.proc.stdin.closed:
|
||||
|
@ -96,17 +63,18 @@ class SSHSocket(socket.socket):
|
|||
def send(self, data):
|
||||
return self._write(data)
|
||||
|
||||
def recv(self):
|
||||
def recv(self, n):
|
||||
if not self.proc:
|
||||
raise Exception('SSH subprocess not initiated.'
|
||||
'connect() must be called first.')
|
||||
return self.proc.stdout.read()
|
||||
return self.proc.stdout.read(n)
|
||||
|
||||
def makefile(self, mode):
|
||||
if not self.proc or self.proc.stdout.closed:
|
||||
buf = io.BytesIO()
|
||||
buf.write(b'\n\n')
|
||||
return buf
|
||||
if not self.proc:
|
||||
self.connect()
|
||||
if six.PY3:
|
||||
self.proc.stdout.channel = self
|
||||
|
||||
return self.proc.stdout
|
||||
|
||||
def close(self):
|
||||
|
@ -124,7 +92,7 @@ class SSHConnection(httplib.HTTPConnection, object):
|
|||
)
|
||||
self.ssh_transport = ssh_transport
|
||||
self.timeout = timeout
|
||||
self.host = host
|
||||
self.ssh_host = host
|
||||
|
||||
def connect(self):
|
||||
if self.ssh_transport:
|
||||
|
@ -132,7 +100,7 @@ class SSHConnection(httplib.HTTPConnection, object):
|
|||
sock.settimeout(self.timeout)
|
||||
sock.exec_command('docker system dial-stdio')
|
||||
else:
|
||||
sock = SSHSocket(self.host)
|
||||
sock = SSHSocket(self.ssh_host)
|
||||
sock.settimeout(self.timeout)
|
||||
sock.connect()
|
||||
|
||||
|
@ -147,16 +115,16 @@ class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
|
|||
'localhost', timeout=timeout, maxsize=maxsize
|
||||
)
|
||||
self.ssh_transport = None
|
||||
self.timeout = timeout
|
||||
if ssh_client:
|
||||
self.ssh_transport = ssh_client.get_transport()
|
||||
self.timeout = timeout
|
||||
self.host = host
|
||||
self.port = None
|
||||
self.ssh_host = host
|
||||
self.ssh_port = None
|
||||
if ':' in host:
|
||||
self.host, self.port = host.split(':')
|
||||
self.ssh_host, self.ssh_port = host.split(':')
|
||||
|
||||
def _new_conn(self):
|
||||
return SSHConnection(self.ssh_transport, self.timeout, self.host)
|
||||
return SSHConnection(self.ssh_transport, self.timeout, self.ssh_host)
|
||||
|
||||
# When re-using connections, urllib3 calls fileno() on our
|
||||
# SSH channel instance, quickly overloading our fd limit. To avoid this,
|
||||
|
@ -193,10 +161,10 @@ class SSHHTTPAdapter(BaseHTTPAdapter):
|
|||
shell_out=True):
|
||||
self.ssh_client = None
|
||||
if not shell_out:
|
||||
self.ssh_client, self.ssh_params = create_paramiko_client(base_url)
|
||||
self._create_paramiko_client(base_url)
|
||||
self._connect()
|
||||
base_url = base_url.lstrip('ssh://')
|
||||
self.host = base_url
|
||||
|
||||
self.ssh_host = base_url.lstrip('ssh://')
|
||||
self.timeout = timeout
|
||||
self.max_pool_size = max_pool_size
|
||||
self.pools = RecentlyUsedContainer(
|
||||
|
@ -204,11 +172,48 @@ class SSHHTTPAdapter(BaseHTTPAdapter):
|
|||
)
|
||||
super(SSHHTTPAdapter, self).__init__()
|
||||
|
||||
def _create_paramiko_client(self, base_url):
|
||||
logging.getLogger("paramiko").setLevel(logging.WARNING)
|
||||
self.ssh_client = paramiko.SSHClient()
|
||||
base_url = six.moves.urllib_parse.urlparse(base_url)
|
||||
self.ssh_params = {
|
||||
"hostname": base_url.hostname,
|
||||
"port": base_url.port,
|
||||
"username": base_url.username
|
||||
}
|
||||
ssh_config_file = os.path.expanduser("~/.ssh/config")
|
||||
if os.path.exists(ssh_config_file):
|
||||
conf = paramiko.SSHConfig()
|
||||
with open(ssh_config_file) as f:
|
||||
conf.parse(f)
|
||||
host_config = conf.lookup(base_url.hostname)
|
||||
self.ssh_conf = host_config
|
||||
if 'proxycommand' in host_config:
|
||||
self.ssh_params["sock"] = paramiko.ProxyCommand(
|
||||
self.ssh_conf['proxycommand']
|
||||
)
|
||||
if 'hostname' in host_config:
|
||||
self.ssh_params['hostname'] = host_config['hostname']
|
||||
if base_url.port is None and 'port' in host_config:
|
||||
self.ssh_params['port'] = self.ssh_conf['port']
|
||||
if base_url.username is None and 'user' in host_config:
|
||||
self.ssh_params['username'] = self.ssh_conf['user']
|
||||
|
||||
self.ssh_client.load_system_host_keys()
|
||||
self.ssh_client.set_missing_host_key_policy(paramiko.WarningPolicy())
|
||||
|
||||
def _connect(self):
|
||||
if self.ssh_client:
|
||||
self.ssh_client.connect(**self.ssh_params)
|
||||
|
||||
def get_connection(self, url, proxies=None):
|
||||
if not self.ssh_client:
|
||||
return SSHConnectionPool(
|
||||
ssh_client=self.ssh_client,
|
||||
timeout=self.timeout,
|
||||
maxsize=self.max_pool_size,
|
||||
host=self.ssh_host
|
||||
)
|
||||
with self.pools.lock:
|
||||
pool = self.pools.get(url)
|
||||
if pool:
|
||||
|
@ -222,7 +227,7 @@ class SSHHTTPAdapter(BaseHTTPAdapter):
|
|||
ssh_client=self.ssh_client,
|
||||
timeout=self.timeout,
|
||||
maxsize=self.max_pool_size,
|
||||
host=self.host
|
||||
host=self.ssh_host
|
||||
)
|
||||
self.pools[url] = pool
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ RUN apk add --no-cache \
|
|||
RUN ssh-keygen -A
|
||||
|
||||
# copy the test SSH config
|
||||
RUN echo "IgnoreUserKnownHosts yes" >> /etc/ssh/sshd_config && \
|
||||
RUN echo "IgnoreUserKnownHosts yes" > /etc/ssh/sshd_config && \
|
||||
echo "PubkeyAuthentication yes" >> /etc/ssh/sshd_config && \
|
||||
echo "PermitRootLogin yes" >> /etc/ssh/sshd_config
|
||||
|
||||
|
|
Loading…
Reference in New Issue