diff --git a/integration_tests/go.mod b/integration_tests/go.mod index 1938ae0f..1651cf0d 100644 --- a/integration_tests/go.mod +++ b/integration_tests/go.mod @@ -46,7 +46,6 @@ require ( github.com/google/btree v1.1.2 // indirect github.com/google/uuid v1.3.0 // indirect github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 // indirect - github.com/influxdata/tdigest v0.0.1 // indirect github.com/klauspost/compress v1.16.5 // indirect github.com/klauspost/cpuid v1.3.1 // indirect github.com/kr/pretty v0.3.1 // indirect @@ -106,6 +105,9 @@ require ( gopkg.in/yaml.v3 v3.0.1 // indirect ) -replace github.com/tikv/client-go/v2 => ../ - -replace github.com/go-ldap/ldap/v3 => github.com/YangKeao/ldap/v3 v3.4.5-0.20230421065457-369a3bab1117 +replace ( + github.com/go-ldap/ldap/v3 => github.com/YangKeao/ldap/v3 v3.4.5-0.20230421065457-369a3bab1117 + github.com/pingcap/tidb => github.com/glorv/tidb v1.1.0-beta.0.20230609065903-a93eafb17c59 + github.com/pingcap/tidb/parser => github.com/glorv/tidb/parser v0.0.0-20230609065903-a93eafb17c59 + github.com/tikv/client-go/v2 => ../ +) diff --git a/integration_tests/go.sum b/integration_tests/go.sum index f26a4d43..cf19253c 100644 --- a/integration_tests/go.sum +++ b/integration_tests/go.sum @@ -6,13 +6,12 @@ cloud.google.com/go/compute/metadata v0.2.3 h1:mg4jlk7mCAj6xXp9UJ4fjI9VUI5rubuGB cloud.google.com/go/iam v0.13.0 h1:+CmB+K0J/33d0zSQ9SlFWUeCCEn5XJA0ZMZ3pHE9u8k= cloud.google.com/go/storage v1.30.1 h1:uOdMxAs8HExqBlnLtnQyP0YkvbiDpdGShGKtx6U/oNM= github.com/AndreasBriese/bbloom v0.0.0-20190306092124-e2d15f34fcf9/go.mod h1:bOvUY6CB00SOBii9/FifXqc0awNKxLFCL/+pkDPuyl8= -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.0 h1:8kDqDngH+DmVBiCtIjCFTGa7MBnsIOkF9IccInFEbjk= -github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.1.0 h1:QkAcEIAKbNL4KoFr4SathZPhDhF4mVwpBMFlYjyAqy8= -github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0 h1:sXr+ck84g/ZlZUOZiNELInmMgOsuGwdjjVkEIde0OtY= -github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.0.0 h1:u/LLAOFgsMv7HmNL4Qufg58y+qElGOt5qv0z1mURkRY= +github.com/Azure/azure-sdk-for-go/sdk/azcore v0.20.0 h1:KQgdWmEOmaJKxaUUZwHAYh12t+b+ZJf8q3friycK1kA= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.12.0 h1:VBvHGLJbaY0+c66NZHdS9cgjHVYSH6DDa0XJMyrblsI= +github.com/Azure/azure-sdk-for-go/sdk/internal v0.8.1 h1:BUYIbDf/mMZ8945v3QkG3OuqGVyS4Iek0AOLwdRAYoc= +github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v0.2.0 h1:62Ew5xXg5UCGIXDOM7+y4IL5/6mQJq1nenhBCJAeGX8= github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 h1:mFRzDkZVAjdal+s7s0MwaRv9igoPqLRdzOLzw/8Xvq8= github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358/go.mod h1:chxPXzSsl7ZWRAuOIE23GDNzjWuZquvFlgA8xmpunjU= -github.com/AzureAD/microsoft-authentication-library-for-go v0.5.1 h1:BWe8a+f/t+7KY7zH2mqygeUD0t8hNFXe08p1Pb3/jKE= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/toml v1.3.0 h1:Ws8e5YmnrGEHzZEzg0YvK/7COGYtTC5PbaH9oSSbgfA= github.com/BurntSushi/toml v1.3.0/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= @@ -129,6 +128,10 @@ github.com/gavv/httpexpect v2.0.0+incompatible/go.mod h1:x+9tiU1YnrOvnB725RkpoLv github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/gin-contrib/sse v0.0.0-20190301062529-5545eab6dad3/go.mod h1:VJ0WA2NBN22VlZ2dKZQPAPnyWw5XTlK1KymzLKsr59s= github.com/gin-gonic/gin v1.4.0/go.mod h1:OW2EZn3DO8Ln9oIKOvM++LBO+5UPHJJDH72/q/3rZdM= +github.com/glorv/tidb v1.1.0-beta.0.20230609065903-a93eafb17c59 h1:jQiiLD+Bil+D7q6XVO/3HTOMtMVV4UDuKxWZ2NkUzOI= +github.com/glorv/tidb v1.1.0-beta.0.20230609065903-a93eafb17c59/go.mod h1:yILd0+97vhHjMBvFGVLiLLE+m6b6lsZXbJrKIPpWE1s= +github.com/glorv/tidb/parser v0.0.0-20230609065903-a93eafb17c59 h1:9MRpDN9FTpI+WHCCGW6SnXrLvczJzDP8UcuMT4RIYJY= +github.com/glorv/tidb/parser v0.0.0-20230609065903-a93eafb17c59/go.mod h1:F6gt/zER0apYbw9dx1esEW1jlHTHmRi0qRk75yWH7ak= github.com/go-asn1-ber/asn1-ber v1.5.4 h1:vXT6d/FNDiELJnLb6hGNa309LMsrCoYFvpwHDF0+Y1A= github.com/go-asn1-ber/asn1-ber v1.5.4/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0= github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclKsC9YodN5RgxqK/VD9HM9JsCSh7rNhMZE98= @@ -161,7 +164,6 @@ github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXP github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/gogo/status v1.1.0/go.mod h1:BFv9nrluPLmrS0EmGVvLaPNmRosr9KapBYd5/hpY1WM= -github.com/golang-jwt/jwt v3.2.1+incompatible h1:73Z+4BJcrTC+KczS6WvTPvRGOp1WmfEP4Q1lOd9Z/+c= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/glog v1.1.1 h1:jxpi2eWoU84wbX9iIEyAeeoac3FLuifZpY9tcNUD9kw= github.com/golang/glog v1.1.1/go.mod h1:zR+okUeTbrL6EL3xHUDxZuEtGv04p5shwip1+mL/rLQ= @@ -229,8 +231,6 @@ github.com/iancoleman/strcase v0.2.0 h1:05I4QRnGpI0m37iZQRuskXh+w77mr6Z41lwQzuHL github.com/imkira/go-interpol v1.1.0/go.mod h1:z0h2/2T3XF8kyEPpRgJ3kmNv+C43p+I/CoI+jC3w2iA= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= -github.com/influxdata/tdigest v0.0.1 h1:XpFptwYmnEKUqmkcDjrzffswZ3nvNeevbUSLPP/ZzIY= -github.com/influxdata/tdigest v0.0.1/go.mod h1:Z0kXnxzbTC2qrx4NaIzYkE1k66+6oEDQTvL95hQFh5Y= github.com/iris-contrib/blackfriday v2.0.0+incompatible/go.mod h1:UzZ2bDEoaSGPbkg6SAB4att1aAwTmVIx/5gCVqeyUdI= github.com/iris-contrib/go.uuid v2.0.0+incompatible/go.mod h1:iz2lgM/1UnEf1kP0L/+fafWORmlnuysV2EMP8MW+qe0= github.com/iris-contrib/i18n v0.0.0-20171121225848-987a633949d0/go.mod h1:pMCz62A0xJL6I+umB2YTlFRwWXaDFA0jy+5HzGiJjqI= @@ -277,7 +277,6 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/labstack/echo/v4 v4.1.11/go.mod h1:i541M3Fj6f76NZtHSj7TXnyM8n2gaodfvfxNnFqi74g= github.com/labstack/gommon v0.3.0/go.mod h1:MULnywXg0yavhxWKc+lOruYdAhDwPK9wf0OL7NoOu+k= github.com/lestrrat-go/blackmagic v1.0.1 h1:lS5Zts+5HIC/8og6cGHb0uCcNCa3OUt1ygh3Qz2Fe80= @@ -365,13 +364,9 @@ github.com/pingcap/log v1.1.1-0.20230317032135-a0d097d16e22 h1:2SOzvGvE8beiC1Y4g github.com/pingcap/log v1.1.1-0.20230317032135-a0d097d16e22/go.mod h1:DWQW5jICDR7UJh4HtxXSM20Churx4CQL0fwL/SoOSA4= github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21 h1:QV6jqlfOkh8hqvEAgwBZa+4bSgO0EeKC7s5c6Luam2I= github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21/go.mod h1:QYnjfA95ZaMefyl1NO8oPtKeb8pYUdnDVhQgf+qdpjM= -github.com/pingcap/tidb v1.1.0-beta.0.20230613044126-1dd16717fefb h1:D0o8qu/ZqGup2ezoO2sQYtThJyA25PQqlaJXeiDEpjM= -github.com/pingcap/tidb v1.1.0-beta.0.20230613044126-1dd16717fefb/go.mod h1:Uq2EbRXlDCEuVrKeTSvbCsmltsaGcD//jh1Bw5bRg4Q= -github.com/pingcap/tidb/parser v0.0.0-20230613044126-1dd16717fefb h1:BVgC0QeZjZ6fIJMqrOOeJ39I4EsY/AAO5luSq2MIzhY= -github.com/pingcap/tidb/parser v0.0.0-20230613044126-1dd16717fefb/go.mod h1:F6gt/zER0apYbw9dx1esEW1jlHTHmRi0qRk75yWH7ak= github.com/pingcap/tipb v0.0.0-20230602100112-acb7942db1ca h1:J2HQyR5v1AcoBzx5/AYJW9XFSIl6si6YoC6yGI1W89c= github.com/pingcap/tipb v0.0.0-20230602100112-acb7942db1ca/go.mod h1:A7mrd7WHBl1o63LE2bIBGEJMTNWXqhgmYiOvMLxozfs= -github.com/pkg/browser v0.0.0-20210115035449-ce105d075bb4 h1:Qj1ukM4GlMWXNdMBuXcXfz/Kw9s1qm0CLY32QxuSImI= +github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4 h1:49lOXmGaUpV9Fz3gd7TFZY106KVlPVa5jcYD1gaQf98= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -575,7 +570,6 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= -golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20230519143937-03e91628a987 h1:3xJIFvzUFbu4ls0BTBYcgbCGhA63eAOEMxIHugyXJqA= golang.org/x/exp v0.0.0-20230519143937-03e91628a987/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= @@ -696,7 +690,6 @@ golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxb golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20181221001348-537d06c36207/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -726,9 +719,6 @@ golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 h1:H2TDz8ibqkAF6YGhCdN3jS9O0/s90v0rJh3X/OLHEUk= -gonum.org/v1/gonum v0.0.0-20181121035319-3f7ecaa7e8ca/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo= -gonum.org/v1/gonum v0.8.2 h1:CCXrcPKiGGotvnN6jfUsKk4rRqm7q09/YbKb5xCEvtM= -gonum.org/v1/netlib v0.0.0-20181029234149-ec6d1f5cefe6/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw= google.golang.org/api v0.114.0 h1:1xQPji6cO2E2vLiI+C/XiFAnsn1WV3mjaEwGLhi3grE= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= diff --git a/internal/client/client_interceptor.go b/internal/client/client_interceptor.go index 3f3f6b62..e4b31b51 100644 --- a/internal/client/client_interceptor.go +++ b/internal/client/client_interceptor.go @@ -56,7 +56,7 @@ func (r interceptedClient) SendRequest(ctx context.Context, addr string, req *ti } } if finalInterceptor != nil { - return finalInterceptor(func(target string, req *tikvrpc.Request) (*tikvrpc.Response, error) { + return finalInterceptor.Wrap(func(target string, req *tikvrpc.Request) (*tikvrpc.Response, error) { return r.Client.SendRequest(ctx, target, req, timeout) })(addr, req) } @@ -102,7 +102,7 @@ func buildResourceControlInterceptor( // Make the request info. reqInfo := resourcecontrol.MakeRequestInfo(req) // Build the interceptor. - return func(next interceptor.RPCInterceptorFunc) interceptor.RPCInterceptorFunc { + interceptFn := func(next interceptor.RPCInterceptorFunc) interceptor.RPCInterceptorFunc { return func(target string, req *tikvrpc.Request) (*tikvrpc.Response, error) { consumption, penalty, err := ResourceControlInterceptor.OnRequestWait(ctx, resourceGroupName, reqInfo) if err != nil { @@ -122,4 +122,5 @@ func buildResourceControlInterceptor( return resp, err } } + return interceptor.NewRPCInterceptor("resource_control", interceptFn) } diff --git a/internal/client/client_interceptor_test.go b/internal/client/client_interceptor_test.go index 88442fe2..128fea5e 100644 --- a/internal/client/client_interceptor_test.go +++ b/internal/client/client_interceptor_test.go @@ -16,6 +16,7 @@ package client import ( "context" + "fmt" "testing" "time" @@ -41,12 +42,55 @@ func (c emptyClient) CloseAddr(addr string) error { func TestInterceptedClient(t *testing.T) { executed := false client := NewInterceptedClient(emptyClient{}, nil) - ctx := interceptor.WithRPCInterceptor(context.Background(), func(next interceptor.RPCInterceptorFunc) interceptor.RPCInterceptorFunc { + ctx := interceptor.WithRPCInterceptor(context.Background(), interceptor.NewRPCInterceptor("test", func(next interceptor.RPCInterceptorFunc) interceptor.RPCInterceptorFunc { return func(target string, req *tikvrpc.Request) (*tikvrpc.Response, error) { executed = true return next(target, req) } - }) + })) _, _ = client.SendRequest(ctx, "", &tikvrpc.Request{}, 0) assert.True(t, executed) } + +func TestAppendChainedInterceptor(t *testing.T) { + executed := make([]int, 0, 10) + client := NewInterceptedClient(emptyClient{}, nil) + + mkInterceptorFn := func(i int) interceptor.RPCInterceptor { + return interceptor.NewRPCInterceptor(fmt.Sprintf("%d", i), func(next interceptor.RPCInterceptorFunc) interceptor.RPCInterceptorFunc { + return func(target string, req *tikvrpc.Request) (*tikvrpc.Response, error) { + executed = append(executed, i) + return next(target, req) + } + }) + } + + checkChained := func(it interceptor.RPCInterceptor, count int, expected []int) { + chain, ok := it.(*interceptor.RPCInterceptorChain) + assert.True(t, ok) + assert.Equal(t, chain.Len(), count) + + executed = executed[:0] + ctx := interceptor.WithRPCInterceptor(context.Background(), it) + _, _ = client.SendRequest(ctx, "", &tikvrpc.Request{}, 0) + assert.Equal(t, executed, expected) + } + + it := mkInterceptorFn(0) + expected := []int{0} + for i := 1; i < 3; i++ { + it = interceptor.ChainRPCInterceptors(it, mkInterceptorFn(i)) + expected = append(expected, i) + checkChained(it, i+1, expected) + } + + it2 := interceptor.ChainRPCInterceptors(mkInterceptorFn(3), mkInterceptorFn(4)) + checkChained(it2, 2, []int{3, 4}) + + chain := interceptor.ChainRPCInterceptors(it, it2) + checkChained(chain, 5, []int{0, 1, 2, 3, 4}) + + // add duplciated + chain = interceptor.ChainRPCInterceptors(chain, mkInterceptorFn(1)) + checkChained(chain, 5, []int{0, 2, 3, 4, 1}) +} diff --git a/tikvrpc/interceptor/interceptor.go b/tikvrpc/interceptor/interceptor.go index 9d7e5515..723cc87c 100644 --- a/tikvrpc/interceptor/interceptor.go +++ b/tikvrpc/interceptor/interceptor.go @@ -42,7 +42,7 @@ import ( // } // } // -// txn.SetRPCInterceptor(LogInterceptor) +// txn.SetRPCInterceptor(NewRPCInterceptor("log", LogInterceptor)) // ``` // // Or you want to inject some dependent modules: @@ -59,7 +59,7 @@ import ( // } // } // -// txn.SetRPCInterceptor(GetLogInterceptor()) +// txn.SetRPCInterceptor(NewRPCInterceptor("log", GetLogInterceptor())) // ``` // // NOTE: Interceptor calls may not correspond one-to-one with the underlying gRPC requests. @@ -69,7 +69,33 @@ import ( // // tikv/kv.go#NewKVStore() // internal/client/client_interceptor.go#SendRequest. -type RPCInterceptor func(next RPCInterceptorFunc) RPCInterceptorFunc +type RPCInterceptor interface { + // Name returns the name of this interceptor + Name() string + // Wrap returns a callable interecpt function. + Wrap(next RPCInterceptorFunc) RPCInterceptorFunc +} + +type rpcInterceptorWrapper struct { + name string + fn func(next RPCInterceptorFunc) RPCInterceptorFunc +} + +func (i *rpcInterceptorWrapper) Name() string { + return i.name +} + +func (i *rpcInterceptorWrapper) Wrap(next RPCInterceptorFunc) RPCInterceptorFunc { + return i.fn(next) +} + +// NewRPCInterceptor build a RPCInterceptor by its name and intercept func. +func NewRPCInterceptor(name string, fn func(next RPCInterceptorFunc) RPCInterceptorFunc) RPCInterceptor { + return &rpcInterceptorWrapper{ + name: name, + fn: fn, + } +} // RPCInterceptorFunc is a callable function used to initiate a request to TiKV. // It is mainly used as the parameter and return value of RPCInterceptor. @@ -80,6 +106,9 @@ type RPCInterceptorFunc func(target string, req *tikvrpc.Request) (*tikvrpc.Resp // similar to the onion model: The earlier the interceptor is executed, the later // it will return. // +// If multiple interceptors with the same name is added to the chain, only the last +// will be kept. +// // We can use RPCInterceptorChain like this: // ``` // @@ -99,7 +128,8 @@ type RPCInterceptorFunc func(target string, req *tikvrpc.Request) (*tikvrpc.Resp // } // } // -// txn.SetRPCInterceptor(NewRPCInterceptorChain().Link(Interceptor1).Link(Interceptor2).Build()) +// txn.SetRPCInterceptor(NewRPCInterceptorChain().Link(NewRPCInterceptor("log1", Interceptor1)).Link(NewRPCInterceptor("log2", Interceptor2)).Build()) +// // ``` // // Then every time an RPC request is initiated, the following text will be printed: @@ -114,6 +144,11 @@ type RPCInterceptorChain struct { chain []RPCInterceptor } +// return the number of sub interceptors, used for test. +func (c *RPCInterceptorChain) Len() int { + return len(c.chain) +} + // NewRPCInterceptorChain creates an empty RPCInterceptorChain. func NewRPCInterceptorChain() *RPCInterceptorChain { return &RPCInterceptorChain{} @@ -121,30 +156,52 @@ func NewRPCInterceptorChain() *RPCInterceptorChain { // Link is used to link the next RPCInterceptor. // Multiple interceptors will be executed in the order of link time. +// If multiple interceptors with the same name is added to the chain, +// only the last is kept. func (c *RPCInterceptorChain) Link(it RPCInterceptor) *RPCInterceptorChain { + if chain, ok := it.(*RPCInterceptorChain); ok { + for _, i := range chain.chain { + c.Link(i) + } + return c + } + for i := range c.chain { + if c.chain[i].Name() == it.Name() { + c.chain = append(c.chain[:i], c.chain[i+1:]...) + break + } + } c.chain = append(c.chain, it) return c } -// Build merges the previously linked interceptors into one. -func (c *RPCInterceptorChain) Build() RPCInterceptor { - return func(next RPCInterceptorFunc) RPCInterceptorFunc { - for n := len(c.chain) - 1; n >= 0; n-- { - next = c.chain[n](next) - } - return next +func (c *RPCInterceptorChain) Name() string { + return "interceptor-chain" +} + +func (c *RPCInterceptorChain) Wrap(next RPCInterceptorFunc) RPCInterceptorFunc { + for n := len(c.chain) - 1; n >= 0; n-- { + next = c.chain[n].Wrap(next) } + return next } // ChainRPCInterceptors chains multiple RPCInterceptors into one. // Multiple RPCInterceptors will be executed in the order of their parameters. // See RPCInterceptorChain for more information. -func ChainRPCInterceptors(its ...RPCInterceptor) RPCInterceptor { - chain := NewRPCInterceptorChain() - for _, it := range its { +func ChainRPCInterceptors(first RPCInterceptor, rest ...RPCInterceptor) RPCInterceptor { + var chain *RPCInterceptorChain + if ch, ok := first.(*RPCInterceptorChain); ok { + chain = ch + } else { + chain = NewRPCInterceptorChain() + chain.Link(first) + } + + for _, it := range rest { chain.Link(it) } - return chain.Build() + return chain } type interceptorCtxKeyType struct{} @@ -182,7 +239,7 @@ func NewMockInterceptorManager() *MockInterceptorManager { // CreateMockInterceptor creates an RPCInterceptor for testing. func (m *MockInterceptorManager) CreateMockInterceptor(name string) RPCInterceptor { - return func(next RPCInterceptorFunc) RPCInterceptorFunc { + fn := func(next RPCInterceptorFunc) RPCInterceptorFunc { return func(target string, req *tikvrpc.Request) (*tikvrpc.Response, error) { m.execLog = append(m.execLog, name) atomic.AddInt32(&m.begin, 1) @@ -190,6 +247,7 @@ func (m *MockInterceptorManager) CreateMockInterceptor(name string) RPCIntercept return next(target, req) } } + return NewRPCInterceptor(name, fn) } // Reset clear all counters. diff --git a/tikvrpc/interceptor/interceptor_test.go b/tikvrpc/interceptor/interceptor_test.go index c54ea711..92530733 100644 --- a/tikvrpc/interceptor/interceptor_test.go +++ b/tikvrpc/interceptor/interceptor_test.go @@ -26,9 +26,8 @@ func TestInterceptor(t *testing.T) { manager := MockInterceptorManager{} it := chain. Link(manager.CreateMockInterceptor("INTERCEPTOR-1")). - Link(manager.CreateMockInterceptor("INTERCEPTOR-2")). - Build() - _, _ = it(func(target string, req *tikvrpc.Request) (*tikvrpc.Response, error) { + Link(manager.CreateMockInterceptor("INTERCEPTOR-2")) + _, _ = it.Wrap(func(target string, req *tikvrpc.Request) (*tikvrpc.Response, error) { return nil, nil })("", nil) assert.Equal(t, 2, manager.BeginCount()) diff --git a/txnkv/txnsnapshot/snapshot.go b/txnkv/txnsnapshot/snapshot.go index b4671d6d..ef293e6b 100644 --- a/txnkv/txnsnapshot/snapshot.go +++ b/txnkv/txnsnapshot/snapshot.go @@ -903,11 +903,12 @@ func (s *KVSnapshot) SetRPCInterceptor(it interceptor.RPCInterceptor) { } // AddRPCInterceptor adds an interceptor, the order of addition is the order of execution. +// the chained interceptors will be dedupcated by its name. func (s *KVSnapshot) AddRPCInterceptor(it interceptor.RPCInterceptor) { s.mu.Lock() defer s.mu.Unlock() if s.mu.interceptor == nil { - s.SetRPCInterceptor(it) + s.mu.interceptor = it return } s.mu.interceptor = interceptor.ChainRPCInterceptors(s.mu.interceptor, it)