163 lines
4.4 KiB
Go
163 lines
4.4 KiB
Go
/*
|
|
* Copyright 2023 The Dragonfly Authors
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package service
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/status"
|
|
"google.golang.org/protobuf/types/known/emptypb"
|
|
|
|
trainerv1 "d7y.io/api/v2/pkg/apis/trainer/v1"
|
|
|
|
logger "d7y.io/dragonfly/v2/internal/dflog"
|
|
"d7y.io/dragonfly/v2/pkg/idgen"
|
|
"d7y.io/dragonfly/v2/trainer/config"
|
|
"d7y.io/dragonfly/v2/trainer/storage"
|
|
"d7y.io/dragonfly/v2/trainer/training"
|
|
)
|
|
|
|
// V1 is the interface for v1 version of the service.
|
|
type V1 struct {
|
|
// Trainer service config.
|
|
config *config.Config
|
|
|
|
// Storage Interface.
|
|
storage storage.Storage
|
|
|
|
// Training Interface.
|
|
training training.Training
|
|
}
|
|
|
|
// New v1 version of service instance.
|
|
func NewV1(
|
|
cfg *config.Config,
|
|
storage storage.Storage,
|
|
training training.Training,
|
|
) *V1 {
|
|
return &V1{cfg, storage, training}
|
|
}
|
|
|
|
// Train implements the Trainer.Train method.
|
|
func (v *V1) Train(stream trainerv1.Trainer_TrainServer) error {
|
|
var (
|
|
ip string
|
|
hostname string
|
|
hostID string
|
|
networkTopologyFile io.WriteCloser
|
|
downloadFile io.WriteCloser
|
|
req *trainerv1.TrainRequest
|
|
initialized bool
|
|
err error
|
|
)
|
|
|
|
for {
|
|
req, err = stream.Recv()
|
|
if err != nil {
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
|
|
logger.Errorf("receive failed: %s", err.Error())
|
|
return err
|
|
}
|
|
|
|
logger := logger.WithHostnameAndIP(req.Hostname, req.Ip)
|
|
if !initialized {
|
|
initialized = true
|
|
ip = req.Ip
|
|
hostname = req.Hostname
|
|
hostID = idgen.HostIDV2(req.Ip, req.Hostname)
|
|
|
|
// Open network topology file and store received data.
|
|
networkTopologyFile, err = v.storage.OpenNetworkTopology(hostID)
|
|
if err != nil {
|
|
msg := fmt.Sprintf("open network topology failed: %s", err.Error())
|
|
logger.Error(msg)
|
|
return status.Error(codes.Internal, msg)
|
|
}
|
|
defer func() {
|
|
networkTopologyFile.Close()
|
|
|
|
// If error occurred, clear network topology.
|
|
if err != nil && err != io.EOF {
|
|
if err := v.storage.ClearNetworkTopology(hostID); err != nil {
|
|
logger.Errorf("clear network topology failed: %s", err.Error())
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Open download file and store received data.
|
|
downloadFile, err = v.storage.OpenDownload(hostID)
|
|
if err != nil {
|
|
msg := fmt.Sprintf("open download failed: %s", err.Error())
|
|
logger.Error(msg)
|
|
return status.Error(codes.Internal, msg)
|
|
}
|
|
defer func() {
|
|
downloadFile.Close()
|
|
|
|
// If error occurred, clear download.
|
|
if err != nil && err != io.EOF {
|
|
if err := v.storage.ClearDownload(hostID); err != nil {
|
|
logger.Errorf("clear download failed: %s", err.Error())
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
switch trainRequest := req.GetRequest().(type) {
|
|
case *trainerv1.TrainRequest_TrainGnnRequest:
|
|
// Store network topology.
|
|
if _, err := networkTopologyFile.Write(trainRequest.TrainGnnRequest.Dataset); err != nil {
|
|
msg := fmt.Sprintf("write network topology failed: %s", err.Error())
|
|
logger.Error(msg)
|
|
return status.Error(codes.Internal, msg)
|
|
}
|
|
case *trainerv1.TrainRequest_TrainMlpRequest:
|
|
// Store download.
|
|
if _, err := downloadFile.Write(trainRequest.TrainMlpRequest.Dataset); err != nil {
|
|
msg := fmt.Sprintf("write download failed: %s", err.Error())
|
|
logger.Error(msg)
|
|
return status.Error(codes.Internal, msg)
|
|
}
|
|
default:
|
|
msg := fmt.Sprintf("receive unknown request: %#v", trainRequest)
|
|
logger.Error(msg)
|
|
return status.Error(codes.FailedPrecondition, msg)
|
|
}
|
|
}
|
|
|
|
// Send empty response and close stream.
|
|
if err := stream.SendAndClose(&emptypb.Empty{}); err != nil {
|
|
logger.Errorf("send and close failed: %s", err.Error())
|
|
return err
|
|
}
|
|
|
|
// If all dataset received, start training.
|
|
go func() {
|
|
if err := v.training.Train(context.Background(), ip, hostname); err != nil {
|
|
logger.Errorf("train failed: %s", err.Error())
|
|
}
|
|
}()
|
|
|
|
return nil
|
|
}
|