From c49a8d03890d99df50d42a0a9f9357a241e25623 Mon Sep 17 00:00:00 2001 From: BruceAko Date: Tue, 18 Mar 2025 13:07:16 +0800 Subject: [PATCH] fix: update peer and task manager methods to use correct redis key format Signed-off-by: BruceAko --- .../resource/persistentcache/peer_manager.go | 15 ++++++++-- .../persistentcache/peer_manager_test.go | 15 ++++++++-- .../resource/persistentcache/task_manager.go | 30 ++++++++++++++----- .../persistentcache/task_manager_mock.go | 12 ++++---- .../persistentcache/task_manager_test.go | 19 ++++++++---- scheduler/service/service_v2.go | 10 +++---- 6 files changed, 72 insertions(+), 29 deletions(-) diff --git a/scheduler/resource/persistentcache/peer_manager.go b/scheduler/resource/persistentcache/peer_manager.go index 0adfc8946..09979d5be 100644 --- a/scheduler/resource/persistentcache/peer_manager.go +++ b/scheduler/resource/persistentcache/peer_manager.go @@ -22,7 +22,9 @@ import ( "context" "encoding/json" "errors" + "fmt" "strconv" + "strings" "time" "github.com/bits-and-blooms/bitset" @@ -367,16 +369,23 @@ func (p *peerManager) LoadAll(ctx context.Context) ([]*Peer, error) { err error ) - peerKeys, cursor, err = p.rdb.Scan(ctx, cursor, pkgredis.MakePersistentCachePeersInScheduler(p.config.Manager.SchedulerClusterID), 10).Result() + prefix := fmt.Sprintf("%s:", pkgredis.MakePersistentCachePeersInScheduler(p.config.Manager.SchedulerClusterID)) + peerKeys, cursor, err = p.rdb.Scan(ctx, cursor, fmt.Sprintf("%s*", prefix), 10).Result() if err != nil { logger.Errorf("scan tasks failed: %v", err) return nil, err } for _, peerKey := range peerKeys { - peer, loaded := p.Load(ctx, peerKey) + peerID := strings.TrimPrefix(peerKey, prefix) + if peerID == "" { + logger.Error("invalid peer key") + continue + } + + peer, loaded := p.Load(ctx, peerID) if !loaded { - logger.WithPeerID(peerKey).Error("load peer failed") + logger.WithPeerID(peerID).Error("load peer failed") continue } diff --git a/scheduler/resource/persistentcache/peer_manager_test.go b/scheduler/resource/persistentcache/peer_manager_test.go index 84b60d007..62eb0d8b4 100644 --- a/scheduler/resource/persistentcache/peer_manager_test.go +++ b/scheduler/resource/persistentcache/peer_manager_test.go @@ -19,6 +19,7 @@ package persistentcache import ( "context" "errors" + "fmt" "strconv" "testing" "time" @@ -248,15 +249,23 @@ func TestPeerManager_LoadAll(t *testing.T) { { name: "redis scan error", mockRedis: func(mock redismock.ClientMock) { - mock.ExpectScan(0, pkgredis.MakePersistentCachePeersInScheduler(42), 10).SetErr(errors.New("redis scan error")) + mock.ExpectScan(0, fmt.Sprintf("%s:*", pkgredis.MakePersistentCachePeersInScheduler(42)), 10).SetErr(errors.New("redis scan error")) }, expectedPeers: nil, expectedErr: true, }, + { + name: "invalid peer key", + mockRedis: func(mock redismock.ClientMock) { + mock.ExpectScan(0, fmt.Sprintf("%s:*", pkgredis.MakePersistentCachePeersInScheduler(42)), 10).SetVal([]string{fmt.Sprintf("%s:", pkgredis.MakePersistentCachePeersInScheduler(42))}, 0) + }, + expectedPeers: nil, + expectedErr: false, + }, { name: "load peer error", mockRedis: func(mock redismock.ClientMock) { - mock.ExpectScan(0, pkgredis.MakePersistentCachePeersInScheduler(42), 10).SetVal([]string{"peer1"}, 0) + mock.ExpectScan(0, fmt.Sprintf("%s:*", pkgredis.MakePersistentCachePeersInScheduler(42)), 10).SetVal([]string{fmt.Sprintf("%s:peer1", pkgredis.MakePersistentCachePeersInScheduler(42))}, 0) mock.ExpectHGetAll(pkgredis.MakePersistentCachePeerKeyInScheduler(42, "peer1")).SetErr(errors.New("redis hgetall error")) }, expectedPeers: nil, @@ -270,7 +279,7 @@ func TestPeerManager_LoadAll(t *testing.T) { t.Fatalf("failed to marshal bitset: %v", err) } - mock.ExpectScan(0, pkgredis.MakePersistentCachePeersInScheduler(42), 10).SetVal([]string{"peer1"}, 0) + mock.ExpectScan(0, fmt.Sprintf("%s:*", pkgredis.MakePersistentCachePeersInScheduler(42)), 10).SetVal([]string{fmt.Sprintf("%s:peer1", pkgredis.MakePersistentCachePeersInScheduler(42))}, 0) mock.ExpectHGetAll(pkgredis.MakePersistentCachePeerKeyInScheduler(42, "peer1")).SetVal(map[string]string{ "id": "peer1", "state": PeerStateSucceeded, diff --git a/scheduler/resource/persistentcache/task_manager.go b/scheduler/resource/persistentcache/task_manager.go index e5c2607f2..f7bcd3f58 100644 --- a/scheduler/resource/persistentcache/task_manager.go +++ b/scheduler/resource/persistentcache/task_manager.go @@ -20,7 +20,9 @@ package persistentcache import ( "context" + "fmt" "strconv" + "strings" "time" "github.com/redis/go-redis/v9" @@ -35,8 +37,8 @@ type TaskManager interface { // Load returns persistent cache task by a key. Load(context.Context, string) (*Task, bool) - // LoadCorrentReplicaCount returns current replica count of the persistent cache task. - LoadCorrentReplicaCount(context.Context, string) (uint64, error) + // LoadCurrentReplicaCount returns current replica count of the persistent cache task. + LoadCurrentReplicaCount(context.Context, string) (uint64, error) // LoadCurrentPersistentReplicaCount returns current persistent replica count of the persistent cache task. LoadCurrentPersistentReplicaCount(context.Context, string) (uint64, error) @@ -138,8 +140,8 @@ func (t *taskManager) Load(ctx context.Context, taskID string) (*Task, bool) { ), true } -// LoadCorrentReplicaCount returns current replica count of the persistent cache task. -func (t *taskManager) LoadCorrentReplicaCount(ctx context.Context, taskID string) (uint64, error) { +// LoadCurrentReplicaCount returns current replica count of the persistent cache task. +func (t *taskManager) LoadCurrentReplicaCount(ctx context.Context, taskID string) (uint64, error) { count, err := t.rdb.SCard(ctx, pkgredis.MakePersistentCachePeersOfPersistentCacheTaskInScheduler(t.config.Manager.SchedulerClusterID, taskID)).Result() return uint64(count), err } @@ -248,16 +250,30 @@ func (t *taskManager) LoadAll(ctx context.Context) ([]*Task, error) { err error ) - taskKeys, cursor, err = t.rdb.Scan(ctx, cursor, pkgredis.MakePersistentCacheTasksInScheduler(t.config.Manager.SchedulerClusterID), 10).Result() + prefix := fmt.Sprintf("%s:", pkgredis.MakePersistentCacheTasksInScheduler(t.config.Manager.SchedulerClusterID)) + taskKeys, cursor, err = t.rdb.Scan(ctx, cursor, fmt.Sprintf("%s*", prefix), 10).Result() if err != nil { logger.Error("scan tasks failed") return nil, err } + taskIDs := make(map[string]struct{}) for _, taskKey := range taskKeys { - task, loaded := t.Load(ctx, taskKey) + suffix := strings.TrimPrefix(taskKey, prefix) + if suffix == "" { + logger.Error("invalid task key") + continue + } + + taskID := strings.Split(suffix, ":")[0] + if _, ok := taskIDs[taskID]; ok { + continue + } + taskIDs[taskID] = struct{}{} + + task, loaded := t.Load(ctx, taskID) if !loaded { - logger.WithTaskID(taskKey).Error("load task failed") + logger.WithTaskID(taskID).Error("load task failed") continue } diff --git a/scheduler/resource/persistentcache/task_manager_mock.go b/scheduler/resource/persistentcache/task_manager_mock.go index 48b972f7f..906dd9196 100644 --- a/scheduler/resource/persistentcache/task_manager_mock.go +++ b/scheduler/resource/persistentcache/task_manager_mock.go @@ -84,19 +84,19 @@ func (mr *MockTaskManagerMockRecorder) LoadAll(arg0 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadAll", reflect.TypeOf((*MockTaskManager)(nil).LoadAll), arg0) } -// LoadCorrentReplicaCount mocks base method. -func (m *MockTaskManager) LoadCorrentReplicaCount(arg0 context.Context, arg1 string) (uint64, error) { +// LoadCurrentReplicaCount mocks base method. +func (m *MockTaskManager) LoadCurrentReplicaCount(arg0 context.Context, arg1 string) (uint64, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "LoadCorrentReplicaCount", arg0, arg1) + ret := m.ctrl.Call(m, "LoadCurrentReplicaCount", arg0, arg1) ret0, _ := ret[0].(uint64) ret1, _ := ret[1].(error) return ret0, ret1 } -// LoadCorrentReplicaCount indicates an expected call of LoadCorrentReplicaCount. -func (mr *MockTaskManagerMockRecorder) LoadCorrentReplicaCount(arg0, arg1 any) *gomock.Call { +// LoadCurrentReplicaCount indicates an expected call of LoadCurrentReplicaCount. +func (mr *MockTaskManagerMockRecorder) LoadCurrentReplicaCount(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadCorrentReplicaCount", reflect.TypeOf((*MockTaskManager)(nil).LoadCorrentReplicaCount), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadCurrentReplicaCount", reflect.TypeOf((*MockTaskManager)(nil).LoadCurrentReplicaCount), arg0, arg1) } // LoadCurrentPersistentReplicaCount mocks base method. diff --git a/scheduler/resource/persistentcache/task_manager_test.go b/scheduler/resource/persistentcache/task_manager_test.go index 8feba1b21..70391ae38 100644 --- a/scheduler/resource/persistentcache/task_manager_test.go +++ b/scheduler/resource/persistentcache/task_manager_test.go @@ -19,6 +19,7 @@ package persistentcache import ( "context" "errors" + "fmt" "strconv" "testing" "time" @@ -183,7 +184,7 @@ func TestTaskManager_Load(t *testing.T) { } } -func TestTaskManager_LoadCorrentReplicaCount(t *testing.T) { +func TestTaskManager_LoadCurrentReplicaCount(t *testing.T) { type args struct { taskID string } @@ -229,7 +230,7 @@ func TestTaskManager_LoadCorrentReplicaCount(t *testing.T) { rdb: rdb, } - cnt, err := tm.LoadCorrentReplicaCount(context.Background(), tt.args.taskID) + cnt, err := tm.LoadCurrentReplicaCount(context.Background(), tt.args.taskID) assert.Equal(t, tt.expectedCount, cnt) assert.Equal(t, tt.expectedErr, err != nil, "error mismatch") assert.NoError(t, mock.ExpectationsWereMet()) @@ -351,15 +352,23 @@ func TestTaskManager_LoadAll(t *testing.T) { { name: "scan error", mockRedis: func(mock redismock.ClientMock) { - mock.ExpectScan(0, pkgredis.MakePersistentCacheTasksInScheduler(42), 10).SetErr(errors.New("scan error")) + mock.ExpectScan(0, fmt.Sprintf("%s:*", pkgredis.MakePersistentCacheTasksInScheduler(42)), 10).SetErr(errors.New("scan error")) }, expectedErr: true, expectedLen: 0, }, + { + name: "invalid task key", + mockRedis: func(mock redismock.ClientMock) { + mock.ExpectScan(0, fmt.Sprintf("%s:*", pkgredis.MakePersistentCacheTasksInScheduler(42)), 10).SetVal([]string{fmt.Sprintf("%s:", pkgredis.MakePersistentCacheTasksInScheduler(42))}, 0) + }, + expectedErr: false, + expectedLen: 0, + }, { name: "load task error", mockRedis: func(mock redismock.ClientMock) { - mock.ExpectScan(0, pkgredis.MakePersistentCacheTasksInScheduler(42), 10).SetVal([]string{"task1"}, 0) + mock.ExpectScan(0, fmt.Sprintf("%s:*", pkgredis.MakePersistentCacheTasksInScheduler(42)), 10).SetVal([]string{fmt.Sprintf("%s:task1", pkgredis.MakePersistentCacheTasksInScheduler(42))}, 0) mock.ExpectHGetAll(pkgredis.MakePersistentCacheTaskKeyInScheduler(42, "task1")).SetErr(errors.New("load error")) }, expectedErr: false, @@ -368,7 +377,7 @@ func TestTaskManager_LoadAll(t *testing.T) { { name: "successful load all", mockRedis: func(mock redismock.ClientMock) { - mock.ExpectScan(0, pkgredis.MakePersistentCacheTasksInScheduler(42), 10).SetVal([]string{"task1", "task2"}, 0) + mock.ExpectScan(0, fmt.Sprintf("%s:*", pkgredis.MakePersistentCacheTasksInScheduler(42)), 10).SetVal([]string{fmt.Sprintf("%s:task1", pkgredis.MakePersistentCacheTasksInScheduler(42)), fmt.Sprintf("%s:task2", pkgredis.MakePersistentCacheTasksInScheduler(42))}, 0) mockData := map[string]string{ "id": "task1", "tag": "tag_value", diff --git a/scheduler/service/service_v2.go b/scheduler/service/service_v2.go index e9398c549..d43ea5e8a 100644 --- a/scheduler/service/service_v2.go +++ b/scheduler/service/service_v2.go @@ -1856,7 +1856,7 @@ func (v *V2) handleRegisterPersistentCachePeerRequest(ctx context.Context, strea return status.Error(codes.Internal, err.Error()) } - currentReplicaCount, err := v.persistentCacheResource.TaskManager().LoadCorrentReplicaCount(ctx, taskID) + currentReplicaCount, err := v.persistentCacheResource.TaskManager().LoadCurrentReplicaCount(ctx, taskID) if err != nil { // Collect RegisterPersistentCachePeerFailureCount metrics. metrics.RegisterPersistentCachePeerFailureCount.WithLabelValues(peer.Host.Type.Name()).Inc() @@ -2051,7 +2051,7 @@ func (v *V2) handleReschedulePersistentCachePeerRequest(ctx context.Context, str return status.Error(codes.Internal, err.Error()) } - currentReplicaCount, err := v.persistentCacheResource.TaskManager().LoadCorrentReplicaCount(ctx, taskID) + currentReplicaCount, err := v.persistentCacheResource.TaskManager().LoadCurrentReplicaCount(ctx, taskID) if err != nil { // Collect RegisterPersistentCachePeerFailureCount metrics. metrics.RegisterPersistentCachePeerFailureCount.WithLabelValues(peer.Host.Type.Name()).Inc() @@ -2314,7 +2314,7 @@ func (v *V2) StatPersistentCachePeer(ctx context.Context, req *schedulerv2.StatP return nil, status.Error(codes.Internal, err.Error()) } - currentReplicaCount, err := v.persistentCacheResource.TaskManager().LoadCorrentReplicaCount(ctx, peer.Task.ID) + currentReplicaCount, err := v.persistentCacheResource.TaskManager().LoadCurrentReplicaCount(ctx, peer.Task.ID) if err != nil { log.Errorf("load current replica count failed %s", err.Error()) return nil, status.Error(codes.Internal, err.Error()) @@ -2569,7 +2569,7 @@ func (v *V2) UploadPersistentCacheTaskFinished(ctx context.Context, req *schedul return nil, status.Error(codes.Internal, err.Error()) } - currentReplicaCount, err := v.persistentCacheResource.TaskManager().LoadCorrentReplicaCount(ctx, peer.Task.ID) + currentReplicaCount, err := v.persistentCacheResource.TaskManager().LoadCurrentReplicaCount(ctx, peer.Task.ID) if err != nil { log.Errorf("load current replica count failed %s", err) return nil, status.Error(codes.Internal, err.Error()) @@ -2769,7 +2769,7 @@ func (v *V2) StatPersistentCacheTask(ctx context.Context, req *schedulerv2.StatP return nil, status.Error(codes.Internal, err.Error()) } - currentReplicaCount, err := v.persistentCacheResource.TaskManager().LoadCorrentReplicaCount(ctx, task.ID) + currentReplicaCount, err := v.persistentCacheResource.TaskManager().LoadCurrentReplicaCount(ctx, task.ID) if err != nil { log.Errorf("load current replica count failed %s", err) return nil, status.Error(codes.Internal, err.Error())