refactor: set and dag with generics (#1490)

Signed-off-by: Gaius <gaius.qi@gmail.com>
This commit is contained in:
Gaius 2022-07-26 21:17:51 +08:00
parent 1d7c87627c
commit 7c2ee7858b
No known key found for this signature in database
GPG Key ID: 8B4E5D1290FA2FFB
26 changed files with 630 additions and 2629 deletions

3
go.mod
View File

@ -45,6 +45,7 @@ require (
github.com/montanaflynn/stats v0.6.6 github.com/montanaflynn/stats v0.6.6
github.com/onsi/ginkgo/v2 v2.1.4 github.com/onsi/ginkgo/v2 v2.1.4
github.com/onsi/gomega v1.19.0 github.com/onsi/gomega v1.19.0
github.com/orcaman/concurrent-map/v2 v2.0.0
github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5
github.com/prometheus/client_golang v1.12.2 github.com/prometheus/client_golang v1.12.2
github.com/schollz/progressbar/v3 v3.8.6 github.com/schollz/progressbar/v3 v3.8.6
@ -77,6 +78,7 @@ require (
gopkg.in/natefinch/lumberjack.v2 v2.0.0 gopkg.in/natefinch/lumberjack.v2 v2.0.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/mysql v1.3.4 gorm.io/driver/mysql v1.3.4
gorm.io/driver/postgres v1.3.7
gorm.io/gorm v1.23.6 gorm.io/gorm v1.23.6
gorm.io/plugin/soft_delete v1.1.0 gorm.io/plugin/soft_delete v1.1.0
k8s.io/apimachinery v0.24.2 k8s.io/apimachinery v0.24.2
@ -201,7 +203,6 @@ require (
google.golang.org/genproto v0.0.0-20220628213854-d9e0b6570c03 // indirect google.golang.org/genproto v0.0.0-20220628213854-d9e0b6570c03 // indirect
gopkg.in/ini.v1 v1.66.6 // indirect gopkg.in/ini.v1 v1.66.6 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect
gorm.io/driver/postgres v1.3.7 // indirect
gorm.io/driver/sqlserver v1.3.2 // indirect gorm.io/driver/sqlserver v1.3.2 // indirect
gorm.io/plugin/dbresolver v1.2.1 // indirect gorm.io/plugin/dbresolver v1.2.1 // indirect
k8s.io/klog/v2 v2.60.1 // indirect k8s.io/klog/v2 v2.60.1 // indirect

2
go.sum
View File

@ -822,6 +822,8 @@ github.com/openzipkin-contrib/zipkin-go-opentracing v0.4.5/go.mod h1:/wsWhb9smxS
github.com/openzipkin/zipkin-go v0.1.6/go.mod h1:QgAqvLzwWbR/WpD4A3cGpPtJrZXNIiJc5AZX7/PBEpw= github.com/openzipkin/zipkin-go v0.1.6/go.mod h1:QgAqvLzwWbR/WpD4A3cGpPtJrZXNIiJc5AZX7/PBEpw=
github.com/openzipkin/zipkin-go v0.2.1/go.mod h1:NaW6tEwdmWMaCDZzg8sh+IBNOxHMPnhQw8ySjnjRyN4= github.com/openzipkin/zipkin-go v0.2.1/go.mod h1:NaW6tEwdmWMaCDZzg8sh+IBNOxHMPnhQw8ySjnjRyN4=
github.com/openzipkin/zipkin-go v0.2.2/go.mod h1:NaW6tEwdmWMaCDZzg8sh+IBNOxHMPnhQw8ySjnjRyN4= github.com/openzipkin/zipkin-go v0.2.2/go.mod h1:NaW6tEwdmWMaCDZzg8sh+IBNOxHMPnhQw8ySjnjRyN4=
github.com/orcaman/concurrent-map/v2 v2.0.0 h1:iSMwuBQvQ1nX5i9gYuGMiSy0fjWHmazdjF+NdSO9JzI=
github.com/orcaman/concurrent-map/v2 v2.0.0/go.mod h1:9Eq3TG2oBe5FirmYWQfYO5iH1q0Jv47PLaNK++uCdOM=
github.com/otiai10/copy v1.7.0/go.mod h1:rmRl6QPdJj6EiUqXQ/4Nn2lLXoNQjFCQbbNrxgc/t3U= github.com/otiai10/copy v1.7.0/go.mod h1:rmRl6QPdJj6EiUqXQ/4Nn2lLXoNQjFCQbbNrxgc/t3U=
github.com/otiai10/curr v0.0.0-20150429015615-9b4961190c95/go.mod h1:9qAhocn7zKJG+0mI8eUu6xqkFDYS2kb2saOteoSB3cE= github.com/otiai10/curr v0.0.0-20150429015615-9b4961190c95/go.mod h1:9qAhocn7zKJG+0mI8eUu6xqkFDYS2kb2saOteoSB3cE=
github.com/otiai10/curr v1.0.0/go.mod h1:LskTG5wDwr8Rs+nNQ+1LlxRjAtTZZjtJW4rMXl6j4vs= github.com/otiai10/curr v1.0.0/go.mod h1:LskTG5wDwr8Rs+nNQ+1LlxRjAtTZZjtJW4rMXl6j4vs=

View File

@ -1,137 +0,0 @@
/*
* Copyright 2020 The Dragonfly Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
//go:generate mockgen -destination sorted_list_mock.go -source sorted_list.go -package list
package list
import (
"container/list"
"sync"
)
type Item interface {
SortedValue() int
}
type SortedList interface {
Len() int
Insert(Item)
Remove(Item)
Contains(Item) bool
Range(func(Item) bool)
ReverseRange(fn func(Item) bool)
}
type sortedList struct {
mu *sync.RWMutex
container *list.List
}
func NewSortedList() SortedList {
return &sortedList{
mu: &sync.RWMutex{},
container: list.New(),
}
}
func (l *sortedList) Len() int {
l.mu.RLock()
defer l.mu.RUnlock()
return l.container.Len()
}
func (l *sortedList) Insert(item Item) {
l.mu.Lock()
defer l.mu.Unlock()
for e := l.container.Front(); e != nil; e = e.Next() {
v, ok := e.Value.(Item)
if !ok {
continue
}
if v.SortedValue() >= item.SortedValue() {
l.container.InsertBefore(item, e)
return
}
}
l.container.PushBack(item)
}
func (l *sortedList) Remove(item Item) {
l.mu.Lock()
defer l.mu.Unlock()
for e := l.container.Front(); e != nil; e = e.Next() {
v, ok := e.Value.(Item)
if !ok {
continue
}
if v == item {
l.container.Remove(e)
return
}
}
}
func (l *sortedList) Contains(item Item) bool {
l.mu.RLock()
defer l.mu.RUnlock()
for e := l.container.Front(); e != nil; e = e.Next() {
if v, ok := e.Value.(Item); ok && v == item {
return true
}
}
return false
}
func (l *sortedList) Range(fn func(Item) bool) {
l.mu.RLock()
defer l.mu.RUnlock()
for e := l.container.Front(); e != nil; e = e.Next() {
v, ok := e.Value.(Item)
if !ok {
continue
}
if !fn(v) {
return
}
}
}
func (l *sortedList) ReverseRange(fn func(Item) bool) {
l.mu.RLock()
defer l.mu.RUnlock()
for e := l.container.Back(); e != nil; e = e.Prev() {
v, ok := e.Value.(Item)
if !ok {
continue
}
if !fn(v) {
return
}
}
}

View File

@ -1,147 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: sorted_list.go
// Package list is a generated GoMock package.
package list
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockItem is a mock of Item interface.
type MockItem struct {
ctrl *gomock.Controller
recorder *MockItemMockRecorder
}
// MockItemMockRecorder is the mock recorder for MockItem.
type MockItemMockRecorder struct {
mock *MockItem
}
// NewMockItem creates a new mock instance.
func NewMockItem(ctrl *gomock.Controller) *MockItem {
mock := &MockItem{ctrl: ctrl}
mock.recorder = &MockItemMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockItem) EXPECT() *MockItemMockRecorder {
return m.recorder
}
// SortedValue mocks base method.
func (m *MockItem) SortedValue() int {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SortedValue")
ret0, _ := ret[0].(int)
return ret0
}
// SortedValue indicates an expected call of SortedValue.
func (mr *MockItemMockRecorder) SortedValue() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SortedValue", reflect.TypeOf((*MockItem)(nil).SortedValue))
}
// MockSortedList is a mock of SortedList interface.
type MockSortedList struct {
ctrl *gomock.Controller
recorder *MockSortedListMockRecorder
}
// MockSortedListMockRecorder is the mock recorder for MockSortedList.
type MockSortedListMockRecorder struct {
mock *MockSortedList
}
// NewMockSortedList creates a new mock instance.
func NewMockSortedList(ctrl *gomock.Controller) *MockSortedList {
mock := &MockSortedList{ctrl: ctrl}
mock.recorder = &MockSortedListMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockSortedList) EXPECT() *MockSortedListMockRecorder {
return m.recorder
}
// Contains mocks base method.
func (m *MockSortedList) Contains(arg0 Item) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Contains", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// Contains indicates an expected call of Contains.
func (mr *MockSortedListMockRecorder) Contains(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Contains", reflect.TypeOf((*MockSortedList)(nil).Contains), arg0)
}
// Insert mocks base method.
func (m *MockSortedList) Insert(arg0 Item) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Insert", arg0)
}
// Insert indicates an expected call of Insert.
func (mr *MockSortedListMockRecorder) Insert(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Insert", reflect.TypeOf((*MockSortedList)(nil).Insert), arg0)
}
// Len mocks base method.
func (m *MockSortedList) Len() int {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Len")
ret0, _ := ret[0].(int)
return ret0
}
// Len indicates an expected call of Len.
func (mr *MockSortedListMockRecorder) Len() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Len", reflect.TypeOf((*MockSortedList)(nil).Len))
}
// Range mocks base method.
func (m *MockSortedList) Range(arg0 func(Item) bool) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Range", arg0)
}
// Range indicates an expected call of Range.
func (mr *MockSortedListMockRecorder) Range(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Range", reflect.TypeOf((*MockSortedList)(nil).Range), arg0)
}
// Remove mocks base method.
func (m *MockSortedList) Remove(arg0 Item) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Remove", arg0)
}
// Remove indicates an expected call of Remove.
func (mr *MockSortedListMockRecorder) Remove(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockSortedList)(nil).Remove), arg0)
}
// ReverseRange mocks base method.
func (m *MockSortedList) ReverseRange(fn func(Item) bool) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "ReverseRange", fn)
}
// ReverseRange indicates an expected call of ReverseRange.
func (mr *MockSortedListMockRecorder) ReverseRange(fn interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReverseRange", reflect.TypeOf((*MockSortedList)(nil).ReverseRange), fn)
}

View File

@ -1,764 +0,0 @@
/*
* Copyright 2020 The Dragonfly Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package list
import (
"math/rand"
"runtime"
"sync"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
)
const N = 1000
func TestSortedListInsert(t *testing.T) {
tests := []struct {
name string
mock func(m ...*MockItemMockRecorder)
expect func(t *testing.T, l SortedList, items ...Item)
}{
{
name: "insert values",
mock: func(m ...*MockItemMockRecorder) {},
expect: func(t *testing.T, l SortedList, items ...Item) {
assert := assert.New(t)
l.Insert(items[0])
assert.Equal(l.Contains(items[0]), true)
assert.Equal(l.Len(), 1)
},
},
{
name: "insert multi value",
mock: func(m ...*MockItemMockRecorder) {
gomock.InOrder(
m[0].SortedValue().Return(0).Times(1),
m[1].SortedValue().Return(1).Times(1),
)
},
expect: func(t *testing.T, l SortedList, items ...Item) {
assert := assert.New(t)
l.Insert(items[0])
l.Insert(items[1])
assert.Equal(l.Contains(items[0]), true)
assert.Equal(l.Contains(items[1]), true)
assert.Equal(l.Len(), 2)
},
},
{
name: "insert same values",
mock: func(m ...*MockItemMockRecorder) {
gomock.InOrder(
m[0].SortedValue().Return(0).Times(2),
)
},
expect: func(t *testing.T, l SortedList, items ...Item) {
assert := assert.New(t)
l.Insert(items[0])
l.Insert(items[0])
assert.Equal(l.Contains(items[0]), true)
assert.Equal(l.Len(), 2)
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ctl := gomock.NewController(t)
defer ctl.Finish()
mockItems := []*MockItem{NewMockItem(ctl), NewMockItem(ctl)}
tc.mock(mockItems[0].EXPECT(), mockItems[1].EXPECT())
tc.expect(t, NewSortedList(), mockItems[0], mockItems[1])
})
}
}
func TestSortedListInsert_Concurrent(t *testing.T) {
runtime.GOMAXPROCS(2)
ctl := gomock.NewController(t)
defer ctl.Finish()
mockItem := NewMockItem(ctl)
mockItem.EXPECT().SortedValue().DoAndReturn(func() int { return rand.Intn(N) }).AnyTimes()
l := NewSortedList()
nums := rand.Perm(N)
var wg sync.WaitGroup
wg.Add(len(nums))
for i := 0; i < len(nums); i++ {
go func(i int) {
l.Insert(mockItem)
wg.Done()
}(i)
}
wg.Wait()
count := 0
l.Range(func(item Item) bool {
count++
return true
})
if count != len(nums) {
t.Errorf("SortedList is missing element")
}
}
func TestSortedListRemove(t *testing.T) {
tests := []struct {
name string
mock func(m ...*MockItemMockRecorder)
expect func(t *testing.T, l SortedList, items ...Item)
}{
{
name: "remove values",
mock: func(m ...*MockItemMockRecorder) {
gomock.InOrder(
m[0].SortedValue().Return(0).Times(1),
m[1].SortedValue().Return(1).Times(1),
)
},
expect: func(t *testing.T, l SortedList, items ...Item) {
assert := assert.New(t)
l.Insert(items[0])
l.Insert(items[1])
assert.Equal(l.Contains(items[0]), true)
assert.Equal(l.Contains(items[1]), true)
assert.Equal(l.Len(), 2)
l.Remove(items[0])
assert.Equal(l.Contains(items[0]), false)
assert.Equal(l.Len(), 1)
l.Remove(items[1])
assert.Equal(l.Contains(items[1]), false)
assert.Equal(l.Len(), 0)
},
},
{
name: "remove value dost not exits",
mock: func(m ...*MockItemMockRecorder) {},
expect: func(t *testing.T, l SortedList, items ...Item) {
assert := assert.New(t)
l.Insert(items[0])
assert.Equal(l.Contains(items[0]), true)
assert.Equal(l.Len(), 1)
l.Remove(items[1])
assert.Equal(l.Contains(items[0]), true)
assert.Equal(l.Len(), 1)
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ctl := gomock.NewController(t)
defer ctl.Finish()
mockItems := []*MockItem{NewMockItem(ctl), NewMockItem(ctl)}
tc.mock(mockItems[0].EXPECT(), mockItems[1].EXPECT())
tc.expect(t, NewSortedList(), mockItems[0], mockItems[1])
})
}
}
func TestSortedListRemove_Concurrent(t *testing.T) {
runtime.GOMAXPROCS(2)
ctl := gomock.NewController(t)
defer ctl.Finish()
mockItem := NewMockItem(ctl)
mockItem.EXPECT().SortedValue().DoAndReturn(func() int { return rand.Intn(N) }).AnyTimes()
l := NewSortedList()
nums := rand.Perm(N)
for i := 0; i < len(nums); i++ {
l.Insert(mockItem)
}
var wg sync.WaitGroup
wg.Add(len(nums))
for i := 0; i < len(nums); i++ {
go func(i int) {
l.Remove(mockItem)
wg.Done()
}(i)
}
wg.Wait()
if l.Len() != 0 {
t.Errorf("SortedList is redundant elements")
}
}
func TestSortedListContains(t *testing.T) {
tests := []struct {
name string
mock func(m ...*MockItemMockRecorder)
expect func(t *testing.T, l SortedList, items ...Item)
}{
{
name: "contains values",
mock: func(m ...*MockItemMockRecorder) {
gomock.InOrder(
m[0].SortedValue().Return(0).Times(1),
m[1].SortedValue().Return(1).Times(1),
)
},
expect: func(t *testing.T, l SortedList, items ...Item) {
assert := assert.New(t)
l.Insert(items[0])
l.Insert(items[1])
assert.Equal(l.Contains(items[0]), true)
assert.Equal(l.Contains(items[1]), true)
},
},
{
name: "contains value dost not exits",
mock: func(m ...*MockItemMockRecorder) {},
expect: func(t *testing.T, l SortedList, items ...Item) {
assert := assert.New(t)
l.Insert(items[0])
assert.Equal(l.Contains(items[1]), false)
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ctl := gomock.NewController(t)
defer ctl.Finish()
mockItems := []*MockItem{NewMockItem(ctl), NewMockItem(ctl)}
tc.mock(mockItems[0].EXPECT(), mockItems[1].EXPECT())
tc.expect(t, NewSortedList(), mockItems[0], mockItems[1])
})
}
}
func TestSortedListContains_Concurrent(t *testing.T) {
runtime.GOMAXPROCS(2)
ctl := gomock.NewController(t)
defer ctl.Finish()
mockItem := NewMockItem(ctl)
mockItem.EXPECT().SortedValue().DoAndReturn(func() int { return rand.Intn(N) }).AnyTimes()
l := NewSortedList()
nums := rand.Perm(N)
for range nums {
l.Insert(mockItem)
}
var wg sync.WaitGroup
for range nums {
wg.Add(1)
go func() {
if ok := l.Contains(mockItem); !ok {
t.Error("SortedList contains error")
}
wg.Done()
}()
}
wg.Wait()
}
func TestSortedListLen(t *testing.T) {
tests := []struct {
name string
mock func(m ...*MockItemMockRecorder)
expect func(t *testing.T, l SortedList, items ...Item)
}{
{
name: "get length",
mock: func(m ...*MockItemMockRecorder) {
gomock.InOrder(
m[0].SortedValue().Return(0).Times(1),
m[1].SortedValue().Return(1).Times(1),
)
},
expect: func(t *testing.T, l SortedList, items ...Item) {
assert := assert.New(t)
l.Insert(items[0])
l.Insert(items[1])
assert.Equal(l.Len(), 2)
},
},
{
name: "get empty list length",
mock: func(m ...*MockItemMockRecorder) {},
expect: func(t *testing.T, l SortedList, items ...Item) {
assert := assert.New(t)
assert.Equal(l.Len(), 0)
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ctl := gomock.NewController(t)
defer ctl.Finish()
mockItems := []*MockItem{NewMockItem(ctl), NewMockItem(ctl)}
tc.mock(mockItems[0].EXPECT(), mockItems[1].EXPECT())
tc.expect(t, NewSortedList(), mockItems[0], mockItems[1])
})
}
}
func TestSortedListLen_Concurrent(t *testing.T) {
runtime.GOMAXPROCS(2)
ctl := gomock.NewController(t)
defer ctl.Finish()
mockItem := NewMockItem(ctl)
mockItem.EXPECT().SortedValue().DoAndReturn(func() int { return rand.Intn(N) }).AnyTimes()
l := NewSortedList()
var wg sync.WaitGroup
wg.Add(1)
go func() {
elems := l.Len()
for i := 0; i < N; i++ {
newElems := l.Len()
if newElems < elems {
t.Errorf("Len shrunk from %v to %v", elems, newElems)
}
}
wg.Done()
}()
for i := 0; i < N; i++ {
l.Insert(mockItem)
}
wg.Wait()
}
func TestSortedListRange(t *testing.T) {
tests := []struct {
name string
mock func(m ...*MockItemMockRecorder)
expect func(t *testing.T, l SortedList, items ...Item)
}{
{
name: "range values",
mock: func(m ...*MockItemMockRecorder) {
gomock.InOrder(
m[0].SortedValue().Return(0).Times(1),
m[1].SortedValue().Return(1).Times(1),
)
},
expect: func(t *testing.T, l SortedList, items ...Item) {
assert := assert.New(t)
l.Insert(items[0])
l.Insert(items[1])
assert.Equal(l.Len(), 2)
i := 0
l.Range(func(item Item) bool {
assert.Equal(item, items[i])
i++
return true
})
},
},
{
name: "range multi values",
mock: func(m ...*MockItemMockRecorder) {
for i := range m {
m[i].SortedValue().Return(i).AnyTimes()
}
},
expect: func(t *testing.T, l SortedList, items ...Item) {
assert := assert.New(t)
for _, item := range items {
l.Insert(item)
}
assert.Equal(l.Len(), 10)
i := 0
l.Range(func(item Item) bool {
assert.Equal(item, items[i])
i++
return true
})
},
},
{
name: "range stoped",
mock: func(m ...*MockItemMockRecorder) {
gomock.InOrder(
m[0].SortedValue().Return(0).Times(1),
m[1].SortedValue().Return(1).Times(1),
)
},
expect: func(t *testing.T, l SortedList, items ...Item) {
assert := assert.New(t)
l.Insert(items[0])
l.Insert(items[1])
assert.Equal(l.Len(), 2)
l.Range(func(item Item) bool {
assert.Equal(item, items[0])
return false
})
},
},
{
name: "range same values",
mock: func(m ...*MockItemMockRecorder) {
gomock.InOrder(
m[0].SortedValue().Return(0).AnyTimes(),
)
},
expect: func(t *testing.T, l SortedList, items ...Item) {
assert := assert.New(t)
l.Insert(items[0])
l.Insert(items[0])
l.Insert(items[0])
assert.Equal(l.Len(), 3)
count := 0
l.Range(func(item Item) bool {
assert.Equal(item, items[0])
count++
return true
})
assert.Equal(count, 3)
},
},
{
name: "range empty list",
mock: func(m ...*MockItemMockRecorder) {
},
expect: func(t *testing.T, l SortedList, items ...Item) {
assert := assert.New(t)
count := 0
l.Range(func(item Item) bool {
count++
return true
})
assert.Equal(count, 0)
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ctl := gomock.NewController(t)
defer ctl.Finish()
var mockItems []Item
var mockItemRecorders []*MockItemMockRecorder
for i := 0; i < 10; i++ {
mockItem := NewMockItem(ctl)
mockItemRecorders = append(mockItemRecorders, mockItem.EXPECT())
mockItems = append(mockItems, mockItem)
}
tc.mock(mockItemRecorders...)
tc.expect(t, NewSortedList(), mockItems...)
})
}
}
func TestSortedListRange_Concurrent(t *testing.T) {
runtime.GOMAXPROCS(2)
ctl := gomock.NewController(t)
defer ctl.Finish()
mockItem := NewMockItem(ctl)
mockItem.EXPECT().SortedValue().DoAndReturn(func() int { return rand.Intn(N) }).AnyTimes()
l := NewSortedList()
var wg sync.WaitGroup
wg.Add(1)
go func() {
i := 0
l.Range(func(_ Item) bool {
i++
return true
})
j := 0
l.Range(func(_ Item) bool {
j++
return true
})
if j < i {
t.Errorf("Values shrunk from %v to %v", i, j)
}
wg.Done()
}()
for i := 0; i < N; i++ {
l.Insert(mockItem)
}
wg.Wait()
}
func TestSortedListReverseRange(t *testing.T) {
tests := []struct {
name string
mock func(m ...*MockItemMockRecorder)
expect func(t *testing.T, l SortedList, items ...Item)
}{
{
name: "reverse range values",
mock: func(m ...*MockItemMockRecorder) {
gomock.InOrder(
m[0].SortedValue().Return(0).Times(1),
m[1].SortedValue().Return(1).Times(1),
)
},
expect: func(t *testing.T, l SortedList, items ...Item) {
assert := assert.New(t)
l.Insert(items[0])
l.Insert(items[1])
assert.Equal(l.Len(), 2)
i := 0
l.ReverseRange(func(item Item) bool {
assert.Equal(item, items[i])
i++
return true
})
},
},
{
name: "reverse range multi values",
mock: func(m ...*MockItemMockRecorder) {
for i := range m {
m[i].SortedValue().Return(i).AnyTimes()
}
},
expect: func(t *testing.T, l SortedList, items ...Item) {
assert := assert.New(t)
for _, item := range items {
l.Insert(item)
}
assert.Equal(l.Len(), 10)
i := 9
l.ReverseRange(func(item Item) bool {
assert.Equal(item, items[i])
i--
return true
})
},
},
{
name: "reverse range stoped",
mock: func(m ...*MockItemMockRecorder) {
gomock.InOrder(
m[0].SortedValue().Return(0).Times(1),
m[1].SortedValue().Return(1).Times(1),
)
},
expect: func(t *testing.T, l SortedList, items ...Item) {
assert := assert.New(t)
l.Insert(items[0])
l.Insert(items[1])
assert.Equal(l.Len(), 2)
l.ReverseRange(func(item Item) bool {
assert.Equal(item, items[1])
return false
})
},
},
{
name: "reverse range same values",
mock: func(m ...*MockItemMockRecorder) {
gomock.InOrder(
m[0].SortedValue().Return(0).AnyTimes(),
)
},
expect: func(t *testing.T, l SortedList, items ...Item) {
assert := assert.New(t)
l.Insert(items[0])
l.Insert(items[0])
l.Insert(items[0])
assert.Equal(l.Len(), 3)
count := 0
l.ReverseRange(func(item Item) bool {
assert.Equal(item, items[0])
count++
return true
})
assert.Equal(count, 3)
},
},
{
name: "reverse range empty list",
mock: func(m ...*MockItemMockRecorder) {
},
expect: func(t *testing.T, l SortedList, items ...Item) {
assert := assert.New(t)
count := 0
l.ReverseRange(func(item Item) bool {
count++
return true
})
assert.Equal(count, 0)
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ctl := gomock.NewController(t)
defer ctl.Finish()
var mockItems []Item
var mockItemRecorders []*MockItemMockRecorder
for i := 0; i < 10; i++ {
mockItem := NewMockItem(ctl)
mockItemRecorders = append(mockItemRecorders, mockItem.EXPECT())
mockItems = append(mockItems, mockItem)
}
tc.mock(mockItemRecorders...)
tc.expect(t, NewSortedList(), mockItems...)
})
}
}
func TestSortedListReverseRange_Concurrent(t *testing.T) {
runtime.GOMAXPROCS(2)
ctl := gomock.NewController(t)
defer ctl.Finish()
mockItem := NewMockItem(ctl)
mockItem.EXPECT().SortedValue().DoAndReturn(func() int { return rand.Intn(N) }).AnyTimes()
l := NewSortedList()
var wg sync.WaitGroup
wg.Add(1)
go func() {
i := 0
l.ReverseRange(func(_ Item) bool {
i++
return true
})
j := 0
l.ReverseRange(func(_ Item) bool {
j++
return true
})
if j < i {
t.Errorf("Values shrunk from %v to %v", i, j)
}
wg.Done()
}()
for i := 0; i < N; i++ {
l.Insert(mockItem)
}
wg.Wait()
}
type item struct{ id int }
func (i *item) SortedValue() int { return rand.Intn(1000) }
func BenchmarkSortedListInsert(b *testing.B) {
l := NewSortedList()
var mockItems []*item
for i := 0; i < b.N; i++ {
mockItems = append(mockItems, &item{id: i})
}
b.ResetTimer()
for _, mockItem := range mockItems {
l.Insert(mockItem)
}
}
func BenchmarkSortedListRemove(b *testing.B) {
l := NewSortedList()
var mockItems []*item
for i := 0; i < b.N; i++ {
mockItems = append(mockItems, &item{id: i})
}
for _, mockItem := range mockItems {
l.Insert(mockItem)
}
b.ResetTimer()
for _, mockItem := range mockItems {
l.Remove(mockItem)
}
}
func BenchmarkSortedListContains(b *testing.B) {
l := NewSortedList()
var mockItems []*item
for i := 0; i < b.N; i++ {
mockItems = append(mockItems, &item{id: i})
}
for _, mockItem := range mockItems {
l.Insert(mockItem)
}
b.ResetTimer()
for _, mockItem := range mockItems {
l.Contains(mockItem)
}
}
func BenchmarkSortedListRange(b *testing.B) {
l := NewSortedList()
var mockItems []*item
for i := 0; i < b.N; i++ {
mockItems = append(mockItems, &item{id: i})
}
for _, mockItem := range mockItems {
l.Insert(mockItem)
}
b.ResetTimer()
l.Range(func(_ Item) bool { return true })
}
func BenchmarkSortedListReverseRange(b *testing.B) {
l := NewSortedList()
var mockItems []*item
for i := 0; i < b.N; i++ {
mockItems = append(mockItems, &item{id: i})
}
for _, mockItem := range mockItems {
l.Insert(mockItem)
}
b.ResetTimer()
l.ReverseRange(func(_ Item) bool { return true })
}

View File

@ -1,108 +0,0 @@
/*
* Copyright 2020 The Dragonfly Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
//go:generate mockgen -destination sorted_unique_list_mock.go -source sorted_unique_list.go -package list
package list
import (
"sync"
"d7y.io/dragonfly/v2/pkg/container/set"
)
type SortedUniqueList interface {
Len() int
Insert(Item)
Remove(Item)
Contains(Item) bool
Range(func(Item) bool)
ReverseRange(fn func(Item) bool)
}
type sortedUniqueList struct {
mu *sync.RWMutex
container SortedList
data set.Set
}
func NewSortedUniqueList() SortedUniqueList {
return &sortedUniqueList{
mu: &sync.RWMutex{},
container: NewSortedList(),
data: set.New(),
}
}
func (ul *sortedUniqueList) Len() int {
ul.mu.RLock()
defer ul.mu.RUnlock()
return ul.container.Len()
}
func (ul *sortedUniqueList) Insert(item Item) {
ul.mu.Lock()
defer ul.mu.Unlock()
if ok := ul.data.Contains(item); ok {
ul.container.Remove(item)
ul.container.Insert(item)
return
}
ul.data.Add(item)
ul.container.Insert(item)
}
func (ul *sortedUniqueList) Remove(item Item) {
ul.mu.Lock()
defer ul.mu.Unlock()
ul.data.Delete(item)
ul.container.Remove(item)
}
func (ul *sortedUniqueList) Contains(item Item) bool {
ul.mu.RLock()
defer ul.mu.RUnlock()
return ul.data.Contains(item)
}
func (ul *sortedUniqueList) Range(fn func(item Item) bool) {
ul.mu.RLock()
defer ul.mu.RUnlock()
ul.container.Range(func(item Item) bool {
if !fn(item) {
return false
}
return true
})
}
func (ul *sortedUniqueList) ReverseRange(fn func(item Item) bool) {
ul.mu.RLock()
defer ul.mu.RUnlock()
ul.container.ReverseRange(func(item Item) bool {
if !fn(item) {
return false
}
return true
})
}

View File

@ -1,110 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: sorted_unique_list.go
// Package list is a generated GoMock package.
package list
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockSortedUniqueList is a mock of SortedUniqueList interface.
type MockSortedUniqueList struct {
ctrl *gomock.Controller
recorder *MockSortedUniqueListMockRecorder
}
// MockSortedUniqueListMockRecorder is the mock recorder for MockSortedUniqueList.
type MockSortedUniqueListMockRecorder struct {
mock *MockSortedUniqueList
}
// NewMockSortedUniqueList creates a new mock instance.
func NewMockSortedUniqueList(ctrl *gomock.Controller) *MockSortedUniqueList {
mock := &MockSortedUniqueList{ctrl: ctrl}
mock.recorder = &MockSortedUniqueListMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockSortedUniqueList) EXPECT() *MockSortedUniqueListMockRecorder {
return m.recorder
}
// Contains mocks base method.
func (m *MockSortedUniqueList) Contains(arg0 Item) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Contains", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// Contains indicates an expected call of Contains.
func (mr *MockSortedUniqueListMockRecorder) Contains(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Contains", reflect.TypeOf((*MockSortedUniqueList)(nil).Contains), arg0)
}
// Insert mocks base method.
func (m *MockSortedUniqueList) Insert(arg0 Item) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Insert", arg0)
}
// Insert indicates an expected call of Insert.
func (mr *MockSortedUniqueListMockRecorder) Insert(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Insert", reflect.TypeOf((*MockSortedUniqueList)(nil).Insert), arg0)
}
// Len mocks base method.
func (m *MockSortedUniqueList) Len() int {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Len")
ret0, _ := ret[0].(int)
return ret0
}
// Len indicates an expected call of Len.
func (mr *MockSortedUniqueListMockRecorder) Len() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Len", reflect.TypeOf((*MockSortedUniqueList)(nil).Len))
}
// Range mocks base method.
func (m *MockSortedUniqueList) Range(arg0 func(Item) bool) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Range", arg0)
}
// Range indicates an expected call of Range.
func (mr *MockSortedUniqueListMockRecorder) Range(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Range", reflect.TypeOf((*MockSortedUniqueList)(nil).Range), arg0)
}
// Remove mocks base method.
func (m *MockSortedUniqueList) Remove(arg0 Item) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Remove", arg0)
}
// Remove indicates an expected call of Remove.
func (mr *MockSortedUniqueListMockRecorder) Remove(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockSortedUniqueList)(nil).Remove), arg0)
}
// ReverseRange mocks base method.
func (m *MockSortedUniqueList) ReverseRange(fn func(Item) bool) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "ReverseRange", fn)
}
// ReverseRange indicates an expected call of ReverseRange.
func (mr *MockSortedUniqueListMockRecorder) ReverseRange(fn interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReverseRange", reflect.TypeOf((*MockSortedUniqueList)(nil).ReverseRange), fn)
}

View File

@ -1,784 +0,0 @@
/*
* Copyright 2020 The Dragonfly Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package list
import (
"math/rand"
"runtime"
"sync"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
)
func TestSortedUniqueListInsert(t *testing.T) {
tests := []struct {
name string
mock func(m ...*MockItemMockRecorder)
expect func(t *testing.T, ul SortedUniqueList, items ...Item)
}{
{
name: "insert values",
mock: func(m ...*MockItemMockRecorder) {},
expect: func(t *testing.T, ul SortedUniqueList, items ...Item) {
assert := assert.New(t)
ul.Insert(items[0])
assert.Equal(ul.Contains(items[0]), true)
assert.Equal(ul.Len(), 1)
},
},
{
name: "insert multi values",
mock: func(m ...*MockItemMockRecorder) {
gomock.InOrder(
m[0].SortedValue().Return(0).Times(1),
m[1].SortedValue().Return(1).Times(1),
)
},
expect: func(t *testing.T, ul SortedUniqueList, items ...Item) {
assert := assert.New(t)
ul.Insert(items[0])
ul.Insert(items[1])
assert.Equal(ul.Contains(items[0]), true)
assert.Equal(ul.Contains(items[1]), true)
assert.Equal(ul.Len(), 2)
},
},
{
name: "insert same values",
mock: func(m ...*MockItemMockRecorder) {},
expect: func(t *testing.T, ul SortedUniqueList, items ...Item) {
assert := assert.New(t)
ul.Insert(items[0])
ul.Insert(items[0])
assert.Equal(ul.Contains(items[0]), true)
assert.Equal(ul.Len(), 1)
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ctl := gomock.NewController(t)
defer ctl.Finish()
mockItems := []*MockItem{NewMockItem(ctl), NewMockItem(ctl)}
tc.mock(mockItems[0].EXPECT(), mockItems[1].EXPECT())
tc.expect(t, NewSortedUniqueList(), mockItems[0], mockItems[1])
})
}
}
func TestSortedUniqueListInsert_Concurrent(t *testing.T) {
runtime.GOMAXPROCS(2)
ctl := gomock.NewController(t)
defer ctl.Finish()
ul := NewSortedUniqueList()
nums := rand.Perm(N)
var mockItems []Item
for _, v := range nums {
mockItem := NewMockItem(ctl)
mockItem.EXPECT().SortedValue().DoAndReturn(func() int { return v }).AnyTimes()
mockItems = append(mockItems, mockItem)
}
var wg sync.WaitGroup
wg.Add(len(mockItems))
for _, mockItem := range mockItems {
go func(item Item) {
ul.Insert(item)
wg.Done()
}(mockItem)
}
wg.Wait()
count := 0
ul.Range(func(item Item) bool {
count++
return true
})
if count != len(nums) {
t.Errorf("SortedUniqueList is missing element")
}
}
func TestSortedUniqueListRemove(t *testing.T) {
tests := []struct {
name string
mock func(m ...*MockItemMockRecorder)
expect func(t *testing.T, ul SortedUniqueList, items ...Item)
}{
{
name: "remove values",
mock: func(m ...*MockItemMockRecorder) {
gomock.InOrder(
m[0].SortedValue().Return(0).Times(1),
m[1].SortedValue().Return(1).Times(1),
)
},
expect: func(t *testing.T, ul SortedUniqueList, items ...Item) {
assert := assert.New(t)
ul.Insert(items[0])
ul.Insert(items[1])
assert.Equal(ul.Contains(items[0]), true)
assert.Equal(ul.Contains(items[1]), true)
assert.Equal(ul.Len(), 2)
ul.Remove(items[0])
assert.Equal(ul.Contains(items[0]), false)
assert.Equal(ul.Len(), 1)
ul.Remove(items[1])
assert.Equal(ul.Contains(items[1]), false)
assert.Equal(ul.Len(), 0)
},
},
{
name: "remove value dost not exits",
mock: func(m ...*MockItemMockRecorder) {},
expect: func(t *testing.T, ul SortedUniqueList, items ...Item) {
assert := assert.New(t)
ul.Insert(items[0])
assert.Equal(ul.Contains(items[0]), true)
assert.Equal(ul.Len(), 1)
ul.Remove(items[1])
assert.Equal(ul.Contains(items[0]), true)
assert.Equal(ul.Len(), 1)
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ctl := gomock.NewController(t)
defer ctl.Finish()
mockItems := []*MockItem{NewMockItem(ctl), NewMockItem(ctl)}
tc.mock(mockItems[0].EXPECT(), mockItems[1].EXPECT())
tc.expect(t, NewSortedUniqueList(), mockItems[0], mockItems[1])
})
}
}
func TestSortedUniqueListRemove_Concurrent(t *testing.T) {
runtime.GOMAXPROCS(2)
ctl := gomock.NewController(t)
defer ctl.Finish()
ul := NewSortedUniqueList()
nums := rand.Perm(N)
var mockItems []Item
for _, v := range nums {
mockItem := NewMockItem(ctl)
mockItem.EXPECT().SortedValue().DoAndReturn(func() int { return v }).AnyTimes()
mockItems = append(mockItems, mockItem)
ul.Insert(mockItem)
}
var wg sync.WaitGroup
wg.Add(len(mockItems))
for _, mockItem := range mockItems {
go func(item Item) {
ul.Remove(item)
wg.Done()
}(mockItem)
}
wg.Wait()
if ul.Len() != 0 {
t.Errorf("SortedUniqueList is redundant elements")
}
}
func TestSortedUniqueListContains(t *testing.T) {
tests := []struct {
name string
mock func(m ...*MockItemMockRecorder)
expect func(t *testing.T, ul SortedUniqueList, items ...Item)
}{
{
name: "contains values",
mock: func(m ...*MockItemMockRecorder) {
gomock.InOrder(
m[0].SortedValue().Return(0).Times(1),
m[1].SortedValue().Return(1).Times(1),
)
},
expect: func(t *testing.T, ul SortedUniqueList, items ...Item) {
assert := assert.New(t)
ul.Insert(items[0])
ul.Insert(items[1])
assert.Equal(ul.Contains(items[0]), true)
assert.Equal(ul.Contains(items[1]), true)
},
},
{
name: "contains value dost not exits",
mock: func(m ...*MockItemMockRecorder) {},
expect: func(t *testing.T, ul SortedUniqueList, items ...Item) {
assert := assert.New(t)
ul.Insert(items[0])
assert.Equal(ul.Contains(items[1]), false)
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ctl := gomock.NewController(t)
defer ctl.Finish()
mockItems := []*MockItem{NewMockItem(ctl), NewMockItem(ctl)}
tc.mock(mockItems[0].EXPECT(), mockItems[1].EXPECT())
tc.expect(t, NewSortedUniqueList(), mockItems[0], mockItems[1])
})
}
}
func TestSortedUniqueListContains_Concurrent(t *testing.T) {
runtime.GOMAXPROCS(2)
ctl := gomock.NewController(t)
defer ctl.Finish()
ul := NewSortedUniqueList()
nums := rand.Perm(N)
var mockItems []Item
for _, v := range nums {
mockItem := NewMockItem(ctl)
mockItem.EXPECT().SortedValue().DoAndReturn(func() int { return v }).AnyTimes()
mockItems = append(mockItems, mockItem)
ul.Insert(mockItem)
}
var wg sync.WaitGroup
wg.Add(len(mockItems))
for _, mockItem := range mockItems {
go func(item Item) {
if ok := ul.Contains(item); !ok {
t.Error("SortedUniqueList contains error")
}
wg.Done()
}(mockItem)
}
wg.Wait()
}
func TestSortedUniqueListLen(t *testing.T) {
tests := []struct {
name string
mock func(m ...*MockItemMockRecorder)
expect func(t *testing.T, ul SortedUniqueList, items ...Item)
}{
{
name: "get length",
mock: func(m ...*MockItemMockRecorder) {
gomock.InOrder(
m[0].SortedValue().Return(0).Times(1),
m[1].SortedValue().Return(1).Times(1),
)
},
expect: func(t *testing.T, ul SortedUniqueList, items ...Item) {
assert := assert.New(t)
ul.Insert(items[0])
ul.Insert(items[1])
assert.Equal(ul.Len(), 2)
},
},
{
name: "get empty list length",
mock: func(m ...*MockItemMockRecorder) {},
expect: func(t *testing.T, ul SortedUniqueList, items ...Item) {
assert := assert.New(t)
assert.Equal(ul.Len(), 0)
},
},
{
name: "get same values length",
mock: func(m ...*MockItemMockRecorder) {},
expect: func(t *testing.T, ul SortedUniqueList, items ...Item) {
assert := assert.New(t)
ul.Insert(items[0])
ul.Insert(items[0])
assert.Equal(ul.Len(), 1)
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ctl := gomock.NewController(t)
defer ctl.Finish()
mockItems := []*MockItem{NewMockItem(ctl), NewMockItem(ctl)}
tc.mock(mockItems[0].EXPECT(), mockItems[1].EXPECT())
tc.expect(t, NewSortedUniqueList(), mockItems[0], mockItems[1])
})
}
}
func TestSortedUniqueListLen_Concurrent(t *testing.T) {
runtime.GOMAXPROCS(2)
ctl := gomock.NewController(t)
defer ctl.Finish()
ul := NewSortedUniqueList()
nums := rand.Perm(N)
var mockItems []Item
for _, v := range nums {
mockItem := NewMockItem(ctl)
mockItem.EXPECT().SortedValue().DoAndReturn(func() int { return v }).AnyTimes()
mockItems = append(mockItems, mockItem)
}
var wg sync.WaitGroup
wg.Add(1)
go func() {
elems := ul.Len()
for i := 0; i < N; i++ {
newElems := ul.Len()
if newElems < elems {
t.Errorf("Len shrunk from %v to %v", elems, newElems)
}
}
wg.Done()
}()
for _, mockItem := range mockItems {
ul.Insert(mockItem)
}
wg.Wait()
}
func TestSortedUniqueListRange(t *testing.T) {
tests := []struct {
name string
mock func(m ...*MockItemMockRecorder)
expect func(t *testing.T, ul SortedUniqueList, items ...Item)
}{
{
name: "range values",
mock: func(m ...*MockItemMockRecorder) {
gomock.InOrder(
m[0].SortedValue().Return(0).Times(1),
m[1].SortedValue().Return(1).Times(1),
)
},
expect: func(t *testing.T, ul SortedUniqueList, items ...Item) {
assert := assert.New(t)
ul.Insert(items[0])
ul.Insert(items[1])
assert.Equal(ul.Len(), 2)
i := 0
ul.Range(func(item Item) bool {
assert.Equal(item, items[i])
i++
return true
})
},
},
{
name: "range multi values",
mock: func(m ...*MockItemMockRecorder) {
for i := range m {
m[i].SortedValue().Return(i).AnyTimes()
}
},
expect: func(t *testing.T, ul SortedUniqueList, items ...Item) {
assert := assert.New(t)
for _, item := range items {
ul.Insert(item)
}
ul.Insert(items[1])
assert.Equal(ul.Len(), 10)
i := 0
ul.Range(func(item Item) bool {
assert.Equal(item, items[i])
i++
return true
})
},
},
{
name: "range stoped",
mock: func(m ...*MockItemMockRecorder) {
gomock.InOrder(
m[0].SortedValue().Return(0).Times(1),
m[1].SortedValue().Return(1).Times(1),
)
},
expect: func(t *testing.T, ul SortedUniqueList, items ...Item) {
assert := assert.New(t)
ul.Insert(items[0])
ul.Insert(items[1])
assert.Equal(ul.Len(), 2)
ul.Range(func(item Item) bool {
assert.Equal(item, items[0])
return false
})
},
},
{
name: "range same values",
mock: func(m ...*MockItemMockRecorder) {
gomock.InOrder(
m[0].SortedValue().Return(0).AnyTimes(),
)
},
expect: func(t *testing.T, ul SortedUniqueList, items ...Item) {
assert := assert.New(t)
ul.Insert(items[0])
ul.Insert(items[0])
ul.Insert(items[0])
assert.Equal(ul.Len(), 1)
count := 0
ul.Range(func(item Item) bool {
assert.Equal(item, items[0])
count++
return true
})
assert.Equal(count, 1)
},
},
{
name: "range empty list",
mock: func(m ...*MockItemMockRecorder) {
},
expect: func(t *testing.T, ul SortedUniqueList, items ...Item) {
assert := assert.New(t)
count := 0
ul.Range(func(item Item) bool {
count++
return true
})
assert.Equal(count, 0)
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ctl := gomock.NewController(t)
defer ctl.Finish()
var mockItems []Item
var mockItemRecorders []*MockItemMockRecorder
for i := 0; i < 10; i++ {
mockItem := NewMockItem(ctl)
mockItemRecorders = append(mockItemRecorders, mockItem.EXPECT())
mockItems = append(mockItems, mockItem)
}
tc.mock(mockItemRecorders...)
tc.expect(t, NewSortedUniqueList(), mockItems...)
})
}
}
func TestSortedUniqueListRange_Concurrent(t *testing.T) {
runtime.GOMAXPROCS(2)
ctl := gomock.NewController(t)
defer ctl.Finish()
ul := NewSortedUniqueList()
nums := rand.Perm(N)
var mockItems []Item
for _, v := range nums {
mockItem := NewMockItem(ctl)
mockItem.EXPECT().SortedValue().DoAndReturn(func() int { return v }).AnyTimes()
mockItems = append(mockItems, mockItem)
}
var wg sync.WaitGroup
wg.Add(1)
go func() {
i := 0
ul.Range(func(_ Item) bool {
i++
return true
})
j := 0
ul.Range(func(_ Item) bool {
j++
return true
})
if j < i {
t.Errorf("Values shrunk from %v to %v", i, j)
}
wg.Done()
}()
for _, mockItem := range mockItems {
ul.Insert(mockItem)
}
wg.Wait()
}
func TestSortedUniqueListReverseRange(t *testing.T) {
tests := []struct {
name string
mock func(m ...*MockItemMockRecorder)
expect func(t *testing.T, ul SortedUniqueList, items ...Item)
}{
{
name: "reverse range values",
mock: func(m ...*MockItemMockRecorder) {
gomock.InOrder(
m[0].SortedValue().Return(0).Times(1),
m[1].SortedValue().Return(1).Times(1),
)
},
expect: func(t *testing.T, ul SortedUniqueList, items ...Item) {
assert := assert.New(t)
ul.Insert(items[0])
ul.Insert(items[1])
assert.Equal(ul.Len(), 2)
i := 1
ul.ReverseRange(func(item Item) bool {
assert.Equal(item, items[i])
i--
return true
})
},
},
{
name: "reverse range multi values",
mock: func(m ...*MockItemMockRecorder) {
for i := range m {
m[i].SortedValue().Return(i).AnyTimes()
}
},
expect: func(t *testing.T, ul SortedUniqueList, items ...Item) {
assert := assert.New(t)
for _, item := range items {
ul.Insert(item)
}
ul.Insert(items[1])
assert.Equal(ul.Len(), 10)
i := 9
ul.Range(func(item Item) bool {
assert.Equal(item, items[i])
i--
return true
})
},
},
{
name: "reverse range stoped",
mock: func(m ...*MockItemMockRecorder) {
gomock.InOrder(
m[0].SortedValue().Return(0).Times(1),
m[1].SortedValue().Return(1).Times(1),
)
},
expect: func(t *testing.T, ul SortedUniqueList, items ...Item) {
assert := assert.New(t)
ul.Insert(items[0])
ul.Insert(items[1])
assert.Equal(ul.Len(), 2)
ul.ReverseRange(func(item Item) bool {
assert.Equal(item, items[0])
return false
})
},
},
{
name: "reverse range same values",
mock: func(m ...*MockItemMockRecorder) {
gomock.InOrder(
m[0].SortedValue().Return(0).AnyTimes(),
)
},
expect: func(t *testing.T, ul SortedUniqueList, items ...Item) {
assert := assert.New(t)
ul.Insert(items[0])
ul.Insert(items[0])
ul.Insert(items[0])
assert.Equal(ul.Len(), 1)
count := 0
ul.ReverseRange(func(item Item) bool {
assert.Equal(item, items[0])
count++
return true
})
assert.Equal(count, 1)
},
},
{
name: "reverse range empty list",
mock: func(m ...*MockItemMockRecorder) {
},
expect: func(t *testing.T, ul SortedUniqueList, items ...Item) {
assert := assert.New(t)
count := 0
ul.ReverseRange(func(item Item) bool {
count++
return true
})
assert.Equal(count, 0)
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ctl := gomock.NewController(t)
defer ctl.Finish()
var mockItems []Item
var mockItemRecorders []*MockItemMockRecorder
for i := 0; i < 10; i++ {
mockItem := NewMockItem(ctl)
mockItemRecorders = append(mockItemRecorders, mockItem.EXPECT())
mockItems = append(mockItems, mockItem)
}
tc.mock(mockItemRecorders...)
tc.expect(t, NewSortedUniqueList(), mockItems...)
})
}
}
func TestSortedUniqueListReverseRange_Concurrent(t *testing.T) {
runtime.GOMAXPROCS(2)
ctl := gomock.NewController(t)
defer ctl.Finish()
ul := NewSortedUniqueList()
nums := rand.Perm(N)
var mockItems []Item
for _, v := range nums {
mockItem := NewMockItem(ctl)
mockItem.EXPECT().SortedValue().DoAndReturn(func() int { return v }).AnyTimes()
mockItems = append(mockItems, mockItem)
}
var wg sync.WaitGroup
wg.Add(1)
go func() {
i := 0
ul.ReverseRange(func(_ Item) bool {
i++
return true
})
j := 0
ul.ReverseRange(func(_ Item) bool {
j++
return true
})
if j < i {
t.Errorf("Values shrunk from %v to %v", i, j)
}
wg.Done()
}()
for _, mockItem := range mockItems {
ul.Insert(mockItem)
}
wg.Wait()
}
func BenchmarkSortedUniqueListInsert(b *testing.B) {
ul := NewSortedUniqueList()
var mockItems []*item
for i := 0; i < b.N; i++ {
mockItems = append(mockItems, &item{id: i})
}
b.ResetTimer()
for _, mockItem := range mockItems {
ul.Insert(mockItem)
}
}
func BenchmarkSortedUniqueListRemove(b *testing.B) {
ul := NewSortedUniqueList()
var mockItems []*item
for i := 0; i < b.N; i++ {
mockItem := &item{id: i}
ul.Insert(mockItem)
mockItems = append(mockItems, mockItem)
}
b.ResetTimer()
for _, mockItem := range mockItems {
ul.Remove(mockItem)
}
}
func BenchmarkSortedUniqueListContains(b *testing.B) {
ul := NewSortedUniqueList()
var mockItems []*item
for i := 0; i < b.N; i++ {
mockItem := &item{id: i}
ul.Insert(mockItem)
mockItems = append(mockItems, mockItem)
}
b.ResetTimer()
for _, mockItem := range mockItems {
ul.Contains(mockItem)
}
}
func BenchmarkSortedUniqueListRange(b *testing.B) {
ul := NewSortedUniqueList()
for i := 0; i < b.N; i++ {
mockItem := item{id: i}
ul.Insert(&mockItem)
}
b.ResetTimer()
ul.Range(func(_ Item) bool { return true })
}
func BenchmarkSortedUniqueListReverseRange(b *testing.B) {
ul := NewSortedUniqueList()
for i := 0; i < b.N; i++ {
mockItem := item{id: i}
ul.Insert(&mockItem)
}
b.ResetTimer()
ul.ReverseRange(func(item Item) bool { return true })
}

View File

@ -11,30 +11,30 @@ import (
) )
// MockSafeSet is a mock of SafeSet interface. // MockSafeSet is a mock of SafeSet interface.
type MockSafeSet struct { type MockSafeSet[T comparable] struct {
ctrl *gomock.Controller ctrl *gomock.Controller
recorder *MockSafeSetMockRecorder recorder *MockSafeSetMockRecorder[T]
} }
// MockSafeSetMockRecorder is the mock recorder for MockSafeSet. // MockSafeSetMockRecorder is the mock recorder for MockSafeSet.
type MockSafeSetMockRecorder struct { type MockSafeSetMockRecorder[T comparable] struct {
mock *MockSafeSet mock *MockSafeSet[T]
} }
// NewMockSafeSet creates a new mock instance. // NewMockSafeSet creates a new mock instance.
func NewMockSafeSet(ctrl *gomock.Controller) *MockSafeSet { func NewMockSafeSet[T comparable](ctrl *gomock.Controller) *MockSafeSet[T] {
mock := &MockSafeSet{ctrl: ctrl} mock := &MockSafeSet[T]{ctrl: ctrl}
mock.recorder = &MockSafeSetMockRecorder{mock} mock.recorder = &MockSafeSetMockRecorder[T]{mock}
return mock return mock
} }
// EXPECT returns an object that allows the caller to indicate expected use. // EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockSafeSet) EXPECT() *MockSafeSetMockRecorder { func (m *MockSafeSet[T]) EXPECT() *MockSafeSetMockRecorder[T] {
return m.recorder return m.recorder
} }
// Add mocks base method. // Add mocks base method.
func (m *MockSafeSet) Add(arg0 any) bool { func (m *MockSafeSet[T]) Add(arg0 T) bool {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Add", arg0) ret := m.ctrl.Call(m, "Add", arg0)
ret0, _ := ret[0].(bool) ret0, _ := ret[0].(bool)
@ -42,25 +42,25 @@ func (m *MockSafeSet) Add(arg0 any) bool {
} }
// Add indicates an expected call of Add. // Add indicates an expected call of Add.
func (mr *MockSafeSetMockRecorder) Add(arg0 interface{}) *gomock.Call { func (mr *MockSafeSetMockRecorder[T]) Add(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockSafeSet)(nil).Add), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockSafeSet[T])(nil).Add), arg0)
} }
// Clear mocks base method. // Clear mocks base method.
func (m *MockSafeSet) Clear() { func (m *MockSafeSet[T]) Clear() {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "Clear") m.ctrl.Call(m, "Clear")
} }
// Clear indicates an expected call of Clear. // Clear indicates an expected call of Clear.
func (mr *MockSafeSetMockRecorder) Clear() *gomock.Call { func (mr *MockSafeSetMockRecorder[T]) Clear() *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Clear", reflect.TypeOf((*MockSafeSet)(nil).Clear)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Clear", reflect.TypeOf((*MockSafeSet[T])(nil).Clear))
} }
// Contains mocks base method. // Contains mocks base method.
func (m *MockSafeSet) Contains(arg0 ...any) bool { func (m *MockSafeSet[T]) Contains(arg0 ...T) bool {
m.ctrl.T.Helper() m.ctrl.T.Helper()
varargs := []interface{}{} varargs := []interface{}{}
for _, a := range arg0 { for _, a := range arg0 {
@ -72,25 +72,25 @@ func (m *MockSafeSet) Contains(arg0 ...any) bool {
} }
// Contains indicates an expected call of Contains. // Contains indicates an expected call of Contains.
func (mr *MockSafeSetMockRecorder) Contains(arg0 ...interface{}) *gomock.Call { func (mr *MockSafeSetMockRecorder[T]) Contains(arg0 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Contains", reflect.TypeOf((*MockSafeSet)(nil).Contains), arg0...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Contains", reflect.TypeOf((*MockSafeSet[T])(nil).Contains), arg0...)
} }
// Delete mocks base method. // Delete mocks base method.
func (m *MockSafeSet) Delete(arg0 any) { func (m *MockSafeSet[T]) Delete(arg0 T) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "Delete", arg0) m.ctrl.Call(m, "Delete", arg0)
} }
// Delete indicates an expected call of Delete. // Delete indicates an expected call of Delete.
func (mr *MockSafeSetMockRecorder) Delete(arg0 interface{}) *gomock.Call { func (mr *MockSafeSetMockRecorder[T]) Delete(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockSafeSet)(nil).Delete), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockSafeSet[T])(nil).Delete), arg0)
} }
// Len mocks base method. // Len mocks base method.
func (m *MockSafeSet) Len() uint { func (m *MockSafeSet[T]) Len() uint {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Len") ret := m.ctrl.Call(m, "Len")
ret0, _ := ret[0].(uint) ret0, _ := ret[0].(uint)
@ -98,21 +98,21 @@ func (m *MockSafeSet) Len() uint {
} }
// Len indicates an expected call of Len. // Len indicates an expected call of Len.
func (mr *MockSafeSetMockRecorder) Len() *gomock.Call { func (mr *MockSafeSetMockRecorder[T]) Len() *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Len", reflect.TypeOf((*MockSafeSet)(nil).Len)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Len", reflect.TypeOf((*MockSafeSet[T])(nil).Len))
} }
// Values mocks base method. // Values mocks base method.
func (m *MockSafeSet) Values() []any { func (m *MockSafeSet[T]) Values() []T {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Values") ret := m.ctrl.Call(m, "Values")
ret0, _ := ret[0].([]any) ret0, _ := ret[0].([]T)
return ret0 return ret0
} }
// Values indicates an expected call of Values. // Values indicates an expected call of Values.
func (mr *MockSafeSetMockRecorder) Values() *gomock.Call { func (mr *MockSafeSetMockRecorder[T]) Values() *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Values", reflect.TypeOf((*MockSafeSet)(nil).Values)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Values", reflect.TypeOf((*MockSafeSet[T])(nil).Values))
} }

View File

@ -11,30 +11,30 @@ import (
) )
// MockSet is a mock of Set interface. // MockSet is a mock of Set interface.
type MockSet struct { type MockSet[T comparable] struct {
ctrl *gomock.Controller ctrl *gomock.Controller
recorder *MockSetMockRecorder recorder *MockSetMockRecorder[T]
} }
// MockSetMockRecorder is the mock recorder for MockSet. // MockSetMockRecorder is the mock recorder for MockSet.
type MockSetMockRecorder struct { type MockSetMockRecorder[T comparable] struct {
mock *MockSet mock *MockSet[T]
} }
// NewMockSet creates a new mock instance. // NewMockSet creates a new mock instance.
func NewMockSet(ctrl *gomock.Controller) *MockSet { func NewMockSet[T comparable](ctrl *gomock.Controller) *MockSet[T] {
mock := &MockSet{ctrl: ctrl} mock := &MockSet[T]{ctrl: ctrl}
mock.recorder = &MockSetMockRecorder{mock} mock.recorder = &MockSetMockRecorder[T]{mock}
return mock return mock
} }
// EXPECT returns an object that allows the caller to indicate expected use. // EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockSet) EXPECT() *MockSetMockRecorder { func (m *MockSet[T]) EXPECT() *MockSetMockRecorder[T] {
return m.recorder return m.recorder
} }
// Add mocks base method. // Add mocks base method.
func (m *MockSet) Add(arg0 any) bool { func (m *MockSet[T]) Add(arg0 T) bool {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Add", arg0) ret := m.ctrl.Call(m, "Add", arg0)
ret0, _ := ret[0].(bool) ret0, _ := ret[0].(bool)
@ -42,25 +42,25 @@ func (m *MockSet) Add(arg0 any) bool {
} }
// Add indicates an expected call of Add. // Add indicates an expected call of Add.
func (mr *MockSetMockRecorder) Add(arg0 interface{}) *gomock.Call { func (mr *MockSetMockRecorder[T]) Add(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockSet)(nil).Add), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockSet[T])(nil).Add), arg0)
} }
// Clear mocks base method. // Clear mocks base method.
func (m *MockSet) Clear() { func (m *MockSet[T]) Clear() {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "Clear") m.ctrl.Call(m, "Clear")
} }
// Clear indicates an expected call of Clear. // Clear indicates an expected call of Clear.
func (mr *MockSetMockRecorder) Clear() *gomock.Call { func (mr *MockSetMockRecorder[T]) Clear() *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Clear", reflect.TypeOf((*MockSet)(nil).Clear)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Clear", reflect.TypeOf((*MockSet[T])(nil).Clear))
} }
// Contains mocks base method. // Contains mocks base method.
func (m *MockSet) Contains(arg0 ...any) bool { func (m *MockSet[T]) Contains(arg0 ...T) bool {
m.ctrl.T.Helper() m.ctrl.T.Helper()
varargs := []interface{}{} varargs := []interface{}{}
for _, a := range arg0 { for _, a := range arg0 {
@ -72,25 +72,25 @@ func (m *MockSet) Contains(arg0 ...any) bool {
} }
// Contains indicates an expected call of Contains. // Contains indicates an expected call of Contains.
func (mr *MockSetMockRecorder) Contains(arg0 ...interface{}) *gomock.Call { func (mr *MockSetMockRecorder[T]) Contains(arg0 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Contains", reflect.TypeOf((*MockSet)(nil).Contains), arg0...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Contains", reflect.TypeOf((*MockSet[T])(nil).Contains), arg0...)
} }
// Delete mocks base method. // Delete mocks base method.
func (m *MockSet) Delete(arg0 any) { func (m *MockSet[T]) Delete(arg0 T) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "Delete", arg0) m.ctrl.Call(m, "Delete", arg0)
} }
// Delete indicates an expected call of Delete. // Delete indicates an expected call of Delete.
func (mr *MockSetMockRecorder) Delete(arg0 interface{}) *gomock.Call { func (mr *MockSetMockRecorder[T]) Delete(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockSet)(nil).Delete), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockSet[T])(nil).Delete), arg0)
} }
// Len mocks base method. // Len mocks base method.
func (m *MockSet) Len() uint { func (m *MockSet[T]) Len() uint {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Len") ret := m.ctrl.Call(m, "Len")
ret0, _ := ret[0].(uint) ret0, _ := ret[0].(uint)
@ -98,21 +98,21 @@ func (m *MockSet) Len() uint {
} }
// Len indicates an expected call of Len. // Len indicates an expected call of Len.
func (mr *MockSetMockRecorder) Len() *gomock.Call { func (mr *MockSetMockRecorder[T]) Len() *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Len", reflect.TypeOf((*MockSet)(nil).Len)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Len", reflect.TypeOf((*MockSet[T])(nil).Len))
} }
// Values mocks base method. // Values mocks base method.
func (m *MockSet) Values() []any { func (m *MockSet[T]) Values() []T {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Values") ret := m.ctrl.Call(m, "Values")
ret0, _ := ret[0].([]any) ret0, _ := ret[0].([]T)
return ret0 return ret0
} }
// Values indicates an expected call of Values. // Values indicates an expected call of Values.
func (mr *MockSetMockRecorder) Values() *gomock.Call { func (mr *MockSetMockRecorder[T]) Values() *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Values", reflect.TypeOf((*MockSet)(nil).Values)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Values", reflect.TypeOf((*MockSet[T])(nil).Values))
} }

View File

@ -22,32 +22,32 @@ import (
"sync" "sync"
) )
type SafeSet interface { type SafeSet[T comparable] interface {
Values() []any Values() []T
Add(any) bool Add(T) bool
Delete(any) Delete(T)
Contains(...any) bool Contains(...T) bool
Len() uint Len() uint
Clear() Clear()
} }
type safeSet struct { type safeSet[T comparable] struct {
mu *sync.RWMutex mu *sync.RWMutex
data map[any]struct{} data map[T]struct{}
} }
func NewSafeSet() SafeSet { func NewSafeSet[T comparable]() SafeSet[T] {
return &safeSet{ return &safeSet[T]{
mu: &sync.RWMutex{}, mu: &sync.RWMutex{},
data: make(map[any]struct{}), data: make(map[T]struct{}),
} }
} }
func (s *safeSet) Values() []any { func (s *safeSet[T]) Values() []T {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
var result []any var result []T
for k := range s.data { for k := range s.data {
result = append(result, k) result = append(result, k)
} }
@ -55,7 +55,7 @@ func (s *safeSet) Values() []any {
return result return result
} }
func (s *safeSet) Add(v any) bool { func (s *safeSet[T]) Add(v T) bool {
s.mu.RLock() s.mu.RLock()
_, found := s.data[v] _, found := s.data[v]
if found { if found {
@ -70,13 +70,13 @@ func (s *safeSet) Add(v any) bool {
return true return true
} }
func (s *safeSet) Delete(v any) { func (s *safeSet[T]) Delete(v T) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
delete(s.data, v) delete(s.data, v)
} }
func (s *safeSet) Contains(vals ...any) bool { func (s *safeSet[T]) Contains(vals ...T) bool {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
for _, v := range vals { for _, v := range vals {
@ -88,14 +88,14 @@ func (s *safeSet) Contains(vals ...any) bool {
return true return true
} }
func (s *safeSet) Len() uint { func (s *safeSet[T]) Len() uint {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
return uint(len(s.data)) return uint(len(s.data))
} }
func (s *safeSet) Clear() { func (s *safeSet[T]) Clear() {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
s.data = make(map[any]struct{}) s.data = make(map[T]struct{})
} }

View File

@ -30,33 +30,33 @@ const N = 1000
func TestSafeSetAdd(t *testing.T) { func TestSafeSetAdd(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
value any value string
expect func(t *testing.T, ok bool, s SafeSet, value any) expect func(t *testing.T, ok bool, s SafeSet[string], value string)
}{ }{
{ {
name: "add value", name: "add value",
value: "foo", value: "foo",
expect: func(t *testing.T, ok bool, s SafeSet, value any) { expect: func(t *testing.T, ok bool, s SafeSet[string], value string) {
assert := assert.New(t) assert := assert.New(t)
assert.Equal(ok, true) assert.Equal(ok, true)
assert.Equal(s.Values(), []any{value}) assert.Equal(s.Values(), []string{value})
}, },
}, },
{ {
name: "add value failed", name: "add value failed",
value: "foo", value: "foo",
expect: func(t *testing.T, _ bool, s SafeSet, value any) { expect: func(t *testing.T, _ bool, s SafeSet[string], value string) {
assert := assert.New(t) assert := assert.New(t)
ok := s.Add("foo") ok := s.Add("foo")
assert.Equal(ok, false) assert.Equal(ok, false)
assert.Equal(s.Values(), []any{value}) assert.Equal(s.Values(), []string{value})
}, },
}, },
} }
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
s := NewSafeSet() s := NewSafeSet[string]()
tc.expect(t, s.Add(tc.value), s, tc.value) tc.expect(t, s.Add(tc.value), s, tc.value)
}) })
} }
@ -65,7 +65,7 @@ func TestSafeSetAdd(t *testing.T) {
func TestSafeSetAdd_Concurrent(t *testing.T) { func TestSafeSetAdd_Concurrent(t *testing.T) {
runtime.GOMAXPROCS(2) runtime.GOMAXPROCS(2)
s := NewSafeSet() s := NewSafeSet[int]()
nums := rand.Perm(N) nums := rand.Perm(N)
var wg sync.WaitGroup var wg sync.WaitGroup
@ -88,13 +88,13 @@ func TestSafeSetAdd_Concurrent(t *testing.T) {
func TestSafeSetDelete(t *testing.T) { func TestSafeSetDelete(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
value any value string
expect func(t *testing.T, s SafeSet, value any) expect func(t *testing.T, s SafeSet[string], value string)
}{ }{
{ {
name: "delete value", name: "delete value",
value: "foo", value: "foo",
expect: func(t *testing.T, s SafeSet, value any) { expect: func(t *testing.T, s SafeSet[string], value string) {
assert := assert.New(t) assert := assert.New(t)
s.Delete(value) s.Delete(value)
assert.Equal(s.Len(), uint(0)) assert.Equal(s.Len(), uint(0))
@ -103,7 +103,7 @@ func TestSafeSetDelete(t *testing.T) {
{ {
name: "delete value does not exist", name: "delete value does not exist",
value: "foo", value: "foo",
expect: func(t *testing.T, s SafeSet, _ any) { expect: func(t *testing.T, s SafeSet[string], _ string) {
assert := assert.New(t) assert := assert.New(t)
s.Delete("bar") s.Delete("bar")
assert.Equal(s.Len(), uint(1)) assert.Equal(s.Len(), uint(1))
@ -113,7 +113,7 @@ func TestSafeSetDelete(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
s := NewSafeSet() s := NewSafeSet[string]()
s.Add(tc.value) s.Add(tc.value)
tc.expect(t, s, tc.value) tc.expect(t, s, tc.value)
}) })
@ -123,7 +123,7 @@ func TestSafeSetDelete(t *testing.T) {
func TestSafeSetDelete_Concurrent(t *testing.T) { func TestSafeSetDelete_Concurrent(t *testing.T) {
runtime.GOMAXPROCS(2) runtime.GOMAXPROCS(2)
s := NewSafeSet() s := NewSafeSet[int]()
nums := rand.Perm(N) nums := rand.Perm(N)
for _, v := range nums { for _, v := range nums {
s.Add(v) s.Add(v)
@ -147,21 +147,21 @@ func TestSafeSetDelete_Concurrent(t *testing.T) {
func TestSafeSetContains(t *testing.T) { func TestSafeSetContains(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
value any value string
expect func(t *testing.T, s SafeSet, value any) expect func(t *testing.T, s SafeSet[string], value string)
}{ }{
{ {
name: "contains value", name: "contains value",
value: "foo", value: "foo",
expect: func(t *testing.T, s SafeSet, value any) { expect: func(t *testing.T, s SafeSet[string], value string) {
assert := assert.New(t) assert := assert.New(t)
assert.Equal(s.Contains(value), true) assert.Equal(s.Contains(string(value)), true)
}, },
}, },
{ {
name: "contains value does not exist", name: "contains value does not exist",
value: "foo", value: "foo",
expect: func(t *testing.T, s SafeSet, _ any) { expect: func(t *testing.T, s SafeSet[string], _ string) {
assert := assert.New(t) assert := assert.New(t)
assert.Equal(s.Contains("bar"), false) assert.Equal(s.Contains("bar"), false)
}, },
@ -170,7 +170,7 @@ func TestSafeSetContains(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
s := NewSafeSet() s := NewSafeSet[string]()
s.Add(tc.value) s.Add(tc.value)
tc.expect(t, s, tc.value) tc.expect(t, s, tc.value)
}) })
@ -180,9 +180,9 @@ func TestSafeSetContains(t *testing.T) {
func TestSafeSetContains_Concurrent(t *testing.T) { func TestSafeSetContains_Concurrent(t *testing.T) {
runtime.GOMAXPROCS(2) runtime.GOMAXPROCS(2)
s := NewSafeSet() s := NewSafeSet[int]()
nums := rand.Perm(N) nums := rand.Perm(N)
interfaces := make([]any, 0) interfaces := make([]int, 0)
for _, v := range nums { for _, v := range nums {
s.Add(v) s.Add(v)
interfaces = append(interfaces, v) interfaces = append(interfaces, v)
@ -202,11 +202,11 @@ func TestSafeSetContains_Concurrent(t *testing.T) {
func TestSetSafeLen(t *testing.T) { func TestSetSafeLen(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
expect func(t *testing.T, s SafeSet) expect func(t *testing.T, s SafeSet[string])
}{ }{
{ {
name: "get length", name: "get length",
expect: func(t *testing.T, s SafeSet) { expect: func(t *testing.T, s SafeSet[string]) {
assert := assert.New(t) assert := assert.New(t)
s.Add("foo") s.Add("foo")
assert.Equal(s.Len(), uint(1)) assert.Equal(s.Len(), uint(1))
@ -214,7 +214,7 @@ func TestSetSafeLen(t *testing.T) {
}, },
{ {
name: "get empty set length", name: "get empty set length",
expect: func(t *testing.T, s SafeSet) { expect: func(t *testing.T, s SafeSet[string]) {
assert := assert.New(t) assert := assert.New(t)
assert.Equal(s.Len(), uint(0)) assert.Equal(s.Len(), uint(0))
}, },
@ -223,7 +223,7 @@ func TestSetSafeLen(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
s := NewSafeSet() s := NewSafeSet[string]()
tc.expect(t, s) tc.expect(t, s)
}) })
} }
@ -232,7 +232,7 @@ func TestSetSafeLen(t *testing.T) {
func TestSafeSetLen_Concurrent(t *testing.T) { func TestSafeSetLen_Concurrent(t *testing.T) {
runtime.GOMAXPROCS(2) runtime.GOMAXPROCS(2)
s := NewSafeSet() s := NewSafeSet[int]()
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
@ -256,26 +256,26 @@ func TestSafeSetLen_Concurrent(t *testing.T) {
func TestSafeSetValues(t *testing.T) { func TestSafeSetValues(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
expect func(t *testing.T, s SafeSet) expect func(t *testing.T, s SafeSet[string])
}{ }{
{ {
name: "get values", name: "get values",
expect: func(t *testing.T, s SafeSet) { expect: func(t *testing.T, s SafeSet[string]) {
assert := assert.New(t) assert := assert.New(t)
s.Add("foo") s.Add("foo")
assert.Equal(s.Values(), []any{"foo"}) assert.Equal(s.Values(), []string{"foo"})
}, },
}, },
{ {
name: "get empty values", name: "get empty values",
expect: func(t *testing.T, s SafeSet) { expect: func(t *testing.T, s SafeSet[string]) {
assert := assert.New(t) assert := assert.New(t)
assert.Equal(s.Values(), []any(nil)) assert.Equal(s.Values(), []string(nil))
}, },
}, },
{ {
name: "get multi values", name: "get multi values",
expect: func(t *testing.T, s SafeSet) { expect: func(t *testing.T, s SafeSet[string]) {
assert := assert.New(t) assert := assert.New(t)
s.Add("foo") s.Add("foo")
s.Add("bar") s.Add("bar")
@ -287,7 +287,7 @@ func TestSafeSetValues(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
s := NewSafeSet() s := NewSafeSet[string]()
tc.expect(t, s) tc.expect(t, s)
}) })
} }
@ -296,7 +296,7 @@ func TestSafeSetValues(t *testing.T) {
func TestSafeSetValues_Concurrent(t *testing.T) { func TestSafeSetValues_Concurrent(t *testing.T) {
runtime.GOMAXPROCS(2) runtime.GOMAXPROCS(2)
s := NewSafeSet() s := NewSafeSet[int]()
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
@ -312,7 +312,7 @@ func TestSafeSetValues_Concurrent(t *testing.T) {
}() }()
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
s.Add(rand.Int()) s.Add(i)
} }
wg.Wait() wg.Wait()
} }
@ -320,32 +320,32 @@ func TestSafeSetValues_Concurrent(t *testing.T) {
func TestSafeSetClear(t *testing.T) { func TestSafeSetClear(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
expect func(t *testing.T, s SafeSet) expect func(t *testing.T, s SafeSet[string])
}{ }{
{ {
name: "clear empty set", name: "clear empty set",
expect: func(t *testing.T, s SafeSet) { expect: func(t *testing.T, s SafeSet[string]) {
assert := assert.New(t) assert := assert.New(t)
s.Clear() s.Clear()
assert.Equal(s.Values(), []any(nil)) assert.Equal(s.Values(), []string(nil))
}, },
}, },
{ {
name: "clear set", name: "clear set",
expect: func(t *testing.T, s SafeSet) { expect: func(t *testing.T, s SafeSet[string]) {
assert := assert.New(t) assert := assert.New(t)
assert.Equal(s.Add("foo"), true) assert.Equal(s.Add("foo"), true)
s.Clear() s.Clear()
assert.Equal(s.Values(), []any(nil)) assert.Equal(s.Values(), []string(nil))
assert.Equal(s.Add("foo"), true) assert.Equal(s.Add("foo"), true)
assert.Equal(s.Values(), []any{"foo"}) assert.Equal(s.Values(), []string{"foo"})
}, },
}, },
} }
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
s := NewSafeSet() s := NewSafeSet[string]()
tc.expect(t, s) tc.expect(t, s)
}) })
} }
@ -354,7 +354,7 @@ func TestSafeSetClear(t *testing.T) {
func TestSafeSetClear_Concurrent(t *testing.T) { func TestSafeSetClear_Concurrent(t *testing.T) {
runtime.GOMAXPROCS(2) runtime.GOMAXPROCS(2)
s := NewSafeSet() s := NewSafeSet[int]()
nums := rand.Perm(N) nums := rand.Perm(N)
var wg sync.WaitGroup var wg sync.WaitGroup

View File

@ -18,23 +18,23 @@
package set package set
type Set interface { type Set[T comparable] interface {
Values() []any Values() []T
Add(any) bool Add(T) bool
Delete(any) Delete(T)
Contains(...any) bool Contains(...T) bool
Len() uint Len() uint
Clear() Clear()
} }
type set map[any]struct{} type set[T comparable] map[T]struct{}
func New() Set { func New[T comparable]() Set[T] {
return &set{} return &set[T]{}
} }
func (s *set) Values() []any { func (s *set[T]) Values() []T {
var result []any var result []T
for k := range *s { for k := range *s {
result = append(result, k) result = append(result, k)
} }
@ -42,7 +42,7 @@ func (s *set) Values() []any {
return result return result
} }
func (s *set) Add(v any) bool { func (s *set[T]) Add(v T) bool {
_, found := (*s)[v] _, found := (*s)[v]
if found { if found {
return false return false
@ -52,11 +52,11 @@ func (s *set) Add(v any) bool {
return true return true
} }
func (s *set) Delete(v any) { func (s *set[T]) Delete(v T) {
delete(*s, v) delete(*s, v)
} }
func (s *set) Contains(vals ...any) bool { func (s *set[T]) Contains(vals ...T) bool {
for _, v := range vals { for _, v := range vals {
if _, ok := (*s)[v]; !ok { if _, ok := (*s)[v]; !ok {
return false return false
@ -66,10 +66,10 @@ func (s *set) Contains(vals ...any) bool {
return true return true
} }
func (s *set) Len() uint { func (s *set[T]) Len() uint {
return uint(len(*s)) return uint(len(*s))
} }
func (s *set) Clear() { func (s *set[T]) Clear() {
*s = set{} *s = set[T]{}
} }

View File

@ -25,33 +25,33 @@ import (
func TestSetAdd(t *testing.T) { func TestSetAdd(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
value any value string
expect func(t *testing.T, ok bool, s Set, value any) expect func(t *testing.T, ok bool, s Set[string], value string)
}{ }{
{ {
name: "add value", name: "add value",
value: "foo", value: "foo",
expect: func(t *testing.T, ok bool, s Set, value any) { expect: func(t *testing.T, ok bool, s Set[string], value string) {
assert := assert.New(t) assert := assert.New(t)
assert.Equal(ok, true) assert.Equal(ok, true)
assert.Equal(s.Values(), []any{value}) assert.Equal(s.Values(), []string{value})
}, },
}, },
{ {
name: "add value failed", name: "add value failed",
value: "foo", value: "foo",
expect: func(t *testing.T, _ bool, s Set, value any) { expect: func(t *testing.T, _ bool, s Set[string], value string) {
assert := assert.New(t) assert := assert.New(t)
ok := s.Add("foo") ok := s.Add("foo")
assert.Equal(ok, false) assert.Equal(ok, false)
assert.Equal(s.Values(), []any{value}) assert.Equal(s.Values(), []string{value})
}, },
}, },
} }
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
s := New() s := New[string]()
tc.expect(t, s.Add(tc.value), s, tc.value) tc.expect(t, s.Add(tc.value), s, tc.value)
}) })
} }
@ -60,13 +60,13 @@ func TestSetAdd(t *testing.T) {
func TestSetDelete(t *testing.T) { func TestSetDelete(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
value any value string
expect func(t *testing.T, s Set, value any) expect func(t *testing.T, s Set[string], value string)
}{ }{
{ {
name: "delete value", name: "delete value",
value: "foo", value: "foo",
expect: func(t *testing.T, s Set, value any) { expect: func(t *testing.T, s Set[string], value string) {
assert := assert.New(t) assert := assert.New(t)
s.Delete(value) s.Delete(value)
assert.Equal(s.Len(), uint(0)) assert.Equal(s.Len(), uint(0))
@ -75,7 +75,7 @@ func TestSetDelete(t *testing.T) {
{ {
name: "delete value does not exist", name: "delete value does not exist",
value: "foo", value: "foo",
expect: func(t *testing.T, s Set, _ any) { expect: func(t *testing.T, s Set[string], _ string) {
assert := assert.New(t) assert := assert.New(t)
s.Delete("bar") s.Delete("bar")
assert.Equal(s.Len(), uint(1)) assert.Equal(s.Len(), uint(1))
@ -85,7 +85,7 @@ func TestSetDelete(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
s := New() s := New[string]()
s.Add(tc.value) s.Add(tc.value)
tc.expect(t, s, tc.value) tc.expect(t, s, tc.value)
}) })
@ -95,13 +95,13 @@ func TestSetDelete(t *testing.T) {
func TestSetContains(t *testing.T) { func TestSetContains(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
value any value string
expect func(t *testing.T, s Set, value any) expect func(t *testing.T, s Set[string], value string)
}{ }{
{ {
name: "contains value", name: "contains value",
value: "foo", value: "foo",
expect: func(t *testing.T, s Set, value any) { expect: func(t *testing.T, s Set[string], value string) {
assert := assert.New(t) assert := assert.New(t)
assert.Equal(s.Contains(value), true) assert.Equal(s.Contains(value), true)
}, },
@ -109,7 +109,7 @@ func TestSetContains(t *testing.T) {
{ {
name: "contains value does not exist", name: "contains value does not exist",
value: "foo", value: "foo",
expect: func(t *testing.T, s Set, _ any) { expect: func(t *testing.T, s Set[string], _ string) {
assert := assert.New(t) assert := assert.New(t)
assert.Equal(s.Contains("bar"), false) assert.Equal(s.Contains("bar"), false)
}, },
@ -118,7 +118,7 @@ func TestSetContains(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
s := New() s := New[string]()
s.Add(tc.value) s.Add(tc.value)
tc.expect(t, s, tc.value) tc.expect(t, s, tc.value)
}) })
@ -128,11 +128,11 @@ func TestSetContains(t *testing.T) {
func TestSetLen(t *testing.T) { func TestSetLen(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
expect func(t *testing.T, s Set) expect func(t *testing.T, s Set[string])
}{ }{
{ {
name: "get length", name: "get length",
expect: func(t *testing.T, s Set) { expect: func(t *testing.T, s Set[string]) {
assert := assert.New(t) assert := assert.New(t)
s.Add("foo") s.Add("foo")
assert.Equal(s.Len(), uint(1)) assert.Equal(s.Len(), uint(1))
@ -140,7 +140,7 @@ func TestSetLen(t *testing.T) {
}, },
{ {
name: "get empty set length", name: "get empty set length",
expect: func(t *testing.T, s Set) { expect: func(t *testing.T, s Set[string]) {
assert := assert.New(t) assert := assert.New(t)
assert.Equal(s.Len(), uint(0)) assert.Equal(s.Len(), uint(0))
}, },
@ -149,7 +149,7 @@ func TestSetLen(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
s := New() s := New[string]()
tc.expect(t, s) tc.expect(t, s)
}) })
} }
@ -158,26 +158,26 @@ func TestSetLen(t *testing.T) {
func TestSetValues(t *testing.T) { func TestSetValues(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
expect func(t *testing.T, s Set) expect func(t *testing.T, s Set[string])
}{ }{
{ {
name: "get values", name: "get values",
expect: func(t *testing.T, s Set) { expect: func(t *testing.T, s Set[string]) {
assert := assert.New(t) assert := assert.New(t)
s.Add("foo") s.Add("foo")
assert.Equal(s.Values(), []any{"foo"}) assert.Equal(s.Values(), []string{"foo"})
}, },
}, },
{ {
name: "get empty values", name: "get empty values",
expect: func(t *testing.T, s Set) { expect: func(t *testing.T, s Set[string]) {
assert := assert.New(t) assert := assert.New(t)
assert.Equal(s.Values(), []any(nil)) assert.Equal(s.Values(), []string(nil))
}, },
}, },
{ {
name: "get multi values", name: "get multi values",
expect: func(t *testing.T, s Set) { expect: func(t *testing.T, s Set[string]) {
assert := assert.New(t) assert := assert.New(t)
s.Add("foo") s.Add("foo")
s.Add("bar") s.Add("bar")
@ -189,7 +189,7 @@ func TestSetValues(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
s := New() s := New[string]()
tc.expect(t, s) tc.expect(t, s)
}) })
} }
@ -198,32 +198,32 @@ func TestSetValues(t *testing.T) {
func TestSetClear(t *testing.T) { func TestSetClear(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
expect func(t *testing.T, s Set) expect func(t *testing.T, s Set[string])
}{ }{
{ {
name: "clear empty set", name: "clear empty set",
expect: func(t *testing.T, s Set) { expect: func(t *testing.T, s Set[string]) {
assert := assert.New(t) assert := assert.New(t)
s.Clear() s.Clear()
assert.Equal(s.Values(), []any(nil)) assert.Equal(s.Values(), []string(nil))
}, },
}, },
{ {
name: "clear set", name: "clear set",
expect: func(t *testing.T, s Set) { expect: func(t *testing.T, s Set[string]) {
assert := assert.New(t) assert := assert.New(t)
assert.Equal(s.Add("foo"), true) assert.Equal(s.Add("foo"), true)
s.Clear() s.Clear()
assert.Equal(s.Values(), []any(nil)) assert.Equal(s.Values(), []string(nil))
assert.Equal(s.Add("foo"), true) assert.Equal(s.Add("foo"), true)
assert.Equal(s.Values(), []any{"foo"}) assert.Equal(s.Values(), []string{"foo"})
}, },
}, },
} }
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
s := New() s := New[string]()
tc.expect(t, s) tc.expect(t, s)
}) })
} }

View File

@ -20,7 +20,11 @@ package dag
import ( import (
"errors" "errors"
"math/rand"
"sync" "sync"
"time"
cmap "github.com/orcaman/concurrent-map/v2"
) )
var ( var (
@ -41,24 +45,30 @@ var (
) )
// DAG is the interface used for directed acyclic graph. // DAG is the interface used for directed acyclic graph.
type DAG interface { type DAG[T comparable] interface {
// AddVertex adds vertex to graph. // AddVertex adds vertex to graph.
AddVertex(id string, value any) error AddVertex(id string, value T) error
// DeleteVertex deletes vertex graph. // DeleteVertex deletes vertex graph.
DeleteVertex(id string) DeleteVertex(id string)
// GetVertex gets vertex from graph. // GetVertex gets vertex from graph.
GetVertex(id string) (*Vertex, error) GetVertex(id string) (*Vertex[T], error)
// GetVertices returns map of vertices. // GetVertices returns map of vertices.
GetVertices() map[string]*Vertex GetVertices() map[string]*Vertex[T]
// GetRandomVertices returns random map of vertices.
GetRandomVertices(n uint) map[string]*Vertex[T]
// GetVertexKeys returns keys of vertices.
GetVertexKeys() []string
// GetSourceVertices returns source vertices. // GetSourceVertices returns source vertices.
GetSourceVertices() map[string]*Vertex GetSourceVertices() map[string]*Vertex[T]
// GetSinkVertices returns sink vertices. // GetSinkVertices returns sink vertices.
GetSinkVertices() map[string]*Vertex GetSinkVertices() map[string]*Vertex[T]
// VertexCount returns count of vertices. // VertexCount returns count of vertices.
VertexCount() int VertexCount() int
@ -74,69 +84,56 @@ type DAG interface {
} }
// dag provides directed acyclic graph function. // dag provides directed acyclic graph function.
type dag struct { type dag[T comparable] struct {
mu sync.RWMutex mu sync.RWMutex
vertices map[string]*Vertex vertices cmap.ConcurrentMap[*Vertex[T]]
} }
// New returns a new DAG interface. // New returns a new DAG interface.
func NewDAG() DAG { func NewDAG[T comparable]() DAG[T] {
return &dag{ return &dag[T]{
vertices: make(map[string]*Vertex), vertices: cmap.New[*Vertex[T]](),
} }
} }
// AddVertex adds vertex to graph. // AddVertex adds vertex to graph.
func (d *dag) AddVertex(id string, value any) error { func (d *dag[T]) AddVertex(id string, value T) error {
d.mu.Lock() d.mu.Lock()
defer d.mu.Unlock() defer d.mu.Unlock()
if _, ok := d.vertices[id]; ok { if _, ok := d.vertices.Get(id); ok {
return ErrVertexAlreadyExists return ErrVertexAlreadyExists
} }
d.vertices[id] = NewVertex(id, value) d.vertices.Set(id, NewVertex(id, value))
return nil return nil
} }
// DeleteVertex deletes vertex graph. // DeleteVertex deletes vertex graph.
func (d *dag) DeleteVertex(id string) { func (d *dag[T]) DeleteVertex(id string) {
d.mu.Lock() d.mu.Lock()
defer d.mu.Unlock() defer d.mu.Unlock()
vertex, ok := d.vertices[id] vertex, ok := d.vertices.Get(id)
if !ok { if !ok {
return return
} }
for _, value := range vertex.Parents.Values() { for _, parent := range vertex.Parents.Values() {
parent, ok := value.(*Vertex)
if !ok {
continue
}
parent.Children.Delete(vertex) parent.Children.Delete(vertex)
} }
for _, value := range vertex.Children.Values() { for _, child := range vertex.Children.Values() {
child, ok := value.(*Vertex)
if !ok {
continue
}
child.Parents.Delete(vertex) child.Parents.Delete(vertex)
continue continue
} }
delete(d.vertices, id) d.vertices.Remove(id)
} }
// GetVertex gets vertex from graph. // GetVertex gets vertex from graph.
func (d *dag) GetVertex(id string) (*Vertex, error) { func (d *dag[T]) GetVertex(id string) (*Vertex[T], error) {
d.mu.RLock() vertex, ok := d.vertices.Get(id)
defer d.mu.RUnlock()
vertex, ok := d.vertices[id]
if !ok { if !ok {
return nil, ErrVertexNotFound return nil, ErrVertexNotFound
} }
@ -145,20 +142,44 @@ func (d *dag) GetVertex(id string) (*Vertex, error) {
} }
// GetVertices returns map of vertices. // GetVertices returns map of vertices.
func (d *dag) GetVertices() map[string]*Vertex { func (d *dag[T]) GetVertices() map[string]*Vertex[T] {
return d.vertices.Items()
}
// GetRandomVertices returns random map of vertices.
func (d *dag[T]) GetRandomVertices(n uint) map[string]*Vertex[T] {
d.mu.RLock() d.mu.RLock()
defer d.mu.RUnlock() defer d.mu.RUnlock()
return d.vertices keys := d.GetVertexKeys()
vertices := d.GetVertices()
if int(n) >= len(keys) {
return vertices
}
rand.Seed(time.Now().Unix())
permutation := rand.Perm(len(keys))[:n]
randomVertices := make(map[string]*Vertex[T])
for _, v := range permutation {
key := keys[v]
randomVertices[key] = vertices[key]
}
return randomVertices
}
// GetVertexKeys returns keys of vertices.
func (d *dag[T]) GetVertexKeys() []string {
return d.vertices.Keys()
} }
// VertexCount returns count of vertices. // VertexCount returns count of vertices.
func (d *dag) VertexCount() int { func (d *dag[T]) VertexCount() int {
return len(d.vertices) return d.vertices.Count()
} }
// CanAddEdge finds whether there are circles through depth-first search. // CanAddEdge finds whether there are circles through depth-first search.
func (d *dag) CanAddEdge(fromVertexID, toVertexID string) bool { func (d *dag[T]) CanAddEdge(fromVertexID, toVertexID string) bool {
d.mu.RLock() d.mu.RLock()
defer d.mu.RUnlock() defer d.mu.RUnlock()
@ -166,22 +187,17 @@ func (d *dag) CanAddEdge(fromVertexID, toVertexID string) bool {
return false return false
} }
fromVertex, ok := d.vertices[fromVertexID] fromVertex, ok := d.vertices.Get(fromVertexID)
if !ok { if !ok {
return false return false
} }
if _, ok := d.vertices[toVertexID]; !ok { if _, ok := d.vertices.Get(toVertexID); !ok {
return false return false
} }
for _, child := range fromVertex.Children.Values() { for _, child := range fromVertex.Children.Values() {
vertex, ok := child.(*Vertex) if child.ID == toVertexID {
if !ok {
continue
}
if vertex.ID == toVertexID {
return false return false
} }
} }
@ -194,7 +210,7 @@ func (d *dag) CanAddEdge(fromVertexID, toVertexID string) bool {
} }
// AddEdge adds edge between two vertices. // AddEdge adds edge between two vertices.
func (d *dag) AddEdge(fromVertexID, toVertexID string) error { func (d *dag[T]) AddEdge(fromVertexID, toVertexID string) error {
d.mu.Lock() d.mu.Lock()
defer d.mu.Unlock() defer d.mu.Unlock()
@ -202,23 +218,18 @@ func (d *dag) AddEdge(fromVertexID, toVertexID string) error {
return ErrCycleBetweenVertices return ErrCycleBetweenVertices
} }
fromVertex, ok := d.vertices[fromVertexID] fromVertex, ok := d.vertices.Get(fromVertexID)
if !ok { if !ok {
return ErrVertexNotFound return ErrVertexNotFound
} }
toVertex, ok := d.vertices[toVertexID] toVertex, ok := d.vertices.Get(toVertexID)
if !ok { if !ok {
return ErrVertexNotFound return ErrVertexNotFound
} }
for _, child := range fromVertex.Children.Values() { for _, child := range fromVertex.Children.Values() {
vertex, ok := child.(*Vertex) if child.ID == toVertexID {
if !ok {
continue
}
if vertex.ID == toVertexID {
return ErrCycleBetweenVertices return ErrCycleBetweenVertices
} }
} }
@ -239,16 +250,16 @@ func (d *dag) AddEdge(fromVertexID, toVertexID string) error {
} }
// DeleteEdge deletes edge between two vertices. // DeleteEdge deletes edge between two vertices.
func (d *dag) DeleteEdge(fromVertexID, toVertexID string) error { func (d *dag[T]) DeleteEdge(fromVertexID, toVertexID string) error {
d.mu.Lock() d.mu.Lock()
defer d.mu.Unlock() defer d.mu.Unlock()
fromVertex, ok := d.vertices[fromVertexID] fromVertex, ok := d.vertices.Get(fromVertexID)
if !ok { if !ok {
return ErrVertexNotFound return ErrVertexNotFound
} }
toVertex, ok := d.vertices[toVertexID] toVertex, ok := d.vertices.Get(toVertexID)
if !ok { if !ok {
return ErrVertexNotFound return ErrVertexNotFound
} }
@ -259,12 +270,12 @@ func (d *dag) DeleteEdge(fromVertexID, toVertexID string) error {
} }
// GetSourceVertices returns source vertices. // GetSourceVertices returns source vertices.
func (d *dag) GetSourceVertices() map[string]*Vertex { func (d *dag[T]) GetSourceVertices() map[string]*Vertex[T] {
d.mu.RLock() d.mu.RLock()
defer d.mu.RUnlock() defer d.mu.RUnlock()
sourceVertices := make(map[string]*Vertex) sourceVertices := make(map[string]*Vertex[T])
for k, v := range d.vertices { for k, v := range d.vertices.Items() {
if v.InDegree() == 0 { if v.InDegree() == 0 {
sourceVertices[k] = v sourceVertices[k] = v
} }
@ -274,12 +285,12 @@ func (d *dag) GetSourceVertices() map[string]*Vertex {
} }
// GetSinkVertices returns sink vertices. // GetSinkVertices returns sink vertices.
func (d *dag) GetSinkVertices() map[string]*Vertex { func (d *dag[T]) GetSinkVertices() map[string]*Vertex[T] {
d.mu.RLock() d.mu.RLock()
defer d.mu.RUnlock() defer d.mu.RUnlock()
sinkVertices := make(map[string]*Vertex) sinkVertices := make(map[string]*Vertex[T])
for k, v := range d.vertices { for k, v := range d.vertices.Items() {
if v.OutDegree() == 0 { if v.OutDegree() == 0 {
sinkVertices[k] = v sinkVertices[k] = v
} }
@ -289,7 +300,7 @@ func (d *dag) GetSinkVertices() map[string]*Vertex {
} }
// depthFirstSearch is a depth-first search of the directed acyclic graph. // depthFirstSearch is a depth-first search of the directed acyclic graph.
func (d *dag) depthFirstSearch(fromVertexID, toVertexID string) bool { func (d *dag[T]) depthFirstSearch(fromVertexID, toVertexID string) bool {
successors := make(map[string]struct{}) successors := make(map[string]struct{})
d.search(fromVertexID, successors) d.search(fromVertexID, successors)
_, ok := successors[toVertexID] _, ok := successors[toVertexID]
@ -297,21 +308,16 @@ func (d *dag) depthFirstSearch(fromVertexID, toVertexID string) bool {
} }
// depthFirstSearch finds successors of vertex. // depthFirstSearch finds successors of vertex.
func (d *dag) search(vertexID string, successors map[string]struct{}) { func (d *dag[T]) search(vertexID string, successors map[string]struct{}) {
vertex, ok := d.vertices[vertexID] vertex, ok := d.vertices.Get(vertexID)
if !ok { if !ok {
return return
} }
for _, child := range vertex.Children.Values() { for _, child := range vertex.Children.Values() {
vertex, ok := child.(*Vertex) if _, ok := successors[child.ID]; !ok {
if !ok { successors[child.ID] = struct{}{}
continue d.search(child.ID, successors)
}
if _, ok := successors[vertex.ID]; !ok {
successors[vertex.ID] = struct{}{}
d.search(vertex.ID, successors)
} }
} }
} }

View File

@ -17,6 +17,7 @@
package dag package dag
import ( import (
"errors"
"fmt" "fmt"
"reflect" "reflect"
"testing" "testing"
@ -25,9 +26,9 @@ import (
) )
func TestNewDAG(t *testing.T) { func TestNewDAG(t *testing.T) {
d := NewDAG() d := NewDAG[string]()
assert := assert.New(t) assert := assert.New(t)
assert.Equal(reflect.TypeOf(d).Elem().Name(), "dag") assert.Equal(reflect.TypeOf(d).Elem().Name(), "dag[string]")
} }
func TestDAGAddVertex(t *testing.T) { func TestDAGAddVertex(t *testing.T) {
@ -35,13 +36,13 @@ func TestDAGAddVertex(t *testing.T) {
name string name string
id string id string
value any value any
expect func(t *testing.T, d DAG, err error) expect func(t *testing.T, d DAG[string], err error)
}{ }{
{ {
name: "add vertex", name: "add vertex",
id: mockVertexID, id: mockVertexID,
value: mockVertexValue, value: mockVertexValue,
expect: func(t *testing.T, d DAG, err error) { expect: func(t *testing.T, d DAG[string], err error) {
assert := assert.New(t) assert := assert.New(t)
assert.NoError(err) assert.NoError(err)
}, },
@ -50,7 +51,7 @@ func TestDAGAddVertex(t *testing.T) {
name: "vertex already exists", name: "vertex already exists",
id: mockVertexID, id: mockVertexID,
value: mockVertexValue, value: mockVertexValue,
expect: func(t *testing.T, d DAG, err error) { expect: func(t *testing.T, d DAG[string], err error) {
assert := assert.New(t) assert := assert.New(t)
assert.NoError(err) assert.NoError(err)
@ -61,7 +62,7 @@ func TestDAGAddVertex(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
d := NewDAG() d := NewDAG[string]()
tc.expect(t, d, d.AddVertex(tc.id, tc.name)) tc.expect(t, d, d.AddVertex(tc.id, tc.name))
}) })
} }
@ -70,11 +71,11 @@ func TestDAGAddVertex(t *testing.T) {
func TestDAGDeleteVertex(t *testing.T) { func TestDAGDeleteVertex(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
expect func(t *testing.T, d DAG) expect func(t *testing.T, d DAG[string])
}{ }{
{ {
name: "delete vertex", name: "delete vertex",
expect: func(t *testing.T, d DAG) { expect: func(t *testing.T, d DAG[string]) {
assert := assert.New(t) assert := assert.New(t)
if err := d.AddVertex(mockVertexID, mockVertexValue); err != nil { if err := d.AddVertex(mockVertexID, mockVertexValue); err != nil {
assert.NoError(err) assert.NoError(err)
@ -87,7 +88,7 @@ func TestDAGDeleteVertex(t *testing.T) {
}, },
{ {
name: "delete vertex with edges", name: "delete vertex with edges",
expect: func(t *testing.T, d DAG) { expect: func(t *testing.T, d DAG[string]) {
assert := assert.New(t) assert := assert.New(t)
var ( var (
@ -119,7 +120,7 @@ func TestDAGDeleteVertex(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
d := NewDAG() d := NewDAG[string]()
tc.expect(t, d) tc.expect(t, d)
}) })
} }
@ -128,11 +129,11 @@ func TestDAGDeleteVertex(t *testing.T) {
func TestDAGGetVertex(t *testing.T) { func TestDAGGetVertex(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
expect func(t *testing.T, d DAG) expect func(t *testing.T, d DAG[string])
}{ }{
{ {
name: "get vertex", name: "get vertex",
expect: func(t *testing.T, d DAG) { expect: func(t *testing.T, d DAG[string]) {
assert := assert.New(t) assert := assert.New(t)
if err := d.AddVertex(mockVertexID, mockVertexValue); err != nil { if err := d.AddVertex(mockVertexID, mockVertexValue); err != nil {
assert.NoError(err) assert.NoError(err)
@ -148,7 +149,7 @@ func TestDAGGetVertex(t *testing.T) {
}, },
{ {
name: "vertex not found", name: "vertex not found",
expect: func(t *testing.T, d DAG) { expect: func(t *testing.T, d DAG[string]) {
assert := assert.New(t) assert := assert.New(t)
_, err := d.GetVertex(mockVertexID) _, err := d.GetVertex(mockVertexID)
assert.EqualError(err, ErrVertexNotFound.Error()) assert.EqualError(err, ErrVertexNotFound.Error())
@ -158,7 +159,7 @@ func TestDAGGetVertex(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
d := NewDAG() d := NewDAG[string]()
tc.expect(t, d) tc.expect(t, d)
}) })
} }
@ -167,11 +168,11 @@ func TestDAGGetVertex(t *testing.T) {
func TestDAGVertexVertexCount(t *testing.T) { func TestDAGVertexVertexCount(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
expect func(t *testing.T, d DAG) expect func(t *testing.T, d DAG[string])
}{ }{
{ {
name: "get length of vertex", name: "get length of vertex",
expect: func(t *testing.T, d DAG) { expect: func(t *testing.T, d DAG[string]) {
assert := assert.New(t) assert := assert.New(t)
if err := d.AddVertex(mockVertexID, mockVertexValue); err != nil { if err := d.AddVertex(mockVertexID, mockVertexValue); err != nil {
assert.NoError(err) assert.NoError(err)
@ -186,7 +187,7 @@ func TestDAGVertexVertexCount(t *testing.T) {
}, },
{ {
name: "empty dag", name: "empty dag",
expect: func(t *testing.T, d DAG) { expect: func(t *testing.T, d DAG[string]) {
assert := assert.New(t) assert := assert.New(t)
assert.Equal(d.VertexCount(), 0) assert.Equal(d.VertexCount(), 0)
}, },
@ -195,7 +196,7 @@ func TestDAGVertexVertexCount(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
d := NewDAG() d := NewDAG[string]()
tc.expect(t, d) tc.expect(t, d)
}) })
} }
@ -204,11 +205,11 @@ func TestDAGVertexVertexCount(t *testing.T) {
func TestDAGGetVertices(t *testing.T) { func TestDAGGetVertices(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
expect func(t *testing.T, d DAG) expect func(t *testing.T, d DAG[string])
}{ }{
{ {
name: "get vertices", name: "get vertices",
expect: func(t *testing.T, d DAG) { expect: func(t *testing.T, d DAG[string]) {
assert := assert.New(t) assert := assert.New(t)
if err := d.AddVertex(mockVertexID, mockVertexValue); err != nil { if err := d.AddVertex(mockVertexID, mockVertexValue); err != nil {
assert.NoError(err) assert.NoError(err)
@ -220,6 +221,15 @@ func TestDAGGetVertices(t *testing.T) {
assert.Equal(vertices[mockVertexID].Value, mockVertexValue) assert.Equal(vertices[mockVertexID].Value, mockVertexValue)
d.DeleteVertex(mockVertexID) d.DeleteVertex(mockVertexID)
vertices = d.GetVertices()
assert.Equal(len(vertices), 0)
},
},
{
name: "dag is empty",
expect: func(t *testing.T, d DAG[string]) {
assert := assert.New(t)
vertices := d.GetVertices()
assert.Equal(len(vertices), 0) assert.Equal(len(vertices), 0)
}, },
}, },
@ -227,7 +237,95 @@ func TestDAGGetVertices(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
d := NewDAG() d := NewDAG[string]()
tc.expect(t, d)
})
}
}
func TestDAGGetRandomVertices(t *testing.T) {
tests := []struct {
name string
expect func(t *testing.T, d DAG[string])
}{
{
name: "get random vertices",
expect: func(t *testing.T, d DAG[string]) {
assert := assert.New(t)
var (
mockVertexEID = "bae"
mockVertexFID = "baf"
)
if err := d.AddVertex(mockVertexEID, mockVertexValue); err != nil {
assert.NoError(err)
}
if err := d.AddVertex(mockVertexFID, mockVertexValue); err != nil {
assert.NoError(err)
}
vertices := d.GetRandomVertices(0)
assert.Equal(len(vertices), 0)
vertices = d.GetRandomVertices(1)
assert.Equal(len(vertices), 1)
vertices = d.GetRandomVertices(2)
assert.Equal(len(vertices), 2)
vertices = d.GetRandomVertices(3)
assert.Equal(len(vertices), 2)
},
},
{
name: "dag is empty",
expect: func(t *testing.T, d DAG[string]) {
assert := assert.New(t)
vertices := d.GetRandomVertices(0)
assert.Equal(len(vertices), 0)
vertices = d.GetRandomVertices(1)
assert.Equal(len(vertices), 0)
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
d := NewDAG[string]()
tc.expect(t, d)
})
}
}
func TestDAGGetVertexKeys(t *testing.T) {
tests := []struct {
name string
expect func(t *testing.T, d DAG[string])
}{
{
name: "get keys of vertices",
expect: func(t *testing.T, d DAG[string]) {
assert := assert.New(t)
if err := d.AddVertex(mockVertexID, mockVertexValue); err != nil {
assert.NoError(err)
}
keys := d.GetVertexKeys()
assert.Equal(len(keys), 1)
assert.Equal(keys[0], mockVertexID)
d.DeleteVertex(mockVertexID)
keys = d.GetVertexKeys()
assert.Equal(len(keys), 0)
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
d := NewDAG[string]()
tc.expect(t, d) tc.expect(t, d)
}) })
} }
@ -236,11 +334,11 @@ func TestDAGGetVertices(t *testing.T) {
func TestDAGAddEdge(t *testing.T) { func TestDAGAddEdge(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
expect func(t *testing.T, d DAG) expect func(t *testing.T, d DAG[string])
}{ }{
{ {
name: "add edge", name: "add edge",
expect: func(t *testing.T, d DAG) { expect: func(t *testing.T, d DAG[string]) {
assert := assert.New(t) assert := assert.New(t)
var ( var (
mockVertexEID = "bae" mockVertexEID = "bae"
@ -293,7 +391,7 @@ func TestDAGAddEdge(t *testing.T) {
}, },
{ {
name: "cycle between vertices", name: "cycle between vertices",
expect: func(t *testing.T, d DAG) { expect: func(t *testing.T, d DAG[string]) {
assert := assert.New(t) assert := assert.New(t)
var ( var (
mockVertexEID = "bae" mockVertexEID = "bae"
@ -358,7 +456,7 @@ func TestDAGAddEdge(t *testing.T) {
}, },
{ {
name: "vertex not found", name: "vertex not found",
expect: func(t *testing.T, d DAG) { expect: func(t *testing.T, d DAG[string]) {
assert := assert.New(t) assert := assert.New(t)
var ( var (
mockVertexEID = "bae" mockVertexEID = "bae"
@ -382,7 +480,7 @@ func TestDAGAddEdge(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
d := NewDAG() d := NewDAG[string]()
tc.expect(t, d) tc.expect(t, d)
}) })
} }
@ -391,11 +489,11 @@ func TestDAGAddEdge(t *testing.T) {
func TestDAGCanAddEdge(t *testing.T) { func TestDAGCanAddEdge(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
expect func(t *testing.T, d DAG) expect func(t *testing.T, d DAG[string])
}{ }{
{ {
name: "can add edge", name: "can add edge",
expect: func(t *testing.T, d DAG) { expect: func(t *testing.T, d DAG[string]) {
assert := assert.New(t) assert := assert.New(t)
var ( var (
mockVertexEID = "bae" mockVertexEID = "bae"
@ -447,7 +545,7 @@ func TestDAGCanAddEdge(t *testing.T) {
}, },
{ {
name: "cycle between vertices", name: "cycle between vertices",
expect: func(t *testing.T, d DAG) { expect: func(t *testing.T, d DAG[string]) {
assert := assert.New(t) assert := assert.New(t)
var ( var (
mockVertexEID = "bae" mockVertexEID = "bae"
@ -511,7 +609,7 @@ func TestDAGCanAddEdge(t *testing.T) {
}, },
{ {
name: "vertex not found", name: "vertex not found",
expect: func(t *testing.T, d DAG) { expect: func(t *testing.T, d DAG[string]) {
assert := assert.New(t) assert := assert.New(t)
var ( var (
mockVertexEID = "bae" mockVertexEID = "bae"
@ -532,7 +630,7 @@ func TestDAGCanAddEdge(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
d := NewDAG() d := NewDAG[string]()
tc.expect(t, d) tc.expect(t, d)
}) })
} }
@ -541,11 +639,11 @@ func TestDAGCanAddEdge(t *testing.T) {
func TestDAGDeleteEdge(t *testing.T) { func TestDAGDeleteEdge(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
expect func(t *testing.T, d DAG) expect func(t *testing.T, d DAG[string])
}{ }{
{ {
name: "delete edge", name: "delete edge",
expect: func(t *testing.T, d DAG) { expect: func(t *testing.T, d DAG[string]) {
assert := assert.New(t) assert := assert.New(t)
var ( var (
mockVertexEID = "bae" mockVertexEID = "bae"
@ -589,7 +687,7 @@ func TestDAGDeleteEdge(t *testing.T) {
}, },
{ {
name: "vertex not found", name: "vertex not found",
expect: func(t *testing.T, d DAG) { expect: func(t *testing.T, d DAG[string]) {
assert := assert.New(t) assert := assert.New(t)
var ( var (
mockVertexEID = "bae" mockVertexEID = "bae"
@ -613,7 +711,7 @@ func TestDAGDeleteEdge(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
d := NewDAG() d := NewDAG[string]()
tc.expect(t, d) tc.expect(t, d)
}) })
} }
@ -622,11 +720,11 @@ func TestDAGDeleteEdge(t *testing.T) {
func TestDAGSourceVertices(t *testing.T) { func TestDAGSourceVertices(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
expect func(t *testing.T, d DAG) expect func(t *testing.T, d DAG[string])
}{ }{
{ {
name: "get source vertices", name: "get source vertices",
expect: func(t *testing.T, d DAG) { expect: func(t *testing.T, d DAG[string]) {
assert := assert.New(t) assert := assert.New(t)
var ( var (
mockVertexEID = "bae" mockVertexEID = "bae"
@ -651,7 +749,7 @@ func TestDAGSourceVertices(t *testing.T) {
}, },
{ {
name: "source vertices not found", name: "source vertices not found",
expect: func(t *testing.T, d DAG) { expect: func(t *testing.T, d DAG[string]) {
assert := assert.New(t) assert := assert.New(t)
sourceVertices := d.GetSourceVertices() sourceVertices := d.GetSourceVertices()
assert.Equal(len(sourceVertices), 0) assert.Equal(len(sourceVertices), 0)
@ -660,7 +758,7 @@ func TestDAGSourceVertices(t *testing.T) {
} }
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
d := NewDAG() d := NewDAG[string]()
tc.expect(t, d) tc.expect(t, d)
}) })
} }
@ -669,11 +767,11 @@ func TestDAGSourceVertices(t *testing.T) {
func TestDAGSinkVertices(t *testing.T) { func TestDAGSinkVertices(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
expect func(t *testing.T, d DAG) expect func(t *testing.T, d DAG[string])
}{ }{
{ {
name: "get sink vertices", name: "get sink vertices",
expect: func(t *testing.T, d DAG) { expect: func(t *testing.T, d DAG[string]) {
assert := assert.New(t) assert := assert.New(t)
var ( var (
mockVertexEID = "bae" mockVertexEID = "bae"
@ -698,7 +796,7 @@ func TestDAGSinkVertices(t *testing.T) {
}, },
{ {
name: "sink vertices not found", name: "sink vertices not found",
expect: func(t *testing.T, d DAG) { expect: func(t *testing.T, d DAG[string]) {
assert := assert.New(t) assert := assert.New(t)
sinkVertices := d.GetSinkVertices() sinkVertices := d.GetSinkVertices()
assert.Equal(len(sinkVertices), 0) assert.Equal(len(sinkVertices), 0)
@ -707,7 +805,7 @@ func TestDAGSinkVertices(t *testing.T) {
} }
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
d := NewDAG() d := NewDAG[string]()
tc.expect(t, d) tc.expect(t, d)
}) })
} }
@ -715,14 +813,14 @@ func TestDAGSinkVertices(t *testing.T) {
func BenchmarkDAGAddVertex(b *testing.B) { func BenchmarkDAGAddVertex(b *testing.B) {
var ids []string var ids []string
d := NewDAG() d := NewDAG[string]()
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
ids = append(ids, fmt.Sprint(n)) ids = append(ids, fmt.Sprint(n))
} }
b.ResetTimer() b.ResetTimer()
for _, id := range ids { for _, id := range ids {
if err := d.AddVertex(id, nil); err != nil { if err := d.AddVertex(id, string(id)); err != nil {
b.Fatal(err) b.Fatal(err)
} }
} }
@ -730,10 +828,10 @@ func BenchmarkDAGAddVertex(b *testing.B) {
func BenchmarkDAGDeleteVertex(b *testing.B) { func BenchmarkDAGDeleteVertex(b *testing.B) {
var ids []string var ids []string
d := NewDAG() d := NewDAG[string]()
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
id := fmt.Sprint(n) id := fmt.Sprint(n)
if err := d.AddVertex(id, nil); err != nil { if err := d.AddVertex(id, string(id)); err != nil {
b.Fatal(err) b.Fatal(err)
} }
@ -746,12 +844,30 @@ func BenchmarkDAGDeleteVertex(b *testing.B) {
} }
} }
func BenchmarkDAGDeleteVertexWithMultiEdges(b *testing.B) { func BenchmarkDAGGetRandomKeys(b *testing.B) {
var ids []string d := NewDAG[string]()
d := NewDAG()
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
id := fmt.Sprint(n) id := fmt.Sprint(n)
if err := d.AddVertex(id, nil); err != nil { if err := d.AddVertex(id, string(id)); err != nil {
b.Fatal(err)
}
}
b.ResetTimer()
for n := 0; n < b.N; n++ {
vertices := d.GetRandomVertices(uint(n))
if len(vertices) != n {
b.Fatal(errors.New("get random vertices failed"))
}
}
}
func BenchmarkDAGDeleteVertexWithMultiEdges(b *testing.B) {
var ids []string
d := NewDAG[string]()
for n := 0; n < b.N; n++ {
id := fmt.Sprint(n)
if err := d.AddVertex(id, string(id)); err != nil {
b.Fatal(err) b.Fatal(err)
} }
@ -779,10 +895,10 @@ func BenchmarkDAGDeleteVertexWithMultiEdges(b *testing.B) {
func BenchmarkDAGAddEdge(b *testing.B) { func BenchmarkDAGAddEdge(b *testing.B) {
var ids []string var ids []string
d := NewDAG() d := NewDAG[string]()
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
id := fmt.Sprint(n) id := fmt.Sprint(n)
if err := d.AddVertex(id, nil); err != nil { if err := d.AddVertex(id, string(id)); err != nil {
b.Fatal(err) b.Fatal(err)
} }
@ -803,10 +919,10 @@ func BenchmarkDAGAddEdge(b *testing.B) {
func BenchmarkDAGAddEdgeWithMultiEdges(b *testing.B) { func BenchmarkDAGAddEdgeWithMultiEdges(b *testing.B) {
var ids []string var ids []string
d := NewDAG() d := NewDAG[string]()
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
id := fmt.Sprint(n) id := fmt.Sprint(n)
if err := d.AddVertex(id, nil); err != nil { if err := d.AddVertex(id, string(id)); err != nil {
b.Fatal(err) b.Fatal(err)
} }
@ -840,10 +956,10 @@ func BenchmarkDAGAddEdgeWithMultiEdges(b *testing.B) {
func BenchmarkDAGDeleteEdge(b *testing.B) { func BenchmarkDAGDeleteEdge(b *testing.B) {
var ids []string var ids []string
d := NewDAG() d := NewDAG[string]()
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
id := fmt.Sprint(n) id := fmt.Sprint(n)
if err := d.AddVertex(id, nil); err != nil { if err := d.AddVertex(id, string(id)); err != nil {
b.Fatal(err) b.Fatal(err)
} }

View File

@ -12,30 +12,30 @@ import (
) )
// MockDAG is a mock of DAG interface. // MockDAG is a mock of DAG interface.
type MockDAG struct { type MockDAG[T comparable] struct {
ctrl *gomock.Controller ctrl *gomock.Controller
recorder *MockDAGMockRecorder recorder *MockDAGMockRecorder[T]
} }
// MockDAGMockRecorder is the mock recorder for MockDAG. // MockDAGMockRecorder is the mock recorder for MockDAG.
type MockDAGMockRecorder struct { type MockDAGMockRecorder[T comparable] struct {
mock *MockDAG mock *MockDAG[T]
} }
// NewMockDAG creates a new mock instance. // NewMockDAG creates a new mock instance.
func NewMockDAG(ctrl *gomock.Controller) *MockDAG { func NewMockDAG[T comparable](ctrl *gomock.Controller) *MockDAG[T] {
mock := &MockDAG{ctrl: ctrl} mock := &MockDAG[T]{ctrl: ctrl}
mock.recorder = &MockDAGMockRecorder{mock} mock.recorder = &MockDAGMockRecorder[T]{mock}
return mock return mock
} }
// EXPECT returns an object that allows the caller to indicate expected use. // EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockDAG) EXPECT() *MockDAGMockRecorder { func (m *MockDAG[T]) EXPECT() *MockDAGMockRecorder[T] {
return m.recorder return m.recorder
} }
// AddEdge mocks base method. // AddEdge mocks base method.
func (m *MockDAG) AddEdge(fromVertexID, toVertexID string) error { func (m *MockDAG[T]) AddEdge(fromVertexID, toVertexID string) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddEdge", fromVertexID, toVertexID) ret := m.ctrl.Call(m, "AddEdge", fromVertexID, toVertexID)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
@ -43,13 +43,13 @@ func (m *MockDAG) AddEdge(fromVertexID, toVertexID string) error {
} }
// AddEdge indicates an expected call of AddEdge. // AddEdge indicates an expected call of AddEdge.
func (mr *MockDAGMockRecorder) AddEdge(fromVertexID, toVertexID interface{}) *gomock.Call { func (mr *MockDAGMockRecorder[T]) AddEdge(fromVertexID, toVertexID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddEdge", reflect.TypeOf((*MockDAG)(nil).AddEdge), fromVertexID, toVertexID) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddEdge", reflect.TypeOf((*MockDAG[T])(nil).AddEdge), fromVertexID, toVertexID)
} }
// AddVertex mocks base method. // AddVertex mocks base method.
func (m *MockDAG) AddVertex(id string, value any) error { func (m *MockDAG[T]) AddVertex(id string, value T) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddVertex", id, value) ret := m.ctrl.Call(m, "AddVertex", id, value)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
@ -57,13 +57,13 @@ func (m *MockDAG) AddVertex(id string, value any) error {
} }
// AddVertex indicates an expected call of AddVertex. // AddVertex indicates an expected call of AddVertex.
func (mr *MockDAGMockRecorder) AddVertex(id, value interface{}) *gomock.Call { func (mr *MockDAGMockRecorder[T]) AddVertex(id, value interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddVertex", reflect.TypeOf((*MockDAG)(nil).AddVertex), id, value) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddVertex", reflect.TypeOf((*MockDAG[T])(nil).AddVertex), id, value)
} }
// CanAddEdge mocks base method. // CanAddEdge mocks base method.
func (m *MockDAG) CanAddEdge(fromVertexID, toVertexID string) bool { func (m *MockDAG[T]) CanAddEdge(fromVertexID, toVertexID string) bool {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CanAddEdge", fromVertexID, toVertexID) ret := m.ctrl.Call(m, "CanAddEdge", fromVertexID, toVertexID)
ret0, _ := ret[0].(bool) ret0, _ := ret[0].(bool)
@ -71,13 +71,13 @@ func (m *MockDAG) CanAddEdge(fromVertexID, toVertexID string) bool {
} }
// CanAddEdge indicates an expected call of CanAddEdge. // CanAddEdge indicates an expected call of CanAddEdge.
func (mr *MockDAGMockRecorder) CanAddEdge(fromVertexID, toVertexID interface{}) *gomock.Call { func (mr *MockDAGMockRecorder[T]) CanAddEdge(fromVertexID, toVertexID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CanAddEdge", reflect.TypeOf((*MockDAG)(nil).CanAddEdge), fromVertexID, toVertexID) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CanAddEdge", reflect.TypeOf((*MockDAG[T])(nil).CanAddEdge), fromVertexID, toVertexID)
} }
// DeleteEdge mocks base method. // DeleteEdge mocks base method.
func (m *MockDAG) DeleteEdge(fromVertexID, toVertexID string) error { func (m *MockDAG[T]) DeleteEdge(fromVertexID, toVertexID string) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteEdge", fromVertexID, toVertexID) ret := m.ctrl.Call(m, "DeleteEdge", fromVertexID, toVertexID)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
@ -85,82 +85,110 @@ func (m *MockDAG) DeleteEdge(fromVertexID, toVertexID string) error {
} }
// DeleteEdge indicates an expected call of DeleteEdge. // DeleteEdge indicates an expected call of DeleteEdge.
func (mr *MockDAGMockRecorder) DeleteEdge(fromVertexID, toVertexID interface{}) *gomock.Call { func (mr *MockDAGMockRecorder[T]) DeleteEdge(fromVertexID, toVertexID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteEdge", reflect.TypeOf((*MockDAG)(nil).DeleteEdge), fromVertexID, toVertexID) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteEdge", reflect.TypeOf((*MockDAG[T])(nil).DeleteEdge), fromVertexID, toVertexID)
} }
// DeleteVertex mocks base method. // DeleteVertex mocks base method.
func (m *MockDAG) DeleteVertex(id string) { func (m *MockDAG[T]) DeleteVertex(id string) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "DeleteVertex", id) m.ctrl.Call(m, "DeleteVertex", id)
} }
// DeleteVertex indicates an expected call of DeleteVertex. // DeleteVertex indicates an expected call of DeleteVertex.
func (mr *MockDAGMockRecorder) DeleteVertex(id interface{}) *gomock.Call { func (mr *MockDAGMockRecorder[T]) DeleteVertex(id interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteVertex", reflect.TypeOf((*MockDAG)(nil).DeleteVertex), id) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteVertex", reflect.TypeOf((*MockDAG[T])(nil).DeleteVertex), id)
}
// GetRandomVertices mocks base method.
func (m *MockDAG[T]) GetRandomVertices(n uint) map[string]*dag.Vertex[T] {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetRandomVertices", n)
ret0, _ := ret[0].(map[string]*dag.Vertex[T])
return ret0
}
// GetRandomVertices indicates an expected call of GetRandomVertices.
func (mr *MockDAGMockRecorder[T]) GetRandomVertices(n interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRandomVertices", reflect.TypeOf((*MockDAG[T])(nil).GetRandomVertices), n)
} }
// GetSinkVertices mocks base method. // GetSinkVertices mocks base method.
func (m *MockDAG) GetSinkVertices() map[string]*dag.Vertex { func (m *MockDAG[T]) GetSinkVertices() map[string]*dag.Vertex[T] {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetSinkVertices") ret := m.ctrl.Call(m, "GetSinkVertices")
ret0, _ := ret[0].(map[string]*dag.Vertex) ret0, _ := ret[0].(map[string]*dag.Vertex[T])
return ret0 return ret0
} }
// GetSinkVertices indicates an expected call of GetSinkVertices. // GetSinkVertices indicates an expected call of GetSinkVertices.
func (mr *MockDAGMockRecorder) GetSinkVertices() *gomock.Call { func (mr *MockDAGMockRecorder[T]) GetSinkVertices() *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSinkVertices", reflect.TypeOf((*MockDAG)(nil).GetSinkVertices)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSinkVertices", reflect.TypeOf((*MockDAG[T])(nil).GetSinkVertices))
} }
// GetSourceVertices mocks base method. // GetSourceVertices mocks base method.
func (m *MockDAG) GetSourceVertices() map[string]*dag.Vertex { func (m *MockDAG[T]) GetSourceVertices() map[string]*dag.Vertex[T] {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetSourceVertices") ret := m.ctrl.Call(m, "GetSourceVertices")
ret0, _ := ret[0].(map[string]*dag.Vertex) ret0, _ := ret[0].(map[string]*dag.Vertex[T])
return ret0 return ret0
} }
// GetSourceVertices indicates an expected call of GetSourceVertices. // GetSourceVertices indicates an expected call of GetSourceVertices.
func (mr *MockDAGMockRecorder) GetSourceVertices() *gomock.Call { func (mr *MockDAGMockRecorder[T]) GetSourceVertices() *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSourceVertices", reflect.TypeOf((*MockDAG)(nil).GetSourceVertices)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSourceVertices", reflect.TypeOf((*MockDAG[T])(nil).GetSourceVertices))
} }
// GetVertex mocks base method. // GetVertex mocks base method.
func (m *MockDAG) GetVertex(id string) (*dag.Vertex, error) { func (m *MockDAG[T]) GetVertex(id string) (*dag.Vertex[T], error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetVertex", id) ret := m.ctrl.Call(m, "GetVertex", id)
ret0, _ := ret[0].(*dag.Vertex) ret0, _ := ret[0].(*dag.Vertex[T])
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// GetVertex indicates an expected call of GetVertex. // GetVertex indicates an expected call of GetVertex.
func (mr *MockDAGMockRecorder) GetVertex(id interface{}) *gomock.Call { func (mr *MockDAGMockRecorder[T]) GetVertex(id interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVertex", reflect.TypeOf((*MockDAG)(nil).GetVertex), id) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVertex", reflect.TypeOf((*MockDAG[T])(nil).GetVertex), id)
}
// GetVertexKeys mocks base method.
func (m *MockDAG[T]) GetVertexKeys() []string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetVertexKeys")
ret0, _ := ret[0].([]string)
return ret0
}
// GetVertexKeys indicates an expected call of GetVertexKeys.
func (mr *MockDAGMockRecorder[T]) GetVertexKeys() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVertexKeys", reflect.TypeOf((*MockDAG[T])(nil).GetVertexKeys))
} }
// GetVertices mocks base method. // GetVertices mocks base method.
func (m *MockDAG) GetVertices() map[string]*dag.Vertex { func (m *MockDAG[T]) GetVertices() map[string]*dag.Vertex[T] {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetVertices") ret := m.ctrl.Call(m, "GetVertices")
ret0, _ := ret[0].(map[string]*dag.Vertex) ret0, _ := ret[0].(map[string]*dag.Vertex[T])
return ret0 return ret0
} }
// GetVertices indicates an expected call of GetVertices. // GetVertices indicates an expected call of GetVertices.
func (mr *MockDAGMockRecorder) GetVertices() *gomock.Call { func (mr *MockDAGMockRecorder[T]) GetVertices() *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVertices", reflect.TypeOf((*MockDAG)(nil).GetVertices)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVertices", reflect.TypeOf((*MockDAG[T])(nil).GetVertices))
} }
// VertexCount mocks base method. // VertexCount mocks base method.
func (m *MockDAG) VertexCount() int { func (m *MockDAG[T]) VertexCount() int {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "VertexCount") ret := m.ctrl.Call(m, "VertexCount")
ret0, _ := ret[0].(int) ret0, _ := ret[0].(int)
@ -168,7 +196,7 @@ func (m *MockDAG) VertexCount() int {
} }
// VertexCount indicates an expected call of VertexCount. // VertexCount indicates an expected call of VertexCount.
func (mr *MockDAGMockRecorder) VertexCount() *gomock.Call { func (mr *MockDAGMockRecorder[T]) VertexCount() *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "VertexCount", reflect.TypeOf((*MockDAG)(nil).VertexCount)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "VertexCount", reflect.TypeOf((*MockDAG[T])(nil).VertexCount))
} }

View File

@ -19,62 +19,52 @@ package dag
import "d7y.io/dragonfly/v2/pkg/container/set" import "d7y.io/dragonfly/v2/pkg/container/set"
// Vertex is a vertex of the directed acyclic graph. // Vertex is a vertex of the directed acyclic graph.
type Vertex struct { type Vertex[T comparable] struct {
ID string ID string
Value any Value T
Parents set.SafeSet Parents set.SafeSet[*Vertex[T]]
Children set.SafeSet Children set.SafeSet[*Vertex[T]]
} }
// New returns a new Vertex instance. // New returns a new Vertex instance.
func NewVertex(id string, value any) *Vertex { func NewVertex[T comparable](id string, value T) *Vertex[T] {
return &Vertex{ return &Vertex[T]{
ID: id, ID: id,
Value: value, Value: value,
Parents: set.NewSafeSet(), Parents: set.NewSafeSet[*Vertex[T]](),
Children: set.NewSafeSet(), Children: set.NewSafeSet[*Vertex[T]](),
} }
} }
// Degree returns the degree of vertex. // Degree returns the degree of vertex.
func (v *Vertex) Degree() int { func (v *Vertex[T]) Degree() int {
return int(v.Parents.Len() + v.Children.Len()) return int(v.Parents.Len() + v.Children.Len())
} }
// InDegree returns the indegree of vertex. // InDegree returns the indegree of vertex.
func (v *Vertex) InDegree() int { func (v *Vertex[T]) InDegree() int {
return int(v.Parents.Len()) return int(v.Parents.Len())
} }
// OutDegree returns the outdegree of vertex. // OutDegree returns the outdegree of vertex.
func (v *Vertex) OutDegree() int { func (v *Vertex[T]) OutDegree() int {
return int(v.Children.Len()) return int(v.Children.Len())
} }
// DeleteInEdges deletes inedges of vertex. // DeleteInEdges deletes inedges of vertex.
func (v *Vertex) DeleteInEdges() { func (v *Vertex[T]) DeleteInEdges() {
for _, value := range v.Parents.Values() { for _, parent := range v.Parents.Values() {
vertex, ok := value.(*Vertex) parent.Children.Delete(v)
if !ok {
continue
}
vertex.Children.Delete(v)
} }
v.Parents = set.NewSafeSet() v.Parents = set.NewSafeSet[*Vertex[T]]()
} }
// DeleteOutEdges deletes outedges of vertex. // DeleteOutEdges deletes outedges of vertex.
func (v *Vertex) DeleteOutEdges() { func (v *Vertex[T]) DeleteOutEdges() {
for _, value := range v.Children.Values() { for _, child := range v.Children.Values() {
vertex, ok := value.(*Vertex) child.Parents.Delete(v)
if !ok {
continue
}
vertex.Parents.Delete(v)
} }
v.Children = set.NewSafeSet() v.Children = set.NewSafeSet[*Vertex[T]]()
} }

View File

@ -43,16 +43,16 @@ func TestVertexDegree(t *testing.T) {
assert.Equal(v.Value, mockVertexValue) assert.Equal(v.Value, mockVertexValue)
assert.Equal(v.Degree(), 0) assert.Equal(v.Degree(), 0)
v.Parents.Add(mockVertexID) v.Parents.Add(v)
assert.Equal(v.Degree(), 1) assert.Equal(v.Degree(), 1)
v.Children.Add(mockVertexID) v.Children.Add(v)
assert.Equal(v.Degree(), 2) assert.Equal(v.Degree(), 2)
v.Parents.Delete(mockVertexID) v.Parents.Delete(v)
assert.Equal(v.Degree(), 1) assert.Equal(v.Degree(), 1)
v.Children.Delete(mockVertexID) v.Children.Delete(v)
assert.Equal(v.Degree(), 0) assert.Equal(v.Degree(), 0)
} }
@ -63,16 +63,16 @@ func TestVertexInDegree(t *testing.T) {
assert.Equal(v.Value, mockVertexValue) assert.Equal(v.Value, mockVertexValue)
assert.Equal(v.InDegree(), 0) assert.Equal(v.InDegree(), 0)
v.Parents.Add(mockVertexID) v.Parents.Add(v)
assert.Equal(v.InDegree(), 1) assert.Equal(v.InDegree(), 1)
v.Children.Add(mockVertexID) v.Children.Add(v)
assert.Equal(v.InDegree(), 1) assert.Equal(v.InDegree(), 1)
v.Parents.Delete(mockVertexID) v.Parents.Delete(v)
assert.Equal(v.InDegree(), 0) assert.Equal(v.InDegree(), 0)
v.Children.Delete(mockVertexID) v.Children.Delete(v)
assert.Equal(v.InDegree(), 0) assert.Equal(v.InDegree(), 0)
} }
@ -83,16 +83,16 @@ func TestVertexOutDegree(t *testing.T) {
assert.Equal(v.Value, mockVertexValue) assert.Equal(v.Value, mockVertexValue)
assert.Equal(v.OutDegree(), 0) assert.Equal(v.OutDegree(), 0)
v.Parents.Add(mockVertexID) v.Parents.Add(v)
assert.Equal(v.OutDegree(), 0) assert.Equal(v.OutDegree(), 0)
v.Children.Add(mockVertexID) v.Children.Add(v)
assert.Equal(v.OutDegree(), 1) assert.Equal(v.OutDegree(), 1)
v.Parents.Delete(mockVertexID) v.Parents.Delete(v)
assert.Equal(v.OutDegree(), 1) assert.Equal(v.OutDegree(), 1)
v.Children.Delete(mockVertexID) v.Children.Delete(v)
assert.Equal(v.OutDegree(), 0) assert.Equal(v.OutDegree(), 0)
} }

View File

@ -31,7 +31,6 @@ import (
logger "d7y.io/dragonfly/v2/internal/dflog" logger "d7y.io/dragonfly/v2/internal/dflog"
"d7y.io/dragonfly/v2/pkg/container/set" "d7y.io/dragonfly/v2/pkg/container/set"
"d7y.io/dragonfly/v2/pkg/dag"
"d7y.io/dragonfly/v2/pkg/rpc/scheduler" "d7y.io/dragonfly/v2/pkg/rpc/scheduler"
) )
@ -135,7 +134,7 @@ type Peer struct {
Host *Host Host *Host
// BlockPeers is bad peer ids. // BlockPeers is bad peer ids.
BlockPeers set.SafeSet BlockPeers set.SafeSet[string]
// NeedBackToSource needs downloaded from source. // NeedBackToSource needs downloaded from source.
// //
@ -171,7 +170,7 @@ func NewPeer(id string, task *Task, host *Host, options ...PeerOption) *Peer {
Stream: &atomic.Value{}, Stream: &atomic.Value{},
Task: task, Task: task,
Host: host, Host: host,
BlockPeers: set.NewSafeSet(), BlockPeers: set.NewSafeSet[string](),
NeedBackToSource: atomic.NewBool(false), NeedBackToSource: atomic.NewBool(false),
IsBackToSource: atomic.NewBool(false), IsBackToSource: atomic.NewBool(false),
CreateAt: atomic.NewTime(time.Now()), CreateAt: atomic.NewTime(time.Now()),
@ -222,7 +221,7 @@ func NewPeer(id string, task *Task, host *Host, options ...PeerOption) *Peer {
}, },
PeerEventDownloadFromBackToSource: func(e *fsm.Event) { PeerEventDownloadFromBackToSource: func(e *fsm.Event) {
p.IsBackToSource.Store(true) p.IsBackToSource.Store(true)
p.Task.BackToSourcePeers.Add(p) p.Task.BackToSourcePeers.Add(p.ID)
if err := p.Task.DeletePeerInEdges(p.ID); err != nil { if err := p.Task.DeletePeerInEdges(p.ID); err != nil {
p.Log.Errorf("delete peer inedges failed: %s", err.Error()) p.Log.Errorf("delete peer inedges failed: %s", err.Error())
@ -234,7 +233,7 @@ func NewPeer(id string, task *Task, host *Host, options ...PeerOption) *Peer {
}, },
PeerEventDownloadSucceeded: func(e *fsm.Event) { PeerEventDownloadSucceeded: func(e *fsm.Event) {
if e.Src == PeerStateBackToSource { if e.Src == PeerStateBackToSource {
p.Task.BackToSourcePeers.Delete(p) p.Task.BackToSourcePeers.Delete(p.ID)
} }
if err := p.Task.DeletePeerInEdges(p.ID); err != nil { if err := p.Task.DeletePeerInEdges(p.ID); err != nil {
@ -249,7 +248,7 @@ func NewPeer(id string, task *Task, host *Host, options ...PeerOption) *Peer {
PeerEventDownloadFailed: func(e *fsm.Event) { PeerEventDownloadFailed: func(e *fsm.Event) {
if e.Src == PeerStateBackToSource { if e.Src == PeerStateBackToSource {
p.Task.PeerFailedCount.Inc() p.Task.PeerFailedCount.Inc()
p.Task.BackToSourcePeers.Delete(p) p.Task.BackToSourcePeers.Delete(p.ID)
} }
if err := p.Task.DeletePeerInEdges(p.ID); err != nil { if err := p.Task.DeletePeerInEdges(p.ID); err != nil {
@ -317,23 +316,12 @@ func (p *Peer) Parents() []*Peer {
} }
var parents []*Peer var parents []*Peer
for _, value := range vertex.Parents.Values() { for _, parent := range vertex.Parents.Values() {
vertex, ok := value.(*dag.Vertex) if parent.Value == nil {
if !ok {
continue continue
} }
vertexVal := vertex.Value parents = append(parents, parent.Value)
if vertexVal == nil {
continue
}
parent, ok := vertexVal.(*Peer)
if !ok {
continue
}
parents = append(parents, parent)
} }
return parents return parents
@ -348,23 +336,12 @@ func (p *Peer) Children() []*Peer {
} }
var children []*Peer var children []*Peer
for _, value := range vertex.Children.Values() { for _, child := range vertex.Children.Values() {
vertex, ok := value.(*dag.Vertex) if child.Value == nil {
if !ok {
continue continue
} }
vertexVal := vertex.Value children = append(children, child.Value)
if vertexVal == nil {
continue
}
child, ok := vertexVal.(*Peer)
if !ok {
continue
}
children = append(children, child)
} }
return children return children

View File

@ -18,8 +18,6 @@ package resource
import ( import (
"errors" "errors"
"math/rand"
reflect "reflect"
"sort" "sort"
"sync" "sync"
"time" "time"
@ -106,7 +104,7 @@ type Task struct {
BackToSourceLimit *atomic.Int32 BackToSourceLimit *atomic.Int32
// BackToSourcePeers is back-to-source sync map. // BackToSourcePeers is back-to-source sync map.
BackToSourcePeers set.SafeSet BackToSourcePeers set.SafeSet[string]
// Task state machine. // Task state machine.
FSM *fsm.FSM FSM *fsm.FSM
@ -115,7 +113,7 @@ type Task struct {
Pieces *sync.Map Pieces *sync.Map
// DAG is directed acyclic graph of peers. // DAG is directed acyclic graph of peers.
DAG dag.DAG DAG dag.DAG[*Peer]
// PeerFailedCount is peer failed count, // PeerFailedCount is peer failed count,
// if one peer succeeds, the value is reset to zero. // if one peer succeeds, the value is reset to zero.
@ -141,9 +139,9 @@ func NewTask(id, url string, taskType base.TaskType, meta *base.UrlMeta, options
ContentLength: atomic.NewInt64(0), ContentLength: atomic.NewInt64(0),
TotalPieceCount: atomic.NewInt32(0), TotalPieceCount: atomic.NewInt32(0),
BackToSourceLimit: atomic.NewInt32(0), BackToSourceLimit: atomic.NewInt32(0),
BackToSourcePeers: set.NewSafeSet(), BackToSourcePeers: set.NewSafeSet[string](),
Pieces: &sync.Map{}, Pieces: &sync.Map{},
DAG: dag.NewDAG(), DAG: dag.NewDAG[*Peer](),
PeerFailedCount: atomic.NewInt32(0), PeerFailedCount: atomic.NewInt32(0),
CreateAt: atomic.NewTime(time.Now()), CreateAt: atomic.NewTime(time.Now()),
UpdateAt: atomic.NewTime(time.Now()), UpdateAt: atomic.NewTime(time.Now()),
@ -189,54 +187,14 @@ func (t *Task) LoadPeer(key string) (*Peer, bool) {
return nil, false return nil, false
} }
value := vertex.Value return vertex.Value, true
if value == nil {
return nil, false
}
return value.(*Peer), true
} }
// LoadRandomPeers return random peers. // LoadRandomPeers return random peers.
func (t *Task) LoadRandomPeers(n uint) []*Peer { func (t *Task) LoadRandomPeers(n uint) []*Peer {
var peers []*Peer var peers []*Peer
vertices := t.DAG.GetVertices() for _, vertex := range t.DAG.GetRandomVertices(n) {
keys := reflect.ValueOf(vertices).MapKeys() peers = append(peers, vertex.Value)
if int(n) >= len(keys) {
for _, vertex := range vertices {
value := vertex.Value
if value == nil {
continue
}
peer, ok := value.(*Peer)
if !ok {
continue
}
peers = append(peers, peer)
}
return peers
}
rand.Seed(time.Now().Unix())
permutation := rand.Perm(len(keys))[:n]
for _, v := range permutation {
key := keys[v].String()
vertex := vertices[key]
value := vertex.Value
if value == nil {
continue
}
peer, ok := value.(*Peer)
if !ok {
continue
}
peers = append(peers, peer)
} }
return peers return peers
@ -282,23 +240,12 @@ func (t *Task) DeletePeerInEdges(key string) error {
return err return err
} }
for _, value := range vertex.Parents.Values() { for _, parent := range vertex.Parents.Values() {
vertex, ok := value.(*dag.Vertex) if parent.Value == nil {
if !ok {
continue continue
} }
vertexVal := vertex.Value parent.Value.Host.UploadPeerCount.Dec()
if vertexVal == nil {
continue
}
parent, ok := vertexVal.(*Peer)
if !ok {
continue
}
parent.Host.UploadPeerCount.Dec()
} }
vertex.DeleteInEdges() vertex.DeleteInEdges()
@ -312,16 +259,11 @@ func (t *Task) DeletePeerOutEdges(key string) error {
return err return err
} }
value := vertex.Value peer := vertex.Value
if value == nil { if peer == nil {
return errors.New("vertex value is nil") return errors.New("vertex value is nil")
} }
peer, ok := value.(*Peer)
if !ok {
return errors.New("vertex value is not peer")
}
peer.Host.UploadPeerCount.Sub(int32(vertex.Children.Len())) peer.Host.UploadPeerCount.Sub(int32(vertex.Children.Len()))
vertex.DeleteOutEdges() vertex.DeleteOutEdges()
return nil return nil
@ -366,13 +308,8 @@ func (t *Task) PeerOutDegree(key string) (int, error) {
func (t *Task) HasAvailablePeer() bool { func (t *Task) HasAvailablePeer() bool {
var hasAvailablePeer bool var hasAvailablePeer bool
for _, vertex := range t.DAG.GetVertices() { for _, vertex := range t.DAG.GetVertices() {
value := vertex.Value peer := vertex.Value
if value == nil { if peer == nil {
continue
}
peer, ok := value.(*Peer)
if !ok {
continue continue
} }
@ -389,13 +326,8 @@ func (t *Task) HasAvailablePeer() bool {
func (t *Task) LoadSeedPeer() (*Peer, bool) { func (t *Task) LoadSeedPeer() (*Peer, bool) {
var peers []*Peer var peers []*Peer
for _, vertex := range t.DAG.GetVertices() { for _, vertex := range t.DAG.GetVertices() {
value := vertex.Value peer := vertex.Value
if value == nil { if peer == nil {
continue
}
peer, ok := value.(*Peer)
if !ok {
continue continue
} }
@ -473,12 +405,11 @@ func (t *Task) CanBackToSource() bool {
// NotifyPeers notify all peers in the task with the state code. // NotifyPeers notify all peers in the task with the state code.
func (t *Task) NotifyPeers(peerPacket *rpcscheduler.PeerPacket, event string) { func (t *Task) NotifyPeers(peerPacket *rpcscheduler.PeerPacket, event string) {
for _, vertex := range t.DAG.GetVertices() { for _, vertex := range t.DAG.GetVertices() {
value := vertex.Value peer := vertex.Value
if value == nil { if peer == nil {
continue continue
} }
peer := value.(*Peer)
if peer.FSM.Is(PeerStateRunning) { if peer.FSM.Is(PeerStateRunning) {
stream, ok := peer.LoadStream() stream, ok := peer.LoadStream()
if !ok { if !ok {

View File

@ -37,7 +37,7 @@ func (m *MockScheduler) EXPECT() *MockSchedulerMockRecorder {
} }
// FindParent mocks base method. // FindParent mocks base method.
func (m *MockScheduler) FindParent(arg0 context.Context, arg1 *resource.Peer, arg2 set.SafeSet) (*resource.Peer, bool) { func (m *MockScheduler) FindParent(arg0 context.Context, arg1 *resource.Peer, arg2 set.SafeSet[string]) (*resource.Peer, bool) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FindParent", arg0, arg1, arg2) ret := m.ctrl.Call(m, "FindParent", arg0, arg1, arg2)
ret0, _ := ret[0].(*resource.Peer) ret0, _ := ret[0].(*resource.Peer)
@ -52,7 +52,7 @@ func (mr *MockSchedulerMockRecorder) FindParent(arg0, arg1, arg2 interface{}) *g
} }
// NotifyAndFindParent mocks base method. // NotifyAndFindParent mocks base method.
func (m *MockScheduler) NotifyAndFindParent(arg0 context.Context, arg1 *resource.Peer, arg2 set.SafeSet) ([]*resource.Peer, bool) { func (m *MockScheduler) NotifyAndFindParent(arg0 context.Context, arg1 *resource.Peer, arg2 set.SafeSet[string]) ([]*resource.Peer, bool) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "NotifyAndFindParent", arg0, arg1, arg2) ret := m.ctrl.Call(m, "NotifyAndFindParent", arg0, arg1, arg2)
ret0, _ := ret[0].([]*resource.Peer) ret0, _ := ret[0].([]*resource.Peer)
@ -67,7 +67,7 @@ func (mr *MockSchedulerMockRecorder) NotifyAndFindParent(arg0, arg1, arg2 interf
} }
// ScheduleParent mocks base method. // ScheduleParent mocks base method.
func (m *MockScheduler) ScheduleParent(arg0 context.Context, arg1 *resource.Peer, arg2 set.SafeSet) { func (m *MockScheduler) ScheduleParent(arg0 context.Context, arg1 *resource.Peer, arg2 set.SafeSet[string]) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "ScheduleParent", arg0, arg1, arg2) m.ctrl.Call(m, "ScheduleParent", arg0, arg1, arg2)
} }

View File

@ -33,13 +33,13 @@ import (
type Scheduler interface { type Scheduler interface {
// ScheduleParent schedule a parent and candidates to a peer. // ScheduleParent schedule a parent and candidates to a peer.
ScheduleParent(context.Context, *resource.Peer, set.SafeSet) ScheduleParent(context.Context, *resource.Peer, set.SafeSet[string])
// Find the parent that best matches the evaluation and notify peer. // Find the parent that best matches the evaluation and notify peer.
NotifyAndFindParent(context.Context, *resource.Peer, set.SafeSet) ([]*resource.Peer, bool) NotifyAndFindParent(context.Context, *resource.Peer, set.SafeSet[string]) ([]*resource.Peer, bool)
// Find the parent that best matches the evaluation. // Find the parent that best matches the evaluation.
FindParent(context.Context, *resource.Peer, set.SafeSet) (*resource.Peer, bool) FindParent(context.Context, *resource.Peer, set.SafeSet[string]) (*resource.Peer, bool)
} }
type scheduler struct { type scheduler struct {
@ -62,7 +62,7 @@ func New(cfg *config.SchedulerConfig, dynconfig config.DynconfigInterface, plugi
} }
// ScheduleParent schedule a parent and candidates to a peer. // ScheduleParent schedule a parent and candidates to a peer.
func (s *scheduler) ScheduleParent(ctx context.Context, peer *resource.Peer, blocklist set.SafeSet) { func (s *scheduler) ScheduleParent(ctx context.Context, peer *resource.Peer, blocklist set.SafeSet[string]) {
var n int var n int
for { for {
select { select {
@ -141,7 +141,7 @@ func (s *scheduler) ScheduleParent(ctx context.Context, peer *resource.Peer, blo
} }
// NotifyAndFindParent finds parent that best matches the evaluation and notify peer. // NotifyAndFindParent finds parent that best matches the evaluation and notify peer.
func (s *scheduler) NotifyAndFindParent(ctx context.Context, peer *resource.Peer, blocklist set.SafeSet) ([]*resource.Peer, bool) { func (s *scheduler) NotifyAndFindParent(ctx context.Context, peer *resource.Peer, blocklist set.SafeSet[string]) ([]*resource.Peer, bool) {
// Only PeerStateRunning peers need to be rescheduled, // Only PeerStateRunning peers need to be rescheduled,
// and other states including the PeerStateBackToSource indicate that // and other states including the PeerStateBackToSource indicate that
// they have been scheduled. // they have been scheduled.
@ -209,7 +209,7 @@ func (s *scheduler) NotifyAndFindParent(ctx context.Context, peer *resource.Peer
} }
// FindParent finds parent that best matches the evaluation. // FindParent finds parent that best matches the evaluation.
func (s *scheduler) FindParent(ctx context.Context, peer *resource.Peer, blocklist set.SafeSet) (*resource.Peer, bool) { func (s *scheduler) FindParent(ctx context.Context, peer *resource.Peer, blocklist set.SafeSet[string]) (*resource.Peer, bool) {
// Filter the candidate parent that can be scheduled. // Filter the candidate parent that can be scheduled.
candidateParents := s.filterCandidateParents(peer, blocklist) candidateParents := s.filterCandidateParents(peer, blocklist)
if len(candidateParents) == 0 { if len(candidateParents) == 0 {
@ -231,7 +231,7 @@ func (s *scheduler) FindParent(ctx context.Context, peer *resource.Peer, blockli
} }
// Filter the candidate parent that can be scheduled. // Filter the candidate parent that can be scheduled.
func (s *scheduler) filterCandidateParents(peer *resource.Peer, blocklist set.SafeSet) []*resource.Peer { func (s *scheduler) filterCandidateParents(peer *resource.Peer, blocklist set.SafeSet[string]) []*resource.Peer {
filterParentLimit := config.DefaultSchedulerFilterParentLimit filterParentLimit := config.DefaultSchedulerFilterParentLimit
if config, ok := s.dynconfig.GetSchedulerClusterConfig(); ok && filterParentLimit > 0 { if config, ok := s.dynconfig.GetSchedulerClusterConfig(); ok && filterParentLimit > 0 {
filterParentLimit = int(config.FilterParentLimit) filterParentLimit = int(config.FilterParentLimit)

View File

@ -127,12 +127,12 @@ func TestScheduler_New(t *testing.T) {
func TestScheduler_ScheduleParent(t *testing.T) { func TestScheduler_ScheduleParent(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
mock func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) mock func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder)
expect func(t *testing.T, peer *resource.Peer) expect func(t *testing.T, peer *resource.Peer)
}{ }{
{ {
name: "context was done", name: "context was done",
mock: func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { mock: func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) {
peer.FSM.SetState(resource.PeerStateRunning) peer.FSM.SetState(resource.PeerStateRunning)
cancel() cancel()
}, },
@ -143,7 +143,7 @@ func TestScheduler_ScheduleParent(t *testing.T) {
}, },
{ {
name: "peer needs back-to-source and peer stream load failed", name: "peer needs back-to-source and peer stream load failed",
mock: func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { mock: func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) {
task := peer.Task task := peer.Task
task.StorePeer(peer) task.StorePeer(peer)
peer.NeedBackToSource.Store(true) peer.NeedBackToSource.Store(true)
@ -156,7 +156,7 @@ func TestScheduler_ScheduleParent(t *testing.T) {
}, },
{ {
name: "peer needs back-to-source and send Code_SchedNeedBackSource code failed", name: "peer needs back-to-source and send Code_SchedNeedBackSource code failed",
mock: func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { mock: func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) {
task := peer.Task task := peer.Task
task.StorePeer(peer) task.StorePeer(peer)
peer.NeedBackToSource.Store(true) peer.NeedBackToSource.Store(true)
@ -173,7 +173,7 @@ func TestScheduler_ScheduleParent(t *testing.T) {
}, },
{ {
name: "peer needs back-to-source and send Code_SchedNeedBackSource code success", name: "peer needs back-to-source and send Code_SchedNeedBackSource code success",
mock: func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { mock: func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) {
task := peer.Task task := peer.Task
task.StorePeer(peer) task.StorePeer(peer)
peer.NeedBackToSource.Store(true) peer.NeedBackToSource.Store(true)
@ -191,7 +191,7 @@ func TestScheduler_ScheduleParent(t *testing.T) {
}, },
{ {
name: "peer needs back-to-source and task state is TaskStateFailed", name: "peer needs back-to-source and task state is TaskStateFailed",
mock: func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { mock: func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) {
task := peer.Task task := peer.Task
task.StorePeer(peer) task.StorePeer(peer)
peer.NeedBackToSource.Store(true) peer.NeedBackToSource.Store(true)
@ -210,7 +210,7 @@ func TestScheduler_ScheduleParent(t *testing.T) {
}, },
{ {
name: "schedule exceeds RetryBackSourceLimit and peer stream load failed", name: "schedule exceeds RetryBackSourceLimit and peer stream load failed",
mock: func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { mock: func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) {
task := peer.Task task := peer.Task
task.StorePeer(peer) task.StorePeer(peer)
peer.FSM.SetState(resource.PeerStateRunning) peer.FSM.SetState(resource.PeerStateRunning)
@ -223,7 +223,7 @@ func TestScheduler_ScheduleParent(t *testing.T) {
}, },
{ {
name: "schedule exceeds RetryLimit and peer stream load failed", name: "schedule exceeds RetryLimit and peer stream load failed",
mock: func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { mock: func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) {
task := peer.Task task := peer.Task
task.StorePeer(peer) task.StorePeer(peer)
peer.FSM.SetState(resource.PeerStateRunning) peer.FSM.SetState(resource.PeerStateRunning)
@ -238,7 +238,7 @@ func TestScheduler_ScheduleParent(t *testing.T) {
}, },
{ {
name: "schedule exceeds RetryLimit and send Code_SchedTaskStatusError code failed", name: "schedule exceeds RetryLimit and send Code_SchedTaskStatusError code failed",
mock: func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { mock: func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) {
task := peer.Task task := peer.Task
task.StorePeer(peer) task.StorePeer(peer)
peer.FSM.SetState(resource.PeerStateRunning) peer.FSM.SetState(resource.PeerStateRunning)
@ -258,7 +258,7 @@ func TestScheduler_ScheduleParent(t *testing.T) {
}, },
{ {
name: "schedule exceeds RetryLimit and send Code_SchedTaskStatusError code success", name: "schedule exceeds RetryLimit and send Code_SchedTaskStatusError code success",
mock: func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { mock: func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) {
task := peer.Task task := peer.Task
task.StorePeer(peer) task.StorePeer(peer)
peer.FSM.SetState(resource.PeerStateRunning) peer.FSM.SetState(resource.PeerStateRunning)
@ -278,7 +278,7 @@ func TestScheduler_ScheduleParent(t *testing.T) {
}, },
{ {
name: "schedule succeeded", name: "schedule succeeded",
mock: func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { mock: func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) {
task := peer.Task task := peer.Task
task.StorePeer(peer) task.StorePeer(peer)
task.StorePeer(seedPeer) task.StorePeer(seedPeer)
@ -313,7 +313,7 @@ func TestScheduler_ScheduleParent(t *testing.T) {
peer := resource.NewPeer(mockPeerID, mockTask, mockHost) peer := resource.NewPeer(mockPeerID, mockTask, mockHost)
mockSeedHost := resource.NewHost(mockRawSeedHost, resource.WithHostType(resource.HostTypeSuperSeed)) mockSeedHost := resource.NewHost(mockRawSeedHost, resource.WithHostType(resource.HostTypeSuperSeed))
seedPeer := resource.NewPeer(mockSeedPeerID, mockTask, mockSeedHost) seedPeer := resource.NewPeer(mockSeedPeerID, mockTask, mockSeedHost)
blocklist := set.NewSafeSet() blocklist := set.NewSafeSet[string]()
tc.mock(cancel, peer, seedPeer, blocklist, stream, stream.EXPECT(), dynconfig.EXPECT()) tc.mock(cancel, peer, seedPeer, blocklist, stream, stream.EXPECT(), dynconfig.EXPECT())
scheduler := New(mockSchedulerConfig, dynconfig, mockPluginDir) scheduler := New(mockSchedulerConfig, dynconfig, mockPluginDir)
@ -326,12 +326,12 @@ func TestScheduler_ScheduleParent(t *testing.T) {
func TestScheduler_NotifyAndFindParent(t *testing.T) { func TestScheduler_NotifyAndFindParent(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
mock func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) mock func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder)
expect func(t *testing.T, peer *resource.Peer, parents []*resource.Peer, ok bool) expect func(t *testing.T, peer *resource.Peer, parents []*resource.Peer, ok bool)
}{ }{
{ {
name: "peer state is PeerStatePending", name: "peer state is PeerStatePending",
mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) {
peer.FSM.SetState(resource.PeerStatePending) peer.FSM.SetState(resource.PeerStatePending)
}, },
expect: func(t *testing.T, peer *resource.Peer, parents []*resource.Peer, ok bool) { expect: func(t *testing.T, peer *resource.Peer, parents []*resource.Peer, ok bool) {
@ -341,7 +341,7 @@ func TestScheduler_NotifyAndFindParent(t *testing.T) {
}, },
{ {
name: "peer state is PeerStateReceivedSmall", name: "peer state is PeerStateReceivedSmall",
mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) {
peer.FSM.SetState(resource.PeerStateReceivedSmall) peer.FSM.SetState(resource.PeerStateReceivedSmall)
}, },
expect: func(t *testing.T, peer *resource.Peer, parents []*resource.Peer, ok bool) { expect: func(t *testing.T, peer *resource.Peer, parents []*resource.Peer, ok bool) {
@ -351,7 +351,7 @@ func TestScheduler_NotifyAndFindParent(t *testing.T) {
}, },
{ {
name: "peer state is PeerStateReceivedNormal", name: "peer state is PeerStateReceivedNormal",
mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) {
peer.FSM.SetState(resource.PeerStateReceivedNormal) peer.FSM.SetState(resource.PeerStateReceivedNormal)
}, },
expect: func(t *testing.T, peer *resource.Peer, parents []*resource.Peer, ok bool) { expect: func(t *testing.T, peer *resource.Peer, parents []*resource.Peer, ok bool) {
@ -361,7 +361,7 @@ func TestScheduler_NotifyAndFindParent(t *testing.T) {
}, },
{ {
name: "peer state is PeerStateBackToSource", name: "peer state is PeerStateBackToSource",
mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) {
peer.FSM.SetState(resource.PeerStateBackToSource) peer.FSM.SetState(resource.PeerStateBackToSource)
}, },
expect: func(t *testing.T, peer *resource.Peer, parents []*resource.Peer, ok bool) { expect: func(t *testing.T, peer *resource.Peer, parents []*resource.Peer, ok bool) {
@ -371,7 +371,7 @@ func TestScheduler_NotifyAndFindParent(t *testing.T) {
}, },
{ {
name: "peer state is PeerStateSucceeded", name: "peer state is PeerStateSucceeded",
mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) {
peer.FSM.SetState(resource.PeerStateSucceeded) peer.FSM.SetState(resource.PeerStateSucceeded)
}, },
expect: func(t *testing.T, peer *resource.Peer, parents []*resource.Peer, ok bool) { expect: func(t *testing.T, peer *resource.Peer, parents []*resource.Peer, ok bool) {
@ -381,7 +381,7 @@ func TestScheduler_NotifyAndFindParent(t *testing.T) {
}, },
{ {
name: "peer state is PeerStateFailed", name: "peer state is PeerStateFailed",
mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) {
peer.FSM.SetState(resource.PeerStateFailed) peer.FSM.SetState(resource.PeerStateFailed)
}, },
expect: func(t *testing.T, peer *resource.Peer, parents []*resource.Peer, ok bool) { expect: func(t *testing.T, peer *resource.Peer, parents []*resource.Peer, ok bool) {
@ -391,7 +391,7 @@ func TestScheduler_NotifyAndFindParent(t *testing.T) {
}, },
{ {
name: "peer state is PeerStateLeave", name: "peer state is PeerStateLeave",
mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) {
peer.FSM.SetState(resource.PeerStateLeave) peer.FSM.SetState(resource.PeerStateLeave)
}, },
expect: func(t *testing.T, peer *resource.Peer, parents []*resource.Peer, ok bool) { expect: func(t *testing.T, peer *resource.Peer, parents []*resource.Peer, ok bool) {
@ -401,7 +401,7 @@ func TestScheduler_NotifyAndFindParent(t *testing.T) {
}, },
{ {
name: "task peers is empty", name: "task peers is empty",
mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) {
peer.FSM.SetState(resource.PeerStateRunning) peer.FSM.SetState(resource.PeerStateRunning)
peer.Task.StorePeer(peer) peer.Task.StorePeer(peer)
md.GetSchedulerClusterConfig().Return(types.SchedulerClusterConfig{}, false).Times(1) md.GetSchedulerClusterConfig().Return(types.SchedulerClusterConfig{}, false).Times(1)
@ -413,7 +413,7 @@ func TestScheduler_NotifyAndFindParent(t *testing.T) {
}, },
{ {
name: "task contains only one peer and peer is itself", name: "task contains only one peer and peer is itself",
mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) {
peer.FSM.SetState(resource.PeerStateRunning) peer.FSM.SetState(resource.PeerStateRunning)
peer.Task.StorePeer(peer) peer.Task.StorePeer(peer)
@ -426,7 +426,7 @@ func TestScheduler_NotifyAndFindParent(t *testing.T) {
}, },
{ {
name: "peer is in blocklist", name: "peer is in blocklist",
mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) {
peer.FSM.SetState(resource.PeerStateRunning) peer.FSM.SetState(resource.PeerStateRunning)
peer.Task.StorePeer(peer) peer.Task.StorePeer(peer)
peer.Task.StorePeer(mockPeer) peer.Task.StorePeer(mockPeer)
@ -441,7 +441,7 @@ func TestScheduler_NotifyAndFindParent(t *testing.T) {
}, },
{ {
name: "peer is bad node", name: "peer is bad node",
mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) {
peer.FSM.SetState(resource.PeerStateRunning) peer.FSM.SetState(resource.PeerStateRunning)
peer.FSM.SetState(resource.PeerStateFailed) peer.FSM.SetState(resource.PeerStateFailed)
peer.Task.StorePeer(mockPeer) peer.Task.StorePeer(mockPeer)
@ -453,7 +453,7 @@ func TestScheduler_NotifyAndFindParent(t *testing.T) {
}, },
{ {
name: "parent is peer's descendant", name: "parent is peer's descendant",
mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) {
peer.FSM.SetState(resource.PeerStateRunning) peer.FSM.SetState(resource.PeerStateRunning)
mockPeer.FSM.SetState(resource.PeerStateRunning) mockPeer.FSM.SetState(resource.PeerStateRunning)
peer.Task.StorePeer(peer) peer.Task.StorePeer(peer)
@ -471,7 +471,7 @@ func TestScheduler_NotifyAndFindParent(t *testing.T) {
}, },
{ {
name: "parent is peer's ancestor", name: "parent is peer's ancestor",
mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) {
peer.FSM.SetState(resource.PeerStateRunning) peer.FSM.SetState(resource.PeerStateRunning)
mockPeer.FSM.SetState(resource.PeerStateRunning) mockPeer.FSM.SetState(resource.PeerStateRunning)
peer.Task.StorePeer(peer) peer.Task.StorePeer(peer)
@ -489,7 +489,7 @@ func TestScheduler_NotifyAndFindParent(t *testing.T) {
}, },
{ {
name: "parent free upload load is zero", name: "parent free upload load is zero",
mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) {
peer.FSM.SetState(resource.PeerStateRunning) peer.FSM.SetState(resource.PeerStateRunning)
mockPeer.FSM.SetState(resource.PeerStateRunning) mockPeer.FSM.SetState(resource.PeerStateRunning)
peer.Task.StorePeer(peer) peer.Task.StorePeer(peer)
@ -505,7 +505,7 @@ func TestScheduler_NotifyAndFindParent(t *testing.T) {
}, },
{ {
name: "peer stream is empty", name: "peer stream is empty",
mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) {
peer.FSM.SetState(resource.PeerStateRunning) peer.FSM.SetState(resource.PeerStateRunning)
mockPeer.FSM.SetState(resource.PeerStateRunning) mockPeer.FSM.SetState(resource.PeerStateRunning)
peer.Task.StorePeer(peer) peer.Task.StorePeer(peer)
@ -521,10 +521,10 @@ func TestScheduler_NotifyAndFindParent(t *testing.T) {
}, },
{ {
name: "peer stream send failed", name: "peer stream send failed",
mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) {
peer.FSM.SetState(resource.PeerStateRunning) peer.FSM.SetState(resource.PeerStateRunning)
mockPeer.FSM.SetState(resource.PeerStateRunning) mockPeer.FSM.SetState(resource.PeerStateRunning)
peer.Task.BackToSourcePeers.Add(mockPeer) peer.Task.BackToSourcePeers.Add(mockPeer.ID)
mockPeer.IsBackToSource.Store(true) mockPeer.IsBackToSource.Store(true)
peer.Task.StorePeer(peer) peer.Task.StorePeer(peer)
peer.Task.StorePeer(mockPeer) peer.Task.StorePeer(mockPeer)
@ -545,7 +545,7 @@ func TestScheduler_NotifyAndFindParent(t *testing.T) {
}, },
{ {
name: "schedule parent", name: "schedule parent",
mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) {
peer.FSM.SetState(resource.PeerStateRunning) peer.FSM.SetState(resource.PeerStateRunning)
mockPeer.FSM.SetState(resource.PeerStateRunning) mockPeer.FSM.SetState(resource.PeerStateRunning)
candidatePeer := resource.NewPeer(idgen.PeerID("127.0.0.1"), mockTask, mockHost) candidatePeer := resource.NewPeer(idgen.PeerID("127.0.0.1"), mockTask, mockHost)
@ -553,8 +553,8 @@ func TestScheduler_NotifyAndFindParent(t *testing.T) {
peer.Task.StorePeer(peer) peer.Task.StorePeer(peer)
peer.Task.StorePeer(mockPeer) peer.Task.StorePeer(mockPeer)
peer.Task.StorePeer(candidatePeer) peer.Task.StorePeer(candidatePeer)
peer.Task.BackToSourcePeers.Add(mockPeer) peer.Task.BackToSourcePeers.Add(mockPeer.ID)
peer.Task.BackToSourcePeers.Add(candidatePeer) peer.Task.BackToSourcePeers.Add(candidatePeer.ID)
mockPeer.IsBackToSource.Store(true) mockPeer.IsBackToSource.Store(true)
candidatePeer.IsBackToSource.Store(true) candidatePeer.IsBackToSource.Store(true)
mockPeer.Pieces.Set(0) mockPeer.Pieces.Set(0)
@ -585,7 +585,7 @@ func TestScheduler_NotifyAndFindParent(t *testing.T) {
mockTask := resource.NewTask(mockTaskID, mockTaskURL, base.TaskType_Normal, mockTaskURLMeta, resource.WithBackToSourceLimit(mockTaskBackToSourceLimit)) mockTask := resource.NewTask(mockTaskID, mockTaskURL, base.TaskType_Normal, mockTaskURLMeta, resource.WithBackToSourceLimit(mockTaskBackToSourceLimit))
peer := resource.NewPeer(mockPeerID, mockTask, mockHost) peer := resource.NewPeer(mockPeerID, mockTask, mockHost)
mockPeer := resource.NewPeer(idgen.PeerID("127.0.0.1"), mockTask, mockHost) mockPeer := resource.NewPeer(idgen.PeerID("127.0.0.1"), mockTask, mockHost)
blocklist := set.NewSafeSet() blocklist := set.NewSafeSet[string]()
tc.mock(peer, mockHost, mockTask, mockPeer, blocklist, stream, dynconfig, stream.EXPECT(), dynconfig.EXPECT()) tc.mock(peer, mockHost, mockTask, mockPeer, blocklist, stream, dynconfig, stream.EXPECT(), dynconfig.EXPECT())
scheduler := New(mockSchedulerConfig, dynconfig, mockPluginDir) scheduler := New(mockSchedulerConfig, dynconfig, mockPluginDir)
@ -598,12 +598,12 @@ func TestScheduler_NotifyAndFindParent(t *testing.T) {
func TestScheduler_FindParent(t *testing.T) { func TestScheduler_FindParent(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
mock func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet, md *configmocks.MockDynconfigInterfaceMockRecorder) mock func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet[string], md *configmocks.MockDynconfigInterfaceMockRecorder)
expect func(t *testing.T, peer *resource.Peer, mockPeers []*resource.Peer, parent *resource.Peer, ok bool) expect func(t *testing.T, peer *resource.Peer, mockPeers []*resource.Peer, parent *resource.Peer, ok bool)
}{ }{
{ {
name: "task peers is empty", name: "task peers is empty",
mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet, md *configmocks.MockDynconfigInterfaceMockRecorder) { mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet[string], md *configmocks.MockDynconfigInterfaceMockRecorder) {
peer.FSM.SetState(resource.PeerStateRunning) peer.FSM.SetState(resource.PeerStateRunning)
md.GetSchedulerClusterConfig().Return(types.SchedulerClusterConfig{}, false).Times(1) md.GetSchedulerClusterConfig().Return(types.SchedulerClusterConfig{}, false).Times(1)
@ -615,7 +615,7 @@ func TestScheduler_FindParent(t *testing.T) {
}, },
{ {
name: "task contains only one peer and peer is itself", name: "task contains only one peer and peer is itself",
mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet, md *configmocks.MockDynconfigInterfaceMockRecorder) { mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet[string], md *configmocks.MockDynconfigInterfaceMockRecorder) {
peer.FSM.SetState(resource.PeerStateRunning) peer.FSM.SetState(resource.PeerStateRunning)
peer.Task.StorePeer(peer) peer.Task.StorePeer(peer)
@ -628,7 +628,7 @@ func TestScheduler_FindParent(t *testing.T) {
}, },
{ {
name: "peer is in blocklist", name: "peer is in blocklist",
mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet, md *configmocks.MockDynconfigInterfaceMockRecorder) { mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet[string], md *configmocks.MockDynconfigInterfaceMockRecorder) {
peer.FSM.SetState(resource.PeerStateRunning) peer.FSM.SetState(resource.PeerStateRunning)
peer.Task.StorePeer(peer) peer.Task.StorePeer(peer)
peer.Task.StorePeer(mockPeers[0]) peer.Task.StorePeer(mockPeers[0])
@ -643,7 +643,7 @@ func TestScheduler_FindParent(t *testing.T) {
}, },
{ {
name: "peer is bad node", name: "peer is bad node",
mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet, md *configmocks.MockDynconfigInterfaceMockRecorder) { mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet[string], md *configmocks.MockDynconfigInterfaceMockRecorder) {
peer.FSM.SetState(resource.PeerStateRunning) peer.FSM.SetState(resource.PeerStateRunning)
mockPeers[0].FSM.SetState(resource.PeerStateFailed) mockPeers[0].FSM.SetState(resource.PeerStateFailed)
peer.Task.StorePeer(peer) peer.Task.StorePeer(peer)
@ -658,7 +658,7 @@ func TestScheduler_FindParent(t *testing.T) {
}, },
{ {
name: "parent is peer's descendant", name: "parent is peer's descendant",
mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet, md *configmocks.MockDynconfigInterfaceMockRecorder) { mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet[string], md *configmocks.MockDynconfigInterfaceMockRecorder) {
peer.FSM.SetState(resource.PeerStateRunning) peer.FSM.SetState(resource.PeerStateRunning)
mockPeers[0].FSM.SetState(resource.PeerStateRunning) mockPeers[0].FSM.SetState(resource.PeerStateRunning)
peer.Task.StorePeer(peer) peer.Task.StorePeer(peer)
@ -676,7 +676,7 @@ func TestScheduler_FindParent(t *testing.T) {
}, },
{ {
name: "parent free upload load is zero", name: "parent free upload load is zero",
mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet, md *configmocks.MockDynconfigInterfaceMockRecorder) { mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet[string], md *configmocks.MockDynconfigInterfaceMockRecorder) {
peer.FSM.SetState(resource.PeerStateRunning) peer.FSM.SetState(resource.PeerStateRunning)
mockPeers[0].FSM.SetState(resource.PeerStateRunning) mockPeers[0].FSM.SetState(resource.PeerStateRunning)
peer.Task.StorePeer(peer) peer.Task.StorePeer(peer)
@ -692,15 +692,15 @@ func TestScheduler_FindParent(t *testing.T) {
}, },
{ {
name: "find back-to-source parent", name: "find back-to-source parent",
mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet, md *configmocks.MockDynconfigInterfaceMockRecorder) { mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet[string], md *configmocks.MockDynconfigInterfaceMockRecorder) {
peer.FSM.SetState(resource.PeerStateRunning) peer.FSM.SetState(resource.PeerStateRunning)
mockPeers[0].FSM.SetState(resource.PeerStateRunning) mockPeers[0].FSM.SetState(resource.PeerStateRunning)
mockPeers[1].FSM.SetState(resource.PeerStateRunning) mockPeers[1].FSM.SetState(resource.PeerStateRunning)
peer.Task.StorePeer(peer) peer.Task.StorePeer(peer)
peer.Task.StorePeer(mockPeers[0]) peer.Task.StorePeer(mockPeers[0])
peer.Task.StorePeer(mockPeers[1]) peer.Task.StorePeer(mockPeers[1])
peer.Task.BackToSourcePeers.Add(mockPeers[0]) peer.Task.BackToSourcePeers.Add(mockPeers[0].ID)
peer.Task.BackToSourcePeers.Add(mockPeers[1]) peer.Task.BackToSourcePeers.Add(mockPeers[1].ID)
mockPeers[0].IsBackToSource.Store(true) mockPeers[0].IsBackToSource.Store(true)
mockPeers[1].IsBackToSource.Store(true) mockPeers[1].IsBackToSource.Store(true)
mockPeers[0].Pieces.Set(0) mockPeers[0].Pieces.Set(0)
@ -718,7 +718,7 @@ func TestScheduler_FindParent(t *testing.T) {
}, },
{ {
name: "find seed peer parent", name: "find seed peer parent",
mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet, md *configmocks.MockDynconfigInterfaceMockRecorder) { mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet[string], md *configmocks.MockDynconfigInterfaceMockRecorder) {
peer.FSM.SetState(resource.PeerStateRunning) peer.FSM.SetState(resource.PeerStateRunning)
mockPeers[0].FSM.SetState(resource.PeerStateRunning) mockPeers[0].FSM.SetState(resource.PeerStateRunning)
mockPeers[1].FSM.SetState(resource.PeerStateRunning) mockPeers[1].FSM.SetState(resource.PeerStateRunning)
@ -743,7 +743,7 @@ func TestScheduler_FindParent(t *testing.T) {
}, },
{ {
name: "parent state is PeerStateSucceeded", name: "parent state is PeerStateSucceeded",
mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet, md *configmocks.MockDynconfigInterfaceMockRecorder) { mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet[string], md *configmocks.MockDynconfigInterfaceMockRecorder) {
peer.FSM.SetState(resource.PeerStateRunning) peer.FSM.SetState(resource.PeerStateRunning)
mockPeers[0].FSM.SetState(resource.PeerStateSucceeded) mockPeers[0].FSM.SetState(resource.PeerStateSucceeded)
mockPeers[1].FSM.SetState(resource.PeerStateSucceeded) mockPeers[1].FSM.SetState(resource.PeerStateSucceeded)
@ -765,7 +765,7 @@ func TestScheduler_FindParent(t *testing.T) {
}, },
{ {
name: "find parent with ancestor", name: "find parent with ancestor",
mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet, md *configmocks.MockDynconfigInterfaceMockRecorder) { mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet[string], md *configmocks.MockDynconfigInterfaceMockRecorder) {
peer.FSM.SetState(resource.PeerStateRunning) peer.FSM.SetState(resource.PeerStateRunning)
mockPeers[0].FSM.SetState(resource.PeerStateRunning) mockPeers[0].FSM.SetState(resource.PeerStateRunning)
mockPeers[1].FSM.SetState(resource.PeerStateRunning) mockPeers[1].FSM.SetState(resource.PeerStateRunning)
@ -796,15 +796,15 @@ func TestScheduler_FindParent(t *testing.T) {
}, },
{ {
name: "find parent and fetch filterParentLimit from manager dynconfig", name: "find parent and fetch filterParentLimit from manager dynconfig",
mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet, md *configmocks.MockDynconfigInterfaceMockRecorder) { mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet[string], md *configmocks.MockDynconfigInterfaceMockRecorder) {
peer.FSM.SetState(resource.PeerStateRunning) peer.FSM.SetState(resource.PeerStateRunning)
mockPeers[0].FSM.SetState(resource.PeerStateRunning) mockPeers[0].FSM.SetState(resource.PeerStateRunning)
mockPeers[1].FSM.SetState(resource.PeerStateRunning) mockPeers[1].FSM.SetState(resource.PeerStateRunning)
peer.Task.StorePeer(peer) peer.Task.StorePeer(peer)
peer.Task.StorePeer(mockPeers[0]) peer.Task.StorePeer(mockPeers[0])
peer.Task.StorePeer(mockPeers[1]) peer.Task.StorePeer(mockPeers[1])
peer.Task.BackToSourcePeers.Add(mockPeers[0]) peer.Task.BackToSourcePeers.Add(mockPeers[0].ID)
peer.Task.BackToSourcePeers.Add(mockPeers[1]) peer.Task.BackToSourcePeers.Add(mockPeers[1].ID)
mockPeers[0].IsBackToSource.Store(true) mockPeers[0].IsBackToSource.Store(true)
mockPeers[1].IsBackToSource.Store(true) mockPeers[1].IsBackToSource.Store(true)
mockPeers[0].Pieces.Set(0) mockPeers[0].Pieces.Set(0)
@ -839,7 +839,7 @@ func TestScheduler_FindParent(t *testing.T) {
mockPeers = append(mockPeers, peer) mockPeers = append(mockPeers, peer)
} }
blocklist := set.NewSafeSet() blocklist := set.NewSafeSet[string]()
tc.mock(peer, mockPeers, blocklist, dynconfig.EXPECT()) tc.mock(peer, mockPeers, blocklist, dynconfig.EXPECT())
scheduler := New(mockSchedulerConfig, dynconfig, mockPluginDir) scheduler := New(mockSchedulerConfig, dynconfig, mockPluginDir)
parent, ok := scheduler.FindParent(context.Background(), peer, blocklist) parent, ok := scheduler.FindParent(context.Background(), peer, blocklist)

View File

@ -123,7 +123,7 @@ func (s *Service) RegisterPeerTask(ctx context.Context, req *rpcscheduler.PeerTa
case base.SizeScope_SMALL: case base.SizeScope_SMALL:
peer.Log.Info("task size scope is small") peer.Log.Info("task size scope is small")
// There is no need to build a tree, just find the parent and return. // There is no need to build a tree, just find the parent and return.
parent, ok := s.scheduler.FindParent(ctx, peer, set.NewSafeSet()) parent, ok := s.scheduler.FindParent(ctx, peer, set.NewSafeSet[string]())
if !ok { if !ok {
peer.Log.Warn("task size scope is small and it can not select parent") peer.Log.Warn("task size scope is small and it can not select parent")
if err := peer.FSM.Event(resource.PeerEventRegisterNormal); err != nil { if err := peer.FSM.Event(resource.PeerEventRegisterNormal); err != nil {
@ -626,7 +626,7 @@ func (s *Service) handleBeginOfPiece(ctx context.Context, peer *resource.Peer) {
} }
peer.Log.Infof("schedule parent because of peer receive begin of piece") peer.Log.Infof("schedule parent because of peer receive begin of piece")
s.scheduler.ScheduleParent(ctx, peer, set.NewSafeSet()) s.scheduler.ScheduleParent(ctx, peer, set.NewSafeSet[string]())
default: default:
peer.Log.Warnf("peer state is %s when receive the begin of piece", peer.FSM.Current()) peer.Log.Warnf("peer state is %s when receive the begin of piece", peer.FSM.Current())
} }

View File

@ -1656,7 +1656,7 @@ func TestService_LeaveTask(t *testing.T) {
gomock.InOrder( gomock.InOrder(
mr.PeerManager().Return(peerManager).Times(1), mr.PeerManager().Return(peerManager).Times(1),
mp.Load(gomock.Any()).Return(peer, true).Times(1), mp.Load(gomock.Any()).Return(peer, true).Times(1),
ms.ScheduleParent(gomock.Any(), gomock.Eq(child), gomock.Eq(set.NewSafeSet())).Return().Times(1), ms.ScheduleParent(gomock.Any(), gomock.Eq(child), gomock.Eq(set.NewSafeSet[string]())).Return().Times(1),
mr.PeerManager().Return(peerManager).Times(1), mr.PeerManager().Return(peerManager).Times(1),
mp.Delete(gomock.Eq(peer.ID)).Return().Times(1), mp.Delete(gomock.Eq(peer.ID)).Return().Times(1),
) )
@ -1674,7 +1674,7 @@ func TestService_LeaveTask(t *testing.T) {
peer.Task.StorePeer(peer) peer.Task.StorePeer(peer)
peer.FSM.SetState(resource.PeerStateSucceeded) peer.FSM.SetState(resource.PeerStateSucceeded)
blocklist := set.NewSafeSet() blocklist := set.NewSafeSet[string]()
blocklist.Add(peer.ID) blocklist.Add(peer.ID)
gomock.InOrder( gomock.InOrder(
mr.PeerManager().Return(peerManager).Times(1), mr.PeerManager().Return(peerManager).Times(1),
@ -1704,7 +1704,7 @@ func TestService_LeaveTask(t *testing.T) {
gomock.InOrder( gomock.InOrder(
mr.PeerManager().Return(peerManager).Times(1), mr.PeerManager().Return(peerManager).Times(1),
mp.Load(gomock.Any()).Return(peer, true).Times(1), mp.Load(gomock.Any()).Return(peer, true).Times(1),
ms.ScheduleParent(gomock.Any(), gomock.Eq(child), gomock.Eq(set.NewSafeSet())).Return().Times(1), ms.ScheduleParent(gomock.Any(), gomock.Eq(child), gomock.Eq(set.NewSafeSet[string]())).Return().Times(1),
mr.PeerManager().Return(peerManager).Times(1), mr.PeerManager().Return(peerManager).Times(1),
mp.Delete(gomock.Eq(peer.ID)).Return().Times(1), mp.Delete(gomock.Eq(peer.ID)).Return().Times(1),
) )
@ -1722,7 +1722,7 @@ func TestService_LeaveTask(t *testing.T) {
peer.Task.StorePeer(peer) peer.Task.StorePeer(peer)
peer.FSM.SetState(resource.PeerStateFailed) peer.FSM.SetState(resource.PeerStateFailed)
blocklist := set.NewSafeSet() blocklist := set.NewSafeSet[string]()
blocklist.Add(peer.ID) blocklist.Add(peer.ID)
gomock.InOrder( gomock.InOrder(
mr.PeerManager().Return(peerManager).Times(1), mr.PeerManager().Return(peerManager).Times(1),
@ -2310,7 +2310,7 @@ func TestService_handleBeginOfPiece(t *testing.T) {
name: "peer state is PeerStateReceivedNormal", name: "peer state is PeerStateReceivedNormal",
mock: func(peer *resource.Peer, scheduler *mocks.MockSchedulerMockRecorder) { mock: func(peer *resource.Peer, scheduler *mocks.MockSchedulerMockRecorder) {
peer.FSM.SetState(resource.PeerStateReceivedNormal) peer.FSM.SetState(resource.PeerStateReceivedNormal)
scheduler.ScheduleParent(gomock.Any(), gomock.Eq(peer), gomock.Eq(set.NewSafeSet())).Return().Times(1) scheduler.ScheduleParent(gomock.Any(), gomock.Eq(peer), gomock.Eq(set.NewSafeSet[string]())).Return().Times(1)
}, },
expect: func(t *testing.T, peer *resource.Peer) { expect: func(t *testing.T, peer *resource.Peer) {
assert := assert.New(t) assert := assert.New(t)
@ -2537,7 +2537,7 @@ func TestService_handlePieceFail(t *testing.T) {
parent: resource.NewPeer(mockSeedPeerID, mockTask, mockHost), parent: resource.NewPeer(mockSeedPeerID, mockTask, mockHost),
run: func(t *testing.T, svc *Service, peer *resource.Peer, parent *resource.Peer, piece *rpcscheduler.PieceResult, peerManager resource.PeerManager, seedPeer resource.SeedPeer, ms *mocks.MockSchedulerMockRecorder, mr *resource.MockResourceMockRecorder, mp *resource.MockPeerManagerMockRecorder, mc *resource.MockSeedPeerMockRecorder) { run: func(t *testing.T, svc *Service, peer *resource.Peer, parent *resource.Peer, piece *rpcscheduler.PieceResult, peerManager resource.PeerManager, seedPeer resource.SeedPeer, ms *mocks.MockSchedulerMockRecorder, mr *resource.MockResourceMockRecorder, mp *resource.MockPeerManagerMockRecorder, mc *resource.MockSeedPeerMockRecorder) {
peer.FSM.SetState(resource.PeerStateRunning) peer.FSM.SetState(resource.PeerStateRunning)
blocklist := set.NewSafeSet() blocklist := set.NewSafeSet[string]()
blocklist.Add(mockSeedPeerID) blocklist.Add(mockSeedPeerID)
gomock.InOrder( gomock.InOrder(
mr.PeerManager().Return(peerManager).Times(1), mr.PeerManager().Return(peerManager).Times(1),
@ -2566,7 +2566,7 @@ func TestService_handlePieceFail(t *testing.T) {
run: func(t *testing.T, svc *Service, peer *resource.Peer, parent *resource.Peer, piece *rpcscheduler.PieceResult, peerManager resource.PeerManager, seedPeer resource.SeedPeer, ms *mocks.MockSchedulerMockRecorder, mr *resource.MockResourceMockRecorder, mp *resource.MockPeerManagerMockRecorder, mc *resource.MockSeedPeerMockRecorder) { run: func(t *testing.T, svc *Service, peer *resource.Peer, parent *resource.Peer, piece *rpcscheduler.PieceResult, peerManager resource.PeerManager, seedPeer resource.SeedPeer, ms *mocks.MockSchedulerMockRecorder, mr *resource.MockResourceMockRecorder, mp *resource.MockPeerManagerMockRecorder, mc *resource.MockSeedPeerMockRecorder) {
peer.FSM.SetState(resource.PeerStateRunning) peer.FSM.SetState(resource.PeerStateRunning)
parent.FSM.SetState(resource.PeerStateRunning) parent.FSM.SetState(resource.PeerStateRunning)
blocklist := set.NewSafeSet() blocklist := set.NewSafeSet[string]()
blocklist.Add(parent.ID) blocklist.Add(parent.ID)
gomock.InOrder( gomock.InOrder(
mr.PeerManager().Return(peerManager).Times(1), mr.PeerManager().Return(peerManager).Times(1),
@ -2596,7 +2596,7 @@ func TestService_handlePieceFail(t *testing.T) {
run: func(t *testing.T, svc *Service, peer *resource.Peer, parent *resource.Peer, piece *rpcscheduler.PieceResult, peerManager resource.PeerManager, seedPeer resource.SeedPeer, ms *mocks.MockSchedulerMockRecorder, mr *resource.MockResourceMockRecorder, mp *resource.MockPeerManagerMockRecorder, mc *resource.MockSeedPeerMockRecorder) { run: func(t *testing.T, svc *Service, peer *resource.Peer, parent *resource.Peer, piece *rpcscheduler.PieceResult, peerManager resource.PeerManager, seedPeer resource.SeedPeer, ms *mocks.MockSchedulerMockRecorder, mr *resource.MockResourceMockRecorder, mp *resource.MockPeerManagerMockRecorder, mc *resource.MockSeedPeerMockRecorder) {
peer.FSM.SetState(resource.PeerStateRunning) peer.FSM.SetState(resource.PeerStateRunning)
peer.Host.Type = resource.HostTypeNormal peer.Host.Type = resource.HostTypeNormal
blocklist := set.NewSafeSet() blocklist := set.NewSafeSet[string]()
blocklist.Add(parent.ID) blocklist.Add(parent.ID)
gomock.InOrder( gomock.InOrder(
mr.PeerManager().Return(peerManager).Times(1), mr.PeerManager().Return(peerManager).Times(1),
@ -2625,7 +2625,7 @@ func TestService_handlePieceFail(t *testing.T) {
run: func(t *testing.T, svc *Service, peer *resource.Peer, parent *resource.Peer, piece *rpcscheduler.PieceResult, peerManager resource.PeerManager, seedPeer resource.SeedPeer, ms *mocks.MockSchedulerMockRecorder, mr *resource.MockResourceMockRecorder, mp *resource.MockPeerManagerMockRecorder, mc *resource.MockSeedPeerMockRecorder) { run: func(t *testing.T, svc *Service, peer *resource.Peer, parent *resource.Peer, piece *rpcscheduler.PieceResult, peerManager resource.PeerManager, seedPeer resource.SeedPeer, ms *mocks.MockSchedulerMockRecorder, mr *resource.MockResourceMockRecorder, mp *resource.MockPeerManagerMockRecorder, mc *resource.MockSeedPeerMockRecorder) {
peer.FSM.SetState(resource.PeerStateRunning) peer.FSM.SetState(resource.PeerStateRunning)
parent.FSM.SetState(resource.PeerStateRunning) parent.FSM.SetState(resource.PeerStateRunning)
blocklist := set.NewSafeSet() blocklist := set.NewSafeSet[string]()
blocklist.Add(parent.ID) blocklist.Add(parent.ID)
gomock.InOrder( gomock.InOrder(
mr.PeerManager().Return(peerManager).Times(1), mr.PeerManager().Return(peerManager).Times(1),
@ -2655,7 +2655,7 @@ func TestService_handlePieceFail(t *testing.T) {
run: func(t *testing.T, svc *Service, peer *resource.Peer, parent *resource.Peer, piece *rpcscheduler.PieceResult, peerManager resource.PeerManager, seedPeer resource.SeedPeer, ms *mocks.MockSchedulerMockRecorder, mr *resource.MockResourceMockRecorder, mp *resource.MockPeerManagerMockRecorder, mc *resource.MockSeedPeerMockRecorder) { run: func(t *testing.T, svc *Service, peer *resource.Peer, parent *resource.Peer, piece *rpcscheduler.PieceResult, peerManager resource.PeerManager, seedPeer resource.SeedPeer, ms *mocks.MockSchedulerMockRecorder, mr *resource.MockResourceMockRecorder, mp *resource.MockPeerManagerMockRecorder, mc *resource.MockSeedPeerMockRecorder) {
peer.FSM.SetState(resource.PeerStateRunning) peer.FSM.SetState(resource.PeerStateRunning)
parent.FSM.SetState(resource.PeerStateRunning) parent.FSM.SetState(resource.PeerStateRunning)
blocklist := set.NewSafeSet() blocklist := set.NewSafeSet[string]()
blocklist.Add(parent.ID) blocklist.Add(parent.ID)
gomock.InOrder( gomock.InOrder(
mr.PeerManager().Return(peerManager).Times(1), mr.PeerManager().Return(peerManager).Times(1),
@ -2844,7 +2844,7 @@ func TestService_handlePeerFail(t *testing.T) {
peer.FSM.SetState(resource.PeerStateRunning) peer.FSM.SetState(resource.PeerStateRunning)
child.FSM.SetState(resource.PeerStateRunning) child.FSM.SetState(resource.PeerStateRunning)
ms.ScheduleParent(gomock.Any(), gomock.Eq(child), gomock.Eq(set.NewSafeSet())).Return().Times(1) ms.ScheduleParent(gomock.Any(), gomock.Eq(child), gomock.Eq(set.NewSafeSet[string]())).Return().Times(1)
}, },
expect: func(t *testing.T, peer *resource.Peer, child *resource.Peer) { expect: func(t *testing.T, peer *resource.Peer, child *resource.Peer) {
assert := assert.New(t) assert := assert.New(t)