dragonfly/scheduler/announcer/announcer_test.go

1148 lines
38 KiB
Go

/*
* Copyright 2022 The Dragonfly Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package announcer
import (
"bytes"
"errors"
"io"
"net"
"sync"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
managerv2 "d7y.io/api/pkg/apis/manager/v2"
trainerv1 "d7y.io/api/pkg/apis/trainer/v1"
trainerv1mocks "d7y.io/api/pkg/apis/trainer/v1/mocks"
managerclientmocks "d7y.io/dragonfly/v2/pkg/rpc/manager/client/mocks"
trainerclientmocks "d7y.io/dragonfly/v2/pkg/rpc/trainer/client/mocks"
"d7y.io/dragonfly/v2/scheduler/config"
storagemocks "d7y.io/dragonfly/v2/scheduler/storage/mocks"
)
type mockReadCloserWithReadError struct{}
func (m *mockReadCloserWithReadError) Read(p []byte) (int, error) {
return 0, errors.New("foo")
}
func (m *mockReadCloserWithReadError) Close() error {
return nil
}
func TestAnnouncer_New(t *testing.T) {
ctl := gomock.NewController(t)
defer ctl.Finish()
mockTrainerClient := trainerclientmocks.NewMockV1(ctl)
tests := []struct {
name string
config *config.Config
option []Option
mock func(m *managerclientmocks.MockV2MockRecorder)
expect func(t *testing.T, announcer Announcer, err error)
}{
{
name: "new announcer",
config: &config.Config{
Server: config.ServerConfig{
Host: "localhost",
AdvertiseIP: net.ParseIP("127.0.0.1"),
AdvertisePort: 8004,
Port: 8080,
},
Host: config.HostConfig{
IDC: "foo",
Location: "bar",
},
Manager: config.ManagerConfig{
SchedulerClusterID: 1,
},
},
option: []Option{},
mock: func(m *managerclientmocks.MockV2MockRecorder) {
m.UpdateScheduler(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1)
},
expect: func(t *testing.T, a Announcer, err error) {
assert := assert.New(t)
instance := a.(*announcer)
assert.NoError(err)
assert.NotNil(instance.config)
assert.NotNil(instance.managerClient)
},
},
{
name: "new announcer with trainer client",
config: &config.Config{
Server: config.ServerConfig{
Host: "localhost",
AdvertiseIP: net.ParseIP("127.0.0.1"),
AdvertisePort: 8004,
Port: 8080,
},
Host: config.HostConfig{
IDC: "foo",
Location: "bar",
},
Manager: config.ManagerConfig{
SchedulerClusterID: 1,
},
},
option: []Option{WithTrainerClient(mockTrainerClient)},
mock: func(m *managerclientmocks.MockV2MockRecorder) {
m.UpdateScheduler(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1)
},
expect: func(t *testing.T, a Announcer, err error) {
assert := assert.New(t)
instance := a.(*announcer)
assert.NoError(err)
assert.NotNil(instance.config)
assert.NotNil(instance.managerClient)
assert.NotNil(instance.trainerClient)
},
},
{
name: "update scheduler failed",
config: &config.Config{
Server: config.ServerConfig{
Host: "localhost",
AdvertiseIP: net.ParseIP("127.0.0.1"),
AdvertisePort: 8004,
Port: 8080,
},
Host: config.HostConfig{
IDC: "foo",
Location: "bar",
},
Manager: config.ManagerConfig{
SchedulerClusterID: 1,
},
},
option: []Option{},
mock: func(m *managerclientmocks.MockV2MockRecorder) {
m.UpdateScheduler(gomock.Any(), gomock.Any()).Return(nil, errors.New("foo")).Times(1)
},
expect: func(t *testing.T, a Announcer, err error) {
assert := assert.New(t)
assert.EqualError(err, "foo")
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
mockManagerClient := managerclientmocks.NewMockV2(ctl)
mockStorage := storagemocks.NewMockStorage(ctl)
tc.mock(mockManagerClient.EXPECT())
a, err := New(tc.config, mockManagerClient, mockStorage, tc.option...)
tc.expect(t, a, err)
})
}
}
func TestAnnouncer_Serve(t *testing.T) {
ctl := gomock.NewController(t)
defer ctl.Finish()
mockTrainerClient := trainerclientmocks.NewMockV1(ctl)
tests := []struct {
name string
config *config.Config
data []byte
option []Option
sleep func()
mock func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder)
except func(t *testing.T, a Announcer)
}{
{
name: "started announcer server success",
config: &config.Config{
Server: config.ServerConfig{
Host: "localhost",
AdvertiseIP: net.ParseIP("127.0.0.1"),
AdvertisePort: 8004,
Port: 8080,
},
Host: config.HostConfig{
IDC: "foo",
Location: "bar",
},
Manager: config.ManagerConfig{
KeepAlive: config.KeepAliveConfig{
Interval: 50 * time.Millisecond,
},
SchedulerClusterID: 1,
},
Trainer: config.TrainerConfig{
Interval: 80 * time.Millisecond,
UploadTimeout: 10 * time.Second,
},
},
data: []byte("bar"),
option: []Option{WithTrainerClient(mockTrainerClient)},
sleep: func() {
time.Sleep(100 * time.Millisecond)
},
mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) {
var wg sync.WaitGroup
wg.Add(4)
gomock.InOrder(
m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
Hostname: "localhost",
Ip: "127.0.0.1",
Port: int32(8004),
Idc: "foo",
Location: "bar",
SchedulerClusterId: uint64(1),
})).Times(1),
m.KeepAlive(gomock.Eq(50*time.Millisecond), gomock.Eq(&managerv2.KeepAliveRequest{
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
Hostname: "localhost",
Ip: "127.0.0.1",
ClusterId: uint64(1),
}), gomock.Any()).Times(1),
mtc.Train(gomock.Any()).Return(stream, nil).Times(1),
mt.CloseAndRecv().Do(func() { wg.Wait() }).Return(nil, nil).Times(1),
)
gomock.InOrder(
ms.OpenNetworkTopology().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1),
mt.Send(gomock.Any()).DoAndReturn(
func(t *trainerv1.TrainRequest) error {
wg.Done()
return nil
}).Times(2),
)
gomock.InOrder(
ms.OpenDownload().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1),
mt.Send(gomock.Any()).DoAndReturn(
func(t *trainerv1.TrainRequest) error {
wg.Done()
return nil
}).Times(2),
)
},
except: func(t *testing.T, a Announcer) {
go a.Serve()
},
},
{
name: "started announcer server success without trainer client",
config: &config.Config{
Server: config.ServerConfig{
Host: "localhost",
AdvertiseIP: net.ParseIP("127.0.0.1"),
AdvertisePort: 8004,
Port: 8080,
},
Host: config.HostConfig{
IDC: "foo",
Location: "bar",
},
Manager: config.ManagerConfig{
KeepAlive: config.KeepAliveConfig{
Interval: 50 * time.Millisecond,
},
SchedulerClusterID: 1,
},
},
data: []byte("bar"),
option: []Option{},
sleep: func() {
time.Sleep(100 * time.Millisecond)
},
mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) {
gomock.InOrder(
m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
Hostname: "localhost",
Ip: "127.0.0.1",
Port: int32(8004),
Idc: "foo",
Location: "bar",
SchedulerClusterId: uint64(1),
})).Times(1),
m.KeepAlive(gomock.Eq(50*time.Millisecond), gomock.Eq(&managerv2.KeepAliveRequest{
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
Hostname: "localhost",
Ip: "127.0.0.1",
ClusterId: uint64(1),
}), gomock.Any()).Times(1),
)
},
except: func(t *testing.T, a Announcer) {
go a.Serve()
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
stream := trainerv1mocks.NewMockTrainer_TrainClient(ctl)
mockManagerClient := managerclientmocks.NewMockV2(ctl)
mockStorage := storagemocks.NewMockStorage(ctl)
tc.mock(stream, tc.data, mockManagerClient.EXPECT(), mockTrainerClient.EXPECT(), mockStorage.EXPECT(), stream.EXPECT())
a, err := New(tc.config, mockManagerClient, mockStorage, tc.option...)
if err != nil {
t.Fatal(err)
}
tc.except(t, a)
tc.sleep()
a.Stop()
})
}
}
func TestAnnouncer_announceToManager(t *testing.T) {
tests := []struct {
name string
config *config.Config
sleep func()
mock func(m *managerclientmocks.MockV2MockRecorder)
except func(a Announcer)
}{
{
name: "announce to manager success",
config: &config.Config{
Server: config.ServerConfig{
Host: "localhost",
AdvertiseIP: net.ParseIP("127.0.0.1"),
AdvertisePort: 8004,
Port: 8080,
},
Host: config.HostConfig{
IDC: "foo",
Location: "bar",
},
Manager: config.ManagerConfig{
KeepAlive: config.KeepAliveConfig{
Interval: 50 * time.Millisecond,
},
SchedulerClusterID: 1,
},
},
sleep: func() {
time.Sleep(100 * time.Millisecond)
},
mock: func(m *managerclientmocks.MockV2MockRecorder) {
gomock.InOrder(
m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
Hostname: "localhost",
Ip: "127.0.0.1",
Port: int32(8004),
Idc: "foo",
Location: "bar",
SchedulerClusterId: uint64(1),
})).Times(1),
m.KeepAlive(gomock.Eq(50*time.Millisecond), gomock.Eq(&managerv2.KeepAliveRequest{
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
Hostname: "localhost",
Ip: "127.0.0.1",
ClusterId: uint64(1),
}), gomock.Any()).Times(1),
)
},
except: func(a Announcer) {
a.(*announcer).announceToManager()
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ctl := gomock.NewController(t)
defer ctl.Finish()
mockManagerClient := managerclientmocks.NewMockV2(ctl)
mockTrainerClient := trainerclientmocks.NewMockV1(ctl)
mockStorage := storagemocks.NewMockStorage(ctl)
tc.mock(mockManagerClient.EXPECT())
a, err := New(tc.config, mockManagerClient, mockStorage, WithTrainerClient(mockTrainerClient))
if err != nil {
t.Fatal(err)
}
tc.except(a)
tc.sleep()
})
}
}
func TestAnnouncer_announceToTrainer(t *testing.T) {
tests := []struct {
name string
config *config.Config
data []byte
sleep func()
mock func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder)
except func(a Announcer)
}{
{
name: "announce to trainer failed",
config: &config.Config{
Server: config.ServerConfig{
Host: "localhost",
AdvertiseIP: net.ParseIP("127.0.0.1"),
AdvertisePort: 8004,
Port: 8080,
},
Host: config.HostConfig{
IDC: "foo",
Location: "bar",
},
Manager: config.ManagerConfig{
SchedulerClusterID: 1,
},
Trainer: config.TrainerConfig{
Interval: 80 * time.Millisecond,
UploadTimeout: 1 * time.Second,
},
},
data: []byte("bar"),
sleep: func() {
time.Sleep(100 * time.Millisecond)
},
mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) {
gomock.InOrder(
m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
Hostname: "localhost",
Ip: "127.0.0.1",
Port: int32(8004),
Idc: "foo",
Location: "bar",
SchedulerClusterId: uint64(1),
})).Times(1),
mtc.Train(gomock.Any()).Return(nil, errors.New("foo")).Times(1),
)
},
except: func(a Announcer) {
go a.(*announcer).announceToTrainer()
},
},
{
name: "announce to trainer success",
config: &config.Config{
Server: config.ServerConfig{
Host: "localhost",
AdvertiseIP: net.ParseIP("127.0.0.1"),
AdvertisePort: 8004,
Port: 8080,
},
Host: config.HostConfig{
IDC: "foo",
Location: "bar",
},
Manager: config.ManagerConfig{
SchedulerClusterID: 1,
},
Trainer: config.TrainerConfig{
Interval: 80 * time.Millisecond,
UploadTimeout: 1 * time.Second,
},
},
data: []byte("bar"),
sleep: func() {
time.Sleep(100 * time.Millisecond)
},
mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) {
var wg sync.WaitGroup
wg.Add(4)
gomock.InOrder(
m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
Hostname: "localhost",
Ip: "127.0.0.1",
Port: int32(8004),
Idc: "foo",
Location: "bar",
SchedulerClusterId: uint64(1),
})).Times(1),
mtc.Train(gomock.Any()).Return(stream, nil).Times(1),
mt.CloseAndRecv().Do(func() { wg.Wait() }).Return(nil, nil).Times(1),
)
ms.OpenDownload().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1)
ms.OpenNetworkTopology().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1)
mt.Send(gomock.Any()).DoAndReturn(
func(t *trainerv1.TrainRequest) error {
wg.Done()
return nil
}).Times(4)
},
except: func(a Announcer) {
go a.(*announcer).announceToTrainer()
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ctl := gomock.NewController(t)
defer ctl.Finish()
stream := trainerv1mocks.NewMockTrainer_TrainClient(ctl)
mockManagerClient := managerclientmocks.NewMockV2(ctl)
mockTrainerClient := trainerclientmocks.NewMockV1(ctl)
mockStorage := storagemocks.NewMockStorage(ctl)
tc.mock(stream, tc.data, mockManagerClient.EXPECT(), mockTrainerClient.EXPECT(), mockStorage.EXPECT(), stream.EXPECT())
a, err := New(tc.config, mockManagerClient, mockStorage, WithTrainerClient(mockTrainerClient))
if err != nil {
t.Fatal(err)
}
tc.except(a)
tc.sleep()
a.Stop()
})
}
}
func TestAnnouncer_train(t *testing.T) {
tests := []struct {
name string
config *config.Config
data []byte
mock func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder)
except func(t *testing.T, announcer Announcer, err error)
}{
{
name: "get stream failed",
config: &config.Config{
Server: config.ServerConfig{
Host: "localhost",
AdvertiseIP: net.ParseIP("127.0.0.1"),
AdvertisePort: 8004,
Port: 8080,
},
Host: config.HostConfig{
IDC: "foo",
Location: "bar",
},
Manager: config.ManagerConfig{
SchedulerClusterID: 1,
},
},
data: []byte("bar"),
mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) {
gomock.InOrder(
m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
Hostname: "localhost",
Ip: "127.0.0.1",
Port: int32(8004),
Idc: "foo",
Location: "bar",
SchedulerClusterId: uint64(1),
})).Times(1),
mtc.Train(gomock.Any()).Return(nil, errors.New("foo")).Times(1),
)
},
except: func(t *testing.T, a Announcer, err error) {
assert := assert.New(t)
assert.EqualError(err, "foo")
},
},
{
name: "upload download failed",
config: &config.Config{
Server: config.ServerConfig{
Host: "localhost",
AdvertiseIP: net.ParseIP("127.0.0.1"),
AdvertisePort: 8004,
Port: 8080,
},
Host: config.HostConfig{
IDC: "foo",
Location: "bar",
},
Manager: config.ManagerConfig{
SchedulerClusterID: 1,
},
},
data: []byte("bar"),
mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) {
gomock.InOrder(
m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
Hostname: "localhost",
Ip: "127.0.0.1",
Port: int32(8004),
Idc: "foo",
Location: "bar",
SchedulerClusterId: uint64(1),
})).Times(1),
mtc.Train(gomock.Any()).Return(stream, nil).Times(1),
)
ms.OpenDownload().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1)
ms.OpenNetworkTopology().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1)
mt.Send(gomock.Any()).DoAndReturn(
func(t *trainerv1.TrainRequest) error {
switch t.Request.(type) {
case *trainerv1.TrainRequest_TrainGnnRequest:
return nil
case *trainerv1.TrainRequest_TrainMlpRequest:
return errors.New("foo")
}
return nil
}).Times(3)
},
except: func(t *testing.T, a Announcer, err error) {
assert := assert.New(t)
assert.EqualError(err, "upload download: foo")
},
},
{
name: "upload network topology failed",
config: &config.Config{
Server: config.ServerConfig{
Host: "localhost",
AdvertiseIP: net.ParseIP("127.0.0.1"),
AdvertisePort: 8004,
Port: 8080,
},
Host: config.HostConfig{
IDC: "foo",
Location: "bar",
},
Manager: config.ManagerConfig{
SchedulerClusterID: 1,
},
},
data: []byte("bar"),
mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) {
gomock.InOrder(
m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
Hostname: "localhost",
Ip: "127.0.0.1",
Port: int32(8004),
Idc: "foo",
Location: "bar",
SchedulerClusterId: uint64(1),
})).Times(1),
mtc.Train(gomock.Any()).Return(stream, nil).Times(1),
)
ms.OpenDownload().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1)
ms.OpenNetworkTopology().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1)
mt.Send(gomock.Any()).DoAndReturn(
func(t *trainerv1.TrainRequest) error {
switch t.Request.(type) {
case *trainerv1.TrainRequest_TrainGnnRequest:
return errors.New("foo")
case *trainerv1.TrainRequest_TrainMlpRequest:
return nil
}
return nil
}).Times(3)
},
except: func(t *testing.T, a Announcer, err error) {
assert := assert.New(t)
assert.EqualError(err, "upload network topology: foo")
},
},
{
name: "close stream failed",
config: &config.Config{
Server: config.ServerConfig{
Host: "localhost",
AdvertiseIP: net.ParseIP("127.0.0.1"),
AdvertisePort: 8004,
Port: 8080,
},
Host: config.HostConfig{
IDC: "foo",
Location: "bar",
},
Manager: config.ManagerConfig{
SchedulerClusterID: 1,
},
},
data: []byte("bar"),
mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) {
var wg sync.WaitGroup
wg.Add(4)
gomock.InOrder(
m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
Hostname: "localhost",
Ip: "127.0.0.1",
Port: int32(8004),
Idc: "foo",
Location: "bar",
SchedulerClusterId: uint64(1),
})).Times(1),
mtc.Train(gomock.Any()).Return(stream, nil).Times(1),
mt.CloseAndRecv().Return(nil, errors.New("foo")).Do(func() { wg.Wait() }).Times(1),
)
ms.OpenNetworkTopology().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1)
ms.OpenDownload().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1)
mt.Send(gomock.Any()).DoAndReturn(
func(t *trainerv1.TrainRequest) error {
wg.Done()
return nil
}).Times(4)
},
except: func(t *testing.T, a Announcer, err error) {
assert := assert.New(t)
assert.EqualError(err, "foo")
},
},
{
name: "train success",
config: &config.Config{
Server: config.ServerConfig{
Host: "localhost",
AdvertiseIP: net.ParseIP("127.0.0.1"),
AdvertisePort: 8004,
Port: 8080,
},
Host: config.HostConfig{
IDC: "foo",
Location: "bar",
},
Manager: config.ManagerConfig{
SchedulerClusterID: 1,
},
},
data: []byte("bar"),
mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) {
var wg sync.WaitGroup
wg.Add(4)
gomock.InOrder(
m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
Hostname: "localhost",
Ip: "127.0.0.1",
Port: int32(8004),
Idc: "foo",
Location: "bar",
SchedulerClusterId: uint64(1),
})).Times(1),
mtc.Train(gomock.Any()).Return(stream, nil).Times(1),
mt.CloseAndRecv().Return(nil, nil).Do(func() { wg.Wait() }).Times(1),
)
ms.OpenNetworkTopology().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1)
ms.OpenDownload().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1)
mt.Send(gomock.Any()).DoAndReturn(
func(t *trainerv1.TrainRequest) error {
wg.Done()
return nil
}).Times(4)
},
except: func(t *testing.T, a Announcer, err error) {
assert := assert.New(t)
assert.NoError(err)
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ctl := gomock.NewController(t)
defer ctl.Finish()
stream := trainerv1mocks.NewMockTrainer_TrainClient(ctl)
mockManagerClient := managerclientmocks.NewMockV2(ctl)
mockTrainerClient := trainerclientmocks.NewMockV1(ctl)
mockStorage := storagemocks.NewMockStorage(ctl)
tc.mock(stream, tc.data, mockManagerClient.EXPECT(), mockTrainerClient.EXPECT(), mockStorage.EXPECT(), stream.EXPECT())
a, err := New(tc.config, mockManagerClient, mockStorage, WithTrainerClient(mockTrainerClient))
if err != nil {
t.Fatal(err)
}
tc.except(t, a, a.(*announcer).train())
})
}
}
func TestAnnouncer_uploadDownloadToTrainer(t *testing.T) {
tests := []struct {
name string
config *config.Config
data []byte
mock func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder)
except func(t *testing.T, announcer Announcer, err error)
}{
{
name: "open download failed",
data: []byte{},
config: &config.Config{
Server: config.ServerConfig{
Host: "localhost",
AdvertiseIP: net.ParseIP("127.0.0.1"),
AdvertisePort: 8004,
Port: 8080,
},
Host: config.HostConfig{
IDC: "foo",
Location: "bar",
},
Manager: config.ManagerConfig{
SchedulerClusterID: 1,
},
},
mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) {
gomock.InOrder(
m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
Hostname: "localhost",
Ip: "127.0.0.1",
Port: int32(8004),
Idc: "foo",
Location: "bar",
SchedulerClusterId: uint64(1),
})).Times(1),
ms.OpenDownload().Return(nil, errors.New("foo")).Times(1),
)
},
except: func(t *testing.T, a Announcer, err error) {
assert := assert.New(t)
assert.EqualError(err, "foo")
},
},
{
name: "read buffer failed",
data: []byte{},
config: &config.Config{
Server: config.ServerConfig{
Host: "localhost",
AdvertiseIP: net.ParseIP("127.0.0.1"),
AdvertisePort: 8004,
Port: 8080,
},
Host: config.HostConfig{
IDC: "foo",
Location: "bar",
},
Manager: config.ManagerConfig{
SchedulerClusterID: 1,
},
},
mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) {
gomock.InOrder(
m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
Hostname: "localhost",
Ip: "127.0.0.1",
Port: int32(8004),
Idc: "foo",
Location: "bar",
SchedulerClusterId: uint64(1),
})).Times(1),
ms.OpenDownload().Return(&mockReadCloserWithReadError{}, nil).Times(1),
)
},
except: func(t *testing.T, a Announcer, err error) {
assert := assert.New(t)
assert.EqualError(err, "foo")
},
},
{
name: "send download request failed",
data: []byte("bar"),
config: &config.Config{
Server: config.ServerConfig{
Host: "localhost",
AdvertiseIP: net.ParseIP("127.0.0.1"),
AdvertisePort: 8004,
Port: 8080,
},
Host: config.HostConfig{
IDC: "foo",
Location: "bar",
},
Manager: config.ManagerConfig{
SchedulerClusterID: 1,
},
},
mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) {
gomock.InOrder(
m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
Hostname: "localhost",
Ip: "127.0.0.1",
Port: int32(8004),
Idc: "foo",
Location: "bar",
SchedulerClusterId: uint64(1),
})).Times(1),
ms.OpenDownload().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1),
mt.Send(gomock.Any()).DoAndReturn(
func(t *trainerv1.TrainRequest) error {
return nil
}).Return(errors.New("foo")),
)
},
except: func(t *testing.T, a Announcer, err error) {
assert := assert.New(t)
assert.EqualError(err, "foo")
},
},
{
name: "send download request success",
config: &config.Config{
Server: config.ServerConfig{
Host: "localhost",
AdvertiseIP: net.ParseIP("127.0.0.1"),
AdvertisePort: 8004,
Port: 8080,
},
Host: config.HostConfig{
IDC: "foo",
Location: "bar",
},
Manager: config.ManagerConfig{
SchedulerClusterID: 1,
},
},
data: []byte("bar"),
mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) {
gomock.InOrder(
m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
Hostname: "localhost",
Ip: "127.0.0.1",
Port: int32(8004),
Idc: "foo",
Location: "bar",
SchedulerClusterId: uint64(1),
})).Times(1),
ms.OpenDownload().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1),
mt.Send(gomock.Any()).DoAndReturn(
func(t *trainerv1.TrainRequest) error {
return nil
}).Times(2),
)
},
except: func(t *testing.T, a Announcer, err error) {
assert := assert.New(t)
assert.NoError(err)
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ctl := gomock.NewController(t)
defer ctl.Finish()
stream := trainerv1mocks.NewMockTrainer_TrainClient(ctl)
mockManagerClient := managerclientmocks.NewMockV2(ctl)
mockTrainerClient := trainerclientmocks.NewMockV1(ctl)
mockStorage := storagemocks.NewMockStorage(ctl)
tc.mock(stream, tc.data, mockManagerClient.EXPECT(), mockTrainerClient.EXPECT(), mockStorage.EXPECT(), stream.EXPECT())
a, err := New(tc.config, mockManagerClient, mockStorage, WithTrainerClient(mockTrainerClient))
if err != nil {
t.Fatal(err)
}
tc.except(t, a, a.(*announcer).uploadDownloadToTrainer(stream))
})
}
}
func TestAnnouncer_uploadNetworkTopologyToTrainer(t *testing.T) {
tests := []struct {
name string
config *config.Config
data []byte
mock func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder)
except func(t *testing.T, announcer Announcer, err error)
}{
{
name: "open networkTopology failed",
data: []byte{},
config: &config.Config{
Server: config.ServerConfig{
Host: "localhost",
AdvertiseIP: net.ParseIP("127.0.0.1"),
AdvertisePort: 8004,
Port: 8080,
},
Host: config.HostConfig{
IDC: "foo",
Location: "bar",
},
Manager: config.ManagerConfig{
SchedulerClusterID: 1,
},
},
mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) {
gomock.InOrder(
m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
Hostname: "localhost",
Ip: "127.0.0.1",
Port: int32(8004),
Idc: "foo",
Location: "bar",
SchedulerClusterId: uint64(1),
})).Times(1),
ms.OpenNetworkTopology().Return(nil, errors.New("foo")).Times(1),
)
},
except: func(t *testing.T, a Announcer, err error) {
assert := assert.New(t)
assert.EqualError(err, "foo")
},
},
{
name: "read buffer failed",
data: []byte{},
config: &config.Config{
Server: config.ServerConfig{
Host: "localhost",
AdvertiseIP: net.ParseIP("127.0.0.1"),
AdvertisePort: 8004,
Port: 8080,
},
Host: config.HostConfig{
IDC: "foo",
Location: "bar",
},
Manager: config.ManagerConfig{
SchedulerClusterID: 1,
},
},
mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) {
gomock.InOrder(
m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
Hostname: "localhost",
Ip: "127.0.0.1",
Port: int32(8004),
Idc: "foo",
Location: "bar",
SchedulerClusterId: uint64(1),
})).Times(1),
ms.OpenNetworkTopology().Return(&mockReadCloserWithReadError{}, nil).Times(1),
)
},
except: func(t *testing.T, a Announcer, err error) {
assert := assert.New(t)
assert.EqualError(err, "foo")
},
},
{
name: "send network topology failed",
data: []byte("bar"),
config: &config.Config{
Server: config.ServerConfig{
Host: "localhost",
AdvertiseIP: net.ParseIP("127.0.0.1"),
AdvertisePort: 8004,
Port: 8080,
},
Host: config.HostConfig{
IDC: "foo",
Location: "bar",
},
Manager: config.ManagerConfig{
SchedulerClusterID: 1,
},
},
mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) {
gomock.InOrder(
m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
Hostname: "localhost",
Ip: "127.0.0.1",
Port: int32(8004),
Idc: "foo",
Location: "bar",
SchedulerClusterId: uint64(1),
})).Times(1),
ms.OpenNetworkTopology().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1),
mt.Send(gomock.Any()).DoAndReturn(
func(t *trainerv1.TrainRequest) error {
return nil
}).Return(errors.New("foo")),
)
},
except: func(t *testing.T, a Announcer, err error) {
assert := assert.New(t)
assert.EqualError(err, "foo")
},
},
{
name: "send network topology success",
config: &config.Config{
Server: config.ServerConfig{
Host: "localhost",
AdvertiseIP: net.ParseIP("127.0.0.1"),
AdvertisePort: 8004,
Port: 8080,
},
Host: config.HostConfig{
IDC: "foo",
Location: "bar",
},
Manager: config.ManagerConfig{
SchedulerClusterID: 1,
},
},
data: []byte("bar"),
mock: func(stream trainerv1.Trainer_TrainClient, data []byte, m *managerclientmocks.MockV2MockRecorder, mtc *trainerclientmocks.MockV1MockRecorder, ms *storagemocks.MockStorageMockRecorder, mt *trainerv1mocks.MockTrainer_TrainClientMockRecorder) {
gomock.InOrder(
m.UpdateScheduler(gomock.Any(), gomock.Eq(&managerv2.UpdateSchedulerRequest{
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
Hostname: "localhost",
Ip: "127.0.0.1",
Port: int32(8004),
Idc: "foo",
Location: "bar",
SchedulerClusterId: uint64(1),
})).Times(1),
ms.OpenNetworkTopology().Return(io.NopCloser(bytes.NewBuffer(data)), nil).Times(1),
mt.Send(gomock.Any()).DoAndReturn(
func(t *trainerv1.TrainRequest) error {
return nil
}).Times(2),
)
},
except: func(t *testing.T, a Announcer, err error) {
assert := assert.New(t)
assert.NoError(err)
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ctl := gomock.NewController(t)
defer ctl.Finish()
stream := trainerv1mocks.NewMockTrainer_TrainClient(ctl)
mockManagerClient := managerclientmocks.NewMockV2(ctl)
mockTrainerClient := trainerclientmocks.NewMockV1(ctl)
mockStorage := storagemocks.NewMockStorage(ctl)
tc.mock(stream, tc.data, mockManagerClient.EXPECT(), mockTrainerClient.EXPECT(), mockStorage.EXPECT(), stream.EXPECT())
a, err := New(tc.config, mockManagerClient, mockStorage, WithTrainerClient(mockTrainerClient))
if err != nil {
t.Fatal(err)
}
tc.except(t, a, a.(*announcer).uploadNetworkTopologyToTrainer(stream))
})
}
}