1148 lines
38 KiB
Go
1148 lines
38 KiB
Go
/*
|
|
* Copyright 2022 The Dragonfly Authors
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package announcer
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"io"
|
|
"net"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/golang/mock/gomock"
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
managerv2 "d7y.io/api/pkg/apis/manager/v2"
|
|
trainerv1 "d7y.io/api/pkg/apis/trainer/v1"
|
|
trainerv1mocks "d7y.io/api/pkg/apis/trainer/v1/mocks"
|
|
|
|
managerclientmocks "d7y.io/dragonfly/v2/pkg/rpc/manager/client/mocks"
|
|
trainerclientmocks "d7y.io/dragonfly/v2/pkg/rpc/trainer/client/mocks"
|
|
"d7y.io/dragonfly/v2/scheduler/config"
|
|
storagemocks "d7y.io/dragonfly/v2/scheduler/storage/mocks"
|
|
)
|
|
|
|
type mockReadCloserWithReadError struct{}
|
|
|
|
func (m *mockReadCloserWithReadError) Read(p []byte) (int, error) {
|
|
return 0, errors.New("foo")
|
|
}
|
|
|
|
func (m *mockReadCloserWithReadError) Close() error {
|
|
return nil
|
|
}
|
|
|
|
func TestAnnouncer_New(t *testing.T) {
|
|
ctl := gomock.NewController(t)
|
|
defer ctl.Finish()
|
|
mockTrainerClient := trainerclientmocks.NewMockV1(ctl)
|
|
|
|
tests := []struct {
|
|
name string
|
|
config *config.Config
|
|
option []Option
|
|
mock func(m *managerclientmocks.MockV2MockRecorder)
|
|
expect func(t *testing.T, announcer Announcer, err error)
|
|
}{
|
|
{
|
|
name: "new announcer",
|
|
config: &config.Config{
|
|
Server: config.ServerConfig{
|
|
Host: "localhost",
|
|
AdvertiseIP: net.ParseIP("127.0.0.1"),
|
|
AdvertisePort: 8004,
|
|
Port: 8080,
|
|
},
|
|
Host: config.HostConfig{
|
|
IDC: "foo",
|
|
Location: "bar",
|
|
},
|
|
Manager: config.ManagerConfig{
|
|
SchedulerClusterID: 1,
|
|
},
|
|
},
|
|
option: []Option{},
|
|
mock: func(m *managerclientmocks.MockV2MockRecorder) {
|
|
m.UpdateScheduler(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1)
|
|
},
|
|
expect: func(t *testing.T, a Announcer, err error) {
|
|
assert := assert.New(t)
|
|
instance := a.(*announcer)
|
|
assert.NoError(err)
|
|
assert.NotNil(instance.config)
|
|
assert.NotNil(instance.managerClient)
|
|
},
|
|
},
|
|
{
|
|
name: "new announcer with trainer client",
|
|
config: &config.Config{
|
|
Server: config.ServerConfig{
|
|
Host: "localhost",
|
|
AdvertiseIP: net.ParseIP("127.0.0.1"),
|
|
AdvertisePort: 8004,
|
|
Port: 8080,
|
|
},
|
|
Host: config.HostConfig{
|
|
IDC: "foo",
|
|
Location: "bar",
|
|
},
|
|
Manager: config.ManagerConfig{
|
|
SchedulerClusterID: 1,
|
|
},
|
|
},
|
|
option: []Option{WithTrainerClient(mockTrainerClient)},
|
|
mock: func(m *managerclientmocks.MockV2MockRecorder) {
|
|
m.UpdateScheduler(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1)
|
|
},
|
|
expect: func(t *testing.T, a Announcer, err error) {
|
|
assert := assert.New(t)
|
|
instance := a.(*announcer)
|
|
assert.NoError(err)
|
|
assert.NotNil(instance.config)
|
|
assert.NotNil(instance.managerClient)
|
|
assert.NotNil(instance.trainerClient)
|
|
},
|
|
},
|
|
{
|
|
name: "update scheduler failed",
|
|
config: &config.Config{
|
|
Server: config.ServerConfig{
|
|
Host: "localhost",
|
|
AdvertiseIP: net.ParseIP("127.0.0.1"),
|
|
AdvertisePort: 8004,
|
|
Port: 8080,
|
|
},
|
|
Host: config.HostConfig{
|
|
IDC: "foo",
|
|
Location: "bar",
|
|
},
|
|
Manager: config.ManagerConfig{
|
|
SchedulerClusterID: 1,
|
|
},
|
|
},
|
|
option: []Option{},
|
|
mock: func(m *managerclientmocks.MockV2MockRecorder) {
|
|
m.UpdateScheduler(gomock.Any(), gomock.Any()).Return(nil, errors.New("foo")).Times(1)
|
|
},
|
|
expect: func(t *testing.T, a Announcer, err error) {
|
|
assert := assert.New(t)
|
|
assert.EqualError(err, "foo")
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
mockManagerClient := managerclientmocks.NewMockV2(ctl)
|
|
mockStorage := storagemocks.NewMockStorage(ctl)
|
|
tc.mock(mockManagerClient.EXPECT())
|
|
|
|
a, err := New(tc.config, mockManagerClient, mockStorage, tc.option...)
|
|
tc.expect(t, a, err)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAnnouncer_Serve(t *testing.T) {
|
|
ctl := gomock.NewController(t)
|
|
defer ctl.Finish()
|
|
mockTrainerClient := trainerclientmocks.NewMockV1(ctl)
|
|
|
|
tests := []struct {
|
|
name string
|
|
config *config.Config
|
|
data []byte
|
|
option []Option
|
|
sleep func()
|
|
mock func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder)
|
|
except func(t *testing.T, a Announcer)
|
|
}{
|
|
{
|
|
name: "started announcer server success",
|
|
config: &config.Config{
|
|
Server: config.ServerConfig{
|
|
Host: "localhost",
|
|
AdvertiseIP: net.ParseIP("127.0.0.1"),
|
|
AdvertisePort: 8004,
|
|
Port: 8080,
|
|
},
|
|
Host: config.HostConfig{
|
|
IDC: "foo",
|
|
Location: "bar",
|
|
},
|
|
Manager: config.ManagerConfig{
|
|
KeepAlive: config.KeepAliveConfig{
|
|
Interval: 50 * time.Millisecond,
|
|
},
|
|
SchedulerClusterID: 1,
|
|
},
|
|
Trainer: config.TrainerConfig{
|
|
Interval: 80 * time.Millisecond,
|
|
UploadTimeout: 10 * time.Second,
|
|
},
|
|
},
|
|
data: []byte("bar"),
|
|
option: []Option{WithTrainerClient(mockTrainerClient)},
|
|
sleep: func() {
|
|
time.Sleep(100 * time.Millisecond)
|
|
},
|
|
mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) {
|
|
var wg sync.WaitGroup
|
|
wg.Add(4)
|
|
gomock.InOrder(
|
|
m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{
|
|
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
|
|
Hostname: "localhost",
|
|
Ip: "127.0.0.1",
|
|
Port: int32(8004),
|
|
Idc: "foo",
|
|
Location: "bar",
|
|
SchedulerClusterId: uint64(1),
|
|
})).Times(1),
|
|
m.KeepAlive(gomock.Eq(50*time.Millisecond), gomock.Eq(&managerv2.KeepAliveRequest{
|
|
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
|
|
Hostname: "localhost",
|
|
Ip: "127.0.0.1",
|
|
ClusterId: uint64(1),
|
|
}), gomock.Any()).Times(1),
|
|
mtc.Train(gomock.Any()).Return(stream, nil).Times(1),
|
|
mt.CloseAndRecv().Do(func() { wg.Wait() }).Return(nil, nil).Times(1),
|
|
)
|
|
|
|
gomock.InOrder(
|
|
ms.OpenNetworkTopology().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1),
|
|
mt.Send(gomock.Any()).DoAndReturn(
|
|
func(t *trainerv1.TrainRequest) error {
|
|
wg.Done()
|
|
return nil
|
|
}).Times(2),
|
|
)
|
|
|
|
gomock.InOrder(
|
|
ms.OpenDownload().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1),
|
|
mt.Send(gomock.Any()).DoAndReturn(
|
|
func(t *trainerv1.TrainRequest) error {
|
|
wg.Done()
|
|
return nil
|
|
}).Times(2),
|
|
)
|
|
},
|
|
except: func(t *testing.T, a Announcer) {
|
|
go a.Serve()
|
|
},
|
|
},
|
|
{
|
|
name: "started announcer server success without trainer client",
|
|
config: &config.Config{
|
|
Server: config.ServerConfig{
|
|
Host: "localhost",
|
|
AdvertiseIP: net.ParseIP("127.0.0.1"),
|
|
AdvertisePort: 8004,
|
|
Port: 8080,
|
|
},
|
|
Host: config.HostConfig{
|
|
IDC: "foo",
|
|
Location: "bar",
|
|
},
|
|
Manager: config.ManagerConfig{
|
|
KeepAlive: config.KeepAliveConfig{
|
|
Interval: 50 * time.Millisecond,
|
|
},
|
|
SchedulerClusterID: 1,
|
|
},
|
|
},
|
|
data: []byte("bar"),
|
|
option: []Option{},
|
|
sleep: func() {
|
|
time.Sleep(100 * time.Millisecond)
|
|
},
|
|
mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) {
|
|
gomock.InOrder(
|
|
m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{
|
|
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
|
|
Hostname: "localhost",
|
|
Ip: "127.0.0.1",
|
|
Port: int32(8004),
|
|
Idc: "foo",
|
|
Location: "bar",
|
|
SchedulerClusterId: uint64(1),
|
|
})).Times(1),
|
|
m.KeepAlive(gomock.Eq(50*time.Millisecond), gomock.Eq(&managerv2.KeepAliveRequest{
|
|
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
|
|
Hostname: "localhost",
|
|
Ip: "127.0.0.1",
|
|
ClusterId: uint64(1),
|
|
}), gomock.Any()).Times(1),
|
|
)
|
|
},
|
|
except: func(t *testing.T, a Announcer) {
|
|
go a.Serve()
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
stream := trainerv1mocks.NewMockTrainer_TrainClient(ctl)
|
|
mockManagerClient := managerclientmocks.NewMockV2(ctl)
|
|
mockStorage := storagemocks.NewMockStorage(ctl)
|
|
|
|
tc.mock(stream, tc.data, mockManagerClient.EXPECT(), mockTrainerClient.EXPECT(), mockStorage.EXPECT(), stream.EXPECT())
|
|
a, err := New(tc.config, mockManagerClient, mockStorage, tc.option...)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
tc.except(t, a)
|
|
tc.sleep()
|
|
a.Stop()
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAnnouncer_announceToManager(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
config *config.Config
|
|
sleep func()
|
|
mock func(m *managerclientmocks.MockV2MockRecorder)
|
|
except func(a Announcer)
|
|
}{
|
|
{
|
|
name: "announce to manager success",
|
|
config: &config.Config{
|
|
Server: config.ServerConfig{
|
|
Host: "localhost",
|
|
AdvertiseIP: net.ParseIP("127.0.0.1"),
|
|
AdvertisePort: 8004,
|
|
Port: 8080,
|
|
},
|
|
Host: config.HostConfig{
|
|
IDC: "foo",
|
|
Location: "bar",
|
|
},
|
|
Manager: config.ManagerConfig{
|
|
KeepAlive: config.KeepAliveConfig{
|
|
Interval: 50 * time.Millisecond,
|
|
},
|
|
SchedulerClusterID: 1,
|
|
},
|
|
},
|
|
sleep: func() {
|
|
time.Sleep(100 * time.Millisecond)
|
|
},
|
|
mock: func(m *managerclientmocks.MockV2MockRecorder) {
|
|
gomock.InOrder(
|
|
m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{
|
|
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
|
|
Hostname: "localhost",
|
|
Ip: "127.0.0.1",
|
|
Port: int32(8004),
|
|
Idc: "foo",
|
|
Location: "bar",
|
|
SchedulerClusterId: uint64(1),
|
|
})).Times(1),
|
|
m.KeepAlive(gomock.Eq(50*time.Millisecond), gomock.Eq(&managerv2.KeepAliveRequest{
|
|
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
|
|
Hostname: "localhost",
|
|
Ip: "127.0.0.1",
|
|
ClusterId: uint64(1),
|
|
}), gomock.Any()).Times(1),
|
|
)
|
|
},
|
|
except: func(a Announcer) {
|
|
a.(*announcer).announceToManager()
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
ctl := gomock.NewController(t)
|
|
defer ctl.Finish()
|
|
mockManagerClient := managerclientmocks.NewMockV2(ctl)
|
|
mockTrainerClient := trainerclientmocks.NewMockV1(ctl)
|
|
mockStorage := storagemocks.NewMockStorage(ctl)
|
|
tc.mock(mockManagerClient.EXPECT())
|
|
|
|
a, err := New(tc.config, mockManagerClient, mockStorage, WithTrainerClient(mockTrainerClient))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
tc.except(a)
|
|
tc.sleep()
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAnnouncer_announceToTrainer(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
config *config.Config
|
|
data []byte
|
|
sleep func()
|
|
mock func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder)
|
|
except func(a Announcer)
|
|
}{
|
|
{
|
|
name: "announce to trainer failed",
|
|
config: &config.Config{
|
|
Server: config.ServerConfig{
|
|
Host: "localhost",
|
|
AdvertiseIP: net.ParseIP("127.0.0.1"),
|
|
AdvertisePort: 8004,
|
|
Port: 8080,
|
|
},
|
|
Host: config.HostConfig{
|
|
IDC: "foo",
|
|
Location: "bar",
|
|
},
|
|
Manager: config.ManagerConfig{
|
|
SchedulerClusterID: 1,
|
|
},
|
|
Trainer: config.TrainerConfig{
|
|
Interval: 80 * time.Millisecond,
|
|
UploadTimeout: 1 * time.Second,
|
|
},
|
|
},
|
|
data: []byte("bar"),
|
|
sleep: func() {
|
|
time.Sleep(100 * time.Millisecond)
|
|
},
|
|
mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) {
|
|
gomock.InOrder(
|
|
m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{
|
|
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
|
|
Hostname: "localhost",
|
|
Ip: "127.0.0.1",
|
|
Port: int32(8004),
|
|
Idc: "foo",
|
|
Location: "bar",
|
|
SchedulerClusterId: uint64(1),
|
|
})).Times(1),
|
|
mtc.Train(gomock.Any()).Return(nil, errors.New("foo")).Times(1),
|
|
)
|
|
},
|
|
except: func(a Announcer) {
|
|
go a.(*announcer).announceToTrainer()
|
|
},
|
|
},
|
|
{
|
|
name: "announce to trainer success",
|
|
config: &config.Config{
|
|
Server: config.ServerConfig{
|
|
Host: "localhost",
|
|
AdvertiseIP: net.ParseIP("127.0.0.1"),
|
|
AdvertisePort: 8004,
|
|
Port: 8080,
|
|
},
|
|
Host: config.HostConfig{
|
|
IDC: "foo",
|
|
Location: "bar",
|
|
},
|
|
Manager: config.ManagerConfig{
|
|
SchedulerClusterID: 1,
|
|
},
|
|
Trainer: config.TrainerConfig{
|
|
Interval: 80 * time.Millisecond,
|
|
UploadTimeout: 1 * time.Second,
|
|
},
|
|
},
|
|
data: []byte("bar"),
|
|
sleep: func() {
|
|
time.Sleep(100 * time.Millisecond)
|
|
},
|
|
mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) {
|
|
var wg sync.WaitGroup
|
|
wg.Add(4)
|
|
gomock.InOrder(
|
|
m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{
|
|
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
|
|
Hostname: "localhost",
|
|
Ip: "127.0.0.1",
|
|
Port: int32(8004),
|
|
Idc: "foo",
|
|
Location: "bar",
|
|
SchedulerClusterId: uint64(1),
|
|
})).Times(1),
|
|
mtc.Train(gomock.Any()).Return(stream, nil).Times(1),
|
|
mt.CloseAndRecv().Do(func() { wg.Wait() }).Return(nil, nil).Times(1),
|
|
)
|
|
ms.OpenDownload().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1)
|
|
ms.OpenNetworkTopology().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1)
|
|
mt.Send(gomock.Any()).DoAndReturn(
|
|
func(t *trainerv1.TrainRequest) error {
|
|
wg.Done()
|
|
return nil
|
|
}).Times(4)
|
|
},
|
|
except: func(a Announcer) {
|
|
go a.(*announcer).announceToTrainer()
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
ctl := gomock.NewController(t)
|
|
defer ctl.Finish()
|
|
stream := trainerv1mocks.NewMockTrainer_TrainClient(ctl)
|
|
mockManagerClient := managerclientmocks.NewMockV2(ctl)
|
|
mockTrainerClient := trainerclientmocks.NewMockV1(ctl)
|
|
mockStorage := storagemocks.NewMockStorage(ctl)
|
|
tc.mock(stream, tc.data, mockManagerClient.EXPECT(), mockTrainerClient.EXPECT(), mockStorage.EXPECT(), stream.EXPECT())
|
|
|
|
a, err := New(tc.config, mockManagerClient, mockStorage, WithTrainerClient(mockTrainerClient))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
tc.except(a)
|
|
tc.sleep()
|
|
a.Stop()
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAnnouncer_train(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
config *config.Config
|
|
data []byte
|
|
mock func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder)
|
|
except func(t *testing.T, announcer Announcer, err error)
|
|
}{
|
|
{
|
|
name: "get stream failed",
|
|
config: &config.Config{
|
|
Server: config.ServerConfig{
|
|
Host: "localhost",
|
|
AdvertiseIP: net.ParseIP("127.0.0.1"),
|
|
AdvertisePort: 8004,
|
|
Port: 8080,
|
|
},
|
|
Host: config.HostConfig{
|
|
IDC: "foo",
|
|
Location: "bar",
|
|
},
|
|
Manager: config.ManagerConfig{
|
|
SchedulerClusterID: 1,
|
|
},
|
|
},
|
|
data: []byte("bar"),
|
|
mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) {
|
|
gomock.InOrder(
|
|
m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{
|
|
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
|
|
Hostname: "localhost",
|
|
Ip: "127.0.0.1",
|
|
Port: int32(8004),
|
|
Idc: "foo",
|
|
Location: "bar",
|
|
SchedulerClusterId: uint64(1),
|
|
})).Times(1),
|
|
mtc.Train(gomock.Any()).Return(nil, errors.New("foo")).Times(1),
|
|
)
|
|
},
|
|
except: func(t *testing.T, a Announcer, err error) {
|
|
assert := assert.New(t)
|
|
assert.EqualError(err, "foo")
|
|
},
|
|
},
|
|
{
|
|
name: "upload download failed",
|
|
config: &config.Config{
|
|
Server: config.ServerConfig{
|
|
Host: "localhost",
|
|
AdvertiseIP: net.ParseIP("127.0.0.1"),
|
|
AdvertisePort: 8004,
|
|
Port: 8080,
|
|
},
|
|
Host: config.HostConfig{
|
|
IDC: "foo",
|
|
Location: "bar",
|
|
},
|
|
Manager: config.ManagerConfig{
|
|
SchedulerClusterID: 1,
|
|
},
|
|
},
|
|
data: []byte("bar"),
|
|
mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) {
|
|
gomock.InOrder(
|
|
m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{
|
|
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
|
|
Hostname: "localhost",
|
|
Ip: "127.0.0.1",
|
|
Port: int32(8004),
|
|
Idc: "foo",
|
|
Location: "bar",
|
|
SchedulerClusterId: uint64(1),
|
|
})).Times(1),
|
|
mtc.Train(gomock.Any()).Return(stream, nil).Times(1),
|
|
)
|
|
ms.OpenDownload().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1)
|
|
ms.OpenNetworkTopology().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1)
|
|
mt.Send(gomock.Any()).DoAndReturn(
|
|
func(t *trainerv1.TrainRequest) error {
|
|
switch t.Request.(type) {
|
|
case *trainerv1.TrainRequest_TrainGnnRequest:
|
|
return nil
|
|
case *trainerv1.TrainRequest_TrainMlpRequest:
|
|
return errors.New("foo")
|
|
}
|
|
return nil
|
|
}).Times(3)
|
|
},
|
|
except: func(t *testing.T, a Announcer, err error) {
|
|
assert := assert.New(t)
|
|
assert.EqualError(err, "upload download: foo")
|
|
},
|
|
},
|
|
{
|
|
name: "upload network topology failed",
|
|
config: &config.Config{
|
|
Server: config.ServerConfig{
|
|
Host: "localhost",
|
|
AdvertiseIP: net.ParseIP("127.0.0.1"),
|
|
AdvertisePort: 8004,
|
|
Port: 8080,
|
|
},
|
|
Host: config.HostConfig{
|
|
IDC: "foo",
|
|
Location: "bar",
|
|
},
|
|
Manager: config.ManagerConfig{
|
|
SchedulerClusterID: 1,
|
|
},
|
|
},
|
|
data: []byte("bar"),
|
|
mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) {
|
|
gomock.InOrder(
|
|
m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{
|
|
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
|
|
Hostname: "localhost",
|
|
Ip: "127.0.0.1",
|
|
Port: int32(8004),
|
|
Idc: "foo",
|
|
Location: "bar",
|
|
SchedulerClusterId: uint64(1),
|
|
})).Times(1),
|
|
mtc.Train(gomock.Any()).Return(stream, nil).Times(1),
|
|
)
|
|
ms.OpenDownload().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1)
|
|
ms.OpenNetworkTopology().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1)
|
|
mt.Send(gomock.Any()).DoAndReturn(
|
|
func(t *trainerv1.TrainRequest) error {
|
|
switch t.Request.(type) {
|
|
case *trainerv1.TrainRequest_TrainGnnRequest:
|
|
return errors.New("foo")
|
|
case *trainerv1.TrainRequest_TrainMlpRequest:
|
|
return nil
|
|
}
|
|
return nil
|
|
}).Times(3)
|
|
},
|
|
except: func(t *testing.T, a Announcer, err error) {
|
|
assert := assert.New(t)
|
|
assert.EqualError(err, "upload network topology: foo")
|
|
},
|
|
},
|
|
{
|
|
name: "close stream failed",
|
|
config: &config.Config{
|
|
Server: config.ServerConfig{
|
|
Host: "localhost",
|
|
AdvertiseIP: net.ParseIP("127.0.0.1"),
|
|
AdvertisePort: 8004,
|
|
Port: 8080,
|
|
},
|
|
Host: config.HostConfig{
|
|
IDC: "foo",
|
|
Location: "bar",
|
|
},
|
|
Manager: config.ManagerConfig{
|
|
SchedulerClusterID: 1,
|
|
},
|
|
},
|
|
data: []byte("bar"),
|
|
mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) {
|
|
var wg sync.WaitGroup
|
|
wg.Add(4)
|
|
gomock.InOrder(
|
|
m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{
|
|
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
|
|
Hostname: "localhost",
|
|
Ip: "127.0.0.1",
|
|
Port: int32(8004),
|
|
Idc: "foo",
|
|
Location: "bar",
|
|
SchedulerClusterId: uint64(1),
|
|
})).Times(1),
|
|
mtc.Train(gomock.Any()).Return(stream, nil).Times(1),
|
|
mt.CloseAndRecv().Return(nil, errors.New("foo")).Do(func() { wg.Wait() }).Times(1),
|
|
)
|
|
ms.OpenNetworkTopology().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1)
|
|
ms.OpenDownload().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1)
|
|
mt.Send(gomock.Any()).DoAndReturn(
|
|
func(t *trainerv1.TrainRequest) error {
|
|
wg.Done()
|
|
return nil
|
|
}).Times(4)
|
|
},
|
|
except: func(t *testing.T, a Announcer, err error) {
|
|
assert := assert.New(t)
|
|
assert.EqualError(err, "foo")
|
|
},
|
|
},
|
|
{
|
|
name: "train success",
|
|
config: &config.Config{
|
|
Server: config.ServerConfig{
|
|
Host: "localhost",
|
|
AdvertiseIP: net.ParseIP("127.0.0.1"),
|
|
AdvertisePort: 8004,
|
|
Port: 8080,
|
|
},
|
|
Host: config.HostConfig{
|
|
IDC: "foo",
|
|
Location: "bar",
|
|
},
|
|
Manager: config.ManagerConfig{
|
|
SchedulerClusterID: 1,
|
|
},
|
|
},
|
|
data: []byte("bar"),
|
|
mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) {
|
|
var wg sync.WaitGroup
|
|
wg.Add(4)
|
|
gomock.InOrder(
|
|
m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{
|
|
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
|
|
Hostname: "localhost",
|
|
Ip: "127.0.0.1",
|
|
Port: int32(8004),
|
|
Idc: "foo",
|
|
Location: "bar",
|
|
SchedulerClusterId: uint64(1),
|
|
})).Times(1),
|
|
mtc.Train(gomock.Any()).Return(stream, nil).Times(1),
|
|
mt.CloseAndRecv().Return(nil, nil).Do(func() { wg.Wait() }).Times(1),
|
|
)
|
|
ms.OpenNetworkTopology().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1)
|
|
ms.OpenDownload().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1)
|
|
mt.Send(gomock.Any()).DoAndReturn(
|
|
func(t *trainerv1.TrainRequest) error {
|
|
wg.Done()
|
|
return nil
|
|
}).Times(4)
|
|
},
|
|
except: func(t *testing.T, a Announcer, err error) {
|
|
assert := assert.New(t)
|
|
assert.NoError(err)
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
ctl := gomock.NewController(t)
|
|
defer ctl.Finish()
|
|
stream := trainerv1mocks.NewMockTrainer_TrainClient(ctl)
|
|
mockManagerClient := managerclientmocks.NewMockV2(ctl)
|
|
mockTrainerClient := trainerclientmocks.NewMockV1(ctl)
|
|
mockStorage := storagemocks.NewMockStorage(ctl)
|
|
tc.mock(stream, tc.data, mockManagerClient.EXPECT(), mockTrainerClient.EXPECT(), mockStorage.EXPECT(), stream.EXPECT())
|
|
|
|
a, err := New(tc.config, mockManagerClient, mockStorage, WithTrainerClient(mockTrainerClient))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
tc.except(t, a, a.(*announcer).train())
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAnnouncer_uploadDownloadToTrainer(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
config *config.Config
|
|
data []byte
|
|
mock func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder)
|
|
except func(t *testing.T, announcer Announcer, err error)
|
|
}{
|
|
{
|
|
name: "open download failed",
|
|
data: []byte{},
|
|
config: &config.Config{
|
|
Server: config.ServerConfig{
|
|
Host: "localhost",
|
|
AdvertiseIP: net.ParseIP("127.0.0.1"),
|
|
AdvertisePort: 8004,
|
|
Port: 8080,
|
|
},
|
|
Host: config.HostConfig{
|
|
IDC: "foo",
|
|
Location: "bar",
|
|
},
|
|
Manager: config.ManagerConfig{
|
|
SchedulerClusterID: 1,
|
|
},
|
|
},
|
|
mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) {
|
|
gomock.InOrder(
|
|
m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{
|
|
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
|
|
Hostname: "localhost",
|
|
Ip: "127.0.0.1",
|
|
Port: int32(8004),
|
|
Idc: "foo",
|
|
Location: "bar",
|
|
SchedulerClusterId: uint64(1),
|
|
})).Times(1),
|
|
ms.OpenDownload().Return(nil, errors.New("foo")).Times(1),
|
|
)
|
|
},
|
|
except: func(t *testing.T, a Announcer, err error) {
|
|
assert := assert.New(t)
|
|
assert.EqualError(err, "foo")
|
|
},
|
|
},
|
|
{
|
|
name: "read buffer failed",
|
|
data: []byte{},
|
|
config: &config.Config{
|
|
Server: config.ServerConfig{
|
|
Host: "localhost",
|
|
AdvertiseIP: net.ParseIP("127.0.0.1"),
|
|
AdvertisePort: 8004,
|
|
Port: 8080,
|
|
},
|
|
Host: config.HostConfig{
|
|
IDC: "foo",
|
|
Location: "bar",
|
|
},
|
|
Manager: config.ManagerConfig{
|
|
SchedulerClusterID: 1,
|
|
},
|
|
},
|
|
mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) {
|
|
gomock.InOrder(
|
|
m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{
|
|
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
|
|
Hostname: "localhost",
|
|
Ip: "127.0.0.1",
|
|
Port: int32(8004),
|
|
Idc: "foo",
|
|
Location: "bar",
|
|
SchedulerClusterId: uint64(1),
|
|
})).Times(1),
|
|
ms.OpenDownload().Return(&mockReadCloserWithReadError{}, nil).Times(1),
|
|
)
|
|
},
|
|
except: func(t *testing.T, a Announcer, err error) {
|
|
assert := assert.New(t)
|
|
assert.EqualError(err, "foo")
|
|
},
|
|
},
|
|
{
|
|
name: "send download request failed",
|
|
data: []byte("bar"),
|
|
config: &config.Config{
|
|
Server: config.ServerConfig{
|
|
Host: "localhost",
|
|
AdvertiseIP: net.ParseIP("127.0.0.1"),
|
|
AdvertisePort: 8004,
|
|
Port: 8080,
|
|
},
|
|
Host: config.HostConfig{
|
|
IDC: "foo",
|
|
Location: "bar",
|
|
},
|
|
Manager: config.ManagerConfig{
|
|
SchedulerClusterID: 1,
|
|
},
|
|
},
|
|
mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) {
|
|
gomock.InOrder(
|
|
m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{
|
|
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
|
|
Hostname: "localhost",
|
|
Ip: "127.0.0.1",
|
|
Port: int32(8004),
|
|
Idc: "foo",
|
|
Location: "bar",
|
|
SchedulerClusterId: uint64(1),
|
|
})).Times(1),
|
|
ms.OpenDownload().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1),
|
|
mt.Send(gomock.Any()).DoAndReturn(
|
|
func(t *trainerv1.TrainRequest) error {
|
|
return nil
|
|
}).Return(errors.New("foo")),
|
|
)
|
|
},
|
|
except: func(t *testing.T, a Announcer, err error) {
|
|
assert := assert.New(t)
|
|
assert.EqualError(err, "foo")
|
|
},
|
|
},
|
|
{
|
|
name: "send download request success",
|
|
config: &config.Config{
|
|
Server: config.ServerConfig{
|
|
Host: "localhost",
|
|
AdvertiseIP: net.ParseIP("127.0.0.1"),
|
|
AdvertisePort: 8004,
|
|
Port: 8080,
|
|
},
|
|
Host: config.HostConfig{
|
|
IDC: "foo",
|
|
Location: "bar",
|
|
},
|
|
Manager: config.ManagerConfig{
|
|
SchedulerClusterID: 1,
|
|
},
|
|
},
|
|
data: []byte("bar"),
|
|
mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) {
|
|
gomock.InOrder(
|
|
m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{
|
|
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
|
|
Hostname: "localhost",
|
|
Ip: "127.0.0.1",
|
|
Port: int32(8004),
|
|
Idc: "foo",
|
|
Location: "bar",
|
|
SchedulerClusterId: uint64(1),
|
|
})).Times(1),
|
|
ms.OpenDownload().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1),
|
|
mt.Send(gomock.Any()).DoAndReturn(
|
|
func(t *trainerv1.TrainRequest) error {
|
|
return nil
|
|
}).Times(2),
|
|
)
|
|
},
|
|
except: func(t *testing.T, a Announcer, err error) {
|
|
assert := assert.New(t)
|
|
assert.NoError(err)
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
ctl := gomock.NewController(t)
|
|
defer ctl.Finish()
|
|
stream := trainerv1mocks.NewMockTrainer_TrainClient(ctl)
|
|
mockManagerClient := managerclientmocks.NewMockV2(ctl)
|
|
mockTrainerClient := trainerclientmocks.NewMockV1(ctl)
|
|
mockStorage := storagemocks.NewMockStorage(ctl)
|
|
tc.mock(stream, tc.data, mockManagerClient.EXPECT(), mockTrainerClient.EXPECT(), mockStorage.EXPECT(), stream.EXPECT())
|
|
|
|
a, err := New(tc.config, mockManagerClient, mockStorage, WithTrainerClient(mockTrainerClient))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
tc.except(t, a, a.(*announcer).uploadDownloadToTrainer(stream))
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAnnouncer_uploadNetworkTopologyToTrainer(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
config *config.Config
|
|
data []byte
|
|
mock func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder)
|
|
except func(t *testing.T, announcer Announcer, err error)
|
|
}{
|
|
{
|
|
name: "open networkTopology failed",
|
|
data: []byte{},
|
|
config: &config.Config{
|
|
Server: config.ServerConfig{
|
|
Host: "localhost",
|
|
AdvertiseIP: net.ParseIP("127.0.0.1"),
|
|
AdvertisePort: 8004,
|
|
Port: 8080,
|
|
},
|
|
Host: config.HostConfig{
|
|
IDC: "foo",
|
|
Location: "bar",
|
|
},
|
|
Manager: config.ManagerConfig{
|
|
SchedulerClusterID: 1,
|
|
},
|
|
},
|
|
mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) {
|
|
gomock.InOrder(
|
|
m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{
|
|
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
|
|
Hostname: "localhost",
|
|
Ip: "127.0.0.1",
|
|
Port: int32(8004),
|
|
Idc: "foo",
|
|
Location: "bar",
|
|
SchedulerClusterId: uint64(1),
|
|
})).Times(1),
|
|
ms.OpenNetworkTopology().Return(nil, errors.New("foo")).Times(1),
|
|
)
|
|
},
|
|
except: func(t *testing.T, a Announcer, err error) {
|
|
assert := assert.New(t)
|
|
assert.EqualError(err, "foo")
|
|
},
|
|
},
|
|
{
|
|
name: "read buffer failed",
|
|
data: []byte{},
|
|
config: &config.Config{
|
|
Server: config.ServerConfig{
|
|
Host: "localhost",
|
|
AdvertiseIP: net.ParseIP("127.0.0.1"),
|
|
AdvertisePort: 8004,
|
|
Port: 8080,
|
|
},
|
|
Host: config.HostConfig{
|
|
IDC: "foo",
|
|
Location: "bar",
|
|
},
|
|
Manager: config.ManagerConfig{
|
|
SchedulerClusterID: 1,
|
|
},
|
|
},
|
|
mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) {
|
|
gomock.InOrder(
|
|
m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{
|
|
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
|
|
Hostname: "localhost",
|
|
Ip: "127.0.0.1",
|
|
Port: int32(8004),
|
|
Idc: "foo",
|
|
Location: "bar",
|
|
SchedulerClusterId: uint64(1),
|
|
})).Times(1),
|
|
ms.OpenNetworkTopology().Return(&mockReadCloserWithReadError{}, nil).Times(1),
|
|
)
|
|
},
|
|
except: func(t *testing.T, a Announcer, err error) {
|
|
assert := assert.New(t)
|
|
assert.EqualError(err, "foo")
|
|
},
|
|
},
|
|
{
|
|
name: "send network topology failed",
|
|
data: []byte("bar"),
|
|
config: &config.Config{
|
|
Server: config.ServerConfig{
|
|
Host: "localhost",
|
|
AdvertiseIP: net.ParseIP("127.0.0.1"),
|
|
AdvertisePort: 8004,
|
|
Port: 8080,
|
|
},
|
|
Host: config.HostConfig{
|
|
IDC: "foo",
|
|
Location: "bar",
|
|
},
|
|
Manager: config.ManagerConfig{
|
|
SchedulerClusterID: 1,
|
|
},
|
|
},
|
|
mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) {
|
|
gomock.InOrder(
|
|
m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{
|
|
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
|
|
Hostname: "localhost",
|
|
Ip: "127.0.0.1",
|
|
Port: int32(8004),
|
|
Idc: "foo",
|
|
Location: "bar",
|
|
SchedulerClusterId: uint64(1),
|
|
})).Times(1),
|
|
ms.OpenNetworkTopology().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1),
|
|
mt.Send(gomock.Any()).DoAndReturn(
|
|
func(t *trainerv1.TrainRequest) error {
|
|
return nil
|
|
}).Return(errors.New("foo")),
|
|
)
|
|
},
|
|
except: func(t *testing.T, a Announcer, err error) {
|
|
assert := assert.New(t)
|
|
assert.EqualError(err, "foo")
|
|
},
|
|
},
|
|
{
|
|
name: "send network topology success",
|
|
config: &config.Config{
|
|
Server: config.ServerConfig{
|
|
Host: "localhost",
|
|
AdvertiseIP: net.ParseIP("127.0.0.1"),
|
|
AdvertisePort: 8004,
|
|
Port: 8080,
|
|
},
|
|
Host: config.HostConfig{
|
|
IDC: "foo",
|
|
Location: "bar",
|
|
},
|
|
Manager: config.ManagerConfig{
|
|
SchedulerClusterID: 1,
|
|
},
|
|
},
|
|
data: []byte("bar"),
|
|
mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) {
|
|
gomock.InOrder(
|
|
m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{
|
|
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
|
|
Hostname: "localhost",
|
|
Ip: "127.0.0.1",
|
|
Port: int32(8004),
|
|
Idc: "foo",
|
|
Location: "bar",
|
|
SchedulerClusterId: uint64(1),
|
|
})).Times(1),
|
|
ms.OpenNetworkTopology().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1),
|
|
mt.Send(gomock.Any()).DoAndReturn(
|
|
func(t *trainerv1.TrainRequest) error {
|
|
return nil
|
|
}).Times(2),
|
|
)
|
|
},
|
|
except: func(t *testing.T, a Announcer, err error) {
|
|
assert := assert.New(t)
|
|
assert.NoError(err)
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
ctl := gomock.NewController(t)
|
|
defer ctl.Finish()
|
|
stream := trainerv1mocks.NewMockTrainer_TrainClient(ctl)
|
|
mockManagerClient := managerclientmocks.NewMockV2(ctl)
|
|
mockTrainerClient := trainerclientmocks.NewMockV1(ctl)
|
|
mockStorage := storagemocks.NewMockStorage(ctl)
|
|
tc.mock(stream, tc.data, mockManagerClient.EXPECT(), mockTrainerClient.EXPECT(), mockStorage.EXPECT(), stream.EXPECT())
|
|
|
|
a, err := New(tc.config, mockManagerClient, mockStorage, WithTrainerClient(mockTrainerClient))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
tc.except(t, a, a.(*announcer).uploadNetworkTopologyToTrainer(stream))
|
|
})
|
|
}
|
|
}
|