diff --git a/docker/client.py b/docker/client.py index 00ba0de9..55c835cf 100644 --- a/docker/client.py +++ b/docker/client.py @@ -26,7 +26,7 @@ import six from .auth import auth from .unixconn import unixconn from .ssladapter import ssladapter -from .utils import utils +from .utils import utils, check_resource from . import errors from .tls import TLSConfig @@ -154,6 +154,7 @@ class Client(requests.Session): 'stream': 1 } + @check_resource def _attach_websocket(self, container, params=None): if six.PY3: raise NotImplementedError("This method is not currently supported " @@ -249,6 +250,7 @@ class Client(requests.Session): def api_version(self): return self._version + @check_resource def attach(self, container, stdout=True, stderr=True, stream=False, logs=False): if isinstance(container, dict): @@ -285,6 +287,7 @@ class Client(requests.Session): [x for x in self._multiplexed_buffer_helper(response)] ) + @check_resource def attach_socket(self, container, params=None, ws=False): if params is None: params = { @@ -398,6 +401,7 @@ class Client(requests.Session): return None, output return match.group(1), output + @check_resource def commit(self, container, repository=None, tag=None, message=None, author=None, conf=None): params = { @@ -434,6 +438,7 @@ class Client(requests.Session): x['Id'] = x['Id'][:12] return res + @check_resource def copy(self, container, resource): if isinstance(container, dict): container = container.get('Id') @@ -478,6 +483,7 @@ class Client(requests.Session): res = self._post_json(u, data=config, params=params) return self._result(res, True) + @check_resource def diff(self, container): if isinstance(container, dict): container = container.get('Id') @@ -504,6 +510,7 @@ class Client(requests.Session): params=params, stream=True), decode=decode) + @check_resource def execute(self, container, cmd, detach=False, stdout=True, stderr=True, stream=False, tty=False): if utils.compare_version('1.15', self._version) < 0: @@ -546,6 +553,7 @@ class Client(requests.Session): [x for x in self._multiplexed_buffer_helper(res)] ) + @check_resource def export(self, container): if isinstance(container, dict): container = container.get('Id') @@ -554,12 +562,14 @@ class Client(requests.Session): self._raise_for_status(res) return res.raw + @check_resource def get_image(self, image): res = self._get(self._url("/images/{0}/get".format(image)), stream=True) self._raise_for_status(res) return res.raw + @check_resource def history(self, image): res = self._get(self._url("/images/{0}/history".format(image))) return self._result(res, True) @@ -669,18 +679,20 @@ class Client(requests.Session): return self._result(self._get(self._url("/info")), True) + @check_resource def insert(self, image, url, path): if utils.compare_version('1.12', self._version) >= 0: raise errors.DeprecatedMethod( 'insert is not available for API version >=1.12' ) - api_url = self._url("/images/" + image + "/insert") + api_url = self._url("/images/{0}/insert".fornat(image)) params = { 'url': url, 'path': path } return self._result(self._post(api_url, params=params)) + @check_resource def inspect_container(self, container): if isinstance(container, dict): container = container.get('Id') @@ -688,12 +700,14 @@ class Client(requests.Session): self._get(self._url("/containers/{0}/json".format(container))), True) - def inspect_image(self, image_id): + @check_resource + def inspect_image(self, image): return self._result( - self._get(self._url("/images/{0}/json".format(image_id))), + self._get(self._url("/images/{0}/json".format(image))), True ) + @check_resource def kill(self, container, signal=None): if isinstance(container, dict): container = container.get('Id') @@ -741,6 +755,7 @@ class Client(requests.Session): self._auth_configs[registry] = req_data return self._result(response, json=True) + @check_resource def logs(self, container, stdout=True, stderr=True, stream=False, timestamps=False, tail='all'): if isinstance(container, dict): @@ -775,6 +790,7 @@ class Client(requests.Session): logs=True ) + @check_resource def pause(self, container): if isinstance(container, dict): container = container.get('Id') @@ -785,6 +801,7 @@ class Client(requests.Session): def ping(self): return self._result(self._get(self._url('/_ping'))) + @check_resource def port(self, container, private_port): if isinstance(container, dict): container = container.get('Id') @@ -882,6 +899,7 @@ class Client(requests.Session): return stream and self._stream_helper(response) \ or self._result(response) + @check_resource def remove_container(self, container, v=False, link=False, force=False): if isinstance(container, dict): container = container.get('Id') @@ -890,6 +908,7 @@ class Client(requests.Session): params=params) self._raise_for_status(res) + @check_resource def remove_image(self, image, force=False, noprune=False): if isinstance(image, dict): image = image.get('Id') @@ -897,6 +916,7 @@ class Client(requests.Session): res = self._delete(self._url("/images/" + image), params=params) self._raise_for_status(res) + @check_resource def rename(self, container, name): if utils.compare_version('1.17', self._version) < 0: raise errors.InvalidVersion( @@ -909,6 +929,7 @@ class Client(requests.Session): res = self._post(url, params=params) self._raise_for_status(res) + @check_resource def resize(self, container, height, width): if isinstance(container, dict): container = container.get('Id') @@ -918,6 +939,7 @@ class Client(requests.Session): res = self._post(url, params=params) self._raise_for_status(res) + @check_resource def restart(self, container, timeout=10): if isinstance(container, dict): container = container.get('Id') @@ -931,6 +953,7 @@ class Client(requests.Session): params={'term': term}), True) + @check_resource def start(self, container, binds=None, port_bindings=None, lxc_conf=None, publish_all_ports=False, links=None, privileged=False, dns=None, dns_search=None, volumes_from=None, network_mode=None, @@ -993,6 +1016,7 @@ class Client(requests.Session): res = self._post_json(url, data=start_config) self._raise_for_status(res) + @check_resource def stats(self, container, decode=None): if utils.compare_version('1.17', self._version) < 0: raise errors.InvalidVersion( @@ -1003,6 +1027,7 @@ class Client(requests.Session): url = self._url("/containers/{0}/stats".format(container)) return self._stream_helper(self._get(url, stream=True), decode=decode) + @check_resource def stop(self, container, timeout=10): if isinstance(container, dict): container = container.get('Id') @@ -1013,6 +1038,7 @@ class Client(requests.Session): timeout=(timeout + self.timeout)) self._raise_for_status(res) + @check_resource def tag(self, image, repository, tag=None, force=False): params = { 'tag': tag, @@ -1024,7 +1050,10 @@ class Client(requests.Session): self._raise_for_status(res) return res.status_code == 201 + @check_resource def top(self, container): + if isinstance(container, dict): + container = container.get('Id') u = self._url("/containers/{0}/top".format(container)) return self._result(self._get(u), True) @@ -1032,6 +1061,7 @@ class Client(requests.Session): url = self._url("/version", versioned_api=api_version) return self._result(self._get(url), json=True) + @check_resource def unpause(self, container): if isinstance(container, dict): container = container.get('Id') @@ -1039,6 +1069,7 @@ class Client(requests.Session): res = self._post(url) self._raise_for_status(res) + @check_resource def wait(self, container, timeout=None): if isinstance(container, dict): container = container.get('Id') diff --git a/docker/utils/__init__.py b/docker/utils/__init__.py index 3594b9cd..81cc8a68 100644 --- a/docker/utils/__init__.py +++ b/docker/utils/__init__.py @@ -5,4 +5,5 @@ from .utils import ( create_container_config, parse_bytes, ping_registry ) # flake8: noqa -from .types import Ulimit, LogConfig # flake8: noqa \ No newline at end of file +from .types import Ulimit, LogConfig # flake8: noqa +from .decorators import check_resource #flake8: noqa diff --git a/docker/utils/decorators.py b/docker/utils/decorators.py new file mode 100644 index 00000000..897db327 --- /dev/null +++ b/docker/utils/decorators.py @@ -0,0 +1,10 @@ + + +def check_resource(f): + def wrapped(self, resource_id=None, *args, **kwargs): + if resource_id is None and ( + kwargs.get('container') is None and kwargs.get('image') is None + ): + raise ValueError('image or container param is None') + return f(self, resource_id, *args, **kwargs) + return wrapped diff --git a/tests/test.py b/tests/test.py index f2af58be..76c56381 100644 --- a/tests/test.py +++ b/tests/test.py @@ -669,6 +669,21 @@ class DockerClientTest(Cleanup, base.BaseTestCase): args[1]['timeout'], docker.client.DEFAULT_TIMEOUT_SECONDS ) + def test_start_container_none(self): + try: + self.client.start(container=None) + except ValueError as e: + self.assertEqual(str(e), 'image or container param is None') + else: + self.fail('Command should raise ValueError') + + try: + self.client.start(None) + except ValueError as e: + self.assertEqual(str(e), 'image or container param is None') + else: + self.fail('Command should raise ValueError') + def test_create_container_with_lxc_conf(self): try: self.client.create_container(