diff --git a/docker/api/client.py b/docker/api/client.py index a2cb459d..60971bdf 100644 --- a/docker/api/client.py +++ b/docker/api/client.py @@ -1,10 +1,15 @@ import json import struct import urllib +import ssl from functools import partial +from typing import Any, AnyStr, Optional, Union, Dict, overload, NoReturn, Iterator + +from typing_extensions import Literal import requests import requests.exceptions +import requests.adapters import websocket from .. import auth @@ -20,6 +25,7 @@ from ..utils import check_resource, config, update_headers, utils from ..utils.json_stream import json_stream from ..utils.proxy import ProxyConfig from ..utils.socket import consume_socket_output, demux_adaptor, frames_iter +from ..utils.typing import BytesOrDict from .build import BuildApiMixin from .config import ConfigApiMixin from .container import ContainerApiMixin @@ -102,11 +108,11 @@ class APIClient( 'base_url', 'timeout'] - def __init__(self, base_url=None, version=None, - timeout=DEFAULT_TIMEOUT_SECONDS, tls=False, - user_agent=DEFAULT_USER_AGENT, num_pools=None, - credstore_env=None, use_ssh_client=False, - max_pool_size=DEFAULT_MAX_POOL_SIZE): + def __init__(self, base_url: Optional[str] = None, version: Optional[str] = None, + timeout: int = DEFAULT_TIMEOUT_SECONDS, tls: Optional[Union[bool, TLSConfig]] = False, + user_agent: str = DEFAULT_USER_AGENT, num_pools: Optional[int] = None, + credstore_env: Optional[Dict[str, Any]] = None, use_ssh_client: bool = False, + max_pool_size: int = DEFAULT_MAX_POOL_SIZE) -> None: super().__init__() if tls and not base_url: @@ -208,7 +214,7 @@ class APIClient( f'no longer supported by this library.' ) - def _retrieve_server_version(self): + def _retrieve_server_version(self) -> str: try: return self.version(api_version=False)["ApiVersion"] except KeyError as ke: @@ -221,29 +227,29 @@ class APIClient( f'Error while fetching server API version: {e}' ) from e - def _set_request_timeout(self, kwargs): + def _set_request_timeout(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: """Prepare the kwargs for an HTTP request by inserting the timeout parameter, if not already present.""" kwargs.setdefault('timeout', self.timeout) return kwargs @update_headers - def _post(self, url, **kwargs): + def _post(self, url: str, **kwargs: Any) -> requests.Response: return self.post(url, **self._set_request_timeout(kwargs)) @update_headers - def _get(self, url, **kwargs): + def _get(self, url: str, **kwargs: Any) -> requests.Response: return self.get(url, **self._set_request_timeout(kwargs)) @update_headers - def _put(self, url, **kwargs): + def _put(self, url: str, **kwargs: Any) -> requests.Response: return self.put(url, **self._set_request_timeout(kwargs)) @update_headers - def _delete(self, url, **kwargs): + def _delete(self, url: str, **kwargs: Any) -> requests.Response: return self.delete(url, **self._set_request_timeout(kwargs)) - def _url(self, pathfmt, *args, **kwargs): + def _url(self, pathfmt: str, *args: str, **kwargs: Any) -> str: for arg in args: if not isinstance(arg, str): raise ValueError( @@ -259,14 +265,30 @@ class APIClient( else: return f'{self.base_url}{formatted_path}' - def _raise_for_status(self, response): + def _raise_for_status(self, response: requests.Response) -> None: """Raises stored :class:`APIError`, if one occurred.""" try: response.raise_for_status() except requests.exceptions.HTTPError as e: raise create_api_error_from_http_exception(e) from e - def _result(self, response, json=False, binary=False): + @overload + def _result(self, response: requests.Response, json: Literal[True], binary: Literal[True]) -> NoReturn: + ... + + @overload + def _result(self, response: requests.Response, json: Literal[False] = ..., binary: Literal[False] = ...) -> str: + ... + + @overload + def _result(self, response: requests.Response, json: Literal[True], binary: bool = ...) -> Any: + ... + + @overload + def _result(self, response: requests.Response, json: bool = ..., binary: Literal[True] = ...) -> bytes: + ... + + def _result(self, response: requests.Response, json: bool = False, binary: bool = False) -> Any: assert not (json and binary) self._raise_for_status(response) @@ -276,7 +298,7 @@ class APIClient( return response.content return response.text - def _post_json(self, url, data, **kwargs): + def _post_json(self, url: str, data: Dict[str, Any], **kwargs: Any) -> requests.Response: # Go <1.1 can't unserialize null to a string # so we do this disgusting thing here. data2 = {} @@ -292,7 +314,7 @@ class APIClient( kwargs['headers']['Content-Type'] = 'application/json' return self._post(url, data=json.dumps(data2), **kwargs) - def _attach_params(self, override=None): + def _attach_params(self, override: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: return override or { 'stdout': 1, 'stderr': 1, @@ -300,7 +322,7 @@ class APIClient( } @check_resource('container') - def _attach_websocket(self, container, params=None): + def _attach_websocket(self, container: str, params: Optional[Dict[str, Any]] = None) -> websocket.WebSocket: url = self._url("/containers/{0}/attach/ws", container) req = requests.Request("POST", url, params=self._attach_params(params)) full_url = req.prepare().url @@ -308,10 +330,10 @@ class APIClient( full_url = full_url.replace("https://", "wss://", 1) return self._create_websocket_connection(full_url) - def _create_websocket_connection(self, url): + def _create_websocket_connection(self, url: str) -> websocket.WebSocket: return websocket.create_connection(url) - def _get_raw_response_socket(self, response): + def _get_raw_response_socket(self, response: requests.Response) -> ssl.SSLSocket: self._raise_for_status(response) if self.base_url == "http+docker://localnpipe": sock = response.raw._fp.fp.raw.sock @@ -333,7 +355,15 @@ class APIClient( return sock - def _stream_helper(self, response, decode=False): + @overload + def _stream_helper(self, response: requests.Response, decode: Literal[True]) -> Iterator[Dict[str, Any]]: + ... + + @overload + def _stream_helper(self, response: requests.Response, decode: Literal[False] = ...) -> Iterator[bytes]: + ... + + def _stream_helper(self, response: requests.Response, decode: bool = False) -> Iterator[BytesOrDict]: """Generator for data coming from a chunked-encoded HTTP response.""" if response.raw._fp.chunked: @@ -354,7 +384,7 @@ class APIClient( # encountered an error immediately yield self._result(response, json=decode) - def _multiplexed_buffer_helper(self, response): + def _multiplexed_buffer_helper(self, response: requests.Response) -> Iterator[bytes]: """A generator of multiplexed data blocks read from a buffered response.""" buf = self._result(response, binary=True) @@ -370,7 +400,7 @@ class APIClient( walker = end yield buf[start:end] - def _multiplexed_response_stream_helper(self, response): + def _multiplexed_response_stream_helper(self, response: requests.Response) -> Iterator[bytes]: """A generator of multiplexed data blocks coming from a response stream.""" @@ -391,7 +421,15 @@ class APIClient( break yield data - def _stream_raw_result(self, response, chunk_size=1, decode=True): + @overload + def _stream_raw_result(self, response: requests.Response, chunk_size: int = ..., decode: Literal[False] = ...) -> Iterator[bytes]: + ... + + @overload + def _stream_raw_result(self, response: requests.Response, chunk_size: int = ..., decode: Literal[True] = ...) -> Iterator[str]: + ... + + def _stream_raw_result(self, response: requests.Response, chunk_size: int = 1, decode: bool = True) -> Iterator[AnyStr]: ''' Stream result for TTY-enabled container and raw binary data''' self._raise_for_status(response) @@ -402,7 +440,7 @@ class APIClient( yield from response.iter_content(chunk_size, decode) - def _read_from_socket(self, response, stream, tty=True, demux=False): + def _read_from_socket(self, response: requests.Response, stream: bool, tty: bool = True, demux: bool = False) -> Any: """Consume all data from the socket, close the response and return the data. If stream=True, then a generator is returned instead and the caller is responsible for closing the response. @@ -427,7 +465,7 @@ class APIClient( finally: response.close() - def _disable_socket_timeout(self, socket): + def _disable_socket_timeout(self, socket: ssl.SSLSocket) -> None: """ Depending on the combination of python version and whether we're connecting over http or https, we might need to access _sock, which may or may not exist; or we may need to just settimeout on socket @@ -456,14 +494,22 @@ class APIClient( s.settimeout(None) @check_resource('container') - def _check_is_tty(self, container): + def _check_is_tty(self, container: str) -> bool: cont = self.inspect_container(container) return cont['Config']['Tty'] - def _get_result(self, container, stream, res): + def _get_result(self, container: str, stream: bool, res: requests.Response): return self._get_result_tty(stream, res, self._check_is_tty(container)) - def _get_result_tty(self, stream, res, is_tty): + @overload + def _get_result_tty(self, stream: Literal[True], res: requests.Response, is_tty: bool) -> Iterator[bytes]: + ... + + @overload + def _get_result_tty(self, stream: Literal[False], res: requests.Response, is_tty: bool) -> bytes: + ... + + def _get_result_tty(self, stream: bool, res: requests.Response, is_tty: bool): # We should also use raw streaming (without keep-alives) # if we're dealing with a tty-enabled container. if is_tty: @@ -479,11 +525,11 @@ class APIClient( list(self._multiplexed_buffer_helper(res)) ) - def _unmount(self, *args): + def _unmount(self, *args: str) -> None: for proto in args: self.adapters.pop(proto) - def get_adapter(self, url): + def get_adapter(self, url: str) -> requests.adapters.BaseAdapter: try: return super().get_adapter(url) except requests.exceptions.InvalidSchema as e: @@ -493,10 +539,10 @@ class APIClient( raise e @property - def api_version(self): + def api_version(self) -> str: return self._version - def reload_config(self, dockercfg_path=None): + def reload_config(self, dockercfg_path: Optional[str] = None) -> None: """ Force a reload of the auth configuration diff --git a/docker/utils/typing.py b/docker/utils/typing.py new file mode 100644 index 00000000..fe21dcdf --- /dev/null +++ b/docker/utils/typing.py @@ -0,0 +1,3 @@ +from typing import Any, Dict, TypeVar + +BytesOrDict = TypeVar("BytesOrDict", bytes, Dict[str, Any]) diff --git a/requirements.txt b/requirements.txt index 897cdbd5..9b248383 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ pywin32==304; sys_platform == 'win32' requests==2.31.0 urllib3==1.26.11 websocket-client==1.3.3 +typing_extensions>=3.10.0.0