diff --git a/docker/api/client.py b/docker/api/client.py index 43e309b5..d9125349 100644 --- a/docker/api/client.py +++ b/docker/api/client.py @@ -11,7 +11,7 @@ from .. import auth from ..constants import (DEFAULT_NUM_POOLS, DEFAULT_NUM_POOLS_SSH, DEFAULT_TIMEOUT_SECONDS, DEFAULT_USER_AGENT, IS_WINDOWS_PLATFORM, MINIMUM_DOCKER_API_VERSION, - STREAM_HEADER_SIZE_BYTES) + STREAM_HEADER_SIZE_BYTES, DEFAULT_SSH_CLIENT) from ..errors import (DockerException, InvalidVersion, TLSParameterError, create_api_error_from_http_exception) from ..tls import TLSConfig @@ -161,7 +161,8 @@ class APIClient( elif base_url.startswith('ssh://'): try: self._custom_adapter = SSHHTTPAdapter( - base_url, timeout, pool_connections=num_pools + base_url, timeout, pool_connections=num_pools, + shell_out=DEFAULT_SSH_CLIENT ) except NameError: raise DockerException( diff --git a/docker/constants.py b/docker/constants.py index c09eedab..5ff549f9 100644 --- a/docker/constants.py +++ b/docker/constants.py @@ -40,3 +40,5 @@ DEFAULT_DATA_CHUNK_SIZE = 1024 * 2048 DEFAULT_SWARM_ADDR_POOL = ['10.0.0.0/8'] DEFAULT_SWARM_SUBNET_SIZE = 24 + +DEFAULT_SSH_CLIENT = True diff --git a/docker/transport/sshconn.py b/docker/transport/sshconn.py index 9cfd9980..37f1a7ba 100644 --- a/docker/transport/sshconn.py +++ b/docker/transport/sshconn.py @@ -3,6 +3,8 @@ import requests.adapters import six import logging import os +import socket +import subprocess from docker.transport.basehttpadapter import BaseHTTPAdapter from .. import constants @@ -20,30 +22,117 @@ except ImportError: RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer +def create_paramiko_client(base_url): + logging.getLogger("paramiko").setLevel(logging.WARNING) + ssh_client = paramiko.SSHClient() + base_url = six.moves.urllib_parse.urlparse(base_url) + ssh_params = { + "hostname": base_url.hostname, + "port": base_url.port, + "username": base_url.username + } + ssh_config_file = os.path.expanduser("~/.ssh/config") + if os.path.exists(ssh_config_file): + conf = paramiko.SSHConfig() + with open(ssh_config_file) as f: + conf.parse(f) + host_config = conf.lookup(base_url.hostname) + ssh_conf = host_config + if 'proxycommand' in host_config: + ssh_params["sock"] = paramiko.ProxyCommand( + ssh_conf['proxycommand'] + ) + if 'hostname' in host_config: + ssh_params['hostname'] = host_config['hostname'] + if 'identityfile' in host_config: + ssh_params['key_filename'] = host_config['identityfile'] + if base_url.port is None and 'port' in host_config: + ssh_params['port'] = ssh_conf['port'] + if base_url.username is None and 'user' in host_config: + ssh_params['username'] = ssh_conf['user'] + + ssh_client.load_system_host_keys() + ssh_client.set_missing_host_key_policy(paramiko.WarningPolicy()) + return ssh_client + + +class SSHSocket(socket.socket): + def __init__(self, host): + super(SSHSocket, self).__init__( + socket.AF_INET, socket.SOCK_STREAM) + self.host = host + self.proc = None + + def connect(self, **kwargs): + args = [ + 'ssh', + self.host, + 'docker system dial-stdio' + ] + self.proc = subprocess.Popen( + ' '.join(args), + shell=True, + stdout=subprocess.PIPE, + stdin=subprocess.PIPE) + + def sendall(self, msg): + if not self.proc or self.proc.stdin.closed: + raise Exception('SSH subprocess not initiated.' + 'connect() must be called first.') + self.proc.stdin.write(msg) + self.proc.stdin.flush() + + def recv(self): + if not self.proc: + raise Exception('SSH subprocess not initiated.' + 'connect() must be called first.') + return self.proc.stdout.read() + + def makefile(self, mode): + return self.proc.stdout + + def close(self): + if not self.proc: + return + self.proc.stdin.write(b'\n\n') + self.proc.stdin.flush() + self.proc.terminate() + + class SSHConnection(httplib.HTTPConnection, object): - def __init__(self, ssh_transport, timeout=60): + def __init__(self, ssh_transport=None, timeout=60, host=None): super(SSHConnection, self).__init__( 'localhost', timeout=timeout ) self.ssh_transport = ssh_transport self.timeout = timeout + self.host = host def connect(self): - sock = self.ssh_transport.open_session() - sock.settimeout(self.timeout) - sock.exec_command('docker system dial-stdio') + if self.ssh_transport: + sock = self.ssh_transport.open_session() + sock.settimeout(self.timeout) + sock.exec_command('docker system dial-stdio') + else: + sock = SSHSocket(self.host) + sock.settimeout(self.timeout) + sock.connect() + self.sock = sock class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool): scheme = 'ssh' - def __init__(self, ssh_client, timeout=60, maxsize=10): + def __init__(self, ssh_client=None, timeout=60, maxsize=10, host=None): super(SSHConnectionPool, self).__init__( 'localhost', timeout=timeout, maxsize=maxsize ) - self.ssh_transport = ssh_client.get_transport() + self.ssh_transport = None + if ssh_client: + self.ssh_transport = ssh_client.get_transport() self.timeout = timeout + self.host = host def _new_conn(self): return SSHConnection(self.ssh_transport, self.timeout) @@ -78,39 +167,14 @@ class SSHHTTPAdapter(BaseHTTPAdapter): ] def __init__(self, base_url, timeout=60, - pool_connections=constants.DEFAULT_NUM_POOLS): - logging.getLogger("paramiko").setLevel(logging.WARNING) - self.ssh_client = paramiko.SSHClient() - base_url = six.moves.urllib_parse.urlparse(base_url) - self.ssh_params = { - "hostname": base_url.hostname, - "port": base_url.port, - "username": base_url.username - } - ssh_config_file = os.path.expanduser("~/.ssh/config") - if os.path.exists(ssh_config_file): - conf = paramiko.SSHConfig() - with open(ssh_config_file) as f: - conf.parse(f) - host_config = conf.lookup(base_url.hostname) - self.ssh_conf = host_config - if 'proxycommand' in host_config: - self.ssh_params["sock"] = paramiko.ProxyCommand( - self.ssh_conf['proxycommand'] - ) - if 'hostname' in host_config: - self.ssh_params['hostname'] = host_config['hostname'] - if 'identityfile' in host_config: - self.ssh_params['key_filename'] = host_config['identityfile'] - if base_url.port is None and 'port' in host_config: - self.ssh_params['port'] = self.ssh_conf['port'] - if base_url.username is None and 'user' in host_config: - self.ssh_params['username'] = self.ssh_conf['user'] + pool_connections=constants.DEFAULT_NUM_POOLS, + shell_out=True): + self.ssh_client = None + if not shell_out: + self.ssh_client = create_paramiko_client(base_url) + self._connect() - self.ssh_client.load_system_host_keys() - self.ssh_client.set_missing_host_key_policy(paramiko.WarningPolicy()) - - self._connect() + self.host = base_url self.timeout = timeout self.pools = RecentlyUsedContainer( pool_connections, dispose_func=lambda p: p.close() @@ -118,7 +182,8 @@ class SSHHTTPAdapter(BaseHTTPAdapter): super(SSHHTTPAdapter, self).__init__() def _connect(self): - self.ssh_client.connect(**self.ssh_params) + if self.ssh_client: + self.ssh_client.connect(**self.ssh_params) def get_connection(self, url, proxies=None): with self.pools.lock: @@ -127,11 +192,13 @@ class SSHHTTPAdapter(BaseHTTPAdapter): return pool # Connection is closed try a reconnect - if not self.ssh_client.get_transport(): + if self.ssh_client and not self.ssh_client.get_transport(): self._connect() pool = SSHConnectionPool( - self.ssh_client, self.timeout + ssh_client=self.ssh_client, + timeout=self.timeout, + host=self.host ) self.pools[url] = pool @@ -139,4 +206,5 @@ class SSHHTTPAdapter(BaseHTTPAdapter): def close(self): super(SSHHTTPAdapter, self).close() - self.ssh_client.close() + if self.ssh_client: + self.ssh_client.close()