Fix ssh connection - don't override the host and port of the http pool

Signed-off-by: aiordache <anca.iordache@docker.com>
This commit is contained in:
aiordache 2020-10-20 10:05:07 +02:00
parent 6da140e26c
commit f5531a94e1
2 changed files with 59 additions and 54 deletions

View File

@ -1,9 +1,9 @@
import io
import paramiko import paramiko
import requests.adapters import requests.adapters
import six import six
import logging import logging
import os import os
import signal
import socket import socket
import subprocess import subprocess
@ -23,40 +23,6 @@ except ImportError:
RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer
def create_paramiko_client(base_url):
logging.getLogger("paramiko").setLevel(logging.WARNING)
ssh_client = paramiko.SSHClient()
base_url = six.moves.urllib_parse.urlparse(base_url)
ssh_params = {
"hostname": base_url.hostname,
"port": base_url.port,
"username": base_url.username
}
ssh_config_file = os.path.expanduser("~/.ssh/config")
if os.path.exists(ssh_config_file):
conf = paramiko.SSHConfig()
with open(ssh_config_file) as f:
conf.parse(f)
host_config = conf.lookup(base_url.hostname)
ssh_conf = host_config
if 'proxycommand' in host_config:
ssh_params["sock"] = paramiko.ProxyCommand(
ssh_conf['proxycommand']
)
if 'hostname' in host_config:
ssh_params['hostname'] = host_config['hostname']
if 'identityfile' in host_config:
ssh_params['key_filename'] = host_config['identityfile']
if base_url.port is None and 'port' in host_config:
ssh_params['port'] = ssh_conf['port']
if base_url.username is None and 'user' in host_config:
ssh_params['username'] = ssh_conf['user']
ssh_client.load_system_host_keys()
ssh_client.set_missing_host_key_policy(paramiko.WarningPolicy())
return ssh_client, ssh_params
class SSHSocket(socket.socket): class SSHSocket(socket.socket):
def __init__(self, host): def __init__(self, host):
super(SSHSocket, self).__init__( super(SSHSocket, self).__init__(
@ -80,7 +46,8 @@ class SSHSocket(socket.socket):
' '.join(args), ' '.join(args),
shell=True, shell=True,
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stdin=subprocess.PIPE) stdin=subprocess.PIPE,
preexec_fn=lambda: signal.signal(signal.SIGINT, signal.SIG_IGN))
def _write(self, data): def _write(self, data):
if not self.proc or self.proc.stdin.closed: if not self.proc or self.proc.stdin.closed:
@ -96,17 +63,18 @@ class SSHSocket(socket.socket):
def send(self, data): def send(self, data):
return self._write(data) return self._write(data)
def recv(self): def recv(self, n):
if not self.proc: if not self.proc:
raise Exception('SSH subprocess not initiated.' raise Exception('SSH subprocess not initiated.'
'connect() must be called first.') 'connect() must be called first.')
return self.proc.stdout.read() return self.proc.stdout.read(n)
def makefile(self, mode): def makefile(self, mode):
if not self.proc or self.proc.stdout.closed: if not self.proc:
buf = io.BytesIO() self.connect()
buf.write(b'\n\n') if six.PY3:
return buf self.proc.stdout.channel = self
return self.proc.stdout return self.proc.stdout
def close(self): def close(self):
@ -124,7 +92,7 @@ class SSHConnection(httplib.HTTPConnection, object):
) )
self.ssh_transport = ssh_transport self.ssh_transport = ssh_transport
self.timeout = timeout self.timeout = timeout
self.host = host self.ssh_host = host
def connect(self): def connect(self):
if self.ssh_transport: if self.ssh_transport:
@ -132,7 +100,7 @@ class SSHConnection(httplib.HTTPConnection, object):
sock.settimeout(self.timeout) sock.settimeout(self.timeout)
sock.exec_command('docker system dial-stdio') sock.exec_command('docker system dial-stdio')
else: else:
sock = SSHSocket(self.host) sock = SSHSocket(self.ssh_host)
sock.settimeout(self.timeout) sock.settimeout(self.timeout)
sock.connect() sock.connect()
@ -147,16 +115,16 @@ class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
'localhost', timeout=timeout, maxsize=maxsize 'localhost', timeout=timeout, maxsize=maxsize
) )
self.ssh_transport = None self.ssh_transport = None
self.timeout = timeout
if ssh_client: if ssh_client:
self.ssh_transport = ssh_client.get_transport() self.ssh_transport = ssh_client.get_transport()
self.timeout = timeout self.ssh_host = host
self.host = host self.ssh_port = None
self.port = None
if ':' in host: if ':' in host:
self.host, self.port = host.split(':') self.ssh_host, self.ssh_port = host.split(':')
def _new_conn(self): def _new_conn(self):
return SSHConnection(self.ssh_transport, self.timeout, self.host) return SSHConnection(self.ssh_transport, self.timeout, self.ssh_host)
# When re-using connections, urllib3 calls fileno() on our # When re-using connections, urllib3 calls fileno() on our
# SSH channel instance, quickly overloading our fd limit. To avoid this, # SSH channel instance, quickly overloading our fd limit. To avoid this,
@ -193,10 +161,10 @@ class SSHHTTPAdapter(BaseHTTPAdapter):
shell_out=True): shell_out=True):
self.ssh_client = None self.ssh_client = None
if not shell_out: if not shell_out:
self.ssh_client, self.ssh_params = create_paramiko_client(base_url) self._create_paramiko_client(base_url)
self._connect() self._connect()
base_url = base_url.lstrip('ssh://')
self.host = base_url self.ssh_host = base_url.lstrip('ssh://')
self.timeout = timeout self.timeout = timeout
self.max_pool_size = max_pool_size self.max_pool_size = max_pool_size
self.pools = RecentlyUsedContainer( self.pools = RecentlyUsedContainer(
@ -204,11 +172,48 @@ class SSHHTTPAdapter(BaseHTTPAdapter):
) )
super(SSHHTTPAdapter, self).__init__() super(SSHHTTPAdapter, self).__init__()
def _create_paramiko_client(self, base_url):
logging.getLogger("paramiko").setLevel(logging.WARNING)
self.ssh_client = paramiko.SSHClient()
base_url = six.moves.urllib_parse.urlparse(base_url)
self.ssh_params = {
"hostname": base_url.hostname,
"port": base_url.port,
"username": base_url.username
}
ssh_config_file = os.path.expanduser("~/.ssh/config")
if os.path.exists(ssh_config_file):
conf = paramiko.SSHConfig()
with open(ssh_config_file) as f:
conf.parse(f)
host_config = conf.lookup(base_url.hostname)
self.ssh_conf = host_config
if 'proxycommand' in host_config:
self.ssh_params["sock"] = paramiko.ProxyCommand(
self.ssh_conf['proxycommand']
)
if 'hostname' in host_config:
self.ssh_params['hostname'] = host_config['hostname']
if base_url.port is None and 'port' in host_config:
self.ssh_params['port'] = self.ssh_conf['port']
if base_url.username is None and 'user' in host_config:
self.ssh_params['username'] = self.ssh_conf['user']
self.ssh_client.load_system_host_keys()
self.ssh_client.set_missing_host_key_policy(paramiko.WarningPolicy())
def _connect(self): def _connect(self):
if self.ssh_client: if self.ssh_client:
self.ssh_client.connect(**self.ssh_params) self.ssh_client.connect(**self.ssh_params)
def get_connection(self, url, proxies=None): def get_connection(self, url, proxies=None):
if not self.ssh_client:
return SSHConnectionPool(
ssh_client=self.ssh_client,
timeout=self.timeout,
maxsize=self.max_pool_size,
host=self.ssh_host
)
with self.pools.lock: with self.pools.lock:
pool = self.pools.get(url) pool = self.pools.get(url)
if pool: if pool:
@ -222,7 +227,7 @@ class SSHHTTPAdapter(BaseHTTPAdapter):
ssh_client=self.ssh_client, ssh_client=self.ssh_client,
timeout=self.timeout, timeout=self.timeout,
maxsize=self.max_pool_size, maxsize=self.max_pool_size,
host=self.host host=self.ssh_host
) )
self.pools[url] = pool self.pools[url] = pool

View File

@ -10,7 +10,7 @@ RUN apk add --no-cache \
RUN ssh-keygen -A RUN ssh-keygen -A
# copy the test SSH config # copy the test SSH config
RUN echo "IgnoreUserKnownHosts yes" >> /etc/ssh/sshd_config && \ RUN echo "IgnoreUserKnownHosts yes" > /etc/ssh/sshd_config && \
echo "PubkeyAuthentication yes" >> /etc/ssh/sshd_config && \ echo "PubkeyAuthentication yes" >> /etc/ssh/sshd_config && \
echo "PermitRootLogin yes" >> /etc/ssh/sshd_config echo "PermitRootLogin yes" >> /etc/ssh/sshd_config