diff --git a/discovery/consul/consul.go b/discovery/consul/consul.go index 00290fa0dd..7d45147123 100644 --- a/discovery/consul/consul.go +++ b/discovery/consul/consul.go @@ -1,7 +1,7 @@ package consul import ( - "errors" + "fmt" "path" "strings" "time" @@ -25,7 +25,7 @@ func init() { func (s *ConsulDiscoveryService) Initialize(uris string, heartbeat int) error { parts := strings.SplitN(uris, "/", 2) if len(parts) < 2 { - return errors.New("missing consul prefix") + return fmt.Errorf("invalid format %q, missing ", uris) } addr := parts[0] path := parts[1] diff --git a/discovery/consul/consul_test.go b/discovery/consul/consul_test.go index 79bf873bdf..9a2ff1b3c0 100644 --- a/discovery/consul/consul_test.go +++ b/discovery/consul/consul_test.go @@ -8,6 +8,13 @@ import ( func TestInitialize(t *testing.T) { discovery := &ConsulDiscoveryService{} - discovery.Initialize("127.0.0.1:8500/path", 0) + + assert.Equal(t, discovery.Initialize("127.0.0.1", 0).Error(), "invalid format \"127.0.0.1\", missing ") + + assert.Error(t, discovery.Initialize("127.0.0.1/path", 0)) assert.Equal(t, discovery.prefix, "path/") + + assert.Error(t, discovery.Initialize("127.0.0.1,127.0.0.2,127.0.0.3/path", 0)) + assert.Equal(t, discovery.prefix, "path/") + } diff --git a/discovery/etcd/etcd.go b/discovery/etcd/etcd.go index 9e4efe4d0c..91641498d1 100644 --- a/discovery/etcd/etcd.go +++ b/discovery/etcd/etcd.go @@ -1,6 +1,7 @@ package etcd import ( + "fmt" "path" "strings" @@ -27,6 +28,11 @@ func (s *EtcdDiscoveryService) Initialize(uris string, heartbeat int) error { ips = strings.Split(parts[0], ",") machines []string ) + + if len(parts) != 2 { + return fmt.Errorf("invalid format %q, missing ", uris) + } + for _, ip := range ips { machines = append(machines, "http://"+ip) } diff --git a/discovery/etcd/etcd_test.go b/discovery/etcd/etcd_test.go index 2088890a63..3070e13e9c 100644 --- a/discovery/etcd/etcd_test.go +++ b/discovery/etcd/etcd_test.go @@ -8,6 +8,9 @@ import ( func TestInitialize(t *testing.T) { discovery := &EtcdDiscoveryService{} + + assert.Equal(t, discovery.Initialize("127.0.0.1", 0).Error(), "invalid format \"127.0.0.1\", missing ") + assert.Error(t, discovery.Initialize("127.0.0.1/path", 0)) assert.Equal(t, discovery.path, "/path/") diff --git a/discovery/zookeeper/zookeeper.go b/discovery/zookeeper/zookeeper.go index d133d9c227..8c990520af 100644 --- a/discovery/zookeeper/zookeeper.go +++ b/discovery/zookeeper/zookeeper.go @@ -1,6 +1,7 @@ package zookeeper import ( + "fmt" "path" "strings" "time" @@ -28,6 +29,10 @@ func (s *ZkDiscoveryService) Initialize(uris string, heartbeat int) error { ips = strings.Split(parts[0], ",") ) + if len(parts) != 2 { + return fmt.Errorf("invalid format %q, missing ", uris) + } + conn, _, err := zk.Connect(ips, time.Second) if err != nil { diff --git a/discovery/zookeeper/zookeeper_test.go b/discovery/zookeeper/zookeeper_test.go index 221d3a57f7..c07a879aaa 100644 --- a/discovery/zookeeper/zookeeper_test.go +++ b/discovery/zookeeper/zookeeper_test.go @@ -9,6 +9,9 @@ import ( func TestInitialize(t *testing.T) { service := &ZkDiscoveryService{} + + assert.Equal(t, service.Initialize("127.0.0.1", 0).Error(), "invalid format \"127.0.0.1\", missing ") + assert.Error(t, service.Initialize("127.0.0.1/path", 0)) assert.Equal(t, service.path, "/path")