diff --git a/discovery/zookeeper/zookeeper.go b/discovery/zookeeper/zookeeper.go index 8c990520af..49845f09c3 100644 --- a/discovery/zookeeper/zookeeper.go +++ b/discovery/zookeeper/zookeeper.go @@ -13,7 +13,7 @@ import ( type ZkDiscoveryService struct { conn *zk.Conn - path string + path []string heartbeat int } @@ -21,6 +21,24 @@ func init() { discovery.Register("zk", &ZkDiscoveryService{}) } +func (s *ZkDiscoveryService) fullpath() string { + return "/" + strings.Join(s.path, "/") +} + +func (s *ZkDiscoveryService) createFullpath() error { + for i := 1; i <= len(s.path); i++ { + newpath := "/" + strings.Join(s.path[:i], "/") + _, err := s.conn.Create(newpath, []byte{1}, 0, zk.WorldACL(zk.PermAll)) + if err != nil { + // It's OK if key already existed. Just skip. + if err != zk.ErrNodeExists { + return err + } + } + } + return nil +} + func (s *ZkDiscoveryService) Initialize(uris string, heartbeat int) error { var ( // split here because uris can contain multiples ips @@ -33,29 +51,29 @@ func (s *ZkDiscoveryService) Initialize(uris string, heartbeat int) error { return fmt.Errorf("invalid format %q, missing ", uris) } - conn, _, err := zk.Connect(ips, time.Second) + if strings.Contains(parts[1], "/") { + s.path = strings.Split(parts[1], "/") + } else { + s.path = []string{parts[1]} + } + conn, _, err := zk.Connect(ips, time.Second) if err != nil { return err } s.conn = conn - s.path = "/" + parts[1] s.heartbeat = heartbeat - - _, err = conn.Create(s.path, []byte{1}, 0, zk.WorldACL(zk.PermAll)) + err = s.createFullpath() if err != nil { - // if key already existed, then skip - if err != zk.ErrNodeExists { - return err - } + return err } return nil } func (s *ZkDiscoveryService) Fetch() ([]*discovery.Node, error) { - addrs, _, err := s.conn.Children(s.path) + addrs, _, err := s.conn.Children(s.fullpath()) if err != nil { return nil, err @@ -78,7 +96,7 @@ func (s *ZkDiscoveryService) createNodes(addrs []string) (nodes []*discovery.Nod func (s *ZkDiscoveryService) Watch(callback discovery.WatchCallback) { - addrs, _, eventChan, err := s.conn.ChildrenW(s.path) + addrs, _, eventChan, err := s.conn.ChildrenW(s.fullpath()) if err != nil { log.Debugf("[ZK] Watch aborted") return @@ -100,41 +118,37 @@ func (s *ZkDiscoveryService) Watch(callback discovery.WatchCallback) { } func (s *ZkDiscoveryService) Register(addr string) error { - newpath := path.Join(s.path, addr) + nodePath := "/" + path.Join(s.fullpath(), addr) // check existing for the parent path first - exist, _, err := s.conn.Exists(s.path) + exist, _, err := s.conn.Exists(s.fullpath()) if err != nil { return err } - // create parent first + // if the parent path does not exist yet if exist == false { - - _, err = s.conn.Create(s.path, []byte{1}, 0, zk.WorldACL(zk.PermAll)) + // create the parent first + err = s.createFullpath() if err != nil { return err } - _, err = s.conn.Create(newpath, []byte(addr), 0, zk.WorldACL(zk.PermAll)) - return err - } else { - - exist, _, err = s.conn.Exists(newpath) + // if node path exists + exist, _, err = s.conn.Exists(nodePath) if err != nil { return err } - + // delete it first if exist { - err = s.conn.Delete(newpath, -1) + err = s.conn.Delete(nodePath, -1) if err != nil { return err } } - - _, err = s.conn.Create(newpath, []byte(addr), 0, zk.WorldACL(zk.PermAll)) - return err } - return nil + // create the node path to store address information + _, err = s.conn.Create(nodePath, []byte(addr), 0, zk.WorldACL(zk.PermAll)) + return err } diff --git a/discovery/zookeeper/zookeeper_test.go b/discovery/zookeeper/zookeeper_test.go index c07a879aaa..681d830801 100644 --- a/discovery/zookeeper/zookeeper_test.go +++ b/discovery/zookeeper/zookeeper_test.go @@ -13,10 +13,13 @@ func TestInitialize(t *testing.T) { assert.Equal(t, service.Initialize("127.0.0.1", 0).Error(), "invalid format \"127.0.0.1\", missing ") assert.Error(t, service.Initialize("127.0.0.1/path", 0)) - assert.Equal(t, service.path, "/path") + assert.Equal(t, service.fullpath(), "/path") assert.Error(t, service.Initialize("127.0.0.1,127.0.0.2,127.0.0.3/path", 0)) - assert.Equal(t, service.path, "/path") + assert.Equal(t, service.fullpath(), "/path") + + assert.Error(t, service.Initialize("127.0.0.1,127.0.0.2,127.0.0.3/path/sub1/sub2", 0)) + assert.Equal(t, service.fullpath(), "/path/sub1/sub2") } func TestCreateNodes(t *testing.T) {