fix concurrent piece map panic in cdn (#1121)

Signed-off-by: sunwp <244372610@qq.com>
This commit is contained in:
sunwp 2022-03-04 18:07:01 +08:00 committed by Gaius
parent 84eaf402c3
commit 4a752a47da
No known key found for this signature in database
GPG Key ID: 8B4E5D1290FA2FFB
11 changed files with 110 additions and 116 deletions

View File

@ -196,15 +196,15 @@ func (css *Server) GetPieceTasks(ctx context.Context, req *base.PieceTaskRequest
span.RecordError(err) span.RecordError(err)
return nil, err return nil, err
} }
pieces, err := css.service.GetSeedPieces(req.TaskId) taskPieces, err := css.service.GetSeedPieces(req.TaskId)
if err != nil { if err != nil {
err = dferrors.Newf(base.Code_CDNError, "failed to get pieces of task(%s) from cdn: %v", seedTask.ID, err) err = dferrors.Newf(base.Code_CDNError, "failed to get pieces of task(%s) from cdn: %v", seedTask.ID, err)
span.RecordError(err) span.RecordError(err)
return nil, err return nil, err
} }
pieceInfos := make([]*base.PieceInfo, 0, len(pieces)) pieceInfos := make([]*base.PieceInfo, 0, len(taskPieces))
var count uint32 = 0 var count uint32 = 0
for _, piece := range pieces { for _, piece := range taskPieces {
if piece.PieceNum >= req.StartNum && (count < req.Limit || req.Limit <= 0) { if piece.PieceNum >= req.StartNum && (count < req.Limit || req.Limit <= 0) {
p := &base.PieceInfo{ p := &base.PieceInfo{
PieceNum: int32(piece.PieceNum), PieceNum: int32(piece.PieceNum),
@ -220,11 +220,10 @@ func (css *Server) GetPieceTasks(ctx context.Context, req *base.PieceTaskRequest
} }
} }
pieceMd5Sign := seedTask.PieceMd5Sign pieceMd5Sign := seedTask.PieceMd5Sign
if len(seedTask.Pieces) == int(seedTask.TotalPieceCount) && pieceMd5Sign == "" { if len(taskPieces) == int(seedTask.TotalPieceCount) && pieceMd5Sign == "" {
taskPieces := seedTask.Pieces
var pieceMd5s []string var pieceMd5s []string
for i := 0; i < len(taskPieces); i++ { for i := 0; i < len(taskPieces); i++ {
pieceMd5s = append(pieceMd5s, taskPieces[uint32(i)].PieceMd5) pieceMd5s = append(pieceMd5s, taskPieces[i].PieceMd5)
} }
pieceMd5Sign = digestutils.Sha256(pieceMd5s...) pieceMd5Sign = digestutils.Sha256(pieceMd5s...)
} }

View File

@ -20,6 +20,7 @@ import (
"context" "context"
"fmt" "fmt"
"sort" "sort"
"sync"
"testing" "testing"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
@ -225,58 +226,58 @@ func TestServer_GetPieceTasks(t *testing.T) {
Range: "", Range: "",
Filter: "", Filter: "",
Header: nil, Header: nil,
Pieces: map[uint32]*task.PieceInfo{ Pieces: new(sync.Map),
0: {
PieceNum: 0,
PieceMd5: "xxxx0",
PieceRange: &rangeutils.Range{
StartIndex: 0,
EndIndex: 99,
},
OriginRange: &rangeutils.Range{
StartIndex: 0,
EndIndex: 99,
},
PieceLen: 100,
PieceStyle: 0,
},
1: {
PieceNum: 1,
PieceMd5: "xxxx1",
PieceRange: &rangeutils.Range{
StartIndex: 100,
EndIndex: 199,
},
OriginRange: &rangeutils.Range{
StartIndex: 100,
EndIndex: 199,
},
PieceLen: 100,
PieceStyle: 0,
},
2: {
PieceNum: 2,
PieceMd5: "xxxx2",
PieceRange: &rangeutils.Range{
StartIndex: 200,
EndIndex: 299,
},
OriginRange: &rangeutils.Range{
StartIndex: 200,
EndIndex: 249,
},
PieceLen: 100,
PieceStyle: 0,
},
},
} }
testTask.Pieces.Store(0, &task.PieceInfo{
PieceNum: 0,
PieceMd5: "xxxx0",
PieceRange: &rangeutils.Range{
StartIndex: 0,
EndIndex: 99,
},
OriginRange: &rangeutils.Range{
StartIndex: 0,
EndIndex: 99,
},
PieceLen: 100,
PieceStyle: 0,
})
testTask.Pieces.Store(1, &task.PieceInfo{
PieceNum: 1,
PieceMd5: "xxxx1",
PieceRange: &rangeutils.Range{
StartIndex: 100,
EndIndex: 199,
},
OriginRange: &rangeutils.Range{
StartIndex: 100,
EndIndex: 199,
},
PieceLen: 100,
PieceStyle: 0,
})
testTask.Pieces.Store(2, &task.PieceInfo{
PieceNum: 2,
PieceMd5: "xxxx2",
PieceRange: &rangeutils.Range{
StartIndex: 200,
EndIndex: 299,
},
OriginRange: &rangeutils.Range{
StartIndex: 200,
EndIndex: 249,
},
PieceLen: 100,
PieceStyle: 0,
})
cdnServiceMock.EXPECT().GetSeedTask(args.req.TaskId).DoAndReturn(func(taskID string) (seedTask *task.SeedTask, err error) { cdnServiceMock.EXPECT().GetSeedTask(args.req.TaskId).DoAndReturn(func(taskID string) (seedTask *task.SeedTask, err error) {
return testTask, nil return testTask, nil
}) })
cdnServiceMock.EXPECT().GetSeedPieces(args.req.TaskId).DoAndReturn(func(taskID string) (pieces []*task.PieceInfo, err error) { cdnServiceMock.EXPECT().GetSeedPieces(args.req.TaskId).DoAndReturn(func(taskID string) (pieces []*task.PieceInfo, err error) {
for u := range testTask.Pieces { testTask.Pieces.Range(func(key, value interface{}) bool {
pieces = append(pieces, testTask.Pieces[u]) pieces = append(pieces, value.(*task.PieceInfo))
} return true
})
sort.Slice(pieces, func(i, j int) bool { sort.Slice(pieces, func(i, j int) bool {
return pieces[i].PieceNum < pieces[j].PieceNum return pieces[i].PieceNum < pieces[j].PieceNum
}) })

View File

@ -22,6 +22,7 @@ import (
"io" "io"
"os" "os"
"strings" "strings"
"sync"
"testing" "testing"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
@ -145,6 +146,7 @@ func (suite *CDNManagerTestSuite) TestTriggerCDN() {
Digest: "md5:f1e2488bba4d1267948d9e2f7008571c", Digest: "md5:f1e2488bba4d1267948d9e2f7008571c",
SourceRealDigest: "", SourceRealDigest: "",
PieceMd5Sign: "", PieceMd5Sign: "",
Pieces: new(sync.Map),
}, },
targetTask: &task.SeedTask{ targetTask: &task.SeedTask{
ID: md5TaskID, ID: md5TaskID,
@ -159,6 +161,7 @@ func (suite *CDNManagerTestSuite) TestTriggerCDN() {
Digest: "md5:f1e2488bba4d1267948d9e2f7008571c", Digest: "md5:f1e2488bba4d1267948d9e2f7008571c",
SourceRealDigest: "md5:f1e2488bba4d1267948d9e2f7008571c", SourceRealDigest: "md5:f1e2488bba4d1267948d9e2f7008571c",
PieceMd5Sign: "bb138842f338fff90af737e4a6b2c6f8e2a7031ca9d5900bc9b646f6406d890f", PieceMd5Sign: "bb138842f338fff90af737e4a6b2c6f8e2a7031ca9d5900bc9b646f6406d890f",
Pieces: new(sync.Map),
}, },
}, },
{ {
@ -176,6 +179,7 @@ func (suite *CDNManagerTestSuite) TestTriggerCDN() {
Digest: "sha256:b9907b9a5ba2b0223868c201b9addfe2ec1da1b90325d57c34f192966b0a68c5", Digest: "sha256:b9907b9a5ba2b0223868c201b9addfe2ec1da1b90325d57c34f192966b0a68c5",
SourceRealDigest: "", SourceRealDigest: "",
PieceMd5Sign: "", PieceMd5Sign: "",
Pieces: new(sync.Map),
}, },
targetTask: &task.SeedTask{ targetTask: &task.SeedTask{
ID: sha256TaskID, ID: sha256TaskID,
@ -190,6 +194,7 @@ func (suite *CDNManagerTestSuite) TestTriggerCDN() {
Digest: "sha256:b9907b9a5ba2b0223868c201b9addfe2ec1da1b90325d57c34f192966b0a68c5", Digest: "sha256:b9907b9a5ba2b0223868c201b9addfe2ec1da1b90325d57c34f192966b0a68c5",
SourceRealDigest: "sha256:b9907b9a5ba2b0223868c201b9addfe2ec1da1b90325d57c34f192966b0a68c5", SourceRealDigest: "sha256:b9907b9a5ba2b0223868c201b9addfe2ec1da1b90325d57c34f192966b0a68c5",
PieceMd5Sign: "bb138842f338fff90af737e4a6b2c6f8e2a7031ca9d5900bc9b646f6406d890f", PieceMd5Sign: "bb138842f338fff90af737e4a6b2c6f8e2a7031ca9d5900bc9b646f6406d890f",
Pieces: new(sync.Map),
}, },
}, },
} }

View File

@ -1,7 +1,7 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: d7y.io/dragonfly/v2/cdn/supervisor (interfaces: CDNService) // Source: d7y.io/dragonfly/v2/cdn/supervisor (interfaces: CDNService)
// Package progress is a generated GoMock package. // Package mocks is a generated GoMock package.
package mocks package mocks
import ( import (

View File

@ -92,10 +92,10 @@ func (mr *MockManagerMockRecorder) Get(arg0 interface{}) *gomock.Call {
} }
// GetProgress mocks base method. // GetProgress mocks base method.
func (m *MockManager) GetProgress(arg0 string) (map[uint32]*task.PieceInfo, error) { func (m *MockManager) GetProgress(arg0 string) ([]*task.PieceInfo, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetProgress", arg0) ret := m.ctrl.Call(m, "GetProgress", arg0)
ret0, _ := ret[0].(map[uint32]*task.PieceInfo) ret0, _ := ret[0].([]*task.PieceInfo)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }

View File

@ -75,6 +75,10 @@ func (pm *manager) WatchSeedProgress(ctx context.Context, clientAddr string, tas
if err != nil { if err != nil {
return nil, err return nil, err
} }
pieces, err := pm.taskManager.GetProgress(taskID)
if err != nil {
return nil, err
}
if seedTask.IsDone() { if seedTask.IsDone() {
pieceChan := make(chan *task.PieceInfo) pieceChan := make(chan *task.PieceInfo)
go func(pieceChan chan *task.PieceInfo) { go func(pieceChan chan *task.PieceInfo) {
@ -82,22 +86,18 @@ func (pm *manager) WatchSeedProgress(ctx context.Context, clientAddr string, tas
logger.Debugf("subscriber %s starts watching task %s seed progress", clientAddr, taskID) logger.Debugf("subscriber %s starts watching task %s seed progress", clientAddr, taskID)
close(pieceChan) close(pieceChan)
}() }()
pieceNums := make([]uint32, 0, len(seedTask.Pieces)) sort.Slice(pieces, func(i, j int) bool {
for pieceNum := range seedTask.Pieces { return pieces[i].PieceNum < pieces[j].PieceNum
pieceNums = append(pieceNums, pieceNum)
}
sort.Slice(pieceNums, func(i, j int) bool {
return pieceNums[i] < pieceNums[j]
}) })
for _, pieceNum := range pieceNums { for _, piece := range pieces {
logger.Debugf("notifies subscriber %s about %d piece info of taskID %s", clientAddr, pieceNum, taskID) logger.Debugf("notifies subscriber %s about %d piece info of taskID %s", clientAddr, piece.PieceNum, taskID)
pieceChan <- seedTask.Pieces[pieceNum] pieceChan <- piece
} }
}(pieceChan) }(pieceChan)
return pieceChan, nil return pieceChan, nil
} }
var progressPublisher, _ = pm.seedTaskSubjects.LoadOrStore(taskID, newProgressPublisher(taskID)) var progressPublisher, _ = pm.seedTaskSubjects.LoadOrStore(taskID, newProgressPublisher(taskID))
observer := newProgressSubscriber(ctx, clientAddr, seedTask.ID, seedTask.Pieces) observer := newProgressSubscriber(ctx, clientAddr, seedTask.ID, pieces)
progressPublisher.(*publisher).AddSubscriber(observer) progressPublisher.(*publisher).AddSubscriber(observer)
return observer.Receiver(), nil return observer.Receiver(), nil
} }

View File

@ -34,25 +34,21 @@ type subscriber struct {
taskID string taskID string
done chan struct{} done chan struct{}
once sync.Once once sync.Once
pieces map[uint32]*task.PieceInfo pieces []*task.PieceInfo
pieceChan chan *task.PieceInfo pieceChan chan *task.PieceInfo
cond *sync.Cond cond *sync.Cond
closed *atomic.Bool closed *atomic.Bool
} }
func newProgressSubscriber(ctx context.Context, clientAddr, taskID string, taskPieces map[uint32]*task.PieceInfo) *subscriber { func newProgressSubscriber(ctx context.Context, clientAddr, taskID string, taskPieces []*task.PieceInfo) *subscriber {
pieces := make(map[uint32]*task.PieceInfo, len(taskPieces))
for u, info := range taskPieces {
pieces[u] = info
}
sub := &subscriber{ sub := &subscriber{
ctx: ctx, ctx: ctx,
scheduler: clientAddr, scheduler: clientAddr,
taskID: taskID, taskID: taskID,
pieces: pieces,
done: make(chan struct{}), done: make(chan struct{}),
pieceChan: make(chan *task.PieceInfo, 100),
cond: sync.NewCond(&sync.Mutex{}), cond: sync.NewCond(&sync.Mutex{}),
pieces: taskPieces,
pieceChan: make(chan *task.PieceInfo, 100),
closed: atomic.NewBool(false), closed: atomic.NewBool(false),
} }
go sub.readLoop() go sub.readLoop()
@ -89,24 +85,20 @@ func (sub *subscriber) readLoop() {
} }
func (sub *subscriber) sendPieces() { func (sub *subscriber) sendPieces() {
pieceNums := make([]uint32, 0, len(sub.pieces)) sort.Slice(sub.pieces, func(i, j int) bool {
for pieceNum := range sub.pieces { return sub.pieces[i].PieceNum < sub.pieces[j].PieceNum
pieceNums = append(pieceNums, pieceNum)
}
sort.Slice(pieceNums, func(i, j int) bool {
return pieceNums[i] < pieceNums[j]
}) })
for _, pieceNum := range pieceNums { for _, piece := range sub.pieces {
logger.Debugf("subscriber %s send %d piece info of taskID %s", sub.scheduler, pieceNum, sub.taskID) logger.Debugf("subscriber %s send %d piece info of taskID %s", sub.scheduler, piece.PieceNum, sub.taskID)
sub.pieceChan <- sub.pieces[pieceNum] sub.pieceChan <- piece
delete(sub.pieces, pieceNum)
} }
sub.pieces = []*task.PieceInfo{}
} }
func (sub *subscriber) Notify(seedPiece *task.PieceInfo) { func (sub *subscriber) Notify(seedPiece *task.PieceInfo) {
logger.Debugf("notifies subscriber %s about %d piece info of taskID %s", sub.scheduler, seedPiece.PieceNum, sub.taskID) logger.Debugf("notifies subscriber %s about %d piece info of taskID %s", sub.scheduler, seedPiece.PieceNum, sub.taskID)
sub.cond.L.Lock() sub.cond.L.Lock()
sub.pieces[seedPiece.PieceNum] = seedPiece sub.pieces = append(sub.pieces, seedPiece)
sub.cond.L.Unlock() sub.cond.L.Unlock()
sub.cond.Signal() sub.cond.Signal()
} }
@ -154,7 +146,8 @@ func (pub *publisher) RemoveSubscriber(sub *subscriber) {
func (pub *publisher) NotifySubscribers(seedPiece *task.PieceInfo) { func (pub *publisher) NotifySubscribers(seedPiece *task.PieceInfo) {
for e := pub.subscribers.Front(); e != nil; e = e.Next() { for e := pub.subscribers.Front(); e != nil; e = e.Next() {
e.Value.(*subscriber).Notify(seedPiece) sub := e.Value.(*subscriber)
sub.Notify(seedPiece)
} }
} }

View File

@ -73,8 +73,8 @@ func Test_publisher_NotifySubscribers(t *testing.T) {
PieceLen: 0, PieceLen: 0,
PieceStyle: 0, PieceStyle: 0,
} }
sub3 := newProgressSubscriber(context.Background(), "client3", "taskTask", map[uint32]*task.PieceInfo{ sub3 := newProgressSubscriber(context.Background(), "client3", "taskTask", []*task.PieceInfo{
100: additionPieceInfo1, additionPieceInfo1,
}) })
additionPieceInfo2 := &task.PieceInfo{ additionPieceInfo2 := &task.PieceInfo{
PieceNum: 200, PieceNum: 200,
@ -85,9 +85,9 @@ func Test_publisher_NotifySubscribers(t *testing.T) {
PieceStyle: 0, PieceStyle: 0,
} }
publisher.AddSubscriber(sub3) publisher.AddSubscriber(sub3)
sub4 := newProgressSubscriber(context.Background(), "client4", "taskTask", map[uint32]*task.PieceInfo{ sub4 := newProgressSubscriber(context.Background(), "client4", "taskTask", []*task.PieceInfo{
100: additionPieceInfo1, additionPieceInfo1,
200: additionPieceInfo2, additionPieceInfo2,
}) })
publisher.AddSubscriber(sub4) publisher.AddSubscriber(sub4)
chan1 := sub1.Receiver() chan1 := sub1.Receiver()
@ -144,6 +144,7 @@ func Test_publisher_NotifySubscribers(t *testing.T) {
assert.Equal(4, pieceCount) assert.Equal(4, pieceCount)
}(chan4) }(chan4)
// notify all subscribers
for i := range notifyPieces { for i := range notifyPieces {
publisher.NotifySubscribers(notifyPieces[i]) publisher.NotifySubscribers(notifyPieces[i])
} }

View File

@ -21,7 +21,6 @@ package supervisor
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"sort"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -124,18 +123,7 @@ func (service *cdnService) triggerCdnSyncAction(ctx context.Context, taskID stri
} }
func (service *cdnService) GetSeedPieces(taskID string) ([]*task.PieceInfo, error) { func (service *cdnService) GetSeedPieces(taskID string) ([]*task.PieceInfo, error) {
pieceMap, err := service.taskManager.GetProgress(taskID) return service.taskManager.GetProgress(taskID)
if err != nil {
return nil, err
}
pieces := make([]*task.PieceInfo, 0, len(pieceMap))
for i := range pieceMap {
pieces = append(pieces, pieceMap[i])
}
sort.Slice(pieces, func(i, j int) bool {
return pieces[i].PieceNum < pieces[j].PieceNum
})
return pieces, nil
} }
func (service *cdnService) GetSeedTask(taskID string) (*task.SeedTask, error) { func (service *cdnService) GetSeedTask(taskID string) (*task.SeedTask, error) {

View File

@ -55,7 +55,7 @@ type Manager interface {
UpdateProgress(taskID string, piece *PieceInfo) (err error) UpdateProgress(taskID string, piece *PieceInfo) (err error)
// GetProgress returns the downloaded pieces belonging to the task // GetProgress returns the downloaded pieces belonging to the task
GetProgress(taskID string) (map[uint32]*PieceInfo, error) GetProgress(taskID string) ([]*PieceInfo, error)
// Exist check task existence with specified taskID. // Exist check task existence with specified taskID.
// returns the task info with specified taskID, or nil if no value is present. // returns the task info with specified taskID, or nil if no value is present.
@ -198,13 +198,13 @@ func (tm *manager) UpdateProgress(taskID string, info *PieceInfo) error {
if !ok { if !ok {
return errTaskNotFound return errTaskNotFound
} }
seedTask.Pieces[info.PieceNum] = info seedTask.Pieces.Store(info.PieceNum, info)
// only update access when update task success // only update access when update task success
tm.accessTimeMap.Store(taskID, time.Now()) tm.accessTimeMap.Store(taskID, time.Now())
return nil return nil
} }
func (tm *manager) GetProgress(taskID string) (map[uint32]*PieceInfo, error) { func (tm *manager) GetProgress(taskID string) ([]*PieceInfo, error) {
synclock.Lock(taskID, false) synclock.Lock(taskID, false)
defer synclock.UnLock(taskID, false) defer synclock.UnLock(taskID, false)
seedTask, ok := tm.getTask(taskID) seedTask, ok := tm.getTask(taskID)
@ -212,7 +212,12 @@ func (tm *manager) GetProgress(taskID string) (map[uint32]*PieceInfo, error) {
return nil, errTaskNotFound return nil, errTaskNotFound
} }
tm.accessTimeMap.Store(taskID, time.Now()) tm.accessTimeMap.Store(taskID, time.Now())
return seedTask.Pieces, nil var pieces []*PieceInfo
seedTask.Pieces.Range(func(key, value interface{}) bool {
pieces = append(pieces, value.(*PieceInfo))
return true
})
return pieces, nil
} }
func (tm *manager) Exist(taskID string) (*SeedTask, bool) { func (tm *manager) Exist(taskID string) (*SeedTask, bool) {

View File

@ -18,6 +18,7 @@ package task
import ( import (
"strings" "strings"
"sync"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts" "github.com/google/go-cmp/cmp/cmpopts"
@ -83,7 +84,7 @@ type SeedTask struct {
Header map[string]string `json:"header,omitempty"` Header map[string]string `json:"header,omitempty"`
// Pieces pieces of task // Pieces pieces of task
Pieces map[uint32]*PieceInfo `json:"-"` Pieces *sync.Map `json:"-"` // map[uint32]*PieceInfo
logger *logger.SugaredLoggerOnWith logger *logger.SugaredLoggerOnWith
} }
@ -124,7 +125,7 @@ func NewSeedTask(taskID string, rawURL string, urlMeta *base.UrlMeta) *SeedTask
Range: urlMeta.Range, Range: urlMeta.Range,
Filter: urlMeta.Filter, Filter: urlMeta.Filter,
Header: urlMeta.Header, Header: urlMeta.Header,
Pieces: make(map[uint32]*PieceInfo), Pieces: new(sync.Map),
logger: logger.WithTaskID(taskID), logger: logger.WithTaskID(taskID),
} }
} }
@ -137,11 +138,12 @@ func (task *SeedTask) Clone() *SeedTask {
cloneTask.Header[key] = value cloneTask.Header[key] = value
} }
} }
if len(task.Pieces) > 0 { cloneTask.Pieces = new(sync.Map)
for pieceNum, piece := range task.Pieces { task.Pieces.Range(func(key, value interface{}) bool {
cloneTask.Pieces[pieceNum] = piece cloneTask.Pieces.Store(key, value)
} return true
} })
return cloneTask return cloneTask
} }
@ -192,7 +194,7 @@ func (task *SeedTask) Log() *logger.SugaredLoggerOnWith {
func (task *SeedTask) StartTrigger() { func (task *SeedTask) StartTrigger() {
task.CdnStatus = StatusRunning task.CdnStatus = StatusRunning
task.Pieces = make(map[uint32]*PieceInfo) task.Pieces = new(sync.Map)
} }
const ( const (