dragonfly/scheduler/announcer/announcer.go

251 lines
6.1 KiB
Go

/*
* Copyright 2022 The Dragonfly Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
//go:generate mockgen -destination mocks/announcer_mock.go -source announcer.go -package mocks
package announcer
import (
"context"
"fmt"
"io"
"time"
"golang.org/x/sync/errgroup"
managerv2 "d7y.io/api/pkg/apis/manager/v2"
trainerv1 "d7y.io/api/pkg/apis/trainer/v1"
logger "d7y.io/dragonfly/v2/internal/dflog"
managerclient "d7y.io/dragonfly/v2/pkg/rpc/manager/client"
trainerclient "d7y.io/dragonfly/v2/pkg/rpc/trainer/client"
"d7y.io/dragonfly/v2/scheduler/config"
"d7y.io/dragonfly/v2/scheduler/storage"
)
const (
// UploadBufferSize is the buffer size for upload.
UploadBufferSize = 1024 * 1024
)
// Announcer is the interface used for announce service.
type Announcer interface {
// Started announcer server.
Serve()
// Stop announcer server.
Stop()
}
// announcer provides announce function.
type announcer struct {
config *config.Config
managerClient managerclient.V2
trainerClient trainerclient.V1
storage storage.Storage
done chan struct{}
}
// WithTrainerClient sets the grpc client of trainer.
func WithTrainerClient(client trainerclient.V1) Option {
return func(a *announcer) {
a.trainerClient = client
}
}
// Option is a functional option for configuring the announcer.
type Option func(s *announcer)
// New returns a new Announcer interface.
func New(cfg *config.Config, managerClient managerclient.V2, storage storage.Storage, options ...Option) (Announcer, error) {
a := &announcer{
config: cfg,
managerClient: managerClient,
storage: storage,
done: make(chan struct{}),
}
for _, opt := range options {
opt(a)
}
// Register to manager.
if _, err := a.managerClient.UpdateScheduler(context.Background(), &managerv2.UpdateSchedulerRequest{
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
Hostname: a.config.Server.Host,
Ip: a.config.Server.AdvertiseIP.String(),
Port: int32(a.config.Server.AdvertisePort),
Idc: a.config.Host.IDC,
Location: a.config.Host.Location,
SchedulerClusterId: uint64(a.config.Manager.SchedulerClusterID),
}); err != nil {
return nil, err
}
return a, nil
}
// Started announcer server.
func (a *announcer) Serve() {
logger.Info("announce scheduler to manager")
go a.announceToManager()
if a.trainerClient != nil {
logger.Info("announce scheduler to trainer")
a.announceToTrainer()
}
}
// Stop announcer server.
func (a *announcer) Stop() {
close(a.done)
}
// announceSeedPeer announces peer information to manager.
func (a *announcer) announceToManager() {
a.managerClient.KeepAlive(a.config.Manager.KeepAlive.Interval, &managerv2.KeepAliveRequest{
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
Hostname: a.config.Server.Host,
Ip: a.config.Server.AdvertiseIP.String(),
ClusterId: uint64(a.config.Manager.SchedulerClusterID),
}, a.done)
}
// announceSeedPeer announces dataset to trainer.
func (a *announcer) announceToTrainer() {
tick := time.NewTicker(a.config.Trainer.Interval)
for {
select {
case <-tick.C:
if err := a.train(); err != nil {
logger.Error(err)
}
case <-a.done:
return
}
}
}
// train uploads dataset to trainer and trigger training.
func (a *announcer) train() error {
ctx, cancel := context.WithTimeout(context.Background(), a.config.Trainer.UploadTimeout)
defer cancel()
stream, err := a.trainerClient.Train(ctx)
if err != nil {
return err
}
eg := errgroup.Group{}
eg.Go(func() error {
if err := a.uploadDownloadToTrainer(stream); err != nil {
return fmt.Errorf("upload download: %w", err)
}
return nil
})
eg.Go(func() error {
if err := a.uploadNetworkTopologyToTrainer(stream); err != nil {
return fmt.Errorf("upload network topology: %w", err)
}
return nil
})
if err := eg.Wait(); err != nil {
return err
}
if _, err := stream.CloseAndRecv(); err != nil {
return err
}
return nil
}
// uploadDownloadToTrainer uploads download information to trainer.
func (a *announcer) uploadDownloadToTrainer(stream trainerv1.Trainer_TrainClient) error {
readCloser, err := a.storage.OpenDownload()
if err != nil {
return err
}
defer readCloser.Close()
buf := make([]byte, UploadBufferSize)
for {
n, err := readCloser.Read(buf)
if err != nil && err != io.EOF {
return err
}
if err := stream.Send(&trainerv1.TrainRequest{
Hostname: a.config.Server.Host,
Ip: a.config.Server.AdvertiseIP.String(),
ClusterId: uint64(a.config.Manager.SchedulerClusterID),
Request: &trainerv1.TrainRequest_TrainMlpRequest{
TrainMlpRequest: &trainerv1.TrainMLPRequest{
Dataset: buf[:n],
},
},
}); err != nil {
return err
}
if err == io.EOF {
break
}
}
return nil
}
// uploadNetworkTopologyToTrainer uploads network topology to trainer.
func (a *announcer) uploadNetworkTopologyToTrainer(stream trainerv1.Trainer_TrainClient) error {
readCloser, err := a.storage.OpenNetworkTopology()
if err != nil {
return err
}
defer readCloser.Close()
buf := make([]byte, UploadBufferSize)
for {
n, err := readCloser.Read(buf)
if err != nil && err != io.EOF {
return err
}
if err := stream.Send(&trainerv1.TrainRequest{
Hostname: a.config.Server.Host,
Ip: a.config.Server.AdvertiseIP.String(),
ClusterId: uint64(a.config.Manager.SchedulerClusterID),
Request: &trainerv1.TrainRequest_TrainGnnRequest{
TrainGnnRequest: &trainerv1.TrainGNNRequest{
Dataset: buf[:n],
},
},
}); err != nil {
return err
}
if err == io.EOF {
break
}
}
return nil
}