diff --git a/discovery/discovery.go b/discovery/discovery.go index 190021e778..0ae35705b5 100644 --- a/discovery/discovery.go +++ b/discovery/discovery.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "net/url" + "strings" "time" log "github.com/Sirupsen/logrus" @@ -11,8 +12,23 @@ import ( type InitFunc func(url string) (DiscoveryService, error) +type Node struct { + url string +} + +func NewNode(url string) *Node { + if !strings.Contains(url, "://") { + url = "http://" + url + } + return &Node{url: url} +} + +func (n Node) String() string { + return n.url +} + type DiscoveryService interface { - Fetch() ([]string, error) + Fetch() ([]*Node, error) Watch(int) <-chan time.Time Register(string) error } diff --git a/discovery/etcd/etcd.go b/discovery/etcd/etcd.go index a8cbfb4a53..a32c95197a 100644 --- a/discovery/etcd/etcd.go +++ b/discovery/etcd/etcd.go @@ -38,15 +38,16 @@ func Init(uris string) (discovery.DiscoveryService, error) { client.CreateDir(path, DEFAULT_TTL) // skip error check error because it might already exists return EtcdDiscoveryService{client: client, path: path}, nil } -func (s EtcdDiscoveryService) Fetch() ([]string, error) { +func (s EtcdDiscoveryService) Fetch() ([]*discovery.Node, error) { resp, err := s.client.Get(s.path, true, true) if err != nil { return nil, err } - nodes := []string{} + + var nodes []*discovery.Node for _, n := range resp.Node.Nodes { - nodes = append(nodes, n.Value) + nodes = append(nodes, discovery.NewNode(n.Value)) } return nodes, nil } diff --git a/discovery/file/file.go b/discovery/file/file.go index f9cce05587..9e037d963b 100644 --- a/discovery/file/file.go +++ b/discovery/file/file.go @@ -21,20 +21,20 @@ func Init(file string) (discovery.DiscoveryService, error) { return FileDiscoveryService{path: file}, nil } -func (s FileDiscoveryService) Fetch() ([]string, error) { +func (s FileDiscoveryService) Fetch() ([]*discovery.Node, error) { data, err := ioutil.ReadFile(s.path) if err != nil { return nil, err } - lines := []string{} + var nodes []*discovery.Node for _, line := range strings.Split(string(data), "\n") { if line != "" { - lines = append(lines, line) + nodes = append(nodes, discovery.NewNode(line)) } } - return lines, nil + return nodes, nil } func (s FileDiscoveryService) Watch(heartbeat int) <-chan time.Time { diff --git a/discovery/token/token.go b/discovery/token/token.go index e0379002c5..36c70aede5 100644 --- a/discovery/token/token.go +++ b/discovery/token/token.go @@ -37,7 +37,7 @@ func New(url string) *TokenDiscoveryService { } // FetchNodes returns the node for the discovery service at the specified endpoint -func (s TokenDiscoveryService) Fetch() ([]string, error) { +func (s TokenDiscoveryService) Fetch() ([]*discovery.Node, error) { resp, err := http.Get(fmt.Sprintf("%s/%s/%s", s.url, "clusters", s.token)) if err != nil { return nil, err @@ -54,7 +54,12 @@ func (s TokenDiscoveryService) Fetch() ([]string, error) { } } - return addrs, nil + var nodes []*discovery.Node + for _, addr := range addrs { + nodes = append(nodes, discovery.NewNode(addr)) + } + + return nodes, nil } func (s TokenDiscoveryService) Watch(heartbeat int) <-chan time.Time { diff --git a/discovery/token/token_test.go b/discovery/token/token_test.go index a15d118539..8b8a7ae945 100644 --- a/discovery/token/token_test.go +++ b/discovery/token/token_test.go @@ -28,7 +28,7 @@ func TestRegister(t *testing.T) { addrs, err := discovery.Fetch() assert.NoError(t, err) assert.Equal(t, len(addrs), 1) - assert.Equal(t, addrs[0], expected) + assert.Equal(t, addrs[0].String(), "http://"+expected) assert.NoError(t, discovery.Register(expected)) } diff --git a/manage.go b/manage.go index 8ac40b05ca..0ce8b847d8 100644 --- a/manage.go +++ b/manage.go @@ -5,7 +5,6 @@ import ( "crypto/x509" "fmt" "io/ioutil" - "strings" log "github.com/Sirupsen/logrus" "github.com/codegangsta/cli" @@ -72,14 +71,11 @@ func manage(c *cli.Context) { } } - refresh := func(c *cluster.Cluster, nodes []string) { + refresh := func(c *cluster.Cluster, nodes []*discovery.Node) { for _, addr := range nodes { - go func(addr string) { - if !strings.Contains(addr, "://") { - addr = "http://" + addr - } - if c.Node(addr) == nil { - n := cluster.NewNode(addr) + go func(node *discovery.Node) { + if c.Node(node.String()) == nil { + n := cluster.NewNode(node.String()) if err := n.Connect(tlsConfig); err != nil { log.Error(err) return @@ -119,7 +115,11 @@ func manage(c *cli.Context) { } }() } else { - refresh(cluster, c.Args()) + var nodes []*discovery.Node + for _, arg := range c.Args() { + nodes = append(nodes, discovery.NewNode(arg)) + } + refresh(cluster, nodes) } }()