feat: add tfserving service to rpc package (#2210)

Signed-off-by: Gaius <gaius.qi@gmail.com>
This commit is contained in:
Gaius 2023-03-22 20:38:57 +08:00
parent f8934e701b
commit 11c68c5879
No known key found for this signature in database
GPG Key ID: 8B4E5D1290FA2FFB
4 changed files with 180 additions and 7 deletions

View File

@ -38,7 +38,7 @@ const (
contextTimeout = 2 * time.Second contextTimeout = 2 * time.Second
) )
// GetClient returns manager client. // GetClient returns health client.
func GetClient(ctx context.Context, target string, opts ...grpc.DialOption) (Client, error) { func GetClient(ctx context.Context, target string, opts ...grpc.DialOption) (Client, error) {
conn, err := grpc.DialContext( conn, err := grpc.DialContext(
ctx, ctx,
@ -90,7 +90,7 @@ type Client interface {
Close() error Close() error
} }
// client provides manager grpc function. // client provides health grpc function.
type client struct { type client struct {
healthpb.HealthClient healthpb.HealthClient
*grpc.ClientConn *grpc.ClientConn

View File

@ -81,18 +81,18 @@ func GetV1(ctx context.Context, target string, opts ...grpc.DialOption) (V1, err
}, nil }, nil
} }
// GetClientV1ByAddr returns v1 version of the manager client with addresses. // GetClientV1ByAddr returns v1 version of the security client with addresses.
func GetV1ByAddr(ctx context.Context, netAddrs []dfnet.NetAddr, opts ...grpc.DialOption) (V1, error) { func GetV1ByAddr(ctx context.Context, netAddrs []dfnet.NetAddr, opts ...grpc.DialOption) (V1, error) {
for _, netAddr := range netAddrs { for _, netAddr := range netAddrs {
ipReachable := reachable.New(&reachable.Config{Address: netAddr.Addr}) ipReachable := reachable.New(&reachable.Config{Address: netAddr.Addr})
if err := ipReachable.Check(); err == nil { if err := ipReachable.Check(); err == nil {
logger.Infof("use %s address for manager grpc client", netAddr.Addr) logger.Infof("use %s address for security grpc client", netAddr.Addr)
return GetV1(ctx, netAddr.Addr, opts...) return GetV1(ctx, netAddr.Addr, opts...)
} }
logger.Warnf("%s manager address can not reachable", netAddr.Addr) logger.Warnf("%s security address can not reachable", netAddr.Addr)
} }
return nil, errors.New("can not find available manager addresses") return nil, errors.New("can not find available security addresses")
} }
// ClientV1 is the interface for v1 version of the grpc client. // ClientV1 is the interface for v1 version of the grpc client.
@ -104,7 +104,7 @@ type V1 interface {
Close() error Close() error
} }
// clientV1 provides v1 version of the manager grpc function. // clientV1 provides v1 version of the security grpc function.
type v1 struct { type v1 struct {
securityv1.CertificateClient securityv1.CertificateClient
*grpc.ClientConn *grpc.ClientConn

View File

@ -0,0 +1,102 @@
/*
* Copyright 2023 The Dragonfly Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
//go:generate mockgen -destination mocks/client_v1_mock.go -source client_v1.go -package mocks
package client
import (
"context"
"time"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpc_zap "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap"
grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry"
grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"google.golang.org/grpc"
tfservingv1 "d7y.io/api/pkg/apis/tfserving/v1"
logger "d7y.io/dragonfly/v2/internal/dflog"
)
const (
// contextTimeout is timeout of grpc invoke.
contextTimeout = 2 * time.Minute
// maxRetries is maximum number of retries.
maxRetries = 3
// backoffWaitBetween is waiting for a fixed period of
// time between calls in backoff linear.
backoffWaitBetween = 500 * time.Millisecond
)
// GetV1 returns v1 version of the prediction client.
func GetV1(ctx context.Context, target string, opts ...grpc.DialOption) (V1, error) {
conn, err := grpc.DialContext(
ctx,
target,
append([]grpc.DialOption{
grpc.WithUnaryInterceptor(grpc_middleware.ChainUnaryClient(
otelgrpc.UnaryClientInterceptor(),
grpc_prometheus.UnaryClientInterceptor,
grpc_zap.UnaryClientInterceptor(logger.GrpcLogger.Desugar()),
grpc_retry.UnaryClientInterceptor(
grpc_retry.WithMax(maxRetries),
grpc_retry.WithBackoff(grpc_retry.BackoffLinear(backoffWaitBetween)),
),
)),
grpc.WithStreamInterceptor(grpc_middleware.ChainStreamClient(
otelgrpc.StreamClientInterceptor(),
grpc_prometheus.StreamClientInterceptor,
grpc_zap.StreamClientInterceptor(logger.GrpcLogger.Desugar()),
)),
}, opts...)...,
)
if err != nil {
return nil, err
}
return &v1{
PredictionServiceClient: tfservingv1.NewPredictionServiceClient(conn),
ClientConn: conn,
}, nil
}
// ClientV1 is the interface for v1 version of the grpc client.
type V1 interface {
// Predict provides access to loaded TensorFlow model.
Predict(context.Context, *tfservingv1.PredictRequest, ...grpc.CallOption) (*tfservingv1.PredictResponse, error)
// Close tears down the ClientConn and all underlying connections.
Close() error
}
// clientV1 provides v1 version of the prediction grpc function.
type v1 struct {
tfservingv1.PredictionServiceClient
*grpc.ClientConn
}
// Predict provides access to loaded TensorFlow model.
func (v *v1) Predict(ctx context.Context, req *tfservingv1.PredictRequest, opts ...grpc.CallOption) (*tfservingv1.PredictResponse, error) {
ctx, cancel := context.WithTimeout(ctx, contextTimeout)
defer cancel()
return v.PredictionServiceClient.Predict(ctx, req, opts...)
}

View File

@ -0,0 +1,71 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: client_v1.go
// Package mocks is a generated GoMock package.
package mocks
import (
context "context"
reflect "reflect"
tfserving "d7y.io/api/pkg/apis/tfserving/v1"
gomock "github.com/golang/mock/gomock"
grpc "google.golang.org/grpc"
)
// MockV1 is a mock of V1 interface.
type MockV1 struct {
ctrl *gomock.Controller
recorder *MockV1MockRecorder
}
// MockV1MockRecorder is the mock recorder for MockV1.
type MockV1MockRecorder struct {
mock *MockV1
}
// NewMockV1 creates a new mock instance.
func NewMockV1(ctrl *gomock.Controller) *MockV1 {
mock := &MockV1{ctrl: ctrl}
mock.recorder = &MockV1MockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockV1) EXPECT() *MockV1MockRecorder {
return m.recorder
}
// Close mocks base method.
func (m *MockV1) Close() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close.
func (mr *MockV1MockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockV1)(nil).Close))
}
// Predict mocks base method.
func (m *MockV1) Predict(arg0 context.Context, arg1 *tfserving.PredictRequest, arg2 ...grpc.CallOption) (*tfserving.PredictResponse, error) {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Predict", varargs...)
ret0, _ := ret[0].(*tfserving.PredictResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Predict indicates an expected call of Predict.
func (mr *MockV1MockRecorder) Predict(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Predict", reflect.TypeOf((*MockV1)(nil).Predict), varargs...)
}