mirror of https://github.com/docker/docker-py.git
				
				
				
			Add device requests (#2471)
* Add DeviceRequest type Signed-off-by: Erwan Rouchet <rouchet@teklia.com> * Add device_requests kwarg in host config Signed-off-by: Erwan Rouchet <rouchet@teklia.com> * Add unit test for device requests Signed-off-by: Erwan Rouchet <rouchet@teklia.com> * Fix unit test Signed-off-by: Erwan Rouchet <rouchet@teklia.com> * Use parentheses for multiline import Signed-off-by: Erwan Rouchet <rouchet@teklia.com> * Create 1.40 client for device-requests test Signed-off-by: Laurie O <laurie_opperman@hotmail.com> Co-authored-by: Laurie O <laurie_opperman@hotmail.com> Co-authored-by: Bastien Abadie <abadie@teklia.com>
This commit is contained in:
		
							parent
							
								
									26d8045ffa
								
							
						
					
					
						commit
						dd0450a14c
					
				|  | @ -480,6 +480,9 @@ class ContainerApiMixin(object): | |||
|                 For example, ``/dev/sda:/dev/xvda:rwm`` allows the container | ||||
|                 to have read-write access to the host's ``/dev/sda`` via a | ||||
|                 node named ``/dev/xvda`` inside the container. | ||||
|             device_requests (:py:class:`list`): Expose host resources such as | ||||
|                 GPUs to the container, as a list of | ||||
|                 :py:class:`docker.types.DeviceRequest` instances. | ||||
|             dns (:py:class:`list`): Set custom DNS servers. | ||||
|             dns_opt (:py:class:`list`): Additional options to be added to the | ||||
|                 container's ``resolv.conf`` file | ||||
|  |  | |||
|  | @ -579,6 +579,9 @@ class ContainerCollection(Collection): | |||
|                 For example, ``/dev/sda:/dev/xvda:rwm`` allows the container | ||||
|                 to have read-write access to the host's ``/dev/sda`` via a | ||||
|                 node named ``/dev/xvda`` inside the container. | ||||
|             device_requests (:py:class:`list`): Expose host resources such as | ||||
|                 GPUs to the container, as a list of | ||||
|                 :py:class:`docker.types.DeviceRequest` instances. | ||||
|             dns (:py:class:`list`): Set custom DNS servers. | ||||
|             dns_opt (:py:class:`list`): Additional options to be added to the | ||||
|                 container's ``resolv.conf`` file. | ||||
|  | @ -998,6 +1001,7 @@ RUN_HOST_CONFIG_KWARGS = [ | |||
|     'device_write_bps', | ||||
|     'device_write_iops', | ||||
|     'devices', | ||||
|     'device_requests', | ||||
|     'dns_opt', | ||||
|     'dns_search', | ||||
|     'dns', | ||||
|  |  | |||
|  | @ -1,5 +1,7 @@ | |||
| # flake8: noqa | ||||
| from .containers import ContainerConfig, HostConfig, LogConfig, Ulimit | ||||
| from .containers import ( | ||||
|     ContainerConfig, HostConfig, LogConfig, Ulimit, DeviceRequest | ||||
| ) | ||||
| from .daemon import CancellableStream | ||||
| from .healthcheck import Healthcheck | ||||
| from .networks import EndpointConfig, IPAMConfig, IPAMPool, NetworkingConfig | ||||
|  |  | |||
|  | @ -154,6 +154,104 @@ class Ulimit(DictType): | |||
|         self['Hard'] = value | ||||
| 
 | ||||
| 
 | ||||
| class DeviceRequest(DictType): | ||||
|     """ | ||||
|     Create a device request to be used with | ||||
|     :py:meth:`~docker.api.container.ContainerApiMixin.create_host_config`. | ||||
| 
 | ||||
|     Args: | ||||
| 
 | ||||
|         driver (str): Which driver to use for this device. Optional. | ||||
|         count (int): Number or devices to request. Optional. | ||||
|             Set to -1 to request all available devices. | ||||
|         device_ids (list): List of strings for device IDs. Optional. | ||||
|             Set either ``count`` or ``device_ids``. | ||||
|         capabilities (list): List of lists of strings to request | ||||
|             capabilities. Optional. The global list acts like an OR, | ||||
|             and the sub-lists are AND. The driver will try to satisfy | ||||
|             one of the sub-lists. | ||||
|             Available capabilities for the ``nvidia`` driver can be found | ||||
|             `here <https://github.com/NVIDIA/nvidia-container-runtime>`_. | ||||
|         options (dict): Driver-specific options. Optional. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, **kwargs): | ||||
|         driver = kwargs.get('driver', kwargs.get('Driver')) | ||||
|         count = kwargs.get('count', kwargs.get('Count')) | ||||
|         device_ids = kwargs.get('device_ids', kwargs.get('DeviceIDs')) | ||||
|         capabilities = kwargs.get('capabilities', kwargs.get('Capabilities')) | ||||
|         options = kwargs.get('options', kwargs.get('Options')) | ||||
| 
 | ||||
|         if driver is None: | ||||
|             driver = '' | ||||
|         elif not isinstance(driver, six.string_types): | ||||
|             raise ValueError('DeviceRequest.driver must be a string') | ||||
|         if count is None: | ||||
|             count = 0 | ||||
|         elif not isinstance(count, int): | ||||
|             raise ValueError('DeviceRequest.count must be an integer') | ||||
|         if device_ids is None: | ||||
|             device_ids = [] | ||||
|         elif not isinstance(device_ids, list): | ||||
|             raise ValueError('DeviceRequest.device_ids must be a list') | ||||
|         if capabilities is None: | ||||
|             capabilities = [] | ||||
|         elif not isinstance(capabilities, list): | ||||
|             raise ValueError('DeviceRequest.capabilities must be a list') | ||||
|         if options is None: | ||||
|             options = {} | ||||
|         elif not isinstance(options, dict): | ||||
|             raise ValueError('DeviceRequest.options must be a dict') | ||||
| 
 | ||||
|         super(DeviceRequest, self).__init__({ | ||||
|             'Driver': driver, | ||||
|             'Count': count, | ||||
|             'DeviceIDs': device_ids, | ||||
|             'Capabilities': capabilities, | ||||
|             'Options': options | ||||
|         }) | ||||
| 
 | ||||
|     @property | ||||
|     def driver(self): | ||||
|         return self['Driver'] | ||||
| 
 | ||||
|     @driver.setter | ||||
|     def driver(self, value): | ||||
|         self['Driver'] = value | ||||
| 
 | ||||
|     @property | ||||
|     def count(self): | ||||
|         return self['Count'] | ||||
| 
 | ||||
|     @count.setter | ||||
|     def count(self, value): | ||||
|         self['Count'] = value | ||||
| 
 | ||||
|     @property | ||||
|     def device_ids(self): | ||||
|         return self['DeviceIDs'] | ||||
| 
 | ||||
|     @device_ids.setter | ||||
|     def device_ids(self, value): | ||||
|         self['DeviceIDs'] = value | ||||
| 
 | ||||
|     @property | ||||
|     def capabilities(self): | ||||
|         return self['Capabilities'] | ||||
| 
 | ||||
|     @capabilities.setter | ||||
|     def capabilities(self, value): | ||||
|         self['Capabilities'] = value | ||||
| 
 | ||||
|     @property | ||||
|     def options(self): | ||||
|         return self['Options'] | ||||
| 
 | ||||
|     @options.setter | ||||
|     def options(self, value): | ||||
|         self['Options'] = value | ||||
| 
 | ||||
| 
 | ||||
| class HostConfig(dict): | ||||
|     def __init__(self, version, binds=None, port_bindings=None, | ||||
|                  lxc_conf=None, publish_all_ports=False, links=None, | ||||
|  | @ -176,7 +274,7 @@ class HostConfig(dict): | |||
|                  volume_driver=None, cpu_count=None, cpu_percent=None, | ||||
|                  nano_cpus=None, cpuset_mems=None, runtime=None, mounts=None, | ||||
|                  cpu_rt_period=None, cpu_rt_runtime=None, | ||||
|                  device_cgroup_rules=None): | ||||
|                  device_cgroup_rules=None, device_requests=None): | ||||
| 
 | ||||
|         if mem_limit is not None: | ||||
|             self['Memory'] = parse_bytes(mem_limit) | ||||
|  | @ -536,6 +634,19 @@ class HostConfig(dict): | |||
|                 ) | ||||
|             self['DeviceCgroupRules'] = device_cgroup_rules | ||||
| 
 | ||||
|         if device_requests is not None: | ||||
|             if version_lt(version, '1.40'): | ||||
|                 raise host_config_version_error('device_requests', '1.40') | ||||
|             if not isinstance(device_requests, list): | ||||
|                 raise host_config_type_error( | ||||
|                     'device_requests', device_requests, 'list' | ||||
|                 ) | ||||
|             self['DeviceRequests'] = [] | ||||
|             for req in device_requests: | ||||
|                 if not isinstance(req, DeviceRequest): | ||||
|                     req = DeviceRequest(**req) | ||||
|                 self['DeviceRequests'].append(req) | ||||
| 
 | ||||
| 
 | ||||
| def host_config_type_error(param, param_value, expected): | ||||
|     error_msg = 'Invalid type for {0} param: expected {1} but found {2}' | ||||
|  |  | |||
|  | @ -5,6 +5,7 @@ import json | |||
| import signal | ||||
| 
 | ||||
| import docker | ||||
| from docker.api import APIClient | ||||
| import pytest | ||||
| import six | ||||
| 
 | ||||
|  | @ -12,7 +13,7 @@ from . import fake_api | |||
| from ..helpers import requires_api_version | ||||
| from .api_test import ( | ||||
|     BaseAPIClientTest, url_prefix, fake_request, DEFAULT_TIMEOUT_SECONDS, | ||||
|     fake_inspect_container | ||||
|     fake_inspect_container, url_base | ||||
| ) | ||||
| 
 | ||||
| try: | ||||
|  | @ -767,6 +768,67 @@ class CreateContainerTest(BaseAPIClientTest): | |||
|         assert args[1]['headers'] == {'Content-Type': 'application/json'} | ||||
|         assert args[1]['timeout'] == DEFAULT_TIMEOUT_SECONDS | ||||
| 
 | ||||
|     def test_create_container_with_device_requests(self): | ||||
|         client = APIClient(version='1.40') | ||||
|         fake_api.fake_responses.setdefault( | ||||
|             '{0}/v1.40/containers/create'.format(fake_api.prefix), | ||||
|             fake_api.post_fake_create_container, | ||||
|         ) | ||||
|         client.create_container( | ||||
|             'busybox', 'true', host_config=client.create_host_config( | ||||
|                 device_requests=[ | ||||
|                     { | ||||
|                         'device_ids': [ | ||||
|                             '0', | ||||
|                             'GPU-3a23c669-1f69-c64e-cf85-44e9b07e7a2a' | ||||
|                         ] | ||||
|                     }, | ||||
|                     { | ||||
|                         'driver': 'nvidia', | ||||
|                         'Count': -1, | ||||
|                         'capabilities': [ | ||||
|                             ['gpu', 'utility'] | ||||
|                         ], | ||||
|                         'options': { | ||||
|                             'key': 'value' | ||||
|                         } | ||||
|                     } | ||||
|                 ] | ||||
|             ) | ||||
|         ) | ||||
| 
 | ||||
|         args = fake_request.call_args | ||||
|         assert args[0][1] == url_base + 'v1.40/' + 'containers/create' | ||||
|         expected_payload = self.base_create_payload() | ||||
|         expected_payload['HostConfig'] = client.create_host_config() | ||||
|         expected_payload['HostConfig']['DeviceRequests'] = [ | ||||
|             { | ||||
|                 'Driver': '', | ||||
|                 'Count': 0, | ||||
|                 'DeviceIDs': [ | ||||
|                     '0', | ||||
|                     'GPU-3a23c669-1f69-c64e-cf85-44e9b07e7a2a' | ||||
|                 ], | ||||
|                 'Capabilities': [], | ||||
|                 'Options': {} | ||||
|             }, | ||||
|             { | ||||
|                 'Driver': 'nvidia', | ||||
|                 'Count': -1, | ||||
|                 'DeviceIDs': [], | ||||
|                 'Capabilities': [ | ||||
|                     ['gpu', 'utility'] | ||||
|                 ], | ||||
|                 'Options': { | ||||
|                     'key': 'value' | ||||
|                 } | ||||
|             } | ||||
|         ] | ||||
|         assert json.loads(args[1]['data']) == expected_payload | ||||
|         assert args[1]['headers']['Content-Type'] == 'application/json' | ||||
|         assert set(args[1]['headers']) <= {'Content-Type', 'User-Agent'} | ||||
|         assert args[1]['timeout'] == DEFAULT_TIMEOUT_SECONDS | ||||
| 
 | ||||
|     def test_create_container_with_labels_dict(self): | ||||
|         labels_dict = { | ||||
|             six.text_type('foo'): six.text_type('1'), | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue