From 339e0e901b27bd3d16a40f4f8b7d5c58bf9309ab Mon Sep 17 00:00:00 2001 From: Gaius Date: Tue, 21 Mar 2023 21:21:08 +0800 Subject: [PATCH] feat: add definition of trainer service (#99) Signed-off-by: Gaius --- build.rs | 1 + hack/protoc.sh | 2 +- pkg/apis/dfdaemon/v2/dfdaemon.proto | 2 +- pkg/apis/scheduler/v2/scheduler.proto | 2 +- pkg/apis/trainer/v1/mocks/mocks.go | 19 + pkg/apis/trainer/v1/mocks/trainer_mock.go | 402 +++++++++++++++++ pkg/apis/trainer/v1/trainer.pb.go | 384 ++++++++++++++++ pkg/apis/trainer/v1/trainer.pb.validate.go | 488 +++++++++++++++++++++ pkg/apis/trainer/v1/trainer.proto | 59 +++ pkg/apis/trainer/v1/trainer_grpc.pb.go | 140 ++++++ proto/trainer.proto | 54 +++ src/trainer.rs | 287 ++++++++++++ 12 files changed, 1837 insertions(+), 3 deletions(-) create mode 100644 pkg/apis/trainer/v1/mocks/mocks.go create mode 100644 pkg/apis/trainer/v1/mocks/trainer_mock.go create mode 100644 pkg/apis/trainer/v1/trainer.pb.go create mode 100644 pkg/apis/trainer/v1/trainer.pb.validate.go create mode 100644 pkg/apis/trainer/v1/trainer.proto create mode 100644 pkg/apis/trainer/v1/trainer_grpc.pb.go create mode 100644 proto/trainer.proto create mode 100644 src/trainer.rs diff --git a/build.rs b/build.rs index 3893365..87a130d 100644 --- a/build.rs +++ b/build.rs @@ -10,6 +10,7 @@ fn main() -> Result<(), Box> { "proto/dfdaemon.proto", "proto/manager.proto", "proto/scheduler.proto", + "proto/trainer.proto", ], &["proto/"], )?; diff --git a/hack/protoc.sh b/hack/protoc.sh index e441192..5866c8a 100755 --- a/hack/protoc.sh +++ b/hack/protoc.sh @@ -4,7 +4,7 @@ PROTOC_ALL_IMAGE=${PROTOC_ALL_IMAGE:-"namely/protoc-all:1.51_1"} PROTO_PATH=pkg/apis LANGUAGE=go -proto_modules="common/v1 common/v2 cdnsystem/v1 dfdaemon/v1 dfdaemon/v2 errordetails/v1 manager/v1 manager/v2 scheduler/v1 scheduler/v2 security/v1" +proto_modules="common/v1 common/v2 cdnsystem/v1 dfdaemon/v1 dfdaemon/v2 errordetails/v1 manager/v1 manager/v2 scheduler/v1 scheduler/v2 security/v1 trainer/v1" echo "generate protos..." diff --git a/pkg/apis/dfdaemon/v2/dfdaemon.proto b/pkg/apis/dfdaemon/v2/dfdaemon.proto index 7ddef8d..dc62d74 100644 --- a/pkg/apis/dfdaemon/v2/dfdaemon.proto +++ b/pkg/apis/dfdaemon/v2/dfdaemon.proto @@ -90,7 +90,7 @@ message DeleteTaskRequest { } // Dfdaemon RPC Service. -service Dfdaemon{ +service Dfdaemon { // SyncPieces syncs pieces from the other peers. rpc SyncPieces(stream SyncPiecesRequest)returns(stream SyncPiecesResponse); diff --git a/pkg/apis/scheduler/v2/scheduler.proto b/pkg/apis/scheduler/v2/scheduler.proto index 99b29fe..eb72e87 100644 --- a/pkg/apis/scheduler/v2/scheduler.proto +++ b/pkg/apis/scheduler/v2/scheduler.proto @@ -320,7 +320,7 @@ message SyncNetworkTopologyRequest { } // Scheduler RPC Service. -service Scheduler{ +service Scheduler { // AnnouncePeer announces peer to scheduler. rpc AnnouncePeer(stream AnnouncePeerRequest) returns(stream AnnouncePeerResponse); diff --git a/pkg/apis/trainer/v1/mocks/mocks.go b/pkg/apis/trainer/v1/mocks/mocks.go new file mode 100644 index 0000000..f8c53be --- /dev/null +++ b/pkg/apis/trainer/v1/mocks/mocks.go @@ -0,0 +1,19 @@ +/* + * Copyright 2023 The Dragonfly Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package mocks + +//go:generate mockgen -destination trainer_mock.go -source ../trainer_grpc.pb.go -package mocks diff --git a/pkg/apis/trainer/v1/mocks/trainer_mock.go b/pkg/apis/trainer/v1/mocks/trainer_mock.go new file mode 100644 index 0000000..942f0b6 --- /dev/null +++ b/pkg/apis/trainer/v1/mocks/trainer_mock.go @@ -0,0 +1,402 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ../trainer_grpc.pb.go + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + trainer "d7y.io/api/pkg/apis/trainer/v1" + gomock "github.com/golang/mock/gomock" + grpc "google.golang.org/grpc" + metadata "google.golang.org/grpc/metadata" + emptypb "google.golang.org/protobuf/types/known/emptypb" +) + +// MockTrainerClient is a mock of TrainerClient interface. +type MockTrainerClient struct { + ctrl *gomock.Controller + recorder *MockTrainerClientMockRecorder +} + +// MockTrainerClientMockRecorder is the mock recorder for MockTrainerClient. +type MockTrainerClientMockRecorder struct { + mock *MockTrainerClient +} + +// NewMockTrainerClient creates a new mock instance. +func NewMockTrainerClient(ctrl *gomock.Controller) *MockTrainerClient { + mock := &MockTrainerClient{ctrl: ctrl} + mock.recorder = &MockTrainerClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTrainerClient) EXPECT() *MockTrainerClientMockRecorder { + return m.recorder +} + +// Train mocks base method. +func (m *MockTrainerClient) Train(ctx context.Context, opts ...grpc.CallOption) (trainer.Trainer_TrainClient, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Train", varargs...) + ret0, _ := ret[0].(trainer.Trainer_TrainClient) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Train indicates an expected call of Train. +func (mr *MockTrainerClientMockRecorder) Train(ctx interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Train", reflect.TypeOf((*MockTrainerClient)(nil).Train), varargs...) +} + +// MockTrainer_TrainClient is a mock of Trainer_TrainClient interface. +type MockTrainer_TrainClient struct { + ctrl *gomock.Controller + recorder *MockTrainer_TrainClientMockRecorder +} + +// MockTrainer_TrainClientMockRecorder is the mock recorder for MockTrainer_TrainClient. +type MockTrainer_TrainClientMockRecorder struct { + mock *MockTrainer_TrainClient +} + +// NewMockTrainer_TrainClient creates a new mock instance. +func NewMockTrainer_TrainClient(ctrl *gomock.Controller) *MockTrainer_TrainClient { + mock := &MockTrainer_TrainClient{ctrl: ctrl} + mock.recorder = &MockTrainer_TrainClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTrainer_TrainClient) EXPECT() *MockTrainer_TrainClientMockRecorder { + return m.recorder +} + +// CloseAndRecv mocks base method. +func (m *MockTrainer_TrainClient) CloseAndRecv() (*emptypb.Empty, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CloseAndRecv") + ret0, _ := ret[0].(*emptypb.Empty) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CloseAndRecv indicates an expected call of CloseAndRecv. +func (mr *MockTrainer_TrainClientMockRecorder) CloseAndRecv() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseAndRecv", reflect.TypeOf((*MockTrainer_TrainClient)(nil).CloseAndRecv)) +} + +// CloseSend mocks base method. +func (m *MockTrainer_TrainClient) CloseSend() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CloseSend") + ret0, _ := ret[0].(error) + return ret0 +} + +// CloseSend indicates an expected call of CloseSend. +func (mr *MockTrainer_TrainClientMockRecorder) CloseSend() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseSend", reflect.TypeOf((*MockTrainer_TrainClient)(nil).CloseSend)) +} + +// Context mocks base method. +func (m *MockTrainer_TrainClient) Context() context.Context { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Context") + ret0, _ := ret[0].(context.Context) + return ret0 +} + +// Context indicates an expected call of Context. +func (mr *MockTrainer_TrainClientMockRecorder) Context() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockTrainer_TrainClient)(nil).Context)) +} + +// Header mocks base method. +func (m *MockTrainer_TrainClient) Header() (metadata.MD, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Header") + ret0, _ := ret[0].(metadata.MD) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Header indicates an expected call of Header. +func (mr *MockTrainer_TrainClientMockRecorder) Header() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Header", reflect.TypeOf((*MockTrainer_TrainClient)(nil).Header)) +} + +// RecvMsg mocks base method. +func (m_2 *MockTrainer_TrainClient) RecvMsg(m interface{}) error { + m_2.ctrl.T.Helper() + ret := m_2.ctrl.Call(m_2, "RecvMsg", m) + ret0, _ := ret[0].(error) + return ret0 +} + +// RecvMsg indicates an expected call of RecvMsg. +func (mr *MockTrainer_TrainClientMockRecorder) RecvMsg(m interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RecvMsg", reflect.TypeOf((*MockTrainer_TrainClient)(nil).RecvMsg), m) +} + +// Send mocks base method. +func (m *MockTrainer_TrainClient) Send(arg0 *trainer.TrainRequest) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Send", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Send indicates an expected call of Send. +func (mr *MockTrainer_TrainClientMockRecorder) Send(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockTrainer_TrainClient)(nil).Send), arg0) +} + +// SendMsg mocks base method. +func (m_2 *MockTrainer_TrainClient) SendMsg(m interface{}) error { + m_2.ctrl.T.Helper() + ret := m_2.ctrl.Call(m_2, "SendMsg", m) + ret0, _ := ret[0].(error) + return ret0 +} + +// SendMsg indicates an expected call of SendMsg. +func (mr *MockTrainer_TrainClientMockRecorder) SendMsg(m interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMsg", reflect.TypeOf((*MockTrainer_TrainClient)(nil).SendMsg), m) +} + +// Trailer mocks base method. +func (m *MockTrainer_TrainClient) Trailer() metadata.MD { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Trailer") + ret0, _ := ret[0].(metadata.MD) + return ret0 +} + +// Trailer indicates an expected call of Trailer. +func (mr *MockTrainer_TrainClientMockRecorder) Trailer() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Trailer", reflect.TypeOf((*MockTrainer_TrainClient)(nil).Trailer)) +} + +// MockTrainerServer is a mock of TrainerServer interface. +type MockTrainerServer struct { + ctrl *gomock.Controller + recorder *MockTrainerServerMockRecorder +} + +// MockTrainerServerMockRecorder is the mock recorder for MockTrainerServer. +type MockTrainerServerMockRecorder struct { + mock *MockTrainerServer +} + +// NewMockTrainerServer creates a new mock instance. +func NewMockTrainerServer(ctrl *gomock.Controller) *MockTrainerServer { + mock := &MockTrainerServer{ctrl: ctrl} + mock.recorder = &MockTrainerServerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTrainerServer) EXPECT() *MockTrainerServerMockRecorder { + return m.recorder +} + +// Train mocks base method. +func (m *MockTrainerServer) Train(arg0 trainer.Trainer_TrainServer) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Train", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Train indicates an expected call of Train. +func (mr *MockTrainerServerMockRecorder) Train(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Train", reflect.TypeOf((*MockTrainerServer)(nil).Train), arg0) +} + +// MockUnsafeTrainerServer is a mock of UnsafeTrainerServer interface. +type MockUnsafeTrainerServer struct { + ctrl *gomock.Controller + recorder *MockUnsafeTrainerServerMockRecorder +} + +// MockUnsafeTrainerServerMockRecorder is the mock recorder for MockUnsafeTrainerServer. +type MockUnsafeTrainerServerMockRecorder struct { + mock *MockUnsafeTrainerServer +} + +// NewMockUnsafeTrainerServer creates a new mock instance. +func NewMockUnsafeTrainerServer(ctrl *gomock.Controller) *MockUnsafeTrainerServer { + mock := &MockUnsafeTrainerServer{ctrl: ctrl} + mock.recorder = &MockUnsafeTrainerServerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockUnsafeTrainerServer) EXPECT() *MockUnsafeTrainerServerMockRecorder { + return m.recorder +} + +// mustEmbedUnimplementedTrainerServer mocks base method. +func (m *MockUnsafeTrainerServer) mustEmbedUnimplementedTrainerServer() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "mustEmbedUnimplementedTrainerServer") +} + +// mustEmbedUnimplementedTrainerServer indicates an expected call of mustEmbedUnimplementedTrainerServer. +func (mr *MockUnsafeTrainerServerMockRecorder) mustEmbedUnimplementedTrainerServer() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "mustEmbedUnimplementedTrainerServer", reflect.TypeOf((*MockUnsafeTrainerServer)(nil).mustEmbedUnimplementedTrainerServer)) +} + +// MockTrainer_TrainServer is a mock of Trainer_TrainServer interface. +type MockTrainer_TrainServer struct { + ctrl *gomock.Controller + recorder *MockTrainer_TrainServerMockRecorder +} + +// MockTrainer_TrainServerMockRecorder is the mock recorder for MockTrainer_TrainServer. +type MockTrainer_TrainServerMockRecorder struct { + mock *MockTrainer_TrainServer +} + +// NewMockTrainer_TrainServer creates a new mock instance. +func NewMockTrainer_TrainServer(ctrl *gomock.Controller) *MockTrainer_TrainServer { + mock := &MockTrainer_TrainServer{ctrl: ctrl} + mock.recorder = &MockTrainer_TrainServerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTrainer_TrainServer) EXPECT() *MockTrainer_TrainServerMockRecorder { + return m.recorder +} + +// Context mocks base method. +func (m *MockTrainer_TrainServer) Context() context.Context { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Context") + ret0, _ := ret[0].(context.Context) + return ret0 +} + +// Context indicates an expected call of Context. +func (mr *MockTrainer_TrainServerMockRecorder) Context() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockTrainer_TrainServer)(nil).Context)) +} + +// Recv mocks base method. +func (m *MockTrainer_TrainServer) Recv() (*trainer.TrainRequest, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Recv") + ret0, _ := ret[0].(*trainer.TrainRequest) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Recv indicates an expected call of Recv. +func (mr *MockTrainer_TrainServerMockRecorder) Recv() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Recv", reflect.TypeOf((*MockTrainer_TrainServer)(nil).Recv)) +} + +// RecvMsg mocks base method. +func (m_2 *MockTrainer_TrainServer) RecvMsg(m interface{}) error { + m_2.ctrl.T.Helper() + ret := m_2.ctrl.Call(m_2, "RecvMsg", m) + ret0, _ := ret[0].(error) + return ret0 +} + +// RecvMsg indicates an expected call of RecvMsg. +func (mr *MockTrainer_TrainServerMockRecorder) RecvMsg(m interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RecvMsg", reflect.TypeOf((*MockTrainer_TrainServer)(nil).RecvMsg), m) +} + +// SendAndClose mocks base method. +func (m *MockTrainer_TrainServer) SendAndClose(arg0 *emptypb.Empty) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendAndClose", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SendAndClose indicates an expected call of SendAndClose. +func (mr *MockTrainer_TrainServerMockRecorder) SendAndClose(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendAndClose", reflect.TypeOf((*MockTrainer_TrainServer)(nil).SendAndClose), arg0) +} + +// SendHeader mocks base method. +func (m *MockTrainer_TrainServer) SendHeader(arg0 metadata.MD) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendHeader", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SendHeader indicates an expected call of SendHeader. +func (mr *MockTrainer_TrainServerMockRecorder) SendHeader(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendHeader", reflect.TypeOf((*MockTrainer_TrainServer)(nil).SendHeader), arg0) +} + +// SendMsg mocks base method. +func (m_2 *MockTrainer_TrainServer) SendMsg(m interface{}) error { + m_2.ctrl.T.Helper() + ret := m_2.ctrl.Call(m_2, "SendMsg", m) + ret0, _ := ret[0].(error) + return ret0 +} + +// SendMsg indicates an expected call of SendMsg. +func (mr *MockTrainer_TrainServerMockRecorder) SendMsg(m interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMsg", reflect.TypeOf((*MockTrainer_TrainServer)(nil).SendMsg), m) +} + +// SetHeader mocks base method. +func (m *MockTrainer_TrainServer) SetHeader(arg0 metadata.MD) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetHeader", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetHeader indicates an expected call of SetHeader. +func (mr *MockTrainer_TrainServerMockRecorder) SetHeader(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHeader", reflect.TypeOf((*MockTrainer_TrainServer)(nil).SetHeader), arg0) +} + +// SetTrailer mocks base method. +func (m *MockTrainer_TrainServer) SetTrailer(arg0 metadata.MD) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetTrailer", arg0) +} + +// SetTrailer indicates an expected call of SetTrailer. +func (mr *MockTrainer_TrainServerMockRecorder) SetTrailer(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTrailer", reflect.TypeOf((*MockTrainer_TrainServer)(nil).SetTrailer), arg0) +} diff --git a/pkg/apis/trainer/v1/trainer.pb.go b/pkg/apis/trainer/v1/trainer.pb.go new file mode 100644 index 0000000..e5720a9 --- /dev/null +++ b/pkg/apis/trainer/v1/trainer.pb.go @@ -0,0 +1,384 @@ +// +// Copyright 2023 The Dragonfly Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.28.1 +// protoc v3.21.6 +// source: pkg/apis/trainer/v1/trainer.proto + +package trainer + +import ( + _ "github.com/envoyproxy/protoc-gen-validate/validate" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + emptypb "google.golang.org/protobuf/types/known/emptypb" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// GNNRequest represents gnn model request of TrainRequest. +type GNNRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // Dataset of training gnn. + Dataset []byte `protobuf:"bytes,1,opt,name=dataset,proto3" json:"dataset,omitempty"` +} + +func (x *GNNRequest) Reset() { + *x = GNNRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_apis_trainer_v1_trainer_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *GNNRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GNNRequest) ProtoMessage() {} + +func (x *GNNRequest) ProtoReflect() protoreflect.Message { + mi := &file_pkg_apis_trainer_v1_trainer_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GNNRequest.ProtoReflect.Descriptor instead. +func (*GNNRequest) Descriptor() ([]byte, []int) { + return file_pkg_apis_trainer_v1_trainer_proto_rawDescGZIP(), []int{0} +} + +func (x *GNNRequest) GetDataset() []byte { + if x != nil { + return x.Dataset + } + return nil +} + +// MLPRequest represents mlp model request of TrainRequest. +type MLPRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // Dataset of training mlp. + Dataset []byte `protobuf:"bytes,1,opt,name=dataset,proto3" json:"dataset,omitempty"` +} + +func (x *MLPRequest) Reset() { + *x = MLPRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_apis_trainer_v1_trainer_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *MLPRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*MLPRequest) ProtoMessage() {} + +func (x *MLPRequest) ProtoReflect() protoreflect.Message { + mi := &file_pkg_apis_trainer_v1_trainer_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use MLPRequest.ProtoReflect.Descriptor instead. +func (*MLPRequest) Descriptor() ([]byte, []int) { + return file_pkg_apis_trainer_v1_trainer_proto_rawDescGZIP(), []int{1} +} + +func (x *MLPRequest) GetDataset() []byte { + if x != nil { + return x.Dataset + } + return nil +} + +// TrainRequest represents request of Train. +type TrainRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // Scheduler hostname. + Hostname string `protobuf:"bytes,1,opt,name=hostname,proto3" json:"hostname,omitempty"` + // Scheduler ip. + Ip string `protobuf:"bytes,2,opt,name=ip,proto3" json:"ip,omitempty"` + // Scheduler cluster id. + ClusterId uint64 `protobuf:"varint,3,opt,name=cluster_id,json=clusterId,proto3" json:"cluster_id,omitempty"` + // Types that are assignable to Request: + // + // *TrainRequest_GnnRequest + // *TrainRequest_MlpRequest + Request isTrainRequest_Request `protobuf_oneof:"request"` +} + +func (x *TrainRequest) Reset() { + *x = TrainRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_apis_trainer_v1_trainer_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *TrainRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TrainRequest) ProtoMessage() {} + +func (x *TrainRequest) ProtoReflect() protoreflect.Message { + mi := &file_pkg_apis_trainer_v1_trainer_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TrainRequest.ProtoReflect.Descriptor instead. +func (*TrainRequest) Descriptor() ([]byte, []int) { + return file_pkg_apis_trainer_v1_trainer_proto_rawDescGZIP(), []int{2} +} + +func (x *TrainRequest) GetHostname() string { + if x != nil { + return x.Hostname + } + return "" +} + +func (x *TrainRequest) GetIp() string { + if x != nil { + return x.Ip + } + return "" +} + +func (x *TrainRequest) GetClusterId() uint64 { + if x != nil { + return x.ClusterId + } + return 0 +} + +func (m *TrainRequest) GetRequest() isTrainRequest_Request { + if m != nil { + return m.Request + } + return nil +} + +func (x *TrainRequest) GetGnnRequest() *GNNRequest { + if x, ok := x.GetRequest().(*TrainRequest_GnnRequest); ok { + return x.GnnRequest + } + return nil +} + +func (x *TrainRequest) GetMlpRequest() *MLPRequest { + if x, ok := x.GetRequest().(*TrainRequest_MlpRequest); ok { + return x.MlpRequest + } + return nil +} + +type isTrainRequest_Request interface { + isTrainRequest_Request() +} + +type TrainRequest_GnnRequest struct { + GnnRequest *GNNRequest `protobuf:"bytes,4,opt,name=gnn_request,json=gnnRequest,proto3,oneof"` +} + +type TrainRequest_MlpRequest struct { + MlpRequest *MLPRequest `protobuf:"bytes,5,opt,name=mlp_request,json=mlpRequest,proto3,oneof"` +} + +func (*TrainRequest_GnnRequest) isTrainRequest_Request() {} + +func (*TrainRequest_MlpRequest) isTrainRequest_Request() {} + +var File_pkg_apis_trainer_v1_trainer_proto protoreflect.FileDescriptor + +var file_pkg_apis_trainer_v1_trainer_proto_rawDesc = []byte{ + 0x0a, 0x21, 0x70, 0x6b, 0x67, 0x2f, 0x61, 0x70, 0x69, 0x73, 0x2f, 0x74, 0x72, 0x61, 0x69, 0x6e, + 0x65, 0x72, 0x2f, 0x76, 0x31, 0x2f, 0x74, 0x72, 0x61, 0x69, 0x6e, 0x65, 0x72, 0x2e, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x12, 0x0a, 0x74, 0x72, 0x61, 0x69, 0x6e, 0x65, 0x72, 0x2e, 0x76, 0x31, 0x1a, + 0x1b, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, + 0x2f, 0x65, 0x6d, 0x70, 0x74, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x17, 0x76, 0x61, + 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x2f, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x2e, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x2f, 0x0a, 0x0a, 0x47, 0x4e, 0x4e, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x12, 0x21, 0x0a, 0x07, 0x64, 0x61, 0x74, 0x61, 0x73, 0x65, 0x74, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x0c, 0x42, 0x07, 0xfa, 0x42, 0x04, 0x7a, 0x02, 0x10, 0x01, 0x52, 0x07, 0x64, + 0x61, 0x74, 0x61, 0x73, 0x65, 0x74, 0x22, 0x2f, 0x0a, 0x0a, 0x4d, 0x4c, 0x50, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x12, 0x21, 0x0a, 0x07, 0x64, 0x61, 0x74, 0x61, 0x73, 0x65, 0x74, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x0c, 0x42, 0x07, 0xfa, 0x42, 0x04, 0x7a, 0x02, 0x10, 0x01, 0x52, 0x07, + 0x64, 0x61, 0x74, 0x61, 0x73, 0x65, 0x74, 0x22, 0xfa, 0x01, 0x0a, 0x0c, 0x54, 0x72, 0x61, 0x69, + 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x23, 0x0a, 0x08, 0x68, 0x6f, 0x73, 0x74, + 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x42, 0x07, 0xfa, 0x42, 0x04, 0x72, + 0x02, 0x10, 0x01, 0x52, 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x17, 0x0a, + 0x02, 0x69, 0x70, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x42, 0x07, 0xfa, 0x42, 0x04, 0x72, 0x02, + 0x70, 0x01, 0x52, 0x02, 0x69, 0x70, 0x12, 0x26, 0x0a, 0x0a, 0x63, 0x6c, 0x75, 0x73, 0x74, 0x65, + 0x72, 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x04, 0x42, 0x07, 0xfa, 0x42, 0x04, 0x32, + 0x02, 0x28, 0x01, 0x52, 0x09, 0x63, 0x6c, 0x75, 0x73, 0x74, 0x65, 0x72, 0x49, 0x64, 0x12, 0x39, + 0x0a, 0x0b, 0x67, 0x6e, 0x6e, 0x5f, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x18, 0x04, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x74, 0x72, 0x61, 0x69, 0x6e, 0x65, 0x72, 0x2e, 0x76, 0x31, + 0x2e, 0x47, 0x4e, 0x4e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x48, 0x00, 0x52, 0x0a, 0x67, + 0x6e, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x39, 0x0a, 0x0b, 0x6d, 0x6c, 0x70, + 0x5f, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, + 0x2e, 0x74, 0x72, 0x61, 0x69, 0x6e, 0x65, 0x72, 0x2e, 0x76, 0x31, 0x2e, 0x4d, 0x4c, 0x50, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x48, 0x00, 0x52, 0x0a, 0x6d, 0x6c, 0x70, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x42, 0x0e, 0x0a, 0x07, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, + 0x03, 0xf8, 0x42, 0x01, 0x32, 0x46, 0x0a, 0x07, 0x54, 0x72, 0x61, 0x69, 0x6e, 0x65, 0x72, 0x12, + 0x3b, 0x0a, 0x05, 0x54, 0x72, 0x61, 0x69, 0x6e, 0x12, 0x18, 0x2e, 0x74, 0x72, 0x61, 0x69, 0x6e, + 0x65, 0x72, 0x2e, 0x76, 0x31, 0x2e, 0x54, 0x72, 0x61, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x28, 0x01, 0x42, 0x28, 0x5a, 0x26, + 0x64, 0x37, 0x79, 0x2e, 0x69, 0x6f, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x61, + 0x70, 0x69, 0x73, 0x2f, 0x74, 0x72, 0x61, 0x69, 0x6e, 0x65, 0x72, 0x2f, 0x76, 0x31, 0x3b, 0x74, + 0x72, 0x61, 0x69, 0x6e, 0x65, 0x72, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_pkg_apis_trainer_v1_trainer_proto_rawDescOnce sync.Once + file_pkg_apis_trainer_v1_trainer_proto_rawDescData = file_pkg_apis_trainer_v1_trainer_proto_rawDesc +) + +func file_pkg_apis_trainer_v1_trainer_proto_rawDescGZIP() []byte { + file_pkg_apis_trainer_v1_trainer_proto_rawDescOnce.Do(func() { + file_pkg_apis_trainer_v1_trainer_proto_rawDescData = protoimpl.X.CompressGZIP(file_pkg_apis_trainer_v1_trainer_proto_rawDescData) + }) + return file_pkg_apis_trainer_v1_trainer_proto_rawDescData +} + +var file_pkg_apis_trainer_v1_trainer_proto_msgTypes = make([]protoimpl.MessageInfo, 3) +var file_pkg_apis_trainer_v1_trainer_proto_goTypes = []interface{}{ + (*GNNRequest)(nil), // 0: trainer.v1.GNNRequest + (*MLPRequest)(nil), // 1: trainer.v1.MLPRequest + (*TrainRequest)(nil), // 2: trainer.v1.TrainRequest + (*emptypb.Empty)(nil), // 3: google.protobuf.Empty +} +var file_pkg_apis_trainer_v1_trainer_proto_depIdxs = []int32{ + 0, // 0: trainer.v1.TrainRequest.gnn_request:type_name -> trainer.v1.GNNRequest + 1, // 1: trainer.v1.TrainRequest.mlp_request:type_name -> trainer.v1.MLPRequest + 2, // 2: trainer.v1.Trainer.Train:input_type -> trainer.v1.TrainRequest + 3, // 3: trainer.v1.Trainer.Train:output_type -> google.protobuf.Empty + 3, // [3:4] is the sub-list for method output_type + 2, // [2:3] is the sub-list for method input_type + 2, // [2:2] is the sub-list for extension type_name + 2, // [2:2] is the sub-list for extension extendee + 0, // [0:2] is the sub-list for field type_name +} + +func init() { file_pkg_apis_trainer_v1_trainer_proto_init() } +func file_pkg_apis_trainer_v1_trainer_proto_init() { + if File_pkg_apis_trainer_v1_trainer_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_pkg_apis_trainer_v1_trainer_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*GNNRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pkg_apis_trainer_v1_trainer_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*MLPRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pkg_apis_trainer_v1_trainer_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*TrainRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + file_pkg_apis_trainer_v1_trainer_proto_msgTypes[2].OneofWrappers = []interface{}{ + (*TrainRequest_GnnRequest)(nil), + (*TrainRequest_MlpRequest)(nil), + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_pkg_apis_trainer_v1_trainer_proto_rawDesc, + NumEnums: 0, + NumMessages: 3, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_pkg_apis_trainer_v1_trainer_proto_goTypes, + DependencyIndexes: file_pkg_apis_trainer_v1_trainer_proto_depIdxs, + MessageInfos: file_pkg_apis_trainer_v1_trainer_proto_msgTypes, + }.Build() + File_pkg_apis_trainer_v1_trainer_proto = out.File + file_pkg_apis_trainer_v1_trainer_proto_rawDesc = nil + file_pkg_apis_trainer_v1_trainer_proto_goTypes = nil + file_pkg_apis_trainer_v1_trainer_proto_depIdxs = nil +} diff --git a/pkg/apis/trainer/v1/trainer.pb.validate.go b/pkg/apis/trainer/v1/trainer.pb.validate.go new file mode 100644 index 0000000..eb18b5c --- /dev/null +++ b/pkg/apis/trainer/v1/trainer.pb.validate.go @@ -0,0 +1,488 @@ +// Code generated by protoc-gen-validate. DO NOT EDIT. +// source: pkg/apis/trainer/v1/trainer.proto + +package trainer + +import ( + "bytes" + "errors" + "fmt" + "net" + "net/mail" + "net/url" + "regexp" + "sort" + "strings" + "time" + "unicode/utf8" + + "google.golang.org/protobuf/types/known/anypb" +) + +// ensure the imports are used +var ( + _ = bytes.MinRead + _ = errors.New("") + _ = fmt.Print + _ = utf8.UTFMax + _ = (*regexp.Regexp)(nil) + _ = (*strings.Reader)(nil) + _ = net.IPv4len + _ = time.Duration(0) + _ = (*url.URL)(nil) + _ = (*mail.Address)(nil) + _ = anypb.Any{} + _ = sort.Sort +) + +// Validate checks the field values on GNNRequest with the rules defined in the +// proto definition for this message. If any rules are violated, the first +// error encountered is returned, or nil if there are no violations. +func (m *GNNRequest) Validate() error { + return m.validate(false) +} + +// ValidateAll checks the field values on GNNRequest with the rules defined in +// the proto definition for this message. If any rules are violated, the +// result is a list of violation errors wrapped in GNNRequestMultiError, or +// nil if none found. +func (m *GNNRequest) ValidateAll() error { + return m.validate(true) +} + +func (m *GNNRequest) validate(all bool) error { + if m == nil { + return nil + } + + var errors []error + + if len(m.GetDataset()) < 1 { + err := GNNRequestValidationError{ + field: "Dataset", + reason: "value length must be at least 1 bytes", + } + if !all { + return err + } + errors = append(errors, err) + } + + if len(errors) > 0 { + return GNNRequestMultiError(errors) + } + + return nil +} + +// GNNRequestMultiError is an error wrapping multiple validation errors +// returned by GNNRequest.ValidateAll() if the designated constraints aren't met. +type GNNRequestMultiError []error + +// Error returns a concatenation of all the error messages it wraps. +func (m GNNRequestMultiError) Error() string { + var msgs []string + for _, err := range m { + msgs = append(msgs, err.Error()) + } + return strings.Join(msgs, "; ") +} + +// AllErrors returns a list of validation violation errors. +func (m GNNRequestMultiError) AllErrors() []error { return m } + +// GNNRequestValidationError is the validation error returned by +// GNNRequest.Validate if the designated constraints aren't met. +type GNNRequestValidationError struct { + field string + reason string + cause error + key bool +} + +// Field function returns field value. +func (e GNNRequestValidationError) Field() string { return e.field } + +// Reason function returns reason value. +func (e GNNRequestValidationError) Reason() string { return e.reason } + +// Cause function returns cause value. +func (e GNNRequestValidationError) Cause() error { return e.cause } + +// Key function returns key value. +func (e GNNRequestValidationError) Key() bool { return e.key } + +// ErrorName returns error name. +func (e GNNRequestValidationError) ErrorName() string { return "GNNRequestValidationError" } + +// Error satisfies the builtin error interface +func (e GNNRequestValidationError) Error() string { + cause := "" + if e.cause != nil { + cause = fmt.Sprintf(" | caused by: %v", e.cause) + } + + key := "" + if e.key { + key = "key for " + } + + return fmt.Sprintf( + "invalid %sGNNRequest.%s: %s%s", + key, + e.field, + e.reason, + cause) +} + +var _ error = GNNRequestValidationError{} + +var _ interface { + Field() string + Reason() string + Key() bool + Cause() error + ErrorName() string +} = GNNRequestValidationError{} + +// Validate checks the field values on MLPRequest with the rules defined in the +// proto definition for this message. If any rules are violated, the first +// error encountered is returned, or nil if there are no violations. +func (m *MLPRequest) Validate() error { + return m.validate(false) +} + +// ValidateAll checks the field values on MLPRequest with the rules defined in +// the proto definition for this message. If any rules are violated, the +// result is a list of violation errors wrapped in MLPRequestMultiError, or +// nil if none found. +func (m *MLPRequest) ValidateAll() error { + return m.validate(true) +} + +func (m *MLPRequest) validate(all bool) error { + if m == nil { + return nil + } + + var errors []error + + if len(m.GetDataset()) < 1 { + err := MLPRequestValidationError{ + field: "Dataset", + reason: "value length must be at least 1 bytes", + } + if !all { + return err + } + errors = append(errors, err) + } + + if len(errors) > 0 { + return MLPRequestMultiError(errors) + } + + return nil +} + +// MLPRequestMultiError is an error wrapping multiple validation errors +// returned by MLPRequest.ValidateAll() if the designated constraints aren't met. +type MLPRequestMultiError []error + +// Error returns a concatenation of all the error messages it wraps. +func (m MLPRequestMultiError) Error() string { + var msgs []string + for _, err := range m { + msgs = append(msgs, err.Error()) + } + return strings.Join(msgs, "; ") +} + +// AllErrors returns a list of validation violation errors. +func (m MLPRequestMultiError) AllErrors() []error { return m } + +// MLPRequestValidationError is the validation error returned by +// MLPRequest.Validate if the designated constraints aren't met. +type MLPRequestValidationError struct { + field string + reason string + cause error + key bool +} + +// Field function returns field value. +func (e MLPRequestValidationError) Field() string { return e.field } + +// Reason function returns reason value. +func (e MLPRequestValidationError) Reason() string { return e.reason } + +// Cause function returns cause value. +func (e MLPRequestValidationError) Cause() error { return e.cause } + +// Key function returns key value. +func (e MLPRequestValidationError) Key() bool { return e.key } + +// ErrorName returns error name. +func (e MLPRequestValidationError) ErrorName() string { return "MLPRequestValidationError" } + +// Error satisfies the builtin error interface +func (e MLPRequestValidationError) Error() string { + cause := "" + if e.cause != nil { + cause = fmt.Sprintf(" | caused by: %v", e.cause) + } + + key := "" + if e.key { + key = "key for " + } + + return fmt.Sprintf( + "invalid %sMLPRequest.%s: %s%s", + key, + e.field, + e.reason, + cause) +} + +var _ error = MLPRequestValidationError{} + +var _ interface { + Field() string + Reason() string + Key() bool + Cause() error + ErrorName() string +} = MLPRequestValidationError{} + +// Validate checks the field values on TrainRequest with the rules defined in +// the proto definition for this message. If any rules are violated, the first +// error encountered is returned, or nil if there are no violations. +func (m *TrainRequest) Validate() error { + return m.validate(false) +} + +// ValidateAll checks the field values on TrainRequest with the rules defined +// in the proto definition for this message. If any rules are violated, the +// result is a list of violation errors wrapped in TrainRequestMultiError, or +// nil if none found. +func (m *TrainRequest) ValidateAll() error { + return m.validate(true) +} + +func (m *TrainRequest) validate(all bool) error { + if m == nil { + return nil + } + + var errors []error + + if utf8.RuneCountInString(m.GetHostname()) < 1 { + err := TrainRequestValidationError{ + field: "Hostname", + reason: "value length must be at least 1 runes", + } + if !all { + return err + } + errors = append(errors, err) + } + + if ip := net.ParseIP(m.GetIp()); ip == nil { + err := TrainRequestValidationError{ + field: "Ip", + reason: "value must be a valid IP address", + } + if !all { + return err + } + errors = append(errors, err) + } + + if m.GetClusterId() < 1 { + err := TrainRequestValidationError{ + field: "ClusterId", + reason: "value must be greater than or equal to 1", + } + if !all { + return err + } + errors = append(errors, err) + } + + oneofRequestPresent := false + switch v := m.Request.(type) { + case *TrainRequest_GnnRequest: + if v == nil { + err := TrainRequestValidationError{ + field: "Request", + reason: "oneof value cannot be a typed-nil", + } + if !all { + return err + } + errors = append(errors, err) + } + oneofRequestPresent = true + + if all { + switch v := interface{}(m.GetGnnRequest()).(type) { + case interface{ ValidateAll() error }: + if err := v.ValidateAll(); err != nil { + errors = append(errors, TrainRequestValidationError{ + field: "GnnRequest", + reason: "embedded message failed validation", + cause: err, + }) + } + case interface{ Validate() error }: + if err := v.Validate(); err != nil { + errors = append(errors, TrainRequestValidationError{ + field: "GnnRequest", + reason: "embedded message failed validation", + cause: err, + }) + } + } + } else if v, ok := interface{}(m.GetGnnRequest()).(interface{ Validate() error }); ok { + if err := v.Validate(); err != nil { + return TrainRequestValidationError{ + field: "GnnRequest", + reason: "embedded message failed validation", + cause: err, + } + } + } + + case *TrainRequest_MlpRequest: + if v == nil { + err := TrainRequestValidationError{ + field: "Request", + reason: "oneof value cannot be a typed-nil", + } + if !all { + return err + } + errors = append(errors, err) + } + oneofRequestPresent = true + + if all { + switch v := interface{}(m.GetMlpRequest()).(type) { + case interface{ ValidateAll() error }: + if err := v.ValidateAll(); err != nil { + errors = append(errors, TrainRequestValidationError{ + field: "MlpRequest", + reason: "embedded message failed validation", + cause: err, + }) + } + case interface{ Validate() error }: + if err := v.Validate(); err != nil { + errors = append(errors, TrainRequestValidationError{ + field: "MlpRequest", + reason: "embedded message failed validation", + cause: err, + }) + } + } + } else if v, ok := interface{}(m.GetMlpRequest()).(interface{ Validate() error }); ok { + if err := v.Validate(); err != nil { + return TrainRequestValidationError{ + field: "MlpRequest", + reason: "embedded message failed validation", + cause: err, + } + } + } + + default: + _ = v // ensures v is used + } + if !oneofRequestPresent { + err := TrainRequestValidationError{ + field: "Request", + reason: "value is required", + } + if !all { + return err + } + errors = append(errors, err) + } + + if len(errors) > 0 { + return TrainRequestMultiError(errors) + } + + return nil +} + +// TrainRequestMultiError is an error wrapping multiple validation errors +// returned by TrainRequest.ValidateAll() if the designated constraints aren't met. +type TrainRequestMultiError []error + +// Error returns a concatenation of all the error messages it wraps. +func (m TrainRequestMultiError) Error() string { + var msgs []string + for _, err := range m { + msgs = append(msgs, err.Error()) + } + return strings.Join(msgs, "; ") +} + +// AllErrors returns a list of validation violation errors. +func (m TrainRequestMultiError) AllErrors() []error { return m } + +// TrainRequestValidationError is the validation error returned by +// TrainRequest.Validate if the designated constraints aren't met. +type TrainRequestValidationError struct { + field string + reason string + cause error + key bool +} + +// Field function returns field value. +func (e TrainRequestValidationError) Field() string { return e.field } + +// Reason function returns reason value. +func (e TrainRequestValidationError) Reason() string { return e.reason } + +// Cause function returns cause value. +func (e TrainRequestValidationError) Cause() error { return e.cause } + +// Key function returns key value. +func (e TrainRequestValidationError) Key() bool { return e.key } + +// ErrorName returns error name. +func (e TrainRequestValidationError) ErrorName() string { return "TrainRequestValidationError" } + +// Error satisfies the builtin error interface +func (e TrainRequestValidationError) Error() string { + cause := "" + if e.cause != nil { + cause = fmt.Sprintf(" | caused by: %v", e.cause) + } + + key := "" + if e.key { + key = "key for " + } + + return fmt.Sprintf( + "invalid %sTrainRequest.%s: %s%s", + key, + e.field, + e.reason, + cause) +} + +var _ error = TrainRequestValidationError{} + +var _ interface { + Field() string + Reason() string + Key() bool + Cause() error + ErrorName() string +} = TrainRequestValidationError{} diff --git a/pkg/apis/trainer/v1/trainer.proto b/pkg/apis/trainer/v1/trainer.proto new file mode 100644 index 0000000..8c45f24 --- /dev/null +++ b/pkg/apis/trainer/v1/trainer.proto @@ -0,0 +1,59 @@ +/* + * Copyright 2023 The Dragonfly Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package trainer.v1; + +import "google/protobuf/empty.proto"; +import "validate/validate.proto"; + +option go_package = "d7y.io/api/pkg/apis/trainer/v1;trainer"; + +// GNNRequest represents gnn model request of TrainRequest. +message GNNRequest { + // Dataset of training gnn. + bytes dataset = 1 [(validate.rules).bytes.min_len = 1]; +} + +// MLPRequest represents mlp model request of TrainRequest. +message MLPRequest { + // Dataset of training mlp. + bytes dataset = 1 [(validate.rules).bytes.min_len = 1]; +} + +// TrainRequest represents request of Train. +message TrainRequest { + // Scheduler hostname. + string hostname = 1 [(validate.rules).string.min_len = 1]; + // Scheduler ip. + string ip = 2 [(validate.rules).string.ip = true]; + // Scheduler cluster id. + uint64 cluster_id = 3 [(validate.rules).uint64 = {gte: 1}]; + + oneof request { + option (validate.required) = true; + + GNNRequest gnn_request = 4; + MLPRequest mlp_request = 5; + } +} + +// Trainer RPC Service. +service Trainer { + // Train trains models of scheduler using dataset. + rpc Train(stream TrainRequest) returns(google.protobuf.Empty); +} diff --git a/pkg/apis/trainer/v1/trainer_grpc.pb.go b/pkg/apis/trainer/v1/trainer_grpc.pb.go new file mode 100644 index 0000000..ae47756 --- /dev/null +++ b/pkg/apis/trainer/v1/trainer_grpc.pb.go @@ -0,0 +1,140 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.2.0 +// - protoc v3.21.6 +// source: pkg/apis/trainer/v1/trainer.proto + +package trainer + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" + emptypb "google.golang.org/protobuf/types/known/emptypb" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.32.0 or later. +const _ = grpc.SupportPackageIsVersion7 + +// TrainerClient is the client API for Trainer service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type TrainerClient interface { + // Train trains models of scheduler using dataset. + Train(ctx context.Context, opts ...grpc.CallOption) (Trainer_TrainClient, error) +} + +type trainerClient struct { + cc grpc.ClientConnInterface +} + +func NewTrainerClient(cc grpc.ClientConnInterface) TrainerClient { + return &trainerClient{cc} +} + +func (c *trainerClient) Train(ctx context.Context, opts ...grpc.CallOption) (Trainer_TrainClient, error) { + stream, err := c.cc.NewStream(ctx, &Trainer_ServiceDesc.Streams[0], "/trainer.v1.Trainer/Train", opts...) + if err != nil { + return nil, err + } + x := &trainerTrainClient{stream} + return x, nil +} + +type Trainer_TrainClient interface { + Send(*TrainRequest) error + CloseAndRecv() (*emptypb.Empty, error) + grpc.ClientStream +} + +type trainerTrainClient struct { + grpc.ClientStream +} + +func (x *trainerTrainClient) Send(m *TrainRequest) error { + return x.ClientStream.SendMsg(m) +} + +func (x *trainerTrainClient) CloseAndRecv() (*emptypb.Empty, error) { + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + m := new(emptypb.Empty) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +// TrainerServer is the server API for Trainer service. +// All implementations should embed UnimplementedTrainerServer +// for forward compatibility +type TrainerServer interface { + // Train trains models of scheduler using dataset. + Train(Trainer_TrainServer) error +} + +// UnimplementedTrainerServer should be embedded to have forward compatible implementations. +type UnimplementedTrainerServer struct { +} + +func (UnimplementedTrainerServer) Train(Trainer_TrainServer) error { + return status.Errorf(codes.Unimplemented, "method Train not implemented") +} + +// UnsafeTrainerServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to TrainerServer will +// result in compilation errors. +type UnsafeTrainerServer interface { + mustEmbedUnimplementedTrainerServer() +} + +func RegisterTrainerServer(s grpc.ServiceRegistrar, srv TrainerServer) { + s.RegisterService(&Trainer_ServiceDesc, srv) +} + +func _Trainer_Train_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(TrainerServer).Train(&trainerTrainServer{stream}) +} + +type Trainer_TrainServer interface { + SendAndClose(*emptypb.Empty) error + Recv() (*TrainRequest, error) + grpc.ServerStream +} + +type trainerTrainServer struct { + grpc.ServerStream +} + +func (x *trainerTrainServer) SendAndClose(m *emptypb.Empty) error { + return x.ServerStream.SendMsg(m) +} + +func (x *trainerTrainServer) Recv() (*TrainRequest, error) { + m := new(TrainRequest) + if err := x.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +// Trainer_ServiceDesc is the grpc.ServiceDesc for Trainer service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var Trainer_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "trainer.v1.Trainer", + HandlerType: (*TrainerServer)(nil), + Methods: []grpc.MethodDesc{}, + Streams: []grpc.StreamDesc{ + { + StreamName: "Train", + Handler: _Trainer_Train_Handler, + ClientStreams: true, + }, + }, + Metadata: "pkg/apis/trainer/v1/trainer.proto", +} diff --git a/proto/trainer.proto b/proto/trainer.proto new file mode 100644 index 0000000..c531523 --- /dev/null +++ b/proto/trainer.proto @@ -0,0 +1,54 @@ +/* + * Copyright 2023 The Dragonfly Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package trainer; + +import "google/protobuf/empty.proto"; + +// GNNRequest represents gnn model request of TrainRequest. +message GNNRequest { + // Dataset of training gnn. + bytes dataset = 1; +} + +// MLPRequest represents mlp model request of TrainRequest. +message MLPRequest { + // Dataset of training mlp. + bytes dataset = 1; +} + +// TrainRequest represents request of Train. +message TrainRequest { + // Scheduler hostname. + string hostname = 1; + // Scheduler ip. + string ip = 2; + // Scheduler cluster id. + uint64 cluster_id = 3; + + oneof request { + GNNRequest gnn_request = 4; + MLPRequest mlp_request = 5; + } +} + +// Trainer RPC Service. +service Trainer { + // Train trains models of scheduler using dataset. + rpc Train(stream TrainRequest) returns(google.protobuf.Empty); +} diff --git a/src/trainer.rs b/src/trainer.rs new file mode 100644 index 0000000..0dac8a8 --- /dev/null +++ b/src/trainer.rs @@ -0,0 +1,287 @@ +/// GNNRequest represents gnn model request of TrainRequest. +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct GnnRequest { + /// Dataset of training gnn. + #[prost(bytes = "vec", tag = "1")] + pub dataset: ::prost::alloc::vec::Vec, +} +/// MLPRequest represents mlp model request of TrainRequest. +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct MlpRequest { + /// Dataset of training mlp. + #[prost(bytes = "vec", tag = "1")] + pub dataset: ::prost::alloc::vec::Vec, +} +/// TrainRequest represents request of Train. +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct TrainRequest { + /// Scheduler hostname. + #[prost(string, tag = "1")] + pub hostname: ::prost::alloc::string::String, + /// Scheduler ip. + #[prost(string, tag = "2")] + pub ip: ::prost::alloc::string::String, + /// Scheduler cluster id. + #[prost(uint64, tag = "3")] + pub cluster_id: u64, + #[prost(oneof = "train_request::Request", tags = "4, 5")] + pub request: ::core::option::Option, +} +/// Nested message and enum types in `TrainRequest`. +pub mod train_request { + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum Request { + #[prost(message, tag = "4")] + GnnRequest(super::GnnRequest), + #[prost(message, tag = "5")] + MlpRequest(super::MlpRequest), + } +} +/// Generated client implementations. +pub mod trainer_client { + #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] + use tonic::codegen::*; + use tonic::codegen::http::Uri; + /// Trainer RPC Service. + #[derive(Debug, Clone)] + pub struct TrainerClient { + inner: tonic::client::Grpc, + } + impl TrainerClient { + /// Attempt to create a new client by connecting to a given endpoint. + pub async fn connect(dst: D) -> Result + where + D: std::convert::TryInto, + D::Error: Into, + { + let conn = tonic::transport::Endpoint::new(dst)?.connect().await?; + Ok(Self::new(conn)) + } + } + impl TrainerClient + where + T: tonic::client::GrpcService, + T::Error: Into, + T::ResponseBody: Body + Send + 'static, + ::Error: Into + Send, + { + pub fn new(inner: T) -> Self { + let inner = tonic::client::Grpc::new(inner); + Self { inner } + } + pub fn with_origin(inner: T, origin: Uri) -> Self { + let inner = tonic::client::Grpc::with_origin(inner, origin); + Self { inner } + } + pub fn with_interceptor( + inner: T, + interceptor: F, + ) -> TrainerClient> + where + F: tonic::service::Interceptor, + T::ResponseBody: Default, + T: tonic::codegen::Service< + http::Request, + Response = http::Response< + >::ResponseBody, + >, + >, + , + >>::Error: Into + Send + Sync, + { + TrainerClient::new(InterceptedService::new(inner, interceptor)) + } + /// Compress requests with the given encoding. + /// + /// This requires the server to support it otherwise it might respond with an + /// error. + #[must_use] + pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.inner = self.inner.send_compressed(encoding); + self + } + /// Enable decompressing responses. + #[must_use] + pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.inner = self.inner.accept_compressed(encoding); + self + } + /// Train trains models of scheduler using dataset. + pub async fn train( + &mut self, + request: impl tonic::IntoStreamingRequest, + ) -> Result, tonic::Status> { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static("/trainer.Trainer/Train"); + self.inner + .client_streaming(request.into_streaming_request(), path, codec) + .await + } + } +} +/// Generated server implementations. +pub mod trainer_server { + #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] + use tonic::codegen::*; + /// Generated trait containing gRPC methods that should be implemented for use with TrainerServer. + #[async_trait] + pub trait Trainer: Send + Sync + 'static { + /// Train trains models of scheduler using dataset. + async fn train( + &self, + request: tonic::Request>, + ) -> Result, tonic::Status>; + } + /// Trainer RPC Service. + #[derive(Debug)] + pub struct TrainerServer { + inner: _Inner, + accept_compression_encodings: EnabledCompressionEncodings, + send_compression_encodings: EnabledCompressionEncodings, + } + struct _Inner(Arc); + impl TrainerServer { + pub fn new(inner: T) -> Self { + Self::from_arc(Arc::new(inner)) + } + pub fn from_arc(inner: Arc) -> Self { + let inner = _Inner(inner); + Self { + inner, + accept_compression_encodings: Default::default(), + send_compression_encodings: Default::default(), + } + } + pub fn with_interceptor( + inner: T, + interceptor: F, + ) -> InterceptedService + where + F: tonic::service::Interceptor, + { + InterceptedService::new(Self::new(inner), interceptor) + } + /// Enable decompressing requests with the given encoding. + #[must_use] + pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.accept_compression_encodings.enable(encoding); + self + } + /// Compress responses with the given encoding, if the client supports it. + #[must_use] + pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.send_compression_encodings.enable(encoding); + self + } + } + impl tonic::codegen::Service> for TrainerServer + where + T: Trainer, + B: Body + Send + 'static, + B::Error: Into + Send + 'static, + { + type Response = http::Response; + type Error = std::convert::Infallible; + type Future = BoxFuture; + fn poll_ready( + &mut self, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + fn call(&mut self, req: http::Request) -> Self::Future { + let inner = self.inner.clone(); + match req.uri().path() { + "/trainer.Trainer/Train" => { + #[allow(non_camel_case_types)] + struct TrainSvc(pub Arc); + impl< + T: Trainer, + > tonic::server::ClientStreamingService + for TrainSvc { + type Response = (); + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request< + tonic::Streaming, + >, + ) -> Self::Future { + let inner = self.0.clone(); + let fut = async move { (*inner).train(request).await }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = TrainSvc(inner); + let codec = tonic::codec::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ); + let res = grpc.client_streaming(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + _ => { + Box::pin(async move { + Ok( + http::Response::builder() + .status(200) + .header("grpc-status", "12") + .header("content-type", "application/grpc") + .body(empty_body()) + .unwrap(), + ) + }) + } + } + } + } + impl Clone for TrainerServer { + fn clone(&self) -> Self { + let inner = self.inner.clone(); + Self { + inner, + accept_compression_encodings: self.accept_compression_encodings, + send_compression_encodings: self.send_compression_encodings, + } + } + } + impl Clone for _Inner { + fn clone(&self) -> Self { + Self(self.0.clone()) + } + } + impl std::fmt::Debug for _Inner { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self.0) + } + } + impl tonic::server::NamedService for TrainerServer { + const NAME: &'static str = "trainer.Trainer"; + } +}