diff --git a/pkg/rpc/health/client/client.go b/pkg/rpc/health/client/client.go index d40d92a97..686415da4 100644 --- a/pkg/rpc/health/client/client.go +++ b/pkg/rpc/health/client/client.go @@ -38,7 +38,7 @@ const ( contextTimeout = 2 * time.Second ) -// GetClient returns manager client. +// GetClient returns health client. func GetClient(ctx context.Context, target string, opts ...grpc.DialOption) (Client, error) { conn, err := grpc.DialContext( ctx, @@ -90,7 +90,7 @@ type Client interface { Close() error } -// client provides manager grpc function. +// client provides health grpc function. type client struct { healthpb.HealthClient *grpc.ClientConn diff --git a/pkg/rpc/security/client/client_v1.go b/pkg/rpc/security/client/client_v1.go index 73a33ac92..aad1f7905 100644 --- a/pkg/rpc/security/client/client_v1.go +++ b/pkg/rpc/security/client/client_v1.go @@ -81,18 +81,18 @@ func GetV1(ctx context.Context, target string, opts ...grpc.DialOption) (V1, err }, nil } -// GetClientV1ByAddr returns v1 version of the manager client with addresses. +// GetClientV1ByAddr returns v1 version of the security client with addresses. func GetV1ByAddr(ctx context.Context, netAddrs []dfnet.NetAddr, opts ...grpc.DialOption) (V1, error) { for _, netAddr := range netAddrs { ipReachable := reachable.New(&reachable.Config{Address: netAddr.Addr}) if err := ipReachable.Check(); err == nil { - logger.Infof("use %s address for manager grpc client", netAddr.Addr) + logger.Infof("use %s address for security grpc client", netAddr.Addr) return GetV1(ctx, netAddr.Addr, opts...) } - logger.Warnf("%s manager address can not reachable", netAddr.Addr) + logger.Warnf("%s security address can not reachable", netAddr.Addr) } - return nil, errors.New("can not find available manager addresses") + return nil, errors.New("can not find available security addresses") } // ClientV1 is the interface for v1 version of the grpc client. @@ -104,7 +104,7 @@ type V1 interface { Close() error } -// clientV1 provides v1 version of the manager grpc function. +// clientV1 provides v1 version of the security grpc function. type v1 struct { securityv1.CertificateClient *grpc.ClientConn diff --git a/pkg/rpc/tfserving/client/client_v1.go b/pkg/rpc/tfserving/client/client_v1.go new file mode 100644 index 000000000..f3115e837 --- /dev/null +++ b/pkg/rpc/tfserving/client/client_v1.go @@ -0,0 +1,102 @@ +/* + * Copyright 2023 The Dragonfly Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//go:generate mockgen -destination mocks/client_v1_mock.go -source client_v1.go -package mocks + +package client + +import ( + "context" + "time" + + grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" + grpc_zap "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap" + grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry" + grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus" + "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" + "google.golang.org/grpc" + + tfservingv1 "d7y.io/api/pkg/apis/tfserving/v1" + + logger "d7y.io/dragonfly/v2/internal/dflog" +) + +const ( + // contextTimeout is timeout of grpc invoke. + contextTimeout = 2 * time.Minute + + // maxRetries is maximum number of retries. + maxRetries = 3 + + // backoffWaitBetween is waiting for a fixed period of + // time between calls in backoff linear. + backoffWaitBetween = 500 * time.Millisecond +) + +// GetV1 returns v1 version of the prediction client. +func GetV1(ctx context.Context, target string, opts ...grpc.DialOption) (V1, error) { + conn, err := grpc.DialContext( + ctx, + target, + append([]grpc.DialOption{ + grpc.WithUnaryInterceptor(grpc_middleware.ChainUnaryClient( + otelgrpc.UnaryClientInterceptor(), + grpc_prometheus.UnaryClientInterceptor, + grpc_zap.UnaryClientInterceptor(logger.GrpcLogger.Desugar()), + grpc_retry.UnaryClientInterceptor( + grpc_retry.WithMax(maxRetries), + grpc_retry.WithBackoff(grpc_retry.BackoffLinear(backoffWaitBetween)), + ), + )), + grpc.WithStreamInterceptor(grpc_middleware.ChainStreamClient( + otelgrpc.StreamClientInterceptor(), + grpc_prometheus.StreamClientInterceptor, + grpc_zap.StreamClientInterceptor(logger.GrpcLogger.Desugar()), + )), + }, opts...)..., + ) + if err != nil { + return nil, err + } + + return &v1{ + PredictionServiceClient: tfservingv1.NewPredictionServiceClient(conn), + ClientConn: conn, + }, nil +} + +// ClientV1 is the interface for v1 version of the grpc client. +type V1 interface { + // Predict provides access to loaded TensorFlow model. + Predict(context.Context, *tfservingv1.PredictRequest, ...grpc.CallOption) (*tfservingv1.PredictResponse, error) + + // Close tears down the ClientConn and all underlying connections. + Close() error +} + +// clientV1 provides v1 version of the prediction grpc function. +type v1 struct { + tfservingv1.PredictionServiceClient + *grpc.ClientConn +} + +// Predict provides access to loaded TensorFlow model. +func (v *v1) Predict(ctx context.Context, req *tfservingv1.PredictRequest, opts ...grpc.CallOption) (*tfservingv1.PredictResponse, error) { + ctx, cancel := context.WithTimeout(ctx, contextTimeout) + defer cancel() + + return v.PredictionServiceClient.Predict(ctx, req, opts...) +} diff --git a/pkg/rpc/tfserving/client/mocks/client_v1_mock.go b/pkg/rpc/tfserving/client/mocks/client_v1_mock.go new file mode 100644 index 000000000..5929d96b4 --- /dev/null +++ b/pkg/rpc/tfserving/client/mocks/client_v1_mock.go @@ -0,0 +1,71 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: client_v1.go + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + tfserving "d7y.io/api/pkg/apis/tfserving/v1" + gomock "github.com/golang/mock/gomock" + grpc "google.golang.org/grpc" +) + +// MockV1 is a mock of V1 interface. +type MockV1 struct { + ctrl *gomock.Controller + recorder *MockV1MockRecorder +} + +// MockV1MockRecorder is the mock recorder for MockV1. +type MockV1MockRecorder struct { + mock *MockV1 +} + +// NewMockV1 creates a new mock instance. +func NewMockV1(ctrl *gomock.Controller) *MockV1 { + mock := &MockV1{ctrl: ctrl} + mock.recorder = &MockV1MockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockV1) EXPECT() *MockV1MockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockV1) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockV1MockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockV1)(nil).Close)) +} + +// Predict mocks base method. +func (m *MockV1) Predict(arg0 context.Context, arg1 *tfserving.PredictRequest, arg2 ...grpc.CallOption) (*tfserving.PredictResponse, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Predict", varargs...) + ret0, _ := ret[0].(*tfserving.PredictResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Predict indicates an expected call of Predict. +func (mr *MockV1MockRecorder) Predict(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Predict", reflect.TypeOf((*MockV1)(nil).Predict), varargs...) +}