diff --git a/config/config.go b/config/config.go index bd534f00..ee2f992a 100644 --- a/config/config.go +++ b/config/config.go @@ -1,5 +1,26 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + package config +import ( + "crypto/tls" + "crypto/x509" + "io/ioutil" + + "github.com/pkg/errors" +) + // Security is SSL configuration. type Security struct { SSLCA string `toml:"ssl-ca" json:"ssl-ca"` @@ -7,5 +28,40 @@ type Security struct { SSLKey string `toml:"ssl-key" json:"ssl-key"` } +// ToTLSConfig generates tls's config based on security section of the config. +func (s *Security) ToTLSConfig() (*tls.Config, error) { + var tlsConfig *tls.Config + if len(s.SSLCA) != 0 { + var certificates = make([]tls.Certificate, 0) + if len(s.SSLCert) != 0 && len(s.SSLKey) != 0 { + // Load the client certificates from disk + certificate, err := tls.LoadX509KeyPair(s.SSLCert, s.SSLKey) + if err != nil { + return nil, errors.Errorf("could not load client key pair: %s", err) + } + certificates = append(certificates, certificate) + } + + // Create a certificate pool from the certificate authority + certPool := x509.NewCertPool() + ca, err := ioutil.ReadFile(s.SSLCA) + if err != nil { + return nil, errors.Errorf("could not read ca certificate: %s", err) + } + + // Append the certificates from the CA + if !certPool.AppendCertsFromPEM(ca) { + return nil, errors.New("failed to append ca certs") + } + + tlsConfig = &tls.Config{ + Certificates: certificates, + RootCAs: certPool, + } + } + + return tlsConfig, nil +} + // EnableOpenTracing is the flag to enable open tracing. var EnableOpenTracing = false diff --git a/go.mod b/go.mod index 4c75c9d0..3522aed6 100644 --- a/go.mod +++ b/go.mod @@ -4,15 +4,18 @@ require ( github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973 // indirect github.com/golang/protobuf v1.2.0 // indirect github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c + github.com/grpc-ecosystem/go-grpc-middleware v1.0.0 + github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect github.com/opentracing/opentracing-go v1.0.2 // indirect github.com/pingcap/errors v0.11.0 github.com/pingcap/kvproto v0.0.0-20181203065228-c14302da291c github.com/pingcap/pd v2.1.0+incompatible - github.com/pkg/errors v0.8.0 // indirect + github.com/pkg/errors v0.8.0 github.com/prometheus/client_golang v0.9.1 github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910 // indirect github.com/prometheus/common v0.0.0-20181126121408-4724e9255275 // indirect github.com/prometheus/procfs v0.0.0-20181129180645-aa55a523dc0a // indirect github.com/sirupsen/logrus v1.2.0 + google.golang.org/grpc v0.0.0-20180607172857-7a6a684ca69e ) diff --git a/go.sum b/go.sum index e0531be5..e657edc8 100644 --- a/go.sum +++ b/go.sum @@ -8,6 +8,10 @@ github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c h1:964Od4U6p2jUkFxvCydnIczKteheJEzHRToSGK3Bnlw= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= +github.com/grpc-ecosystem/go-grpc-middleware v1.0.0 h1:BWIsLfhgKhV5g/oF34aRjniBHLTZe5DNekSjbAjIS6c= +github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= +github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 h1:Ovs26xHkKqVztRpIrF/92BcuyuQ/YW4NSIpoGtfXNho= +github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= github.com/konsorten/go-windows-terminal-sequences v1.0.1 h1:mweAR1A6xJ3oS2pRaGiHgQ4OO8tzTaLawm8vnODuwDk= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= @@ -20,6 +24,7 @@ github.com/pingcap/kvproto v0.0.0-20181203065228-c14302da291c h1:Qf5St5XGwKgKQLa github.com/pingcap/kvproto v0.0.0-20181203065228-c14302da291c/go.mod h1:Ja9XPjot9q4/3JyCZodnWDGNXt4pKemhIYCvVJM7P24= github.com/pingcap/pd v2.1.0+incompatible h1:X0o443C/jXF6yJiiP1xTRxjKPHRe4gwQqjGcFTz1MuU= github.com/pingcap/pd v2.1.0+incompatible/go.mod h1:nD3+EoYes4+aNNODO99ES59V83MZSI+dFbhyr667a0E= +github.com/pingcap/tipb v0.0.0-20170310053819-1043caee48da h1:DYBPt8ui5cxeiUB2rdLz0W2ptwH5cT95XN4aXGo/JEE= github.com/pingcap/tipb v0.0.0-20170310053819-1043caee48da/go.mod h1:RtkHW8WbcNxj8lsbzjaILci01CtYnYbIkQhjyZWrWVI= github.com/pkg/errors v0.8.0 h1:WdK/asTD0HN+q6hsWO3/vpuAkAr+tw6aNJNDFFf0+qw= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= diff --git a/rpc/calls.go b/rpc/calls.go new file mode 100644 index 00000000..4b4dadc0 --- /dev/null +++ b/rpc/calls.go @@ -0,0 +1,558 @@ +// Copyright 2016 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package rpc + +import ( + "context" + "fmt" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/kvproto/pkg/coprocessor" + "github.com/pingcap/kvproto/pkg/errorpb" + "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/kvproto/pkg/tikvpb" +) + +// CmdType represents the concrete request type in Request or response type in Response. +type CmdType uint16 + +// CmdType values. +const ( + CmdGet CmdType = 1 + iota + CmdScan + CmdPrewrite + CmdCommit + CmdCleanup + CmdBatchGet + CmdBatchRollback + CmdScanLock + CmdResolveLock + CmdGC + CmdDeleteRange + + CmdRawGet CmdType = 256 + iota + CmdRawBatchGet + CmdRawPut + CmdRawBatchPut + CmdRawDelete + CmdRawBatchDelete + CmdRawDeleteRange + CmdRawScan + + CmdUnsafeDestroyRange + + CmdCop CmdType = 512 + iota + CmdCopStream + + CmdMvccGetByKey CmdType = 1024 + iota + CmdMvccGetByStartTs + CmdSplitRegion +) + +func (t CmdType) String() string { + switch t { + case CmdGet: + return "Get" + case CmdScan: + return "Scan" + case CmdPrewrite: + return "Prewrite" + case CmdCommit: + return "Commit" + case CmdCleanup: + return "Cleanup" + case CmdBatchGet: + return "BatchGet" + case CmdBatchRollback: + return "BatchRollback" + case CmdScanLock: + return "ScanLock" + case CmdResolveLock: + return "ResolveLock" + case CmdGC: + return "GC" + case CmdDeleteRange: + return "DeleteRange" + case CmdRawGet: + return "RawGet" + case CmdRawBatchGet: + return "RawBatchGet" + case CmdRawPut: + return "RawPut" + case CmdRawBatchPut: + return "RawBatchPut" + case CmdRawDelete: + return "RawDelete" + case CmdRawBatchDelete: + return "RawBatchDelete" + case CmdRawDeleteRange: + return "RawDeleteRange" + case CmdRawScan: + return "RawScan" + case CmdUnsafeDestroyRange: + return "UnsafeDestroyRange" + case CmdCop: + return "Cop" + case CmdCopStream: + return "CopStream" + case CmdMvccGetByKey: + return "MvccGetByKey" + case CmdMvccGetByStartTs: + return "MvccGetByStartTS" + case CmdSplitRegion: + return "SplitRegion" + } + return "Unknown" +} + +// Request wraps all kv/coprocessor requests. +type Request struct { + kvrpcpb.Context + Type CmdType + Get *kvrpcpb.GetRequest + Scan *kvrpcpb.ScanRequest + Prewrite *kvrpcpb.PrewriteRequest + Commit *kvrpcpb.CommitRequest + Cleanup *kvrpcpb.CleanupRequest + BatchGet *kvrpcpb.BatchGetRequest + BatchRollback *kvrpcpb.BatchRollbackRequest + ScanLock *kvrpcpb.ScanLockRequest + ResolveLock *kvrpcpb.ResolveLockRequest + GC *kvrpcpb.GCRequest + DeleteRange *kvrpcpb.DeleteRangeRequest + RawGet *kvrpcpb.RawGetRequest + RawBatchGet *kvrpcpb.RawBatchGetRequest + RawPut *kvrpcpb.RawPutRequest + RawBatchPut *kvrpcpb.RawBatchPutRequest + RawDelete *kvrpcpb.RawDeleteRequest + RawBatchDelete *kvrpcpb.RawBatchDeleteRequest + RawDeleteRange *kvrpcpb.RawDeleteRangeRequest + RawScan *kvrpcpb.RawScanRequest + UnsafeDestroyRange *kvrpcpb.UnsafeDestroyRangeRequest + Cop *coprocessor.Request + MvccGetByKey *kvrpcpb.MvccGetByKeyRequest + MvccGetByStartTs *kvrpcpb.MvccGetByStartTsRequest + SplitRegion *kvrpcpb.SplitRegionRequest +} + +// Response wraps all kv/coprocessor responses. +type Response struct { + Type CmdType + Get *kvrpcpb.GetResponse + Scan *kvrpcpb.ScanResponse + Prewrite *kvrpcpb.PrewriteResponse + Commit *kvrpcpb.CommitResponse + Cleanup *kvrpcpb.CleanupResponse + BatchGet *kvrpcpb.BatchGetResponse + BatchRollback *kvrpcpb.BatchRollbackResponse + ScanLock *kvrpcpb.ScanLockResponse + ResolveLock *kvrpcpb.ResolveLockResponse + GC *kvrpcpb.GCResponse + DeleteRange *kvrpcpb.DeleteRangeResponse + RawGet *kvrpcpb.RawGetResponse + RawBatchGet *kvrpcpb.RawBatchGetResponse + RawPut *kvrpcpb.RawPutResponse + RawBatchPut *kvrpcpb.RawBatchPutResponse + RawDelete *kvrpcpb.RawDeleteResponse + RawBatchDelete *kvrpcpb.RawBatchDeleteResponse + RawDeleteRange *kvrpcpb.RawDeleteRangeResponse + RawScan *kvrpcpb.RawScanResponse + UnsafeDestroyRange *kvrpcpb.UnsafeDestroyRangeResponse + Cop *coprocessor.Response + CopStream *CopStreamResponse + MvccGetByKey *kvrpcpb.MvccGetByKeyResponse + MvccGetByStartTS *kvrpcpb.MvccGetByStartTsResponse + SplitRegion *kvrpcpb.SplitRegionResponse +} + +// CopStreamResponse combinates tikvpb.Tikv_CoprocessorStreamClient and the first Recv() result together. +// In streaming API, get grpc stream client may not involve any network packet, then region error have +// to be handled in Recv() function. This struct facilitates the error handling. +type CopStreamResponse struct { + tikvpb.Tikv_CoprocessorStreamClient + *coprocessor.Response // The first result of Recv() + Timeout time.Duration + Lease // Shared by this object and a background goroutine. +} + +// SetContext set the Context field for the given req to the specified ctx. +func SetContext(req *Request, region *metapb.Region, peer *metapb.Peer) error { + ctx := &req.Context + ctx.RegionId = region.Id + ctx.RegionEpoch = region.RegionEpoch + ctx.Peer = peer + + switch req.Type { + case CmdGet: + req.Get.Context = ctx + case CmdScan: + req.Scan.Context = ctx + case CmdPrewrite: + req.Prewrite.Context = ctx + case CmdCommit: + req.Commit.Context = ctx + case CmdCleanup: + req.Cleanup.Context = ctx + case CmdBatchGet: + req.BatchGet.Context = ctx + case CmdBatchRollback: + req.BatchRollback.Context = ctx + case CmdScanLock: + req.ScanLock.Context = ctx + case CmdResolveLock: + req.ResolveLock.Context = ctx + case CmdGC: + req.GC.Context = ctx + case CmdDeleteRange: + req.DeleteRange.Context = ctx + case CmdRawGet: + req.RawGet.Context = ctx + case CmdRawBatchGet: + req.RawBatchGet.Context = ctx + case CmdRawPut: + req.RawPut.Context = ctx + case CmdRawBatchPut: + req.RawBatchPut.Context = ctx + case CmdRawDelete: + req.RawDelete.Context = ctx + case CmdRawBatchDelete: + req.RawBatchDelete.Context = ctx + case CmdRawDeleteRange: + req.RawDeleteRange.Context = ctx + case CmdRawScan: + req.RawScan.Context = ctx + case CmdUnsafeDestroyRange: + req.UnsafeDestroyRange.Context = ctx + case CmdCop: + req.Cop.Context = ctx + case CmdCopStream: + req.Cop.Context = ctx + case CmdMvccGetByKey: + req.MvccGetByKey.Context = ctx + case CmdMvccGetByStartTs: + req.MvccGetByStartTs.Context = ctx + case CmdSplitRegion: + req.SplitRegion.Context = ctx + default: + return fmt.Errorf("invalid request type %v", req.Type) + } + return nil +} + +// GenRegionErrorResp returns corresponding Response with specified RegionError +// according to the given req. +func GenRegionErrorResp(req *Request, e *errorpb.Error) (*Response, error) { + resp := &Response{} + resp.Type = req.Type + switch req.Type { + case CmdGet: + resp.Get = &kvrpcpb.GetResponse{ + RegionError: e, + } + case CmdScan: + resp.Scan = &kvrpcpb.ScanResponse{ + RegionError: e, + } + case CmdPrewrite: + resp.Prewrite = &kvrpcpb.PrewriteResponse{ + RegionError: e, + } + case CmdCommit: + resp.Commit = &kvrpcpb.CommitResponse{ + RegionError: e, + } + case CmdCleanup: + resp.Cleanup = &kvrpcpb.CleanupResponse{ + RegionError: e, + } + case CmdBatchGet: + resp.BatchGet = &kvrpcpb.BatchGetResponse{ + RegionError: e, + } + case CmdBatchRollback: + resp.BatchRollback = &kvrpcpb.BatchRollbackResponse{ + RegionError: e, + } + case CmdScanLock: + resp.ScanLock = &kvrpcpb.ScanLockResponse{ + RegionError: e, + } + case CmdResolveLock: + resp.ResolveLock = &kvrpcpb.ResolveLockResponse{ + RegionError: e, + } + case CmdGC: + resp.GC = &kvrpcpb.GCResponse{ + RegionError: e, + } + case CmdDeleteRange: + resp.DeleteRange = &kvrpcpb.DeleteRangeResponse{ + RegionError: e, + } + case CmdRawGet: + resp.RawGet = &kvrpcpb.RawGetResponse{ + RegionError: e, + } + case CmdRawBatchGet: + resp.RawBatchGet = &kvrpcpb.RawBatchGetResponse{ + RegionError: e, + } + case CmdRawPut: + resp.RawPut = &kvrpcpb.RawPutResponse{ + RegionError: e, + } + case CmdRawBatchPut: + resp.RawBatchPut = &kvrpcpb.RawBatchPutResponse{ + RegionError: e, + } + case CmdRawDelete: + resp.RawDelete = &kvrpcpb.RawDeleteResponse{ + RegionError: e, + } + case CmdRawBatchDelete: + resp.RawBatchDelete = &kvrpcpb.RawBatchDeleteResponse{ + RegionError: e, + } + case CmdRawDeleteRange: + resp.RawDeleteRange = &kvrpcpb.RawDeleteRangeResponse{ + RegionError: e, + } + case CmdRawScan: + resp.RawScan = &kvrpcpb.RawScanResponse{ + RegionError: e, + } + case CmdUnsafeDestroyRange: + resp.UnsafeDestroyRange = &kvrpcpb.UnsafeDestroyRangeResponse{ + RegionError: e, + } + case CmdCop: + resp.Cop = &coprocessor.Response{ + RegionError: e, + } + case CmdCopStream: + resp.CopStream = &CopStreamResponse{ + Response: &coprocessor.Response{ + RegionError: e, + }, + } + case CmdMvccGetByKey: + resp.MvccGetByKey = &kvrpcpb.MvccGetByKeyResponse{ + RegionError: e, + } + case CmdMvccGetByStartTs: + resp.MvccGetByStartTS = &kvrpcpb.MvccGetByStartTsResponse{ + RegionError: e, + } + case CmdSplitRegion: + resp.SplitRegion = &kvrpcpb.SplitRegionResponse{ + RegionError: e, + } + default: + return nil, fmt.Errorf("invalid request type %v", req.Type) + } + return resp, nil +} + +// GetRegionError returns the RegionError of the underlying concrete response. +func (resp *Response) GetRegionError() (*errorpb.Error, error) { + var e *errorpb.Error + switch resp.Type { + case CmdGet: + e = resp.Get.GetRegionError() + case CmdScan: + e = resp.Scan.GetRegionError() + case CmdPrewrite: + e = resp.Prewrite.GetRegionError() + case CmdCommit: + e = resp.Commit.GetRegionError() + case CmdCleanup: + e = resp.Cleanup.GetRegionError() + case CmdBatchGet: + e = resp.BatchGet.GetRegionError() + case CmdBatchRollback: + e = resp.BatchRollback.GetRegionError() + case CmdScanLock: + e = resp.ScanLock.GetRegionError() + case CmdResolveLock: + e = resp.ResolveLock.GetRegionError() + case CmdGC: + e = resp.GC.GetRegionError() + case CmdDeleteRange: + e = resp.DeleteRange.GetRegionError() + case CmdRawGet: + e = resp.RawGet.GetRegionError() + case CmdRawBatchGet: + e = resp.RawBatchGet.GetRegionError() + case CmdRawPut: + e = resp.RawPut.GetRegionError() + case CmdRawBatchPut: + e = resp.RawBatchPut.GetRegionError() + case CmdRawDelete: + e = resp.RawDelete.GetRegionError() + case CmdRawBatchDelete: + e = resp.RawBatchDelete.GetRegionError() + case CmdRawDeleteRange: + e = resp.RawDeleteRange.GetRegionError() + case CmdRawScan: + e = resp.RawScan.GetRegionError() + case CmdUnsafeDestroyRange: + e = resp.UnsafeDestroyRange.GetRegionError() + case CmdCop: + e = resp.Cop.GetRegionError() + case CmdCopStream: + e = resp.CopStream.Response.GetRegionError() + case CmdMvccGetByKey: + e = resp.MvccGetByKey.GetRegionError() + case CmdMvccGetByStartTs: + e = resp.MvccGetByStartTS.GetRegionError() + case CmdSplitRegion: + e = resp.SplitRegion.GetRegionError() + default: + return nil, fmt.Errorf("invalid response type %v", resp.Type) + } + return e, nil +} + +// CallRPC launches a rpc call. +// ch is needed to implement timeout for coprocessor streaing, the stream object's +// cancel function will be sent to the channel, together with a lease checked by a background goroutine. +func CallRPC(ctx context.Context, client tikvpb.TikvClient, req *Request) (*Response, error) { + resp := &Response{} + resp.Type = req.Type + var err error + switch req.Type { + case CmdGet: + resp.Get, err = client.KvGet(ctx, req.Get) + case CmdScan: + resp.Scan, err = client.KvScan(ctx, req.Scan) + case CmdPrewrite: + resp.Prewrite, err = client.KvPrewrite(ctx, req.Prewrite) + case CmdCommit: + resp.Commit, err = client.KvCommit(ctx, req.Commit) + case CmdCleanup: + resp.Cleanup, err = client.KvCleanup(ctx, req.Cleanup) + case CmdBatchGet: + resp.BatchGet, err = client.KvBatchGet(ctx, req.BatchGet) + case CmdBatchRollback: + resp.BatchRollback, err = client.KvBatchRollback(ctx, req.BatchRollback) + case CmdScanLock: + resp.ScanLock, err = client.KvScanLock(ctx, req.ScanLock) + case CmdResolveLock: + resp.ResolveLock, err = client.KvResolveLock(ctx, req.ResolveLock) + case CmdGC: + resp.GC, err = client.KvGC(ctx, req.GC) + case CmdDeleteRange: + resp.DeleteRange, err = client.KvDeleteRange(ctx, req.DeleteRange) + case CmdRawGet: + resp.RawGet, err = client.RawGet(ctx, req.RawGet) + case CmdRawBatchGet: + resp.RawBatchGet, err = client.RawBatchGet(ctx, req.RawBatchGet) + case CmdRawPut: + resp.RawPut, err = client.RawPut(ctx, req.RawPut) + case CmdRawBatchPut: + resp.RawBatchPut, err = client.RawBatchPut(ctx, req.RawBatchPut) + case CmdRawDelete: + resp.RawDelete, err = client.RawDelete(ctx, req.RawDelete) + case CmdRawBatchDelete: + resp.RawBatchDelete, err = client.RawBatchDelete(ctx, req.RawBatchDelete) + case CmdRawDeleteRange: + resp.RawDeleteRange, err = client.RawDeleteRange(ctx, req.RawDeleteRange) + case CmdRawScan: + resp.RawScan, err = client.RawScan(ctx, req.RawScan) + case CmdUnsafeDestroyRange: + resp.UnsafeDestroyRange, err = client.UnsafeDestroyRange(ctx, req.UnsafeDestroyRange) + case CmdCop: + resp.Cop, err = client.Coprocessor(ctx, req.Cop) + case CmdCopStream: + var streamClient tikvpb.Tikv_CoprocessorStreamClient + streamClient, err = client.CoprocessorStream(ctx, req.Cop) + resp.CopStream = &CopStreamResponse{ + Tikv_CoprocessorStreamClient: streamClient, + } + case CmdMvccGetByKey: + resp.MvccGetByKey, err = client.MvccGetByKey(ctx, req.MvccGetByKey) + case CmdMvccGetByStartTs: + resp.MvccGetByStartTS, err = client.MvccGetByStartTs(ctx, req.MvccGetByStartTs) + case CmdSplitRegion: + resp.SplitRegion, err = client.SplitRegion(ctx, req.SplitRegion) + default: + return nil, errors.Errorf("invalid request type: %v", req.Type) + } + if err != nil { + return nil, errors.Trace(err) + } + return resp, nil +} + +// Lease is used to implement grpc stream timeout. +type Lease struct { + Cancel context.CancelFunc + deadline int64 // A time.UnixNano value, if time.Now().UnixNano() > deadline, cancel() would be called. +} + +// Recv overrides the stream client Recv() function. +func (resp *CopStreamResponse) Recv() (*coprocessor.Response, error) { + deadline := time.Now().Add(resp.Timeout).UnixNano() + atomic.StoreInt64(&resp.Lease.deadline, deadline) + + ret, err := resp.Tikv_CoprocessorStreamClient.Recv() + + atomic.StoreInt64(&resp.Lease.deadline, 0) // Stop the lease check. + return ret, errors.Trace(err) +} + +// Close closes the CopStreamResponse object. +func (resp *CopStreamResponse) Close() { + atomic.StoreInt64(&resp.Lease.deadline, 1) +} + +// CheckStreamTimeoutLoop runs periodically to check is there any stream request timeouted. +// Lease is an object to track stream requests, call this function with "go CheckStreamTimeoutLoop()" +func CheckStreamTimeoutLoop(ch <-chan *Lease) { + ticker := time.NewTicker(200 * time.Millisecond) + defer ticker.Stop() + array := make([]*Lease, 0, 1024) + + for { + select { + case item, ok := <-ch: + if !ok { + // This channel close means goroutine should return. + return + } + array = append(array, item) + case now := <-ticker.C: + array = keepOnlyActive(array, now.UnixNano()) + } + } +} + +// keepOnlyActive removes completed items, call cancel function for timeout items. +func keepOnlyActive(array []*Lease, now int64) []*Lease { + idx := 0 + for i := 0; i < len(array); i++ { + item := array[i] + deadline := atomic.LoadInt64(&item.deadline) + if deadline == 0 || deadline > now { + array[idx] = array[i] + idx++ + } else { + item.Cancel() + } + } + return array[:idx] +} diff --git a/rpc/client.go b/rpc/client.go index ab05c7c5..42b98b20 100644 --- a/rpc/client.go +++ b/rpc/client.go @@ -15,16 +15,20 @@ package rpc import ( "context" + "io" "strconv" "sync" "sync/atomic" "time" + grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" + grpc_opentracing "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing" + grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus" "github.com/pingcap/errors" + "github.com/pingcap/kvproto/pkg/coprocessor" "github.com/pingcap/kvproto/pkg/tikvpb" - "github.com/pingcap/tidb/config" - "github.com/pingcap/tidb/store/tikv/tikvrpc" - "github.com/pingcap/tidb/terror" + log "github.com/sirupsen/logrus" + "github.com/tikv/client-go/config" "github.com/tikv/client-go/metrics" "google.golang.org/grpc" "google.golang.org/grpc/credentials" @@ -55,6 +59,12 @@ var MaxCallMsgSize = 1<<31 - 1 // Timeout durations. const ( dialTimeout = 5 * time.Second + readTimeoutShort = 20 * time.Second // For requests that read/write several key-values. + ReadTimeoutMedium = 60 * time.Second // For requests that may need scan region. + ReadTimeoutLong = 150 * time.Second // For requests that may need scan region multiple times. + GCTimeout = 5 * time.Minute + UnsafeDestroyRangeTimeout = 5 * time.Minute + grpcInitialWindowSize = 1 << 30 grpcInitialConnWindowSize = 1 << 30 ) @@ -65,18 +75,21 @@ type Client interface { // Close should release all data. Close() error // SendRequest sends Request. - SendRequest(ctx context.Context, addr string, req *tikvrpc.Request, timeout time.Duration) (*tikvrpc.Response, error) + SendRequest(ctx context.Context, addr string, req *Request, timeout time.Duration) (*Response, error) } type connArray struct { index uint32 v []*grpc.ClientConn + // Bind with a background goroutine to process coprocessor streaming timeout. + streamTimeout chan *Lease } func newConnArray(maxSize uint, addr string, security config.Security) (*connArray, error) { a := &connArray{ - index: 0, - v: make([]*grpc.ClientConn, maxSize), + index: 0, + v: make([]*grpc.ClientConn, maxSize), + streamTimeout: make(chan *Lease, 1024), } if err := a.Init(addr, security); err != nil { return nil, err @@ -86,7 +99,7 @@ func newConnArray(maxSize uint, addr string, security config.Security) (*connArr func (a *connArray) Init(addr string, security config.Security) error { opt := grpc.WithInsecure() - if len(security.ClusterSSLCA) != 0 { + if len(security.SSLCA) != 0 { tlsConfig, err := security.ToTLSConfig() if err != nil { return errors.Trace(err) @@ -96,8 +109,7 @@ func (a *connArray) Init(addr string, security config.Security) error { unaryInterceptor := grpc_prometheus.UnaryClientInterceptor streamInterceptor := grpc_prometheus.StreamClientInterceptor - cfg := config.GetGlobalConfig() - if cfg.OpenTracing.Enable { + if config.EnableOpenTracing { unaryInterceptor = grpc_middleware.ChainUnaryClient( unaryInterceptor, grpc_opentracing.UnaryClientInterceptor(), @@ -135,6 +147,7 @@ func (a *connArray) Init(addr string, security config.Security) error { } a.v[i] = conn } + go CheckStreamTimeoutLoop(a.streamTimeout) return nil } @@ -147,8 +160,7 @@ func (a *connArray) Get() *grpc.ClientConn { func (a *connArray) Close() { for i, c := range a.v { if c != nil { - err := c.Close() - terror.Log(errors.Trace(err)) + c.Close() a.v[i] = nil } } @@ -221,7 +233,7 @@ func (c *rpcClient) closeConns() { } // SendRequest sends a Request to server and receives Response. -func (c *rpcClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.Request, timeout time.Duration) (*tikvrpc.Response, error) { +func (c *rpcClient) SendRequest(ctx context.Context, addr string, req *Request, timeout time.Duration) (*Response, error) { start := time.Now() reqType := req.Type.String() storeID := strconv.FormatUint(req.Context.GetPeer().GetStoreId(), 10) @@ -234,9 +246,40 @@ func (c *rpcClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R return nil, errors.Trace(err) } client := tikvpb.NewTikvClient(connArray.Get()) - ctx1, cancel := context.WithTimeout(ctx, timeout) - defer cancel() - return tikvrpc.CallRPC(ctx1, client, req) + + if req.Type != CmdCopStream { + ctx1, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + return CallRPC(ctx1, client, req) + } + + // Coprocessor streaming request. + // Use context to support timeout for grpc streaming client. + ctx1, cancel := context.WithCancel(ctx) + resp, err := CallRPC(ctx1, client, req) + if err != nil { + return nil, errors.Trace(err) + } + + // Put the lease object to the timeout channel, so it would be checked periodically. + copStream := resp.CopStream + copStream.Timeout = timeout + copStream.Lease.Cancel = cancel + connArray.streamTimeout <- &copStream.Lease + + // Read the first streaming response to get CopStreamResponse. + // This can make error handling much easier, because SendReq() retry on + // region error automatically. + var first *coprocessor.Response + first, err = copStream.Recv() + if err != nil { + if errors.Cause(err) != io.EOF { + return nil, errors.Trace(err) + } + log.Debug("copstream returns nothing for the request.") + } + copStream.Response = first + return resp, nil } func (c *rpcClient) Close() error { diff --git a/rpc/messages.go b/rpc/messages.go deleted file mode 100644 index 9ee21f81..00000000 --- a/rpc/messages.go +++ /dev/null @@ -1,19 +0,0 @@ -package rpc - -import ( - "reflect" - "unsafe" - - "github.com/pingcap/kvproto/pkg/kvrpcpb" - "github.com/pingcap/kvproto/pkg/metapb" -) - -func setContext(req interface{}, region *metapb.Region, peer *metapb.Peer) { - ctx := kvrpcpb.Context{ - RegionId: region.Id, - RegionEpoch: region.RegionEpoch, - Peer: peer, - } - // Need generics in Go2. - reflect.ValueOf(req).FieldByName("Context").SetPointer(unsafe.Pointer(&ctx)) -}