From afc54df6b39f748d11de1dab93db94cbb20c6b4a Mon Sep 17 00:00:00 2001 From: Gaius Date: Thu, 28 Mar 2024 18:04:23 +0800 Subject: [PATCH] feat: optimize graph based on sync.Map (#3152) Signed-off-by: Gaius --- pkg/container/set/safe_set.go | 6 +- pkg/container/set/set.go | 6 +- pkg/graph/dag/dag.go | 142 +++++++++++++++++------------ pkg/graph/dag/dag_test.go | 58 +++++------- pkg/graph/dag/mocks/dag_mock.go | 18 +--- pkg/graph/dag/vertex.go | 8 +- pkg/graph/dg/dg.go | 157 +++++++++++++++++--------------- pkg/graph/dg/dg_test.go | 58 +++++------- pkg/graph/dg/mocks/dg_mock.go | 18 +--- pkg/graph/dg/vertex.go | 8 +- scheduler/resource/task.go | 2 +- 11 files changed, 235 insertions(+), 246 deletions(-) diff --git a/pkg/container/set/safe_set.go b/pkg/container/set/safe_set.go index cd8941455..e6f26d6d1 100644 --- a/pkg/container/set/safe_set.go +++ b/pkg/container/set/safe_set.go @@ -47,7 +47,11 @@ func (s *safeSet[T]) Values() []T { s.mu.RLock() defer s.mu.RUnlock() - var result []T + if len(s.data) == 0 { + return nil + } + + result := make([]T, 0, len(s.data)) for k := range s.data { result = append(result, k) } diff --git a/pkg/container/set/set.go b/pkg/container/set/set.go index 853210b1a..41e3bdcc6 100644 --- a/pkg/container/set/set.go +++ b/pkg/container/set/set.go @@ -34,7 +34,11 @@ func New[T comparable]() Set[T] { } func (s *set[T]) Values() []T { - var result []T + if len(*s) == 0 { + return nil + } + + result := make([]T, 0, len(*s)) for k := range *s { result = append(result, k) } diff --git a/pkg/graph/dag/dag.go b/pkg/graph/dag/dag.go index dbb2db333..a9e22e787 100644 --- a/pkg/graph/dag/dag.go +++ b/pkg/graph/dag/dag.go @@ -20,11 +20,9 @@ package dag import ( "errors" - "math/rand" "sync" - "time" - cmap "github.com/orcaman/concurrent-map/v2" + "go.uber.org/atomic" "d7y.io/dragonfly/v2/pkg/container/set" ) @@ -33,6 +31,9 @@ var ( // ErrVertexNotFound represents vertex not found. ErrVertexNotFound = errors.New("vertex not found") + // ErrVertexInvalid represents vertex invalid. + ErrVertexInvalid = errors.New("vertex invalid") + // ErrVertexAlreadyExists represents vertex already exists. ErrVertexAlreadyExists = errors.New("vertex already exists") @@ -63,9 +64,6 @@ type DAG[T comparable] interface { // GetRandomVertices returns random map of vertices. GetRandomVertices(n uint) []*Vertex[T] - // GetVertexKeys returns keys of vertices. - GetVertexKeys() []string - // GetSourceVertices returns source vertices. GetSourceVertices() []*Vertex[T] @@ -73,7 +71,7 @@ type DAG[T comparable] interface { GetSinkVertices() []*Vertex[T] // VertexCount returns count of vertices. - VertexCount() int + VertexCount() uint64 // AddEdge adds edge between two vertices. AddEdge(fromVertexID, toVertexID string) error @@ -93,14 +91,17 @@ type DAG[T comparable] interface { // dag provides directed acyclic graph function. type dag[T comparable] struct { + vertices *sync.Map + count *atomic.Uint64 mu sync.RWMutex - vertices cmap.ConcurrentMap[string, *Vertex[T]] } // New returns a new DAG interface. func NewDAG[T comparable]() DAG[T] { return &dag[T]{ - vertices: cmap.New[*Vertex[T]](), + vertices: &sync.Map{}, + count: atomic.NewUint64(0), + mu: sync.RWMutex{}, } } @@ -109,11 +110,11 @@ func (d *dag[T]) AddVertex(id string, value T) error { d.mu.Lock() defer d.mu.Unlock() - if _, ok := d.vertices.Get(id); ok { + if _, loaded := d.vertices.LoadOrStore(id, NewVertex(id, value)); loaded { return ErrVertexAlreadyExists } - d.vertices.Set(id, NewVertex(id, value)) + d.count.Inc() return nil } @@ -122,7 +123,12 @@ func (d *dag[T]) DeleteVertex(id string) { d.mu.Lock() defer d.mu.Unlock() - vertex, ok := d.vertices.Get(id) + rawVertex, loaded := d.vertices.Load(id) + if !loaded { + return + } + + vertex, ok := rawVertex.(*Vertex[T]) if !ok { return } @@ -136,22 +142,47 @@ func (d *dag[T]) DeleteVertex(id string) { continue } - d.vertices.Remove(id) + d.vertices.Delete(id) + d.count.Dec() } // GetVertex gets vertex from graph. func (d *dag[T]) GetVertex(id string) (*Vertex[T], error) { - vertex, ok := d.vertices.Get(id) - if !ok { + rawVertex, loaded := d.vertices.Load(id) + if !loaded { return nil, ErrVertexNotFound } + vertex, ok := rawVertex.(*Vertex[T]) + if !ok { + return nil, ErrVertexInvalid + } + return vertex, nil } // GetVertices returns map of vertices. func (d *dag[T]) GetVertices() map[string]*Vertex[T] { - return d.vertices.Items() + d.mu.RLock() + defer d.mu.RUnlock() + + vertices := make(map[string]*Vertex[T], d.count.Load()) + d.vertices.Range(func(key, value interface{}) bool { + vertex, ok := value.(*Vertex[T]) + if !ok { + return true + } + + id, ok := key.(string) + if !ok { + return true + } + + vertices[id] = vertex + return true + }) + + return vertices } // GetRandomVertices returns random map of vertices. @@ -159,32 +190,27 @@ func (d *dag[T]) GetRandomVertices(n uint) []*Vertex[T] { d.mu.RLock() defer d.mu.RUnlock() - keys := d.GetVertexKeys() - if int(n) >= len(keys) { - n = uint(len(keys)) + if n == 0 { + return nil } - r := rand.New(rand.NewSource(time.Now().UnixNano())) - permutation := r.Perm(len(keys))[:n] randomVertices := make([]*Vertex[T], 0, n) - for _, v := range permutation { - key := keys[v] - if vertex, err := d.GetVertex(key); err == nil { - randomVertices = append(randomVertices, vertex) + d.vertices.Range(func(key, value interface{}) bool { + vertex, ok := value.(*Vertex[T]) + if !ok { + return true } - } + + randomVertices = append(randomVertices, vertex) + return uint(len(randomVertices)) < n + }) return randomVertices } -// GetVertexKeys returns keys of vertices. -func (d *dag[T]) GetVertexKeys() []string { - return d.vertices.Keys() -} - // VertexCount returns count of vertices. -func (d *dag[T]) VertexCount() int { - return d.vertices.Count() +func (d *dag[T]) VertexCount() uint64 { + return d.count.Load() } // AddEdge adds edge between two vertices. @@ -196,14 +222,14 @@ func (d *dag[T]) AddEdge(fromVertexID, toVertexID string) error { return ErrCycleBetweenVertices } - fromVertex, ok := d.vertices.Get(fromVertexID) - if !ok { - return ErrVertexNotFound + fromVertex, err := d.GetVertex(fromVertexID) + if err != nil { + return err } - toVertex, ok := d.vertices.Get(toVertexID) - if !ok { - return ErrVertexNotFound + toVertex, err := d.GetVertex(toVertexID) + if err != nil { + return err } for _, child := range fromVertex.Children.Values() { @@ -232,14 +258,14 @@ func (d *dag[T]) DeleteEdge(fromVertexID, toVertexID string) error { d.mu.Lock() defer d.mu.Unlock() - fromVertex, ok := d.vertices.Get(fromVertexID) - if !ok { - return ErrVertexNotFound + fromVertex, err := d.GetVertex(fromVertexID) + if err != nil { + return err } - toVertex, ok := d.vertices.Get(toVertexID) - if !ok { - return ErrVertexNotFound + toVertex, err := d.GetVertex(toVertexID) + if err != nil { + return err } fromVertex.Children.Delete(toVertex) @@ -256,12 +282,12 @@ func (d *dag[T]) CanAddEdge(fromVertexID, toVertexID string) bool { return false } - fromVertex, ok := d.vertices.Get(fromVertexID) - if !ok { + fromVertex, err := d.GetVertex(fromVertexID) + if err != nil { return false } - if _, ok := d.vertices.Get(toVertexID); !ok { + if _, err := d.GetVertex(toVertexID); err != nil { return false } @@ -283,9 +309,9 @@ func (d *dag[T]) DeleteVertexInEdges(id string) error { d.mu.Lock() defer d.mu.Unlock() - vertex, ok := d.vertices.Get(id) - if !ok { - return ErrVertexNotFound + vertex, err := d.GetVertex(id) + if err != nil { + return err } for _, parent := range vertex.Parents.Values() { @@ -301,9 +327,9 @@ func (d *dag[T]) DeleteVertexOutEdges(id string) error { d.mu.Lock() defer d.mu.Unlock() - vertex, ok := d.vertices.Get(id) - if !ok { - return ErrVertexNotFound + vertex, err := d.GetVertex(id) + if err != nil { + return err } for _, child := range vertex.Children.Values() { @@ -320,7 +346,7 @@ func (d *dag[T]) GetSourceVertices() []*Vertex[T] { defer d.mu.RUnlock() var sourceVertices []*Vertex[T] - for _, vertex := range d.vertices.Items() { + for _, vertex := range d.GetVertices() { if vertex.InDegree() == 0 { sourceVertices = append(sourceVertices, vertex) } @@ -335,7 +361,7 @@ func (d *dag[T]) GetSinkVertices() []*Vertex[T] { defer d.mu.RUnlock() var sinkVertices []*Vertex[T] - for _, vertex := range d.vertices.Items() { + for _, vertex := range d.GetVertices() { if vertex.OutDegree() == 0 { sinkVertices = append(sinkVertices, vertex) } @@ -354,8 +380,8 @@ func (d *dag[T]) depthFirstSearch(fromVertexID, toVertexID string) bool { // search finds successors of vertex. func (d *dag[T]) search(vertexID string, successors map[string]struct{}) { - vertex, ok := d.vertices.Get(vertexID) - if !ok { + vertex, err := d.GetVertex(vertexID) + if err != nil { return } diff --git a/pkg/graph/dag/dag_test.go b/pkg/graph/dag/dag_test.go index 14a897934..1579fb9af 100644 --- a/pkg/graph/dag/dag_test.go +++ b/pkg/graph/dag/dag_test.go @@ -179,17 +179,17 @@ func TestDAG_VertexCount(t *testing.T) { } d.VertexCount() - assert.Equal(d.VertexCount(), 1) + assert.Equal(d.VertexCount(), uint64(1)) d.DeleteVertex(mockVertexID) - assert.Equal(d.VertexCount(), 0) + assert.Equal(d.VertexCount(), uint64(0)) }, }, { name: "empty dag", expect: func(t *testing.T, d DAG[string]) { assert := assert.New(t) - assert.Equal(d.VertexCount(), 0) + assert.Equal(d.VertexCount(), uint64(0)) }, }, } @@ -302,38 +302,6 @@ func TestDAG_GetRandomVertices(t *testing.T) { } } -func TestDAG_GetVertexKeys(t *testing.T) { - tests := []struct { - name string - expect func(t *testing.T, d DAG[string]) - }{ - { - name: "get keys of vertices", - expect: func(t *testing.T, d DAG[string]) { - assert := assert.New(t) - if err := d.AddVertex(mockVertexID, mockVertexValue); err != nil { - assert.NoError(err) - } - - keys := d.GetVertexKeys() - assert.Equal(len(keys), 1) - assert.Equal(keys[0], mockVertexID) - - d.DeleteVertex(mockVertexID) - keys = d.GetVertexKeys() - assert.Equal(len(keys), 0) - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - d := NewDAG[string]() - tc.expect(t, d) - }) - } -} - func TestDAG_AddEdge(t *testing.T) { tests := []struct { name string @@ -965,7 +933,25 @@ func BenchmarkDAG_DeleteVertex(b *testing.B) { } } -func BenchmarkDAG_GetRandomKeys(b *testing.B) { +func BenchmarkDAG_GetVertices(b *testing.B) { + d := NewDAG[string]() + for n := 0; n < b.N; n++ { + id := fmt.Sprint(n) + if err := d.AddVertex(id, string(id)); err != nil { + b.Fatal(err) + } + } + + b.ResetTimer() + for n := 0; n < b.N; n++ { + vertices := d.GetVertices() + if len(vertices) != b.N { + b.Fatal(errors.New("get vertices failed")) + } + } +} + +func BenchmarkDAG_GetRandomVertices(b *testing.B) { d := NewDAG[string]() for n := 0; n < b.N; n++ { id := fmt.Sprint(n) diff --git a/pkg/graph/dag/mocks/dag_mock.go b/pkg/graph/dag/mocks/dag_mock.go index a8e04bb71..5c16c417a 100644 --- a/pkg/graph/dag/mocks/dag_mock.go +++ b/pkg/graph/dag/mocks/dag_mock.go @@ -191,20 +191,6 @@ func (mr *MockDAGMockRecorder[T]) GetVertex(id any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVertex", reflect.TypeOf((*MockDAG[T])(nil).GetVertex), id) } -// GetVertexKeys mocks base method. -func (m *MockDAG[T]) GetVertexKeys() []string { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetVertexKeys") - ret0, _ := ret[0].([]string) - return ret0 -} - -// GetVertexKeys indicates an expected call of GetVertexKeys. -func (mr *MockDAGMockRecorder[T]) GetVertexKeys() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVertexKeys", reflect.TypeOf((*MockDAG[T])(nil).GetVertexKeys)) -} - // GetVertices mocks base method. func (m *MockDAG[T]) GetVertices() map[string]*dag.Vertex[T] { m.ctrl.T.Helper() @@ -220,10 +206,10 @@ func (mr *MockDAGMockRecorder[T]) GetVertices() *gomock.Call { } // VertexCount mocks base method. -func (m *MockDAG[T]) VertexCount() int { +func (m *MockDAG[T]) VertexCount() uint64 { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "VertexCount") - ret0, _ := ret[0].(int) + ret0, _ := ret[0].(uint64) return ret0 } diff --git a/pkg/graph/dag/vertex.go b/pkg/graph/dag/vertex.go index bd28e3059..c331fa1dc 100644 --- a/pkg/graph/dag/vertex.go +++ b/pkg/graph/dag/vertex.go @@ -24,8 +24,8 @@ import ( type Vertex[T comparable] struct { ID string Value T - Parents set.SafeSet[*Vertex[T]] - Children set.SafeSet[*Vertex[T]] + Parents set.Set[*Vertex[T]] + Children set.Set[*Vertex[T]] } // New returns a new Vertex instance. @@ -33,8 +33,8 @@ func NewVertex[T comparable](id string, value T) *Vertex[T] { return &Vertex[T]{ ID: id, Value: value, - Parents: set.NewSafeSet[*Vertex[T]](), - Children: set.NewSafeSet[*Vertex[T]](), + Parents: set.New[*Vertex[T]](), + Children: set.New[*Vertex[T]](), } } diff --git a/pkg/graph/dg/dg.go b/pkg/graph/dg/dg.go index 2af79e120..c06033bc2 100644 --- a/pkg/graph/dg/dg.go +++ b/pkg/graph/dg/dg.go @@ -20,11 +20,9 @@ package dg import ( "errors" - "math/rand" "sync" - "time" - cmap "github.com/orcaman/concurrent-map/v2" + "go.uber.org/atomic" "d7y.io/dragonfly/v2/pkg/container/set" ) @@ -33,6 +31,9 @@ var ( // ErrVertexNotFound represents vertex not found. ErrVertexNotFound = errors.New("vertex not found") + // ErrVertexInvalid represents vertex invalid. + ErrVertexInvalid = errors.New("vertex invalid") + // ErrVertexAlreadyExists represents vertex already exists. ErrVertexAlreadyExists = errors.New("vertex already exists") @@ -63,9 +64,6 @@ type DG[T comparable] interface { // GetRandomVertices returns random map of vertices. GetRandomVertices(n uint) []*Vertex[T] - // GetVertexKeys returns keys of vertices. - GetVertexKeys() []string - // GetSourceVertices returns source vertices. GetSourceVertices() []*Vertex[T] @@ -73,7 +71,7 @@ type DG[T comparable] interface { GetSinkVertices() []*Vertex[T] // VertexCount returns count of vertices. - VertexCount() int + VertexCount() uint64 // AddEdge adds edge between two vertices. AddEdge(fromVertexID, toVertexID string) error @@ -81,7 +79,7 @@ type DG[T comparable] interface { // DeleteEdge deletes edge between two vertices. DeleteEdge(fromVertexID, toVertexID string) error - // CanAddEdge finds whether there are circles through depth-first search. + // CanAddEdge indicates whether can add edge between two vertices. CanAddEdge(fromVertexID, toVertexID string) bool // DeleteVertexInEdges deletes inedges of vertex. @@ -93,14 +91,17 @@ type DG[T comparable] interface { // dg provides directed graph function. type dg[T comparable] struct { + vertices *sync.Map + count *atomic.Uint64 mu sync.RWMutex - vertices cmap.ConcurrentMap[string, *Vertex[T]] } // New returns a new DG interface. func NewDG[T comparable]() DG[T] { return &dg[T]{ - vertices: cmap.New[*Vertex[T]](), + vertices: &sync.Map{}, + count: atomic.NewUint64(0), + mu: sync.RWMutex{}, } } @@ -109,11 +110,11 @@ func (d *dg[T]) AddVertex(id string, value T) error { d.mu.Lock() defer d.mu.Unlock() - if _, ok := d.vertices.Get(id); ok { + if _, loaded := d.vertices.LoadOrStore(id, NewVertex(id, value)); loaded { return ErrVertexAlreadyExists } - d.vertices.Set(id, NewVertex(id, value)) + d.count.Inc() return nil } @@ -122,7 +123,12 @@ func (d *dg[T]) DeleteVertex(id string) { d.mu.Lock() defer d.mu.Unlock() - vertex, ok := d.vertices.Get(id) + rawVertex, loaded := d.vertices.Load(id) + if !loaded { + return + } + + vertex, ok := rawVertex.(*Vertex[T]) if !ok { return } @@ -136,22 +142,47 @@ func (d *dg[T]) DeleteVertex(id string) { continue } - d.vertices.Remove(id) + d.vertices.Delete(id) + d.count.Dec() } // GetVertex gets vertex from graph. func (d *dg[T]) GetVertex(id string) (*Vertex[T], error) { - vertex, ok := d.vertices.Get(id) - if !ok { + rawVertex, loaded := d.vertices.Load(id) + if !loaded { return nil, ErrVertexNotFound } + vertex, ok := rawVertex.(*Vertex[T]) + if !ok { + return nil, ErrVertexInvalid + } + return vertex, nil } // GetVertices returns map of vertices. func (d *dg[T]) GetVertices() map[string]*Vertex[T] { - return d.vertices.Items() + d.mu.RLock() + defer d.mu.RUnlock() + + vertices := make(map[string]*Vertex[T], d.count.Load()) + d.vertices.Range(func(key, value interface{}) bool { + vertex, ok := value.(*Vertex[T]) + if !ok { + return true + } + + id, ok := key.(string) + if !ok { + return true + } + + vertices[id] = vertex + return true + }) + + return vertices } // GetRandomVertices returns random map of vertices. @@ -159,32 +190,27 @@ func (d *dg[T]) GetRandomVertices(n uint) []*Vertex[T] { d.mu.RLock() defer d.mu.RUnlock() - keys := d.GetVertexKeys() - if int(n) >= len(keys) { - n = uint(len(keys)) + if n == 0 { + return nil } - r := rand.New(rand.NewSource(time.Now().UnixNano())) - permutation := r.Perm(len(keys))[:n] randomVertices := make([]*Vertex[T], 0, n) - for _, v := range permutation { - key := keys[v] - if vertex, err := d.GetVertex(key); err == nil { - randomVertices = append(randomVertices, vertex) + d.vertices.Range(func(key, value interface{}) bool { + vertex, ok := value.(*Vertex[T]) + if !ok { + return true } - } + + randomVertices = append(randomVertices, vertex) + return uint(len(randomVertices)) < n + }) return randomVertices } -// GetVertexKeys returns keys of vertices. -func (d *dg[T]) GetVertexKeys() []string { - return d.vertices.Keys() -} - // VertexCount returns count of vertices. -func (d *dg[T]) VertexCount() int { - return d.vertices.Count() +func (d *dg[T]) VertexCount() uint64 { + return d.count.Load() } // AddEdge adds edge between two vertices. @@ -196,14 +222,14 @@ func (d *dg[T]) AddEdge(fromVertexID, toVertexID string) error { return ErrCycleBetweenVertices } - fromVertex, ok := d.vertices.Get(fromVertexID) - if !ok { - return ErrVertexNotFound + fromVertex, err := d.GetVertex(fromVertexID) + if err != nil { + return err } - toVertex, ok := d.vertices.Get(toVertexID) - if !ok { - return ErrVertexNotFound + toVertex, err := d.GetVertex(toVertexID) + if err != nil { + return err } for _, child := range fromVertex.Children.Values() { @@ -228,14 +254,14 @@ func (d *dg[T]) DeleteEdge(fromVertexID, toVertexID string) error { d.mu.Lock() defer d.mu.Unlock() - fromVertex, ok := d.vertices.Get(fromVertexID) - if !ok { - return ErrVertexNotFound + fromVertex, err := d.GetVertex(fromVertexID) + if err != nil { + return err } - toVertex, ok := d.vertices.Get(toVertexID) - if !ok { - return ErrVertexNotFound + toVertex, err := d.GetVertex(toVertexID) + if err != nil { + return err } fromVertex.Children.Delete(toVertex) @@ -243,7 +269,7 @@ func (d *dg[T]) DeleteEdge(fromVertexID, toVertexID string) error { return nil } -// CanAddEdge finds whether there are circles through depth-first search. +// CanAddEdge indicates whether can add edge between two vertices. func (d *dg[T]) CanAddEdge(fromVertexID, toVertexID string) bool { d.mu.RLock() defer d.mu.RUnlock() @@ -252,12 +278,12 @@ func (d *dg[T]) CanAddEdge(fromVertexID, toVertexID string) bool { return false } - fromVertex, ok := d.vertices.Get(fromVertexID) - if !ok { + fromVertex, err := d.GetVertex(fromVertexID) + if err != nil { return false } - if _, ok := d.vertices.Get(toVertexID); !ok { + if _, err := d.GetVertex(toVertexID); err != nil { return false } @@ -275,9 +301,9 @@ func (d *dg[T]) DeleteVertexInEdges(id string) error { d.mu.Lock() defer d.mu.Unlock() - vertex, ok := d.vertices.Get(id) - if !ok { - return ErrVertexNotFound + vertex, err := d.GetVertex(id) + if err != nil { + return err } for _, parent := range vertex.Parents.Values() { @@ -293,9 +319,9 @@ func (d *dg[T]) DeleteVertexOutEdges(id string) error { d.mu.Lock() defer d.mu.Unlock() - vertex, ok := d.vertices.Get(id) - if !ok { - return ErrVertexNotFound + vertex, err := d.GetVertex(id) + if err != nil { + return err } for _, child := range vertex.Children.Values() { @@ -312,7 +338,7 @@ func (d *dg[T]) GetSourceVertices() []*Vertex[T] { defer d.mu.RUnlock() var sourceVertices []*Vertex[T] - for _, vertex := range d.vertices.Items() { + for _, vertex := range d.GetVertices() { if vertex.InDegree() == 0 { sourceVertices = append(sourceVertices, vertex) } @@ -327,7 +353,7 @@ func (d *dg[T]) GetSinkVertices() []*Vertex[T] { defer d.mu.RUnlock() var sinkVertices []*Vertex[T] - for _, vertex := range d.vertices.Items() { + for _, vertex := range d.GetVertices() { if vertex.OutDegree() == 0 { sinkVertices = append(sinkVertices, vertex) } @@ -335,18 +361,3 @@ func (d *dg[T]) GetSinkVertices() []*Vertex[T] { return sinkVertices } - -// search finds successors of vertex. -func (d *dg[T]) search(vertexID string, successors map[string]struct{}) { - vertex, ok := d.vertices.Get(vertexID) - if !ok { - return - } - - for _, child := range vertex.Children.Values() { - if _, ok := successors[child.ID]; !ok { - successors[child.ID] = struct{}{} - d.search(child.ID, successors) - } - } -} diff --git a/pkg/graph/dg/dg_test.go b/pkg/graph/dg/dg_test.go index 3b60e1df0..ce234dede 100644 --- a/pkg/graph/dg/dg_test.go +++ b/pkg/graph/dg/dg_test.go @@ -179,17 +179,17 @@ func TestDG_VertexCount(t *testing.T) { } d.VertexCount() - assert.Equal(d.VertexCount(), 1) + assert.Equal(d.VertexCount(), uint64(1)) d.DeleteVertex(mockVertexID) - assert.Equal(d.VertexCount(), 0) + assert.Equal(d.VertexCount(), uint64(0)) }, }, { name: "empty dg", expect: func(t *testing.T, d DG[string]) { assert := assert.New(t) - assert.Equal(d.VertexCount(), 0) + assert.Equal(d.VertexCount(), uint64(0)) }, }, } @@ -299,38 +299,6 @@ func TestDG_GetRandomVertices(t *testing.T) { } } -func TestDG_GetVertexKeys(t *testing.T) { - tests := []struct { - name string - expect func(t *testing.T, d DG[string]) - }{ - { - name: "get keys of vertices", - expect: func(t *testing.T, d DG[string]) { - assert := assert.New(t) - if err := d.AddVertex(mockVertexID, mockVertexValue); err != nil { - assert.NoError(err) - } - - keys := d.GetVertexKeys() - assert.Equal(len(keys), 1) - assert.Equal(keys[0], mockVertexID) - - d.DeleteVertex(mockVertexID) - keys = d.GetVertexKeys() - assert.Equal(len(keys), 0) - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - d := NewDG[string]() - tc.expect(t, d) - }) - } -} - func TestDG_AddEdge(t *testing.T) { tests := []struct { name string @@ -897,7 +865,25 @@ func BenchmarkDG_DeleteVertex(b *testing.B) { } } -func BenchmarkDG_GetRandomKeys(b *testing.B) { +func BenchmarkDAG_GetVertices(b *testing.B) { + d := NewDG[string]() + for n := 0; n < b.N; n++ { + id := fmt.Sprint(n) + if err := d.AddVertex(id, string(id)); err != nil { + b.Fatal(err) + } + } + + b.ResetTimer() + for n := 0; n < b.N; n++ { + vertices := d.GetVertices() + if len(vertices) != b.N { + b.Fatal(errors.New("get vertices failed")) + } + } +} + +func BenchmarkDG_GetRandomVertices(b *testing.B) { d := NewDG[string]() for n := 0; n < b.N; n++ { id := fmt.Sprint(n) diff --git a/pkg/graph/dg/mocks/dg_mock.go b/pkg/graph/dg/mocks/dg_mock.go index 81800ad8f..25033cb72 100644 --- a/pkg/graph/dg/mocks/dg_mock.go +++ b/pkg/graph/dg/mocks/dg_mock.go @@ -191,20 +191,6 @@ func (mr *MockDGMockRecorder[T]) GetVertex(id any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVertex", reflect.TypeOf((*MockDG[T])(nil).GetVertex), id) } -// GetVertexKeys mocks base method. -func (m *MockDG[T]) GetVertexKeys() []string { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetVertexKeys") - ret0, _ := ret[0].([]string) - return ret0 -} - -// GetVertexKeys indicates an expected call of GetVertexKeys. -func (mr *MockDGMockRecorder[T]) GetVertexKeys() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVertexKeys", reflect.TypeOf((*MockDG[T])(nil).GetVertexKeys)) -} - // GetVertices mocks base method. func (m *MockDG[T]) GetVertices() map[string]*dg.Vertex[T] { m.ctrl.T.Helper() @@ -220,10 +206,10 @@ func (mr *MockDGMockRecorder[T]) GetVertices() *gomock.Call { } // VertexCount mocks base method. -func (m *MockDG[T]) VertexCount() int { +func (m *MockDG[T]) VertexCount() uint64 { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "VertexCount") - ret0, _ := ret[0].(int) + ret0, _ := ret[0].(uint64) return ret0 } diff --git a/pkg/graph/dg/vertex.go b/pkg/graph/dg/vertex.go index b9595e207..2e4c454b8 100644 --- a/pkg/graph/dg/vertex.go +++ b/pkg/graph/dg/vertex.go @@ -24,8 +24,8 @@ import ( type Vertex[T comparable] struct { ID string Value T - Parents set.SafeSet[*Vertex[T]] - Children set.SafeSet[*Vertex[T]] + Parents set.Set[*Vertex[T]] + Children set.Set[*Vertex[T]] } // New returns a new Vertex instance. @@ -33,8 +33,8 @@ func NewVertex[T comparable](id string, value T) *Vertex[T] { return &Vertex[T]{ ID: id, Value: value, - Parents: set.NewSafeSet[*Vertex[T]](), - Children: set.NewSafeSet[*Vertex[T]](), + Parents: set.New[*Vertex[T]](), + Children: set.New[*Vertex[T]](), } } diff --git a/scheduler/resource/task.go b/scheduler/resource/task.go index b0adee6a5..3690cb155 100644 --- a/scheduler/resource/task.go +++ b/scheduler/resource/task.go @@ -269,7 +269,7 @@ func (t *Task) DeletePeer(key string) { // PeerCount returns count of peer. func (t *Task) PeerCount() int { - return t.DAG.VertexCount() + return int(t.DAG.VertexCount()) } // AddPeerEdge adds inedges between two peers.