From 7c2ee7858baacb0925941a9d6ec2cb9037602893 Mon Sep 17 00:00:00 2001 From: Gaius Date: Tue, 26 Jul 2022 21:17:51 +0800 Subject: [PATCH] refactor: set and dag with generics (#1490) Signed-off-by: Gaius --- go.mod | 3 +- go.sum | 2 + pkg/container/list/sorted_list.go | 137 --- pkg/container/list/sorted_list_mock.go | 147 ---- pkg/container/list/sorted_list_test.go | 764 ----------------- pkg/container/list/sorted_unique_list.go | 108 --- pkg/container/list/sorted_unique_list_mock.go | 110 --- pkg/container/list/sorted_unique_list_test.go | 784 ------------------ pkg/container/set/mocks/safe_set_mock.go | 54 +- pkg/container/set/mocks/set_mock.go | 54 +- pkg/container/set/safe_set.go | 36 +- pkg/container/set/safe_set_test.go | 88 +- pkg/container/set/set.go | 32 +- pkg/container/set/set_test.go | 70 +- pkg/dag/dag.go | 158 ++-- pkg/dag/dag_test.go | 230 +++-- pkg/dag/mocks/dag_mock.go | 112 ++- pkg/dag/vertex.go | 48 +- pkg/dag/vertex_test.go | 24 +- scheduler/resource/peer.go | 45 +- scheduler/resource/task.go | 105 +-- scheduler/scheduler/mocks/scheduler_mock.go | 6 +- scheduler/scheduler/scheduler.go | 14 +- scheduler/scheduler/scheduler_test.go | 102 +-- scheduler/service/service.go | 4 +- scheduler/service/service_test.go | 22 +- 26 files changed, 630 insertions(+), 2629 deletions(-) delete mode 100644 pkg/container/list/sorted_list.go delete mode 100644 pkg/container/list/sorted_list_mock.go delete mode 100644 pkg/container/list/sorted_list_test.go delete mode 100644 pkg/container/list/sorted_unique_list.go delete mode 100644 pkg/container/list/sorted_unique_list_mock.go delete mode 100644 pkg/container/list/sorted_unique_list_test.go diff --git a/go.mod b/go.mod index 25a0e3b8b..322f2fb98 100644 --- a/go.mod +++ b/go.mod @@ -45,6 +45,7 @@ require ( github.com/montanaflynn/stats v0.6.6 github.com/onsi/ginkgo/v2 v2.1.4 github.com/onsi/gomega v1.19.0 + github.com/orcaman/concurrent-map/v2 v2.0.0 github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 github.com/prometheus/client_golang v1.12.2 github.com/schollz/progressbar/v3 v3.8.6 @@ -77,6 +78,7 @@ require ( gopkg.in/natefinch/lumberjack.v2 v2.0.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/mysql v1.3.4 + gorm.io/driver/postgres v1.3.7 gorm.io/gorm v1.23.6 gorm.io/plugin/soft_delete v1.1.0 k8s.io/apimachinery v0.24.2 @@ -201,7 +203,6 @@ require ( google.golang.org/genproto v0.0.0-20220628213854-d9e0b6570c03 // indirect gopkg.in/ini.v1 v1.66.6 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect - gorm.io/driver/postgres v1.3.7 // indirect gorm.io/driver/sqlserver v1.3.2 // indirect gorm.io/plugin/dbresolver v1.2.1 // indirect k8s.io/klog/v2 v2.60.1 // indirect diff --git a/go.sum b/go.sum index b4267022d..0f25c4094 100644 --- a/go.sum +++ b/go.sum @@ -822,6 +822,8 @@ github.com/openzipkin-contrib/zipkin-go-opentracing v0.4.5/go.mod h1:/wsWhb9smxS github.com/openzipkin/zipkin-go v0.1.6/go.mod h1:QgAqvLzwWbR/WpD4A3cGpPtJrZXNIiJc5AZX7/PBEpw= github.com/openzipkin/zipkin-go v0.2.1/go.mod h1:NaW6tEwdmWMaCDZzg8sh+IBNOxHMPnhQw8ySjnjRyN4= github.com/openzipkin/zipkin-go v0.2.2/go.mod h1:NaW6tEwdmWMaCDZzg8sh+IBNOxHMPnhQw8ySjnjRyN4= +github.com/orcaman/concurrent-map/v2 v2.0.0 h1:iSMwuBQvQ1nX5i9gYuGMiSy0fjWHmazdjF+NdSO9JzI= +github.com/orcaman/concurrent-map/v2 v2.0.0/go.mod h1:9Eq3TG2oBe5FirmYWQfYO5iH1q0Jv47PLaNK++uCdOM= github.com/otiai10/copy v1.7.0/go.mod h1:rmRl6QPdJj6EiUqXQ/4Nn2lLXoNQjFCQbbNrxgc/t3U= github.com/otiai10/curr v0.0.0-20150429015615-9b4961190c95/go.mod h1:9qAhocn7zKJG+0mI8eUu6xqkFDYS2kb2saOteoSB3cE= github.com/otiai10/curr v1.0.0/go.mod h1:LskTG5wDwr8Rs+nNQ+1LlxRjAtTZZjtJW4rMXl6j4vs= diff --git a/pkg/container/list/sorted_list.go b/pkg/container/list/sorted_list.go deleted file mode 100644 index 8b0cb0bb7..000000000 --- a/pkg/container/list/sorted_list.go +++ /dev/null @@ -1,137 +0,0 @@ -/* - * Copyright 2020 The Dragonfly Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -//go:generate mockgen -destination sorted_list_mock.go -source sorted_list.go -package list - -package list - -import ( - "container/list" - "sync" -) - -type Item interface { - SortedValue() int -} - -type SortedList interface { - Len() int - Insert(Item) - Remove(Item) - Contains(Item) bool - Range(func(Item) bool) - ReverseRange(fn func(Item) bool) -} - -type sortedList struct { - mu *sync.RWMutex - container *list.List -} - -func NewSortedList() SortedList { - return &sortedList{ - mu: &sync.RWMutex{}, - container: list.New(), - } -} - -func (l *sortedList) Len() int { - l.mu.RLock() - defer l.mu.RUnlock() - - return l.container.Len() -} - -func (l *sortedList) Insert(item Item) { - l.mu.Lock() - defer l.mu.Unlock() - - for e := l.container.Front(); e != nil; e = e.Next() { - v, ok := e.Value.(Item) - if !ok { - continue - } - - if v.SortedValue() >= item.SortedValue() { - l.container.InsertBefore(item, e) - return - } - } - - l.container.PushBack(item) -} - -func (l *sortedList) Remove(item Item) { - l.mu.Lock() - defer l.mu.Unlock() - - for e := l.container.Front(); e != nil; e = e.Next() { - v, ok := e.Value.(Item) - if !ok { - continue - } - - if v == item { - l.container.Remove(e) - return - } - } -} - -func (l *sortedList) Contains(item Item) bool { - l.mu.RLock() - defer l.mu.RUnlock() - - for e := l.container.Front(); e != nil; e = e.Next() { - if v, ok := e.Value.(Item); ok && v == item { - return true - } - } - - return false -} - -func (l *sortedList) Range(fn func(Item) bool) { - l.mu.RLock() - defer l.mu.RUnlock() - - for e := l.container.Front(); e != nil; e = e.Next() { - v, ok := e.Value.(Item) - if !ok { - continue - } - - if !fn(v) { - return - } - } -} - -func (l *sortedList) ReverseRange(fn func(Item) bool) { - l.mu.RLock() - defer l.mu.RUnlock() - - for e := l.container.Back(); e != nil; e = e.Prev() { - v, ok := e.Value.(Item) - if !ok { - continue - } - - if !fn(v) { - return - } - } -} diff --git a/pkg/container/list/sorted_list_mock.go b/pkg/container/list/sorted_list_mock.go deleted file mode 100644 index 5eb6fc859..000000000 --- a/pkg/container/list/sorted_list_mock.go +++ /dev/null @@ -1,147 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: sorted_list.go - -// Package list is a generated GoMock package. -package list - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" -) - -// MockItem is a mock of Item interface. -type MockItem struct { - ctrl *gomock.Controller - recorder *MockItemMockRecorder -} - -// MockItemMockRecorder is the mock recorder for MockItem. -type MockItemMockRecorder struct { - mock *MockItem -} - -// NewMockItem creates a new mock instance. -func NewMockItem(ctrl *gomock.Controller) *MockItem { - mock := &MockItem{ctrl: ctrl} - mock.recorder = &MockItemMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockItem) EXPECT() *MockItemMockRecorder { - return m.recorder -} - -// SortedValue mocks base method. -func (m *MockItem) SortedValue() int { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SortedValue") - ret0, _ := ret[0].(int) - return ret0 -} - -// SortedValue indicates an expected call of SortedValue. -func (mr *MockItemMockRecorder) SortedValue() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SortedValue", reflect.TypeOf((*MockItem)(nil).SortedValue)) -} - -// MockSortedList is a mock of SortedList interface. -type MockSortedList struct { - ctrl *gomock.Controller - recorder *MockSortedListMockRecorder -} - -// MockSortedListMockRecorder is the mock recorder for MockSortedList. -type MockSortedListMockRecorder struct { - mock *MockSortedList -} - -// NewMockSortedList creates a new mock instance. -func NewMockSortedList(ctrl *gomock.Controller) *MockSortedList { - mock := &MockSortedList{ctrl: ctrl} - mock.recorder = &MockSortedListMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockSortedList) EXPECT() *MockSortedListMockRecorder { - return m.recorder -} - -// Contains mocks base method. -func (m *MockSortedList) Contains(arg0 Item) bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Contains", arg0) - ret0, _ := ret[0].(bool) - return ret0 -} - -// Contains indicates an expected call of Contains. -func (mr *MockSortedListMockRecorder) Contains(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Contains", reflect.TypeOf((*MockSortedList)(nil).Contains), arg0) -} - -// Insert mocks base method. -func (m *MockSortedList) Insert(arg0 Item) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Insert", arg0) -} - -// Insert indicates an expected call of Insert. -func (mr *MockSortedListMockRecorder) Insert(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Insert", reflect.TypeOf((*MockSortedList)(nil).Insert), arg0) -} - -// Len mocks base method. -func (m *MockSortedList) Len() int { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Len") - ret0, _ := ret[0].(int) - return ret0 -} - -// Len indicates an expected call of Len. -func (mr *MockSortedListMockRecorder) Len() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Len", reflect.TypeOf((*MockSortedList)(nil).Len)) -} - -// Range mocks base method. -func (m *MockSortedList) Range(arg0 func(Item) bool) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Range", arg0) -} - -// Range indicates an expected call of Range. -func (mr *MockSortedListMockRecorder) Range(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Range", reflect.TypeOf((*MockSortedList)(nil).Range), arg0) -} - -// Remove mocks base method. -func (m *MockSortedList) Remove(arg0 Item) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Remove", arg0) -} - -// Remove indicates an expected call of Remove. -func (mr *MockSortedListMockRecorder) Remove(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockSortedList)(nil).Remove), arg0) -} - -// ReverseRange mocks base method. -func (m *MockSortedList) ReverseRange(fn func(Item) bool) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReverseRange", fn) -} - -// ReverseRange indicates an expected call of ReverseRange. -func (mr *MockSortedListMockRecorder) ReverseRange(fn interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReverseRange", reflect.TypeOf((*MockSortedList)(nil).ReverseRange), fn) -} diff --git a/pkg/container/list/sorted_list_test.go b/pkg/container/list/sorted_list_test.go deleted file mode 100644 index 56bc29988..000000000 --- a/pkg/container/list/sorted_list_test.go +++ /dev/null @@ -1,764 +0,0 @@ -/* - * Copyright 2020 The Dragonfly Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package list - -import ( - "math/rand" - "runtime" - "sync" - "testing" - - "github.com/golang/mock/gomock" - "github.com/stretchr/testify/assert" -) - -const N = 1000 - -func TestSortedListInsert(t *testing.T) { - tests := []struct { - name string - mock func(m ...*MockItemMockRecorder) - expect func(t *testing.T, l SortedList, items ...Item) - }{ - { - name: "insert values", - mock: func(m ...*MockItemMockRecorder) {}, - expect: func(t *testing.T, l SortedList, items ...Item) { - assert := assert.New(t) - l.Insert(items[0]) - assert.Equal(l.Contains(items[0]), true) - assert.Equal(l.Len(), 1) - }, - }, - { - name: "insert multi value", - mock: func(m ...*MockItemMockRecorder) { - gomock.InOrder( - m[0].SortedValue().Return(0).Times(1), - m[1].SortedValue().Return(1).Times(1), - ) - }, - expect: func(t *testing.T, l SortedList, items ...Item) { - assert := assert.New(t) - l.Insert(items[0]) - l.Insert(items[1]) - assert.Equal(l.Contains(items[0]), true) - assert.Equal(l.Contains(items[1]), true) - assert.Equal(l.Len(), 2) - }, - }, - { - name: "insert same values", - mock: func(m ...*MockItemMockRecorder) { - gomock.InOrder( - m[0].SortedValue().Return(0).Times(2), - ) - }, - expect: func(t *testing.T, l SortedList, items ...Item) { - assert := assert.New(t) - l.Insert(items[0]) - l.Insert(items[0]) - assert.Equal(l.Contains(items[0]), true) - assert.Equal(l.Len(), 2) - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - ctl := gomock.NewController(t) - defer ctl.Finish() - - mockItems := []*MockItem{NewMockItem(ctl), NewMockItem(ctl)} - tc.mock(mockItems[0].EXPECT(), mockItems[1].EXPECT()) - tc.expect(t, NewSortedList(), mockItems[0], mockItems[1]) - }) - } -} - -func TestSortedListInsert_Concurrent(t *testing.T) { - runtime.GOMAXPROCS(2) - - ctl := gomock.NewController(t) - defer ctl.Finish() - mockItem := NewMockItem(ctl) - mockItem.EXPECT().SortedValue().DoAndReturn(func() int { return rand.Intn(N) }).AnyTimes() - - l := NewSortedList() - nums := rand.Perm(N) - - var wg sync.WaitGroup - wg.Add(len(nums)) - for i := 0; i < len(nums); i++ { - go func(i int) { - l.Insert(mockItem) - wg.Done() - }(i) - } - - wg.Wait() - count := 0 - l.Range(func(item Item) bool { - count++ - return true - }) - if count != len(nums) { - t.Errorf("SortedList is missing element") - } -} - -func TestSortedListRemove(t *testing.T) { - tests := []struct { - name string - mock func(m ...*MockItemMockRecorder) - expect func(t *testing.T, l SortedList, items ...Item) - }{ - { - name: "remove values", - mock: func(m ...*MockItemMockRecorder) { - gomock.InOrder( - m[0].SortedValue().Return(0).Times(1), - m[1].SortedValue().Return(1).Times(1), - ) - }, - expect: func(t *testing.T, l SortedList, items ...Item) { - assert := assert.New(t) - l.Insert(items[0]) - l.Insert(items[1]) - assert.Equal(l.Contains(items[0]), true) - assert.Equal(l.Contains(items[1]), true) - assert.Equal(l.Len(), 2) - l.Remove(items[0]) - assert.Equal(l.Contains(items[0]), false) - assert.Equal(l.Len(), 1) - l.Remove(items[1]) - assert.Equal(l.Contains(items[1]), false) - assert.Equal(l.Len(), 0) - }, - }, - { - name: "remove value dost not exits", - mock: func(m ...*MockItemMockRecorder) {}, - expect: func(t *testing.T, l SortedList, items ...Item) { - assert := assert.New(t) - l.Insert(items[0]) - assert.Equal(l.Contains(items[0]), true) - assert.Equal(l.Len(), 1) - l.Remove(items[1]) - assert.Equal(l.Contains(items[0]), true) - assert.Equal(l.Len(), 1) - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - ctl := gomock.NewController(t) - defer ctl.Finish() - - mockItems := []*MockItem{NewMockItem(ctl), NewMockItem(ctl)} - tc.mock(mockItems[0].EXPECT(), mockItems[1].EXPECT()) - tc.expect(t, NewSortedList(), mockItems[0], mockItems[1]) - }) - } -} - -func TestSortedListRemove_Concurrent(t *testing.T) { - runtime.GOMAXPROCS(2) - - ctl := gomock.NewController(t) - defer ctl.Finish() - mockItem := NewMockItem(ctl) - mockItem.EXPECT().SortedValue().DoAndReturn(func() int { return rand.Intn(N) }).AnyTimes() - - l := NewSortedList() - nums := rand.Perm(N) - - for i := 0; i < len(nums); i++ { - l.Insert(mockItem) - } - - var wg sync.WaitGroup - wg.Add(len(nums)) - for i := 0; i < len(nums); i++ { - go func(i int) { - l.Remove(mockItem) - wg.Done() - }(i) - } - - wg.Wait() - if l.Len() != 0 { - t.Errorf("SortedList is redundant elements") - } -} - -func TestSortedListContains(t *testing.T) { - tests := []struct { - name string - mock func(m ...*MockItemMockRecorder) - expect func(t *testing.T, l SortedList, items ...Item) - }{ - { - name: "contains values", - mock: func(m ...*MockItemMockRecorder) { - gomock.InOrder( - m[0].SortedValue().Return(0).Times(1), - m[1].SortedValue().Return(1).Times(1), - ) - }, - expect: func(t *testing.T, l SortedList, items ...Item) { - assert := assert.New(t) - l.Insert(items[0]) - l.Insert(items[1]) - assert.Equal(l.Contains(items[0]), true) - assert.Equal(l.Contains(items[1]), true) - }, - }, - { - name: "contains value dost not exits", - mock: func(m ...*MockItemMockRecorder) {}, - expect: func(t *testing.T, l SortedList, items ...Item) { - assert := assert.New(t) - l.Insert(items[0]) - assert.Equal(l.Contains(items[1]), false) - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - ctl := gomock.NewController(t) - defer ctl.Finish() - - mockItems := []*MockItem{NewMockItem(ctl), NewMockItem(ctl)} - tc.mock(mockItems[0].EXPECT(), mockItems[1].EXPECT()) - tc.expect(t, NewSortedList(), mockItems[0], mockItems[1]) - }) - } -} - -func TestSortedListContains_Concurrent(t *testing.T) { - runtime.GOMAXPROCS(2) - - ctl := gomock.NewController(t) - defer ctl.Finish() - mockItem := NewMockItem(ctl) - mockItem.EXPECT().SortedValue().DoAndReturn(func() int { return rand.Intn(N) }).AnyTimes() - - l := NewSortedList() - nums := rand.Perm(N) - for range nums { - l.Insert(mockItem) - } - - var wg sync.WaitGroup - for range nums { - wg.Add(1) - go func() { - if ok := l.Contains(mockItem); !ok { - t.Error("SortedList contains error") - } - wg.Done() - }() - } - wg.Wait() -} - -func TestSortedListLen(t *testing.T) { - tests := []struct { - name string - mock func(m ...*MockItemMockRecorder) - expect func(t *testing.T, l SortedList, items ...Item) - }{ - { - name: "get length", - mock: func(m ...*MockItemMockRecorder) { - gomock.InOrder( - m[0].SortedValue().Return(0).Times(1), - m[1].SortedValue().Return(1).Times(1), - ) - }, - expect: func(t *testing.T, l SortedList, items ...Item) { - assert := assert.New(t) - l.Insert(items[0]) - l.Insert(items[1]) - assert.Equal(l.Len(), 2) - }, - }, - { - name: "get empty list length", - mock: func(m ...*MockItemMockRecorder) {}, - expect: func(t *testing.T, l SortedList, items ...Item) { - assert := assert.New(t) - assert.Equal(l.Len(), 0) - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - ctl := gomock.NewController(t) - defer ctl.Finish() - - mockItems := []*MockItem{NewMockItem(ctl), NewMockItem(ctl)} - tc.mock(mockItems[0].EXPECT(), mockItems[1].EXPECT()) - tc.expect(t, NewSortedList(), mockItems[0], mockItems[1]) - }) - } -} - -func TestSortedListLen_Concurrent(t *testing.T) { - runtime.GOMAXPROCS(2) - - ctl := gomock.NewController(t) - defer ctl.Finish() - mockItem := NewMockItem(ctl) - mockItem.EXPECT().SortedValue().DoAndReturn(func() int { return rand.Intn(N) }).AnyTimes() - - l := NewSortedList() - - var wg sync.WaitGroup - wg.Add(1) - go func() { - elems := l.Len() - for i := 0; i < N; i++ { - newElems := l.Len() - if newElems < elems { - t.Errorf("Len shrunk from %v to %v", elems, newElems) - } - } - wg.Done() - }() - - for i := 0; i < N; i++ { - l.Insert(mockItem) - } - wg.Wait() -} - -func TestSortedListRange(t *testing.T) { - tests := []struct { - name string - mock func(m ...*MockItemMockRecorder) - expect func(t *testing.T, l SortedList, items ...Item) - }{ - { - name: "range values", - mock: func(m ...*MockItemMockRecorder) { - gomock.InOrder( - m[0].SortedValue().Return(0).Times(1), - m[1].SortedValue().Return(1).Times(1), - ) - }, - expect: func(t *testing.T, l SortedList, items ...Item) { - assert := assert.New(t) - l.Insert(items[0]) - l.Insert(items[1]) - assert.Equal(l.Len(), 2) - - i := 0 - l.Range(func(item Item) bool { - assert.Equal(item, items[i]) - i++ - return true - }) - }, - }, - { - name: "range multi values", - mock: func(m ...*MockItemMockRecorder) { - for i := range m { - m[i].SortedValue().Return(i).AnyTimes() - } - }, - expect: func(t *testing.T, l SortedList, items ...Item) { - assert := assert.New(t) - for _, item := range items { - l.Insert(item) - } - assert.Equal(l.Len(), 10) - - i := 0 - l.Range(func(item Item) bool { - assert.Equal(item, items[i]) - i++ - return true - }) - }, - }, - { - name: "range stoped", - mock: func(m ...*MockItemMockRecorder) { - gomock.InOrder( - m[0].SortedValue().Return(0).Times(1), - m[1].SortedValue().Return(1).Times(1), - ) - }, - expect: func(t *testing.T, l SortedList, items ...Item) { - assert := assert.New(t) - l.Insert(items[0]) - l.Insert(items[1]) - assert.Equal(l.Len(), 2) - - l.Range(func(item Item) bool { - assert.Equal(item, items[0]) - return false - }) - }, - }, - { - name: "range same values", - mock: func(m ...*MockItemMockRecorder) { - gomock.InOrder( - m[0].SortedValue().Return(0).AnyTimes(), - ) - }, - expect: func(t *testing.T, l SortedList, items ...Item) { - assert := assert.New(t) - l.Insert(items[0]) - l.Insert(items[0]) - l.Insert(items[0]) - assert.Equal(l.Len(), 3) - - count := 0 - l.Range(func(item Item) bool { - assert.Equal(item, items[0]) - count++ - return true - }) - assert.Equal(count, 3) - }, - }, - { - name: "range empty list", - mock: func(m ...*MockItemMockRecorder) { - }, - expect: func(t *testing.T, l SortedList, items ...Item) { - assert := assert.New(t) - count := 0 - l.Range(func(item Item) bool { - count++ - return true - }) - assert.Equal(count, 0) - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - ctl := gomock.NewController(t) - defer ctl.Finish() - - var mockItems []Item - var mockItemRecorders []*MockItemMockRecorder - for i := 0; i < 10; i++ { - mockItem := NewMockItem(ctl) - mockItemRecorders = append(mockItemRecorders, mockItem.EXPECT()) - mockItems = append(mockItems, mockItem) - } - - tc.mock(mockItemRecorders...) - tc.expect(t, NewSortedList(), mockItems...) - }) - } -} - -func TestSortedListRange_Concurrent(t *testing.T) { - runtime.GOMAXPROCS(2) - - ctl := gomock.NewController(t) - defer ctl.Finish() - mockItem := NewMockItem(ctl) - mockItem.EXPECT().SortedValue().DoAndReturn(func() int { return rand.Intn(N) }).AnyTimes() - - l := NewSortedList() - var wg sync.WaitGroup - wg.Add(1) - go func() { - i := 0 - l.Range(func(_ Item) bool { - i++ - return true - }) - - j := 0 - l.Range(func(_ Item) bool { - j++ - return true - }) - if j < i { - t.Errorf("Values shrunk from %v to %v", i, j) - } - wg.Done() - }() - - for i := 0; i < N; i++ { - l.Insert(mockItem) - } - wg.Wait() -} - -func TestSortedListReverseRange(t *testing.T) { - tests := []struct { - name string - mock func(m ...*MockItemMockRecorder) - expect func(t *testing.T, l SortedList, items ...Item) - }{ - { - name: "reverse range values", - mock: func(m ...*MockItemMockRecorder) { - gomock.InOrder( - m[0].SortedValue().Return(0).Times(1), - m[1].SortedValue().Return(1).Times(1), - ) - }, - expect: func(t *testing.T, l SortedList, items ...Item) { - assert := assert.New(t) - l.Insert(items[0]) - l.Insert(items[1]) - assert.Equal(l.Len(), 2) - - i := 0 - l.ReverseRange(func(item Item) bool { - assert.Equal(item, items[i]) - i++ - return true - }) - }, - }, - { - name: "reverse range multi values", - mock: func(m ...*MockItemMockRecorder) { - for i := range m { - m[i].SortedValue().Return(i).AnyTimes() - } - }, - expect: func(t *testing.T, l SortedList, items ...Item) { - assert := assert.New(t) - for _, item := range items { - l.Insert(item) - } - assert.Equal(l.Len(), 10) - - i := 9 - l.ReverseRange(func(item Item) bool { - assert.Equal(item, items[i]) - i-- - return true - }) - }, - }, - { - name: "reverse range stoped", - mock: func(m ...*MockItemMockRecorder) { - gomock.InOrder( - m[0].SortedValue().Return(0).Times(1), - m[1].SortedValue().Return(1).Times(1), - ) - }, - expect: func(t *testing.T, l SortedList, items ...Item) { - assert := assert.New(t) - l.Insert(items[0]) - l.Insert(items[1]) - assert.Equal(l.Len(), 2) - - l.ReverseRange(func(item Item) bool { - assert.Equal(item, items[1]) - return false - }) - }, - }, - { - name: "reverse range same values", - mock: func(m ...*MockItemMockRecorder) { - gomock.InOrder( - m[0].SortedValue().Return(0).AnyTimes(), - ) - }, - expect: func(t *testing.T, l SortedList, items ...Item) { - assert := assert.New(t) - l.Insert(items[0]) - l.Insert(items[0]) - l.Insert(items[0]) - assert.Equal(l.Len(), 3) - - count := 0 - l.ReverseRange(func(item Item) bool { - assert.Equal(item, items[0]) - count++ - return true - }) - assert.Equal(count, 3) - }, - }, - { - name: "reverse range empty list", - mock: func(m ...*MockItemMockRecorder) { - }, - expect: func(t *testing.T, l SortedList, items ...Item) { - assert := assert.New(t) - count := 0 - l.ReverseRange(func(item Item) bool { - count++ - return true - }) - assert.Equal(count, 0) - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - ctl := gomock.NewController(t) - defer ctl.Finish() - - var mockItems []Item - var mockItemRecorders []*MockItemMockRecorder - for i := 0; i < 10; i++ { - mockItem := NewMockItem(ctl) - mockItemRecorders = append(mockItemRecorders, mockItem.EXPECT()) - mockItems = append(mockItems, mockItem) - } - - tc.mock(mockItemRecorders...) - tc.expect(t, NewSortedList(), mockItems...) - }) - } -} - -func TestSortedListReverseRange_Concurrent(t *testing.T) { - runtime.GOMAXPROCS(2) - - ctl := gomock.NewController(t) - defer ctl.Finish() - mockItem := NewMockItem(ctl) - mockItem.EXPECT().SortedValue().DoAndReturn(func() int { return rand.Intn(N) }).AnyTimes() - - l := NewSortedList() - var wg sync.WaitGroup - wg.Add(1) - go func() { - i := 0 - l.ReverseRange(func(_ Item) bool { - i++ - return true - }) - - j := 0 - l.ReverseRange(func(_ Item) bool { - j++ - return true - }) - if j < i { - t.Errorf("Values shrunk from %v to %v", i, j) - } - wg.Done() - }() - - for i := 0; i < N; i++ { - l.Insert(mockItem) - } - wg.Wait() -} - -type item struct{ id int } - -func (i *item) SortedValue() int { return rand.Intn(1000) } - -func BenchmarkSortedListInsert(b *testing.B) { - l := NewSortedList() - - var mockItems []*item - for i := 0; i < b.N; i++ { - mockItems = append(mockItems, &item{id: i}) - } - - b.ResetTimer() - for _, mockItem := range mockItems { - l.Insert(mockItem) - } -} - -func BenchmarkSortedListRemove(b *testing.B) { - l := NewSortedList() - - var mockItems []*item - for i := 0; i < b.N; i++ { - mockItems = append(mockItems, &item{id: i}) - } - - for _, mockItem := range mockItems { - l.Insert(mockItem) - } - - b.ResetTimer() - for _, mockItem := range mockItems { - l.Remove(mockItem) - } -} - -func BenchmarkSortedListContains(b *testing.B) { - l := NewSortedList() - - var mockItems []*item - for i := 0; i < b.N; i++ { - mockItems = append(mockItems, &item{id: i}) - } - - for _, mockItem := range mockItems { - l.Insert(mockItem) - } - - b.ResetTimer() - for _, mockItem := range mockItems { - l.Contains(mockItem) - } -} - -func BenchmarkSortedListRange(b *testing.B) { - l := NewSortedList() - - var mockItems []*item - for i := 0; i < b.N; i++ { - mockItems = append(mockItems, &item{id: i}) - } - - for _, mockItem := range mockItems { - l.Insert(mockItem) - } - - b.ResetTimer() - l.Range(func(_ Item) bool { return true }) -} - -func BenchmarkSortedListReverseRange(b *testing.B) { - l := NewSortedList() - - var mockItems []*item - for i := 0; i < b.N; i++ { - mockItems = append(mockItems, &item{id: i}) - } - - for _, mockItem := range mockItems { - l.Insert(mockItem) - } - - b.ResetTimer() - l.ReverseRange(func(_ Item) bool { return true }) -} diff --git a/pkg/container/list/sorted_unique_list.go b/pkg/container/list/sorted_unique_list.go deleted file mode 100644 index 32ce6d9b7..000000000 --- a/pkg/container/list/sorted_unique_list.go +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Copyright 2020 The Dragonfly Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -//go:generate mockgen -destination sorted_unique_list_mock.go -source sorted_unique_list.go -package list - -package list - -import ( - "sync" - - "d7y.io/dragonfly/v2/pkg/container/set" -) - -type SortedUniqueList interface { - Len() int - Insert(Item) - Remove(Item) - Contains(Item) bool - Range(func(Item) bool) - ReverseRange(fn func(Item) bool) -} - -type sortedUniqueList struct { - mu *sync.RWMutex - container SortedList - data set.Set -} - -func NewSortedUniqueList() SortedUniqueList { - return &sortedUniqueList{ - mu: &sync.RWMutex{}, - container: NewSortedList(), - data: set.New(), - } -} - -func (ul *sortedUniqueList) Len() int { - ul.mu.RLock() - defer ul.mu.RUnlock() - - return ul.container.Len() -} - -func (ul *sortedUniqueList) Insert(item Item) { - ul.mu.Lock() - defer ul.mu.Unlock() - - if ok := ul.data.Contains(item); ok { - ul.container.Remove(item) - ul.container.Insert(item) - return - } - - ul.data.Add(item) - ul.container.Insert(item) -} - -func (ul *sortedUniqueList) Remove(item Item) { - ul.mu.Lock() - defer ul.mu.Unlock() - - ul.data.Delete(item) - ul.container.Remove(item) -} - -func (ul *sortedUniqueList) Contains(item Item) bool { - ul.mu.RLock() - defer ul.mu.RUnlock() - - return ul.data.Contains(item) -} - -func (ul *sortedUniqueList) Range(fn func(item Item) bool) { - ul.mu.RLock() - defer ul.mu.RUnlock() - - ul.container.Range(func(item Item) bool { - if !fn(item) { - return false - } - return true - }) -} - -func (ul *sortedUniqueList) ReverseRange(fn func(item Item) bool) { - ul.mu.RLock() - defer ul.mu.RUnlock() - - ul.container.ReverseRange(func(item Item) bool { - if !fn(item) { - return false - } - return true - }) -} diff --git a/pkg/container/list/sorted_unique_list_mock.go b/pkg/container/list/sorted_unique_list_mock.go deleted file mode 100644 index 58e96a7aa..000000000 --- a/pkg/container/list/sorted_unique_list_mock.go +++ /dev/null @@ -1,110 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: sorted_unique_list.go - -// Package list is a generated GoMock package. -package list - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" -) - -// MockSortedUniqueList is a mock of SortedUniqueList interface. -type MockSortedUniqueList struct { - ctrl *gomock.Controller - recorder *MockSortedUniqueListMockRecorder -} - -// MockSortedUniqueListMockRecorder is the mock recorder for MockSortedUniqueList. -type MockSortedUniqueListMockRecorder struct { - mock *MockSortedUniqueList -} - -// NewMockSortedUniqueList creates a new mock instance. -func NewMockSortedUniqueList(ctrl *gomock.Controller) *MockSortedUniqueList { - mock := &MockSortedUniqueList{ctrl: ctrl} - mock.recorder = &MockSortedUniqueListMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockSortedUniqueList) EXPECT() *MockSortedUniqueListMockRecorder { - return m.recorder -} - -// Contains mocks base method. -func (m *MockSortedUniqueList) Contains(arg0 Item) bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Contains", arg0) - ret0, _ := ret[0].(bool) - return ret0 -} - -// Contains indicates an expected call of Contains. -func (mr *MockSortedUniqueListMockRecorder) Contains(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Contains", reflect.TypeOf((*MockSortedUniqueList)(nil).Contains), arg0) -} - -// Insert mocks base method. -func (m *MockSortedUniqueList) Insert(arg0 Item) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Insert", arg0) -} - -// Insert indicates an expected call of Insert. -func (mr *MockSortedUniqueListMockRecorder) Insert(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Insert", reflect.TypeOf((*MockSortedUniqueList)(nil).Insert), arg0) -} - -// Len mocks base method. -func (m *MockSortedUniqueList) Len() int { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Len") - ret0, _ := ret[0].(int) - return ret0 -} - -// Len indicates an expected call of Len. -func (mr *MockSortedUniqueListMockRecorder) Len() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Len", reflect.TypeOf((*MockSortedUniqueList)(nil).Len)) -} - -// Range mocks base method. -func (m *MockSortedUniqueList) Range(arg0 func(Item) bool) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Range", arg0) -} - -// Range indicates an expected call of Range. -func (mr *MockSortedUniqueListMockRecorder) Range(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Range", reflect.TypeOf((*MockSortedUniqueList)(nil).Range), arg0) -} - -// Remove mocks base method. -func (m *MockSortedUniqueList) Remove(arg0 Item) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Remove", arg0) -} - -// Remove indicates an expected call of Remove. -func (mr *MockSortedUniqueListMockRecorder) Remove(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockSortedUniqueList)(nil).Remove), arg0) -} - -// ReverseRange mocks base method. -func (m *MockSortedUniqueList) ReverseRange(fn func(Item) bool) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReverseRange", fn) -} - -// ReverseRange indicates an expected call of ReverseRange. -func (mr *MockSortedUniqueListMockRecorder) ReverseRange(fn interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReverseRange", reflect.TypeOf((*MockSortedUniqueList)(nil).ReverseRange), fn) -} diff --git a/pkg/container/list/sorted_unique_list_test.go b/pkg/container/list/sorted_unique_list_test.go deleted file mode 100644 index 4c5f50ecd..000000000 --- a/pkg/container/list/sorted_unique_list_test.go +++ /dev/null @@ -1,784 +0,0 @@ -/* - * Copyright 2020 The Dragonfly Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package list - -import ( - "math/rand" - "runtime" - "sync" - "testing" - - "github.com/golang/mock/gomock" - "github.com/stretchr/testify/assert" -) - -func TestSortedUniqueListInsert(t *testing.T) { - tests := []struct { - name string - mock func(m ...*MockItemMockRecorder) - expect func(t *testing.T, ul SortedUniqueList, items ...Item) - }{ - { - name: "insert values", - mock: func(m ...*MockItemMockRecorder) {}, - expect: func(t *testing.T, ul SortedUniqueList, items ...Item) { - assert := assert.New(t) - ul.Insert(items[0]) - assert.Equal(ul.Contains(items[0]), true) - assert.Equal(ul.Len(), 1) - }, - }, - { - name: "insert multi values", - mock: func(m ...*MockItemMockRecorder) { - gomock.InOrder( - m[0].SortedValue().Return(0).Times(1), - m[1].SortedValue().Return(1).Times(1), - ) - }, - expect: func(t *testing.T, ul SortedUniqueList, items ...Item) { - assert := assert.New(t) - ul.Insert(items[0]) - ul.Insert(items[1]) - assert.Equal(ul.Contains(items[0]), true) - assert.Equal(ul.Contains(items[1]), true) - assert.Equal(ul.Len(), 2) - }, - }, - { - name: "insert same values", - mock: func(m ...*MockItemMockRecorder) {}, - expect: func(t *testing.T, ul SortedUniqueList, items ...Item) { - assert := assert.New(t) - ul.Insert(items[0]) - ul.Insert(items[0]) - assert.Equal(ul.Contains(items[0]), true) - assert.Equal(ul.Len(), 1) - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - ctl := gomock.NewController(t) - defer ctl.Finish() - - mockItems := []*MockItem{NewMockItem(ctl), NewMockItem(ctl)} - tc.mock(mockItems[0].EXPECT(), mockItems[1].EXPECT()) - tc.expect(t, NewSortedUniqueList(), mockItems[0], mockItems[1]) - }) - } -} - -func TestSortedUniqueListInsert_Concurrent(t *testing.T) { - runtime.GOMAXPROCS(2) - - ctl := gomock.NewController(t) - defer ctl.Finish() - - ul := NewSortedUniqueList() - nums := rand.Perm(N) - - var mockItems []Item - for _, v := range nums { - mockItem := NewMockItem(ctl) - mockItem.EXPECT().SortedValue().DoAndReturn(func() int { return v }).AnyTimes() - mockItems = append(mockItems, mockItem) - } - - var wg sync.WaitGroup - wg.Add(len(mockItems)) - for _, mockItem := range mockItems { - go func(item Item) { - ul.Insert(item) - wg.Done() - }(mockItem) - } - - wg.Wait() - count := 0 - ul.Range(func(item Item) bool { - count++ - return true - }) - if count != len(nums) { - t.Errorf("SortedUniqueList is missing element") - } -} - -func TestSortedUniqueListRemove(t *testing.T) { - tests := []struct { - name string - mock func(m ...*MockItemMockRecorder) - expect func(t *testing.T, ul SortedUniqueList, items ...Item) - }{ - { - name: "remove values", - mock: func(m ...*MockItemMockRecorder) { - gomock.InOrder( - m[0].SortedValue().Return(0).Times(1), - m[1].SortedValue().Return(1).Times(1), - ) - }, - expect: func(t *testing.T, ul SortedUniqueList, items ...Item) { - assert := assert.New(t) - ul.Insert(items[0]) - ul.Insert(items[1]) - assert.Equal(ul.Contains(items[0]), true) - assert.Equal(ul.Contains(items[1]), true) - assert.Equal(ul.Len(), 2) - ul.Remove(items[0]) - assert.Equal(ul.Contains(items[0]), false) - assert.Equal(ul.Len(), 1) - ul.Remove(items[1]) - assert.Equal(ul.Contains(items[1]), false) - assert.Equal(ul.Len(), 0) - }, - }, - { - name: "remove value dost not exits", - mock: func(m ...*MockItemMockRecorder) {}, - expect: func(t *testing.T, ul SortedUniqueList, items ...Item) { - assert := assert.New(t) - ul.Insert(items[0]) - assert.Equal(ul.Contains(items[0]), true) - assert.Equal(ul.Len(), 1) - ul.Remove(items[1]) - assert.Equal(ul.Contains(items[0]), true) - assert.Equal(ul.Len(), 1) - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - ctl := gomock.NewController(t) - defer ctl.Finish() - - mockItems := []*MockItem{NewMockItem(ctl), NewMockItem(ctl)} - tc.mock(mockItems[0].EXPECT(), mockItems[1].EXPECT()) - tc.expect(t, NewSortedUniqueList(), mockItems[0], mockItems[1]) - }) - } -} - -func TestSortedUniqueListRemove_Concurrent(t *testing.T) { - runtime.GOMAXPROCS(2) - - ctl := gomock.NewController(t) - defer ctl.Finish() - - ul := NewSortedUniqueList() - nums := rand.Perm(N) - - var mockItems []Item - for _, v := range nums { - mockItem := NewMockItem(ctl) - mockItem.EXPECT().SortedValue().DoAndReturn(func() int { return v }).AnyTimes() - mockItems = append(mockItems, mockItem) - ul.Insert(mockItem) - } - - var wg sync.WaitGroup - wg.Add(len(mockItems)) - for _, mockItem := range mockItems { - go func(item Item) { - ul.Remove(item) - wg.Done() - }(mockItem) - } - - wg.Wait() - if ul.Len() != 0 { - t.Errorf("SortedUniqueList is redundant elements") - } -} - -func TestSortedUniqueListContains(t *testing.T) { - tests := []struct { - name string - mock func(m ...*MockItemMockRecorder) - expect func(t *testing.T, ul SortedUniqueList, items ...Item) - }{ - { - name: "contains values", - mock: func(m ...*MockItemMockRecorder) { - gomock.InOrder( - m[0].SortedValue().Return(0).Times(1), - m[1].SortedValue().Return(1).Times(1), - ) - }, - expect: func(t *testing.T, ul SortedUniqueList, items ...Item) { - assert := assert.New(t) - ul.Insert(items[0]) - ul.Insert(items[1]) - assert.Equal(ul.Contains(items[0]), true) - assert.Equal(ul.Contains(items[1]), true) - }, - }, - { - name: "contains value dost not exits", - mock: func(m ...*MockItemMockRecorder) {}, - expect: func(t *testing.T, ul SortedUniqueList, items ...Item) { - assert := assert.New(t) - ul.Insert(items[0]) - assert.Equal(ul.Contains(items[1]), false) - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - ctl := gomock.NewController(t) - defer ctl.Finish() - - mockItems := []*MockItem{NewMockItem(ctl), NewMockItem(ctl)} - tc.mock(mockItems[0].EXPECT(), mockItems[1].EXPECT()) - tc.expect(t, NewSortedUniqueList(), mockItems[0], mockItems[1]) - }) - } -} - -func TestSortedUniqueListContains_Concurrent(t *testing.T) { - runtime.GOMAXPROCS(2) - - ctl := gomock.NewController(t) - defer ctl.Finish() - - ul := NewSortedUniqueList() - nums := rand.Perm(N) - - var mockItems []Item - for _, v := range nums { - mockItem := NewMockItem(ctl) - mockItem.EXPECT().SortedValue().DoAndReturn(func() int { return v }).AnyTimes() - mockItems = append(mockItems, mockItem) - ul.Insert(mockItem) - } - - var wg sync.WaitGroup - wg.Add(len(mockItems)) - for _, mockItem := range mockItems { - go func(item Item) { - if ok := ul.Contains(item); !ok { - t.Error("SortedUniqueList contains error") - } - wg.Done() - }(mockItem) - } - wg.Wait() -} - -func TestSortedUniqueListLen(t *testing.T) { - tests := []struct { - name string - mock func(m ...*MockItemMockRecorder) - expect func(t *testing.T, ul SortedUniqueList, items ...Item) - }{ - { - name: "get length", - mock: func(m ...*MockItemMockRecorder) { - gomock.InOrder( - m[0].SortedValue().Return(0).Times(1), - m[1].SortedValue().Return(1).Times(1), - ) - }, - expect: func(t *testing.T, ul SortedUniqueList, items ...Item) { - assert := assert.New(t) - ul.Insert(items[0]) - ul.Insert(items[1]) - assert.Equal(ul.Len(), 2) - }, - }, - { - name: "get empty list length", - mock: func(m ...*MockItemMockRecorder) {}, - expect: func(t *testing.T, ul SortedUniqueList, items ...Item) { - assert := assert.New(t) - assert.Equal(ul.Len(), 0) - }, - }, - { - name: "get same values length", - mock: func(m ...*MockItemMockRecorder) {}, - expect: func(t *testing.T, ul SortedUniqueList, items ...Item) { - assert := assert.New(t) - ul.Insert(items[0]) - ul.Insert(items[0]) - assert.Equal(ul.Len(), 1) - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - ctl := gomock.NewController(t) - defer ctl.Finish() - - mockItems := []*MockItem{NewMockItem(ctl), NewMockItem(ctl)} - tc.mock(mockItems[0].EXPECT(), mockItems[1].EXPECT()) - tc.expect(t, NewSortedUniqueList(), mockItems[0], mockItems[1]) - }) - } -} - -func TestSortedUniqueListLen_Concurrent(t *testing.T) { - runtime.GOMAXPROCS(2) - - ctl := gomock.NewController(t) - defer ctl.Finish() - - ul := NewSortedUniqueList() - nums := rand.Perm(N) - - var mockItems []Item - for _, v := range nums { - mockItem := NewMockItem(ctl) - mockItem.EXPECT().SortedValue().DoAndReturn(func() int { return v }).AnyTimes() - mockItems = append(mockItems, mockItem) - } - - var wg sync.WaitGroup - wg.Add(1) - go func() { - elems := ul.Len() - for i := 0; i < N; i++ { - newElems := ul.Len() - if newElems < elems { - t.Errorf("Len shrunk from %v to %v", elems, newElems) - } - } - wg.Done() - }() - - for _, mockItem := range mockItems { - ul.Insert(mockItem) - } - wg.Wait() -} - -func TestSortedUniqueListRange(t *testing.T) { - tests := []struct { - name string - mock func(m ...*MockItemMockRecorder) - expect func(t *testing.T, ul SortedUniqueList, items ...Item) - }{ - { - name: "range values", - mock: func(m ...*MockItemMockRecorder) { - gomock.InOrder( - m[0].SortedValue().Return(0).Times(1), - m[1].SortedValue().Return(1).Times(1), - ) - }, - expect: func(t *testing.T, ul SortedUniqueList, items ...Item) { - assert := assert.New(t) - ul.Insert(items[0]) - ul.Insert(items[1]) - assert.Equal(ul.Len(), 2) - - i := 0 - ul.Range(func(item Item) bool { - assert.Equal(item, items[i]) - i++ - return true - }) - }, - }, - { - name: "range multi values", - mock: func(m ...*MockItemMockRecorder) { - for i := range m { - m[i].SortedValue().Return(i).AnyTimes() - } - }, - expect: func(t *testing.T, ul SortedUniqueList, items ...Item) { - assert := assert.New(t) - for _, item := range items { - ul.Insert(item) - } - ul.Insert(items[1]) - assert.Equal(ul.Len(), 10) - - i := 0 - ul.Range(func(item Item) bool { - assert.Equal(item, items[i]) - i++ - return true - }) - }, - }, - { - name: "range stoped", - mock: func(m ...*MockItemMockRecorder) { - gomock.InOrder( - m[0].SortedValue().Return(0).Times(1), - m[1].SortedValue().Return(1).Times(1), - ) - }, - expect: func(t *testing.T, ul SortedUniqueList, items ...Item) { - assert := assert.New(t) - ul.Insert(items[0]) - ul.Insert(items[1]) - assert.Equal(ul.Len(), 2) - - ul.Range(func(item Item) bool { - assert.Equal(item, items[0]) - return false - }) - }, - }, - { - name: "range same values", - mock: func(m ...*MockItemMockRecorder) { - gomock.InOrder( - m[0].SortedValue().Return(0).AnyTimes(), - ) - }, - expect: func(t *testing.T, ul SortedUniqueList, items ...Item) { - assert := assert.New(t) - ul.Insert(items[0]) - ul.Insert(items[0]) - ul.Insert(items[0]) - assert.Equal(ul.Len(), 1) - - count := 0 - ul.Range(func(item Item) bool { - assert.Equal(item, items[0]) - count++ - return true - }) - assert.Equal(count, 1) - }, - }, - { - name: "range empty list", - mock: func(m ...*MockItemMockRecorder) { - }, - expect: func(t *testing.T, ul SortedUniqueList, items ...Item) { - assert := assert.New(t) - count := 0 - ul.Range(func(item Item) bool { - count++ - return true - }) - assert.Equal(count, 0) - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - ctl := gomock.NewController(t) - defer ctl.Finish() - - var mockItems []Item - var mockItemRecorders []*MockItemMockRecorder - for i := 0; i < 10; i++ { - mockItem := NewMockItem(ctl) - mockItemRecorders = append(mockItemRecorders, mockItem.EXPECT()) - mockItems = append(mockItems, mockItem) - } - - tc.mock(mockItemRecorders...) - tc.expect(t, NewSortedUniqueList(), mockItems...) - }) - } -} - -func TestSortedUniqueListRange_Concurrent(t *testing.T) { - runtime.GOMAXPROCS(2) - - ctl := gomock.NewController(t) - defer ctl.Finish() - - ul := NewSortedUniqueList() - nums := rand.Perm(N) - - var mockItems []Item - for _, v := range nums { - mockItem := NewMockItem(ctl) - mockItem.EXPECT().SortedValue().DoAndReturn(func() int { return v }).AnyTimes() - mockItems = append(mockItems, mockItem) - } - - var wg sync.WaitGroup - wg.Add(1) - go func() { - i := 0 - ul.Range(func(_ Item) bool { - i++ - return true - }) - - j := 0 - ul.Range(func(_ Item) bool { - j++ - return true - }) - if j < i { - t.Errorf("Values shrunk from %v to %v", i, j) - } - wg.Done() - }() - - for _, mockItem := range mockItems { - ul.Insert(mockItem) - } - wg.Wait() -} - -func TestSortedUniqueListReverseRange(t *testing.T) { - tests := []struct { - name string - mock func(m ...*MockItemMockRecorder) - expect func(t *testing.T, ul SortedUniqueList, items ...Item) - }{ - { - name: "reverse range values", - mock: func(m ...*MockItemMockRecorder) { - gomock.InOrder( - m[0].SortedValue().Return(0).Times(1), - m[1].SortedValue().Return(1).Times(1), - ) - }, - expect: func(t *testing.T, ul SortedUniqueList, items ...Item) { - assert := assert.New(t) - ul.Insert(items[0]) - ul.Insert(items[1]) - assert.Equal(ul.Len(), 2) - - i := 1 - ul.ReverseRange(func(item Item) bool { - assert.Equal(item, items[i]) - i-- - return true - }) - }, - }, - { - name: "reverse range multi values", - mock: func(m ...*MockItemMockRecorder) { - for i := range m { - m[i].SortedValue().Return(i).AnyTimes() - } - }, - expect: func(t *testing.T, ul SortedUniqueList, items ...Item) { - assert := assert.New(t) - for _, item := range items { - ul.Insert(item) - } - ul.Insert(items[1]) - assert.Equal(ul.Len(), 10) - - i := 9 - ul.Range(func(item Item) bool { - assert.Equal(item, items[i]) - i-- - return true - }) - }, - }, - { - name: "reverse range stoped", - mock: func(m ...*MockItemMockRecorder) { - gomock.InOrder( - m[0].SortedValue().Return(0).Times(1), - m[1].SortedValue().Return(1).Times(1), - ) - }, - expect: func(t *testing.T, ul SortedUniqueList, items ...Item) { - assert := assert.New(t) - ul.Insert(items[0]) - ul.Insert(items[1]) - assert.Equal(ul.Len(), 2) - - ul.ReverseRange(func(item Item) bool { - assert.Equal(item, items[0]) - return false - }) - }, - }, - { - name: "reverse range same values", - mock: func(m ...*MockItemMockRecorder) { - gomock.InOrder( - m[0].SortedValue().Return(0).AnyTimes(), - ) - }, - expect: func(t *testing.T, ul SortedUniqueList, items ...Item) { - assert := assert.New(t) - ul.Insert(items[0]) - ul.Insert(items[0]) - ul.Insert(items[0]) - assert.Equal(ul.Len(), 1) - - count := 0 - ul.ReverseRange(func(item Item) bool { - assert.Equal(item, items[0]) - count++ - return true - }) - assert.Equal(count, 1) - }, - }, - { - name: "reverse range empty list", - mock: func(m ...*MockItemMockRecorder) { - }, - expect: func(t *testing.T, ul SortedUniqueList, items ...Item) { - assert := assert.New(t) - count := 0 - ul.ReverseRange(func(item Item) bool { - count++ - return true - }) - assert.Equal(count, 0) - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - ctl := gomock.NewController(t) - defer ctl.Finish() - - var mockItems []Item - var mockItemRecorders []*MockItemMockRecorder - for i := 0; i < 10; i++ { - mockItem := NewMockItem(ctl) - mockItemRecorders = append(mockItemRecorders, mockItem.EXPECT()) - mockItems = append(mockItems, mockItem) - } - - tc.mock(mockItemRecorders...) - tc.expect(t, NewSortedUniqueList(), mockItems...) - }) - } -} - -func TestSortedUniqueListReverseRange_Concurrent(t *testing.T) { - runtime.GOMAXPROCS(2) - - ctl := gomock.NewController(t) - defer ctl.Finish() - - ul := NewSortedUniqueList() - nums := rand.Perm(N) - - var mockItems []Item - for _, v := range nums { - mockItem := NewMockItem(ctl) - mockItem.EXPECT().SortedValue().DoAndReturn(func() int { return v }).AnyTimes() - mockItems = append(mockItems, mockItem) - } - - var wg sync.WaitGroup - wg.Add(1) - go func() { - i := 0 - ul.ReverseRange(func(_ Item) bool { - i++ - return true - }) - - j := 0 - ul.ReverseRange(func(_ Item) bool { - j++ - return true - }) - if j < i { - t.Errorf("Values shrunk from %v to %v", i, j) - } - wg.Done() - }() - - for _, mockItem := range mockItems { - ul.Insert(mockItem) - } - wg.Wait() -} - -func BenchmarkSortedUniqueListInsert(b *testing.B) { - ul := NewSortedUniqueList() - - var mockItems []*item - for i := 0; i < b.N; i++ { - mockItems = append(mockItems, &item{id: i}) - } - - b.ResetTimer() - for _, mockItem := range mockItems { - ul.Insert(mockItem) - } -} - -func BenchmarkSortedUniqueListRemove(b *testing.B) { - ul := NewSortedUniqueList() - - var mockItems []*item - for i := 0; i < b.N; i++ { - mockItem := &item{id: i} - ul.Insert(mockItem) - mockItems = append(mockItems, mockItem) - } - - b.ResetTimer() - for _, mockItem := range mockItems { - ul.Remove(mockItem) - } -} - -func BenchmarkSortedUniqueListContains(b *testing.B) { - ul := NewSortedUniqueList() - - var mockItems []*item - for i := 0; i < b.N; i++ { - mockItem := &item{id: i} - ul.Insert(mockItem) - mockItems = append(mockItems, mockItem) - } - - b.ResetTimer() - for _, mockItem := range mockItems { - ul.Contains(mockItem) - } -} - -func BenchmarkSortedUniqueListRange(b *testing.B) { - ul := NewSortedUniqueList() - - for i := 0; i < b.N; i++ { - mockItem := item{id: i} - ul.Insert(&mockItem) - } - - b.ResetTimer() - ul.Range(func(_ Item) bool { return true }) -} - -func BenchmarkSortedUniqueListReverseRange(b *testing.B) { - ul := NewSortedUniqueList() - - for i := 0; i < b.N; i++ { - mockItem := item{id: i} - ul.Insert(&mockItem) - } - - b.ResetTimer() - ul.ReverseRange(func(item Item) bool { return true }) -} diff --git a/pkg/container/set/mocks/safe_set_mock.go b/pkg/container/set/mocks/safe_set_mock.go index 3a2e98f9b..a5582db0b 100644 --- a/pkg/container/set/mocks/safe_set_mock.go +++ b/pkg/container/set/mocks/safe_set_mock.go @@ -11,30 +11,30 @@ import ( ) // MockSafeSet is a mock of SafeSet interface. -type MockSafeSet struct { +type MockSafeSet[T comparable] struct { ctrl *gomock.Controller - recorder *MockSafeSetMockRecorder + recorder *MockSafeSetMockRecorder[T] } // MockSafeSetMockRecorder is the mock recorder for MockSafeSet. -type MockSafeSetMockRecorder struct { - mock *MockSafeSet +type MockSafeSetMockRecorder[T comparable] struct { + mock *MockSafeSet[T] } // NewMockSafeSet creates a new mock instance. -func NewMockSafeSet(ctrl *gomock.Controller) *MockSafeSet { - mock := &MockSafeSet{ctrl: ctrl} - mock.recorder = &MockSafeSetMockRecorder{mock} +func NewMockSafeSet[T comparable](ctrl *gomock.Controller) *MockSafeSet[T] { + mock := &MockSafeSet[T]{ctrl: ctrl} + mock.recorder = &MockSafeSetMockRecorder[T]{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockSafeSet) EXPECT() *MockSafeSetMockRecorder { +func (m *MockSafeSet[T]) EXPECT() *MockSafeSetMockRecorder[T] { return m.recorder } // Add mocks base method. -func (m *MockSafeSet) Add(arg0 any) bool { +func (m *MockSafeSet[T]) Add(arg0 T) bool { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Add", arg0) ret0, _ := ret[0].(bool) @@ -42,25 +42,25 @@ func (m *MockSafeSet) Add(arg0 any) bool { } // Add indicates an expected call of Add. -func (mr *MockSafeSetMockRecorder) Add(arg0 interface{}) *gomock.Call { +func (mr *MockSafeSetMockRecorder[T]) Add(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockSafeSet)(nil).Add), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockSafeSet[T])(nil).Add), arg0) } // Clear mocks base method. -func (m *MockSafeSet) Clear() { +func (m *MockSafeSet[T]) Clear() { m.ctrl.T.Helper() m.ctrl.Call(m, "Clear") } // Clear indicates an expected call of Clear. -func (mr *MockSafeSetMockRecorder) Clear() *gomock.Call { +func (mr *MockSafeSetMockRecorder[T]) Clear() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Clear", reflect.TypeOf((*MockSafeSet)(nil).Clear)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Clear", reflect.TypeOf((*MockSafeSet[T])(nil).Clear)) } // Contains mocks base method. -func (m *MockSafeSet) Contains(arg0 ...any) bool { +func (m *MockSafeSet[T]) Contains(arg0 ...T) bool { m.ctrl.T.Helper() varargs := []interface{}{} for _, a := range arg0 { @@ -72,25 +72,25 @@ func (m *MockSafeSet) Contains(arg0 ...any) bool { } // Contains indicates an expected call of Contains. -func (mr *MockSafeSetMockRecorder) Contains(arg0 ...interface{}) *gomock.Call { +func (mr *MockSafeSetMockRecorder[T]) Contains(arg0 ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Contains", reflect.TypeOf((*MockSafeSet)(nil).Contains), arg0...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Contains", reflect.TypeOf((*MockSafeSet[T])(nil).Contains), arg0...) } // Delete mocks base method. -func (m *MockSafeSet) Delete(arg0 any) { +func (m *MockSafeSet[T]) Delete(arg0 T) { m.ctrl.T.Helper() m.ctrl.Call(m, "Delete", arg0) } // Delete indicates an expected call of Delete. -func (mr *MockSafeSetMockRecorder) Delete(arg0 interface{}) *gomock.Call { +func (mr *MockSafeSetMockRecorder[T]) Delete(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockSafeSet)(nil).Delete), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockSafeSet[T])(nil).Delete), arg0) } // Len mocks base method. -func (m *MockSafeSet) Len() uint { +func (m *MockSafeSet[T]) Len() uint { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Len") ret0, _ := ret[0].(uint) @@ -98,21 +98,21 @@ func (m *MockSafeSet) Len() uint { } // Len indicates an expected call of Len. -func (mr *MockSafeSetMockRecorder) Len() *gomock.Call { +func (mr *MockSafeSetMockRecorder[T]) Len() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Len", reflect.TypeOf((*MockSafeSet)(nil).Len)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Len", reflect.TypeOf((*MockSafeSet[T])(nil).Len)) } // Values mocks base method. -func (m *MockSafeSet) Values() []any { +func (m *MockSafeSet[T]) Values() []T { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Values") - ret0, _ := ret[0].([]any) + ret0, _ := ret[0].([]T) return ret0 } // Values indicates an expected call of Values. -func (mr *MockSafeSetMockRecorder) Values() *gomock.Call { +func (mr *MockSafeSetMockRecorder[T]) Values() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Values", reflect.TypeOf((*MockSafeSet)(nil).Values)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Values", reflect.TypeOf((*MockSafeSet[T])(nil).Values)) } diff --git a/pkg/container/set/mocks/set_mock.go b/pkg/container/set/mocks/set_mock.go index e523b6fd7..803d07d7f 100644 --- a/pkg/container/set/mocks/set_mock.go +++ b/pkg/container/set/mocks/set_mock.go @@ -11,30 +11,30 @@ import ( ) // MockSet is a mock of Set interface. -type MockSet struct { +type MockSet[T comparable] struct { ctrl *gomock.Controller - recorder *MockSetMockRecorder + recorder *MockSetMockRecorder[T] } // MockSetMockRecorder is the mock recorder for MockSet. -type MockSetMockRecorder struct { - mock *MockSet +type MockSetMockRecorder[T comparable] struct { + mock *MockSet[T] } // NewMockSet creates a new mock instance. -func NewMockSet(ctrl *gomock.Controller) *MockSet { - mock := &MockSet{ctrl: ctrl} - mock.recorder = &MockSetMockRecorder{mock} +func NewMockSet[T comparable](ctrl *gomock.Controller) *MockSet[T] { + mock := &MockSet[T]{ctrl: ctrl} + mock.recorder = &MockSetMockRecorder[T]{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockSet) EXPECT() *MockSetMockRecorder { +func (m *MockSet[T]) EXPECT() *MockSetMockRecorder[T] { return m.recorder } // Add mocks base method. -func (m *MockSet) Add(arg0 any) bool { +func (m *MockSet[T]) Add(arg0 T) bool { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Add", arg0) ret0, _ := ret[0].(bool) @@ -42,25 +42,25 @@ func (m *MockSet) Add(arg0 any) bool { } // Add indicates an expected call of Add. -func (mr *MockSetMockRecorder) Add(arg0 interface{}) *gomock.Call { +func (mr *MockSetMockRecorder[T]) Add(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockSet)(nil).Add), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockSet[T])(nil).Add), arg0) } // Clear mocks base method. -func (m *MockSet) Clear() { +func (m *MockSet[T]) Clear() { m.ctrl.T.Helper() m.ctrl.Call(m, "Clear") } // Clear indicates an expected call of Clear. -func (mr *MockSetMockRecorder) Clear() *gomock.Call { +func (mr *MockSetMockRecorder[T]) Clear() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Clear", reflect.TypeOf((*MockSet)(nil).Clear)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Clear", reflect.TypeOf((*MockSet[T])(nil).Clear)) } // Contains mocks base method. -func (m *MockSet) Contains(arg0 ...any) bool { +func (m *MockSet[T]) Contains(arg0 ...T) bool { m.ctrl.T.Helper() varargs := []interface{}{} for _, a := range arg0 { @@ -72,25 +72,25 @@ func (m *MockSet) Contains(arg0 ...any) bool { } // Contains indicates an expected call of Contains. -func (mr *MockSetMockRecorder) Contains(arg0 ...interface{}) *gomock.Call { +func (mr *MockSetMockRecorder[T]) Contains(arg0 ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Contains", reflect.TypeOf((*MockSet)(nil).Contains), arg0...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Contains", reflect.TypeOf((*MockSet[T])(nil).Contains), arg0...) } // Delete mocks base method. -func (m *MockSet) Delete(arg0 any) { +func (m *MockSet[T]) Delete(arg0 T) { m.ctrl.T.Helper() m.ctrl.Call(m, "Delete", arg0) } // Delete indicates an expected call of Delete. -func (mr *MockSetMockRecorder) Delete(arg0 interface{}) *gomock.Call { +func (mr *MockSetMockRecorder[T]) Delete(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockSet)(nil).Delete), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockSet[T])(nil).Delete), arg0) } // Len mocks base method. -func (m *MockSet) Len() uint { +func (m *MockSet[T]) Len() uint { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Len") ret0, _ := ret[0].(uint) @@ -98,21 +98,21 @@ func (m *MockSet) Len() uint { } // Len indicates an expected call of Len. -func (mr *MockSetMockRecorder) Len() *gomock.Call { +func (mr *MockSetMockRecorder[T]) Len() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Len", reflect.TypeOf((*MockSet)(nil).Len)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Len", reflect.TypeOf((*MockSet[T])(nil).Len)) } // Values mocks base method. -func (m *MockSet) Values() []any { +func (m *MockSet[T]) Values() []T { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Values") - ret0, _ := ret[0].([]any) + ret0, _ := ret[0].([]T) return ret0 } // Values indicates an expected call of Values. -func (mr *MockSetMockRecorder) Values() *gomock.Call { +func (mr *MockSetMockRecorder[T]) Values() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Values", reflect.TypeOf((*MockSet)(nil).Values)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Values", reflect.TypeOf((*MockSet[T])(nil).Values)) } diff --git a/pkg/container/set/safe_set.go b/pkg/container/set/safe_set.go index ab01db259..cd8941455 100644 --- a/pkg/container/set/safe_set.go +++ b/pkg/container/set/safe_set.go @@ -22,32 +22,32 @@ import ( "sync" ) -type SafeSet interface { - Values() []any - Add(any) bool - Delete(any) - Contains(...any) bool +type SafeSet[T comparable] interface { + Values() []T + Add(T) bool + Delete(T) + Contains(...T) bool Len() uint Clear() } -type safeSet struct { +type safeSet[T comparable] struct { mu *sync.RWMutex - data map[any]struct{} + data map[T]struct{} } -func NewSafeSet() SafeSet { - return &safeSet{ +func NewSafeSet[T comparable]() SafeSet[T] { + return &safeSet[T]{ mu: &sync.RWMutex{}, - data: make(map[any]struct{}), + data: make(map[T]struct{}), } } -func (s *safeSet) Values() []any { +func (s *safeSet[T]) Values() []T { s.mu.RLock() defer s.mu.RUnlock() - var result []any + var result []T for k := range s.data { result = append(result, k) } @@ -55,7 +55,7 @@ func (s *safeSet) Values() []any { return result } -func (s *safeSet) Add(v any) bool { +func (s *safeSet[T]) Add(v T) bool { s.mu.RLock() _, found := s.data[v] if found { @@ -70,13 +70,13 @@ func (s *safeSet) Add(v any) bool { return true } -func (s *safeSet) Delete(v any) { +func (s *safeSet[T]) Delete(v T) { s.mu.Lock() defer s.mu.Unlock() delete(s.data, v) } -func (s *safeSet) Contains(vals ...any) bool { +func (s *safeSet[T]) Contains(vals ...T) bool { s.mu.RLock() defer s.mu.RUnlock() for _, v := range vals { @@ -88,14 +88,14 @@ func (s *safeSet) Contains(vals ...any) bool { return true } -func (s *safeSet) Len() uint { +func (s *safeSet[T]) Len() uint { s.mu.RLock() defer s.mu.RUnlock() return uint(len(s.data)) } -func (s *safeSet) Clear() { +func (s *safeSet[T]) Clear() { s.mu.Lock() defer s.mu.Unlock() - s.data = make(map[any]struct{}) + s.data = make(map[T]struct{}) } diff --git a/pkg/container/set/safe_set_test.go b/pkg/container/set/safe_set_test.go index fad218de0..0b2252bf1 100644 --- a/pkg/container/set/safe_set_test.go +++ b/pkg/container/set/safe_set_test.go @@ -30,33 +30,33 @@ const N = 1000 func TestSafeSetAdd(t *testing.T) { tests := []struct { name string - value any - expect func(t *testing.T, ok bool, s SafeSet, value any) + value string + expect func(t *testing.T, ok bool, s SafeSet[string], value string) }{ { name: "add value", value: "foo", - expect: func(t *testing.T, ok bool, s SafeSet, value any) { + expect: func(t *testing.T, ok bool, s SafeSet[string], value string) { assert := assert.New(t) assert.Equal(ok, true) - assert.Equal(s.Values(), []any{value}) + assert.Equal(s.Values(), []string{value}) }, }, { name: "add value failed", value: "foo", - expect: func(t *testing.T, _ bool, s SafeSet, value any) { + expect: func(t *testing.T, _ bool, s SafeSet[string], value string) { assert := assert.New(t) ok := s.Add("foo") assert.Equal(ok, false) - assert.Equal(s.Values(), []any{value}) + assert.Equal(s.Values(), []string{value}) }, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - s := NewSafeSet() + s := NewSafeSet[string]() tc.expect(t, s.Add(tc.value), s, tc.value) }) } @@ -65,7 +65,7 @@ func TestSafeSetAdd(t *testing.T) { func TestSafeSetAdd_Concurrent(t *testing.T) { runtime.GOMAXPROCS(2) - s := NewSafeSet() + s := NewSafeSet[int]() nums := rand.Perm(N) var wg sync.WaitGroup @@ -88,13 +88,13 @@ func TestSafeSetAdd_Concurrent(t *testing.T) { func TestSafeSetDelete(t *testing.T) { tests := []struct { name string - value any - expect func(t *testing.T, s SafeSet, value any) + value string + expect func(t *testing.T, s SafeSet[string], value string) }{ { name: "delete value", value: "foo", - expect: func(t *testing.T, s SafeSet, value any) { + expect: func(t *testing.T, s SafeSet[string], value string) { assert := assert.New(t) s.Delete(value) assert.Equal(s.Len(), uint(0)) @@ -103,7 +103,7 @@ func TestSafeSetDelete(t *testing.T) { { name: "delete value does not exist", value: "foo", - expect: func(t *testing.T, s SafeSet, _ any) { + expect: func(t *testing.T, s SafeSet[string], _ string) { assert := assert.New(t) s.Delete("bar") assert.Equal(s.Len(), uint(1)) @@ -113,7 +113,7 @@ func TestSafeSetDelete(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - s := NewSafeSet() + s := NewSafeSet[string]() s.Add(tc.value) tc.expect(t, s, tc.value) }) @@ -123,7 +123,7 @@ func TestSafeSetDelete(t *testing.T) { func TestSafeSetDelete_Concurrent(t *testing.T) { runtime.GOMAXPROCS(2) - s := NewSafeSet() + s := NewSafeSet[int]() nums := rand.Perm(N) for _, v := range nums { s.Add(v) @@ -147,21 +147,21 @@ func TestSafeSetDelete_Concurrent(t *testing.T) { func TestSafeSetContains(t *testing.T) { tests := []struct { name string - value any - expect func(t *testing.T, s SafeSet, value any) + value string + expect func(t *testing.T, s SafeSet[string], value string) }{ { name: "contains value", value: "foo", - expect: func(t *testing.T, s SafeSet, value any) { + expect: func(t *testing.T, s SafeSet[string], value string) { assert := assert.New(t) - assert.Equal(s.Contains(value), true) + assert.Equal(s.Contains(string(value)), true) }, }, { name: "contains value does not exist", value: "foo", - expect: func(t *testing.T, s SafeSet, _ any) { + expect: func(t *testing.T, s SafeSet[string], _ string) { assert := assert.New(t) assert.Equal(s.Contains("bar"), false) }, @@ -170,7 +170,7 @@ func TestSafeSetContains(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - s := NewSafeSet() + s := NewSafeSet[string]() s.Add(tc.value) tc.expect(t, s, tc.value) }) @@ -180,9 +180,9 @@ func TestSafeSetContains(t *testing.T) { func TestSafeSetContains_Concurrent(t *testing.T) { runtime.GOMAXPROCS(2) - s := NewSafeSet() + s := NewSafeSet[int]() nums := rand.Perm(N) - interfaces := make([]any, 0) + interfaces := make([]int, 0) for _, v := range nums { s.Add(v) interfaces = append(interfaces, v) @@ -202,11 +202,11 @@ func TestSafeSetContains_Concurrent(t *testing.T) { func TestSetSafeLen(t *testing.T) { tests := []struct { name string - expect func(t *testing.T, s SafeSet) + expect func(t *testing.T, s SafeSet[string]) }{ { name: "get length", - expect: func(t *testing.T, s SafeSet) { + expect: func(t *testing.T, s SafeSet[string]) { assert := assert.New(t) s.Add("foo") assert.Equal(s.Len(), uint(1)) @@ -214,7 +214,7 @@ func TestSetSafeLen(t *testing.T) { }, { name: "get empty set length", - expect: func(t *testing.T, s SafeSet) { + expect: func(t *testing.T, s SafeSet[string]) { assert := assert.New(t) assert.Equal(s.Len(), uint(0)) }, @@ -223,7 +223,7 @@ func TestSetSafeLen(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - s := NewSafeSet() + s := NewSafeSet[string]() tc.expect(t, s) }) } @@ -232,7 +232,7 @@ func TestSetSafeLen(t *testing.T) { func TestSafeSetLen_Concurrent(t *testing.T) { runtime.GOMAXPROCS(2) - s := NewSafeSet() + s := NewSafeSet[int]() var wg sync.WaitGroup wg.Add(1) @@ -256,26 +256,26 @@ func TestSafeSetLen_Concurrent(t *testing.T) { func TestSafeSetValues(t *testing.T) { tests := []struct { name string - expect func(t *testing.T, s SafeSet) + expect func(t *testing.T, s SafeSet[string]) }{ { name: "get values", - expect: func(t *testing.T, s SafeSet) { + expect: func(t *testing.T, s SafeSet[string]) { assert := assert.New(t) s.Add("foo") - assert.Equal(s.Values(), []any{"foo"}) + assert.Equal(s.Values(), []string{"foo"}) }, }, { name: "get empty values", - expect: func(t *testing.T, s SafeSet) { + expect: func(t *testing.T, s SafeSet[string]) { assert := assert.New(t) - assert.Equal(s.Values(), []any(nil)) + assert.Equal(s.Values(), []string(nil)) }, }, { name: "get multi values", - expect: func(t *testing.T, s SafeSet) { + expect: func(t *testing.T, s SafeSet[string]) { assert := assert.New(t) s.Add("foo") s.Add("bar") @@ -287,7 +287,7 @@ func TestSafeSetValues(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - s := NewSafeSet() + s := NewSafeSet[string]() tc.expect(t, s) }) } @@ -296,7 +296,7 @@ func TestSafeSetValues(t *testing.T) { func TestSafeSetValues_Concurrent(t *testing.T) { runtime.GOMAXPROCS(2) - s := NewSafeSet() + s := NewSafeSet[int]() var wg sync.WaitGroup wg.Add(1) @@ -312,7 +312,7 @@ func TestSafeSetValues_Concurrent(t *testing.T) { }() for i := 0; i < N; i++ { - s.Add(rand.Int()) + s.Add(i) } wg.Wait() } @@ -320,32 +320,32 @@ func TestSafeSetValues_Concurrent(t *testing.T) { func TestSafeSetClear(t *testing.T) { tests := []struct { name string - expect func(t *testing.T, s SafeSet) + expect func(t *testing.T, s SafeSet[string]) }{ { name: "clear empty set", - expect: func(t *testing.T, s SafeSet) { + expect: func(t *testing.T, s SafeSet[string]) { assert := assert.New(t) s.Clear() - assert.Equal(s.Values(), []any(nil)) + assert.Equal(s.Values(), []string(nil)) }, }, { name: "clear set", - expect: func(t *testing.T, s SafeSet) { + expect: func(t *testing.T, s SafeSet[string]) { assert := assert.New(t) assert.Equal(s.Add("foo"), true) s.Clear() - assert.Equal(s.Values(), []any(nil)) + assert.Equal(s.Values(), []string(nil)) assert.Equal(s.Add("foo"), true) - assert.Equal(s.Values(), []any{"foo"}) + assert.Equal(s.Values(), []string{"foo"}) }, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - s := NewSafeSet() + s := NewSafeSet[string]() tc.expect(t, s) }) } @@ -354,7 +354,7 @@ func TestSafeSetClear(t *testing.T) { func TestSafeSetClear_Concurrent(t *testing.T) { runtime.GOMAXPROCS(2) - s := NewSafeSet() + s := NewSafeSet[int]() nums := rand.Perm(N) var wg sync.WaitGroup diff --git a/pkg/container/set/set.go b/pkg/container/set/set.go index 6163dfe4c..853210b1a 100644 --- a/pkg/container/set/set.go +++ b/pkg/container/set/set.go @@ -18,23 +18,23 @@ package set -type Set interface { - Values() []any - Add(any) bool - Delete(any) - Contains(...any) bool +type Set[T comparable] interface { + Values() []T + Add(T) bool + Delete(T) + Contains(...T) bool Len() uint Clear() } -type set map[any]struct{} +type set[T comparable] map[T]struct{} -func New() Set { - return &set{} +func New[T comparable]() Set[T] { + return &set[T]{} } -func (s *set) Values() []any { - var result []any +func (s *set[T]) Values() []T { + var result []T for k := range *s { result = append(result, k) } @@ -42,7 +42,7 @@ func (s *set) Values() []any { return result } -func (s *set) Add(v any) bool { +func (s *set[T]) Add(v T) bool { _, found := (*s)[v] if found { return false @@ -52,11 +52,11 @@ func (s *set) Add(v any) bool { return true } -func (s *set) Delete(v any) { +func (s *set[T]) Delete(v T) { delete(*s, v) } -func (s *set) Contains(vals ...any) bool { +func (s *set[T]) Contains(vals ...T) bool { for _, v := range vals { if _, ok := (*s)[v]; !ok { return false @@ -66,10 +66,10 @@ func (s *set) Contains(vals ...any) bool { return true } -func (s *set) Len() uint { +func (s *set[T]) Len() uint { return uint(len(*s)) } -func (s *set) Clear() { - *s = set{} +func (s *set[T]) Clear() { + *s = set[T]{} } diff --git a/pkg/container/set/set_test.go b/pkg/container/set/set_test.go index bacecbef7..a783b2fc6 100644 --- a/pkg/container/set/set_test.go +++ b/pkg/container/set/set_test.go @@ -25,33 +25,33 @@ import ( func TestSetAdd(t *testing.T) { tests := []struct { name string - value any - expect func(t *testing.T, ok bool, s Set, value any) + value string + expect func(t *testing.T, ok bool, s Set[string], value string) }{ { name: "add value", value: "foo", - expect: func(t *testing.T, ok bool, s Set, value any) { + expect: func(t *testing.T, ok bool, s Set[string], value string) { assert := assert.New(t) assert.Equal(ok, true) - assert.Equal(s.Values(), []any{value}) + assert.Equal(s.Values(), []string{value}) }, }, { name: "add value failed", value: "foo", - expect: func(t *testing.T, _ bool, s Set, value any) { + expect: func(t *testing.T, _ bool, s Set[string], value string) { assert := assert.New(t) ok := s.Add("foo") assert.Equal(ok, false) - assert.Equal(s.Values(), []any{value}) + assert.Equal(s.Values(), []string{value}) }, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - s := New() + s := New[string]() tc.expect(t, s.Add(tc.value), s, tc.value) }) } @@ -60,13 +60,13 @@ func TestSetAdd(t *testing.T) { func TestSetDelete(t *testing.T) { tests := []struct { name string - value any - expect func(t *testing.T, s Set, value any) + value string + expect func(t *testing.T, s Set[string], value string) }{ { name: "delete value", value: "foo", - expect: func(t *testing.T, s Set, value any) { + expect: func(t *testing.T, s Set[string], value string) { assert := assert.New(t) s.Delete(value) assert.Equal(s.Len(), uint(0)) @@ -75,7 +75,7 @@ func TestSetDelete(t *testing.T) { { name: "delete value does not exist", value: "foo", - expect: func(t *testing.T, s Set, _ any) { + expect: func(t *testing.T, s Set[string], _ string) { assert := assert.New(t) s.Delete("bar") assert.Equal(s.Len(), uint(1)) @@ -85,7 +85,7 @@ func TestSetDelete(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - s := New() + s := New[string]() s.Add(tc.value) tc.expect(t, s, tc.value) }) @@ -95,13 +95,13 @@ func TestSetDelete(t *testing.T) { func TestSetContains(t *testing.T) { tests := []struct { name string - value any - expect func(t *testing.T, s Set, value any) + value string + expect func(t *testing.T, s Set[string], value string) }{ { name: "contains value", value: "foo", - expect: func(t *testing.T, s Set, value any) { + expect: func(t *testing.T, s Set[string], value string) { assert := assert.New(t) assert.Equal(s.Contains(value), true) }, @@ -109,7 +109,7 @@ func TestSetContains(t *testing.T) { { name: "contains value does not exist", value: "foo", - expect: func(t *testing.T, s Set, _ any) { + expect: func(t *testing.T, s Set[string], _ string) { assert := assert.New(t) assert.Equal(s.Contains("bar"), false) }, @@ -118,7 +118,7 @@ func TestSetContains(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - s := New() + s := New[string]() s.Add(tc.value) tc.expect(t, s, tc.value) }) @@ -128,11 +128,11 @@ func TestSetContains(t *testing.T) { func TestSetLen(t *testing.T) { tests := []struct { name string - expect func(t *testing.T, s Set) + expect func(t *testing.T, s Set[string]) }{ { name: "get length", - expect: func(t *testing.T, s Set) { + expect: func(t *testing.T, s Set[string]) { assert := assert.New(t) s.Add("foo") assert.Equal(s.Len(), uint(1)) @@ -140,7 +140,7 @@ func TestSetLen(t *testing.T) { }, { name: "get empty set length", - expect: func(t *testing.T, s Set) { + expect: func(t *testing.T, s Set[string]) { assert := assert.New(t) assert.Equal(s.Len(), uint(0)) }, @@ -149,7 +149,7 @@ func TestSetLen(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - s := New() + s := New[string]() tc.expect(t, s) }) } @@ -158,26 +158,26 @@ func TestSetLen(t *testing.T) { func TestSetValues(t *testing.T) { tests := []struct { name string - expect func(t *testing.T, s Set) + expect func(t *testing.T, s Set[string]) }{ { name: "get values", - expect: func(t *testing.T, s Set) { + expect: func(t *testing.T, s Set[string]) { assert := assert.New(t) s.Add("foo") - assert.Equal(s.Values(), []any{"foo"}) + assert.Equal(s.Values(), []string{"foo"}) }, }, { name: "get empty values", - expect: func(t *testing.T, s Set) { + expect: func(t *testing.T, s Set[string]) { assert := assert.New(t) - assert.Equal(s.Values(), []any(nil)) + assert.Equal(s.Values(), []string(nil)) }, }, { name: "get multi values", - expect: func(t *testing.T, s Set) { + expect: func(t *testing.T, s Set[string]) { assert := assert.New(t) s.Add("foo") s.Add("bar") @@ -189,7 +189,7 @@ func TestSetValues(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - s := New() + s := New[string]() tc.expect(t, s) }) } @@ -198,32 +198,32 @@ func TestSetValues(t *testing.T) { func TestSetClear(t *testing.T) { tests := []struct { name string - expect func(t *testing.T, s Set) + expect func(t *testing.T, s Set[string]) }{ { name: "clear empty set", - expect: func(t *testing.T, s Set) { + expect: func(t *testing.T, s Set[string]) { assert := assert.New(t) s.Clear() - assert.Equal(s.Values(), []any(nil)) + assert.Equal(s.Values(), []string(nil)) }, }, { name: "clear set", - expect: func(t *testing.T, s Set) { + expect: func(t *testing.T, s Set[string]) { assert := assert.New(t) assert.Equal(s.Add("foo"), true) s.Clear() - assert.Equal(s.Values(), []any(nil)) + assert.Equal(s.Values(), []string(nil)) assert.Equal(s.Add("foo"), true) - assert.Equal(s.Values(), []any{"foo"}) + assert.Equal(s.Values(), []string{"foo"}) }, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - s := New() + s := New[string]() tc.expect(t, s) }) } diff --git a/pkg/dag/dag.go b/pkg/dag/dag.go index 6969545ba..08b47c33f 100644 --- a/pkg/dag/dag.go +++ b/pkg/dag/dag.go @@ -20,7 +20,11 @@ package dag import ( "errors" + "math/rand" "sync" + "time" + + cmap "github.com/orcaman/concurrent-map/v2" ) var ( @@ -41,24 +45,30 @@ var ( ) // DAG is the interface used for directed acyclic graph. -type DAG interface { +type DAG[T comparable] interface { // AddVertex adds vertex to graph. - AddVertex(id string, value any) error + AddVertex(id string, value T) error // DeleteVertex deletes vertex graph. DeleteVertex(id string) // GetVertex gets vertex from graph. - GetVertex(id string) (*Vertex, error) + GetVertex(id string) (*Vertex[T], error) // GetVertices returns map of vertices. - GetVertices() map[string]*Vertex + GetVertices() map[string]*Vertex[T] + + // GetRandomVertices returns random map of vertices. + GetRandomVertices(n uint) map[string]*Vertex[T] + + // GetVertexKeys returns keys of vertices. + GetVertexKeys() []string // GetSourceVertices returns source vertices. - GetSourceVertices() map[string]*Vertex + GetSourceVertices() map[string]*Vertex[T] // GetSinkVertices returns sink vertices. - GetSinkVertices() map[string]*Vertex + GetSinkVertices() map[string]*Vertex[T] // VertexCount returns count of vertices. VertexCount() int @@ -74,69 +84,56 @@ type DAG interface { } // dag provides directed acyclic graph function. -type dag struct { +type dag[T comparable] struct { mu sync.RWMutex - vertices map[string]*Vertex + vertices cmap.ConcurrentMap[*Vertex[T]] } // New returns a new DAG interface. -func NewDAG() DAG { - return &dag{ - vertices: make(map[string]*Vertex), +func NewDAG[T comparable]() DAG[T] { + return &dag[T]{ + vertices: cmap.New[*Vertex[T]](), } } // AddVertex adds vertex to graph. -func (d *dag) AddVertex(id string, value any) error { +func (d *dag[T]) AddVertex(id string, value T) error { d.mu.Lock() defer d.mu.Unlock() - if _, ok := d.vertices[id]; ok { + if _, ok := d.vertices.Get(id); ok { return ErrVertexAlreadyExists } - d.vertices[id] = NewVertex(id, value) + d.vertices.Set(id, NewVertex(id, value)) return nil } // DeleteVertex deletes vertex graph. -func (d *dag) DeleteVertex(id string) { +func (d *dag[T]) DeleteVertex(id string) { d.mu.Lock() defer d.mu.Unlock() - vertex, ok := d.vertices[id] + vertex, ok := d.vertices.Get(id) if !ok { return } - for _, value := range vertex.Parents.Values() { - parent, ok := value.(*Vertex) - if !ok { - continue - } - + for _, parent := range vertex.Parents.Values() { parent.Children.Delete(vertex) } - for _, value := range vertex.Children.Values() { - child, ok := value.(*Vertex) - if !ok { - continue - } - + for _, child := range vertex.Children.Values() { child.Parents.Delete(vertex) continue } - delete(d.vertices, id) + d.vertices.Remove(id) } // GetVertex gets vertex from graph. -func (d *dag) GetVertex(id string) (*Vertex, error) { - d.mu.RLock() - defer d.mu.RUnlock() - - vertex, ok := d.vertices[id] +func (d *dag[T]) GetVertex(id string) (*Vertex[T], error) { + vertex, ok := d.vertices.Get(id) if !ok { return nil, ErrVertexNotFound } @@ -145,20 +142,44 @@ func (d *dag) GetVertex(id string) (*Vertex, error) { } // GetVertices returns map of vertices. -func (d *dag) GetVertices() map[string]*Vertex { +func (d *dag[T]) GetVertices() map[string]*Vertex[T] { + return d.vertices.Items() +} + +// GetRandomVertices returns random map of vertices. +func (d *dag[T]) GetRandomVertices(n uint) map[string]*Vertex[T] { d.mu.RLock() defer d.mu.RUnlock() - return d.vertices + keys := d.GetVertexKeys() + vertices := d.GetVertices() + if int(n) >= len(keys) { + return vertices + } + + rand.Seed(time.Now().Unix()) + permutation := rand.Perm(len(keys))[:n] + randomVertices := make(map[string]*Vertex[T]) + for _, v := range permutation { + key := keys[v] + randomVertices[key] = vertices[key] + } + + return randomVertices +} + +// GetVertexKeys returns keys of vertices. +func (d *dag[T]) GetVertexKeys() []string { + return d.vertices.Keys() } // VertexCount returns count of vertices. -func (d *dag) VertexCount() int { - return len(d.vertices) +func (d *dag[T]) VertexCount() int { + return d.vertices.Count() } // CanAddEdge finds whether there are circles through depth-first search. -func (d *dag) CanAddEdge(fromVertexID, toVertexID string) bool { +func (d *dag[T]) CanAddEdge(fromVertexID, toVertexID string) bool { d.mu.RLock() defer d.mu.RUnlock() @@ -166,22 +187,17 @@ func (d *dag) CanAddEdge(fromVertexID, toVertexID string) bool { return false } - fromVertex, ok := d.vertices[fromVertexID] + fromVertex, ok := d.vertices.Get(fromVertexID) if !ok { return false } - if _, ok := d.vertices[toVertexID]; !ok { + if _, ok := d.vertices.Get(toVertexID); !ok { return false } for _, child := range fromVertex.Children.Values() { - vertex, ok := child.(*Vertex) - if !ok { - continue - } - - if vertex.ID == toVertexID { + if child.ID == toVertexID { return false } } @@ -194,7 +210,7 @@ func (d *dag) CanAddEdge(fromVertexID, toVertexID string) bool { } // AddEdge adds edge between two vertices. -func (d *dag) AddEdge(fromVertexID, toVertexID string) error { +func (d *dag[T]) AddEdge(fromVertexID, toVertexID string) error { d.mu.Lock() defer d.mu.Unlock() @@ -202,23 +218,18 @@ func (d *dag) AddEdge(fromVertexID, toVertexID string) error { return ErrCycleBetweenVertices } - fromVertex, ok := d.vertices[fromVertexID] + fromVertex, ok := d.vertices.Get(fromVertexID) if !ok { return ErrVertexNotFound } - toVertex, ok := d.vertices[toVertexID] + toVertex, ok := d.vertices.Get(toVertexID) if !ok { return ErrVertexNotFound } for _, child := range fromVertex.Children.Values() { - vertex, ok := child.(*Vertex) - if !ok { - continue - } - - if vertex.ID == toVertexID { + if child.ID == toVertexID { return ErrCycleBetweenVertices } } @@ -239,16 +250,16 @@ func (d *dag) AddEdge(fromVertexID, toVertexID string) error { } // DeleteEdge deletes edge between two vertices. -func (d *dag) DeleteEdge(fromVertexID, toVertexID string) error { +func (d *dag[T]) DeleteEdge(fromVertexID, toVertexID string) error { d.mu.Lock() defer d.mu.Unlock() - fromVertex, ok := d.vertices[fromVertexID] + fromVertex, ok := d.vertices.Get(fromVertexID) if !ok { return ErrVertexNotFound } - toVertex, ok := d.vertices[toVertexID] + toVertex, ok := d.vertices.Get(toVertexID) if !ok { return ErrVertexNotFound } @@ -259,12 +270,12 @@ func (d *dag) DeleteEdge(fromVertexID, toVertexID string) error { } // GetSourceVertices returns source vertices. -func (d *dag) GetSourceVertices() map[string]*Vertex { +func (d *dag[T]) GetSourceVertices() map[string]*Vertex[T] { d.mu.RLock() defer d.mu.RUnlock() - sourceVertices := make(map[string]*Vertex) - for k, v := range d.vertices { + sourceVertices := make(map[string]*Vertex[T]) + for k, v := range d.vertices.Items() { if v.InDegree() == 0 { sourceVertices[k] = v } @@ -274,12 +285,12 @@ func (d *dag) GetSourceVertices() map[string]*Vertex { } // GetSinkVertices returns sink vertices. -func (d *dag) GetSinkVertices() map[string]*Vertex { +func (d *dag[T]) GetSinkVertices() map[string]*Vertex[T] { d.mu.RLock() defer d.mu.RUnlock() - sinkVertices := make(map[string]*Vertex) - for k, v := range d.vertices { + sinkVertices := make(map[string]*Vertex[T]) + for k, v := range d.vertices.Items() { if v.OutDegree() == 0 { sinkVertices[k] = v } @@ -289,7 +300,7 @@ func (d *dag) GetSinkVertices() map[string]*Vertex { } // depthFirstSearch is a depth-first search of the directed acyclic graph. -func (d *dag) depthFirstSearch(fromVertexID, toVertexID string) bool { +func (d *dag[T]) depthFirstSearch(fromVertexID, toVertexID string) bool { successors := make(map[string]struct{}) d.search(fromVertexID, successors) _, ok := successors[toVertexID] @@ -297,21 +308,16 @@ func (d *dag) depthFirstSearch(fromVertexID, toVertexID string) bool { } // depthFirstSearch finds successors of vertex. -func (d *dag) search(vertexID string, successors map[string]struct{}) { - vertex, ok := d.vertices[vertexID] +func (d *dag[T]) search(vertexID string, successors map[string]struct{}) { + vertex, ok := d.vertices.Get(vertexID) if !ok { return } for _, child := range vertex.Children.Values() { - vertex, ok := child.(*Vertex) - if !ok { - continue - } - - if _, ok := successors[vertex.ID]; !ok { - successors[vertex.ID] = struct{}{} - d.search(vertex.ID, successors) + if _, ok := successors[child.ID]; !ok { + successors[child.ID] = struct{}{} + d.search(child.ID, successors) } } } diff --git a/pkg/dag/dag_test.go b/pkg/dag/dag_test.go index 3097fa7a8..73ceb1ec2 100644 --- a/pkg/dag/dag_test.go +++ b/pkg/dag/dag_test.go @@ -17,6 +17,7 @@ package dag import ( + "errors" "fmt" "reflect" "testing" @@ -25,9 +26,9 @@ import ( ) func TestNewDAG(t *testing.T) { - d := NewDAG() + d := NewDAG[string]() assert := assert.New(t) - assert.Equal(reflect.TypeOf(d).Elem().Name(), "dag") + assert.Equal(reflect.TypeOf(d).Elem().Name(), "dag[string]") } func TestDAGAddVertex(t *testing.T) { @@ -35,13 +36,13 @@ func TestDAGAddVertex(t *testing.T) { name string id string value any - expect func(t *testing.T, d DAG, err error) + expect func(t *testing.T, d DAG[string], err error) }{ { name: "add vertex", id: mockVertexID, value: mockVertexValue, - expect: func(t *testing.T, d DAG, err error) { + expect: func(t *testing.T, d DAG[string], err error) { assert := assert.New(t) assert.NoError(err) }, @@ -50,7 +51,7 @@ func TestDAGAddVertex(t *testing.T) { name: "vertex already exists", id: mockVertexID, value: mockVertexValue, - expect: func(t *testing.T, d DAG, err error) { + expect: func(t *testing.T, d DAG[string], err error) { assert := assert.New(t) assert.NoError(err) @@ -61,7 +62,7 @@ func TestDAGAddVertex(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - d := NewDAG() + d := NewDAG[string]() tc.expect(t, d, d.AddVertex(tc.id, tc.name)) }) } @@ -70,11 +71,11 @@ func TestDAGAddVertex(t *testing.T) { func TestDAGDeleteVertex(t *testing.T) { tests := []struct { name string - expect func(t *testing.T, d DAG) + expect func(t *testing.T, d DAG[string]) }{ { name: "delete vertex", - expect: func(t *testing.T, d DAG) { + expect: func(t *testing.T, d DAG[string]) { assert := assert.New(t) if err := d.AddVertex(mockVertexID, mockVertexValue); err != nil { assert.NoError(err) @@ -87,7 +88,7 @@ func TestDAGDeleteVertex(t *testing.T) { }, { name: "delete vertex with edges", - expect: func(t *testing.T, d DAG) { + expect: func(t *testing.T, d DAG[string]) { assert := assert.New(t) var ( @@ -119,7 +120,7 @@ func TestDAGDeleteVertex(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - d := NewDAG() + d := NewDAG[string]() tc.expect(t, d) }) } @@ -128,11 +129,11 @@ func TestDAGDeleteVertex(t *testing.T) { func TestDAGGetVertex(t *testing.T) { tests := []struct { name string - expect func(t *testing.T, d DAG) + expect func(t *testing.T, d DAG[string]) }{ { name: "get vertex", - expect: func(t *testing.T, d DAG) { + expect: func(t *testing.T, d DAG[string]) { assert := assert.New(t) if err := d.AddVertex(mockVertexID, mockVertexValue); err != nil { assert.NoError(err) @@ -148,7 +149,7 @@ func TestDAGGetVertex(t *testing.T) { }, { name: "vertex not found", - expect: func(t *testing.T, d DAG) { + expect: func(t *testing.T, d DAG[string]) { assert := assert.New(t) _, err := d.GetVertex(mockVertexID) assert.EqualError(err, ErrVertexNotFound.Error()) @@ -158,7 +159,7 @@ func TestDAGGetVertex(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - d := NewDAG() + d := NewDAG[string]() tc.expect(t, d) }) } @@ -167,11 +168,11 @@ func TestDAGGetVertex(t *testing.T) { func TestDAGVertexVertexCount(t *testing.T) { tests := []struct { name string - expect func(t *testing.T, d DAG) + expect func(t *testing.T, d DAG[string]) }{ { name: "get length of vertex", - expect: func(t *testing.T, d DAG) { + expect: func(t *testing.T, d DAG[string]) { assert := assert.New(t) if err := d.AddVertex(mockVertexID, mockVertexValue); err != nil { assert.NoError(err) @@ -186,7 +187,7 @@ func TestDAGVertexVertexCount(t *testing.T) { }, { name: "empty dag", - expect: func(t *testing.T, d DAG) { + expect: func(t *testing.T, d DAG[string]) { assert := assert.New(t) assert.Equal(d.VertexCount(), 0) }, @@ -195,7 +196,7 @@ func TestDAGVertexVertexCount(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - d := NewDAG() + d := NewDAG[string]() tc.expect(t, d) }) } @@ -204,11 +205,11 @@ func TestDAGVertexVertexCount(t *testing.T) { func TestDAGGetVertices(t *testing.T) { tests := []struct { name string - expect func(t *testing.T, d DAG) + expect func(t *testing.T, d DAG[string]) }{ { name: "get vertices", - expect: func(t *testing.T, d DAG) { + expect: func(t *testing.T, d DAG[string]) { assert := assert.New(t) if err := d.AddVertex(mockVertexID, mockVertexValue); err != nil { assert.NoError(err) @@ -220,6 +221,15 @@ func TestDAGGetVertices(t *testing.T) { assert.Equal(vertices[mockVertexID].Value, mockVertexValue) d.DeleteVertex(mockVertexID) + vertices = d.GetVertices() + assert.Equal(len(vertices), 0) + }, + }, + { + name: "dag is empty", + expect: func(t *testing.T, d DAG[string]) { + assert := assert.New(t) + vertices := d.GetVertices() assert.Equal(len(vertices), 0) }, }, @@ -227,7 +237,95 @@ func TestDAGGetVertices(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - d := NewDAG() + d := NewDAG[string]() + tc.expect(t, d) + }) + } +} + +func TestDAGGetRandomVertices(t *testing.T) { + tests := []struct { + name string + expect func(t *testing.T, d DAG[string]) + }{ + { + name: "get random vertices", + expect: func(t *testing.T, d DAG[string]) { + assert := assert.New(t) + var ( + mockVertexEID = "bae" + mockVertexFID = "baf" + ) + + if err := d.AddVertex(mockVertexEID, mockVertexValue); err != nil { + assert.NoError(err) + } + + if err := d.AddVertex(mockVertexFID, mockVertexValue); err != nil { + assert.NoError(err) + } + + vertices := d.GetRandomVertices(0) + assert.Equal(len(vertices), 0) + + vertices = d.GetRandomVertices(1) + assert.Equal(len(vertices), 1) + + vertices = d.GetRandomVertices(2) + assert.Equal(len(vertices), 2) + + vertices = d.GetRandomVertices(3) + assert.Equal(len(vertices), 2) + }, + }, + { + name: "dag is empty", + expect: func(t *testing.T, d DAG[string]) { + assert := assert.New(t) + vertices := d.GetRandomVertices(0) + assert.Equal(len(vertices), 0) + + vertices = d.GetRandomVertices(1) + assert.Equal(len(vertices), 0) + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + d := NewDAG[string]() + tc.expect(t, d) + }) + } +} + +func TestDAGGetVertexKeys(t *testing.T) { + tests := []struct { + name string + expect func(t *testing.T, d DAG[string]) + }{ + { + name: "get keys of vertices", + expect: func(t *testing.T, d DAG[string]) { + assert := assert.New(t) + if err := d.AddVertex(mockVertexID, mockVertexValue); err != nil { + assert.NoError(err) + } + + keys := d.GetVertexKeys() + assert.Equal(len(keys), 1) + assert.Equal(keys[0], mockVertexID) + + d.DeleteVertex(mockVertexID) + keys = d.GetVertexKeys() + assert.Equal(len(keys), 0) + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + d := NewDAG[string]() tc.expect(t, d) }) } @@ -236,11 +334,11 @@ func TestDAGGetVertices(t *testing.T) { func TestDAGAddEdge(t *testing.T) { tests := []struct { name string - expect func(t *testing.T, d DAG) + expect func(t *testing.T, d DAG[string]) }{ { name: "add edge", - expect: func(t *testing.T, d DAG) { + expect: func(t *testing.T, d DAG[string]) { assert := assert.New(t) var ( mockVertexEID = "bae" @@ -293,7 +391,7 @@ func TestDAGAddEdge(t *testing.T) { }, { name: "cycle between vertices", - expect: func(t *testing.T, d DAG) { + expect: func(t *testing.T, d DAG[string]) { assert := assert.New(t) var ( mockVertexEID = "bae" @@ -358,7 +456,7 @@ func TestDAGAddEdge(t *testing.T) { }, { name: "vertex not found", - expect: func(t *testing.T, d DAG) { + expect: func(t *testing.T, d DAG[string]) { assert := assert.New(t) var ( mockVertexEID = "bae" @@ -382,7 +480,7 @@ func TestDAGAddEdge(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - d := NewDAG() + d := NewDAG[string]() tc.expect(t, d) }) } @@ -391,11 +489,11 @@ func TestDAGAddEdge(t *testing.T) { func TestDAGCanAddEdge(t *testing.T) { tests := []struct { name string - expect func(t *testing.T, d DAG) + expect func(t *testing.T, d DAG[string]) }{ { name: "can add edge", - expect: func(t *testing.T, d DAG) { + expect: func(t *testing.T, d DAG[string]) { assert := assert.New(t) var ( mockVertexEID = "bae" @@ -447,7 +545,7 @@ func TestDAGCanAddEdge(t *testing.T) { }, { name: "cycle between vertices", - expect: func(t *testing.T, d DAG) { + expect: func(t *testing.T, d DAG[string]) { assert := assert.New(t) var ( mockVertexEID = "bae" @@ -511,7 +609,7 @@ func TestDAGCanAddEdge(t *testing.T) { }, { name: "vertex not found", - expect: func(t *testing.T, d DAG) { + expect: func(t *testing.T, d DAG[string]) { assert := assert.New(t) var ( mockVertexEID = "bae" @@ -532,7 +630,7 @@ func TestDAGCanAddEdge(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - d := NewDAG() + d := NewDAG[string]() tc.expect(t, d) }) } @@ -541,11 +639,11 @@ func TestDAGCanAddEdge(t *testing.T) { func TestDAGDeleteEdge(t *testing.T) { tests := []struct { name string - expect func(t *testing.T, d DAG) + expect func(t *testing.T, d DAG[string]) }{ { name: "delete edge", - expect: func(t *testing.T, d DAG) { + expect: func(t *testing.T, d DAG[string]) { assert := assert.New(t) var ( mockVertexEID = "bae" @@ -589,7 +687,7 @@ func TestDAGDeleteEdge(t *testing.T) { }, { name: "vertex not found", - expect: func(t *testing.T, d DAG) { + expect: func(t *testing.T, d DAG[string]) { assert := assert.New(t) var ( mockVertexEID = "bae" @@ -613,7 +711,7 @@ func TestDAGDeleteEdge(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - d := NewDAG() + d := NewDAG[string]() tc.expect(t, d) }) } @@ -622,11 +720,11 @@ func TestDAGDeleteEdge(t *testing.T) { func TestDAGSourceVertices(t *testing.T) { tests := []struct { name string - expect func(t *testing.T, d DAG) + expect func(t *testing.T, d DAG[string]) }{ { name: "get source vertices", - expect: func(t *testing.T, d DAG) { + expect: func(t *testing.T, d DAG[string]) { assert := assert.New(t) var ( mockVertexEID = "bae" @@ -651,7 +749,7 @@ func TestDAGSourceVertices(t *testing.T) { }, { name: "source vertices not found", - expect: func(t *testing.T, d DAG) { + expect: func(t *testing.T, d DAG[string]) { assert := assert.New(t) sourceVertices := d.GetSourceVertices() assert.Equal(len(sourceVertices), 0) @@ -660,7 +758,7 @@ func TestDAGSourceVertices(t *testing.T) { } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - d := NewDAG() + d := NewDAG[string]() tc.expect(t, d) }) } @@ -669,11 +767,11 @@ func TestDAGSourceVertices(t *testing.T) { func TestDAGSinkVertices(t *testing.T) { tests := []struct { name string - expect func(t *testing.T, d DAG) + expect func(t *testing.T, d DAG[string]) }{ { name: "get sink vertices", - expect: func(t *testing.T, d DAG) { + expect: func(t *testing.T, d DAG[string]) { assert := assert.New(t) var ( mockVertexEID = "bae" @@ -698,7 +796,7 @@ func TestDAGSinkVertices(t *testing.T) { }, { name: "sink vertices not found", - expect: func(t *testing.T, d DAG) { + expect: func(t *testing.T, d DAG[string]) { assert := assert.New(t) sinkVertices := d.GetSinkVertices() assert.Equal(len(sinkVertices), 0) @@ -707,7 +805,7 @@ func TestDAGSinkVertices(t *testing.T) { } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - d := NewDAG() + d := NewDAG[string]() tc.expect(t, d) }) } @@ -715,14 +813,14 @@ func TestDAGSinkVertices(t *testing.T) { func BenchmarkDAGAddVertex(b *testing.B) { var ids []string - d := NewDAG() + d := NewDAG[string]() for n := 0; n < b.N; n++ { ids = append(ids, fmt.Sprint(n)) } b.ResetTimer() for _, id := range ids { - if err := d.AddVertex(id, nil); err != nil { + if err := d.AddVertex(id, string(id)); err != nil { b.Fatal(err) } } @@ -730,10 +828,10 @@ func BenchmarkDAGAddVertex(b *testing.B) { func BenchmarkDAGDeleteVertex(b *testing.B) { var ids []string - d := NewDAG() + d := NewDAG[string]() for n := 0; n < b.N; n++ { id := fmt.Sprint(n) - if err := d.AddVertex(id, nil); err != nil { + if err := d.AddVertex(id, string(id)); err != nil { b.Fatal(err) } @@ -746,12 +844,30 @@ func BenchmarkDAGDeleteVertex(b *testing.B) { } } -func BenchmarkDAGDeleteVertexWithMultiEdges(b *testing.B) { - var ids []string - d := NewDAG() +func BenchmarkDAGGetRandomKeys(b *testing.B) { + d := NewDAG[string]() for n := 0; n < b.N; n++ { id := fmt.Sprint(n) - if err := d.AddVertex(id, nil); err != nil { + if err := d.AddVertex(id, string(id)); err != nil { + b.Fatal(err) + } + } + + b.ResetTimer() + for n := 0; n < b.N; n++ { + vertices := d.GetRandomVertices(uint(n)) + if len(vertices) != n { + b.Fatal(errors.New("get random vertices failed")) + } + } +} + +func BenchmarkDAGDeleteVertexWithMultiEdges(b *testing.B) { + var ids []string + d := NewDAG[string]() + for n := 0; n < b.N; n++ { + id := fmt.Sprint(n) + if err := d.AddVertex(id, string(id)); err != nil { b.Fatal(err) } @@ -779,10 +895,10 @@ func BenchmarkDAGDeleteVertexWithMultiEdges(b *testing.B) { func BenchmarkDAGAddEdge(b *testing.B) { var ids []string - d := NewDAG() + d := NewDAG[string]() for n := 0; n < b.N; n++ { id := fmt.Sprint(n) - if err := d.AddVertex(id, nil); err != nil { + if err := d.AddVertex(id, string(id)); err != nil { b.Fatal(err) } @@ -803,10 +919,10 @@ func BenchmarkDAGAddEdge(b *testing.B) { func BenchmarkDAGAddEdgeWithMultiEdges(b *testing.B) { var ids []string - d := NewDAG() + d := NewDAG[string]() for n := 0; n < b.N; n++ { id := fmt.Sprint(n) - if err := d.AddVertex(id, nil); err != nil { + if err := d.AddVertex(id, string(id)); err != nil { b.Fatal(err) } @@ -840,10 +956,10 @@ func BenchmarkDAGAddEdgeWithMultiEdges(b *testing.B) { func BenchmarkDAGDeleteEdge(b *testing.B) { var ids []string - d := NewDAG() + d := NewDAG[string]() for n := 0; n < b.N; n++ { id := fmt.Sprint(n) - if err := d.AddVertex(id, nil); err != nil { + if err := d.AddVertex(id, string(id)); err != nil { b.Fatal(err) } diff --git a/pkg/dag/mocks/dag_mock.go b/pkg/dag/mocks/dag_mock.go index 8e63e9294..d272f2678 100644 --- a/pkg/dag/mocks/dag_mock.go +++ b/pkg/dag/mocks/dag_mock.go @@ -12,30 +12,30 @@ import ( ) // MockDAG is a mock of DAG interface. -type MockDAG struct { +type MockDAG[T comparable] struct { ctrl *gomock.Controller - recorder *MockDAGMockRecorder + recorder *MockDAGMockRecorder[T] } // MockDAGMockRecorder is the mock recorder for MockDAG. -type MockDAGMockRecorder struct { - mock *MockDAG +type MockDAGMockRecorder[T comparable] struct { + mock *MockDAG[T] } // NewMockDAG creates a new mock instance. -func NewMockDAG(ctrl *gomock.Controller) *MockDAG { - mock := &MockDAG{ctrl: ctrl} - mock.recorder = &MockDAGMockRecorder{mock} +func NewMockDAG[T comparable](ctrl *gomock.Controller) *MockDAG[T] { + mock := &MockDAG[T]{ctrl: ctrl} + mock.recorder = &MockDAGMockRecorder[T]{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockDAG) EXPECT() *MockDAGMockRecorder { +func (m *MockDAG[T]) EXPECT() *MockDAGMockRecorder[T] { return m.recorder } // AddEdge mocks base method. -func (m *MockDAG) AddEdge(fromVertexID, toVertexID string) error { +func (m *MockDAG[T]) AddEdge(fromVertexID, toVertexID string) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AddEdge", fromVertexID, toVertexID) ret0, _ := ret[0].(error) @@ -43,13 +43,13 @@ func (m *MockDAG) AddEdge(fromVertexID, toVertexID string) error { } // AddEdge indicates an expected call of AddEdge. -func (mr *MockDAGMockRecorder) AddEdge(fromVertexID, toVertexID interface{}) *gomock.Call { +func (mr *MockDAGMockRecorder[T]) AddEdge(fromVertexID, toVertexID interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddEdge", reflect.TypeOf((*MockDAG)(nil).AddEdge), fromVertexID, toVertexID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddEdge", reflect.TypeOf((*MockDAG[T])(nil).AddEdge), fromVertexID, toVertexID) } // AddVertex mocks base method. -func (m *MockDAG) AddVertex(id string, value any) error { +func (m *MockDAG[T]) AddVertex(id string, value T) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AddVertex", id, value) ret0, _ := ret[0].(error) @@ -57,13 +57,13 @@ func (m *MockDAG) AddVertex(id string, value any) error { } // AddVertex indicates an expected call of AddVertex. -func (mr *MockDAGMockRecorder) AddVertex(id, value interface{}) *gomock.Call { +func (mr *MockDAGMockRecorder[T]) AddVertex(id, value interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddVertex", reflect.TypeOf((*MockDAG)(nil).AddVertex), id, value) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddVertex", reflect.TypeOf((*MockDAG[T])(nil).AddVertex), id, value) } // CanAddEdge mocks base method. -func (m *MockDAG) CanAddEdge(fromVertexID, toVertexID string) bool { +func (m *MockDAG[T]) CanAddEdge(fromVertexID, toVertexID string) bool { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "CanAddEdge", fromVertexID, toVertexID) ret0, _ := ret[0].(bool) @@ -71,13 +71,13 @@ func (m *MockDAG) CanAddEdge(fromVertexID, toVertexID string) bool { } // CanAddEdge indicates an expected call of CanAddEdge. -func (mr *MockDAGMockRecorder) CanAddEdge(fromVertexID, toVertexID interface{}) *gomock.Call { +func (mr *MockDAGMockRecorder[T]) CanAddEdge(fromVertexID, toVertexID interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CanAddEdge", reflect.TypeOf((*MockDAG)(nil).CanAddEdge), fromVertexID, toVertexID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CanAddEdge", reflect.TypeOf((*MockDAG[T])(nil).CanAddEdge), fromVertexID, toVertexID) } // DeleteEdge mocks base method. -func (m *MockDAG) DeleteEdge(fromVertexID, toVertexID string) error { +func (m *MockDAG[T]) DeleteEdge(fromVertexID, toVertexID string) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "DeleteEdge", fromVertexID, toVertexID) ret0, _ := ret[0].(error) @@ -85,82 +85,110 @@ func (m *MockDAG) DeleteEdge(fromVertexID, toVertexID string) error { } // DeleteEdge indicates an expected call of DeleteEdge. -func (mr *MockDAGMockRecorder) DeleteEdge(fromVertexID, toVertexID interface{}) *gomock.Call { +func (mr *MockDAGMockRecorder[T]) DeleteEdge(fromVertexID, toVertexID interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteEdge", reflect.TypeOf((*MockDAG)(nil).DeleteEdge), fromVertexID, toVertexID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteEdge", reflect.TypeOf((*MockDAG[T])(nil).DeleteEdge), fromVertexID, toVertexID) } // DeleteVertex mocks base method. -func (m *MockDAG) DeleteVertex(id string) { +func (m *MockDAG[T]) DeleteVertex(id string) { m.ctrl.T.Helper() m.ctrl.Call(m, "DeleteVertex", id) } // DeleteVertex indicates an expected call of DeleteVertex. -func (mr *MockDAGMockRecorder) DeleteVertex(id interface{}) *gomock.Call { +func (mr *MockDAGMockRecorder[T]) DeleteVertex(id interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteVertex", reflect.TypeOf((*MockDAG)(nil).DeleteVertex), id) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteVertex", reflect.TypeOf((*MockDAG[T])(nil).DeleteVertex), id) +} + +// GetRandomVertices mocks base method. +func (m *MockDAG[T]) GetRandomVertices(n uint) map[string]*dag.Vertex[T] { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRandomVertices", n) + ret0, _ := ret[0].(map[string]*dag.Vertex[T]) + return ret0 +} + +// GetRandomVertices indicates an expected call of GetRandomVertices. +func (mr *MockDAGMockRecorder[T]) GetRandomVertices(n interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRandomVertices", reflect.TypeOf((*MockDAG[T])(nil).GetRandomVertices), n) } // GetSinkVertices mocks base method. -func (m *MockDAG) GetSinkVertices() map[string]*dag.Vertex { +func (m *MockDAG[T]) GetSinkVertices() map[string]*dag.Vertex[T] { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetSinkVertices") - ret0, _ := ret[0].(map[string]*dag.Vertex) + ret0, _ := ret[0].(map[string]*dag.Vertex[T]) return ret0 } // GetSinkVertices indicates an expected call of GetSinkVertices. -func (mr *MockDAGMockRecorder) GetSinkVertices() *gomock.Call { +func (mr *MockDAGMockRecorder[T]) GetSinkVertices() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSinkVertices", reflect.TypeOf((*MockDAG)(nil).GetSinkVertices)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSinkVertices", reflect.TypeOf((*MockDAG[T])(nil).GetSinkVertices)) } // GetSourceVertices mocks base method. -func (m *MockDAG) GetSourceVertices() map[string]*dag.Vertex { +func (m *MockDAG[T]) GetSourceVertices() map[string]*dag.Vertex[T] { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetSourceVertices") - ret0, _ := ret[0].(map[string]*dag.Vertex) + ret0, _ := ret[0].(map[string]*dag.Vertex[T]) return ret0 } // GetSourceVertices indicates an expected call of GetSourceVertices. -func (mr *MockDAGMockRecorder) GetSourceVertices() *gomock.Call { +func (mr *MockDAGMockRecorder[T]) GetSourceVertices() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSourceVertices", reflect.TypeOf((*MockDAG)(nil).GetSourceVertices)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSourceVertices", reflect.TypeOf((*MockDAG[T])(nil).GetSourceVertices)) } // GetVertex mocks base method. -func (m *MockDAG) GetVertex(id string) (*dag.Vertex, error) { +func (m *MockDAG[T]) GetVertex(id string) (*dag.Vertex[T], error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetVertex", id) - ret0, _ := ret[0].(*dag.Vertex) + ret0, _ := ret[0].(*dag.Vertex[T]) ret1, _ := ret[1].(error) return ret0, ret1 } // GetVertex indicates an expected call of GetVertex. -func (mr *MockDAGMockRecorder) GetVertex(id interface{}) *gomock.Call { +func (mr *MockDAGMockRecorder[T]) GetVertex(id interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVertex", reflect.TypeOf((*MockDAG)(nil).GetVertex), id) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVertex", reflect.TypeOf((*MockDAG[T])(nil).GetVertex), id) +} + +// GetVertexKeys mocks base method. +func (m *MockDAG[T]) GetVertexKeys() []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetVertexKeys") + ret0, _ := ret[0].([]string) + return ret0 +} + +// GetVertexKeys indicates an expected call of GetVertexKeys. +func (mr *MockDAGMockRecorder[T]) GetVertexKeys() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVertexKeys", reflect.TypeOf((*MockDAG[T])(nil).GetVertexKeys)) } // GetVertices mocks base method. -func (m *MockDAG) GetVertices() map[string]*dag.Vertex { +func (m *MockDAG[T]) GetVertices() map[string]*dag.Vertex[T] { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetVertices") - ret0, _ := ret[0].(map[string]*dag.Vertex) + ret0, _ := ret[0].(map[string]*dag.Vertex[T]) return ret0 } // GetVertices indicates an expected call of GetVertices. -func (mr *MockDAGMockRecorder) GetVertices() *gomock.Call { +func (mr *MockDAGMockRecorder[T]) GetVertices() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVertices", reflect.TypeOf((*MockDAG)(nil).GetVertices)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVertices", reflect.TypeOf((*MockDAG[T])(nil).GetVertices)) } // VertexCount mocks base method. -func (m *MockDAG) VertexCount() int { +func (m *MockDAG[T]) VertexCount() int { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "VertexCount") ret0, _ := ret[0].(int) @@ -168,7 +196,7 @@ func (m *MockDAG) VertexCount() int { } // VertexCount indicates an expected call of VertexCount. -func (mr *MockDAGMockRecorder) VertexCount() *gomock.Call { +func (mr *MockDAGMockRecorder[T]) VertexCount() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "VertexCount", reflect.TypeOf((*MockDAG)(nil).VertexCount)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "VertexCount", reflect.TypeOf((*MockDAG[T])(nil).VertexCount)) } diff --git a/pkg/dag/vertex.go b/pkg/dag/vertex.go index 96c06ab3d..86ee5c03c 100644 --- a/pkg/dag/vertex.go +++ b/pkg/dag/vertex.go @@ -19,62 +19,52 @@ package dag import "d7y.io/dragonfly/v2/pkg/container/set" // Vertex is a vertex of the directed acyclic graph. -type Vertex struct { +type Vertex[T comparable] struct { ID string - Value any - Parents set.SafeSet - Children set.SafeSet + Value T + Parents set.SafeSet[*Vertex[T]] + Children set.SafeSet[*Vertex[T]] } // New returns a new Vertex instance. -func NewVertex(id string, value any) *Vertex { - return &Vertex{ +func NewVertex[T comparable](id string, value T) *Vertex[T] { + return &Vertex[T]{ ID: id, Value: value, - Parents: set.NewSafeSet(), - Children: set.NewSafeSet(), + Parents: set.NewSafeSet[*Vertex[T]](), + Children: set.NewSafeSet[*Vertex[T]](), } } // Degree returns the degree of vertex. -func (v *Vertex) Degree() int { +func (v *Vertex[T]) Degree() int { return int(v.Parents.Len() + v.Children.Len()) } // InDegree returns the indegree of vertex. -func (v *Vertex) InDegree() int { +func (v *Vertex[T]) InDegree() int { return int(v.Parents.Len()) } // OutDegree returns the outdegree of vertex. -func (v *Vertex) OutDegree() int { +func (v *Vertex[T]) OutDegree() int { return int(v.Children.Len()) } // DeleteInEdges deletes inedges of vertex. -func (v *Vertex) DeleteInEdges() { - for _, value := range v.Parents.Values() { - vertex, ok := value.(*Vertex) - if !ok { - continue - } - - vertex.Children.Delete(v) +func (v *Vertex[T]) DeleteInEdges() { + for _, parent := range v.Parents.Values() { + parent.Children.Delete(v) } - v.Parents = set.NewSafeSet() + v.Parents = set.NewSafeSet[*Vertex[T]]() } // DeleteOutEdges deletes outedges of vertex. -func (v *Vertex) DeleteOutEdges() { - for _, value := range v.Children.Values() { - vertex, ok := value.(*Vertex) - if !ok { - continue - } - - vertex.Parents.Delete(v) +func (v *Vertex[T]) DeleteOutEdges() { + for _, child := range v.Children.Values() { + child.Parents.Delete(v) } - v.Children = set.NewSafeSet() + v.Children = set.NewSafeSet[*Vertex[T]]() } diff --git a/pkg/dag/vertex_test.go b/pkg/dag/vertex_test.go index 5c063b851..410d95517 100644 --- a/pkg/dag/vertex_test.go +++ b/pkg/dag/vertex_test.go @@ -43,16 +43,16 @@ func TestVertexDegree(t *testing.T) { assert.Equal(v.Value, mockVertexValue) assert.Equal(v.Degree(), 0) - v.Parents.Add(mockVertexID) + v.Parents.Add(v) assert.Equal(v.Degree(), 1) - v.Children.Add(mockVertexID) + v.Children.Add(v) assert.Equal(v.Degree(), 2) - v.Parents.Delete(mockVertexID) + v.Parents.Delete(v) assert.Equal(v.Degree(), 1) - v.Children.Delete(mockVertexID) + v.Children.Delete(v) assert.Equal(v.Degree(), 0) } @@ -63,16 +63,16 @@ func TestVertexInDegree(t *testing.T) { assert.Equal(v.Value, mockVertexValue) assert.Equal(v.InDegree(), 0) - v.Parents.Add(mockVertexID) + v.Parents.Add(v) assert.Equal(v.InDegree(), 1) - v.Children.Add(mockVertexID) + v.Children.Add(v) assert.Equal(v.InDegree(), 1) - v.Parents.Delete(mockVertexID) + v.Parents.Delete(v) assert.Equal(v.InDegree(), 0) - v.Children.Delete(mockVertexID) + v.Children.Delete(v) assert.Equal(v.InDegree(), 0) } @@ -83,16 +83,16 @@ func TestVertexOutDegree(t *testing.T) { assert.Equal(v.Value, mockVertexValue) assert.Equal(v.OutDegree(), 0) - v.Parents.Add(mockVertexID) + v.Parents.Add(v) assert.Equal(v.OutDegree(), 0) - v.Children.Add(mockVertexID) + v.Children.Add(v) assert.Equal(v.OutDegree(), 1) - v.Parents.Delete(mockVertexID) + v.Parents.Delete(v) assert.Equal(v.OutDegree(), 1) - v.Children.Delete(mockVertexID) + v.Children.Delete(v) assert.Equal(v.OutDegree(), 0) } diff --git a/scheduler/resource/peer.go b/scheduler/resource/peer.go index 61dd5c85f..4bed49585 100644 --- a/scheduler/resource/peer.go +++ b/scheduler/resource/peer.go @@ -31,7 +31,6 @@ import ( logger "d7y.io/dragonfly/v2/internal/dflog" "d7y.io/dragonfly/v2/pkg/container/set" - "d7y.io/dragonfly/v2/pkg/dag" "d7y.io/dragonfly/v2/pkg/rpc/scheduler" ) @@ -135,7 +134,7 @@ type Peer struct { Host *Host // BlockPeers is bad peer ids. - BlockPeers set.SafeSet + BlockPeers set.SafeSet[string] // NeedBackToSource needs downloaded from source. // @@ -171,7 +170,7 @@ func NewPeer(id string, task *Task, host *Host, options ...PeerOption) *Peer { Stream: &atomic.Value{}, Task: task, Host: host, - BlockPeers: set.NewSafeSet(), + BlockPeers: set.NewSafeSet[string](), NeedBackToSource: atomic.NewBool(false), IsBackToSource: atomic.NewBool(false), CreateAt: atomic.NewTime(time.Now()), @@ -222,7 +221,7 @@ func NewPeer(id string, task *Task, host *Host, options ...PeerOption) *Peer { }, PeerEventDownloadFromBackToSource: func(e *fsm.Event) { p.IsBackToSource.Store(true) - p.Task.BackToSourcePeers.Add(p) + p.Task.BackToSourcePeers.Add(p.ID) if err := p.Task.DeletePeerInEdges(p.ID); err != nil { p.Log.Errorf("delete peer inedges failed: %s", err.Error()) @@ -234,7 +233,7 @@ func NewPeer(id string, task *Task, host *Host, options ...PeerOption) *Peer { }, PeerEventDownloadSucceeded: func(e *fsm.Event) { if e.Src == PeerStateBackToSource { - p.Task.BackToSourcePeers.Delete(p) + p.Task.BackToSourcePeers.Delete(p.ID) } if err := p.Task.DeletePeerInEdges(p.ID); err != nil { @@ -249,7 +248,7 @@ func NewPeer(id string, task *Task, host *Host, options ...PeerOption) *Peer { PeerEventDownloadFailed: func(e *fsm.Event) { if e.Src == PeerStateBackToSource { p.Task.PeerFailedCount.Inc() - p.Task.BackToSourcePeers.Delete(p) + p.Task.BackToSourcePeers.Delete(p.ID) } if err := p.Task.DeletePeerInEdges(p.ID); err != nil { @@ -317,23 +316,12 @@ func (p *Peer) Parents() []*Peer { } var parents []*Peer - for _, value := range vertex.Parents.Values() { - vertex, ok := value.(*dag.Vertex) - if !ok { + for _, parent := range vertex.Parents.Values() { + if parent.Value == nil { continue } - vertexVal := vertex.Value - if vertexVal == nil { - continue - } - - parent, ok := vertexVal.(*Peer) - if !ok { - continue - } - - parents = append(parents, parent) + parents = append(parents, parent.Value) } return parents @@ -348,23 +336,12 @@ func (p *Peer) Children() []*Peer { } var children []*Peer - for _, value := range vertex.Children.Values() { - vertex, ok := value.(*dag.Vertex) - if !ok { + for _, child := range vertex.Children.Values() { + if child.Value == nil { continue } - vertexVal := vertex.Value - if vertexVal == nil { - continue - } - - child, ok := vertexVal.(*Peer) - if !ok { - continue - } - - children = append(children, child) + children = append(children, child.Value) } return children diff --git a/scheduler/resource/task.go b/scheduler/resource/task.go index 6dccc6eb2..d42ee64c1 100644 --- a/scheduler/resource/task.go +++ b/scheduler/resource/task.go @@ -18,8 +18,6 @@ package resource import ( "errors" - "math/rand" - reflect "reflect" "sort" "sync" "time" @@ -106,7 +104,7 @@ type Task struct { BackToSourceLimit *atomic.Int32 // BackToSourcePeers is back-to-source sync map. - BackToSourcePeers set.SafeSet + BackToSourcePeers set.SafeSet[string] // Task state machine. FSM *fsm.FSM @@ -115,7 +113,7 @@ type Task struct { Pieces *sync.Map // DAG is directed acyclic graph of peers. - DAG dag.DAG + DAG dag.DAG[*Peer] // PeerFailedCount is peer failed count, // if one peer succeeds, the value is reset to zero. @@ -141,9 +139,9 @@ func NewTask(id, url string, taskType base.TaskType, meta *base.UrlMeta, options ContentLength: atomic.NewInt64(0), TotalPieceCount: atomic.NewInt32(0), BackToSourceLimit: atomic.NewInt32(0), - BackToSourcePeers: set.NewSafeSet(), + BackToSourcePeers: set.NewSafeSet[string](), Pieces: &sync.Map{}, - DAG: dag.NewDAG(), + DAG: dag.NewDAG[*Peer](), PeerFailedCount: atomic.NewInt32(0), CreateAt: atomic.NewTime(time.Now()), UpdateAt: atomic.NewTime(time.Now()), @@ -189,54 +187,14 @@ func (t *Task) LoadPeer(key string) (*Peer, bool) { return nil, false } - value := vertex.Value - if value == nil { - return nil, false - } - - return value.(*Peer), true + return vertex.Value, true } // LoadRandomPeers return random peers. func (t *Task) LoadRandomPeers(n uint) []*Peer { var peers []*Peer - vertices := t.DAG.GetVertices() - keys := reflect.ValueOf(vertices).MapKeys() - if int(n) >= len(keys) { - for _, vertex := range vertices { - value := vertex.Value - if value == nil { - continue - } - - peer, ok := value.(*Peer) - if !ok { - continue - } - - peers = append(peers, peer) - } - - return peers - } - - rand.Seed(time.Now().Unix()) - permutation := rand.Perm(len(keys))[:n] - for _, v := range permutation { - key := keys[v].String() - - vertex := vertices[key] - value := vertex.Value - if value == nil { - continue - } - - peer, ok := value.(*Peer) - if !ok { - continue - } - - peers = append(peers, peer) + for _, vertex := range t.DAG.GetRandomVertices(n) { + peers = append(peers, vertex.Value) } return peers @@ -282,23 +240,12 @@ func (t *Task) DeletePeerInEdges(key string) error { return err } - for _, value := range vertex.Parents.Values() { - vertex, ok := value.(*dag.Vertex) - if !ok { + for _, parent := range vertex.Parents.Values() { + if parent.Value == nil { continue } - vertexVal := vertex.Value - if vertexVal == nil { - continue - } - - parent, ok := vertexVal.(*Peer) - if !ok { - continue - } - - parent.Host.UploadPeerCount.Dec() + parent.Value.Host.UploadPeerCount.Dec() } vertex.DeleteInEdges() @@ -312,16 +259,11 @@ func (t *Task) DeletePeerOutEdges(key string) error { return err } - value := vertex.Value - if value == nil { + peer := vertex.Value + if peer == nil { return errors.New("vertex value is nil") } - peer, ok := value.(*Peer) - if !ok { - return errors.New("vertex value is not peer") - } - peer.Host.UploadPeerCount.Sub(int32(vertex.Children.Len())) vertex.DeleteOutEdges() return nil @@ -366,13 +308,8 @@ func (t *Task) PeerOutDegree(key string) (int, error) { func (t *Task) HasAvailablePeer() bool { var hasAvailablePeer bool for _, vertex := range t.DAG.GetVertices() { - value := vertex.Value - if value == nil { - continue - } - - peer, ok := value.(*Peer) - if !ok { + peer := vertex.Value + if peer == nil { continue } @@ -389,13 +326,8 @@ func (t *Task) HasAvailablePeer() bool { func (t *Task) LoadSeedPeer() (*Peer, bool) { var peers []*Peer for _, vertex := range t.DAG.GetVertices() { - value := vertex.Value - if value == nil { - continue - } - - peer, ok := value.(*Peer) - if !ok { + peer := vertex.Value + if peer == nil { continue } @@ -473,12 +405,11 @@ func (t *Task) CanBackToSource() bool { // NotifyPeers notify all peers in the task with the state code. func (t *Task) NotifyPeers(peerPacket *rpcscheduler.PeerPacket, event string) { for _, vertex := range t.DAG.GetVertices() { - value := vertex.Value - if value == nil { + peer := vertex.Value + if peer == nil { continue } - peer := value.(*Peer) if peer.FSM.Is(PeerStateRunning) { stream, ok := peer.LoadStream() if !ok { diff --git a/scheduler/scheduler/mocks/scheduler_mock.go b/scheduler/scheduler/mocks/scheduler_mock.go index 82857f449..e087412b7 100644 --- a/scheduler/scheduler/mocks/scheduler_mock.go +++ b/scheduler/scheduler/mocks/scheduler_mock.go @@ -37,7 +37,7 @@ func (m *MockScheduler) EXPECT() *MockSchedulerMockRecorder { } // FindParent mocks base method. -func (m *MockScheduler) FindParent(arg0 context.Context, arg1 *resource.Peer, arg2 set.SafeSet) (*resource.Peer, bool) { +func (m *MockScheduler) FindParent(arg0 context.Context, arg1 *resource.Peer, arg2 set.SafeSet[string]) (*resource.Peer, bool) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "FindParent", arg0, arg1, arg2) ret0, _ := ret[0].(*resource.Peer) @@ -52,7 +52,7 @@ func (mr *MockSchedulerMockRecorder) FindParent(arg0, arg1, arg2 interface{}) *g } // NotifyAndFindParent mocks base method. -func (m *MockScheduler) NotifyAndFindParent(arg0 context.Context, arg1 *resource.Peer, arg2 set.SafeSet) ([]*resource.Peer, bool) { +func (m *MockScheduler) NotifyAndFindParent(arg0 context.Context, arg1 *resource.Peer, arg2 set.SafeSet[string]) ([]*resource.Peer, bool) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "NotifyAndFindParent", arg0, arg1, arg2) ret0, _ := ret[0].([]*resource.Peer) @@ -67,7 +67,7 @@ func (mr *MockSchedulerMockRecorder) NotifyAndFindParent(arg0, arg1, arg2 interf } // ScheduleParent mocks base method. -func (m *MockScheduler) ScheduleParent(arg0 context.Context, arg1 *resource.Peer, arg2 set.SafeSet) { +func (m *MockScheduler) ScheduleParent(arg0 context.Context, arg1 *resource.Peer, arg2 set.SafeSet[string]) { m.ctrl.T.Helper() m.ctrl.Call(m, "ScheduleParent", arg0, arg1, arg2) } diff --git a/scheduler/scheduler/scheduler.go b/scheduler/scheduler/scheduler.go index ccd1872a7..2eb2813f6 100644 --- a/scheduler/scheduler/scheduler.go +++ b/scheduler/scheduler/scheduler.go @@ -33,13 +33,13 @@ import ( type Scheduler interface { // ScheduleParent schedule a parent and candidates to a peer. - ScheduleParent(context.Context, *resource.Peer, set.SafeSet) + ScheduleParent(context.Context, *resource.Peer, set.SafeSet[string]) // Find the parent that best matches the evaluation and notify peer. - NotifyAndFindParent(context.Context, *resource.Peer, set.SafeSet) ([]*resource.Peer, bool) + NotifyAndFindParent(context.Context, *resource.Peer, set.SafeSet[string]) ([]*resource.Peer, bool) // Find the parent that best matches the evaluation. - FindParent(context.Context, *resource.Peer, set.SafeSet) (*resource.Peer, bool) + FindParent(context.Context, *resource.Peer, set.SafeSet[string]) (*resource.Peer, bool) } type scheduler struct { @@ -62,7 +62,7 @@ func New(cfg *config.SchedulerConfig, dynconfig config.DynconfigInterface, plugi } // ScheduleParent schedule a parent and candidates to a peer. -func (s *scheduler) ScheduleParent(ctx context.Context, peer *resource.Peer, blocklist set.SafeSet) { +func (s *scheduler) ScheduleParent(ctx context.Context, peer *resource.Peer, blocklist set.SafeSet[string]) { var n int for { select { @@ -141,7 +141,7 @@ func (s *scheduler) ScheduleParent(ctx context.Context, peer *resource.Peer, blo } // NotifyAndFindParent finds parent that best matches the evaluation and notify peer. -func (s *scheduler) NotifyAndFindParent(ctx context.Context, peer *resource.Peer, blocklist set.SafeSet) ([]*resource.Peer, bool) { +func (s *scheduler) NotifyAndFindParent(ctx context.Context, peer *resource.Peer, blocklist set.SafeSet[string]) ([]*resource.Peer, bool) { // Only PeerStateRunning peers need to be rescheduled, // and other states including the PeerStateBackToSource indicate that // they have been scheduled. @@ -209,7 +209,7 @@ func (s *scheduler) NotifyAndFindParent(ctx context.Context, peer *resource.Peer } // FindParent finds parent that best matches the evaluation. -func (s *scheduler) FindParent(ctx context.Context, peer *resource.Peer, blocklist set.SafeSet) (*resource.Peer, bool) { +func (s *scheduler) FindParent(ctx context.Context, peer *resource.Peer, blocklist set.SafeSet[string]) (*resource.Peer, bool) { // Filter the candidate parent that can be scheduled. candidateParents := s.filterCandidateParents(peer, blocklist) if len(candidateParents) == 0 { @@ -231,7 +231,7 @@ func (s *scheduler) FindParent(ctx context.Context, peer *resource.Peer, blockli } // Filter the candidate parent that can be scheduled. -func (s *scheduler) filterCandidateParents(peer *resource.Peer, blocklist set.SafeSet) []*resource.Peer { +func (s *scheduler) filterCandidateParents(peer *resource.Peer, blocklist set.SafeSet[string]) []*resource.Peer { filterParentLimit := config.DefaultSchedulerFilterParentLimit if config, ok := s.dynconfig.GetSchedulerClusterConfig(); ok && filterParentLimit > 0 { filterParentLimit = int(config.FilterParentLimit) diff --git a/scheduler/scheduler/scheduler_test.go b/scheduler/scheduler/scheduler_test.go index 6365e94d1..5d04e5376 100644 --- a/scheduler/scheduler/scheduler_test.go +++ b/scheduler/scheduler/scheduler_test.go @@ -127,12 +127,12 @@ func TestScheduler_New(t *testing.T) { func TestScheduler_ScheduleParent(t *testing.T) { tests := []struct { name string - mock func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) + mock func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) expect func(t *testing.T, peer *resource.Peer) }{ { name: "context was done", - mock: func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { + mock: func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { peer.FSM.SetState(resource.PeerStateRunning) cancel() }, @@ -143,7 +143,7 @@ func TestScheduler_ScheduleParent(t *testing.T) { }, { name: "peer needs back-to-source and peer stream load failed", - mock: func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { + mock: func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { task := peer.Task task.StorePeer(peer) peer.NeedBackToSource.Store(true) @@ -156,7 +156,7 @@ func TestScheduler_ScheduleParent(t *testing.T) { }, { name: "peer needs back-to-source and send Code_SchedNeedBackSource code failed", - mock: func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { + mock: func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { task := peer.Task task.StorePeer(peer) peer.NeedBackToSource.Store(true) @@ -173,7 +173,7 @@ func TestScheduler_ScheduleParent(t *testing.T) { }, { name: "peer needs back-to-source and send Code_SchedNeedBackSource code success", - mock: func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { + mock: func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { task := peer.Task task.StorePeer(peer) peer.NeedBackToSource.Store(true) @@ -191,7 +191,7 @@ func TestScheduler_ScheduleParent(t *testing.T) { }, { name: "peer needs back-to-source and task state is TaskStateFailed", - mock: func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { + mock: func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { task := peer.Task task.StorePeer(peer) peer.NeedBackToSource.Store(true) @@ -210,7 +210,7 @@ func TestScheduler_ScheduleParent(t *testing.T) { }, { name: "schedule exceeds RetryBackSourceLimit and peer stream load failed", - mock: func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { + mock: func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { task := peer.Task task.StorePeer(peer) peer.FSM.SetState(resource.PeerStateRunning) @@ -223,7 +223,7 @@ func TestScheduler_ScheduleParent(t *testing.T) { }, { name: "schedule exceeds RetryLimit and peer stream load failed", - mock: func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { + mock: func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { task := peer.Task task.StorePeer(peer) peer.FSM.SetState(resource.PeerStateRunning) @@ -238,7 +238,7 @@ func TestScheduler_ScheduleParent(t *testing.T) { }, { name: "schedule exceeds RetryLimit and send Code_SchedTaskStatusError code failed", - mock: func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { + mock: func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { task := peer.Task task.StorePeer(peer) peer.FSM.SetState(resource.PeerStateRunning) @@ -258,7 +258,7 @@ func TestScheduler_ScheduleParent(t *testing.T) { }, { name: "schedule exceeds RetryLimit and send Code_SchedTaskStatusError code success", - mock: func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { + mock: func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { task := peer.Task task.StorePeer(peer) peer.FSM.SetState(resource.PeerStateRunning) @@ -278,7 +278,7 @@ func TestScheduler_ScheduleParent(t *testing.T) { }, { name: "schedule succeeded", - mock: func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { + mock: func(cancel context.CancelFunc, peer *resource.Peer, seedPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, mr *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { task := peer.Task task.StorePeer(peer) task.StorePeer(seedPeer) @@ -313,7 +313,7 @@ func TestScheduler_ScheduleParent(t *testing.T) { peer := resource.NewPeer(mockPeerID, mockTask, mockHost) mockSeedHost := resource.NewHost(mockRawSeedHost, resource.WithHostType(resource.HostTypeSuperSeed)) seedPeer := resource.NewPeer(mockSeedPeerID, mockTask, mockSeedHost) - blocklist := set.NewSafeSet() + blocklist := set.NewSafeSet[string]() tc.mock(cancel, peer, seedPeer, blocklist, stream, stream.EXPECT(), dynconfig.EXPECT()) scheduler := New(mockSchedulerConfig, dynconfig, mockPluginDir) @@ -326,12 +326,12 @@ func TestScheduler_ScheduleParent(t *testing.T) { func TestScheduler_NotifyAndFindParent(t *testing.T) { tests := []struct { name string - mock func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) + mock func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) expect func(t *testing.T, peer *resource.Peer, parents []*resource.Peer, ok bool) }{ { name: "peer state is PeerStatePending", - mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { + mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { peer.FSM.SetState(resource.PeerStatePending) }, expect: func(t *testing.T, peer *resource.Peer, parents []*resource.Peer, ok bool) { @@ -341,7 +341,7 @@ func TestScheduler_NotifyAndFindParent(t *testing.T) { }, { name: "peer state is PeerStateReceivedSmall", - mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { + mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { peer.FSM.SetState(resource.PeerStateReceivedSmall) }, expect: func(t *testing.T, peer *resource.Peer, parents []*resource.Peer, ok bool) { @@ -351,7 +351,7 @@ func TestScheduler_NotifyAndFindParent(t *testing.T) { }, { name: "peer state is PeerStateReceivedNormal", - mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { + mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { peer.FSM.SetState(resource.PeerStateReceivedNormal) }, expect: func(t *testing.T, peer *resource.Peer, parents []*resource.Peer, ok bool) { @@ -361,7 +361,7 @@ func TestScheduler_NotifyAndFindParent(t *testing.T) { }, { name: "peer state is PeerStateBackToSource", - mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { + mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { peer.FSM.SetState(resource.PeerStateBackToSource) }, expect: func(t *testing.T, peer *resource.Peer, parents []*resource.Peer, ok bool) { @@ -371,7 +371,7 @@ func TestScheduler_NotifyAndFindParent(t *testing.T) { }, { name: "peer state is PeerStateSucceeded", - mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { + mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { peer.FSM.SetState(resource.PeerStateSucceeded) }, expect: func(t *testing.T, peer *resource.Peer, parents []*resource.Peer, ok bool) { @@ -381,7 +381,7 @@ func TestScheduler_NotifyAndFindParent(t *testing.T) { }, { name: "peer state is PeerStateFailed", - mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { + mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { peer.FSM.SetState(resource.PeerStateFailed) }, expect: func(t *testing.T, peer *resource.Peer, parents []*resource.Peer, ok bool) { @@ -391,7 +391,7 @@ func TestScheduler_NotifyAndFindParent(t *testing.T) { }, { name: "peer state is PeerStateLeave", - mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { + mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { peer.FSM.SetState(resource.PeerStateLeave) }, expect: func(t *testing.T, peer *resource.Peer, parents []*resource.Peer, ok bool) { @@ -401,7 +401,7 @@ func TestScheduler_NotifyAndFindParent(t *testing.T) { }, { name: "task peers is empty", - mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { + mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { peer.FSM.SetState(resource.PeerStateRunning) peer.Task.StorePeer(peer) md.GetSchedulerClusterConfig().Return(types.SchedulerClusterConfig{}, false).Times(1) @@ -413,7 +413,7 @@ func TestScheduler_NotifyAndFindParent(t *testing.T) { }, { name: "task contains only one peer and peer is itself", - mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { + mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { peer.FSM.SetState(resource.PeerStateRunning) peer.Task.StorePeer(peer) @@ -426,7 +426,7 @@ func TestScheduler_NotifyAndFindParent(t *testing.T) { }, { name: "peer is in blocklist", - mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { + mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { peer.FSM.SetState(resource.PeerStateRunning) peer.Task.StorePeer(peer) peer.Task.StorePeer(mockPeer) @@ -441,7 +441,7 @@ func TestScheduler_NotifyAndFindParent(t *testing.T) { }, { name: "peer is bad node", - mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { + mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { peer.FSM.SetState(resource.PeerStateRunning) peer.FSM.SetState(resource.PeerStateFailed) peer.Task.StorePeer(mockPeer) @@ -453,7 +453,7 @@ func TestScheduler_NotifyAndFindParent(t *testing.T) { }, { name: "parent is peer's descendant", - mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { + mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { peer.FSM.SetState(resource.PeerStateRunning) mockPeer.FSM.SetState(resource.PeerStateRunning) peer.Task.StorePeer(peer) @@ -471,7 +471,7 @@ func TestScheduler_NotifyAndFindParent(t *testing.T) { }, { name: "parent is peer's ancestor", - mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { + mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { peer.FSM.SetState(resource.PeerStateRunning) mockPeer.FSM.SetState(resource.PeerStateRunning) peer.Task.StorePeer(peer) @@ -489,7 +489,7 @@ func TestScheduler_NotifyAndFindParent(t *testing.T) { }, { name: "parent free upload load is zero", - mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { + mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { peer.FSM.SetState(resource.PeerStateRunning) mockPeer.FSM.SetState(resource.PeerStateRunning) peer.Task.StorePeer(peer) @@ -505,7 +505,7 @@ func TestScheduler_NotifyAndFindParent(t *testing.T) { }, { name: "peer stream is empty", - mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { + mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { peer.FSM.SetState(resource.PeerStateRunning) mockPeer.FSM.SetState(resource.PeerStateRunning) peer.Task.StorePeer(peer) @@ -521,10 +521,10 @@ func TestScheduler_NotifyAndFindParent(t *testing.T) { }, { name: "peer stream send failed", - mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { + mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { peer.FSM.SetState(resource.PeerStateRunning) mockPeer.FSM.SetState(resource.PeerStateRunning) - peer.Task.BackToSourcePeers.Add(mockPeer) + peer.Task.BackToSourcePeers.Add(mockPeer.ID) mockPeer.IsBackToSource.Store(true) peer.Task.StorePeer(peer) peer.Task.StorePeer(mockPeer) @@ -545,7 +545,7 @@ func TestScheduler_NotifyAndFindParent(t *testing.T) { }, { name: "schedule parent", - mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet, stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { + mock: func(peer *resource.Peer, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, blocklist set.SafeSet[string], stream rpcscheduler.Scheduler_ReportPieceResultServer, dynconfig config.DynconfigInterface, ms *rpcschedulermocks.MockScheduler_ReportPieceResultServerMockRecorder, md *configmocks.MockDynconfigInterfaceMockRecorder) { peer.FSM.SetState(resource.PeerStateRunning) mockPeer.FSM.SetState(resource.PeerStateRunning) candidatePeer := resource.NewPeer(idgen.PeerID("127.0.0.1"), mockTask, mockHost) @@ -553,8 +553,8 @@ func TestScheduler_NotifyAndFindParent(t *testing.T) { peer.Task.StorePeer(peer) peer.Task.StorePeer(mockPeer) peer.Task.StorePeer(candidatePeer) - peer.Task.BackToSourcePeers.Add(mockPeer) - peer.Task.BackToSourcePeers.Add(candidatePeer) + peer.Task.BackToSourcePeers.Add(mockPeer.ID) + peer.Task.BackToSourcePeers.Add(candidatePeer.ID) mockPeer.IsBackToSource.Store(true) candidatePeer.IsBackToSource.Store(true) mockPeer.Pieces.Set(0) @@ -585,7 +585,7 @@ func TestScheduler_NotifyAndFindParent(t *testing.T) { mockTask := resource.NewTask(mockTaskID, mockTaskURL, base.TaskType_Normal, mockTaskURLMeta, resource.WithBackToSourceLimit(mockTaskBackToSourceLimit)) peer := resource.NewPeer(mockPeerID, mockTask, mockHost) mockPeer := resource.NewPeer(idgen.PeerID("127.0.0.1"), mockTask, mockHost) - blocklist := set.NewSafeSet() + blocklist := set.NewSafeSet[string]() tc.mock(peer, mockHost, mockTask, mockPeer, blocklist, stream, dynconfig, stream.EXPECT(), dynconfig.EXPECT()) scheduler := New(mockSchedulerConfig, dynconfig, mockPluginDir) @@ -598,12 +598,12 @@ func TestScheduler_NotifyAndFindParent(t *testing.T) { func TestScheduler_FindParent(t *testing.T) { tests := []struct { name string - mock func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet, md *configmocks.MockDynconfigInterfaceMockRecorder) + mock func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet[string], md *configmocks.MockDynconfigInterfaceMockRecorder) expect func(t *testing.T, peer *resource.Peer, mockPeers []*resource.Peer, parent *resource.Peer, ok bool) }{ { name: "task peers is empty", - mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet, md *configmocks.MockDynconfigInterfaceMockRecorder) { + mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet[string], md *configmocks.MockDynconfigInterfaceMockRecorder) { peer.FSM.SetState(resource.PeerStateRunning) md.GetSchedulerClusterConfig().Return(types.SchedulerClusterConfig{}, false).Times(1) @@ -615,7 +615,7 @@ func TestScheduler_FindParent(t *testing.T) { }, { name: "task contains only one peer and peer is itself", - mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet, md *configmocks.MockDynconfigInterfaceMockRecorder) { + mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet[string], md *configmocks.MockDynconfigInterfaceMockRecorder) { peer.FSM.SetState(resource.PeerStateRunning) peer.Task.StorePeer(peer) @@ -628,7 +628,7 @@ func TestScheduler_FindParent(t *testing.T) { }, { name: "peer is in blocklist", - mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet, md *configmocks.MockDynconfigInterfaceMockRecorder) { + mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet[string], md *configmocks.MockDynconfigInterfaceMockRecorder) { peer.FSM.SetState(resource.PeerStateRunning) peer.Task.StorePeer(peer) peer.Task.StorePeer(mockPeers[0]) @@ -643,7 +643,7 @@ func TestScheduler_FindParent(t *testing.T) { }, { name: "peer is bad node", - mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet, md *configmocks.MockDynconfigInterfaceMockRecorder) { + mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet[string], md *configmocks.MockDynconfigInterfaceMockRecorder) { peer.FSM.SetState(resource.PeerStateRunning) mockPeers[0].FSM.SetState(resource.PeerStateFailed) peer.Task.StorePeer(peer) @@ -658,7 +658,7 @@ func TestScheduler_FindParent(t *testing.T) { }, { name: "parent is peer's descendant", - mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet, md *configmocks.MockDynconfigInterfaceMockRecorder) { + mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet[string], md *configmocks.MockDynconfigInterfaceMockRecorder) { peer.FSM.SetState(resource.PeerStateRunning) mockPeers[0].FSM.SetState(resource.PeerStateRunning) peer.Task.StorePeer(peer) @@ -676,7 +676,7 @@ func TestScheduler_FindParent(t *testing.T) { }, { name: "parent free upload load is zero", - mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet, md *configmocks.MockDynconfigInterfaceMockRecorder) { + mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet[string], md *configmocks.MockDynconfigInterfaceMockRecorder) { peer.FSM.SetState(resource.PeerStateRunning) mockPeers[0].FSM.SetState(resource.PeerStateRunning) peer.Task.StorePeer(peer) @@ -692,15 +692,15 @@ func TestScheduler_FindParent(t *testing.T) { }, { name: "find back-to-source parent", - mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet, md *configmocks.MockDynconfigInterfaceMockRecorder) { + mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet[string], md *configmocks.MockDynconfigInterfaceMockRecorder) { peer.FSM.SetState(resource.PeerStateRunning) mockPeers[0].FSM.SetState(resource.PeerStateRunning) mockPeers[1].FSM.SetState(resource.PeerStateRunning) peer.Task.StorePeer(peer) peer.Task.StorePeer(mockPeers[0]) peer.Task.StorePeer(mockPeers[1]) - peer.Task.BackToSourcePeers.Add(mockPeers[0]) - peer.Task.BackToSourcePeers.Add(mockPeers[1]) + peer.Task.BackToSourcePeers.Add(mockPeers[0].ID) + peer.Task.BackToSourcePeers.Add(mockPeers[1].ID) mockPeers[0].IsBackToSource.Store(true) mockPeers[1].IsBackToSource.Store(true) mockPeers[0].Pieces.Set(0) @@ -718,7 +718,7 @@ func TestScheduler_FindParent(t *testing.T) { }, { name: "find seed peer parent", - mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet, md *configmocks.MockDynconfigInterfaceMockRecorder) { + mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet[string], md *configmocks.MockDynconfigInterfaceMockRecorder) { peer.FSM.SetState(resource.PeerStateRunning) mockPeers[0].FSM.SetState(resource.PeerStateRunning) mockPeers[1].FSM.SetState(resource.PeerStateRunning) @@ -743,7 +743,7 @@ func TestScheduler_FindParent(t *testing.T) { }, { name: "parent state is PeerStateSucceeded", - mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet, md *configmocks.MockDynconfigInterfaceMockRecorder) { + mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet[string], md *configmocks.MockDynconfigInterfaceMockRecorder) { peer.FSM.SetState(resource.PeerStateRunning) mockPeers[0].FSM.SetState(resource.PeerStateSucceeded) mockPeers[1].FSM.SetState(resource.PeerStateSucceeded) @@ -765,7 +765,7 @@ func TestScheduler_FindParent(t *testing.T) { }, { name: "find parent with ancestor", - mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet, md *configmocks.MockDynconfigInterfaceMockRecorder) { + mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet[string], md *configmocks.MockDynconfigInterfaceMockRecorder) { peer.FSM.SetState(resource.PeerStateRunning) mockPeers[0].FSM.SetState(resource.PeerStateRunning) mockPeers[1].FSM.SetState(resource.PeerStateRunning) @@ -796,15 +796,15 @@ func TestScheduler_FindParent(t *testing.T) { }, { name: "find parent and fetch filterParentLimit from manager dynconfig", - mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet, md *configmocks.MockDynconfigInterfaceMockRecorder) { + mock: func(peer *resource.Peer, mockPeers []*resource.Peer, blocklist set.SafeSet[string], md *configmocks.MockDynconfigInterfaceMockRecorder) { peer.FSM.SetState(resource.PeerStateRunning) mockPeers[0].FSM.SetState(resource.PeerStateRunning) mockPeers[1].FSM.SetState(resource.PeerStateRunning) peer.Task.StorePeer(peer) peer.Task.StorePeer(mockPeers[0]) peer.Task.StorePeer(mockPeers[1]) - peer.Task.BackToSourcePeers.Add(mockPeers[0]) - peer.Task.BackToSourcePeers.Add(mockPeers[1]) + peer.Task.BackToSourcePeers.Add(mockPeers[0].ID) + peer.Task.BackToSourcePeers.Add(mockPeers[1].ID) mockPeers[0].IsBackToSource.Store(true) mockPeers[1].IsBackToSource.Store(true) mockPeers[0].Pieces.Set(0) @@ -839,7 +839,7 @@ func TestScheduler_FindParent(t *testing.T) { mockPeers = append(mockPeers, peer) } - blocklist := set.NewSafeSet() + blocklist := set.NewSafeSet[string]() tc.mock(peer, mockPeers, blocklist, dynconfig.EXPECT()) scheduler := New(mockSchedulerConfig, dynconfig, mockPluginDir) parent, ok := scheduler.FindParent(context.Background(), peer, blocklist) diff --git a/scheduler/service/service.go b/scheduler/service/service.go index d0dd925a8..2a3041f38 100644 --- a/scheduler/service/service.go +++ b/scheduler/service/service.go @@ -123,7 +123,7 @@ func (s *Service) RegisterPeerTask(ctx context.Context, req *rpcscheduler.PeerTa case base.SizeScope_SMALL: peer.Log.Info("task size scope is small") // There is no need to build a tree, just find the parent and return. - parent, ok := s.scheduler.FindParent(ctx, peer, set.NewSafeSet()) + parent, ok := s.scheduler.FindParent(ctx, peer, set.NewSafeSet[string]()) if !ok { peer.Log.Warn("task size scope is small and it can not select parent") if err := peer.FSM.Event(resource.PeerEventRegisterNormal); err != nil { @@ -626,7 +626,7 @@ func (s *Service) handleBeginOfPiece(ctx context.Context, peer *resource.Peer) { } peer.Log.Infof("schedule parent because of peer receive begin of piece") - s.scheduler.ScheduleParent(ctx, peer, set.NewSafeSet()) + s.scheduler.ScheduleParent(ctx, peer, set.NewSafeSet[string]()) default: peer.Log.Warnf("peer state is %s when receive the begin of piece", peer.FSM.Current()) } diff --git a/scheduler/service/service_test.go b/scheduler/service/service_test.go index 57c1bfa6c..060037375 100644 --- a/scheduler/service/service_test.go +++ b/scheduler/service/service_test.go @@ -1656,7 +1656,7 @@ func TestService_LeaveTask(t *testing.T) { gomock.InOrder( mr.PeerManager().Return(peerManager).Times(1), mp.Load(gomock.Any()).Return(peer, true).Times(1), - ms.ScheduleParent(gomock.Any(), gomock.Eq(child), gomock.Eq(set.NewSafeSet())).Return().Times(1), + ms.ScheduleParent(gomock.Any(), gomock.Eq(child), gomock.Eq(set.NewSafeSet[string]())).Return().Times(1), mr.PeerManager().Return(peerManager).Times(1), mp.Delete(gomock.Eq(peer.ID)).Return().Times(1), ) @@ -1674,7 +1674,7 @@ func TestService_LeaveTask(t *testing.T) { peer.Task.StorePeer(peer) peer.FSM.SetState(resource.PeerStateSucceeded) - blocklist := set.NewSafeSet() + blocklist := set.NewSafeSet[string]() blocklist.Add(peer.ID) gomock.InOrder( mr.PeerManager().Return(peerManager).Times(1), @@ -1704,7 +1704,7 @@ func TestService_LeaveTask(t *testing.T) { gomock.InOrder( mr.PeerManager().Return(peerManager).Times(1), mp.Load(gomock.Any()).Return(peer, true).Times(1), - ms.ScheduleParent(gomock.Any(), gomock.Eq(child), gomock.Eq(set.NewSafeSet())).Return().Times(1), + ms.ScheduleParent(gomock.Any(), gomock.Eq(child), gomock.Eq(set.NewSafeSet[string]())).Return().Times(1), mr.PeerManager().Return(peerManager).Times(1), mp.Delete(gomock.Eq(peer.ID)).Return().Times(1), ) @@ -1722,7 +1722,7 @@ func TestService_LeaveTask(t *testing.T) { peer.Task.StorePeer(peer) peer.FSM.SetState(resource.PeerStateFailed) - blocklist := set.NewSafeSet() + blocklist := set.NewSafeSet[string]() blocklist.Add(peer.ID) gomock.InOrder( mr.PeerManager().Return(peerManager).Times(1), @@ -2310,7 +2310,7 @@ func TestService_handleBeginOfPiece(t *testing.T) { name: "peer state is PeerStateReceivedNormal", mock: func(peer *resource.Peer, scheduler *mocks.MockSchedulerMockRecorder) { peer.FSM.SetState(resource.PeerStateReceivedNormal) - scheduler.ScheduleParent(gomock.Any(), gomock.Eq(peer), gomock.Eq(set.NewSafeSet())).Return().Times(1) + scheduler.ScheduleParent(gomock.Any(), gomock.Eq(peer), gomock.Eq(set.NewSafeSet[string]())).Return().Times(1) }, expect: func(t *testing.T, peer *resource.Peer) { assert := assert.New(t) @@ -2537,7 +2537,7 @@ func TestService_handlePieceFail(t *testing.T) { parent: resource.NewPeer(mockSeedPeerID, mockTask, mockHost), run: func(t *testing.T, svc *Service, peer *resource.Peer, parent *resource.Peer, piece *rpcscheduler.PieceResult, peerManager resource.PeerManager, seedPeer resource.SeedPeer, ms *mocks.MockSchedulerMockRecorder, mr *resource.MockResourceMockRecorder, mp *resource.MockPeerManagerMockRecorder, mc *resource.MockSeedPeerMockRecorder) { peer.FSM.SetState(resource.PeerStateRunning) - blocklist := set.NewSafeSet() + blocklist := set.NewSafeSet[string]() blocklist.Add(mockSeedPeerID) gomock.InOrder( mr.PeerManager().Return(peerManager).Times(1), @@ -2566,7 +2566,7 @@ func TestService_handlePieceFail(t *testing.T) { run: func(t *testing.T, svc *Service, peer *resource.Peer, parent *resource.Peer, piece *rpcscheduler.PieceResult, peerManager resource.PeerManager, seedPeer resource.SeedPeer, ms *mocks.MockSchedulerMockRecorder, mr *resource.MockResourceMockRecorder, mp *resource.MockPeerManagerMockRecorder, mc *resource.MockSeedPeerMockRecorder) { peer.FSM.SetState(resource.PeerStateRunning) parent.FSM.SetState(resource.PeerStateRunning) - blocklist := set.NewSafeSet() + blocklist := set.NewSafeSet[string]() blocklist.Add(parent.ID) gomock.InOrder( mr.PeerManager().Return(peerManager).Times(1), @@ -2596,7 +2596,7 @@ func TestService_handlePieceFail(t *testing.T) { run: func(t *testing.T, svc *Service, peer *resource.Peer, parent *resource.Peer, piece *rpcscheduler.PieceResult, peerManager resource.PeerManager, seedPeer resource.SeedPeer, ms *mocks.MockSchedulerMockRecorder, mr *resource.MockResourceMockRecorder, mp *resource.MockPeerManagerMockRecorder, mc *resource.MockSeedPeerMockRecorder) { peer.FSM.SetState(resource.PeerStateRunning) peer.Host.Type = resource.HostTypeNormal - blocklist := set.NewSafeSet() + blocklist := set.NewSafeSet[string]() blocklist.Add(parent.ID) gomock.InOrder( mr.PeerManager().Return(peerManager).Times(1), @@ -2625,7 +2625,7 @@ func TestService_handlePieceFail(t *testing.T) { run: func(t *testing.T, svc *Service, peer *resource.Peer, parent *resource.Peer, piece *rpcscheduler.PieceResult, peerManager resource.PeerManager, seedPeer resource.SeedPeer, ms *mocks.MockSchedulerMockRecorder, mr *resource.MockResourceMockRecorder, mp *resource.MockPeerManagerMockRecorder, mc *resource.MockSeedPeerMockRecorder) { peer.FSM.SetState(resource.PeerStateRunning) parent.FSM.SetState(resource.PeerStateRunning) - blocklist := set.NewSafeSet() + blocklist := set.NewSafeSet[string]() blocklist.Add(parent.ID) gomock.InOrder( mr.PeerManager().Return(peerManager).Times(1), @@ -2655,7 +2655,7 @@ func TestService_handlePieceFail(t *testing.T) { run: func(t *testing.T, svc *Service, peer *resource.Peer, parent *resource.Peer, piece *rpcscheduler.PieceResult, peerManager resource.PeerManager, seedPeer resource.SeedPeer, ms *mocks.MockSchedulerMockRecorder, mr *resource.MockResourceMockRecorder, mp *resource.MockPeerManagerMockRecorder, mc *resource.MockSeedPeerMockRecorder) { peer.FSM.SetState(resource.PeerStateRunning) parent.FSM.SetState(resource.PeerStateRunning) - blocklist := set.NewSafeSet() + blocklist := set.NewSafeSet[string]() blocklist.Add(parent.ID) gomock.InOrder( mr.PeerManager().Return(peerManager).Times(1), @@ -2844,7 +2844,7 @@ func TestService_handlePeerFail(t *testing.T) { peer.FSM.SetState(resource.PeerStateRunning) child.FSM.SetState(resource.PeerStateRunning) - ms.ScheduleParent(gomock.Any(), gomock.Eq(child), gomock.Eq(set.NewSafeSet())).Return().Times(1) + ms.ScheduleParent(gomock.Any(), gomock.Eq(child), gomock.Eq(set.NewSafeSet[string]())).Return().Times(1) }, expect: func(t *testing.T, peer *resource.Peer, child *resource.Peer) { assert := assert.New(t)