types: Add types to `api/client.py`

Signed-off-by: Victorien Plot <65306057+Viicos@users.noreply.github.com>
Signed-off-by: Viicos <65306057+Viicos@users.noreply.github.com>
This commit is contained in:
Victorien Plot 2023-01-03 12:21:07 +01:00 committed by Viicos
parent 6db8e694db
commit b1603f34b2
3 changed files with 82 additions and 32 deletions

View File

@ -1,10 +1,15 @@
import json import json
import struct import struct
import urllib import urllib
import ssl
from functools import partial from functools import partial
from typing import Any, AnyStr, Optional, Union, Dict, overload, NoReturn, Iterator
from typing_extensions import Literal
import requests import requests
import requests.exceptions import requests.exceptions
import requests.adapters
import websocket import websocket
from .. import auth from .. import auth
@ -20,6 +25,7 @@ from ..utils import check_resource, config, update_headers, utils
from ..utils.json_stream import json_stream from ..utils.json_stream import json_stream
from ..utils.proxy import ProxyConfig from ..utils.proxy import ProxyConfig
from ..utils.socket import consume_socket_output, demux_adaptor, frames_iter from ..utils.socket import consume_socket_output, demux_adaptor, frames_iter
from ..utils.typing import BytesOrDict
from .build import BuildApiMixin from .build import BuildApiMixin
from .config import ConfigApiMixin from .config import ConfigApiMixin
from .container import ContainerApiMixin from .container import ContainerApiMixin
@ -102,11 +108,11 @@ class APIClient(
'base_url', 'base_url',
'timeout'] 'timeout']
def __init__(self, base_url=None, version=None, def __init__(self, base_url: Optional[str] = None, version: Optional[str] = None,
timeout=DEFAULT_TIMEOUT_SECONDS, tls=False, timeout: int = DEFAULT_TIMEOUT_SECONDS, tls: Optional[Union[bool, TLSConfig]] = False,
user_agent=DEFAULT_USER_AGENT, num_pools=None, user_agent: str = DEFAULT_USER_AGENT, num_pools: Optional[int] = None,
credstore_env=None, use_ssh_client=False, credstore_env: Optional[Dict[str, Any]] = None, use_ssh_client: bool = False,
max_pool_size=DEFAULT_MAX_POOL_SIZE): max_pool_size: int = DEFAULT_MAX_POOL_SIZE) -> None:
super().__init__() super().__init__()
if tls and not base_url: if tls and not base_url:
@ -208,7 +214,7 @@ class APIClient(
f'no longer supported by this library.' f'no longer supported by this library.'
) )
def _retrieve_server_version(self): def _retrieve_server_version(self) -> str:
try: try:
return self.version(api_version=False)["ApiVersion"] return self.version(api_version=False)["ApiVersion"]
except KeyError as ke: except KeyError as ke:
@ -221,29 +227,29 @@ class APIClient(
f'Error while fetching server API version: {e}' f'Error while fetching server API version: {e}'
) from e ) from e
def _set_request_timeout(self, kwargs): def _set_request_timeout(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""Prepare the kwargs for an HTTP request by inserting the timeout """Prepare the kwargs for an HTTP request by inserting the timeout
parameter, if not already present.""" parameter, if not already present."""
kwargs.setdefault('timeout', self.timeout) kwargs.setdefault('timeout', self.timeout)
return kwargs return kwargs
@update_headers @update_headers
def _post(self, url, **kwargs): def _post(self, url: str, **kwargs: Any) -> requests.Response:
return self.post(url, **self._set_request_timeout(kwargs)) return self.post(url, **self._set_request_timeout(kwargs))
@update_headers @update_headers
def _get(self, url, **kwargs): def _get(self, url: str, **kwargs: Any) -> requests.Response:
return self.get(url, **self._set_request_timeout(kwargs)) return self.get(url, **self._set_request_timeout(kwargs))
@update_headers @update_headers
def _put(self, url, **kwargs): def _put(self, url: str, **kwargs: Any) -> requests.Response:
return self.put(url, **self._set_request_timeout(kwargs)) return self.put(url, **self._set_request_timeout(kwargs))
@update_headers @update_headers
def _delete(self, url, **kwargs): def _delete(self, url: str, **kwargs: Any) -> requests.Response:
return self.delete(url, **self._set_request_timeout(kwargs)) return self.delete(url, **self._set_request_timeout(kwargs))
def _url(self, pathfmt, *args, **kwargs): def _url(self, pathfmt: str, *args: str, **kwargs: Any) -> str:
for arg in args: for arg in args:
if not isinstance(arg, str): if not isinstance(arg, str):
raise ValueError( raise ValueError(
@ -259,14 +265,30 @@ class APIClient(
else: else:
return f'{self.base_url}{formatted_path}' return f'{self.base_url}{formatted_path}'
def _raise_for_status(self, response): def _raise_for_status(self, response: requests.Response) -> None:
"""Raises stored :class:`APIError`, if one occurred.""" """Raises stored :class:`APIError`, if one occurred."""
try: try:
response.raise_for_status() response.raise_for_status()
except requests.exceptions.HTTPError as e: except requests.exceptions.HTTPError as e:
raise create_api_error_from_http_exception(e) from e raise create_api_error_from_http_exception(e) from e
def _result(self, response, json=False, binary=False): @overload
def _result(self, response: requests.Response, json: Literal[True], binary: Literal[True]) -> NoReturn:
...
@overload
def _result(self, response: requests.Response, json: Literal[False] = ..., binary: Literal[False] = ...) -> str:
...
@overload
def _result(self, response: requests.Response, json: Literal[True], binary: bool = ...) -> Any:
...
@overload
def _result(self, response: requests.Response, json: bool = ..., binary: Literal[True] = ...) -> bytes:
...
def _result(self, response: requests.Response, json: bool = False, binary: bool = False) -> Any:
assert not (json and binary) assert not (json and binary)
self._raise_for_status(response) self._raise_for_status(response)
@ -276,7 +298,7 @@ class APIClient(
return response.content return response.content
return response.text return response.text
def _post_json(self, url, data, **kwargs): def _post_json(self, url: str, data: Dict[str, Any], **kwargs: Any) -> requests.Response:
# Go <1.1 can't unserialize null to a string # Go <1.1 can't unserialize null to a string
# so we do this disgusting thing here. # so we do this disgusting thing here.
data2 = {} data2 = {}
@ -292,7 +314,7 @@ class APIClient(
kwargs['headers']['Content-Type'] = 'application/json' kwargs['headers']['Content-Type'] = 'application/json'
return self._post(url, data=json.dumps(data2), **kwargs) return self._post(url, data=json.dumps(data2), **kwargs)
def _attach_params(self, override=None): def _attach_params(self, override: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
return override or { return override or {
'stdout': 1, 'stdout': 1,
'stderr': 1, 'stderr': 1,
@ -300,7 +322,7 @@ class APIClient(
} }
@check_resource('container') @check_resource('container')
def _attach_websocket(self, container, params=None): def _attach_websocket(self, container: str, params: Optional[Dict[str, Any]] = None) -> websocket.WebSocket:
url = self._url("/containers/{0}/attach/ws", container) url = self._url("/containers/{0}/attach/ws", container)
req = requests.Request("POST", url, params=self._attach_params(params)) req = requests.Request("POST", url, params=self._attach_params(params))
full_url = req.prepare().url full_url = req.prepare().url
@ -308,10 +330,10 @@ class APIClient(
full_url = full_url.replace("https://", "wss://", 1) full_url = full_url.replace("https://", "wss://", 1)
return self._create_websocket_connection(full_url) return self._create_websocket_connection(full_url)
def _create_websocket_connection(self, url): def _create_websocket_connection(self, url: str) -> websocket.WebSocket:
return websocket.create_connection(url) return websocket.create_connection(url)
def _get_raw_response_socket(self, response): def _get_raw_response_socket(self, response: requests.Response) -> ssl.SSLSocket:
self._raise_for_status(response) self._raise_for_status(response)
if self.base_url == "http+docker://localnpipe": if self.base_url == "http+docker://localnpipe":
sock = response.raw._fp.fp.raw.sock sock = response.raw._fp.fp.raw.sock
@ -333,7 +355,15 @@ class APIClient(
return sock return sock
def _stream_helper(self, response, decode=False): @overload
def _stream_helper(self, response: requests.Response, decode: Literal[True]) -> Iterator[Dict[str, Any]]:
...
@overload
def _stream_helper(self, response: requests.Response, decode: Literal[False] = ...) -> Iterator[bytes]:
...
def _stream_helper(self, response: requests.Response, decode: bool = False) -> Iterator[BytesOrDict]:
"""Generator for data coming from a chunked-encoded HTTP response.""" """Generator for data coming from a chunked-encoded HTTP response."""
if response.raw._fp.chunked: if response.raw._fp.chunked:
@ -354,7 +384,7 @@ class APIClient(
# encountered an error immediately # encountered an error immediately
yield self._result(response, json=decode) yield self._result(response, json=decode)
def _multiplexed_buffer_helper(self, response): def _multiplexed_buffer_helper(self, response: requests.Response) -> Iterator[bytes]:
"""A generator of multiplexed data blocks read from a buffered """A generator of multiplexed data blocks read from a buffered
response.""" response."""
buf = self._result(response, binary=True) buf = self._result(response, binary=True)
@ -370,7 +400,7 @@ class APIClient(
walker = end walker = end
yield buf[start:end] yield buf[start:end]
def _multiplexed_response_stream_helper(self, response): def _multiplexed_response_stream_helper(self, response: requests.Response) -> Iterator[bytes]:
"""A generator of multiplexed data blocks coming from a response """A generator of multiplexed data blocks coming from a response
stream.""" stream."""
@ -391,7 +421,15 @@ class APIClient(
break break
yield data yield data
def _stream_raw_result(self, response, chunk_size=1, decode=True): @overload
def _stream_raw_result(self, response: requests.Response, chunk_size: int = ..., decode: Literal[False] = ...) -> Iterator[bytes]:
...
@overload
def _stream_raw_result(self, response: requests.Response, chunk_size: int = ..., decode: Literal[True] = ...) -> Iterator[str]:
...
def _stream_raw_result(self, response: requests.Response, chunk_size: int = 1, decode: bool = True) -> Iterator[AnyStr]:
''' Stream result for TTY-enabled container and raw binary data''' ''' Stream result for TTY-enabled container and raw binary data'''
self._raise_for_status(response) self._raise_for_status(response)
@ -402,7 +440,7 @@ class APIClient(
yield from response.iter_content(chunk_size, decode) yield from response.iter_content(chunk_size, decode)
def _read_from_socket(self, response, stream, tty=True, demux=False): def _read_from_socket(self, response: requests.Response, stream: bool, tty: bool = True, demux: bool = False) -> Any:
"""Consume all data from the socket, close the response and return the """Consume all data from the socket, close the response and return the
data. If stream=True, then a generator is returned instead and the data. If stream=True, then a generator is returned instead and the
caller is responsible for closing the response. caller is responsible for closing the response.
@ -427,7 +465,7 @@ class APIClient(
finally: finally:
response.close() response.close()
def _disable_socket_timeout(self, socket): def _disable_socket_timeout(self, socket: ssl.SSLSocket) -> None:
""" Depending on the combination of python version and whether we're """ Depending on the combination of python version and whether we're
connecting over http or https, we might need to access _sock, which connecting over http or https, we might need to access _sock, which
may or may not exist; or we may need to just settimeout on socket may or may not exist; or we may need to just settimeout on socket
@ -456,14 +494,22 @@ class APIClient(
s.settimeout(None) s.settimeout(None)
@check_resource('container') @check_resource('container')
def _check_is_tty(self, container): def _check_is_tty(self, container: str) -> bool:
cont = self.inspect_container(container) cont = self.inspect_container(container)
return cont['Config']['Tty'] return cont['Config']['Tty']
def _get_result(self, container, stream, res): def _get_result(self, container: str, stream: bool, res: requests.Response):
return self._get_result_tty(stream, res, self._check_is_tty(container)) return self._get_result_tty(stream, res, self._check_is_tty(container))
def _get_result_tty(self, stream, res, is_tty): @overload
def _get_result_tty(self, stream: Literal[True], res: requests.Response, is_tty: bool) -> Iterator[bytes]:
...
@overload
def _get_result_tty(self, stream: Literal[False], res: requests.Response, is_tty: bool) -> bytes:
...
def _get_result_tty(self, stream: bool, res: requests.Response, is_tty: bool):
# We should also use raw streaming (without keep-alives) # We should also use raw streaming (without keep-alives)
# if we're dealing with a tty-enabled container. # if we're dealing with a tty-enabled container.
if is_tty: if is_tty:
@ -479,11 +525,11 @@ class APIClient(
list(self._multiplexed_buffer_helper(res)) list(self._multiplexed_buffer_helper(res))
) )
def _unmount(self, *args): def _unmount(self, *args: str) -> None:
for proto in args: for proto in args:
self.adapters.pop(proto) self.adapters.pop(proto)
def get_adapter(self, url): def get_adapter(self, url: str) -> requests.adapters.BaseAdapter:
try: try:
return super().get_adapter(url) return super().get_adapter(url)
except requests.exceptions.InvalidSchema as e: except requests.exceptions.InvalidSchema as e:
@ -493,10 +539,10 @@ class APIClient(
raise e raise e
@property @property
def api_version(self): def api_version(self) -> str:
return self._version return self._version
def reload_config(self, dockercfg_path=None): def reload_config(self, dockercfg_path: Optional[str] = None) -> None:
""" """
Force a reload of the auth configuration Force a reload of the auth configuration

3
docker/utils/typing.py Normal file
View File

@ -0,0 +1,3 @@
from typing import Any, Dict, TypeVar
BytesOrDict = TypeVar("BytesOrDict", bytes, Dict[str, Any])

View File

@ -4,3 +4,4 @@ pywin32==304; sys_platform == 'win32'
requests==2.31.0 requests==2.31.0
urllib3==1.26.11 urllib3==1.26.11
websocket-client==1.3.3 websocket-client==1.3.3
typing_extensions>=3.10.0.0