feat: optimize graph based on sync.Map (#3152)

Signed-off-by: Gaius <gaius.qi@gmail.com>
This commit is contained in:
Gaius 2024-03-28 18:04:23 +08:00 committed by GitHub
parent 6099a917fb
commit afc54df6b3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 235 additions and 246 deletions

View File

@ -47,7 +47,11 @@ func (s *safeSet[T]) Values() []T {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
var result []T if len(s.data) == 0 {
return nil
}
result := make([]T, 0, len(s.data))
for k := range s.data { for k := range s.data {
result = append(result, k) result = append(result, k)
} }

View File

@ -34,7 +34,11 @@ func New[T comparable]() Set[T] {
} }
func (s *set[T]) Values() []T { func (s *set[T]) Values() []T {
var result []T if len(*s) == 0 {
return nil
}
result := make([]T, 0, len(*s))
for k := range *s { for k := range *s {
result = append(result, k) result = append(result, k)
} }

View File

@ -20,11 +20,9 @@ package dag
import ( import (
"errors" "errors"
"math/rand"
"sync" "sync"
"time"
cmap "github.com/orcaman/concurrent-map/v2" "go.uber.org/atomic"
"d7y.io/dragonfly/v2/pkg/container/set" "d7y.io/dragonfly/v2/pkg/container/set"
) )
@ -33,6 +31,9 @@ var (
// ErrVertexNotFound represents vertex not found. // ErrVertexNotFound represents vertex not found.
ErrVertexNotFound = errors.New("vertex not found") ErrVertexNotFound = errors.New("vertex not found")
// ErrVertexInvalid represents vertex invalid.
ErrVertexInvalid = errors.New("vertex invalid")
// ErrVertexAlreadyExists represents vertex already exists. // ErrVertexAlreadyExists represents vertex already exists.
ErrVertexAlreadyExists = errors.New("vertex already exists") ErrVertexAlreadyExists = errors.New("vertex already exists")
@ -63,9 +64,6 @@ type DAG[T comparable] interface {
// GetRandomVertices returns random map of vertices. // GetRandomVertices returns random map of vertices.
GetRandomVertices(n uint) []*Vertex[T] GetRandomVertices(n uint) []*Vertex[T]
// GetVertexKeys returns keys of vertices.
GetVertexKeys() []string
// GetSourceVertices returns source vertices. // GetSourceVertices returns source vertices.
GetSourceVertices() []*Vertex[T] GetSourceVertices() []*Vertex[T]
@ -73,7 +71,7 @@ type DAG[T comparable] interface {
GetSinkVertices() []*Vertex[T] GetSinkVertices() []*Vertex[T]
// VertexCount returns count of vertices. // VertexCount returns count of vertices.
VertexCount() int VertexCount() uint64
// AddEdge adds edge between two vertices. // AddEdge adds edge between two vertices.
AddEdge(fromVertexID, toVertexID string) error AddEdge(fromVertexID, toVertexID string) error
@ -93,14 +91,17 @@ type DAG[T comparable] interface {
// dag provides directed acyclic graph function. // dag provides directed acyclic graph function.
type dag[T comparable] struct { type dag[T comparable] struct {
vertices *sync.Map
count *atomic.Uint64
mu sync.RWMutex mu sync.RWMutex
vertices cmap.ConcurrentMap[string, *Vertex[T]]
} }
// New returns a new DAG interface. // New returns a new DAG interface.
func NewDAG[T comparable]() DAG[T] { func NewDAG[T comparable]() DAG[T] {
return &dag[T]{ return &dag[T]{
vertices: cmap.New[*Vertex[T]](), vertices: &sync.Map{},
count: atomic.NewUint64(0),
mu: sync.RWMutex{},
} }
} }
@ -109,11 +110,11 @@ func (d *dag[T]) AddVertex(id string, value T) error {
d.mu.Lock() d.mu.Lock()
defer d.mu.Unlock() defer d.mu.Unlock()
if _, ok := d.vertices.Get(id); ok { if _, loaded := d.vertices.LoadOrStore(id, NewVertex(id, value)); loaded {
return ErrVertexAlreadyExists return ErrVertexAlreadyExists
} }
d.vertices.Set(id, NewVertex(id, value)) d.count.Inc()
return nil return nil
} }
@ -122,7 +123,12 @@ func (d *dag[T]) DeleteVertex(id string) {
d.mu.Lock() d.mu.Lock()
defer d.mu.Unlock() defer d.mu.Unlock()
vertex, ok := d.vertices.Get(id) rawVertex, loaded := d.vertices.Load(id)
if !loaded {
return
}
vertex, ok := rawVertex.(*Vertex[T])
if !ok { if !ok {
return return
} }
@ -136,22 +142,47 @@ func (d *dag[T]) DeleteVertex(id string) {
continue continue
} }
d.vertices.Remove(id) d.vertices.Delete(id)
d.count.Dec()
} }
// GetVertex gets vertex from graph. // GetVertex gets vertex from graph.
func (d *dag[T]) GetVertex(id string) (*Vertex[T], error) { func (d *dag[T]) GetVertex(id string) (*Vertex[T], error) {
vertex, ok := d.vertices.Get(id) rawVertex, loaded := d.vertices.Load(id)
if !ok { if !loaded {
return nil, ErrVertexNotFound return nil, ErrVertexNotFound
} }
vertex, ok := rawVertex.(*Vertex[T])
if !ok {
return nil, ErrVertexInvalid
}
return vertex, nil return vertex, nil
} }
// GetVertices returns map of vertices. // GetVertices returns map of vertices.
func (d *dag[T]) GetVertices() map[string]*Vertex[T] { func (d *dag[T]) GetVertices() map[string]*Vertex[T] {
return d.vertices.Items() d.mu.RLock()
defer d.mu.RUnlock()
vertices := make(map[string]*Vertex[T], d.count.Load())
d.vertices.Range(func(key, value interface{}) bool {
vertex, ok := value.(*Vertex[T])
if !ok {
return true
}
id, ok := key.(string)
if !ok {
return true
}
vertices[id] = vertex
return true
})
return vertices
} }
// GetRandomVertices returns random map of vertices. // GetRandomVertices returns random map of vertices.
@ -159,32 +190,27 @@ func (d *dag[T]) GetRandomVertices(n uint) []*Vertex[T] {
d.mu.RLock() d.mu.RLock()
defer d.mu.RUnlock() defer d.mu.RUnlock()
keys := d.GetVertexKeys() if n == 0 {
if int(n) >= len(keys) { return nil
n = uint(len(keys))
} }
r := rand.New(rand.NewSource(time.Now().UnixNano()))
permutation := r.Perm(len(keys))[:n]
randomVertices := make([]*Vertex[T], 0, n) randomVertices := make([]*Vertex[T], 0, n)
for _, v := range permutation { d.vertices.Range(func(key, value interface{}) bool {
key := keys[v] vertex, ok := value.(*Vertex[T])
if vertex, err := d.GetVertex(key); err == nil { if !ok {
randomVertices = append(randomVertices, vertex) return true
} }
}
randomVertices = append(randomVertices, vertex)
return uint(len(randomVertices)) < n
})
return randomVertices return randomVertices
} }
// GetVertexKeys returns keys of vertices.
func (d *dag[T]) GetVertexKeys() []string {
return d.vertices.Keys()
}
// VertexCount returns count of vertices. // VertexCount returns count of vertices.
func (d *dag[T]) VertexCount() int { func (d *dag[T]) VertexCount() uint64 {
return d.vertices.Count() return d.count.Load()
} }
// AddEdge adds edge between two vertices. // AddEdge adds edge between two vertices.
@ -196,14 +222,14 @@ func (d *dag[T]) AddEdge(fromVertexID, toVertexID string) error {
return ErrCycleBetweenVertices return ErrCycleBetweenVertices
} }
fromVertex, ok := d.vertices.Get(fromVertexID) fromVertex, err := d.GetVertex(fromVertexID)
if !ok { if err != nil {
return ErrVertexNotFound return err
} }
toVertex, ok := d.vertices.Get(toVertexID) toVertex, err := d.GetVertex(toVertexID)
if !ok { if err != nil {
return ErrVertexNotFound return err
} }
for _, child := range fromVertex.Children.Values() { for _, child := range fromVertex.Children.Values() {
@ -232,14 +258,14 @@ func (d *dag[T]) DeleteEdge(fromVertexID, toVertexID string) error {
d.mu.Lock() d.mu.Lock()
defer d.mu.Unlock() defer d.mu.Unlock()
fromVertex, ok := d.vertices.Get(fromVertexID) fromVertex, err := d.GetVertex(fromVertexID)
if !ok { if err != nil {
return ErrVertexNotFound return err
} }
toVertex, ok := d.vertices.Get(toVertexID) toVertex, err := d.GetVertex(toVertexID)
if !ok { if err != nil {
return ErrVertexNotFound return err
} }
fromVertex.Children.Delete(toVertex) fromVertex.Children.Delete(toVertex)
@ -256,12 +282,12 @@ func (d *dag[T]) CanAddEdge(fromVertexID, toVertexID string) bool {
return false return false
} }
fromVertex, ok := d.vertices.Get(fromVertexID) fromVertex, err := d.GetVertex(fromVertexID)
if !ok { if err != nil {
return false return false
} }
if _, ok := d.vertices.Get(toVertexID); !ok { if _, err := d.GetVertex(toVertexID); err != nil {
return false return false
} }
@ -283,9 +309,9 @@ func (d *dag[T]) DeleteVertexInEdges(id string) error {
d.mu.Lock() d.mu.Lock()
defer d.mu.Unlock() defer d.mu.Unlock()
vertex, ok := d.vertices.Get(id) vertex, err := d.GetVertex(id)
if !ok { if err != nil {
return ErrVertexNotFound return err
} }
for _, parent := range vertex.Parents.Values() { for _, parent := range vertex.Parents.Values() {
@ -301,9 +327,9 @@ func (d *dag[T]) DeleteVertexOutEdges(id string) error {
d.mu.Lock() d.mu.Lock()
defer d.mu.Unlock() defer d.mu.Unlock()
vertex, ok := d.vertices.Get(id) vertex, err := d.GetVertex(id)
if !ok { if err != nil {
return ErrVertexNotFound return err
} }
for _, child := range vertex.Children.Values() { for _, child := range vertex.Children.Values() {
@ -320,7 +346,7 @@ func (d *dag[T]) GetSourceVertices() []*Vertex[T] {
defer d.mu.RUnlock() defer d.mu.RUnlock()
var sourceVertices []*Vertex[T] var sourceVertices []*Vertex[T]
for _, vertex := range d.vertices.Items() { for _, vertex := range d.GetVertices() {
if vertex.InDegree() == 0 { if vertex.InDegree() == 0 {
sourceVertices = append(sourceVertices, vertex) sourceVertices = append(sourceVertices, vertex)
} }
@ -335,7 +361,7 @@ func (d *dag[T]) GetSinkVertices() []*Vertex[T] {
defer d.mu.RUnlock() defer d.mu.RUnlock()
var sinkVertices []*Vertex[T] var sinkVertices []*Vertex[T]
for _, vertex := range d.vertices.Items() { for _, vertex := range d.GetVertices() {
if vertex.OutDegree() == 0 { if vertex.OutDegree() == 0 {
sinkVertices = append(sinkVertices, vertex) sinkVertices = append(sinkVertices, vertex)
} }
@ -354,8 +380,8 @@ func (d *dag[T]) depthFirstSearch(fromVertexID, toVertexID string) bool {
// search finds successors of vertex. // search finds successors of vertex.
func (d *dag[T]) search(vertexID string, successors map[string]struct{}) { func (d *dag[T]) search(vertexID string, successors map[string]struct{}) {
vertex, ok := d.vertices.Get(vertexID) vertex, err := d.GetVertex(vertexID)
if !ok { if err != nil {
return return
} }

View File

@ -179,17 +179,17 @@ func TestDAG_VertexCount(t *testing.T) {
} }
d.VertexCount() d.VertexCount()
assert.Equal(d.VertexCount(), 1) assert.Equal(d.VertexCount(), uint64(1))
d.DeleteVertex(mockVertexID) d.DeleteVertex(mockVertexID)
assert.Equal(d.VertexCount(), 0) assert.Equal(d.VertexCount(), uint64(0))
}, },
}, },
{ {
name: "empty dag", name: "empty dag",
expect: func(t *testing.T, d DAG[string]) { expect: func(t *testing.T, d DAG[string]) {
assert := assert.New(t) assert := assert.New(t)
assert.Equal(d.VertexCount(), 0) assert.Equal(d.VertexCount(), uint64(0))
}, },
}, },
} }
@ -302,38 +302,6 @@ func TestDAG_GetRandomVertices(t *testing.T) {
} }
} }
func TestDAG_GetVertexKeys(t *testing.T) {
tests := []struct {
name string
expect func(t *testing.T, d DAG[string])
}{
{
name: "get keys of vertices",
expect: func(t *testing.T, d DAG[string]) {
assert := assert.New(t)
if err := d.AddVertex(mockVertexID, mockVertexValue); err != nil {
assert.NoError(err)
}
keys := d.GetVertexKeys()
assert.Equal(len(keys), 1)
assert.Equal(keys[0], mockVertexID)
d.DeleteVertex(mockVertexID)
keys = d.GetVertexKeys()
assert.Equal(len(keys), 0)
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
d := NewDAG[string]()
tc.expect(t, d)
})
}
}
func TestDAG_AddEdge(t *testing.T) { func TestDAG_AddEdge(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@ -965,7 +933,25 @@ func BenchmarkDAG_DeleteVertex(b *testing.B) {
} }
} }
func BenchmarkDAG_GetRandomKeys(b *testing.B) { func BenchmarkDAG_GetVertices(b *testing.B) {
d := NewDAG[string]()
for n := 0; n < b.N; n++ {
id := fmt.Sprint(n)
if err := d.AddVertex(id, string(id)); err != nil {
b.Fatal(err)
}
}
b.ResetTimer()
for n := 0; n < b.N; n++ {
vertices := d.GetVertices()
if len(vertices) != b.N {
b.Fatal(errors.New("get vertices failed"))
}
}
}
func BenchmarkDAG_GetRandomVertices(b *testing.B) {
d := NewDAG[string]() d := NewDAG[string]()
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
id := fmt.Sprint(n) id := fmt.Sprint(n)

View File

@ -191,20 +191,6 @@ func (mr *MockDAGMockRecorder[T]) GetVertex(id any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVertex", reflect.TypeOf((*MockDAG[T])(nil).GetVertex), id) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVertex", reflect.TypeOf((*MockDAG[T])(nil).GetVertex), id)
} }
// GetVertexKeys mocks base method.
func (m *MockDAG[T]) GetVertexKeys() []string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetVertexKeys")
ret0, _ := ret[0].([]string)
return ret0
}
// GetVertexKeys indicates an expected call of GetVertexKeys.
func (mr *MockDAGMockRecorder[T]) GetVertexKeys() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVertexKeys", reflect.TypeOf((*MockDAG[T])(nil).GetVertexKeys))
}
// GetVertices mocks base method. // GetVertices mocks base method.
func (m *MockDAG[T]) GetVertices() map[string]*dag.Vertex[T] { func (m *MockDAG[T]) GetVertices() map[string]*dag.Vertex[T] {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -220,10 +206,10 @@ func (mr *MockDAGMockRecorder[T]) GetVertices() *gomock.Call {
} }
// VertexCount mocks base method. // VertexCount mocks base method.
func (m *MockDAG[T]) VertexCount() int { func (m *MockDAG[T]) VertexCount() uint64 {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "VertexCount") ret := m.ctrl.Call(m, "VertexCount")
ret0, _ := ret[0].(int) ret0, _ := ret[0].(uint64)
return ret0 return ret0
} }

View File

@ -24,8 +24,8 @@ import (
type Vertex[T comparable] struct { type Vertex[T comparable] struct {
ID string ID string
Value T Value T
Parents set.SafeSet[*Vertex[T]] Parents set.Set[*Vertex[T]]
Children set.SafeSet[*Vertex[T]] Children set.Set[*Vertex[T]]
} }
// New returns a new Vertex instance. // New returns a new Vertex instance.
@ -33,8 +33,8 @@ func NewVertex[T comparable](id string, value T) *Vertex[T] {
return &Vertex[T]{ return &Vertex[T]{
ID: id, ID: id,
Value: value, Value: value,
Parents: set.NewSafeSet[*Vertex[T]](), Parents: set.New[*Vertex[T]](),
Children: set.NewSafeSet[*Vertex[T]](), Children: set.New[*Vertex[T]](),
} }
} }

View File

@ -20,11 +20,9 @@ package dg
import ( import (
"errors" "errors"
"math/rand"
"sync" "sync"
"time"
cmap "github.com/orcaman/concurrent-map/v2" "go.uber.org/atomic"
"d7y.io/dragonfly/v2/pkg/container/set" "d7y.io/dragonfly/v2/pkg/container/set"
) )
@ -33,6 +31,9 @@ var (
// ErrVertexNotFound represents vertex not found. // ErrVertexNotFound represents vertex not found.
ErrVertexNotFound = errors.New("vertex not found") ErrVertexNotFound = errors.New("vertex not found")
// ErrVertexInvalid represents vertex invalid.
ErrVertexInvalid = errors.New("vertex invalid")
// ErrVertexAlreadyExists represents vertex already exists. // ErrVertexAlreadyExists represents vertex already exists.
ErrVertexAlreadyExists = errors.New("vertex already exists") ErrVertexAlreadyExists = errors.New("vertex already exists")
@ -63,9 +64,6 @@ type DG[T comparable] interface {
// GetRandomVertices returns random map of vertices. // GetRandomVertices returns random map of vertices.
GetRandomVertices(n uint) []*Vertex[T] GetRandomVertices(n uint) []*Vertex[T]
// GetVertexKeys returns keys of vertices.
GetVertexKeys() []string
// GetSourceVertices returns source vertices. // GetSourceVertices returns source vertices.
GetSourceVertices() []*Vertex[T] GetSourceVertices() []*Vertex[T]
@ -73,7 +71,7 @@ type DG[T comparable] interface {
GetSinkVertices() []*Vertex[T] GetSinkVertices() []*Vertex[T]
// VertexCount returns count of vertices. // VertexCount returns count of vertices.
VertexCount() int VertexCount() uint64
// AddEdge adds edge between two vertices. // AddEdge adds edge between two vertices.
AddEdge(fromVertexID, toVertexID string) error AddEdge(fromVertexID, toVertexID string) error
@ -81,7 +79,7 @@ type DG[T comparable] interface {
// DeleteEdge deletes edge between two vertices. // DeleteEdge deletes edge between two vertices.
DeleteEdge(fromVertexID, toVertexID string) error DeleteEdge(fromVertexID, toVertexID string) error
// CanAddEdge finds whether there are circles through depth-first search. // CanAddEdge indicates whether can add edge between two vertices.
CanAddEdge(fromVertexID, toVertexID string) bool CanAddEdge(fromVertexID, toVertexID string) bool
// DeleteVertexInEdges deletes inedges of vertex. // DeleteVertexInEdges deletes inedges of vertex.
@ -93,14 +91,17 @@ type DG[T comparable] interface {
// dg provides directed graph function. // dg provides directed graph function.
type dg[T comparable] struct { type dg[T comparable] struct {
vertices *sync.Map
count *atomic.Uint64
mu sync.RWMutex mu sync.RWMutex
vertices cmap.ConcurrentMap[string, *Vertex[T]]
} }
// New returns a new DG interface. // New returns a new DG interface.
func NewDG[T comparable]() DG[T] { func NewDG[T comparable]() DG[T] {
return &dg[T]{ return &dg[T]{
vertices: cmap.New[*Vertex[T]](), vertices: &sync.Map{},
count: atomic.NewUint64(0),
mu: sync.RWMutex{},
} }
} }
@ -109,11 +110,11 @@ func (d *dg[T]) AddVertex(id string, value T) error {
d.mu.Lock() d.mu.Lock()
defer d.mu.Unlock() defer d.mu.Unlock()
if _, ok := d.vertices.Get(id); ok { if _, loaded := d.vertices.LoadOrStore(id, NewVertex(id, value)); loaded {
return ErrVertexAlreadyExists return ErrVertexAlreadyExists
} }
d.vertices.Set(id, NewVertex(id, value)) d.count.Inc()
return nil return nil
} }
@ -122,7 +123,12 @@ func (d *dg[T]) DeleteVertex(id string) {
d.mu.Lock() d.mu.Lock()
defer d.mu.Unlock() defer d.mu.Unlock()
vertex, ok := d.vertices.Get(id) rawVertex, loaded := d.vertices.Load(id)
if !loaded {
return
}
vertex, ok := rawVertex.(*Vertex[T])
if !ok { if !ok {
return return
} }
@ -136,22 +142,47 @@ func (d *dg[T]) DeleteVertex(id string) {
continue continue
} }
d.vertices.Remove(id) d.vertices.Delete(id)
d.count.Dec()
} }
// GetVertex gets vertex from graph. // GetVertex gets vertex from graph.
func (d *dg[T]) GetVertex(id string) (*Vertex[T], error) { func (d *dg[T]) GetVertex(id string) (*Vertex[T], error) {
vertex, ok := d.vertices.Get(id) rawVertex, loaded := d.vertices.Load(id)
if !ok { if !loaded {
return nil, ErrVertexNotFound return nil, ErrVertexNotFound
} }
vertex, ok := rawVertex.(*Vertex[T])
if !ok {
return nil, ErrVertexInvalid
}
return vertex, nil return vertex, nil
} }
// GetVertices returns map of vertices. // GetVertices returns map of vertices.
func (d *dg[T]) GetVertices() map[string]*Vertex[T] { func (d *dg[T]) GetVertices() map[string]*Vertex[T] {
return d.vertices.Items() d.mu.RLock()
defer d.mu.RUnlock()
vertices := make(map[string]*Vertex[T], d.count.Load())
d.vertices.Range(func(key, value interface{}) bool {
vertex, ok := value.(*Vertex[T])
if !ok {
return true
}
id, ok := key.(string)
if !ok {
return true
}
vertices[id] = vertex
return true
})
return vertices
} }
// GetRandomVertices returns random map of vertices. // GetRandomVertices returns random map of vertices.
@ -159,32 +190,27 @@ func (d *dg[T]) GetRandomVertices(n uint) []*Vertex[T] {
d.mu.RLock() d.mu.RLock()
defer d.mu.RUnlock() defer d.mu.RUnlock()
keys := d.GetVertexKeys() if n == 0 {
if int(n) >= len(keys) { return nil
n = uint(len(keys))
} }
r := rand.New(rand.NewSource(time.Now().UnixNano()))
permutation := r.Perm(len(keys))[:n]
randomVertices := make([]*Vertex[T], 0, n) randomVertices := make([]*Vertex[T], 0, n)
for _, v := range permutation { d.vertices.Range(func(key, value interface{}) bool {
key := keys[v] vertex, ok := value.(*Vertex[T])
if vertex, err := d.GetVertex(key); err == nil { if !ok {
randomVertices = append(randomVertices, vertex) return true
} }
}
randomVertices = append(randomVertices, vertex)
return uint(len(randomVertices)) < n
})
return randomVertices return randomVertices
} }
// GetVertexKeys returns keys of vertices.
func (d *dg[T]) GetVertexKeys() []string {
return d.vertices.Keys()
}
// VertexCount returns count of vertices. // VertexCount returns count of vertices.
func (d *dg[T]) VertexCount() int { func (d *dg[T]) VertexCount() uint64 {
return d.vertices.Count() return d.count.Load()
} }
// AddEdge adds edge between two vertices. // AddEdge adds edge between two vertices.
@ -196,14 +222,14 @@ func (d *dg[T]) AddEdge(fromVertexID, toVertexID string) error {
return ErrCycleBetweenVertices return ErrCycleBetweenVertices
} }
fromVertex, ok := d.vertices.Get(fromVertexID) fromVertex, err := d.GetVertex(fromVertexID)
if !ok { if err != nil {
return ErrVertexNotFound return err
} }
toVertex, ok := d.vertices.Get(toVertexID) toVertex, err := d.GetVertex(toVertexID)
if !ok { if err != nil {
return ErrVertexNotFound return err
} }
for _, child := range fromVertex.Children.Values() { for _, child := range fromVertex.Children.Values() {
@ -228,14 +254,14 @@ func (d *dg[T]) DeleteEdge(fromVertexID, toVertexID string) error {
d.mu.Lock() d.mu.Lock()
defer d.mu.Unlock() defer d.mu.Unlock()
fromVertex, ok := d.vertices.Get(fromVertexID) fromVertex, err := d.GetVertex(fromVertexID)
if !ok { if err != nil {
return ErrVertexNotFound return err
} }
toVertex, ok := d.vertices.Get(toVertexID) toVertex, err := d.GetVertex(toVertexID)
if !ok { if err != nil {
return ErrVertexNotFound return err
} }
fromVertex.Children.Delete(toVertex) fromVertex.Children.Delete(toVertex)
@ -243,7 +269,7 @@ func (d *dg[T]) DeleteEdge(fromVertexID, toVertexID string) error {
return nil return nil
} }
// CanAddEdge finds whether there are circles through depth-first search. // CanAddEdge indicates whether can add edge between two vertices.
func (d *dg[T]) CanAddEdge(fromVertexID, toVertexID string) bool { func (d *dg[T]) CanAddEdge(fromVertexID, toVertexID string) bool {
d.mu.RLock() d.mu.RLock()
defer d.mu.RUnlock() defer d.mu.RUnlock()
@ -252,12 +278,12 @@ func (d *dg[T]) CanAddEdge(fromVertexID, toVertexID string) bool {
return false return false
} }
fromVertex, ok := d.vertices.Get(fromVertexID) fromVertex, err := d.GetVertex(fromVertexID)
if !ok { if err != nil {
return false return false
} }
if _, ok := d.vertices.Get(toVertexID); !ok { if _, err := d.GetVertex(toVertexID); err != nil {
return false return false
} }
@ -275,9 +301,9 @@ func (d *dg[T]) DeleteVertexInEdges(id string) error {
d.mu.Lock() d.mu.Lock()
defer d.mu.Unlock() defer d.mu.Unlock()
vertex, ok := d.vertices.Get(id) vertex, err := d.GetVertex(id)
if !ok { if err != nil {
return ErrVertexNotFound return err
} }
for _, parent := range vertex.Parents.Values() { for _, parent := range vertex.Parents.Values() {
@ -293,9 +319,9 @@ func (d *dg[T]) DeleteVertexOutEdges(id string) error {
d.mu.Lock() d.mu.Lock()
defer d.mu.Unlock() defer d.mu.Unlock()
vertex, ok := d.vertices.Get(id) vertex, err := d.GetVertex(id)
if !ok { if err != nil {
return ErrVertexNotFound return err
} }
for _, child := range vertex.Children.Values() { for _, child := range vertex.Children.Values() {
@ -312,7 +338,7 @@ func (d *dg[T]) GetSourceVertices() []*Vertex[T] {
defer d.mu.RUnlock() defer d.mu.RUnlock()
var sourceVertices []*Vertex[T] var sourceVertices []*Vertex[T]
for _, vertex := range d.vertices.Items() { for _, vertex := range d.GetVertices() {
if vertex.InDegree() == 0 { if vertex.InDegree() == 0 {
sourceVertices = append(sourceVertices, vertex) sourceVertices = append(sourceVertices, vertex)
} }
@ -327,7 +353,7 @@ func (d *dg[T]) GetSinkVertices() []*Vertex[T] {
defer d.mu.RUnlock() defer d.mu.RUnlock()
var sinkVertices []*Vertex[T] var sinkVertices []*Vertex[T]
for _, vertex := range d.vertices.Items() { for _, vertex := range d.GetVertices() {
if vertex.OutDegree() == 0 { if vertex.OutDegree() == 0 {
sinkVertices = append(sinkVertices, vertex) sinkVertices = append(sinkVertices, vertex)
} }
@ -335,18 +361,3 @@ func (d *dg[T]) GetSinkVertices() []*Vertex[T] {
return sinkVertices return sinkVertices
} }
// search finds successors of vertex.
func (d *dg[T]) search(vertexID string, successors map[string]struct{}) {
vertex, ok := d.vertices.Get(vertexID)
if !ok {
return
}
for _, child := range vertex.Children.Values() {
if _, ok := successors[child.ID]; !ok {
successors[child.ID] = struct{}{}
d.search(child.ID, successors)
}
}
}

View File

@ -179,17 +179,17 @@ func TestDG_VertexCount(t *testing.T) {
} }
d.VertexCount() d.VertexCount()
assert.Equal(d.VertexCount(), 1) assert.Equal(d.VertexCount(), uint64(1))
d.DeleteVertex(mockVertexID) d.DeleteVertex(mockVertexID)
assert.Equal(d.VertexCount(), 0) assert.Equal(d.VertexCount(), uint64(0))
}, },
}, },
{ {
name: "empty dg", name: "empty dg",
expect: func(t *testing.T, d DG[string]) { expect: func(t *testing.T, d DG[string]) {
assert := assert.New(t) assert := assert.New(t)
assert.Equal(d.VertexCount(), 0) assert.Equal(d.VertexCount(), uint64(0))
}, },
}, },
} }
@ -299,38 +299,6 @@ func TestDG_GetRandomVertices(t *testing.T) {
} }
} }
func TestDG_GetVertexKeys(t *testing.T) {
tests := []struct {
name string
expect func(t *testing.T, d DG[string])
}{
{
name: "get keys of vertices",
expect: func(t *testing.T, d DG[string]) {
assert := assert.New(t)
if err := d.AddVertex(mockVertexID, mockVertexValue); err != nil {
assert.NoError(err)
}
keys := d.GetVertexKeys()
assert.Equal(len(keys), 1)
assert.Equal(keys[0], mockVertexID)
d.DeleteVertex(mockVertexID)
keys = d.GetVertexKeys()
assert.Equal(len(keys), 0)
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
d := NewDG[string]()
tc.expect(t, d)
})
}
}
func TestDG_AddEdge(t *testing.T) { func TestDG_AddEdge(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@ -897,7 +865,25 @@ func BenchmarkDG_DeleteVertex(b *testing.B) {
} }
} }
func BenchmarkDG_GetRandomKeys(b *testing.B) { func BenchmarkDAG_GetVertices(b *testing.B) {
d := NewDG[string]()
for n := 0; n < b.N; n++ {
id := fmt.Sprint(n)
if err := d.AddVertex(id, string(id)); err != nil {
b.Fatal(err)
}
}
b.ResetTimer()
for n := 0; n < b.N; n++ {
vertices := d.GetVertices()
if len(vertices) != b.N {
b.Fatal(errors.New("get vertices failed"))
}
}
}
func BenchmarkDG_GetRandomVertices(b *testing.B) {
d := NewDG[string]() d := NewDG[string]()
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
id := fmt.Sprint(n) id := fmt.Sprint(n)

View File

@ -191,20 +191,6 @@ func (mr *MockDGMockRecorder[T]) GetVertex(id any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVertex", reflect.TypeOf((*MockDG[T])(nil).GetVertex), id) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVertex", reflect.TypeOf((*MockDG[T])(nil).GetVertex), id)
} }
// GetVertexKeys mocks base method.
func (m *MockDG[T]) GetVertexKeys() []string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetVertexKeys")
ret0, _ := ret[0].([]string)
return ret0
}
// GetVertexKeys indicates an expected call of GetVertexKeys.
func (mr *MockDGMockRecorder[T]) GetVertexKeys() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVertexKeys", reflect.TypeOf((*MockDG[T])(nil).GetVertexKeys))
}
// GetVertices mocks base method. // GetVertices mocks base method.
func (m *MockDG[T]) GetVertices() map[string]*dg.Vertex[T] { func (m *MockDG[T]) GetVertices() map[string]*dg.Vertex[T] {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -220,10 +206,10 @@ func (mr *MockDGMockRecorder[T]) GetVertices() *gomock.Call {
} }
// VertexCount mocks base method. // VertexCount mocks base method.
func (m *MockDG[T]) VertexCount() int { func (m *MockDG[T]) VertexCount() uint64 {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "VertexCount") ret := m.ctrl.Call(m, "VertexCount")
ret0, _ := ret[0].(int) ret0, _ := ret[0].(uint64)
return ret0 return ret0
} }

View File

@ -24,8 +24,8 @@ import (
type Vertex[T comparable] struct { type Vertex[T comparable] struct {
ID string ID string
Value T Value T
Parents set.SafeSet[*Vertex[T]] Parents set.Set[*Vertex[T]]
Children set.SafeSet[*Vertex[T]] Children set.Set[*Vertex[T]]
} }
// New returns a new Vertex instance. // New returns a new Vertex instance.
@ -33,8 +33,8 @@ func NewVertex[T comparable](id string, value T) *Vertex[T] {
return &Vertex[T]{ return &Vertex[T]{
ID: id, ID: id,
Value: value, Value: value,
Parents: set.NewSafeSet[*Vertex[T]](), Parents: set.New[*Vertex[T]](),
Children: set.NewSafeSet[*Vertex[T]](), Children: set.New[*Vertex[T]](),
} }
} }

View File

@ -269,7 +269,7 @@ func (t *Task) DeletePeer(key string) {
// PeerCount returns count of peer. // PeerCount returns count of peer.
func (t *Task) PeerCount() int { func (t *Task) PeerCount() int {
return t.DAG.VertexCount() return int(t.DAG.VertexCount())
} }
// AddPeerEdge adds inedges between two peers. // AddPeerEdge adds inedges between two peers.