Shell out to SSH client for an ssh connection

Signed-off-by: aiordache <anca.iordache@docker.com>
This commit is contained in:
aiordache 2020-09-22 10:20:18 +02:00
parent 9d8cd023e8
commit dd8b9b7f10
3 changed files with 115 additions and 44 deletions

View File

@ -11,7 +11,7 @@ from .. import auth
from ..constants import (DEFAULT_NUM_POOLS, DEFAULT_NUM_POOLS_SSH, from ..constants import (DEFAULT_NUM_POOLS, DEFAULT_NUM_POOLS_SSH,
DEFAULT_TIMEOUT_SECONDS, DEFAULT_USER_AGENT, DEFAULT_TIMEOUT_SECONDS, DEFAULT_USER_AGENT,
IS_WINDOWS_PLATFORM, MINIMUM_DOCKER_API_VERSION, IS_WINDOWS_PLATFORM, MINIMUM_DOCKER_API_VERSION,
STREAM_HEADER_SIZE_BYTES) STREAM_HEADER_SIZE_BYTES, DEFAULT_SSH_CLIENT)
from ..errors import (DockerException, InvalidVersion, TLSParameterError, from ..errors import (DockerException, InvalidVersion, TLSParameterError,
create_api_error_from_http_exception) create_api_error_from_http_exception)
from ..tls import TLSConfig from ..tls import TLSConfig
@ -161,7 +161,8 @@ class APIClient(
elif base_url.startswith('ssh://'): elif base_url.startswith('ssh://'):
try: try:
self._custom_adapter = SSHHTTPAdapter( self._custom_adapter = SSHHTTPAdapter(
base_url, timeout, pool_connections=num_pools base_url, timeout, pool_connections=num_pools,
shell_out=DEFAULT_SSH_CLIENT
) )
except NameError: except NameError:
raise DockerException( raise DockerException(

View File

@ -40,3 +40,5 @@ DEFAULT_DATA_CHUNK_SIZE = 1024 * 2048
DEFAULT_SWARM_ADDR_POOL = ['10.0.0.0/8'] DEFAULT_SWARM_ADDR_POOL = ['10.0.0.0/8']
DEFAULT_SWARM_SUBNET_SIZE = 24 DEFAULT_SWARM_SUBNET_SIZE = 24
DEFAULT_SSH_CLIENT = True

View File

@ -3,6 +3,8 @@ import requests.adapters
import six import six
import logging import logging
import os import os
import socket
import subprocess
from docker.transport.basehttpadapter import BaseHTTPAdapter from docker.transport.basehttpadapter import BaseHTTPAdapter
from .. import constants from .. import constants
@ -20,30 +22,117 @@ except ImportError:
RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer
def create_paramiko_client(base_url):
logging.getLogger("paramiko").setLevel(logging.WARNING)
ssh_client = paramiko.SSHClient()
base_url = six.moves.urllib_parse.urlparse(base_url)
ssh_params = {
"hostname": base_url.hostname,
"port": base_url.port,
"username": base_url.username
}
ssh_config_file = os.path.expanduser("~/.ssh/config")
if os.path.exists(ssh_config_file):
conf = paramiko.SSHConfig()
with open(ssh_config_file) as f:
conf.parse(f)
host_config = conf.lookup(base_url.hostname)
ssh_conf = host_config
if 'proxycommand' in host_config:
ssh_params["sock"] = paramiko.ProxyCommand(
ssh_conf['proxycommand']
)
if 'hostname' in host_config:
ssh_params['hostname'] = host_config['hostname']
if 'identityfile' in host_config:
ssh_params['key_filename'] = host_config['identityfile']
if base_url.port is None and 'port' in host_config:
ssh_params['port'] = ssh_conf['port']
if base_url.username is None and 'user' in host_config:
ssh_params['username'] = ssh_conf['user']
ssh_client.load_system_host_keys()
ssh_client.set_missing_host_key_policy(paramiko.WarningPolicy())
return ssh_client
class SSHSocket(socket.socket):
def __init__(self, host):
super(SSHSocket, self).__init__(
socket.AF_INET, socket.SOCK_STREAM)
self.host = host
self.proc = None
def connect(self, **kwargs):
args = [
'ssh',
self.host,
'docker system dial-stdio'
]
self.proc = subprocess.Popen(
' '.join(args),
shell=True,
stdout=subprocess.PIPE,
stdin=subprocess.PIPE)
def sendall(self, msg):
if not self.proc or self.proc.stdin.closed:
raise Exception('SSH subprocess not initiated.'
'connect() must be called first.')
self.proc.stdin.write(msg)
self.proc.stdin.flush()
def recv(self):
if not self.proc:
raise Exception('SSH subprocess not initiated.'
'connect() must be called first.')
return self.proc.stdout.read()
def makefile(self, mode):
return self.proc.stdout
def close(self):
if not self.proc:
return
self.proc.stdin.write(b'\n\n')
self.proc.stdin.flush()
self.proc.terminate()
class SSHConnection(httplib.HTTPConnection, object): class SSHConnection(httplib.HTTPConnection, object):
def __init__(self, ssh_transport, timeout=60): def __init__(self, ssh_transport=None, timeout=60, host=None):
super(SSHConnection, self).__init__( super(SSHConnection, self).__init__(
'localhost', timeout=timeout 'localhost', timeout=timeout
) )
self.ssh_transport = ssh_transport self.ssh_transport = ssh_transport
self.timeout = timeout self.timeout = timeout
self.host = host
def connect(self): def connect(self):
if self.ssh_transport:
sock = self.ssh_transport.open_session() sock = self.ssh_transport.open_session()
sock.settimeout(self.timeout) sock.settimeout(self.timeout)
sock.exec_command('docker system dial-stdio') sock.exec_command('docker system dial-stdio')
else:
sock = SSHSocket(self.host)
sock.settimeout(self.timeout)
sock.connect()
self.sock = sock self.sock = sock
class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool): class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
scheme = 'ssh' scheme = 'ssh'
def __init__(self, ssh_client, timeout=60, maxsize=10): def __init__(self, ssh_client=None, timeout=60, maxsize=10, host=None):
super(SSHConnectionPool, self).__init__( super(SSHConnectionPool, self).__init__(
'localhost', timeout=timeout, maxsize=maxsize 'localhost', timeout=timeout, maxsize=maxsize
) )
self.ssh_transport = None
if ssh_client:
self.ssh_transport = ssh_client.get_transport() self.ssh_transport = ssh_client.get_transport()
self.timeout = timeout self.timeout = timeout
self.host = host
def _new_conn(self): def _new_conn(self):
return SSHConnection(self.ssh_transport, self.timeout) return SSHConnection(self.ssh_transport, self.timeout)
@ -78,39 +167,14 @@ class SSHHTTPAdapter(BaseHTTPAdapter):
] ]
def __init__(self, base_url, timeout=60, def __init__(self, base_url, timeout=60,
pool_connections=constants.DEFAULT_NUM_POOLS): pool_connections=constants.DEFAULT_NUM_POOLS,
logging.getLogger("paramiko").setLevel(logging.WARNING) shell_out=True):
self.ssh_client = paramiko.SSHClient() self.ssh_client = None
base_url = six.moves.urllib_parse.urlparse(base_url) if not shell_out:
self.ssh_params = { self.ssh_client = create_paramiko_client(base_url)
"hostname": base_url.hostname,
"port": base_url.port,
"username": base_url.username
}
ssh_config_file = os.path.expanduser("~/.ssh/config")
if os.path.exists(ssh_config_file):
conf = paramiko.SSHConfig()
with open(ssh_config_file) as f:
conf.parse(f)
host_config = conf.lookup(base_url.hostname)
self.ssh_conf = host_config
if 'proxycommand' in host_config:
self.ssh_params["sock"] = paramiko.ProxyCommand(
self.ssh_conf['proxycommand']
)
if 'hostname' in host_config:
self.ssh_params['hostname'] = host_config['hostname']
if 'identityfile' in host_config:
self.ssh_params['key_filename'] = host_config['identityfile']
if base_url.port is None and 'port' in host_config:
self.ssh_params['port'] = self.ssh_conf['port']
if base_url.username is None and 'user' in host_config:
self.ssh_params['username'] = self.ssh_conf['user']
self.ssh_client.load_system_host_keys()
self.ssh_client.set_missing_host_key_policy(paramiko.WarningPolicy())
self._connect() self._connect()
self.host = base_url
self.timeout = timeout self.timeout = timeout
self.pools = RecentlyUsedContainer( self.pools = RecentlyUsedContainer(
pool_connections, dispose_func=lambda p: p.close() pool_connections, dispose_func=lambda p: p.close()
@ -118,6 +182,7 @@ class SSHHTTPAdapter(BaseHTTPAdapter):
super(SSHHTTPAdapter, self).__init__() super(SSHHTTPAdapter, self).__init__()
def _connect(self): def _connect(self):
if self.ssh_client:
self.ssh_client.connect(**self.ssh_params) self.ssh_client.connect(**self.ssh_params)
def get_connection(self, url, proxies=None): def get_connection(self, url, proxies=None):
@ -127,11 +192,13 @@ class SSHHTTPAdapter(BaseHTTPAdapter):
return pool return pool
# Connection is closed try a reconnect # Connection is closed try a reconnect
if not self.ssh_client.get_transport(): if self.ssh_client and not self.ssh_client.get_transport():
self._connect() self._connect()
pool = SSHConnectionPool( pool = SSHConnectionPool(
self.ssh_client, self.timeout ssh_client=self.ssh_client,
timeout=self.timeout,
host=self.host
) )
self.pools[url] = pool self.pools[url] = pool
@ -139,4 +206,5 @@ class SSHHTTPAdapter(BaseHTTPAdapter):
def close(self): def close(self):
super(SSHHTTPAdapter, self).close() super(SSHHTTPAdapter, self).close()
if self.ssh_client:
self.ssh_client.close() self.ssh_client.close()