diff --git a/pkg/rpc/scheduler/client/client_v2.go b/pkg/rpc/scheduler/client/client_v2.go index 1c7debd76..5daf422d5 100644 --- a/pkg/rpc/scheduler/client/client_v2.go +++ b/pkg/rpc/scheduler/client/client_v2.go @@ -53,7 +53,6 @@ func GetV2(ctx context.Context, dynconfig config.Dynconfig, opts ...grpc.DialOpt append([]grpc.DialOption{ grpc.WithDefaultServiceConfig(pkgbalancer.BalancerServiceConfig), grpc.WithUnaryInterceptor(grpc_middleware.ChainUnaryClient( - rpc.ConvertErrorUnaryClientInterceptor, otelgrpc.UnaryClientInterceptor(), grpc_prometheus.UnaryClientInterceptor, grpc_zap.UnaryClientInterceptor(logger.GrpcLogger.Desugar()), @@ -64,7 +63,6 @@ func GetV2(ctx context.Context, dynconfig config.Dynconfig, opts ...grpc.DialOpt rpc.RefresherUnaryClientInterceptor(dynconfig), )), grpc.WithStreamInterceptor(grpc_middleware.ChainStreamClient( - rpc.ConvertErrorStreamClientInterceptor, otelgrpc.StreamClientInterceptor(), grpc_prometheus.StreamClientInterceptor, grpc_zap.StreamClientInterceptor(logger.GrpcLogger.Desugar()), @@ -93,7 +91,6 @@ func GetV2ByAddr(ctx context.Context, target string, opts ...grpc.DialOption) (V append([]grpc.DialOption{ grpc.WithDefaultServiceConfig(pkgbalancer.BalancerServiceConfig), grpc.WithUnaryInterceptor(grpc_middleware.ChainUnaryClient( - rpc.ConvertErrorUnaryClientInterceptor, otelgrpc.UnaryClientInterceptor(), grpc_prometheus.UnaryClientInterceptor, grpc_zap.UnaryClientInterceptor(logger.GrpcLogger.Desugar()), @@ -103,7 +100,6 @@ func GetV2ByAddr(ctx context.Context, target string, opts ...grpc.DialOption) (V ), )), grpc.WithStreamInterceptor(grpc_middleware.ChainStreamClient( - rpc.ConvertErrorStreamClientInterceptor, otelgrpc.StreamClientInterceptor(), grpc_prometheus.StreamClientInterceptor, grpc_zap.StreamClientInterceptor(logger.GrpcLogger.Desugar()), diff --git a/pkg/rpc/trainer/client/client_v1.go b/pkg/rpc/trainer/client/client_v1.go new file mode 100644 index 000000000..56536f12f --- /dev/null +++ b/pkg/rpc/trainer/client/client_v1.go @@ -0,0 +1,96 @@ +/* + * Copyright 2023 The Dragonfly Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//go:generate mockgen -destination mocks/client_v1_mock.go -source client_v1.go -package mocks + +package client + +import ( + "context" + "time" + + grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" + grpc_zap "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap" + grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry" + grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus" + "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" + "google.golang.org/grpc" + + trainerv1 "d7y.io/api/pkg/apis/trainer/v1" + + logger "d7y.io/dragonfly/v2/internal/dflog" +) + +const ( + // maxRetries is maximum number of retries. + maxRetries = 3 + + // backoffWaitBetween is waiting for a fixed period of + // time between calls in backoff linear. + backoffWaitBetween = 500 * time.Millisecond +) + +// GetV1ByAddr returns v1 version of the trainer client by address. +func GetV1ByAddr(ctx context.Context, target string, opts ...grpc.DialOption) (V1, error) { + conn, err := grpc.DialContext( + ctx, + target, + append([]grpc.DialOption{ + grpc.WithUnaryInterceptor(grpc_middleware.ChainUnaryClient( + otelgrpc.UnaryClientInterceptor(), + grpc_prometheus.UnaryClientInterceptor, + grpc_zap.UnaryClientInterceptor(logger.GrpcLogger.Desugar()), + grpc_retry.UnaryClientInterceptor( + grpc_retry.WithMax(maxRetries), + grpc_retry.WithBackoff(grpc_retry.BackoffLinear(backoffWaitBetween)), + ), + )), + grpc.WithStreamInterceptor(grpc_middleware.ChainStreamClient( + otelgrpc.StreamClientInterceptor(), + grpc_prometheus.StreamClientInterceptor, + grpc_zap.StreamClientInterceptor(logger.GrpcLogger.Desugar()), + )), + }, opts...)..., + ) + if err != nil { + return nil, err + } + + return &v1{ + TrainerClient: trainerv1.NewTrainerClient(conn), + ClientConn: conn, + }, nil +} + +// V1 is the interface for v1 version of the grpc client. +type V1 interface { + // Train models of scheduler using dataset. + Train(context.Context, ...grpc.CallOption) (trainerv1.Trainer_TrainClient, error) + + // Close tears down the ClientConn and all underlying connections. + Close() error +} + +// v1 provides v1 version of the trainer grpc function. +type v1 struct { + trainerv1.TrainerClient + *grpc.ClientConn +} + +// Train models of scheduler using dataset. +func (v *v1) Train(ctx context.Context, opts ...grpc.CallOption) (trainerv1.Trainer_TrainClient, error) { + return v.TrainerClient.Train(ctx, opts...) +} diff --git a/pkg/rpc/trainer/client/mocks/client_v1_mock.go b/pkg/rpc/trainer/client/mocks/client_v1_mock.go new file mode 100644 index 000000000..80fbe31e0 --- /dev/null +++ b/pkg/rpc/trainer/client/mocks/client_v1_mock.go @@ -0,0 +1,71 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: client_v1.go + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + trainer "d7y.io/api/pkg/apis/trainer/v1" + gomock "github.com/golang/mock/gomock" + grpc "google.golang.org/grpc" +) + +// MockV1 is a mock of V1 interface. +type MockV1 struct { + ctrl *gomock.Controller + recorder *MockV1MockRecorder +} + +// MockV1MockRecorder is the mock recorder for MockV1. +type MockV1MockRecorder struct { + mock *MockV1 +} + +// NewMockV1 creates a new mock instance. +func NewMockV1(ctrl *gomock.Controller) *MockV1 { + mock := &MockV1{ctrl: ctrl} + mock.recorder = &MockV1MockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockV1) EXPECT() *MockV1MockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockV1) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockV1MockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockV1)(nil).Close)) +} + +// Train mocks base method. +func (m *MockV1) Train(arg0 context.Context, arg1 ...grpc.CallOption) (trainer.Trainer_TrainClient, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Train", varargs...) + ret0, _ := ret[0].(trainer.Trainer_TrainClient) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Train indicates an expected call of Train. +func (mr *MockV1MockRecorder) Train(arg0 interface{}, arg1 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Train", reflect.TypeOf((*MockV1)(nil).Train), varargs...) +} diff --git a/pkg/rpc/trainer/server/server.go b/pkg/rpc/trainer/server/server.go new file mode 100644 index 000000000..0be92fb3b --- /dev/null +++ b/pkg/rpc/trainer/server/server.go @@ -0,0 +1,95 @@ +/* + * Copyright 2023 The Dragonfly Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package server + +import ( + "time" + + grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" + grpc_zap "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap" + grpc_ratelimit "github.com/grpc-ecosystem/go-grpc-middleware/ratelimit" + grpc_recovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery" + grpc_validator "github.com/grpc-ecosystem/go-grpc-middleware/validator" + grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus" + "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" + "google.golang.org/grpc" + "google.golang.org/grpc/health" + healthpb "google.golang.org/grpc/health/grpc_health_v1" + "google.golang.org/grpc/keepalive" + "google.golang.org/grpc/reflection" + + trainerv1 "d7y.io/api/pkg/apis/trainer/v1" + + logger "d7y.io/dragonfly/v2/internal/dflog" + "d7y.io/dragonfly/v2/pkg/rpc" +) + +const ( + // DefaultQPS is default qps of grpc server. + DefaultQPS = 10 * 1000 + + // DefaultBurst is default burst of grpc server. + DefaultBurst = 20 * 1000 + + // DefaultMaxConnectionIdle is default max connection idle of grpc keepalive. + DefaultMaxConnectionIdle = 10 * time.Minute + + // DefaultMaxConnectionAge is default max connection age of grpc keepalive. + DefaultMaxConnectionAge = 12 * time.Hour + + // DefaultMaxConnectionAgeGrace is default max connection age grace of grpc keepalive. + DefaultMaxConnectionAgeGrace = 5 * time.Minute +) + +// New returns grpc server instance and register service on grpc server. +func New(trainerServerV1 trainerv1.TrainerServer, opts ...grpc.ServerOption) *grpc.Server { + limiter := rpc.NewRateLimiterInterceptor(DefaultQPS, DefaultBurst) + + grpcServer := grpc.NewServer(append([]grpc.ServerOption{ + grpc.KeepaliveParams(keepalive.ServerParameters{ + MaxConnectionIdle: DefaultMaxConnectionIdle, + MaxConnectionAge: DefaultMaxConnectionAge, + MaxConnectionAgeGrace: DefaultMaxConnectionAgeGrace, + }), + grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer( + grpc_ratelimit.UnaryServerInterceptor(limiter), + otelgrpc.UnaryServerInterceptor(), + grpc_prometheus.UnaryServerInterceptor, + grpc_zap.UnaryServerInterceptor(logger.GrpcLogger.Desugar()), + grpc_validator.UnaryServerInterceptor(), + grpc_recovery.UnaryServerInterceptor(), + )), + grpc.StreamInterceptor(grpc_middleware.ChainStreamServer( + grpc_ratelimit.StreamServerInterceptor(limiter), + otelgrpc.StreamServerInterceptor(), + grpc_prometheus.StreamServerInterceptor, + grpc_zap.StreamServerInterceptor(logger.GrpcLogger.Desugar()), + grpc_validator.StreamServerInterceptor(), + grpc_recovery.StreamServerInterceptor(), + )), + }, opts...)...) + + // Register servers on v1 version of the grpc server. + trainerv1.RegisterTrainerServer(grpcServer, trainerServerV1) + + // Register health on grpc server. + healthpb.RegisterHealthServer(grpcServer, health.NewServer()) + + // Register reflection on grpc server. + reflection.Register(grpcServer) + return grpcServer +}