From dd0450a14c407050db141af486cc2ed9639ffd8d Mon Sep 17 00:00:00 2001 From: Lucidiot Date: Fri, 7 Aug 2020 13:58:35 +0200 Subject: [PATCH] Add device requests (#2471) * Add DeviceRequest type Signed-off-by: Erwan Rouchet * Add device_requests kwarg in host config Signed-off-by: Erwan Rouchet * Add unit test for device requests Signed-off-by: Erwan Rouchet * Fix unit test Signed-off-by: Erwan Rouchet * Use parentheses for multiline import Signed-off-by: Erwan Rouchet * Create 1.40 client for device-requests test Signed-off-by: Laurie O Co-authored-by: Laurie O Co-authored-by: Bastien Abadie --- docker/api/container.py | 3 + docker/models/containers.py | 4 ++ docker/types/__init__.py | 4 +- docker/types/containers.py | 113 ++++++++++++++++++++++++++++++- tests/unit/api_container_test.py | 64 ++++++++++++++++- 5 files changed, 185 insertions(+), 3 deletions(-) diff --git a/docker/api/container.py b/docker/api/container.py index 9df22a52..2ba08e53 100644 --- a/docker/api/container.py +++ b/docker/api/container.py @@ -480,6 +480,9 @@ class ContainerApiMixin(object): For example, ``/dev/sda:/dev/xvda:rwm`` allows the container to have read-write access to the host's ``/dev/sda`` via a node named ``/dev/xvda`` inside the container. + device_requests (:py:class:`list`): Expose host resources such as + GPUs to the container, as a list of + :py:class:`docker.types.DeviceRequest` instances. dns (:py:class:`list`): Set custom DNS servers. dns_opt (:py:class:`list`): Additional options to be added to the container's ``resolv.conf`` file diff --git a/docker/models/containers.py b/docker/models/containers.py index d1f275f7..e8082ba4 100644 --- a/docker/models/containers.py +++ b/docker/models/containers.py @@ -579,6 +579,9 @@ class ContainerCollection(Collection): For example, ``/dev/sda:/dev/xvda:rwm`` allows the container to have read-write access to the host's ``/dev/sda`` via a node named ``/dev/xvda`` inside the container. + device_requests (:py:class:`list`): Expose host resources such as + GPUs to the container, as a list of + :py:class:`docker.types.DeviceRequest` instances. dns (:py:class:`list`): Set custom DNS servers. dns_opt (:py:class:`list`): Additional options to be added to the container's ``resolv.conf`` file. @@ -998,6 +1001,7 @@ RUN_HOST_CONFIG_KWARGS = [ 'device_write_bps', 'device_write_iops', 'devices', + 'device_requests', 'dns_opt', 'dns_search', 'dns', diff --git a/docker/types/__init__.py b/docker/types/__init__.py index 5db330e2..b425746e 100644 --- a/docker/types/__init__.py +++ b/docker/types/__init__.py @@ -1,5 +1,7 @@ # flake8: noqa -from .containers import ContainerConfig, HostConfig, LogConfig, Ulimit +from .containers import ( + ContainerConfig, HostConfig, LogConfig, Ulimit, DeviceRequest +) from .daemon import CancellableStream from .healthcheck import Healthcheck from .networks import EndpointConfig, IPAMConfig, IPAMPool, NetworkingConfig diff --git a/docker/types/containers.py b/docker/types/containers.py index fd8cab49..149b85df 100644 --- a/docker/types/containers.py +++ b/docker/types/containers.py @@ -154,6 +154,104 @@ class Ulimit(DictType): self['Hard'] = value +class DeviceRequest(DictType): + """ + Create a device request to be used with + :py:meth:`~docker.api.container.ContainerApiMixin.create_host_config`. + + Args: + + driver (str): Which driver to use for this device. Optional. + count (int): Number or devices to request. Optional. + Set to -1 to request all available devices. + device_ids (list): List of strings for device IDs. Optional. + Set either ``count`` or ``device_ids``. + capabilities (list): List of lists of strings to request + capabilities. Optional. The global list acts like an OR, + and the sub-lists are AND. The driver will try to satisfy + one of the sub-lists. + Available capabilities for the ``nvidia`` driver can be found + `here `_. + options (dict): Driver-specific options. Optional. + """ + + def __init__(self, **kwargs): + driver = kwargs.get('driver', kwargs.get('Driver')) + count = kwargs.get('count', kwargs.get('Count')) + device_ids = kwargs.get('device_ids', kwargs.get('DeviceIDs')) + capabilities = kwargs.get('capabilities', kwargs.get('Capabilities')) + options = kwargs.get('options', kwargs.get('Options')) + + if driver is None: + driver = '' + elif not isinstance(driver, six.string_types): + raise ValueError('DeviceRequest.driver must be a string') + if count is None: + count = 0 + elif not isinstance(count, int): + raise ValueError('DeviceRequest.count must be an integer') + if device_ids is None: + device_ids = [] + elif not isinstance(device_ids, list): + raise ValueError('DeviceRequest.device_ids must be a list') + if capabilities is None: + capabilities = [] + elif not isinstance(capabilities, list): + raise ValueError('DeviceRequest.capabilities must be a list') + if options is None: + options = {} + elif not isinstance(options, dict): + raise ValueError('DeviceRequest.options must be a dict') + + super(DeviceRequest, self).__init__({ + 'Driver': driver, + 'Count': count, + 'DeviceIDs': device_ids, + 'Capabilities': capabilities, + 'Options': options + }) + + @property + def driver(self): + return self['Driver'] + + @driver.setter + def driver(self, value): + self['Driver'] = value + + @property + def count(self): + return self['Count'] + + @count.setter + def count(self, value): + self['Count'] = value + + @property + def device_ids(self): + return self['DeviceIDs'] + + @device_ids.setter + def device_ids(self, value): + self['DeviceIDs'] = value + + @property + def capabilities(self): + return self['Capabilities'] + + @capabilities.setter + def capabilities(self, value): + self['Capabilities'] = value + + @property + def options(self): + return self['Options'] + + @options.setter + def options(self, value): + self['Options'] = value + + class HostConfig(dict): def __init__(self, version, binds=None, port_bindings=None, lxc_conf=None, publish_all_ports=False, links=None, @@ -176,7 +274,7 @@ class HostConfig(dict): volume_driver=None, cpu_count=None, cpu_percent=None, nano_cpus=None, cpuset_mems=None, runtime=None, mounts=None, cpu_rt_period=None, cpu_rt_runtime=None, - device_cgroup_rules=None): + device_cgroup_rules=None, device_requests=None): if mem_limit is not None: self['Memory'] = parse_bytes(mem_limit) @@ -536,6 +634,19 @@ class HostConfig(dict): ) self['DeviceCgroupRules'] = device_cgroup_rules + if device_requests is not None: + if version_lt(version, '1.40'): + raise host_config_version_error('device_requests', '1.40') + if not isinstance(device_requests, list): + raise host_config_type_error( + 'device_requests', device_requests, 'list' + ) + self['DeviceRequests'] = [] + for req in device_requests: + if not isinstance(req, DeviceRequest): + req = DeviceRequest(**req) + self['DeviceRequests'].append(req) + def host_config_type_error(param, param_value, expected): error_msg = 'Invalid type for {0} param: expected {1} but found {2}' diff --git a/tests/unit/api_container_test.py b/tests/unit/api_container_test.py index a7e183c8..8a0577e7 100644 --- a/tests/unit/api_container_test.py +++ b/tests/unit/api_container_test.py @@ -5,6 +5,7 @@ import json import signal import docker +from docker.api import APIClient import pytest import six @@ -12,7 +13,7 @@ from . import fake_api from ..helpers import requires_api_version from .api_test import ( BaseAPIClientTest, url_prefix, fake_request, DEFAULT_TIMEOUT_SECONDS, - fake_inspect_container + fake_inspect_container, url_base ) try: @@ -767,6 +768,67 @@ class CreateContainerTest(BaseAPIClientTest): assert args[1]['headers'] == {'Content-Type': 'application/json'} assert args[1]['timeout'] == DEFAULT_TIMEOUT_SECONDS + def test_create_container_with_device_requests(self): + client = APIClient(version='1.40') + fake_api.fake_responses.setdefault( + '{0}/v1.40/containers/create'.format(fake_api.prefix), + fake_api.post_fake_create_container, + ) + client.create_container( + 'busybox', 'true', host_config=client.create_host_config( + device_requests=[ + { + 'device_ids': [ + '0', + 'GPU-3a23c669-1f69-c64e-cf85-44e9b07e7a2a' + ] + }, + { + 'driver': 'nvidia', + 'Count': -1, + 'capabilities': [ + ['gpu', 'utility'] + ], + 'options': { + 'key': 'value' + } + } + ] + ) + ) + + args = fake_request.call_args + assert args[0][1] == url_base + 'v1.40/' + 'containers/create' + expected_payload = self.base_create_payload() + expected_payload['HostConfig'] = client.create_host_config() + expected_payload['HostConfig']['DeviceRequests'] = [ + { + 'Driver': '', + 'Count': 0, + 'DeviceIDs': [ + '0', + 'GPU-3a23c669-1f69-c64e-cf85-44e9b07e7a2a' + ], + 'Capabilities': [], + 'Options': {} + }, + { + 'Driver': 'nvidia', + 'Count': -1, + 'DeviceIDs': [], + 'Capabilities': [ + ['gpu', 'utility'] + ], + 'Options': { + 'key': 'value' + } + } + ] + assert json.loads(args[1]['data']) == expected_payload + assert args[1]['headers']['Content-Type'] == 'application/json' + assert set(args[1]['headers']) <= {'Content-Type', 'User-Agent'} + assert args[1]['timeout'] == DEFAULT_TIMEOUT_SECONDS + def test_create_container_with_labels_dict(self): labels_dict = { six.text_type('foo'): six.text_type('1'),