client-go/tikvrpc/interceptor/interceptor.go

272 lines
8.5 KiB
Go

// Copyright 2021 TiKV Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package interceptor
import (
"context"
"sync/atomic"
"github.com/tikv/client-go/v2/tikvrpc"
)
// RPCInterceptor is used to decorate the RPC requests to TiKV.
//
// The definition of an interceptor is: Given an RPCInterceptorFunc, we will
// get the decorated RPCInterceptorFunc with additional logic before and after
// the execution of the given RPCInterceptorFunc.
//
// The decorated RPCInterceptorFunc will be executed before and after the real
// RPC request is initiated to TiKV.
//
// We can implement an RPCInterceptor like this:
// ```
//
// func LogInterceptor(next InterceptorFunc) RPCInterceptorFunc {
// return func(target string, req *tikvrpc.Request) (*tikvrpc.Response, error) {
// log.Println("before")
// resp, err := next(target, req)
// log.Println("after")
// return resp, err
// }
// }
//
// txn.SetRPCInterceptor(NewRPCInterceptor("log", LogInterceptor))
// ```
//
// Or you want to inject some dependent modules:
// ```
//
// func GetLogInterceptor(lg *log.Logger) RPCInterceptor {
// return func(next RPCInterceptorFunc) RPCInterceptorFunc {
// return func(target string, req *tikvrpc.Request) (*tikvrpc.Response, error) {
// lg.Println("before")
// resp, err := next(target, req)
// lg.Println("after")
// return resp, err
// }
// }
// }
//
// txn.SetRPCInterceptor(NewRPCInterceptor("log", GetLogInterceptor()))
// ```
//
// NOTE: Interceptor calls may not correspond one-to-one with the underlying gRPC requests.
// This is because there may be some exceptions, such as: request batched, no
// valid connection etc. If you have questions about the execution location of
// RPCInterceptor, please refer to:
//
// tikv/kv.go#NewKVStore()
// internal/client/client_interceptor.go#SendRequest.
type RPCInterceptor interface {
// Name returns the name of this interceptor
Name() string
// Wrap returns a callable interecpt function.
Wrap(next RPCInterceptorFunc) RPCInterceptorFunc
}
type rpcInterceptorWrapper struct {
name string
fn func(next RPCInterceptorFunc) RPCInterceptorFunc
}
func (i *rpcInterceptorWrapper) Name() string {
return i.name
}
func (i *rpcInterceptorWrapper) Wrap(next RPCInterceptorFunc) RPCInterceptorFunc {
return i.fn(next)
}
// NewRPCInterceptor build a RPCInterceptor by its name and intercept func.
func NewRPCInterceptor(name string, fn func(next RPCInterceptorFunc) RPCInterceptorFunc) RPCInterceptor {
return &rpcInterceptorWrapper{
name: name,
fn: fn,
}
}
// RPCInterceptorFunc is a callable function used to initiate a request to TiKV.
// It is mainly used as the parameter and return value of RPCInterceptor.
type RPCInterceptorFunc func(target string, req *tikvrpc.Request) (*tikvrpc.Response, error)
// RPCInterceptorChain is used to combine multiple interceptors into one.
// Multiple interceptors will be executed in the order of link time, but are more
// similar to the onion model: The earlier the interceptor is executed, the later
// it will return.
//
// If multiple interceptors with the same name is added to the chain, only the last
// will be kept.
//
// We can use RPCInterceptorChain like this:
// ```
//
// func Interceptor1(next InterceptorFunc) RPCInterceptorFunc {
// return func(target string, req *tikvrpc.Request) (*tikvrpc.Response, error) {
// fmt.Println("begin-interceptor-1")
// defer fmt.Println("end-interceptor-1")
// return next(target, req)
// }
// }
//
// func Interceptor2(next InterceptorFunc) RPCInterceptorFunc {
// return func(target string, req *tikvrpc.Request) (*tikvrpc.Response, error) {
// fmt.Println("begin-interceptor-2")
// defer fmt.Println("end-interceptor-2")
// return next(target, req)
// }
// }
//
// txn.SetRPCInterceptor(NewRPCInterceptorChain().Link(NewRPCInterceptor("log1", Interceptor1)).Link(NewRPCInterceptor("log2", Interceptor2)).Build())
//
// ```
//
// Then every time an RPC request is initiated, the following text will be printed:
// ```
// begin-interceptor-1
// begin-interceptor-2
// /* do request & respond here */
// end-interceptor-2
// end-interceptor-1
// ```
type RPCInterceptorChain struct {
chain []RPCInterceptor
}
// return the number of sub interceptors, used for test.
func (c *RPCInterceptorChain) Len() int {
return len(c.chain)
}
// NewRPCInterceptorChain creates an empty RPCInterceptorChain.
func NewRPCInterceptorChain() *RPCInterceptorChain {
return &RPCInterceptorChain{}
}
// Link is used to link the next RPCInterceptor.
// Multiple interceptors will be executed in the order of link time.
// If multiple interceptors with the same name is added to the chain,
// only the last is kept.
func (c *RPCInterceptorChain) Link(it RPCInterceptor) *RPCInterceptorChain {
if chain, ok := it.(*RPCInterceptorChain); ok {
for _, i := range chain.chain {
c.Link(i)
}
return c
}
for i := range c.chain {
if c.chain[i].Name() == it.Name() {
c.chain = append(c.chain[:i], c.chain[i+1:]...)
break
}
}
c.chain = append(c.chain, it)
return c
}
func (c *RPCInterceptorChain) Name() string {
return "interceptor-chain"
}
func (c *RPCInterceptorChain) Wrap(next RPCInterceptorFunc) RPCInterceptorFunc {
for n := len(c.chain) - 1; n >= 0; n-- {
next = c.chain[n].Wrap(next)
}
return next
}
// ChainRPCInterceptors chains multiple RPCInterceptors into one.
// Multiple RPCInterceptors will be executed in the order of their parameters.
// See RPCInterceptorChain for more information.
func ChainRPCInterceptors(first RPCInterceptor, rest ...RPCInterceptor) RPCInterceptor {
// always build a new interceptor to avoid potential data race
chain := NewRPCInterceptorChain()
chain.Link(first)
for _, it := range rest {
chain.Link(it)
}
return chain
}
type interceptorCtxKeyType struct{}
var interceptorCtxKey = interceptorCtxKeyType{}
// WithRPCInterceptor is a helper function used to bind RPCInterceptor with ctx.
func WithRPCInterceptor(ctx context.Context, interceptor RPCInterceptor) context.Context {
if v := ctx.Value(interceptorCtxKey); v != nil {
interceptor = ChainRPCInterceptors(v.(RPCInterceptor), interceptor)
}
return context.WithValue(ctx, interceptorCtxKey, interceptor)
}
// GetRPCInterceptorFromCtx gets the RPCInterceptor bound by the previous call
// to WithRPCInterceptor, and returns nil if there is none.
func GetRPCInterceptorFromCtx(ctx context.Context) RPCInterceptor {
if v := ctx.Value(interceptorCtxKey); v != nil {
return v.(RPCInterceptor)
}
return nil
}
/* Suite for testing */
// MockInterceptorManager can be used to create Interceptor and record the
// number of executions of the created Interceptor.
type MockInterceptorManager struct {
begin int32
end int32
execLog []string // interceptor name list
}
// NewMockInterceptorManager creates an empty MockInterceptorManager.
func NewMockInterceptorManager() *MockInterceptorManager {
return &MockInterceptorManager{execLog: []string{}}
}
// CreateMockInterceptor creates an RPCInterceptor for testing.
func (m *MockInterceptorManager) CreateMockInterceptor(name string) RPCInterceptor {
fn := func(next RPCInterceptorFunc) RPCInterceptorFunc {
return func(target string, req *tikvrpc.Request) (*tikvrpc.Response, error) {
m.execLog = append(m.execLog, name)
atomic.AddInt32(&m.begin, 1)
defer atomic.AddInt32(&m.end, 1)
return next(target, req)
}
}
return NewRPCInterceptor(name, fn)
}
// Reset clear all counters.
func (m *MockInterceptorManager) Reset() {
atomic.StoreInt32(&m.begin, 0)
atomic.StoreInt32(&m.end, 0)
m.execLog = []string{}
}
// BeginCount gets how many times the previously created Interceptor has been executed.
func (m *MockInterceptorManager) BeginCount() int {
return int(atomic.LoadInt32(&m.begin))
}
// EndCount gets how many times the previously created Interceptor has been returned.
func (m *MockInterceptorManager) EndCount() int {
return int(atomic.LoadInt32(&m.end))
}
// ExecLog gets execution log of all interceptors.
func (m *MockInterceptorManager) ExecLog() []string {
return m.execLog
}