Add device requests (#2471)

* Add DeviceRequest type

Signed-off-by: Erwan Rouchet <rouchet@teklia.com>

* Add device_requests kwarg in host config

Signed-off-by: Erwan Rouchet <rouchet@teklia.com>

* Add unit test for device requests

Signed-off-by: Erwan Rouchet <rouchet@teklia.com>

* Fix unit test

Signed-off-by: Erwan Rouchet <rouchet@teklia.com>

* Use parentheses for multiline import

Signed-off-by: Erwan Rouchet <rouchet@teklia.com>

* Create 1.40 client for device-requests test

Signed-off-by: Laurie O <laurie_opperman@hotmail.com>

Co-authored-by: Laurie O <laurie_opperman@hotmail.com>
Co-authored-by: Bastien Abadie <abadie@teklia.com>
This commit is contained in:
Lucidiot 2020-08-07 13:58:35 +02:00 committed by GitHub
parent 26d8045ffa
commit dd0450a14c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 185 additions and 3 deletions

View File

@ -480,6 +480,9 @@ class ContainerApiMixin(object):
For example, ``/dev/sda:/dev/xvda:rwm`` allows the container For example, ``/dev/sda:/dev/xvda:rwm`` allows the container
to have read-write access to the host's ``/dev/sda`` via a to have read-write access to the host's ``/dev/sda`` via a
node named ``/dev/xvda`` inside the container. node named ``/dev/xvda`` inside the container.
device_requests (:py:class:`list`): Expose host resources such as
GPUs to the container, as a list of
:py:class:`docker.types.DeviceRequest` instances.
dns (:py:class:`list`): Set custom DNS servers. dns (:py:class:`list`): Set custom DNS servers.
dns_opt (:py:class:`list`): Additional options to be added to the dns_opt (:py:class:`list`): Additional options to be added to the
container's ``resolv.conf`` file container's ``resolv.conf`` file

View File

@ -579,6 +579,9 @@ class ContainerCollection(Collection):
For example, ``/dev/sda:/dev/xvda:rwm`` allows the container For example, ``/dev/sda:/dev/xvda:rwm`` allows the container
to have read-write access to the host's ``/dev/sda`` via a to have read-write access to the host's ``/dev/sda`` via a
node named ``/dev/xvda`` inside the container. node named ``/dev/xvda`` inside the container.
device_requests (:py:class:`list`): Expose host resources such as
GPUs to the container, as a list of
:py:class:`docker.types.DeviceRequest` instances.
dns (:py:class:`list`): Set custom DNS servers. dns (:py:class:`list`): Set custom DNS servers.
dns_opt (:py:class:`list`): Additional options to be added to the dns_opt (:py:class:`list`): Additional options to be added to the
container's ``resolv.conf`` file. container's ``resolv.conf`` file.
@ -998,6 +1001,7 @@ RUN_HOST_CONFIG_KWARGS = [
'device_write_bps', 'device_write_bps',
'device_write_iops', 'device_write_iops',
'devices', 'devices',
'device_requests',
'dns_opt', 'dns_opt',
'dns_search', 'dns_search',
'dns', 'dns',

View File

@ -1,5 +1,7 @@
# flake8: noqa # flake8: noqa
from .containers import ContainerConfig, HostConfig, LogConfig, Ulimit from .containers import (
ContainerConfig, HostConfig, LogConfig, Ulimit, DeviceRequest
)
from .daemon import CancellableStream from .daemon import CancellableStream
from .healthcheck import Healthcheck from .healthcheck import Healthcheck
from .networks import EndpointConfig, IPAMConfig, IPAMPool, NetworkingConfig from .networks import EndpointConfig, IPAMConfig, IPAMPool, NetworkingConfig

View File

@ -154,6 +154,104 @@ class Ulimit(DictType):
self['Hard'] = value self['Hard'] = value
class DeviceRequest(DictType):
"""
Create a device request to be used with
:py:meth:`~docker.api.container.ContainerApiMixin.create_host_config`.
Args:
driver (str): Which driver to use for this device. Optional.
count (int): Number or devices to request. Optional.
Set to -1 to request all available devices.
device_ids (list): List of strings for device IDs. Optional.
Set either ``count`` or ``device_ids``.
capabilities (list): List of lists of strings to request
capabilities. Optional. The global list acts like an OR,
and the sub-lists are AND. The driver will try to satisfy
one of the sub-lists.
Available capabilities for the ``nvidia`` driver can be found
`here <https://github.com/NVIDIA/nvidia-container-runtime>`_.
options (dict): Driver-specific options. Optional.
"""
def __init__(self, **kwargs):
driver = kwargs.get('driver', kwargs.get('Driver'))
count = kwargs.get('count', kwargs.get('Count'))
device_ids = kwargs.get('device_ids', kwargs.get('DeviceIDs'))
capabilities = kwargs.get('capabilities', kwargs.get('Capabilities'))
options = kwargs.get('options', kwargs.get('Options'))
if driver is None:
driver = ''
elif not isinstance(driver, six.string_types):
raise ValueError('DeviceRequest.driver must be a string')
if count is None:
count = 0
elif not isinstance(count, int):
raise ValueError('DeviceRequest.count must be an integer')
if device_ids is None:
device_ids = []
elif not isinstance(device_ids, list):
raise ValueError('DeviceRequest.device_ids must be a list')
if capabilities is None:
capabilities = []
elif not isinstance(capabilities, list):
raise ValueError('DeviceRequest.capabilities must be a list')
if options is None:
options = {}
elif not isinstance(options, dict):
raise ValueError('DeviceRequest.options must be a dict')
super(DeviceRequest, self).__init__({
'Driver': driver,
'Count': count,
'DeviceIDs': device_ids,
'Capabilities': capabilities,
'Options': options
})
@property
def driver(self):
return self['Driver']
@driver.setter
def driver(self, value):
self['Driver'] = value
@property
def count(self):
return self['Count']
@count.setter
def count(self, value):
self['Count'] = value
@property
def device_ids(self):
return self['DeviceIDs']
@device_ids.setter
def device_ids(self, value):
self['DeviceIDs'] = value
@property
def capabilities(self):
return self['Capabilities']
@capabilities.setter
def capabilities(self, value):
self['Capabilities'] = value
@property
def options(self):
return self['Options']
@options.setter
def options(self, value):
self['Options'] = value
class HostConfig(dict): class HostConfig(dict):
def __init__(self, version, binds=None, port_bindings=None, def __init__(self, version, binds=None, port_bindings=None,
lxc_conf=None, publish_all_ports=False, links=None, lxc_conf=None, publish_all_ports=False, links=None,
@ -176,7 +274,7 @@ class HostConfig(dict):
volume_driver=None, cpu_count=None, cpu_percent=None, volume_driver=None, cpu_count=None, cpu_percent=None,
nano_cpus=None, cpuset_mems=None, runtime=None, mounts=None, nano_cpus=None, cpuset_mems=None, runtime=None, mounts=None,
cpu_rt_period=None, cpu_rt_runtime=None, cpu_rt_period=None, cpu_rt_runtime=None,
device_cgroup_rules=None): device_cgroup_rules=None, device_requests=None):
if mem_limit is not None: if mem_limit is not None:
self['Memory'] = parse_bytes(mem_limit) self['Memory'] = parse_bytes(mem_limit)
@ -536,6 +634,19 @@ class HostConfig(dict):
) )
self['DeviceCgroupRules'] = device_cgroup_rules self['DeviceCgroupRules'] = device_cgroup_rules
if device_requests is not None:
if version_lt(version, '1.40'):
raise host_config_version_error('device_requests', '1.40')
if not isinstance(device_requests, list):
raise host_config_type_error(
'device_requests', device_requests, 'list'
)
self['DeviceRequests'] = []
for req in device_requests:
if not isinstance(req, DeviceRequest):
req = DeviceRequest(**req)
self['DeviceRequests'].append(req)
def host_config_type_error(param, param_value, expected): def host_config_type_error(param, param_value, expected):
error_msg = 'Invalid type for {0} param: expected {1} but found {2}' error_msg = 'Invalid type for {0} param: expected {1} but found {2}'

View File

@ -5,6 +5,7 @@ import json
import signal import signal
import docker import docker
from docker.api import APIClient
import pytest import pytest
import six import six
@ -12,7 +13,7 @@ from . import fake_api
from ..helpers import requires_api_version from ..helpers import requires_api_version
from .api_test import ( from .api_test import (
BaseAPIClientTest, url_prefix, fake_request, DEFAULT_TIMEOUT_SECONDS, BaseAPIClientTest, url_prefix, fake_request, DEFAULT_TIMEOUT_SECONDS,
fake_inspect_container fake_inspect_container, url_base
) )
try: try:
@ -767,6 +768,67 @@ class CreateContainerTest(BaseAPIClientTest):
assert args[1]['headers'] == {'Content-Type': 'application/json'} assert args[1]['headers'] == {'Content-Type': 'application/json'}
assert args[1]['timeout'] == DEFAULT_TIMEOUT_SECONDS assert args[1]['timeout'] == DEFAULT_TIMEOUT_SECONDS
def test_create_container_with_device_requests(self):
client = APIClient(version='1.40')
fake_api.fake_responses.setdefault(
'{0}/v1.40/containers/create'.format(fake_api.prefix),
fake_api.post_fake_create_container,
)
client.create_container(
'busybox', 'true', host_config=client.create_host_config(
device_requests=[
{
'device_ids': [
'0',
'GPU-3a23c669-1f69-c64e-cf85-44e9b07e7a2a'
]
},
{
'driver': 'nvidia',
'Count': -1,
'capabilities': [
['gpu', 'utility']
],
'options': {
'key': 'value'
}
}
]
)
)
args = fake_request.call_args
assert args[0][1] == url_base + 'v1.40/' + 'containers/create'
expected_payload = self.base_create_payload()
expected_payload['HostConfig'] = client.create_host_config()
expected_payload['HostConfig']['DeviceRequests'] = [
{
'Driver': '',
'Count': 0,
'DeviceIDs': [
'0',
'GPU-3a23c669-1f69-c64e-cf85-44e9b07e7a2a'
],
'Capabilities': [],
'Options': {}
},
{
'Driver': 'nvidia',
'Count': -1,
'DeviceIDs': [],
'Capabilities': [
['gpu', 'utility']
],
'Options': {
'key': 'value'
}
}
]
assert json.loads(args[1]['data']) == expected_payload
assert args[1]['headers']['Content-Type'] == 'application/json'
assert set(args[1]['headers']) <= {'Content-Type', 'User-Agent'}
assert args[1]['timeout'] == DEFAULT_TIMEOUT_SECONDS
def test_create_container_with_labels_dict(self): def test_create_container_with_labels_dict(self):
labels_dict = { labels_dict = {
six.text_type('foo'): six.text_type('1'), six.text_type('foo'): six.text_type('1'),