diff --git a/cluster/swarm/cluster.go b/cluster/swarm/cluster.go index 32e0ed7946..ff6507afea 100644 --- a/cluster/swarm/cluster.go +++ b/cluster/swarm/cluster.go @@ -332,12 +332,10 @@ func (c *Cluster) Load(imageReader io.Reader, callback func(what, status string) c.RLock() pipeWriters := []*io.PipeWriter{} - pipeReaders := []*io.PipeReader{} for _, n := range c.engines { wg.Add(1) pipeReader, pipeWriter := io.Pipe() - pipeReaders = append(pipeReaders, pipeReader) pipeWriters = append(pipeWriters, pipeWriter) go func(reader *io.PipeReader, nn *cluster.Engine) { @@ -353,16 +351,17 @@ func (c *Cluster) Load(imageReader io.Reader, callback func(what, status string) } }(pipeReader, n) } + c.RUnlock() // create multi-writer listWriter := []io.Writer{} for _, pipeW := range pipeWriters { listWriter = append(listWriter, pipeW) } - mutiWriter := io.MultiWriter(listWriter...) + multiWriter := io.MultiWriter(listWriter...) - // copy image-reader to muti-writer - _, err := io.Copy(mutiWriter, imageReader) + // copy image-reader to multi-writer + _, err := io.Copy(multiWriter, imageReader) if err != nil { log.Error(err) } @@ -372,8 +371,6 @@ func (c *Cluster) Load(imageReader io.Reader, callback func(what, status string) pipeW.Close() } - c.RUnlock() - wg.Wait() } diff --git a/cluster/swarm/cluster_test.go b/cluster/swarm/cluster_test.go index aaab11a471..45f915e624 100644 --- a/cluster/swarm/cluster_test.go +++ b/cluster/swarm/cluster_test.go @@ -150,3 +150,46 @@ func TestImportImage(t *testing.T) { } c.Import("-", "testImageError", "latest", bytes.NewReader(nil), callback) } + +func TestLoadImage(t *testing.T) { + // create cluster + c := &Cluster{ + engines: make(map[string]*cluster.Engine), + } + + // create engione + id := "test-engine" + engine := cluster.NewEngine(id, 0) + engine.Name = id + engine.ID = id + + // create mock client + client := mockclient.NewMockClient() + client.On("Info").Return(mockInfo, nil) + client.On("StartMonitorEvents", mock.Anything, mock.Anything, mock.Anything).Return() + client.On("ListContainers", true, false, "").Return([]dockerclient.Container{}, nil).Once() + client.On("ListImages").Return([]*dockerclient.Image{}, nil) + + // connect client + engine.ConnectWithClient(client) + + // add engine to cluster + c.engines[engine.ID] = engine + + // load success + client.On("LoadImage", mock.AnythingOfType("*io.PipeReader")).Return(nil).Once() + callback := func(what, status string) { + //if load OK, will not come here + t.Fatalf("Load error") + } + c.Load(bytes.NewReader(nil), callback) + + // load error + err := fmt.Errorf("Load error") + client.On("LoadImage", mock.AnythingOfType("*io.PipeReader")).Return(err).Once() + callback = func(what, status string) { + // load error + assert.Equal(t, status, "Load error") + } + c.Load(bytes.NewReader(nil), callback) +}