diff --git a/compose/project.py b/compose/project.py index 7df79685a7..abc3132a27 100644 --- a/compose/project.py +++ b/compose/project.py @@ -9,7 +9,10 @@ from .config import get_service_name_from_net, ConfigurationError from .const import DEFAULT_TIMEOUT, LABEL_PROJECT, LABEL_SERVICE, LABEL_ONE_OFF from .container import Container from .legacy import check_for_legacy_containers +from .service import ContainerNet +from .service import Net from .service import Service +from .service import ServiceNet from .utils import parallel_execute log = logging.getLogger(__name__) @@ -180,18 +183,18 @@ class Project(object): def get_net(self, service_dict): net = service_dict.pop('net', None) if not net: - return + return Net(None) net_name = get_service_name_from_net(net) if not net_name: - return net + return Net(net) try: - return self.get_service(net_name) + return ServiceNet(self.get_service(net_name)) except NoSuchService: pass try: - return Container.from_id(self.client, net_name) + return ContainerNet(Container.from_id(self.client, net_name)) except APIError: raise ConfigurationError( 'Service "%s" is trying to use the network of "%s", ' diff --git a/compose/service.py b/compose/service.py index a567a54c94..25584fe2cd 100644 --- a/compose/service.py +++ b/compose/service.py @@ -107,7 +107,7 @@ class Service(object): self.project = project self.links = links or [] self.volumes_from = volumes_from or [] - self.net = net or None + self.net = net or Net(None) self.options = options def containers(self, stopped=False, one_off=False): @@ -478,12 +478,12 @@ class Service(object): 'options': self.options, 'image_id': self.image()['Id'], 'links': [(service.name, alias) for service, alias in self.links], - 'net': self.get_net_name() or getattr(self.net, 'id', self.net), + 'net': self.net.id, 'volumes_from': self.get_volumes_from_names(), } def get_dependency_names(self): - net_name = self.get_net_name() + net_name = self.net.service_name return (self.get_linked_names() + self.get_volumes_from_names() + ([net_name] if net_name else [])) @@ -494,12 +494,6 @@ class Service(object): def get_volumes_from_names(self): return [s.name for s in self.volumes_from if isinstance(s, Service)] - def get_net_name(self): - if isinstance(self.net, Service): - return self.net.name - else: - return - def get_container_name(self, number, one_off=False): # TODO: Implement issue #652 here return build_container_name(self.project, self.name, number, one_off) @@ -551,25 +545,6 @@ class Service(object): return volumes_from - def _get_net(self): - if not self.net: - return None - - if isinstance(self.net, Service): - containers = self.net.containers() - if len(containers) > 0: - net = 'container:' + containers[0].id - else: - log.warning("Warning: Service %s is trying to use reuse the network stack " - "of another service that is not running." % (self.net.name)) - net = None - elif isinstance(self.net, Container): - net = 'container:' + self.net.id - else: - net = self.net - - return net - def _get_container_create_options( self, override_options, @@ -690,7 +665,7 @@ class Service(object): binds=options.get('binds'), volumes_from=self._get_volumes_from(), privileged=privileged, - network_mode=self._get_net(), + network_mode=self.net.mode, devices=devices, dns=dns, dns_search=dns_search, @@ -785,6 +760,61 @@ class Service(object): stream_output(output, sys.stdout) +class Net(object): + """A `standard` network mode (ex: host, bridge)""" + + service_name = None + + def __init__(self, net): + self.net = net + + @property + def id(self): + return self.net + + mode = id + + +class ContainerNet(object): + """A network mode that uses a containers network stack.""" + + service_name = None + + def __init__(self, container): + self.container = container + + @property + def id(self): + return self.container.id + + @property + def mode(self): + return 'container:' + self.container.id + + +class ServiceNet(object): + """A network mode that uses a service's network stack.""" + + def __init__(self, service): + self.service = service + + @property + def id(self): + return self.service.name + + service_name = id + + @property + def mode(self): + containers = self.service.containers() + if containers: + return 'container:' + containers[0].id + + log.warn("Warning: Service %s is trying to use reuse the network stack " + "of another service that is not running." % (self.id)) + return None + + # Names diff --git a/tests/unit/project_test.py b/tests/unit/project_test.py index 93bf12ff57..a66aaf5d27 100644 --- a/tests/unit/project_test.py +++ b/tests/unit/project_test.py @@ -220,7 +220,7 @@ class ProjectTest(unittest.TestCase): } ], self.mock_client) service = project.get_service('test') - self.assertEqual(service._get_net(), None) + self.assertEqual(service.net.id, None) self.assertNotIn('NetworkMode', service._get_container_host_config({})) def test_use_net_from_container(self): @@ -235,7 +235,7 @@ class ProjectTest(unittest.TestCase): } ], self.mock_client) service = project.get_service('test') - self.assertEqual(service._get_net(), 'container:' + container_id) + self.assertEqual(service.net.mode, 'container:' + container_id) def test_use_net_from_service(self): container_name = 'test_aaa_1' @@ -260,7 +260,7 @@ class ProjectTest(unittest.TestCase): ], self.mock_client) service = project.get_service('test') - self.assertEqual(service._get_net(), 'container:' + container_name) + self.assertEqual(service.net.mode, 'container:' + container_name) def test_container_without_name(self): self.mock_client.containers.return_value = [ diff --git a/tests/unit/service_test.py b/tests/unit/service_test.py index 9979c8f123..6ed3d981a3 100644 --- a/tests/unit/service_test.py +++ b/tests/unit/service_test.py @@ -7,21 +7,27 @@ import mock import docker from docker.utils import LogConfig -from compose.service import Service +from .. import mock +from .. import unittest +from compose.const import LABEL_CONFIG_HASH +from compose.const import LABEL_ONE_OFF +from compose.const import LABEL_PROJECT +from compose.const import LABEL_SERVICE from compose.container import Container -from compose.const import LABEL_SERVICE, LABEL_PROJECT, LABEL_ONE_OFF -from compose.service import ( - ConfigError, - NeedsBuildError, - NoSuchImageError, - build_port_bindings, - build_volume_binding, - get_container_data_volumes, - merge_volume_bindings, - parse_repository_tag, - parse_volume_spec, - split_port, -) +from compose.service import ConfigError +from compose.service import ContainerNet +from compose.service import NeedsBuildError +from compose.service import Net +from compose.service import NoSuchImageError +from compose.service import Service +from compose.service import ServiceNet +from compose.service import build_port_bindings +from compose.service import build_volume_binding +from compose.service import get_container_data_volumes +from compose.service import merge_volume_bindings +from compose.service import parse_repository_tag +from compose.service import parse_volume_spec +from compose.service import split_port class ServiceTest(unittest.TestCase): @@ -393,7 +399,7 @@ class ServiceTest(unittest.TestCase): 'foo', image='example.com/foo', client=self.mock_client, - net=Service('other'), + net=ServiceNet(Service('other')), links=[(Service('one'), 'one')], volumes_from=[Service('two')]) @@ -429,6 +435,49 @@ class ServiceTest(unittest.TestCase): self.assertEqual(config_dict, expected) +class NetTestCase(unittest.TestCase): + + def test_net(self): + net = Net('host') + self.assertEqual(net.id, 'host') + self.assertEqual(net.mode, 'host') + self.assertEqual(net.service_name, None) + + def test_net_container(self): + container_id = 'abcd' + net = ContainerNet(Container(None, {'Id': container_id})) + self.assertEqual(net.id, container_id) + self.assertEqual(net.mode, 'container:' + container_id) + self.assertEqual(net.service_name, None) + + def test_net_service(self): + container_id = 'bbbb' + service_name = 'web' + mock_client = mock.create_autospec(docker.Client) + mock_client.containers.return_value = [ + {'Id': container_id, 'Name': container_id, 'Image': 'abcd'}, + ] + + service = Service(name=service_name, client=mock_client) + net = ServiceNet(service) + + self.assertEqual(net.id, service_name) + self.assertEqual(net.mode, 'container:' + container_id) + self.assertEqual(net.service_name, service_name) + + def test_net_service_no_containers(self): + service_name = 'web' + mock_client = mock.create_autospec(docker.Client) + mock_client.containers.return_value = [] + + service = Service(name=service_name, client=mock_client) + net = ServiceNet(service) + + self.assertEqual(net.id, service_name) + self.assertEqual(net.mode, None) + self.assertEqual(net.service_name, service_name) + + def mock_get_image(images): if images: return images[0]