fix concurrent piece map panic in cdn (#1121)

Signed-off-by: sunwp <244372610@qq.com>
This commit is contained in:
sunwp 2022-03-04 18:07:01 +08:00 committed by Gaius
parent 84eaf402c3
commit 4a752a47da
No known key found for this signature in database
GPG Key ID: 8B4E5D1290FA2FFB
11 changed files with 110 additions and 116 deletions

View File

@ -196,15 +196,15 @@ func (css *Server) GetPieceTasks(ctx context.Context, req *base.PieceTaskRequest
span.RecordError(err)
return nil, err
}
pieces, err := css.service.GetSeedPieces(req.TaskId)
taskPieces, err := css.service.GetSeedPieces(req.TaskId)
if err != nil {
err = dferrors.Newf(base.Code_CDNError, "failed to get pieces of task(%s) from cdn: %v", seedTask.ID, err)
span.RecordError(err)
return nil, err
}
pieceInfos := make([]*base.PieceInfo, 0, len(pieces))
pieceInfos := make([]*base.PieceInfo, 0, len(taskPieces))
var count uint32 = 0
for _, piece := range pieces {
for _, piece := range taskPieces {
if piece.PieceNum >= req.StartNum && (count < req.Limit || req.Limit <= 0) {
p := &base.PieceInfo{
PieceNum: int32(piece.PieceNum),
@ -220,11 +220,10 @@ func (css *Server) GetPieceTasks(ctx context.Context, req *base.PieceTaskRequest
}
}
pieceMd5Sign := seedTask.PieceMd5Sign
if len(seedTask.Pieces) == int(seedTask.TotalPieceCount) && pieceMd5Sign == "" {
taskPieces := seedTask.Pieces
if len(taskPieces) == int(seedTask.TotalPieceCount) && pieceMd5Sign == "" {
var pieceMd5s []string
for i := 0; i < len(taskPieces); i++ {
pieceMd5s = append(pieceMd5s, taskPieces[uint32(i)].PieceMd5)
pieceMd5s = append(pieceMd5s, taskPieces[i].PieceMd5)
}
pieceMd5Sign = digestutils.Sha256(pieceMd5s...)
}

View File

@ -20,6 +20,7 @@ import (
"context"
"fmt"
"sort"
"sync"
"testing"
"github.com/golang/mock/gomock"
@ -225,58 +226,58 @@ func TestServer_GetPieceTasks(t *testing.T) {
Range: "",
Filter: "",
Header: nil,
Pieces: map[uint32]*task.PieceInfo{
0: {
PieceNum: 0,
PieceMd5: "xxxx0",
PieceRange: &rangeutils.Range{
StartIndex: 0,
EndIndex: 99,
},
OriginRange: &rangeutils.Range{
StartIndex: 0,
EndIndex: 99,
},
PieceLen: 100,
PieceStyle: 0,
},
1: {
PieceNum: 1,
PieceMd5: "xxxx1",
PieceRange: &rangeutils.Range{
StartIndex: 100,
EndIndex: 199,
},
OriginRange: &rangeutils.Range{
StartIndex: 100,
EndIndex: 199,
},
PieceLen: 100,
PieceStyle: 0,
},
2: {
PieceNum: 2,
PieceMd5: "xxxx2",
PieceRange: &rangeutils.Range{
StartIndex: 200,
EndIndex: 299,
},
OriginRange: &rangeutils.Range{
StartIndex: 200,
EndIndex: 249,
},
PieceLen: 100,
PieceStyle: 0,
},
},
Pieces: new(sync.Map),
}
testTask.Pieces.Store(0, &task.PieceInfo{
PieceNum: 0,
PieceMd5: "xxxx0",
PieceRange: &rangeutils.Range{
StartIndex: 0,
EndIndex: 99,
},
OriginRange: &rangeutils.Range{
StartIndex: 0,
EndIndex: 99,
},
PieceLen: 100,
PieceStyle: 0,
})
testTask.Pieces.Store(1, &task.PieceInfo{
PieceNum: 1,
PieceMd5: "xxxx1",
PieceRange: &rangeutils.Range{
StartIndex: 100,
EndIndex: 199,
},
OriginRange: &rangeutils.Range{
StartIndex: 100,
EndIndex: 199,
},
PieceLen: 100,
PieceStyle: 0,
})
testTask.Pieces.Store(2, &task.PieceInfo{
PieceNum: 2,
PieceMd5: "xxxx2",
PieceRange: &rangeutils.Range{
StartIndex: 200,
EndIndex: 299,
},
OriginRange: &rangeutils.Range{
StartIndex: 200,
EndIndex: 249,
},
PieceLen: 100,
PieceStyle: 0,
})
cdnServiceMock.EXPECT().GetSeedTask(args.req.TaskId).DoAndReturn(func(taskID string) (seedTask *task.SeedTask, err error) {
return testTask, nil
})
cdnServiceMock.EXPECT().GetSeedPieces(args.req.TaskId).DoAndReturn(func(taskID string) (pieces []*task.PieceInfo, err error) {
for u := range testTask.Pieces {
pieces = append(pieces, testTask.Pieces[u])
}
testTask.Pieces.Range(func(key, value interface{}) bool {
pieces = append(pieces, value.(*task.PieceInfo))
return true
})
sort.Slice(pieces, func(i, j int) bool {
return pieces[i].PieceNum < pieces[j].PieceNum
})

View File

@ -22,6 +22,7 @@ import (
"io"
"os"
"strings"
"sync"
"testing"
"github.com/golang/mock/gomock"
@ -145,6 +146,7 @@ func (suite *CDNManagerTestSuite) TestTriggerCDN() {
Digest: "md5:f1e2488bba4d1267948d9e2f7008571c",
SourceRealDigest: "",
PieceMd5Sign: "",
Pieces: new(sync.Map),
},
targetTask: &task.SeedTask{
ID: md5TaskID,
@ -159,6 +161,7 @@ func (suite *CDNManagerTestSuite) TestTriggerCDN() {
Digest: "md5:f1e2488bba4d1267948d9e2f7008571c",
SourceRealDigest: "md5:f1e2488bba4d1267948d9e2f7008571c",
PieceMd5Sign: "bb138842f338fff90af737e4a6b2c6f8e2a7031ca9d5900bc9b646f6406d890f",
Pieces: new(sync.Map),
},
},
{
@ -176,6 +179,7 @@ func (suite *CDNManagerTestSuite) TestTriggerCDN() {
Digest: "sha256:b9907b9a5ba2b0223868c201b9addfe2ec1da1b90325d57c34f192966b0a68c5",
SourceRealDigest: "",
PieceMd5Sign: "",
Pieces: new(sync.Map),
},
targetTask: &task.SeedTask{
ID: sha256TaskID,
@ -190,6 +194,7 @@ func (suite *CDNManagerTestSuite) TestTriggerCDN() {
Digest: "sha256:b9907b9a5ba2b0223868c201b9addfe2ec1da1b90325d57c34f192966b0a68c5",
SourceRealDigest: "sha256:b9907b9a5ba2b0223868c201b9addfe2ec1da1b90325d57c34f192966b0a68c5",
PieceMd5Sign: "bb138842f338fff90af737e4a6b2c6f8e2a7031ca9d5900bc9b646f6406d890f",
Pieces: new(sync.Map),
},
},
}

View File

@ -1,7 +1,7 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: d7y.io/dragonfly/v2/cdn/supervisor (interfaces: CDNService)
// Package progress is a generated GoMock package.
// Package mocks is a generated GoMock package.
package mocks
import (

View File

@ -92,10 +92,10 @@ func (mr *MockManagerMockRecorder) Get(arg0 interface{}) *gomock.Call {
}
// GetProgress mocks base method.
func (m *MockManager) GetProgress(arg0 string) (map[uint32]*task.PieceInfo, error) {
func (m *MockManager) GetProgress(arg0 string) ([]*task.PieceInfo, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetProgress", arg0)
ret0, _ := ret[0].(map[uint32]*task.PieceInfo)
ret0, _ := ret[0].([]*task.PieceInfo)
ret1, _ := ret[1].(error)
return ret0, ret1
}

View File

@ -75,6 +75,10 @@ func (pm *manager) WatchSeedProgress(ctx context.Context, clientAddr string, tas
if err != nil {
return nil, err
}
pieces, err := pm.taskManager.GetProgress(taskID)
if err != nil {
return nil, err
}
if seedTask.IsDone() {
pieceChan := make(chan *task.PieceInfo)
go func(pieceChan chan *task.PieceInfo) {
@ -82,22 +86,18 @@ func (pm *manager) WatchSeedProgress(ctx context.Context, clientAddr string, tas
logger.Debugf("subscriber %s starts watching task %s seed progress", clientAddr, taskID)
close(pieceChan)
}()
pieceNums := make([]uint32, 0, len(seedTask.Pieces))
for pieceNum := range seedTask.Pieces {
pieceNums = append(pieceNums, pieceNum)
}
sort.Slice(pieceNums, func(i, j int) bool {
return pieceNums[i] < pieceNums[j]
sort.Slice(pieces, func(i, j int) bool {
return pieces[i].PieceNum < pieces[j].PieceNum
})
for _, pieceNum := range pieceNums {
logger.Debugf("notifies subscriber %s about %d piece info of taskID %s", clientAddr, pieceNum, taskID)
pieceChan <- seedTask.Pieces[pieceNum]
for _, piece := range pieces {
logger.Debugf("notifies subscriber %s about %d piece info of taskID %s", clientAddr, piece.PieceNum, taskID)
pieceChan <- piece
}
}(pieceChan)
return pieceChan, nil
}
var progressPublisher, _ = pm.seedTaskSubjects.LoadOrStore(taskID, newProgressPublisher(taskID))
observer := newProgressSubscriber(ctx, clientAddr, seedTask.ID, seedTask.Pieces)
observer := newProgressSubscriber(ctx, clientAddr, seedTask.ID, pieces)
progressPublisher.(*publisher).AddSubscriber(observer)
return observer.Receiver(), nil
}

View File

@ -34,25 +34,21 @@ type subscriber struct {
taskID string
done chan struct{}
once sync.Once
pieces map[uint32]*task.PieceInfo
pieces []*task.PieceInfo
pieceChan chan *task.PieceInfo
cond *sync.Cond
closed *atomic.Bool
}
func newProgressSubscriber(ctx context.Context, clientAddr, taskID string, taskPieces map[uint32]*task.PieceInfo) *subscriber {
pieces := make(map[uint32]*task.PieceInfo, len(taskPieces))
for u, info := range taskPieces {
pieces[u] = info
}
func newProgressSubscriber(ctx context.Context, clientAddr, taskID string, taskPieces []*task.PieceInfo) *subscriber {
sub := &subscriber{
ctx: ctx,
scheduler: clientAddr,
taskID: taskID,
pieces: pieces,
done: make(chan struct{}),
pieceChan: make(chan *task.PieceInfo, 100),
cond: sync.NewCond(&sync.Mutex{}),
pieces: taskPieces,
pieceChan: make(chan *task.PieceInfo, 100),
closed: atomic.NewBool(false),
}
go sub.readLoop()
@ -89,24 +85,20 @@ func (sub *subscriber) readLoop() {
}
func (sub *subscriber) sendPieces() {
pieceNums := make([]uint32, 0, len(sub.pieces))
for pieceNum := range sub.pieces {
pieceNums = append(pieceNums, pieceNum)
}
sort.Slice(pieceNums, func(i, j int) bool {
return pieceNums[i] < pieceNums[j]
sort.Slice(sub.pieces, func(i, j int) bool {
return sub.pieces[i].PieceNum < sub.pieces[j].PieceNum
})
for _, pieceNum := range pieceNums {
logger.Debugf("subscriber %s send %d piece info of taskID %s", sub.scheduler, pieceNum, sub.taskID)
sub.pieceChan <- sub.pieces[pieceNum]
delete(sub.pieces, pieceNum)
for _, piece := range sub.pieces {
logger.Debugf("subscriber %s send %d piece info of taskID %s", sub.scheduler, piece.PieceNum, sub.taskID)
sub.pieceChan <- piece
}
sub.pieces = []*task.PieceInfo{}
}
func (sub *subscriber) Notify(seedPiece *task.PieceInfo) {
logger.Debugf("notifies subscriber %s about %d piece info of taskID %s", sub.scheduler, seedPiece.PieceNum, sub.taskID)
sub.cond.L.Lock()
sub.pieces[seedPiece.PieceNum] = seedPiece
sub.pieces = append(sub.pieces, seedPiece)
sub.cond.L.Unlock()
sub.cond.Signal()
}
@ -154,7 +146,8 @@ func (pub *publisher) RemoveSubscriber(sub *subscriber) {
func (pub *publisher) NotifySubscribers(seedPiece *task.PieceInfo) {
for e := pub.subscribers.Front(); e != nil; e = e.Next() {
e.Value.(*subscriber).Notify(seedPiece)
sub := e.Value.(*subscriber)
sub.Notify(seedPiece)
}
}

View File

@ -73,8 +73,8 @@ func Test_publisher_NotifySubscribers(t *testing.T) {
PieceLen: 0,
PieceStyle: 0,
}
sub3 := newProgressSubscriber(context.Background(), "client3", "taskTask", map[uint32]*task.PieceInfo{
100: additionPieceInfo1,
sub3 := newProgressSubscriber(context.Background(), "client3", "taskTask", []*task.PieceInfo{
additionPieceInfo1,
})
additionPieceInfo2 := &task.PieceInfo{
PieceNum: 200,
@ -85,9 +85,9 @@ func Test_publisher_NotifySubscribers(t *testing.T) {
PieceStyle: 0,
}
publisher.AddSubscriber(sub3)
sub4 := newProgressSubscriber(context.Background(), "client4", "taskTask", map[uint32]*task.PieceInfo{
100: additionPieceInfo1,
200: additionPieceInfo2,
sub4 := newProgressSubscriber(context.Background(), "client4", "taskTask", []*task.PieceInfo{
additionPieceInfo1,
additionPieceInfo2,
})
publisher.AddSubscriber(sub4)
chan1 := sub1.Receiver()
@ -144,6 +144,7 @@ func Test_publisher_NotifySubscribers(t *testing.T) {
assert.Equal(4, pieceCount)
}(chan4)
// notify all subscribers
for i := range notifyPieces {
publisher.NotifySubscribers(notifyPieces[i])
}

View File

@ -21,7 +21,6 @@ package supervisor
import (
"context"
"encoding/json"
"sort"
"github.com/pkg/errors"
@ -124,18 +123,7 @@ func (service *cdnService) triggerCdnSyncAction(ctx context.Context, taskID stri
}
func (service *cdnService) GetSeedPieces(taskID string) ([]*task.PieceInfo, error) {
pieceMap, err := service.taskManager.GetProgress(taskID)
if err != nil {
return nil, err
}
pieces := make([]*task.PieceInfo, 0, len(pieceMap))
for i := range pieceMap {
pieces = append(pieces, pieceMap[i])
}
sort.Slice(pieces, func(i, j int) bool {
return pieces[i].PieceNum < pieces[j].PieceNum
})
return pieces, nil
return service.taskManager.GetProgress(taskID)
}
func (service *cdnService) GetSeedTask(taskID string) (*task.SeedTask, error) {

View File

@ -55,7 +55,7 @@ type Manager interface {
UpdateProgress(taskID string, piece *PieceInfo) (err error)
// GetProgress returns the downloaded pieces belonging to the task
GetProgress(taskID string) (map[uint32]*PieceInfo, error)
GetProgress(taskID string) ([]*PieceInfo, error)
// Exist check task existence with specified taskID.
// returns the task info with specified taskID, or nil if no value is present.
@ -198,13 +198,13 @@ func (tm *manager) UpdateProgress(taskID string, info *PieceInfo) error {
if !ok {
return errTaskNotFound
}
seedTask.Pieces[info.PieceNum] = info
seedTask.Pieces.Store(info.PieceNum, info)
// only update access when update task success
tm.accessTimeMap.Store(taskID, time.Now())
return nil
}
func (tm *manager) GetProgress(taskID string) (map[uint32]*PieceInfo, error) {
func (tm *manager) GetProgress(taskID string) ([]*PieceInfo, error) {
synclock.Lock(taskID, false)
defer synclock.UnLock(taskID, false)
seedTask, ok := tm.getTask(taskID)
@ -212,7 +212,12 @@ func (tm *manager) GetProgress(taskID string) (map[uint32]*PieceInfo, error) {
return nil, errTaskNotFound
}
tm.accessTimeMap.Store(taskID, time.Now())
return seedTask.Pieces, nil
var pieces []*PieceInfo
seedTask.Pieces.Range(func(key, value interface{}) bool {
pieces = append(pieces, value.(*PieceInfo))
return true
})
return pieces, nil
}
func (tm *manager) Exist(taskID string) (*SeedTask, bool) {

View File

@ -18,6 +18,7 @@ package task
import (
"strings"
"sync"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
@ -83,7 +84,7 @@ type SeedTask struct {
Header map[string]string `json:"header,omitempty"`
// Pieces pieces of task
Pieces map[uint32]*PieceInfo `json:"-"`
Pieces *sync.Map `json:"-"` // map[uint32]*PieceInfo
logger *logger.SugaredLoggerOnWith
}
@ -124,7 +125,7 @@ func NewSeedTask(taskID string, rawURL string, urlMeta *base.UrlMeta) *SeedTask
Range: urlMeta.Range,
Filter: urlMeta.Filter,
Header: urlMeta.Header,
Pieces: make(map[uint32]*PieceInfo),
Pieces: new(sync.Map),
logger: logger.WithTaskID(taskID),
}
}
@ -137,11 +138,12 @@ func (task *SeedTask) Clone() *SeedTask {
cloneTask.Header[key] = value
}
}
if len(task.Pieces) > 0 {
for pieceNum, piece := range task.Pieces {
cloneTask.Pieces[pieceNum] = piece
}
}
cloneTask.Pieces = new(sync.Map)
task.Pieces.Range(func(key, value interface{}) bool {
cloneTask.Pieces.Store(key, value)
return true
})
return cloneTask
}
@ -192,7 +194,7 @@ func (task *SeedTask) Log() *logger.SugaredLoggerOnWith {
func (task *SeedTask) StartTrigger() {
task.CdnStatus = StatusRunning
task.Pieces = make(map[uint32]*PieceInfo)
task.Pieces = new(sync.Map)
}
const (