From 4a752a47da433105d77f561b24b451909bcbfbe9 Mon Sep 17 00:00:00 2001 From: sunwp <244372610@qq.com> Date: Fri, 4 Mar 2022 18:07:01 +0800 Subject: [PATCH] fix concurrent piece map panic in cdn (#1121) Signed-off-by: sunwp <244372610@qq.com> --- cdn/rpcserver/rpcserver.go | 11 +-- cdn/rpcserver/rpcserver_test.go | 95 ++++++++++--------- cdn/supervisor/cdn/manager_test.go | 5 + cdn/supervisor/mocks/mock_cdn_service.go | 2 +- .../mocks/task/mock_task_manager.go | 4 +- cdn/supervisor/progress/manager.go | 20 ++-- cdn/supervisor/progress/progress.go | 33 +++---- cdn/supervisor/progress/progress_test.go | 11 ++- cdn/supervisor/service.go | 14 +-- cdn/supervisor/task/manager.go | 13 ++- cdn/supervisor/task/task.go | 18 ++-- 11 files changed, 110 insertions(+), 116 deletions(-) diff --git a/cdn/rpcserver/rpcserver.go b/cdn/rpcserver/rpcserver.go index 2a39b258f..1e5ff0af4 100644 --- a/cdn/rpcserver/rpcserver.go +++ b/cdn/rpcserver/rpcserver.go @@ -196,15 +196,15 @@ func (css *Server) GetPieceTasks(ctx context.Context, req *base.PieceTaskRequest span.RecordError(err) return nil, err } - pieces, err := css.service.GetSeedPieces(req.TaskId) + taskPieces, err := css.service.GetSeedPieces(req.TaskId) if err != nil { err = dferrors.Newf(base.Code_CDNError, "failed to get pieces of task(%s) from cdn: %v", seedTask.ID, err) span.RecordError(err) return nil, err } - pieceInfos := make([]*base.PieceInfo, 0, len(pieces)) + pieceInfos := make([]*base.PieceInfo, 0, len(taskPieces)) var count uint32 = 0 - for _, piece := range pieces { + for _, piece := range taskPieces { if piece.PieceNum >= req.StartNum && (count < req.Limit || req.Limit <= 0) { p := &base.PieceInfo{ PieceNum: int32(piece.PieceNum), @@ -220,11 +220,10 @@ func (css *Server) GetPieceTasks(ctx context.Context, req *base.PieceTaskRequest } } pieceMd5Sign := seedTask.PieceMd5Sign - if len(seedTask.Pieces) == int(seedTask.TotalPieceCount) && pieceMd5Sign == "" { - taskPieces := seedTask.Pieces + if len(taskPieces) == int(seedTask.TotalPieceCount) && pieceMd5Sign == "" { var pieceMd5s []string for i := 0; i < len(taskPieces); i++ { - pieceMd5s = append(pieceMd5s, taskPieces[uint32(i)].PieceMd5) + pieceMd5s = append(pieceMd5s, taskPieces[i].PieceMd5) } pieceMd5Sign = digestutils.Sha256(pieceMd5s...) } diff --git a/cdn/rpcserver/rpcserver_test.go b/cdn/rpcserver/rpcserver_test.go index 081f42109..edfa60953 100644 --- a/cdn/rpcserver/rpcserver_test.go +++ b/cdn/rpcserver/rpcserver_test.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "sort" + "sync" "testing" "github.com/golang/mock/gomock" @@ -225,58 +226,58 @@ func TestServer_GetPieceTasks(t *testing.T) { Range: "", Filter: "", Header: nil, - Pieces: map[uint32]*task.PieceInfo{ - 0: { - PieceNum: 0, - PieceMd5: "xxxx0", - PieceRange: &rangeutils.Range{ - StartIndex: 0, - EndIndex: 99, - }, - OriginRange: &rangeutils.Range{ - StartIndex: 0, - EndIndex: 99, - }, - PieceLen: 100, - PieceStyle: 0, - }, - 1: { - PieceNum: 1, - PieceMd5: "xxxx1", - PieceRange: &rangeutils.Range{ - StartIndex: 100, - EndIndex: 199, - }, - OriginRange: &rangeutils.Range{ - StartIndex: 100, - EndIndex: 199, - }, - PieceLen: 100, - PieceStyle: 0, - }, - 2: { - PieceNum: 2, - PieceMd5: "xxxx2", - PieceRange: &rangeutils.Range{ - StartIndex: 200, - EndIndex: 299, - }, - OriginRange: &rangeutils.Range{ - StartIndex: 200, - EndIndex: 249, - }, - PieceLen: 100, - PieceStyle: 0, - }, - }, + Pieces: new(sync.Map), } + testTask.Pieces.Store(0, &task.PieceInfo{ + PieceNum: 0, + PieceMd5: "xxxx0", + PieceRange: &rangeutils.Range{ + StartIndex: 0, + EndIndex: 99, + }, + OriginRange: &rangeutils.Range{ + StartIndex: 0, + EndIndex: 99, + }, + PieceLen: 100, + PieceStyle: 0, + }) + testTask.Pieces.Store(1, &task.PieceInfo{ + PieceNum: 1, + PieceMd5: "xxxx1", + PieceRange: &rangeutils.Range{ + StartIndex: 100, + EndIndex: 199, + }, + OriginRange: &rangeutils.Range{ + StartIndex: 100, + EndIndex: 199, + }, + PieceLen: 100, + PieceStyle: 0, + }) + testTask.Pieces.Store(2, &task.PieceInfo{ + PieceNum: 2, + PieceMd5: "xxxx2", + PieceRange: &rangeutils.Range{ + StartIndex: 200, + EndIndex: 299, + }, + OriginRange: &rangeutils.Range{ + StartIndex: 200, + EndIndex: 249, + }, + PieceLen: 100, + PieceStyle: 0, + }) cdnServiceMock.EXPECT().GetSeedTask(args.req.TaskId).DoAndReturn(func(taskID string) (seedTask *task.SeedTask, err error) { return testTask, nil }) cdnServiceMock.EXPECT().GetSeedPieces(args.req.TaskId).DoAndReturn(func(taskID string) (pieces []*task.PieceInfo, err error) { - for u := range testTask.Pieces { - pieces = append(pieces, testTask.Pieces[u]) - } + testTask.Pieces.Range(func(key, value interface{}) bool { + pieces = append(pieces, value.(*task.PieceInfo)) + return true + }) sort.Slice(pieces, func(i, j int) bool { return pieces[i].PieceNum < pieces[j].PieceNum }) diff --git a/cdn/supervisor/cdn/manager_test.go b/cdn/supervisor/cdn/manager_test.go index f850d1092..6384bb8db 100644 --- a/cdn/supervisor/cdn/manager_test.go +++ b/cdn/supervisor/cdn/manager_test.go @@ -22,6 +22,7 @@ import ( "io" "os" "strings" + "sync" "testing" "github.com/golang/mock/gomock" @@ -145,6 +146,7 @@ func (suite *CDNManagerTestSuite) TestTriggerCDN() { Digest: "md5:f1e2488bba4d1267948d9e2f7008571c", SourceRealDigest: "", PieceMd5Sign: "", + Pieces: new(sync.Map), }, targetTask: &task.SeedTask{ ID: md5TaskID, @@ -159,6 +161,7 @@ func (suite *CDNManagerTestSuite) TestTriggerCDN() { Digest: "md5:f1e2488bba4d1267948d9e2f7008571c", SourceRealDigest: "md5:f1e2488bba4d1267948d9e2f7008571c", PieceMd5Sign: "bb138842f338fff90af737e4a6b2c6f8e2a7031ca9d5900bc9b646f6406d890f", + Pieces: new(sync.Map), }, }, { @@ -176,6 +179,7 @@ func (suite *CDNManagerTestSuite) TestTriggerCDN() { Digest: "sha256:b9907b9a5ba2b0223868c201b9addfe2ec1da1b90325d57c34f192966b0a68c5", SourceRealDigest: "", PieceMd5Sign: "", + Pieces: new(sync.Map), }, targetTask: &task.SeedTask{ ID: sha256TaskID, @@ -190,6 +194,7 @@ func (suite *CDNManagerTestSuite) TestTriggerCDN() { Digest: "sha256:b9907b9a5ba2b0223868c201b9addfe2ec1da1b90325d57c34f192966b0a68c5", SourceRealDigest: "sha256:b9907b9a5ba2b0223868c201b9addfe2ec1da1b90325d57c34f192966b0a68c5", PieceMd5Sign: "bb138842f338fff90af737e4a6b2c6f8e2a7031ca9d5900bc9b646f6406d890f", + Pieces: new(sync.Map), }, }, } diff --git a/cdn/supervisor/mocks/mock_cdn_service.go b/cdn/supervisor/mocks/mock_cdn_service.go index 411008272..bac361de4 100644 --- a/cdn/supervisor/mocks/mock_cdn_service.go +++ b/cdn/supervisor/mocks/mock_cdn_service.go @@ -1,7 +1,7 @@ // Code generated by MockGen. DO NOT EDIT. // Source: d7y.io/dragonfly/v2/cdn/supervisor (interfaces: CDNService) -// Package progress is a generated GoMock package. +// Package mocks is a generated GoMock package. package mocks import ( diff --git a/cdn/supervisor/mocks/task/mock_task_manager.go b/cdn/supervisor/mocks/task/mock_task_manager.go index 53ab7821d..c67c9e94d 100644 --- a/cdn/supervisor/mocks/task/mock_task_manager.go +++ b/cdn/supervisor/mocks/task/mock_task_manager.go @@ -92,10 +92,10 @@ func (mr *MockManagerMockRecorder) Get(arg0 interface{}) *gomock.Call { } // GetProgress mocks base method. -func (m *MockManager) GetProgress(arg0 string) (map[uint32]*task.PieceInfo, error) { +func (m *MockManager) GetProgress(arg0 string) ([]*task.PieceInfo, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetProgress", arg0) - ret0, _ := ret[0].(map[uint32]*task.PieceInfo) + ret0, _ := ret[0].([]*task.PieceInfo) ret1, _ := ret[1].(error) return ret0, ret1 } diff --git a/cdn/supervisor/progress/manager.go b/cdn/supervisor/progress/manager.go index 1ba07f6f5..51b9a3605 100644 --- a/cdn/supervisor/progress/manager.go +++ b/cdn/supervisor/progress/manager.go @@ -75,6 +75,10 @@ func (pm *manager) WatchSeedProgress(ctx context.Context, clientAddr string, tas if err != nil { return nil, err } + pieces, err := pm.taskManager.GetProgress(taskID) + if err != nil { + return nil, err + } if seedTask.IsDone() { pieceChan := make(chan *task.PieceInfo) go func(pieceChan chan *task.PieceInfo) { @@ -82,22 +86,18 @@ func (pm *manager) WatchSeedProgress(ctx context.Context, clientAddr string, tas logger.Debugf("subscriber %s starts watching task %s seed progress", clientAddr, taskID) close(pieceChan) }() - pieceNums := make([]uint32, 0, len(seedTask.Pieces)) - for pieceNum := range seedTask.Pieces { - pieceNums = append(pieceNums, pieceNum) - } - sort.Slice(pieceNums, func(i, j int) bool { - return pieceNums[i] < pieceNums[j] + sort.Slice(pieces, func(i, j int) bool { + return pieces[i].PieceNum < pieces[j].PieceNum }) - for _, pieceNum := range pieceNums { - logger.Debugf("notifies subscriber %s about %d piece info of taskID %s", clientAddr, pieceNum, taskID) - pieceChan <- seedTask.Pieces[pieceNum] + for _, piece := range pieces { + logger.Debugf("notifies subscriber %s about %d piece info of taskID %s", clientAddr, piece.PieceNum, taskID) + pieceChan <- piece } }(pieceChan) return pieceChan, nil } var progressPublisher, _ = pm.seedTaskSubjects.LoadOrStore(taskID, newProgressPublisher(taskID)) - observer := newProgressSubscriber(ctx, clientAddr, seedTask.ID, seedTask.Pieces) + observer := newProgressSubscriber(ctx, clientAddr, seedTask.ID, pieces) progressPublisher.(*publisher).AddSubscriber(observer) return observer.Receiver(), nil } diff --git a/cdn/supervisor/progress/progress.go b/cdn/supervisor/progress/progress.go index df8d184df..9fb35ec4c 100644 --- a/cdn/supervisor/progress/progress.go +++ b/cdn/supervisor/progress/progress.go @@ -34,25 +34,21 @@ type subscriber struct { taskID string done chan struct{} once sync.Once - pieces map[uint32]*task.PieceInfo + pieces []*task.PieceInfo pieceChan chan *task.PieceInfo cond *sync.Cond closed *atomic.Bool } -func newProgressSubscriber(ctx context.Context, clientAddr, taskID string, taskPieces map[uint32]*task.PieceInfo) *subscriber { - pieces := make(map[uint32]*task.PieceInfo, len(taskPieces)) - for u, info := range taskPieces { - pieces[u] = info - } +func newProgressSubscriber(ctx context.Context, clientAddr, taskID string, taskPieces []*task.PieceInfo) *subscriber { sub := &subscriber{ ctx: ctx, scheduler: clientAddr, taskID: taskID, - pieces: pieces, done: make(chan struct{}), - pieceChan: make(chan *task.PieceInfo, 100), cond: sync.NewCond(&sync.Mutex{}), + pieces: taskPieces, + pieceChan: make(chan *task.PieceInfo, 100), closed: atomic.NewBool(false), } go sub.readLoop() @@ -89,24 +85,20 @@ func (sub *subscriber) readLoop() { } func (sub *subscriber) sendPieces() { - pieceNums := make([]uint32, 0, len(sub.pieces)) - for pieceNum := range sub.pieces { - pieceNums = append(pieceNums, pieceNum) - } - sort.Slice(pieceNums, func(i, j int) bool { - return pieceNums[i] < pieceNums[j] + sort.Slice(sub.pieces, func(i, j int) bool { + return sub.pieces[i].PieceNum < sub.pieces[j].PieceNum }) - for _, pieceNum := range pieceNums { - logger.Debugf("subscriber %s send %d piece info of taskID %s", sub.scheduler, pieceNum, sub.taskID) - sub.pieceChan <- sub.pieces[pieceNum] - delete(sub.pieces, pieceNum) + for _, piece := range sub.pieces { + logger.Debugf("subscriber %s send %d piece info of taskID %s", sub.scheduler, piece.PieceNum, sub.taskID) + sub.pieceChan <- piece } + sub.pieces = []*task.PieceInfo{} } func (sub *subscriber) Notify(seedPiece *task.PieceInfo) { logger.Debugf("notifies subscriber %s about %d piece info of taskID %s", sub.scheduler, seedPiece.PieceNum, sub.taskID) sub.cond.L.Lock() - sub.pieces[seedPiece.PieceNum] = seedPiece + sub.pieces = append(sub.pieces, seedPiece) sub.cond.L.Unlock() sub.cond.Signal() } @@ -154,7 +146,8 @@ func (pub *publisher) RemoveSubscriber(sub *subscriber) { func (pub *publisher) NotifySubscribers(seedPiece *task.PieceInfo) { for e := pub.subscribers.Front(); e != nil; e = e.Next() { - e.Value.(*subscriber).Notify(seedPiece) + sub := e.Value.(*subscriber) + sub.Notify(seedPiece) } } diff --git a/cdn/supervisor/progress/progress_test.go b/cdn/supervisor/progress/progress_test.go index 2934a2171..24a295910 100644 --- a/cdn/supervisor/progress/progress_test.go +++ b/cdn/supervisor/progress/progress_test.go @@ -73,8 +73,8 @@ func Test_publisher_NotifySubscribers(t *testing.T) { PieceLen: 0, PieceStyle: 0, } - sub3 := newProgressSubscriber(context.Background(), "client3", "taskTask", map[uint32]*task.PieceInfo{ - 100: additionPieceInfo1, + sub3 := newProgressSubscriber(context.Background(), "client3", "taskTask", []*task.PieceInfo{ + additionPieceInfo1, }) additionPieceInfo2 := &task.PieceInfo{ PieceNum: 200, @@ -85,9 +85,9 @@ func Test_publisher_NotifySubscribers(t *testing.T) { PieceStyle: 0, } publisher.AddSubscriber(sub3) - sub4 := newProgressSubscriber(context.Background(), "client4", "taskTask", map[uint32]*task.PieceInfo{ - 100: additionPieceInfo1, - 200: additionPieceInfo2, + sub4 := newProgressSubscriber(context.Background(), "client4", "taskTask", []*task.PieceInfo{ + additionPieceInfo1, + additionPieceInfo2, }) publisher.AddSubscriber(sub4) chan1 := sub1.Receiver() @@ -144,6 +144,7 @@ func Test_publisher_NotifySubscribers(t *testing.T) { assert.Equal(4, pieceCount) }(chan4) + // notify all subscribers for i := range notifyPieces { publisher.NotifySubscribers(notifyPieces[i]) } diff --git a/cdn/supervisor/service.go b/cdn/supervisor/service.go index 3f967e925..be53bd3be 100644 --- a/cdn/supervisor/service.go +++ b/cdn/supervisor/service.go @@ -21,7 +21,6 @@ package supervisor import ( "context" "encoding/json" - "sort" "github.com/pkg/errors" @@ -124,18 +123,7 @@ func (service *cdnService) triggerCdnSyncAction(ctx context.Context, taskID stri } func (service *cdnService) GetSeedPieces(taskID string) ([]*task.PieceInfo, error) { - pieceMap, err := service.taskManager.GetProgress(taskID) - if err != nil { - return nil, err - } - pieces := make([]*task.PieceInfo, 0, len(pieceMap)) - for i := range pieceMap { - pieces = append(pieces, pieceMap[i]) - } - sort.Slice(pieces, func(i, j int) bool { - return pieces[i].PieceNum < pieces[j].PieceNum - }) - return pieces, nil + return service.taskManager.GetProgress(taskID) } func (service *cdnService) GetSeedTask(taskID string) (*task.SeedTask, error) { diff --git a/cdn/supervisor/task/manager.go b/cdn/supervisor/task/manager.go index 01cec2e93..506a12a69 100644 --- a/cdn/supervisor/task/manager.go +++ b/cdn/supervisor/task/manager.go @@ -55,7 +55,7 @@ type Manager interface { UpdateProgress(taskID string, piece *PieceInfo) (err error) // GetProgress returns the downloaded pieces belonging to the task - GetProgress(taskID string) (map[uint32]*PieceInfo, error) + GetProgress(taskID string) ([]*PieceInfo, error) // Exist check task existence with specified taskID. // returns the task info with specified taskID, or nil if no value is present. @@ -198,13 +198,13 @@ func (tm *manager) UpdateProgress(taskID string, info *PieceInfo) error { if !ok { return errTaskNotFound } - seedTask.Pieces[info.PieceNum] = info + seedTask.Pieces.Store(info.PieceNum, info) // only update access when update task success tm.accessTimeMap.Store(taskID, time.Now()) return nil } -func (tm *manager) GetProgress(taskID string) (map[uint32]*PieceInfo, error) { +func (tm *manager) GetProgress(taskID string) ([]*PieceInfo, error) { synclock.Lock(taskID, false) defer synclock.UnLock(taskID, false) seedTask, ok := tm.getTask(taskID) @@ -212,7 +212,12 @@ func (tm *manager) GetProgress(taskID string) (map[uint32]*PieceInfo, error) { return nil, errTaskNotFound } tm.accessTimeMap.Store(taskID, time.Now()) - return seedTask.Pieces, nil + var pieces []*PieceInfo + seedTask.Pieces.Range(func(key, value interface{}) bool { + pieces = append(pieces, value.(*PieceInfo)) + return true + }) + return pieces, nil } func (tm *manager) Exist(taskID string) (*SeedTask, bool) { diff --git a/cdn/supervisor/task/task.go b/cdn/supervisor/task/task.go index c60d4a9f1..fd06be592 100644 --- a/cdn/supervisor/task/task.go +++ b/cdn/supervisor/task/task.go @@ -18,6 +18,7 @@ package task import ( "strings" + "sync" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" @@ -83,7 +84,7 @@ type SeedTask struct { Header map[string]string `json:"header,omitempty"` // Pieces pieces of task - Pieces map[uint32]*PieceInfo `json:"-"` + Pieces *sync.Map `json:"-"` // map[uint32]*PieceInfo logger *logger.SugaredLoggerOnWith } @@ -124,7 +125,7 @@ func NewSeedTask(taskID string, rawURL string, urlMeta *base.UrlMeta) *SeedTask Range: urlMeta.Range, Filter: urlMeta.Filter, Header: urlMeta.Header, - Pieces: make(map[uint32]*PieceInfo), + Pieces: new(sync.Map), logger: logger.WithTaskID(taskID), } } @@ -137,11 +138,12 @@ func (task *SeedTask) Clone() *SeedTask { cloneTask.Header[key] = value } } - if len(task.Pieces) > 0 { - for pieceNum, piece := range task.Pieces { - cloneTask.Pieces[pieceNum] = piece - } - } + cloneTask.Pieces = new(sync.Map) + task.Pieces.Range(func(key, value interface{}) bool { + cloneTask.Pieces.Store(key, value) + return true + }) + return cloneTask } @@ -192,7 +194,7 @@ func (task *SeedTask) Log() *logger.SugaredLoggerOnWith { func (task *SeedTask) StartTrigger() { task.CdnStatus = StatusRunning - task.Pieces = make(map[uint32]*PieceInfo) + task.Pieces = new(sync.Map) } const (